1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-07-21 08:36:15 +02:00

Compare commits

..

36 Commits

Author SHA1 Message Date
Nikolay Kim
57981ca04a update tests to async handlers 2019-11-22 11:49:35 +06:00
Nikolay Kim
e668acc596 update travis config 2019-11-22 10:13:32 +06:00
Nikolay Kim
512dd2be63 disable rustls support 2019-11-22 07:01:05 +06:00
Nikolay Kim
8683ba8bb0 rename .to_async() to .to() 2019-11-21 21:36:35 +06:00
Nikolay Kim
0b9e3d381b add test with custom connector 2019-11-21 17:36:18 +06:00
Nikolay Kim
1f0577f8d5 cleanup api doc examples 2019-11-21 16:02:17 +06:00
Nikolay Kim
53c5151692 use response instead of result for asyn c handlers 2019-11-21 16:02:17 +06:00
Nikolay Kim
55698f2524 migrade rest of middlewares 2019-11-21 16:02:17 +06:00
Nikolay Kim
471f82f0e0 migrate actix-multipart 2019-11-21 16:02:17 +06:00
Nikolay Kim
60ada97b3d migrate actix-session 2019-11-21 16:02:17 +06:00
Nikolay Kim
0de101bc4d update actix-web-codegen tests 2019-11-21 16:02:17 +06:00
Nikolay Kim
95e2a0ef2e migrate actix-framed 2019-11-21 16:02:17 +06:00
Nikolay Kim
69cadcdedb migrate actix-files 2019-11-21 16:02:17 +06:00
Nikolay Kim
6ac4ac66b9 migrate actix-cors 2019-11-21 16:02:17 +06:00
Nikolay Kim
3646725cf6 migrate actix-identity 2019-11-21 16:02:17 +06:00
Nikolay Kim
ff62facc0d disable unmigrated crates 2019-11-21 16:02:17 +06:00
Nikolay Kim
b510527a9f update awc tests 2019-11-21 16:02:17 +06:00
Nikolay Kim
3127dd4db6 migrate actix-web to std::future 2019-11-21 16:02:17 +06:00
Nikolay Kim
d081e57316 fix h2 client send body 2019-11-21 16:02:17 +06:00
Nikolay Kim
1ffa7d18d3 drop unpin constraint 2019-11-21 16:02:17 +06:00
Nikolay Kim
687884fb94 update test-server tests 2019-11-21 16:02:17 +06:00
Nikolay Kim
5ab29b2e62 migrate awc and test-server to std::future 2019-11-21 16:02:17 +06:00
Nikolay Kim
a6a2d2f444 update ssl impls 2019-11-21 16:02:17 +06:00
Nikolay Kim
9e95efcc16 migrate client to std::future 2019-11-21 16:02:17 +06:00
Nikolay Kim
8cba1170e6 make actix-http compile with std::future 2019-11-21 16:02:17 +06:00
Nikolay Kim
5cb2d500d1 update actix-web-actors 2019-11-14 08:58:24 +06:00
Nikolay Kim
0212c618c6 prepare actix-web release 2019-11-14 08:55:37 +06:00
Feiko Nanninga
88110ed268 Add security note to ConnectionInfo::remote() (#1158) 2019-11-14 08:32:47 +06:00
Nikolay Kim
fba02fdd8c prep awc release 2019-11-06 11:33:25 -08:00
Nikolay Kim
b2934ad8d2 prep actix-file release 2019-11-06 11:25:26 -08:00
Nikolay Kim
f7f410d033 fix test order dep 2019-11-06 11:20:47 -08:00
Nikolay Kim
885ff7396e prepare actox-http release 2019-11-06 10:35:13 -08:00
Erlend Langseth
61b38e8d0d Increase timeouts in test-server (#1153) 2019-11-06 06:09:22 -08:00
Hung-I Wang
edcde67076 Fix escaping/encoding problems in Content-Disposition header (#1151)
* Fix filename encoding in Content-Disposition of acitx_files::NamedFile

* Add more comments on how to use Content-Disposition header properly & Fix some trivial problems

* Improve Content-Disposition filename(*) parameters of actix_files::NamedFile

* Tweak Content-Disposition parse to accept empty param value in quoted-string

* Fix typos in comments in .../content_disposition.rs (pointed out by @JohnTitor)

* Update CHANGES.md

* Update CHANGES.md again
2019-11-06 06:08:37 -08:00
Jonathas Conceição
f0612f7570 awc: Add support for setting query from Serialize type for client request (#1130)
Signed-off-by: Jonathas-Conceicao <jadoliveira@inf.ufpel.edu.br>
2019-10-26 08:27:14 +03:00
Anton Lazarev
ace98e3a1e support Host guards when Host header is unset (#1129) 2019-10-15 05:05:54 +06:00
132 changed files with 10750 additions and 9669 deletions

View File

@@ -10,9 +10,9 @@ matrix:
include: include:
- rust: stable - rust: stable
- rust: beta - rust: beta
- rust: nightly-2019-08-10 - rust: nightly-2019-11-20
allow_failures: allow_failures:
- rust: nightly-2019-08-10 - rust: nightly-2019-11-20
env: env:
global: global:
@@ -25,7 +25,7 @@ before_install:
- sudo apt-get install -y openssl libssl-dev libelf-dev libdw-dev cmake gcc binutils-dev libiberty-dev - sudo apt-get install -y openssl libssl-dev libelf-dev libdw-dev cmake gcc binutils-dev libiberty-dev
before_cache: | before_cache: |
if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-08-10" ]]; then if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-11-20" ]]; then
RUSTFLAGS="--cfg procmacro2_semver_exempt" cargo install --version 0.6.11 cargo-tarpaulin RUSTFLAGS="--cfg procmacro2_semver_exempt" cargo install --version 0.6.11 cargo-tarpaulin
fi fi
@@ -37,8 +37,8 @@ script:
- cargo update - cargo update
- cargo check --all --no-default-features - cargo check --all --no-default-features
- cargo test --all-features --all -- --nocapture - cargo test --all-features --all -- --nocapture
- cd actix-http; cargo test --no-default-features --features="rust-tls" -- --nocapture; cd .. # - cd actix-http; cargo test --no-default-features --features="rustls" -- --nocapture; cd ..
- cd awc; cargo test --no-default-features --features="rust-tls" -- --nocapture; cd .. # - cd awc; cargo test --no-default-features --features="rustls" -- --nocapture; cd ..
# Upload docs # Upload docs
after_success: after_success:
@@ -51,7 +51,7 @@ after_success:
echo "Uploaded documentation" echo "Uploaded documentation"
fi fi
- | - |
if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-08-10" ]]; then if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-11-20" ]]; then
taskset -c 0 cargo tarpaulin --out Xml --all --all-features taskset -c 0 cargo tarpaulin --out Xml --all --all-features
bash <(curl -s https://codecov.io/bash) bash <(curl -s https://codecov.io/bash)
echo "Uploaded code coverage" echo "Uploaded code coverage"

View File

@@ -1,11 +1,16 @@
# Changes # Changes
## [1.0.9] - 2019-xx-xx ## [1.0.9] - 2019-11-14
### Added ### Added
* Add `Payload::into_inner` method and make stored `def::Payload` public. (#1110) * Add `Payload::into_inner` method and make stored `def::Payload` public. (#1110)
### Changed
* Support `Host` guards when the `Host` header is unset (e.g. HTTP/2 requests) (#1129)
## [1.0.8] - 2019-09-25 ## [1.0.8] - 2019-09-25
### Added ### Added

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-web" name = "actix-web"
version = "1.0.8" version = "2.0.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix web is a simple, pragmatic and extremely fast web framework for Rust." description = "Actix web is a simple, pragmatic and extremely fast web framework for Rust."
readme = "README.md" readme = "README.md"
@@ -16,7 +16,7 @@ exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"]
edition = "2018" edition = "2018"
[package.metadata.docs.rs] [package.metadata.docs.rs]
features = ["ssl", "brotli", "flate2-zlib", "secure-cookies", "client", "rust-tls", "uds"] features = ["openssl", "rustls", "brotli", "flate2-zlib", "secure-cookies", "client"]
[badges] [badges]
travis-ci = { repository = "actix/actix-web", branch = "master" } travis-ci = { repository = "actix/actix-web", branch = "master" }
@@ -63,37 +63,35 @@ secure-cookies = ["actix-http/secure-cookies"]
fail = ["actix-http/fail"] fail = ["actix-http/fail"]
# openssl # openssl
ssl = ["openssl", "actix-server/ssl", "awc/ssl"] openssl = ["open-ssl", "actix-server/openssl", "awc/openssl"]
# rustls # rustls
rust-tls = ["rustls", "actix-server/rust-tls", "awc/rust-tls"] # rustls = ["rust-tls", "actix-server/rustls", "awc/rustls"]
# unix domain sockets support
uds = ["actix-server/uds"]
[dependencies] [dependencies]
actix-codec = "0.1.2" actix-codec = "0.2.0-alpha.1"
actix-service = "0.4.1" actix-service = "1.0.0-alpha.1"
actix-utils = "0.4.4" actix-utils = "0.5.0-alpha.1"
actix-router = "0.1.5" actix-router = "0.1.5"
actix-rt = "0.2.4" actix-rt = "1.0.0-alpha.1"
actix-web-codegen = "0.1.2" actix-web-codegen = "0.2.0-alpha.1"
actix-http = "0.2.9" actix-http = "0.3.0-alpha.1"
actix-server = "0.6.1" actix-server = "0.8.0-alpha.1"
actix-server-config = "0.1.2" actix-server-config = "0.3.0-alpha.1"
actix-testing = "0.1.0" actix-testing = "0.3.0-alpha.1"
actix-threadpool = "0.1.1" actix-threadpool = "0.2.0-alpha.1"
awc = { version = "0.2.7", optional = true } awc = { version = "0.3.0-alpha.1", optional = true }
bytes = "0.4" bytes = "0.4"
derive_more = "0.15.0" derive_more = "0.15.0"
encoding_rs = "0.8" encoding_rs = "0.8"
futures = "0.1.25" futures = "0.3.1"
hashbrown = "0.5.0" hashbrown = "0.6.3"
log = "0.4" log = "0.4"
mime = "0.3" mime = "0.3"
net2 = "0.2.33" net2 = "0.2.33"
parking_lot = "0.9" parking_lot = "0.9"
pin-project = "0.4.5"
regex = "1.0" regex = "1.0"
serde = { version = "1.0", features=["derive"] } serde = { version = "1.0", features=["derive"] }
serde_json = "1.0" serde_json = "1.0"
@@ -102,17 +100,17 @@ time = "0.1.42"
url = "2.1" url = "2.1"
# ssl support # ssl support
openssl = { version="0.10", optional = true } open-ssl = { version="0.10", package="openssl", optional = true }
rustls = { version = "0.15", optional = true } rust-tls = { version = "0.16", package="rustls", optional = true }
[dev-dependencies] [dev-dependencies]
actix = "0.8.3" # actix = "0.8.3"
actix-connect = "0.2.2" actix-connect = "0.3.0-alpha.1"
actix-http-test = "0.2.4" actix-http-test = "0.3.0-alpha.1"
rand = "0.7" rand = "0.7"
env_logger = "0.6" env_logger = "0.6"
serde_derive = "1.0" serde_derive = "1.0"
tokio-timer = "0.2.8" tokio-timer = "0.3.0-alpha.6"
brotli2 = "0.3.2" brotli2 = "0.3.2"
flate2 = "1.0.2" flate2 = "1.0.2"
@@ -126,8 +124,28 @@ actix-web = { path = "." }
actix-http = { path = "actix-http" } actix-http = { path = "actix-http" }
actix-http-test = { path = "test-server" } actix-http-test = { path = "test-server" }
actix-web-codegen = { path = "actix-web-codegen" } actix-web-codegen = { path = "actix-web-codegen" }
actix-web-actors = { path = "actix-web-actors" } # actix-web-actors = { path = "actix-web-actors" }
actix-session = { path = "actix-session" } actix-session = { path = "actix-session" }
actix-files = { path = "actix-files" } actix-files = { path = "actix-files" }
actix-multipart = { path = "actix-multipart" } actix-multipart = { path = "actix-multipart" }
awc = { path = "awc" } awc = { path = "awc" }
actix-codec = { git = "https://github.com/actix/actix-net.git" }
actix-connect = { git = "https://github.com/actix/actix-net.git" }
actix-rt = { git = "https://github.com/actix/actix-net.git" }
actix-server = { git = "https://github.com/actix/actix-net.git" }
actix-server-config = { git = "https://github.com/actix/actix-net.git" }
actix-service = { git = "https://github.com/actix/actix-net.git" }
actix-testing = { git = "https://github.com/actix/actix-net.git" }
actix-threadpool = { git = "https://github.com/actix/actix-net.git" }
actix-utils = { git = "https://github.com/actix/actix-net.git" }
# actix-codec = { path = "../actix-net/actix-codec" }
# actix-connect = { path = "../actix-net/actix-connect" }
# actix-rt = { path = "../actix-net/actix-rt" }
# actix-server = { path = "../actix-net/actix-server" }
# actix-server-config = { path = "../actix-net/actix-server-config" }
# actix-service = { path = "../actix-net/actix-service" }
# actix-testing = { path = "../actix-net/actix-testing" }
# actix-threadpool = { path = "../actix-net/actix-threadpool" }
# actix-utils = { path = "../actix-net/actix-utils" }

View File

@@ -1,3 +1,10 @@
## 2.0.0
* Sync handlers has been removed. `.to_async()` methtod has been renamed to `.to()`
replace `fn` with `async fn` to convert sync handler to async
## 1.0.1 ## 1.0.1
* Cors middleware has been moved to `actix-cors` crate * Cors middleware has been moved to `actix-cors` crate

View File

@@ -19,17 +19,16 @@ Actix web is a simple, pragmatic and extremely fast web framework for Rust.
* [User Guide](https://actix.rs/docs/) * [User Guide](https://actix.rs/docs/)
* [API Documentation (1.0)](https://docs.rs/actix-web/) * [API Documentation (1.0)](https://docs.rs/actix-web/)
* [API Documentation (0.7)](https://docs.rs/actix-web/0.7.19/actix_web/)
* [Chat on gitter](https://gitter.im/actix/actix) * [Chat on gitter](https://gitter.im/actix/actix)
* Cargo package: [actix-web](https://crates.io/crates/actix-web) * Cargo package: [actix-web](https://crates.io/crates/actix-web)
* Minimum supported Rust version: 1.36 or later * Minimum supported Rust version: 1.39 or later
## Example ## Example
```rust ```rust
use actix_web::{web, App, HttpServer, Responder}; use actix_web::{web, App, HttpServer, Responder};
fn index(info: web::Path<(u32, String)>) -> impl Responder { async fn index(info: web::Path<(u32, String)>) -> impl Responder {
format!("Hello {}! id:{}", info.1, info.0) format!("Hello {}! id:{}", info.1, info.0)
} }

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-cors" name = "actix-cors"
version = "0.1.0" version = "0.2.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Cross-origin resource sharing (CORS) for Actix applications." description = "Cross-origin resource sharing (CORS) for Actix applications."
readme = "README.md" readme = "README.md"
@@ -10,14 +10,14 @@ repository = "https://github.com/actix/actix-web.git"
documentation = "https://docs.rs/actix-cors/" documentation = "https://docs.rs/actix-cors/"
license = "MIT/Apache-2.0" license = "MIT/Apache-2.0"
edition = "2018" edition = "2018"
#workspace = ".." workspace = ".."
[lib] [lib]
name = "actix_cors" name = "actix_cors"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = "1.0.0" actix-web = "2.0.0-alpha.1"
actix-service = "0.4.0" actix-service = "1.0.0-alpha.1"
derive_more = "0.15.0" derive_more = "0.15.0"
futures = "0.1.25" futures = "0.3.1"

View File

@@ -11,7 +11,7 @@
//! use actix_cors::Cors; //! use actix_cors::Cors;
//! use actix_web::{http, web, App, HttpRequest, HttpResponse, HttpServer}; //! use actix_web::{http, web, App, HttpRequest, HttpResponse, HttpServer};
//! //!
//! fn index(req: HttpRequest) -> &'static str { //! async fn index(req: HttpRequest) -> &'static str {
//! "Hello world" //! "Hello world"
//! } //! }
//! //!
@@ -23,7 +23,8 @@
//! .allowed_methods(vec!["GET", "POST"]) //! .allowed_methods(vec!["GET", "POST"])
//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) //! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT])
//! .allowed_header(http::header::CONTENT_TYPE) //! .allowed_header(http::header::CONTENT_TYPE)
//! .max_age(3600)) //! .max_age(3600)
//! .finish())
//! .service( //! .service(
//! web::resource("/index.html") //! web::resource("/index.html")
//! .route(web::get().to(index)) //! .route(web::get().to(index))
@@ -41,16 +42,16 @@
use std::collections::HashSet; use std::collections::HashSet;
use std::iter::FromIterator; use std::iter::FromIterator;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::{IntoTransform, Service, Transform}; use actix_service::{Service, Transform};
use actix_web::dev::{RequestHead, ServiceRequest, ServiceResponse}; use actix_web::dev::{RequestHead, ServiceRequest, ServiceResponse};
use actix_web::error::{Error, ResponseError, Result}; use actix_web::error::{Error, ResponseError, Result};
use actix_web::http::header::{self, HeaderName, HeaderValue}; use actix_web::http::header::{self, HeaderName, HeaderValue};
use actix_web::http::{self, HttpTryFrom, Method, StatusCode, Uri}; use actix_web::http::{self, HttpTryFrom, Method, StatusCode, Uri};
use actix_web::HttpResponse; use actix_web::HttpResponse;
use derive_more::Display; use derive_more::Display;
use futures::future::{ok, Either, Future, FutureResult}; use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready};
use futures::Poll;
/// A set of errors that can occur during processing CORS /// A set of errors that can occur during processing CORS
#[derive(Debug, Display)] #[derive(Debug, Display)]
@@ -456,25 +457,9 @@ impl Cors {
} }
self self
} }
}
fn cors<'a>( /// Construct cors middleware
parts: &'a mut Option<Inner>, pub fn finish(self) -> CorsFactory {
err: &Option<http::Error>,
) -> Option<&'a mut Inner> {
if err.is_some() {
return None;
}
parts.as_mut()
}
impl<S, B> IntoTransform<CorsFactory, S> for Cors
where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
fn into_transform(self) -> CorsFactory {
let mut slf = if !self.methods { let mut slf = if !self.methods {
self.allowed_methods(vec![ self.allowed_methods(vec![
Method::GET, Method::GET,
@@ -521,6 +506,16 @@ where
} }
} }
fn cors<'a>(
parts: &'a mut Option<Inner>,
err: &Option<http::Error>,
) -> Option<&'a mut Inner> {
if err.is_some() {
return None;
}
parts.as_mut()
}
/// `Middleware` for Cross-origin resource sharing support /// `Middleware` for Cross-origin resource sharing support
/// ///
/// The Cors struct contains the settings for CORS requests to be validated and /// The Cors struct contains the settings for CORS requests to be validated and
@@ -540,7 +535,7 @@ where
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Transform = CorsMiddleware<S>; type Transform = CorsMiddleware<S>;
type Future = FutureResult<Self::Transform, Self::InitError>; type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
ok(CorsMiddleware { ok(CorsMiddleware {
@@ -682,12 +677,12 @@ where
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = Error; type Error = Error;
type Future = Either< type Future = Either<
FutureResult<Self::Response, Error>, Ready<Result<Self::Response, Error>>,
Either<S::Future, Box<dyn Future<Item = Self::Response, Error = Error>>>, LocalBoxFuture<'static, Result<Self::Response, Error>>,
>; >;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready() self.service.poll_ready(cx)
} }
fn call(&mut self, req: ServiceRequest) -> Self::Future { fn call(&mut self, req: ServiceRequest) -> Self::Future {
@@ -698,7 +693,7 @@ where
.and_then(|_| self.inner.validate_allowed_method(req.head())) .and_then(|_| self.inner.validate_allowed_method(req.head()))
.and_then(|_| self.inner.validate_allowed_headers(req.head())) .and_then(|_| self.inner.validate_allowed_headers(req.head()))
{ {
return Either::A(ok(req.error_response(e))); return Either::Left(ok(req.error_response(e)));
} }
// allowed headers // allowed headers
@@ -751,39 +746,50 @@ where
.finish() .finish()
.into_body(); .into_body();
Either::A(ok(req.into_response(res))) Either::Left(ok(req.into_response(res)))
} else if req.headers().contains_key(&header::ORIGIN) { } else {
// Only check requests with a origin header. if req.headers().contains_key(&header::ORIGIN) {
if let Err(e) = self.inner.validate_origin(req.head()) { // Only check requests with a origin header.
return Either::A(ok(req.error_response(e))); if let Err(e) = self.inner.validate_origin(req.head()) {
return Either::Left(ok(req.error_response(e)));
}
} }
let inner = self.inner.clone(); let inner = self.inner.clone();
let has_origin = req.headers().contains_key(&header::ORIGIN);
let fut = self.service.call(req);
Either::B(Either::B(Box::new(self.service.call(req).and_then( Either::Right(
move |mut res| { async move {
if let Some(origin) = let res = fut.await;
inner.access_control_allow_origin(res.request().head())
{
res.headers_mut()
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
};
if let Some(ref expose) = inner.expose_hdrs { if has_origin {
res.headers_mut().insert( let mut res = res?;
header::ACCESS_CONTROL_EXPOSE_HEADERS, if let Some(origin) =
HeaderValue::try_from(expose.as_str()).unwrap(), inner.access_control_allow_origin(res.request().head())
); {
} res.headers_mut().insert(
if inner.supports_credentials { header::ACCESS_CONTROL_ALLOW_ORIGIN,
res.headers_mut().insert( origin.clone(),
header::ACCESS_CONTROL_ALLOW_CREDENTIALS, );
HeaderValue::from_static("true"), };
);
} if let Some(ref expose) = inner.expose_hdrs {
if inner.vary_header { res.headers_mut().insert(
let value = header::ACCESS_CONTROL_EXPOSE_HEADERS,
if let Some(hdr) = res.headers_mut().get(&header::VARY) { HeaderValue::try_from(expose.as_str()).unwrap(),
);
}
if inner.supports_credentials {
res.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
);
}
if inner.vary_header {
let value = if let Some(hdr) =
res.headers_mut().get(&header::VARY)
{
let mut val: Vec<u8> = let mut val: Vec<u8> =
Vec::with_capacity(hdr.as_bytes().len() + 8); Vec::with_capacity(hdr.as_bytes().len() + 8);
val.extend(hdr.as_bytes()); val.extend(hdr.as_bytes());
@@ -792,159 +798,153 @@ where
} else { } else {
HeaderValue::from_static("Origin") HeaderValue::from_static("Origin")
}; };
res.headers_mut().insert(header::VARY, value); res.headers_mut().insert(header::VARY, value);
}
Ok(res)
} else {
res
} }
Ok(res) }
}, .boxed_local(),
)))) )
} else {
Either::B(Either::A(self.service.call(req)))
} }
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use actix_service::{IntoService, Transform}; use actix_service::{service_fn2, Transform};
use actix_web::test::{self, block_on, TestRequest}; use actix_web::test::{self, block_on, TestRequest};
use super::*; use super::*;
impl Cors {
fn finish<F, S, B>(self, srv: F) -> CorsMiddleware<S>
where
F: IntoService<S>,
S: Service<
Request = ServiceRequest,
Response = ServiceResponse<B>,
Error = Error,
> + 'static,
S::Future: 'static,
B: 'static,
{
block_on(
IntoTransform::<CorsFactory, S>::into_transform(self)
.new_transform(srv.into_service()),
)
.unwrap()
}
}
#[test] #[test]
#[should_panic(expected = "Credentials are allowed, but the Origin is set to")] #[should_panic(expected = "Credentials are allowed, but the Origin is set to")]
fn cors_validates_illegal_allow_credentials() { fn cors_validates_illegal_allow_credentials() {
let _cors = Cors::new() let _cors = Cors::new().supports_credentials().send_wildcard().finish();
.supports_credentials()
.send_wildcard()
.finish(test::ok_service());
} }
#[test] #[test]
fn validate_origin_allows_all_origins() { fn validate_origin_allows_all_origins() {
let mut cors = Cors::new().finish(test::ok_service()); block_on(async {
let req = TestRequest::with_header("Origin", "https://www.example.com") let mut cors = Cors::new()
.to_srv_request(); .finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com")
.to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
})
} }
#[test] #[test]
fn default() { fn default() {
let mut cors = block_on(async {
block_on(Cors::default().new_transform(test::ok_service())).unwrap(); let mut cors = Cors::default()
let req = TestRequest::with_header("Origin", "https://www.example.com") .new_transform(test::ok_service())
.to_srv_request(); .await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com")
.to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
})
} }
#[test] #[test]
fn test_preflight() { fn test_preflight() {
let mut cors = Cors::new() block_on(async {
.send_wildcard() let mut cors = Cors::new()
.max_age(3600) .send_wildcard()
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) .max_age(3600)
.allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_header(header::CONTENT_TYPE) .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
.finish(test::ok_service()); .allowed_header(header::CONTENT_TYPE)
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed") .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed")
.to_srv_request(); .to_srv_request();
assert!(cors.inner.validate_allowed_method(req.head()).is_err()); assert!(cors.inner.validate_allowed_method(req.head()).is_err());
assert!(cors.inner.validate_allowed_headers(req.head()).is_err()); assert!(cors.inner.validate_allowed_headers(req.head()).is_err());
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "put") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
assert!(cors.inner.validate_allowed_method(req.head()).is_err()); assert!(cors.inner.validate_allowed_method(req.head()).is_err());
assert!(cors.inner.validate_allowed_headers(req.head()).is_ok()); assert!(cors.inner.validate_allowed_headers(req.head()).is_ok());
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
.header( .header(
header::ACCESS_CONTROL_REQUEST_HEADERS, header::ACCESS_CONTROL_REQUEST_HEADERS,
"AUTHORIZATION,ACCEPT", "AUTHORIZATION,ACCEPT",
) )
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!( assert_eq!(
&b"*"[..], &b"*"[..],
resp.headers() resp.headers()
.get(&header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(&header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.as_bytes()
);
assert_eq!(
&b"3600"[..],
resp.headers()
.get(&header::ACCESS_CONTROL_MAX_AGE)
.unwrap()
.as_bytes()
);
let hdr = resp
.headers()
.get(&header::ACCESS_CONTROL_ALLOW_HEADERS)
.unwrap() .unwrap()
.as_bytes() .to_str()
); .unwrap();
assert_eq!( assert!(hdr.contains("authorization"));
&b"3600"[..], assert!(hdr.contains("accept"));
resp.headers() assert!(hdr.contains("content-type"));
.get(&header::ACCESS_CONTROL_MAX_AGE)
let methods = resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_METHODS)
.unwrap() .unwrap()
.as_bytes() .to_str()
); .unwrap();
let hdr = resp assert!(methods.contains("POST"));
.headers() assert!(methods.contains("GET"));
.get(&header::ACCESS_CONTROL_ALLOW_HEADERS) assert!(methods.contains("OPTIONS"));
.unwrap()
.to_str()
.unwrap();
assert!(hdr.contains("authorization"));
assert!(hdr.contains("accept"));
assert!(hdr.contains("content-type"));
let methods = resp Rc::get_mut(&mut cors.inner).unwrap().preflight = false;
.headers()
.get(header::ACCESS_CONTROL_ALLOW_METHODS)
.unwrap()
.to_str()
.unwrap();
assert!(methods.contains("POST"));
assert!(methods.contains("GET"));
assert!(methods.contains("OPTIONS"));
Rc::get_mut(&mut cors.inner).unwrap().preflight = false; let req = TestRequest::with_header("Origin", "https://www.example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
.header(
header::ACCESS_CONTROL_REQUEST_HEADERS,
"AUTHORIZATION,ACCEPT",
)
.method(Method::OPTIONS)
.to_srv_request();
let req = TestRequest::with_header("Origin", "https://www.example.com") let resp = test::call_service(&mut cors, req).await;
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") assert_eq!(resp.status(), StatusCode::OK);
.header( })
header::ACCESS_CONTROL_REQUEST_HEADERS,
"AUTHORIZATION,ACCEPT",
)
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req);
assert_eq!(resp.status(), StatusCode::OK);
} }
// #[test] // #[test]
@@ -960,216 +960,254 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "OriginNotAllowed")] #[should_panic(expected = "OriginNotAllowed")]
fn test_validate_not_allowed_origin() { fn test_validate_not_allowed_origin() {
let cors = Cors::new() block_on(async {
.allowed_origin("https://www.example.com") let cors = Cors::new()
.finish(test::ok_service()); .allowed_origin("https://www.example.com")
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.unknown.com") let req = TestRequest::with_header("Origin", "https://www.unknown.com")
.method(Method::GET) .method(Method::GET)
.to_srv_request(); .to_srv_request();
cors.inner.validate_origin(req.head()).unwrap(); cors.inner.validate_origin(req.head()).unwrap();
cors.inner.validate_allowed_method(req.head()).unwrap(); cors.inner.validate_allowed_method(req.head()).unwrap();
cors.inner.validate_allowed_headers(req.head()).unwrap(); cors.inner.validate_allowed_headers(req.head()).unwrap();
})
} }
#[test] #[test]
fn test_validate_origin() { fn test_validate_origin() {
let mut cors = Cors::new() block_on(async {
.allowed_origin("https://www.example.com") let mut cors = Cors::new()
.finish(test::ok_service()); .allowed_origin("https://www.example.com")
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::GET) .method(Method::GET)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
})
} }
#[test] #[test]
fn test_no_origin_response() { fn test_no_origin_response() {
let mut cors = Cors::new().disable_preflight().finish(test::ok_service()); block_on(async {
let mut cors = Cors::new()
.disable_preflight()
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::default().method(Method::GET).to_srv_request(); let req = TestRequest::default().method(Method::GET).to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert!(resp assert!(resp
.headers() .headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.is_none());
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req);
assert_eq!(
&b"https://www.example.com"[..],
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap() .is_none());
.as_bytes()
); let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req).await;
assert_eq!(
&b"https://www.example.com"[..],
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.as_bytes()
);
})
} }
#[test] #[test]
fn test_response() { fn test_response() {
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; block_on(async {
let mut cors = Cors::new() let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
.send_wildcard() let mut cors = Cors::new()
.disable_preflight() .send_wildcard()
.max_age(3600) .disable_preflight()
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) .max_age(3600)
.allowed_headers(exposed_headers.clone()) .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.expose_headers(exposed_headers.clone()) .allowed_headers(exposed_headers.clone())
.allowed_header(header::CONTENT_TYPE) .expose_headers(exposed_headers.clone())
.finish(test::ok_service()); .allowed_header(header::CONTENT_TYPE)
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!( assert_eq!(
&b"*"[..], &b"*"[..],
resp.headers() resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.as_bytes()
);
assert_eq!(
&b"Origin"[..],
resp.headers().get(header::VARY).unwrap().as_bytes()
);
{
let headers = resp
.headers()
.get(header::ACCESS_CONTROL_EXPOSE_HEADERS)
.unwrap()
.to_str()
.unwrap()
.split(',')
.map(|s| s.trim())
.collect::<Vec<&str>>();
for h in exposed_headers {
assert!(headers.contains(&h.as_str()));
}
}
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
let mut cors = Cors::new()
.send_wildcard()
.disable_preflight()
.max_age(3600)
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_headers(exposed_headers.clone())
.expose_headers(exposed_headers.clone())
.allowed_header(header::CONTENT_TYPE)
.finish()
.new_transform(service_fn2(|req: ServiceRequest| {
ok(req.into_response(
HttpResponse::Ok().header(header::VARY, "Accept").finish(),
))
}))
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req).await;
assert_eq!(
&b"Accept, Origin"[..],
resp.headers().get(header::VARY).unwrap().as_bytes()
);
let mut cors = Cors::new()
.disable_vary_header()
.allowed_origin("https://www.example.com")
.allowed_origin("https://www.google.com")
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
.to_srv_request();
let resp = test::call_service(&mut cors, req).await;
let origins_str = resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap() .unwrap()
.as_bytes()
);
assert_eq!(
&b"Origin"[..],
resp.headers().get(header::VARY).unwrap().as_bytes()
);
{
let headers = resp
.headers()
.get(header::ACCESS_CONTROL_EXPOSE_HEADERS)
.unwrap()
.to_str() .to_str()
.unwrap() .unwrap();
.split(',')
.map(|s| s.trim())
.collect::<Vec<&str>>();
for h in exposed_headers { assert_eq!("https://www.example.com", origins_str);
assert!(headers.contains(&h.as_str())); })
}
}
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
let mut cors = Cors::new()
.send_wildcard()
.disable_preflight()
.max_age(3600)
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_headers(exposed_headers.clone())
.expose_headers(exposed_headers.clone())
.allowed_header(header::CONTENT_TYPE)
.finish(|req: ServiceRequest| {
req.into_response(
HttpResponse::Ok().header(header::VARY, "Accept").finish(),
)
});
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req);
assert_eq!(
&b"Accept, Origin"[..],
resp.headers().get(header::VARY).unwrap().as_bytes()
);
let mut cors = Cors::new()
.disable_vary_header()
.allowed_origin("https://www.example.com")
.allowed_origin("https://www.google.com")
.finish(test::ok_service());
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
.to_srv_request();
let resp = test::call_service(&mut cors, req);
let origins_str = resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.to_str()
.unwrap();
assert_eq!("https://www.example.com", origins_str);
} }
#[test] #[test]
fn test_multiple_origins() { fn test_multiple_origins() {
let mut cors = Cors::new() block_on(async {
.allowed_origin("https://example.com") let mut cors = Cors::new()
.allowed_origin("https://example.org") .allowed_origin("https://example.com")
.allowed_methods(vec![Method::GET]) .allowed_origin("https://example.org")
.finish(test::ok_service()); .allowed_methods(vec![Method::GET])
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://example.com") let req = TestRequest::with_header("Origin", "https://example.com")
.method(Method::GET) .method(Method::GET)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!( assert_eq!(
&b"https://example.com"[..], &b"https://example.com"[..],
resp.headers() resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap() .unwrap()
.as_bytes() .as_bytes()
); );
let req = TestRequest::with_header("Origin", "https://example.org") let req = TestRequest::with_header("Origin", "https://example.org")
.method(Method::GET) .method(Method::GET)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!( assert_eq!(
&b"https://example.org"[..], &b"https://example.org"[..],
resp.headers() resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap() .unwrap()
.as_bytes() .as_bytes()
); );
})
} }
#[test] #[test]
fn test_multiple_origins_preflight() { fn test_multiple_origins_preflight() {
let mut cors = Cors::new() block_on(async {
.allowed_origin("https://example.com") let mut cors = Cors::new()
.allowed_origin("https://example.org") .allowed_origin("https://example.com")
.allowed_methods(vec![Method::GET]) .allowed_origin("https://example.org")
.finish(test::ok_service()); .allowed_methods(vec![Method::GET])
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://example.com") let req = TestRequest::with_header("Origin", "https://example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!( assert_eq!(
&b"https://example.com"[..], &b"https://example.com"[..],
resp.headers() resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap() .unwrap()
.as_bytes() .as_bytes()
); );
let req = TestRequest::with_header("Origin", "https://example.org") let req = TestRequest::with_header("Origin", "https://example.org")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req); let resp = test::call_service(&mut cors, req).await;
assert_eq!( assert_eq!(
&b"https://example.org"[..], &b"https://example.org"[..],
resp.headers() resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap() .unwrap()
.as_bytes() .as_bytes()
); );
})
} }
} }

View File

@@ -1,5 +1,9 @@
# Changes # Changes
## [0.1.7] - 2019-11-06
* Add an additional `filename*` param in the `Content-Disposition` header of `actix_files::NamedFile` to be more compatible. (#1151)
## [0.1.6] - 2019-10-14 ## [0.1.6] - 2019-10-14
* Add option to redirect to a slash-ended path `Files` #1132 * Add option to redirect to a slash-ended path `Files` #1132

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-files" name = "actix-files"
version = "0.1.6" version = "0.2.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Static files support for actix web." description = "Static files support for actix web."
readme = "README.md" readme = "README.md"
@@ -18,12 +18,12 @@ name = "actix_files"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "1.0.8", default-features = false } actix-web = { version = "2.0.0-alpha.1", default-features = false }
actix-http = "0.2.9" actix-http = "0.3.0-alpha.1"
actix-service = "0.4.1" actix-service = "1.0.0-alpha.1"
bitflags = "1" bitflags = "1"
bytes = "0.4" bytes = "0.4"
futures = "0.1.25" futures = "0.3.1"
derive_more = "0.15.0" derive_more = "0.15.0"
log = "0.4" log = "0.4"
mime = "0.3" mime = "0.3"
@@ -32,4 +32,4 @@ percent-encoding = "2.1"
v_htmlescape = "0.4" v_htmlescape = "0.4"
[dev-dependencies] [dev-dependencies]
actix-web = { version = "1.0.8", features=["ssl"] } actix-web = { version = "2.0.0-alpha.1", features=["openssl"] }

File diff suppressed because it is too large Load Diff

View File

@@ -13,11 +13,12 @@ use mime_guess::from_path;
use actix_http::body::SizedStream; use actix_http::body::SizedStream;
use actix_web::http::header::{ use actix_web::http::header::{
self, ContentDisposition, DispositionParam, DispositionType, self, Charset, ContentDisposition, DispositionParam, DispositionType, ExtendedValue,
}; };
use actix_web::http::{ContentEncoding, StatusCode}; use actix_web::http::{ContentEncoding, StatusCode};
use actix_web::middleware::BodyEncoding; use actix_web::middleware::BodyEncoding;
use actix_web::{Error, HttpMessage, HttpRequest, HttpResponse, Responder}; use actix_web::{Error, HttpMessage, HttpRequest, HttpResponse, Responder};
use futures::future::{ready, Ready};
use crate::range::HttpRange; use crate::range::HttpRange;
use crate::ChunkedReadFile; use crate::ChunkedReadFile;
@@ -93,9 +94,18 @@ impl NamedFile {
mime::IMAGE | mime::TEXT | mime::VIDEO => DispositionType::Inline, mime::IMAGE | mime::TEXT | mime::VIDEO => DispositionType::Inline,
_ => DispositionType::Attachment, _ => DispositionType::Attachment,
}; };
let mut parameters =
vec![DispositionParam::Filename(String::from(filename.as_ref()))];
if !filename.is_ascii() {
parameters.push(DispositionParam::FilenameExt(ExtendedValue {
charset: Charset::Ext(String::from("UTF-8")),
language_tag: None,
value: filename.into_owned().into_bytes(),
}))
}
let cd = ContentDisposition { let cd = ContentDisposition {
disposition: disposition_type, disposition: disposition_type,
parameters: vec![DispositionParam::Filename(filename.into_owned())], parameters: parameters,
}; };
(ct, cd) (ct, cd)
}; };
@@ -246,62 +256,8 @@ impl NamedFile {
pub(crate) fn last_modified(&self) -> Option<header::HttpDate> { pub(crate) fn last_modified(&self) -> Option<header::HttpDate> {
self.modified.map(|mtime| mtime.into()) self.modified.map(|mtime| mtime.into())
} }
}
impl Deref for NamedFile { pub fn into_response(self, req: &HttpRequest) -> Result<HttpResponse, Error> {
type Target = File;
fn deref(&self) -> &File {
&self.file
}
}
impl DerefMut for NamedFile {
fn deref_mut(&mut self) -> &mut File {
&mut self.file
}
}
/// Returns true if `req` has no `If-Match` header or one which matches `etag`.
fn any_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
match req.get_header::<header::IfMatch>() {
None | Some(header::IfMatch::Any) => true,
Some(header::IfMatch::Items(ref items)) => {
if let Some(some_etag) = etag {
for item in items {
if item.strong_eq(some_etag) {
return true;
}
}
}
false
}
}
}
/// Returns true if `req` doesn't have an `If-None-Match` header matching `req`.
fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
match req.get_header::<header::IfNoneMatch>() {
Some(header::IfNoneMatch::Any) => false,
Some(header::IfNoneMatch::Items(ref items)) => {
if let Some(some_etag) = etag {
for item in items {
if item.weak_eq(some_etag) {
return false;
}
}
}
true
}
None => true,
}
}
impl Responder for NamedFile {
type Error = Error;
type Future = Result<HttpResponse, Error>;
fn respond_to(self, req: &HttpRequest) -> Self::Future {
if self.status_code != StatusCode::OK { if self.status_code != StatusCode::OK {
let mut resp = HttpResponse::build(self.status_code); let mut resp = HttpResponse::build(self.status_code);
resp.set(header::ContentType(self.content_type.clone())) resp.set(header::ContentType(self.content_type.clone()))
@@ -433,8 +389,67 @@ impl Responder for NamedFile {
counter: 0, counter: 0,
}; };
if offset != 0 || length != self.md.len() { if offset != 0 || length != self.md.len() {
return Ok(resp.status(StatusCode::PARTIAL_CONTENT).streaming(reader)); Ok(resp.status(StatusCode::PARTIAL_CONTENT).streaming(reader))
}; } else {
Ok(resp.body(SizedStream::new(length, reader))) Ok(resp.body(SizedStream::new(length, reader)))
}
}
}
impl Deref for NamedFile {
type Target = File;
fn deref(&self) -> &File {
&self.file
}
}
impl DerefMut for NamedFile {
fn deref_mut(&mut self) -> &mut File {
&mut self.file
}
}
/// Returns true if `req` has no `If-Match` header or one which matches `etag`.
fn any_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
match req.get_header::<header::IfMatch>() {
None | Some(header::IfMatch::Any) => true,
Some(header::IfMatch::Items(ref items)) => {
if let Some(some_etag) = etag {
for item in items {
if item.strong_eq(some_etag) {
return true;
}
}
}
false
}
}
}
/// Returns true if `req` doesn't have an `If-None-Match` header matching `req`.
fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
match req.get_header::<header::IfNoneMatch>() {
Some(header::IfNoneMatch::Any) => false,
Some(header::IfNoneMatch::Items(ref items)) => {
if let Some(some_etag) = etag {
for item in items {
if item.weak_eq(some_etag) {
return false;
}
}
}
true
}
None => true,
}
}
impl Responder for NamedFile {
type Error = Error;
type Future = Ready<Result<HttpResponse, Error>>;
fn respond_to(self, req: &HttpRequest) -> Self::Future {
ready(self.into_response(req))
} }
} }

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-framed" name = "actix-framed"
version = "0.2.1" version = "0.3.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix framed app server" description = "Actix framed app server"
readme = "README.md" readme = "README.md"
@@ -20,19 +20,20 @@ name = "actix_framed"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-codec = "0.1.2" actix-codec = "0.2.0-alpha.1"
actix-service = "0.4.1" actix-service = "1.0.0-alpha.1"
actix-router = "0.1.2" actix-router = "0.1.2"
actix-rt = "0.2.2" actix-rt = "1.0.0-alpha.1"
actix-http = "0.2.7" actix-http = "0.3.0-alpha.1"
actix-server-config = "0.1.2" actix-server-config = "0.3.0-alpha.1"
bytes = "0.4" bytes = "0.4"
futures = "0.1.25" futures = "0.3.1"
pin-project = "0.4.6"
log = "0.4" log = "0.4"
[dev-dependencies] [dev-dependencies]
actix-server = { version = "0.6.0", features=["ssl"] } actix-server = { version = "0.8.0-alpha.1", features=["openssl"] }
actix-connect = { version = "0.2.0", features=["ssl"] } actix-connect = { version = "0.3.0-alpha.1", features=["openssl"] }
actix-http-test = { version = "0.2.4", features=["ssl"] } actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] }
actix-utils = "0.4.4" actix-utils = "0.5.0-alpha.1"

View File

@@ -1,21 +1,24 @@
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::h1::{Codec, SendResponse}; use actix_http::h1::{Codec, SendResponse};
use actix_http::{Error, Request, Response}; use actix_http::{Error, Request, Response};
use actix_router::{Path, Router, Url}; use actix_router::{Path, Router, Url};
use actix_server_config::ServerConfig; use actix_server_config::ServerConfig;
use actix_service::{IntoNewService, NewService, Service}; use actix_service::{IntoServiceFactory, Service, ServiceFactory};
use futures::{Async, Future, Poll}; use futures::future::{ok, FutureExt, LocalBoxFuture};
use crate::helpers::{BoxedHttpNewService, BoxedHttpService, HttpNewService}; use crate::helpers::{BoxedHttpNewService, BoxedHttpService, HttpNewService};
use crate::request::FramedRequest; use crate::request::FramedRequest;
use crate::state::State; use crate::state::State;
type BoxedResponse = Box<dyn Future<Item = (), Error = Error>>; type BoxedResponse = LocalBoxFuture<'static, Result<(), Error>>;
pub trait HttpServiceFactory { pub trait HttpServiceFactory {
type Factory: NewService; type Factory: ServiceFactory;
fn path(&self) -> &str; fn path(&self) -> &str;
@@ -48,19 +51,19 @@ impl<T: 'static, S: 'static> FramedApp<T, S> {
pub fn service<U>(mut self, factory: U) -> Self pub fn service<U>(mut self, factory: U) -> Self
where where
U: HttpServiceFactory, U: HttpServiceFactory,
U::Factory: NewService< U::Factory: ServiceFactory<
Config = (), Config = (),
Request = FramedRequest<T, S>, Request = FramedRequest<T, S>,
Response = (), Response = (),
Error = Error, Error = Error,
InitError = (), InitError = (),
> + 'static, > + 'static,
<U::Factory as NewService>::Future: 'static, <U::Factory as ServiceFactory>::Future: 'static,
<U::Factory as NewService>::Service: Service< <U::Factory as ServiceFactory>::Service: Service<
Request = FramedRequest<T, S>, Request = FramedRequest<T, S>,
Response = (), Response = (),
Error = Error, Error = Error,
Future = Box<dyn Future<Item = (), Error = Error>>, Future = LocalBoxFuture<'static, Result<(), Error>>,
>, >,
{ {
let path = factory.path().to_string(); let path = factory.path().to_string();
@@ -70,12 +73,12 @@ impl<T: 'static, S: 'static> FramedApp<T, S> {
} }
} }
impl<T, S> IntoNewService<FramedAppFactory<T, S>> for FramedApp<T, S> impl<T, S> IntoServiceFactory<FramedAppFactory<T, S>> for FramedApp<T, S>
where where
T: AsyncRead + AsyncWrite + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
S: 'static, S: 'static,
{ {
fn into_new_service(self) -> FramedAppFactory<T, S> { fn into_factory(self) -> FramedAppFactory<T, S> {
FramedAppFactory { FramedAppFactory {
state: self.state, state: self.state,
services: Rc::new(self.services), services: Rc::new(self.services),
@@ -89,9 +92,9 @@ pub struct FramedAppFactory<T, S> {
services: Rc<Vec<(String, BoxedHttpNewService<FramedRequest<T, S>>)>>, services: Rc<Vec<(String, BoxedHttpNewService<FramedRequest<T, S>>)>>,
} }
impl<T, S> NewService for FramedAppFactory<T, S> impl<T, S> ServiceFactory for FramedAppFactory<T, S>
where where
T: AsyncRead + AsyncWrite + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
S: 'static, S: 'static,
{ {
type Config = ServerConfig; type Config = ServerConfig;
@@ -128,28 +131,30 @@ pub struct CreateService<T, S> {
enum CreateServiceItem<T, S> { enum CreateServiceItem<T, S> {
Future( Future(
Option<String>, Option<String>,
Box<dyn Future<Item = BoxedHttpService<FramedRequest<T, S>>, Error = ()>>, LocalBoxFuture<'static, Result<BoxedHttpService<FramedRequest<T, S>>, ()>>,
), ),
Service(String, BoxedHttpService<FramedRequest<T, S>>), Service(String, BoxedHttpService<FramedRequest<T, S>>),
} }
impl<S: 'static, T: 'static> Future for CreateService<T, S> impl<S: 'static, T: 'static> Future for CreateService<T, S>
where where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite + Unpin,
{ {
type Item = FramedAppService<T, S>; type Output = Result<FramedAppService<T, S>, ()>;
type Error = ();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut done = true; let mut done = true;
// poll http services // poll http services
for item in &mut self.fut { for item in &mut self.fut {
let res = match item { let res = match item {
CreateServiceItem::Future(ref mut path, ref mut fut) => { CreateServiceItem::Future(ref mut path, ref mut fut) => {
match fut.poll()? { match Pin::new(fut).poll(cx) {
Async::Ready(service) => Some((path.take().unwrap(), service)), Poll::Ready(Ok(service)) => {
Async::NotReady => { Some((path.take().unwrap(), service))
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => {
done = false; done = false;
None None
} }
@@ -176,12 +181,12 @@ where
} }
router router
}); });
Ok(Async::Ready(FramedAppService { Poll::Ready(Ok(FramedAppService {
router: router.finish(), router: router.finish(),
state: self.state.clone(), state: self.state.clone(),
})) }))
} else { } else {
Ok(Async::NotReady) Poll::Pending
} }
} }
} }
@@ -193,15 +198,15 @@ pub struct FramedAppService<T, S> {
impl<S: 'static, T: 'static> Service for FramedAppService<T, S> impl<S: 'static, T: 'static> Service for FramedAppService<T, S>
where where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite + Unpin,
{ {
type Request = (Request, Framed<T, Codec>); type Request = (Request, Framed<T, Codec>);
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type Future = BoxedResponse; type Future = BoxedResponse;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Ok(Async::Ready(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future { fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future {
@@ -210,8 +215,8 @@ where
if let Some((srv, _info)) = self.router.recognize_mut(&mut path) { if let Some((srv, _info)) = self.router.recognize_mut(&mut path) {
return srv.call(FramedRequest::new(req, framed, path, self.state.clone())); return srv.call(FramedRequest::new(req, framed, path, self.state.clone()));
} }
Box::new( SendResponse::new(framed, Response::NotFound().finish())
SendResponse::new(framed, Response::NotFound().finish()).then(|_| Ok(())), .then(|_| ok(()))
) .boxed_local()
} }
} }

View File

@@ -1,36 +1,38 @@
use std::task::{Context, Poll};
use actix_http::Error; use actix_http::Error;
use actix_service::{NewService, Service}; use actix_service::{Service, ServiceFactory};
use futures::{Future, Poll}; use futures::future::{FutureExt, LocalBoxFuture};
pub(crate) type BoxedHttpService<Req> = Box< pub(crate) type BoxedHttpService<Req> = Box<
dyn Service< dyn Service<
Request = Req, Request = Req,
Response = (), Response = (),
Error = Error, Error = Error,
Future = Box<dyn Future<Item = (), Error = Error>>, Future = LocalBoxFuture<'static, Result<(), Error>>,
>, >,
>; >;
pub(crate) type BoxedHttpNewService<Req> = Box< pub(crate) type BoxedHttpNewService<Req> = Box<
dyn NewService< dyn ServiceFactory<
Config = (), Config = (),
Request = Req, Request = Req,
Response = (), Response = (),
Error = Error, Error = Error,
InitError = (), InitError = (),
Service = BoxedHttpService<Req>, Service = BoxedHttpService<Req>,
Future = Box<dyn Future<Item = BoxedHttpService<Req>, Error = ()>>, Future = LocalBoxFuture<'static, Result<BoxedHttpService<Req>, ()>>,
>, >,
>; >;
pub(crate) struct HttpNewService<T: NewService>(T); pub(crate) struct HttpNewService<T: ServiceFactory>(T);
impl<T> HttpNewService<T> impl<T> HttpNewService<T>
where where
T: NewService<Response = (), Error = Error>, T: ServiceFactory<Response = (), Error = Error>,
T::Response: 'static, T::Response: 'static,
T::Future: 'static, T::Future: 'static,
T::Service: Service<Future = Box<dyn Future<Item = (), Error = Error>>> + 'static, T::Service: Service<Future = LocalBoxFuture<'static, Result<(), Error>>> + 'static,
<T::Service as Service>::Future: 'static, <T::Service as Service>::Future: 'static,
{ {
pub fn new(service: T) -> Self { pub fn new(service: T) -> Self {
@@ -38,12 +40,12 @@ where
} }
} }
impl<T> NewService for HttpNewService<T> impl<T> ServiceFactory for HttpNewService<T>
where where
T: NewService<Config = (), Response = (), Error = Error>, T: ServiceFactory<Config = (), Response = (), Error = Error>,
T::Request: 'static, T::Request: 'static,
T::Future: 'static, T::Future: 'static,
T::Service: Service<Future = Box<dyn Future<Item = (), Error = Error>>> + 'static, T::Service: Service<Future = LocalBoxFuture<'static, Result<(), Error>>> + 'static,
<T::Service as Service>::Future: 'static, <T::Service as Service>::Future: 'static,
{ {
type Config = (); type Config = ();
@@ -52,13 +54,19 @@ where
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Service = BoxedHttpService<T::Request>; type Service = BoxedHttpService<T::Request>;
type Future = Box<dyn Future<Item = Self::Service, Error = ()>>; type Future = LocalBoxFuture<'static, Result<Self::Service, ()>>;
fn new_service(&self, _: &()) -> Self::Future { fn new_service(&self, _: &()) -> Self::Future {
Box::new(self.0.new_service(&()).map_err(|_| ()).and_then(|service| { let fut = self.0.new_service(&());
let service: BoxedHttpService<_> = Box::new(HttpServiceWrapper { service });
Ok(service) async move {
})) fut.await.map_err(|_| ()).map(|service| {
let service: BoxedHttpService<_> =
Box::new(HttpServiceWrapper { service });
service
})
}
.boxed_local()
} }
} }
@@ -70,7 +78,7 @@ impl<T> Service for HttpServiceWrapper<T>
where where
T: Service< T: Service<
Response = (), Response = (),
Future = Box<dyn Future<Item = (), Error = Error>>, Future = LocalBoxFuture<'static, Result<(), Error>>,
Error = Error, Error = Error,
>, >,
T::Request: 'static, T::Request: 'static,
@@ -78,10 +86,10 @@ where
type Request = T::Request; type Request = T::Request;
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type Future = Box<dyn Future<Item = (), Error = Error>>; type Future = LocalBoxFuture<'static, Result<(), Error>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready() self.service.poll_ready(cx)
} }
fn call(&mut self, req: Self::Request) -> Self::Future { fn call(&mut self, req: Self::Request) -> Self::Future {

View File

@@ -1,11 +1,12 @@
use std::fmt; use std::fmt;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::{http::Method, Error}; use actix_http::{http::Method, Error};
use actix_service::{NewService, Service}; use actix_service::{Service, ServiceFactory};
use futures::future::{ok, FutureResult}; use futures::future::{ok, FutureExt, LocalBoxFuture, Ready};
use futures::{Async, Future, IntoFuture, Poll};
use log::error; use log::error;
use crate::app::HttpServiceFactory; use crate::app::HttpServiceFactory;
@@ -15,11 +16,11 @@ use crate::request::FramedRequest;
/// ///
/// Route uses builder-like pattern for configuration. /// Route uses builder-like pattern for configuration.
/// If handler is not explicitly set, default *404 Not Found* handler is used. /// If handler is not explicitly set, default *404 Not Found* handler is used.
pub struct FramedRoute<Io, S, F = (), R = ()> { pub struct FramedRoute<Io, S, F = (), R = (), E = ()> {
handler: F, handler: F,
pattern: String, pattern: String,
methods: Vec<Method>, methods: Vec<Method>,
state: PhantomData<(Io, S, R)>, state: PhantomData<(Io, S, R, E)>,
} }
impl<Io, S> FramedRoute<Io, S> { impl<Io, S> FramedRoute<Io, S> {
@@ -53,12 +54,12 @@ impl<Io, S> FramedRoute<Io, S> {
self self
} }
pub fn to<F, R>(self, handler: F) -> FramedRoute<Io, S, F, R> pub fn to<F, R, E>(self, handler: F) -> FramedRoute<Io, S, F, R, E>
where where
F: FnMut(FramedRequest<Io, S>) -> R, F: FnMut(FramedRequest<Io, S>) -> R,
R: IntoFuture<Item = ()>, R: Future<Output = Result<(), E>> + 'static,
R::Future: 'static,
R::Error: fmt::Debug, E: fmt::Debug,
{ {
FramedRoute { FramedRoute {
handler, handler,
@@ -69,15 +70,14 @@ impl<Io, S> FramedRoute<Io, S> {
} }
} }
impl<Io, S, F, R> HttpServiceFactory for FramedRoute<Io, S, F, R> impl<Io, S, F, R, E> HttpServiceFactory for FramedRoute<Io, S, F, R, E>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + 'static,
F: FnMut(FramedRequest<Io, S>) -> R + Clone, F: FnMut(FramedRequest<Io, S>) -> R + Clone,
R: IntoFuture<Item = ()>, R: Future<Output = Result<(), E>> + 'static,
R::Future: 'static, E: fmt::Display,
R::Error: fmt::Display,
{ {
type Factory = FramedRouteFactory<Io, S, F, R>; type Factory = FramedRouteFactory<Io, S, F, R, E>;
fn path(&self) -> &str { fn path(&self) -> &str {
&self.pattern &self.pattern
@@ -92,27 +92,26 @@ where
} }
} }
pub struct FramedRouteFactory<Io, S, F, R> { pub struct FramedRouteFactory<Io, S, F, R, E> {
handler: F, handler: F,
methods: Vec<Method>, methods: Vec<Method>,
_t: PhantomData<(Io, S, R)>, _t: PhantomData<(Io, S, R, E)>,
} }
impl<Io, S, F, R> NewService for FramedRouteFactory<Io, S, F, R> impl<Io, S, F, R, E> ServiceFactory for FramedRouteFactory<Io, S, F, R, E>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + 'static,
F: FnMut(FramedRequest<Io, S>) -> R + Clone, F: FnMut(FramedRequest<Io, S>) -> R + Clone,
R: IntoFuture<Item = ()>, R: Future<Output = Result<(), E>> + 'static,
R::Future: 'static, E: fmt::Display,
R::Error: fmt::Display,
{ {
type Config = (); type Config = ();
type Request = FramedRequest<Io, S>; type Request = FramedRequest<Io, S>;
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Service = FramedRouteService<Io, S, F, R>; type Service = FramedRouteService<Io, S, F, R, E>;
type Future = FutureResult<Self::Service, Self::InitError>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: &()) -> Self::Future { fn new_service(&self, _: &()) -> Self::Future {
ok(FramedRouteService { ok(FramedRouteService {
@@ -123,35 +122,38 @@ where
} }
} }
pub struct FramedRouteService<Io, S, F, R> { pub struct FramedRouteService<Io, S, F, R, E> {
handler: F, handler: F,
methods: Vec<Method>, methods: Vec<Method>,
_t: PhantomData<(Io, S, R)>, _t: PhantomData<(Io, S, R, E)>,
} }
impl<Io, S, F, R> Service for FramedRouteService<Io, S, F, R> impl<Io, S, F, R, E> Service for FramedRouteService<Io, S, F, R, E>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + 'static,
F: FnMut(FramedRequest<Io, S>) -> R + Clone, F: FnMut(FramedRequest<Io, S>) -> R + Clone,
R: IntoFuture<Item = ()>, R: Future<Output = Result<(), E>> + 'static,
R::Future: 'static, E: fmt::Display,
R::Error: fmt::Display,
{ {
type Request = FramedRequest<Io, S>; type Request = FramedRequest<Io, S>;
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type Future = Box<dyn Future<Item = (), Error = Error>>; type Future = LocalBoxFuture<'static, Result<(), Error>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Ok(Async::Ready(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, req: FramedRequest<Io, S>) -> Self::Future { fn call(&mut self, req: FramedRequest<Io, S>) -> Self::Future {
Box::new((self.handler)(req).into_future().then(|res| { let fut = (self.handler)(req);
async move {
let res = fut.await;
if let Err(e) = res { if let Err(e) = res {
error!("Error in request handler: {}", e); error!("Error in request handler: {}", e);
} }
Ok(()) Ok(())
})) }
.boxed_local()
} }
} }

View File

@@ -1,4 +1,6 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::body::BodySize; use actix_http::body::BodySize;
@@ -6,9 +8,9 @@ use actix_http::error::ResponseError;
use actix_http::h1::{Codec, Message}; use actix_http::h1::{Codec, Message};
use actix_http::ws::{verify_handshake, HandshakeError}; use actix_http::ws::{verify_handshake, HandshakeError};
use actix_http::{Request, Response}; use actix_http::{Request, Response};
use actix_service::{NewService, Service}; use actix_service::{Service, ServiceFactory};
use futures::future::{ok, Either, FutureResult}; use futures::future::{err, ok, Either, Ready};
use futures::{Async, Future, IntoFuture, Poll, Sink}; use futures::Future;
/// Service that verifies incoming request if it is valid websocket /// Service that verifies incoming request if it is valid websocket
/// upgrade request. In case of error returns `HandshakeError` /// upgrade request. In case of error returns `HandshakeError`
@@ -22,14 +24,14 @@ impl<T, C> Default for VerifyWebSockets<T, C> {
} }
} }
impl<T, C> NewService for VerifyWebSockets<T, C> { impl<T, C> ServiceFactory for VerifyWebSockets<T, C> {
type Config = C; type Config = C;
type Request = (Request, Framed<T, Codec>); type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>); type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>); type Error = (HandshakeError, Framed<T, Codec>);
type InitError = (); type InitError = ();
type Service = VerifyWebSockets<T, C>; type Service = VerifyWebSockets<T, C>;
type Future = FutureResult<Self::Service, Self::InitError>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: &C) -> Self::Future { fn new_service(&self, _: &C) -> Self::Future {
ok(VerifyWebSockets { _t: PhantomData }) ok(VerifyWebSockets { _t: PhantomData })
@@ -40,16 +42,16 @@ impl<T, C> Service for VerifyWebSockets<T, C> {
type Request = (Request, Framed<T, Codec>); type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>); type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>); type Error = (HandshakeError, Framed<T, Codec>);
type Future = FutureResult<Self::Response, Self::Error>; type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Ok(Async::Ready(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future { fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future {
match verify_handshake(req.head()) { match verify_handshake(req.head()) {
Err(e) => Err((e, framed)).into_future(), Err(e) => err((e, framed)),
Ok(_) => Ok((req, framed)).into_future(), Ok(_) => ok((req, framed)),
} }
} }
} }
@@ -67,9 +69,9 @@ where
} }
} }
impl<T, R, E, C> NewService for SendError<T, R, E, C> impl<T, R, E, C> ServiceFactory for SendError<T, R, E, C>
where where
T: AsyncRead + AsyncWrite + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
R: 'static, R: 'static,
E: ResponseError + 'static, E: ResponseError + 'static,
{ {
@@ -79,7 +81,7 @@ where
type Error = (E, Framed<T, Codec>); type Error = (E, Framed<T, Codec>);
type InitError = (); type InitError = ();
type Service = SendError<T, R, E, C>; type Service = SendError<T, R, E, C>;
type Future = FutureResult<Self::Service, Self::InitError>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: &C) -> Self::Future { fn new_service(&self, _: &C) -> Self::Future {
ok(SendError(PhantomData)) ok(SendError(PhantomData))
@@ -88,25 +90,25 @@ where
impl<T, R, E, C> Service for SendError<T, R, E, C> impl<T, R, E, C> Service for SendError<T, R, E, C>
where where
T: AsyncRead + AsyncWrite + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
R: 'static, R: 'static,
E: ResponseError + 'static, E: ResponseError + 'static,
{ {
type Request = Result<R, (E, Framed<T, Codec>)>; type Request = Result<R, (E, Framed<T, Codec>)>;
type Response = R; type Response = R;
type Error = (E, Framed<T, Codec>); type Error = (E, Framed<T, Codec>);
type Future = Either<FutureResult<R, (E, Framed<T, Codec>)>, SendErrorFut<T, R, E>>; type Future = Either<Ready<Result<R, (E, Framed<T, Codec>)>>, SendErrorFut<T, R, E>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Ok(Async::Ready(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, req: Result<R, (E, Framed<T, Codec>)>) -> Self::Future { fn call(&mut self, req: Result<R, (E, Framed<T, Codec>)>) -> Self::Future {
match req { match req {
Ok(r) => Either::A(ok(r)), Ok(r) => Either::Left(ok(r)),
Err((e, framed)) => { Err((e, framed)) => {
let res = e.error_response().drop_body(); let res = e.error_response().drop_body();
Either::B(SendErrorFut { Either::Right(SendErrorFut {
framed: Some(framed), framed: Some(framed),
res: Some((res, BodySize::Empty).into()), res: Some((res, BodySize::Empty).into()),
err: Some(e), err: Some(e),
@@ -117,6 +119,7 @@ where
} }
} }
#[pin_project::pin_project]
pub struct SendErrorFut<T, R, E> { pub struct SendErrorFut<T, R, E> {
res: Option<Message<(Response<()>, BodySize)>>, res: Option<Message<(Response<()>, BodySize)>>,
framed: Option<Framed<T, Codec>>, framed: Option<Framed<T, Codec>>,
@@ -127,23 +130,27 @@ pub struct SendErrorFut<T, R, E> {
impl<T, R, E> Future for SendErrorFut<T, R, E> impl<T, R, E> Future for SendErrorFut<T, R, E>
where where
E: ResponseError, E: ResponseError,
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite + Unpin,
{ {
type Item = R; type Output = Result<R, (E, Framed<T, Codec>)>;
type Error = (E, Framed<T, Codec>);
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let Some(res) = self.res.take() { if let Some(res) = self.res.take() {
if self.framed.as_mut().unwrap().force_send(res).is_err() { if self.framed.as_mut().unwrap().write(res).is_err() {
return Err((self.err.take().unwrap(), self.framed.take().unwrap())); return Poll::Ready(Err((
self.err.take().unwrap(),
self.framed.take().unwrap(),
)));
} }
} }
match self.framed.as_mut().unwrap().poll_complete() { match self.framed.as_mut().unwrap().flush(cx) {
Ok(Async::Ready(_)) => { Poll::Ready(Ok(_)) => {
Err((self.err.take().unwrap(), self.framed.take().unwrap())) Poll::Ready(Err((self.err.take().unwrap(), self.framed.take().unwrap())))
} }
Ok(Async::NotReady) => Ok(Async::NotReady), Poll::Ready(Err(_)) => {
Err(_) => Err((self.err.take().unwrap(), self.framed.take().unwrap())), Poll::Ready(Err((self.err.take().unwrap(), self.framed.take().unwrap())))
}
Poll::Pending => Poll::Pending,
} }
} }
} }

View File

@@ -1,4 +1,6 @@
//! Various helpers for Actix applications to use during testing. //! Various helpers for Actix applications to use during testing.
use std::future::Future;
use actix_codec::Framed; use actix_codec::Framed;
use actix_http::h1::Codec; use actix_http::h1::Codec;
use actix_http::http::header::{Header, HeaderName, IntoHeaderValue}; use actix_http::http::header::{Header, HeaderName, IntoHeaderValue};
@@ -6,7 +8,6 @@ use actix_http::http::{HttpTryFrom, Method, Uri, Version};
use actix_http::test::{TestBuffer, TestRequest as HttpTestRequest}; use actix_http::test::{TestBuffer, TestRequest as HttpTestRequest};
use actix_router::{Path, Url}; use actix_router::{Path, Url};
use actix_rt::Runtime; use actix_rt::Runtime;
use futures::IntoFuture;
use crate::{FramedRequest, State}; use crate::{FramedRequest, State};
@@ -121,10 +122,10 @@ impl<S> TestRequest<S> {
pub fn run<F, R, I, E>(self, f: F) -> Result<I, E> pub fn run<F, R, I, E>(self, f: F) -> Result<I, E>
where where
F: FnOnce(FramedRequest<TestBuffer, S>) -> R, F: FnOnce(FramedRequest<TestBuffer, S>) -> R,
R: IntoFuture<Item = I, Error = E>, R: Future<Output = Result<I, E>>,
{ {
let mut rt = Runtime::new().unwrap(); let mut rt = Runtime::new().unwrap();
rt.block_on(f(self.finish()).into_future()) rt.block_on(f(self.finish()))
} }
} }

View File

@@ -1,30 +1,31 @@
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::{body, http::StatusCode, ws, Error, HttpService, Response}; use actix_http::{body, http::StatusCode, ws, Error, HttpService, Response};
use actix_http_test::TestServer; use actix_http_test::{block_on, TestServer};
use actix_service::{IntoNewService, NewService}; use actix_service::{pipeline_factory, IntoServiceFactory, ServiceFactory};
use actix_utils::framed::FramedTransport; use actix_utils::framed::FramedTransport;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::future::{self, ok}; use futures::{future, SinkExt, StreamExt};
use futures::{Future, Sink, Stream};
use actix_framed::{FramedApp, FramedRequest, FramedRoute, SendError, VerifyWebSockets}; use actix_framed::{FramedApp, FramedRequest, FramedRoute, SendError, VerifyWebSockets};
fn ws_service<T: AsyncRead + AsyncWrite>( async fn ws_service<T: AsyncRead + AsyncWrite>(
req: FramedRequest<T>, req: FramedRequest<T>,
) -> impl Future<Item = (), Error = Error> { ) -> Result<(), Error> {
let (req, framed, _) = req.into_parts(); let (req, mut framed, _) = req.into_parts();
let res = ws::handshake(req.head()).unwrap().message_body(()); let res = ws::handshake(req.head()).unwrap().message_body(());
framed framed
.send((res, body::BodySize::None).into()) .send((res, body::BodySize::None).into())
.map_err(|_| panic!()) .await
.and_then(|framed| { .unwrap();
FramedTransport::new(framed.into_framed(ws::Codec::new()), service) FramedTransport::new(framed.into_framed(ws::Codec::new()), service)
.map_err(|_| panic!()) .await
}) .unwrap();
Ok(())
} }
fn service(msg: ws::Frame) -> impl Future<Item = ws::Message, Error = Error> { async fn service(msg: ws::Frame) -> Result<ws::Message, Error> {
let msg = match msg { let msg = match msg {
ws::Frame::Ping(msg) => ws::Message::Pong(msg), ws::Frame::Ping(msg) => ws::Message::Pong(msg),
ws::Frame::Text(text) => { ws::Frame::Text(text) => {
@@ -34,108 +35,129 @@ fn service(msg: ws::Frame) -> impl Future<Item = ws::Message, Error = Error> {
ws::Frame::Close(reason) => ws::Message::Close(reason), ws::Frame::Close(reason) => ws::Message::Close(reason),
_ => panic!(), _ => panic!(),
}; };
ok(msg) Ok(msg)
} }
#[test] #[test]
fn test_simple() { fn test_simple() {
let mut srv = TestServer::new(|| { block_on(async {
HttpService::build() let mut srv = TestServer::start(|| {
.upgrade( HttpService::build()
FramedApp::new().service(FramedRoute::get("/index.html").to(ws_service)), .upgrade(
) FramedApp::new()
.finish(|_| future::ok::<_, Error>(Response::NotFound())) .service(FramedRoute::get("/index.html").to(ws_service)),
}); )
.finish(|_| future::ok::<_, Error>(Response::NotFound()))
});
assert!(srv.ws_at("/test").is_err()); assert!(srv.ws_at("/test").await.is_err());
// client service // client service
let framed = srv.ws_at("/index.html").unwrap(); let mut framed = srv.ws_at("/index.html").await.unwrap();
let framed = srv framed
.block_on(framed.send(ws::Message::Text("text".to_string()))) .send(ws::Message::Text("text".to_string()))
.unwrap(); .await
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); .unwrap();
assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text"))))); let (item, mut framed) = framed.into_future().await;
assert_eq!(
item.unwrap().unwrap(),
ws::Frame::Text(Some(BytesMut::from("text")))
);
let framed = srv framed
.block_on(framed.send(ws::Message::Binary("text".into()))) .send(ws::Message::Binary("text".into()))
.unwrap(); .await
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); .unwrap();
assert_eq!( let (item, mut framed) = framed.into_future().await;
item, assert_eq!(
Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))) item.unwrap().unwrap(),
); ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))
);
let framed = srv framed.send(ws::Message::Ping("text".into())).await.unwrap();
.block_on(framed.send(ws::Message::Ping("text".into()))) let (item, mut framed) = framed.into_future().await;
.unwrap(); assert_eq!(
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); item.unwrap().unwrap(),
assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into()))); ws::Frame::Pong("text".to_string().into())
);
let framed = srv framed
.block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))) .send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
.unwrap(); .await
.unwrap();
let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); let (item, _) = framed.into_future().await;
assert_eq!( assert_eq!(
item, item.unwrap().unwrap(),
Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) ws::Frame::Close(Some(ws::CloseCode::Normal.into()))
); );
})
} }
#[test] #[test]
fn test_service() { fn test_service() {
let mut srv = TestServer::new(|| { block_on(async {
actix_http::h1::OneRequest::new().map_err(|_| ()).and_then( let mut srv = TestServer::start(|| {
VerifyWebSockets::default() pipeline_factory(actix_http::h1::OneRequest::new().map_err(|_| ())).and_then(
.then(SendError::default()) pipeline_factory(
.map_err(|_| ()) pipeline_factory(VerifyWebSockets::default())
.then(SendError::default())
.map_err(|_| ()),
)
.and_then( .and_then(
FramedApp::new() FramedApp::new()
.service(FramedRoute::get("/index.html").to(ws_service)) .service(FramedRoute::get("/index.html").to(ws_service))
.into_new_service() .into_factory()
.map_err(|_| ()), .map_err(|_| ()),
), ),
) )
}); });
// non ws request // non ws request
let res = srv.block_on(srv.get("/index.html").send()).unwrap(); let res = srv.get("/index.html").send().await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST); assert_eq!(res.status(), StatusCode::BAD_REQUEST);
// not found // not found
assert!(srv.ws_at("/test").is_err()); assert!(srv.ws_at("/test").await.is_err());
// client service // client service
let framed = srv.ws_at("/index.html").unwrap(); let mut framed = srv.ws_at("/index.html").await.unwrap();
let framed = srv framed
.block_on(framed.send(ws::Message::Text("text".to_string()))) .send(ws::Message::Text("text".to_string()))
.unwrap(); .await
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); .unwrap();
assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text"))))); let (item, mut framed) = framed.into_future().await;
assert_eq!(
item.unwrap().unwrap(),
ws::Frame::Text(Some(BytesMut::from("text")))
);
let framed = srv framed
.block_on(framed.send(ws::Message::Binary("text".into()))) .send(ws::Message::Binary("text".into()))
.unwrap(); .await
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); .unwrap();
assert_eq!( let (item, mut framed) = framed.into_future().await;
item, assert_eq!(
Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))) item.unwrap().unwrap(),
); ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))
);
let framed = srv framed.send(ws::Message::Ping("text".into())).await.unwrap();
.block_on(framed.send(ws::Message::Ping("text".into()))) let (item, mut framed) = framed.into_future().await;
.unwrap(); assert_eq!(
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); item.unwrap().unwrap(),
assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into()))); ws::Frame::Pong("text".to_string().into())
);
let framed = srv framed
.block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))) .send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
.unwrap(); .await
.unwrap();
let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); let (item, _) = framed.into_future().await;
assert_eq!( assert_eq!(
item, item.unwrap().unwrap(),
Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) ws::Frame::Close(Some(ws::CloseCode::Normal.into()))
); );
})
} }

View File

@@ -1,11 +1,15 @@
# Changes # Changes
## Not released yet ## [0.2.11] - 2019-11-06
### Added ### Added
* Add support for serde_json::Value to be passed as argument to ResponseBuilder.body() * Add support for serde_json::Value to be passed as argument to ResponseBuilder.body()
* Add an additional `filename*` param in the `Content-Disposition` header of `actix_files::NamedFile` to be more compatible. (#1151)
* Allow to use `std::convert::Infallible` as `actix_http::error::Error`
### Fixed ### Fixed
* To be compatible with non-English error responses, `ResponseError` rendered with `text/plain; charset=utf-8` header #1118 * To be compatible with non-English error responses, `ResponseError` rendered with `text/plain; charset=utf-8` header #1118
@@ -17,9 +21,6 @@
* Add support for sending HTTP requests with `Rc<RequestHead>` in addition to sending HTTP requests with `RequestHead` * Add support for sending HTTP requests with `Rc<RequestHead>` in addition to sending HTTP requests with `RequestHead`
* Allow to use `std::convert::Infallible` as `actix_http::error::Error`
### Fixed ### Fixed
* h2 will use error response #1080 * h2 will use error response #1080

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-http" name = "actix-http"
version = "0.2.10" version = "0.3.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix http primitives" description = "Actix http primitives"
readme = "README.md" readme = "README.md"
@@ -16,7 +16,7 @@ edition = "2018"
workspace = ".." workspace = ".."
[package.metadata.docs.rs] [package.metadata.docs.rs]
features = ["ssl", "fail", "brotli", "flate2-zlib", "secure-cookies"] features = ["openssl", "fail", "brotli", "flate2-zlib", "secure-cookies"]
[lib] [lib]
name = "actix_http" name = "actix_http"
@@ -26,10 +26,10 @@ path = "src/lib.rs"
default = [] default = []
# openssl # openssl
ssl = ["openssl", "actix-connect/ssl"] openssl = ["open-ssl", "actix-connect/openssl", "tokio-openssl"]
# rustls support # rustls support
rust-tls = ["rustls", "webpki-roots", "actix-connect/rust-tls"] # rustls = ["rust-tls", "webpki-roots", "actix-connect/rustls"]
# brotli encoding, requires c compiler # brotli encoding, requires c compiler
brotli = ["brotli2"] brotli = ["brotli2"]
@@ -47,23 +47,24 @@ fail = ["failure"]
secure-cookies = ["ring"] secure-cookies = ["ring"]
[dependencies] [dependencies]
actix-service = "0.4.1" actix-service = "1.0.0-alpha.1"
actix-codec = "0.1.2" actix-codec = "0.2.0-alpha.1"
actix-connect = "0.2.4" actix-connect = "1.0.0-alpha.1"
actix-utils = "0.4.4" actix-utils = "0.5.0-alpha.1"
actix-server-config = "0.1.2" actix-server-config = "0.3.0-alpha.1"
actix-threadpool = "0.1.1" actix-threadpool = "0.2.0-alpha.1"
base64 = "0.10" base64 = "0.10"
bitflags = "1.0" bitflags = "1.0"
bytes = "0.4" bytes = "0.4"
copyless = "0.1.4" copyless = "0.1.4"
chrono = "0.4.6"
derive_more = "0.15.0" derive_more = "0.15.0"
either = "1.5.2" either = "1.5.2"
encoding_rs = "0.8" encoding_rs = "0.8"
futures = "0.1.25" futures = "0.3.1"
hashbrown = "0.5.0" hashbrown = "0.6.3"
h2 = "0.1.16" h2 = "0.2.0-alpha.3"
http = "0.1.17" http = "0.1.17"
httparse = "1.3" httparse = "1.3"
indexmap = "1.2" indexmap = "1.2"
@@ -72,6 +73,7 @@ language-tags = "0.2"
log = "0.4" log = "0.4"
mime = "0.3" mime = "0.3"
percent-encoding = "2.1" percent-encoding = "2.1"
pin-project = "0.4.5"
rand = "0.7" rand = "0.7"
regex = "1.0" regex = "1.0"
serde = "1.0" serde = "1.0"
@@ -80,13 +82,16 @@ sha1 = "0.6"
slab = "0.4" slab = "0.4"
serde_urlencoded = "0.6.1" serde_urlencoded = "0.6.1"
time = "0.1.42" time = "0.1.42"
tokio-tcp = "0.1.3"
tokio-timer = "0.2.8" tokio = "=0.2.0-alpha.6"
tokio-current-thread = "0.1" tokio-io = "=0.2.0-alpha.6"
trust-dns-resolver = { version="0.11.1", default-features = false } tokio-net = "=0.2.0-alpha.6"
tokio-timer = "0.3.0-alpha.6"
tokio-executor = "=0.2.0-alpha.6"
trust-dns-resolver = { version="0.18.0-alpha.1", default-features = false }
# for secure cookie # for secure cookie
ring = { version = "0.14.6", optional = true } ring = { version = "0.16.9", optional = true }
# compression # compression
brotli2 = { version="0.3.2", optional = true } brotli2 = { version="0.3.2", optional = true }
@@ -94,17 +99,17 @@ flate2 = { version="1.0.7", optional = true, default-features = false }
# optional deps # optional deps
failure = { version = "0.1.5", optional = true } failure = { version = "0.1.5", optional = true }
openssl = { version="0.10", optional = true } open-ssl = { version="0.10", package="openssl", optional = true }
rustls = { version = "0.15.2", optional = true } tokio-openssl = { version = "0.4.0-alpha.6", optional = true }
webpki-roots = { version = "0.16", optional = true }
chrono = "0.4.6" rust-tls = { version = "0.16.0", package="rustls", optional = true }
webpki-roots = { version = "0.18", optional = true }
[dev-dependencies] [dev-dependencies]
actix-rt = "0.2.2" actix-rt = "1.0.0-alpha.1"
actix-server = { version = "0.6.0", features=["ssl", "rust-tls"] } actix-server = { version = "0.8.0-alpha.1", features=["openssl"] }
actix-connect = { version = "0.2.0", features=["ssl"] } actix-connect = { version = "1.0.0-alpha.1", features=["openssl"] }
actix-http-test = { version = "0.2.4", features=["ssl"] } actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] }
env_logger = "0.6" env_logger = "0.6"
serde_derive = "1.0" serde_derive = "1.0"
openssl = { version="0.10" } open-ssl = { version="0.10", package="openssl" }
tokio-tcp = "0.1"

View File

@@ -1,9 +1,9 @@
use std::{env, io}; use std::{env, io};
use actix_http::{error::PayloadError, HttpService, Request, Response}; use actix_http::{Error, HttpService, Request, Response};
use actix_server::Server; use actix_server::Server;
use bytes::BytesMut; use bytes::BytesMut;
use futures::{Future, Stream}; use futures::StreamExt;
use http::header::HeaderValue; use http::header::HeaderValue;
use log::info; use log::info;
@@ -17,20 +17,22 @@ fn main() -> io::Result<()> {
.client_timeout(1000) .client_timeout(1000)
.client_disconnect(1000) .client_disconnect(1000)
.finish(|mut req: Request| { .finish(|mut req: Request| {
req.take_payload() async move {
.fold(BytesMut::new(), move |mut body, chunk| { let mut body = BytesMut::new();
body.extend_from_slice(&chunk); while let Some(item) = req.payload().next().await {
Ok::<_, PayloadError>(body) body.extend_from_slice(&item?);
}) }
.and_then(|bytes| {
info!("request body: {:?}", bytes); info!("request body: {:?}", body);
let mut res = Response::Ok(); Ok::<_, Error>(
res.header( Response::Ok()
"x-head", .header(
HeaderValue::from_static("dummy value!"), "x-head",
); HeaderValue::from_static("dummy value!"),
Ok(res.body(bytes)) )
}) .body(body),
)
}
}) })
})? })?
.run() .run()

View File

@@ -1,25 +1,22 @@
use std::{env, io}; use std::{env, io};
use actix_http::http::HeaderValue; use actix_http::http::HeaderValue;
use actix_http::{error::PayloadError, Error, HttpService, Request, Response}; use actix_http::{Error, HttpService, Request, Response};
use actix_server::Server; use actix_server::Server;
use bytes::BytesMut; use bytes::BytesMut;
use futures::{Future, Stream}; use futures::StreamExt;
use log::info; use log::info;
fn handle_request(mut req: Request) -> impl Future<Item = Response, Error = Error> { async fn handle_request(mut req: Request) -> Result<Response, Error> {
req.take_payload() let mut body = BytesMut::new();
.fold(BytesMut::new(), move |mut body, chunk| { while let Some(item) = req.payload().next().await {
body.extend_from_slice(&chunk); body.extend_from_slice(&item?)
Ok::<_, PayloadError>(body) }
})
.from_err() info!("request body: {:?}", body);
.and_then(|bytes| { Ok(Response::Ok()
info!("request body: {:?}", bytes); .header("x-head", HeaderValue::from_static("dummy value!"))
let mut res = Response::Ok(); .body(body))
res.header("x-head", HeaderValue::from_static("dummy value!"));
Ok(res.body(bytes))
})
} }
fn main() -> io::Result<()> { fn main() -> io::Result<()> {
@@ -28,7 +25,7 @@ fn main() -> io::Result<()> {
Server::build() Server::build()
.bind("echo", "127.0.0.1:8080", || { .bind("echo", "127.0.0.1:8080", || {
HttpService::build().finish(|_req: Request| handle_request(_req)) HttpService::build().finish(handle_request)
})? })?
.run() .run()
} }

View File

@@ -1,8 +1,11 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, mem}; use std::{fmt, mem};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{Async, Poll, Stream}; use futures::Stream;
use pin_project::{pin_project, project};
use crate::error::Error; use crate::error::Error;
@@ -32,7 +35,7 @@ impl BodySize {
pub trait MessageBody { pub trait MessageBody {
fn size(&self) -> BodySize; fn size(&self) -> BodySize;
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error>; fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>>;
} }
impl MessageBody for () { impl MessageBody for () {
@@ -40,8 +43,8 @@ impl MessageBody for () {
BodySize::Empty BodySize::Empty
} }
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> { fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
Ok(Async::Ready(None)) Poll::Ready(None)
} }
} }
@@ -50,11 +53,12 @@ impl<T: MessageBody> MessageBody for Box<T> {
self.as_ref().size() self.as_ref().size()
} }
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> { fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
self.as_mut().poll_next() self.as_mut().poll_next(cx)
} }
} }
#[pin_project]
pub enum ResponseBody<B> { pub enum ResponseBody<B> {
Body(B), Body(B),
Other(Body), Other(Body),
@@ -93,20 +97,24 @@ impl<B: MessageBody> MessageBody for ResponseBody<B> {
} }
} }
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> { fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
match self { match self {
ResponseBody::Body(ref mut body) => body.poll_next(), ResponseBody::Body(ref mut body) => body.poll_next(cx),
ResponseBody::Other(ref mut body) => body.poll_next(), ResponseBody::Other(ref mut body) => body.poll_next(cx),
} }
} }
} }
impl<B: MessageBody> Stream for ResponseBody<B> { impl<B: MessageBody> Stream for ResponseBody<B> {
type Item = Bytes; type Item = Result<Bytes, Error>;
type Error = Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { #[project]
self.poll_next() fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
#[project]
match self.project() {
ResponseBody::Body(ref mut body) => body.poll_next(cx),
ResponseBody::Other(ref mut body) => body.poll_next(cx),
}
} }
} }
@@ -144,19 +152,19 @@ impl MessageBody for Body {
} }
} }
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> { fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
match self { match self {
Body::None => Ok(Async::Ready(None)), Body::None => Poll::Ready(None),
Body::Empty => Ok(Async::Ready(None)), Body::Empty => Poll::Ready(None),
Body::Bytes(ref mut bin) => { Body::Bytes(ref mut bin) => {
let len = bin.len(); let len = bin.len();
if len == 0 { if len == 0 {
Ok(Async::Ready(None)) Poll::Ready(None)
} else { } else {
Ok(Async::Ready(Some(mem::replace(bin, Bytes::new())))) Poll::Ready(Some(Ok(mem::replace(bin, Bytes::new()))))
} }
} }
Body::Message(ref mut body) => body.poll_next(), Body::Message(ref mut body) => body.poll_next(cx),
} }
} }
} }
@@ -242,7 +250,7 @@ impl From<serde_json::Value> for Body {
impl<S> From<SizedStream<S>> for Body impl<S> From<SizedStream<S>> for Body
where where
S: Stream<Item = Bytes, Error = Error> + 'static, S: Stream<Item = Result<Bytes, Error>> + 'static,
{ {
fn from(s: SizedStream<S>) -> Body { fn from(s: SizedStream<S>) -> Body {
Body::from_message(s) Body::from_message(s)
@@ -251,7 +259,7 @@ where
impl<S, E> From<BodyStream<S, E>> for Body impl<S, E> From<BodyStream<S, E>> for Body
where where
S: Stream<Item = Bytes, Error = E> + 'static, S: Stream<Item = Result<Bytes, E>> + 'static,
E: Into<Error> + 'static, E: Into<Error> + 'static,
{ {
fn from(s: BodyStream<S, E>) -> Body { fn from(s: BodyStream<S, E>) -> Body {
@@ -264,11 +272,11 @@ impl MessageBody for Bytes {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> { fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
if self.is_empty() { if self.is_empty() {
Ok(Async::Ready(None)) Poll::Ready(None)
} else { } else {
Ok(Async::Ready(Some(mem::replace(self, Bytes::new())))) Poll::Ready(Some(Ok(mem::replace(self, Bytes::new()))))
} }
} }
} }
@@ -278,13 +286,11 @@ impl MessageBody for BytesMut {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> { fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
if self.is_empty() { if self.is_empty() {
Ok(Async::Ready(None)) Poll::Ready(None)
} else { } else {
Ok(Async::Ready(Some( Poll::Ready(Some(Ok(mem::replace(self, BytesMut::new()).freeze())))
mem::replace(self, BytesMut::new()).freeze(),
)))
} }
} }
} }
@@ -294,11 +300,11 @@ impl MessageBody for &'static str {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> { fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
if self.is_empty() { if self.is_empty() {
Ok(Async::Ready(None)) Poll::Ready(None)
} else { } else {
Ok(Async::Ready(Some(Bytes::from_static( Poll::Ready(Some(Ok(Bytes::from_static(
mem::replace(self, "").as_ref(), mem::replace(self, "").as_ref(),
)))) ))))
} }
@@ -310,13 +316,11 @@ impl MessageBody for &'static [u8] {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> { fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
if self.is_empty() { if self.is_empty() {
Ok(Async::Ready(None)) Poll::Ready(None)
} else { } else {
Ok(Async::Ready(Some(Bytes::from_static(mem::replace( Poll::Ready(Some(Ok(Bytes::from_static(mem::replace(self, b"")))))
self, b"",
)))))
} }
} }
} }
@@ -326,14 +330,11 @@ impl MessageBody for Vec<u8> {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> { fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
if self.is_empty() { if self.is_empty() {
Ok(Async::Ready(None)) Poll::Ready(None)
} else { } else {
Ok(Async::Ready(Some(Bytes::from(mem::replace( Poll::Ready(Some(Ok(Bytes::from(mem::replace(self, Vec::new())))))
self,
Vec::new(),
)))))
} }
} }
} }
@@ -343,11 +344,11 @@ impl MessageBody for String {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> { fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
if self.is_empty() { if self.is_empty() {
Ok(Async::Ready(None)) Poll::Ready(None)
} else { } else {
Ok(Async::Ready(Some(Bytes::from( Poll::Ready(Some(Ok(Bytes::from(
mem::replace(self, String::new()).into_bytes(), mem::replace(self, String::new()).into_bytes(),
)))) ))))
} }
@@ -356,14 +357,16 @@ impl MessageBody for String {
/// Type represent streaming body. /// Type represent streaming body.
/// Response does not contain `content-length` header and appropriate transfer encoding is used. /// Response does not contain `content-length` header and appropriate transfer encoding is used.
#[pin_project]
pub struct BodyStream<S, E> { pub struct BodyStream<S, E> {
#[pin]
stream: S, stream: S,
_t: PhantomData<E>, _t: PhantomData<E>,
} }
impl<S, E> BodyStream<S, E> impl<S, E> BodyStream<S, E>
where where
S: Stream<Item = Bytes, Error = E>, S: Stream<Item = Result<Bytes, E>>,
E: Into<Error>, E: Into<Error>,
{ {
pub fn new(stream: S) -> Self { pub fn new(stream: S) -> Self {
@@ -376,28 +379,34 @@ where
impl<S, E> MessageBody for BodyStream<S, E> impl<S, E> MessageBody for BodyStream<S, E>
where where
S: Stream<Item = Bytes, Error = E>, S: Stream<Item = Result<Bytes, E>>,
E: Into<Error>, E: Into<Error>,
{ {
fn size(&self) -> BodySize { fn size(&self) -> BodySize {
BodySize::Stream BodySize::Stream
} }
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> { fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
self.stream.poll().map_err(std::convert::Into::into) unsafe { Pin::new_unchecked(self) }
.project()
.stream
.poll_next(cx)
.map(|res| res.map(|res| res.map_err(std::convert::Into::into)))
} }
} }
/// Type represent streaming body. This body implementation should be used /// Type represent streaming body. This body implementation should be used
/// if total size of stream is known. Data get sent as is without using transfer encoding. /// if total size of stream is known. Data get sent as is without using transfer encoding.
#[pin_project]
pub struct SizedStream<S> { pub struct SizedStream<S> {
size: u64, size: u64,
#[pin]
stream: S, stream: S,
} }
impl<S> SizedStream<S> impl<S> SizedStream<S>
where where
S: Stream<Item = Bytes, Error = Error>, S: Stream<Item = Result<Bytes, Error>>,
{ {
pub fn new(size: u64, stream: S) -> Self { pub fn new(size: u64, stream: S) -> Self {
SizedStream { size, stream } SizedStream { size, stream }
@@ -406,20 +415,25 @@ where
impl<S> MessageBody for SizedStream<S> impl<S> MessageBody for SizedStream<S>
where where
S: Stream<Item = Bytes, Error = Error>, S: Stream<Item = Result<Bytes, Error>>,
{ {
fn size(&self) -> BodySize { fn size(&self) -> BodySize {
BodySize::Sized64(self.size) BodySize::Sized64(self.size)
} }
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> { fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
self.stream.poll() unsafe { Pin::new_unchecked(self) }
.project()
.stream
.poll_next(cx)
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use actix_http_test::block_on;
use futures::future::{lazy, poll_fn};
impl Body { impl Body {
pub(crate) fn get_ref(&self) -> &[u8] { pub(crate) fn get_ref(&self) -> &[u8] {
@@ -447,8 +461,8 @@ mod tests {
assert_eq!("test".size(), BodySize::Sized(4)); assert_eq!("test".size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
"test".poll_next().unwrap(), block_on(poll_fn(|cx| "test".poll_next(cx))).unwrap().ok(),
Async::Ready(Some(Bytes::from("test"))) Some(Bytes::from("test"))
); );
} }
@@ -464,8 +478,10 @@ mod tests {
assert_eq!((&b"test"[..]).size(), BodySize::Sized(4)); assert_eq!((&b"test"[..]).size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
(&b"test"[..]).poll_next().unwrap(), block_on(poll_fn(|cx| (&b"test"[..]).poll_next(cx)))
Async::Ready(Some(Bytes::from("test"))) .unwrap()
.ok(),
Some(Bytes::from("test"))
); );
} }
@@ -476,8 +492,10 @@ mod tests {
assert_eq!(Vec::from("test").size(), BodySize::Sized(4)); assert_eq!(Vec::from("test").size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
Vec::from("test").poll_next().unwrap(), block_on(poll_fn(|cx| Vec::from("test").poll_next(cx)))
Async::Ready(Some(Bytes::from("test"))) .unwrap()
.ok(),
Some(Bytes::from("test"))
); );
} }
@@ -489,8 +507,8 @@ mod tests {
assert_eq!(b.size(), BodySize::Sized(4)); assert_eq!(b.size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
b.poll_next().unwrap(), block_on(poll_fn(|cx| b.poll_next(cx))).unwrap().ok(),
Async::Ready(Some(Bytes::from("test"))) Some(Bytes::from("test"))
); );
} }
@@ -502,8 +520,8 @@ mod tests {
assert_eq!(b.size(), BodySize::Sized(4)); assert_eq!(b.size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
b.poll_next().unwrap(), block_on(poll_fn(|cx| b.poll_next(cx))).unwrap().ok(),
Async::Ready(Some(Bytes::from("test"))) Some(Bytes::from("test"))
); );
} }
@@ -517,22 +535,22 @@ mod tests {
assert_eq!(b.size(), BodySize::Sized(4)); assert_eq!(b.size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
b.poll_next().unwrap(), block_on(poll_fn(|cx| b.poll_next(cx))).unwrap().ok(),
Async::Ready(Some(Bytes::from("test"))) Some(Bytes::from("test"))
); );
} }
#[test] #[test]
fn test_unit() { fn test_unit() {
assert_eq!(().size(), BodySize::Empty); assert_eq!(().size(), BodySize::Empty);
assert_eq!(().poll_next().unwrap(), Async::Ready(None)); assert!(block_on(poll_fn(|cx| ().poll_next(cx))).is_none());
} }
#[test] #[test]
fn test_box() { fn test_box() {
let mut val = Box::new(()); let mut val = Box::new(());
assert_eq!(val.size(), BodySize::Empty); assert_eq!(val.size(), BodySize::Empty);
assert_eq!(val.poll_next().unwrap(), Async::Ready(None)); assert!(block_on(poll_fn(|cx| val.poll_next(cx))).is_none());
} }
#[test] #[test]

View File

@@ -4,7 +4,7 @@ use std::rc::Rc;
use actix_codec::Framed; use actix_codec::Framed;
use actix_server_config::ServerConfig as SrvConfig; use actix_server_config::ServerConfig as SrvConfig;
use actix_service::{IntoNewService, NewService, Service}; use actix_service::{IntoServiceFactory, Service, ServiceFactory};
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::config::{KeepAlive, ServiceConfig}; use crate::config::{KeepAlive, ServiceConfig};
@@ -32,9 +32,10 @@ pub struct HttpServiceBuilder<T, S, X = ExpectHandler, U = UpgradeHandler<T>> {
impl<T, S> HttpServiceBuilder<T, S, ExpectHandler, UpgradeHandler<T>> impl<T, S> HttpServiceBuilder<T, S, ExpectHandler, UpgradeHandler<T>>
where where
S: NewService<Config = SrvConfig, Request = Request>, S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
<S::Service as Service>::Future: 'static,
{ {
/// Create instance of `ServiceConfigBuilder` /// Create instance of `ServiceConfigBuilder`
pub fn new() -> Self { pub fn new() -> Self {
@@ -52,19 +53,22 @@ where
impl<T, S, X, U> HttpServiceBuilder<T, S, X, U> impl<T, S, X, U> HttpServiceBuilder<T, S, X, U>
where where
S: NewService<Config = SrvConfig, Request = Request>, S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
X: NewService<Config = SrvConfig, Request = Request, Response = Request>, <S::Service as Service>::Future: 'static,
X: ServiceFactory<Config = SrvConfig, Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: NewService< <X::Service as Service>::Future: 'static,
U: ServiceFactory<
Config = SrvConfig, Config = SrvConfig,
Request = (Request, Framed<T, Codec>), Request = (Request, Framed<T, Codec>),
Response = (), Response = (),
>, >,
U::Error: fmt::Display, U::Error: fmt::Display,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
<U::Service as Service>::Future: 'static,
{ {
/// Set server keep-alive setting. /// Set server keep-alive setting.
/// ///
@@ -108,16 +112,17 @@ where
/// request will be forwarded to main service. /// request will be forwarded to main service.
pub fn expect<F, X1>(self, expect: F) -> HttpServiceBuilder<T, S, X1, U> pub fn expect<F, X1>(self, expect: F) -> HttpServiceBuilder<T, S, X1, U>
where where
F: IntoNewService<X1>, F: IntoServiceFactory<X1>,
X1: NewService<Config = SrvConfig, Request = Request, Response = Request>, X1: ServiceFactory<Config = SrvConfig, Request = Request, Response = Request>,
X1::Error: Into<Error>, X1::Error: Into<Error>,
X1::InitError: fmt::Debug, X1::InitError: fmt::Debug,
<X1::Service as Service>::Future: 'static,
{ {
HttpServiceBuilder { HttpServiceBuilder {
keep_alive: self.keep_alive, keep_alive: self.keep_alive,
client_timeout: self.client_timeout, client_timeout: self.client_timeout,
client_disconnect: self.client_disconnect, client_disconnect: self.client_disconnect,
expect: expect.into_new_service(), expect: expect.into_factory(),
upgrade: self.upgrade, upgrade: self.upgrade,
on_connect: self.on_connect, on_connect: self.on_connect,
_t: PhantomData, _t: PhantomData,
@@ -130,21 +135,22 @@ where
/// and this service get called with original request and framed object. /// and this service get called with original request and framed object.
pub fn upgrade<F, U1>(self, upgrade: F) -> HttpServiceBuilder<T, S, X, U1> pub fn upgrade<F, U1>(self, upgrade: F) -> HttpServiceBuilder<T, S, X, U1>
where where
F: IntoNewService<U1>, F: IntoServiceFactory<U1>,
U1: NewService< U1: ServiceFactory<
Config = SrvConfig, Config = SrvConfig,
Request = (Request, Framed<T, Codec>), Request = (Request, Framed<T, Codec>),
Response = (), Response = (),
>, >,
U1::Error: fmt::Display, U1::Error: fmt::Display,
U1::InitError: fmt::Debug, U1::InitError: fmt::Debug,
<U1::Service as Service>::Future: 'static,
{ {
HttpServiceBuilder { HttpServiceBuilder {
keep_alive: self.keep_alive, keep_alive: self.keep_alive,
client_timeout: self.client_timeout, client_timeout: self.client_timeout,
client_disconnect: self.client_disconnect, client_disconnect: self.client_disconnect,
expect: self.expect, expect: self.expect,
upgrade: Some(upgrade.into_new_service()), upgrade: Some(upgrade.into_factory()),
on_connect: self.on_connect, on_connect: self.on_connect,
_t: PhantomData, _t: PhantomData,
} }
@@ -166,8 +172,8 @@ where
/// Finish service configuration and create *http service* for HTTP/1 protocol. /// Finish service configuration and create *http service* for HTTP/1 protocol.
pub fn h1<F, P, B>(self, service: F) -> H1Service<T, P, S, B, X, U> pub fn h1<F, P, B>(self, service: F) -> H1Service<T, P, S, B, X, U>
where where
B: MessageBody + 'static, B: MessageBody,
F: IntoNewService<S>, F: IntoServiceFactory<S>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
@@ -177,7 +183,7 @@ where
self.client_timeout, self.client_timeout,
self.client_disconnect, self.client_disconnect,
); );
H1Service::with_config(cfg, service.into_new_service()) H1Service::with_config(cfg, service.into_factory())
.expect(self.expect) .expect(self.expect)
.upgrade(self.upgrade) .upgrade(self.upgrade)
.on_connect(self.on_connect) .on_connect(self.on_connect)
@@ -187,10 +193,10 @@ where
pub fn h2<F, P, B>(self, service: F) -> H2Service<T, P, S, B> pub fn h2<F, P, B>(self, service: F) -> H2Service<T, P, S, B>
where where
B: MessageBody + 'static, B: MessageBody + 'static,
F: IntoNewService<S>, F: IntoServiceFactory<S>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
{ {
let cfg = ServiceConfig::new( let cfg = ServiceConfig::new(
@@ -198,18 +204,17 @@ where
self.client_timeout, self.client_timeout,
self.client_disconnect, self.client_disconnect,
); );
H2Service::with_config(cfg, service.into_new_service()) H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect)
.on_connect(self.on_connect)
} }
/// Finish service configuration and create `HttpService` instance. /// Finish service configuration and create `HttpService` instance.
pub fn finish<F, P, B>(self, service: F) -> HttpService<T, P, S, B, X, U> pub fn finish<F, P, B>(self, service: F) -> HttpService<T, P, S, B, X, U>
where where
B: MessageBody + 'static, B: MessageBody + 'static,
F: IntoNewService<S>, F: IntoServiceFactory<S>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
{ {
let cfg = ServiceConfig::new( let cfg = ServiceConfig::new(
@@ -217,7 +222,7 @@ where
self.client_timeout, self.client_timeout,
self.client_disconnect, self.client_disconnect,
); );
HttpService::with_config(cfg, service.into_new_service()) HttpService::with_config(cfg, service.into_factory())
.expect(self.expect) .expect(self.expect)
.upgrade(self.upgrade) .upgrade(self.upgrade)
.on_connect(self.on_connect) .on_connect(self.on_connect)

View File

@@ -1,10 +1,12 @@
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, io, time}; use std::{fmt, io, time};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use bytes::{Buf, Bytes}; use bytes::{Buf, Bytes};
use futures::future::{err, Either, Future, FutureResult}; use futures::future::{err, Either, Future, FutureExt, LocalBoxFuture, Ready};
use futures::Poll;
use h2::client::SendRequest; use h2::client::SendRequest;
use pin_project::{pin_project, project};
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::h1::ClientCodec; use crate::h1::ClientCodec;
@@ -21,8 +23,8 @@ pub(crate) enum ConnectionType<Io> {
} }
pub trait Connection { pub trait Connection {
type Io: AsyncRead + AsyncWrite; type Io: AsyncRead + AsyncWrite + Unpin;
type Future: Future<Item = (ResponseHead, Payload), Error = SendRequestError>; type Future: Future<Output = Result<(ResponseHead, Payload), SendRequestError>>;
fn protocol(&self) -> Protocol; fn protocol(&self) -> Protocol;
@@ -34,8 +36,7 @@ pub trait Connection {
) -> Self::Future; ) -> Self::Future;
type TunnelFuture: Future< type TunnelFuture: Future<
Item = (ResponseHead, Framed<Self::Io, ClientCodec>), Output = Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
Error = SendRequestError,
>; >;
/// Send request, returns Response and Framed /// Send request, returns Response and Framed
@@ -71,7 +72,7 @@ where
} }
} }
impl<T: AsyncRead + AsyncWrite> IoConnection<T> { impl<T: AsyncRead + AsyncWrite + Unpin> IoConnection<T> {
pub(crate) fn new( pub(crate) fn new(
io: ConnectionType<T>, io: ConnectionType<T>,
created: time::Instant, created: time::Instant,
@@ -91,11 +92,11 @@ impl<T: AsyncRead + AsyncWrite> IoConnection<T> {
impl<T> Connection for IoConnection<T> impl<T> Connection for IoConnection<T>
where where
T: AsyncRead + AsyncWrite + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
type Io = T; type Io = T;
type Future = type Future =
Box<dyn Future<Item = (ResponseHead, Payload), Error = SendRequestError>>; LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>;
fn protocol(&self) -> Protocol { fn protocol(&self) -> Protocol {
match self.io { match self.io {
@@ -111,38 +112,30 @@ where
body: B, body: B,
) -> Self::Future { ) -> Self::Future {
match self.io.take().unwrap() { match self.io.take().unwrap() {
ConnectionType::H1(io) => Box::new(h1proto::send_request( ConnectionType::H1(io) => {
io, h1proto::send_request(io, head.into(), body, self.created, self.pool)
head.into(), .boxed_local()
body, }
self.created, ConnectionType::H2(io) => {
self.pool, h2proto::send_request(io, head.into(), body, self.created, self.pool)
)), .boxed_local()
ConnectionType::H2(io) => Box::new(h2proto::send_request( }
io,
head.into(),
body,
self.created,
self.pool,
)),
} }
} }
type TunnelFuture = Either< type TunnelFuture = Either<
Box< LocalBoxFuture<
dyn Future< 'static,
Item = (ResponseHead, Framed<Self::Io, ClientCodec>), Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
Error = SendRequestError,
>,
>, >,
FutureResult<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>, Ready<Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>>,
>; >;
/// Send request, returns Response and Framed /// Send request, returns Response and Framed
fn open_tunnel<H: Into<RequestHeadType>>(mut self, head: H) -> Self::TunnelFuture { fn open_tunnel<H: Into<RequestHeadType>>(mut self, head: H) -> Self::TunnelFuture {
match self.io.take().unwrap() { match self.io.take().unwrap() {
ConnectionType::H1(io) => { ConnectionType::H1(io) => {
Either::A(Box::new(h1proto::open_tunnel(io, head.into()))) Either::Left(h1proto::open_tunnel(io, head.into()).boxed_local())
} }
ConnectionType::H2(io) => { ConnectionType::H2(io) => {
if let Some(mut pool) = self.pool.take() { if let Some(mut pool) = self.pool.take() {
@@ -152,7 +145,7 @@ where
None, None,
)); ));
} }
Either::B(err(SendRequestError::TunnelNotSupported)) Either::Right(err(SendRequestError::TunnelNotSupported))
} }
} }
} }
@@ -166,12 +159,12 @@ pub(crate) enum EitherConnection<A, B> {
impl<A, B> Connection for EitherConnection<A, B> impl<A, B> Connection for EitherConnection<A, B>
where where
A: AsyncRead + AsyncWrite + 'static, A: AsyncRead + AsyncWrite + Unpin + 'static,
B: AsyncRead + AsyncWrite + 'static, B: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
type Io = EitherIo<A, B>; type Io = EitherIo<A, B>;
type Future = type Future =
Box<dyn Future<Item = (ResponseHead, Payload), Error = SendRequestError>>; LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>;
fn protocol(&self) -> Protocol { fn protocol(&self) -> Protocol {
match self { match self {
@@ -191,44 +184,30 @@ where
} }
} }
type TunnelFuture = Box< type TunnelFuture = LocalBoxFuture<
dyn Future< 'static,
Item = (ResponseHead, Framed<Self::Io, ClientCodec>), Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
Error = SendRequestError,
>,
>; >;
/// Send request, returns Response and Framed /// Send request, returns Response and Framed
fn open_tunnel<H: Into<RequestHeadType>>(self, head: H) -> Self::TunnelFuture { fn open_tunnel<H: Into<RequestHeadType>>(self, head: H) -> Self::TunnelFuture {
match self { match self {
EitherConnection::A(con) => Box::new( EitherConnection::A(con) => con
con.open_tunnel(head) .open_tunnel(head)
.map(|(head, framed)| (head, framed.map_io(EitherIo::A))), .map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::A))))
), .boxed_local(),
EitherConnection::B(con) => Box::new( EitherConnection::B(con) => con
con.open_tunnel(head) .open_tunnel(head)
.map(|(head, framed)| (head, framed.map_io(EitherIo::B))), .map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::B))))
), .boxed_local(),
} }
} }
} }
#[pin_project]
pub enum EitherIo<A, B> { pub enum EitherIo<A, B> {
A(A), A(#[pin] A),
B(B), B(#[pin] B),
}
impl<A, B> io::Read for EitherIo<A, B>
where
A: io::Read,
B: io::Read,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
EitherIo::A(ref mut val) => val.read(buf),
EitherIo::B(ref mut val) => val.read(buf),
}
}
} }
impl<A, B> AsyncRead for EitherIo<A, B> impl<A, B> AsyncRead for EitherIo<A, B>
@@ -236,6 +215,19 @@ where
A: AsyncRead, A: AsyncRead,
B: AsyncRead, B: AsyncRead,
{ {
#[project]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
#[project]
match self.project() {
EitherIo::A(val) => val.poll_read(cx, buf),
EitherIo::B(val) => val.poll_read(cx, buf),
}
}
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
match self { match self {
EitherIo::A(ref val) => val.prepare_uninitialized_buffer(buf), EitherIo::A(ref val) => val.prepare_uninitialized_buffer(buf),
@@ -244,45 +236,58 @@ where
} }
} }
impl<A, B> io::Write for EitherIo<A, B>
where
A: io::Write,
B: io::Write,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
EitherIo::A(ref mut val) => val.write(buf),
EitherIo::B(ref mut val) => val.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
EitherIo::A(ref mut val) => val.flush(),
EitherIo::B(ref mut val) => val.flush(),
}
}
}
impl<A, B> AsyncWrite for EitherIo<A, B> impl<A, B> AsyncWrite for EitherIo<A, B>
where where
A: AsyncWrite, A: AsyncWrite,
B: AsyncWrite, B: AsyncWrite,
{ {
fn shutdown(&mut self) -> Poll<(), io::Error> { #[project]
match self { fn poll_write(
EitherIo::A(ref mut val) => val.shutdown(), self: Pin<&mut Self>,
EitherIo::B(ref mut val) => val.shutdown(), cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
#[project]
match self.project() {
EitherIo::A(val) => val.poll_write(cx, buf),
EitherIo::B(val) => val.poll_write(cx, buf),
} }
} }
fn write_buf<U: Buf>(&mut self, buf: &mut U) -> Poll<usize, io::Error> #[project]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
#[project]
match self.project() {
EitherIo::A(val) => val.poll_flush(cx),
EitherIo::B(val) => val.poll_flush(cx),
}
}
#[project]
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
#[project]
match self.project() {
EitherIo::A(val) => val.poll_shutdown(cx),
EitherIo::B(val) => val.poll_shutdown(cx),
}
}
#[project]
fn poll_write_buf<U: Buf>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut U,
) -> Poll<Result<usize, io::Error>>
where where
Self: Sized, Self: Sized,
{ {
match self { #[project]
EitherIo::A(ref mut val) => val.write_buf(buf), match self.project() {
EitherIo::B(ref mut val) => val.write_buf(buf), EitherIo::A(val) => val.poll_write_buf(cx, buf),
EitherIo::B(val) => val.poll_write_buf(cx, buf),
} }
} }
} }

View File

@@ -1,37 +1,41 @@
use std::fmt; use std::fmt;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_connect::{ use actix_connect::{
default_connector, Connect as TcpConnect, Connection as TcpConnection, default_connector, Connect as TcpConnect, Connection as TcpConnection,
}; };
use actix_service::{apply_fn, Service, ServiceExt}; use actix_service::{apply_fn, Service};
use actix_utils::timeout::{TimeoutError, TimeoutService}; use actix_utils::timeout::{TimeoutError, TimeoutService};
use futures::future::Ready;
use http::Uri; use http::Uri;
use tokio_tcp::TcpStream; use tokio_net::tcp::TcpStream;
use super::connection::Connection; use super::connection::Connection;
use super::error::ConnectError; use super::error::ConnectError;
use super::pool::{ConnectionPool, Protocol}; use super::pool::{ConnectionPool, Protocol};
use super::Connect; use super::Connect;
#[cfg(feature = "ssl")] #[cfg(feature = "openssl")]
use openssl::ssl::SslConnector as OpensslConnector; use open_ssl::ssl::SslConnector as OpensslConnector;
#[cfg(feature = "rust-tls")] #[cfg(feature = "rustls")]
use rustls::ClientConfig; use rust_tls::ClientConfig;
#[cfg(feature = "rust-tls")] #[cfg(feature = "rustls")]
use std::sync::Arc; use std::sync::Arc;
#[cfg(any(feature = "ssl", feature = "rust-tls"))] #[cfg(any(feature = "openssl", feature = "rustls"))]
enum SslConnector { enum SslConnector {
#[cfg(feature = "ssl")] #[cfg(feature = "openssl")]
Openssl(OpensslConnector), Openssl(OpensslConnector),
#[cfg(feature = "rust-tls")] #[cfg(feature = "rustls")]
Rustls(Arc<ClientConfig>), Rustls(Arc<ClientConfig>),
} }
#[cfg(not(any(feature = "ssl", feature = "rust-tls")))] #[cfg(not(any(feature = "openssl", feature = "rustls")))]
type SslConnector = (); type SslConnector = ();
/// Manages http client network connectivity /// Manages http client network connectivity
@@ -58,8 +62,8 @@ pub struct Connector<T, U> {
_t: PhantomData<U>, _t: PhantomData<U>,
} }
trait Io: AsyncRead + AsyncWrite {} trait Io: AsyncRead + AsyncWrite + Unpin {}
impl<T: AsyncRead + AsyncWrite> Io for T {} impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
impl Connector<(), ()> { impl Connector<(), ()> {
#[allow(clippy::new_ret_no_self)] #[allow(clippy::new_ret_no_self)]
@@ -72,9 +76,9 @@ impl Connector<(), ()> {
TcpStream, TcpStream,
> { > {
let ssl = { let ssl = {
#[cfg(feature = "ssl")] #[cfg(feature = "openssl")]
{ {
use openssl::ssl::SslMethod; use open_ssl::ssl::SslMethod;
let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap(); let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap();
let _ = ssl let _ = ssl
@@ -82,7 +86,7 @@ impl Connector<(), ()> {
.map_err(|e| error!("Can not set alpn protocol: {:?}", e)); .map_err(|e| error!("Can not set alpn protocol: {:?}", e));
SslConnector::Openssl(ssl.build()) SslConnector::Openssl(ssl.build())
} }
#[cfg(all(not(feature = "ssl"), feature = "rust-tls"))] #[cfg(all(not(feature = "openssl"), feature = "rustls"))]
{ {
let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let mut config = ClientConfig::new(); let mut config = ClientConfig::new();
@@ -92,7 +96,7 @@ impl Connector<(), ()> {
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
SslConnector::Rustls(Arc::new(config)) SslConnector::Rustls(Arc::new(config))
} }
#[cfg(not(any(feature = "ssl", feature = "rust-tls")))] #[cfg(not(any(feature = "openssl", feature = "rustls")))]
{} {}
}; };
@@ -113,7 +117,7 @@ impl<T, U> Connector<T, U> {
/// Use custom connector. /// Use custom connector.
pub fn connector<T1, U1>(self, connector: T1) -> Connector<T1, U1> pub fn connector<T1, U1>(self, connector: T1) -> Connector<T1, U1>
where where
U1: AsyncRead + AsyncWrite + fmt::Debug, U1: AsyncRead + AsyncWrite + Unpin + fmt::Debug,
T1: Service< T1: Service<
Request = TcpConnect<Uri>, Request = TcpConnect<Uri>,
Response = TcpConnection<Uri, U1>, Response = TcpConnection<Uri, U1>,
@@ -135,7 +139,7 @@ impl<T, U> Connector<T, U> {
impl<T, U> Connector<T, U> impl<T, U> Connector<T, U>
where where
U: AsyncRead + AsyncWrite + fmt::Debug + 'static, U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
T: Service< T: Service<
Request = TcpConnect<Uri>, Request = TcpConnect<Uri>,
Response = TcpConnection<Uri, U>, Response = TcpConnection<Uri, U>,
@@ -150,14 +154,14 @@ where
self self
} }
#[cfg(feature = "ssl")] #[cfg(feature = "openssl")]
/// Use custom `SslConnector` instance. /// Use custom `SslConnector` instance.
pub fn ssl(mut self, connector: OpensslConnector) -> Self { pub fn ssl(mut self, connector: OpensslConnector) -> Self {
self.ssl = SslConnector::Openssl(connector); self.ssl = SslConnector::Openssl(connector);
self self
} }
#[cfg(feature = "rust-tls")] #[cfg(feature = "rustls")]
pub fn rustls(mut self, connector: Arc<ClientConfig>) -> Self { pub fn rustls(mut self, connector: Arc<ClientConfig>) -> Self {
self.ssl = SslConnector::Rustls(connector); self.ssl = SslConnector::Rustls(connector);
self self
@@ -212,8 +216,8 @@ where
pub fn finish( pub fn finish(
self, self,
) -> impl Service<Request = Connect, Response = impl Connection, Error = ConnectError> ) -> impl Service<Request = Connect, Response = impl Connection, Error = ConnectError>
+ Clone { + Clone {
#[cfg(not(any(feature = "ssl", feature = "rust-tls")))] #[cfg(not(any(feature = "openssl", feature = "rustls")))]
{ {
let connector = TimeoutService::new( let connector = TimeoutService::new(
self.timeout, self.timeout,
@@ -238,32 +242,32 @@ where
), ),
} }
} }
#[cfg(any(feature = "ssl", feature = "rust-tls"))] #[cfg(any(feature = "openssl", feature = "rustls"))]
{ {
const H2: &[u8] = b"h2"; const H2: &[u8] = b"h2";
#[cfg(feature = "ssl")] #[cfg(feature = "openssl")]
use actix_connect::ssl::OpensslConnector; use actix_connect::ssl::OpensslConnector;
#[cfg(feature = "rust-tls")] #[cfg(feature = "rustls")]
use actix_connect::ssl::RustlsConnector; use actix_connect::ssl::RustlsConnector;
use actix_service::boxed::service; use actix_service::{boxed::service, pipeline};
#[cfg(feature = "rust-tls")] #[cfg(feature = "rustls")]
use rustls::Session; use rust_tls::Session;
let ssl_service = TimeoutService::new( let ssl_service = TimeoutService::new(
self.timeout, self.timeout,
apply_fn(self.connector.clone(), |msg: Connect, srv| { pipeline(
srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) apply_fn(self.connector.clone(), |msg: Connect, srv| {
}) srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
.map_err(ConnectError::from) })
.map_err(ConnectError::from),
)
.and_then(match self.ssl { .and_then(match self.ssl {
#[cfg(feature = "ssl")] #[cfg(feature = "openssl")]
SslConnector::Openssl(ssl) => service( SslConnector::Openssl(ssl) => service(
OpensslConnector::service(ssl) OpensslConnector::service(ssl)
.map_err(ConnectError::from)
.map(|stream| { .map(|stream| {
let sock = stream.into_parts().0; let sock = stream.into_parts().0;
let h2 = sock let h2 = sock
.get_ref()
.ssl() .ssl()
.selected_alpn_protocol() .selected_alpn_protocol()
.map(|protos| protos.windows(2).any(|w| w == H2)) .map(|protos| protos.windows(2).any(|w| w == H2))
@@ -273,9 +277,10 @@ where
} else { } else {
(Box::new(sock) as Box<dyn Io>, Protocol::Http1) (Box::new(sock) as Box<dyn Io>, Protocol::Http1)
} }
}), })
.map_err(ConnectError::from),
), ),
#[cfg(feature = "rust-tls")] #[cfg(feature = "rustls")]
SslConnector::Rustls(ssl) => service( SslConnector::Rustls(ssl) => service(
RustlsConnector::service(ssl) RustlsConnector::service(ssl)
.map_err(ConnectError::from) .map_err(ConnectError::from)
@@ -303,7 +308,7 @@ where
let tcp_service = TimeoutService::new( let tcp_service = TimeoutService::new(
self.timeout, self.timeout,
apply_fn(self.connector.clone(), |msg: Connect, srv| { apply_fn(self.connector, |msg: Connect, srv| {
srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
}) })
.map_err(ConnectError::from) .map_err(ConnectError::from)
@@ -334,19 +339,20 @@ where
} }
} }
#[cfg(not(any(feature = "ssl", feature = "rust-tls")))] #[cfg(not(any(feature = "openssl", feature = "rustls")))]
mod connect_impl { mod connect_impl {
use futures::future::{err, Either, FutureResult}; use std::task::{Context, Poll};
use futures::Poll;
use futures::future::{err, Either, Ready};
use futures::ready;
use super::*; use super::*;
use crate::client::connection::IoConnection; use crate::client::connection::IoConnection;
pub(crate) struct InnerConnector<T, Io> pub(crate) struct InnerConnector<T, Io>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
pub(crate) tcp_pool: ConnectionPool<T, Io>, pub(crate) tcp_pool: ConnectionPool<T, Io>,
@@ -354,9 +360,8 @@ mod connect_impl {
impl<T, Io> Clone for InnerConnector<T, Io> impl<T, Io> Clone for InnerConnector<T, Io>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
@@ -368,9 +373,8 @@ mod connect_impl {
impl<T, Io> Service for InnerConnector<T, Io> impl<T, Io> Service for InnerConnector<T, Io>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
type Request = Connect; type Request = Connect;
@@ -378,38 +382,38 @@ mod connect_impl {
type Error = ConnectError; type Error = ConnectError;
type Future = Either< type Future = Either<
<ConnectionPool<T, Io> as Service>::Future, <ConnectionPool<T, Io> as Service>::Future,
FutureResult<IoConnection<Io>, ConnectError>, Ready<Result<IoConnection<Io>, ConnectError>>,
>; >;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.tcp_pool.poll_ready() self.tcp_pool.poll_ready(cx)
} }
fn call(&mut self, req: Connect) -> Self::Future { fn call(&mut self, req: Connect) -> Self::Future {
match req.uri.scheme_str() { match req.uri.scheme_str() {
Some("https") | Some("wss") => { Some("https") | Some("wss") => {
Either::B(err(ConnectError::SslIsNotSupported)) Either::Right(err(ConnectError::SslIsNotSupported))
} }
_ => Either::A(self.tcp_pool.call(req)), _ => Either::Left(self.tcp_pool.call(req)),
} }
} }
} }
} }
#[cfg(any(feature = "ssl", feature = "rust-tls"))] #[cfg(any(feature = "openssl", feature = "rustls"))]
mod connect_impl { mod connect_impl {
use std::marker::PhantomData; use std::marker::PhantomData;
use futures::future::{Either, FutureResult}; use futures::future::Either;
use futures::{Async, Future, Poll}; use futures::ready;
use super::*; use super::*;
use crate::client::connection::EitherConnection; use crate::client::connection::EitherConnection;
pub(crate) struct InnerConnector<T1, T2, Io1, Io2> pub(crate) struct InnerConnector<T1, T2, Io1, Io2>
where where
Io1: AsyncRead + AsyncWrite + 'static, Io1: AsyncRead + AsyncWrite + Unpin + 'static,
Io2: AsyncRead + AsyncWrite + 'static, Io2: AsyncRead + AsyncWrite + Unpin + 'static,
T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>, T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>,
T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>, T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>,
{ {
@@ -419,13 +423,11 @@ mod connect_impl {
impl<T1, T2, Io1, Io2> Clone for InnerConnector<T1, T2, Io1, Io2> impl<T1, T2, Io1, Io2> Clone for InnerConnector<T1, T2, Io1, Io2>
where where
Io1: AsyncRead + AsyncWrite + 'static, Io1: AsyncRead + AsyncWrite + Unpin + 'static,
Io2: AsyncRead + AsyncWrite + 'static, Io2: AsyncRead + AsyncWrite + Unpin + 'static,
T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError> T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError> T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
@@ -438,53 +440,47 @@ mod connect_impl {
impl<T1, T2, Io1, Io2> Service for InnerConnector<T1, T2, Io1, Io2> impl<T1, T2, Io1, Io2> Service for InnerConnector<T1, T2, Io1, Io2>
where where
Io1: AsyncRead + AsyncWrite + 'static, Io1: AsyncRead + AsyncWrite + Unpin + 'static,
Io2: AsyncRead + AsyncWrite + 'static, Io2: AsyncRead + AsyncWrite + Unpin + 'static,
T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError> T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError> T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
type Request = Connect; type Request = Connect;
type Response = EitherConnection<Io1, Io2>; type Response = EitherConnection<Io1, Io2>;
type Error = ConnectError; type Error = ConnectError;
type Future = Either< type Future = Either<
FutureResult<Self::Response, Self::Error>, InnerConnectorResponseA<T1, Io1, Io2>,
Either< InnerConnectorResponseB<T2, Io1, Io2>,
InnerConnectorResponseA<T1, Io1, Io2>,
InnerConnectorResponseB<T2, Io1, Io2>,
>,
>; >;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.tcp_pool.poll_ready() self.tcp_pool.poll_ready(cx)
} }
fn call(&mut self, req: Connect) -> Self::Future { fn call(&mut self, req: Connect) -> Self::Future {
match req.uri.scheme_str() { match req.uri.scheme_str() {
Some("https") | Some("wss") => { Some("https") | Some("wss") => Either::Right(InnerConnectorResponseB {
Either::B(Either::B(InnerConnectorResponseB { fut: self.ssl_pool.call(req),
fut: self.ssl_pool.call(req), _t: PhantomData,
_t: PhantomData, }),
})) _ => Either::Left(InnerConnectorResponseA {
}
_ => Either::B(Either::A(InnerConnectorResponseA {
fut: self.tcp_pool.call(req), fut: self.tcp_pool.call(req),
_t: PhantomData, _t: PhantomData,
})), }),
} }
} }
} }
#[pin_project::pin_project]
pub(crate) struct InnerConnectorResponseA<T, Io1, Io2> pub(crate) struct InnerConnectorResponseA<T, Io1, Io2>
where where
Io1: AsyncRead + AsyncWrite + 'static, Io1: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
#[pin]
fut: <ConnectionPool<T, Io1> as Service>::Future, fut: <ConnectionPool<T, Io1> as Service>::Future,
_t: PhantomData<Io2>, _t: PhantomData<Io2>,
} }
@@ -492,29 +488,28 @@ mod connect_impl {
impl<T, Io1, Io2> Future for InnerConnectorResponseA<T, Io1, Io2> impl<T, Io1, Io2> Future for InnerConnectorResponseA<T, Io1, Io2>
where where
T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
Io1: AsyncRead + AsyncWrite + 'static, Io1: AsyncRead + AsyncWrite + Unpin + 'static,
Io2: AsyncRead + AsyncWrite + 'static, Io2: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
type Item = EitherConnection<Io1, Io2>; type Output = Result<EitherConnection<Io1, Io2>, ConnectError>;
type Error = ConnectError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match self.fut.poll()? { Poll::Ready(
Async::NotReady => Ok(Async::NotReady), ready!(Pin::new(&mut self.get_mut().fut).poll(cx))
Async::Ready(res) => Ok(Async::Ready(EitherConnection::A(res))), .map(|res| EitherConnection::A(res)),
} )
} }
} }
#[pin_project::pin_project]
pub(crate) struct InnerConnectorResponseB<T, Io1, Io2> pub(crate) struct InnerConnectorResponseB<T, Io1, Io2>
where where
Io2: AsyncRead + AsyncWrite + 'static, Io2: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
#[pin]
fut: <ConnectionPool<T, Io2> as Service>::Future, fut: <ConnectionPool<T, Io2> as Service>::Future,
_t: PhantomData<Io1>, _t: PhantomData<Io1>,
} }
@@ -522,19 +517,17 @@ mod connect_impl {
impl<T, Io1, Io2> Future for InnerConnectorResponseB<T, Io1, Io2> impl<T, Io1, Io2> Future for InnerConnectorResponseB<T, Io1, Io2>
where where
T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
Io1: AsyncRead + AsyncWrite + 'static, Io1: AsyncRead + AsyncWrite + Unpin + 'static,
Io2: AsyncRead + AsyncWrite + 'static, Io2: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
type Item = EitherConnection<Io1, Io2>; type Output = Result<EitherConnection<Io1, Io2>, ConnectError>;
type Error = ConnectError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match self.fut.poll()? { Poll::Ready(
Async::NotReady => Ok(Async::NotReady), ready!(Pin::new(&mut self.get_mut().fut).poll(cx))
Async::Ready(res) => Ok(Async::Ready(EitherConnection::B(res))), .map(|res| EitherConnection::B(res)),
} )
} }
} }
} }

View File

@@ -3,8 +3,8 @@ use std::io;
use derive_more::{Display, From}; use derive_more::{Display, From};
use trust_dns_resolver::error::ResolveError; use trust_dns_resolver::error::ResolveError;
#[cfg(feature = "ssl")] #[cfg(feature = "openssl")]
use openssl::ssl::{Error as SslError, HandshakeError}; use open_ssl::ssl::{Error as SslError, HandshakeError};
use crate::error::{Error, ParseError, ResponseError}; use crate::error::{Error, ParseError, ResponseError};
use crate::http::Error as HttpError; use crate::http::Error as HttpError;
@@ -18,7 +18,7 @@ pub enum ConnectError {
SslIsNotSupported, SslIsNotSupported,
/// SSL error /// SSL error
#[cfg(feature = "ssl")] #[cfg(feature = "openssl")]
#[display(fmt = "{}", _0)] #[display(fmt = "{}", _0)]
SslError(SslError), SslError(SslError),
@@ -63,7 +63,7 @@ impl From<actix_connect::ConnectError> for ConnectError {
} }
} }
#[cfg(feature = "ssl")] #[cfg(feature = "openssl")]
impl<T> From<HandshakeError<T>> for ConnectError { impl<T> From<HandshakeError<T>> for ConnectError {
fn from(err: HandshakeError<T>) -> ConnectError { fn from(err: HandshakeError<T>) -> ConnectError {
match err { match err {

View File

@@ -1,10 +1,13 @@
use std::future::Future;
use std::io::Write; use std::io::Write;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, time}; use std::{io, time};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use bytes::{BufMut, Bytes, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};
use futures::future::{ok, Either}; use futures::future::{ok, poll_fn, Either};
use futures::{Async, Future, Poll, Sink, Stream}; use futures::{Sink, SinkExt, Stream, StreamExt};
use crate::error::PayloadError; use crate::error::PayloadError;
use crate::h1; use crate::h1;
@@ -18,15 +21,15 @@ use super::error::{ConnectError, SendRequestError};
use super::pool::Acquired; use super::pool::Acquired;
use crate::body::{BodySize, MessageBody}; use crate::body::{BodySize, MessageBody};
pub(crate) fn send_request<T, B>( pub(crate) async fn send_request<T, B>(
io: T, io: T,
mut head: RequestHeadType, mut head: RequestHeadType,
body: B, body: B,
created: time::Instant, created: time::Instant,
pool: Option<Acquired<T>>, pool: Option<Acquired<T>>,
) -> impl Future<Item = (ResponseHead, Payload), Error = SendRequestError> ) -> Result<(ResponseHead, Payload), SendRequestError>
where where
T: AsyncRead + AsyncWrite + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
B: MessageBody, B: MessageBody,
{ {
// set request host header // set request host header
@@ -62,68 +65,99 @@ where
io: Some(io), io: Some(io),
}; };
let len = body.size();
// create Framed and send request // create Framed and send request
Framed::new(io, h1::ClientCodec::default()) let mut framed = Framed::new(io, h1::ClientCodec::default());
.send((head, len).into()) framed.send((head, body.size()).into()).await?;
.from_err()
// send request body // send request body
.and_then(move |framed| match body.size() { match body.size() {
BodySize::None | BodySize::Empty | BodySize::Sized(0) => { BodySize::None | BodySize::Empty | BodySize::Sized(0) => (),
Either::A(ok(framed)) _ => send_body(body, &mut framed).await?,
} };
_ => Either::B(SendBody::new(body, framed)),
}) // read response and init read body
// read response and init read body let res = framed.into_future().await;
.and_then(|framed| { let (head, framed) = if let (Some(result), framed) = res {
framed let item = result.map_err(SendRequestError::from)?;
.into_future() (item, framed)
.map_err(|(e, _)| SendRequestError::from(e)) } else {
.and_then(|(item, framed)| { return Err(SendRequestError::from(ConnectError::Disconnected));
if let Some(res) = item { };
match framed.get_codec().message_type() {
h1::MessageType::None => { match framed.get_codec().message_type() {
let force_close = !framed.get_codec().keepalive(); h1::MessageType::None => {
release_connection(framed, force_close); let force_close = !framed.get_codec().keepalive();
Ok((res, Payload::None)) release_connection(framed, force_close);
} Ok((head, Payload::None))
_ => { }
let pl: PayloadStream = Box::new(PlStream::new(framed)); _ => {
Ok((res, pl.into())) let pl: PayloadStream = PlStream::new(framed).boxed_local();
} Ok((head, pl.into()))
} }
} else { }
Err(ConnectError::Disconnected.into())
}
})
})
} }
pub(crate) fn open_tunnel<T>( pub(crate) async fn open_tunnel<T>(
io: T, io: T,
head: RequestHeadType, head: RequestHeadType,
) -> impl Future<Item = (ResponseHead, Framed<T, h1::ClientCodec>), Error = SendRequestError> ) -> Result<(ResponseHead, Framed<T, h1::ClientCodec>), SendRequestError>
where where
T: AsyncRead + AsyncWrite + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
// create Framed and send request // create Framed and send request
Framed::new(io, h1::ClientCodec::default()) let mut framed = Framed::new(io, h1::ClientCodec::default());
.send((head, BodySize::None).into()) framed.send((head, BodySize::None).into()).await?;
.from_err()
// read response // read response
.and_then(|framed| { if let (Some(result), framed) = framed.into_future().await {
framed let head = result.map_err(SendRequestError::from)?;
.into_future() Ok((head, framed))
.map_err(|(e, _)| SendRequestError::from(e)) } else {
.and_then(|(head, framed)| { Err(SendRequestError::from(ConnectError::Disconnected))
if let Some(head) = head { }
Ok((head, framed)) }
/// send request body to the peer
pub(crate) async fn send_body<I, B>(
mut body: B,
framed: &mut Framed<I, h1::ClientCodec>,
) -> Result<(), SendRequestError>
where
I: ConnectionLifetime,
B: MessageBody,
{
let mut eof = false;
while !eof {
while !eof && !framed.is_write_buf_full() {
match poll_fn(|cx| body.poll_next(cx)).await {
Some(result) => {
framed.write(h1::Message::Chunk(Some(result?)))?;
}
None => {
eof = true;
framed.write(h1::Message::Chunk(None))?;
}
}
}
if !framed.is_write_buf_empty() {
poll_fn(|cx| match framed.flush(cx) {
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => {
if !framed.is_write_buf_full() {
Poll::Ready(Ok(()))
} else { } else {
Err(SendRequestError::from(ConnectError::Disconnected)) Poll::Pending
} }
}) }
}) })
.await?;
}
}
SinkExt::flush(framed).await?;
Ok(())
} }
#[doc(hidden)] #[doc(hidden)]
@@ -134,7 +168,10 @@ pub struct H1Connection<T> {
pool: Option<Acquired<T>>, pool: Option<Acquired<T>>,
} }
impl<T: AsyncRead + AsyncWrite + 'static> ConnectionLifetime for H1Connection<T> { impl<T> ConnectionLifetime for H1Connection<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
/// Close connection /// Close connection
fn close(&mut self) { fn close(&mut self) {
if let Some(mut pool) = self.pool.take() { if let Some(mut pool) = self.pool.take() {
@@ -162,98 +199,41 @@ impl<T: AsyncRead + AsyncWrite + 'static> ConnectionLifetime for H1Connection<T>
} }
} }
impl<T: AsyncRead + AsyncWrite + 'static> io::Read for H1Connection<T> { impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncRead for H1Connection<T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.io.as_mut().unwrap().read(buf) self.io.as_ref().unwrap().prepare_uninitialized_buffer(buf)
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io.as_mut().unwrap()).poll_read(cx, buf)
} }
} }
impl<T: AsyncRead + AsyncWrite + 'static> AsyncRead for H1Connection<T> {} impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncWrite for H1Connection<T> {
fn poll_write(
impl<T: AsyncRead + AsyncWrite + 'static> io::Write for H1Connection<T> { mut self: Pin<&mut Self>,
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { cx: &mut Context<'_>,
self.io.as_mut().unwrap().write(buf) buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io.as_mut().unwrap()).poll_write(cx, buf)
} }
fn flush(&mut self) -> io::Result<()> { fn poll_flush(
self.io.as_mut().unwrap().flush() mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Pin::new(self.io.as_mut().unwrap()).poll_flush(cx)
} }
}
impl<T: AsyncRead + AsyncWrite + 'static> AsyncWrite for H1Connection<T> { fn poll_shutdown(
fn shutdown(&mut self) -> Poll<(), io::Error> { mut self: Pin<&mut Self>,
self.io.as_mut().unwrap().shutdown() cx: &mut Context,
} ) -> Poll<Result<(), io::Error>> {
} Pin::new(self.io.as_mut().unwrap()).poll_shutdown(cx)
/// Future responsible for sending request body to the peer
pub(crate) struct SendBody<I, B> {
body: Option<B>,
framed: Option<Framed<I, h1::ClientCodec>>,
flushed: bool,
}
impl<I, B> SendBody<I, B>
where
I: AsyncRead + AsyncWrite + 'static,
B: MessageBody,
{
pub(crate) fn new(body: B, framed: Framed<I, h1::ClientCodec>) -> Self {
SendBody {
body: Some(body),
framed: Some(framed),
flushed: true,
}
}
}
impl<I, B> Future for SendBody<I, B>
where
I: ConnectionLifetime,
B: MessageBody,
{
type Item = Framed<I, h1::ClientCodec>;
type Error = SendRequestError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut body_ready = true;
loop {
while body_ready
&& self.body.is_some()
&& !self.framed.as_ref().unwrap().is_write_buf_full()
{
match self.body.as_mut().unwrap().poll_next()? {
Async::Ready(item) => {
// check if body is done
if item.is_none() {
let _ = self.body.take();
}
self.flushed = false;
self.framed
.as_mut()
.unwrap()
.force_send(h1::Message::Chunk(item))?;
break;
}
Async::NotReady => body_ready = false,
}
}
if !self.flushed {
match self.framed.as_mut().unwrap().poll_complete()? {
Async::Ready(_) => {
self.flushed = true;
continue;
}
Async::NotReady => return Ok(Async::NotReady),
}
}
if self.body.is_none() {
return Ok(Async::Ready(self.framed.take().unwrap()));
}
return Ok(Async::NotReady);
}
} }
} }
@@ -270,23 +250,24 @@ impl<Io: ConnectionLifetime> PlStream<Io> {
} }
impl<Io: ConnectionLifetime> Stream for PlStream<Io> { impl<Io: ConnectionLifetime> Stream for PlStream<Io> {
type Item = Bytes; type Item = Result<Bytes, PayloadError>;
type Error = PayloadError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match self.framed.as_mut().unwrap().poll()? { let this = self.get_mut();
Async::NotReady => Ok(Async::NotReady),
Async::Ready(Some(chunk)) => { match this.framed.as_mut().unwrap().next_item(cx)? {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(chunk)) => {
if let Some(chunk) = chunk { if let Some(chunk) = chunk {
Ok(Async::Ready(Some(chunk))) Poll::Ready(Some(Ok(chunk)))
} else { } else {
let framed = self.framed.take().unwrap(); let framed = this.framed.take().unwrap();
let force_close = !framed.get_codec().keepalive(); let force_close = !framed.get_codec().keepalive();
release_connection(framed, force_close); release_connection(framed, force_close);
Ok(Async::Ready(None)) Poll::Ready(None)
} }
} }
Async::Ready(None) => Ok(Async::Ready(None)), Poll::Ready(None) => Poll::Ready(None),
} }
} }
} }

View File

@@ -1,9 +1,11 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time; use std::time;
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use bytes::Bytes; use bytes::Bytes;
use futures::future::{err, Either}; use futures::future::{err, poll_fn, Either};
use futures::{Async, Future, Poll};
use h2::{client::SendRequest, SendStream}; use h2::{client::SendRequest, SendStream};
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING}; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING};
use http::{request::Request, HttpTryFrom, Method, Version}; use http::{request::Request, HttpTryFrom, Method, Version};
@@ -17,15 +19,15 @@ use super::connection::{ConnectionType, IoConnection};
use super::error::SendRequestError; use super::error::SendRequestError;
use super::pool::Acquired; use super::pool::Acquired;
pub(crate) fn send_request<T, B>( pub(crate) async fn send_request<T, B>(
io: SendRequest<Bytes>, mut io: SendRequest<Bytes>,
head: RequestHeadType, head: RequestHeadType,
body: B, body: B,
created: time::Instant, created: time::Instant,
pool: Option<Acquired<T>>, pool: Option<Acquired<T>>,
) -> impl Future<Item = (ResponseHead, Payload), Error = SendRequestError> ) -> Result<(ResponseHead, Payload), SendRequestError>
where where
T: AsyncRead + AsyncWrite + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
B: MessageBody, B: MessageBody,
{ {
trace!("Sending client request: {:?} {:?}", head, body.size()); trace!("Sending client request: {:?} {:?}", head, body.size());
@@ -36,158 +38,140 @@ where
_ => false, _ => false,
}; };
io.ready() let mut req = Request::new(());
.map_err(SendRequestError::from) *req.uri_mut() = head.as_ref().uri.clone();
.and_then(move |mut io| { *req.method_mut() = head.as_ref().method.clone();
let mut req = Request::new(()); *req.version_mut() = Version::HTTP_2;
*req.uri_mut() = head.as_ref().uri.clone();
*req.method_mut() = head.as_ref().method.clone();
*req.version_mut() = Version::HTTP_2;
let mut skip_len = true; let mut skip_len = true;
// let mut has_date = false; // let mut has_date = false;
// Content length // Content length
let _ = match length { let _ = match length {
BodySize::None => None, BodySize::None => None,
BodySize::Stream => { BodySize::Stream => {
skip_len = false; skip_len = false;
None None
} }
BodySize::Empty => req BodySize::Empty => req
.headers_mut() .headers_mut()
.insert(CONTENT_LENGTH, HeaderValue::from_static("0")), .insert(CONTENT_LENGTH, HeaderValue::from_static("0")),
BodySize::Sized(len) => req.headers_mut().insert( BodySize::Sized(len) => req.headers_mut().insert(
CONTENT_LENGTH, CONTENT_LENGTH,
HeaderValue::try_from(format!("{}", len)).unwrap(), HeaderValue::try_from(format!("{}", len)).unwrap(),
), ),
BodySize::Sized64(len) => req.headers_mut().insert( BodySize::Sized64(len) => req.headers_mut().insert(
CONTENT_LENGTH, CONTENT_LENGTH,
HeaderValue::try_from(format!("{}", len)).unwrap(), HeaderValue::try_from(format!("{}", len)).unwrap(),
), ),
}; };
// Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate. // Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate.
let (head, extra_headers) = match head { let (head, extra_headers) = match head {
RequestHeadType::Owned(head) => { RequestHeadType::Owned(head) => (RequestHeadType::Owned(head), HeaderMap::new()),
(RequestHeadType::Owned(head), HeaderMap::new()) RequestHeadType::Rc(head, extra_headers) => (
} RequestHeadType::Rc(head, None),
RequestHeadType::Rc(head, extra_headers) => ( extra_headers.unwrap_or_else(HeaderMap::new),
RequestHeadType::Rc(head, None), ),
extra_headers.unwrap_or_else(HeaderMap::new), };
),
};
// merging headers from head and extra headers. // merging headers from head and extra headers.
let headers = head let headers = head
.as_ref() .as_ref()
.headers .headers
.iter() .iter()
.filter(|(name, _)| !extra_headers.contains_key(*name)) .filter(|(name, _)| !extra_headers.contains_key(*name))
.chain(extra_headers.iter()); .chain(extra_headers.iter());
// copy headers // copy headers
for (key, value) in headers { for (key, value) in headers {
match *key { match *key {
CONNECTION | TRANSFER_ENCODING => continue, // http2 specific CONNECTION | TRANSFER_ENCODING => continue, // http2 specific
CONTENT_LENGTH if skip_len => continue, CONTENT_LENGTH if skip_len => continue,
// DATE => has_date = true, // DATE => has_date = true,
_ => (), _ => (),
} }
req.headers_mut().append(key, value.clone()); req.headers_mut().append(key, value.clone());
}
let res = poll_fn(|cx| io.poll_ready(cx)).await;
if let Err(e) = res {
release(io, pool, created, e.is_io());
return Err(SendRequestError::from(e));
}
let resp = match io.send_request(req, eof) {
Ok((fut, send)) => {
release(io, pool, created, false);
if !eof {
send_body(body, send).await?;
} }
fut.await.map_err(SendRequestError::from)?
}
Err(e) => {
release(io, pool, created, e.is_io());
return Err(e.into());
}
};
match io.send_request(req, eof) { let (parts, body) = resp.into_parts();
Ok((res, send)) => { let payload = if head_req { Payload::None } else { body.into() };
release(io, pool, created, false);
if !eof { let mut head = ResponseHead::new(parts.status);
Either::A(Either::B( head.version = parts.version;
SendBody { head.headers = parts.headers.into();
body, Ok((head, payload))
send,
buf: None,
}
.and_then(move |_| res.map_err(SendRequestError::from)),
))
} else {
Either::B(res.map_err(SendRequestError::from))
}
}
Err(e) => {
release(io, pool, created, e.is_io());
Either::A(Either::A(err(e.into())))
}
}
})
.and_then(move |resp| {
let (parts, body) = resp.into_parts();
let payload = if head_req { Payload::None } else { body.into() };
let mut head = ResponseHead::new(parts.status);
head.version = parts.version;
head.headers = parts.headers.into();
Ok((head, payload))
})
.from_err()
} }
struct SendBody<B: MessageBody> { async fn send_body<B: MessageBody>(
body: B, mut body: B,
send: SendStream<Bytes>, mut send: SendStream<Bytes>,
buf: Option<Bytes>, ) -> Result<(), SendRequestError> {
} let mut buf = None;
loop {
impl<B: MessageBody> Future for SendBody<B> { if buf.is_none() {
type Item = (); match poll_fn(|cx| body.poll_next(cx)).await {
type Error = SendRequestError; Some(Ok(b)) => {
send.reserve_capacity(b.len());
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { buf = Some(b);
loop {
if self.buf.is_none() {
match self.body.poll_next() {
Ok(Async::Ready(Some(buf))) => {
self.send.reserve_capacity(buf.len());
self.buf = Some(buf);
}
Ok(Async::Ready(None)) => {
if let Err(e) = self.send.send_data(Bytes::new(), true) {
return Err(e.into());
}
self.send.reserve_capacity(0);
return Ok(Async::Ready(()));
}
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(e) => return Err(e.into()),
} }
} Some(Err(e)) => return Err(e.into()),
None => {
match self.send.poll_capacity() { if let Err(e) = send.send_data(Bytes::new(), true) {
Ok(Async::NotReady) => return Ok(Async::NotReady),
Ok(Async::Ready(None)) => return Ok(Async::Ready(())),
Ok(Async::Ready(Some(cap))) => {
let mut buf = self.buf.take().unwrap();
let len = buf.len();
let bytes = buf.split_to(std::cmp::min(cap, len));
if let Err(e) = self.send.send_data(bytes, false) {
return Err(e.into()); return Err(e.into());
} else {
if !buf.is_empty() {
self.send.reserve_capacity(buf.len());
self.buf = Some(buf);
}
continue;
} }
send.reserve_capacity(0);
return Ok(());
} }
Err(e) => return Err(e.into()),
} }
} }
match poll_fn(|cx| send.poll_capacity(cx)).await {
None => return Ok(()),
Some(Ok(cap)) => {
let b = buf.as_mut().unwrap();
let len = b.len();
let bytes = b.split_to(std::cmp::min(cap, len));
if let Err(e) = send.send_data(bytes, false) {
return Err(e.into());
} else {
if !b.is_empty() {
send.reserve_capacity(b.len());
} else {
buf = None;
}
continue;
}
}
Some(Err(e)) => return Err(e.into()),
}
} }
} }
// release SendRequest object // release SendRequest object
fn release<T: AsyncRead + AsyncWrite + 'static>( fn release<T: AsyncRead + AsyncWrite + Unpin + 'static>(
io: SendRequest<Bytes>, io: SendRequest<Bytes>,
pool: Option<Acquired<T>>, pool: Option<Acquired<T>>,
created: time::Instant, created: time::Instant,

View File

@@ -1,22 +1,23 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::future::Future;
use std::io; use std::io;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_service::Service; use actix_service::Service;
use actix_utils::{oneshot, task::LocalWaker};
use bytes::Bytes; use bytes::Bytes;
use futures::future::{err, ok, Either, FutureResult}; use futures::future::{err, ok, poll_fn, Either, FutureExt, LocalBoxFuture, Ready};
use futures::task::AtomicTask; use h2::client::{handshake, Connection, SendRequest};
use futures::unsync::oneshot;
use futures::{Async, Future, Poll};
use h2::client::{handshake, Handshake};
use hashbrown::HashMap; use hashbrown::HashMap;
use http::uri::Authority; use http::uri::Authority;
use indexmap::IndexSet; use indexmap::IndexSet;
use slab::Slab; use slab::Slab;
use tokio_timer::{sleep, Delay}; use tokio_timer::{delay_for, Delay};
use super::connection::{ConnectionType, IoConnection}; use super::connection::{ConnectionType, IoConnection};
use super::error::ConnectError; use super::error::ConnectError;
@@ -41,16 +42,12 @@ impl From<Authority> for Key {
} }
/// Connections pool /// Connections pool
pub(crate) struct ConnectionPool<T, Io: AsyncRead + AsyncWrite + 'static>( pub(crate) struct ConnectionPool<T, Io: 'static>(Rc<RefCell<T>>, Rc<RefCell<Inner<Io>>>);
T,
Rc<RefCell<Inner<Io>>>,
);
impl<T, Io> ConnectionPool<T, Io> impl<T, Io> ConnectionPool<T, Io>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
pub(crate) fn new( pub(crate) fn new(
@@ -61,7 +58,7 @@ where
limit: usize, limit: usize,
) -> Self { ) -> Self {
ConnectionPool( ConnectionPool(
connector, Rc::new(RefCell::new(connector)),
Rc::new(RefCell::new(Inner { Rc::new(RefCell::new(Inner {
conn_lifetime, conn_lifetime,
conn_keep_alive, conn_keep_alive,
@@ -71,7 +68,7 @@ where
waiters: Slab::new(), waiters: Slab::new(),
waiters_queue: IndexSet::new(), waiters_queue: IndexSet::new(),
available: HashMap::new(), available: HashMap::new(),
task: None, waker: LocalWaker::new(),
})), })),
) )
} }
@@ -79,8 +76,7 @@ where
impl<T, Io> Clone for ConnectionPool<T, Io> impl<T, Io> Clone for ConnectionPool<T, Io>
where where
T: Clone, Io: 'static,
Io: AsyncRead + AsyncWrite + 'static,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
ConnectionPool(self.0.clone(), self.1.clone()) ConnectionPool(self.0.clone(), self.1.clone())
@@ -89,86 +85,116 @@ where
impl<T, Io> Service for ConnectionPool<T, Io> impl<T, Io> Service for ConnectionPool<T, Io>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
type Request = Connect; type Request = Connect;
type Response = IoConnection<Io>; type Response = IoConnection<Io>;
type Error = ConnectError; type Error = ConnectError;
type Future = Either< type Future = LocalBoxFuture<'static, Result<IoConnection<Io>, ConnectError>>;
FutureResult<Self::Response, Self::Error>,
Either<WaitForConnection<Io>, OpenConnection<T::Future, Io>>,
>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready() self.0.poll_ready(cx)
} }
fn call(&mut self, req: Connect) -> Self::Future { fn call(&mut self, req: Connect) -> Self::Future {
let key = if let Some(authority) = req.uri.authority_part() { // start support future
authority.clone().into() tokio_executor::current_thread::spawn(ConnectorPoolSupport {
} else { connector: self.0.clone(),
return Either::A(err(ConnectError::Unresolverd)); inner: self.1.clone(),
});
let mut connector = self.0.clone();
let inner = self.1.clone();
let fut = async move {
let key = if let Some(authority) = req.uri.authority_part() {
authority.clone().into()
} else {
return Err(ConnectError::Unresolverd);
};
// acquire connection
match poll_fn(|cx| Poll::Ready(inner.borrow_mut().acquire(&key, cx))).await {
Acquire::Acquired(io, created) => {
// use existing connection
return Ok(IoConnection::new(
io,
created,
Some(Acquired(key, Some(inner))),
));
}
Acquire::Available => {
// open tcp connection
let (io, proto) = connector.call(req).await?;
let guard = OpenGuard::new(key, inner);
if proto == Protocol::Http1 {
Ok(IoConnection::new(
ConnectionType::H1(io),
Instant::now(),
Some(guard.consume()),
))
} else {
let (snd, connection) = handshake(io).await?;
tokio_executor::current_thread::spawn(connection.map(|_| ()));
Ok(IoConnection::new(
ConnectionType::H2(snd),
Instant::now(),
Some(guard.consume()),
))
}
}
_ => {
// connection is not available, wait
let (rx, token) = inner.borrow_mut().wait_for(req);
let guard = WaiterGuard::new(key, token, inner);
let res = match rx.await {
Err(_) => Err(ConnectError::Disconnected),
Ok(res) => res,
};
guard.consume();
res
}
}
}; };
// acquire connection fut.boxed_local()
match self.1.as_ref().borrow_mut().acquire(&key) {
Acquire::Acquired(io, created) => {
// use existing connection
return Either::A(ok(IoConnection::new(
io,
created,
Some(Acquired(key, Some(self.1.clone()))),
)));
}
Acquire::Available => {
// open new connection
return Either::B(Either::B(OpenConnection::new(
key,
self.1.clone(),
self.0.call(req),
)));
}
_ => (),
}
// connection is not available, wait
let (rx, token, support) = self.1.as_ref().borrow_mut().wait_for(req);
// start support future
if !support {
self.1.as_ref().borrow_mut().task = Some(AtomicTask::new());
tokio_current_thread::spawn(ConnectorPoolSupport {
connector: self.0.clone(),
inner: self.1.clone(),
})
}
Either::B(Either::A(WaitForConnection {
rx,
key,
token,
inner: Some(self.1.clone()),
}))
} }
} }
#[doc(hidden)] struct WaiterGuard<Io>
pub struct WaitForConnection<Io>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
key: Key, key: Key,
token: usize, token: usize,
rx: oneshot::Receiver<Result<IoConnection<Io>, ConnectError>>,
inner: Option<Rc<RefCell<Inner<Io>>>>, inner: Option<Rc<RefCell<Inner<Io>>>>,
} }
impl<Io> Drop for WaitForConnection<Io> impl<Io> WaiterGuard<Io>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
fn new(key: Key, token: usize, inner: Rc<RefCell<Inner<Io>>>) -> Self {
Self {
key,
token,
inner: Some(inner),
}
}
fn consume(mut self) {
let _ = self.inner.take();
}
}
impl<Io> Drop for WaiterGuard<Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(i) = self.inner.take() { if let Some(i) = self.inner.take() {
@@ -179,113 +205,43 @@ where
} }
} }
impl<Io> Future for WaitForConnection<Io> struct OpenGuard<Io>
where where
Io: AsyncRead + AsyncWrite, Io: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
type Item = IoConnection<Io>;
type Error = ConnectError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.rx.poll() {
Ok(Async::Ready(item)) => match item {
Err(err) => Err(err),
Ok(conn) => {
let _ = self.inner.take();
Ok(Async::Ready(conn))
}
},
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(_) => {
let _ = self.inner.take();
Err(ConnectError::Disconnected)
}
}
}
}
#[doc(hidden)]
pub struct OpenConnection<F, Io>
where
Io: AsyncRead + AsyncWrite + 'static,
{
fut: F,
key: Key, key: Key,
h2: Option<Handshake<Io, Bytes>>,
inner: Option<Rc<RefCell<Inner<Io>>>>, inner: Option<Rc<RefCell<Inner<Io>>>>,
} }
impl<F, Io> OpenConnection<F, Io> impl<Io> OpenGuard<Io>
where where
F: Future<Item = (Io, Protocol), Error = ConnectError>, Io: AsyncRead + AsyncWrite + Unpin + 'static,
Io: AsyncRead + AsyncWrite + 'static,
{ {
fn new(key: Key, inner: Rc<RefCell<Inner<Io>>>, fut: F) -> Self { fn new(key: Key, inner: Rc<RefCell<Inner<Io>>>) -> Self {
OpenConnection { Self {
key, key,
fut,
inner: Some(inner), inner: Some(inner),
h2: None,
} }
} }
fn consume(mut self) -> Acquired<Io> {
Acquired(self.key.clone(), self.inner.take())
}
} }
impl<F, Io> Drop for OpenConnection<F, Io> impl<Io> Drop for OpenGuard<Io>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(inner) = self.inner.take() { if let Some(i) = self.inner.take() {
let mut inner = inner.as_ref().borrow_mut(); let mut inner = i.as_ref().borrow_mut();
inner.release(); inner.release();
inner.check_availibility(); inner.check_availibility();
} }
} }
} }
impl<F, Io> Future for OpenConnection<F, Io>
where
F: Future<Item = (Io, Protocol), Error = ConnectError>,
Io: AsyncRead + AsyncWrite,
{
type Item = IoConnection<Io>;
type Error = ConnectError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(ref mut h2) = self.h2 {
return match h2.poll() {
Ok(Async::Ready((snd, connection))) => {
tokio_current_thread::spawn(connection.map_err(|_| ()));
Ok(Async::Ready(IoConnection::new(
ConnectionType::H2(snd),
Instant::now(),
Some(Acquired(self.key.clone(), self.inner.take())),
)))
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(e) => Err(e.into()),
};
}
match self.fut.poll() {
Err(err) => Err(err),
Ok(Async::Ready((io, proto))) => {
if proto == Protocol::Http1 {
Ok(Async::Ready(IoConnection::new(
ConnectionType::H1(io),
Instant::now(),
Some(Acquired(self.key.clone(), self.inner.take())),
)))
} else {
self.h2 = Some(handshake(io));
self.poll()
}
}
Ok(Async::NotReady) => Ok(Async::NotReady),
}
}
}
enum Acquire<T> { enum Acquire<T> {
Acquired(ConnectionType<T>, Instant), Acquired(ConnectionType<T>, Instant),
Available, Available,
@@ -312,7 +268,7 @@ pub(crate) struct Inner<Io> {
)>, )>,
>, >,
waiters_queue: IndexSet<(Key, usize)>, waiters_queue: IndexSet<(Key, usize)>,
task: Option<AtomicTask>, waker: LocalWaker,
} }
impl<Io> Inner<Io> { impl<Io> Inner<Io> {
@@ -332,7 +288,7 @@ impl<Io> Inner<Io> {
impl<Io> Inner<Io> impl<Io> Inner<Io>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
/// connection is not available, wait /// connection is not available, wait
fn wait_for( fn wait_for(
@@ -341,7 +297,6 @@ where
) -> ( ) -> (
oneshot::Receiver<Result<IoConnection<Io>, ConnectError>>, oneshot::Receiver<Result<IoConnection<Io>, ConnectError>>,
usize, usize,
bool,
) { ) {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
@@ -351,10 +306,10 @@ where
entry.insert(Some((connect, tx))); entry.insert(Some((connect, tx)));
assert!(self.waiters_queue.insert((key, token))); assert!(self.waiters_queue.insert((key, token)));
(rx, token, self.task.is_some()) (rx, token)
} }
fn acquire(&mut self, key: &Key) -> Acquire<Io> { fn acquire(&mut self, key: &Key, cx: &mut Context) -> Acquire<Io> {
// check limits // check limits
if self.limit > 0 && self.acquired >= self.limit { if self.limit > 0 && self.acquired >= self.limit {
return Acquire::NotAvailable; return Acquire::NotAvailable;
@@ -373,7 +328,7 @@ where
{ {
if let Some(timeout) = self.disconnect_timeout { if let Some(timeout) = self.disconnect_timeout {
if let ConnectionType::H1(io) = conn.io { if let ConnectionType::H1(io) = conn.io {
tokio_current_thread::spawn(CloseConnection::new( tokio_executor::current_thread::spawn(CloseConnection::new(
io, timeout, io, timeout,
)) ))
} }
@@ -382,19 +337,19 @@ where
let mut io = conn.io; let mut io = conn.io;
let mut buf = [0; 2]; let mut buf = [0; 2];
if let ConnectionType::H1(ref mut s) = io { if let ConnectionType::H1(ref mut s) = io {
match s.read(&mut buf) { match Pin::new(s).poll_read(cx, &mut buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), Poll::Pending => (),
Ok(n) if n > 0 => { Poll::Ready(Ok(n)) if n > 0 => {
if let Some(timeout) = self.disconnect_timeout { if let Some(timeout) = self.disconnect_timeout {
if let ConnectionType::H1(io) = io { if let ConnectionType::H1(io) = io {
tokio_current_thread::spawn( tokio_executor::current_thread::spawn(
CloseConnection::new(io, timeout), CloseConnection::new(io, timeout),
) )
} }
} }
continue; continue;
} }
Ok(_) | Err(_) => continue, _ => continue,
} }
} }
return Acquire::Acquired(io, conn.created); return Acquire::Acquired(io, conn.created);
@@ -421,7 +376,7 @@ where
self.acquired -= 1; self.acquired -= 1;
if let Some(timeout) = self.disconnect_timeout { if let Some(timeout) = self.disconnect_timeout {
if let ConnectionType::H1(io) = io { if let ConnectionType::H1(io) = io {
tokio_current_thread::spawn(CloseConnection::new(io, timeout)) tokio_executor::current_thread::spawn(CloseConnection::new(io, timeout))
} }
} }
self.check_availibility(); self.check_availibility();
@@ -429,9 +384,7 @@ where
fn check_availibility(&self) { fn check_availibility(&self) {
if !self.waiters_queue.is_empty() && self.acquired < self.limit { if !self.waiters_queue.is_empty() && self.acquired < self.limit {
if let Some(t) = self.task.as_ref() { self.waker.wake();
t.notify()
}
} }
} }
} }
@@ -443,29 +396,30 @@ struct CloseConnection<T> {
impl<T> CloseConnection<T> impl<T> CloseConnection<T>
where where
T: AsyncWrite, T: AsyncWrite + Unpin,
{ {
fn new(io: T, timeout: Duration) -> Self { fn new(io: T, timeout: Duration) -> Self {
CloseConnection { CloseConnection {
io, io,
timeout: sleep(timeout), timeout: delay_for(timeout),
} }
} }
} }
impl<T> Future for CloseConnection<T> impl<T> Future for CloseConnection<T>
where where
T: AsyncWrite, T: AsyncWrite + Unpin,
{ {
type Item = (); type Output = ();
type Error = ();
fn poll(&mut self) -> Poll<(), ()> { fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
match self.timeout.poll() { let this = self.get_mut();
Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())),
Ok(Async::NotReady) => match self.io.shutdown() { match Pin::new(&mut this.timeout).poll(cx) {
Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())), Poll::Ready(_) => Poll::Ready(()),
Ok(Async::NotReady) => Ok(Async::NotReady), Poll::Pending => match Pin::new(&mut this.io).poll_shutdown(cx) {
Poll::Ready(_) => Poll::Ready(()),
Poll::Pending => Poll::Pending,
}, },
} }
} }
@@ -473,7 +427,7 @@ where
struct ConnectorPoolSupport<T, Io> struct ConnectorPoolSupport<T, Io>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
connector: T, connector: T,
inner: Rc<RefCell<Inner<Io>>>, inner: Rc<RefCell<Inner<Io>>>,
@@ -481,16 +435,17 @@ where
impl<T, Io> Future for ConnectorPoolSupport<T, Io> impl<T, Io> Future for ConnectorPoolSupport<T, Io>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>, T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>,
T::Future: 'static, T::Future: 'static,
{ {
type Item = (); type Output = ();
type Error = ();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut inner = self.inner.as_ref().borrow_mut(); let this = unsafe { self.get_unchecked_mut() };
inner.task.as_ref().unwrap().register();
let mut inner = this.inner.as_ref().borrow_mut();
inner.waker.register(cx.waker());
// check waiters // check waiters
loop { loop {
@@ -505,14 +460,14 @@ where
continue; continue;
} }
match inner.acquire(&key) { match inner.acquire(&key, cx) {
Acquire::NotAvailable => break, Acquire::NotAvailable => break,
Acquire::Acquired(io, created) => { Acquire::Acquired(io, created) => {
let tx = inner.waiters.get_mut(token).unwrap().take().unwrap().1; let tx = inner.waiters.get_mut(token).unwrap().take().unwrap().1;
if let Err(conn) = tx.send(Ok(IoConnection::new( if let Err(conn) = tx.send(Ok(IoConnection::new(
io, io,
created, created,
Some(Acquired(key.clone(), Some(self.inner.clone()))), Some(Acquired(key.clone(), Some(this.inner.clone()))),
))) { ))) {
let (io, created) = conn.unwrap().into_inner(); let (io, created) = conn.unwrap().into_inner();
inner.release_conn(&key, io, created); inner.release_conn(&key, io, created);
@@ -524,33 +479,38 @@ where
OpenWaitingConnection::spawn( OpenWaitingConnection::spawn(
key.clone(), key.clone(),
tx, tx,
self.inner.clone(), this.inner.clone(),
self.connector.call(connect), this.connector.call(connect),
); );
} }
} }
let _ = inner.waiters_queue.swap_remove_index(0); let _ = inner.waiters_queue.swap_remove_index(0);
} }
Ok(Async::NotReady) Poll::Pending
} }
} }
struct OpenWaitingConnection<F, Io> struct OpenWaitingConnection<F, Io>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
fut: F, fut: F,
key: Key, key: Key,
h2: Option<Handshake<Io, Bytes>>, h2: Option<
LocalBoxFuture<
'static,
Result<(SendRequest<Bytes>, Connection<Io, Bytes>), h2::Error>,
>,
>,
rx: Option<oneshot::Sender<Result<IoConnection<Io>, ConnectError>>>, rx: Option<oneshot::Sender<Result<IoConnection<Io>, ConnectError>>>,
inner: Option<Rc<RefCell<Inner<Io>>>>, inner: Option<Rc<RefCell<Inner<Io>>>>,
} }
impl<F, Io> OpenWaitingConnection<F, Io> impl<F, Io> OpenWaitingConnection<F, Io>
where where
F: Future<Item = (Io, Protocol), Error = ConnectError> + 'static, F: Future<Output = Result<(Io, Protocol), ConnectError>> + 'static,
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
fn spawn( fn spawn(
key: Key, key: Key,
@@ -558,7 +518,7 @@ where
inner: Rc<RefCell<Inner<Io>>>, inner: Rc<RefCell<Inner<Io>>>,
fut: F, fut: F,
) { ) {
tokio_current_thread::spawn(OpenWaitingConnection { tokio_executor::current_thread::spawn(OpenWaitingConnection {
key, key,
fut, fut,
h2: None, h2: None,
@@ -570,7 +530,7 @@ where
impl<F, Io> Drop for OpenWaitingConnection<F, Io> impl<F, Io> Drop for OpenWaitingConnection<F, Io>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(inner) = self.inner.take() { if let Some(inner) = self.inner.take() {
@@ -583,59 +543,60 @@ where
impl<F, Io> Future for OpenWaitingConnection<F, Io> impl<F, Io> Future for OpenWaitingConnection<F, Io>
where where
F: Future<Item = (Io, Protocol), Error = ConnectError>, F: Future<Output = Result<(Io, Protocol), ConnectError>>,
Io: AsyncRead + AsyncWrite, Io: AsyncRead + AsyncWrite + Unpin,
{ {
type Item = (); type Output = ();
type Error = ();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let Some(ref mut h2) = self.h2 { let this = unsafe { self.get_unchecked_mut() };
return match h2.poll() {
Ok(Async::Ready((snd, connection))) => { if let Some(ref mut h2) = this.h2 {
tokio_current_thread::spawn(connection.map_err(|_| ())); return match Pin::new(h2).poll(cx) {
let rx = self.rx.take().unwrap(); Poll::Ready(Ok((snd, connection))) => {
tokio_executor::current_thread::spawn(connection.map(|_| ()));
let rx = this.rx.take().unwrap();
let _ = rx.send(Ok(IoConnection::new( let _ = rx.send(Ok(IoConnection::new(
ConnectionType::H2(snd), ConnectionType::H2(snd),
Instant::now(), Instant::now(),
Some(Acquired(self.key.clone(), self.inner.take())), Some(Acquired(this.key.clone(), this.inner.take())),
))); )));
Ok(Async::Ready(())) Poll::Ready(())
} }
Ok(Async::NotReady) => Ok(Async::NotReady), Poll::Pending => Poll::Pending,
Err(err) => { Poll::Ready(Err(err)) => {
let _ = self.inner.take(); let _ = this.inner.take();
if let Some(rx) = self.rx.take() { if let Some(rx) = this.rx.take() {
let _ = rx.send(Err(ConnectError::H2(err))); let _ = rx.send(Err(ConnectError::H2(err)));
} }
Err(()) Poll::Ready(())
} }
}; };
} }
match self.fut.poll() { match unsafe { Pin::new_unchecked(&mut this.fut) }.poll(cx) {
Err(err) => { Poll::Ready(Err(err)) => {
let _ = self.inner.take(); let _ = this.inner.take();
if let Some(rx) = self.rx.take() { if let Some(rx) = this.rx.take() {
let _ = rx.send(Err(err)); let _ = rx.send(Err(err));
} }
Err(()) Poll::Ready(())
} }
Ok(Async::Ready((io, proto))) => { Poll::Ready(Ok((io, proto))) => {
if proto == Protocol::Http1 { if proto == Protocol::Http1 {
let rx = self.rx.take().unwrap(); let rx = this.rx.take().unwrap();
let _ = rx.send(Ok(IoConnection::new( let _ = rx.send(Ok(IoConnection::new(
ConnectionType::H1(io), ConnectionType::H1(io),
Instant::now(), Instant::now(),
Some(Acquired(self.key.clone(), self.inner.take())), Some(Acquired(this.key.clone(), this.inner.take())),
))); )));
Ok(Async::Ready(())) Poll::Ready(())
} else { } else {
self.h2 = Some(handshake(io)); this.h2 = Some(handshake(io).boxed_local());
self.poll() unsafe { Pin::new_unchecked(this) }.poll(cx)
} }
} }
Ok(Async::NotReady) => Ok(Async::NotReady), Poll::Pending => Poll::Pending,
} }
} }
} }
@@ -644,7 +605,7 @@ pub(crate) struct Acquired<T>(Key, Option<Rc<RefCell<Inner<T>>>>);
impl<T> Acquired<T> impl<T> Acquired<T>
where where
T: AsyncRead + AsyncWrite + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
pub(crate) fn close(&mut self, conn: IoConnection<T>) { pub(crate) fn close(&mut self, conn: IoConnection<T>) {
if let Some(inner) = self.1.take() { if let Some(inner) = self.1.take() {

View File

@@ -1,8 +1,8 @@
use std::cell::UnsafeCell; use std::cell::UnsafeCell;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::Service; use actix_service::Service;
use futures::Poll;
#[doc(hidden)] #[doc(hidden)]
/// Service that allows to turn non-clone service to a service with `Clone` impl /// Service that allows to turn non-clone service to a service with `Clone` impl
@@ -32,8 +32,8 @@ where
type Error = T::Error; type Error = T::Error;
type Future = T::Future; type Future = T::Future;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
unsafe { &mut *self.0.as_ref().get() }.poll_ready() unsafe { &mut *self.0.as_ref().get() }.poll_ready(cx)
} }
fn call(&mut self, req: T::Request) -> Self::Future { fn call(&mut self, req: T::Request) -> Self::Future {

View File

@@ -5,9 +5,9 @@ use std::rc::Rc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use bytes::BytesMut; use bytes::BytesMut;
use futures::{future, Future}; use futures::{future, Future, FutureExt};
use time; use time;
use tokio_timer::{sleep, Delay}; use tokio_timer::{delay, delay_for, Delay};
// "Sun, 06 Nov 1994 08:49:37 GMT".len() // "Sun, 06 Nov 1994 08:49:37 GMT".len()
const DATE_VALUE_LENGTH: usize = 29; const DATE_VALUE_LENGTH: usize = 29;
@@ -104,10 +104,10 @@ impl ServiceConfig {
#[inline] #[inline]
/// Client timeout for first request. /// Client timeout for first request.
pub fn client_timer(&self) -> Option<Delay> { pub fn client_timer(&self) -> Option<Delay> {
let delay = self.0.client_timeout; let delay_time = self.0.client_timeout;
if delay != 0 { if delay_time != 0 {
Some(Delay::new( Some(delay(
self.0.timer.now() + Duration::from_millis(delay), self.0.timer.now() + Duration::from_millis(delay_time),
)) ))
} else { } else {
None None
@@ -138,7 +138,7 @@ impl ServiceConfig {
/// Return keep-alive timer delay is configured. /// Return keep-alive timer delay is configured.
pub fn keep_alive_timer(&self) -> Option<Delay> { pub fn keep_alive_timer(&self) -> Option<Delay> {
if let Some(ka) = self.0.keep_alive { if let Some(ka) = self.0.keep_alive {
Some(Delay::new(self.0.timer.now() + ka)) Some(delay(self.0.timer.now() + ka))
} else { } else {
None None
} }
@@ -242,12 +242,12 @@ impl DateService {
// periodic date update // periodic date update
let s = self.clone(); let s = self.clone();
tokio_current_thread::spawn(sleep(Duration::from_millis(500)).then( tokio_executor::current_thread::spawn(
move |_| { delay_for(Duration::from_millis(500)).then(move |_| {
s.0.reset(); s.0.reset();
future::ok(()) future::ready(())
}, }),
)); );
} }
} }
@@ -277,7 +277,7 @@ mod tests {
fn test_date() { fn test_date() {
let mut rt = System::new("test"); let mut rt = System::new("test");
let _ = rt.block_on(future::lazy(|| { let _ = rt.block_on(future::lazy(|_| {
let settings = ServiceConfig::new(KeepAlive::Os, 0, 0); let settings = ServiceConfig::new(KeepAlive::Os, 0, 0);
let mut buf1 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); let mut buf1 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10);
settings.set_date(&mut buf1); settings.set_date(&mut buf1);

View File

@@ -1,13 +1,12 @@
use ring::digest::{Algorithm, SHA256}; use ring::hkdf::{Algorithm, KeyType, Prk, HKDF_SHA256};
use ring::hkdf::expand; use ring::hmac;
use ring::hmac::SigningKey;
use ring::rand::{SecureRandom, SystemRandom}; use ring::rand::{SecureRandom, SystemRandom};
use super::private::KEY_LEN as PRIVATE_KEY_LEN; use super::private::KEY_LEN as PRIVATE_KEY_LEN;
use super::signed::KEY_LEN as SIGNED_KEY_LEN; use super::signed::KEY_LEN as SIGNED_KEY_LEN;
static HKDF_DIGEST: &Algorithm = &SHA256; static HKDF_DIGEST: Algorithm = HKDF_SHA256;
const KEYS_INFO: &str = "COOKIE;SIGNED:HMAC-SHA256;PRIVATE:AEAD-AES-256-GCM"; const KEYS_INFO: &[&[u8]] = &[b"COOKIE;SIGNED:HMAC-SHA256;PRIVATE:AEAD-AES-256-GCM"];
/// A cryptographic master key for use with `Signed` and/or `Private` jars. /// A cryptographic master key for use with `Signed` and/or `Private` jars.
/// ///
@@ -25,6 +24,13 @@ pub struct Key {
encryption_key: [u8; PRIVATE_KEY_LEN], encryption_key: [u8; PRIVATE_KEY_LEN],
} }
impl KeyType for &Key {
#[inline]
fn len(&self) -> usize {
SIGNED_KEY_LEN + PRIVATE_KEY_LEN
}
}
impl Key { impl Key {
/// Derives new signing/encryption keys from a master key. /// Derives new signing/encryption keys from a master key.
/// ///
@@ -56,21 +62,26 @@ impl Key {
); );
} }
// Expand the user's key into two. // An empty `Key` structure; will be filled in with HKDF derived keys.
let prk = SigningKey::new(HKDF_DIGEST, key); let mut output_key = Key {
signing_key: [0; SIGNED_KEY_LEN],
encryption_key: [0; PRIVATE_KEY_LEN],
};
// Expand the master key into two HKDF generated keys.
let mut both_keys = [0; SIGNED_KEY_LEN + PRIVATE_KEY_LEN]; let mut both_keys = [0; SIGNED_KEY_LEN + PRIVATE_KEY_LEN];
expand(&prk, KEYS_INFO.as_bytes(), &mut both_keys); let prk = Prk::new_less_safe(HKDF_DIGEST, key);
let okm = prk.expand(KEYS_INFO, &output_key).expect("okm expand");
okm.fill(&mut both_keys).expect("fill keys");
// Copy the keys into their respective arrays. // Copy the key parts into their respective fields.
let mut signing_key = [0; SIGNED_KEY_LEN]; output_key
let mut encryption_key = [0; PRIVATE_KEY_LEN]; .signing_key
signing_key.copy_from_slice(&both_keys[..SIGNED_KEY_LEN]); .copy_from_slice(&both_keys[..SIGNED_KEY_LEN]);
encryption_key.copy_from_slice(&both_keys[SIGNED_KEY_LEN..]); output_key
.encryption_key
Key { .copy_from_slice(&both_keys[SIGNED_KEY_LEN..]);
signing_key, output_key
encryption_key,
}
} }
/// Generates signing/encryption keys from a secure, random source. Keys are /// Generates signing/encryption keys from a secure, random source. Keys are

View File

@@ -1,8 +1,8 @@
use std::str; use std::str;
use log::warn; use log::warn;
use ring::aead::{open_in_place, seal_in_place, Aad, Algorithm, Nonce, AES_256_GCM}; use ring::aead::{Aad, Algorithm, Nonce, AES_256_GCM};
use ring::aead::{OpeningKey, SealingKey}; use ring::aead::{LessSafeKey, UnboundKey};
use ring::rand::{SecureRandom, SystemRandom}; use ring::rand::{SecureRandom, SystemRandom};
use super::Key; use super::Key;
@@ -10,7 +10,7 @@ use crate::cookie::{Cookie, CookieJar};
// Keep these in sync, and keep the key len synced with the `private` docs as // Keep these in sync, and keep the key len synced with the `private` docs as
// well as the `KEYS_INFO` const in secure::Key. // well as the `KEYS_INFO` const in secure::Key.
static ALGO: &Algorithm = &AES_256_GCM; static ALGO: &'static Algorithm = &AES_256_GCM;
const NONCE_LEN: usize = 12; const NONCE_LEN: usize = 12;
pub const KEY_LEN: usize = 32; pub const KEY_LEN: usize = 32;
@@ -53,11 +53,14 @@ impl<'a> PrivateJar<'a> {
} }
let ad = Aad::from(name.as_bytes()); let ad = Aad::from(name.as_bytes());
let key = OpeningKey::new(ALGO, &self.key).expect("opening key"); let key = LessSafeKey::new(
let (nonce, sealed) = data.split_at_mut(NONCE_LEN); UnboundKey::new(&ALGO, &self.key).expect("matching key length"),
);
let (nonce, mut sealed) = data.split_at_mut(NONCE_LEN);
let nonce = let nonce =
Nonce::try_assume_unique_for_key(nonce).expect("invalid length of `nonce`"); Nonce::try_assume_unique_for_key(nonce).expect("invalid length of `nonce`");
let unsealed = open_in_place(&key, nonce, ad, 0, sealed) let unsealed = key
.open_in_place(nonce, ad, &mut sealed)
.map_err(|_| "invalid key/nonce/value: bad seal")?; .map_err(|_| "invalid key/nonce/value: bad seal")?;
if let Ok(unsealed_utf8) = str::from_utf8(unsealed) { if let Ok(unsealed_utf8) = str::from_utf8(unsealed) {
@@ -196,30 +199,33 @@ Please change it as soon as possible."
fn encrypt_name_value(name: &[u8], value: &[u8], key: &[u8]) -> Vec<u8> { fn encrypt_name_value(name: &[u8], value: &[u8], key: &[u8]) -> Vec<u8> {
// Create the `SealingKey` structure. // Create the `SealingKey` structure.
let key = SealingKey::new(ALGO, key).expect("sealing key creation"); let unbound = UnboundKey::new(&ALGO, key).expect("matching key length");
let key = LessSafeKey::new(unbound);
// Create a vec to hold the [nonce | cookie value | overhead]. // Create a vec to hold the [nonce | cookie value | overhead].
let overhead = ALGO.tag_len(); let mut data = vec![0; NONCE_LEN + value.len() + ALGO.tag_len()];
let mut data = vec![0; NONCE_LEN + value.len() + overhead];
// Randomly generate the nonce, then copy the cookie value as input. // Randomly generate the nonce, then copy the cookie value as input.
let (nonce, in_out) = data.split_at_mut(NONCE_LEN); let (nonce, in_out) = data.split_at_mut(NONCE_LEN);
let (in_out, tag) = in_out.split_at_mut(value.len());
in_out.copy_from_slice(value);
// Randomly generate the nonce into the nonce piece.
SystemRandom::new() SystemRandom::new()
.fill(nonce) .fill(nonce)
.expect("couldn't random fill nonce"); .expect("couldn't random fill nonce");
in_out[..value.len()].copy_from_slice(value); let nonce = Nonce::try_assume_unique_for_key(nonce).expect("invalid `nonce` length");
let nonce =
Nonce::try_assume_unique_for_key(nonce).expect("invalid length of `nonce`");
// Use cookie's name as associated data to prevent value swapping. // Use cookie's name as associated data to prevent value swapping.
let ad = Aad::from(name); let ad = Aad::from(name);
let ad_tag = key
.seal_in_place_separate_tag(nonce, ad, in_out)
.expect("in-place seal");
// Perform the actual sealing operation and get the output length. // Copy the tag into the tag piece.
let output_len = tag.copy_from_slice(ad_tag.as_ref());
seal_in_place(&key, nonce, ad, in_out, overhead).expect("in-place seal");
// Remove the overhead and return the sealed content. // Remove the overhead and return the sealed content.
data.truncate(NONCE_LEN + output_len);
data data
} }

View File

@@ -1,12 +1,11 @@
use ring::digest::{Algorithm, SHA256}; use ring::hmac::{self, sign, verify};
use ring::hmac::{sign, verify_with_own_key as verify, SigningKey};
use super::Key; use super::Key;
use crate::cookie::{Cookie, CookieJar}; use crate::cookie::{Cookie, CookieJar};
// Keep these in sync, and keep the key len synced with the `signed` docs as // Keep these in sync, and keep the key len synced with the `signed` docs as
// well as the `KEYS_INFO` const in secure::Key. // well as the `KEYS_INFO` const in secure::Key.
static HMAC_DIGEST: &Algorithm = &SHA256; static HMAC_DIGEST: hmac::Algorithm = hmac::HMAC_SHA256;
const BASE64_DIGEST_LEN: usize = 44; const BASE64_DIGEST_LEN: usize = 44;
pub const KEY_LEN: usize = 32; pub const KEY_LEN: usize = 32;
@@ -21,7 +20,7 @@ pub const KEY_LEN: usize = 32;
/// This type is only available when the `secure` feature is enabled. /// This type is only available when the `secure` feature is enabled.
pub struct SignedJar<'a> { pub struct SignedJar<'a> {
parent: &'a mut CookieJar, parent: &'a mut CookieJar,
key: SigningKey, key: hmac::Key,
} }
impl<'a> SignedJar<'a> { impl<'a> SignedJar<'a> {
@@ -32,7 +31,7 @@ impl<'a> SignedJar<'a> {
pub fn new(parent: &'a mut CookieJar, key: &Key) -> SignedJar<'a> { pub fn new(parent: &'a mut CookieJar, key: &Key) -> SignedJar<'a> {
SignedJar { SignedJar {
parent, parent,
key: SigningKey::new(HMAC_DIGEST, key.signing()), key: hmac::Key::new(HMAC_DIGEST, key.signing()),
} }
} }

View File

@@ -1,4 +1,7 @@
use std::future::Future;
use std::io::{self, Write}; use std::io::{self, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_threadpool::{run, CpuFuture}; use actix_threadpool::{run, CpuFuture};
#[cfg(feature = "brotli")] #[cfg(feature = "brotli")]
@@ -6,7 +9,7 @@ use brotli2::write::BrotliDecoder;
use bytes::Bytes; use bytes::Bytes;
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
use flate2::write::{GzDecoder, ZlibDecoder}; use flate2::write::{GzDecoder, ZlibDecoder};
use futures::{try_ready, Async, Future, Poll, Stream}; use futures::{ready, Stream};
use super::Writer; use super::Writer;
use crate::error::PayloadError; use crate::error::PayloadError;
@@ -18,12 +21,12 @@ pub struct Decoder<S> {
decoder: Option<ContentDecoder>, decoder: Option<ContentDecoder>,
stream: S, stream: S,
eof: bool, eof: bool,
fut: Option<CpuFuture<(Option<Bytes>, ContentDecoder), io::Error>>, fut: Option<CpuFuture<Result<(Option<Bytes>, ContentDecoder), io::Error>>>,
} }
impl<S> Decoder<S> impl<S> Decoder<S>
where where
S: Stream<Item = Bytes, Error = PayloadError>, S: Stream<Item = Result<Bytes, PayloadError>>,
{ {
/// Construct a decoder. /// Construct a decoder.
#[inline] #[inline]
@@ -71,34 +74,41 @@ where
impl<S> Stream for Decoder<S> impl<S> Stream for Decoder<S>
where where
S: Stream<Item = Bytes, Error = PayloadError>, S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{ {
type Item = Bytes; type Item = Result<Bytes, PayloadError>;
type Error = PayloadError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Option<Self::Item>> {
loop { loop {
if let Some(ref mut fut) = self.fut { if let Some(ref mut fut) = self.fut {
let (chunk, decoder) = try_ready!(fut.poll()); let (chunk, decoder) = match ready!(Pin::new(fut).poll(cx)) {
Ok(Ok(item)) => item,
Ok(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
Err(e) => return Poll::Ready(Some(Err(e.into()))),
};
self.decoder = Some(decoder); self.decoder = Some(decoder);
self.fut.take(); self.fut.take();
if let Some(chunk) = chunk { if let Some(chunk) = chunk {
return Ok(Async::Ready(Some(chunk))); return Poll::Ready(Some(Ok(chunk)));
} }
} }
if self.eof { if self.eof {
return Ok(Async::Ready(None)); return Poll::Ready(None);
} }
match self.stream.poll()? { match Pin::new(&mut self.stream).poll_next(cx) {
Async::Ready(Some(chunk)) => { Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
Poll::Ready(Some(Ok(chunk))) => {
if let Some(mut decoder) = self.decoder.take() { if let Some(mut decoder) = self.decoder.take() {
if chunk.len() < INPLACE { if chunk.len() < INPLACE {
let chunk = decoder.feed_data(chunk)?; let chunk = decoder.feed_data(chunk)?;
self.decoder = Some(decoder); self.decoder = Some(decoder);
if let Some(chunk) = chunk { if let Some(chunk) = chunk {
return Ok(Async::Ready(Some(chunk))); return Poll::Ready(Some(Ok(chunk)));
} }
} else { } else {
self.fut = Some(run(move || { self.fut = Some(run(move || {
@@ -108,21 +118,25 @@ where
} }
continue; continue;
} else { } else {
return Ok(Async::Ready(Some(chunk))); return Poll::Ready(Some(Ok(chunk)));
} }
} }
Async::Ready(None) => { Poll::Ready(None) => {
self.eof = true; self.eof = true;
return if let Some(mut decoder) = self.decoder.take() { return if let Some(mut decoder) = self.decoder.take() {
Ok(Async::Ready(decoder.feed_eof()?)) match decoder.feed_eof() {
Ok(Some(res)) => Poll::Ready(Some(Ok(res))),
Ok(None) => Poll::Ready(None),
Err(err) => Poll::Ready(Some(Err(err.into()))),
}
} else { } else {
Ok(Async::Ready(None)) Poll::Ready(None)
}; };
} }
Async::NotReady => break, Poll::Pending => break,
} }
} }
Ok(Async::NotReady) Poll::Pending
} }
} }

View File

@@ -1,5 +1,8 @@
//! Stream encoder //! Stream encoder
use std::future::Future;
use std::io::{self, Write}; use std::io::{self, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_threadpool::{run, CpuFuture}; use actix_threadpool::{run, CpuFuture};
#[cfg(feature = "brotli")] #[cfg(feature = "brotli")]
@@ -7,7 +10,6 @@ use brotli2::write::BrotliEncoder;
use bytes::Bytes; use bytes::Bytes;
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
use flate2::write::{GzEncoder, ZlibEncoder}; use flate2::write::{GzEncoder, ZlibEncoder};
use futures::{Async, Future, Poll};
use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::body::{Body, BodySize, MessageBody, ResponseBody};
use crate::http::header::{ContentEncoding, CONTENT_ENCODING}; use crate::http::header::{ContentEncoding, CONTENT_ENCODING};
@@ -22,7 +24,7 @@ pub struct Encoder<B> {
eof: bool, eof: bool,
body: EncoderBody<B>, body: EncoderBody<B>,
encoder: Option<ContentEncoder>, encoder: Option<ContentEncoder>,
fut: Option<CpuFuture<ContentEncoder, io::Error>>, fut: Option<CpuFuture<Result<ContentEncoder, io::Error>>>,
} }
impl<B: MessageBody> Encoder<B> { impl<B: MessageBody> Encoder<B> {
@@ -94,43 +96,46 @@ impl<B: MessageBody> MessageBody for Encoder<B> {
} }
} }
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> { fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
loop { loop {
if self.eof { if self.eof {
return Ok(Async::Ready(None)); return Poll::Ready(None);
} }
if let Some(ref mut fut) = self.fut { if let Some(ref mut fut) = self.fut {
let mut encoder = futures::try_ready!(fut.poll()); let mut encoder = match futures::ready!(Pin::new(fut).poll(cx)) {
Ok(Ok(item)) => item,
Ok(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
Err(e) => return Poll::Ready(Some(Err(e.into()))),
};
let chunk = encoder.take(); let chunk = encoder.take();
self.encoder = Some(encoder); self.encoder = Some(encoder);
self.fut.take(); self.fut.take();
if !chunk.is_empty() { if !chunk.is_empty() {
return Ok(Async::Ready(Some(chunk))); return Poll::Ready(Some(Ok(chunk)));
} }
} }
let result = match self.body { let result = match self.body {
EncoderBody::Bytes(ref mut b) => { EncoderBody::Bytes(ref mut b) => {
if b.is_empty() { if b.is_empty() {
Async::Ready(None) Poll::Ready(None)
} else { } else {
Async::Ready(Some(std::mem::replace(b, Bytes::new()))) Poll::Ready(Some(Ok(std::mem::replace(b, Bytes::new()))))
} }
} }
EncoderBody::Stream(ref mut b) => b.poll_next()?, EncoderBody::Stream(ref mut b) => b.poll_next(cx),
EncoderBody::BoxedStream(ref mut b) => b.poll_next()?, EncoderBody::BoxedStream(ref mut b) => b.poll_next(cx),
}; };
match result { match result {
Async::NotReady => return Ok(Async::NotReady), Poll::Ready(Some(Ok(chunk))) => {
Async::Ready(Some(chunk)) => {
if let Some(mut encoder) = self.encoder.take() { if let Some(mut encoder) = self.encoder.take() {
if chunk.len() < INPLACE { if chunk.len() < INPLACE {
encoder.write(&chunk)?; encoder.write(&chunk)?;
let chunk = encoder.take(); let chunk = encoder.take();
self.encoder = Some(encoder); self.encoder = Some(encoder);
if !chunk.is_empty() { if !chunk.is_empty() {
return Ok(Async::Ready(Some(chunk))); return Poll::Ready(Some(Ok(chunk)));
} }
} else { } else {
self.fut = Some(run(move || { self.fut = Some(run(move || {
@@ -139,22 +144,23 @@ impl<B: MessageBody> MessageBody for Encoder<B> {
})); }));
} }
} else { } else {
return Ok(Async::Ready(Some(chunk))); return Poll::Ready(Some(Ok(chunk)));
} }
} }
Async::Ready(None) => { Poll::Ready(None) => {
if let Some(encoder) = self.encoder.take() { if let Some(encoder) = self.encoder.take() {
let chunk = encoder.finish()?; let chunk = encoder.finish()?;
if chunk.is_empty() { if chunk.is_empty() {
return Ok(Async::Ready(None)); return Poll::Ready(None);
} else { } else {
self.eof = true; self.eof = true;
return Ok(Async::Ready(Some(chunk))); return Poll::Ready(Some(Ok(chunk)));
} }
} else { } else {
return Ok(Async::Ready(None)); return Poll::Ready(None);
} }
} }
val => return val,
} }
} }
} }

View File

@@ -6,11 +6,10 @@ use std::str::Utf8Error;
use std::string::FromUtf8Error; use std::string::FromUtf8Error;
use std::{fmt, io, result}; use std::{fmt, io, result};
pub use actix_threadpool::BlockingError;
use actix_utils::timeout::TimeoutError; use actix_utils::timeout::TimeoutError;
use bytes::BytesMut; use bytes::BytesMut;
use derive_more::{Display, From}; use derive_more::{Display, From};
use futures::Canceled; pub use futures::channel::oneshot::Canceled;
use http::uri::InvalidUri; use http::uri::InvalidUri;
use http::{header, Error as HttpError, StatusCode}; use http::{header, Error as HttpError, StatusCode};
use httparse; use httparse;
@@ -182,13 +181,13 @@ impl ResponseError for FormError {}
/// `InternalServerError` for `TimerError` /// `InternalServerError` for `TimerError`
impl ResponseError for TimerError {} impl ResponseError for TimerError {}
#[cfg(feature = "ssl")] #[cfg(feature = "openssl")]
/// `InternalServerError` for `openssl::ssl::Error` /// `InternalServerError` for `openssl::ssl::Error`
impl ResponseError for openssl::ssl::Error {} impl ResponseError for open_ssl::ssl::Error {}
#[cfg(feature = "ssl")] #[cfg(feature = "openssl")]
/// `InternalServerError` for `openssl::ssl::HandshakeError` /// `InternalServerError` for `openssl::ssl::HandshakeError`
impl ResponseError for openssl::ssl::HandshakeError<tokio_tcp::TcpStream> {} impl<T: std::fmt::Debug> ResponseError for open_ssl::ssl::HandshakeError<T> {}
/// Return `BAD_REQUEST` for `de::value::Error` /// Return `BAD_REQUEST` for `de::value::Error`
impl ResponseError for DeError { impl ResponseError for DeError {
@@ -197,8 +196,8 @@ impl ResponseError for DeError {
} }
} }
/// `InternalServerError` for `BlockingError` /// `InternalServerError` for `Canceled`
impl<E: fmt::Debug> ResponseError for BlockingError<E> {} impl ResponseError for Canceled {}
/// Return `BAD_REQUEST` for `Utf8Error` /// Return `BAD_REQUEST` for `Utf8Error`
impl ResponseError for Utf8Error { impl ResponseError for Utf8Error {
@@ -236,9 +235,6 @@ impl ResponseError for header::InvalidHeaderValueBytes {
} }
} }
/// `InternalServerError` for `futures::Canceled`
impl ResponseError for Canceled {}
/// A set of errors that can occur during parsing HTTP streams /// A set of errors that can occur during parsing HTTP streams
#[derive(Debug, Display)] #[derive(Debug, Display)]
pub enum ParseError { pub enum ParseError {
@@ -365,15 +361,12 @@ impl From<io::Error> for PayloadError {
} }
} }
impl From<BlockingError<io::Error>> for PayloadError { impl From<Canceled> for PayloadError {
fn from(err: BlockingError<io::Error>) -> Self { fn from(_: Canceled) -> Self {
match err { PayloadError::Io(io::Error::new(
BlockingError::Error(e) => PayloadError::Io(e), io::ErrorKind::Other,
BlockingError::Canceled => PayloadError::Io(io::Error::new( "Operation is canceled",
io::ErrorKind::Other, ))
"Thread pool is gone",
)),
}
} }
} }

View File

@@ -1,10 +1,12 @@
use std::future::Future;
use std::io; use std::io;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::Decoder; use actix_codec::Decoder;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{Async, Poll};
use http::header::{HeaderName, HeaderValue}; use http::header::{HeaderName, HeaderValue};
use http::{header, HttpTryFrom, Method, StatusCode, Uri, Version}; use http::{header, HttpTryFrom, Method, StatusCode, Uri, Version};
use httparse; use httparse;
@@ -442,9 +444,10 @@ impl Decoder for PayloadDecoder {
loop { loop {
let mut buf = None; let mut buf = None;
// advances the chunked state // advances the chunked state
*state = match state.step(src, size, &mut buf)? { *state = match state.step(src, size, &mut buf) {
Async::NotReady => return Ok(None), Poll::Pending => return Ok(None),
Async::Ready(state) => state, Poll::Ready(Ok(state)) => state,
Poll::Ready(Err(e)) => return Err(e),
}; };
if *state == ChunkedState::End { if *state == ChunkedState::End {
trace!("End of chunked stream"); trace!("End of chunked stream");
@@ -476,7 +479,7 @@ macro_rules! byte (
$rdr.split_to(1); $rdr.split_to(1);
b b
} else { } else {
return Ok(Async::NotReady) return Poll::Pending
} }
}) })
); );
@@ -487,7 +490,7 @@ impl ChunkedState {
body: &mut BytesMut, body: &mut BytesMut,
size: &mut u64, size: &mut u64,
buf: &mut Option<Bytes>, buf: &mut Option<Bytes>,
) -> Poll<ChunkedState, io::Error> { ) -> Poll<Result<ChunkedState, io::Error>> {
use self::ChunkedState::*; use self::ChunkedState::*;
match *self { match *self {
Size => ChunkedState::read_size(body, size), Size => ChunkedState::read_size(body, size),
@@ -499,10 +502,14 @@ impl ChunkedState {
BodyLf => ChunkedState::read_body_lf(body), BodyLf => ChunkedState::read_body_lf(body),
EndCr => ChunkedState::read_end_cr(body), EndCr => ChunkedState::read_end_cr(body),
EndLf => ChunkedState::read_end_lf(body), EndLf => ChunkedState::read_end_lf(body),
End => Ok(Async::Ready(ChunkedState::End)), End => Poll::Ready(Ok(ChunkedState::End)),
} }
} }
fn read_size(rdr: &mut BytesMut, size: &mut u64) -> Poll<ChunkedState, io::Error> {
fn read_size(
rdr: &mut BytesMut,
size: &mut u64,
) -> Poll<Result<ChunkedState, io::Error>> {
let radix = 16; let radix = 16;
match byte!(rdr) { match byte!(rdr) {
b @ b'0'..=b'9' => { b @ b'0'..=b'9' => {
@@ -517,48 +524,49 @@ impl ChunkedState {
*size *= radix; *size *= radix;
*size += u64::from(b + 10 - b'A'); *size += u64::from(b + 10 - b'A');
} }
b'\t' | b' ' => return Ok(Async::Ready(ChunkedState::SizeLws)), b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)),
b';' => return Ok(Async::Ready(ChunkedState::Extension)), b';' => return Poll::Ready(Ok(ChunkedState::Extension)),
b'\r' => return Ok(Async::Ready(ChunkedState::SizeLf)), b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)),
_ => { _ => {
return Err(io::Error::new( return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk size line: Invalid Size", "Invalid chunk size line: Invalid Size",
)); )));
} }
} }
Ok(Async::Ready(ChunkedState::Size)) Poll::Ready(Ok(ChunkedState::Size))
} }
fn read_size_lws(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
fn read_size_lws(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
trace!("read_size_lws"); trace!("read_size_lws");
match byte!(rdr) { match byte!(rdr) {
// LWS can follow the chunk size, but no more digits can come // LWS can follow the chunk size, but no more digits can come
b'\t' | b' ' => Ok(Async::Ready(ChunkedState::SizeLws)), b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)),
b';' => Ok(Async::Ready(ChunkedState::Extension)), b';' => Poll::Ready(Ok(ChunkedState::Extension)),
b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)), b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
_ => Err(io::Error::new( _ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk size linear white space", "Invalid chunk size linear white space",
)), ))),
} }
} }
fn read_extension(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> { fn read_extension(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) { match byte!(rdr) {
b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)), b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
_ => Ok(Async::Ready(ChunkedState::Extension)), // no supported extensions _ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions
} }
} }
fn read_size_lf( fn read_size_lf(
rdr: &mut BytesMut, rdr: &mut BytesMut,
size: &mut u64, size: &mut u64,
) -> Poll<ChunkedState, io::Error> { ) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) { match byte!(rdr) {
b'\n' if *size > 0 => Ok(Async::Ready(ChunkedState::Body)), b'\n' if *size > 0 => Poll::Ready(Ok(ChunkedState::Body)),
b'\n' if *size == 0 => Ok(Async::Ready(ChunkedState::EndCr)), b'\n' if *size == 0 => Poll::Ready(Ok(ChunkedState::EndCr)),
_ => Err(io::Error::new( _ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk size LF", "Invalid chunk size LF",
)), ))),
} }
} }
@@ -566,12 +574,12 @@ impl ChunkedState {
rdr: &mut BytesMut, rdr: &mut BytesMut,
rem: &mut u64, rem: &mut u64,
buf: &mut Option<Bytes>, buf: &mut Option<Bytes>,
) -> Poll<ChunkedState, io::Error> { ) -> Poll<Result<ChunkedState, io::Error>> {
trace!("Chunked read, remaining={:?}", rem); trace!("Chunked read, remaining={:?}", rem);
let len = rdr.len() as u64; let len = rdr.len() as u64;
if len == 0 { if len == 0 {
Ok(Async::Ready(ChunkedState::Body)) Poll::Ready(Ok(ChunkedState::Body))
} else { } else {
let slice; let slice;
if *rem > len { if *rem > len {
@@ -583,47 +591,47 @@ impl ChunkedState {
} }
*buf = Some(slice); *buf = Some(slice);
if *rem > 0 { if *rem > 0 {
Ok(Async::Ready(ChunkedState::Body)) Poll::Ready(Ok(ChunkedState::Body))
} else { } else {
Ok(Async::Ready(ChunkedState::BodyCr)) Poll::Ready(Ok(ChunkedState::BodyCr))
} }
} }
} }
fn read_body_cr(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> { fn read_body_cr(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) { match byte!(rdr) {
b'\r' => Ok(Async::Ready(ChunkedState::BodyLf)), b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)),
_ => Err(io::Error::new( _ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk body CR", "Invalid chunk body CR",
)), ))),
} }
} }
fn read_body_lf(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> { fn read_body_lf(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) { match byte!(rdr) {
b'\n' => Ok(Async::Ready(ChunkedState::Size)), b'\n' => Poll::Ready(Ok(ChunkedState::Size)),
_ => Err(io::Error::new( _ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk body LF", "Invalid chunk body LF",
)), ))),
} }
} }
fn read_end_cr(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> { fn read_end_cr(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) { match byte!(rdr) {
b'\r' => Ok(Async::Ready(ChunkedState::EndLf)), b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)),
_ => Err(io::Error::new( _ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk end CR", "Invalid chunk end CR",
)), ))),
} }
} }
fn read_end_lf(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> { fn read_end_lf(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) { match byte!(rdr) {
b'\n' => Ok(Async::Ready(ChunkedState::End)), b'\n' => Poll::Ready(Ok(ChunkedState::End)),
_ => Err(io::Error::new( _ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk end LF", "Invalid chunk end LF",
)), ))),
} }
} }
} }

View File

@@ -1,15 +1,17 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant; use std::time::Instant;
use std::{fmt, io, net}; use std::{fmt, io, io::Write, net};
use actix_codec::{Decoder, Encoder, Framed, FramedParts}; use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts};
use actix_server_config::IoStream; use actix_server_config::IoStream;
use actix_service::Service; use actix_service::Service;
use bitflags::bitflags; use bitflags::bitflags;
use bytes::{BufMut, BytesMut}; use bytes::{BufMut, BytesMut};
use futures::{Async, Future, Poll};
use log::{error, trace}; use log::{error, trace};
use tokio_timer::Delay; use tokio_timer::{delay, Delay};
use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::body::{Body, BodySize, MessageBody, ResponseBody};
use crate::cloneable::CloneableService; use crate::cloneable::CloneableService;
@@ -261,14 +263,14 @@ where
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
fn can_read(&self) -> bool { fn can_read(&self, cx: &mut Context) -> bool {
if self if self
.flags .flags
.intersects(Flags::READ_DISCONNECT | Flags::UPGRADE) .intersects(Flags::READ_DISCONNECT | Flags::UPGRADE)
{ {
false false
} else if let Some(ref info) = self.payload { } else if let Some(ref info) = self.payload {
info.need_read() == PayloadStatus::Read info.need_read(cx) == PayloadStatus::Read
} else { } else {
true true
} }
@@ -287,7 +289,7 @@ where
/// ///
/// true - got whouldblock /// true - got whouldblock
/// false - didnt get whouldblock /// false - didnt get whouldblock
fn poll_flush(&mut self) -> Result<bool, DispatchError> { fn poll_flush(&mut self, cx: &mut Context) -> Result<bool, DispatchError> {
if self.write_buf.is_empty() { if self.write_buf.is_empty() {
return Ok(false); return Ok(false);
} }
@@ -295,31 +297,31 @@ where
let len = self.write_buf.len(); let len = self.write_buf.len();
let mut written = 0; let mut written = 0;
while written < len { while written < len {
match self.io.write(&self.write_buf[written..]) { match unsafe { Pin::new_unchecked(&mut self.io) }
Ok(0) => { .poll_write(cx, &self.write_buf[written..])
{
Poll::Ready(Ok(0)) => {
return Err(DispatchError::Io(io::Error::new( return Err(DispatchError::Io(io::Error::new(
io::ErrorKind::WriteZero, io::ErrorKind::WriteZero,
"", "",
))); )));
} }
Ok(n) => { Poll::Ready(Ok(n)) => {
written += n; written += n;
} }
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { Poll::Pending => {
if written > 0 { if written > 0 {
let _ = self.write_buf.split_to(written); let _ = self.write_buf.split_to(written);
} }
return Ok(true); return Ok(true);
} }
Err(err) => return Err(DispatchError::Io(err)), Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)),
} }
} }
if written > 0 { if written == self.write_buf.len() {
if written == self.write_buf.len() { unsafe { self.write_buf.set_len(0) }
unsafe { self.write_buf.set_len(0) } } else {
} else { let _ = self.write_buf.split_to(written);
let _ = self.write_buf.split_to(written);
}
} }
Ok(false) Ok(false)
} }
@@ -350,12 +352,15 @@ where
.extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
} }
fn poll_response(&mut self) -> Result<PollResponse, DispatchError> { fn poll_response(
&mut self,
cx: &mut Context,
) -> Result<PollResponse, DispatchError> {
loop { loop {
let state = match self.state { let state = match self.state {
State::None => match self.messages.pop_front() { State::None => match self.messages.pop_front() {
Some(DispatcherMessage::Item(req)) => { Some(DispatcherMessage::Item(req)) => {
Some(self.handle_request(req)?) Some(self.handle_request(req, cx)?)
} }
Some(DispatcherMessage::Error(res)) => { Some(DispatcherMessage::Error(res)) => {
Some(self.send_response(res, ResponseBody::Other(Body::Empty))?) Some(self.send_response(res, ResponseBody::Other(Body::Empty))?)
@@ -365,54 +370,58 @@ where
} }
None => None, None => None,
}, },
State::ExpectCall(ref mut fut) => match fut.poll() { State::ExpectCall(ref mut fut) => {
Ok(Async::Ready(req)) => { match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
self.send_continue(); Poll::Ready(Ok(req)) => {
self.state = State::ServiceCall(self.service.call(req)); self.send_continue();
continue; self.state = State::ServiceCall(self.service.call(req));
continue;
}
Poll::Ready(Err(e)) => {
let res: Response = e.into().into();
let (res, body) = res.replace_body(());
Some(self.send_response(res, body.into_body())?)
}
Poll::Pending => None,
} }
Ok(Async::NotReady) => None, }
Err(e) => { State::ServiceCall(ref mut fut) => {
let res: Response = e.into().into(); match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
let (res, body) = res.replace_body(()); Poll::Ready(Ok(res)) => {
Some(self.send_response(res, body.into_body())?) let (res, body) = res.into().replace_body(());
self.state = self.send_response(res, body)?;
continue;
}
Poll::Ready(Err(e)) => {
let res: Response = e.into().into();
let (res, body) = res.replace_body(());
Some(self.send_response(res, body.into_body())?)
}
Poll::Pending => None,
} }
}, }
State::ServiceCall(ref mut fut) => match fut.poll() {
Ok(Async::Ready(res)) => {
let (res, body) = res.into().replace_body(());
self.state = self.send_response(res, body)?;
continue;
}
Ok(Async::NotReady) => None,
Err(e) => {
let res: Response = e.into().into();
let (res, body) = res.replace_body(());
Some(self.send_response(res, body.into_body())?)
}
},
State::SendPayload(ref mut stream) => { State::SendPayload(ref mut stream) => {
loop { loop {
if self.write_buf.len() < HW_BUFFER_SIZE { if self.write_buf.len() < HW_BUFFER_SIZE {
match stream match stream.poll_next(cx) {
.poll_next() Poll::Ready(Some(Ok(item))) => {
.map_err(|_| DispatchError::Unknown)?
{
Async::Ready(Some(item)) => {
self.codec.encode( self.codec.encode(
Message::Chunk(Some(item)), Message::Chunk(Some(item)),
&mut self.write_buf, &mut self.write_buf,
)?; )?;
continue; continue;
} }
Async::Ready(None) => { Poll::Ready(None) => {
self.codec.encode( self.codec.encode(
Message::Chunk(None), Message::Chunk(None),
&mut self.write_buf, &mut self.write_buf,
)?; )?;
self.state = State::None; self.state = State::None;
} }
Async::NotReady => return Ok(PollResponse::DoNothing), Poll::Ready(Some(Err(_))) => {
return Err(DispatchError::Unknown)
}
Poll::Pending => return Ok(PollResponse::DoNothing),
} }
} else { } else {
return Ok(PollResponse::DrainWriteBuf); return Ok(PollResponse::DrainWriteBuf);
@@ -433,7 +442,7 @@ where
// if read-backpressure is enabled and we consumed some data. // if read-backpressure is enabled and we consumed some data.
// we may read more data and retry // we may read more data and retry
if self.state.is_call() { if self.state.is_call() {
if self.poll_request()? { if self.poll_request(cx)? {
continue; continue;
} }
} else if !self.messages.is_empty() { } else if !self.messages.is_empty() {
@@ -446,17 +455,21 @@ where
Ok(PollResponse::DoNothing) Ok(PollResponse::DoNothing)
} }
fn handle_request(&mut self, req: Request) -> Result<State<S, B, X>, DispatchError> { fn handle_request(
&mut self,
req: Request,
cx: &mut Context,
) -> Result<State<S, B, X>, DispatchError> {
// Handle `EXPECT: 100-Continue` header // Handle `EXPECT: 100-Continue` header
let req = if req.head().expect() { let req = if req.head().expect() {
let mut task = self.expect.call(req); let mut task = self.expect.call(req);
match task.poll() { match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) {
Ok(Async::Ready(req)) => { Poll::Ready(Ok(req)) => {
self.send_continue(); self.send_continue();
req req
} }
Ok(Async::NotReady) => return Ok(State::ExpectCall(task)), Poll::Pending => return Ok(State::ExpectCall(task)),
Err(e) => { Poll::Ready(Err(e)) => {
let e = e.into(); let e = e.into();
let res: Response = e.into(); let res: Response = e.into();
let (res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
@@ -469,13 +482,13 @@ where
// Call service // Call service
let mut task = self.service.call(req); let mut task = self.service.call(req);
match task.poll() { match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) {
Ok(Async::Ready(res)) => { Poll::Ready(Ok(res)) => {
let (res, body) = res.into().replace_body(()); let (res, body) = res.into().replace_body(());
self.send_response(res, body) self.send_response(res, body)
} }
Ok(Async::NotReady) => Ok(State::ServiceCall(task)), Poll::Pending => Ok(State::ServiceCall(task)),
Err(e) => { Poll::Ready(Err(e)) => {
let res: Response = e.into().into(); let res: Response = e.into().into();
let (res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
self.send_response(res, body.into_body()) self.send_response(res, body.into_body())
@@ -484,9 +497,12 @@ where
} }
/// Process one incoming requests /// Process one incoming requests
pub(self) fn poll_request(&mut self) -> Result<bool, DispatchError> { pub(self) fn poll_request(
&mut self,
cx: &mut Context,
) -> Result<bool, DispatchError> {
// limit a mount of non processed requests // limit a mount of non processed requests
if self.messages.len() >= MAX_PIPELINED_MESSAGES || !self.can_read() { if self.messages.len() >= MAX_PIPELINED_MESSAGES || !self.can_read(cx) {
return Ok(false); return Ok(false);
} }
@@ -521,7 +537,7 @@ where
// handle request early // handle request early
if self.state.is_empty() { if self.state.is_empty() {
self.state = self.handle_request(req)?; self.state = self.handle_request(req, cx)?;
} else { } else {
self.messages.push_back(DispatcherMessage::Item(req)); self.messages.push_back(DispatcherMessage::Item(req));
} }
@@ -587,12 +603,12 @@ where
} }
/// keep-alive timer /// keep-alive timer
fn poll_keepalive(&mut self) -> Result<(), DispatchError> { fn poll_keepalive(&mut self, cx: &mut Context) -> Result<(), DispatchError> {
if self.ka_timer.is_none() { if self.ka_timer.is_none() {
// shutdown timeout // shutdown timeout
if self.flags.contains(Flags::SHUTDOWN) { if self.flags.contains(Flags::SHUTDOWN) {
if let Some(interval) = self.codec.config().client_disconnect_timer() { if let Some(interval) = self.codec.config().client_disconnect_timer() {
self.ka_timer = Some(Delay::new(interval)); self.ka_timer = Some(delay(interval));
} else { } else {
self.flags.insert(Flags::READ_DISCONNECT); self.flags.insert(Flags::READ_DISCONNECT);
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = self.payload.take() {
@@ -605,11 +621,8 @@ where
} }
} }
match self.ka_timer.as_mut().unwrap().poll().map_err(|e| { match Pin::new(&mut self.ka_timer.as_mut().unwrap()).poll(cx) {
error!("Timer error {:?}", e); Poll::Ready(()) => {
DispatchError::Unknown
})? {
Async::Ready(_) => {
// if we get timeout during shutdown, drop connection // if we get timeout during shutdown, drop connection
if self.flags.contains(Flags::SHUTDOWN) { if self.flags.contains(Flags::SHUTDOWN) {
return Err(DispatchError::DisconnectTimeout); return Err(DispatchError::DisconnectTimeout);
@@ -624,9 +637,9 @@ where
if let Some(deadline) = if let Some(deadline) =
self.codec.config().client_disconnect_timer() self.codec.config().client_disconnect_timer()
{ {
if let Some(timer) = self.ka_timer.as_mut() { if let Some(mut timer) = self.ka_timer.as_mut() {
timer.reset(deadline); timer.reset(deadline);
let _ = timer.poll(); let _ = Pin::new(&mut timer).poll(cx);
} }
} else { } else {
// no shutdown timeout, drop socket // no shutdown timeout, drop socket
@@ -650,23 +663,37 @@ where
} else if let Some(deadline) = } else if let Some(deadline) =
self.codec.config().keep_alive_expire() self.codec.config().keep_alive_expire()
{ {
if let Some(timer) = self.ka_timer.as_mut() { if let Some(mut timer) = self.ka_timer.as_mut() {
timer.reset(deadline); timer.reset(deadline);
let _ = timer.poll(); let _ = Pin::new(&mut timer).poll(cx);
} }
} }
} else if let Some(timer) = self.ka_timer.as_mut() { } else if let Some(mut timer) = self.ka_timer.as_mut() {
timer.reset(self.ka_expire); timer.reset(self.ka_expire);
let _ = timer.poll(); let _ = Pin::new(&mut timer).poll(cx);
} }
} }
Async::NotReady => (), Poll::Pending => (),
} }
Ok(()) Ok(())
} }
} }
impl<T, S, B, X, U> Unpin for Dispatcher<T, S, B, X, U>
where
T: IoStream,
S: Service<Request = Request>,
S::Error: Into<Error>,
S::Response: Into<Response<B>>,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
{
}
impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U> impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
where where
T: IoStream, T: IoStream,
@@ -679,27 +706,28 @@ where
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
type Item = (); type Output = Result<(), DispatchError>;
type Error = DispatchError;
#[inline] #[inline]
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match self.inner { match self.as_mut().inner {
DispatcherState::Normal(ref mut inner) => { DispatcherState::Normal(ref mut inner) => {
inner.poll_keepalive()?; inner.poll_keepalive(cx)?;
if inner.flags.contains(Flags::SHUTDOWN) { if inner.flags.contains(Flags::SHUTDOWN) {
if inner.flags.contains(Flags::WRITE_DISCONNECT) { if inner.flags.contains(Flags::WRITE_DISCONNECT) {
Ok(Async::Ready(())) Poll::Ready(Ok(()))
} else { } else {
// flush buffer // flush buffer
inner.poll_flush()?; inner.poll_flush(cx)?;
if !inner.write_buf.is_empty() { if !inner.write_buf.is_empty() {
Ok(Async::NotReady) Poll::Pending
} else { } else {
match inner.io.shutdown()? { match Pin::new(&mut inner.io).poll_shutdown(cx) {
Async::Ready(_) => Ok(Async::Ready(())), Poll::Ready(res) => {
Async::NotReady => Ok(Async::NotReady), Poll::Ready(res.map_err(DispatchError::from))
}
Poll::Pending => Poll::Pending,
} }
} }
} }
@@ -707,12 +735,12 @@ where
// read socket into a buf // read socket into a buf
let should_disconnect = let should_disconnect =
if !inner.flags.contains(Flags::READ_DISCONNECT) { if !inner.flags.contains(Flags::READ_DISCONNECT) {
read_available(&mut inner.io, &mut inner.read_buf)? read_available(cx, &mut inner.io, &mut inner.read_buf)?
} else { } else {
None None
}; };
inner.poll_request()?; inner.poll_request(cx)?;
if let Some(true) = should_disconnect { if let Some(true) = should_disconnect {
inner.flags.insert(Flags::READ_DISCONNECT); inner.flags.insert(Flags::READ_DISCONNECT);
if let Some(mut payload) = inner.payload.take() { if let Some(mut payload) = inner.payload.take() {
@@ -724,7 +752,7 @@ where
if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE { if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE {
inner.write_buf.reserve(HW_BUFFER_SIZE); inner.write_buf.reserve(HW_BUFFER_SIZE);
} }
let result = inner.poll_response()?; let result = inner.poll_response(cx)?;
let drain = result == PollResponse::DrainWriteBuf; let drain = result == PollResponse::DrainWriteBuf;
// switch to upgrade handler // switch to upgrade handler
@@ -742,7 +770,7 @@ where
self.inner = DispatcherState::Upgrade( self.inner = DispatcherState::Upgrade(
inner.upgrade.unwrap().call((req, framed)), inner.upgrade.unwrap().call((req, framed)),
); );
return self.poll(); return self.poll(cx);
} else { } else {
panic!() panic!()
} }
@@ -751,14 +779,14 @@ where
// we didnt get WouldBlock from write operation, // we didnt get WouldBlock from write operation,
// so data get written to kernel completely (OSX) // so data get written to kernel completely (OSX)
// and we have to write again otherwise response can get stuck // and we have to write again otherwise response can get stuck
if inner.poll_flush()? || !drain { if inner.poll_flush(cx)? || !drain {
break; break;
} }
} }
// client is gone // client is gone
if inner.flags.contains(Flags::WRITE_DISCONNECT) { if inner.flags.contains(Flags::WRITE_DISCONNECT) {
return Ok(Async::Ready(())); return Poll::Ready(Ok(()));
} }
let is_empty = inner.state.is_empty(); let is_empty = inner.state.is_empty();
@@ -771,38 +799,44 @@ where
// keep-alive and stream errors // keep-alive and stream errors
if is_empty && inner.write_buf.is_empty() { if is_empty && inner.write_buf.is_empty() {
if let Some(err) = inner.error.take() { if let Some(err) = inner.error.take() {
Err(err) Poll::Ready(Err(err))
} }
// disconnect if keep-alive is not enabled // disconnect if keep-alive is not enabled
else if inner.flags.contains(Flags::STARTED) else if inner.flags.contains(Flags::STARTED)
&& !inner.flags.intersects(Flags::KEEPALIVE) && !inner.flags.intersects(Flags::KEEPALIVE)
{ {
inner.flags.insert(Flags::SHUTDOWN); inner.flags.insert(Flags::SHUTDOWN);
self.poll() self.poll(cx)
} }
// disconnect if shutdown // disconnect if shutdown
else if inner.flags.contains(Flags::SHUTDOWN) { else if inner.flags.contains(Flags::SHUTDOWN) {
self.poll() self.poll(cx)
} else { } else {
Ok(Async::NotReady) Poll::Pending
} }
} else { } else {
Ok(Async::NotReady) Poll::Pending
} }
} }
} }
DispatcherState::Upgrade(ref mut fut) => fut.poll().map_err(|e| { DispatcherState::Upgrade(ref mut fut) => {
error!("Upgrade handler error: {}", e); unsafe { Pin::new_unchecked(fut) }.poll(cx).map_err(|e| {
DispatchError::Upgrade error!("Upgrade handler error: {}", e);
}), DispatchError::Upgrade
})
}
DispatcherState::None => panic!(), DispatcherState::None => panic!(),
} }
} }
} }
fn read_available<T>(io: &mut T, buf: &mut BytesMut) -> Result<Option<bool>, io::Error> fn read_available<T>(
cx: &mut Context,
io: &mut T,
buf: &mut BytesMut,
) -> Result<Option<bool>, io::Error>
where where
T: io::Read, T: AsyncRead + Unpin,
{ {
let mut read_some = false; let mut read_some = false;
loop { loop {
@@ -810,19 +844,18 @@ where
buf.reserve(HW_BUFFER_SIZE); buf.reserve(HW_BUFFER_SIZE);
} }
let read = unsafe { io.read(buf.bytes_mut()) }; match read(cx, io, buf) {
match read { Poll::Pending => {
Ok(n) => { return if read_some { Ok(Some(false)) } else { Ok(None) };
}
Poll::Ready(Ok(n)) => {
if n == 0 { if n == 0 {
return Ok(Some(true)); return Ok(Some(true));
} else { } else {
read_some = true; read_some = true;
unsafe {
buf.advance_mut(n);
}
} }
} }
Err(e) => { Poll::Ready(Err(e)) => {
return if e.kind() == io::ErrorKind::WouldBlock { return if e.kind() == io::ErrorKind::WouldBlock {
if read_some { if read_some {
Ok(Some(false)) Ok(Some(false))
@@ -833,12 +866,23 @@ where
Ok(Some(true)) Ok(Some(true))
} else { } else {
Err(e) Err(e)
}; }
} }
} }
} }
} }
fn read<T>(
cx: &mut Context,
io: &mut T,
buf: &mut BytesMut,
) -> Poll<Result<usize, io::Error>>
where
T: AsyncRead + Unpin,
{
Pin::new(io).poll_read_buf(cx, buf)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use actix_service::IntoService; use actix_service::IntoService;
@@ -852,7 +896,7 @@ mod tests {
#[test] #[test]
fn test_req_parse_err() { fn test_req_parse_err() {
let mut sys = actix_rt::System::new("test"); let mut sys = actix_rt::System::new("test");
let _ = sys.block_on(lazy(|| { let _ = sys.block_on(lazy(|cx| {
let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n"); let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n");
let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler<TestBuffer>>::new( let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler<TestBuffer>>::new(
@@ -865,7 +909,10 @@ mod tests {
None, None,
None, None,
); );
assert!(h1.poll().is_err()); match Pin::new(&mut h1).poll(cx) {
Poll::Pending => panic!(),
Poll::Ready(res) => assert!(res.is_err()),
}
if let DispatcherState::Normal(ref inner) = h1.inner { if let DispatcherState::Normal(ref inner) = h1.inner {
assert!(inner.flags.contains(Flags::READ_DISCONNECT)); assert!(inner.flags.contains(Flags::READ_DISCONNECT));

View File

@@ -548,10 +548,11 @@ mod tests {
ConnectionType::Close, ConnectionType::Close,
&ServiceConfig::default(), &ServiceConfig::default(),
); );
assert_eq!( let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap();
bytes.take().freeze(), assert!(data.contains("Content-Length: 0\r\n"));
Bytes::from_static(b"\r\nContent-Length: 0\r\nConnection: close\r\nDate: date\r\nContent-Type: plain/text\r\n\r\n") assert!(data.contains("Connection: close\r\n"));
); assert!(data.contains("Content-Type: plain/text\r\n"));
assert!(data.contains("Date: date\r\n"));
let _ = head.encode_headers( let _ = head.encode_headers(
&mut bytes, &mut bytes,
@@ -560,10 +561,10 @@ mod tests {
ConnectionType::KeepAlive, ConnectionType::KeepAlive,
&ServiceConfig::default(), &ServiceConfig::default(),
); );
assert_eq!( let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap();
bytes.take().freeze(), assert!(data.contains("Transfer-Encoding: chunked\r\n"));
Bytes::from_static(b"\r\nTransfer-Encoding: chunked\r\nDate: date\r\nContent-Type: plain/text\r\n\r\n") assert!(data.contains("Content-Type: plain/text\r\n"));
); assert!(data.contains("Date: date\r\n"));
let _ = head.encode_headers( let _ = head.encode_headers(
&mut bytes, &mut bytes,
@@ -572,10 +573,10 @@ mod tests {
ConnectionType::KeepAlive, ConnectionType::KeepAlive,
&ServiceConfig::default(), &ServiceConfig::default(),
); );
assert_eq!( let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap();
bytes.take().freeze(), assert!(data.contains("Content-Length: 100\r\n"));
Bytes::from_static(b"\r\nContent-Length: 100\r\nDate: date\r\nContent-Type: plain/text\r\n\r\n") assert!(data.contains("Content-Type: plain/text\r\n"));
); assert!(data.contains("Date: date\r\n"));
let mut head = RequestHead::default(); let mut head = RequestHead::default();
head.set_camel_case_headers(false); head.set_camel_case_headers(false);
@@ -586,7 +587,6 @@ mod tests {
.append(CONTENT_TYPE, HeaderValue::from_static("xml")); .append(CONTENT_TYPE, HeaderValue::from_static("xml"));
let mut head = RequestHeadType::Owned(head); let mut head = RequestHeadType::Owned(head);
let _ = head.encode_headers( let _ = head.encode_headers(
&mut bytes, &mut bytes,
Version::HTTP_11, Version::HTTP_11,
@@ -594,10 +594,11 @@ mod tests {
ConnectionType::KeepAlive, ConnectionType::KeepAlive,
&ServiceConfig::default(), &ServiceConfig::default(),
); );
assert_eq!( let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap();
bytes.take().freeze(), assert!(data.contains("transfer-encoding: chunked\r\n"));
Bytes::from_static(b"\r\ntransfer-encoding: chunked\r\ndate: date\r\ncontent-type: xml\r\ncontent-type: plain/text\r\n\r\n") assert!(data.contains("content-type: xml\r\n"));
); assert!(data.contains("content-type: plain/text\r\n"));
assert!(data.contains("date: date\r\n"));
} }
#[test] #[test]
@@ -626,9 +627,10 @@ mod tests {
ConnectionType::Close, ConnectionType::Close,
&ServiceConfig::default(), &ServiceConfig::default(),
); );
assert_eq!( let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap();
bytes.take().freeze(), assert!(data.contains("content-length: 0\r\n"));
Bytes::from_static(b"\r\ncontent-length: 0\r\nconnection: close\r\nauthorization: another authorization\r\ndate: date\r\n\r\n") assert!(data.contains("connection: close\r\n"));
); assert!(data.contains("authorization: another authorization\r\n"));
assert!(data.contains("date: date\r\n"));
} }
} }

View File

@@ -1,21 +1,24 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_server_config::ServerConfig; use actix_server_config::ServerConfig;
use actix_service::{NewService, Service}; use actix_service::{Service, ServiceFactory};
use futures::future::{ok, FutureResult}; use futures::future::{ok, Ready};
use futures::{Async, Poll};
use crate::error::Error; use crate::error::Error;
use crate::request::Request; use crate::request::Request;
pub struct ExpectHandler; pub struct ExpectHandler;
impl NewService for ExpectHandler { impl ServiceFactory for ExpectHandler {
type Config = ServerConfig; type Config = ServerConfig;
type Request = Request; type Request = Request;
type Response = Request; type Response = Request;
type Error = Error; type Error = Error;
type Service = ExpectHandler; type Service = ExpectHandler;
type InitError = Error; type InitError = Error;
type Future = FutureResult<Self::Service, Self::InitError>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: &ServerConfig) -> Self::Future { fn new_service(&self, _: &ServerConfig) -> Self::Future {
ok(ExpectHandler) ok(ExpectHandler)
@@ -26,10 +29,10 @@ impl Service for ExpectHandler {
type Request = Request; type Request = Request;
type Response = Request; type Response = Request;
type Error = Error; type Error = Error;
type Future = FutureResult<Self::Response, Self::Error>; type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Ok(Async::Ready(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, req: Request) -> Self::Future { fn call(&mut self, req: Request) -> Self::Future {

View File

@@ -1,12 +1,14 @@
//! Payload stream //! Payload stream
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::rc::{Rc, Weak}; use std::rc::{Rc, Weak};
use std::task::{Context, Poll};
use actix_utils::task::LocalWaker;
use bytes::Bytes; use bytes::Bytes;
use futures::task::current as current_task; use futures::Stream;
use futures::task::Task;
use futures::{Async, Poll, Stream};
use crate::error::PayloadError; use crate::error::PayloadError;
@@ -77,15 +79,24 @@ impl Payload {
pub fn unread_data(&mut self, data: Bytes) { pub fn unread_data(&mut self, data: Bytes) {
self.inner.borrow_mut().unread_data(data); self.inner.borrow_mut().unread_data(data);
} }
#[inline]
pub fn readany(
&mut self,
cx: &mut Context,
) -> Poll<Option<Result<Bytes, PayloadError>>> {
self.inner.borrow_mut().readany(cx)
}
} }
impl Stream for Payload { impl Stream for Payload {
type Item = Bytes; type Item = Result<Bytes, PayloadError>;
type Error = PayloadError;
#[inline] fn poll_next(
fn poll(&mut self) -> Poll<Option<Bytes>, PayloadError> { self: Pin<&mut Self>,
self.inner.borrow_mut().readany() cx: &mut Context,
) -> Poll<Option<Result<Bytes, PayloadError>>> {
self.inner.borrow_mut().readany(cx)
} }
} }
@@ -117,19 +128,14 @@ impl PayloadSender {
} }
#[inline] #[inline]
pub fn need_read(&self) -> PayloadStatus { pub fn need_read(&self, cx: &mut Context) -> PayloadStatus {
// we check need_read only if Payload (other side) is alive, // we check need_read only if Payload (other side) is alive,
// otherwise always return true (consume payload) // otherwise always return true (consume payload)
if let Some(shared) = self.inner.upgrade() { if let Some(shared) = self.inner.upgrade() {
if shared.borrow().need_read { if shared.borrow().need_read {
PayloadStatus::Read PayloadStatus::Read
} else { } else {
#[cfg(not(test))] shared.borrow_mut().io_task.register(cx.waker());
{
if shared.borrow_mut().io_task.is_none() {
shared.borrow_mut().io_task = Some(current_task());
}
}
PayloadStatus::Pause PayloadStatus::Pause
} }
} else { } else {
@@ -145,8 +151,8 @@ struct Inner {
err: Option<PayloadError>, err: Option<PayloadError>,
need_read: bool, need_read: bool,
items: VecDeque<Bytes>, items: VecDeque<Bytes>,
task: Option<Task>, task: LocalWaker,
io_task: Option<Task>, io_task: LocalWaker,
} }
impl Inner { impl Inner {
@@ -157,8 +163,8 @@ impl Inner {
err: None, err: None,
items: VecDeque::new(), items: VecDeque::new(),
need_read: true, need_read: true,
task: None, task: LocalWaker::new(),
io_task: None, io_task: LocalWaker::new(),
} }
} }
@@ -178,7 +184,7 @@ impl Inner {
self.items.push_back(data); self.items.push_back(data);
self.need_read = self.len < MAX_BUFFER_SIZE; self.need_read = self.len < MAX_BUFFER_SIZE;
if let Some(task) = self.task.take() { if let Some(task) = self.task.take() {
task.notify() task.wake()
} }
} }
@@ -187,34 +193,28 @@ impl Inner {
self.len self.len
} }
fn readany(&mut self) -> Poll<Option<Bytes>, PayloadError> { fn readany(
&mut self,
cx: &mut Context,
) -> Poll<Option<Result<Bytes, PayloadError>>> {
if let Some(data) = self.items.pop_front() { if let Some(data) = self.items.pop_front() {
self.len -= data.len(); self.len -= data.len();
self.need_read = self.len < MAX_BUFFER_SIZE; self.need_read = self.len < MAX_BUFFER_SIZE;
if self.need_read && self.task.is_none() && !self.eof { if self.need_read && !self.eof {
self.task = Some(current_task()); self.task.register(cx.waker());
} }
if let Some(task) = self.io_task.take() { self.io_task.wake();
task.notify() Poll::Ready(Some(Ok(data)))
}
Ok(Async::Ready(Some(data)))
} else if let Some(err) = self.err.take() { } else if let Some(err) = self.err.take() {
Err(err) Poll::Ready(Some(Err(err)))
} else if self.eof { } else if self.eof {
Ok(Async::Ready(None)) Poll::Ready(None)
} else { } else {
self.need_read = true; self.need_read = true;
#[cfg(not(test))] self.task.register(cx.waker());
{ self.io_task.wake();
if self.task.is_none() { Poll::Pending
self.task = Some(current_task());
}
if let Some(task) = self.io_task.take() {
task.notify()
}
}
Ok(Async::NotReady)
} }
} }
@@ -228,27 +228,23 @@ impl Inner {
mod tests { mod tests {
use super::*; use super::*;
use actix_rt::Runtime; use actix_rt::Runtime;
use futures::future::{lazy, result}; use futures::future::{poll_fn, ready};
#[test] #[test]
fn test_unread_data() { fn test_unread_data() {
Runtime::new() Runtime::new().unwrap().block_on(async {
.unwrap() let (_, mut payload) = Payload::create(false);
.block_on(lazy(|| {
let (_, mut payload) = Payload::create(false);
payload.unread_data(Bytes::from("data")); payload.unread_data(Bytes::from("data"));
assert!(!payload.is_empty()); assert!(!payload.is_empty());
assert_eq!(payload.len(), 4); assert_eq!(payload.len(), 4);
assert_eq!( assert_eq!(
Async::Ready(Some(Bytes::from("data"))), Bytes::from("data"),
payload.poll().ok().unwrap() poll_fn(|cx| payload.readany(cx)).await.unwrap().unwrap()
); );
let res: Result<(), ()> = Ok(()); ready(())
result(res) });
}))
.unwrap();
} }
} }

View File

@@ -1,12 +1,15 @@
use std::fmt; use std::fmt;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_codec::Framed; use actix_codec::Framed;
use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig}; use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig};
use actix_service::{IntoNewService, NewService, Service}; use actix_service::{IntoServiceFactory, Service, ServiceFactory};
use futures::future::{ok, FutureResult}; use futures::future::{ok, Ready};
use futures::{try_ready, Async, Future, IntoFuture, Poll, Stream}; use futures::{ready, Stream};
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::cloneable::CloneableService; use crate::cloneable::CloneableService;
@@ -20,7 +23,7 @@ use super::codec::Codec;
use super::dispatcher::Dispatcher; use super::dispatcher::Dispatcher;
use super::{ExpectHandler, Message, UpgradeHandler}; use super::{ExpectHandler, Message, UpgradeHandler};
/// `NewService` implementation for HTTP1 transport /// `ServiceFactory` implementation for HTTP1 transport
pub struct H1Service<T, P, S, B, X = ExpectHandler, U = UpgradeHandler<T>> { pub struct H1Service<T, P, S, B, X = ExpectHandler, U = UpgradeHandler<T>> {
srv: S, srv: S,
cfg: ServiceConfig, cfg: ServiceConfig,
@@ -32,19 +35,19 @@ pub struct H1Service<T, P, S, B, X = ExpectHandler, U = UpgradeHandler<T>> {
impl<T, P, S, B> H1Service<T, P, S, B> impl<T, P, S, B> H1Service<T, P, S, B>
where where
S: NewService<Config = SrvConfig, Request = Request>, S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
B: MessageBody, B: MessageBody,
{ {
/// Create new `HttpService` instance with default config. /// Create new `HttpService` instance with default config.
pub fn new<F: IntoNewService<S>>(service: F) -> Self { pub fn new<F: IntoServiceFactory<S>>(service: F) -> Self {
let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0);
H1Service { H1Service {
cfg, cfg,
srv: service.into_new_service(), srv: service.into_factory(),
expect: ExpectHandler, expect: ExpectHandler,
upgrade: None, upgrade: None,
on_connect: None, on_connect: None,
@@ -53,10 +56,13 @@ where
} }
/// Create new `HttpService` instance with config. /// Create new `HttpService` instance with config.
pub fn with_config<F: IntoNewService<S>>(cfg: ServiceConfig, service: F) -> Self { pub fn with_config<F: IntoServiceFactory<S>>(
cfg: ServiceConfig,
service: F,
) -> Self {
H1Service { H1Service {
cfg, cfg,
srv: service.into_new_service(), srv: service.into_factory(),
expect: ExpectHandler, expect: ExpectHandler,
upgrade: None, upgrade: None,
on_connect: None, on_connect: None,
@@ -67,7 +73,7 @@ where
impl<T, P, S, B, X, U> H1Service<T, P, S, B, X, U> impl<T, P, S, B, X, U> H1Service<T, P, S, B, X, U>
where where
S: NewService<Config = SrvConfig, Request = Request>, S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
@@ -75,7 +81,7 @@ where
{ {
pub fn expect<X1>(self, expect: X1) -> H1Service<T, P, S, B, X1, U> pub fn expect<X1>(self, expect: X1) -> H1Service<T, P, S, B, X1, U>
where where
X1: NewService<Request = Request, Response = Request>, X1: ServiceFactory<Request = Request, Response = Request>,
X1::Error: Into<Error>, X1::Error: Into<Error>,
X1::InitError: fmt::Debug, X1::InitError: fmt::Debug,
{ {
@@ -91,7 +97,7 @@ where
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> H1Service<T, P, S, B, X, U1> pub fn upgrade<U1>(self, upgrade: Option<U1>) -> H1Service<T, P, S, B, X, U1>
where where
U1: NewService<Request = (Request, Framed<T, Codec>), Response = ()>, U1: ServiceFactory<Request = (Request, Framed<T, Codec>), Response = ()>,
U1::Error: fmt::Display, U1::Error: fmt::Display,
U1::InitError: fmt::Debug, U1::InitError: fmt::Debug,
{ {
@@ -115,18 +121,18 @@ where
} }
} }
impl<T, P, S, B, X, U> NewService for H1Service<T, P, S, B, X, U> impl<T, P, S, B, X, U> ServiceFactory for H1Service<T, P, S, B, X, U>
where where
T: IoStream, T: IoStream,
S: NewService<Config = SrvConfig, Request = Request>, S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
B: MessageBody, B: MessageBody,
X: NewService<Config = SrvConfig, Request = Request, Response = Request>, X: ServiceFactory<Config = SrvConfig, Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: NewService< U: ServiceFactory<
Config = SrvConfig, Config = SrvConfig,
Request = (Request, Framed<T, Codec>), Request = (Request, Framed<T, Codec>),
Response = (), Response = (),
@@ -144,7 +150,7 @@ where
fn new_service(&self, cfg: &SrvConfig) -> Self::Future { fn new_service(&self, cfg: &SrvConfig) -> Self::Future {
H1ServiceResponse { H1ServiceResponse {
fut: self.srv.new_service(cfg).into_future(), fut: self.srv.new_service(cfg),
fut_ex: Some(self.expect.new_service(cfg)), fut_ex: Some(self.expect.new_service(cfg)),
fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)), fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)),
expect: None, expect: None,
@@ -157,20 +163,24 @@ where
} }
#[doc(hidden)] #[doc(hidden)]
#[pin_project::pin_project]
pub struct H1ServiceResponse<T, P, S, B, X, U> pub struct H1ServiceResponse<T, P, S, B, X, U>
where where
S: NewService<Request = Request>, S: ServiceFactory<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
X: NewService<Request = Request, Response = Request>, X: ServiceFactory<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: NewService<Request = (Request, Framed<T, Codec>), Response = ()>, U: ServiceFactory<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
{ {
#[pin]
fut: S::Future, fut: S::Future,
#[pin]
fut_ex: Option<X::Future>, fut_ex: Option<X::Future>,
#[pin]
fut_upg: Option<U::Future>, fut_upg: Option<U::Future>,
expect: Option<X::Service>, expect: Option<X::Service>,
upgrade: Option<U::Service>, upgrade: Option<U::Service>,
@@ -182,49 +192,57 @@ where
impl<T, P, S, B, X, U> Future for H1ServiceResponse<T, P, S, B, X, U> impl<T, P, S, B, X, U> Future for H1ServiceResponse<T, P, S, B, X, U>
where where
T: IoStream, T: IoStream,
S: NewService<Request = Request>, S: ServiceFactory<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
B: MessageBody, B: MessageBody,
X: NewService<Request = Request, Response = Request>, X: ServiceFactory<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: NewService<Request = (Request, Framed<T, Codec>), Response = ()>, U: ServiceFactory<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
{ {
type Item = H1ServiceHandler<T, P, S::Service, B, X::Service, U::Service>; type Output =
type Error = (); Result<H1ServiceHandler<T, P, S::Service, B, X::Service, U::Service>, ()>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let Some(ref mut fut) = self.fut_ex { let mut this = self.as_mut().project();
let expect = try_ready!(fut
.poll() if let Some(fut) = this.fut_ex.as_pin_mut() {
.map_err(|e| log::error!("Init http service error: {:?}", e))); let expect = ready!(fut
self.expect = Some(expect); .poll(cx)
self.fut_ex.take(); .map_err(|e| log::error!("Init http service error: {:?}", e)))?;
this = self.as_mut().project();
*this.expect = Some(expect);
this.fut_ex.set(None);
} }
if let Some(ref mut fut) = self.fut_upg { if let Some(fut) = this.fut_upg.as_pin_mut() {
let upgrade = try_ready!(fut let upgrade = ready!(fut
.poll() .poll(cx)
.map_err(|e| log::error!("Init http service error: {:?}", e))); .map_err(|e| log::error!("Init http service error: {:?}", e)))?;
self.upgrade = Some(upgrade); this = self.as_mut().project();
self.fut_ex.take(); *this.upgrade = Some(upgrade);
this.fut_ex.set(None);
} }
let service = try_ready!(self let result = ready!(this
.fut .fut
.poll() .poll(cx)
.map_err(|e| log::error!("Init http service error: {:?}", e))); .map_err(|e| log::error!("Init http service error: {:?}", e)));
Ok(Async::Ready(H1ServiceHandler::new(
self.cfg.take().unwrap(), Poll::Ready(result.map(|service| {
service, let this = self.as_mut().project();
self.expect.take().unwrap(), H1ServiceHandler::new(
self.upgrade.take(), this.cfg.take().unwrap(),
self.on_connect.clone(), service,
))) this.expect.take().unwrap(),
this.upgrade.take(),
this.on_connect.clone(),
)
}))
} }
} }
@@ -284,10 +302,10 @@ where
type Error = DispatchError; type Error = DispatchError;
type Future = Dispatcher<T, S, B, X, U>; type Future = Dispatcher<T, S, B, X, U>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let ready = self let ready = self
.expect .expect
.poll_ready() .poll_ready(cx)
.map_err(|e| { .map_err(|e| {
let e = e.into(); let e = e.into();
log::error!("Http service readiness error: {:?}", e); log::error!("Http service readiness error: {:?}", e);
@@ -297,7 +315,7 @@ where
let ready = self let ready = self
.srv .srv
.poll_ready() .poll_ready(cx)
.map_err(|e| { .map_err(|e| {
let e = e.into(); let e = e.into();
log::error!("Http service readiness error: {:?}", e); log::error!("Http service readiness error: {:?}", e);
@@ -307,9 +325,9 @@ where
&& ready; && ready;
if ready { if ready {
Ok(Async::Ready(())) Poll::Ready(Ok(()))
} else { } else {
Ok(Async::NotReady) Poll::Pending
} }
} }
@@ -333,7 +351,7 @@ where
} }
} }
/// `NewService` implementation for `OneRequestService` service /// `ServiceFactory` implementation for `OneRequestService` service
#[derive(Default)] #[derive(Default)]
pub struct OneRequest<T, P> { pub struct OneRequest<T, P> {
config: ServiceConfig, config: ServiceConfig,
@@ -353,7 +371,7 @@ where
} }
} }
impl<T, P> NewService for OneRequest<T, P> impl<T, P> ServiceFactory for OneRequest<T, P>
where where
T: IoStream, T: IoStream,
{ {
@@ -363,7 +381,7 @@ where
type Error = ParseError; type Error = ParseError;
type InitError = (); type InitError = ();
type Service = OneRequestService<T, P>; type Service = OneRequestService<T, P>;
type Future = FutureResult<Self::Service, Self::InitError>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: &SrvConfig) -> Self::Future { fn new_service(&self, _: &SrvConfig) -> Self::Future {
ok(OneRequestService { ok(OneRequestService {
@@ -389,8 +407,8 @@ where
type Error = ParseError; type Error = ParseError;
type Future = OneRequestServiceResponse<T>; type Future = OneRequestServiceResponse<T>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Ok(Async::Ready(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, req: Self::Request) -> Self::Future { fn call(&mut self, req: Self::Request) -> Self::Future {
@@ -415,19 +433,19 @@ impl<T> Future for OneRequestServiceResponse<T>
where where
T: IoStream, T: IoStream,
{ {
type Item = (Request, Framed<T, Codec>); type Output = Result<(Request, Framed<T, Codec>), ParseError>;
type Error = ParseError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match self.framed.as_mut().unwrap().poll()? { match self.framed.as_mut().unwrap().next_item(cx) {
Async::Ready(Some(req)) => match req { Poll::Ready(Some(Ok(req))) => match req {
Message::Item(req) => { Message::Item(req) => {
Ok(Async::Ready((req, self.framed.take().unwrap()))) Poll::Ready(Ok((req, self.framed.take().unwrap())))
} }
Message::Chunk(_) => unreachable!("Something is wrong"), Message::Chunk(_) => unreachable!("Something is wrong"),
}, },
Async::Ready(None) => Err(ParseError::Incomplete), Poll::Ready(Some(Err(err))) => Poll::Ready(Err(err)),
Async::NotReady => Ok(Async::NotReady), Poll::Ready(None) => Poll::Ready(Err(ParseError::Incomplete)),
Poll::Pending => Poll::Pending,
} }
} }
} }

View File

@@ -1,10 +1,12 @@
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::Framed; use actix_codec::Framed;
use actix_server_config::ServerConfig; use actix_server_config::ServerConfig;
use actix_service::{NewService, Service}; use actix_service::{Service, ServiceFactory};
use futures::future::FutureResult; use futures::future::Ready;
use futures::{Async, Poll};
use crate::error::Error; use crate::error::Error;
use crate::h1::Codec; use crate::h1::Codec;
@@ -12,14 +14,14 @@ use crate::request::Request;
pub struct UpgradeHandler<T>(PhantomData<T>); pub struct UpgradeHandler<T>(PhantomData<T>);
impl<T> NewService for UpgradeHandler<T> { impl<T> ServiceFactory for UpgradeHandler<T> {
type Config = ServerConfig; type Config = ServerConfig;
type Request = (Request, Framed<T, Codec>); type Request = (Request, Framed<T, Codec>);
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type Service = UpgradeHandler<T>; type Service = UpgradeHandler<T>;
type InitError = Error; type InitError = Error;
type Future = FutureResult<Self::Service, Self::InitError>; type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: &ServerConfig) -> Self::Future { fn new_service(&self, _: &ServerConfig) -> Self::Future {
unimplemented!() unimplemented!()
@@ -30,10 +32,10 @@ impl<T> Service for UpgradeHandler<T> {
type Request = (Request, Framed<T, Codec>); type Request = (Request, Framed<T, Codec>);
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type Future = FutureResult<Self::Response, Self::Error>; type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Ok(Async::Ready(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, _: Self::Request) -> Self::Future { fn call(&mut self, _: Self::Request) -> Self::Future {

View File

@@ -1,5 +1,9 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use futures::{Async, Future, Poll, Sink}; use futures::Sink;
use crate::body::{BodySize, MessageBody, ResponseBody}; use crate::body::{BodySize, MessageBody, ResponseBody};
use crate::error::Error; use crate::error::Error;
@@ -7,6 +11,7 @@ use crate::h1::{Codec, Message};
use crate::response::Response; use crate::response::Response;
/// Send http/1 response /// Send http/1 response
#[pin_project::pin_project]
pub struct SendResponse<T, B> { pub struct SendResponse<T, B> {
res: Option<Message<(Response<()>, BodySize)>>, res: Option<Message<(Response<()>, BodySize)>>,
body: Option<ResponseBody<B>>, body: Option<ResponseBody<B>>,
@@ -33,60 +38,61 @@ where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
B: MessageBody, B: MessageBody,
{ {
type Item = Framed<T, Codec>; type Output = Result<Framed<T, Codec>, Error>;
type Error = Error;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop { loop {
let mut body_ready = self.body.is_some(); let mut body_ready = this.body.is_some();
let framed = self.framed.as_mut().unwrap(); let framed = this.framed.as_mut().unwrap();
// send body // send body
if self.res.is_none() && self.body.is_some() { if this.res.is_none() && this.body.is_some() {
while body_ready && self.body.is_some() && !framed.is_write_buf_full() { while body_ready && this.body.is_some() && !framed.is_write_buf_full() {
match self.body.as_mut().unwrap().poll_next()? { match this.body.as_mut().unwrap().poll_next(cx)? {
Async::Ready(item) => { Poll::Ready(item) => {
// body is done // body is done
if item.is_none() { if item.is_none() {
let _ = self.body.take(); let _ = this.body.take();
} }
framed.force_send(Message::Chunk(item))?; framed.write(Message::Chunk(item))?;
} }
Async::NotReady => body_ready = false, Poll::Pending => body_ready = false,
} }
} }
} }
// flush write buffer // flush write buffer
if !framed.is_write_buf_empty() { if !framed.is_write_buf_empty() {
match framed.poll_complete()? { match framed.flush(cx)? {
Async::Ready(_) => { Poll::Ready(_) => {
if body_ready { if body_ready {
continue; continue;
} else { } else {
return Ok(Async::NotReady); return Poll::Pending;
} }
} }
Async::NotReady => return Ok(Async::NotReady), Poll::Pending => return Poll::Pending,
} }
} }
// send response // send response
if let Some(res) = self.res.take() { if let Some(res) = this.res.take() {
framed.force_send(res)?; framed.write(res)?;
continue; continue;
} }
if self.body.is_some() { if this.body.is_some() {
if body_ready { if body_ready {
continue; continue;
} else { } else {
return Ok(Async::NotReady); return Poll::Pending;
} }
} else { } else {
break; break;
} }
} }
Ok(Async::Ready(self.framed.take().unwrap())) Poll::Ready(Ok(this.framed.take().unwrap()))
} }
} }

View File

@@ -1,5 +1,8 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant; use std::time::Instant;
use std::{fmt, mem, net}; use std::{fmt, mem, net};
@@ -8,7 +11,7 @@ use actix_server_config::IoStream;
use actix_service::Service; use actix_service::Service;
use bitflags::bitflags; use bitflags::bitflags;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{try_ready, Async, Future, Poll, Sink, Stream}; use futures::{ready, Sink, Stream};
use h2::server::{Connection, SendResponse}; use h2::server::{Connection, SendResponse};
use h2::{RecvStream, SendStream}; use h2::{RecvStream, SendStream};
use http::header::{ use http::header::{
@@ -32,6 +35,7 @@ use crate::response::Response;
const CHUNK_SIZE: usize = 16_384; const CHUNK_SIZE: usize = 16_384;
/// Dispatcher for HTTP/2 protocol /// Dispatcher for HTTP/2 protocol
#[pin_project::pin_project]
pub struct Dispatcher<T: IoStream, S: Service<Request = Request>, B: MessageBody> { pub struct Dispatcher<T: IoStream, S: Service<Request = Request>, B: MessageBody> {
service: CloneableService<S>, service: CloneableService<S>,
connection: Connection<T, Bytes>, connection: Connection<T, Bytes>,
@@ -48,9 +52,9 @@ where
T: IoStream, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Future: 'static, // S::Future: 'static,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody,
{ {
pub(crate) fn new( pub(crate) fn new(
service: CloneableService<S>, service: CloneableService<S>,
@@ -93,61 +97,75 @@ impl<T, S, B> Future for Dispatcher<T, S, B>
where where
T: IoStream, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
type Item = (); type Output = Result<(), DispatchError>;
type Error = DispatchError;
#[inline] #[inline]
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
loop { loop {
match self.connection.poll()? { match Pin::new(&mut this.connection).poll_accept(cx) {
Async::Ready(None) => return Ok(Async::Ready(())), Poll::Ready(None) => return Poll::Ready(Ok(())),
Async::Ready(Some((req, res))) => { Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())),
Poll::Ready(Some(Ok((req, res)))) => {
// update keep-alive expire // update keep-alive expire
if self.ka_timer.is_some() { if this.ka_timer.is_some() {
if let Some(expire) = self.config.keep_alive_expire() { if let Some(expire) = this.config.keep_alive_expire() {
self.ka_expire = expire; this.ka_expire = expire;
} }
} }
let (parts, body) = req.into_parts(); let (parts, body) = req.into_parts();
let mut req = Request::with_payload(body.into()); let mut req = Request::with_payload(Payload::<
crate::payload::PayloadStream,
>::H2(
crate::h2::Payload::new(body)
));
let head = &mut req.head_mut(); let head = &mut req.head_mut();
head.uri = parts.uri; head.uri = parts.uri;
head.method = parts.method; head.method = parts.method;
head.version = parts.version; head.version = parts.version;
head.headers = parts.headers.into(); head.headers = parts.headers.into();
head.peer_addr = self.peer_addr; head.peer_addr = this.peer_addr;
// set on_connect data // set on_connect data
if let Some(ref on_connect) = self.on_connect { if let Some(ref on_connect) = this.on_connect {
on_connect.set(&mut req.extensions_mut()); on_connect.set(&mut req.extensions_mut());
} }
tokio_current_thread::spawn(ServiceResponse::<S::Future, B> { tokio_executor::current_thread::spawn(ServiceResponse::<
S::Future,
S::Response,
S::Error,
B,
> {
state: ServiceResponseState::ServiceCall( state: ServiceResponseState::ServiceCall(
self.service.call(req), this.service.call(req),
Some(res), Some(res),
), ),
config: self.config.clone(), config: this.config.clone(),
buffer: None, buffer: None,
}) _t: PhantomData,
});
} }
Async::NotReady => return Ok(Async::NotReady), Poll::Pending => return Poll::Pending,
} }
} }
} }
} }
struct ServiceResponse<F, B> { #[pin_project::pin_project]
struct ServiceResponse<F, I, E, B> {
state: ServiceResponseState<F, B>, state: ServiceResponseState<F, B>,
config: ServiceConfig, config: ServiceConfig,
buffer: Option<Bytes>, buffer: Option<Bytes>,
_t: PhantomData<(I, E)>,
} }
enum ServiceResponseState<F, B> { enum ServiceResponseState<F, B> {
@@ -155,12 +173,12 @@ enum ServiceResponseState<F, B> {
SendPayload(SendStream<Bytes>, ResponseBody<B>), SendPayload(SendStream<Bytes>, ResponseBody<B>),
} }
impl<F, B> ServiceResponse<F, B> impl<F, I, E, B> ServiceResponse<F, I, E, B>
where where
F: Future, F: Future<Output = Result<I, E>>,
F::Error: Into<Error>, E: Into<Error>,
F::Item: Into<Response<B>>, I: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody,
{ {
fn prepare_response( fn prepare_response(
&self, &self,
@@ -223,109 +241,121 @@ where
} }
} }
impl<F, B> Future for ServiceResponse<F, B> impl<F, I, E, B> Future for ServiceResponse<F, I, E, B>
where where
F: Future, F: Future<Output = Result<I, E>>,
F::Error: Into<Error>, E: Into<Error>,
F::Item: Into<Response<B>>, I: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody,
{ {
type Item = (); type Output = ();
type Error = ();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match self.state { let mut this = self.as_mut().project();
match this.state {
ServiceResponseState::ServiceCall(ref mut call, ref mut send) => { ServiceResponseState::ServiceCall(ref mut call, ref mut send) => {
match call.poll() { match unsafe { Pin::new_unchecked(call) }.poll(cx) {
Ok(Async::Ready(res)) => { Poll::Ready(Ok(res)) => {
let (res, body) = res.into().replace_body(()); let (res, body) = res.into().replace_body(());
let mut send = send.take().unwrap(); let mut send = send.take().unwrap();
let mut size = body.size(); let mut size = body.size();
let h2_res = self.prepare_response(res.head(), &mut size); let h2_res =
self.as_mut().prepare_response(res.head(), &mut size);
this = self.as_mut().project();
let stream = let stream = match send.send_response(h2_res, size.is_eof()) {
send.send_response(h2_res, size.is_eof()).map_err(|e| { Err(e) => {
trace!("Error sending h2 response: {:?}", e); trace!("Error sending h2 response: {:?}", e);
})?; return Poll::Ready(());
}
Ok(stream) => stream,
};
if size.is_eof() { if size.is_eof() {
Ok(Async::Ready(())) Poll::Ready(())
} else { } else {
self.state = ServiceResponseState::SendPayload(stream, body); *this.state =
self.poll() ServiceResponseState::SendPayload(stream, body);
self.poll(cx)
} }
} }
Ok(Async::NotReady) => Ok(Async::NotReady), Poll::Pending => Poll::Pending,
Err(e) => { Poll::Ready(Err(e)) => {
let res: Response = e.into().into(); let res: Response = e.into().into();
let (res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
let mut send = send.take().unwrap(); let mut send = send.take().unwrap();
let mut size = body.size(); let mut size = body.size();
let h2_res = self.prepare_response(res.head(), &mut size); let h2_res =
self.as_mut().prepare_response(res.head(), &mut size);
this = self.as_mut().project();
let stream = let stream = match send.send_response(h2_res, size.is_eof()) {
send.send_response(h2_res, size.is_eof()).map_err(|e| { Err(e) => {
trace!("Error sending h2 response: {:?}", e); trace!("Error sending h2 response: {:?}", e);
})?; return Poll::Ready(());
}
Ok(stream) => stream,
};
if size.is_eof() { if size.is_eof() {
Ok(Async::Ready(())) Poll::Ready(())
} else { } else {
self.state = ServiceResponseState::SendPayload( *this.state = ServiceResponseState::SendPayload(
stream, stream,
body.into_body(), body.into_body(),
); );
self.poll() self.poll(cx)
} }
} }
} }
} }
ServiceResponseState::SendPayload(ref mut stream, ref mut body) => loop { ServiceResponseState::SendPayload(ref mut stream, ref mut body) => loop {
loop { loop {
if let Some(ref mut buffer) = self.buffer { if let Some(ref mut buffer) = this.buffer {
match stream.poll_capacity().map_err(|e| warn!("{:?}", e))? { match stream.poll_capacity(cx) {
Async::NotReady => return Ok(Async::NotReady), Poll::Pending => return Poll::Pending,
Async::Ready(None) => return Ok(Async::Ready(())), Poll::Ready(None) => return Poll::Ready(()),
Async::Ready(Some(cap)) => { Poll::Ready(Some(Ok(cap))) => {
let len = buffer.len(); let len = buffer.len();
let bytes = buffer.split_to(std::cmp::min(cap, len)); let bytes = buffer.split_to(std::cmp::min(cap, len));
if let Err(e) = stream.send_data(bytes, false) { if let Err(e) = stream.send_data(bytes, false) {
warn!("{:?}", e); warn!("{:?}", e);
return Err(()); return Poll::Ready(());
} else if !buffer.is_empty() { } else if !buffer.is_empty() {
let cap = std::cmp::min(buffer.len(), CHUNK_SIZE); let cap = std::cmp::min(buffer.len(), CHUNK_SIZE);
stream.reserve_capacity(cap); stream.reserve_capacity(cap);
} else { } else {
self.buffer.take(); this.buffer.take();
} }
} }
Poll::Ready(Some(Err(e))) => {
warn!("{:?}", e);
return Poll::Ready(());
}
} }
} else { } else {
match body.poll_next() { match body.poll_next(cx) {
Ok(Async::NotReady) => { Poll::Pending => return Poll::Pending,
return Ok(Async::NotReady); Poll::Ready(None) => {
}
Ok(Async::Ready(None)) => {
if let Err(e) = stream.send_data(Bytes::new(), true) { if let Err(e) = stream.send_data(Bytes::new(), true) {
warn!("{:?}", e); warn!("{:?}", e);
return Err(());
} else {
return Ok(Async::Ready(()));
} }
return Poll::Ready(());
} }
Ok(Async::Ready(Some(chunk))) => { Poll::Ready(Some(Ok(chunk))) => {
stream.reserve_capacity(std::cmp::min( stream.reserve_capacity(std::cmp::min(
chunk.len(), chunk.len(),
CHUNK_SIZE, CHUNK_SIZE,
)); ));
self.buffer = Some(chunk); *this.buffer = Some(chunk);
} }
Err(e) => { Poll::Ready(Some(Err(e))) => {
error!("Response payload stream error: {:?}", e); error!("Response payload stream error: {:?}", e);
return Err(()); return Poll::Ready(());
} }
} }
} }

View File

@@ -1,9 +1,11 @@
#![allow(dead_code, unused_imports)] #![allow(dead_code, unused_imports)]
use std::fmt; use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes; use bytes::Bytes;
use futures::{Async, Poll, Stream}; use futures::Stream;
use h2::RecvStream; use h2::RecvStream;
mod dispatcher; mod dispatcher;
@@ -25,22 +27,23 @@ impl Payload {
} }
impl Stream for Payload { impl Stream for Payload {
type Item = Bytes; type Item = Result<Bytes, PayloadError>;
type Error = PayloadError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match self.pl.poll() { let this = self.get_mut();
Ok(Async::Ready(Some(chunk))) => {
match Pin::new(&mut this.pl).poll_data(cx) {
Poll::Ready(Some(Ok(chunk))) => {
let len = chunk.len(); let len = chunk.len();
if let Err(err) = self.pl.release_capacity().release_capacity(len) { if let Err(err) = this.pl.release_capacity().release_capacity(len) {
Err(err.into()) Poll::Ready(Some(Err(err.into())))
} else { } else {
Ok(Async::Ready(Some(chunk))) Poll::Ready(Some(Ok(chunk)))
} }
} }
Ok(Async::Ready(None)) => Ok(Async::Ready(None)), Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))),
Ok(Async::NotReady) => Ok(Async::NotReady), Poll::Pending => Poll::Pending,
Err(err) => Err(err.into()), Poll::Ready(None) => Poll::Ready(None),
} }
} }
} }

View File

@@ -1,13 +1,16 @@
use std::fmt::Debug; use std::fmt::Debug;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, net, rc}; use std::{io, net, rc};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig}; use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig};
use actix_service::{IntoNewService, NewService, Service}; use actix_service::{IntoServiceFactory, Service, ServiceFactory};
use bytes::Bytes; use bytes::Bytes;
use futures::future::{ok, FutureResult}; use futures::future::{ok, Ready};
use futures::{try_ready, Async, Future, IntoFuture, Poll, Stream}; use futures::{ready, Stream};
use h2::server::{self, Connection, Handshake}; use h2::server::{self, Connection, Handshake};
use h2::RecvStream; use h2::RecvStream;
use log::error; use log::error;
@@ -23,7 +26,7 @@ use crate::response::Response;
use super::dispatcher::Dispatcher; use super::dispatcher::Dispatcher;
/// `NewService` implementation for HTTP2 transport /// `ServiceFactory` implementation for HTTP2 transport
pub struct H2Service<T, P, S, B> { pub struct H2Service<T, P, S, B> {
srv: S, srv: S,
cfg: ServiceConfig, cfg: ServiceConfig,
@@ -33,30 +36,33 @@ pub struct H2Service<T, P, S, B> {
impl<T, P, S, B> H2Service<T, P, S, B> impl<T, P, S, B> H2Service<T, P, S, B>
where where
S: NewService<Config = SrvConfig, Request = Request>, S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
/// Create new `HttpService` instance. /// Create new `HttpService` instance.
pub fn new<F: IntoNewService<S>>(service: F) -> Self { pub fn new<F: IntoServiceFactory<S>>(service: F) -> Self {
let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0);
H2Service { H2Service {
cfg, cfg,
on_connect: None, on_connect: None,
srv: service.into_new_service(), srv: service.into_factory(),
_t: PhantomData, _t: PhantomData,
} }
} }
/// Create new `HttpService` instance with config. /// Create new `HttpService` instance with config.
pub fn with_config<F: IntoNewService<S>>(cfg: ServiceConfig, service: F) -> Self { pub fn with_config<F: IntoServiceFactory<S>>(
cfg: ServiceConfig,
service: F,
) -> Self {
H2Service { H2Service {
cfg, cfg,
on_connect: None, on_connect: None,
srv: service.into_new_service(), srv: service.into_factory(),
_t: PhantomData, _t: PhantomData,
} }
} }
@@ -71,12 +77,12 @@ where
} }
} }
impl<T, P, S, B> NewService for H2Service<T, P, S, B> impl<T, P, S, B> ServiceFactory for H2Service<T, P, S, B>
where where
T: IoStream, T: IoStream,
S: NewService<Config = SrvConfig, Request = Request>, S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
@@ -90,7 +96,7 @@ where
fn new_service(&self, cfg: &SrvConfig) -> Self::Future { fn new_service(&self, cfg: &SrvConfig) -> Self::Future {
H2ServiceResponse { H2ServiceResponse {
fut: self.srv.new_service(cfg).into_future(), fut: self.srv.new_service(cfg),
cfg: Some(self.cfg.clone()), cfg: Some(self.cfg.clone()),
on_connect: self.on_connect.clone(), on_connect: self.on_connect.clone(),
_t: PhantomData, _t: PhantomData,
@@ -99,8 +105,10 @@ where
} }
#[doc(hidden)] #[doc(hidden)]
pub struct H2ServiceResponse<T, P, S: NewService, B> { #[pin_project::pin_project]
fut: <S::Future as IntoFuture>::Future, pub struct H2ServiceResponse<T, P, S: ServiceFactory, B> {
#[pin]
fut: S::Future,
cfg: Option<ServiceConfig>, cfg: Option<ServiceConfig>,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
_t: PhantomData<(T, P, B)>, _t: PhantomData<(T, P, B)>,
@@ -109,22 +117,25 @@ pub struct H2ServiceResponse<T, P, S: NewService, B> {
impl<T, P, S, B> Future for H2ServiceResponse<T, P, S, B> impl<T, P, S, B> Future for H2ServiceResponse<T, P, S, B>
where where
T: IoStream, T: IoStream,
S: NewService<Config = SrvConfig, Request = Request>, S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
type Item = H2ServiceHandler<T, P, S::Service, B>; type Output = Result<H2ServiceHandler<T, P, S::Service, B>, S::InitError>;
type Error = S::InitError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let service = try_ready!(self.fut.poll()); let this = self.as_mut().project();
Ok(Async::Ready(H2ServiceHandler::new(
self.cfg.take().unwrap(), Poll::Ready(ready!(this.fut.poll(cx)).map(|service| {
self.on_connect.clone(), let this = self.as_mut().project();
service, H2ServiceHandler::new(
))) this.cfg.take().unwrap(),
this.on_connect.clone(),
service,
)
}))
} }
} }
@@ -139,9 +150,9 @@ pub struct H2ServiceHandler<T, P, S, B> {
impl<T, P, S, B> H2ServiceHandler<T, P, S, B> impl<T, P, S, B> H2ServiceHandler<T, P, S, B>
where where
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
fn new( fn new(
@@ -162,9 +173,9 @@ impl<T, P, S, B> Service for H2ServiceHandler<T, P, S, B>
where where
T: IoStream, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
type Request = Io<T, P>; type Request = Io<T, P>;
@@ -172,8 +183,8 @@ where
type Error = DispatchError; type Error = DispatchError;
type Future = H2ServiceHandlerResponse<T, S, B>; type Future = H2ServiceHandlerResponse<T, S, B>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.srv.poll_ready().map_err(|e| { self.srv.poll_ready(cx).map_err(|e| {
let e = e.into(); let e = e.into();
error!("Service readiness error: {:?}", e); error!("Service readiness error: {:?}", e);
DispatchError::Service(e) DispatchError::Service(e)
@@ -219,9 +230,9 @@ pub struct H2ServiceHandlerResponse<T, S, B>
where where
T: IoStream, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
state: State<T, S, B>, state: State<T, S, B>,
@@ -231,25 +242,24 @@ impl<T, S, B> Future for H2ServiceHandlerResponse<T, S, B>
where where
T: IoStream, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
B: MessageBody, B: MessageBody,
{ {
type Item = (); type Output = Result<(), DispatchError>;
type Error = DispatchError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match self.state { match self.state {
State::Incoming(ref mut disp) => disp.poll(), State::Incoming(ref mut disp) => Pin::new(disp).poll(cx),
State::Handshake( State::Handshake(
ref mut srv, ref mut srv,
ref mut config, ref mut config,
ref peer_addr, ref peer_addr,
ref mut on_connect, ref mut on_connect,
ref mut handshake, ref mut handshake,
) => match handshake.poll() { ) => match Pin::new(handshake).poll(cx) {
Ok(Async::Ready(conn)) => { Poll::Ready(Ok(conn)) => {
self.state = State::Incoming(Dispatcher::new( self.state = State::Incoming(Dispatcher::new(
srv.take().unwrap(), srv.take().unwrap(),
conn, conn,
@@ -258,13 +268,13 @@ where
None, None,
*peer_addr, *peer_addr,
)); ));
self.poll() self.poll(cx)
} }
Ok(Async::NotReady) => Ok(Async::NotReady), Poll::Ready(Err(err)) => {
Err(err) => {
trace!("H2 handshake error: {}", err); trace!("H2 handshake error: {}", err);
Err(err.into()) Poll::Ready(Err(err.into()))
} }
Poll::Pending => Poll::Pending,
}, },
} }
} }

View File

@@ -76,6 +76,11 @@ pub enum DispositionParam {
/// the form. /// the form.
Name(String), Name(String),
/// A plain file name. /// A plain file name.
///
/// It is [not supposed](https://tools.ietf.org/html/rfc6266#appendix-D) to contain any
/// non-ASCII characters when used in a *Content-Disposition* HTTP response header, where
/// [`FilenameExt`](DispositionParam::FilenameExt) with charset UTF-8 may be used instead
/// in case there are Unicode characters in file names.
Filename(String), Filename(String),
/// An extended file name. It must not exist for `ContentType::Formdata` according to /// An extended file name. It must not exist for `ContentType::Formdata` according to
/// [RFC7578 Section 4.2](https://tools.ietf.org/html/rfc7578#section-4.2). /// [RFC7578 Section 4.2](https://tools.ietf.org/html/rfc7578#section-4.2).
@@ -220,7 +225,16 @@ impl DispositionParam {
/// ext-token = <the characters in token, followed by "*"> /// ext-token = <the characters in token, followed by "*">
/// ``` /// ```
/// ///
/// **Note**: filename* [must not](https://tools.ietf.org/html/rfc7578#section-4.2) be used within /// # Note
///
/// filename is [not supposed](https://tools.ietf.org/html/rfc6266#appendix-D) to contain any
/// non-ASCII characters when used in a *Content-Disposition* HTTP response header, where
/// filename* with charset UTF-8 may be used instead in case there are Unicode characters in file
/// names.
/// filename is [acceptable](https://tools.ietf.org/html/rfc7578#section-4.2) to be UTF-8 encoded
/// directly in a *Content-Disposition* header for *multipart/form-data*, though.
///
/// filename* [must not](https://tools.ietf.org/html/rfc7578#section-4.2) be used within
/// *multipart/form-data*. /// *multipart/form-data*.
/// ///
/// # Example /// # Example
@@ -251,6 +265,22 @@ impl DispositionParam {
/// }; /// };
/// assert_eq!(cd2.get_name(), Some("file")); // field name /// assert_eq!(cd2.get_name(), Some("file")); // field name
/// assert_eq!(cd2.get_filename(), Some("bill.odt")); /// assert_eq!(cd2.get_filename(), Some("bill.odt"));
///
/// // HTTP response header with Unicode characters in file names
/// let cd3 = ContentDisposition {
/// disposition: DispositionType::Attachment,
/// parameters: vec![
/// DispositionParam::FilenameExt(ExtendedValue {
/// charset: Charset::Ext(String::from("UTF-8")),
/// language_tag: None,
/// value: String::from("\u{1f600}.svg").into_bytes(),
/// }),
/// // fallback for better compatibility
/// DispositionParam::Filename(String::from("Grinning-Face-Emoji.svg"))
/// ],
/// };
/// assert_eq!(cd3.get_filename_ext().map(|ev| ev.value.as_ref()),
/// Some("\u{1f600}.svg".as_bytes()));
/// ``` /// ```
/// ///
/// # WARN /// # WARN
@@ -333,15 +363,17 @@ impl ContentDisposition {
// token: won't contains semicolon according to RFC 2616 Section 2.2 // token: won't contains semicolon according to RFC 2616 Section 2.2
let (token, new_left) = split_once_and_trim(left, ';'); let (token, new_left) = split_once_and_trim(left, ';');
left = new_left; left = new_left;
if token.is_empty() {
// quoted-string can be empty, but token cannot be empty
return Err(crate::error::ParseError::Header);
}
token.to_owned() token.to_owned()
}; };
if value.is_empty() {
return Err(crate::error::ParseError::Header);
}
let param = if param_name.eq_ignore_ascii_case("name") { let param = if param_name.eq_ignore_ascii_case("name") {
DispositionParam::Name(value) DispositionParam::Name(value)
} else if param_name.eq_ignore_ascii_case("filename") { } else if param_name.eq_ignore_ascii_case("filename") {
// See also comments in test_from_raw_uncessary_percent_decode.
DispositionParam::Filename(value) DispositionParam::Filename(value)
} else { } else {
DispositionParam::Unknown(param_name.to_owned(), value) DispositionParam::Unknown(param_name.to_owned(), value)
@@ -466,11 +498,40 @@ impl fmt::Display for DispositionType {
impl fmt::Display for DispositionParam { impl fmt::Display for DispositionParam {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// All ASCII control charaters (0-30, 127) excepting horizontal tab, double quote, and // All ASCII control characters (0-30, 127) including horizontal tab, double quote, and
// backslash should be escaped in quoted-string (i.e. "foobar"). // backslash should be escaped in quoted-string (i.e. "foobar").
// Ref: RFC6266 S4.1 -> RFC2616 S2.2; RFC 7578 S4.2 -> RFC2183 S2 -> ... . // Ref: RFC6266 S4.1 -> RFC2616 S3.6
// filename-parm = "filename" "=" value
// value = token | quoted-string
// quoted-string = ( <"> *(qdtext | quoted-pair ) <"> )
// qdtext = <any TEXT except <">>
// quoted-pair = "\" CHAR
// TEXT = <any OCTET except CTLs,
// but including LWS>
// LWS = [CRLF] 1*( SP | HT )
// OCTET = <any 8-bit sequence of data>
// CHAR = <any US-ASCII character (octets 0 - 127)>
// CTL = <any US-ASCII control character
// (octets 0 - 31) and DEL (127)>
//
// Ref: RFC7578 S4.2 -> RFC2183 S2 -> RFC2045 S5.1
// parameter := attribute "=" value
// attribute := token
// ; Matching of attributes
// ; is ALWAYS case-insensitive.
// value := token / quoted-string
// token := 1*<any (US-ASCII) CHAR except SPACE, CTLs,
// or tspecials>
// tspecials := "(" / ")" / "<" / ">" / "@" /
// "," / ";" / ":" / "\" / <">
// "/" / "[" / "]" / "?" / "="
// ; Must be in quoted-string,
// ; to use within parameter values
//
//
// See also comments in test_from_raw_uncessary_percent_decode.
lazy_static! { lazy_static! {
static ref RE: Regex = Regex::new("[\x01-\x08\x10\x1F\x7F\"\\\\]").unwrap(); static ref RE: Regex = Regex::new("[\x00-\x08\x10-\x1F\x7F\"\\\\]").unwrap();
} }
match self { match self {
DispositionParam::Name(ref value) => write!(f, "name={}", value), DispositionParam::Name(ref value) => write!(f, "name={}", value),
@@ -774,8 +835,18 @@ mod tests {
#[test] #[test]
fn test_from_raw_uncessary_percent_decode() { fn test_from_raw_uncessary_percent_decode() {
// In fact, RFC7578 (multipart/form-data) Section 2 and 4.2 suggests that filename with
// non-ASCII characters MAY be percent-encoded.
// On the contrary, RFC6266 or other RFCs related to Content-Disposition response header
// do not mention such percent-encoding.
// So, it appears to be undecidable whether to percent-decode or not without
// knowing the usage scenario (multipart/form-data v.s. HTTP response header) and
// inevitable to unnecessarily percent-decode filename with %XX in the former scenario.
// Fortunately, it seems that almost all mainstream browsers just send UTF-8 encoded file
// names in quoted-string format (tested on Edge, IE11, Chrome and Firefox) without
// percent-encoding. So we do not bother to attempt to percent-decode.
let a = HeaderValue::from_static( let a = HeaderValue::from_static(
"form-data; name=photo; filename=\"%74%65%73%74%2e%70%6e%67\"", // Should not be decoded! "form-data; name=photo; filename=\"%74%65%73%74%2e%70%6e%67\"",
); );
let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap();
let b = ContentDisposition { let b = ContentDisposition {
@@ -811,6 +882,9 @@ mod tests {
let a = HeaderValue::from_static("inline; filename= "); let a = HeaderValue::from_static("inline; filename= ");
assert!(ContentDisposition::from_raw(&a).is_err()); assert!(ContentDisposition::from_raw(&a).is_err());
let a = HeaderValue::from_static("inline; filename=\"\"");
assert!(ContentDisposition::from_raw(&a).expect("parse cd").get_filename().expect("filename").is_empty());
} }
#[test] #[test]

View File

@@ -4,7 +4,8 @@
clippy::too_many_arguments, clippy::too_many_arguments,
clippy::new_without_default, clippy::new_without_default,
clippy::borrow_interior_mutable_const, clippy::borrow_interior_mutable_const,
clippy::write_with_newline clippy::write_with_newline,
unused_imports
)] )]
#[macro_use] #[macro_use]

View File

@@ -388,6 +388,12 @@ impl BoxedResponseHead {
pub fn new(status: StatusCode) -> Self { pub fn new(status: StatusCode) -> Self {
RESPONSE_POOL.with(|p| p.get_message(status)) RESPONSE_POOL.with(|p| p.get_message(status))
} }
pub(crate) fn take(&mut self) -> Self {
BoxedResponseHead {
head: self.head.take(),
}
}
} }
impl std::ops::Deref for BoxedResponseHead { impl std::ops::Deref for BoxedResponseHead {
@@ -406,7 +412,9 @@ impl std::ops::DerefMut for BoxedResponseHead {
impl Drop for BoxedResponseHead { impl Drop for BoxedResponseHead {
fn drop(&mut self) { fn drop(&mut self) {
RESPONSE_POOL.with(|p| p.release(self.head.take().unwrap())) if let Some(head) = self.head.take() {
RESPONSE_POOL.with(move |p| p.release(head))
}
} }
} }

View File

@@ -1,11 +1,15 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes; use bytes::Bytes;
use futures::{Async, Poll, Stream}; use futures::Stream;
use h2::RecvStream; use h2::RecvStream;
use crate::error::PayloadError; use crate::error::PayloadError;
/// Type represent boxed payload /// Type represent boxed payload
pub type PayloadStream = Box<dyn Stream<Item = Bytes, Error = PayloadError>>; pub type PayloadStream = Pin<Box<dyn Stream<Item = Result<Bytes, PayloadError>>>>;
/// Type represent streaming payload /// Type represent streaming payload
pub enum Payload<S = PayloadStream> { pub enum Payload<S = PayloadStream> {
@@ -48,18 +52,17 @@ impl<S> Payload<S> {
impl<S> Stream for Payload<S> impl<S> Stream for Payload<S>
where where
S: Stream<Item = Bytes, Error = PayloadError>, S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{ {
type Item = Bytes; type Item = Result<Bytes, PayloadError>;
type Error = PayloadError;
#[inline] #[inline]
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match self { match self.get_mut() {
Payload::None => Ok(Async::Ready(None)), Payload::None => Poll::Ready(None),
Payload::H1(ref mut pl) => pl.poll(), Payload::H1(ref mut pl) => pl.readany(cx),
Payload::H2(ref mut pl) => pl.poll(), Payload::H2(ref mut pl) => Pin::new(pl).poll_next(cx),
Payload::Stream(ref mut pl) => pl.poll(), Payload::Stream(ref mut pl) => Pin::new(pl).poll_next(cx),
} }
} }
} }

View File

@@ -80,6 +80,11 @@ impl<P> Request<P> {
) )
} }
/// Get request's payload
pub fn payload(&mut self) -> &mut Payload<P> {
&mut self.payload
}
/// Get request's payload /// Get request's payload
pub fn take_payload(&mut self) -> Payload<P> { pub fn take_payload(&mut self) -> Payload<P> {
std::mem::replace(&mut self.payload, Payload::None) std::mem::replace(&mut self.payload, Payload::None)
@@ -199,7 +204,6 @@ mod tests {
assert_eq!(req.uri().query(), Some("q=1")); assert_eq!(req.uri().query(), Some("q=1"));
let s = format!("{:?}", req); let s = format!("{:?}", req);
println!("T: {:?}", s);
assert!(s.contains("Request HTTP/1.1 GET:/index.html")); assert!(s.contains("Request HTTP/1.1 GET:/index.html"));
} }
} }

View File

@@ -1,11 +1,14 @@
//! Http response //! Http response
use std::cell::{Ref, RefMut}; use std::cell::{Ref, RefMut};
use std::future::Future;
use std::io::Write; use std::io::Write;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, str}; use std::{fmt, str};
use bytes::{BufMut, Bytes, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};
use futures::future::{ok, FutureResult, IntoFuture}; use futures::future::{ok, Ready};
use futures::Stream; use futures::stream::Stream;
use serde::Serialize; use serde::Serialize;
use serde_json; use serde_json;
@@ -280,13 +283,15 @@ impl<B: MessageBody> fmt::Debug for Response<B> {
} }
} }
impl IntoFuture for Response { impl Future for Response {
type Item = Response; type Output = Result<Response, Error>;
type Error = Error;
type Future = FutureResult<Response, Error>;
fn into_future(self) -> Self::Future { fn poll(mut self: Pin<&mut Self>, _: &mut Context) -> Poll<Self::Output> {
ok(self) Poll::Ready(Ok(Response {
head: self.head.take(),
body: self.body.take_body(),
error: self.error.take(),
}))
} }
} }
@@ -635,7 +640,7 @@ impl ResponseBuilder {
/// `ResponseBuilder` can not be used after this call. /// `ResponseBuilder` can not be used after this call.
pub fn streaming<S, E>(&mut self, stream: S) -> Response pub fn streaming<S, E>(&mut self, stream: S) -> Response
where where
S: Stream<Item = Bytes, Error = E> + 'static, S: Stream<Item = Result<Bytes, E>> + 'static,
E: Into<Error> + 'static, E: Into<Error> + 'static,
{ {
self.body(Body::from_message(BodyStream::new(stream))) self.body(Body::from_message(BodyStream::new(stream)))
@@ -757,13 +762,11 @@ impl<'a> From<&'a ResponseHead> for ResponseBuilder {
} }
} }
impl IntoFuture for ResponseBuilder { impl Future for ResponseBuilder {
type Item = Response; type Output = Result<Response, Error>;
type Error = Error;
type Future = FutureResult<Response, Error>;
fn into_future(mut self) -> Self::Future { fn poll(mut self: Pin<&mut Self>, _: &mut Context) -> Poll<Self::Output> {
ok(self.finish()) Poll::Ready(Ok(self.finish()))
} }
} }

View File

@@ -1,14 +1,17 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, io, net, rc}; use std::{fmt, io, net, rc};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_server_config::{ use actix_server_config::{
Io as ServerIo, IoStream, Protocol, ServerConfig as SrvConfig, Io as ServerIo, IoStream, Protocol, ServerConfig as SrvConfig,
}; };
use actix_service::{IntoNewService, NewService, Service}; use actix_service::{IntoServiceFactory, Service, ServiceFactory};
use bytes::{Buf, BufMut, Bytes, BytesMut}; use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures::{try_ready, Async, Future, IntoFuture, Poll}; use futures::{ready, Future};
use h2::server::{self, Handshake}; use h2::server::{self, Handshake};
use pin_project::{pin_project, project};
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::builder::HttpServiceBuilder; use crate::builder::HttpServiceBuilder;
@@ -20,7 +23,7 @@ use crate::request::Request;
use crate::response::Response; use crate::response::Response;
use crate::{h1, h2::Dispatcher}; use crate::{h1, h2::Dispatcher};
/// `NewService` HTTP1.1/HTTP2 transport implementation /// `ServiceFactory` HTTP1.1/HTTP2 transport implementation
pub struct HttpService<T, P, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler<T>> { pub struct HttpService<T, P, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler<T>> {
srv: S, srv: S,
cfg: ServiceConfig, cfg: ServiceConfig,
@@ -32,10 +35,10 @@ pub struct HttpService<T, P, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler
impl<T, S, B> HttpService<T, (), S, B> impl<T, S, B> HttpService<T, (), S, B>
where where
S: NewService<Config = SrvConfig, Request = Request>, S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
@@ -47,20 +50,20 @@ where
impl<T, P, S, B> HttpService<T, P, S, B> impl<T, P, S, B> HttpService<T, P, S, B>
where where
S: NewService<Config = SrvConfig, Request = Request>, S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
/// Create new `HttpService` instance. /// Create new `HttpService` instance.
pub fn new<F: IntoNewService<S>>(service: F) -> Self { pub fn new<F: IntoServiceFactory<S>>(service: F) -> Self {
let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0);
HttpService { HttpService {
cfg, cfg,
srv: service.into_new_service(), srv: service.into_factory(),
expect: h1::ExpectHandler, expect: h1::ExpectHandler,
upgrade: None, upgrade: None,
on_connect: None, on_connect: None,
@@ -69,13 +72,13 @@ where
} }
/// Create new `HttpService` instance with config. /// Create new `HttpService` instance with config.
pub(crate) fn with_config<F: IntoNewService<S>>( pub(crate) fn with_config<F: IntoServiceFactory<S>>(
cfg: ServiceConfig, cfg: ServiceConfig,
service: F, service: F,
) -> Self { ) -> Self {
HttpService { HttpService {
cfg, cfg,
srv: service.into_new_service(), srv: service.into_factory(),
expect: h1::ExpectHandler, expect: h1::ExpectHandler,
upgrade: None, upgrade: None,
on_connect: None, on_connect: None,
@@ -86,10 +89,11 @@ where
impl<T, P, S, B, X, U> HttpService<T, P, S, B, X, U> impl<T, P, S, B, X, U> HttpService<T, P, S, B, X, U>
where where
S: NewService<Config = SrvConfig, Request = Request>, S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody, B: MessageBody,
{ {
/// Provide service for `EXPECT: 100-Continue` support. /// Provide service for `EXPECT: 100-Continue` support.
@@ -99,9 +103,10 @@ where
/// request will be forwarded to main service. /// request will be forwarded to main service.
pub fn expect<X1>(self, expect: X1) -> HttpService<T, P, S, B, X1, U> pub fn expect<X1>(self, expect: X1) -> HttpService<T, P, S, B, X1, U>
where where
X1: NewService<Config = SrvConfig, Request = Request, Response = Request>, X1: ServiceFactory<Config = SrvConfig, Request = Request, Response = Request>,
X1::Error: Into<Error>, X1::Error: Into<Error>,
X1::InitError: fmt::Debug, X1::InitError: fmt::Debug,
<X1::Service as Service>::Future: 'static,
{ {
HttpService { HttpService {
expect, expect,
@@ -119,13 +124,14 @@ where
/// and this service get called with original request and framed object. /// and this service get called with original request and framed object.
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> HttpService<T, P, S, B, X, U1> pub fn upgrade<U1>(self, upgrade: Option<U1>) -> HttpService<T, P, S, B, X, U1>
where where
U1: NewService< U1: ServiceFactory<
Config = SrvConfig, Config = SrvConfig,
Request = (Request, Framed<T, h1::Codec>), Request = (Request, Framed<T, h1::Codec>),
Response = (), Response = (),
>, >,
U1::Error: fmt::Display, U1::Error: fmt::Display,
U1::InitError: fmt::Debug, U1::InitError: fmt::Debug,
<U1::Service as Service>::Future: 'static,
{ {
HttpService { HttpService {
upgrade, upgrade,
@@ -147,25 +153,27 @@ where
} }
} }
impl<T, P, S, B, X, U> NewService for HttpService<T, P, S, B, X, U> impl<T, P, S, B, X, U> ServiceFactory for HttpService<T, P, S, B, X, U>
where where
T: IoStream, T: IoStream,
S: NewService<Config = SrvConfig, Request = Request>, S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
X: NewService<Config = SrvConfig, Request = Request, Response = Request>, X: ServiceFactory<Config = SrvConfig, Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: NewService< <X::Service as Service>::Future: 'static,
U: ServiceFactory<
Config = SrvConfig, Config = SrvConfig,
Request = (Request, Framed<T, h1::Codec>), Request = (Request, Framed<T, h1::Codec>),
Response = (), Response = (),
>, >,
U::Error: fmt::Display, U::Error: fmt::Display,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
<U::Service as Service>::Future: 'static,
{ {
type Config = SrvConfig; type Config = SrvConfig;
type Request = ServerIo<T, P>; type Request = ServerIo<T, P>;
@@ -177,7 +185,7 @@ where
fn new_service(&self, cfg: &SrvConfig) -> Self::Future { fn new_service(&self, cfg: &SrvConfig) -> Self::Future {
HttpServiceResponse { HttpServiceResponse {
fut: self.srv.new_service(cfg).into_future(), fut: self.srv.new_service(cfg),
fut_ex: Some(self.expect.new_service(cfg)), fut_ex: Some(self.expect.new_service(cfg)),
fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)), fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)),
expect: None, expect: None,
@@ -190,9 +198,20 @@ where
} }
#[doc(hidden)] #[doc(hidden)]
pub struct HttpServiceResponse<T, P, S: NewService, B, X: NewService, U: NewService> { #[pin_project]
pub struct HttpServiceResponse<
T,
P,
S: ServiceFactory,
B,
X: ServiceFactory,
U: ServiceFactory,
> {
#[pin]
fut: S::Future, fut: S::Future,
#[pin]
fut_ex: Option<X::Future>, fut_ex: Option<X::Future>,
#[pin]
fut_upg: Option<U::Future>, fut_upg: Option<U::Future>,
expect: Option<X::Service>, expect: Option<X::Service>,
upgrade: Option<U::Service>, upgrade: Option<U::Service>,
@@ -204,50 +223,59 @@ pub struct HttpServiceResponse<T, P, S: NewService, B, X: NewService, U: NewServ
impl<T, P, S, B, X, U> Future for HttpServiceResponse<T, P, S, B, X, U> impl<T, P, S, B, X, U> Future for HttpServiceResponse<T, P, S, B, X, U>
where where
T: IoStream, T: IoStream,
S: NewService<Request = Request>, S: ServiceFactory<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
X: NewService<Request = Request, Response = Request>, X: ServiceFactory<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: NewService<Request = (Request, Framed<T, h1::Codec>), Response = ()>, <X::Service as Service>::Future: 'static,
U: ServiceFactory<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
<U::Service as Service>::Future: 'static,
{ {
type Item = HttpServiceHandler<T, P, S::Service, B, X::Service, U::Service>; type Output =
type Error = (); Result<HttpServiceHandler<T, P, S::Service, B, X::Service, U::Service>, ()>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let Some(ref mut fut) = self.fut_ex { let mut this = self.as_mut().project();
let expect = try_ready!(fut
.poll() if let Some(fut) = this.fut_ex.as_pin_mut() {
.map_err(|e| log::error!("Init http service error: {:?}", e))); let expect = ready!(fut
self.expect = Some(expect); .poll(cx)
self.fut_ex.take(); .map_err(|e| log::error!("Init http service error: {:?}", e)))?;
this = self.as_mut().project();
*this.expect = Some(expect);
this.fut_ex.set(None);
} }
if let Some(ref mut fut) = self.fut_upg { if let Some(fut) = this.fut_upg.as_pin_mut() {
let upgrade = try_ready!(fut let upgrade = ready!(fut
.poll() .poll(cx)
.map_err(|e| log::error!("Init http service error: {:?}", e))); .map_err(|e| log::error!("Init http service error: {:?}", e)))?;
self.upgrade = Some(upgrade); this = self.as_mut().project();
self.fut_ex.take(); *this.upgrade = Some(upgrade);
this.fut_ex.set(None);
} }
let service = try_ready!(self let result = ready!(this
.fut .fut
.poll() .poll(cx)
.map_err(|e| log::error!("Init http service error: {:?}", e))); .map_err(|e| log::error!("Init http service error: {:?}", e)));
Ok(Async::Ready(HttpServiceHandler::new( Poll::Ready(result.map(|service| {
self.cfg.take().unwrap(), let this = self.as_mut().project();
service, HttpServiceHandler::new(
self.expect.take().unwrap(), this.cfg.take().unwrap(),
self.upgrade.take(), service,
self.on_connect.clone(), this.expect.take().unwrap(),
))) this.upgrade.take(),
this.on_connect.clone(),
)
}))
} }
} }
@@ -264,9 +292,9 @@ pub struct HttpServiceHandler<T, P, S, B, X, U> {
impl<T, P, S, B, X, U> HttpServiceHandler<T, P, S, B, X, U> impl<T, P, S, B, X, U> HttpServiceHandler<T, P, S, B, X, U>
where where
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>, X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
@@ -295,9 +323,9 @@ impl<T, P, S, B, X, U> Service for HttpServiceHandler<T, P, S, B, X, U>
where where
T: IoStream, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>, X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
@@ -309,10 +337,10 @@ where
type Error = DispatchError; type Error = DispatchError;
type Future = HttpServiceHandlerResponse<T, S, B, X, U>; type Future = HttpServiceHandlerResponse<T, S, B, X, U>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let ready = self let ready = self
.expect .expect
.poll_ready() .poll_ready(cx)
.map_err(|e| { .map_err(|e| {
let e = e.into(); let e = e.into();
log::error!("Http service readiness error: {:?}", e); log::error!("Http service readiness error: {:?}", e);
@@ -322,7 +350,7 @@ where
let ready = self let ready = self
.srv .srv
.poll_ready() .poll_ready(cx)
.map_err(|e| { .map_err(|e| {
let e = e.into(); let e = e.into();
log::error!("Http service readiness error: {:?}", e); log::error!("Http service readiness error: {:?}", e);
@@ -332,9 +360,9 @@ where
&& ready; && ready;
if ready { if ready {
Ok(Async::Ready(())) Poll::Ready(Ok(()))
} else { } else {
Ok(Async::NotReady) Poll::Pending
} }
} }
@@ -389,6 +417,7 @@ where
} }
} }
#[pin_project]
enum State<T, S, B, X, U> enum State<T, S, B, X, U>
where where
S: Service<Request = Request>, S: Service<Request = Request>,
@@ -401,8 +430,8 @@ where
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
H1(h1::Dispatcher<T, S, B, X, U>), H1(#[pin] h1::Dispatcher<T, S, B, X, U>),
H2(Dispatcher<Io<T>, S, B>), H2(#[pin] Dispatcher<Io<T>, S, B>),
Unknown( Unknown(
Option<( Option<(
T, T,
@@ -425,19 +454,21 @@ where
), ),
} }
#[pin_project]
pub struct HttpServiceHandlerResponse<T, S, B, X, U> pub struct HttpServiceHandlerResponse<T, S, B, X, U>
where where
T: IoStream, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>, X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
#[pin]
state: State<T, S, B, X, U>, state: State<T, S, B, X, U>,
} }
@@ -447,30 +478,51 @@ impl<T, S, B, X, U> Future for HttpServiceHandlerResponse<T, S, B, X, U>
where where
T: IoStream, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error> + 'static,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>> + 'static,
B: MessageBody, B: MessageBody,
X: Service<Request = Request, Response = Request>, X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
type Item = (); type Output = Result<(), DispatchError>;
type Error = DispatchError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match self.state { self.project().state.poll(cx)
State::H1(ref mut disp) => disp.poll(), }
State::H2(ref mut disp) => disp.poll(), }
impl<T, S, B, X, U> State<T, S, B, X, U>
where
T: IoStream,
S: Service<Request = Request>,
S::Error: Into<Error> + 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display,
{
#[project]
fn poll(
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Result<(), DispatchError>> {
#[project]
match self.as_mut().project() {
State::H1(disp) => disp.poll(cx),
State::H2(disp) => disp.poll(cx),
State::Unknown(ref mut data) => { State::Unknown(ref mut data) => {
if let Some(ref mut item) = data { if let Some(ref mut item) = data {
loop { loop {
// Safety - we only write to the returned slice. // Safety - we only write to the returned slice.
let b = unsafe { item.1.bytes_mut() }; let b = unsafe { item.1.bytes_mut() };
let n = try_ready!(item.0.poll_read(b)); let n = ready!(Pin::new(&mut item.0).poll_read(cx, b))?;
if n == 0 { if n == 0 {
return Ok(Async::Ready(())); return Poll::Ready(Ok(()));
} }
// Safety - we know that 'n' bytes have // Safety - we know that 'n' bytes have
// been initialized via the contract of // been initialized via the contract of
@@ -491,15 +543,15 @@ where
inner: io, inner: io,
unread: Some(buf), unread: Some(buf),
}; };
self.state = State::Handshake(Some(( self.set(State::Handshake(Some((
server::handshake(io), server::handshake(io),
cfg, cfg,
srv, srv,
peer_addr, peer_addr,
on_connect, on_connect,
))); ))));
} else { } else {
self.state = State::H1(h1::Dispatcher::with_timeout( self.set(State::H1(h1::Dispatcher::with_timeout(
io, io,
h1::Codec::new(cfg.clone()), h1::Codec::new(cfg.clone()),
cfg, cfg,
@@ -509,36 +561,38 @@ where
expect, expect,
upgrade, upgrade,
on_connect, on_connect,
)) )))
} }
self.poll() self.poll(cx)
} }
State::Handshake(ref mut data) => { State::Handshake(ref mut data) => {
let conn = if let Some(ref mut item) = data { let conn = if let Some(ref mut item) = data {
match item.0.poll() { match Pin::new(&mut item.0).poll(cx) {
Ok(Async::Ready(conn)) => conn, Poll::Ready(Ok(conn)) => conn,
Ok(Async::NotReady) => return Ok(Async::NotReady), Poll::Ready(Err(err)) => {
Err(err) => {
trace!("H2 handshake error: {}", err); trace!("H2 handshake error: {}", err);
return Err(err.into()); return Poll::Ready(Err(err.into()));
} }
Poll::Pending => return Poll::Pending,
} }
} else { } else {
panic!() panic!()
}; };
let (_, cfg, srv, peer_addr, on_connect) = data.take().unwrap(); let (_, cfg, srv, peer_addr, on_connect) = data.take().unwrap();
self.state = State::H2(Dispatcher::new( self.set(State::H2(Dispatcher::new(
srv, conn, on_connect, cfg, None, peer_addr, srv, conn, on_connect, cfg, None, peer_addr,
)); )));
self.poll() self.poll(cx)
} }
} }
} }
} }
/// Wrapper for `AsyncRead + AsyncWrite` types /// Wrapper for `AsyncRead + AsyncWrite` types
#[pin_project::pin_project]
struct Io<T> { struct Io<T> {
unread: Option<BytesMut>, unread: Option<BytesMut>,
#[pin]
inner: T, inner: T,
} }
@@ -568,21 +622,65 @@ impl<T: io::Write> io::Write for Io<T> {
} }
impl<T: AsyncRead> AsyncRead for Io<T> { impl<T: AsyncRead> AsyncRead for Io<T> {
// unsafe fn initializer(&self) -> io::Initializer {
// self.get_mut().inner.initializer()
// }
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.inner.prepare_uninitialized_buffer(buf) self.inner.prepare_uninitialized_buffer(buf)
} }
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.project();
if let Some(mut bytes) = this.unread.take() {
let size = std::cmp::min(buf.len(), bytes.len());
buf[..size].copy_from_slice(&bytes[..size]);
if bytes.len() > size {
bytes.split_to(size);
*this.unread = Some(bytes);
}
Poll::Ready(Ok(size))
} else {
this.inner.poll_read(cx, buf)
}
}
// fn poll_read_vectored(
// self: Pin<&mut Self>,
// cx: &mut Context<'_>,
// bufs: &mut [io::IoSliceMut<'_>],
// ) -> Poll<io::Result<usize>> {
// self.get_mut().inner.poll_read_vectored(cx, bufs)
// }
} }
impl<T: AsyncWrite> AsyncWrite for Io<T> { impl<T: AsyncWrite> tokio_io::AsyncWrite for Io<T> {
fn shutdown(&mut self) -> Poll<(), io::Error> { fn poll_write(
self.inner.shutdown() self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.project().inner.poll_write(cx, buf)
} }
fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
self.inner.write_buf(buf) fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
self.project().inner.poll_shutdown(cx)
} }
} }
impl<T: IoStream> IoStream for Io<T> { impl<T: IoStream> actix_server_config::IoStream for Io<T> {
#[inline] #[inline]
fn peer_addr(&self) -> Option<net::SocketAddr> { fn peer_addr(&self) -> Option<net::SocketAddr> {
self.inner.peer_addr() self.inner.peer_addr()

View File

@@ -1,12 +1,13 @@
//! Test Various helpers for Actix applications to use during testing. //! Test Various helpers for Actix applications to use during testing.
use std::fmt::Write as FmtWrite; use std::fmt::Write as FmtWrite;
use std::io; use std::io::{self, Read, Write};
use std::pin::Pin;
use std::str::FromStr; use std::str::FromStr;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_server_config::IoStream; use actix_server_config::IoStream;
use bytes::{Buf, Bytes, BytesMut}; use bytes::{Buf, Bytes, BytesMut};
use futures::{Async, Poll};
use http::header::{self, HeaderName, HeaderValue}; use http::header::{self, HeaderName, HeaderValue};
use http::{HttpTryFrom, Method, Uri, Version}; use http::{HttpTryFrom, Method, Uri, Version};
use percent_encoding::percent_encode; use percent_encoding::percent_encode;
@@ -244,14 +245,31 @@ impl io::Write for TestBuffer {
} }
} }
impl AsyncRead for TestBuffer {} impl AsyncRead for TestBuffer {
fn poll_read(
self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.get_mut().read(buf))
}
}
impl AsyncWrite for TestBuffer { impl AsyncWrite for TestBuffer {
fn shutdown(&mut self) -> Poll<(), io::Error> { fn poll_write(
Ok(Async::Ready(())) self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.get_mut().write(buf))
} }
fn write_buf<B: Buf>(&mut self, _: &mut B) -> Poll<usize, io::Error> {
Ok(Async::NotReady) fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
} }
} }

View File

@@ -1,7 +1,10 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_service::{IntoService, Service}; use actix_service::{IntoService, Service};
use actix_utils::framed::{FramedTransport, FramedTransportError}; use actix_utils::framed::{FramedTransport, FramedTransportError};
use futures::{Future, Poll};
use super::{Codec, Frame, Message}; use super::{Codec, Frame, Message};
@@ -40,10 +43,9 @@ where
S::Future: 'static, S::Future: 'static,
S::Error: 'static, S::Error: 'static,
{ {
type Item = (); type Output = Result<(), FramedTransportError<S::Error, Codec>>;
type Error = FramedTransportError<S::Error, Codec>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.inner.poll() Pin::new(&mut self.inner).poll(cx)
} }
} }

View File

@@ -1,9 +1,9 @@
use actix_service::NewService; use actix_service::ServiceFactory;
use bytes::Bytes; use bytes::Bytes;
use futures::future::{self, ok}; use futures::future::{self, ok};
use actix_http::{http, HttpService, Request, Response}; use actix_http::{http, HttpService, Request, Response};
use actix_http_test::TestServer; use actix_http_test::{block_on, TestServer};
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
@@ -29,55 +29,63 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
#[test] #[test]
fn test_h1_v2() { fn test_h1_v2() {
env_logger::init(); block_on(async {
let mut srv = TestServer::new(move || { let srv = TestServer::start(move || {
HttpService::build().finish(|_| future::ok::<_, ()>(Response::Ok().body(STR))) HttpService::build()
}); .finish(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
let response = srv.block_on(srv.get("/").send()).unwrap(); });
assert!(response.status().is_success());
let request = srv.get("/").header("x-test", "111").send(); let response = srv.get("/").send().await.unwrap();
let response = srv.block_on(request).unwrap(); assert!(response.status().is_success());
assert!(response.status().is_success());
// read response let request = srv.get("/").header("x-test", "111").send();
let bytes = srv.load_body(response).unwrap(); let mut response = request.await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref())); assert!(response.status().is_success());
let response = srv.block_on(srv.post("/").send()).unwrap(); // read response
assert!(response.status().is_success()); let bytes = response.body().await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
// read response let mut response = srv.post("/").send().await.unwrap();
let bytes = srv.load_body(response).unwrap(); assert!(response.status().is_success());
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
// read response
let bytes = response.body().await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
} }
#[test] #[test]
fn test_connection_close() { fn test_connection_close() {
let mut srv = TestServer::new(move || { block_on(async {
HttpService::build() let srv = TestServer::start(move || {
.finish(|_| ok::<_, ()>(Response::Ok().body(STR))) HttpService::build()
.map(|_| ()) .finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
}); .map(|_| ())
let response = srv.block_on(srv.get("/").force_close().send()).unwrap(); });
assert!(response.status().is_success());
let response = srv.get("/").force_close().send().await.unwrap();
assert!(response.status().is_success());
})
} }
#[test] #[test]
fn test_with_query_parameter() { fn test_with_query_parameter() {
let mut srv = TestServer::new(move || { block_on(async {
HttpService::build() let srv = TestServer::start(move || {
.finish(|req: Request| { HttpService::build()
if req.uri().query().unwrap().contains("qp=") { .finish(|req: Request| {
ok::<_, ()>(Response::Ok().finish()) if req.uri().query().unwrap().contains("qp=") {
} else { ok::<_, ()>(Response::Ok().finish())
ok::<_, ()>(Response::BadRequest().finish()) } else {
} ok::<_, ()>(Response::BadRequest().finish())
}) }
.map(|_| ()) })
}); .map(|_| ())
});
let request = srv.request(http::Method::GET, srv.url("/?qp=5")).send(); let request = srv.request(http::Method::GET, srv.url("/?qp=5"));
let response = srv.block_on(request).unwrap(); let response = request.send().await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
})
} }

View File

@@ -0,0 +1,545 @@
#![cfg(feature = "openssl")]
use std::io;
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http_test::{block_on, TestServer};
use actix_server::ssl::OpensslAcceptor;
use actix_server_config::ServerConfig;
use actix_service::{factory_fn_cfg, pipeline_factory, service_fn2, ServiceFactory};
use bytes::{Bytes, BytesMut};
use futures::future::{err, ok, ready};
use futures::stream::{once, Stream, StreamExt};
use open_ssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod};
use actix_http::error::{ErrorBadRequest, PayloadError};
use actix_http::http::header::{self, HeaderName, HeaderValue};
use actix_http::http::{Method, StatusCode, Version};
use actix_http::httpmessage::HttpMessage;
use actix_http::{body, Error, HttpService, Request, Response};
async fn load_body<S>(stream: S) -> Result<BytesMut, PayloadError>
where
S: Stream<Item = Result<Bytes, PayloadError>>,
{
let body = stream
.map(|res| match res {
Ok(chunk) => chunk,
Err(_) => panic!(),
})
.fold(BytesMut::new(), move |mut body, chunk| {
body.extend_from_slice(&chunk);
ready(body)
})
.await;
Ok(body)
}
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> io::Result<OpensslAcceptor<T, ()>> {
// load ssl keys
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
builder
.set_private_key_file("../tests/key.pem", SslFiletype::PEM)
.unwrap();
builder
.set_certificate_chain_file("../tests/cert.pem")
.unwrap();
builder.set_alpn_select_callback(|_, protos| {
const H2: &[u8] = b"\x02h2";
if protos.windows(3).any(|window| window == H2) {
Ok(b"h2")
} else {
Err(AlpnError::NOACK)
}
});
builder.set_alpn_protos(b"\x02h2")?;
Ok(OpensslAcceptor::new(builder.build()))
}
#[test]
fn test_h2() -> io::Result<()> {
block_on(async {
let openssl = ssl_acceptor()?;
let srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| ok::<_, Error>(Response::Ok().finish()))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
})
}
#[test]
fn test_h2_1() -> io::Result<()> {
block_on(async {
let openssl = ssl_acceptor()?;
let srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.finish(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_2);
ok::<_, Error>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
})
}
#[test]
fn test_h2_body() -> io::Result<()> {
block_on(async {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let openssl = ssl_acceptor()?;
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|mut req: Request<_>| {
async move {
let body = load_body(req.take_payload()).await?;
Ok::<_, Error>(Response::Ok().body(body))
}
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send_body(data.clone()).await.unwrap();
assert!(response.status().is_success());
let body = srv.load_body(response).await.unwrap();
assert_eq!(&body, data.as_bytes());
Ok(())
})
}
#[test]
fn test_h2_content_length() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
ok::<_, ()>(Response::new(statuses[indx]))
})
.map_err(|_| ()),
)
});
let header = HeaderName::from_static("content-length");
let value = HeaderValue::from_static("0");
{
for i in 0..4 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);
let req = srv
.request(Method::HEAD, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);
}
for i in 4..6 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
}
}
})
}
#[test]
fn test_h2_headers() {
block_on(async {
let data = STR.repeat(10);
let data2 = data.clone();
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
let data = data.clone();
pipeline_factory(openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)))
.and_then(
HttpService::build().h2(move |_| {
let mut builder = Response::Ok();
for idx in 0..90 {
builder.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
ok::<_, ()>(builder.body(data.clone()))
}).map_err(|_| ()))
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from(data2));
})
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test]
fn test_h2_body2() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_head_empty() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
})
}
#[test]
fn test_h2_head_binary() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| {
ok::<_, ()>(
Response::Ok().content_length(STR.len() as u64).body(STR),
)
})
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
})
}
#[test]
fn test_h2_head_binary2() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
})
}
#[test]
fn test_h2_body_length() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| {
let body = once(ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.body(body::SizedStream::new(STR.len() as u64, body)),
)
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_body_chunked_explicit() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| {
let body =
once(ok::<_, Error>(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
assert!(!response.headers().contains_key(header::TRANSFER_ENCODING));
// read response
let bytes = srv.load_body(response).await.unwrap();
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_response_http_error_handling() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(factory_fn_cfg(|_: &ServerConfig| {
ok::<_, ()>(service_fn2(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>(
Response::Ok()
.header(header::CONTENT_TYPE, broken_header)
.body(STR),
)
}))
}))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
})
}
#[test]
fn test_h2_service_error() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| err::<Response, Error>(ErrorBadRequest("error")))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
})
}
#[test]
fn test_h2_on_connect() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.on_connect(|_| 10usize)
.h2(|req: Request| {
assert!(req.extensions().contains::<usize>());
ok::<_, ()>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
})
}

View File

@@ -0,0 +1,474 @@
#![cfg(feature = "rustls")]
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::error::PayloadError;
use actix_http::http::header::{self, HeaderName, HeaderValue};
use actix_http::http::{Method, StatusCode, Version};
use actix_http::{body, error, Error, HttpService, Request, Response};
use actix_http_test::{block_on, TestServer};
use actix_server::ssl::RustlsAcceptor;
use actix_server_config::ServerConfig;
use actix_service::{factory_fn_cfg, pipeline_factory, service_fn2, ServiceFactory};
use bytes::{Bytes, BytesMut};
use futures::future::{self, err, ok};
use futures::stream::{once, Stream, StreamExt};
use rust_tls::{
internal::pemfile::{certs, pkcs8_private_keys},
NoClientAuth, ServerConfig as RustlsServerConfig,
};
use std::fs::File;
use std::io::{self, BufReader};
async fn load_body<S>(mut stream: S) -> Result<BytesMut, PayloadError>
where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{
let mut body = BytesMut::new();
while let Some(item) = stream.next().await {
body.extend_from_slice(&item?)
}
Ok(body)
}
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> io::Result<RustlsAcceptor<T, ()>> {
// load ssl keys
let mut config = RustlsServerConfig::new(NoClientAuth::new());
let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap());
let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap());
let cert_chain = certs(cert_file).unwrap();
let mut keys = pkcs8_private_keys(key_file).unwrap();
config.set_single_cert(cert_chain, keys.remove(0)).unwrap();
let protos = vec![b"h2".to_vec()];
config.set_protocols(&protos);
Ok(RustlsAcceptor::new(config))
}
#[test]
fn test_h2() -> io::Result<()> {
block_on(async {
let rustls = ssl_acceptor()?;
let srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| future::ok::<_, Error>(Response::Ok().finish()))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
})
}
#[test]
fn test_h2_1() -> io::Result<()> {
block_on(async {
let rustls = ssl_acceptor()?;
let srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.finish(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_2);
future::ok::<_, Error>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
})
}
#[test]
fn test_h2_body1() -> io::Result<()> {
block_on(async {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let rustls = ssl_acceptor()?;
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|mut req: Request<_>| {
async move {
let body = load_body(req.take_payload()).await?;
Ok::<_, Error>(Response::Ok().body(body))
}
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send_body(data.clone()).await.unwrap();
assert!(response.status().is_success());
let body = srv.load_body(response).await.unwrap();
assert_eq!(&body, data.as_bytes());
Ok(())
})
}
#[test]
fn test_h2_content_length() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
future::ok::<_, ()>(Response::new(statuses[indx]))
})
.map_err(|_| ()),
)
});
let header = HeaderName::from_static("content-length");
let value = HeaderValue::from_static("0");
{
for i in 0..4 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);
let req = srv
.request(Method::HEAD, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);
}
for i in 4..6 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
}
}
})
}
#[test]
fn test_h2_headers() {
block_on(async {
let data = STR.repeat(10);
let data2 = data.clone();
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
let data = data.clone();
pipeline_factory(rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build().h2(move |_| {
let mut config = Response::Ok();
for idx in 0..90 {
config.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
future::ok::<_, ()>(config.body(data.clone()))
}).map_err(|_| ()))
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from(data2));
})
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test]
fn test_h2_body2() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_head_empty() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
})
}
#[test]
fn test_h2_head_binary() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| {
ok::<_, ()>(
Response::Ok()
.content_length(STR.len() as u64)
.body(STR),
)
})
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
})
}
#[test]
fn test_h2_head_binary2() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
})
}
#[test]
fn test_h2_body_length() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| {
let body = once(ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok().body(body::SizedStream::new(
STR.len() as u64,
body,
)),
)
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_body_chunked_explicit() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| {
let body =
once(ok::<_, Error>(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
assert!(!response.headers().contains_key(header::TRANSFER_ENCODING));
// read response
let bytes = srv.load_body(response).await.unwrap();
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_response_http_error_handling() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(factory_fn_cfg(|_: &ServerConfig| {
ok::<_, ()>(service_fn2(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>(
Response::Ok()
.header(
http::header::CONTENT_TYPE,
broken_header,
)
.body(STR),
)
}))
}))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
})
}
#[test]
fn test_h2_service_error() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| err::<Response, Error>(error::ErrorBadRequest("error")))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
})
}

View File

@@ -1,462 +0,0 @@
#![cfg(feature = "rust-tls")]
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::error::PayloadError;
use actix_http::http::header::{self, HeaderName, HeaderValue};
use actix_http::http::{Method, StatusCode, Version};
use actix_http::{body, error, Error, HttpService, Request, Response};
use actix_http_test::TestServer;
use actix_server::ssl::RustlsAcceptor;
use actix_server_config::ServerConfig;
use actix_service::{new_service_cfg, NewService};
use bytes::{Bytes, BytesMut};
use futures::future::{self, ok, Future};
use futures::stream::{once, Stream};
use rustls::{
internal::pemfile::{certs, pkcs8_private_keys},
NoClientAuth, ServerConfig as RustlsServerConfig,
};
use std::fs::File;
use std::io::{BufReader, Result};
fn load_body<S>(stream: S) -> impl Future<Item = BytesMut, Error = PayloadError>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
stream.fold(BytesMut::new(), move |mut body, chunk| {
body.extend_from_slice(&chunk);
Ok::<_, PayloadError>(body)
})
}
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<RustlsAcceptor<T, ()>> {
// load ssl keys
let mut config = RustlsServerConfig::new(NoClientAuth::new());
let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap());
let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap());
let cert_chain = certs(cert_file).unwrap();
let mut keys = pkcs8_private_keys(key_file).unwrap();
config.set_single_cert(cert_chain, keys.remove(0)).unwrap();
let protos = vec![b"h2".to_vec()];
config.set_protocols(&protos);
Ok(RustlsAcceptor::new(config))
}
#[test]
fn test_h2() -> Result<()> {
let rustls = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| future::ok::<_, Error>(Response::Ok().finish()))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
Ok(())
}
#[test]
fn test_h2_1() -> Result<()> {
let rustls = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.finish(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_2);
future::ok::<_, Error>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
Ok(())
}
#[test]
fn test_h2_body() -> Result<()> {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let rustls = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|mut req: Request<_>| {
load_body(req.take_payload())
.and_then(|body| Ok(Response::Ok().body(body)))
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send_body(data.clone())).unwrap();
assert!(response.status().is_success());
let body = srv.load_body(response).unwrap();
assert_eq!(&body, data.as_bytes());
Ok(())
}
#[test]
fn test_h2_content_length() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
future::ok::<_, ()>(Response::new(statuses[indx]))
})
.map_err(|_| ()),
)
});
let header = HeaderName::from_static("content-length");
let value = HeaderValue::from_static("0");
{
for i in 0..4 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None);
let req = srv
.request(Method::HEAD, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None);
}
for i in 4..6 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
}
}
}
#[test]
fn test_h2_headers() {
let data = STR.repeat(10);
let data2 = data.clone();
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
let data = data.clone();
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build().h2(move |_| {
let mut config = Response::Ok();
for idx in 0..90 {
config.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
future::ok::<_, ()>(config.body(data.clone()))
}).map_err(|_| ()))
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from(data2));
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test]
fn test_h2_body2() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_head_empty() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty());
}
#[test]
fn test_h2_head_binary() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
ok::<_, ()>(
Response::Ok().content_length(STR.len() as u64).body(STR),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty());
}
#[test]
fn test_h2_head_binary2() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
}
#[test]
fn test_h2_body_length() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.body(body::SizedStream::new(STR.len() as u64, body)),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_body_chunked_explicit() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
let body =
once::<_, Error>(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
assert!(!response.headers().contains_key(header::TRANSFER_ENCODING));
// read response
let bytes = srv.load_body(response).unwrap();
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_response_http_error_handling() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(new_service_cfg(|_: &ServerConfig| {
Ok::<_, ()>(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>(
Response::Ok()
.header(http::header::CONTENT_TYPE, broken_header)
.body(STR),
)
})
}))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
}
#[test]
fn test_h2_service_error() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| Err::<Response, Error>(error::ErrorBadRequest("error")))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,480 +0,0 @@
#![cfg(feature = "ssl")]
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http_test::TestServer;
use actix_server::ssl::OpensslAcceptor;
use actix_server_config::ServerConfig;
use actix_service::{new_service_cfg, NewService};
use bytes::{Bytes, BytesMut};
use futures::future::{ok, Future};
use futures::stream::{once, Stream};
use openssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod};
use std::io::Result;
use actix_http::error::{ErrorBadRequest, PayloadError};
use actix_http::http::header::{self, HeaderName, HeaderValue};
use actix_http::http::{Method, StatusCode, Version};
use actix_http::httpmessage::HttpMessage;
use actix_http::{body, Error, HttpService, Request, Response};
fn load_body<S>(stream: S) -> impl Future<Item = BytesMut, Error = PayloadError>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
stream.fold(BytesMut::new(), move |mut body, chunk| {
body.extend_from_slice(&chunk);
Ok::<_, PayloadError>(body)
})
}
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> {
// load ssl keys
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
builder
.set_private_key_file("../tests/key.pem", SslFiletype::PEM)
.unwrap();
builder
.set_certificate_chain_file("../tests/cert.pem")
.unwrap();
builder.set_alpn_select_callback(|_, protos| {
const H2: &[u8] = b"\x02h2";
if protos.windows(3).any(|window| window == H2) {
Ok(b"h2")
} else {
Err(AlpnError::NOACK)
}
});
builder.set_alpn_protos(b"\x02h2")?;
Ok(OpensslAcceptor::new(builder.build()))
}
#[test]
fn test_h2() -> Result<()> {
let openssl = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| ok::<_, Error>(Response::Ok().finish()))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
Ok(())
}
#[test]
fn test_h2_1() -> Result<()> {
let openssl = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.finish(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_2);
ok::<_, Error>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
Ok(())
}
#[test]
fn test_h2_body() -> Result<()> {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let openssl = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|mut req: Request<_>| {
load_body(req.take_payload())
.and_then(|body| Ok(Response::Ok().body(body)))
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send_body(data.clone())).unwrap();
assert!(response.status().is_success());
let body = srv.load_body(response).unwrap();
assert_eq!(&body, data.as_bytes());
Ok(())
}
#[test]
fn test_h2_content_length() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
ok::<_, ()>(Response::new(statuses[indx]))
})
.map_err(|_| ()),
)
});
let header = HeaderName::from_static("content-length");
let value = HeaderValue::from_static("0");
{
for i in 0..4 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None);
let req = srv
.request(Method::HEAD, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None);
}
for i in 4..6 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
}
}
}
#[test]
fn test_h2_headers() {
let data = STR.repeat(10);
let data2 = data.clone();
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
let data = data.clone();
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build().h2(move |_| {
let mut builder = Response::Ok();
for idx in 0..90 {
builder.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
ok::<_, ()>(builder.body(data.clone()))
}).map_err(|_| ()))
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from(data2));
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test]
fn test_h2_body2() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_head_empty() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty());
}
#[test]
fn test_h2_head_binary() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
ok::<_, ()>(
Response::Ok().content_length(STR.len() as u64).body(STR),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty());
}
#[test]
fn test_h2_head_binary2() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
}
#[test]
fn test_h2_body_length() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.body(body::SizedStream::new(STR.len() as u64, body)),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_body_chunked_explicit() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
let body =
once::<_, Error>(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
assert!(!response.headers().contains_key(header::TRANSFER_ENCODING));
// read response
let bytes = srv.load_body(response).unwrap();
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_response_http_error_handling() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(new_service_cfg(|_: &ServerConfig| {
Ok::<_, ()>(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>(
Response::Ok()
.header(header::CONTENT_TYPE, broken_header)
.body(STR),
)
})
}))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
}
#[test]
fn test_h2_service_error() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| Err::<Response, Error>(ErrorBadRequest("error")))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
}
#[test]
fn test_h2_on_connect() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.on_connect(|_| 10usize)
.h2(|req: Request| {
assert!(req.extensions().contains::<usize>());
ok::<_, ()>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
}

View File

@@ -1,26 +1,27 @@
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::{body, h1, ws, Error, HttpService, Request, Response}; use actix_http::{body, h1, ws, Error, HttpService, Request, Response};
use actix_http_test::TestServer; use actix_http_test::{block_on, TestServer};
use actix_utils::framed::FramedTransport; use actix_utils::framed::FramedTransport;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::future::{self, ok}; use futures::future;
use futures::{Future, Sink, Stream}; use futures::{SinkExt, StreamExt};
fn ws_service<T: AsyncRead + AsyncWrite>( async fn ws_service<T: AsyncRead + AsyncWrite + Unpin>(
(req, framed): (Request, Framed<T, h1::Codec>), (req, mut framed): (Request, Framed<T, h1::Codec>),
) -> impl Future<Item = (), Error = Error> { ) -> Result<(), Error> {
let res = ws::handshake(req.head()).unwrap().message_body(()); let res = ws::handshake(req.head()).unwrap().message_body(());
framed framed
.send((res, body::BodySize::None).into()) .send((res, body::BodySize::None).into())
.await
.unwrap();
FramedTransport::new(framed.into_framed(ws::Codec::new()), service)
.await
.map_err(|_| panic!()) .map_err(|_| panic!())
.and_then(|framed| {
FramedTransport::new(framed.into_framed(ws::Codec::new()), service)
.map_err(|_| panic!())
})
} }
fn service(msg: ws::Frame) -> impl Future<Item = ws::Message, Error = Error> { async fn service(msg: ws::Frame) -> Result<ws::Message, Error> {
let msg = match msg { let msg = match msg {
ws::Frame::Ping(msg) => ws::Message::Pong(msg), ws::Frame::Ping(msg) => ws::Message::Pong(msg),
ws::Frame::Text(text) => { ws::Frame::Text(text) => {
@@ -30,47 +31,56 @@ fn service(msg: ws::Frame) -> impl Future<Item = ws::Message, Error = Error> {
ws::Frame::Close(reason) => ws::Message::Close(reason), ws::Frame::Close(reason) => ws::Message::Close(reason),
_ => panic!(), _ => panic!(),
}; };
ok(msg) Ok(msg)
} }
#[test] #[test]
fn test_simple() { fn test_simple() {
let mut srv = TestServer::new(|| { block_on(async {
HttpService::build() let mut srv = TestServer::start(|| {
.upgrade(ws_service) HttpService::build()
.finish(|_| future::ok::<_, ()>(Response::NotFound())) .upgrade(actix_service::service_fn(ws_service))
}); .finish(|_| future::ok::<_, ()>(Response::NotFound()))
});
// client service // client service
let framed = srv.ws().unwrap(); let mut framed = srv.ws().await.unwrap();
let framed = srv framed
.block_on(framed.send(ws::Message::Text("text".to_string()))) .send(ws::Message::Text("text".to_string()))
.unwrap(); .await
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); .unwrap();
assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text"))))); let (item, mut framed) = framed.into_future().await;
assert_eq!(
item.unwrap().unwrap(),
ws::Frame::Text(Some(BytesMut::from("text")))
);
let framed = srv framed
.block_on(framed.send(ws::Message::Binary("text".into()))) .send(ws::Message::Binary("text".into()))
.unwrap(); .await
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); .unwrap();
assert_eq!( let (item, mut framed) = framed.into_future().await;
item, assert_eq!(
Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))) item.unwrap().unwrap(),
); ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))
);
let framed = srv framed.send(ws::Message::Ping("text".into())).await.unwrap();
.block_on(framed.send(ws::Message::Ping("text".into()))) let (item, mut framed) = framed.into_future().await;
.unwrap(); assert_eq!(
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); item.unwrap().unwrap(),
assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into()))); ws::Frame::Pong("text".to_string().into())
);
let framed = srv framed
.block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))) .send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
.unwrap(); .await
.unwrap();
let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); let (item, _framed) = framed.into_future().await;
assert_eq!( assert_eq!(
item, item.unwrap().unwrap(),
Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) ws::Frame::Close(Some(ws::CloseCode::Normal.into()))
); );
})
} }

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-identity" name = "actix-identity"
version = "0.1.0" version = "0.2.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Identity service for actix web framework." description = "Identity service for actix web framework."
readme = "README.md" readme = "README.md"
@@ -17,14 +17,14 @@ name = "actix_identity"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "1.0.0", default-features = false, features = ["secure-cookies"] } actix-web = { version = "2.0.0-alpha.1", default-features = false, features = ["secure-cookies"] }
actix-service = "0.4.0" actix-service = "1.0.0-alpha.1"
futures = "0.1.25" futures = "0.3.1"
serde = "1.0" serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
time = "0.1.42" time = "0.1.42"
[dev-dependencies] [dev-dependencies]
actix-rt = "0.2.2" actix-rt = "1.0.0-alpha.1"
actix-http = "0.2.3" actix-http = "0.3.0-alpha.1"
bytes = "0.4" bytes = "0.4"

View File

@@ -16,7 +16,7 @@
//! use actix_web::*; //! use actix_web::*;
//! use actix_identity::{Identity, CookieIdentityPolicy, IdentityService}; //! use actix_identity::{Identity, CookieIdentityPolicy, IdentityService};
//! //!
//! fn index(id: Identity) -> String { //! async fn index(id: Identity) -> String {
//! // access request identity //! // access request identity
//! if let Some(id) = id.identity() { //! if let Some(id) = id.identity() {
//! format!("Welcome! {}", id) //! format!("Welcome! {}", id)
@@ -25,12 +25,12 @@
//! } //! }
//! } //! }
//! //!
//! fn login(id: Identity) -> HttpResponse { //! async fn login(id: Identity) -> HttpResponse {
//! id.remember("User1".to_owned()); // <- remember identity //! id.remember("User1".to_owned()); // <- remember identity
//! HttpResponse::Ok().finish() //! HttpResponse::Ok().finish()
//! } //! }
//! //!
//! fn logout(id: Identity) -> HttpResponse { //! async fn logout(id: Identity) -> HttpResponse {
//! id.forget(); // <- remove identity //! id.forget(); // <- remove identity
//! HttpResponse::Ok().finish() //! HttpResponse::Ok().finish()
//! } //! }
@@ -47,12 +47,13 @@
//! } //! }
//! ``` //! ```
use std::cell::RefCell; use std::cell::RefCell;
use std::future::Future;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use std::time::SystemTime; use std::time::SystemTime;
use actix_service::{Service, Transform}; use actix_service::{Service, Transform};
use futures::future::{ok, Either, FutureResult}; use futures::future::{ok, FutureExt, LocalBoxFuture, Ready};
use futures::{Future, IntoFuture, Poll};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use time::Duration; use time::Duration;
@@ -165,21 +166,21 @@ where
impl FromRequest for Identity { impl FromRequest for Identity {
type Config = (); type Config = ();
type Error = Error; type Error = Error;
type Future = Result<Identity, Error>; type Future = Ready<Result<Identity, Error>>;
#[inline] #[inline]
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
Ok(Identity(req.clone())) ok(Identity(req.clone()))
} }
} }
/// Identity policy definition. /// Identity policy definition.
pub trait IdentityPolicy: Sized + 'static { pub trait IdentityPolicy: Sized + 'static {
/// The return type of the middleware /// The return type of the middleware
type Future: IntoFuture<Item = Option<String>, Error = Error>; type Future: Future<Output = Result<Option<String>, Error>>;
/// The return type of the middleware /// The return type of the middleware
type ResponseFuture: IntoFuture<Item = (), Error = Error>; type ResponseFuture: Future<Output = Result<(), Error>>;
/// Parse the session from request and load data from a service identity. /// Parse the session from request and load data from a service identity.
fn from_request(&self, request: &mut ServiceRequest) -> Self::Future; fn from_request(&self, request: &mut ServiceRequest) -> Self::Future;
@@ -234,7 +235,7 @@ where
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Transform = IdentityServiceMiddleware<S, T>; type Transform = IdentityServiceMiddleware<S, T>;
type Future = FutureResult<Self::Transform, Self::InitError>; type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
ok(IdentityServiceMiddleware { ok(IdentityServiceMiddleware {
@@ -261,46 +262,39 @@ where
type Request = ServiceRequest; type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = Error; type Error = Error;
type Future = Box<dyn Future<Item = Self::Response, Error = Self::Error>>; type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.service.borrow_mut().poll_ready() self.service.borrow_mut().poll_ready(cx)
} }
fn call(&mut self, mut req: ServiceRequest) -> Self::Future { fn call(&mut self, mut req: ServiceRequest) -> Self::Future {
let srv = self.service.clone(); let srv = self.service.clone();
let backend = self.backend.clone(); let backend = self.backend.clone();
let fut = self.backend.from_request(&mut req);
Box::new( async move {
self.backend.from_request(&mut req).into_future().then( match fut.await {
move |res| match res { Ok(id) => {
Ok(id) => { req.extensions_mut()
req.extensions_mut() .insert(IdentityItem { id, changed: false });
.insert(IdentityItem { id, changed: false });
Either::A(srv.borrow_mut().call(req).and_then(move |mut res| { let mut res = srv.borrow_mut().call(req).await?;
let id = let id = res.request().extensions_mut().remove::<IdentityItem>();
res.request().extensions_mut().remove::<IdentityItem>();
if let Some(id) = id { if let Some(id) = id {
Either::A( match backend.to_response(id.id, id.changed, &mut res).await {
backend Ok(_) => Ok(res),
.to_response(id.id, id.changed, &mut res) Err(e) => Ok(res.error_response(e)),
.into_future() }
.then(move |t| match t { } else {
Ok(_) => Ok(res), Ok(res)
Err(e) => Ok(res.error_response(e)),
}),
)
} else {
Either::B(ok(res))
}
}))
} }
Err(err) => Either::B(ok(req.error_response(err))), }
}, Err(err) => Ok(req.error_response(err)),
), }
) }
.boxed_local()
} }
} }
@@ -547,11 +541,11 @@ impl CookieIdentityPolicy {
} }
impl IdentityPolicy for CookieIdentityPolicy { impl IdentityPolicy for CookieIdentityPolicy {
type Future = Result<Option<String>, Error>; type Future = Ready<Result<Option<String>, Error>>;
type ResponseFuture = Result<(), Error>; type ResponseFuture = Ready<Result<(), Error>>;
fn from_request(&self, req: &mut ServiceRequest) -> Self::Future { fn from_request(&self, req: &mut ServiceRequest) -> Self::Future {
Ok(self.0.load(req).map( ok(self.0.load(req).map(
|CookieValue { |CookieValue {
identity, identity,
login_timestamp, login_timestamp,
@@ -603,7 +597,7 @@ impl IdentityPolicy for CookieIdentityPolicy {
} else { } else {
Ok(()) Ok(())
}; };
Ok(()) ok(())
} }
} }
@@ -613,7 +607,7 @@ mod tests {
use super::*; use super::*;
use actix_web::http::StatusCode; use actix_web::http::StatusCode;
use actix_web::test::{self, TestRequest}; use actix_web::test::{self, block_on, TestRequest};
use actix_web::{web, App, Error, HttpResponse}; use actix_web::{web, App, Error, HttpResponse};
const COOKIE_KEY_MASTER: [u8; 32] = [0; 32]; const COOKIE_KEY_MASTER: [u8; 32] = [0; 32];
@@ -622,115 +616,138 @@ mod tests {
#[test] #[test]
fn test_identity() { fn test_identity() {
let mut srv = test::init_service( block_on(async {
App::new() let mut srv = test::init_service(
.wrap(IdentityService::new( App::new()
CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) .wrap(IdentityService::new(
.domain("www.rust-lang.org") CookieIdentityPolicy::new(&COOKIE_KEY_MASTER)
.name(COOKIE_NAME) .domain("www.rust-lang.org")
.path("/") .name(COOKIE_NAME)
.secure(true), .path("/")
)) .secure(true),
.service(web::resource("/index").to(|id: Identity| { ))
if id.identity().is_some() { .service(web::resource("/index").to(|id: Identity| {
HttpResponse::Created() if id.identity().is_some() {
} else { HttpResponse::Created()
} else {
HttpResponse::Ok()
}
}))
.service(web::resource("/login").to(|id: Identity| {
id.remember(COOKIE_LOGIN.to_string());
HttpResponse::Ok() HttpResponse::Ok()
} }))
})) .service(web::resource("/logout").to(|id: Identity| {
.service(web::resource("/login").to(|id: Identity| { if id.identity().is_some() {
id.remember(COOKIE_LOGIN.to_string()); id.forget();
HttpResponse::Ok() HttpResponse::Ok()
})) } else {
.service(web::resource("/logout").to(|id: Identity| { HttpResponse::BadRequest()
if id.identity().is_some() { }
id.forget(); })),
HttpResponse::Ok() )
} else { .await;
HttpResponse::BadRequest() let resp = test::call_service(
} &mut srv,
})), TestRequest::with_uri("/index").to_request(),
); )
let resp = .await;
test::call_service(&mut srv, TestRequest::with_uri("/index").to_request()); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.status(), StatusCode::OK);
let resp = let resp = test::call_service(
test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()); &mut srv,
assert_eq!(resp.status(), StatusCode::OK); TestRequest::with_uri("/login").to_request(),
let c = resp.response().cookies().next().unwrap().to_owned(); )
.await;
assert_eq!(resp.status(), StatusCode::OK);
let c = resp.response().cookies().next().unwrap().to_owned();
let resp = test::call_service( let resp = test::call_service(
&mut srv, &mut srv,
TestRequest::with_uri("/index") TestRequest::with_uri("/index")
.cookie(c.clone()) .cookie(c.clone())
.to_request(), .to_request(),
); )
assert_eq!(resp.status(), StatusCode::CREATED); .await;
assert_eq!(resp.status(), StatusCode::CREATED);
let resp = test::call_service( let resp = test::call_service(
&mut srv, &mut srv,
TestRequest::with_uri("/logout") TestRequest::with_uri("/logout")
.cookie(c.clone()) .cookie(c.clone())
.to_request(), .to_request(),
); )
assert_eq!(resp.status(), StatusCode::OK); .await;
assert!(resp.headers().contains_key(header::SET_COOKIE)) assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().contains_key(header::SET_COOKIE))
})
} }
#[test] #[test]
fn test_identity_max_age_time() { fn test_identity_max_age_time() {
let duration = Duration::days(1); block_on(async {
let mut srv = test::init_service( let duration = Duration::days(1);
App::new() let mut srv = test::init_service(
.wrap(IdentityService::new( App::new()
CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) .wrap(IdentityService::new(
.domain("www.rust-lang.org") CookieIdentityPolicy::new(&COOKIE_KEY_MASTER)
.name(COOKIE_NAME) .domain("www.rust-lang.org")
.path("/") .name(COOKIE_NAME)
.max_age_time(duration) .path("/")
.secure(true), .max_age_time(duration)
)) .secure(true),
.service(web::resource("/login").to(|id: Identity| { ))
id.remember("test".to_string()); .service(web::resource("/login").to(|id: Identity| {
HttpResponse::Ok() id.remember("test".to_string());
})), HttpResponse::Ok()
); })),
let resp = )
test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()); .await;
assert_eq!(resp.status(), StatusCode::OK); let resp = test::call_service(
assert!(resp.headers().contains_key(header::SET_COOKIE)); &mut srv,
let c = resp.response().cookies().next().unwrap().to_owned(); TestRequest::with_uri("/login").to_request(),
assert_eq!(duration, c.max_age().unwrap()); )
.await;
assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().contains_key(header::SET_COOKIE));
let c = resp.response().cookies().next().unwrap().to_owned();
assert_eq!(duration, c.max_age().unwrap());
})
} }
#[test] #[test]
fn test_identity_max_age() { fn test_identity_max_age() {
let seconds = 60; block_on(async {
let mut srv = test::init_service( let seconds = 60;
App::new() let mut srv = test::init_service(
.wrap(IdentityService::new( App::new()
CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) .wrap(IdentityService::new(
.domain("www.rust-lang.org") CookieIdentityPolicy::new(&COOKIE_KEY_MASTER)
.name(COOKIE_NAME) .domain("www.rust-lang.org")
.path("/") .name(COOKIE_NAME)
.max_age(seconds) .path("/")
.secure(true), .max_age(seconds)
)) .secure(true),
.service(web::resource("/login").to(|id: Identity| { ))
id.remember("test".to_string()); .service(web::resource("/login").to(|id: Identity| {
HttpResponse::Ok() id.remember("test".to_string());
})), HttpResponse::Ok()
); })),
let resp = )
test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()); .await;
assert_eq!(resp.status(), StatusCode::OK); let resp = test::call_service(
assert!(resp.headers().contains_key(header::SET_COOKIE)); &mut srv,
let c = resp.response().cookies().next().unwrap().to_owned(); TestRequest::with_uri("/login").to_request(),
assert_eq!(Duration::seconds(seconds as i64), c.max_age().unwrap()); )
.await;
assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().contains_key(header::SET_COOKIE));
let c = resp.response().cookies().next().unwrap().to_owned();
assert_eq!(Duration::seconds(seconds as i64), c.max_age().unwrap());
})
} }
fn create_identity_server< async fn create_identity_server<
F: Fn(CookieIdentityPolicy) -> CookieIdentityPolicy + Sync + Send + Clone + 'static, F: Fn(CookieIdentityPolicy) -> CookieIdentityPolicy + Sync + Send + Clone + 'static,
>( >(
f: F, f: F,
@@ -747,13 +764,16 @@ mod tests {
.secure(false) .secure(false)
.name(COOKIE_NAME)))) .name(COOKIE_NAME))))
.service(web::resource("/").to(|id: Identity| { .service(web::resource("/").to(|id: Identity| {
let identity = id.identity(); async move {
if identity.is_none() { let identity = id.identity();
id.remember(COOKIE_LOGIN.to_string()) if identity.is_none() {
id.remember(COOKIE_LOGIN.to_string())
}
web::Json(identity)
} }
web::Json(identity)
})), })),
) )
.await
} }
fn legacy_login_cookie(identity: &'static str) -> Cookie<'static> { fn legacy_login_cookie(identity: &'static str) -> Cookie<'static> {
@@ -786,15 +806,8 @@ mod tests {
jar.get(COOKIE_NAME).unwrap().clone() jar.get(COOKIE_NAME).unwrap().clone()
} }
fn assert_logged_in(response: &mut ServiceResponse, identity: Option<&str>) { async fn assert_logged_in(response: ServiceResponse, identity: Option<&str>) {
use bytes::BytesMut; let bytes = test::read_body(response).await;
use futures::Stream;
let bytes =
test::block_on(response.take_body().fold(BytesMut::new(), |mut b, c| {
b.extend(c);
Ok::<_, Error>(b)
}))
.unwrap();
let resp: Option<String> = serde_json::from_slice(&bytes[..]).unwrap(); let resp: Option<String> = serde_json::from_slice(&bytes[..]).unwrap();
assert_eq!(resp.as_ref().map(|s| s.borrow()), identity); assert_eq!(resp.as_ref().map(|s| s.borrow()), identity);
} }
@@ -874,183 +887,221 @@ mod tests {
#[test] #[test]
fn test_identity_legacy_cookie_is_set() { fn test_identity_legacy_cookie_is_set() {
let mut srv = create_identity_server(|c| c); block_on(async {
let mut resp = let mut srv = create_identity_server(|c| c).await;
test::call_service(&mut srv, TestRequest::with_uri("/").to_request()); let mut resp =
assert_logged_in(&mut resp, None); test::call_service(&mut srv, TestRequest::with_uri("/").to_request())
assert_legacy_login_cookie(&mut resp, COOKIE_LOGIN); .await;
assert_legacy_login_cookie(&mut resp, COOKIE_LOGIN);
assert_logged_in(resp, None).await;
})
} }
#[test] #[test]
fn test_identity_legacy_cookie_works() { fn test_identity_legacy_cookie_works() {
let mut srv = create_identity_server(|c| c); block_on(async {
let cookie = legacy_login_cookie(COOKIE_LOGIN); let mut srv = create_identity_server(|c| c).await;
let mut resp = test::call_service( let cookie = legacy_login_cookie(COOKIE_LOGIN);
&mut srv, let mut resp = test::call_service(
TestRequest::with_uri("/") &mut srv,
.cookie(cookie.clone()) TestRequest::with_uri("/")
.to_request(), .cookie(cookie.clone())
); .to_request(),
assert_logged_in(&mut resp, Some(COOKIE_LOGIN)); )
assert_no_login_cookie(&mut resp); .await;
assert_no_login_cookie(&mut resp);
assert_logged_in(resp, Some(COOKIE_LOGIN)).await;
})
} }
#[test] #[test]
fn test_identity_legacy_cookie_rejected_if_visit_timestamp_needed() { fn test_identity_legacy_cookie_rejected_if_visit_timestamp_needed() {
let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90))); block_on(async {
let cookie = legacy_login_cookie(COOKIE_LOGIN); let mut srv =
let mut resp = test::call_service( create_identity_server(|c| c.visit_deadline(Duration::days(90))).await;
&mut srv, let cookie = legacy_login_cookie(COOKIE_LOGIN);
TestRequest::with_uri("/") let mut resp = test::call_service(
.cookie(cookie.clone()) &mut srv,
.to_request(), TestRequest::with_uri("/")
); .cookie(cookie.clone())
assert_logged_in(&mut resp, None); .to_request(),
assert_login_cookie( )
&mut resp, .await;
COOKIE_LOGIN, assert_login_cookie(
LoginTimestampCheck::NoTimestamp, &mut resp,
VisitTimeStampCheck::NewTimestamp, COOKIE_LOGIN,
); LoginTimestampCheck::NoTimestamp,
VisitTimeStampCheck::NewTimestamp,
);
assert_logged_in(resp, None).await;
})
} }
#[test] #[test]
fn test_identity_legacy_cookie_rejected_if_login_timestamp_needed() { fn test_identity_legacy_cookie_rejected_if_login_timestamp_needed() {
let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90))); block_on(async {
let cookie = legacy_login_cookie(COOKIE_LOGIN); let mut srv =
let mut resp = test::call_service( create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
&mut srv, let cookie = legacy_login_cookie(COOKIE_LOGIN);
TestRequest::with_uri("/") let mut resp = test::call_service(
.cookie(cookie.clone()) &mut srv,
.to_request(), TestRequest::with_uri("/")
); .cookie(cookie.clone())
assert_logged_in(&mut resp, None); .to_request(),
assert_login_cookie( )
&mut resp, .await;
COOKIE_LOGIN, assert_login_cookie(
LoginTimestampCheck::NewTimestamp, &mut resp,
VisitTimeStampCheck::NoTimestamp, COOKIE_LOGIN,
); LoginTimestampCheck::NewTimestamp,
VisitTimeStampCheck::NoTimestamp,
);
assert_logged_in(resp, None).await;
})
} }
#[test] #[test]
fn test_identity_cookie_rejected_if_login_timestamp_needed() { fn test_identity_cookie_rejected_if_login_timestamp_needed() {
let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90))); block_on(async {
let cookie = login_cookie(COOKIE_LOGIN, None, Some(SystemTime::now())); let mut srv =
let mut resp = test::call_service( create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
&mut srv, let cookie = login_cookie(COOKIE_LOGIN, None, Some(SystemTime::now()));
TestRequest::with_uri("/") let mut resp = test::call_service(
.cookie(cookie.clone()) &mut srv,
.to_request(), TestRequest::with_uri("/")
); .cookie(cookie.clone())
assert_logged_in(&mut resp, None); .to_request(),
assert_login_cookie( )
&mut resp, .await;
COOKIE_LOGIN, assert_login_cookie(
LoginTimestampCheck::NewTimestamp, &mut resp,
VisitTimeStampCheck::NoTimestamp, COOKIE_LOGIN,
); LoginTimestampCheck::NewTimestamp,
VisitTimeStampCheck::NoTimestamp,
);
assert_logged_in(resp, None).await;
})
} }
#[test] #[test]
fn test_identity_cookie_rejected_if_visit_timestamp_needed() { fn test_identity_cookie_rejected_if_visit_timestamp_needed() {
let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90))); block_on(async {
let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); let mut srv =
let mut resp = test::call_service( create_identity_server(|c| c.visit_deadline(Duration::days(90))).await;
&mut srv, let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None);
TestRequest::with_uri("/") let mut resp = test::call_service(
.cookie(cookie.clone()) &mut srv,
.to_request(), TestRequest::with_uri("/")
); .cookie(cookie.clone())
assert_logged_in(&mut resp, None); .to_request(),
assert_login_cookie( )
&mut resp, .await;
COOKIE_LOGIN, assert_login_cookie(
LoginTimestampCheck::NoTimestamp, &mut resp,
VisitTimeStampCheck::NewTimestamp, COOKIE_LOGIN,
); LoginTimestampCheck::NoTimestamp,
VisitTimeStampCheck::NewTimestamp,
);
assert_logged_in(resp, None).await;
})
} }
#[test] #[test]
fn test_identity_cookie_rejected_if_login_timestamp_too_old() { fn test_identity_cookie_rejected_if_login_timestamp_too_old() {
let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90))); block_on(async {
let cookie = login_cookie( let mut srv =
COOKIE_LOGIN, create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
Some(SystemTime::now() - Duration::days(180).to_std().unwrap()), let cookie = login_cookie(
None, COOKIE_LOGIN,
); Some(SystemTime::now() - Duration::days(180).to_std().unwrap()),
let mut resp = test::call_service( None,
&mut srv, );
TestRequest::with_uri("/") let mut resp = test::call_service(
.cookie(cookie.clone()) &mut srv,
.to_request(), TestRequest::with_uri("/")
); .cookie(cookie.clone())
assert_logged_in(&mut resp, None); .to_request(),
assert_login_cookie( )
&mut resp, .await;
COOKIE_LOGIN, assert_login_cookie(
LoginTimestampCheck::NewTimestamp, &mut resp,
VisitTimeStampCheck::NoTimestamp, COOKIE_LOGIN,
); LoginTimestampCheck::NewTimestamp,
VisitTimeStampCheck::NoTimestamp,
);
assert_logged_in(resp, None).await;
})
} }
#[test] #[test]
fn test_identity_cookie_rejected_if_visit_timestamp_too_old() { fn test_identity_cookie_rejected_if_visit_timestamp_too_old() {
let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90))); block_on(async {
let cookie = login_cookie( let mut srv =
COOKIE_LOGIN, create_identity_server(|c| c.visit_deadline(Duration::days(90))).await;
None, let cookie = login_cookie(
Some(SystemTime::now() - Duration::days(180).to_std().unwrap()), COOKIE_LOGIN,
); None,
let mut resp = test::call_service( Some(SystemTime::now() - Duration::days(180).to_std().unwrap()),
&mut srv, );
TestRequest::with_uri("/") let mut resp = test::call_service(
.cookie(cookie.clone()) &mut srv,
.to_request(), TestRequest::with_uri("/")
); .cookie(cookie.clone())
assert_logged_in(&mut resp, None); .to_request(),
assert_login_cookie( )
&mut resp, .await;
COOKIE_LOGIN, assert_login_cookie(
LoginTimestampCheck::NoTimestamp, &mut resp,
VisitTimeStampCheck::NewTimestamp, COOKIE_LOGIN,
); LoginTimestampCheck::NoTimestamp,
VisitTimeStampCheck::NewTimestamp,
);
assert_logged_in(resp, None).await;
})
} }
#[test] #[test]
fn test_identity_cookie_not_updated_on_login_deadline() { fn test_identity_cookie_not_updated_on_login_deadline() {
let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90))); block_on(async {
let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); let mut srv =
let mut resp = test::call_service( create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
&mut srv, let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None);
TestRequest::with_uri("/") let mut resp = test::call_service(
.cookie(cookie.clone()) &mut srv,
.to_request(), TestRequest::with_uri("/")
); .cookie(cookie.clone())
assert_logged_in(&mut resp, Some(COOKIE_LOGIN)); .to_request(),
assert_no_login_cookie(&mut resp); )
.await;
assert_no_login_cookie(&mut resp);
assert_logged_in(resp, Some(COOKIE_LOGIN)).await;
})
} }
#[test] #[test]
fn test_identity_cookie_updated_on_visit_deadline() { fn test_identity_cookie_updated_on_visit_deadline() {
let mut srv = create_identity_server(|c| { block_on(async {
c.visit_deadline(Duration::days(90)) let mut srv = create_identity_server(|c| {
.login_deadline(Duration::days(90)) c.visit_deadline(Duration::days(90))
}); .login_deadline(Duration::days(90))
let timestamp = SystemTime::now() - Duration::days(1).to_std().unwrap(); })
let cookie = login_cookie(COOKIE_LOGIN, Some(timestamp), Some(timestamp)); .await;
let mut resp = test::call_service( let timestamp = SystemTime::now() - Duration::days(1).to_std().unwrap();
&mut srv, let cookie = login_cookie(COOKIE_LOGIN, Some(timestamp), Some(timestamp));
TestRequest::with_uri("/") let mut resp = test::call_service(
.cookie(cookie.clone()) &mut srv,
.to_request(), TestRequest::with_uri("/")
); .cookie(cookie.clone())
assert_logged_in(&mut resp, Some(COOKIE_LOGIN)); .to_request(),
assert_login_cookie( )
&mut resp, .await;
COOKIE_LOGIN, assert_login_cookie(
LoginTimestampCheck::OldTimestamp(timestamp), &mut resp,
VisitTimeStampCheck::NewTimestamp, COOKIE_LOGIN,
); LoginTimestampCheck::OldTimestamp(timestamp),
VisitTimeStampCheck::NewTimestamp,
);
assert_logged_in(resp, Some(COOKIE_LOGIN)).await;
})
} }
} }

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-multipart" name = "actix-multipart"
version = "0.1.4" version = "0.2.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Multipart support for actix web framework." description = "Multipart support for actix web framework."
readme = "README.md" readme = "README.md"
@@ -18,17 +18,18 @@ name = "actix_multipart"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "1.0.0", default-features = false } actix-web = { version = "2.0.0-alpha.1", default-features = false }
actix-service = "0.4.1" actix-service = "1.0.0-alpha.1"
actix-utils = "0.5.0-alpha.1"
bytes = "0.4" bytes = "0.4"
derive_more = "0.15.0" derive_more = "0.15.0"
httparse = "1.3" httparse = "1.3"
futures = "0.1.25" futures = "0.3.1"
log = "0.4" log = "0.4"
mime = "0.3" mime = "0.3"
time = "0.1" time = "0.1"
twoway = "0.2" twoway = "0.2"
[dev-dependencies] [dev-dependencies]
actix-rt = "0.2.2" actix-rt = "1.0.0-alpha.1"
actix-http = "0.2.4" actix-http = "0.3.0-alpha.1"

View File

@@ -1,5 +1,6 @@
//! Multipart payload support //! Multipart payload support
use actix_web::{dev::Payload, Error, FromRequest, HttpRequest}; use actix_web::{dev::Payload, Error, FromRequest, HttpRequest};
use futures::future::{ok, Ready};
use crate::server::Multipart; use crate::server::Multipart;
@@ -10,33 +11,31 @@ use crate::server::Multipart;
/// ## Server example /// ## Server example
/// ///
/// ```rust /// ```rust
/// # use futures::{Future, Stream}; /// use futures::{Stream, StreamExt};
/// # use futures::future::{ok, result, Either};
/// use actix_web::{web, HttpResponse, Error}; /// use actix_web::{web, HttpResponse, Error};
/// use actix_multipart as mp; /// use actix_multipart as mp;
/// ///
/// fn index(payload: mp::Multipart) -> impl Future<Item = HttpResponse, Error = Error> { /// async fn index(mut payload: mp::Multipart) -> Result<HttpResponse, Error> {
/// payload.from_err() // <- get multipart stream for current request /// // iterate over multipart stream
/// .and_then(|field| { // <- iterate over multipart items /// while let Some(item) = payload.next().await {
/// let mut field = item?;
///
/// // Field in turn is stream of *Bytes* object /// // Field in turn is stream of *Bytes* object
/// field.from_err() /// while let Some(chunk) = field.next().await {
/// .fold((), |_, chunk| { /// println!("-- CHUNK: \n{:?}", std::str::from_utf8(&chunk?));
/// println!("-- CHUNK: \n{:?}", std::str::from_utf8(&chunk)); /// }
/// Ok::<_, Error>(()) /// }
/// }) /// Ok(HttpResponse::Ok().into())
/// })
/// .fold((), |_, _| Ok::<_, Error>(()))
/// .map(|_| HttpResponse::Ok().into())
/// } /// }
/// # fn main() {} /// # fn main() {}
/// ``` /// ```
impl FromRequest for Multipart { impl FromRequest for Multipart {
type Error = Error; type Error = Error;
type Future = Result<Multipart, Error>; type Future = Ready<Result<Multipart, Error>>;
type Config = (); type Config = ();
#[inline] #[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
Ok(Multipart::new(req.headers(), payload.take())) ok(Multipart::new(req.headers(), payload.take()))
} }
} }

View File

@@ -1,15 +1,17 @@
//! Multipart payload support //! Multipart payload support
use std::cell::{Cell, RefCell, RefMut}; use std::cell::{Cell, RefCell, RefMut};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use std::{cmp, fmt}; use std::{cmp, fmt};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::task::{current as current_task, Task}; use futures::stream::{LocalBoxStream, Stream, StreamExt};
use futures::{Async, Poll, Stream};
use httparse; use httparse;
use mime; use mime;
use actix_utils::task::LocalWaker;
use actix_web::error::{ParseError, PayloadError}; use actix_web::error::{ParseError, PayloadError};
use actix_web::http::header::{ use actix_web::http::header::{
self, ContentDisposition, HeaderMap, HeaderName, HeaderValue, self, ContentDisposition, HeaderMap, HeaderName, HeaderValue,
@@ -60,7 +62,7 @@ impl Multipart {
/// Create multipart instance for boundary. /// Create multipart instance for boundary.
pub fn new<S>(headers: &HeaderMap, stream: S) -> Multipart pub fn new<S>(headers: &HeaderMap, stream: S) -> Multipart
where where
S: Stream<Item = Bytes, Error = PayloadError> + 'static, S: Stream<Item = Result<Bytes, PayloadError>> + Unpin + 'static,
{ {
match Self::boundary(headers) { match Self::boundary(headers) {
Ok(boundary) => Multipart { Ok(boundary) => Multipart {
@@ -104,22 +106,25 @@ impl Multipart {
} }
impl Stream for Multipart { impl Stream for Multipart {
type Item = Field; type Item = Result<Field, MultipartError>;
type Error = MultipartError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Option<Self::Item>> {
if let Some(err) = self.error.take() { if let Some(err) = self.error.take() {
Err(err) Poll::Ready(Some(Err(err)))
} else if self.safety.current() { } else if self.safety.current() {
let mut inner = self.inner.as_mut().unwrap().borrow_mut(); let this = self.get_mut();
if let Some(mut payload) = inner.payload.get_mut(&self.safety) { let mut inner = this.inner.as_mut().unwrap().borrow_mut();
payload.poll_stream()?; if let Some(mut payload) = inner.payload.get_mut(&this.safety) {
payload.poll_stream(cx)?;
} }
inner.poll(&self.safety) inner.poll(&this.safety, cx)
} else if !self.safety.is_clean() { } else if !self.safety.is_clean() {
Err(MultipartError::NotConsumed) Poll::Ready(Some(Err(MultipartError::NotConsumed)))
} else { } else {
Ok(Async::NotReady) Poll::Pending
} }
} }
} }
@@ -238,9 +243,13 @@ impl InnerMultipart {
Ok(Some(eof)) Ok(Some(eof))
} }
fn poll(&mut self, safety: &Safety) -> Poll<Option<Field>, MultipartError> { fn poll(
&mut self,
safety: &Safety,
cx: &mut Context,
) -> Poll<Option<Result<Field, MultipartError>>> {
if self.state == InnerState::Eof { if self.state == InnerState::Eof {
Ok(Async::Ready(None)) Poll::Ready(None)
} else { } else {
// release field // release field
loop { loop {
@@ -249,10 +258,13 @@ impl InnerMultipart {
if safety.current() { if safety.current() {
let stop = match self.item { let stop = match self.item {
InnerMultipartItem::Field(ref mut field) => { InnerMultipartItem::Field(ref mut field) => {
match field.borrow_mut().poll(safety)? { match field.borrow_mut().poll(safety) {
Async::NotReady => return Ok(Async::NotReady), Poll::Pending => return Poll::Pending,
Async::Ready(Some(_)) => continue, Poll::Ready(Some(Ok(_))) => continue,
Async::Ready(None) => true, Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e)))
}
Poll::Ready(None) => true,
} }
} }
InnerMultipartItem::None => false, InnerMultipartItem::None => false,
@@ -277,12 +289,12 @@ impl InnerMultipart {
Some(eof) => { Some(eof) => {
if eof { if eof {
self.state = InnerState::Eof; self.state = InnerState::Eof;
return Ok(Async::Ready(None)); return Poll::Ready(None);
} else { } else {
self.state = InnerState::Headers; self.state = InnerState::Headers;
} }
} }
None => return Ok(Async::NotReady), None => return Poll::Pending,
} }
} }
// read boundary // read boundary
@@ -291,11 +303,11 @@ impl InnerMultipart {
&mut *payload, &mut *payload,
&self.boundary, &self.boundary,
)? { )? {
None => return Ok(Async::NotReady), None => return Poll::Pending,
Some(eof) => { Some(eof) => {
if eof { if eof {
self.state = InnerState::Eof; self.state = InnerState::Eof;
return Ok(Async::Ready(None)); return Poll::Ready(None);
} else { } else {
self.state = InnerState::Headers; self.state = InnerState::Headers;
} }
@@ -311,14 +323,14 @@ impl InnerMultipart {
self.state = InnerState::Boundary; self.state = InnerState::Boundary;
headers headers
} else { } else {
return Ok(Async::NotReady); return Poll::Pending;
} }
} else { } else {
unreachable!() unreachable!()
} }
} else { } else {
log::debug!("NotReady: field is in flight"); log::debug!("NotReady: field is in flight");
return Ok(Async::NotReady); return Poll::Pending;
}; };
// content type // content type
@@ -335,7 +347,7 @@ impl InnerMultipart {
// nested multipart stream // nested multipart stream
if mt.type_() == mime::MULTIPART { if mt.type_() == mime::MULTIPART {
Err(MultipartError::Nested) Poll::Ready(Some(Err(MultipartError::Nested)))
} else { } else {
let field = Rc::new(RefCell::new(InnerField::new( let field = Rc::new(RefCell::new(InnerField::new(
self.payload.clone(), self.payload.clone(),
@@ -344,12 +356,7 @@ impl InnerMultipart {
)?)); )?));
self.item = InnerMultipartItem::Field(Rc::clone(&field)); self.item = InnerMultipartItem::Field(Rc::clone(&field));
Ok(Async::Ready(Some(Field::new( Poll::Ready(Some(Ok(Field::new(safety.clone(cx), headers, mt, field))))
safety.clone(),
headers,
mt,
field,
))))
} }
} }
} }
@@ -409,23 +416,21 @@ impl Field {
} }
impl Stream for Field { impl Stream for Field {
type Item = Bytes; type Item = Result<Bytes, MultipartError>;
type Error = MultipartError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
if self.safety.current() { if self.safety.current() {
let mut inner = self.inner.borrow_mut(); let mut inner = self.inner.borrow_mut();
if let Some(mut payload) = if let Some(mut payload) =
inner.payload.as_ref().unwrap().get_mut(&self.safety) inner.payload.as_ref().unwrap().get_mut(&self.safety)
{ {
payload.poll_stream()?; payload.poll_stream(cx)?;
} }
inner.poll(&self.safety) inner.poll(&self.safety)
} else if !self.safety.is_clean() { } else if !self.safety.is_clean() {
Err(MultipartError::NotConsumed) Poll::Ready(Some(Err(MultipartError::NotConsumed)))
} else { } else {
Ok(Async::NotReady) Poll::Pending
} }
} }
} }
@@ -482,9 +487,9 @@ impl InnerField {
fn read_len( fn read_len(
payload: &mut PayloadBuffer, payload: &mut PayloadBuffer,
size: &mut u64, size: &mut u64,
) -> Poll<Option<Bytes>, MultipartError> { ) -> Poll<Option<Result<Bytes, MultipartError>>> {
if *size == 0 { if *size == 0 {
Ok(Async::Ready(None)) Poll::Ready(None)
} else { } else {
match payload.read_max(*size)? { match payload.read_max(*size)? {
Some(mut chunk) => { Some(mut chunk) => {
@@ -494,13 +499,13 @@ impl InnerField {
if !chunk.is_empty() { if !chunk.is_empty() {
payload.unprocessed(chunk); payload.unprocessed(chunk);
} }
Ok(Async::Ready(Some(ch))) Poll::Ready(Some(Ok(ch)))
} }
None => { None => {
if payload.eof && (*size != 0) { if payload.eof && (*size != 0) {
Err(MultipartError::Incomplete) Poll::Ready(Some(Err(MultipartError::Incomplete)))
} else { } else {
Ok(Async::NotReady) Poll::Pending
} }
} }
} }
@@ -512,15 +517,15 @@ impl InnerField {
fn read_stream( fn read_stream(
payload: &mut PayloadBuffer, payload: &mut PayloadBuffer,
boundary: &str, boundary: &str,
) -> Poll<Option<Bytes>, MultipartError> { ) -> Poll<Option<Result<Bytes, MultipartError>>> {
let mut pos = 0; let mut pos = 0;
let len = payload.buf.len(); let len = payload.buf.len();
if len == 0 { if len == 0 {
return if payload.eof { return if payload.eof {
Err(MultipartError::Incomplete) Poll::Ready(Some(Err(MultipartError::Incomplete)))
} else { } else {
Ok(Async::NotReady) Poll::Pending
}; };
} }
@@ -537,10 +542,10 @@ impl InnerField {
if let Some(b_len) = b_len { if let Some(b_len) = b_len {
let b_size = boundary.len() + b_len; let b_size = boundary.len() + b_len;
if len < b_size { if len < b_size {
return Ok(Async::NotReady); return Poll::Pending;
} else if &payload.buf[b_len..b_size] == boundary.as_bytes() { } else if &payload.buf[b_len..b_size] == boundary.as_bytes() {
// found boundary // found boundary
return Ok(Async::Ready(None)); return Poll::Ready(None);
} }
} }
} }
@@ -552,9 +557,9 @@ impl InnerField {
// check if we have enough data for boundary detection // check if we have enough data for boundary detection
if cur + 4 > len { if cur + 4 > len {
if cur > 0 { if cur > 0 {
Ok(Async::Ready(Some(payload.buf.split_to(cur).freeze()))) Poll::Ready(Some(Ok(payload.buf.split_to(cur).freeze())))
} else { } else {
Ok(Async::NotReady) Poll::Pending
} }
} else { } else {
// check boundary // check boundary
@@ -565,7 +570,7 @@ impl InnerField {
{ {
if cur != 0 { if cur != 0 {
// return buffer // return buffer
Ok(Async::Ready(Some(payload.buf.split_to(cur).freeze()))) Poll::Ready(Some(Ok(payload.buf.split_to(cur).freeze())))
} else { } else {
pos = cur + 1; pos = cur + 1;
continue; continue;
@@ -577,49 +582,51 @@ impl InnerField {
} }
} }
} else { } else {
Ok(Async::Ready(Some(payload.buf.take().freeze()))) Poll::Ready(Some(Ok(payload.buf.take().freeze())))
}; };
} }
} }
fn poll(&mut self, s: &Safety) -> Poll<Option<Bytes>, MultipartError> { fn poll(&mut self, s: &Safety) -> Poll<Option<Result<Bytes, MultipartError>>> {
if self.payload.is_none() { if self.payload.is_none() {
return Ok(Async::Ready(None)); return Poll::Ready(None);
} }
let result = if let Some(mut payload) = self.payload.as_ref().unwrap().get_mut(s) let result = if let Some(mut payload) = self.payload.as_ref().unwrap().get_mut(s)
{ {
if !self.eof { if !self.eof {
let res = if let Some(ref mut len) = self.length { let res = if let Some(ref mut len) = self.length {
InnerField::read_len(&mut *payload, len)? InnerField::read_len(&mut *payload, len)
} else { } else {
InnerField::read_stream(&mut *payload, &self.boundary)? InnerField::read_stream(&mut *payload, &self.boundary)
}; };
match res { match res {
Async::NotReady => return Ok(Async::NotReady), Poll::Pending => return Poll::Pending,
Async::Ready(Some(bytes)) => return Ok(Async::Ready(Some(bytes))), Poll::Ready(Some(Ok(bytes))) => return Poll::Ready(Some(Ok(bytes))),
Async::Ready(None) => self.eof = true, Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(None) => self.eof = true,
} }
} }
match payload.readline()? { match payload.readline() {
None => Async::Ready(None), Ok(None) => Poll::Ready(None),
Some(line) => { Ok(Some(line)) => {
if line.as_ref() != b"\r\n" { if line.as_ref() != b"\r\n" {
log::warn!("multipart field did not read all the data or it is malformed"); log::warn!("multipart field did not read all the data or it is malformed");
} }
Async::Ready(None) Poll::Ready(None)
} }
Err(e) => Poll::Ready(Some(Err(e))),
} }
} else { } else {
Async::NotReady Poll::Pending
}; };
if Async::Ready(None) == result { if let Poll::Ready(None) = result {
self.payload.take(); self.payload.take();
} }
Ok(result) result
} }
} }
@@ -659,7 +666,7 @@ impl Clone for PayloadRef {
/// most task. /// most task.
#[derive(Debug)] #[derive(Debug)]
struct Safety { struct Safety {
task: Option<Task>, task: LocalWaker,
level: usize, level: usize,
payload: Rc<PhantomData<bool>>, payload: Rc<PhantomData<bool>>,
clean: Rc<Cell<bool>>, clean: Rc<Cell<bool>>,
@@ -669,7 +676,7 @@ impl Safety {
fn new() -> Safety { fn new() -> Safety {
let payload = Rc::new(PhantomData); let payload = Rc::new(PhantomData);
Safety { Safety {
task: None, task: LocalWaker::new(),
level: Rc::strong_count(&payload), level: Rc::strong_count(&payload),
clean: Rc::new(Cell::new(true)), clean: Rc::new(Cell::new(true)),
payload, payload,
@@ -683,17 +690,17 @@ impl Safety {
fn is_clean(&self) -> bool { fn is_clean(&self) -> bool {
self.clean.get() self.clean.get()
} }
}
impl Clone for Safety { fn clone(&self, cx: &mut Context) -> Safety {
fn clone(&self) -> Safety {
let payload = Rc::clone(&self.payload); let payload = Rc::clone(&self.payload);
Safety { let s = Safety {
task: Some(current_task()), task: LocalWaker::new(),
level: Rc::strong_count(&payload), level: Rc::strong_count(&payload),
clean: self.clean.clone(), clean: self.clean.clone(),
payload, payload,
} };
s.task.register(cx.waker());
s
} }
} }
@@ -704,7 +711,7 @@ impl Drop for Safety {
self.clean.set(true); self.clean.set(true);
} }
if let Some(task) = self.task.take() { if let Some(task) = self.task.take() {
task.notify() task.wake()
} }
} }
} }
@@ -713,31 +720,32 @@ impl Drop for Safety {
struct PayloadBuffer { struct PayloadBuffer {
eof: bool, eof: bool,
buf: BytesMut, buf: BytesMut,
stream: Box<dyn Stream<Item = Bytes, Error = PayloadError>>, stream: LocalBoxStream<'static, Result<Bytes, PayloadError>>,
} }
impl PayloadBuffer { impl PayloadBuffer {
/// Create new `PayloadBuffer` instance /// Create new `PayloadBuffer` instance
fn new<S>(stream: S) -> Self fn new<S>(stream: S) -> Self
where where
S: Stream<Item = Bytes, Error = PayloadError> + 'static, S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
{ {
PayloadBuffer { PayloadBuffer {
eof: false, eof: false,
buf: BytesMut::new(), buf: BytesMut::new(),
stream: Box::new(stream), stream: stream.boxed_local(),
} }
} }
fn poll_stream(&mut self) -> Result<(), PayloadError> { fn poll_stream(&mut self, cx: &mut Context) -> Result<(), PayloadError> {
loop { loop {
match self.stream.poll()? { match Pin::new(&mut self.stream).poll_next(cx) {
Async::Ready(Some(data)) => self.buf.extend_from_slice(&data), Poll::Ready(Some(Ok(data))) => self.buf.extend_from_slice(&data),
Async::Ready(None) => { Poll::Ready(Some(Err(e))) => return Err(e),
Poll::Ready(None) => {
self.eof = true; self.eof = true;
return Ok(()); return Ok(());
} }
Async::NotReady => return Ok(()), Poll::Pending => return Ok(()),
} }
} }
} }
@@ -800,13 +808,14 @@ impl PayloadBuffer {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use actix_http::h1::Payload;
use bytes::Bytes;
use futures::unsync::mpsc;
use super::*; use super::*;
use actix_http::h1::Payload;
use actix_utils::mpsc;
use actix_web::http::header::{DispositionParam, DispositionType}; use actix_web::http::header::{DispositionParam, DispositionType};
use actix_web::test::run_on; use actix_web::test::block_on;
use bytes::Bytes;
use futures::future::lazy;
#[test] #[test]
fn test_boundary() { fn test_boundary() {
@@ -852,12 +861,12 @@ mod tests {
} }
fn create_stream() -> ( fn create_stream() -> (
mpsc::UnboundedSender<Result<Bytes, PayloadError>>, mpsc::Sender<Result<Bytes, PayloadError>>,
impl Stream<Item = Bytes, Error = PayloadError>, impl Stream<Item = Result<Bytes, PayloadError>>,
) { ) {
let (tx, rx) = mpsc::unbounded(); let (tx, rx) = mpsc::channel();
(tx, rx.map_err(|_| panic!()).and_then(|res| res)) (tx, rx.map(|res| res.map_err(|_| panic!())))
} }
fn create_simple_request_with_header() -> (Bytes, HeaderMap) { fn create_simple_request_with_header() -> (Bytes, HeaderMap) {
@@ -884,28 +893,28 @@ mod tests {
#[test] #[test]
fn test_multipart_no_end_crlf() { fn test_multipart_no_end_crlf() {
run_on(|| { block_on(async {
let (sender, payload) = create_stream(); let (sender, payload) = create_stream();
let (bytes, headers) = create_simple_request_with_header(); let (bytes, headers) = create_simple_request_with_header();
let bytes_stripped = bytes.slice_to(bytes.len()); // strip crlf let bytes_stripped = bytes.slice_to(bytes.len()); // strip crlf
sender.unbounded_send(Ok(bytes_stripped)).unwrap(); sender.send(Ok(bytes_stripped)).unwrap();
drop(sender); // eof drop(sender); // eof
let mut multipart = Multipart::new(&headers, payload); let mut multipart = Multipart::new(&headers, payload);
match multipart.poll().unwrap() { match multipart.next().await.unwrap() {
Async::Ready(Some(_)) => (), Ok(_) => (),
_ => unreachable!(), _ => unreachable!(),
} }
match multipart.poll().unwrap() { match multipart.next().await.unwrap() {
Async::Ready(Some(_)) => (), Ok(_) => (),
_ => unreachable!(), _ => unreachable!(),
} }
match multipart.poll().unwrap() { match multipart.next().await {
Async::Ready(None) => (), None => (),
_ => unreachable!(), _ => unreachable!(),
} }
}) })
@@ -913,15 +922,15 @@ mod tests {
#[test] #[test]
fn test_multipart() { fn test_multipart() {
run_on(|| { block_on(async {
let (sender, payload) = create_stream(); let (sender, payload) = create_stream();
let (bytes, headers) = create_simple_request_with_header(); let (bytes, headers) = create_simple_request_with_header();
sender.unbounded_send(Ok(bytes)).unwrap(); sender.send(Ok(bytes)).unwrap();
let mut multipart = Multipart::new(&headers, payload); let mut multipart = Multipart::new(&headers, payload);
match multipart.poll().unwrap() { match multipart.next().await {
Async::Ready(Some(mut field)) => { Some(Ok(mut field)) => {
let cd = field.content_disposition().unwrap(); let cd = field.content_disposition().unwrap();
assert_eq!(cd.disposition, DispositionType::FormData); assert_eq!(cd.disposition, DispositionType::FormData);
assert_eq!(cd.parameters[0], DispositionParam::Name("file".into())); assert_eq!(cd.parameters[0], DispositionParam::Name("file".into()));
@@ -929,37 +938,37 @@ mod tests {
assert_eq!(field.content_type().type_(), mime::TEXT); assert_eq!(field.content_type().type_(), mime::TEXT);
assert_eq!(field.content_type().subtype(), mime::PLAIN); assert_eq!(field.content_type().subtype(), mime::PLAIN);
match field.poll().unwrap() { match field.next().await.unwrap() {
Async::Ready(Some(chunk)) => assert_eq!(chunk, "test"), Ok(chunk) => assert_eq!(chunk, "test"),
_ => unreachable!(), _ => unreachable!(),
} }
match field.poll().unwrap() { match field.next().await {
Async::Ready(None) => (), None => (),
_ => unreachable!(), _ => unreachable!(),
} }
} }
_ => unreachable!(), _ => unreachable!(),
} }
match multipart.poll().unwrap() { match multipart.next().await.unwrap() {
Async::Ready(Some(mut field)) => { Ok(mut field) => {
assert_eq!(field.content_type().type_(), mime::TEXT); assert_eq!(field.content_type().type_(), mime::TEXT);
assert_eq!(field.content_type().subtype(), mime::PLAIN); assert_eq!(field.content_type().subtype(), mime::PLAIN);
match field.poll() { match field.next().await {
Ok(Async::Ready(Some(chunk))) => assert_eq!(chunk, "data"), Some(Ok(chunk)) => assert_eq!(chunk, "data"),
_ => unreachable!(), _ => unreachable!(),
} }
match field.poll() { match field.next().await {
Ok(Async::Ready(None)) => (), None => (),
_ => unreachable!(), _ => unreachable!(),
} }
} }
_ => unreachable!(), _ => unreachable!(),
} }
match multipart.poll().unwrap() { match multipart.next().await {
Async::Ready(None) => (), None => (),
_ => unreachable!(), _ => unreachable!(),
} }
}); });
@@ -967,15 +976,15 @@ mod tests {
#[test] #[test]
fn test_stream() { fn test_stream() {
run_on(|| { block_on(async {
let (sender, payload) = create_stream(); let (sender, payload) = create_stream();
let (bytes, headers) = create_simple_request_with_header(); let (bytes, headers) = create_simple_request_with_header();
sender.unbounded_send(Ok(bytes)).unwrap(); sender.send(Ok(bytes)).unwrap();
let mut multipart = Multipart::new(&headers, payload); let mut multipart = Multipart::new(&headers, payload);
match multipart.poll().unwrap() { match multipart.next().await.unwrap() {
Async::Ready(Some(mut field)) => { Ok(mut field) => {
let cd = field.content_disposition().unwrap(); let cd = field.content_disposition().unwrap();
assert_eq!(cd.disposition, DispositionType::FormData); assert_eq!(cd.disposition, DispositionType::FormData);
assert_eq!(cd.parameters[0], DispositionParam::Name("file".into())); assert_eq!(cd.parameters[0], DispositionParam::Name("file".into()));
@@ -983,37 +992,37 @@ mod tests {
assert_eq!(field.content_type().type_(), mime::TEXT); assert_eq!(field.content_type().type_(), mime::TEXT);
assert_eq!(field.content_type().subtype(), mime::PLAIN); assert_eq!(field.content_type().subtype(), mime::PLAIN);
match field.poll().unwrap() { match field.next().await.unwrap() {
Async::Ready(Some(chunk)) => assert_eq!(chunk, "test"), Ok(chunk) => assert_eq!(chunk, "test"),
_ => unreachable!(), _ => unreachable!(),
} }
match field.poll().unwrap() { match field.next().await {
Async::Ready(None) => (), None => (),
_ => unreachable!(), _ => unreachable!(),
} }
} }
_ => unreachable!(), _ => unreachable!(),
} }
match multipart.poll().unwrap() { match multipart.next().await {
Async::Ready(Some(mut field)) => { Some(Ok(mut field)) => {
assert_eq!(field.content_type().type_(), mime::TEXT); assert_eq!(field.content_type().type_(), mime::TEXT);
assert_eq!(field.content_type().subtype(), mime::PLAIN); assert_eq!(field.content_type().subtype(), mime::PLAIN);
match field.poll() { match field.next().await {
Ok(Async::Ready(Some(chunk))) => assert_eq!(chunk, "data"), Some(Ok(chunk)) => assert_eq!(chunk, "data"),
_ => unreachable!(), _ => unreachable!(),
} }
match field.poll() { match field.next().await {
Ok(Async::Ready(None)) => (), None => (),
_ => unreachable!(), _ => unreachable!(),
} }
} }
_ => unreachable!(), _ => unreachable!(),
} }
match multipart.poll().unwrap() { match multipart.next().await {
Async::Ready(None) => (), None => (),
_ => unreachable!(), _ => unreachable!(),
} }
}); });
@@ -1021,26 +1030,26 @@ mod tests {
#[test] #[test]
fn test_basic() { fn test_basic() {
run_on(|| { block_on(async {
let (_, payload) = Payload::create(false); let (_, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload); let mut payload = PayloadBuffer::new(payload);
assert_eq!(payload.buf.len(), 0); assert_eq!(payload.buf.len(), 0);
payload.poll_stream().unwrap(); lazy(|cx| payload.poll_stream(cx)).await.unwrap();
assert_eq!(None, payload.read_max(1).unwrap()); assert_eq!(None, payload.read_max(1).unwrap());
}) })
} }
#[test] #[test]
fn test_eof() { fn test_eof() {
run_on(|| { block_on(async {
let (mut sender, payload) = Payload::create(false); let (mut sender, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload); let mut payload = PayloadBuffer::new(payload);
assert_eq!(None, payload.read_max(4).unwrap()); assert_eq!(None, payload.read_max(4).unwrap());
sender.feed_data(Bytes::from("data")); sender.feed_data(Bytes::from("data"));
sender.feed_eof(); sender.feed_eof();
payload.poll_stream().unwrap(); lazy(|cx| payload.poll_stream(cx)).await.unwrap();
assert_eq!(Some(Bytes::from("data")), payload.read_max(4).unwrap()); assert_eq!(Some(Bytes::from("data")), payload.read_max(4).unwrap());
assert_eq!(payload.buf.len(), 0); assert_eq!(payload.buf.len(), 0);
@@ -1051,24 +1060,24 @@ mod tests {
#[test] #[test]
fn test_err() { fn test_err() {
run_on(|| { block_on(async {
let (mut sender, payload) = Payload::create(false); let (mut sender, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload); let mut payload = PayloadBuffer::new(payload);
assert_eq!(None, payload.read_max(1).unwrap()); assert_eq!(None, payload.read_max(1).unwrap());
sender.set_error(PayloadError::Incomplete(None)); sender.set_error(PayloadError::Incomplete(None));
payload.poll_stream().err().unwrap(); lazy(|cx| payload.poll_stream(cx)).await.err().unwrap();
}) })
} }
#[test] #[test]
fn test_readmax() { fn test_readmax() {
run_on(|| { block_on(async {
let (mut sender, payload) = Payload::create(false); let (mut sender, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload); let mut payload = PayloadBuffer::new(payload);
sender.feed_data(Bytes::from("line1")); sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2")); sender.feed_data(Bytes::from("line2"));
payload.poll_stream().unwrap(); lazy(|cx| payload.poll_stream(cx)).await.unwrap();
assert_eq!(payload.buf.len(), 10); assert_eq!(payload.buf.len(), 10);
assert_eq!(Some(Bytes::from("line1")), payload.read_max(5).unwrap()); assert_eq!(Some(Bytes::from("line1")), payload.read_max(5).unwrap());
@@ -1081,7 +1090,7 @@ mod tests {
#[test] #[test]
fn test_readexactly() { fn test_readexactly() {
run_on(|| { block_on(async {
let (mut sender, payload) = Payload::create(false); let (mut sender, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload); let mut payload = PayloadBuffer::new(payload);
@@ -1089,7 +1098,7 @@ mod tests {
sender.feed_data(Bytes::from("line1")); sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2")); sender.feed_data(Bytes::from("line2"));
payload.poll_stream().unwrap(); lazy(|cx| payload.poll_stream(cx)).await.unwrap();
assert_eq!(Some(Bytes::from_static(b"li")), payload.read_exact(2)); assert_eq!(Some(Bytes::from_static(b"li")), payload.read_exact(2));
assert_eq!(payload.buf.len(), 8); assert_eq!(payload.buf.len(), 8);
@@ -1101,7 +1110,7 @@ mod tests {
#[test] #[test]
fn test_readuntil() { fn test_readuntil() {
run_on(|| { block_on(async {
let (mut sender, payload) = Payload::create(false); let (mut sender, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload); let mut payload = PayloadBuffer::new(payload);
@@ -1109,7 +1118,7 @@ mod tests {
sender.feed_data(Bytes::from("line1")); sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2")); sender.feed_data(Bytes::from("line2"));
payload.poll_stream().unwrap(); lazy(|cx| payload.poll_stream(cx)).await.unwrap();
assert_eq!( assert_eq!(
Some(Bytes::from("line")), Some(Bytes::from("line")),

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-session" name = "actix-session"
version = "0.2.0" version = "0.3.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Session for actix web framework." description = "Session for actix web framework."
readme = "README.md" readme = "README.md"
@@ -24,15 +24,15 @@ default = ["cookie-session"]
cookie-session = ["actix-web/secure-cookies"] cookie-session = ["actix-web/secure-cookies"]
[dependencies] [dependencies]
actix-web = "1.0.0" actix-web = "2.0.0-alpha.1"
actix-service = "0.4.1" actix-service = "1.0.0-alpha.1"
bytes = "0.4" bytes = "0.4"
derive_more = "0.15.0" derive_more = "0.15.0"
futures = "0.1.25" futures = "0.3.1"
hashbrown = "0.5.0" hashbrown = "0.6.3"
serde = "1.0" serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
time = "0.1.42" time = "0.1.42"
[dev-dependencies] [dev-dependencies]
actix-rt = "0.2.2" actix-rt = "1.0.0-alpha.1"

View File

@@ -17,6 +17,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::{Service, Transform}; use actix_service::{Service, Transform};
use actix_web::cookie::{Cookie, CookieJar, Key, SameSite}; use actix_web::cookie::{Cookie, CookieJar, Key, SameSite};
@@ -24,8 +25,7 @@ use actix_web::dev::{ServiceRequest, ServiceResponse};
use actix_web::http::{header::SET_COOKIE, HeaderValue}; use actix_web::http::{header::SET_COOKIE, HeaderValue};
use actix_web::{Error, HttpMessage, ResponseError}; use actix_web::{Error, HttpMessage, ResponseError};
use derive_more::{Display, From}; use derive_more::{Display, From};
use futures::future::{ok, Future, FutureResult}; use futures::future::{ok, FutureExt, LocalBoxFuture, Ready};
use futures::Poll;
use serde_json::error::Error as JsonError; use serde_json::error::Error as JsonError;
use crate::{Session, SessionStatus}; use crate::{Session, SessionStatus};
@@ -284,7 +284,7 @@ where
type Error = S::Error; type Error = S::Error;
type InitError = (); type InitError = ();
type Transform = CookieSessionMiddleware<S>; type Transform = CookieSessionMiddleware<S>;
type Future = FutureResult<Self::Transform, Self::InitError>; type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
ok(CookieSessionMiddleware { ok(CookieSessionMiddleware {
@@ -309,10 +309,10 @@ where
type Request = ServiceRequest; type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = S::Error; type Error = S::Error;
type Future = Box<dyn Future<Item = Self::Response, Error = Self::Error>>; type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready() self.service.poll_ready(cx)
} }
/// On first request, a new session cookie is returned in response, regardless /// On first request, a new session cookie is returned in response, regardless
@@ -325,29 +325,36 @@ where
let (is_new, state) = self.inner.load(&req); let (is_new, state) = self.inner.load(&req);
Session::set_session(state.into_iter(), &mut req); Session::set_session(state.into_iter(), &mut req);
Box::new(self.service.call(req).map(move |mut res| { let fut = self.service.call(req);
match Session::get_changes(&mut res) {
(SessionStatus::Changed, Some(state)) async move {
| (SessionStatus::Renewed, Some(state)) => { fut.await.map(|mut res| {
res.checked_expr(|res| inner.set_cookie(res, state)) match Session::get_changes(&mut res) {
} (SessionStatus::Changed, Some(state))
(SessionStatus::Unchanged, _) => | (SessionStatus::Renewed, Some(state)) => {
// set a new session cookie upon first request (new client) res.checked_expr(|res| inner.set_cookie(res, state))
{ }
if is_new { (SessionStatus::Unchanged, _) =>
let state: HashMap<String, String> = HashMap::new(); // set a new session cookie upon first request (new client)
res.checked_expr(|res| inner.set_cookie(res, state.into_iter())) {
} else { if is_new {
let state: HashMap<String, String> = HashMap::new();
res.checked_expr(|res| {
inner.set_cookie(res, state.into_iter())
})
} else {
res
}
}
(SessionStatus::Purged, _) => {
let _ = inner.remove_cookie(&mut res);
res res
} }
_ => res,
} }
(SessionStatus::Purged, _) => { })
let _ = inner.remove_cookie(&mut res); }
res .boxed_local()
}
_ => res,
}
}))
} }
} }
@@ -359,101 +366,123 @@ mod tests {
#[test] #[test]
fn cookie_session() { fn cookie_session() {
let mut app = test::init_service( test::block_on(async {
App::new() let mut app = test::init_service(
.wrap(CookieSession::signed(&[0; 32]).secure(false)) App::new()
.service(web::resource("/").to(|ses: Session| { .wrap(CookieSession::signed(&[0; 32]).secure(false))
let _ = ses.set("counter", 100); .service(web::resource("/").to(|ses: Session| {
"test" async move {
})), let _ = ses.set("counter", 100);
); "test"
}
})),
)
.await;
let request = test::TestRequest::get().to_request(); let request = test::TestRequest::get().to_request();
let response = test::block_on(app.call(request)).unwrap(); let response = app.call(request).await.unwrap();
assert!(response assert!(response
.response() .response()
.cookies() .cookies()
.find(|c| c.name() == "actix-session") .find(|c| c.name() == "actix-session")
.is_some()); .is_some());
})
} }
#[test] #[test]
fn private_cookie() { fn private_cookie() {
let mut app = test::init_service( test::block_on(async {
App::new() let mut app = test::init_service(
.wrap(CookieSession::private(&[0; 32]).secure(false)) App::new()
.service(web::resource("/").to(|ses: Session| { .wrap(CookieSession::private(&[0; 32]).secure(false))
let _ = ses.set("counter", 100); .service(web::resource("/").to(|ses: Session| {
"test" async move {
})), let _ = ses.set("counter", 100);
); "test"
}
})),
)
.await;
let request = test::TestRequest::get().to_request(); let request = test::TestRequest::get().to_request();
let response = test::block_on(app.call(request)).unwrap(); let response = app.call(request).await.unwrap();
assert!(response assert!(response
.response() .response()
.cookies() .cookies()
.find(|c| c.name() == "actix-session") .find(|c| c.name() == "actix-session")
.is_some()); .is_some());
})
} }
#[test] #[test]
fn cookie_session_extractor() { fn cookie_session_extractor() {
let mut app = test::init_service( test::block_on(async {
App::new() let mut app = test::init_service(
.wrap(CookieSession::signed(&[0; 32]).secure(false)) App::new()
.service(web::resource("/").to(|ses: Session| { .wrap(CookieSession::signed(&[0; 32]).secure(false))
let _ = ses.set("counter", 100); .service(web::resource("/").to(|ses: Session| {
"test" async move {
})), let _ = ses.set("counter", 100);
); "test"
}
})),
)
.await;
let request = test::TestRequest::get().to_request(); let request = test::TestRequest::get().to_request();
let response = test::block_on(app.call(request)).unwrap(); let response = app.call(request).await.unwrap();
assert!(response assert!(response
.response() .response()
.cookies() .cookies()
.find(|c| c.name() == "actix-session") .find(|c| c.name() == "actix-session")
.is_some()); .is_some());
})
} }
#[test] #[test]
fn basics() { fn basics() {
let mut app = test::init_service( test::block_on(async {
App::new() let mut app = test::init_service(
.wrap( App::new()
CookieSession::signed(&[0; 32]) .wrap(
.path("/test/") CookieSession::signed(&[0; 32])
.name("actix-test") .path("/test/")
.domain("localhost") .name("actix-test")
.http_only(true) .domain("localhost")
.same_site(SameSite::Lax) .http_only(true)
.max_age(100), .same_site(SameSite::Lax)
) .max_age(100),
.service(web::resource("/").to(|ses: Session| { )
let _ = ses.set("counter", 100); .service(web::resource("/").to(|ses: Session| {
"test" async move {
})) let _ = ses.set("counter", 100);
.service(web::resource("/test/").to(|ses: Session| { "test"
let val: usize = ses.get("counter").unwrap().unwrap(); }
format!("counter: {}", val) }))
})), .service(web::resource("/test/").to(|ses: Session| {
); async move {
let val: usize = ses.get("counter").unwrap().unwrap();
format!("counter: {}", val)
}
})),
)
.await;
let request = test::TestRequest::get().to_request(); let request = test::TestRequest::get().to_request();
let response = test::block_on(app.call(request)).unwrap(); let response = app.call(request).await.unwrap();
let cookie = response let cookie = response
.response() .response()
.cookies() .cookies()
.find(|c| c.name() == "actix-test") .find(|c| c.name() == "actix-test")
.unwrap() .unwrap()
.clone(); .clone();
assert_eq!(cookie.path().unwrap(), "/test/"); assert_eq!(cookie.path().unwrap(), "/test/");
let request = test::TestRequest::with_uri("/test/") let request = test::TestRequest::with_uri("/test/")
.cookie(cookie) .cookie(cookie)
.to_request(); .to_request();
let body = test::read_response(&mut app, request); let body = test::read_response(&mut app, request).await;
assert_eq!(body, Bytes::from_static(b"counter: 100")); assert_eq!(body, Bytes::from_static(b"counter: 100"));
})
} }
} }

View File

@@ -47,6 +47,7 @@ use std::rc::Rc;
use actix_web::dev::{Extensions, Payload, ServiceRequest, ServiceResponse}; use actix_web::dev::{Extensions, Payload, ServiceRequest, ServiceResponse};
use actix_web::{Error, FromRequest, HttpMessage, HttpRequest}; use actix_web::{Error, FromRequest, HttpMessage, HttpRequest};
use futures::future::{ok, Ready};
use hashbrown::HashMap; use hashbrown::HashMap;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::Serialize; use serde::Serialize;
@@ -230,12 +231,12 @@ impl Session {
/// ``` /// ```
impl FromRequest for Session { impl FromRequest for Session {
type Error = Error; type Error = Error;
type Future = Result<Session, Error>; type Future = Ready<Result<Session, Error>>;
type Config = (); type Config = ();
#[inline] #[inline]
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
Ok(Session::get_session(&mut *req.extensions_mut())) ok(Session::get_session(&mut *req.extensions_mut()))
} }
} }

View File

@@ -1,5 +1,9 @@
# Changes # Changes
## [1.0.3] - 2019-11-14
* Update actix-web and actix-http dependencies
## [1.0.2] - 2019-07-20 ## [1.0.2] - 2019-07-20
* Add `ws::start_with_addr()`, returning the address of the created actor, along * Add `ws::start_with_addr()`, returning the address of the created actor, along

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-web-actors" name = "actix-web-actors"
version = "1.0.2" version = "1.0.3"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix actors support for actix web framework." description = "Actix actors support for actix web framework."
readme = "README.md" readme = "README.md"
@@ -19,8 +19,8 @@ path = "src/lib.rs"
[dependencies] [dependencies]
actix = "0.8.3" actix = "0.8.3"
actix-web = "1.0.3" actix-web = "1.0.9"
actix-http = "0.2.5" actix-http = "0.2.11"
actix-codec = "0.1.2" actix-codec = "0.1.2"
bytes = "0.4" bytes = "0.4"
futures = "0.1.25" futures = "0.1.25"

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-web-codegen" name = "actix-web-codegen"
version = "0.1.3" version = "0.2.0-alpha.1"
description = "Actix web proc macros" description = "Actix web proc macros"
readme = "README.md" readme = "README.md"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
@@ -17,7 +17,7 @@ syn = { version = "1", features = ["full", "parsing"] }
proc-macro2 = "1" proc-macro2 = "1"
[dev-dependencies] [dev-dependencies]
actix-web = { version = "1.0.0" } actix-web = { version = "2.0.0-alph.a" }
actix-http = { version = "0.2.4", features=["ssl"] } actix-http = { version = "0.3.0-alpha.1", features=["openssl"] }
actix-http-test = { version = "0.2.0", features=["ssl"] } actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] }
futures = { version = "0.1" } futures = { version = "0.3.1" }

View File

@@ -35,8 +35,8 @@
//! use futures::{future, Future}; //! use futures::{future, Future};
//! //!
//! #[get("/test")] //! #[get("/test")]
//! fn async_test() -> impl Future<Item=HttpResponse, Error=actix_web::Error> { //! async fn async_test() -> Result<HttpResponse, actix_web::Error> {
//! future::ok(HttpResponse::Ok().finish()) //! Ok(HttpResponse::Ok().finish())
//! } //! }
//! ``` //! ```

View File

@@ -13,7 +13,7 @@ enum ResourceType {
impl ToTokens for ResourceType { impl ToTokens for ResourceType {
fn to_tokens(&self, stream: &mut TokenStream2) { fn to_tokens(&self, stream: &mut TokenStream2) {
let ident = match self { let ident = match self {
ResourceType::Async => "to_async", ResourceType::Async => "to",
ResourceType::Sync => "to", ResourceType::Sync => "to",
}; };
let ident = Ident::new(ident, Span::call_site()); let ident = Ident::new(ident, Span::call_site());

View File

@@ -1,157 +1,163 @@
use actix_http::HttpService; use actix_http::HttpService;
use actix_http_test::TestServer; use actix_http_test::{block_on, TestServer};
use actix_web::{http, web::Path, App, HttpResponse, Responder}; use actix_web::{http, web::Path, App, HttpResponse, Responder};
use actix_web_codegen::{connect, delete, get, head, options, patch, post, put, trace}; use actix_web_codegen::{connect, delete, get, head, options, patch, post, put, trace};
use futures::{future, Future}; use futures::{future, Future};
#[get("/test")] #[get("/test")]
fn test() -> impl Responder { async fn test() -> impl Responder {
HttpResponse::Ok() HttpResponse::Ok()
} }
#[put("/test")] #[put("/test")]
fn put_test() -> impl Responder { async fn put_test() -> impl Responder {
HttpResponse::Created() HttpResponse::Created()
} }
#[patch("/test")] #[patch("/test")]
fn patch_test() -> impl Responder { async fn patch_test() -> impl Responder {
HttpResponse::Ok() HttpResponse::Ok()
} }
#[post("/test")] #[post("/test")]
fn post_test() -> impl Responder { async fn post_test() -> impl Responder {
HttpResponse::NoContent() HttpResponse::NoContent()
} }
#[head("/test")] #[head("/test")]
fn head_test() -> impl Responder { async fn head_test() -> impl Responder {
HttpResponse::Ok() HttpResponse::Ok()
} }
#[connect("/test")] #[connect("/test")]
fn connect_test() -> impl Responder { async fn connect_test() -> impl Responder {
HttpResponse::Ok() HttpResponse::Ok()
} }
#[options("/test")] #[options("/test")]
fn options_test() -> impl Responder { async fn options_test() -> impl Responder {
HttpResponse::Ok() HttpResponse::Ok()
} }
#[trace("/test")] #[trace("/test")]
fn trace_test() -> impl Responder { async fn trace_test() -> impl Responder {
HttpResponse::Ok() HttpResponse::Ok()
} }
#[get("/test")] #[get("/test")]
fn auto_async() -> impl Future<Item = HttpResponse, Error = actix_web::Error> { fn auto_async() -> impl Future<Output = Result<HttpResponse, actix_web::Error>> {
future::ok(HttpResponse::Ok().finish()) future::ok(HttpResponse::Ok().finish())
} }
#[get("/test")] #[get("/test")]
fn auto_sync() -> impl Future<Item = HttpResponse, Error = actix_web::Error> { fn auto_sync() -> impl Future<Output = Result<HttpResponse, actix_web::Error>> {
future::ok(HttpResponse::Ok().finish()) future::ok(HttpResponse::Ok().finish())
} }
#[put("/test/{param}")] #[put("/test/{param}")]
fn put_param_test(_: Path<String>) -> impl Responder { async fn put_param_test(_: Path<String>) -> impl Responder {
HttpResponse::Created() HttpResponse::Created()
} }
#[delete("/test/{param}")] #[delete("/test/{param}")]
fn delete_param_test(_: Path<String>) -> impl Responder { async fn delete_param_test(_: Path<String>) -> impl Responder {
HttpResponse::NoContent() HttpResponse::NoContent()
} }
#[get("/test/{param}")] #[get("/test/{param}")]
fn get_param_test(_: Path<String>) -> impl Responder { async fn get_param_test(_: Path<String>) -> impl Responder {
HttpResponse::Ok() HttpResponse::Ok()
} }
#[test] #[test]
fn test_params() { fn test_params() {
let mut srv = TestServer::new(|| { block_on(async {
HttpService::new( let srv = TestServer::start(|| {
App::new() HttpService::new(
.service(get_param_test) App::new()
.service(put_param_test) .service(get_param_test)
.service(delete_param_test), .service(put_param_test)
) .service(delete_param_test),
}); )
});
let request = srv.request(http::Method::GET, srv.url("/test/it")); let request = srv.request(http::Method::GET, srv.url("/test/it"));
let response = srv.block_on(request.send()).unwrap(); let response = request.send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::OK); assert_eq!(response.status(), http::StatusCode::OK);
let request = srv.request(http::Method::PUT, srv.url("/test/it")); let request = srv.request(http::Method::PUT, srv.url("/test/it"));
let response = srv.block_on(request.send()).unwrap(); let response = request.send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::CREATED); assert_eq!(response.status(), http::StatusCode::CREATED);
let request = srv.request(http::Method::DELETE, srv.url("/test/it")); let request = srv.request(http::Method::DELETE, srv.url("/test/it"));
let response = srv.block_on(request.send()).unwrap(); let response = request.send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::NO_CONTENT); assert_eq!(response.status(), http::StatusCode::NO_CONTENT);
})
} }
#[test] #[test]
fn test_body() { fn test_body() {
let mut srv = TestServer::new(|| { block_on(async {
HttpService::new( let srv = TestServer::start(|| {
App::new() HttpService::new(
.service(post_test) App::new()
.service(put_test) .service(post_test)
.service(head_test) .service(put_test)
.service(connect_test) .service(head_test)
.service(options_test) .service(connect_test)
.service(trace_test) .service(options_test)
.service(patch_test) .service(trace_test)
.service(test), .service(patch_test)
) .service(test),
}); )
let request = srv.request(http::Method::GET, srv.url("/test")); });
let response = srv.block_on(request.send()).unwrap(); let request = srv.request(http::Method::GET, srv.url("/test"));
assert!(response.status().is_success()); let response = request.send().await.unwrap();
assert!(response.status().is_success());
let request = srv.request(http::Method::HEAD, srv.url("/test")); let request = srv.request(http::Method::HEAD, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap(); let response = request.send().await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
let request = srv.request(http::Method::CONNECT, srv.url("/test")); let request = srv.request(http::Method::CONNECT, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap(); let response = request.send().await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
let request = srv.request(http::Method::OPTIONS, srv.url("/test")); let request = srv.request(http::Method::OPTIONS, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap(); let response = request.send().await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
let request = srv.request(http::Method::TRACE, srv.url("/test")); let request = srv.request(http::Method::TRACE, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap(); let response = request.send().await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
let request = srv.request(http::Method::PATCH, srv.url("/test")); let request = srv.request(http::Method::PATCH, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap(); let response = request.send().await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
let request = srv.request(http::Method::PUT, srv.url("/test")); let request = srv.request(http::Method::PUT, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap(); let response = request.send().await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
assert_eq!(response.status(), http::StatusCode::CREATED); assert_eq!(response.status(), http::StatusCode::CREATED);
let request = srv.request(http::Method::POST, srv.url("/test")); let request = srv.request(http::Method::POST, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap(); let response = request.send().await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
assert_eq!(response.status(), http::StatusCode::NO_CONTENT); assert_eq!(response.status(), http::StatusCode::NO_CONTENT);
let request = srv.request(http::Method::GET, srv.url("/test")); let request = srv.request(http::Method::GET, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap(); let response = request.send().await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
})
} }
#[test] #[test]
fn test_auto_async() { fn test_auto_async() {
let mut srv = TestServer::new(|| HttpService::new(App::new().service(auto_async))); block_on(async {
let srv = TestServer::start(|| HttpService::new(App::new().service(auto_async)));
let request = srv.request(http::Method::GET, srv.url("/test")); let request = srv.request(http::Method::GET, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap(); let response = request.send().await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
})
} }

View File

@@ -1,5 +1,9 @@
# Changes # Changes
## [0.2.8] - 2019-11-06
* Add support for setting query from Serialize type for client request.
## [0.2.7] - 2019-09-25 ## [0.2.7] - 2019-09-25

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "awc" name = "awc"
version = "0.2.7" version = "0.3.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix http client." description = "Actix http client."
readme = "README.md" readme = "README.md"
@@ -21,16 +21,16 @@ name = "awc"
path = "src/lib.rs" path = "src/lib.rs"
[package.metadata.docs.rs] [package.metadata.docs.rs]
features = ["ssl", "brotli", "flate2-zlib"] features = ["openssl", "brotli", "flate2-zlib"]
[features] [features]
default = ["brotli", "flate2-zlib"] default = ["brotli", "flate2-zlib"]
# openssl # openssl
ssl = ["openssl", "actix-http/ssl"] openssl = ["open-ssl", "actix-http/openssl"]
# rustls # rustls
rust-tls = ["rustls", "actix-http/rust-tls"] # rustls = ["rust-tls", "actix-http/rustls"]
# brotli encoding, requires c compiler # brotli encoding, requires c compiler
brotli = ["actix-http/brotli"] brotli = ["actix-http/brotli"]
@@ -42,13 +42,14 @@ flate2-zlib = ["actix-http/flate2-zlib"]
flate2-rust = ["actix-http/flate2-rust"] flate2-rust = ["actix-http/flate2-rust"]
[dependencies] [dependencies]
actix-codec = "0.1.2" actix-codec = "0.2.0-alpha.1"
actix-service = "0.4.1" actix-service = "1.0.0-alpha.1"
actix-http = "0.2.10" actix-http = "0.3.0-alpha.1"
base64 = "0.10.1" base64 = "0.10.1"
bytes = "0.4" bytes = "0.4"
derive_more = "0.15.0" derive_more = "0.15.0"
futures = "0.1.25" futures = "0.3.1"
log =" 0.4" log =" 0.4"
mime = "0.3" mime = "0.3"
percent-encoding = "2.1" percent-encoding = "2.1"
@@ -56,21 +57,21 @@ rand = "0.7"
serde = "1.0" serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
serde_urlencoded = "0.6.1" serde_urlencoded = "0.6.1"
tokio-timer = "0.2.8" tokio-timer = "0.3.0-alpha.6"
openssl = { version="0.10", optional = true } open-ssl = { version="0.10", package="openssl", optional = true }
rustls = { version = "0.15.2", optional = true } rust-tls = { version = "0.16.0", package="rustls", optional = true, features = ["dangerous_configuration"] }
[dev-dependencies] [dev-dependencies]
actix-rt = "0.2.2" actix-rt = "1.0.0-alpha.1"
actix-web = { version = "1.0.0", features=["ssl"] } actix-connect = { version = "1.0.0-alpha.1", features=["openssl"] }
actix-http = { version = "0.2.10", features=["ssl"] } actix-web = { version = "2.0.0-alpha.1", features=["openssl"] }
actix-http-test = { version = "0.2.0", features=["ssl"] } actix-http = { version = "0.3.0-alpha.1", features=["openssl"] }
actix-utils = "0.4.1" actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] }
actix-server = { version = "0.6.0", features=["ssl", "rust-tls"] } actix-utils = "0.5.0-alpha.1"
actix-server = { version = "0.8.0-alpha.1", features=["openssl"] }
brotli2 = { version="0.3.2" } brotli2 = { version="0.3.2" }
flate2 = { version="1.0.2" } flate2 = { version="1.0.2" }
env_logger = "0.6" env_logger = "0.6"
rand = "0.7" rand = "0.7"
tokio-tcp = "0.1" tokio-tcp = "0.1"
webpki = "0.19" webpki = { version = "0.21" }
rustls = { version = "0.15.2", features = ["dangerous_configuration"] }

View File

@@ -1,4 +1,6 @@
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use std::{fmt, io, net}; use std::{fmt, io, net};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
@@ -10,7 +12,7 @@ use actix_http::h1::ClientCodec;
use actix_http::http::HeaderMap; use actix_http::http::HeaderMap;
use actix_http::{RequestHead, RequestHeadType, ResponseHead}; use actix_http::{RequestHead, RequestHeadType, ResponseHead};
use actix_service::Service; use actix_service::Service;
use futures::{Future, Poll}; use futures::future::{FutureExt, LocalBoxFuture};
use crate::response::ClientResponse; use crate::response::ClientResponse;
@@ -22,7 +24,7 @@ pub(crate) trait Connect {
head: RequestHead, head: RequestHead,
body: Body, body: Body,
addr: Option<net::SocketAddr>, addr: Option<net::SocketAddr>,
) -> Box<dyn Future<Item = ClientResponse, Error = SendRequestError>>; ) -> LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>>;
fn send_request_extra( fn send_request_extra(
&mut self, &mut self,
@@ -30,18 +32,16 @@ pub(crate) trait Connect {
extra_headers: Option<HeaderMap>, extra_headers: Option<HeaderMap>,
body: Body, body: Body,
addr: Option<net::SocketAddr>, addr: Option<net::SocketAddr>,
) -> Box<dyn Future<Item = ClientResponse, Error = SendRequestError>>; ) -> LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>>;
/// Send request, returns Response and Framed /// Send request, returns Response and Framed
fn open_tunnel( fn open_tunnel(
&mut self, &mut self,
head: RequestHead, head: RequestHead,
addr: Option<net::SocketAddr>, addr: Option<net::SocketAddr>,
) -> Box< ) -> LocalBoxFuture<
dyn Future< 'static,
Item = (ResponseHead, Framed<BoxedSocket, ClientCodec>), Result<(ResponseHead, Framed<BoxedSocket, ClientCodec>), SendRequestError>,
Error = SendRequestError,
>,
>; >;
/// Send request and extra headers, returns Response and Framed /// Send request and extra headers, returns Response and Framed
@@ -50,11 +50,9 @@ pub(crate) trait Connect {
head: Rc<RequestHead>, head: Rc<RequestHead>,
extra_headers: Option<HeaderMap>, extra_headers: Option<HeaderMap>,
addr: Option<net::SocketAddr>, addr: Option<net::SocketAddr>,
) -> Box< ) -> LocalBoxFuture<
dyn Future< 'static,
Item = (ResponseHead, Framed<BoxedSocket, ClientCodec>), Result<(ResponseHead, Framed<BoxedSocket, ClientCodec>), SendRequestError>,
Error = SendRequestError,
>,
>; >;
} }
@@ -72,21 +70,23 @@ where
head: RequestHead, head: RequestHead,
body: Body, body: Body,
addr: Option<net::SocketAddr>, addr: Option<net::SocketAddr>,
) -> Box<dyn Future<Item = ClientResponse, Error = SendRequestError>> { ) -> LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>> {
Box::new( // connect to the host
self.0 let fut = self.0.call(ClientConnect {
// connect to the host uri: head.uri.clone(),
.call(ClientConnect { addr,
uri: head.uri.clone(), });
addr,
}) async move {
.from_err() let connection = fut.await?;
// send request
.and_then(move |connection| { // send request
connection.send_request(RequestHeadType::from(head), body) connection
}) .send_request(RequestHeadType::from(head), body)
.map(|(head, payload)| ClientResponse::new(head, payload)), .await
) .map(|(head, payload)| ClientResponse::new(head, payload))
}
.boxed_local()
} }
fn send_request_extra( fn send_request_extra(
@@ -95,51 +95,51 @@ where
extra_headers: Option<HeaderMap>, extra_headers: Option<HeaderMap>,
body: Body, body: Body,
addr: Option<net::SocketAddr>, addr: Option<net::SocketAddr>,
) -> Box<dyn Future<Item = ClientResponse, Error = SendRequestError>> { ) -> LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>> {
Box::new( // connect to the host
self.0 let fut = self.0.call(ClientConnect {
// connect to the host uri: head.uri.clone(),
.call(ClientConnect { addr,
uri: head.uri.clone(), });
addr,
}) async move {
.from_err() let connection = fut.await?;
// send request
.and_then(move |connection| { // send request
connection let (head, payload) = connection
.send_request(RequestHeadType::Rc(head, extra_headers), body) .send_request(RequestHeadType::Rc(head, extra_headers), body)
}) .await?;
.map(|(head, payload)| ClientResponse::new(head, payload)),
) Ok(ClientResponse::new(head, payload))
}
.boxed_local()
} }
fn open_tunnel( fn open_tunnel(
&mut self, &mut self,
head: RequestHead, head: RequestHead,
addr: Option<net::SocketAddr>, addr: Option<net::SocketAddr>,
) -> Box< ) -> LocalBoxFuture<
dyn Future< 'static,
Item = (ResponseHead, Framed<BoxedSocket, ClientCodec>), Result<(ResponseHead, Framed<BoxedSocket, ClientCodec>), SendRequestError>,
Error = SendRequestError,
>,
> { > {
Box::new( // connect to the host
self.0 let fut = self.0.call(ClientConnect {
// connect to the host uri: head.uri.clone(),
.call(ClientConnect { addr,
uri: head.uri.clone(), });
addr,
}) async move {
.from_err() let connection = fut.await?;
// send request
.and_then(move |connection| { // send request
connection.open_tunnel(RequestHeadType::from(head)) let (head, framed) =
}) connection.open_tunnel(RequestHeadType::from(head)).await?;
.map(|(head, framed)| {
let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io))));
(head, framed) Ok((head, framed))
}), }
) .boxed_local()
} }
fn open_tunnel_extra( fn open_tunnel_extra(
@@ -147,48 +147,47 @@ where
head: Rc<RequestHead>, head: Rc<RequestHead>,
extra_headers: Option<HeaderMap>, extra_headers: Option<HeaderMap>,
addr: Option<net::SocketAddr>, addr: Option<net::SocketAddr>,
) -> Box< ) -> LocalBoxFuture<
dyn Future< 'static,
Item = (ResponseHead, Framed<BoxedSocket, ClientCodec>), Result<(ResponseHead, Framed<BoxedSocket, ClientCodec>), SendRequestError>,
Error = SendRequestError,
>,
> { > {
Box::new( // connect to the host
self.0 let fut = self.0.call(ClientConnect {
// connect to the host uri: head.uri.clone(),
.call(ClientConnect { addr,
uri: head.uri.clone(), });
addr,
}) async move {
.from_err() let connection = fut.await?;
// send request
.and_then(move |connection| { // send request
connection.open_tunnel(RequestHeadType::Rc(head, extra_headers)) let (head, framed) = connection
}) .open_tunnel(RequestHeadType::Rc(head, extra_headers))
.map(|(head, framed)| { .await?;
let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io))));
(head, framed) let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io))));
}), Ok((head, framed))
) }
.boxed_local()
} }
} }
trait AsyncSocket { trait AsyncSocket {
fn as_read(&self) -> &dyn AsyncRead; fn as_read(&self) -> &(dyn AsyncRead + Unpin);
fn as_read_mut(&mut self) -> &mut dyn AsyncRead; fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin);
fn as_write(&mut self) -> &mut dyn AsyncWrite; fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin);
} }
struct Socket<T: AsyncRead + AsyncWrite>(T); struct Socket<T: AsyncRead + AsyncWrite + Unpin>(T);
impl<T: AsyncRead + AsyncWrite> AsyncSocket for Socket<T> { impl<T: AsyncRead + AsyncWrite + Unpin> AsyncSocket for Socket<T> {
fn as_read(&self) -> &dyn AsyncRead { fn as_read(&self) -> &(dyn AsyncRead + Unpin) {
&self.0 &self.0
} }
fn as_read_mut(&mut self) -> &mut dyn AsyncRead { fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin) {
&mut self.0 &mut self.0
} }
fn as_write(&mut self) -> &mut dyn AsyncWrite { fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin) {
&mut self.0 &mut self.0
} }
} }
@@ -201,30 +200,37 @@ impl fmt::Debug for BoxedSocket {
} }
} }
impl io::Read for BoxedSocket {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.as_read_mut().read(buf)
}
}
impl AsyncRead for BoxedSocket { impl AsyncRead for BoxedSocket {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.0.as_read().prepare_uninitialized_buffer(buf) self.0.as_read().prepare_uninitialized_buffer(buf)
} }
}
impl io::Write for BoxedSocket { fn poll_read(
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { self: Pin<&mut Self>,
self.0.as_write().write(buf) cx: &mut Context<'_>,
} buf: &mut [u8],
) -> Poll<io::Result<usize>> {
fn flush(&mut self) -> io::Result<()> { Pin::new(self.get_mut().0.as_read_mut()).poll_read(cx, buf)
self.0.as_write().flush()
} }
} }
impl AsyncWrite for BoxedSocket { impl AsyncWrite for BoxedSocket {
fn shutdown(&mut self) -> Poll<(), io::Error> { fn poll_write(
self.0.as_write().shutdown() self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(self.get_mut().0.as_write()).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(self.get_mut().0.as_write()).poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Pin::new(self.get_mut().0.as_write()).poll_shutdown(cx)
} }
} }

View File

@@ -82,7 +82,7 @@ impl FrozenClientRequest {
/// Send a streaming body. /// Send a streaming body.
pub fn send_stream<S, E>(&self, stream: S) -> SendClientRequest pub fn send_stream<S, E>(&self, stream: S) -> SendClientRequest
where where
S: Stream<Item = Bytes, Error = E> + 'static, S: Stream<Item = Result<Bytes, E>> + Unpin + 'static,
E: Into<Error> + 'static, E: Into<Error> + 'static,
{ {
RequestSender::Rc(self.head.clone(), None).send_stream( RequestSender::Rc(self.head.clone(), None).send_stream(
@@ -203,7 +203,7 @@ impl FrozenSendBuilder {
/// Complete request construction and send a streaming body. /// Complete request construction and send a streaming body.
pub fn send_stream<S, E>(self, stream: S) -> SendClientRequest pub fn send_stream<S, E>(self, stream: S) -> SendClientRequest
where where
S: Stream<Item = Bytes, Error = E> + 'static, S: Stream<Item = Result<Bytes, E>> + Unpin + 'static,
E: Into<Error> + 'static, E: Into<Error> + 'static,
{ {
if let Some(e) = self.err { if let Some(e) = self.err {

View File

@@ -7,18 +7,18 @@
//! use awc::Client; //! use awc::Client;
//! //!
//! fn main() { //! fn main() {
//! System::new("test").block_on(lazy(|| { //! System::new("test").block_on(async {
//! let mut client = Client::default(); //! let mut client = Client::default();
//! //!
//! client.get("http://www.rust-lang.org") // <- Create request builder //! client.get("http://www.rust-lang.org") // <- Create request builder
//! .header("User-Agent", "Actix-web") //! .header("User-Agent", "Actix-web")
//! .send() // <- Send http request //! .send() // <- Send http request
//! .map_err(|_| ()) //! .await
//! .and_then(|response| { // <- server http response //! .and_then(|response| { // <- server http response
//! println!("Response: {:?}", response); //! println!("Response: {:?}", response);
//! Ok(()) //! Ok(())
//! }) //! })
//! })); //! });
//! } //! }
//! ``` //! ```
use std::cell::RefCell; use std::cell::RefCell;
@@ -52,23 +52,22 @@ use self::connect::{Connect, ConnectorWrapper};
/// An HTTP Client /// An HTTP Client
/// ///
/// ```rust /// ```rust
/// # use futures::future::{Future, lazy};
/// use actix_rt::System; /// use actix_rt::System;
/// use awc::Client; /// use awc::Client;
/// ///
/// fn main() { /// fn main() {
/// System::new("test").block_on(lazy(|| { /// System::new("test").block_on(async {
/// let mut client = Client::default(); /// let mut client = Client::default();
/// ///
/// client.get("http://www.rust-lang.org") // <- Create request builder /// client.get("http://www.rust-lang.org") // <- Create request builder
/// .header("User-Agent", "Actix-web") /// .header("User-Agent", "Actix-web")
/// .send() // <- Send http request /// .send() // <- Send http request
/// .map_err(|_| ()) /// .await
/// .and_then(|response| { // <- server http response /// .and_then(|response| { // <- server http response
/// println!("Response: {:?}", response); /// println!("Response: {:?}", response);
/// Ok(()) /// Ok(())
/// }) /// })
/// })); /// });
/// } /// }
/// ``` /// ```
#[derive(Clone)] #[derive(Clone)]

View File

@@ -37,21 +37,21 @@ const HTTPS_ENCODING: &str = "gzip, deflate";
/// builder-like pattern. /// builder-like pattern.
/// ///
/// ```rust /// ```rust
/// use futures::future::{Future, lazy};
/// use actix_rt::System; /// use actix_rt::System;
/// ///
/// fn main() { /// fn main() {
/// System::new("test").block_on(lazy(|| { /// System::new("test").block_on(async {
/// awc::Client::new() /// let response = awc::Client::new()
/// .get("http://www.rust-lang.org") // <- Create request builder /// .get("http://www.rust-lang.org") // <- Create request builder
/// .header("User-Agent", "Actix-web") /// .header("User-Agent", "Actix-web")
/// .send() // <- Send http request /// .send() // <- Send http request
/// .map_err(|_| ()) /// .await;
/// .and_then(|response| { // <- server http response ///
/// println!("Response: {:?}", response); /// response.and_then(|response| { // <- server http response
/// Ok(()) /// println!("Response: {:?}", response);
/// Ok(())
/// }) /// })
/// })); /// });
/// } /// }
/// ``` /// ```
pub struct ClientRequest { pub struct ClientRequest {
@@ -158,7 +158,7 @@ impl ClientRequest {
/// ///
/// ```rust /// ```rust
/// fn main() { /// fn main() {
/// # actix_rt::System::new("test").block_on(futures::future::lazy(|| { /// # actix_rt::System::new("test").block_on(futures::future::lazy(|_| {
/// let req = awc::Client::new() /// let req = awc::Client::new()
/// .get("http://www.rust-lang.org") /// .get("http://www.rust-lang.org")
/// .set(awc::http::header::Date::now()) /// .set(awc::http::header::Date::now())
@@ -186,13 +186,13 @@ impl ClientRequest {
/// use awc::{http, Client}; /// use awc::{http, Client};
/// ///
/// fn main() { /// fn main() {
/// # actix_rt::System::new("test").block_on(futures::future::lazy(|| { /// # actix_rt::System::new("test").block_on(async {
/// let req = Client::new() /// let req = Client::new()
/// .get("http://www.rust-lang.org") /// .get("http://www.rust-lang.org")
/// .header("X-TEST", "value") /// .header("X-TEST", "value")
/// .header(http::header::CONTENT_TYPE, "application/json"); /// .header(http::header::CONTENT_TYPE, "application/json");
/// # Ok::<_, ()>(()) /// # Ok::<_, ()>(())
/// # })); /// # });
/// } /// }
/// ``` /// ```
pub fn header<K, V>(mut self, key: K, value: V) -> Self pub fn header<K, V>(mut self, key: K, value: V) -> Self
@@ -309,9 +309,8 @@ impl ClientRequest {
/// ///
/// ```rust /// ```rust
/// # use actix_rt::System; /// # use actix_rt::System;
/// # use futures::future::{lazy, Future};
/// fn main() { /// fn main() {
/// System::new("test").block_on(lazy(|| { /// System::new("test").block_on(async {
/// awc::Client::new().get("https://www.rust-lang.org") /// awc::Client::new().get("https://www.rust-lang.org")
/// .cookie( /// .cookie(
/// awc::http::Cookie::build("name", "value") /// awc::http::Cookie::build("name", "value")
@@ -322,12 +321,12 @@ impl ClientRequest {
/// .finish(), /// .finish(),
/// ) /// )
/// .send() /// .send()
/// .map_err(|_| ()) /// .await
/// .and_then(|response| { /// .and_then(|response| {
/// println!("Response: {:?}", response); /// println!("Response: {:?}", response);
/// Ok(()) /// Ok(())
/// }) /// })
/// })); /// });
/// } /// }
/// ``` /// ```
pub fn cookie(mut self, cookie: Cookie<'_>) -> Self { pub fn cookie(mut self, cookie: Cookie<'_>) -> Self {
@@ -382,6 +381,27 @@ impl ClientRequest {
} }
} }
/// Sets the query part of the request
pub fn query<T: Serialize>(
mut self,
query: &T,
) -> Result<Self, serde_urlencoded::ser::Error> {
let mut parts = self.head.uri.clone().into_parts();
if let Some(path_and_query) = parts.path_and_query {
let query = serde_urlencoded::to_string(query)?;
let path = path_and_query.path();
parts.path_and_query = format!("{}?{}", path, query).parse().ok();
match Uri::from_parts(parts) {
Ok(uri) => self.head.uri = uri,
Err(e) => self.err = Some(e.into()),
}
}
Ok(self)
}
/// Freeze request builder and construct `FrozenClientRequest`, /// Freeze request builder and construct `FrozenClientRequest`,
/// which could be used for sending same request multiple times. /// which could be used for sending same request multiple times.
pub fn freeze(self) -> Result<FrozenClientRequest, FreezeRequestError> { pub fn freeze(self) -> Result<FrozenClientRequest, FreezeRequestError> {
@@ -457,7 +477,7 @@ impl ClientRequest {
/// Set an streaming body and generate `ClientRequest`. /// Set an streaming body and generate `ClientRequest`.
pub fn send_stream<S, E>(self, stream: S) -> SendClientRequest pub fn send_stream<S, E>(self, stream: S) -> SendClientRequest
where where
S: Stream<Item = Bytes, Error = E> + 'static, S: Stream<Item = Result<Bytes, E>> + Unpin + 'static,
E: Into<Error> + 'static, E: Into<Error> + 'static,
{ {
let slf = match self.prep_for_sending() { let slf = match self.prep_for_sending() {
@@ -690,4 +710,13 @@ mod tests {
"Bearer someS3cr3tAutht0k3n" "Bearer someS3cr3tAutht0k3n"
); );
} }
#[test]
fn client_query() {
let req = Client::new()
.get("/")
.query(&[("key1", "val1"), ("key2", "val2")])
.unwrap();
assert_eq!(req.get_uri().query().unwrap(), "key1=val1&key2=val2");
}
} }

View File

@@ -1,9 +1,11 @@
use std::cell::{Ref, RefMut}; use std::cell::{Ref, RefMut};
use std::fmt; use std::fmt;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{Async, Future, Poll, Stream}; use futures::{ready, Future, Stream};
use actix_http::cookie::Cookie; use actix_http::cookie::Cookie;
use actix_http::error::{CookieParseError, PayloadError}; use actix_http::error::{CookieParseError, PayloadError};
@@ -104,7 +106,7 @@ impl<S> ClientResponse<S> {
impl<S> ClientResponse<S> impl<S> ClientResponse<S>
where where
S: Stream<Item = Bytes, Error = PayloadError>, S: Stream<Item = Result<Bytes, PayloadError>>,
{ {
/// Loads http response's body. /// Loads http response's body.
pub fn body(&mut self) -> MessageBody<S> { pub fn body(&mut self) -> MessageBody<S> {
@@ -125,13 +127,12 @@ where
impl<S> Stream for ClientResponse<S> impl<S> Stream for ClientResponse<S>
where where
S: Stream<Item = Bytes, Error = PayloadError>, S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{ {
type Item = Bytes; type Item = Result<Bytes, PayloadError>;
type Error = PayloadError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.payload.poll() Pin::new(&mut self.get_mut().payload).poll_next(cx)
} }
} }
@@ -155,7 +156,7 @@ pub struct MessageBody<S> {
impl<S> MessageBody<S> impl<S> MessageBody<S>
where where
S: Stream<Item = Bytes, Error = PayloadError>, S: Stream<Item = Result<Bytes, PayloadError>>,
{ {
/// Create `MessageBody` for request. /// Create `MessageBody` for request.
pub fn new(res: &mut ClientResponse<S>) -> MessageBody<S> { pub fn new(res: &mut ClientResponse<S>) -> MessageBody<S> {
@@ -198,23 +199,24 @@ where
impl<S> Future for MessageBody<S> impl<S> Future for MessageBody<S>
where where
S: Stream<Item = Bytes, Error = PayloadError>, S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{ {
type Item = Bytes; type Output = Result<Bytes, PayloadError>;
type Error = PayloadError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let Some(err) = self.err.take() { let this = self.get_mut();
return Err(err);
if let Some(err) = this.err.take() {
return Poll::Ready(Err(err));
} }
if let Some(len) = self.length.take() { if let Some(len) = this.length.take() {
if len > self.fut.as_ref().unwrap().limit { if len > this.fut.as_ref().unwrap().limit {
return Err(PayloadError::Overflow); return Poll::Ready(Err(PayloadError::Overflow));
} }
} }
self.fut.as_mut().unwrap().poll() Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx)
} }
} }
@@ -233,7 +235,7 @@ pub struct JsonBody<S, U> {
impl<S, U> JsonBody<S, U> impl<S, U> JsonBody<S, U>
where where
S: Stream<Item = Bytes, Error = PayloadError>, S: Stream<Item = Result<Bytes, PayloadError>>,
U: DeserializeOwned, U: DeserializeOwned,
{ {
/// Create `JsonBody` for request. /// Create `JsonBody` for request.
@@ -279,27 +281,35 @@ where
} }
} }
impl<T, U> Future for JsonBody<T, U> impl<T, U> Unpin for JsonBody<T, U>
where where
T: Stream<Item = Bytes, Error = PayloadError>, T: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
U: DeserializeOwned, U: DeserializeOwned,
{ {
type Item = U; }
type Error = JsonPayloadError;
fn poll(&mut self) -> Poll<U, JsonPayloadError> { impl<T, U> Future for JsonBody<T, U>
where
T: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
U: DeserializeOwned,
{
type Output = Result<U, JsonPayloadError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let Some(err) = self.err.take() { if let Some(err) = self.err.take() {
return Err(err); return Poll::Ready(Err(err));
} }
if let Some(len) = self.length.take() { if let Some(len) = self.length.take() {
if len > self.fut.as_ref().unwrap().limit { if len > self.fut.as_ref().unwrap().limit {
return Err(JsonPayloadError::Payload(PayloadError::Overflow)); return Poll::Ready(Err(JsonPayloadError::Payload(
PayloadError::Overflow,
)));
} }
} }
let body = futures::try_ready!(self.fut.as_mut().unwrap().poll()); let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?;
Ok(Async::Ready(serde_json::from_slice::<U>(&body)?)) Poll::Ready(serde_json::from_slice::<U>(&body).map_err(JsonPayloadError::from))
} }
} }
@@ -321,24 +331,25 @@ impl<S> ReadBody<S> {
impl<S> Future for ReadBody<S> impl<S> Future for ReadBody<S>
where where
S: Stream<Item = Bytes, Error = PayloadError>, S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{ {
type Item = Bytes; type Output = Result<Bytes, PayloadError>;
type Error = PayloadError;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop { loop {
return match self.stream.poll()? { return match Pin::new(&mut this.stream).poll_next(cx)? {
Async::Ready(Some(chunk)) => { Poll::Ready(Some(chunk)) => {
if (self.buf.len() + chunk.len()) > self.limit { if (this.buf.len() + chunk.len()) > this.limit {
Err(PayloadError::Overflow) Poll::Ready(Err(PayloadError::Overflow))
} else { } else {
self.buf.extend_from_slice(&chunk); this.buf.extend_from_slice(&chunk);
continue; continue;
} }
} }
Async::Ready(None) => Ok(Async::Ready(self.buf.take().freeze())), Poll::Ready(None) => Poll::Ready(Ok(this.buf.take().freeze())),
Async::NotReady => Ok(Async::NotReady), Poll::Pending => Poll::Pending,
}; };
} }
} }
@@ -348,41 +359,40 @@ where
mod tests { mod tests {
use super::*; use super::*;
use actix_http_test::block_on; use actix_http_test::block_on;
use futures::Async;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{http::header, test::TestResponse}; use crate::{http::header, test::TestResponse};
#[test] #[test]
fn test_body() { fn test_body() {
let mut req = TestResponse::with_header(header::CONTENT_LENGTH, "xxxx").finish(); block_on(async {
match req.body().poll().err().unwrap() { let mut req =
PayloadError::UnknownLength => (), TestResponse::with_header(header::CONTENT_LENGTH, "xxxx").finish();
_ => unreachable!("error"), match req.body().await.err().unwrap() {
} PayloadError::UnknownLength => (),
_ => unreachable!("error"),
}
let mut req = let mut req =
TestResponse::with_header(header::CONTENT_LENGTH, "1000000").finish(); TestResponse::with_header(header::CONTENT_LENGTH, "1000000").finish();
match req.body().poll().err().unwrap() { match req.body().await.err().unwrap() {
PayloadError::Overflow => (), PayloadError::Overflow => (),
_ => unreachable!("error"), _ => unreachable!("error"),
} }
let mut req = TestResponse::default() let mut req = TestResponse::default()
.set_payload(Bytes::from_static(b"test")) .set_payload(Bytes::from_static(b"test"))
.finish(); .finish();
match req.body().poll().ok().unwrap() { assert_eq!(req.body().await.ok().unwrap(), Bytes::from_static(b"test"));
Async::Ready(bytes) => assert_eq!(bytes, Bytes::from_static(b"test")),
_ => unreachable!("error"),
}
let mut req = TestResponse::default() let mut req = TestResponse::default()
.set_payload(Bytes::from_static(b"11111111111111")) .set_payload(Bytes::from_static(b"11111111111111"))
.finish(); .finish();
match req.body().limit(5).poll().err().unwrap() { match req.body().limit(5).await.err().unwrap() {
PayloadError::Overflow => (), PayloadError::Overflow => (),
_ => unreachable!("error"), _ => unreachable!("error"),
} }
})
} }
#[derive(Serialize, Deserialize, PartialEq, Debug)] #[derive(Serialize, Deserialize, PartialEq, Debug)]
@@ -406,54 +416,56 @@ mod tests {
#[test] #[test]
fn test_json_body() { fn test_json_body() {
let mut req = TestResponse::default().finish(); block_on(async {
let json = block_on(JsonBody::<_, MyObject>::new(&mut req)); let mut req = TestResponse::default().finish();
assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); let json = JsonBody::<_, MyObject>::new(&mut req).await;
assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
let mut req = TestResponse::default() let mut req = TestResponse::default()
.header( .header(
header::CONTENT_TYPE, header::CONTENT_TYPE,
header::HeaderValue::from_static("application/text"), header::HeaderValue::from_static("application/text"),
) )
.finish(); .finish();
let json = block_on(JsonBody::<_, MyObject>::new(&mut req)); let json = JsonBody::<_, MyObject>::new(&mut req).await;
assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
let mut req = TestResponse::default() let mut req = TestResponse::default()
.header( .header(
header::CONTENT_TYPE, header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"), header::HeaderValue::from_static("application/json"),
) )
.header( .header(
header::CONTENT_LENGTH, header::CONTENT_LENGTH,
header::HeaderValue::from_static("10000"), header::HeaderValue::from_static("10000"),
) )
.finish(); .finish();
let json = block_on(JsonBody::<_, MyObject>::new(&mut req).limit(100)); let json = JsonBody::<_, MyObject>::new(&mut req).limit(100).await;
assert!(json_eq( assert!(json_eq(
json.err().unwrap(), json.err().unwrap(),
JsonPayloadError::Payload(PayloadError::Overflow) JsonPayloadError::Payload(PayloadError::Overflow)
)); ));
let mut req = TestResponse::default() let mut req = TestResponse::default()
.header( .header(
header::CONTENT_TYPE, header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"), header::HeaderValue::from_static("application/json"),
) )
.header( .header(
header::CONTENT_LENGTH, header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"), header::HeaderValue::from_static("16"),
) )
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.finish(); .finish();
let json = block_on(JsonBody::<_, MyObject>::new(&mut req)); let json = JsonBody::<_, MyObject>::new(&mut req).await;
assert_eq!( assert_eq!(
json.ok().unwrap(), json.ok().unwrap(),
MyObject { MyObject {
name: "test".to_owned() name: "test".to_owned()
} }
); );
})
} }
} }

View File

@@ -1,13 +1,15 @@
use std::net; use std::net;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::time::{Duration, Instant}; use std::task::{Context, Poll};
use std::time::Duration;
use bytes::Bytes; use bytes::Bytes;
use derive_more::From; use derive_more::From;
use futures::{try_ready, Async, Future, Poll, Stream}; use futures::{future::LocalBoxFuture, ready, Future, Stream};
use serde::Serialize; use serde::Serialize;
use serde_json; use serde_json;
use tokio_timer::Delay; use tokio_timer::{delay_for, Delay};
use actix_http::body::{Body, BodyStream}; use actix_http::body::{Body, BodyStream};
use actix_http::encoding::Decoder; use actix_http::encoding::Decoder;
@@ -47,7 +49,7 @@ impl Into<SendRequestError> for PrepForSendingError {
#[must_use = "futures do nothing unless polled"] #[must_use = "futures do nothing unless polled"]
pub enum SendClientRequest { pub enum SendClientRequest {
Fut( Fut(
Box<dyn Future<Item = ClientResponse, Error = SendRequestError>>, LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>>,
Option<Delay>, Option<Delay>,
bool, bool,
), ),
@@ -56,41 +58,51 @@ pub enum SendClientRequest {
impl SendClientRequest { impl SendClientRequest {
pub(crate) fn new( pub(crate) fn new(
send: Box<dyn Future<Item = ClientResponse, Error = SendRequestError>>, send: LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>>,
response_decompress: bool, response_decompress: bool,
timeout: Option<Duration>, timeout: Option<Duration>,
) -> SendClientRequest { ) -> SendClientRequest {
let delay = timeout.map(|t| Delay::new(Instant::now() + t)); let delay = timeout.map(|t| delay_for(t));
SendClientRequest::Fut(send, delay, response_decompress) SendClientRequest::Fut(send, delay, response_decompress)
} }
} }
impl Future for SendClientRequest { impl Future for SendClientRequest {
type Item = ClientResponse<Decoder<Payload<PayloadStream>>>; type Output =
type Error = SendRequestError; Result<ClientResponse<Decoder<Payload<PayloadStream>>>, SendRequestError>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match self { let this = self.get_mut();
match this {
SendClientRequest::Fut(send, delay, response_decompress) => { SendClientRequest::Fut(send, delay, response_decompress) => {
if delay.is_some() { if delay.is_some() {
match delay.poll() { match Pin::new(delay.as_mut().unwrap()).poll(cx) {
Ok(Async::NotReady) => (), Poll::Pending => (),
_ => return Err(SendRequestError::Timeout), _ => return Poll::Ready(Err(SendRequestError::Timeout)),
} }
} }
let res = try_ready!(send.poll()).map_body(|head, payload| { let res = ready!(Pin::new(send).poll(cx)).map(|res| {
if *response_decompress { res.map_body(|head, payload| {
Payload::Stream(Decoder::from_headers(payload, &head.headers)) if *response_decompress {
} else { Payload::Stream(Decoder::from_headers(
Payload::Stream(Decoder::new(payload, ContentEncoding::Identity)) payload,
} &head.headers,
))
} else {
Payload::Stream(Decoder::new(
payload,
ContentEncoding::Identity,
))
}
})
}); });
Ok(Async::Ready(res)) Poll::Ready(res)
} }
SendClientRequest::Err(ref mut e) => match e.take() { SendClientRequest::Err(ref mut e) => match e.take() {
Some(e) => Err(e), Some(e) => Poll::Ready(Err(e)),
None => panic!("Attempting to call completed future"), None => panic!("Attempting to call completed future"),
}, },
} }
@@ -223,7 +235,7 @@ impl RequestSender {
stream: S, stream: S,
) -> SendClientRequest ) -> SendClientRequest
where where
S: Stream<Item = Bytes, Error = E> + 'static, S: Stream<Item = Result<Bytes, E>> + Unpin + 'static,
E: Into<Error> + 'static, E: Into<Error> + 'static,
{ {
self.send_body( self.send_body(

View File

@@ -7,7 +7,6 @@ use std::{fmt, str};
use actix_codec::Framed; use actix_codec::Framed;
use actix_http::cookie::{Cookie, CookieJar}; use actix_http::cookie::{Cookie, CookieJar};
use actix_http::{ws, Payload, RequestHead}; use actix_http::{ws, Payload, RequestHead};
use futures::future::{err, Either, Future};
use percent_encoding::percent_encode; use percent_encoding::percent_encode;
use tokio_timer::Timeout; use tokio_timer::Timeout;
@@ -210,27 +209,26 @@ impl WebsocketsRequest {
} }
/// Complete request construction and connect to a websockets server. /// Complete request construction and connect to a websockets server.
pub fn connect( pub async fn connect(
mut self, mut self,
) -> impl Future<Item = (ClientResponse, Framed<BoxedSocket, Codec>), Error = WsClientError> ) -> Result<(ClientResponse, Framed<BoxedSocket, Codec>), WsClientError> {
{
if let Some(e) = self.err.take() { if let Some(e) = self.err.take() {
return Either::A(err(e.into())); return Err(e.into());
} }
// validate uri // validate uri
let uri = &self.head.uri; let uri = &self.head.uri;
if uri.host().is_none() { if uri.host().is_none() {
return Either::A(err(InvalidUrl::MissingHost.into())); return Err(InvalidUrl::MissingHost.into());
} else if uri.scheme_part().is_none() { } else if uri.scheme_part().is_none() {
return Either::A(err(InvalidUrl::MissingScheme.into())); return Err(InvalidUrl::MissingScheme.into());
} else if let Some(scheme) = uri.scheme_part() { } else if let Some(scheme) = uri.scheme_part() {
match scheme.as_str() { match scheme.as_str() {
"http" | "ws" | "https" | "wss" => (), "http" | "ws" | "https" | "wss" => (),
_ => return Either::A(err(InvalidUrl::UnknownScheme.into())), _ => return Err(InvalidUrl::UnknownScheme.into()),
} }
} else { } else {
return Either::A(err(InvalidUrl::UnknownScheme.into())); return Err(InvalidUrl::UnknownScheme.into());
} }
if !self.head.headers.contains_key(header::HOST) { if !self.head.headers.contains_key(header::HOST) {
@@ -294,90 +292,83 @@ impl WebsocketsRequest {
.config .config
.connector .connector
.borrow_mut() .borrow_mut()
.open_tunnel(head, self.addr) .open_tunnel(head, self.addr);
.from_err()
.and_then(move |(head, framed)| {
// verify response
if head.status != StatusCode::SWITCHING_PROTOCOLS {
return Err(WsClientError::InvalidResponseStatus(head.status));
}
// Check for "UPGRADE" to websocket header
let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) {
if let Ok(s) = hdr.to_str() {
s.to_ascii_lowercase().contains("websocket")
} else {
false
}
} else {
false
};
if !has_hdr {
log::trace!("Invalid upgrade header");
return Err(WsClientError::InvalidUpgradeHeader);
}
// Check for "CONNECTION" header
if let Some(conn) = head.headers.get(&header::CONNECTION) {
if let Ok(s) = conn.to_str() {
if !s.to_ascii_lowercase().contains("upgrade") {
log::trace!("Invalid connection header: {}", s);
return Err(WsClientError::InvalidConnectionHeader(
conn.clone(),
));
}
} else {
log::trace!("Invalid connection header: {:?}", conn);
return Err(WsClientError::InvalidConnectionHeader(
conn.clone(),
));
}
} else {
log::trace!("Missing connection header");
return Err(WsClientError::MissingConnectionHeader);
}
if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) {
let encoded = ws::hash_key(key.as_ref());
if hdr_key.as_bytes() != encoded.as_bytes() {
log::trace!(
"Invalid challenge response: expected: {} received: {:?}",
encoded,
key
);
return Err(WsClientError::InvalidChallengeResponse(
encoded,
hdr_key.clone(),
));
}
} else {
log::trace!("Missing SEC-WEBSOCKET-ACCEPT header");
return Err(WsClientError::MissingWebSocketAcceptHeader);
};
// response and ws framed
Ok((
ClientResponse::new(head, Payload::None),
framed.map_codec(|_| {
if server_mode {
ws::Codec::new().max_size(max_size)
} else {
ws::Codec::new().max_size(max_size).client_mode()
}
}),
))
});
// set request timeout // set request timeout
if let Some(timeout) = self.config.timeout { let (head, framed) = if let Some(timeout) = self.config.timeout {
Either::B(Either::A(Timeout::new(fut, timeout).map_err(|e| { Timeout::new(fut, timeout)
if let Some(e) = e.into_inner() { .await
e .map_err(|_| SendRequestError::Timeout.into())
} else { .and_then(|res| res)?
SendRequestError::Timeout.into()
}
})))
} else { } else {
Either::B(Either::B(fut)) fut.await?
};
// verify response
if head.status != StatusCode::SWITCHING_PROTOCOLS {
return Err(WsClientError::InvalidResponseStatus(head.status));
} }
// Check for "UPGRADE" to websocket header
let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) {
if let Ok(s) = hdr.to_str() {
s.to_ascii_lowercase().contains("websocket")
} else {
false
}
} else {
false
};
if !has_hdr {
log::trace!("Invalid upgrade header");
return Err(WsClientError::InvalidUpgradeHeader);
}
// Check for "CONNECTION" header
if let Some(conn) = head.headers.get(&header::CONNECTION) {
if let Ok(s) = conn.to_str() {
if !s.to_ascii_lowercase().contains("upgrade") {
log::trace!("Invalid connection header: {}", s);
return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
}
} else {
log::trace!("Invalid connection header: {:?}", conn);
return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
}
} else {
log::trace!("Missing connection header");
return Err(WsClientError::MissingConnectionHeader);
}
if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) {
let encoded = ws::hash_key(key.as_ref());
if hdr_key.as_bytes() != encoded.as_bytes() {
log::trace!(
"Invalid challenge response: expected: {} received: {:?}",
encoded,
key
);
return Err(WsClientError::InvalidChallengeResponse(
encoded,
hdr_key.clone(),
));
}
} else {
log::trace!("Missing SEC-WEBSOCKET-ACCEPT header");
return Err(WsClientError::MissingWebSocketAcceptHeader);
};
// response and ws framed
Ok((
ClientResponse::new(head, Payload::None),
framed.map_codec(|_| {
if server_mode {
ws::Codec::new().max_size(max_size)
} else {
ws::Codec::new().max_size(max_size).client_mode()
}
}),
))
} }
} }
@@ -398,6 +389,8 @@ impl fmt::Debug for WebsocketsRequest {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use actix_web::test::block_on;
use super::*; use super::*;
use crate::Client; use crate::Client;
@@ -472,35 +465,33 @@ mod tests {
#[test] #[test]
fn basics() { fn basics() {
let req = Client::new() block_on(async {
.ws("http://localhost/") let req = Client::new()
.origin("test-origin") .ws("http://localhost/")
.max_frame_size(100) .origin("test-origin")
.server_mode() .max_frame_size(100)
.protocols(&["v1", "v2"]) .server_mode()
.set_header_if_none(header::CONTENT_TYPE, "json") .protocols(&["v1", "v2"])
.set_header_if_none(header::CONTENT_TYPE, "text") .set_header_if_none(header::CONTENT_TYPE, "json")
.cookie(Cookie::build("cookie1", "value1").finish()); .set_header_if_none(header::CONTENT_TYPE, "text")
assert_eq!( .cookie(Cookie::build("cookie1", "value1").finish());
req.origin.as_ref().unwrap().to_str().unwrap(), assert_eq!(
"test-origin" req.origin.as_ref().unwrap().to_str().unwrap(),
); "test-origin"
assert_eq!(req.max_size, 100); );
assert_eq!(req.server_mode, true); assert_eq!(req.max_size, 100);
assert_eq!(req.protocols, Some("v1,v2".to_string())); assert_eq!(req.server_mode, true);
assert_eq!( assert_eq!(req.protocols, Some("v1,v2".to_string()));
req.head.headers.get(header::CONTENT_TYPE).unwrap(), assert_eq!(
header::HeaderValue::from_static("json") req.head.headers.get(header::CONTENT_TYPE).unwrap(),
); header::HeaderValue::from_static("json")
);
let _ = actix_http_test::block_fn(move || req.connect()); let _ = req.connect().await;
assert!(Client::new().ws("/").connect().poll().is_err()); assert!(Client::new().ws("/").connect().await.is_err());
assert!(Client::new().ws("http:///test").connect().poll().is_err()); assert!(Client::new().ws("http:///test").connect().await.is_err());
assert!(Client::new() assert!(Client::new().ws("hmm://test.com/").connect().await.is_err());
.ws("hmm://test.com/") })
.connect()
.poll()
.is_err());
} }
} }

File diff suppressed because it is too large Load Diff

View File

@@ -1,96 +1,109 @@
#![cfg(feature = "rust-tls")] #![cfg(feature = "rustls")]
use rustls::{ use rust_tls::ClientConfig;
internal::pemfile::{certs, pkcs8_private_keys},
ClientConfig, NoClientAuth,
};
use std::fs::File; use std::io::Result;
use std::io::{BufReader, Result};
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::HttpService; use actix_http::HttpService;
use actix_http_test::TestServer; use actix_http_test::{block_on, TestServer};
use actix_server::ssl::RustlsAcceptor; use actix_server::ssl::OpensslAcceptor;
use actix_service::{service_fn, NewService}; use actix_service::{pipeline_factory, ServiceFactory};
use actix_web::http::Version; use actix_web::http::Version;
use actix_web::{web, App, HttpResponse}; use actix_web::{web, App, HttpResponse};
use futures::future::ok;
use open_ssl::ssl::{SslAcceptor, SslFiletype, SslMethod, SslVerifyMode};
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<RustlsAcceptor<T, ()>> { fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> {
use rustls::ServerConfig;
// load ssl keys // load ssl keys
let mut config = ServerConfig::new(NoClientAuth::new()); let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap()); builder.set_verify_callback(SslVerifyMode::NONE, |_, _| true);
let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap()); builder
let cert_chain = certs(cert_file).unwrap(); .set_private_key_file("../tests/key.pem", SslFiletype::PEM)
let mut keys = pkcs8_private_keys(key_file).unwrap(); .unwrap();
config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); builder
let protos = vec![b"h2".to_vec()]; .set_certificate_chain_file("../tests/cert.pem")
config.set_protocols(&protos); .unwrap();
Ok(RustlsAcceptor::new(config)) builder.set_alpn_select_callback(|_, protos| {
const H2: &[u8] = b"\x02h2";
if protos.windows(3).any(|window| window == H2) {
Ok(b"h2")
} else {
Err(open_ssl::ssl::AlpnError::NOACK)
}
});
builder.set_alpn_protos(b"\x02h2")?;
Ok(actix_server::ssl::OpensslAcceptor::new(builder.build()))
} }
mod danger { mod danger {
pub struct NoCertificateVerification {} pub struct NoCertificateVerification {}
impl rustls::ServerCertVerifier for NoCertificateVerification { impl rust_tls::ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert( fn verify_server_cert(
&self, &self,
_roots: &rustls::RootCertStore, _roots: &rust_tls::RootCertStore,
_presented_certs: &[rustls::Certificate], _presented_certs: &[rust_tls::Certificate],
_dns_name: webpki::DNSNameRef<'_>, _dns_name: webpki::DNSNameRef<'_>,
_ocsp: &[u8], _ocsp: &[u8],
) -> Result<rustls::ServerCertVerified, rustls::TLSError> { ) -> Result<rust_tls::ServerCertVerified, rust_tls::TLSError> {
Ok(rustls::ServerCertVerified::assertion()) Ok(rust_tls::ServerCertVerified::assertion())
} }
} }
} }
#[test] // #[test]
fn test_connection_reuse_h2() { fn _test_connection_reuse_h2() {
let rustls = ssl_acceptor().unwrap(); block_on(async {
let num = Arc::new(AtomicUsize::new(0)); let openssl = ssl_acceptor().unwrap();
let num2 = num.clone(); let num = Arc::new(AtomicUsize::new(0));
let num2 = num.clone();
let mut srv = TestServer::new(move || { let srv = TestServer::start(move || {
let num2 = num2.clone(); let num2 = num2.clone();
service_fn(move |io| { pipeline_factory(move |io| {
num2.fetch_add(1, Ordering::Relaxed); num2.fetch_add(1, Ordering::Relaxed);
Ok(io) ok(io)
}) })
.and_then(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) .and_then(
.and_then( openssl
HttpService::build() .clone()
.h2(App::new() .map_err(|e| println!("Openssl error: {}", e)),
.service(web::resource("/").route(web::to(|| HttpResponse::Ok())))) )
.map_err(|_| ()), .and_then(
) HttpService::build()
}); .h2(App::new().service(
web::resource("/").route(web::to(|| HttpResponse::Ok())),
))
.map_err(|_| ()),
)
});
// disable ssl verification // disable ssl verification
let mut config = ClientConfig::new(); let mut config = ClientConfig::new();
let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
config.set_protocols(&protos); config.set_protocols(&protos);
config config
.dangerous() .dangerous()
.set_certificate_verifier(Arc::new(danger::NoCertificateVerification {})); .set_certificate_verifier(Arc::new(danger::NoCertificateVerification {}));
let client = awc::Client::build() let client = awc::Client::build()
.connector(awc::Connector::new().rustls(Arc::new(config)).finish()) .connector(awc::Connector::new().rustls(Arc::new(config)).finish())
.finish(); .finish();
// req 1 // req 1
let request = client.get(srv.surl("/")).send(); let request = client.get(srv.surl("/")).send();
let response = srv.block_on(request).unwrap(); let response = request.await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
// req 2 // req 2
let req = client.post(srv.surl("/")); let req = client.post(srv.surl("/"));
let response = srv.block_on_fn(move || req.send()).unwrap(); let response = req.send().await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2); assert_eq!(response.version(), Version::HTTP_2);
// one connection // one connection
assert_eq!(num.load(Ordering::Relaxed), 1); assert_eq!(num.load(Ordering::Relaxed), 1);
})
} }

View File

@@ -1,5 +1,5 @@
#![cfg(feature = "ssl")] #![cfg(feature = "openssl")]
use openssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode}; use open_ssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode};
use std::io::Result; use std::io::Result;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
@@ -7,11 +7,12 @@ use std::sync::Arc;
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::HttpService; use actix_http::HttpService;
use actix_http_test::TestServer; use actix_http_test::{block_on, TestServer};
use actix_server::ssl::OpensslAcceptor; use actix_server::ssl::OpensslAcceptor;
use actix_service::{service_fn, NewService}; use actix_service::{pipeline_factory, ServiceFactory};
use actix_web::http::Version; use actix_web::http::Version;
use actix_web::{web, App, HttpResponse}; use actix_web::{web, App, HttpResponse};
use futures::future::ok;
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> { fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> {
// load ssl keys // load ssl keys
@@ -27,7 +28,7 @@ fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> {
if protos.windows(3).any(|window| window == H2) { if protos.windows(3).any(|window| window == H2) {
Ok(b"h2") Ok(b"h2")
} else { } else {
Err(openssl::ssl::AlpnError::NOACK) Err(open_ssl::ssl::AlpnError::NOACK)
} }
}); });
builder.set_alpn_protos(b"\x02h2")?; builder.set_alpn_protos(b"\x02h2")?;
@@ -36,51 +37,54 @@ fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> {
#[test] #[test]
fn test_connection_reuse_h2() { fn test_connection_reuse_h2() {
let openssl = ssl_acceptor().unwrap(); block_on(async {
let num = Arc::new(AtomicUsize::new(0)); let openssl = ssl_acceptor().unwrap();
let num2 = num.clone(); let num = Arc::new(AtomicUsize::new(0));
let num2 = num.clone();
let mut srv = TestServer::new(move || { let srv = TestServer::start(move || {
let num2 = num2.clone(); let num2 = num2.clone();
service_fn(move |io| { pipeline_factory(move |io| {
num2.fetch_add(1, Ordering::Relaxed); num2.fetch_add(1, Ordering::Relaxed);
Ok(io) ok(io)
}) })
.and_then( .and_then(
openssl openssl
.clone() .clone()
.map_err(|e| println!("Openssl error: {}", e)), .map_err(|e| println!("Openssl error: {}", e)),
) )
.and_then( .and_then(
HttpService::build() HttpService::build()
.h2(App::new() .h2(App::new().service(
.service(web::resource("/").route(web::to(|| HttpResponse::Ok())))) web::resource("/").route(web::to(|| HttpResponse::Ok())),
.map_err(|_| ()), ))
) .map_err(|_| ()),
}); )
});
// disable ssl verification // disable ssl verification
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_verify(SslVerifyMode::NONE); builder.set_verify(SslVerifyMode::NONE);
let _ = builder let _ = builder
.set_alpn_protos(b"\x02h2\x08http/1.1") .set_alpn_protos(b"\x02h2\x08http/1.1")
.map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e));
let client = awc::Client::build() let client = awc::Client::build()
.connector(awc::Connector::new().ssl(builder.build()).finish()) .connector(awc::Connector::new().ssl(builder.build()).finish())
.finish(); .finish();
// req 1 // req 1
let request = client.get(srv.surl("/")).send(); let request = client.get(srv.surl("/")).send();
let response = srv.block_on(request).unwrap(); let response = request.await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
// req 2 // req 2
let req = client.post(srv.surl("/")); let req = client.post(srv.surl("/"));
let response = srv.block_on_fn(move || req.send()).unwrap(); let response = req.send().await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2); assert_eq!(response.version(), Version::HTTP_2);
// one connection // one connection
assert_eq!(num.load(Ordering::Relaxed), 1); assert_eq!(num.load(Ordering::Relaxed), 1);
})
} }

View File

@@ -2,81 +2,82 @@ use std::io;
use actix_codec::Framed; use actix_codec::Framed;
use actix_http::{body::BodySize, h1, ws, Error, HttpService, Request, Response}; use actix_http::{body::BodySize, h1, ws, Error, HttpService, Request, Response};
use actix_http_test::TestServer; use actix_http_test::{block_on, TestServer};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::future::ok; use futures::future::ok;
use futures::{Future, Sink, Stream}; use futures::{SinkExt, StreamExt};
fn ws_service(req: ws::Frame) -> impl Future<Item = ws::Message, Error = io::Error> { async fn ws_service(req: ws::Frame) -> Result<ws::Message, io::Error> {
match req { match req {
ws::Frame::Ping(msg) => ok(ws::Message::Pong(msg)), ws::Frame::Ping(msg) => Ok(ws::Message::Pong(msg)),
ws::Frame::Text(text) => { ws::Frame::Text(text) => {
let text = if let Some(pl) = text { let text = if let Some(pl) = text {
String::from_utf8(Vec::from(pl.as_ref())).unwrap() String::from_utf8(Vec::from(pl.as_ref())).unwrap()
} else { } else {
String::new() String::new()
}; };
ok(ws::Message::Text(text)) Ok(ws::Message::Text(text))
} }
ws::Frame::Binary(bin) => ok(ws::Message::Binary( ws::Frame::Binary(bin) => Ok(ws::Message::Binary(
bin.map(|e| e.freeze()) bin.map(|e| e.freeze())
.unwrap_or_else(|| Bytes::from("")) .unwrap_or_else(|| Bytes::from(""))
.into(), .into(),
)), )),
ws::Frame::Close(reason) => ok(ws::Message::Close(reason)), ws::Frame::Close(reason) => Ok(ws::Message::Close(reason)),
_ => ok(ws::Message::Close(None)), _ => Ok(ws::Message::Close(None)),
} }
} }
#[test] #[test]
fn test_simple() { fn test_simple() {
let mut srv = TestServer::new(|| { block_on(async {
HttpService::build() let mut srv = TestServer::start(|| {
.upgrade(|(req, framed): (Request, Framed<_, _>)| { HttpService::build()
let res = ws::handshake_response(req.head()).finish(); .upgrade(|(req, mut framed): (Request, Framed<_, _>)| {
// send handshake response async move {
framed let res = ws::handshake_response(req.head()).finish();
.send(h1::Message::Item((res.drop_body(), BodySize::None))) // send handshake response
.map_err(|e: io::Error| e.into()) framed
.and_then(|framed| { .send(h1::Message::Item((res.drop_body(), BodySize::None)))
.await?;
// start websocket service // start websocket service
let framed = framed.into_framed(ws::Codec::new()); let framed = framed.into_framed(ws::Codec::new());
ws::Transport::with(framed, ws_service) ws::Transport::with(framed, ws_service).await
}) }
}) })
.finish(|_| ok::<_, Error>(Response::NotFound())) .finish(|_| ok::<_, Error>(Response::NotFound()))
}); });
// client service // client service
let framed = srv.ws().unwrap(); let mut framed = srv.ws().await.unwrap();
let framed = srv framed
.block_on(framed.send(ws::Message::Text("text".to_string()))) .send(ws::Message::Text("text".to_string()))
.unwrap(); .await
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); .unwrap();
assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text"))))); let item = framed.next().await.unwrap().unwrap();
assert_eq!(item, ws::Frame::Text(Some(BytesMut::from("text"))));
let framed = srv framed
.block_on(framed.send(ws::Message::Binary("text".into()))) .send(ws::Message::Binary("text".into()))
.unwrap(); .await
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); .unwrap();
assert_eq!( let item = framed.next().await.unwrap().unwrap();
item, assert_eq!(
Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))) item,
); ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))
);
let framed = srv framed.send(ws::Message::Ping("text".into())).await.unwrap();
.block_on(framed.send(ws::Message::Ping("text".into()))) let item = framed.next().await.unwrap().unwrap();
.unwrap(); assert_eq!(item, ws::Frame::Pong("text".to_string().into()));
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into())));
let framed = srv framed
.block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))) .send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
.unwrap(); .await
.unwrap();
let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); let item = framed.next().await.unwrap().unwrap();
assert_eq!( assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into())));
item, })
Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into())))
);
} }

View File

@@ -1,22 +1,18 @@
use futures::IntoFuture; use actix_web::{get, middleware, web, App, HttpRequest, HttpResponse, HttpServer};
use actix_web::{
get, middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer,
};
#[get("/resource1/{name}/index.html")] #[get("/resource1/{name}/index.html")]
fn index(req: HttpRequest, name: web::Path<String>) -> String { async fn index(req: HttpRequest, name: web::Path<String>) -> String {
println!("REQ: {:?}", req); println!("REQ: {:?}", req);
format!("Hello: {}!\r\n", name) format!("Hello: {}!\r\n", name)
} }
fn index_async(req: HttpRequest) -> impl IntoFuture<Item = &'static str, Error = Error> { async fn index_async(req: HttpRequest) -> &'static str {
println!("REQ: {:?}", req); println!("REQ: {:?}", req);
Ok("Hello world!\r\n") "Hello world!\r\n"
} }
#[get("/")] #[get("/")]
fn no_params() -> &'static str { async fn no_params() -> &'static str {
"Hello world!\r\n" "Hello world!\r\n"
} }
@@ -39,9 +35,9 @@ fn main() -> std::io::Result<()> {
.default_service( .default_service(
web::route().to(|| HttpResponse::MethodNotAllowed()), web::route().to(|| HttpResponse::MethodNotAllowed()),
) )
.route(web::get().to_async(index_async)), .route(web::get().to(index_async)),
) )
.service(web::resource("/test1.html").to(|| "Test\r\n")) .service(web::resource("/test1.html").to(|| async { "Test\r\n" }))
}) })
.bind("127.0.0.1:8080")? .bind("127.0.0.1:8080")?
.workers(1) .workers(1)

View File

@@ -1,26 +1,27 @@
use actix_http::Error; use actix_http::Error;
use actix_rt::System; use actix_rt::System;
use futures::{future::lazy, Future};
fn main() -> Result<(), Error> { fn main() -> Result<(), Error> {
std::env::set_var("RUST_LOG", "actix_http=trace"); std::env::set_var("RUST_LOG", "actix_http=trace");
env_logger::init(); env_logger::init();
System::new("test").block_on(lazy(|| { System::new("test").block_on(async {
awc::Client::new() let client = awc::Client::new();
.get("https://www.rust-lang.org/") // <- Create request builder
.header("User-Agent", "Actix-web")
.send() // <- Send http request
.from_err()
.and_then(|mut response| {
// <- server http response
println!("Response: {:?}", response);
// read response body // Create request builder, configure request and send
response let mut response = client
.body() .get("https://www.rust-lang.org/")
.from_err() .header("User-Agent", "Actix-web")
.map(|body| println!("Downloaded: {:?} bytes", body.len())) .send()
}) .await?;
}))
// server http response
println!("Response: {:?}", response);
// read response body
let body = response.body().await?;
println!("Downloaded: {:?} bytes", body.len());
Ok(())
})
} }

View File

@@ -1,26 +1,24 @@
use futures::IntoFuture;
use actix_web::{ use actix_web::{
get, middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer, get, middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer,
}; };
#[get("/resource1/{name}/index.html")] #[get("/resource1/{name}/index.html")]
fn index(req: HttpRequest, name: web::Path<String>) -> String { async fn index(req: HttpRequest, name: web::Path<String>) -> String {
println!("REQ: {:?}", req); println!("REQ: {:?}", req);
format!("Hello: {}!\r\n", name) format!("Hello: {}!\r\n", name)
} }
fn index_async(req: HttpRequest) -> impl IntoFuture<Item = &'static str, Error = Error> { async fn index_async(req: HttpRequest) -> Result<&'static str, Error> {
println!("REQ: {:?}", req); println!("REQ: {:?}", req);
Ok("Hello world!\r\n") Ok("Hello world!\r\n")
} }
#[get("/")] #[get("/")]
fn no_params() -> &'static str { async fn no_params() -> &'static str {
"Hello world!\r\n" "Hello world!\r\n"
} }
#[cfg(feature = "uds")] #[cfg(unix)]
fn main() -> std::io::Result<()> { fn main() -> std::io::Result<()> {
std::env::set_var("RUST_LOG", "actix_server=info,actix_web=info"); std::env::set_var("RUST_LOG", "actix_server=info,actix_web=info");
env_logger::init(); env_logger::init();
@@ -40,14 +38,14 @@ fn main() -> std::io::Result<()> {
.default_service( .default_service(
web::route().to(|| HttpResponse::MethodNotAllowed()), web::route().to(|| HttpResponse::MethodNotAllowed()),
) )
.route(web::get().to_async(index_async)), .route(web::get().to(index_async)),
) )
.service(web::resource("/test1.html").to(|| "Test\r\n")) .service(web::resource("/test1.html").to(|| async { "Test\r\n" }))
}) })
.bind_uds("/Users/fafhrd91/uds-test")? .bind_uds("/Users/fafhrd91/uds-test")?
.workers(1) .workers(1)
.run() .run()
} }
#[cfg(not(feature = "uds"))] #[cfg(not(unix))]
fn main() {} fn main() {}

View File

@@ -1,14 +1,17 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::fmt; use std::fmt;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_http::body::{Body, MessageBody}; use actix_http::body::{Body, MessageBody};
use actix_service::boxed::{self, BoxedNewService}; use actix_service::boxed::{self, BoxedNewService};
use actix_service::{ use actix_service::{
apply_transform, IntoNewService, IntoTransform, NewService, Transform, apply, apply_fn_factory, IntoServiceFactory, ServiceFactory, Transform,
}; };
use futures::{Future, IntoFuture}; use futures::future::{FutureExt, LocalBoxFuture};
use crate::app_service::{AppEntry, AppInit, AppRoutingFactory}; use crate::app_service::{AppEntry, AppInit, AppRoutingFactory};
use crate::config::{AppConfig, AppConfigInner, ServiceConfig}; use crate::config::{AppConfig, AppConfigInner, ServiceConfig};
@@ -18,19 +21,19 @@ use crate::error::Error;
use crate::resource::Resource; use crate::resource::Resource;
use crate::route::Route; use crate::route::Route;
use crate::service::{ use crate::service::{
HttpServiceFactory, ServiceFactory, ServiceFactoryWrapper, ServiceRequest, AppServiceFactory, HttpServiceFactory, ServiceFactoryWrapper, ServiceRequest,
ServiceResponse, ServiceResponse,
}; };
type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>; type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>;
type FnDataFactory = type FnDataFactory =
Box<dyn Fn() -> Box<dyn Future<Item = Box<dyn DataFactory>, Error = ()>>>; Box<dyn Fn() -> LocalBoxFuture<'static, Result<Box<dyn DataFactory>, ()>>>;
/// Application builder - structure that follows the builder pattern /// Application builder - structure that follows the builder pattern
/// for building application instances. /// for building application instances.
pub struct App<T, B> { pub struct App<T, B> {
endpoint: T, endpoint: T,
services: Vec<Box<dyn ServiceFactory>>, services: Vec<Box<dyn AppServiceFactory>>,
default: Option<Rc<HttpNewService>>, default: Option<Rc<HttpNewService>>,
factory_ref: Rc<RefCell<Option<AppRoutingFactory>>>, factory_ref: Rc<RefCell<Option<AppRoutingFactory>>>,
data: Vec<Box<dyn DataFactory>>, data: Vec<Box<dyn DataFactory>>,
@@ -61,7 +64,7 @@ impl App<AppEntry, Body> {
impl<T, B> App<T, B> impl<T, B> App<T, B>
where where
B: MessageBody, B: MessageBody,
T: NewService< T: ServiceFactory<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse<B>, Response = ServiceResponse<B>,
@@ -87,7 +90,7 @@ where
/// counter: Cell<usize>, /// counter: Cell<usize>,
/// } /// }
/// ///
/// fn index(data: web::Data<MyData>) { /// async fn index(data: web::Data<MyData>) {
/// data.counter.set(data.counter.get() + 1); /// data.counter.set(data.counter.get() + 1);
/// } /// }
/// ///
@@ -107,24 +110,30 @@ where
/// Set application data factory. This function is /// Set application data factory. This function is
/// similar to `.data()` but it accepts data factory. Data object get /// similar to `.data()` but it accepts data factory. Data object get
/// constructed asynchronously during application initialization. /// constructed asynchronously during application initialization.
pub fn data_factory<F, Out>(mut self, data: F) -> Self pub fn data_factory<F, Out, D, E>(mut self, data: F) -> Self
where where
F: Fn() -> Out + 'static, F: Fn() -> Out + 'static,
Out: IntoFuture + 'static, Out: Future<Output = Result<D, E>> + 'static,
Out::Error: std::fmt::Debug, D: 'static,
E: std::fmt::Debug,
{ {
self.data_factories.push(Box::new(move || { self.data_factories.push(Box::new(move || {
Box::new( {
data() let fut = data();
.into_future() async move {
.map_err(|e| { match fut.await {
log::error!("Can not construct data instance: {:?}", e); Err(e) => {
}) log::error!("Can not construct data instance: {:?}", e);
.map(|data| { Err(())
let data: Box<dyn DataFactory> = Box::new(Data::new(data)); }
data Ok(data) => {
}), let data: Box<dyn DataFactory> = Box::new(Data::new(data));
) Ok(data)
}
}
}
}
.boxed_local()
})); }));
self self
} }
@@ -183,7 +192,7 @@ where
/// ```rust /// ```rust
/// use actix_web::{web, App, HttpResponse}; /// use actix_web::{web, App, HttpResponse};
/// ///
/// fn index(data: web::Path<(String, String)>) -> &'static str { /// async fn index(data: web::Path<(String, String)>) -> &'static str {
/// "Welcome!" /// "Welcome!"
/// } /// }
/// ///
@@ -238,7 +247,7 @@ where
/// ```rust /// ```rust
/// use actix_web::{web, App, HttpResponse}; /// use actix_web::{web, App, HttpResponse};
/// ///
/// fn index() -> &'static str { /// async fn index() -> &'static str {
/// "Welcome!" /// "Welcome!"
/// } /// }
/// ///
@@ -267,8 +276,8 @@ where
/// ``` /// ```
pub fn default_service<F, U>(mut self, f: F) -> Self pub fn default_service<F, U>(mut self, f: F) -> Self
where where
F: IntoNewService<U>, F: IntoServiceFactory<U>,
U: NewService< U: ServiceFactory<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse, Response = ServiceResponse,
@@ -277,11 +286,9 @@ where
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
{ {
// create and configure default resource // create and configure default resource
self.default = Some(Rc::new(boxed::new_service( self.default = Some(Rc::new(boxed::factory(f.into_factory().map_init_err(
f.into_new_service().map_init_err(|e| { |e| log::error!("Can not construct default service: {:?}", e),
log::error!("Can not construct default service: {:?}", e) ))));
}),
)));
self self
} }
@@ -295,7 +302,7 @@ where
/// ```rust /// ```rust
/// use actix_web::{web, App, HttpRequest, HttpResponse, Result}; /// use actix_web::{web, App, HttpRequest, HttpResponse, Result};
/// ///
/// fn index(req: HttpRequest) -> Result<HttpResponse> { /// async fn index(req: HttpRequest) -> Result<HttpResponse> {
/// let url = req.url_for("youtube", &["asdlkjqme"])?; /// let url = req.url_for("youtube", &["asdlkjqme"])?;
/// assert_eq!(url.as_str(), "https://youtube.com/watch/asdlkjqme"); /// assert_eq!(url.as_str(), "https://youtube.com/watch/asdlkjqme");
/// Ok(HttpResponse::Ok().into()) /// Ok(HttpResponse::Ok().into())
@@ -336,11 +343,10 @@ where
/// ///
/// ```rust /// ```rust
/// use actix_service::Service; /// use actix_service::Service;
/// # use futures::Future;
/// use actix_web::{middleware, web, App}; /// use actix_web::{middleware, web, App};
/// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue};
/// ///
/// fn index() -> &'static str { /// async fn index() -> &'static str {
/// "Welcome!" /// "Welcome!"
/// } /// }
/// ///
@@ -350,11 +356,11 @@ where
/// .route("/index.html", web::get().to(index)); /// .route("/index.html", web::get().to(index));
/// } /// }
/// ``` /// ```
pub fn wrap<M, B1, F>( pub fn wrap<M, B1>(
self, self,
mw: F, mw: M,
) -> App< ) -> App<
impl NewService< impl ServiceFactory<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse<B1>, Response = ServiceResponse<B1>,
@@ -372,11 +378,9 @@ where
InitError = (), InitError = (),
>, >,
B1: MessageBody, B1: MessageBody,
F: IntoTransform<M, T::Service>,
{ {
let endpoint = apply_transform(mw, self.endpoint);
App { App {
endpoint, endpoint: apply(mw, self.endpoint),
data: self.data, data: self.data,
data_factories: self.data_factories, data_factories: self.data_factories,
services: self.services, services: self.services,
@@ -397,23 +401,25 @@ where
/// ///
/// ```rust /// ```rust
/// use actix_service::Service; /// use actix_service::Service;
/// # use futures::Future;
/// use actix_web::{web, App}; /// use actix_web::{web, App};
/// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue};
/// ///
/// fn index() -> &'static str { /// async fn index() -> &'static str {
/// "Welcome!" /// "Welcome!"
/// } /// }
/// ///
/// fn main() { /// fn main() {
/// let app = App::new() /// let app = App::new()
/// .wrap_fn(|req, srv| /// .wrap_fn(|req, srv| {
/// srv.call(req).map(|mut res| { /// let fut = srv.call(req);
/// async {
/// let mut res = fut.await?;
/// res.headers_mut().insert( /// res.headers_mut().insert(
/// CONTENT_TYPE, HeaderValue::from_static("text/plain"), /// CONTENT_TYPE, HeaderValue::from_static("text/plain"),
/// ); /// );
/// res /// Ok(res)
/// })) /// }
/// })
/// .route("/index.html", web::get().to(index)); /// .route("/index.html", web::get().to(index));
/// } /// }
/// ``` /// ```
@@ -421,7 +427,7 @@ where
self, self,
mw: F, mw: F,
) -> App< ) -> App<
impl NewService< impl ServiceFactory<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse<B1>, Response = ServiceResponse<B1>,
@@ -433,16 +439,26 @@ where
where where
B1: MessageBody, B1: MessageBody,
F: FnMut(ServiceRequest, &mut T::Service) -> R + Clone, F: FnMut(ServiceRequest, &mut T::Service) -> R + Clone,
R: IntoFuture<Item = ServiceResponse<B1>, Error = Error>, R: Future<Output = Result<ServiceResponse<B1>, Error>>,
{ {
self.wrap(mw) App {
endpoint: apply_fn_factory(self.endpoint, mw),
data: self.data,
data_factories: self.data_factories,
services: self.services,
default: self.default,
factory_ref: self.factory_ref,
config: self.config,
external: self.external,
_t: PhantomData,
}
} }
} }
impl<T, B> IntoNewService<AppInit<T, B>> for App<T, B> impl<T, B> IntoServiceFactory<AppInit<T, B>> for App<T, B>
where where
B: MessageBody, B: MessageBody,
T: NewService< T: ServiceFactory<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse<B>, Response = ServiceResponse<B>,
@@ -450,7 +466,7 @@ where
InitError = (), InitError = (),
>, >,
{ {
fn into_new_service(self) -> AppInit<T, B> { fn into_factory(self) -> AppInit<T, B> {
AppInit { AppInit {
data: Rc::new(self.data), data: Rc::new(self.data),
data_factories: Rc::new(self.data_factories), data_factories: Rc::new(self.data_factories),
@@ -468,82 +484,89 @@ where
mod tests { mod tests {
use actix_service::Service; use actix_service::Service;
use bytes::Bytes; use bytes::Bytes;
use futures::{Future, IntoFuture}; use futures::future::{ok, Future};
use super::*; use super::*;
use crate::http::{header, HeaderValue, Method, StatusCode}; use crate::http::{header, HeaderValue, Method, StatusCode};
use crate::middleware::DefaultHeaders;
use crate::service::{ServiceRequest, ServiceResponse}; use crate::service::{ServiceRequest, ServiceResponse};
use crate::test::{ use crate::test::{block_on, call_service, init_service, read_body, TestRequest};
block_fn, block_on, call_service, init_service, read_body, TestRequest,
};
use crate::{web, Error, HttpRequest, HttpResponse}; use crate::{web, Error, HttpRequest, HttpResponse};
#[test] #[test]
fn test_default_resource() { fn test_default_resource() {
let mut srv = init_service( block_on(async {
App::new().service(web::resource("/test").to(|| HttpResponse::Ok())), let mut srv = init_service(
); App::new().service(web::resource("/test").to(|| HttpResponse::Ok())),
let req = TestRequest::with_uri("/test").to_request(); )
let resp = block_fn(|| srv.call(req)).unwrap(); .await;
assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/test").to_request();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let req = TestRequest::with_uri("/blah").to_request(); let req = TestRequest::with_uri("/blah").to_request();
let resp = block_on(srv.call(req)).unwrap(); let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND); assert_eq!(resp.status(), StatusCode::NOT_FOUND);
let mut srv = init_service( let mut srv = init_service(
App::new() App::new()
.service(web::resource("/test").to(|| HttpResponse::Ok())) .service(web::resource("/test").to(|| HttpResponse::Ok()))
.service( .service(
web::resource("/test2") web::resource("/test2")
.default_service(|r: ServiceRequest| { .default_service(|r: ServiceRequest| {
r.into_response(HttpResponse::Created()) ok(r.into_response(HttpResponse::Created()))
}) })
.route(web::get().to(|| HttpResponse::Ok())), .route(web::get().to(|| HttpResponse::Ok())),
) )
.default_service(|r: ServiceRequest| { .default_service(|r: ServiceRequest| {
r.into_response(HttpResponse::MethodNotAllowed()) ok(r.into_response(HttpResponse::MethodNotAllowed()))
}), }),
); )
.await;
let req = TestRequest::with_uri("/blah").to_request(); let req = TestRequest::with_uri("/blah").to_request();
let resp = block_on(srv.call(req)).unwrap(); let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
let req = TestRequest::with_uri("/test2").to_request(); let req = TestRequest::with_uri("/test2").to_request();
let resp = block_on(srv.call(req)).unwrap(); let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
let req = TestRequest::with_uri("/test2") let req = TestRequest::with_uri("/test2")
.method(Method::POST) .method(Method::POST)
.to_request(); .to_request();
let resp = block_on(srv.call(req)).unwrap(); let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::CREATED); assert_eq!(resp.status(), StatusCode::CREATED);
})
} }
#[test] #[test]
fn test_data_factory() { fn test_data_factory() {
let mut srv = block_on(async {
init_service(App::new().data_factory(|| Ok::<_, ()>(10usize)).service( let mut srv =
web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()), init_service(App::new().data_factory(|| ok::<_, ()>(10usize)).service(
)); web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()),
let req = TestRequest::default().to_request(); ))
let resp = block_on(srv.call(req)).unwrap(); .await;
assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::default().to_request();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let mut srv = let mut srv =
init_service(App::new().data_factory(|| Ok::<_, ()>(10u32)).service( init_service(App::new().data_factory(|| ok::<_, ()>(10u32)).service(
web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()), web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()),
)); ))
let req = TestRequest::default().to_request(); .await;
let resp = block_on(srv.call(req)).unwrap(); let req = TestRequest::default().to_request();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
})
} }
fn md<S, B>( fn md<S, B>(
req: ServiceRequest, req: ServiceRequest,
srv: &mut S, srv: &mut S,
) -> impl IntoFuture<Item = ServiceResponse<B>, Error = Error> ) -> impl Future<Output = Result<ServiceResponse<B>, Error>>
where where
S: Service< S: Service<
Request = ServiceRequest, Request = ServiceRequest,
@@ -551,112 +574,141 @@ mod tests {
Error = Error, Error = Error,
>, >,
{ {
srv.call(req).map(|mut res| { let fut = srv.call(req);
async move {
let mut res = fut.await?;
res.headers_mut() res.headers_mut()
.insert(header::CONTENT_TYPE, HeaderValue::from_static("0001")); .insert(header::CONTENT_TYPE, HeaderValue::from_static("0001"));
res Ok(res)
}) }
} }
#[test] #[test]
fn test_wrap() { fn test_wrap() {
let mut srv = init_service( block_on(async {
App::new() let mut srv =
.wrap(md) init_service(
.route("/test", web::get().to(|| HttpResponse::Ok())), App::new()
); .wrap(DefaultHeaders::new().header(
let req = TestRequest::with_uri("/test").to_request(); header::CONTENT_TYPE,
let resp = call_service(&mut srv, req); HeaderValue::from_static("0001"),
assert_eq!(resp.status(), StatusCode::OK); ))
assert_eq!( .route("/test", web::get().to(|| HttpResponse::Ok())),
resp.headers().get(header::CONTENT_TYPE).unwrap(), )
HeaderValue::from_static("0001") .await;
); let req = TestRequest::with_uri("/test").to_request();
let resp = call_service(&mut srv, req).await;
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
HeaderValue::from_static("0001")
);
})
} }
#[test] #[test]
fn test_router_wrap() { fn test_router_wrap() {
let mut srv = init_service( block_on(async {
App::new() let mut srv =
.route("/test", web::get().to(|| HttpResponse::Ok())) init_service(
.wrap(md), App::new()
); .route("/test", web::get().to(|| HttpResponse::Ok()))
let req = TestRequest::with_uri("/test").to_request(); .wrap(DefaultHeaders::new().header(
let resp = call_service(&mut srv, req); header::CONTENT_TYPE,
assert_eq!(resp.status(), StatusCode::OK); HeaderValue::from_static("0001"),
assert_eq!( )),
resp.headers().get(header::CONTENT_TYPE).unwrap(), )
HeaderValue::from_static("0001") .await;
); let req = TestRequest::with_uri("/test").to_request();
let resp = call_service(&mut srv, req).await;
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
HeaderValue::from_static("0001")
);
})
} }
#[test] #[test]
fn test_wrap_fn() { fn test_wrap_fn() {
let mut srv = init_service( block_on(async {
App::new() let mut srv = init_service(
.wrap_fn(|req, srv| { App::new()
srv.call(req).map(|mut res| { .wrap_fn(|req, srv| {
res.headers_mut().insert( let fut = srv.call(req);
header::CONTENT_TYPE, async move {
HeaderValue::from_static("0001"), let mut res = fut.await?;
); res.headers_mut().insert(
res header::CONTENT_TYPE,
HeaderValue::from_static("0001"),
);
Ok(res)
}
}) })
}) .service(web::resource("/test").to(|| HttpResponse::Ok())),
.service(web::resource("/test").to(|| HttpResponse::Ok())), )
); .await;
let req = TestRequest::with_uri("/test").to_request(); let req = TestRequest::with_uri("/test").to_request();
let resp = call_service(&mut srv, req); let resp = call_service(&mut srv, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!( assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(), resp.headers().get(header::CONTENT_TYPE).unwrap(),
HeaderValue::from_static("0001") HeaderValue::from_static("0001")
); );
})
} }
#[test] #[test]
fn test_router_wrap_fn() { fn test_router_wrap_fn() {
let mut srv = init_service( block_on(async {
App::new() let mut srv = init_service(
.route("/test", web::get().to(|| HttpResponse::Ok())) App::new()
.wrap_fn(|req, srv| { .route("/test", web::get().to(|| HttpResponse::Ok()))
srv.call(req).map(|mut res| { .wrap_fn(|req, srv| {
res.headers_mut().insert( let fut = srv.call(req);
header::CONTENT_TYPE, async {
HeaderValue::from_static("0001"), let mut res = fut.await?;
); res.headers_mut().insert(
res header::CONTENT_TYPE,
}) HeaderValue::from_static("0001"),
}), );
); Ok(res)
let req = TestRequest::with_uri("/test").to_request(); }
let resp = call_service(&mut srv, req); }),
assert_eq!(resp.status(), StatusCode::OK); )
assert_eq!( .await;
resp.headers().get(header::CONTENT_TYPE).unwrap(), let req = TestRequest::with_uri("/test").to_request();
HeaderValue::from_static("0001") let resp = call_service(&mut srv, req).await;
); assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
HeaderValue::from_static("0001")
);
})
} }
#[test] #[test]
fn test_external_resource() { fn test_external_resource() {
let mut srv = init_service( block_on(async {
App::new() let mut srv = init_service(
.external_resource("youtube", "https://youtube.com/watch/{video_id}") App::new()
.route( .external_resource("youtube", "https://youtube.com/watch/{video_id}")
"/test", .route(
web::get().to(|req: HttpRequest| { "/test",
HttpResponse::Ok().body(format!( web::get().to(|req: HttpRequest| {
"{}", HttpResponse::Ok().body(format!(
req.url_for("youtube", &["12345"]).unwrap() "{}",
)) req.url_for("youtube", &["12345"]).unwrap()
}), ))
), }),
); ),
let req = TestRequest::with_uri("/test").to_request(); )
let resp = call_service(&mut srv, req); .await;
assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/test").to_request();
let body = read_body(resp); let resp = call_service(&mut srv, req).await;
assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345")); assert_eq!(resp.status(), StatusCode::OK);
let body = read_body(resp).await;
assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345"));
})
} }
} }

View File

@@ -1,14 +1,16 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_http::{Extensions, Request, Response}; use actix_http::{Extensions, Request, Response};
use actix_router::{Path, ResourceDef, ResourceInfo, Router, Url}; use actix_router::{Path, ResourceDef, ResourceInfo, Router, Url};
use actix_server_config::ServerConfig; use actix_server_config::ServerConfig;
use actix_service::boxed::{self, BoxedNewService, BoxedService}; use actix_service::boxed::{self, BoxedNewService, BoxedService};
use actix_service::{service_fn, NewService, Service}; use actix_service::{service_fn, Service, ServiceFactory};
use futures::future::{ok, Either, FutureResult}; use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready};
use futures::{Async, Future, Poll};
use crate::config::{AppConfig, AppService}; use crate::config::{AppConfig, AppService};
use crate::data::DataFactory; use crate::data::DataFactory;
@@ -16,23 +18,20 @@ use crate::error::Error;
use crate::guard::Guard; use crate::guard::Guard;
use crate::request::{HttpRequest, HttpRequestPool}; use crate::request::{HttpRequest, HttpRequestPool};
use crate::rmap::ResourceMap; use crate::rmap::ResourceMap;
use crate::service::{ServiceFactory, ServiceRequest, ServiceResponse}; use crate::service::{AppServiceFactory, ServiceRequest, ServiceResponse};
type Guards = Vec<Box<dyn Guard>>; type Guards = Vec<Box<dyn Guard>>;
type HttpService = BoxedService<ServiceRequest, ServiceResponse, Error>; type HttpService = BoxedService<ServiceRequest, ServiceResponse, Error>;
type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>; type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>;
type BoxedResponse = Either< type BoxedResponse = LocalBoxFuture<'static, Result<ServiceResponse, Error>>;
FutureResult<ServiceResponse, Error>,
Box<dyn Future<Item = ServiceResponse, Error = Error>>,
>;
type FnDataFactory = type FnDataFactory =
Box<dyn Fn() -> Box<dyn Future<Item = Box<dyn DataFactory>, Error = ()>>>; Box<dyn Fn() -> LocalBoxFuture<'static, Result<Box<dyn DataFactory>, ()>>>;
/// Service factory to convert `Request` to a `ServiceRequest<S>`. /// Service factory to convert `Request` to a `ServiceRequest<S>`.
/// It also executes data factories. /// It also executes data factories.
pub struct AppInit<T, B> pub struct AppInit<T, B>
where where
T: NewService< T: ServiceFactory<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse<B>, Response = ServiceResponse<B>,
@@ -44,15 +43,15 @@ where
pub(crate) data: Rc<Vec<Box<dyn DataFactory>>>, pub(crate) data: Rc<Vec<Box<dyn DataFactory>>>,
pub(crate) data_factories: Rc<Vec<FnDataFactory>>, pub(crate) data_factories: Rc<Vec<FnDataFactory>>,
pub(crate) config: RefCell<AppConfig>, pub(crate) config: RefCell<AppConfig>,
pub(crate) services: Rc<RefCell<Vec<Box<dyn ServiceFactory>>>>, pub(crate) services: Rc<RefCell<Vec<Box<dyn AppServiceFactory>>>>,
pub(crate) default: Option<Rc<HttpNewService>>, pub(crate) default: Option<Rc<HttpNewService>>,
pub(crate) factory_ref: Rc<RefCell<Option<AppRoutingFactory>>>, pub(crate) factory_ref: Rc<RefCell<Option<AppRoutingFactory>>>,
pub(crate) external: RefCell<Vec<ResourceDef>>, pub(crate) external: RefCell<Vec<ResourceDef>>,
} }
impl<T, B> NewService for AppInit<T, B> impl<T, B> ServiceFactory for AppInit<T, B>
where where
T: NewService< T: ServiceFactory<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse<B>, Response = ServiceResponse<B>,
@@ -71,8 +70,8 @@ where
fn new_service(&self, cfg: &ServerConfig) -> Self::Future { fn new_service(&self, cfg: &ServerConfig) -> Self::Future {
// update resource default service // update resource default service
let default = self.default.clone().unwrap_or_else(|| { let default = self.default.clone().unwrap_or_else(|| {
Rc::new(boxed::new_service(service_fn(|req: ServiceRequest| { Rc::new(boxed::factory(service_fn(|req: ServiceRequest| {
Ok(req.into_response(Response::NotFound().finish())) ok(req.into_response(Response::NotFound().finish()))
}))) })))
}); });
@@ -135,23 +134,25 @@ where
} }
} }
#[pin_project::pin_project]
pub struct AppInitResult<T, B> pub struct AppInitResult<T, B>
where where
T: NewService, T: ServiceFactory,
{ {
endpoint: Option<T::Service>, endpoint: Option<T::Service>,
#[pin]
endpoint_fut: T::Future, endpoint_fut: T::Future,
rmap: Rc<ResourceMap>, rmap: Rc<ResourceMap>,
config: AppConfig, config: AppConfig,
data: Rc<Vec<Box<dyn DataFactory>>>, data: Rc<Vec<Box<dyn DataFactory>>>,
data_factories: Vec<Box<dyn DataFactory>>, data_factories: Vec<Box<dyn DataFactory>>,
data_factories_fut: Vec<Box<dyn Future<Item = Box<dyn DataFactory>, Error = ()>>>, data_factories_fut: Vec<LocalBoxFuture<'static, Result<Box<dyn DataFactory>, ()>>>,
_t: PhantomData<B>, _t: PhantomData<B>,
} }
impl<T, B> Future for AppInitResult<T, B> impl<T, B> Future for AppInitResult<T, B>
where where
T: NewService< T: ServiceFactory<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse<B>, Response = ServiceResponse<B>,
@@ -159,48 +160,49 @@ where
InitError = (), InitError = (),
>, >,
{ {
type Item = AppInitService<T::Service, B>; type Output = Result<AppInitService<T::Service, B>, ()>;
type Error = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
// async data factories // async data factories
let mut idx = 0; let mut idx = 0;
while idx < self.data_factories_fut.len() { while idx < this.data_factories_fut.len() {
match self.data_factories_fut[idx].poll()? { match Pin::new(&mut this.data_factories_fut[idx]).poll(cx)? {
Async::Ready(f) => { Poll::Ready(f) => {
self.data_factories.push(f); this.data_factories.push(f);
let _ = self.data_factories_fut.remove(idx); let _ = this.data_factories_fut.remove(idx);
} }
Async::NotReady => idx += 1, Poll::Pending => idx += 1,
} }
} }
if self.endpoint.is_none() { if this.endpoint.is_none() {
if let Async::Ready(srv) = self.endpoint_fut.poll()? { if let Poll::Ready(srv) = this.endpoint_fut.poll(cx)? {
self.endpoint = Some(srv); *this.endpoint = Some(srv);
} }
} }
if self.endpoint.is_some() && self.data_factories_fut.is_empty() { if this.endpoint.is_some() && this.data_factories_fut.is_empty() {
// create app data container // create app data container
let mut data = Extensions::new(); let mut data = Extensions::new();
for f in self.data.iter() { for f in this.data.iter() {
f.create(&mut data); f.create(&mut data);
} }
for f in &self.data_factories { for f in this.data_factories.iter() {
f.create(&mut data); f.create(&mut data);
} }
Ok(Async::Ready(AppInitService { Poll::Ready(Ok(AppInitService {
service: self.endpoint.take().unwrap(), service: this.endpoint.take().unwrap(),
rmap: self.rmap.clone(), rmap: this.rmap.clone(),
config: self.config.clone(), config: this.config.clone(),
data: Rc::new(data), data: Rc::new(data),
pool: HttpRequestPool::create(), pool: HttpRequestPool::create(),
})) }))
} else { } else {
Ok(Async::NotReady) Poll::Pending
} }
} }
} }
@@ -226,8 +228,8 @@ where
type Error = T::Error; type Error = T::Error;
type Future = T::Future; type Future = T::Future;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready() self.service.poll_ready(cx)
} }
fn call(&mut self, req: Request) -> Self::Future { fn call(&mut self, req: Request) -> Self::Future {
@@ -270,7 +272,7 @@ pub struct AppRoutingFactory {
default: Rc<HttpNewService>, default: Rc<HttpNewService>,
} }
impl NewService for AppRoutingFactory { impl ServiceFactory for AppRoutingFactory {
type Config = (); type Config = ();
type Request = ServiceRequest; type Request = ServiceRequest;
type Response = ServiceResponse; type Response = ServiceResponse;
@@ -288,7 +290,7 @@ impl NewService for AppRoutingFactory {
CreateAppRoutingItem::Future( CreateAppRoutingItem::Future(
Some(path.clone()), Some(path.clone()),
guards.borrow_mut().take(), guards.borrow_mut().take(),
service.new_service(&()), service.new_service(&()).boxed_local(),
) )
}) })
.collect(), .collect(),
@@ -298,14 +300,14 @@ impl NewService for AppRoutingFactory {
} }
} }
type HttpServiceFut = Box<dyn Future<Item = HttpService, Error = ()>>; type HttpServiceFut = LocalBoxFuture<'static, Result<HttpService, ()>>;
/// Create app service /// Create app service
#[doc(hidden)] #[doc(hidden)]
pub struct AppRoutingFactoryResponse { pub struct AppRoutingFactoryResponse {
fut: Vec<CreateAppRoutingItem>, fut: Vec<CreateAppRoutingItem>,
default: Option<HttpService>, default: Option<HttpService>,
default_fut: Option<Box<dyn Future<Item = HttpService, Error = ()>>>, default_fut: Option<LocalBoxFuture<'static, Result<HttpService, ()>>>,
} }
enum CreateAppRoutingItem { enum CreateAppRoutingItem {
@@ -314,16 +316,15 @@ enum CreateAppRoutingItem {
} }
impl Future for AppRoutingFactoryResponse { impl Future for AppRoutingFactoryResponse {
type Item = AppRouting; type Output = Result<AppRouting, ()>;
type Error = ();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut done = true; let mut done = true;
if let Some(ref mut fut) = self.default_fut { if let Some(ref mut fut) = self.default_fut {
match fut.poll()? { match Pin::new(fut).poll(cx)? {
Async::Ready(default) => self.default = Some(default), Poll::Ready(default) => self.default = Some(default),
Async::NotReady => done = false, Poll::Pending => done = false,
} }
} }
@@ -334,11 +335,12 @@ impl Future for AppRoutingFactoryResponse {
ref mut path, ref mut path,
ref mut guards, ref mut guards,
ref mut fut, ref mut fut,
) => match fut.poll()? { ) => match Pin::new(fut).poll(cx) {
Async::Ready(service) => { Poll::Ready(Ok(service)) => {
Some((path.take().unwrap(), guards.take(), service)) Some((path.take().unwrap(), guards.take(), service))
} }
Async::NotReady => { Poll::Ready(Err(_)) => return Poll::Ready(Err(())),
Poll::Pending => {
done = false; done = false;
None None
} }
@@ -364,13 +366,13 @@ impl Future for AppRoutingFactoryResponse {
} }
router router
}); });
Ok(Async::Ready(AppRouting { Poll::Ready(Ok(AppRouting {
ready: None, ready: None,
router: router.finish(), router: router.finish(),
default: self.default.take(), default: self.default.take(),
})) }))
} else { } else {
Ok(Async::NotReady) Poll::Pending
} }
} }
} }
@@ -387,11 +389,11 @@ impl Service for AppRouting {
type Error = Error; type Error = Error;
type Future = BoxedResponse; type Future = BoxedResponse;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
if self.ready.is_none() { if self.ready.is_none() {
Ok(Async::Ready(())) Poll::Ready(Ok(()))
} else { } else {
Ok(Async::NotReady) Poll::Pending
} }
} }
@@ -413,7 +415,7 @@ impl Service for AppRouting {
default.call(req) default.call(req)
} else { } else {
let req = req.into_parts().0; let req = req.into_parts().0;
Either::A(ok(ServiceResponse::new(req, Response::NotFound().finish()))) ok(ServiceResponse::new(req, Response::NotFound().finish())).boxed_local()
} }
} }
} }
@@ -429,7 +431,7 @@ impl AppEntry {
} }
} }
impl NewService for AppEntry { impl ServiceFactory for AppEntry {
type Config = (); type Config = ();
type Request = ServiceRequest; type Request = ServiceRequest;
type Response = ServiceResponse; type Response = ServiceResponse;
@@ -464,15 +466,16 @@ mod tests {
#[test] #[test]
fn drop_data() { fn drop_data() {
let data = Arc::new(AtomicBool::new(false)); let data = Arc::new(AtomicBool::new(false));
{ test::block_on(async {
let mut app = test::init_service( let mut app = test::init_service(
App::new() App::new()
.data(DropData(data.clone())) .data(DropData(data.clone()))
.service(web::resource("/test").to(|| HttpResponse::Ok())), .service(web::resource("/test").to(|| HttpResponse::Ok())),
); )
.await;
let req = test::TestRequest::with_uri("/test").to_request(); let req = test::TestRequest::with_uri("/test").to_request();
let _ = test::block_on(app.call(req)).unwrap(); let _ = app.call(req).await.unwrap();
} });
assert!(data.load(Ordering::Relaxed)); assert!(data.load(Ordering::Relaxed));
} }
} }

View File

@@ -3,7 +3,7 @@ use std::rc::Rc;
use actix_http::Extensions; use actix_http::Extensions;
use actix_router::ResourceDef; use actix_router::ResourceDef;
use actix_service::{boxed, IntoNewService, NewService}; use actix_service::{boxed, IntoServiceFactory, ServiceFactory};
use crate::data::{Data, DataFactory}; use crate::data::{Data, DataFactory};
use crate::error::Error; use crate::error::Error;
@@ -12,7 +12,7 @@ use crate::resource::Resource;
use crate::rmap::ResourceMap; use crate::rmap::ResourceMap;
use crate::route::Route; use crate::route::Route;
use crate::service::{ use crate::service::{
HttpServiceFactory, ServiceFactory, ServiceFactoryWrapper, ServiceRequest, AppServiceFactory, HttpServiceFactory, ServiceFactoryWrapper, ServiceRequest,
ServiceResponse, ServiceResponse,
}; };
@@ -102,11 +102,11 @@ impl AppService {
&mut self, &mut self,
rdef: ResourceDef, rdef: ResourceDef,
guards: Option<Vec<Box<dyn Guard>>>, guards: Option<Vec<Box<dyn Guard>>>,
service: F, factory: F,
nested: Option<Rc<ResourceMap>>, nested: Option<Rc<ResourceMap>>,
) where ) where
F: IntoNewService<S>, F: IntoServiceFactory<S>,
S: NewService< S: ServiceFactory<
Config = (), Config = (),
Request = ServiceRequest, Request = ServiceRequest,
Response = ServiceResponse, Response = ServiceResponse,
@@ -116,7 +116,7 @@ impl AppService {
{ {
self.services.push(( self.services.push((
rdef, rdef,
boxed::new_service(service.into_new_service()), boxed::factory(factory.into_factory()),
guards, guards,
nested, nested,
)); ));
@@ -174,7 +174,7 @@ impl Default for AppConfigInner {
/// to set of external methods. This could help with /// to set of external methods. This could help with
/// modularization of big application configuration. /// modularization of big application configuration.
pub struct ServiceConfig { pub struct ServiceConfig {
pub(crate) services: Vec<Box<dyn ServiceFactory>>, pub(crate) services: Vec<Box<dyn AppServiceFactory>>,
pub(crate) data: Vec<Box<dyn DataFactory>>, pub(crate) data: Vec<Box<dyn DataFactory>>,
pub(crate) external: Vec<ResourceDef>, pub(crate) external: Vec<ResourceDef>,
} }
@@ -251,17 +251,19 @@ mod tests {
#[test] #[test]
fn test_data() { fn test_data() {
let cfg = |cfg: &mut ServiceConfig| { block_on(async {
cfg.data(10usize); let cfg = |cfg: &mut ServiceConfig| {
}; cfg.data(10usize);
};
let mut srv = let mut srv = init_service(App::new().configure(cfg).service(
init_service(App::new().configure(cfg).service(
web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()), web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()),
)); ))
let req = TestRequest::default().to_request(); .await;
let resp = block_on(srv.call(req)).unwrap(); let req = TestRequest::default().to_request();
assert_eq!(resp.status(), StatusCode::OK); let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
})
} }
// #[test] // #[test]
@@ -298,50 +300,57 @@ mod tests {
#[test] #[test]
fn test_external_resource() { fn test_external_resource() {
let mut srv = init_service( block_on(async {
App::new() let mut srv = init_service(
.configure(|cfg| { App::new()
cfg.external_resource( .configure(|cfg| {
"youtube", cfg.external_resource(
"https://youtube.com/watch/{video_id}", "youtube",
); "https://youtube.com/watch/{video_id}",
}) );
.route( })
"/test", .route(
web::get().to(|req: HttpRequest| { "/test",
HttpResponse::Ok().body(format!( web::get().to(|req: HttpRequest| {
"{}", HttpResponse::Ok().body(format!(
req.url_for("youtube", &["12345"]).unwrap() "{}",
)) req.url_for("youtube", &["12345"]).unwrap()
}), ))
), }),
); ),
let req = TestRequest::with_uri("/test").to_request(); )
let resp = call_service(&mut srv, req); .await;
assert_eq!(resp.status(), StatusCode::OK); let req = TestRequest::with_uri("/test").to_request();
let body = read_body(resp); let resp = call_service(&mut srv, req).await;
assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345")); assert_eq!(resp.status(), StatusCode::OK);
let body = read_body(resp).await;
assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345"));
})
} }
#[test] #[test]
fn test_service() { fn test_service() {
let mut srv = init_service(App::new().configure(|cfg| { block_on(async {
cfg.service( let mut srv = init_service(App::new().configure(|cfg| {
web::resource("/test").route(web::get().to(|| HttpResponse::Created())), cfg.service(
) web::resource("/test")
.route("/index.html", web::get().to(|| HttpResponse::Ok())); .route(web::get().to(|| HttpResponse::Created())),
})); )
.route("/index.html", web::get().to(|| HttpResponse::Ok()));
}))
.await;
let req = TestRequest::with_uri("/test") let req = TestRequest::with_uri("/test")
.method(Method::GET) .method(Method::GET)
.to_request(); .to_request();
let resp = call_service(&mut srv, req); let resp = call_service(&mut srv, req).await;
assert_eq!(resp.status(), StatusCode::CREATED); assert_eq!(resp.status(), StatusCode::CREATED);
let req = TestRequest::with_uri("/index.html") let req = TestRequest::with_uri("/index.html")
.method(Method::GET) .method(Method::GET)
.to_request(); .to_request();
let resp = call_service(&mut srv, req); let resp = call_service(&mut srv, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
})
} }
} }

View File

@@ -3,6 +3,7 @@ use std::sync::Arc;
use actix_http::error::{Error, ErrorInternalServerError}; use actix_http::error::{Error, ErrorInternalServerError};
use actix_http::Extensions; use actix_http::Extensions;
use futures::future::{err, ok, Ready};
use crate::dev::Payload; use crate::dev::Payload;
use crate::extract::FromRequest; use crate::extract::FromRequest;
@@ -44,7 +45,7 @@ pub(crate) trait DataFactory {
/// } /// }
/// ///
/// /// Use `Data<T>` extractor to access data in handler. /// /// Use `Data<T>` extractor to access data in handler.
/// fn index(data: web::Data<Mutex<MyData>>) { /// async fn index(data: web::Data<Mutex<MyData>>) {
/// let mut data = data.lock().unwrap(); /// let mut data = data.lock().unwrap();
/// data.counter += 1; /// data.counter += 1;
/// } /// }
@@ -101,19 +102,19 @@ impl<T> Clone for Data<T> {
impl<T: 'static> FromRequest for Data<T> { impl<T: 'static> FromRequest for Data<T> {
type Config = (); type Config = ();
type Error = Error; type Error = Error;
type Future = Result<Self, Error>; type Future = Ready<Result<Self, Error>>;
#[inline] #[inline]
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
if let Some(st) = req.get_app_data::<T>() { if let Some(st) = req.get_app_data::<T>() {
Ok(st) ok(st)
} else { } else {
log::debug!( log::debug!(
"Failed to construct App-level Data extractor. \ "Failed to construct App-level Data extractor. \
Request path: {:?}", Request path: {:?}",
req.path() req.path()
); );
Err(ErrorInternalServerError( err(ErrorInternalServerError(
"App data is not configured, to configure use App::data()", "App data is not configured, to configure use App::data()",
)) ))
} }
@@ -142,85 +143,99 @@ mod tests {
#[test] #[test]
fn test_data_extractor() { fn test_data_extractor() {
let mut srv = block_on(async {
init_service(App::new().data(10usize).service( let mut srv = init_service(App::new().data(10usize).service(
web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()), web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()),
)); ))
.await;
let req = TestRequest::default().to_request(); let req = TestRequest::default().to_request();
let resp = block_on(srv.call(req)).unwrap(); let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
let mut srv = let mut srv = init_service(App::new().data(10u32).service(
init_service(App::new().data(10u32).service(
web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()), web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()),
)); ))
let req = TestRequest::default().to_request(); .await;
let resp = block_on(srv.call(req)).unwrap(); let req = TestRequest::default().to_request();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
})
} }
#[test] #[test]
fn test_register_data_extractor() { fn test_register_data_extractor() {
let mut srv = block_on(async {
init_service(App::new().register_data(Data::new(10usize)).service( let mut srv =
web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()), init_service(App::new().register_data(Data::new(10usize)).service(
)); web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()),
))
.await;
let req = TestRequest::default().to_request(); let req = TestRequest::default().to_request();
let resp = block_on(srv.call(req)).unwrap(); let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
let mut srv = let mut srv =
init_service(App::new().register_data(Data::new(10u32)).service( init_service(App::new().register_data(Data::new(10u32)).service(
web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()), web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()),
)); ))
let req = TestRequest::default().to_request(); .await;
let resp = block_on(srv.call(req)).unwrap(); let req = TestRequest::default().to_request();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
})
} }
#[test] #[test]
fn test_route_data_extractor() { fn test_route_data_extractor() {
let mut srv = block_on(async {
init_service(App::new().service(web::resource("/").data(10usize).route( let mut srv = init_service(App::new().service(
web::get().to(|data: web::Data<usize>| { web::resource("/").data(10usize).route(web::get().to(
let _ = data.clone(); |data: web::Data<usize>| {
HttpResponse::Ok() let _ = data.clone();
}), HttpResponse::Ok()
))); },
)),
))
.await;
let req = TestRequest::default().to_request(); let req = TestRequest::default().to_request();
let resp = block_on(srv.call(req)).unwrap(); let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
// different type // different type
let mut srv = init_service( let mut srv = init_service(
App::new().service( App::new().service(
web::resource("/") web::resource("/")
.data(10u32) .data(10u32)
.route(web::get().to(|_: web::Data<usize>| HttpResponse::Ok())), .route(web::get().to(|_: web::Data<usize>| HttpResponse::Ok())),
), ),
); )
let req = TestRequest::default().to_request(); .await;
let resp = block_on(srv.call(req)).unwrap(); let req = TestRequest::default().to_request();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
})
} }
#[test] #[test]
fn test_override_data() { fn test_override_data() {
let mut srv = init_service(App::new().data(1usize).service( block_on(async {
web::resource("/").data(10usize).route(web::get().to( let mut srv = init_service(App::new().data(1usize).service(
|data: web::Data<usize>| { web::resource("/").data(10usize).route(web::get().to(
assert_eq!(*data, 10); |data: web::Data<usize>| {
let _ = data.clone(); assert_eq!(*data, 10);
HttpResponse::Ok() let _ = data.clone();
}, HttpResponse::Ok()
)), },
)); )),
))
.await;
let req = TestRequest::default().to_request(); let req = TestRequest::default().to_request();
let resp = block_on(srv.call(req)).unwrap(); let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
})
} }
} }

View File

@@ -1,8 +1,10 @@
//! Request extractors //! Request extractors
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_http::error::Error; use actix_http::error::Error;
use futures::future::ok; use futures::future::{ok, FutureExt, LocalBoxFuture, Ready};
use futures::{future, Async, Future, IntoFuture, Poll};
use crate::dev::Payload; use crate::dev::Payload;
use crate::request::HttpRequest; use crate::request::HttpRequest;
@@ -15,7 +17,7 @@ pub trait FromRequest: Sized {
type Error: Into<Error>; type Error: Into<Error>;
/// Future that resolves to a Self /// Future that resolves to a Self
type Future: IntoFuture<Item = Self, Error = Self::Error>; type Future: Future<Output = Result<Self, Self::Error>>;
/// Configuration for this extractor /// Configuration for this extractor
type Config: Default + 'static; type Config: Default + 'static;
@@ -48,6 +50,7 @@ pub trait FromRequest: Sized {
/// ```rust /// ```rust
/// use actix_web::{web, dev, App, Error, HttpRequest, FromRequest}; /// use actix_web::{web, dev, App, Error, HttpRequest, FromRequest};
/// use actix_web::error::ErrorBadRequest; /// use actix_web::error::ErrorBadRequest;
/// use futures::future::{ok, err, Ready};
/// use serde_derive::Deserialize; /// use serde_derive::Deserialize;
/// use rand; /// use rand;
/// ///
@@ -58,21 +61,21 @@ pub trait FromRequest: Sized {
/// ///
/// impl FromRequest for Thing { /// impl FromRequest for Thing {
/// type Error = Error; /// type Error = Error;
/// type Future = Result<Self, Self::Error>; /// type Future = Ready<Result<Self, Self::Error>>;
/// type Config = (); /// type Config = ();
/// ///
/// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { /// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
/// if rand::random() { /// if rand::random() {
/// Ok(Thing { name: "thingy".into() }) /// ok(Thing { name: "thingy".into() })
/// } else { /// } else {
/// Err(ErrorBadRequest("no luck")) /// err(ErrorBadRequest("no luck"))
/// } /// }
/// ///
/// } /// }
/// } /// }
/// ///
/// /// extract `Thing` from request /// /// extract `Thing` from request
/// fn index(supplied_thing: Option<Thing>) -> String { /// async fn index(supplied_thing: Option<Thing>) -> String {
/// match supplied_thing { /// match supplied_thing {
/// // Puns not intended /// // Puns not intended
/// Some(thing) => format!("Got something: {:?}", thing), /// Some(thing) => format!("Got something: {:?}", thing),
@@ -94,21 +97,19 @@ where
{ {
type Config = T::Config; type Config = T::Config;
type Error = Error; type Error = Error;
type Future = Box<dyn Future<Item = Option<T>, Error = Error>>; type Future = LocalBoxFuture<'static, Result<Option<T>, Error>>;
#[inline] #[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
Box::new( T::from_request(req, payload)
T::from_request(req, payload) .then(|r| match r {
.into_future() Ok(v) => ok(Some(v)),
.then(|r| match r { Err(e) => {
Ok(v) => future::ok(Some(v)), log::debug!("Error for Option<T> extractor: {}", e.into());
Err(e) => { ok(None)
log::debug!("Error for Option<T> extractor: {}", e.into()); }
future::ok(None) })
} .boxed_local()
}),
)
} }
} }
@@ -121,6 +122,7 @@ where
/// ```rust /// ```rust
/// use actix_web::{web, dev, App, Result, Error, HttpRequest, FromRequest}; /// use actix_web::{web, dev, App, Result, Error, HttpRequest, FromRequest};
/// use actix_web::error::ErrorBadRequest; /// use actix_web::error::ErrorBadRequest;
/// use futures::future::{ok, err, Ready};
/// use serde_derive::Deserialize; /// use serde_derive::Deserialize;
/// use rand; /// use rand;
/// ///
@@ -131,20 +133,20 @@ where
/// ///
/// impl FromRequest for Thing { /// impl FromRequest for Thing {
/// type Error = Error; /// type Error = Error;
/// type Future = Result<Thing, Error>; /// type Future = Ready<Result<Thing, Error>>;
/// type Config = (); /// type Config = ();
/// ///
/// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { /// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
/// if rand::random() { /// if rand::random() {
/// Ok(Thing { name: "thingy".into() }) /// ok(Thing { name: "thingy".into() })
/// } else { /// } else {
/// Err(ErrorBadRequest("no luck")) /// err(ErrorBadRequest("no luck"))
/// } /// }
/// } /// }
/// } /// }
/// ///
/// /// extract `Thing` from request /// /// extract `Thing` from request
/// fn index(supplied_thing: Result<Thing>) -> String { /// async fn index(supplied_thing: Result<Thing>) -> String {
/// match supplied_thing { /// match supplied_thing {
/// Ok(thing) => format!("Got thing: {:?}", thing), /// Ok(thing) => format!("Got thing: {:?}", thing),
/// Err(e) => format!("Error extracting thing: {}", e) /// Err(e) => format!("Error extracting thing: {}", e)
@@ -157,26 +159,24 @@ where
/// ); /// );
/// } /// }
/// ``` /// ```
impl<T: 'static> FromRequest for Result<T, T::Error> impl<T> FromRequest for Result<T, T::Error>
where where
T: FromRequest, T: FromRequest + 'static,
T::Future: 'static,
T::Error: 'static, T::Error: 'static,
T::Future: 'static,
{ {
type Config = T::Config; type Config = T::Config;
type Error = Error; type Error = Error;
type Future = Box<dyn Future<Item = Result<T, T::Error>, Error = Error>>; type Future = LocalBoxFuture<'static, Result<Result<T, T::Error>, Error>>;
#[inline] #[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
Box::new( T::from_request(req, payload)
T::from_request(req, payload) .then(|res| match res {
.into_future() Ok(v) => ok(Ok(v)),
.then(|res| match res { Err(e) => ok(Err(e)),
Ok(v) => ok(Ok(v)), })
Err(e) => ok(Err(e)), .boxed_local()
}),
)
} }
} }
@@ -184,10 +184,10 @@ where
impl FromRequest for () { impl FromRequest for () {
type Config = (); type Config = ();
type Error = Error; type Error = Error;
type Future = Result<(), Error>; type Future = Ready<Result<(), Error>>;
fn from_request(_: &HttpRequest, _: &mut Payload) -> Self::Future { fn from_request(_: &HttpRequest, _: &mut Payload) -> Self::Future {
Ok(()) ok(())
} }
} }
@@ -204,43 +204,44 @@ macro_rules! tuple_from_req ({$fut_type:ident, $(($n:tt, $T:ident)),+} => {
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
$fut_type { $fut_type {
items: <($(Option<$T>,)+)>::default(), items: <($(Option<$T>,)+)>::default(),
futs: ($($T::from_request(req, payload).into_future(),)+), futs: ($($T::from_request(req, payload),)+),
} }
} }
} }
#[doc(hidden)] #[doc(hidden)]
#[pin_project::pin_project]
pub struct $fut_type<$($T: FromRequest),+> { pub struct $fut_type<$($T: FromRequest),+> {
items: ($(Option<$T>,)+), items: ($(Option<$T>,)+),
futs: ($(<$T::Future as futures::IntoFuture>::Future,)+), futs: ($($T::Future,)+),
} }
impl<$($T: FromRequest),+> Future for $fut_type<$($T),+> impl<$($T: FromRequest),+> Future for $fut_type<$($T),+>
{ {
type Item = ($($T,)+); type Output = Result<($($T,)+), Error>;
type Error = Error;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut ready = true; let mut ready = true;
$( $(
if self.items.$n.is_none() { if this.items.$n.is_none() {
match self.futs.$n.poll() { match unsafe { Pin::new_unchecked(&mut this.futs.$n) }.poll(cx) {
Ok(Async::Ready(item)) => { Poll::Ready(Ok(item)) => {
self.items.$n = Some(item); this.items.$n = Some(item);
} }
Ok(Async::NotReady) => ready = false, Poll::Pending => ready = false,
Err(e) => return Err(e.into()), Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())),
} }
} }
)+ )+
if ready { if ready {
Ok(Async::Ready( Poll::Ready(Ok(
($(self.items.$n.take().unwrap(),)+) ($(this.items.$n.take().unwrap(),)+)
)) ))
} else { } else {
Ok(Async::NotReady) Poll::Pending
} }
} }
} }

Some files were not shown because too many files have changed in this diff Show More