1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-07-15 14:26:15 +02:00

Compare commits

...

42 Commits

Author SHA1 Message Date
Nikolay Kim
57981ca04a update tests to async handlers 2019-11-22 11:49:35 +06:00
Nikolay Kim
e668acc596 update travis config 2019-11-22 10:13:32 +06:00
Nikolay Kim
512dd2be63 disable rustls support 2019-11-22 07:01:05 +06:00
Nikolay Kim
8683ba8bb0 rename .to_async() to .to() 2019-11-21 21:36:35 +06:00
Nikolay Kim
0b9e3d381b add test with custom connector 2019-11-21 17:36:18 +06:00
Nikolay Kim
1f0577f8d5 cleanup api doc examples 2019-11-21 16:02:17 +06:00
Nikolay Kim
53c5151692 use response instead of result for asyn c handlers 2019-11-21 16:02:17 +06:00
Nikolay Kim
55698f2524 migrade rest of middlewares 2019-11-21 16:02:17 +06:00
Nikolay Kim
471f82f0e0 migrate actix-multipart 2019-11-21 16:02:17 +06:00
Nikolay Kim
60ada97b3d migrate actix-session 2019-11-21 16:02:17 +06:00
Nikolay Kim
0de101bc4d update actix-web-codegen tests 2019-11-21 16:02:17 +06:00
Nikolay Kim
95e2a0ef2e migrate actix-framed 2019-11-21 16:02:17 +06:00
Nikolay Kim
69cadcdedb migrate actix-files 2019-11-21 16:02:17 +06:00
Nikolay Kim
6ac4ac66b9 migrate actix-cors 2019-11-21 16:02:17 +06:00
Nikolay Kim
3646725cf6 migrate actix-identity 2019-11-21 16:02:17 +06:00
Nikolay Kim
ff62facc0d disable unmigrated crates 2019-11-21 16:02:17 +06:00
Nikolay Kim
b510527a9f update awc tests 2019-11-21 16:02:17 +06:00
Nikolay Kim
3127dd4db6 migrate actix-web to std::future 2019-11-21 16:02:17 +06:00
Nikolay Kim
d081e57316 fix h2 client send body 2019-11-21 16:02:17 +06:00
Nikolay Kim
1ffa7d18d3 drop unpin constraint 2019-11-21 16:02:17 +06:00
Nikolay Kim
687884fb94 update test-server tests 2019-11-21 16:02:17 +06:00
Nikolay Kim
5ab29b2e62 migrate awc and test-server to std::future 2019-11-21 16:02:17 +06:00
Nikolay Kim
a6a2d2f444 update ssl impls 2019-11-21 16:02:17 +06:00
Nikolay Kim
9e95efcc16 migrate client to std::future 2019-11-21 16:02:17 +06:00
Nikolay Kim
8cba1170e6 make actix-http compile with std::future 2019-11-21 16:02:17 +06:00
Nikolay Kim
5cb2d500d1 update actix-web-actors 2019-11-14 08:58:24 +06:00
Nikolay Kim
0212c618c6 prepare actix-web release 2019-11-14 08:55:37 +06:00
Feiko Nanninga
88110ed268 Add security note to ConnectionInfo::remote() (#1158) 2019-11-14 08:32:47 +06:00
Nikolay Kim
fba02fdd8c prep awc release 2019-11-06 11:33:25 -08:00
Nikolay Kim
b2934ad8d2 prep actix-file release 2019-11-06 11:25:26 -08:00
Nikolay Kim
f7f410d033 fix test order dep 2019-11-06 11:20:47 -08:00
Nikolay Kim
885ff7396e prepare actox-http release 2019-11-06 10:35:13 -08:00
Erlend Langseth
61b38e8d0d Increase timeouts in test-server (#1153) 2019-11-06 06:09:22 -08:00
Hung-I Wang
edcde67076 Fix escaping/encoding problems in Content-Disposition header (#1151)
* Fix filename encoding in Content-Disposition of acitx_files::NamedFile

* Add more comments on how to use Content-Disposition header properly & Fix some trivial problems

* Improve Content-Disposition filename(*) parameters of actix_files::NamedFile

* Tweak Content-Disposition parse to accept empty param value in quoted-string

* Fix typos in comments in .../content_disposition.rs (pointed out by @JohnTitor)

* Update CHANGES.md

* Update CHANGES.md again
2019-11-06 06:08:37 -08:00
Jonathas Conceição
f0612f7570 awc: Add support for setting query from Serialize type for client request (#1130)
Signed-off-by: Jonathas-Conceicao <jadoliveira@inf.ufpel.edu.br>
2019-10-26 08:27:14 +03:00
Anton Lazarev
ace98e3a1e support Host guards when Host header is unset (#1129) 2019-10-15 05:05:54 +06:00
Nikolay Kim
1ca9d87f0a prep actix-web-codegen release 2019-10-14 21:35:53 +06:00
DanSnow
967f965405 Update syn & quote to 1.0 (#1133)
* chore(actix-web-codegen): Upgrade syn and quote to 1.0

* feat(actix-web-codegen): Generate better error message

* doc(actix-web-codegen): Update CHANGES.md

* fix: Build with stable rust
2019-10-14 21:34:17 +06:00
Nikolay Kim
062e51e8ce prep actix-file release 2019-10-14 21:26:26 +06:00
Roberto Huertas
a48e616def feat(files): add possibility to redirect to slash-ended path (#1134)
When accessing to a folder without a final slash, the index file will be loaded ok, but if it has
references (like a css or an image in an html file) these resources won't be loaded correctly if
they are using relative paths. In order to solve this, this PR adds the possibility to detect
folders without a final slash and make a 302 redirect to mitigate this issue. The behavior is off by
default. We're adding a new method called `redirect_to_slash_directory` which can be used to enable
this behavior.
2019-10-14 21:23:15 +06:00
MaySantucci
effa96f5e4 Removed httpcode 'MovedPermanenty'. (#1128) 2019-10-12 06:45:12 +06:00
Nathan
cc0b4be5b7 Fix typo in response.rs body() comment (#1126)
Fixes https://github.com/actix/actix-web/issues/1125
2019-10-09 19:11:55 +06:00
134 changed files with 10962 additions and 9770 deletions

View File

@@ -10,9 +10,9 @@ matrix:
include:
- rust: stable
- rust: beta
- rust: nightly-2019-08-10
- rust: nightly-2019-11-20
allow_failures:
- rust: nightly-2019-08-10
- rust: nightly-2019-11-20
env:
global:
@@ -25,7 +25,7 @@ before_install:
- sudo apt-get install -y openssl libssl-dev libelf-dev libdw-dev cmake gcc binutils-dev libiberty-dev
before_cache: |
if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-08-10" ]]; then
if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-11-20" ]]; then
RUSTFLAGS="--cfg procmacro2_semver_exempt" cargo install --version 0.6.11 cargo-tarpaulin
fi
@@ -37,8 +37,8 @@ script:
- cargo update
- cargo check --all --no-default-features
- cargo test --all-features --all -- --nocapture
- cd actix-http; cargo test --no-default-features --features="rust-tls" -- --nocapture; cd ..
- cd awc; cargo test --no-default-features --features="rust-tls" -- --nocapture; cd ..
# - cd actix-http; cargo test --no-default-features --features="rustls" -- --nocapture; cd ..
# - cd awc; cargo test --no-default-features --features="rustls" -- --nocapture; cd ..
# Upload docs
after_success:
@@ -51,7 +51,7 @@ after_success:
echo "Uploaded documentation"
fi
- |
if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-08-10" ]]; then
if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-11-20" ]]; then
taskset -c 0 cargo tarpaulin --out Xml --all --all-features
bash <(curl -s https://codecov.io/bash)
echo "Uploaded code coverage"

View File

@@ -1,11 +1,16 @@
# Changes
## [1.0.9] - 2019-xx-xx
## [1.0.9] - 2019-11-14
### Added
* Add `Payload::into_inner` method and make stored `def::Payload` public. (#1110)
### Changed
* Support `Host` guards when the `Host` header is unset (e.g. HTTP/2 requests) (#1129)
## [1.0.8] - 2019-09-25
### Added

View File

@@ -1,6 +1,6 @@
[package]
name = "actix-web"
version = "1.0.8"
version = "2.0.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix web is a simple, pragmatic and extremely fast web framework for Rust."
readme = "README.md"
@@ -16,7 +16,7 @@ exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"]
edition = "2018"
[package.metadata.docs.rs]
features = ["ssl", "brotli", "flate2-zlib", "secure-cookies", "client", "rust-tls", "uds"]
features = ["openssl", "rustls", "brotli", "flate2-zlib", "secure-cookies", "client"]
[badges]
travis-ci = { repository = "actix/actix-web", branch = "master" }
@@ -63,37 +63,35 @@ secure-cookies = ["actix-http/secure-cookies"]
fail = ["actix-http/fail"]
# openssl
ssl = ["openssl", "actix-server/ssl", "awc/ssl"]
openssl = ["open-ssl", "actix-server/openssl", "awc/openssl"]
# rustls
rust-tls = ["rustls", "actix-server/rust-tls", "awc/rust-tls"]
# unix domain sockets support
uds = ["actix-server/uds"]
# rustls = ["rust-tls", "actix-server/rustls", "awc/rustls"]
[dependencies]
actix-codec = "0.1.2"
actix-service = "0.4.1"
actix-utils = "0.4.4"
actix-codec = "0.2.0-alpha.1"
actix-service = "1.0.0-alpha.1"
actix-utils = "0.5.0-alpha.1"
actix-router = "0.1.5"
actix-rt = "0.2.4"
actix-web-codegen = "0.1.2"
actix-http = "0.2.9"
actix-server = "0.6.1"
actix-server-config = "0.1.2"
actix-testing = "0.1.0"
actix-threadpool = "0.1.1"
awc = { version = "0.2.7", optional = true }
actix-rt = "1.0.0-alpha.1"
actix-web-codegen = "0.2.0-alpha.1"
actix-http = "0.3.0-alpha.1"
actix-server = "0.8.0-alpha.1"
actix-server-config = "0.3.0-alpha.1"
actix-testing = "0.3.0-alpha.1"
actix-threadpool = "0.2.0-alpha.1"
awc = { version = "0.3.0-alpha.1", optional = true }
bytes = "0.4"
derive_more = "0.15.0"
encoding_rs = "0.8"
futures = "0.1.25"
hashbrown = "0.5.0"
futures = "0.3.1"
hashbrown = "0.6.3"
log = "0.4"
mime = "0.3"
net2 = "0.2.33"
parking_lot = "0.9"
pin-project = "0.4.5"
regex = "1.0"
serde = { version = "1.0", features=["derive"] }
serde_json = "1.0"
@@ -102,17 +100,17 @@ time = "0.1.42"
url = "2.1"
# ssl support
openssl = { version="0.10", optional = true }
rustls = { version = "0.15", optional = true }
open-ssl = { version="0.10", package="openssl", optional = true }
rust-tls = { version = "0.16", package="rustls", optional = true }
[dev-dependencies]
actix = "0.8.3"
actix-connect = "0.2.2"
actix-http-test = "0.2.4"
# actix = "0.8.3"
actix-connect = "0.3.0-alpha.1"
actix-http-test = "0.3.0-alpha.1"
rand = "0.7"
env_logger = "0.6"
serde_derive = "1.0"
tokio-timer = "0.2.8"
tokio-timer = "0.3.0-alpha.6"
brotli2 = "0.3.2"
flate2 = "1.0.2"
@@ -126,8 +124,28 @@ actix-web = { path = "." }
actix-http = { path = "actix-http" }
actix-http-test = { path = "test-server" }
actix-web-codegen = { path = "actix-web-codegen" }
actix-web-actors = { path = "actix-web-actors" }
# actix-web-actors = { path = "actix-web-actors" }
actix-session = { path = "actix-session" }
actix-files = { path = "actix-files" }
actix-multipart = { path = "actix-multipart" }
awc = { path = "awc" }
actix-codec = { git = "https://github.com/actix/actix-net.git" }
actix-connect = { git = "https://github.com/actix/actix-net.git" }
actix-rt = { git = "https://github.com/actix/actix-net.git" }
actix-server = { git = "https://github.com/actix/actix-net.git" }
actix-server-config = { git = "https://github.com/actix/actix-net.git" }
actix-service = { git = "https://github.com/actix/actix-net.git" }
actix-testing = { git = "https://github.com/actix/actix-net.git" }
actix-threadpool = { git = "https://github.com/actix/actix-net.git" }
actix-utils = { git = "https://github.com/actix/actix-net.git" }
# actix-codec = { path = "../actix-net/actix-codec" }
# actix-connect = { path = "../actix-net/actix-connect" }
# actix-rt = { path = "../actix-net/actix-rt" }
# actix-server = { path = "../actix-net/actix-server" }
# actix-server-config = { path = "../actix-net/actix-server-config" }
# actix-service = { path = "../actix-net/actix-service" }
# actix-testing = { path = "../actix-net/actix-testing" }
# actix-threadpool = { path = "../actix-net/actix-threadpool" }
# actix-utils = { path = "../actix-net/actix-utils" }

View File

@@ -1,3 +1,10 @@
## 2.0.0
* Sync handlers has been removed. `.to_async()` methtod has been renamed to `.to()`
replace `fn` with `async fn` to convert sync handler to async
## 1.0.1
* Cors middleware has been moved to `actix-cors` crate

View File

@@ -19,17 +19,16 @@ Actix web is a simple, pragmatic and extremely fast web framework for Rust.
* [User Guide](https://actix.rs/docs/)
* [API Documentation (1.0)](https://docs.rs/actix-web/)
* [API Documentation (0.7)](https://docs.rs/actix-web/0.7.19/actix_web/)
* [Chat on gitter](https://gitter.im/actix/actix)
* Cargo package: [actix-web](https://crates.io/crates/actix-web)
* Minimum supported Rust version: 1.36 or later
* Minimum supported Rust version: 1.39 or later
## Example
```rust
use actix_web::{web, App, HttpServer, Responder};
fn index(info: web::Path<(u32, String)>) -> impl Responder {
async fn index(info: web::Path<(u32, String)>) -> impl Responder {
format!("Hello {}! id:{}", info.1, info.0)
}

View File

@@ -1,6 +1,6 @@
[package]
name = "actix-cors"
version = "0.1.0"
version = "0.2.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Cross-origin resource sharing (CORS) for Actix applications."
readme = "README.md"
@@ -10,14 +10,14 @@ repository = "https://github.com/actix/actix-web.git"
documentation = "https://docs.rs/actix-cors/"
license = "MIT/Apache-2.0"
edition = "2018"
#workspace = ".."
workspace = ".."
[lib]
name = "actix_cors"
path = "src/lib.rs"
[dependencies]
actix-web = "1.0.0"
actix-service = "0.4.0"
actix-web = "2.0.0-alpha.1"
actix-service = "1.0.0-alpha.1"
derive_more = "0.15.0"
futures = "0.1.25"
futures = "0.3.1"

View File

@@ -11,7 +11,7 @@
//! use actix_cors::Cors;
//! use actix_web::{http, web, App, HttpRequest, HttpResponse, HttpServer};
//!
//! fn index(req: HttpRequest) -> &'static str {
//! async fn index(req: HttpRequest) -> &'static str {
//! "Hello world"
//! }
//!
@@ -23,7 +23,8 @@
//! .allowed_methods(vec!["GET", "POST"])
//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT])
//! .allowed_header(http::header::CONTENT_TYPE)
//! .max_age(3600))
//! .max_age(3600)
//! .finish())
//! .service(
//! web::resource("/index.html")
//! .route(web::get().to(index))
@@ -41,16 +42,16 @@
use std::collections::HashSet;
use std::iter::FromIterator;
use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::{IntoTransform, Service, Transform};
use actix_service::{Service, Transform};
use actix_web::dev::{RequestHead, ServiceRequest, ServiceResponse};
use actix_web::error::{Error, ResponseError, Result};
use actix_web::http::header::{self, HeaderName, HeaderValue};
use actix_web::http::{self, HttpTryFrom, Method, StatusCode, Uri};
use actix_web::HttpResponse;
use derive_more::Display;
use futures::future::{ok, Either, Future, FutureResult};
use futures::Poll;
use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready};
/// A set of errors that can occur during processing CORS
#[derive(Debug, Display)]
@@ -456,25 +457,9 @@ impl Cors {
}
self
}
}
fn cors<'a>(
parts: &'a mut Option<Inner>,
err: &Option<http::Error>,
) -> Option<&'a mut Inner> {
if err.is_some() {
return None;
}
parts.as_mut()
}
impl<S, B> IntoTransform<CorsFactory, S> for Cors
where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
fn into_transform(self) -> CorsFactory {
/// Construct cors middleware
pub fn finish(self) -> CorsFactory {
let mut slf = if !self.methods {
self.allowed_methods(vec![
Method::GET,
@@ -521,6 +506,16 @@ where
}
}
fn cors<'a>(
parts: &'a mut Option<Inner>,
err: &Option<http::Error>,
) -> Option<&'a mut Inner> {
if err.is_some() {
return None;
}
parts.as_mut()
}
/// `Middleware` for Cross-origin resource sharing support
///
/// The Cors struct contains the settings for CORS requests to be validated and
@@ -540,7 +535,7 @@ where
type Error = Error;
type InitError = ();
type Transform = CorsMiddleware<S>;
type Future = FutureResult<Self::Transform, Self::InitError>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(CorsMiddleware {
@@ -682,12 +677,12 @@ where
type Response = ServiceResponse<B>;
type Error = Error;
type Future = Either<
FutureResult<Self::Response, Error>,
Either<S::Future, Box<dyn Future<Item = Self::Response, Error = Error>>>,
Ready<Result<Self::Response, Error>>,
LocalBoxFuture<'static, Result<Self::Response, Error>>,
>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.poll_ready()
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: ServiceRequest) -> Self::Future {
@@ -698,7 +693,7 @@ where
.and_then(|_| self.inner.validate_allowed_method(req.head()))
.and_then(|_| self.inner.validate_allowed_headers(req.head()))
{
return Either::A(ok(req.error_response(e)));
return Either::Left(ok(req.error_response(e)));
}
// allowed headers
@@ -751,22 +746,32 @@ where
.finish()
.into_body();
Either::A(ok(req.into_response(res)))
} else if req.headers().contains_key(&header::ORIGIN) {
Either::Left(ok(req.into_response(res)))
} else {
if req.headers().contains_key(&header::ORIGIN) {
// Only check requests with a origin header.
if let Err(e) = self.inner.validate_origin(req.head()) {
return Either::A(ok(req.error_response(e)));
return Either::Left(ok(req.error_response(e)));
}
}
let inner = self.inner.clone();
let has_origin = req.headers().contains_key(&header::ORIGIN);
let fut = self.service.call(req);
Either::B(Either::B(Box::new(self.service.call(req).and_then(
move |mut res| {
Either::Right(
async move {
let res = fut.await;
if has_origin {
let mut res = res?;
if let Some(origin) =
inner.access_control_allow_origin(res.request().head())
{
res.headers_mut()
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
res.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
origin.clone(),
);
};
if let Some(ref expose) = inner.expose_hdrs {
@@ -782,8 +787,9 @@ where
);
}
if inner.vary_header {
let value =
if let Some(hdr) = res.headers_mut().get(&header::VARY) {
let value = if let Some(hdr) =
res.headers_mut().get(&header::VARY)
{
let mut val: Vec<u8> =
Vec::with_capacity(hdr.as_bytes().len() + 8);
val.extend(hdr.as_bytes());
@@ -795,80 +801,73 @@ where
res.headers_mut().insert(header::VARY, value);
}
Ok(res)
},
))))
} else {
Either::B(Either::A(self.service.call(req)))
res
}
}
.boxed_local(),
)
}
}
}
#[cfg(test)]
mod tests {
use actix_service::{IntoService, Transform};
use actix_service::{service_fn2, Transform};
use actix_web::test::{self, block_on, TestRequest};
use super::*;
impl Cors {
fn finish<F, S, B>(self, srv: F) -> CorsMiddleware<S>
where
F: IntoService<S>,
S: Service<
Request = ServiceRequest,
Response = ServiceResponse<B>,
Error = Error,
> + 'static,
S::Future: 'static,
B: 'static,
{
block_on(
IntoTransform::<CorsFactory, S>::into_transform(self)
.new_transform(srv.into_service()),
)
.unwrap()
}
}
#[test]
#[should_panic(expected = "Credentials are allowed, but the Origin is set to")]
fn cors_validates_illegal_allow_credentials() {
let _cors = Cors::new()
.supports_credentials()
.send_wildcard()
.finish(test::ok_service());
let _cors = Cors::new().supports_credentials().send_wildcard().finish();
}
#[test]
fn validate_origin_allows_all_origins() {
let mut cors = Cors::new().finish(test::ok_service());
block_on(async {
let mut cors = Cors::new()
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com")
.to_srv_request();
let resp = test::call_service(&mut cors, req);
let resp = test::call_service(&mut cors, req).await;
assert_eq!(resp.status(), StatusCode::OK);
})
}
#[test]
fn default() {
let mut cors =
block_on(Cors::default().new_transform(test::ok_service())).unwrap();
block_on(async {
let mut cors = Cors::default()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com")
.to_srv_request();
let resp = test::call_service(&mut cors, req);
let resp = test::call_service(&mut cors, req).await;
assert_eq!(resp.status(), StatusCode::OK);
})
}
#[test]
fn test_preflight() {
block_on(async {
let mut cors = Cors::new()
.send_wildcard()
.max_age(3600)
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
.allowed_header(header::CONTENT_TYPE)
.finish(test::ok_service());
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
@@ -877,7 +876,7 @@ mod tests {
assert!(cors.inner.validate_allowed_method(req.head()).is_err());
assert!(cors.inner.validate_allowed_headers(req.head()).is_err());
let resp = test::call_service(&mut cors, req);
let resp = test::call_service(&mut cors, req).await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let req = TestRequest::with_header("Origin", "https://www.example.com")
@@ -897,7 +896,7 @@ mod tests {
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req);
let resp = test::call_service(&mut cors, req).await;
assert_eq!(
&b"*"[..],
resp.headers()
@@ -943,8 +942,9 @@ mod tests {
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req);
let resp = test::call_service(&mut cors, req).await;
assert_eq!(resp.status(), StatusCode::OK);
})
}
// #[test]
@@ -960,9 +960,13 @@ mod tests {
#[test]
#[should_panic(expected = "OriginNotAllowed")]
fn test_validate_not_allowed_origin() {
block_on(async {
let cors = Cors::new()
.allowed_origin("https://www.example.com")
.finish(test::ok_service());
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.unknown.com")
.method(Method::GET)
@@ -970,28 +974,40 @@ mod tests {
cors.inner.validate_origin(req.head()).unwrap();
cors.inner.validate_allowed_method(req.head()).unwrap();
cors.inner.validate_allowed_headers(req.head()).unwrap();
})
}
#[test]
fn test_validate_origin() {
block_on(async {
let mut cors = Cors::new()
.allowed_origin("https://www.example.com")
.finish(test::ok_service());
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::GET)
.to_srv_request();
let resp = test::call_service(&mut cors, req);
let resp = test::call_service(&mut cors, req).await;
assert_eq!(resp.status(), StatusCode::OK);
})
}
#[test]
fn test_no_origin_response() {
let mut cors = Cors::new().disable_preflight().finish(test::ok_service());
block_on(async {
let mut cors = Cors::new()
.disable_preflight()
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::default().method(Method::GET).to_srv_request();
let resp = test::call_service(&mut cors, req);
let resp = test::call_service(&mut cors, req).await;
assert!(resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
@@ -1000,7 +1016,7 @@ mod tests {
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req);
let resp = test::call_service(&mut cors, req).await;
assert_eq!(
&b"https://www.example.com"[..],
resp.headers()
@@ -1008,10 +1024,12 @@ mod tests {
.unwrap()
.as_bytes()
);
})
}
#[test]
fn test_response() {
block_on(async {
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
let mut cors = Cors::new()
.send_wildcard()
@@ -1021,13 +1039,16 @@ mod tests {
.allowed_headers(exposed_headers.clone())
.expose_headers(exposed_headers.clone())
.allowed_header(header::CONTENT_TYPE)
.finish(test::ok_service());
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req);
let resp = test::call_service(&mut cors, req).await;
assert_eq!(
&b"*"[..],
resp.headers()
@@ -1065,15 +1086,18 @@ mod tests {
.allowed_headers(exposed_headers.clone())
.expose_headers(exposed_headers.clone())
.allowed_header(header::CONTENT_TYPE)
.finish(|req: ServiceRequest| {
req.into_response(
.finish()
.new_transform(service_fn2(|req: ServiceRequest| {
ok(req.into_response(
HttpResponse::Ok().header(header::VARY, "Accept").finish(),
)
});
))
}))
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req);
let resp = test::call_service(&mut cors, req).await;
assert_eq!(
&b"Accept, Origin"[..],
resp.headers().get(header::VARY).unwrap().as_bytes()
@@ -1083,13 +1107,16 @@ mod tests {
.disable_vary_header()
.allowed_origin("https://www.example.com")
.allowed_origin("https://www.google.com")
.finish(test::ok_service());
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS)
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
.to_srv_request();
let resp = test::call_service(&mut cors, req);
let resp = test::call_service(&mut cors, req).await;
let origins_str = resp
.headers()
@@ -1099,21 +1126,26 @@ mod tests {
.unwrap();
assert_eq!("https://www.example.com", origins_str);
})
}
#[test]
fn test_multiple_origins() {
block_on(async {
let mut cors = Cors::new()
.allowed_origin("https://example.com")
.allowed_origin("https://example.org")
.allowed_methods(vec![Method::GET])
.finish(test::ok_service());
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://example.com")
.method(Method::GET)
.to_srv_request();
let resp = test::call_service(&mut cors, req);
let resp = test::call_service(&mut cors, req).await;
assert_eq!(
&b"https://example.com"[..],
resp.headers()
@@ -1126,7 +1158,7 @@ mod tests {
.method(Method::GET)
.to_srv_request();
let resp = test::call_service(&mut cors, req);
let resp = test::call_service(&mut cors, req).await;
assert_eq!(
&b"https://example.org"[..],
resp.headers()
@@ -1134,22 +1166,27 @@ mod tests {
.unwrap()
.as_bytes()
);
})
}
#[test]
fn test_multiple_origins_preflight() {
block_on(async {
let mut cors = Cors::new()
.allowed_origin("https://example.com")
.allowed_origin("https://example.org")
.allowed_methods(vec![Method::GET])
.finish(test::ok_service());
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req);
let resp = test::call_service(&mut cors, req).await;
assert_eq!(
&b"https://example.com"[..],
resp.headers()
@@ -1163,7 +1200,7 @@ mod tests {
.method(Method::OPTIONS)
.to_srv_request();
let resp = test::call_service(&mut cors, req);
let resp = test::call_service(&mut cors, req).await;
assert_eq!(
&b"https://example.org"[..],
resp.headers()
@@ -1171,5 +1208,6 @@ mod tests {
.unwrap()
.as_bytes()
);
})
}
}

View File

@@ -1,5 +1,13 @@
# Changes
## [0.1.7] - 2019-11-06
* Add an additional `filename*` param in the `Content-Disposition` header of `actix_files::NamedFile` to be more compatible. (#1151)
## [0.1.6] - 2019-10-14
* Add option to redirect to a slash-ended path `Files` #1132
## [0.1.5] - 2019-10-08
* Bump up `mime_guess` crate version to 2.0.1
@@ -8,17 +16,14 @@
* Allow user defined request guards for `Files` #1113
## [0.1.4] - 2019-07-20
* Allow to disable `Content-Disposition` header #686
## [0.1.3] - 2019-06-28
* Do not set `Content-Length` header, let actix-http set it #930
## [0.1.2] - 2019-06-13
* Content-Length is 0 for NamedFile HEAD request #914

View File

@@ -1,6 +1,6 @@
[package]
name = "actix-files"
version = "0.1.5"
version = "0.2.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Static files support for actix web."
readme = "README.md"
@@ -18,12 +18,12 @@ name = "actix_files"
path = "src/lib.rs"
[dependencies]
actix-web = { version = "1.0.8", default-features = false }
actix-http = "0.2.9"
actix-service = "0.4.1"
actix-web = { version = "2.0.0-alpha.1", default-features = false }
actix-http = "0.3.0-alpha.1"
actix-service = "1.0.0-alpha.1"
bitflags = "1"
bytes = "0.4"
futures = "0.1.25"
futures = "0.3.1"
derive_more = "0.15.0"
log = "0.4"
mime = "0.3"
@@ -32,4 +32,4 @@ percent-encoding = "2.1"
v_htmlescape = "0.4"
[dev-dependencies]
actix-web = { version = "1.0.8", features=["ssl"] }
actix-web = { version = "2.0.0-alpha.1", features=["openssl"] }

File diff suppressed because it is too large Load Diff

View File

@@ -13,11 +13,12 @@ use mime_guess::from_path;
use actix_http::body::SizedStream;
use actix_web::http::header::{
self, ContentDisposition, DispositionParam, DispositionType,
self, Charset, ContentDisposition, DispositionParam, DispositionType, ExtendedValue,
};
use actix_web::http::{ContentEncoding, StatusCode};
use actix_web::middleware::BodyEncoding;
use actix_web::{Error, HttpMessage, HttpRequest, HttpResponse, Responder};
use futures::future::{ready, Ready};
use crate::range::HttpRange;
use crate::ChunkedReadFile;
@@ -93,9 +94,18 @@ impl NamedFile {
mime::IMAGE | mime::TEXT | mime::VIDEO => DispositionType::Inline,
_ => DispositionType::Attachment,
};
let mut parameters =
vec![DispositionParam::Filename(String::from(filename.as_ref()))];
if !filename.is_ascii() {
parameters.push(DispositionParam::FilenameExt(ExtendedValue {
charset: Charset::Ext(String::from("UTF-8")),
language_tag: None,
value: filename.into_owned().into_bytes(),
}))
}
let cd = ContentDisposition {
disposition: disposition_type,
parameters: vec![DispositionParam::Filename(filename.into_owned())],
parameters: parameters,
};
(ct, cd)
};
@@ -246,62 +256,8 @@ impl NamedFile {
pub(crate) fn last_modified(&self) -> Option<header::HttpDate> {
self.modified.map(|mtime| mtime.into())
}
}
impl Deref for NamedFile {
type Target = File;
fn deref(&self) -> &File {
&self.file
}
}
impl DerefMut for NamedFile {
fn deref_mut(&mut self) -> &mut File {
&mut self.file
}
}
/// Returns true if `req` has no `If-Match` header or one which matches `etag`.
fn any_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
match req.get_header::<header::IfMatch>() {
None | Some(header::IfMatch::Any) => true,
Some(header::IfMatch::Items(ref items)) => {
if let Some(some_etag) = etag {
for item in items {
if item.strong_eq(some_etag) {
return true;
}
}
}
false
}
}
}
/// Returns true if `req` doesn't have an `If-None-Match` header matching `req`.
fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
match req.get_header::<header::IfNoneMatch>() {
Some(header::IfNoneMatch::Any) => false,
Some(header::IfNoneMatch::Items(ref items)) => {
if let Some(some_etag) = etag {
for item in items {
if item.weak_eq(some_etag) {
return false;
}
}
}
true
}
None => true,
}
}
impl Responder for NamedFile {
type Error = Error;
type Future = Result<HttpResponse, Error>;
fn respond_to(self, req: &HttpRequest) -> Self::Future {
pub fn into_response(self, req: &HttpRequest) -> Result<HttpResponse, Error> {
if self.status_code != StatusCode::OK {
let mut resp = HttpResponse::build(self.status_code);
resp.set(header::ContentType(self.content_type.clone()))
@@ -433,8 +389,67 @@ impl Responder for NamedFile {
counter: 0,
};
if offset != 0 || length != self.md.len() {
return Ok(resp.status(StatusCode::PARTIAL_CONTENT).streaming(reader));
};
Ok(resp.status(StatusCode::PARTIAL_CONTENT).streaming(reader))
} else {
Ok(resp.body(SizedStream::new(length, reader)))
}
}
}
impl Deref for NamedFile {
type Target = File;
fn deref(&self) -> &File {
&self.file
}
}
impl DerefMut for NamedFile {
fn deref_mut(&mut self) -> &mut File {
&mut self.file
}
}
/// Returns true if `req` has no `If-Match` header or one which matches `etag`.
fn any_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
match req.get_header::<header::IfMatch>() {
None | Some(header::IfMatch::Any) => true,
Some(header::IfMatch::Items(ref items)) => {
if let Some(some_etag) = etag {
for item in items {
if item.strong_eq(some_etag) {
return true;
}
}
}
false
}
}
}
/// Returns true if `req` doesn't have an `If-None-Match` header matching `req`.
fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
match req.get_header::<header::IfNoneMatch>() {
Some(header::IfNoneMatch::Any) => false,
Some(header::IfNoneMatch::Items(ref items)) => {
if let Some(some_etag) = etag {
for item in items {
if item.weak_eq(some_etag) {
return false;
}
}
}
true
}
None => true,
}
}
impl Responder for NamedFile {
type Error = Error;
type Future = Ready<Result<HttpResponse, Error>>;
fn respond_to(self, req: &HttpRequest) -> Self::Future {
ready(self.into_response(req))
}
}

View File

@@ -1,6 +1,6 @@
[package]
name = "actix-framed"
version = "0.2.1"
version = "0.3.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix framed app server"
readme = "README.md"
@@ -20,19 +20,20 @@ name = "actix_framed"
path = "src/lib.rs"
[dependencies]
actix-codec = "0.1.2"
actix-service = "0.4.1"
actix-codec = "0.2.0-alpha.1"
actix-service = "1.0.0-alpha.1"
actix-router = "0.1.2"
actix-rt = "0.2.2"
actix-http = "0.2.7"
actix-server-config = "0.1.2"
actix-rt = "1.0.0-alpha.1"
actix-http = "0.3.0-alpha.1"
actix-server-config = "0.3.0-alpha.1"
bytes = "0.4"
futures = "0.1.25"
futures = "0.3.1"
pin-project = "0.4.6"
log = "0.4"
[dev-dependencies]
actix-server = { version = "0.6.0", features=["ssl"] }
actix-connect = { version = "0.2.0", features=["ssl"] }
actix-http-test = { version = "0.2.4", features=["ssl"] }
actix-utils = "0.4.4"
actix-server = { version = "0.8.0-alpha.1", features=["openssl"] }
actix-connect = { version = "0.3.0-alpha.1", features=["openssl"] }
actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] }
actix-utils = "0.5.0-alpha.1"

View File

@@ -1,21 +1,24 @@
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::h1::{Codec, SendResponse};
use actix_http::{Error, Request, Response};
use actix_router::{Path, Router, Url};
use actix_server_config::ServerConfig;
use actix_service::{IntoNewService, NewService, Service};
use futures::{Async, Future, Poll};
use actix_service::{IntoServiceFactory, Service, ServiceFactory};
use futures::future::{ok, FutureExt, LocalBoxFuture};
use crate::helpers::{BoxedHttpNewService, BoxedHttpService, HttpNewService};
use crate::request::FramedRequest;
use crate::state::State;
type BoxedResponse = Box<dyn Future<Item = (), Error = Error>>;
type BoxedResponse = LocalBoxFuture<'static, Result<(), Error>>;
pub trait HttpServiceFactory {
type Factory: NewService;
type Factory: ServiceFactory;
fn path(&self) -> &str;
@@ -48,19 +51,19 @@ impl<T: 'static, S: 'static> FramedApp<T, S> {
pub fn service<U>(mut self, factory: U) -> Self
where
U: HttpServiceFactory,
U::Factory: NewService<
U::Factory: ServiceFactory<
Config = (),
Request = FramedRequest<T, S>,
Response = (),
Error = Error,
InitError = (),
> + 'static,
<U::Factory as NewService>::Future: 'static,
<U::Factory as NewService>::Service: Service<
<U::Factory as ServiceFactory>::Future: 'static,
<U::Factory as ServiceFactory>::Service: Service<
Request = FramedRequest<T, S>,
Response = (),
Error = Error,
Future = Box<dyn Future<Item = (), Error = Error>>,
Future = LocalBoxFuture<'static, Result<(), Error>>,
>,
{
let path = factory.path().to_string();
@@ -70,12 +73,12 @@ impl<T: 'static, S: 'static> FramedApp<T, S> {
}
}
impl<T, S> IntoNewService<FramedAppFactory<T, S>> for FramedApp<T, S>
impl<T, S> IntoServiceFactory<FramedAppFactory<T, S>> for FramedApp<T, S>
where
T: AsyncRead + AsyncWrite + 'static,
T: AsyncRead + AsyncWrite + Unpin + 'static,
S: 'static,
{
fn into_new_service(self) -> FramedAppFactory<T, S> {
fn into_factory(self) -> FramedAppFactory<T, S> {
FramedAppFactory {
state: self.state,
services: Rc::new(self.services),
@@ -89,9 +92,9 @@ pub struct FramedAppFactory<T, S> {
services: Rc<Vec<(String, BoxedHttpNewService<FramedRequest<T, S>>)>>,
}
impl<T, S> NewService for FramedAppFactory<T, S>
impl<T, S> ServiceFactory for FramedAppFactory<T, S>
where
T: AsyncRead + AsyncWrite + 'static,
T: AsyncRead + AsyncWrite + Unpin + 'static,
S: 'static,
{
type Config = ServerConfig;
@@ -128,28 +131,30 @@ pub struct CreateService<T, S> {
enum CreateServiceItem<T, S> {
Future(
Option<String>,
Box<dyn Future<Item = BoxedHttpService<FramedRequest<T, S>>, Error = ()>>,
LocalBoxFuture<'static, Result<BoxedHttpService<FramedRequest<T, S>>, ()>>,
),
Service(String, BoxedHttpService<FramedRequest<T, S>>),
}
impl<S: 'static, T: 'static> Future for CreateService<T, S>
where
T: AsyncRead + AsyncWrite,
T: AsyncRead + AsyncWrite + Unpin,
{
type Item = FramedAppService<T, S>;
type Error = ();
type Output = Result<FramedAppService<T, S>, ()>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut done = true;
// poll http services
for item in &mut self.fut {
let res = match item {
CreateServiceItem::Future(ref mut path, ref mut fut) => {
match fut.poll()? {
Async::Ready(service) => Some((path.take().unwrap(), service)),
Async::NotReady => {
match Pin::new(fut).poll(cx) {
Poll::Ready(Ok(service)) => {
Some((path.take().unwrap(), service))
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => {
done = false;
None
}
@@ -176,12 +181,12 @@ where
}
router
});
Ok(Async::Ready(FramedAppService {
Poll::Ready(Ok(FramedAppService {
router: router.finish(),
state: self.state.clone(),
}))
} else {
Ok(Async::NotReady)
Poll::Pending
}
}
}
@@ -193,15 +198,15 @@ pub struct FramedAppService<T, S> {
impl<S: 'static, T: 'static> Service for FramedAppService<T, S>
where
T: AsyncRead + AsyncWrite,
T: AsyncRead + AsyncWrite + Unpin,
{
type Request = (Request, Framed<T, Codec>);
type Response = ();
type Error = Error;
type Future = BoxedResponse;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Ok(Async::Ready(()))
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future {
@@ -210,8 +215,8 @@ where
if let Some((srv, _info)) = self.router.recognize_mut(&mut path) {
return srv.call(FramedRequest::new(req, framed, path, self.state.clone()));
}
Box::new(
SendResponse::new(framed, Response::NotFound().finish()).then(|_| Ok(())),
)
SendResponse::new(framed, Response::NotFound().finish())
.then(|_| ok(()))
.boxed_local()
}
}

View File

@@ -1,36 +1,38 @@
use std::task::{Context, Poll};
use actix_http::Error;
use actix_service::{NewService, Service};
use futures::{Future, Poll};
use actix_service::{Service, ServiceFactory};
use futures::future::{FutureExt, LocalBoxFuture};
pub(crate) type BoxedHttpService<Req> = Box<
dyn Service<
Request = Req,
Response = (),
Error = Error,
Future = Box<dyn Future<Item = (), Error = Error>>,
Future = LocalBoxFuture<'static, Result<(), Error>>,
>,
>;
pub(crate) type BoxedHttpNewService<Req> = Box<
dyn NewService<
dyn ServiceFactory<
Config = (),
Request = Req,
Response = (),
Error = Error,
InitError = (),
Service = BoxedHttpService<Req>,
Future = Box<dyn Future<Item = BoxedHttpService<Req>, Error = ()>>,
Future = LocalBoxFuture<'static, Result<BoxedHttpService<Req>, ()>>,
>,
>;
pub(crate) struct HttpNewService<T: NewService>(T);
pub(crate) struct HttpNewService<T: ServiceFactory>(T);
impl<T> HttpNewService<T>
where
T: NewService<Response = (), Error = Error>,
T: ServiceFactory<Response = (), Error = Error>,
T::Response: 'static,
T::Future: 'static,
T::Service: Service<Future = Box<dyn Future<Item = (), Error = Error>>> + 'static,
T::Service: Service<Future = LocalBoxFuture<'static, Result<(), Error>>> + 'static,
<T::Service as Service>::Future: 'static,
{
pub fn new(service: T) -> Self {
@@ -38,12 +40,12 @@ where
}
}
impl<T> NewService for HttpNewService<T>
impl<T> ServiceFactory for HttpNewService<T>
where
T: NewService<Config = (), Response = (), Error = Error>,
T: ServiceFactory<Config = (), Response = (), Error = Error>,
T::Request: 'static,
T::Future: 'static,
T::Service: Service<Future = Box<dyn Future<Item = (), Error = Error>>> + 'static,
T::Service: Service<Future = LocalBoxFuture<'static, Result<(), Error>>> + 'static,
<T::Service as Service>::Future: 'static,
{
type Config = ();
@@ -52,13 +54,19 @@ where
type Error = Error;
type InitError = ();
type Service = BoxedHttpService<T::Request>;
type Future = Box<dyn Future<Item = Self::Service, Error = ()>>;
type Future = LocalBoxFuture<'static, Result<Self::Service, ()>>;
fn new_service(&self, _: &()) -> Self::Future {
Box::new(self.0.new_service(&()).map_err(|_| ()).and_then(|service| {
let service: BoxedHttpService<_> = Box::new(HttpServiceWrapper { service });
Ok(service)
}))
let fut = self.0.new_service(&());
async move {
fut.await.map_err(|_| ()).map(|service| {
let service: BoxedHttpService<_> =
Box::new(HttpServiceWrapper { service });
service
})
}
.boxed_local()
}
}
@@ -70,7 +78,7 @@ impl<T> Service for HttpServiceWrapper<T>
where
T: Service<
Response = (),
Future = Box<dyn Future<Item = (), Error = Error>>,
Future = LocalBoxFuture<'static, Result<(), Error>>,
Error = Error,
>,
T::Request: 'static,
@@ -78,10 +86,10 @@ where
type Request = T::Request;
type Response = ();
type Error = Error;
type Future = Box<dyn Future<Item = (), Error = Error>>;
type Future = LocalBoxFuture<'static, Result<(), Error>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.poll_ready()
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: Self::Request) -> Self::Future {

View File

@@ -1,11 +1,12 @@
use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::{http::Method, Error};
use actix_service::{NewService, Service};
use futures::future::{ok, FutureResult};
use futures::{Async, Future, IntoFuture, Poll};
use actix_service::{Service, ServiceFactory};
use futures::future::{ok, FutureExt, LocalBoxFuture, Ready};
use log::error;
use crate::app::HttpServiceFactory;
@@ -15,11 +16,11 @@ use crate::request::FramedRequest;
///
/// Route uses builder-like pattern for configuration.
/// If handler is not explicitly set, default *404 Not Found* handler is used.
pub struct FramedRoute<Io, S, F = (), R = ()> {
pub struct FramedRoute<Io, S, F = (), R = (), E = ()> {
handler: F,
pattern: String,
methods: Vec<Method>,
state: PhantomData<(Io, S, R)>,
state: PhantomData<(Io, S, R, E)>,
}
impl<Io, S> FramedRoute<Io, S> {
@@ -53,12 +54,12 @@ impl<Io, S> FramedRoute<Io, S> {
self
}
pub fn to<F, R>(self, handler: F) -> FramedRoute<Io, S, F, R>
pub fn to<F, R, E>(self, handler: F) -> FramedRoute<Io, S, F, R, E>
where
F: FnMut(FramedRequest<Io, S>) -> R,
R: IntoFuture<Item = ()>,
R::Future: 'static,
R::Error: fmt::Debug,
R: Future<Output = Result<(), E>> + 'static,
E: fmt::Debug,
{
FramedRoute {
handler,
@@ -69,15 +70,14 @@ impl<Io, S> FramedRoute<Io, S> {
}
}
impl<Io, S, F, R> HttpServiceFactory for FramedRoute<Io, S, F, R>
impl<Io, S, F, R, E> HttpServiceFactory for FramedRoute<Io, S, F, R, E>
where
Io: AsyncRead + AsyncWrite + 'static,
F: FnMut(FramedRequest<Io, S>) -> R + Clone,
R: IntoFuture<Item = ()>,
R::Future: 'static,
R::Error: fmt::Display,
R: Future<Output = Result<(), E>> + 'static,
E: fmt::Display,
{
type Factory = FramedRouteFactory<Io, S, F, R>;
type Factory = FramedRouteFactory<Io, S, F, R, E>;
fn path(&self) -> &str {
&self.pattern
@@ -92,27 +92,26 @@ where
}
}
pub struct FramedRouteFactory<Io, S, F, R> {
pub struct FramedRouteFactory<Io, S, F, R, E> {
handler: F,
methods: Vec<Method>,
_t: PhantomData<(Io, S, R)>,
_t: PhantomData<(Io, S, R, E)>,
}
impl<Io, S, F, R> NewService for FramedRouteFactory<Io, S, F, R>
impl<Io, S, F, R, E> ServiceFactory for FramedRouteFactory<Io, S, F, R, E>
where
Io: AsyncRead + AsyncWrite + 'static,
F: FnMut(FramedRequest<Io, S>) -> R + Clone,
R: IntoFuture<Item = ()>,
R::Future: 'static,
R::Error: fmt::Display,
R: Future<Output = Result<(), E>> + 'static,
E: fmt::Display,
{
type Config = ();
type Request = FramedRequest<Io, S>;
type Response = ();
type Error = Error;
type InitError = ();
type Service = FramedRouteService<Io, S, F, R>;
type Future = FutureResult<Self::Service, Self::InitError>;
type Service = FramedRouteService<Io, S, F, R, E>;
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: &()) -> Self::Future {
ok(FramedRouteService {
@@ -123,35 +122,38 @@ where
}
}
pub struct FramedRouteService<Io, S, F, R> {
pub struct FramedRouteService<Io, S, F, R, E> {
handler: F,
methods: Vec<Method>,
_t: PhantomData<(Io, S, R)>,
_t: PhantomData<(Io, S, R, E)>,
}
impl<Io, S, F, R> Service for FramedRouteService<Io, S, F, R>
impl<Io, S, F, R, E> Service for FramedRouteService<Io, S, F, R, E>
where
Io: AsyncRead + AsyncWrite + 'static,
F: FnMut(FramedRequest<Io, S>) -> R + Clone,
R: IntoFuture<Item = ()>,
R::Future: 'static,
R::Error: fmt::Display,
R: Future<Output = Result<(), E>> + 'static,
E: fmt::Display,
{
type Request = FramedRequest<Io, S>;
type Response = ();
type Error = Error;
type Future = Box<dyn Future<Item = (), Error = Error>>;
type Future = LocalBoxFuture<'static, Result<(), Error>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Ok(Async::Ready(()))
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: FramedRequest<Io, S>) -> Self::Future {
Box::new((self.handler)(req).into_future().then(|res| {
let fut = (self.handler)(req);
async move {
let res = fut.await;
if let Err(e) = res {
error!("Error in request handler: {}", e);
}
Ok(())
}))
}
.boxed_local()
}
}

View File

@@ -1,4 +1,6 @@
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::body::BodySize;
@@ -6,9 +8,9 @@ use actix_http::error::ResponseError;
use actix_http::h1::{Codec, Message};
use actix_http::ws::{verify_handshake, HandshakeError};
use actix_http::{Request, Response};
use actix_service::{NewService, Service};
use futures::future::{ok, Either, FutureResult};
use futures::{Async, Future, IntoFuture, Poll, Sink};
use actix_service::{Service, ServiceFactory};
use futures::future::{err, ok, Either, Ready};
use futures::Future;
/// Service that verifies incoming request if it is valid websocket
/// upgrade request. In case of error returns `HandshakeError`
@@ -22,14 +24,14 @@ impl<T, C> Default for VerifyWebSockets<T, C> {
}
}
impl<T, C> NewService for VerifyWebSockets<T, C> {
impl<T, C> ServiceFactory for VerifyWebSockets<T, C> {
type Config = C;
type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>);
type InitError = ();
type Service = VerifyWebSockets<T, C>;
type Future = FutureResult<Self::Service, Self::InitError>;
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: &C) -> Self::Future {
ok(VerifyWebSockets { _t: PhantomData })
@@ -40,16 +42,16 @@ impl<T, C> Service for VerifyWebSockets<T, C> {
type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>);
type Future = FutureResult<Self::Response, Self::Error>;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Ok(Async::Ready(()))
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future {
match verify_handshake(req.head()) {
Err(e) => Err((e, framed)).into_future(),
Ok(_) => Ok((req, framed)).into_future(),
Err(e) => err((e, framed)),
Ok(_) => ok((req, framed)),
}
}
}
@@ -67,9 +69,9 @@ where
}
}
impl<T, R, E, C> NewService for SendError<T, R, E, C>
impl<T, R, E, C> ServiceFactory for SendError<T, R, E, C>
where
T: AsyncRead + AsyncWrite + 'static,
T: AsyncRead + AsyncWrite + Unpin + 'static,
R: 'static,
E: ResponseError + 'static,
{
@@ -79,7 +81,7 @@ where
type Error = (E, Framed<T, Codec>);
type InitError = ();
type Service = SendError<T, R, E, C>;
type Future = FutureResult<Self::Service, Self::InitError>;
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: &C) -> Self::Future {
ok(SendError(PhantomData))
@@ -88,25 +90,25 @@ where
impl<T, R, E, C> Service for SendError<T, R, E, C>
where
T: AsyncRead + AsyncWrite + 'static,
T: AsyncRead + AsyncWrite + Unpin + 'static,
R: 'static,
E: ResponseError + 'static,
{
type Request = Result<R, (E, Framed<T, Codec>)>;
type Response = R;
type Error = (E, Framed<T, Codec>);
type Future = Either<FutureResult<R, (E, Framed<T, Codec>)>, SendErrorFut<T, R, E>>;
type Future = Either<Ready<Result<R, (E, Framed<T, Codec>)>>, SendErrorFut<T, R, E>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Ok(Async::Ready(()))
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Result<R, (E, Framed<T, Codec>)>) -> Self::Future {
match req {
Ok(r) => Either::A(ok(r)),
Ok(r) => Either::Left(ok(r)),
Err((e, framed)) => {
let res = e.error_response().drop_body();
Either::B(SendErrorFut {
Either::Right(SendErrorFut {
framed: Some(framed),
res: Some((res, BodySize::Empty).into()),
err: Some(e),
@@ -117,6 +119,7 @@ where
}
}
#[pin_project::pin_project]
pub struct SendErrorFut<T, R, E> {
res: Option<Message<(Response<()>, BodySize)>>,
framed: Option<Framed<T, Codec>>,
@@ -127,23 +130,27 @@ pub struct SendErrorFut<T, R, E> {
impl<T, R, E> Future for SendErrorFut<T, R, E>
where
E: ResponseError,
T: AsyncRead + AsyncWrite,
T: AsyncRead + AsyncWrite + Unpin,
{
type Item = R;
type Error = (E, Framed<T, Codec>);
type Output = Result<R, (E, Framed<T, Codec>)>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let Some(res) = self.res.take() {
if self.framed.as_mut().unwrap().force_send(res).is_err() {
return Err((self.err.take().unwrap(), self.framed.take().unwrap()));
if self.framed.as_mut().unwrap().write(res).is_err() {
return Poll::Ready(Err((
self.err.take().unwrap(),
self.framed.take().unwrap(),
)));
}
}
match self.framed.as_mut().unwrap().poll_complete() {
Ok(Async::Ready(_)) => {
Err((self.err.take().unwrap(), self.framed.take().unwrap()))
match self.framed.as_mut().unwrap().flush(cx) {
Poll::Ready(Ok(_)) => {
Poll::Ready(Err((self.err.take().unwrap(), self.framed.take().unwrap())))
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(_) => Err((self.err.take().unwrap(), self.framed.take().unwrap())),
Poll::Ready(Err(_)) => {
Poll::Ready(Err((self.err.take().unwrap(), self.framed.take().unwrap())))
}
Poll::Pending => Poll::Pending,
}
}
}

View File

@@ -1,4 +1,6 @@
//! Various helpers for Actix applications to use during testing.
use std::future::Future;
use actix_codec::Framed;
use actix_http::h1::Codec;
use actix_http::http::header::{Header, HeaderName, IntoHeaderValue};
@@ -6,7 +8,6 @@ use actix_http::http::{HttpTryFrom, Method, Uri, Version};
use actix_http::test::{TestBuffer, TestRequest as HttpTestRequest};
use actix_router::{Path, Url};
use actix_rt::Runtime;
use futures::IntoFuture;
use crate::{FramedRequest, State};
@@ -121,10 +122,10 @@ impl<S> TestRequest<S> {
pub fn run<F, R, I, E>(self, f: F) -> Result<I, E>
where
F: FnOnce(FramedRequest<TestBuffer, S>) -> R,
R: IntoFuture<Item = I, Error = E>,
R: Future<Output = Result<I, E>>,
{
let mut rt = Runtime::new().unwrap();
rt.block_on(f(self.finish()).into_future())
rt.block_on(f(self.finish()))
}
}

View File

@@ -1,30 +1,31 @@
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::{body, http::StatusCode, ws, Error, HttpService, Response};
use actix_http_test::TestServer;
use actix_service::{IntoNewService, NewService};
use actix_http_test::{block_on, TestServer};
use actix_service::{pipeline_factory, IntoServiceFactory, ServiceFactory};
use actix_utils::framed::FramedTransport;
use bytes::{Bytes, BytesMut};
use futures::future::{self, ok};
use futures::{Future, Sink, Stream};
use futures::{future, SinkExt, StreamExt};
use actix_framed::{FramedApp, FramedRequest, FramedRoute, SendError, VerifyWebSockets};
fn ws_service<T: AsyncRead + AsyncWrite>(
async fn ws_service<T: AsyncRead + AsyncWrite>(
req: FramedRequest<T>,
) -> impl Future<Item = (), Error = Error> {
let (req, framed, _) = req.into_parts();
) -> Result<(), Error> {
let (req, mut framed, _) = req.into_parts();
let res = ws::handshake(req.head()).unwrap().message_body(());
framed
.send((res, body::BodySize::None).into())
.map_err(|_| panic!())
.and_then(|framed| {
.await
.unwrap();
FramedTransport::new(framed.into_framed(ws::Codec::new()), service)
.map_err(|_| panic!())
})
.await
.unwrap();
Ok(())
}
fn service(msg: ws::Frame) -> impl Future<Item = ws::Message, Error = Error> {
async fn service(msg: ws::Frame) -> Result<ws::Message, Error> {
let msg = match msg {
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
ws::Frame::Text(text) => {
@@ -34,108 +35,129 @@ fn service(msg: ws::Frame) -> impl Future<Item = ws::Message, Error = Error> {
ws::Frame::Close(reason) => ws::Message::Close(reason),
_ => panic!(),
};
ok(msg)
Ok(msg)
}
#[test]
fn test_simple() {
let mut srv = TestServer::new(|| {
block_on(async {
let mut srv = TestServer::start(|| {
HttpService::build()
.upgrade(
FramedApp::new().service(FramedRoute::get("/index.html").to(ws_service)),
FramedApp::new()
.service(FramedRoute::get("/index.html").to(ws_service)),
)
.finish(|_| future::ok::<_, Error>(Response::NotFound()))
});
assert!(srv.ws_at("/test").is_err());
assert!(srv.ws_at("/test").await.is_err());
// client service
let framed = srv.ws_at("/index.html").unwrap();
let framed = srv
.block_on(framed.send(ws::Message::Text("text".to_string())))
let mut framed = srv.ws_at("/index.html").await.unwrap();
framed
.send(ws::Message::Text("text".to_string()))
.await
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text")))));
let framed = srv
.block_on(framed.send(ws::Message::Binary("text".into())))
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
let (item, mut framed) = framed.into_future().await;
assert_eq!(
item,
Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into())))
item.unwrap().unwrap(),
ws::Frame::Text(Some(BytesMut::from("text")))
);
let framed = srv
.block_on(framed.send(ws::Message::Ping("text".into())))
framed
.send(ws::Message::Binary("text".into()))
.await
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into())));
let framed = srv
.block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))))
.unwrap();
let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
let (item, mut framed) = framed.into_future().await;
assert_eq!(
item,
Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into())))
item.unwrap().unwrap(),
ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))
);
framed.send(ws::Message::Ping("text".into())).await.unwrap();
let (item, mut framed) = framed.into_future().await;
assert_eq!(
item.unwrap().unwrap(),
ws::Frame::Pong("text".to_string().into())
);
framed
.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
.await
.unwrap();
let (item, _) = framed.into_future().await;
assert_eq!(
item.unwrap().unwrap(),
ws::Frame::Close(Some(ws::CloseCode::Normal.into()))
);
})
}
#[test]
fn test_service() {
let mut srv = TestServer::new(|| {
actix_http::h1::OneRequest::new().map_err(|_| ()).and_then(
VerifyWebSockets::default()
block_on(async {
let mut srv = TestServer::start(|| {
pipeline_factory(actix_http::h1::OneRequest::new().map_err(|_| ())).and_then(
pipeline_factory(
pipeline_factory(VerifyWebSockets::default())
.then(SendError::default())
.map_err(|_| ())
.map_err(|_| ()),
)
.and_then(
FramedApp::new()
.service(FramedRoute::get("/index.html").to(ws_service))
.into_new_service()
.into_factory()
.map_err(|_| ()),
),
)
});
// non ws request
let res = srv.block_on(srv.get("/index.html").send()).unwrap();
let res = srv.get("/index.html").send().await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
// not found
assert!(srv.ws_at("/test").is_err());
assert!(srv.ws_at("/test").await.is_err());
// client service
let framed = srv.ws_at("/index.html").unwrap();
let framed = srv
.block_on(framed.send(ws::Message::Text("text".to_string())))
let mut framed = srv.ws_at("/index.html").await.unwrap();
framed
.send(ws::Message::Text("text".to_string()))
.await
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text")))));
let framed = srv
.block_on(framed.send(ws::Message::Binary("text".into())))
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
let (item, mut framed) = framed.into_future().await;
assert_eq!(
item,
Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into())))
item.unwrap().unwrap(),
ws::Frame::Text(Some(BytesMut::from("text")))
);
let framed = srv
.block_on(framed.send(ws::Message::Ping("text".into())))
framed
.send(ws::Message::Binary("text".into()))
.await
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into())));
let framed = srv
.block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))))
.unwrap();
let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
let (item, mut framed) = framed.into_future().await;
assert_eq!(
item,
Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into())))
item.unwrap().unwrap(),
ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))
);
framed.send(ws::Message::Ping("text".into())).await.unwrap();
let (item, mut framed) = framed.into_future().await;
assert_eq!(
item.unwrap().unwrap(),
ws::Frame::Pong("text".to_string().into())
);
framed
.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
.await
.unwrap();
let (item, _) = framed.into_future().await;
assert_eq!(
item.unwrap().unwrap(),
ws::Frame::Close(Some(ws::CloseCode::Normal.into()))
);
})
}

View File

@@ -1,11 +1,15 @@
# Changes
## Not released yet
## [0.2.11] - 2019-11-06
### Added
* Add support for serde_json::Value to be passed as argument to ResponseBuilder.body()
* Add an additional `filename*` param in the `Content-Disposition` header of `actix_files::NamedFile` to be more compatible. (#1151)
* Allow to use `std::convert::Infallible` as `actix_http::error::Error`
### Fixed
* To be compatible with non-English error responses, `ResponseError` rendered with `text/plain; charset=utf-8` header #1118
@@ -17,9 +21,6 @@
* Add support for sending HTTP requests with `Rc<RequestHead>` in addition to sending HTTP requests with `RequestHead`
* Allow to use `std::convert::Infallible` as `actix_http::error::Error`
### Fixed
* h2 will use error response #1080

View File

@@ -1,6 +1,6 @@
[package]
name = "actix-http"
version = "0.2.10"
version = "0.3.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix http primitives"
readme = "README.md"
@@ -16,7 +16,7 @@ edition = "2018"
workspace = ".."
[package.metadata.docs.rs]
features = ["ssl", "fail", "brotli", "flate2-zlib", "secure-cookies"]
features = ["openssl", "fail", "brotli", "flate2-zlib", "secure-cookies"]
[lib]
name = "actix_http"
@@ -26,10 +26,10 @@ path = "src/lib.rs"
default = []
# openssl
ssl = ["openssl", "actix-connect/ssl"]
openssl = ["open-ssl", "actix-connect/openssl", "tokio-openssl"]
# rustls support
rust-tls = ["rustls", "webpki-roots", "actix-connect/rust-tls"]
# rustls = ["rust-tls", "webpki-roots", "actix-connect/rustls"]
# brotli encoding, requires c compiler
brotli = ["brotli2"]
@@ -47,23 +47,24 @@ fail = ["failure"]
secure-cookies = ["ring"]
[dependencies]
actix-service = "0.4.1"
actix-codec = "0.1.2"
actix-connect = "0.2.4"
actix-utils = "0.4.4"
actix-server-config = "0.1.2"
actix-threadpool = "0.1.1"
actix-service = "1.0.0-alpha.1"
actix-codec = "0.2.0-alpha.1"
actix-connect = "1.0.0-alpha.1"
actix-utils = "0.5.0-alpha.1"
actix-server-config = "0.3.0-alpha.1"
actix-threadpool = "0.2.0-alpha.1"
base64 = "0.10"
bitflags = "1.0"
bytes = "0.4"
copyless = "0.1.4"
chrono = "0.4.6"
derive_more = "0.15.0"
either = "1.5.2"
encoding_rs = "0.8"
futures = "0.1.25"
hashbrown = "0.5.0"
h2 = "0.1.16"
futures = "0.3.1"
hashbrown = "0.6.3"
h2 = "0.2.0-alpha.3"
http = "0.1.17"
httparse = "1.3"
indexmap = "1.2"
@@ -72,6 +73,7 @@ language-tags = "0.2"
log = "0.4"
mime = "0.3"
percent-encoding = "2.1"
pin-project = "0.4.5"
rand = "0.7"
regex = "1.0"
serde = "1.0"
@@ -80,13 +82,16 @@ sha1 = "0.6"
slab = "0.4"
serde_urlencoded = "0.6.1"
time = "0.1.42"
tokio-tcp = "0.1.3"
tokio-timer = "0.2.8"
tokio-current-thread = "0.1"
trust-dns-resolver = { version="0.11.1", default-features = false }
tokio = "=0.2.0-alpha.6"
tokio-io = "=0.2.0-alpha.6"
tokio-net = "=0.2.0-alpha.6"
tokio-timer = "0.3.0-alpha.6"
tokio-executor = "=0.2.0-alpha.6"
trust-dns-resolver = { version="0.18.0-alpha.1", default-features = false }
# for secure cookie
ring = { version = "0.14.6", optional = true }
ring = { version = "0.16.9", optional = true }
# compression
brotli2 = { version="0.3.2", optional = true }
@@ -94,17 +99,17 @@ flate2 = { version="1.0.7", optional = true, default-features = false }
# optional deps
failure = { version = "0.1.5", optional = true }
openssl = { version="0.10", optional = true }
rustls = { version = "0.15.2", optional = true }
webpki-roots = { version = "0.16", optional = true }
chrono = "0.4.6"
open-ssl = { version="0.10", package="openssl", optional = true }
tokio-openssl = { version = "0.4.0-alpha.6", optional = true }
rust-tls = { version = "0.16.0", package="rustls", optional = true }
webpki-roots = { version = "0.18", optional = true }
[dev-dependencies]
actix-rt = "0.2.2"
actix-server = { version = "0.6.0", features=["ssl", "rust-tls"] }
actix-connect = { version = "0.2.0", features=["ssl"] }
actix-http-test = { version = "0.2.4", features=["ssl"] }
actix-rt = "1.0.0-alpha.1"
actix-server = { version = "0.8.0-alpha.1", features=["openssl"] }
actix-connect = { version = "1.0.0-alpha.1", features=["openssl"] }
actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] }
env_logger = "0.6"
serde_derive = "1.0"
openssl = { version="0.10" }
tokio-tcp = "0.1"
open-ssl = { version="0.10", package="openssl" }

View File

@@ -1,9 +1,9 @@
use std::{env, io};
use actix_http::{error::PayloadError, HttpService, Request, Response};
use actix_http::{Error, HttpService, Request, Response};
use actix_server::Server;
use bytes::BytesMut;
use futures::{Future, Stream};
use futures::StreamExt;
use http::header::HeaderValue;
use log::info;
@@ -17,20 +17,22 @@ fn main() -> io::Result<()> {
.client_timeout(1000)
.client_disconnect(1000)
.finish(|mut req: Request| {
req.take_payload()
.fold(BytesMut::new(), move |mut body, chunk| {
body.extend_from_slice(&chunk);
Ok::<_, PayloadError>(body)
})
.and_then(|bytes| {
info!("request body: {:?}", bytes);
let mut res = Response::Ok();
res.header(
async move {
let mut body = BytesMut::new();
while let Some(item) = req.payload().next().await {
body.extend_from_slice(&item?);
}
info!("request body: {:?}", body);
Ok::<_, Error>(
Response::Ok()
.header(
"x-head",
HeaderValue::from_static("dummy value!"),
);
Ok(res.body(bytes))
})
)
.body(body),
)
}
})
})?
.run()

View File

@@ -1,25 +1,22 @@
use std::{env, io};
use actix_http::http::HeaderValue;
use actix_http::{error::PayloadError, Error, HttpService, Request, Response};
use actix_http::{Error, HttpService, Request, Response};
use actix_server::Server;
use bytes::BytesMut;
use futures::{Future, Stream};
use futures::StreamExt;
use log::info;
fn handle_request(mut req: Request) -> impl Future<Item = Response, Error = Error> {
req.take_payload()
.fold(BytesMut::new(), move |mut body, chunk| {
body.extend_from_slice(&chunk);
Ok::<_, PayloadError>(body)
})
.from_err()
.and_then(|bytes| {
info!("request body: {:?}", bytes);
let mut res = Response::Ok();
res.header("x-head", HeaderValue::from_static("dummy value!"));
Ok(res.body(bytes))
})
async fn handle_request(mut req: Request) -> Result<Response, Error> {
let mut body = BytesMut::new();
while let Some(item) = req.payload().next().await {
body.extend_from_slice(&item?)
}
info!("request body: {:?}", body);
Ok(Response::Ok()
.header("x-head", HeaderValue::from_static("dummy value!"))
.body(body))
}
fn main() -> io::Result<()> {
@@ -28,7 +25,7 @@ fn main() -> io::Result<()> {
Server::build()
.bind("echo", "127.0.0.1:8080", || {
HttpService::build().finish(|_req: Request| handle_request(_req))
HttpService::build().finish(handle_request)
})?
.run()
}

View File

@@ -1,8 +1,11 @@
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, mem};
use bytes::{Bytes, BytesMut};
use futures::{Async, Poll, Stream};
use futures::Stream;
use pin_project::{pin_project, project};
use crate::error::Error;
@@ -32,7 +35,7 @@ impl BodySize {
pub trait MessageBody {
fn size(&self) -> BodySize;
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error>;
fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>>;
}
impl MessageBody for () {
@@ -40,8 +43,8 @@ impl MessageBody for () {
BodySize::Empty
}
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
Ok(Async::Ready(None))
fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
Poll::Ready(None)
}
}
@@ -50,11 +53,12 @@ impl<T: MessageBody> MessageBody for Box<T> {
self.as_ref().size()
}
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
self.as_mut().poll_next()
fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
self.as_mut().poll_next(cx)
}
}
#[pin_project]
pub enum ResponseBody<B> {
Body(B),
Other(Body),
@@ -93,20 +97,24 @@ impl<B: MessageBody> MessageBody for ResponseBody<B> {
}
}
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
match self {
ResponseBody::Body(ref mut body) => body.poll_next(),
ResponseBody::Other(ref mut body) => body.poll_next(),
ResponseBody::Body(ref mut body) => body.poll_next(cx),
ResponseBody::Other(ref mut body) => body.poll_next(cx),
}
}
}
impl<B: MessageBody> Stream for ResponseBody<B> {
type Item = Bytes;
type Error = Error;
type Item = Result<Bytes, Error>;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
self.poll_next()
#[project]
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
#[project]
match self.project() {
ResponseBody::Body(ref mut body) => body.poll_next(cx),
ResponseBody::Other(ref mut body) => body.poll_next(cx),
}
}
}
@@ -144,19 +152,19 @@ impl MessageBody for Body {
}
}
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
match self {
Body::None => Ok(Async::Ready(None)),
Body::Empty => Ok(Async::Ready(None)),
Body::None => Poll::Ready(None),
Body::Empty => Poll::Ready(None),
Body::Bytes(ref mut bin) => {
let len = bin.len();
if len == 0 {
Ok(Async::Ready(None))
Poll::Ready(None)
} else {
Ok(Async::Ready(Some(mem::replace(bin, Bytes::new()))))
Poll::Ready(Some(Ok(mem::replace(bin, Bytes::new()))))
}
}
Body::Message(ref mut body) => body.poll_next(),
Body::Message(ref mut body) => body.poll_next(cx),
}
}
}
@@ -242,7 +250,7 @@ impl From<serde_json::Value> for Body {
impl<S> From<SizedStream<S>> for Body
where
S: Stream<Item = Bytes, Error = Error> + 'static,
S: Stream<Item = Result<Bytes, Error>> + 'static,
{
fn from(s: SizedStream<S>) -> Body {
Body::from_message(s)
@@ -251,7 +259,7 @@ where
impl<S, E> From<BodyStream<S, E>> for Body
where
S: Stream<Item = Bytes, Error = E> + 'static,
S: Stream<Item = Result<Bytes, E>> + 'static,
E: Into<Error> + 'static,
{
fn from(s: BodyStream<S, E>) -> Body {
@@ -264,11 +272,11 @@ impl MessageBody for Bytes {
BodySize::Sized(self.len())
}
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
if self.is_empty() {
Ok(Async::Ready(None))
Poll::Ready(None)
} else {
Ok(Async::Ready(Some(mem::replace(self, Bytes::new()))))
Poll::Ready(Some(Ok(mem::replace(self, Bytes::new()))))
}
}
}
@@ -278,13 +286,11 @@ impl MessageBody for BytesMut {
BodySize::Sized(self.len())
}
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
if self.is_empty() {
Ok(Async::Ready(None))
Poll::Ready(None)
} else {
Ok(Async::Ready(Some(
mem::replace(self, BytesMut::new()).freeze(),
)))
Poll::Ready(Some(Ok(mem::replace(self, BytesMut::new()).freeze())))
}
}
}
@@ -294,11 +300,11 @@ impl MessageBody for &'static str {
BodySize::Sized(self.len())
}
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
if self.is_empty() {
Ok(Async::Ready(None))
Poll::Ready(None)
} else {
Ok(Async::Ready(Some(Bytes::from_static(
Poll::Ready(Some(Ok(Bytes::from_static(
mem::replace(self, "").as_ref(),
))))
}
@@ -310,13 +316,11 @@ impl MessageBody for &'static [u8] {
BodySize::Sized(self.len())
}
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
if self.is_empty() {
Ok(Async::Ready(None))
Poll::Ready(None)
} else {
Ok(Async::Ready(Some(Bytes::from_static(mem::replace(
self, b"",
)))))
Poll::Ready(Some(Ok(Bytes::from_static(mem::replace(self, b"")))))
}
}
}
@@ -326,14 +330,11 @@ impl MessageBody for Vec<u8> {
BodySize::Sized(self.len())
}
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
if self.is_empty() {
Ok(Async::Ready(None))
Poll::Ready(None)
} else {
Ok(Async::Ready(Some(Bytes::from(mem::replace(
self,
Vec::new(),
)))))
Poll::Ready(Some(Ok(Bytes::from(mem::replace(self, Vec::new())))))
}
}
}
@@ -343,11 +344,11 @@ impl MessageBody for String {
BodySize::Sized(self.len())
}
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
fn poll_next(&mut self, _: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
if self.is_empty() {
Ok(Async::Ready(None))
Poll::Ready(None)
} else {
Ok(Async::Ready(Some(Bytes::from(
Poll::Ready(Some(Ok(Bytes::from(
mem::replace(self, String::new()).into_bytes(),
))))
}
@@ -356,14 +357,16 @@ impl MessageBody for String {
/// Type represent streaming body.
/// Response does not contain `content-length` header and appropriate transfer encoding is used.
#[pin_project]
pub struct BodyStream<S, E> {
#[pin]
stream: S,
_t: PhantomData<E>,
}
impl<S, E> BodyStream<S, E>
where
S: Stream<Item = Bytes, Error = E>,
S: Stream<Item = Result<Bytes, E>>,
E: Into<Error>,
{
pub fn new(stream: S) -> Self {
@@ -376,28 +379,34 @@ where
impl<S, E> MessageBody for BodyStream<S, E>
where
S: Stream<Item = Bytes, Error = E>,
S: Stream<Item = Result<Bytes, E>>,
E: Into<Error>,
{
fn size(&self) -> BodySize {
BodySize::Stream
}
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
self.stream.poll().map_err(std::convert::Into::into)
fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
unsafe { Pin::new_unchecked(self) }
.project()
.stream
.poll_next(cx)
.map(|res| res.map(|res| res.map_err(std::convert::Into::into)))
}
}
/// Type represent streaming body. This body implementation should be used
/// if total size of stream is known. Data get sent as is without using transfer encoding.
#[pin_project]
pub struct SizedStream<S> {
size: u64,
#[pin]
stream: S,
}
impl<S> SizedStream<S>
where
S: Stream<Item = Bytes, Error = Error>,
S: Stream<Item = Result<Bytes, Error>>,
{
pub fn new(size: u64, stream: S) -> Self {
SizedStream { size, stream }
@@ -406,20 +415,25 @@ where
impl<S> MessageBody for SizedStream<S>
where
S: Stream<Item = Bytes, Error = Error>,
S: Stream<Item = Result<Bytes, Error>>,
{
fn size(&self) -> BodySize {
BodySize::Sized64(self.size)
}
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
self.stream.poll()
fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
unsafe { Pin::new_unchecked(self) }
.project()
.stream
.poll_next(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use actix_http_test::block_on;
use futures::future::{lazy, poll_fn};
impl Body {
pub(crate) fn get_ref(&self) -> &[u8] {
@@ -447,8 +461,8 @@ mod tests {
assert_eq!("test".size(), BodySize::Sized(4));
assert_eq!(
"test".poll_next().unwrap(),
Async::Ready(Some(Bytes::from("test")))
block_on(poll_fn(|cx| "test".poll_next(cx))).unwrap().ok(),
Some(Bytes::from("test"))
);
}
@@ -464,8 +478,10 @@ mod tests {
assert_eq!((&b"test"[..]).size(), BodySize::Sized(4));
assert_eq!(
(&b"test"[..]).poll_next().unwrap(),
Async::Ready(Some(Bytes::from("test")))
block_on(poll_fn(|cx| (&b"test"[..]).poll_next(cx)))
.unwrap()
.ok(),
Some(Bytes::from("test"))
);
}
@@ -476,8 +492,10 @@ mod tests {
assert_eq!(Vec::from("test").size(), BodySize::Sized(4));
assert_eq!(
Vec::from("test").poll_next().unwrap(),
Async::Ready(Some(Bytes::from("test")))
block_on(poll_fn(|cx| Vec::from("test").poll_next(cx)))
.unwrap()
.ok(),
Some(Bytes::from("test"))
);
}
@@ -489,8 +507,8 @@ mod tests {
assert_eq!(b.size(), BodySize::Sized(4));
assert_eq!(
b.poll_next().unwrap(),
Async::Ready(Some(Bytes::from("test")))
block_on(poll_fn(|cx| b.poll_next(cx))).unwrap().ok(),
Some(Bytes::from("test"))
);
}
@@ -502,8 +520,8 @@ mod tests {
assert_eq!(b.size(), BodySize::Sized(4));
assert_eq!(
b.poll_next().unwrap(),
Async::Ready(Some(Bytes::from("test")))
block_on(poll_fn(|cx| b.poll_next(cx))).unwrap().ok(),
Some(Bytes::from("test"))
);
}
@@ -517,22 +535,22 @@ mod tests {
assert_eq!(b.size(), BodySize::Sized(4));
assert_eq!(
b.poll_next().unwrap(),
Async::Ready(Some(Bytes::from("test")))
block_on(poll_fn(|cx| b.poll_next(cx))).unwrap().ok(),
Some(Bytes::from("test"))
);
}
#[test]
fn test_unit() {
assert_eq!(().size(), BodySize::Empty);
assert_eq!(().poll_next().unwrap(), Async::Ready(None));
assert!(block_on(poll_fn(|cx| ().poll_next(cx))).is_none());
}
#[test]
fn test_box() {
let mut val = Box::new(());
assert_eq!(val.size(), BodySize::Empty);
assert_eq!(val.poll_next().unwrap(), Async::Ready(None));
assert!(block_on(poll_fn(|cx| val.poll_next(cx))).is_none());
}
#[test]

View File

@@ -4,7 +4,7 @@ use std::rc::Rc;
use actix_codec::Framed;
use actix_server_config::ServerConfig as SrvConfig;
use actix_service::{IntoNewService, NewService, Service};
use actix_service::{IntoServiceFactory, Service, ServiceFactory};
use crate::body::MessageBody;
use crate::config::{KeepAlive, ServiceConfig};
@@ -32,9 +32,10 @@ pub struct HttpServiceBuilder<T, S, X = ExpectHandler, U = UpgradeHandler<T>> {
impl<T, S> HttpServiceBuilder<T, S, ExpectHandler, UpgradeHandler<T>>
where
S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>,
S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug,
<S::Service as Service>::Future: 'static,
{
/// Create instance of `ServiceConfigBuilder`
pub fn new() -> Self {
@@ -52,19 +53,22 @@ where
impl<T, S, X, U> HttpServiceBuilder<T, S, X, U>
where
S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>,
S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug,
X: NewService<Config = SrvConfig, Request = Request, Response = Request>,
<S::Service as Service>::Future: 'static,
X: ServiceFactory<Config = SrvConfig, Request = Request, Response = Request>,
X::Error: Into<Error>,
X::InitError: fmt::Debug,
U: NewService<
<X::Service as Service>::Future: 'static,
U: ServiceFactory<
Config = SrvConfig,
Request = (Request, Framed<T, Codec>),
Response = (),
>,
U::Error: fmt::Display,
U::InitError: fmt::Debug,
<U::Service as Service>::Future: 'static,
{
/// Set server keep-alive setting.
///
@@ -108,16 +112,17 @@ where
/// request will be forwarded to main service.
pub fn expect<F, X1>(self, expect: F) -> HttpServiceBuilder<T, S, X1, U>
where
F: IntoNewService<X1>,
X1: NewService<Config = SrvConfig, Request = Request, Response = Request>,
F: IntoServiceFactory<X1>,
X1: ServiceFactory<Config = SrvConfig, Request = Request, Response = Request>,
X1::Error: Into<Error>,
X1::InitError: fmt::Debug,
<X1::Service as Service>::Future: 'static,
{
HttpServiceBuilder {
keep_alive: self.keep_alive,
client_timeout: self.client_timeout,
client_disconnect: self.client_disconnect,
expect: expect.into_new_service(),
expect: expect.into_factory(),
upgrade: self.upgrade,
on_connect: self.on_connect,
_t: PhantomData,
@@ -130,21 +135,22 @@ where
/// and this service get called with original request and framed object.
pub fn upgrade<F, U1>(self, upgrade: F) -> HttpServiceBuilder<T, S, X, U1>
where
F: IntoNewService<U1>,
U1: NewService<
F: IntoServiceFactory<U1>,
U1: ServiceFactory<
Config = SrvConfig,
Request = (Request, Framed<T, Codec>),
Response = (),
>,
U1::Error: fmt::Display,
U1::InitError: fmt::Debug,
<U1::Service as Service>::Future: 'static,
{
HttpServiceBuilder {
keep_alive: self.keep_alive,
client_timeout: self.client_timeout,
client_disconnect: self.client_disconnect,
expect: self.expect,
upgrade: Some(upgrade.into_new_service()),
upgrade: Some(upgrade.into_factory()),
on_connect: self.on_connect,
_t: PhantomData,
}
@@ -166,8 +172,8 @@ where
/// Finish service configuration and create *http service* for HTTP/1 protocol.
pub fn h1<F, P, B>(self, service: F) -> H1Service<T, P, S, B, X, U>
where
B: MessageBody + 'static,
F: IntoNewService<S>,
B: MessageBody,
F: IntoServiceFactory<S>,
S::Error: Into<Error>,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
@@ -177,7 +183,7 @@ where
self.client_timeout,
self.client_disconnect,
);
H1Service::with_config(cfg, service.into_new_service())
H1Service::with_config(cfg, service.into_factory())
.expect(self.expect)
.upgrade(self.upgrade)
.on_connect(self.on_connect)
@@ -187,10 +193,10 @@ where
pub fn h2<F, P, B>(self, service: F) -> H2Service<T, P, S, B>
where
B: MessageBody + 'static,
F: IntoNewService<S>,
S::Error: Into<Error>,
F: IntoServiceFactory<S>,
S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
{
let cfg = ServiceConfig::new(
@@ -198,18 +204,17 @@ where
self.client_timeout,
self.client_disconnect,
);
H2Service::with_config(cfg, service.into_new_service())
.on_connect(self.on_connect)
H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect)
}
/// Finish service configuration and create `HttpService` instance.
pub fn finish<F, P, B>(self, service: F) -> HttpService<T, P, S, B, X, U>
where
B: MessageBody + 'static,
F: IntoNewService<S>,
S::Error: Into<Error>,
F: IntoServiceFactory<S>,
S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
{
let cfg = ServiceConfig::new(
@@ -217,7 +222,7 @@ where
self.client_timeout,
self.client_disconnect,
);
HttpService::with_config(cfg, service.into_new_service())
HttpService::with_config(cfg, service.into_factory())
.expect(self.expect)
.upgrade(self.upgrade)
.on_connect(self.on_connect)

View File

@@ -1,10 +1,12 @@
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, io, time};
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use bytes::{Buf, Bytes};
use futures::future::{err, Either, Future, FutureResult};
use futures::Poll;
use futures::future::{err, Either, Future, FutureExt, LocalBoxFuture, Ready};
use h2::client::SendRequest;
use pin_project::{pin_project, project};
use crate::body::MessageBody;
use crate::h1::ClientCodec;
@@ -21,8 +23,8 @@ pub(crate) enum ConnectionType<Io> {
}
pub trait Connection {
type Io: AsyncRead + AsyncWrite;
type Future: Future<Item = (ResponseHead, Payload), Error = SendRequestError>;
type Io: AsyncRead + AsyncWrite + Unpin;
type Future: Future<Output = Result<(ResponseHead, Payload), SendRequestError>>;
fn protocol(&self) -> Protocol;
@@ -34,8 +36,7 @@ pub trait Connection {
) -> Self::Future;
type TunnelFuture: Future<
Item = (ResponseHead, Framed<Self::Io, ClientCodec>),
Error = SendRequestError,
Output = Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
>;
/// Send request, returns Response and Framed
@@ -71,7 +72,7 @@ where
}
}
impl<T: AsyncRead + AsyncWrite> IoConnection<T> {
impl<T: AsyncRead + AsyncWrite + Unpin> IoConnection<T> {
pub(crate) fn new(
io: ConnectionType<T>,
created: time::Instant,
@@ -91,11 +92,11 @@ impl<T: AsyncRead + AsyncWrite> IoConnection<T> {
impl<T> Connection for IoConnection<T>
where
T: AsyncRead + AsyncWrite + 'static,
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
type Io = T;
type Future =
Box<dyn Future<Item = (ResponseHead, Payload), Error = SendRequestError>>;
LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>;
fn protocol(&self) -> Protocol {
match self.io {
@@ -111,38 +112,30 @@ where
body: B,
) -> Self::Future {
match self.io.take().unwrap() {
ConnectionType::H1(io) => Box::new(h1proto::send_request(
io,
head.into(),
body,
self.created,
self.pool,
)),
ConnectionType::H2(io) => Box::new(h2proto::send_request(
io,
head.into(),
body,
self.created,
self.pool,
)),
ConnectionType::H1(io) => {
h1proto::send_request(io, head.into(), body, self.created, self.pool)
.boxed_local()
}
ConnectionType::H2(io) => {
h2proto::send_request(io, head.into(), body, self.created, self.pool)
.boxed_local()
}
}
}
type TunnelFuture = Either<
Box<
dyn Future<
Item = (ResponseHead, Framed<Self::Io, ClientCodec>),
Error = SendRequestError,
LocalBoxFuture<
'static,
Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
>,
>,
FutureResult<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
Ready<Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>>,
>;
/// Send request, returns Response and Framed
fn open_tunnel<H: Into<RequestHeadType>>(mut self, head: H) -> Self::TunnelFuture {
match self.io.take().unwrap() {
ConnectionType::H1(io) => {
Either::A(Box::new(h1proto::open_tunnel(io, head.into())))
Either::Left(h1proto::open_tunnel(io, head.into()).boxed_local())
}
ConnectionType::H2(io) => {
if let Some(mut pool) = self.pool.take() {
@@ -152,7 +145,7 @@ where
None,
));
}
Either::B(err(SendRequestError::TunnelNotSupported))
Either::Right(err(SendRequestError::TunnelNotSupported))
}
}
}
@@ -166,12 +159,12 @@ pub(crate) enum EitherConnection<A, B> {
impl<A, B> Connection for EitherConnection<A, B>
where
A: AsyncRead + AsyncWrite + 'static,
B: AsyncRead + AsyncWrite + 'static,
A: AsyncRead + AsyncWrite + Unpin + 'static,
B: AsyncRead + AsyncWrite + Unpin + 'static,
{
type Io = EitherIo<A, B>;
type Future =
Box<dyn Future<Item = (ResponseHead, Payload), Error = SendRequestError>>;
LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>;
fn protocol(&self) -> Protocol {
match self {
@@ -191,44 +184,30 @@ where
}
}
type TunnelFuture = Box<
dyn Future<
Item = (ResponseHead, Framed<Self::Io, ClientCodec>),
Error = SendRequestError,
>,
type TunnelFuture = LocalBoxFuture<
'static,
Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
>;
/// Send request, returns Response and Framed
fn open_tunnel<H: Into<RequestHeadType>>(self, head: H) -> Self::TunnelFuture {
match self {
EitherConnection::A(con) => Box::new(
con.open_tunnel(head)
.map(|(head, framed)| (head, framed.map_io(EitherIo::A))),
),
EitherConnection::B(con) => Box::new(
con.open_tunnel(head)
.map(|(head, framed)| (head, framed.map_io(EitherIo::B))),
),
EitherConnection::A(con) => con
.open_tunnel(head)
.map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::A))))
.boxed_local(),
EitherConnection::B(con) => con
.open_tunnel(head)
.map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::B))))
.boxed_local(),
}
}
}
#[pin_project]
pub enum EitherIo<A, B> {
A(A),
B(B),
}
impl<A, B> io::Read for EitherIo<A, B>
where
A: io::Read,
B: io::Read,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
EitherIo::A(ref mut val) => val.read(buf),
EitherIo::B(ref mut val) => val.read(buf),
}
}
A(#[pin] A),
B(#[pin] B),
}
impl<A, B> AsyncRead for EitherIo<A, B>
@@ -236,6 +215,19 @@ where
A: AsyncRead,
B: AsyncRead,
{
#[project]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
#[project]
match self.project() {
EitherIo::A(val) => val.poll_read(cx, buf),
EitherIo::B(val) => val.poll_read(cx, buf),
}
}
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
match self {
EitherIo::A(ref val) => val.prepare_uninitialized_buffer(buf),
@@ -244,45 +236,58 @@ where
}
}
impl<A, B> io::Write for EitherIo<A, B>
where
A: io::Write,
B: io::Write,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
EitherIo::A(ref mut val) => val.write(buf),
EitherIo::B(ref mut val) => val.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
EitherIo::A(ref mut val) => val.flush(),
EitherIo::B(ref mut val) => val.flush(),
}
}
}
impl<A, B> AsyncWrite for EitherIo<A, B>
where
A: AsyncWrite,
B: AsyncWrite,
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
match self {
EitherIo::A(ref mut val) => val.shutdown(),
EitherIo::B(ref mut val) => val.shutdown(),
#[project]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
#[project]
match self.project() {
EitherIo::A(val) => val.poll_write(cx, buf),
EitherIo::B(val) => val.poll_write(cx, buf),
}
}
fn write_buf<U: Buf>(&mut self, buf: &mut U) -> Poll<usize, io::Error>
#[project]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
#[project]
match self.project() {
EitherIo::A(val) => val.poll_flush(cx),
EitherIo::B(val) => val.poll_flush(cx),
}
}
#[project]
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
#[project]
match self.project() {
EitherIo::A(val) => val.poll_shutdown(cx),
EitherIo::B(val) => val.poll_shutdown(cx),
}
}
#[project]
fn poll_write_buf<U: Buf>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut U,
) -> Poll<Result<usize, io::Error>>
where
Self: Sized,
{
match self {
EitherIo::A(ref mut val) => val.write_buf(buf),
EitherIo::B(ref mut val) => val.write_buf(buf),
#[project]
match self.project() {
EitherIo::A(val) => val.poll_write_buf(cx, buf),
EitherIo::B(val) => val.poll_write_buf(cx, buf),
}
}
}

View File

@@ -1,37 +1,41 @@
use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use actix_codec::{AsyncRead, AsyncWrite};
use actix_connect::{
default_connector, Connect as TcpConnect, Connection as TcpConnection,
};
use actix_service::{apply_fn, Service, ServiceExt};
use actix_service::{apply_fn, Service};
use actix_utils::timeout::{TimeoutError, TimeoutService};
use futures::future::Ready;
use http::Uri;
use tokio_tcp::TcpStream;
use tokio_net::tcp::TcpStream;
use super::connection::Connection;
use super::error::ConnectError;
use super::pool::{ConnectionPool, Protocol};
use super::Connect;
#[cfg(feature = "ssl")]
use openssl::ssl::SslConnector as OpensslConnector;
#[cfg(feature = "openssl")]
use open_ssl::ssl::SslConnector as OpensslConnector;
#[cfg(feature = "rust-tls")]
use rustls::ClientConfig;
#[cfg(feature = "rust-tls")]
#[cfg(feature = "rustls")]
use rust_tls::ClientConfig;
#[cfg(feature = "rustls")]
use std::sync::Arc;
#[cfg(any(feature = "ssl", feature = "rust-tls"))]
#[cfg(any(feature = "openssl", feature = "rustls"))]
enum SslConnector {
#[cfg(feature = "ssl")]
#[cfg(feature = "openssl")]
Openssl(OpensslConnector),
#[cfg(feature = "rust-tls")]
#[cfg(feature = "rustls")]
Rustls(Arc<ClientConfig>),
}
#[cfg(not(any(feature = "ssl", feature = "rust-tls")))]
#[cfg(not(any(feature = "openssl", feature = "rustls")))]
type SslConnector = ();
/// Manages http client network connectivity
@@ -58,8 +62,8 @@ pub struct Connector<T, U> {
_t: PhantomData<U>,
}
trait Io: AsyncRead + AsyncWrite {}
impl<T: AsyncRead + AsyncWrite> Io for T {}
trait Io: AsyncRead + AsyncWrite + Unpin {}
impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
impl Connector<(), ()> {
#[allow(clippy::new_ret_no_self)]
@@ -72,9 +76,9 @@ impl Connector<(), ()> {
TcpStream,
> {
let ssl = {
#[cfg(feature = "ssl")]
#[cfg(feature = "openssl")]
{
use openssl::ssl::SslMethod;
use open_ssl::ssl::SslMethod;
let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap();
let _ = ssl
@@ -82,7 +86,7 @@ impl Connector<(), ()> {
.map_err(|e| error!("Can not set alpn protocol: {:?}", e));
SslConnector::Openssl(ssl.build())
}
#[cfg(all(not(feature = "ssl"), feature = "rust-tls"))]
#[cfg(all(not(feature = "openssl"), feature = "rustls"))]
{
let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let mut config = ClientConfig::new();
@@ -92,7 +96,7 @@ impl Connector<(), ()> {
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
SslConnector::Rustls(Arc::new(config))
}
#[cfg(not(any(feature = "ssl", feature = "rust-tls")))]
#[cfg(not(any(feature = "openssl", feature = "rustls")))]
{}
};
@@ -113,7 +117,7 @@ impl<T, U> Connector<T, U> {
/// Use custom connector.
pub fn connector<T1, U1>(self, connector: T1) -> Connector<T1, U1>
where
U1: AsyncRead + AsyncWrite + fmt::Debug,
U1: AsyncRead + AsyncWrite + Unpin + fmt::Debug,
T1: Service<
Request = TcpConnect<Uri>,
Response = TcpConnection<Uri, U1>,
@@ -135,7 +139,7 @@ impl<T, U> Connector<T, U> {
impl<T, U> Connector<T, U>
where
U: AsyncRead + AsyncWrite + fmt::Debug + 'static,
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
T: Service<
Request = TcpConnect<Uri>,
Response = TcpConnection<Uri, U>,
@@ -150,14 +154,14 @@ where
self
}
#[cfg(feature = "ssl")]
#[cfg(feature = "openssl")]
/// Use custom `SslConnector` instance.
pub fn ssl(mut self, connector: OpensslConnector) -> Self {
self.ssl = SslConnector::Openssl(connector);
self
}
#[cfg(feature = "rust-tls")]
#[cfg(feature = "rustls")]
pub fn rustls(mut self, connector: Arc<ClientConfig>) -> Self {
self.ssl = SslConnector::Rustls(connector);
self
@@ -213,7 +217,7 @@ where
self,
) -> impl Service<Request = Connect, Response = impl Connection, Error = ConnectError>
+ Clone {
#[cfg(not(any(feature = "ssl", feature = "rust-tls")))]
#[cfg(not(any(feature = "openssl", feature = "rustls")))]
{
let connector = TimeoutService::new(
self.timeout,
@@ -238,32 +242,32 @@ where
),
}
}
#[cfg(any(feature = "ssl", feature = "rust-tls"))]
#[cfg(any(feature = "openssl", feature = "rustls"))]
{
const H2: &[u8] = b"h2";
#[cfg(feature = "ssl")]
#[cfg(feature = "openssl")]
use actix_connect::ssl::OpensslConnector;
#[cfg(feature = "rust-tls")]
#[cfg(feature = "rustls")]
use actix_connect::ssl::RustlsConnector;
use actix_service::boxed::service;
#[cfg(feature = "rust-tls")]
use rustls::Session;
use actix_service::{boxed::service, pipeline};
#[cfg(feature = "rustls")]
use rust_tls::Session;
let ssl_service = TimeoutService::new(
self.timeout,
pipeline(
apply_fn(self.connector.clone(), |msg: Connect, srv| {
srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
})
.map_err(ConnectError::from)
.map_err(ConnectError::from),
)
.and_then(match self.ssl {
#[cfg(feature = "ssl")]
#[cfg(feature = "openssl")]
SslConnector::Openssl(ssl) => service(
OpensslConnector::service(ssl)
.map_err(ConnectError::from)
.map(|stream| {
let sock = stream.into_parts().0;
let h2 = sock
.get_ref()
.ssl()
.selected_alpn_protocol()
.map(|protos| protos.windows(2).any(|w| w == H2))
@@ -273,9 +277,10 @@ where
} else {
(Box::new(sock) as Box<dyn Io>, Protocol::Http1)
}
}),
})
.map_err(ConnectError::from),
),
#[cfg(feature = "rust-tls")]
#[cfg(feature = "rustls")]
SslConnector::Rustls(ssl) => service(
RustlsConnector::service(ssl)
.map_err(ConnectError::from)
@@ -303,7 +308,7 @@ where
let tcp_service = TimeoutService::new(
self.timeout,
apply_fn(self.connector.clone(), |msg: Connect, srv| {
apply_fn(self.connector, |msg: Connect, srv| {
srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
})
.map_err(ConnectError::from)
@@ -334,19 +339,20 @@ where
}
}
#[cfg(not(any(feature = "ssl", feature = "rust-tls")))]
#[cfg(not(any(feature = "openssl", feature = "rustls")))]
mod connect_impl {
use futures::future::{err, Either, FutureResult};
use futures::Poll;
use std::task::{Context, Poll};
use futures::future::{err, Either, Ready};
use futures::ready;
use super::*;
use crate::client::connection::IoConnection;
pub(crate) struct InnerConnector<T, Io>
where
Io: AsyncRead + AsyncWrite + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static,
{
pub(crate) tcp_pool: ConnectionPool<T, Io>,
@@ -354,9 +360,8 @@ mod connect_impl {
impl<T, Io> Clone for InnerConnector<T, Io>
where
Io: AsyncRead + AsyncWrite + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static,
{
fn clone(&self) -> Self {
@@ -368,9 +373,8 @@ mod connect_impl {
impl<T, Io> Service for InnerConnector<T, Io>
where
Io: AsyncRead + AsyncWrite + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static,
{
type Request = Connect;
@@ -378,38 +382,38 @@ mod connect_impl {
type Error = ConnectError;
type Future = Either<
<ConnectionPool<T, Io> as Service>::Future,
FutureResult<IoConnection<Io>, ConnectError>,
Ready<Result<IoConnection<Io>, ConnectError>>,
>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.tcp_pool.poll_ready()
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.tcp_pool.poll_ready(cx)
}
fn call(&mut self, req: Connect) -> Self::Future {
match req.uri.scheme_str() {
Some("https") | Some("wss") => {
Either::B(err(ConnectError::SslIsNotSupported))
Either::Right(err(ConnectError::SslIsNotSupported))
}
_ => Either::A(self.tcp_pool.call(req)),
_ => Either::Left(self.tcp_pool.call(req)),
}
}
}
}
#[cfg(any(feature = "ssl", feature = "rust-tls"))]
#[cfg(any(feature = "openssl", feature = "rustls"))]
mod connect_impl {
use std::marker::PhantomData;
use futures::future::{Either, FutureResult};
use futures::{Async, Future, Poll};
use futures::future::Either;
use futures::ready;
use super::*;
use crate::client::connection::EitherConnection;
pub(crate) struct InnerConnector<T1, T2, Io1, Io2>
where
Io1: AsyncRead + AsyncWrite + 'static,
Io2: AsyncRead + AsyncWrite + 'static,
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>,
T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>,
{
@@ -419,13 +423,11 @@ mod connect_impl {
impl<T1, T2, Io1, Io2> Clone for InnerConnector<T1, T2, Io1, Io2>
where
Io1: AsyncRead + AsyncWrite + 'static,
Io2: AsyncRead + AsyncWrite + 'static,
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static,
T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static,
{
fn clone(&self) -> Self {
@@ -438,53 +440,47 @@ mod connect_impl {
impl<T1, T2, Io1, Io2> Service for InnerConnector<T1, T2, Io1, Io2>
where
Io1: AsyncRead + AsyncWrite + 'static,
Io2: AsyncRead + AsyncWrite + 'static,
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static,
T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static,
{
type Request = Connect;
type Response = EitherConnection<Io1, Io2>;
type Error = ConnectError;
type Future = Either<
FutureResult<Self::Response, Self::Error>,
Either<
InnerConnectorResponseA<T1, Io1, Io2>,
InnerConnectorResponseB<T2, Io1, Io2>,
>,
>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.tcp_pool.poll_ready()
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.tcp_pool.poll_ready(cx)
}
fn call(&mut self, req: Connect) -> Self::Future {
match req.uri.scheme_str() {
Some("https") | Some("wss") => {
Either::B(Either::B(InnerConnectorResponseB {
Some("https") | Some("wss") => Either::Right(InnerConnectorResponseB {
fut: self.ssl_pool.call(req),
_t: PhantomData,
}))
}
_ => Either::B(Either::A(InnerConnectorResponseA {
}),
_ => Either::Left(InnerConnectorResponseA {
fut: self.tcp_pool.call(req),
_t: PhantomData,
})),
}),
}
}
}
#[pin_project::pin_project]
pub(crate) struct InnerConnectorResponseA<T, Io1, Io2>
where
Io1: AsyncRead + AsyncWrite + 'static,
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static,
{
#[pin]
fut: <ConnectionPool<T, Io1> as Service>::Future,
_t: PhantomData<Io2>,
}
@@ -492,29 +488,28 @@ mod connect_impl {
impl<T, Io1, Io2> Future for InnerConnectorResponseA<T, Io1, Io2>
where
T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static,
Io1: AsyncRead + AsyncWrite + 'static,
Io2: AsyncRead + AsyncWrite + 'static,
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
{
type Item = EitherConnection<Io1, Io2>;
type Error = ConnectError;
type Output = Result<EitherConnection<Io1, Io2>, ConnectError>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.fut.poll()? {
Async::NotReady => Ok(Async::NotReady),
Async::Ready(res) => Ok(Async::Ready(EitherConnection::A(res))),
}
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
Poll::Ready(
ready!(Pin::new(&mut self.get_mut().fut).poll(cx))
.map(|res| EitherConnection::A(res)),
)
}
}
#[pin_project::pin_project]
pub(crate) struct InnerConnectorResponseB<T, Io1, Io2>
where
Io2: AsyncRead + AsyncWrite + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static,
{
#[pin]
fut: <ConnectionPool<T, Io2> as Service>::Future,
_t: PhantomData<Io1>,
}
@@ -522,19 +517,17 @@ mod connect_impl {
impl<T, Io1, Io2> Future for InnerConnectorResponseB<T, Io1, Io2>
where
T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static,
Io1: AsyncRead + AsyncWrite + 'static,
Io2: AsyncRead + AsyncWrite + 'static,
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
{
type Item = EitherConnection<Io1, Io2>;
type Error = ConnectError;
type Output = Result<EitherConnection<Io1, Io2>, ConnectError>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.fut.poll()? {
Async::NotReady => Ok(Async::NotReady),
Async::Ready(res) => Ok(Async::Ready(EitherConnection::B(res))),
}
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
Poll::Ready(
ready!(Pin::new(&mut self.get_mut().fut).poll(cx))
.map(|res| EitherConnection::B(res)),
)
}
}
}

View File

@@ -3,8 +3,8 @@ use std::io;
use derive_more::{Display, From};
use trust_dns_resolver::error::ResolveError;
#[cfg(feature = "ssl")]
use openssl::ssl::{Error as SslError, HandshakeError};
#[cfg(feature = "openssl")]
use open_ssl::ssl::{Error as SslError, HandshakeError};
use crate::error::{Error, ParseError, ResponseError};
use crate::http::Error as HttpError;
@@ -18,7 +18,7 @@ pub enum ConnectError {
SslIsNotSupported,
/// SSL error
#[cfg(feature = "ssl")]
#[cfg(feature = "openssl")]
#[display(fmt = "{}", _0)]
SslError(SslError),
@@ -63,7 +63,7 @@ impl From<actix_connect::ConnectError> for ConnectError {
}
}
#[cfg(feature = "ssl")]
#[cfg(feature = "openssl")]
impl<T> From<HandshakeError<T>> for ConnectError {
fn from(err: HandshakeError<T>) -> ConnectError {
match err {

View File

@@ -1,10 +1,13 @@
use std::future::Future;
use std::io::Write;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, time};
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use bytes::{BufMut, Bytes, BytesMut};
use futures::future::{ok, Either};
use futures::{Async, Future, Poll, Sink, Stream};
use futures::future::{ok, poll_fn, Either};
use futures::{Sink, SinkExt, Stream, StreamExt};
use crate::error::PayloadError;
use crate::h1;
@@ -18,15 +21,15 @@ use super::error::{ConnectError, SendRequestError};
use super::pool::Acquired;
use crate::body::{BodySize, MessageBody};
pub(crate) fn send_request<T, B>(
pub(crate) async fn send_request<T, B>(
io: T,
mut head: RequestHeadType,
body: B,
created: time::Instant,
pool: Option<Acquired<T>>,
) -> impl Future<Item = (ResponseHead, Payload), Error = SendRequestError>
) -> Result<(ResponseHead, Payload), SendRequestError>
where
T: AsyncRead + AsyncWrite + 'static,
T: AsyncRead + AsyncWrite + Unpin + 'static,
B: MessageBody,
{
// set request host header
@@ -62,68 +65,99 @@ where
io: Some(io),
};
let len = body.size();
// create Framed and send request
Framed::new(io, h1::ClientCodec::default())
.send((head, len).into())
.from_err()
let mut framed = Framed::new(io, h1::ClientCodec::default());
framed.send((head, body.size()).into()).await?;
// send request body
.and_then(move |framed| match body.size() {
BodySize::None | BodySize::Empty | BodySize::Sized(0) => {
Either::A(ok(framed))
}
_ => Either::B(SendBody::new(body, framed)),
})
match body.size() {
BodySize::None | BodySize::Empty | BodySize::Sized(0) => (),
_ => send_body(body, &mut framed).await?,
};
// read response and init read body
.and_then(|framed| {
framed
.into_future()
.map_err(|(e, _)| SendRequestError::from(e))
.and_then(|(item, framed)| {
if let Some(res) = item {
let res = framed.into_future().await;
let (head, framed) = if let (Some(result), framed) = res {
let item = result.map_err(SendRequestError::from)?;
(item, framed)
} else {
return Err(SendRequestError::from(ConnectError::Disconnected));
};
match framed.get_codec().message_type() {
h1::MessageType::None => {
let force_close = !framed.get_codec().keepalive();
release_connection(framed, force_close);
Ok((res, Payload::None))
Ok((head, Payload::None))
}
_ => {
let pl: PayloadStream = Box::new(PlStream::new(framed));
Ok((res, pl.into()))
let pl: PayloadStream = PlStream::new(framed).boxed_local();
Ok((head, pl.into()))
}
}
} else {
Err(ConnectError::Disconnected.into())
}
})
})
}
pub(crate) fn open_tunnel<T>(
pub(crate) async fn open_tunnel<T>(
io: T,
head: RequestHeadType,
) -> impl Future<Item = (ResponseHead, Framed<T, h1::ClientCodec>), Error = SendRequestError>
) -> Result<(ResponseHead, Framed<T, h1::ClientCodec>), SendRequestError>
where
T: AsyncRead + AsyncWrite + 'static,
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
// create Framed and send request
Framed::new(io, h1::ClientCodec::default())
.send((head, BodySize::None).into())
.from_err()
let mut framed = Framed::new(io, h1::ClientCodec::default());
framed.send((head, BodySize::None).into()).await?;
// read response
.and_then(|framed| {
framed
.into_future()
.map_err(|(e, _)| SendRequestError::from(e))
.and_then(|(head, framed)| {
if let Some(head) = head {
if let (Some(result), framed) = framed.into_future().await {
let head = result.map_err(SendRequestError::from)?;
Ok((head, framed))
} else {
Err(SendRequestError::from(ConnectError::Disconnected))
}
}
/// send request body to the peer
pub(crate) async fn send_body<I, B>(
mut body: B,
framed: &mut Framed<I, h1::ClientCodec>,
) -> Result<(), SendRequestError>
where
I: ConnectionLifetime,
B: MessageBody,
{
let mut eof = false;
while !eof {
while !eof && !framed.is_write_buf_full() {
match poll_fn(|cx| body.poll_next(cx)).await {
Some(result) => {
framed.write(h1::Message::Chunk(Some(result?)))?;
}
None => {
eof = true;
framed.write(h1::Message::Chunk(None))?;
}
}
}
if !framed.is_write_buf_empty() {
poll_fn(|cx| match framed.flush(cx) {
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => {
if !framed.is_write_buf_full() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
})
})
.await?;
}
}
SinkExt::flush(framed).await?;
Ok(())
}
#[doc(hidden)]
@@ -134,7 +168,10 @@ pub struct H1Connection<T> {
pool: Option<Acquired<T>>,
}
impl<T: AsyncRead + AsyncWrite + 'static> ConnectionLifetime for H1Connection<T> {
impl<T> ConnectionLifetime for H1Connection<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
/// Close connection
fn close(&mut self) {
if let Some(mut pool) = self.pool.take() {
@@ -162,98 +199,41 @@ impl<T: AsyncRead + AsyncWrite + 'static> ConnectionLifetime for H1Connection<T>
}
}
impl<T: AsyncRead + AsyncWrite + 'static> io::Read for H1Connection<T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.io.as_mut().unwrap().read(buf)
impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncRead for H1Connection<T> {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.io.as_ref().unwrap().prepare_uninitialized_buffer(buf)
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io.as_mut().unwrap()).poll_read(cx, buf)
}
}
impl<T: AsyncRead + AsyncWrite + 'static> AsyncRead for H1Connection<T> {}
impl<T: AsyncRead + AsyncWrite + 'static> io::Write for H1Connection<T> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.io.as_mut().unwrap().write(buf)
impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncWrite for H1Connection<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io.as_mut().unwrap()).poll_write(cx, buf)
}
fn flush(&mut self) -> io::Result<()> {
self.io.as_mut().unwrap().flush()
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Pin::new(self.io.as_mut().unwrap()).poll_flush(cx)
}
impl<T: AsyncRead + AsyncWrite + 'static> AsyncWrite for H1Connection<T> {
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.io.as_mut().unwrap().shutdown()
}
}
/// Future responsible for sending request body to the peer
pub(crate) struct SendBody<I, B> {
body: Option<B>,
framed: Option<Framed<I, h1::ClientCodec>>,
flushed: bool,
}
impl<I, B> SendBody<I, B>
where
I: AsyncRead + AsyncWrite + 'static,
B: MessageBody,
{
pub(crate) fn new(body: B, framed: Framed<I, h1::ClientCodec>) -> Self {
SendBody {
body: Some(body),
framed: Some(framed),
flushed: true,
}
}
}
impl<I, B> Future for SendBody<I, B>
where
I: ConnectionLifetime,
B: MessageBody,
{
type Item = Framed<I, h1::ClientCodec>;
type Error = SendRequestError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut body_ready = true;
loop {
while body_ready
&& self.body.is_some()
&& !self.framed.as_ref().unwrap().is_write_buf_full()
{
match self.body.as_mut().unwrap().poll_next()? {
Async::Ready(item) => {
// check if body is done
if item.is_none() {
let _ = self.body.take();
}
self.flushed = false;
self.framed
.as_mut()
.unwrap()
.force_send(h1::Message::Chunk(item))?;
break;
}
Async::NotReady => body_ready = false,
}
}
if !self.flushed {
match self.framed.as_mut().unwrap().poll_complete()? {
Async::Ready(_) => {
self.flushed = true;
continue;
}
Async::NotReady => return Ok(Async::NotReady),
}
}
if self.body.is_none() {
return Ok(Async::Ready(self.framed.take().unwrap()));
}
return Ok(Async::NotReady);
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Result<(), io::Error>> {
Pin::new(self.io.as_mut().unwrap()).poll_shutdown(cx)
}
}
@@ -270,23 +250,24 @@ impl<Io: ConnectionLifetime> PlStream<Io> {
}
impl<Io: ConnectionLifetime> Stream for PlStream<Io> {
type Item = Bytes;
type Error = PayloadError;
type Item = Result<Bytes, PayloadError>;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
match self.framed.as_mut().unwrap().poll()? {
Async::NotReady => Ok(Async::NotReady),
Async::Ready(Some(chunk)) => {
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match this.framed.as_mut().unwrap().next_item(cx)? {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(chunk)) => {
if let Some(chunk) = chunk {
Ok(Async::Ready(Some(chunk)))
Poll::Ready(Some(Ok(chunk)))
} else {
let framed = self.framed.take().unwrap();
let framed = this.framed.take().unwrap();
let force_close = !framed.get_codec().keepalive();
release_connection(framed, force_close);
Ok(Async::Ready(None))
Poll::Ready(None)
}
}
Async::Ready(None) => Ok(Async::Ready(None)),
Poll::Ready(None) => Poll::Ready(None),
}
}
}

View File

@@ -1,9 +1,11 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time;
use actix_codec::{AsyncRead, AsyncWrite};
use bytes::Bytes;
use futures::future::{err, Either};
use futures::{Async, Future, Poll};
use futures::future::{err, poll_fn, Either};
use h2::{client::SendRequest, SendStream};
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING};
use http::{request::Request, HttpTryFrom, Method, Version};
@@ -17,15 +19,15 @@ use super::connection::{ConnectionType, IoConnection};
use super::error::SendRequestError;
use super::pool::Acquired;
pub(crate) fn send_request<T, B>(
io: SendRequest<Bytes>,
pub(crate) async fn send_request<T, B>(
mut io: SendRequest<Bytes>,
head: RequestHeadType,
body: B,
created: time::Instant,
pool: Option<Acquired<T>>,
) -> impl Future<Item = (ResponseHead, Payload), Error = SendRequestError>
) -> Result<(ResponseHead, Payload), SendRequestError>
where
T: AsyncRead + AsyncWrite + 'static,
T: AsyncRead + AsyncWrite + Unpin + 'static,
B: MessageBody,
{
trace!("Sending client request: {:?} {:?}", head, body.size());
@@ -36,9 +38,6 @@ where
_ => false,
};
io.ready()
.map_err(SendRequestError::from)
.and_then(move |mut io| {
let mut req = Request::new(());
*req.uri_mut() = head.as_ref().uri.clone();
*req.method_mut() = head.as_ref().method.clone();
@@ -69,9 +68,7 @@ where
// Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate.
let (head, extra_headers) = match head {
RequestHeadType::Owned(head) => {
(RequestHeadType::Owned(head), HeaderMap::new())
}
RequestHeadType::Owned(head) => (RequestHeadType::Owned(head), HeaderMap::new()),
RequestHeadType::Rc(head, extra_headers) => (
RequestHeadType::Rc(head, None),
extra_headers.unwrap_or_else(HeaderMap::new),
@@ -97,30 +94,27 @@ where
req.headers_mut().append(key, value.clone());
}
match io.send_request(req, eof) {
Ok((res, send)) => {
let res = poll_fn(|cx| io.poll_ready(cx)).await;
if let Err(e) = res {
release(io, pool, created, e.is_io());
return Err(SendRequestError::from(e));
}
let resp = match io.send_request(req, eof) {
Ok((fut, send)) => {
release(io, pool, created, false);
if !eof {
Either::A(Either::B(
SendBody {
body,
send,
buf: None,
}
.and_then(move |_| res.map_err(SendRequestError::from)),
))
} else {
Either::B(res.map_err(SendRequestError::from))
send_body(body, send).await?;
}
fut.await.map_err(SendRequestError::from)?
}
Err(e) => {
release(io, pool, created, e.is_io());
Either::A(Either::A(err(e.into())))
return Err(e.into());
}
}
})
.and_then(move |resp| {
};
let (parts, body) = resp.into_parts();
let payload = if head_req { Payload::None } else { body.into() };
@@ -128,66 +122,56 @@ where
head.version = parts.version;
head.headers = parts.headers.into();
Ok((head, payload))
})
.from_err()
}
struct SendBody<B: MessageBody> {
body: B,
send: SendStream<Bytes>,
buf: Option<Bytes>,
}
impl<B: MessageBody> Future for SendBody<B> {
type Item = ();
type Error = SendRequestError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
async fn send_body<B: MessageBody>(
mut body: B,
mut send: SendStream<Bytes>,
) -> Result<(), SendRequestError> {
let mut buf = None;
loop {
if self.buf.is_none() {
match self.body.poll_next() {
Ok(Async::Ready(Some(buf))) => {
self.send.reserve_capacity(buf.len());
self.buf = Some(buf);
if buf.is_none() {
match poll_fn(|cx| body.poll_next(cx)).await {
Some(Ok(b)) => {
send.reserve_capacity(b.len());
buf = Some(b);
}
Ok(Async::Ready(None)) => {
if let Err(e) = self.send.send_data(Bytes::new(), true) {
Some(Err(e)) => return Err(e.into()),
None => {
if let Err(e) = send.send_data(Bytes::new(), true) {
return Err(e.into());
}
self.send.reserve_capacity(0);
return Ok(Async::Ready(()));
send.reserve_capacity(0);
return Ok(());
}
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(e) => return Err(e.into()),
}
}
match self.send.poll_capacity() {
Ok(Async::NotReady) => return Ok(Async::NotReady),
Ok(Async::Ready(None)) => return Ok(Async::Ready(())),
Ok(Async::Ready(Some(cap))) => {
let mut buf = self.buf.take().unwrap();
let len = buf.len();
let bytes = buf.split_to(std::cmp::min(cap, len));
match poll_fn(|cx| send.poll_capacity(cx)).await {
None => return Ok(()),
Some(Ok(cap)) => {
let b = buf.as_mut().unwrap();
let len = b.len();
let bytes = b.split_to(std::cmp::min(cap, len));
if let Err(e) = self.send.send_data(bytes, false) {
if let Err(e) = send.send_data(bytes, false) {
return Err(e.into());
} else {
if !buf.is_empty() {
self.send.reserve_capacity(buf.len());
self.buf = Some(buf);
if !b.is_empty() {
send.reserve_capacity(b.len());
} else {
buf = None;
}
continue;
}
}
Err(e) => return Err(e.into()),
}
Some(Err(e)) => return Err(e.into()),
}
}
}
// release SendRequest object
fn release<T: AsyncRead + AsyncWrite + 'static>(
fn release<T: AsyncRead + AsyncWrite + Unpin + 'static>(
io: SendRequest<Bytes>,
pool: Option<Acquired<T>>,
created: time::Instant,

View File

@@ -1,22 +1,23 @@
use std::cell::RefCell;
use std::collections::VecDeque;
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use actix_codec::{AsyncRead, AsyncWrite};
use actix_service::Service;
use actix_utils::{oneshot, task::LocalWaker};
use bytes::Bytes;
use futures::future::{err, ok, Either, FutureResult};
use futures::task::AtomicTask;
use futures::unsync::oneshot;
use futures::{Async, Future, Poll};
use h2::client::{handshake, Handshake};
use futures::future::{err, ok, poll_fn, Either, FutureExt, LocalBoxFuture, Ready};
use h2::client::{handshake, Connection, SendRequest};
use hashbrown::HashMap;
use http::uri::Authority;
use indexmap::IndexSet;
use slab::Slab;
use tokio_timer::{sleep, Delay};
use tokio_timer::{delay_for, Delay};
use super::connection::{ConnectionType, IoConnection};
use super::error::ConnectError;
@@ -41,16 +42,12 @@ impl From<Authority> for Key {
}
/// Connections pool
pub(crate) struct ConnectionPool<T, Io: AsyncRead + AsyncWrite + 'static>(
T,
Rc<RefCell<Inner<Io>>>,
);
pub(crate) struct ConnectionPool<T, Io: 'static>(Rc<RefCell<T>>, Rc<RefCell<Inner<Io>>>);
impl<T, Io> ConnectionPool<T, Io>
where
Io: AsyncRead + AsyncWrite + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static,
{
pub(crate) fn new(
@@ -61,7 +58,7 @@ where
limit: usize,
) -> Self {
ConnectionPool(
connector,
Rc::new(RefCell::new(connector)),
Rc::new(RefCell::new(Inner {
conn_lifetime,
conn_keep_alive,
@@ -71,7 +68,7 @@ where
waiters: Slab::new(),
waiters_queue: IndexSet::new(),
available: HashMap::new(),
task: None,
waker: LocalWaker::new(),
})),
)
}
@@ -79,8 +76,7 @@ where
impl<T, Io> Clone for ConnectionPool<T, Io>
where
T: Clone,
Io: AsyncRead + AsyncWrite + 'static,
Io: 'static,
{
fn clone(&self) -> Self {
ConnectionPool(self.0.clone(), self.1.clone())
@@ -89,86 +85,116 @@ where
impl<T, Io> Service for ConnectionPool<T, Io>
where
Io: AsyncRead + AsyncWrite + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static,
{
type Request = Connect;
type Response = IoConnection<Io>;
type Error = ConnectError;
type Future = Either<
FutureResult<Self::Response, Self::Error>,
Either<WaitForConnection<Io>, OpenConnection<T::Future, Io>>,
>;
type Future = LocalBoxFuture<'static, Result<IoConnection<Io>, ConnectError>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.0.poll_ready()
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
fn call(&mut self, req: Connect) -> Self::Future {
// start support future
tokio_executor::current_thread::spawn(ConnectorPoolSupport {
connector: self.0.clone(),
inner: self.1.clone(),
});
let mut connector = self.0.clone();
let inner = self.1.clone();
let fut = async move {
let key = if let Some(authority) = req.uri.authority_part() {
authority.clone().into()
} else {
return Either::A(err(ConnectError::Unresolverd));
return Err(ConnectError::Unresolverd);
};
// acquire connection
match self.1.as_ref().borrow_mut().acquire(&key) {
match poll_fn(|cx| Poll::Ready(inner.borrow_mut().acquire(&key, cx))).await {
Acquire::Acquired(io, created) => {
// use existing connection
return Either::A(ok(IoConnection::new(
return Ok(IoConnection::new(
io,
created,
Some(Acquired(key, Some(self.1.clone()))),
)));
Some(Acquired(key, Some(inner))),
));
}
Acquire::Available => {
// open new connection
return Either::B(Either::B(OpenConnection::new(
key,
self.1.clone(),
self.0.call(req),
)));
}
_ => (),
}
// open tcp connection
let (io, proto) = connector.call(req).await?;
let guard = OpenGuard::new(key, inner);
if proto == Protocol::Http1 {
Ok(IoConnection::new(
ConnectionType::H1(io),
Instant::now(),
Some(guard.consume()),
))
} else {
let (snd, connection) = handshake(io).await?;
tokio_executor::current_thread::spawn(connection.map(|_| ()));
Ok(IoConnection::new(
ConnectionType::H2(snd),
Instant::now(),
Some(guard.consume()),
))
}
}
_ => {
// connection is not available, wait
let (rx, token, support) = self.1.as_ref().borrow_mut().wait_for(req);
let (rx, token) = inner.borrow_mut().wait_for(req);
// start support future
if !support {
self.1.as_ref().borrow_mut().task = Some(AtomicTask::new());
tokio_current_thread::spawn(ConnectorPoolSupport {
connector: self.0.clone(),
inner: self.1.clone(),
})
let guard = WaiterGuard::new(key, token, inner);
let res = match rx.await {
Err(_) => Err(ConnectError::Disconnected),
Ok(res) => res,
};
guard.consume();
res
}
}
};
Either::B(Either::A(WaitForConnection {
rx,
key,
token,
inner: Some(self.1.clone()),
}))
fut.boxed_local()
}
}
#[doc(hidden)]
pub struct WaitForConnection<Io>
struct WaiterGuard<Io>
where
Io: AsyncRead + AsyncWrite + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
key: Key,
token: usize,
rx: oneshot::Receiver<Result<IoConnection<Io>, ConnectError>>,
inner: Option<Rc<RefCell<Inner<Io>>>>,
}
impl<Io> Drop for WaitForConnection<Io>
impl<Io> WaiterGuard<Io>
where
Io: AsyncRead + AsyncWrite + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
fn new(key: Key, token: usize, inner: Rc<RefCell<Inner<Io>>>) -> Self {
Self {
key,
token,
inner: Some(inner),
}
}
fn consume(mut self) {
let _ = self.inner.take();
}
}
impl<Io> Drop for WaiterGuard<Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
fn drop(&mut self) {
if let Some(i) = self.inner.take() {
@@ -179,113 +205,43 @@ where
}
}
impl<Io> Future for WaitForConnection<Io>
struct OpenGuard<Io>
where
Io: AsyncRead + AsyncWrite,
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
type Item = IoConnection<Io>;
type Error = ConnectError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.rx.poll() {
Ok(Async::Ready(item)) => match item {
Err(err) => Err(err),
Ok(conn) => {
let _ = self.inner.take();
Ok(Async::Ready(conn))
}
},
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(_) => {
let _ = self.inner.take();
Err(ConnectError::Disconnected)
}
}
}
}
#[doc(hidden)]
pub struct OpenConnection<F, Io>
where
Io: AsyncRead + AsyncWrite + 'static,
{
fut: F,
key: Key,
h2: Option<Handshake<Io, Bytes>>,
inner: Option<Rc<RefCell<Inner<Io>>>>,
}
impl<F, Io> OpenConnection<F, Io>
impl<Io> OpenGuard<Io>
where
F: Future<Item = (Io, Protocol), Error = ConnectError>,
Io: AsyncRead + AsyncWrite + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
fn new(key: Key, inner: Rc<RefCell<Inner<Io>>>, fut: F) -> Self {
OpenConnection {
fn new(key: Key, inner: Rc<RefCell<Inner<Io>>>) -> Self {
Self {
key,
fut,
inner: Some(inner),
h2: None,
}
}
}
impl<F, Io> Drop for OpenConnection<F, Io>
fn consume(mut self) -> Acquired<Io> {
Acquired(self.key.clone(), self.inner.take())
}
}
impl<Io> Drop for OpenGuard<Io>
where
Io: AsyncRead + AsyncWrite + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
let mut inner = inner.as_ref().borrow_mut();
if let Some(i) = self.inner.take() {
let mut inner = i.as_ref().borrow_mut();
inner.release();
inner.check_availibility();
}
}
}
impl<F, Io> Future for OpenConnection<F, Io>
where
F: Future<Item = (Io, Protocol), Error = ConnectError>,
Io: AsyncRead + AsyncWrite,
{
type Item = IoConnection<Io>;
type Error = ConnectError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(ref mut h2) = self.h2 {
return match h2.poll() {
Ok(Async::Ready((snd, connection))) => {
tokio_current_thread::spawn(connection.map_err(|_| ()));
Ok(Async::Ready(IoConnection::new(
ConnectionType::H2(snd),
Instant::now(),
Some(Acquired(self.key.clone(), self.inner.take())),
)))
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(e) => Err(e.into()),
};
}
match self.fut.poll() {
Err(err) => Err(err),
Ok(Async::Ready((io, proto))) => {
if proto == Protocol::Http1 {
Ok(Async::Ready(IoConnection::new(
ConnectionType::H1(io),
Instant::now(),
Some(Acquired(self.key.clone(), self.inner.take())),
)))
} else {
self.h2 = Some(handshake(io));
self.poll()
}
}
Ok(Async::NotReady) => Ok(Async::NotReady),
}
}
}
enum Acquire<T> {
Acquired(ConnectionType<T>, Instant),
Available,
@@ -312,7 +268,7 @@ pub(crate) struct Inner<Io> {
)>,
>,
waiters_queue: IndexSet<(Key, usize)>,
task: Option<AtomicTask>,
waker: LocalWaker,
}
impl<Io> Inner<Io> {
@@ -332,7 +288,7 @@ impl<Io> Inner<Io> {
impl<Io> Inner<Io>
where
Io: AsyncRead + AsyncWrite + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
/// connection is not available, wait
fn wait_for(
@@ -341,7 +297,6 @@ where
) -> (
oneshot::Receiver<Result<IoConnection<Io>, ConnectError>>,
usize,
bool,
) {
let (tx, rx) = oneshot::channel();
@@ -351,10 +306,10 @@ where
entry.insert(Some((connect, tx)));
assert!(self.waiters_queue.insert((key, token)));
(rx, token, self.task.is_some())
(rx, token)
}
fn acquire(&mut self, key: &Key) -> Acquire<Io> {
fn acquire(&mut self, key: &Key, cx: &mut Context) -> Acquire<Io> {
// check limits
if self.limit > 0 && self.acquired >= self.limit {
return Acquire::NotAvailable;
@@ -373,7 +328,7 @@ where
{
if let Some(timeout) = self.disconnect_timeout {
if let ConnectionType::H1(io) = conn.io {
tokio_current_thread::spawn(CloseConnection::new(
tokio_executor::current_thread::spawn(CloseConnection::new(
io, timeout,
))
}
@@ -382,19 +337,19 @@ where
let mut io = conn.io;
let mut buf = [0; 2];
if let ConnectionType::H1(ref mut s) = io {
match s.read(&mut buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (),
Ok(n) if n > 0 => {
match Pin::new(s).poll_read(cx, &mut buf) {
Poll::Pending => (),
Poll::Ready(Ok(n)) if n > 0 => {
if let Some(timeout) = self.disconnect_timeout {
if let ConnectionType::H1(io) = io {
tokio_current_thread::spawn(
tokio_executor::current_thread::spawn(
CloseConnection::new(io, timeout),
)
}
}
continue;
}
Ok(_) | Err(_) => continue,
_ => continue,
}
}
return Acquire::Acquired(io, conn.created);
@@ -421,7 +376,7 @@ where
self.acquired -= 1;
if let Some(timeout) = self.disconnect_timeout {
if let ConnectionType::H1(io) = io {
tokio_current_thread::spawn(CloseConnection::new(io, timeout))
tokio_executor::current_thread::spawn(CloseConnection::new(io, timeout))
}
}
self.check_availibility();
@@ -429,9 +384,7 @@ where
fn check_availibility(&self) {
if !self.waiters_queue.is_empty() && self.acquired < self.limit {
if let Some(t) = self.task.as_ref() {
t.notify()
}
self.waker.wake();
}
}
}
@@ -443,29 +396,30 @@ struct CloseConnection<T> {
impl<T> CloseConnection<T>
where
T: AsyncWrite,
T: AsyncWrite + Unpin,
{
fn new(io: T, timeout: Duration) -> Self {
CloseConnection {
io,
timeout: sleep(timeout),
timeout: delay_for(timeout),
}
}
}
impl<T> Future for CloseConnection<T>
where
T: AsyncWrite,
T: AsyncWrite + Unpin,
{
type Item = ();
type Error = ();
type Output = ();
fn poll(&mut self) -> Poll<(), ()> {
match self.timeout.poll() {
Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())),
Ok(Async::NotReady) => match self.io.shutdown() {
Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())),
Ok(Async::NotReady) => Ok(Async::NotReady),
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
let this = self.get_mut();
match Pin::new(&mut this.timeout).poll(cx) {
Poll::Ready(_) => Poll::Ready(()),
Poll::Pending => match Pin::new(&mut this.io).poll_shutdown(cx) {
Poll::Ready(_) => Poll::Ready(()),
Poll::Pending => Poll::Pending,
},
}
}
@@ -473,7 +427,7 @@ where
struct ConnectorPoolSupport<T, Io>
where
Io: AsyncRead + AsyncWrite + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
connector: T,
inner: Rc<RefCell<Inner<Io>>>,
@@ -481,16 +435,17 @@ where
impl<T, Io> Future for ConnectorPoolSupport<T, Io>
where
Io: AsyncRead + AsyncWrite + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>,
T::Future: 'static,
{
type Item = ();
type Error = ();
type Output = ();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut inner = self.inner.as_ref().borrow_mut();
inner.task.as_ref().unwrap().register();
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
let mut inner = this.inner.as_ref().borrow_mut();
inner.waker.register(cx.waker());
// check waiters
loop {
@@ -505,14 +460,14 @@ where
continue;
}
match inner.acquire(&key) {
match inner.acquire(&key, cx) {
Acquire::NotAvailable => break,
Acquire::Acquired(io, created) => {
let tx = inner.waiters.get_mut(token).unwrap().take().unwrap().1;
if let Err(conn) = tx.send(Ok(IoConnection::new(
io,
created,
Some(Acquired(key.clone(), Some(self.inner.clone()))),
Some(Acquired(key.clone(), Some(this.inner.clone()))),
))) {
let (io, created) = conn.unwrap().into_inner();
inner.release_conn(&key, io, created);
@@ -524,33 +479,38 @@ where
OpenWaitingConnection::spawn(
key.clone(),
tx,
self.inner.clone(),
self.connector.call(connect),
this.inner.clone(),
this.connector.call(connect),
);
}
}
let _ = inner.waiters_queue.swap_remove_index(0);
}
Ok(Async::NotReady)
Poll::Pending
}
}
struct OpenWaitingConnection<F, Io>
where
Io: AsyncRead + AsyncWrite + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
fut: F,
key: Key,
h2: Option<Handshake<Io, Bytes>>,
h2: Option<
LocalBoxFuture<
'static,
Result<(SendRequest<Bytes>, Connection<Io, Bytes>), h2::Error>,
>,
>,
rx: Option<oneshot::Sender<Result<IoConnection<Io>, ConnectError>>>,
inner: Option<Rc<RefCell<Inner<Io>>>>,
}
impl<F, Io> OpenWaitingConnection<F, Io>
where
F: Future<Item = (Io, Protocol), Error = ConnectError> + 'static,
Io: AsyncRead + AsyncWrite + 'static,
F: Future<Output = Result<(Io, Protocol), ConnectError>> + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
fn spawn(
key: Key,
@@ -558,7 +518,7 @@ where
inner: Rc<RefCell<Inner<Io>>>,
fut: F,
) {
tokio_current_thread::spawn(OpenWaitingConnection {
tokio_executor::current_thread::spawn(OpenWaitingConnection {
key,
fut,
h2: None,
@@ -570,7 +530,7 @@ where
impl<F, Io> Drop for OpenWaitingConnection<F, Io>
where
Io: AsyncRead + AsyncWrite + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
@@ -583,59 +543,60 @@ where
impl<F, Io> Future for OpenWaitingConnection<F, Io>
where
F: Future<Item = (Io, Protocol), Error = ConnectError>,
Io: AsyncRead + AsyncWrite,
F: Future<Output = Result<(Io, Protocol), ConnectError>>,
Io: AsyncRead + AsyncWrite + Unpin,
{
type Item = ();
type Error = ();
type Output = ();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(ref mut h2) = self.h2 {
return match h2.poll() {
Ok(Async::Ready((snd, connection))) => {
tokio_current_thread::spawn(connection.map_err(|_| ()));
let rx = self.rx.take().unwrap();
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
if let Some(ref mut h2) = this.h2 {
return match Pin::new(h2).poll(cx) {
Poll::Ready(Ok((snd, connection))) => {
tokio_executor::current_thread::spawn(connection.map(|_| ()));
let rx = this.rx.take().unwrap();
let _ = rx.send(Ok(IoConnection::new(
ConnectionType::H2(snd),
Instant::now(),
Some(Acquired(self.key.clone(), self.inner.take())),
Some(Acquired(this.key.clone(), this.inner.take())),
)));
Ok(Async::Ready(()))
Poll::Ready(())
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(err) => {
let _ = self.inner.take();
if let Some(rx) = self.rx.take() {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(err)) => {
let _ = this.inner.take();
if let Some(rx) = this.rx.take() {
let _ = rx.send(Err(ConnectError::H2(err)));
}
Err(())
Poll::Ready(())
}
};
}
match self.fut.poll() {
Err(err) => {
let _ = self.inner.take();
if let Some(rx) = self.rx.take() {
match unsafe { Pin::new_unchecked(&mut this.fut) }.poll(cx) {
Poll::Ready(Err(err)) => {
let _ = this.inner.take();
if let Some(rx) = this.rx.take() {
let _ = rx.send(Err(err));
}
Err(())
Poll::Ready(())
}
Ok(Async::Ready((io, proto))) => {
Poll::Ready(Ok((io, proto))) => {
if proto == Protocol::Http1 {
let rx = self.rx.take().unwrap();
let rx = this.rx.take().unwrap();
let _ = rx.send(Ok(IoConnection::new(
ConnectionType::H1(io),
Instant::now(),
Some(Acquired(self.key.clone(), self.inner.take())),
Some(Acquired(this.key.clone(), this.inner.take())),
)));
Ok(Async::Ready(()))
Poll::Ready(())
} else {
self.h2 = Some(handshake(io));
self.poll()
this.h2 = Some(handshake(io).boxed_local());
unsafe { Pin::new_unchecked(this) }.poll(cx)
}
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Poll::Pending => Poll::Pending,
}
}
}
@@ -644,7 +605,7 @@ pub(crate) struct Acquired<T>(Key, Option<Rc<RefCell<Inner<T>>>>);
impl<T> Acquired<T>
where
T: AsyncRead + AsyncWrite + 'static,
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
pub(crate) fn close(&mut self, conn: IoConnection<T>) {
if let Some(inner) = self.1.take() {

View File

@@ -1,8 +1,8 @@
use std::cell::UnsafeCell;
use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::Service;
use futures::Poll;
#[doc(hidden)]
/// Service that allows to turn non-clone service to a service with `Clone` impl
@@ -32,8 +32,8 @@ where
type Error = T::Error;
type Future = T::Future;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
unsafe { &mut *self.0.as_ref().get() }.poll_ready()
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
unsafe { &mut *self.0.as_ref().get() }.poll_ready(cx)
}
fn call(&mut self, req: T::Request) -> Self::Future {

View File

@@ -5,9 +5,9 @@ use std::rc::Rc;
use std::time::{Duration, Instant};
use bytes::BytesMut;
use futures::{future, Future};
use futures::{future, Future, FutureExt};
use time;
use tokio_timer::{sleep, Delay};
use tokio_timer::{delay, delay_for, Delay};
// "Sun, 06 Nov 1994 08:49:37 GMT".len()
const DATE_VALUE_LENGTH: usize = 29;
@@ -104,10 +104,10 @@ impl ServiceConfig {
#[inline]
/// Client timeout for first request.
pub fn client_timer(&self) -> Option<Delay> {
let delay = self.0.client_timeout;
if delay != 0 {
Some(Delay::new(
self.0.timer.now() + Duration::from_millis(delay),
let delay_time = self.0.client_timeout;
if delay_time != 0 {
Some(delay(
self.0.timer.now() + Duration::from_millis(delay_time),
))
} else {
None
@@ -138,7 +138,7 @@ impl ServiceConfig {
/// Return keep-alive timer delay is configured.
pub fn keep_alive_timer(&self) -> Option<Delay> {
if let Some(ka) = self.0.keep_alive {
Some(Delay::new(self.0.timer.now() + ka))
Some(delay(self.0.timer.now() + ka))
} else {
None
}
@@ -242,12 +242,12 @@ impl DateService {
// periodic date update
let s = self.clone();
tokio_current_thread::spawn(sleep(Duration::from_millis(500)).then(
move |_| {
tokio_executor::current_thread::spawn(
delay_for(Duration::from_millis(500)).then(move |_| {
s.0.reset();
future::ok(())
},
));
future::ready(())
}),
);
}
}
@@ -277,7 +277,7 @@ mod tests {
fn test_date() {
let mut rt = System::new("test");
let _ = rt.block_on(future::lazy(|| {
let _ = rt.block_on(future::lazy(|_| {
let settings = ServiceConfig::new(KeepAlive::Os, 0, 0);
let mut buf1 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10);
settings.set_date(&mut buf1);

View File

@@ -1,13 +1,12 @@
use ring::digest::{Algorithm, SHA256};
use ring::hkdf::expand;
use ring::hmac::SigningKey;
use ring::hkdf::{Algorithm, KeyType, Prk, HKDF_SHA256};
use ring::hmac;
use ring::rand::{SecureRandom, SystemRandom};
use super::private::KEY_LEN as PRIVATE_KEY_LEN;
use super::signed::KEY_LEN as SIGNED_KEY_LEN;
static HKDF_DIGEST: &Algorithm = &SHA256;
const KEYS_INFO: &str = "COOKIE;SIGNED:HMAC-SHA256;PRIVATE:AEAD-AES-256-GCM";
static HKDF_DIGEST: Algorithm = HKDF_SHA256;
const KEYS_INFO: &[&[u8]] = &[b"COOKIE;SIGNED:HMAC-SHA256;PRIVATE:AEAD-AES-256-GCM"];
/// A cryptographic master key for use with `Signed` and/or `Private` jars.
///
@@ -25,6 +24,13 @@ pub struct Key {
encryption_key: [u8; PRIVATE_KEY_LEN],
}
impl KeyType for &Key {
#[inline]
fn len(&self) -> usize {
SIGNED_KEY_LEN + PRIVATE_KEY_LEN
}
}
impl Key {
/// Derives new signing/encryption keys from a master key.
///
@@ -56,21 +62,26 @@ impl Key {
);
}
// Expand the user's key into two.
let prk = SigningKey::new(HKDF_DIGEST, key);
// An empty `Key` structure; will be filled in with HKDF derived keys.
let mut output_key = Key {
signing_key: [0; SIGNED_KEY_LEN],
encryption_key: [0; PRIVATE_KEY_LEN],
};
// Expand the master key into two HKDF generated keys.
let mut both_keys = [0; SIGNED_KEY_LEN + PRIVATE_KEY_LEN];
expand(&prk, KEYS_INFO.as_bytes(), &mut both_keys);
let prk = Prk::new_less_safe(HKDF_DIGEST, key);
let okm = prk.expand(KEYS_INFO, &output_key).expect("okm expand");
okm.fill(&mut both_keys).expect("fill keys");
// Copy the keys into their respective arrays.
let mut signing_key = [0; SIGNED_KEY_LEN];
let mut encryption_key = [0; PRIVATE_KEY_LEN];
signing_key.copy_from_slice(&both_keys[..SIGNED_KEY_LEN]);
encryption_key.copy_from_slice(&both_keys[SIGNED_KEY_LEN..]);
Key {
signing_key,
encryption_key,
}
// Copy the key parts into their respective fields.
output_key
.signing_key
.copy_from_slice(&both_keys[..SIGNED_KEY_LEN]);
output_key
.encryption_key
.copy_from_slice(&both_keys[SIGNED_KEY_LEN..]);
output_key
}
/// Generates signing/encryption keys from a secure, random source. Keys are

View File

@@ -1,8 +1,8 @@
use std::str;
use log::warn;
use ring::aead::{open_in_place, seal_in_place, Aad, Algorithm, Nonce, AES_256_GCM};
use ring::aead::{OpeningKey, SealingKey};
use ring::aead::{Aad, Algorithm, Nonce, AES_256_GCM};
use ring::aead::{LessSafeKey, UnboundKey};
use ring::rand::{SecureRandom, SystemRandom};
use super::Key;
@@ -10,7 +10,7 @@ use crate::cookie::{Cookie, CookieJar};
// Keep these in sync, and keep the key len synced with the `private` docs as
// well as the `KEYS_INFO` const in secure::Key.
static ALGO: &Algorithm = &AES_256_GCM;
static ALGO: &'static Algorithm = &AES_256_GCM;
const NONCE_LEN: usize = 12;
pub const KEY_LEN: usize = 32;
@@ -53,11 +53,14 @@ impl<'a> PrivateJar<'a> {
}
let ad = Aad::from(name.as_bytes());
let key = OpeningKey::new(ALGO, &self.key).expect("opening key");
let (nonce, sealed) = data.split_at_mut(NONCE_LEN);
let key = LessSafeKey::new(
UnboundKey::new(&ALGO, &self.key).expect("matching key length"),
);
let (nonce, mut sealed) = data.split_at_mut(NONCE_LEN);
let nonce =
Nonce::try_assume_unique_for_key(nonce).expect("invalid length of `nonce`");
let unsealed = open_in_place(&key, nonce, ad, 0, sealed)
let unsealed = key
.open_in_place(nonce, ad, &mut sealed)
.map_err(|_| "invalid key/nonce/value: bad seal")?;
if let Ok(unsealed_utf8) = str::from_utf8(unsealed) {
@@ -196,30 +199,33 @@ Please change it as soon as possible."
fn encrypt_name_value(name: &[u8], value: &[u8], key: &[u8]) -> Vec<u8> {
// Create the `SealingKey` structure.
let key = SealingKey::new(ALGO, key).expect("sealing key creation");
let unbound = UnboundKey::new(&ALGO, key).expect("matching key length");
let key = LessSafeKey::new(unbound);
// Create a vec to hold the [nonce | cookie value | overhead].
let overhead = ALGO.tag_len();
let mut data = vec![0; NONCE_LEN + value.len() + overhead];
let mut data = vec![0; NONCE_LEN + value.len() + ALGO.tag_len()];
// Randomly generate the nonce, then copy the cookie value as input.
let (nonce, in_out) = data.split_at_mut(NONCE_LEN);
let (in_out, tag) = in_out.split_at_mut(value.len());
in_out.copy_from_slice(value);
// Randomly generate the nonce into the nonce piece.
SystemRandom::new()
.fill(nonce)
.expect("couldn't random fill nonce");
in_out[..value.len()].copy_from_slice(value);
let nonce =
Nonce::try_assume_unique_for_key(nonce).expect("invalid length of `nonce`");
let nonce = Nonce::try_assume_unique_for_key(nonce).expect("invalid `nonce` length");
// Use cookie's name as associated data to prevent value swapping.
let ad = Aad::from(name);
let ad_tag = key
.seal_in_place_separate_tag(nonce, ad, in_out)
.expect("in-place seal");
// Perform the actual sealing operation and get the output length.
let output_len =
seal_in_place(&key, nonce, ad, in_out, overhead).expect("in-place seal");
// Copy the tag into the tag piece.
tag.copy_from_slice(ad_tag.as_ref());
// Remove the overhead and return the sealed content.
data.truncate(NONCE_LEN + output_len);
data
}

View File

@@ -1,12 +1,11 @@
use ring::digest::{Algorithm, SHA256};
use ring::hmac::{sign, verify_with_own_key as verify, SigningKey};
use ring::hmac::{self, sign, verify};
use super::Key;
use crate::cookie::{Cookie, CookieJar};
// Keep these in sync, and keep the key len synced with the `signed` docs as
// well as the `KEYS_INFO` const in secure::Key.
static HMAC_DIGEST: &Algorithm = &SHA256;
static HMAC_DIGEST: hmac::Algorithm = hmac::HMAC_SHA256;
const BASE64_DIGEST_LEN: usize = 44;
pub const KEY_LEN: usize = 32;
@@ -21,7 +20,7 @@ pub const KEY_LEN: usize = 32;
/// This type is only available when the `secure` feature is enabled.
pub struct SignedJar<'a> {
parent: &'a mut CookieJar,
key: SigningKey,
key: hmac::Key,
}
impl<'a> SignedJar<'a> {
@@ -32,7 +31,7 @@ impl<'a> SignedJar<'a> {
pub fn new(parent: &'a mut CookieJar, key: &Key) -> SignedJar<'a> {
SignedJar {
parent,
key: SigningKey::new(HMAC_DIGEST, key.signing()),
key: hmac::Key::new(HMAC_DIGEST, key.signing()),
}
}

View File

@@ -1,4 +1,7 @@
use std::future::Future;
use std::io::{self, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_threadpool::{run, CpuFuture};
#[cfg(feature = "brotli")]
@@ -6,7 +9,7 @@ use brotli2::write::BrotliDecoder;
use bytes::Bytes;
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
use flate2::write::{GzDecoder, ZlibDecoder};
use futures::{try_ready, Async, Future, Poll, Stream};
use futures::{ready, Stream};
use super::Writer;
use crate::error::PayloadError;
@@ -18,12 +21,12 @@ pub struct Decoder<S> {
decoder: Option<ContentDecoder>,
stream: S,
eof: bool,
fut: Option<CpuFuture<(Option<Bytes>, ContentDecoder), io::Error>>,
fut: Option<CpuFuture<Result<(Option<Bytes>, ContentDecoder), io::Error>>>,
}
impl<S> Decoder<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
S: Stream<Item = Result<Bytes, PayloadError>>,
{
/// Construct a decoder.
#[inline]
@@ -71,34 +74,41 @@ where
impl<S> Stream for Decoder<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{
type Item = Bytes;
type Error = PayloadError;
type Item = Result<Bytes, PayloadError>;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Option<Self::Item>> {
loop {
if let Some(ref mut fut) = self.fut {
let (chunk, decoder) = try_ready!(fut.poll());
let (chunk, decoder) = match ready!(Pin::new(fut).poll(cx)) {
Ok(Ok(item)) => item,
Ok(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
Err(e) => return Poll::Ready(Some(Err(e.into()))),
};
self.decoder = Some(decoder);
self.fut.take();
if let Some(chunk) = chunk {
return Ok(Async::Ready(Some(chunk)));
return Poll::Ready(Some(Ok(chunk)));
}
}
if self.eof {
return Ok(Async::Ready(None));
return Poll::Ready(None);
}
match self.stream.poll()? {
Async::Ready(Some(chunk)) => {
match Pin::new(&mut self.stream).poll_next(cx) {
Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
Poll::Ready(Some(Ok(chunk))) => {
if let Some(mut decoder) = self.decoder.take() {
if chunk.len() < INPLACE {
let chunk = decoder.feed_data(chunk)?;
self.decoder = Some(decoder);
if let Some(chunk) = chunk {
return Ok(Async::Ready(Some(chunk)));
return Poll::Ready(Some(Ok(chunk)));
}
} else {
self.fut = Some(run(move || {
@@ -108,21 +118,25 @@ where
}
continue;
} else {
return Ok(Async::Ready(Some(chunk)));
return Poll::Ready(Some(Ok(chunk)));
}
}
Async::Ready(None) => {
Poll::Ready(None) => {
self.eof = true;
return if let Some(mut decoder) = self.decoder.take() {
Ok(Async::Ready(decoder.feed_eof()?))
match decoder.feed_eof() {
Ok(Some(res)) => Poll::Ready(Some(Ok(res))),
Ok(None) => Poll::Ready(None),
Err(err) => Poll::Ready(Some(Err(err.into()))),
}
} else {
Ok(Async::Ready(None))
Poll::Ready(None)
};
}
Async::NotReady => break,
Poll::Pending => break,
}
}
Ok(Async::NotReady)
Poll::Pending
}
}

View File

@@ -1,5 +1,8 @@
//! Stream encoder
use std::future::Future;
use std::io::{self, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_threadpool::{run, CpuFuture};
#[cfg(feature = "brotli")]
@@ -7,7 +10,6 @@ use brotli2::write::BrotliEncoder;
use bytes::Bytes;
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
use flate2::write::{GzEncoder, ZlibEncoder};
use futures::{Async, Future, Poll};
use crate::body::{Body, BodySize, MessageBody, ResponseBody};
use crate::http::header::{ContentEncoding, CONTENT_ENCODING};
@@ -22,7 +24,7 @@ pub struct Encoder<B> {
eof: bool,
body: EncoderBody<B>,
encoder: Option<ContentEncoder>,
fut: Option<CpuFuture<ContentEncoder, io::Error>>,
fut: Option<CpuFuture<Result<ContentEncoder, io::Error>>>,
}
impl<B: MessageBody> Encoder<B> {
@@ -94,43 +96,46 @@ impl<B: MessageBody> MessageBody for Encoder<B> {
}
}
fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
fn poll_next(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, Error>>> {
loop {
if self.eof {
return Ok(Async::Ready(None));
return Poll::Ready(None);
}
if let Some(ref mut fut) = self.fut {
let mut encoder = futures::try_ready!(fut.poll());
let mut encoder = match futures::ready!(Pin::new(fut).poll(cx)) {
Ok(Ok(item)) => item,
Ok(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
Err(e) => return Poll::Ready(Some(Err(e.into()))),
};
let chunk = encoder.take();
self.encoder = Some(encoder);
self.fut.take();
if !chunk.is_empty() {
return Ok(Async::Ready(Some(chunk)));
return Poll::Ready(Some(Ok(chunk)));
}
}
let result = match self.body {
EncoderBody::Bytes(ref mut b) => {
if b.is_empty() {
Async::Ready(None)
Poll::Ready(None)
} else {
Async::Ready(Some(std::mem::replace(b, Bytes::new())))
Poll::Ready(Some(Ok(std::mem::replace(b, Bytes::new()))))
}
}
EncoderBody::Stream(ref mut b) => b.poll_next()?,
EncoderBody::BoxedStream(ref mut b) => b.poll_next()?,
EncoderBody::Stream(ref mut b) => b.poll_next(cx),
EncoderBody::BoxedStream(ref mut b) => b.poll_next(cx),
};
match result {
Async::NotReady => return Ok(Async::NotReady),
Async::Ready(Some(chunk)) => {
Poll::Ready(Some(Ok(chunk))) => {
if let Some(mut encoder) = self.encoder.take() {
if chunk.len() < INPLACE {
encoder.write(&chunk)?;
let chunk = encoder.take();
self.encoder = Some(encoder);
if !chunk.is_empty() {
return Ok(Async::Ready(Some(chunk)));
return Poll::Ready(Some(Ok(chunk)));
}
} else {
self.fut = Some(run(move || {
@@ -139,22 +144,23 @@ impl<B: MessageBody> MessageBody for Encoder<B> {
}));
}
} else {
return Ok(Async::Ready(Some(chunk)));
return Poll::Ready(Some(Ok(chunk)));
}
}
Async::Ready(None) => {
Poll::Ready(None) => {
if let Some(encoder) = self.encoder.take() {
let chunk = encoder.finish()?;
if chunk.is_empty() {
return Ok(Async::Ready(None));
return Poll::Ready(None);
} else {
self.eof = true;
return Ok(Async::Ready(Some(chunk)));
return Poll::Ready(Some(Ok(chunk)));
}
} else {
return Ok(Async::Ready(None));
return Poll::Ready(None);
}
}
val => return val,
}
}
}

View File

@@ -6,11 +6,10 @@ use std::str::Utf8Error;
use std::string::FromUtf8Error;
use std::{fmt, io, result};
pub use actix_threadpool::BlockingError;
use actix_utils::timeout::TimeoutError;
use bytes::BytesMut;
use derive_more::{Display, From};
use futures::Canceled;
pub use futures::channel::oneshot::Canceled;
use http::uri::InvalidUri;
use http::{header, Error as HttpError, StatusCode};
use httparse;
@@ -182,13 +181,13 @@ impl ResponseError for FormError {}
/// `InternalServerError` for `TimerError`
impl ResponseError for TimerError {}
#[cfg(feature = "ssl")]
#[cfg(feature = "openssl")]
/// `InternalServerError` for `openssl::ssl::Error`
impl ResponseError for openssl::ssl::Error {}
impl ResponseError for open_ssl::ssl::Error {}
#[cfg(feature = "ssl")]
#[cfg(feature = "openssl")]
/// `InternalServerError` for `openssl::ssl::HandshakeError`
impl ResponseError for openssl::ssl::HandshakeError<tokio_tcp::TcpStream> {}
impl<T: std::fmt::Debug> ResponseError for open_ssl::ssl::HandshakeError<T> {}
/// Return `BAD_REQUEST` for `de::value::Error`
impl ResponseError for DeError {
@@ -197,8 +196,8 @@ impl ResponseError for DeError {
}
}
/// `InternalServerError` for `BlockingError`
impl<E: fmt::Debug> ResponseError for BlockingError<E> {}
/// `InternalServerError` for `Canceled`
impl ResponseError for Canceled {}
/// Return `BAD_REQUEST` for `Utf8Error`
impl ResponseError for Utf8Error {
@@ -236,9 +235,6 @@ impl ResponseError for header::InvalidHeaderValueBytes {
}
}
/// `InternalServerError` for `futures::Canceled`
impl ResponseError for Canceled {}
/// A set of errors that can occur during parsing HTTP streams
#[derive(Debug, Display)]
pub enum ParseError {
@@ -365,15 +361,12 @@ impl From<io::Error> for PayloadError {
}
}
impl From<BlockingError<io::Error>> for PayloadError {
fn from(err: BlockingError<io::Error>) -> Self {
match err {
BlockingError::Error(e) => PayloadError::Io(e),
BlockingError::Canceled => PayloadError::Io(io::Error::new(
impl From<Canceled> for PayloadError {
fn from(_: Canceled) -> Self {
PayloadError::Io(io::Error::new(
io::ErrorKind::Other,
"Thread pool is gone",
)),
}
"Operation is canceled",
))
}
}

View File

@@ -1,10 +1,12 @@
use std::future::Future;
use std::io;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::Decoder;
use bytes::{Bytes, BytesMut};
use futures::{Async, Poll};
use http::header::{HeaderName, HeaderValue};
use http::{header, HttpTryFrom, Method, StatusCode, Uri, Version};
use httparse;
@@ -442,9 +444,10 @@ impl Decoder for PayloadDecoder {
loop {
let mut buf = None;
// advances the chunked state
*state = match state.step(src, size, &mut buf)? {
Async::NotReady => return Ok(None),
Async::Ready(state) => state,
*state = match state.step(src, size, &mut buf) {
Poll::Pending => return Ok(None),
Poll::Ready(Ok(state)) => state,
Poll::Ready(Err(e)) => return Err(e),
};
if *state == ChunkedState::End {
trace!("End of chunked stream");
@@ -476,7 +479,7 @@ macro_rules! byte (
$rdr.split_to(1);
b
} else {
return Ok(Async::NotReady)
return Poll::Pending
}
})
);
@@ -487,7 +490,7 @@ impl ChunkedState {
body: &mut BytesMut,
size: &mut u64,
buf: &mut Option<Bytes>,
) -> Poll<ChunkedState, io::Error> {
) -> Poll<Result<ChunkedState, io::Error>> {
use self::ChunkedState::*;
match *self {
Size => ChunkedState::read_size(body, size),
@@ -499,10 +502,14 @@ impl ChunkedState {
BodyLf => ChunkedState::read_body_lf(body),
EndCr => ChunkedState::read_end_cr(body),
EndLf => ChunkedState::read_end_lf(body),
End => Ok(Async::Ready(ChunkedState::End)),
End => Poll::Ready(Ok(ChunkedState::End)),
}
}
fn read_size(rdr: &mut BytesMut, size: &mut u64) -> Poll<ChunkedState, io::Error> {
fn read_size(
rdr: &mut BytesMut,
size: &mut u64,
) -> Poll<Result<ChunkedState, io::Error>> {
let radix = 16;
match byte!(rdr) {
b @ b'0'..=b'9' => {
@@ -517,48 +524,49 @@ impl ChunkedState {
*size *= radix;
*size += u64::from(b + 10 - b'A');
}
b'\t' | b' ' => return Ok(Async::Ready(ChunkedState::SizeLws)),
b';' => return Ok(Async::Ready(ChunkedState::Extension)),
b'\r' => return Ok(Async::Ready(ChunkedState::SizeLf)),
b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)),
b';' => return Poll::Ready(Ok(ChunkedState::Extension)),
b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)),
_ => {
return Err(io::Error::new(
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk size line: Invalid Size",
));
)));
}
}
Ok(Async::Ready(ChunkedState::Size))
Poll::Ready(Ok(ChunkedState::Size))
}
fn read_size_lws(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
fn read_size_lws(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
trace!("read_size_lws");
match byte!(rdr) {
// LWS can follow the chunk size, but no more digits can come
b'\t' | b' ' => Ok(Async::Ready(ChunkedState::SizeLws)),
b';' => Ok(Async::Ready(ChunkedState::Extension)),
b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)),
_ => Err(io::Error::new(
b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)),
b';' => Poll::Ready(Ok(ChunkedState::Extension)),
b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk size linear white space",
)),
))),
}
}
fn read_extension(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
fn read_extension(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)),
_ => Ok(Async::Ready(ChunkedState::Extension)), // no supported extensions
b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
_ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions
}
}
fn read_size_lf(
rdr: &mut BytesMut,
size: &mut u64,
) -> Poll<ChunkedState, io::Error> {
) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\n' if *size > 0 => Ok(Async::Ready(ChunkedState::Body)),
b'\n' if *size == 0 => Ok(Async::Ready(ChunkedState::EndCr)),
_ => Err(io::Error::new(
b'\n' if *size > 0 => Poll::Ready(Ok(ChunkedState::Body)),
b'\n' if *size == 0 => Poll::Ready(Ok(ChunkedState::EndCr)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk size LF",
)),
))),
}
}
@@ -566,12 +574,12 @@ impl ChunkedState {
rdr: &mut BytesMut,
rem: &mut u64,
buf: &mut Option<Bytes>,
) -> Poll<ChunkedState, io::Error> {
) -> Poll<Result<ChunkedState, io::Error>> {
trace!("Chunked read, remaining={:?}", rem);
let len = rdr.len() as u64;
if len == 0 {
Ok(Async::Ready(ChunkedState::Body))
Poll::Ready(Ok(ChunkedState::Body))
} else {
let slice;
if *rem > len {
@@ -583,47 +591,47 @@ impl ChunkedState {
}
*buf = Some(slice);
if *rem > 0 {
Ok(Async::Ready(ChunkedState::Body))
Poll::Ready(Ok(ChunkedState::Body))
} else {
Ok(Async::Ready(ChunkedState::BodyCr))
Poll::Ready(Ok(ChunkedState::BodyCr))
}
}
}
fn read_body_cr(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
fn read_body_cr(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\r' => Ok(Async::Ready(ChunkedState::BodyLf)),
_ => Err(io::Error::new(
b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk body CR",
)),
))),
}
}
fn read_body_lf(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
fn read_body_lf(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\n' => Ok(Async::Ready(ChunkedState::Size)),
_ => Err(io::Error::new(
b'\n' => Poll::Ready(Ok(ChunkedState::Size)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk body LF",
)),
))),
}
}
fn read_end_cr(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
fn read_end_cr(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\r' => Ok(Async::Ready(ChunkedState::EndLf)),
_ => Err(io::Error::new(
b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk end CR",
)),
))),
}
}
fn read_end_lf(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
fn read_end_lf(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\n' => Ok(Async::Ready(ChunkedState::End)),
_ => Err(io::Error::new(
b'\n' => Poll::Ready(Ok(ChunkedState::End)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk end LF",
)),
))),
}
}
}

View File

@@ -1,15 +1,17 @@
use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use std::{fmt, io, net};
use std::{fmt, io, io::Write, net};
use actix_codec::{Decoder, Encoder, Framed, FramedParts};
use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts};
use actix_server_config::IoStream;
use actix_service::Service;
use bitflags::bitflags;
use bytes::{BufMut, BytesMut};
use futures::{Async, Future, Poll};
use log::{error, trace};
use tokio_timer::Delay;
use tokio_timer::{delay, Delay};
use crate::body::{Body, BodySize, MessageBody, ResponseBody};
use crate::cloneable::CloneableService;
@@ -261,14 +263,14 @@ where
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
{
fn can_read(&self) -> bool {
fn can_read(&self, cx: &mut Context) -> bool {
if self
.flags
.intersects(Flags::READ_DISCONNECT | Flags::UPGRADE)
{
false
} else if let Some(ref info) = self.payload {
info.need_read() == PayloadStatus::Read
info.need_read(cx) == PayloadStatus::Read
} else {
true
}
@@ -287,7 +289,7 @@ where
///
/// true - got whouldblock
/// false - didnt get whouldblock
fn poll_flush(&mut self) -> Result<bool, DispatchError> {
fn poll_flush(&mut self, cx: &mut Context) -> Result<bool, DispatchError> {
if self.write_buf.is_empty() {
return Ok(false);
}
@@ -295,32 +297,32 @@ where
let len = self.write_buf.len();
let mut written = 0;
while written < len {
match self.io.write(&self.write_buf[written..]) {
Ok(0) => {
match unsafe { Pin::new_unchecked(&mut self.io) }
.poll_write(cx, &self.write_buf[written..])
{
Poll::Ready(Ok(0)) => {
return Err(DispatchError::Io(io::Error::new(
io::ErrorKind::WriteZero,
"",
)));
}
Ok(n) => {
Poll::Ready(Ok(n)) => {
written += n;
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
Poll::Pending => {
if written > 0 {
let _ = self.write_buf.split_to(written);
}
return Ok(true);
}
Err(err) => return Err(DispatchError::Io(err)),
Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)),
}
}
if written > 0 {
if written == self.write_buf.len() {
unsafe { self.write_buf.set_len(0) }
} else {
let _ = self.write_buf.split_to(written);
}
}
Ok(false)
}
@@ -350,12 +352,15 @@ where
.extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
}
fn poll_response(&mut self) -> Result<PollResponse, DispatchError> {
fn poll_response(
&mut self,
cx: &mut Context,
) -> Result<PollResponse, DispatchError> {
loop {
let state = match self.state {
State::None => match self.messages.pop_front() {
Some(DispatcherMessage::Item(req)) => {
Some(self.handle_request(req)?)
Some(self.handle_request(req, cx)?)
}
Some(DispatcherMessage::Error(res)) => {
Some(self.send_response(res, ResponseBody::Other(Body::Empty))?)
@@ -365,54 +370,58 @@ where
}
None => None,
},
State::ExpectCall(ref mut fut) => match fut.poll() {
Ok(Async::Ready(req)) => {
State::ExpectCall(ref mut fut) => {
match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
Poll::Ready(Ok(req)) => {
self.send_continue();
self.state = State::ServiceCall(self.service.call(req));
continue;
}
Ok(Async::NotReady) => None,
Err(e) => {
Poll::Ready(Err(e)) => {
let res: Response = e.into().into();
let (res, body) = res.replace_body(());
Some(self.send_response(res, body.into_body())?)
}
},
State::ServiceCall(ref mut fut) => match fut.poll() {
Ok(Async::Ready(res)) => {
Poll::Pending => None,
}
}
State::ServiceCall(ref mut fut) => {
match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
Poll::Ready(Ok(res)) => {
let (res, body) = res.into().replace_body(());
self.state = self.send_response(res, body)?;
continue;
}
Ok(Async::NotReady) => None,
Err(e) => {
Poll::Ready(Err(e)) => {
let res: Response = e.into().into();
let (res, body) = res.replace_body(());
Some(self.send_response(res, body.into_body())?)
}
},
Poll::Pending => None,
}
}
State::SendPayload(ref mut stream) => {
loop {
if self.write_buf.len() < HW_BUFFER_SIZE {
match stream
.poll_next()
.map_err(|_| DispatchError::Unknown)?
{
Async::Ready(Some(item)) => {
match stream.poll_next(cx) {
Poll::Ready(Some(Ok(item))) => {
self.codec.encode(
Message::Chunk(Some(item)),
&mut self.write_buf,
)?;
continue;
}
Async::Ready(None) => {
Poll::Ready(None) => {
self.codec.encode(
Message::Chunk(None),
&mut self.write_buf,
)?;
self.state = State::None;
}
Async::NotReady => return Ok(PollResponse::DoNothing),
Poll::Ready(Some(Err(_))) => {
return Err(DispatchError::Unknown)
}
Poll::Pending => return Ok(PollResponse::DoNothing),
}
} else {
return Ok(PollResponse::DrainWriteBuf);
@@ -433,7 +442,7 @@ where
// if read-backpressure is enabled and we consumed some data.
// we may read more data and retry
if self.state.is_call() {
if self.poll_request()? {
if self.poll_request(cx)? {
continue;
}
} else if !self.messages.is_empty() {
@@ -446,17 +455,21 @@ where
Ok(PollResponse::DoNothing)
}
fn handle_request(&mut self, req: Request) -> Result<State<S, B, X>, DispatchError> {
fn handle_request(
&mut self,
req: Request,
cx: &mut Context,
) -> Result<State<S, B, X>, DispatchError> {
// Handle `EXPECT: 100-Continue` header
let req = if req.head().expect() {
let mut task = self.expect.call(req);
match task.poll() {
Ok(Async::Ready(req)) => {
match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) {
Poll::Ready(Ok(req)) => {
self.send_continue();
req
}
Ok(Async::NotReady) => return Ok(State::ExpectCall(task)),
Err(e) => {
Poll::Pending => return Ok(State::ExpectCall(task)),
Poll::Ready(Err(e)) => {
let e = e.into();
let res: Response = e.into();
let (res, body) = res.replace_body(());
@@ -469,13 +482,13 @@ where
// Call service
let mut task = self.service.call(req);
match task.poll() {
Ok(Async::Ready(res)) => {
match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) {
Poll::Ready(Ok(res)) => {
let (res, body) = res.into().replace_body(());
self.send_response(res, body)
}
Ok(Async::NotReady) => Ok(State::ServiceCall(task)),
Err(e) => {
Poll::Pending => Ok(State::ServiceCall(task)),
Poll::Ready(Err(e)) => {
let res: Response = e.into().into();
let (res, body) = res.replace_body(());
self.send_response(res, body.into_body())
@@ -484,9 +497,12 @@ where
}
/// Process one incoming requests
pub(self) fn poll_request(&mut self) -> Result<bool, DispatchError> {
pub(self) fn poll_request(
&mut self,
cx: &mut Context,
) -> Result<bool, DispatchError> {
// limit a mount of non processed requests
if self.messages.len() >= MAX_PIPELINED_MESSAGES || !self.can_read() {
if self.messages.len() >= MAX_PIPELINED_MESSAGES || !self.can_read(cx) {
return Ok(false);
}
@@ -521,7 +537,7 @@ where
// handle request early
if self.state.is_empty() {
self.state = self.handle_request(req)?;
self.state = self.handle_request(req, cx)?;
} else {
self.messages.push_back(DispatcherMessage::Item(req));
}
@@ -587,12 +603,12 @@ where
}
/// keep-alive timer
fn poll_keepalive(&mut self) -> Result<(), DispatchError> {
fn poll_keepalive(&mut self, cx: &mut Context) -> Result<(), DispatchError> {
if self.ka_timer.is_none() {
// shutdown timeout
if self.flags.contains(Flags::SHUTDOWN) {
if let Some(interval) = self.codec.config().client_disconnect_timer() {
self.ka_timer = Some(Delay::new(interval));
self.ka_timer = Some(delay(interval));
} else {
self.flags.insert(Flags::READ_DISCONNECT);
if let Some(mut payload) = self.payload.take() {
@@ -605,11 +621,8 @@ where
}
}
match self.ka_timer.as_mut().unwrap().poll().map_err(|e| {
error!("Timer error {:?}", e);
DispatchError::Unknown
})? {
Async::Ready(_) => {
match Pin::new(&mut self.ka_timer.as_mut().unwrap()).poll(cx) {
Poll::Ready(()) => {
// if we get timeout during shutdown, drop connection
if self.flags.contains(Flags::SHUTDOWN) {
return Err(DispatchError::DisconnectTimeout);
@@ -624,9 +637,9 @@ where
if let Some(deadline) =
self.codec.config().client_disconnect_timer()
{
if let Some(timer) = self.ka_timer.as_mut() {
if let Some(mut timer) = self.ka_timer.as_mut() {
timer.reset(deadline);
let _ = timer.poll();
let _ = Pin::new(&mut timer).poll(cx);
}
} else {
// no shutdown timeout, drop socket
@@ -650,23 +663,37 @@ where
} else if let Some(deadline) =
self.codec.config().keep_alive_expire()
{
if let Some(timer) = self.ka_timer.as_mut() {
if let Some(mut timer) = self.ka_timer.as_mut() {
timer.reset(deadline);
let _ = timer.poll();
let _ = Pin::new(&mut timer).poll(cx);
}
}
} else if let Some(timer) = self.ka_timer.as_mut() {
} else if let Some(mut timer) = self.ka_timer.as_mut() {
timer.reset(self.ka_expire);
let _ = timer.poll();
let _ = Pin::new(&mut timer).poll(cx);
}
}
Async::NotReady => (),
Poll::Pending => (),
}
Ok(())
}
}
impl<T, S, B, X, U> Unpin for Dispatcher<T, S, B, X, U>
where
T: IoStream,
S: Service<Request = Request>,
S::Error: Into<Error>,
S::Response: Into<Response<B>>,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
{
}
impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
where
T: IoStream,
@@ -679,27 +706,28 @@ where
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
{
type Item = ();
type Error = DispatchError;
type Output = Result<(), DispatchError>;
#[inline]
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.inner {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match self.as_mut().inner {
DispatcherState::Normal(ref mut inner) => {
inner.poll_keepalive()?;
inner.poll_keepalive(cx)?;
if inner.flags.contains(Flags::SHUTDOWN) {
if inner.flags.contains(Flags::WRITE_DISCONNECT) {
Ok(Async::Ready(()))
Poll::Ready(Ok(()))
} else {
// flush buffer
inner.poll_flush()?;
inner.poll_flush(cx)?;
if !inner.write_buf.is_empty() {
Ok(Async::NotReady)
Poll::Pending
} else {
match inner.io.shutdown()? {
Async::Ready(_) => Ok(Async::Ready(())),
Async::NotReady => Ok(Async::NotReady),
match Pin::new(&mut inner.io).poll_shutdown(cx) {
Poll::Ready(res) => {
Poll::Ready(res.map_err(DispatchError::from))
}
Poll::Pending => Poll::Pending,
}
}
}
@@ -707,12 +735,12 @@ where
// read socket into a buf
let should_disconnect =
if !inner.flags.contains(Flags::READ_DISCONNECT) {
read_available(&mut inner.io, &mut inner.read_buf)?
read_available(cx, &mut inner.io, &mut inner.read_buf)?
} else {
None
};
inner.poll_request()?;
inner.poll_request(cx)?;
if let Some(true) = should_disconnect {
inner.flags.insert(Flags::READ_DISCONNECT);
if let Some(mut payload) = inner.payload.take() {
@@ -724,7 +752,7 @@ where
if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE {
inner.write_buf.reserve(HW_BUFFER_SIZE);
}
let result = inner.poll_response()?;
let result = inner.poll_response(cx)?;
let drain = result == PollResponse::DrainWriteBuf;
// switch to upgrade handler
@@ -742,7 +770,7 @@ where
self.inner = DispatcherState::Upgrade(
inner.upgrade.unwrap().call((req, framed)),
);
return self.poll();
return self.poll(cx);
} else {
panic!()
}
@@ -751,14 +779,14 @@ where
// we didnt get WouldBlock from write operation,
// so data get written to kernel completely (OSX)
// and we have to write again otherwise response can get stuck
if inner.poll_flush()? || !drain {
if inner.poll_flush(cx)? || !drain {
break;
}
}
// client is gone
if inner.flags.contains(Flags::WRITE_DISCONNECT) {
return Ok(Async::Ready(()));
return Poll::Ready(Ok(()));
}
let is_empty = inner.state.is_empty();
@@ -771,38 +799,44 @@ where
// keep-alive and stream errors
if is_empty && inner.write_buf.is_empty() {
if let Some(err) = inner.error.take() {
Err(err)
Poll::Ready(Err(err))
}
// disconnect if keep-alive is not enabled
else if inner.flags.contains(Flags::STARTED)
&& !inner.flags.intersects(Flags::KEEPALIVE)
{
inner.flags.insert(Flags::SHUTDOWN);
self.poll()
self.poll(cx)
}
// disconnect if shutdown
else if inner.flags.contains(Flags::SHUTDOWN) {
self.poll()
self.poll(cx)
} else {
Ok(Async::NotReady)
Poll::Pending
}
} else {
Ok(Async::NotReady)
Poll::Pending
}
}
}
DispatcherState::Upgrade(ref mut fut) => fut.poll().map_err(|e| {
DispatcherState::Upgrade(ref mut fut) => {
unsafe { Pin::new_unchecked(fut) }.poll(cx).map_err(|e| {
error!("Upgrade handler error: {}", e);
DispatchError::Upgrade
}),
})
}
DispatcherState::None => panic!(),
}
}
}
fn read_available<T>(io: &mut T, buf: &mut BytesMut) -> Result<Option<bool>, io::Error>
fn read_available<T>(
cx: &mut Context,
io: &mut T,
buf: &mut BytesMut,
) -> Result<Option<bool>, io::Error>
where
T: io::Read,
T: AsyncRead + Unpin,
{
let mut read_some = false;
loop {
@@ -810,19 +844,18 @@ where
buf.reserve(HW_BUFFER_SIZE);
}
let read = unsafe { io.read(buf.bytes_mut()) };
match read {
Ok(n) => {
match read(cx, io, buf) {
Poll::Pending => {
return if read_some { Ok(Some(false)) } else { Ok(None) };
}
Poll::Ready(Ok(n)) => {
if n == 0 {
return Ok(Some(true));
} else {
read_some = true;
unsafe {
buf.advance_mut(n);
}
}
}
Err(e) => {
Poll::Ready(Err(e)) => {
return if e.kind() == io::ErrorKind::WouldBlock {
if read_some {
Ok(Some(false))
@@ -833,11 +866,22 @@ where
Ok(Some(true))
} else {
Err(e)
};
}
}
}
}
}
fn read<T>(
cx: &mut Context,
io: &mut T,
buf: &mut BytesMut,
) -> Poll<Result<usize, io::Error>>
where
T: AsyncRead + Unpin,
{
Pin::new(io).poll_read_buf(cx, buf)
}
#[cfg(test)]
mod tests {
@@ -852,7 +896,7 @@ mod tests {
#[test]
fn test_req_parse_err() {
let mut sys = actix_rt::System::new("test");
let _ = sys.block_on(lazy(|| {
let _ = sys.block_on(lazy(|cx| {
let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n");
let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler<TestBuffer>>::new(
@@ -865,7 +909,10 @@ mod tests {
None,
None,
);
assert!(h1.poll().is_err());
match Pin::new(&mut h1).poll(cx) {
Poll::Pending => panic!(),
Poll::Ready(res) => assert!(res.is_err()),
}
if let DispatcherState::Normal(ref inner) = h1.inner {
assert!(inner.flags.contains(Flags::READ_DISCONNECT));

View File

@@ -548,10 +548,11 @@ mod tests {
ConnectionType::Close,
&ServiceConfig::default(),
);
assert_eq!(
bytes.take().freeze(),
Bytes::from_static(b"\r\nContent-Length: 0\r\nConnection: close\r\nDate: date\r\nContent-Type: plain/text\r\n\r\n")
);
let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap();
assert!(data.contains("Content-Length: 0\r\n"));
assert!(data.contains("Connection: close\r\n"));
assert!(data.contains("Content-Type: plain/text\r\n"));
assert!(data.contains("Date: date\r\n"));
let _ = head.encode_headers(
&mut bytes,
@@ -560,10 +561,10 @@ mod tests {
ConnectionType::KeepAlive,
&ServiceConfig::default(),
);
assert_eq!(
bytes.take().freeze(),
Bytes::from_static(b"\r\nTransfer-Encoding: chunked\r\nDate: date\r\nContent-Type: plain/text\r\n\r\n")
);
let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap();
assert!(data.contains("Transfer-Encoding: chunked\r\n"));
assert!(data.contains("Content-Type: plain/text\r\n"));
assert!(data.contains("Date: date\r\n"));
let _ = head.encode_headers(
&mut bytes,
@@ -572,10 +573,10 @@ mod tests {
ConnectionType::KeepAlive,
&ServiceConfig::default(),
);
assert_eq!(
bytes.take().freeze(),
Bytes::from_static(b"\r\nContent-Length: 100\r\nDate: date\r\nContent-Type: plain/text\r\n\r\n")
);
let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap();
assert!(data.contains("Content-Length: 100\r\n"));
assert!(data.contains("Content-Type: plain/text\r\n"));
assert!(data.contains("Date: date\r\n"));
let mut head = RequestHead::default();
head.set_camel_case_headers(false);
@@ -586,7 +587,6 @@ mod tests {
.append(CONTENT_TYPE, HeaderValue::from_static("xml"));
let mut head = RequestHeadType::Owned(head);
let _ = head.encode_headers(
&mut bytes,
Version::HTTP_11,
@@ -594,10 +594,11 @@ mod tests {
ConnectionType::KeepAlive,
&ServiceConfig::default(),
);
assert_eq!(
bytes.take().freeze(),
Bytes::from_static(b"\r\ntransfer-encoding: chunked\r\ndate: date\r\ncontent-type: xml\r\ncontent-type: plain/text\r\n\r\n")
);
let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap();
assert!(data.contains("transfer-encoding: chunked\r\n"));
assert!(data.contains("content-type: xml\r\n"));
assert!(data.contains("content-type: plain/text\r\n"));
assert!(data.contains("date: date\r\n"));
}
#[test]
@@ -626,9 +627,10 @@ mod tests {
ConnectionType::Close,
&ServiceConfig::default(),
);
assert_eq!(
bytes.take().freeze(),
Bytes::from_static(b"\r\ncontent-length: 0\r\nconnection: close\r\nauthorization: another authorization\r\ndate: date\r\n\r\n")
);
let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap();
assert!(data.contains("content-length: 0\r\n"));
assert!(data.contains("connection: close\r\n"));
assert!(data.contains("authorization: another authorization\r\n"));
assert!(data.contains("date: date\r\n"));
}
}

View File

@@ -1,21 +1,24 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_server_config::ServerConfig;
use actix_service::{NewService, Service};
use futures::future::{ok, FutureResult};
use futures::{Async, Poll};
use actix_service::{Service, ServiceFactory};
use futures::future::{ok, Ready};
use crate::error::Error;
use crate::request::Request;
pub struct ExpectHandler;
impl NewService for ExpectHandler {
impl ServiceFactory for ExpectHandler {
type Config = ServerConfig;
type Request = Request;
type Response = Request;
type Error = Error;
type Service = ExpectHandler;
type InitError = Error;
type Future = FutureResult<Self::Service, Self::InitError>;
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: &ServerConfig) -> Self::Future {
ok(ExpectHandler)
@@ -26,10 +29,10 @@ impl Service for ExpectHandler {
type Request = Request;
type Response = Request;
type Error = Error;
type Future = FutureResult<Self::Response, Self::Error>;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Ok(Async::Ready(()))
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request) -> Self::Future {

View File

@@ -1,12 +1,14 @@
//! Payload stream
use std::cell::RefCell;
use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::rc::{Rc, Weak};
use std::task::{Context, Poll};
use actix_utils::task::LocalWaker;
use bytes::Bytes;
use futures::task::current as current_task;
use futures::task::Task;
use futures::{Async, Poll, Stream};
use futures::Stream;
use crate::error::PayloadError;
@@ -77,15 +79,24 @@ impl Payload {
pub fn unread_data(&mut self, data: Bytes) {
self.inner.borrow_mut().unread_data(data);
}
#[inline]
pub fn readany(
&mut self,
cx: &mut Context,
) -> Poll<Option<Result<Bytes, PayloadError>>> {
self.inner.borrow_mut().readany(cx)
}
}
impl Stream for Payload {
type Item = Bytes;
type Error = PayloadError;
type Item = Result<Bytes, PayloadError>;
#[inline]
fn poll(&mut self) -> Poll<Option<Bytes>, PayloadError> {
self.inner.borrow_mut().readany()
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Option<Result<Bytes, PayloadError>>> {
self.inner.borrow_mut().readany(cx)
}
}
@@ -117,19 +128,14 @@ impl PayloadSender {
}
#[inline]
pub fn need_read(&self) -> PayloadStatus {
pub fn need_read(&self, cx: &mut Context) -> PayloadStatus {
// we check need_read only if Payload (other side) is alive,
// otherwise always return true (consume payload)
if let Some(shared) = self.inner.upgrade() {
if shared.borrow().need_read {
PayloadStatus::Read
} else {
#[cfg(not(test))]
{
if shared.borrow_mut().io_task.is_none() {
shared.borrow_mut().io_task = Some(current_task());
}
}
shared.borrow_mut().io_task.register(cx.waker());
PayloadStatus::Pause
}
} else {
@@ -145,8 +151,8 @@ struct Inner {
err: Option<PayloadError>,
need_read: bool,
items: VecDeque<Bytes>,
task: Option<Task>,
io_task: Option<Task>,
task: LocalWaker,
io_task: LocalWaker,
}
impl Inner {
@@ -157,8 +163,8 @@ impl Inner {
err: None,
items: VecDeque::new(),
need_read: true,
task: None,
io_task: None,
task: LocalWaker::new(),
io_task: LocalWaker::new(),
}
}
@@ -178,7 +184,7 @@ impl Inner {
self.items.push_back(data);
self.need_read = self.len < MAX_BUFFER_SIZE;
if let Some(task) = self.task.take() {
task.notify()
task.wake()
}
}
@@ -187,34 +193,28 @@ impl Inner {
self.len
}
fn readany(&mut self) -> Poll<Option<Bytes>, PayloadError> {
fn readany(
&mut self,
cx: &mut Context,
) -> Poll<Option<Result<Bytes, PayloadError>>> {
if let Some(data) = self.items.pop_front() {
self.len -= data.len();
self.need_read = self.len < MAX_BUFFER_SIZE;
if self.need_read && self.task.is_none() && !self.eof {
self.task = Some(current_task());
if self.need_read && !self.eof {
self.task.register(cx.waker());
}
if let Some(task) = self.io_task.take() {
task.notify()
}
Ok(Async::Ready(Some(data)))
self.io_task.wake();
Poll::Ready(Some(Ok(data)))
} else if let Some(err) = self.err.take() {
Err(err)
Poll::Ready(Some(Err(err)))
} else if self.eof {
Ok(Async::Ready(None))
Poll::Ready(None)
} else {
self.need_read = true;
#[cfg(not(test))]
{
if self.task.is_none() {
self.task = Some(current_task());
}
if let Some(task) = self.io_task.take() {
task.notify()
}
}
Ok(Async::NotReady)
self.task.register(cx.waker());
self.io_task.wake();
Poll::Pending
}
}
@@ -228,13 +228,11 @@ impl Inner {
mod tests {
use super::*;
use actix_rt::Runtime;
use futures::future::{lazy, result};
use futures::future::{poll_fn, ready};
#[test]
fn test_unread_data() {
Runtime::new()
.unwrap()
.block_on(lazy(|| {
Runtime::new().unwrap().block_on(async {
let (_, mut payload) = Payload::create(false);
payload.unread_data(Bytes::from("data"));
@@ -242,13 +240,11 @@ mod tests {
assert_eq!(payload.len(), 4);
assert_eq!(
Async::Ready(Some(Bytes::from("data"))),
payload.poll().ok().unwrap()
Bytes::from("data"),
poll_fn(|cx| payload.readany(cx)).await.unwrap().unwrap()
);
let res: Result<(), ()> = Ok(());
result(res)
}))
.unwrap();
ready(())
});
}
}

View File

@@ -1,12 +1,15 @@
use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use actix_codec::Framed;
use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig};
use actix_service::{IntoNewService, NewService, Service};
use futures::future::{ok, FutureResult};
use futures::{try_ready, Async, Future, IntoFuture, Poll, Stream};
use actix_service::{IntoServiceFactory, Service, ServiceFactory};
use futures::future::{ok, Ready};
use futures::{ready, Stream};
use crate::body::MessageBody;
use crate::cloneable::CloneableService;
@@ -20,7 +23,7 @@ use super::codec::Codec;
use super::dispatcher::Dispatcher;
use super::{ExpectHandler, Message, UpgradeHandler};
/// `NewService` implementation for HTTP1 transport
/// `ServiceFactory` implementation for HTTP1 transport
pub struct H1Service<T, P, S, B, X = ExpectHandler, U = UpgradeHandler<T>> {
srv: S,
cfg: ServiceConfig,
@@ -32,19 +35,19 @@ pub struct H1Service<T, P, S, B, X = ExpectHandler, U = UpgradeHandler<T>> {
impl<T, P, S, B> H1Service<T, P, S, B>
where
S: NewService<Config = SrvConfig, Request = Request>,
S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
B: MessageBody,
{
/// Create new `HttpService` instance with default config.
pub fn new<F: IntoNewService<S>>(service: F) -> Self {
pub fn new<F: IntoServiceFactory<S>>(service: F) -> Self {
let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0);
H1Service {
cfg,
srv: service.into_new_service(),
srv: service.into_factory(),
expect: ExpectHandler,
upgrade: None,
on_connect: None,
@@ -53,10 +56,13 @@ where
}
/// Create new `HttpService` instance with config.
pub fn with_config<F: IntoNewService<S>>(cfg: ServiceConfig, service: F) -> Self {
pub fn with_config<F: IntoServiceFactory<S>>(
cfg: ServiceConfig,
service: F,
) -> Self {
H1Service {
cfg,
srv: service.into_new_service(),
srv: service.into_factory(),
expect: ExpectHandler,
upgrade: None,
on_connect: None,
@@ -67,7 +73,7 @@ where
impl<T, P, S, B, X, U> H1Service<T, P, S, B, X, U>
where
S: NewService<Config = SrvConfig, Request = Request>,
S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>,
S::Response: Into<Response<B>>,
S::InitError: fmt::Debug,
@@ -75,7 +81,7 @@ where
{
pub fn expect<X1>(self, expect: X1) -> H1Service<T, P, S, B, X1, U>
where
X1: NewService<Request = Request, Response = Request>,
X1: ServiceFactory<Request = Request, Response = Request>,
X1::Error: Into<Error>,
X1::InitError: fmt::Debug,
{
@@ -91,7 +97,7 @@ where
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> H1Service<T, P, S, B, X, U1>
where
U1: NewService<Request = (Request, Framed<T, Codec>), Response = ()>,
U1: ServiceFactory<Request = (Request, Framed<T, Codec>), Response = ()>,
U1::Error: fmt::Display,
U1::InitError: fmt::Debug,
{
@@ -115,18 +121,18 @@ where
}
}
impl<T, P, S, B, X, U> NewService for H1Service<T, P, S, B, X, U>
impl<T, P, S, B, X, U> ServiceFactory for H1Service<T, P, S, B, X, U>
where
T: IoStream,
S: NewService<Config = SrvConfig, Request = Request>,
S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>,
S::Response: Into<Response<B>>,
S::InitError: fmt::Debug,
B: MessageBody,
X: NewService<Config = SrvConfig, Request = Request, Response = Request>,
X: ServiceFactory<Config = SrvConfig, Request = Request, Response = Request>,
X::Error: Into<Error>,
X::InitError: fmt::Debug,
U: NewService<
U: ServiceFactory<
Config = SrvConfig,
Request = (Request, Framed<T, Codec>),
Response = (),
@@ -144,7 +150,7 @@ where
fn new_service(&self, cfg: &SrvConfig) -> Self::Future {
H1ServiceResponse {
fut: self.srv.new_service(cfg).into_future(),
fut: self.srv.new_service(cfg),
fut_ex: Some(self.expect.new_service(cfg)),
fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)),
expect: None,
@@ -157,20 +163,24 @@ where
}
#[doc(hidden)]
#[pin_project::pin_project]
pub struct H1ServiceResponse<T, P, S, B, X, U>
where
S: NewService<Request = Request>,
S: ServiceFactory<Request = Request>,
S::Error: Into<Error>,
S::InitError: fmt::Debug,
X: NewService<Request = Request, Response = Request>,
X: ServiceFactory<Request = Request, Response = Request>,
X::Error: Into<Error>,
X::InitError: fmt::Debug,
U: NewService<Request = (Request, Framed<T, Codec>), Response = ()>,
U: ServiceFactory<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
U::InitError: fmt::Debug,
{
#[pin]
fut: S::Future,
#[pin]
fut_ex: Option<X::Future>,
#[pin]
fut_upg: Option<U::Future>,
expect: Option<X::Service>,
upgrade: Option<U::Service>,
@@ -182,49 +192,57 @@ where
impl<T, P, S, B, X, U> Future for H1ServiceResponse<T, P, S, B, X, U>
where
T: IoStream,
S: NewService<Request = Request>,
S: ServiceFactory<Request = Request>,
S::Error: Into<Error>,
S::Response: Into<Response<B>>,
S::InitError: fmt::Debug,
B: MessageBody,
X: NewService<Request = Request, Response = Request>,
X: ServiceFactory<Request = Request, Response = Request>,
X::Error: Into<Error>,
X::InitError: fmt::Debug,
U: NewService<Request = (Request, Framed<T, Codec>), Response = ()>,
U: ServiceFactory<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
U::InitError: fmt::Debug,
{
type Item = H1ServiceHandler<T, P, S::Service, B, X::Service, U::Service>;
type Error = ();
type Output =
Result<H1ServiceHandler<T, P, S::Service, B, X::Service, U::Service>, ()>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(ref mut fut) = self.fut_ex {
let expect = try_ready!(fut
.poll()
.map_err(|e| log::error!("Init http service error: {:?}", e)));
self.expect = Some(expect);
self.fut_ex.take();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut this = self.as_mut().project();
if let Some(fut) = this.fut_ex.as_pin_mut() {
let expect = ready!(fut
.poll(cx)
.map_err(|e| log::error!("Init http service error: {:?}", e)))?;
this = self.as_mut().project();
*this.expect = Some(expect);
this.fut_ex.set(None);
}
if let Some(ref mut fut) = self.fut_upg {
let upgrade = try_ready!(fut
.poll()
.map_err(|e| log::error!("Init http service error: {:?}", e)));
self.upgrade = Some(upgrade);
self.fut_ex.take();
if let Some(fut) = this.fut_upg.as_pin_mut() {
let upgrade = ready!(fut
.poll(cx)
.map_err(|e| log::error!("Init http service error: {:?}", e)))?;
this = self.as_mut().project();
*this.upgrade = Some(upgrade);
this.fut_ex.set(None);
}
let service = try_ready!(self
let result = ready!(this
.fut
.poll()
.poll(cx)
.map_err(|e| log::error!("Init http service error: {:?}", e)));
Ok(Async::Ready(H1ServiceHandler::new(
self.cfg.take().unwrap(),
Poll::Ready(result.map(|service| {
let this = self.as_mut().project();
H1ServiceHandler::new(
this.cfg.take().unwrap(),
service,
self.expect.take().unwrap(),
self.upgrade.take(),
self.on_connect.clone(),
)))
this.expect.take().unwrap(),
this.upgrade.take(),
this.on_connect.clone(),
)
}))
}
}
@@ -284,10 +302,10 @@ where
type Error = DispatchError;
type Future = Dispatcher<T, S, B, X, U>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let ready = self
.expect
.poll_ready()
.poll_ready(cx)
.map_err(|e| {
let e = e.into();
log::error!("Http service readiness error: {:?}", e);
@@ -297,7 +315,7 @@ where
let ready = self
.srv
.poll_ready()
.poll_ready(cx)
.map_err(|e| {
let e = e.into();
log::error!("Http service readiness error: {:?}", e);
@@ -307,9 +325,9 @@ where
&& ready;
if ready {
Ok(Async::Ready(()))
Poll::Ready(Ok(()))
} else {
Ok(Async::NotReady)
Poll::Pending
}
}
@@ -333,7 +351,7 @@ where
}
}
/// `NewService` implementation for `OneRequestService` service
/// `ServiceFactory` implementation for `OneRequestService` service
#[derive(Default)]
pub struct OneRequest<T, P> {
config: ServiceConfig,
@@ -353,7 +371,7 @@ where
}
}
impl<T, P> NewService for OneRequest<T, P>
impl<T, P> ServiceFactory for OneRequest<T, P>
where
T: IoStream,
{
@@ -363,7 +381,7 @@ where
type Error = ParseError;
type InitError = ();
type Service = OneRequestService<T, P>;
type Future = FutureResult<Self::Service, Self::InitError>;
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: &SrvConfig) -> Self::Future {
ok(OneRequestService {
@@ -389,8 +407,8 @@ where
type Error = ParseError;
type Future = OneRequestServiceResponse<T>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Ok(Async::Ready(()))
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Self::Request) -> Self::Future {
@@ -415,19 +433,19 @@ impl<T> Future for OneRequestServiceResponse<T>
where
T: IoStream,
{
type Item = (Request, Framed<T, Codec>);
type Error = ParseError;
type Output = Result<(Request, Framed<T, Codec>), ParseError>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.framed.as_mut().unwrap().poll()? {
Async::Ready(Some(req)) => match req {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match self.framed.as_mut().unwrap().next_item(cx) {
Poll::Ready(Some(Ok(req))) => match req {
Message::Item(req) => {
Ok(Async::Ready((req, self.framed.take().unwrap())))
Poll::Ready(Ok((req, self.framed.take().unwrap())))
}
Message::Chunk(_) => unreachable!("Something is wrong"),
},
Async::Ready(None) => Err(ParseError::Incomplete),
Async::NotReady => Ok(Async::NotReady),
Poll::Ready(Some(Err(err))) => Poll::Ready(Err(err)),
Poll::Ready(None) => Poll::Ready(Err(ParseError::Incomplete)),
Poll::Pending => Poll::Pending,
}
}
}

View File

@@ -1,10 +1,12 @@
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::Framed;
use actix_server_config::ServerConfig;
use actix_service::{NewService, Service};
use futures::future::FutureResult;
use futures::{Async, Poll};
use actix_service::{Service, ServiceFactory};
use futures::future::Ready;
use crate::error::Error;
use crate::h1::Codec;
@@ -12,14 +14,14 @@ use crate::request::Request;
pub struct UpgradeHandler<T>(PhantomData<T>);
impl<T> NewService for UpgradeHandler<T> {
impl<T> ServiceFactory for UpgradeHandler<T> {
type Config = ServerConfig;
type Request = (Request, Framed<T, Codec>);
type Response = ();
type Error = Error;
type Service = UpgradeHandler<T>;
type InitError = Error;
type Future = FutureResult<Self::Service, Self::InitError>;
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: &ServerConfig) -> Self::Future {
unimplemented!()
@@ -30,10 +32,10 @@ impl<T> Service for UpgradeHandler<T> {
type Request = (Request, Framed<T, Codec>);
type Response = ();
type Error = Error;
type Future = FutureResult<Self::Response, Self::Error>;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Ok(Async::Ready(()))
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: Self::Request) -> Self::Future {

View File

@@ -1,5 +1,9 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use futures::{Async, Future, Poll, Sink};
use futures::Sink;
use crate::body::{BodySize, MessageBody, ResponseBody};
use crate::error::Error;
@@ -7,6 +11,7 @@ use crate::h1::{Codec, Message};
use crate::response::Response;
/// Send http/1 response
#[pin_project::pin_project]
pub struct SendResponse<T, B> {
res: Option<Message<(Response<()>, BodySize)>>,
body: Option<ResponseBody<B>>,
@@ -33,60 +38,61 @@ where
T: AsyncRead + AsyncWrite,
B: MessageBody,
{
type Item = Framed<T, Codec>;
type Error = Error;
type Output = Result<Framed<T, Codec>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
let mut body_ready = self.body.is_some();
let framed = self.framed.as_mut().unwrap();
let mut body_ready = this.body.is_some();
let framed = this.framed.as_mut().unwrap();
// send body
if self.res.is_none() && self.body.is_some() {
while body_ready && self.body.is_some() && !framed.is_write_buf_full() {
match self.body.as_mut().unwrap().poll_next()? {
Async::Ready(item) => {
if this.res.is_none() && this.body.is_some() {
while body_ready && this.body.is_some() && !framed.is_write_buf_full() {
match this.body.as_mut().unwrap().poll_next(cx)? {
Poll::Ready(item) => {
// body is done
if item.is_none() {
let _ = self.body.take();
let _ = this.body.take();
}
framed.force_send(Message::Chunk(item))?;
framed.write(Message::Chunk(item))?;
}
Async::NotReady => body_ready = false,
Poll::Pending => body_ready = false,
}
}
}
// flush write buffer
if !framed.is_write_buf_empty() {
match framed.poll_complete()? {
Async::Ready(_) => {
match framed.flush(cx)? {
Poll::Ready(_) => {
if body_ready {
continue;
} else {
return Ok(Async::NotReady);
return Poll::Pending;
}
}
Async::NotReady => return Ok(Async::NotReady),
Poll::Pending => return Poll::Pending,
}
}
// send response
if let Some(res) = self.res.take() {
framed.force_send(res)?;
if let Some(res) = this.res.take() {
framed.write(res)?;
continue;
}
if self.body.is_some() {
if this.body.is_some() {
if body_ready {
continue;
} else {
return Ok(Async::NotReady);
return Poll::Pending;
}
} else {
break;
}
}
Ok(Async::Ready(self.framed.take().unwrap()))
Poll::Ready(Ok(this.framed.take().unwrap()))
}
}

View File

@@ -1,5 +1,8 @@
use std::collections::VecDeque;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use std::{fmt, mem, net};
@@ -8,7 +11,7 @@ use actix_server_config::IoStream;
use actix_service::Service;
use bitflags::bitflags;
use bytes::{Bytes, BytesMut};
use futures::{try_ready, Async, Future, Poll, Sink, Stream};
use futures::{ready, Sink, Stream};
use h2::server::{Connection, SendResponse};
use h2::{RecvStream, SendStream};
use http::header::{
@@ -32,6 +35,7 @@ use crate::response::Response;
const CHUNK_SIZE: usize = 16_384;
/// Dispatcher for HTTP/2 protocol
#[pin_project::pin_project]
pub struct Dispatcher<T: IoStream, S: Service<Request = Request>, B: MessageBody> {
service: CloneableService<S>,
connection: Connection<T, Bytes>,
@@ -48,9 +52,9 @@ where
T: IoStream,
S: Service<Request = Request>,
S::Error: Into<Error>,
S::Future: 'static,
// S::Future: 'static,
S::Response: Into<Response<B>>,
B: MessageBody + 'static,
B: MessageBody,
{
pub(crate) fn new(
service: CloneableService<S>,
@@ -93,61 +97,75 @@ impl<T, S, B> Future for Dispatcher<T, S, B>
where
T: IoStream,
S: Service<Request = Request>,
S::Error: Into<Error>,
S::Error: Into<Error> + 'static,
S::Future: 'static,
S::Response: Into<Response<B>>,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
{
type Item = ();
type Error = DispatchError;
type Output = Result<(), DispatchError>;
#[inline]
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
match self.connection.poll()? {
Async::Ready(None) => return Ok(Async::Ready(())),
Async::Ready(Some((req, res))) => {
match Pin::new(&mut this.connection).poll_accept(cx) {
Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())),
Poll::Ready(Some(Ok((req, res)))) => {
// update keep-alive expire
if self.ka_timer.is_some() {
if let Some(expire) = self.config.keep_alive_expire() {
self.ka_expire = expire;
if this.ka_timer.is_some() {
if let Some(expire) = this.config.keep_alive_expire() {
this.ka_expire = expire;
}
}
let (parts, body) = req.into_parts();
let mut req = Request::with_payload(body.into());
let mut req = Request::with_payload(Payload::<
crate::payload::PayloadStream,
>::H2(
crate::h2::Payload::new(body)
));
let head = &mut req.head_mut();
head.uri = parts.uri;
head.method = parts.method;
head.version = parts.version;
head.headers = parts.headers.into();
head.peer_addr = self.peer_addr;
head.peer_addr = this.peer_addr;
// set on_connect data
if let Some(ref on_connect) = self.on_connect {
if let Some(ref on_connect) = this.on_connect {
on_connect.set(&mut req.extensions_mut());
}
tokio_current_thread::spawn(ServiceResponse::<S::Future, B> {
tokio_executor::current_thread::spawn(ServiceResponse::<
S::Future,
S::Response,
S::Error,
B,
> {
state: ServiceResponseState::ServiceCall(
self.service.call(req),
this.service.call(req),
Some(res),
),
config: self.config.clone(),
config: this.config.clone(),
buffer: None,
})
_t: PhantomData,
});
}
Async::NotReady => return Ok(Async::NotReady),
Poll::Pending => return Poll::Pending,
}
}
}
}
struct ServiceResponse<F, B> {
#[pin_project::pin_project]
struct ServiceResponse<F, I, E, B> {
state: ServiceResponseState<F, B>,
config: ServiceConfig,
buffer: Option<Bytes>,
_t: PhantomData<(I, E)>,
}
enum ServiceResponseState<F, B> {
@@ -155,12 +173,12 @@ enum ServiceResponseState<F, B> {
SendPayload(SendStream<Bytes>, ResponseBody<B>),
}
impl<F, B> ServiceResponse<F, B>
impl<F, I, E, B> ServiceResponse<F, I, E, B>
where
F: Future,
F::Error: Into<Error>,
F::Item: Into<Response<B>>,
B: MessageBody + 'static,
F: Future<Output = Result<I, E>>,
E: Into<Error>,
I: Into<Response<B>>,
B: MessageBody,
{
fn prepare_response(
&self,
@@ -223,109 +241,121 @@ where
}
}
impl<F, B> Future for ServiceResponse<F, B>
impl<F, I, E, B> Future for ServiceResponse<F, I, E, B>
where
F: Future,
F::Error: Into<Error>,
F::Item: Into<Response<B>>,
B: MessageBody + 'static,
F: Future<Output = Result<I, E>>,
E: Into<Error>,
I: Into<Response<B>>,
B: MessageBody,
{
type Item = ();
type Error = ();
type Output = ();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.state {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut this = self.as_mut().project();
match this.state {
ServiceResponseState::ServiceCall(ref mut call, ref mut send) => {
match call.poll() {
Ok(Async::Ready(res)) => {
match unsafe { Pin::new_unchecked(call) }.poll(cx) {
Poll::Ready(Ok(res)) => {
let (res, body) = res.into().replace_body(());
let mut send = send.take().unwrap();
let mut size = body.size();
let h2_res = self.prepare_response(res.head(), &mut size);
let h2_res =
self.as_mut().prepare_response(res.head(), &mut size);
this = self.as_mut().project();
let stream =
send.send_response(h2_res, size.is_eof()).map_err(|e| {
let stream = match send.send_response(h2_res, size.is_eof()) {
Err(e) => {
trace!("Error sending h2 response: {:?}", e);
})?;
return Poll::Ready(());
}
Ok(stream) => stream,
};
if size.is_eof() {
Ok(Async::Ready(()))
Poll::Ready(())
} else {
self.state = ServiceResponseState::SendPayload(stream, body);
self.poll()
*this.state =
ServiceResponseState::SendPayload(stream, body);
self.poll(cx)
}
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(e) => {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
let res: Response = e.into().into();
let (res, body) = res.replace_body(());
let mut send = send.take().unwrap();
let mut size = body.size();
let h2_res = self.prepare_response(res.head(), &mut size);
let h2_res =
self.as_mut().prepare_response(res.head(), &mut size);
this = self.as_mut().project();
let stream =
send.send_response(h2_res, size.is_eof()).map_err(|e| {
let stream = match send.send_response(h2_res, size.is_eof()) {
Err(e) => {
trace!("Error sending h2 response: {:?}", e);
})?;
return Poll::Ready(());
}
Ok(stream) => stream,
};
if size.is_eof() {
Ok(Async::Ready(()))
Poll::Ready(())
} else {
self.state = ServiceResponseState::SendPayload(
*this.state = ServiceResponseState::SendPayload(
stream,
body.into_body(),
);
self.poll()
self.poll(cx)
}
}
}
}
ServiceResponseState::SendPayload(ref mut stream, ref mut body) => loop {
loop {
if let Some(ref mut buffer) = self.buffer {
match stream.poll_capacity().map_err(|e| warn!("{:?}", e))? {
Async::NotReady => return Ok(Async::NotReady),
Async::Ready(None) => return Ok(Async::Ready(())),
Async::Ready(Some(cap)) => {
if let Some(ref mut buffer) = this.buffer {
match stream.poll_capacity(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(()),
Poll::Ready(Some(Ok(cap))) => {
let len = buffer.len();
let bytes = buffer.split_to(std::cmp::min(cap, len));
if let Err(e) = stream.send_data(bytes, false) {
warn!("{:?}", e);
return Err(());
return Poll::Ready(());
} else if !buffer.is_empty() {
let cap = std::cmp::min(buffer.len(), CHUNK_SIZE);
stream.reserve_capacity(cap);
} else {
self.buffer.take();
this.buffer.take();
}
}
Poll::Ready(Some(Err(e))) => {
warn!("{:?}", e);
return Poll::Ready(());
}
}
} else {
match body.poll_next() {
Ok(Async::NotReady) => {
return Ok(Async::NotReady);
}
Ok(Async::Ready(None)) => {
match body.poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => {
if let Err(e) = stream.send_data(Bytes::new(), true) {
warn!("{:?}", e);
return Err(());
} else {
return Ok(Async::Ready(()));
}
return Poll::Ready(());
}
Ok(Async::Ready(Some(chunk))) => {
Poll::Ready(Some(Ok(chunk))) => {
stream.reserve_capacity(std::cmp::min(
chunk.len(),
CHUNK_SIZE,
));
self.buffer = Some(chunk);
*this.buffer = Some(chunk);
}
Err(e) => {
Poll::Ready(Some(Err(e))) => {
error!("Response payload stream error: {:?}", e);
return Err(());
return Poll::Ready(());
}
}
}

View File

@@ -1,9 +1,11 @@
#![allow(dead_code, unused_imports)]
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures::{Async, Poll, Stream};
use futures::Stream;
use h2::RecvStream;
mod dispatcher;
@@ -25,22 +27,23 @@ impl Payload {
}
impl Stream for Payload {
type Item = Bytes;
type Error = PayloadError;
type Item = Result<Bytes, PayloadError>;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
match self.pl.poll() {
Ok(Async::Ready(Some(chunk))) => {
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match Pin::new(&mut this.pl).poll_data(cx) {
Poll::Ready(Some(Ok(chunk))) => {
let len = chunk.len();
if let Err(err) = self.pl.release_capacity().release_capacity(len) {
Err(err.into())
if let Err(err) = this.pl.release_capacity().release_capacity(len) {
Poll::Ready(Some(Err(err.into())))
} else {
Ok(Async::Ready(Some(chunk)))
Poll::Ready(Some(Ok(chunk)))
}
}
Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(err) => Err(err.into()),
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))),
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
}
}
}

View File

@@ -1,13 +1,16 @@
use std::fmt::Debug;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, net, rc};
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig};
use actix_service::{IntoNewService, NewService, Service};
use actix_service::{IntoServiceFactory, Service, ServiceFactory};
use bytes::Bytes;
use futures::future::{ok, FutureResult};
use futures::{try_ready, Async, Future, IntoFuture, Poll, Stream};
use futures::future::{ok, Ready};
use futures::{ready, Stream};
use h2::server::{self, Connection, Handshake};
use h2::RecvStream;
use log::error;
@@ -23,7 +26,7 @@ use crate::response::Response;
use super::dispatcher::Dispatcher;
/// `NewService` implementation for HTTP2 transport
/// `ServiceFactory` implementation for HTTP2 transport
pub struct H2Service<T, P, S, B> {
srv: S,
cfg: ServiceConfig,
@@ -33,30 +36,33 @@ pub struct H2Service<T, P, S, B> {
impl<T, P, S, B> H2Service<T, P, S, B>
where
S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>,
S::Response: Into<Response<B>>,
S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
{
/// Create new `HttpService` instance.
pub fn new<F: IntoNewService<S>>(service: F) -> Self {
pub fn new<F: IntoServiceFactory<S>>(service: F) -> Self {
let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0);
H2Service {
cfg,
on_connect: None,
srv: service.into_new_service(),
srv: service.into_factory(),
_t: PhantomData,
}
}
/// Create new `HttpService` instance with config.
pub fn with_config<F: IntoNewService<S>>(cfg: ServiceConfig, service: F) -> Self {
pub fn with_config<F: IntoServiceFactory<S>>(
cfg: ServiceConfig,
service: F,
) -> Self {
H2Service {
cfg,
on_connect: None,
srv: service.into_new_service(),
srv: service.into_factory(),
_t: PhantomData,
}
}
@@ -71,12 +77,12 @@ where
}
}
impl<T, P, S, B> NewService for H2Service<T, P, S, B>
impl<T, P, S, B> ServiceFactory for H2Service<T, P, S, B>
where
T: IoStream,
S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>,
S::Response: Into<Response<B>>,
S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
{
@@ -90,7 +96,7 @@ where
fn new_service(&self, cfg: &SrvConfig) -> Self::Future {
H2ServiceResponse {
fut: self.srv.new_service(cfg).into_future(),
fut: self.srv.new_service(cfg),
cfg: Some(self.cfg.clone()),
on_connect: self.on_connect.clone(),
_t: PhantomData,
@@ -99,8 +105,10 @@ where
}
#[doc(hidden)]
pub struct H2ServiceResponse<T, P, S: NewService, B> {
fut: <S::Future as IntoFuture>::Future,
#[pin_project::pin_project]
pub struct H2ServiceResponse<T, P, S: ServiceFactory, B> {
#[pin]
fut: S::Future,
cfg: Option<ServiceConfig>,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
_t: PhantomData<(T, P, B)>,
@@ -109,22 +117,25 @@ pub struct H2ServiceResponse<T, P, S: NewService, B> {
impl<T, P, S, B> Future for H2ServiceResponse<T, P, S, B>
where
T: IoStream,
S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>,
S::Response: Into<Response<B>>,
S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
{
type Item = H2ServiceHandler<T, P, S::Service, B>;
type Error = S::InitError;
type Output = Result<H2ServiceHandler<T, P, S::Service, B>, S::InitError>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let service = try_ready!(self.fut.poll());
Ok(Async::Ready(H2ServiceHandler::new(
self.cfg.take().unwrap(),
self.on_connect.clone(),
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.as_mut().project();
Poll::Ready(ready!(this.fut.poll(cx)).map(|service| {
let this = self.as_mut().project();
H2ServiceHandler::new(
this.cfg.take().unwrap(),
this.on_connect.clone(),
service,
)))
)
}))
}
}
@@ -139,9 +150,9 @@ pub struct H2ServiceHandler<T, P, S, B> {
impl<T, P, S, B> H2ServiceHandler<T, P, S, B>
where
S: Service<Request = Request>,
S::Error: Into<Error>,
S::Error: Into<Error> + 'static,
S::Future: 'static,
S::Response: Into<Response<B>>,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
{
fn new(
@@ -162,9 +173,9 @@ impl<T, P, S, B> Service for H2ServiceHandler<T, P, S, B>
where
T: IoStream,
S: Service<Request = Request>,
S::Error: Into<Error>,
S::Error: Into<Error> + 'static,
S::Future: 'static,
S::Response: Into<Response<B>>,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
{
type Request = Io<T, P>;
@@ -172,8 +183,8 @@ where
type Error = DispatchError;
type Future = H2ServiceHandlerResponse<T, S, B>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.srv.poll_ready().map_err(|e| {
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.srv.poll_ready(cx).map_err(|e| {
let e = e.into();
error!("Service readiness error: {:?}", e);
DispatchError::Service(e)
@@ -219,9 +230,9 @@ pub struct H2ServiceHandlerResponse<T, S, B>
where
T: IoStream,
S: Service<Request = Request>,
S::Error: Into<Error>,
S::Error: Into<Error> + 'static,
S::Future: 'static,
S::Response: Into<Response<B>>,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
{
state: State<T, S, B>,
@@ -231,25 +242,24 @@ impl<T, S, B> Future for H2ServiceHandlerResponse<T, S, B>
where
T: IoStream,
S: Service<Request = Request>,
S::Error: Into<Error>,
S::Error: Into<Error> + 'static,
S::Future: 'static,
S::Response: Into<Response<B>>,
S::Response: Into<Response<B>> + 'static,
B: MessageBody,
{
type Item = ();
type Error = DispatchError;
type Output = Result<(), DispatchError>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match self.state {
State::Incoming(ref mut disp) => disp.poll(),
State::Incoming(ref mut disp) => Pin::new(disp).poll(cx),
State::Handshake(
ref mut srv,
ref mut config,
ref peer_addr,
ref mut on_connect,
ref mut handshake,
) => match handshake.poll() {
Ok(Async::Ready(conn)) => {
) => match Pin::new(handshake).poll(cx) {
Poll::Ready(Ok(conn)) => {
self.state = State::Incoming(Dispatcher::new(
srv.take().unwrap(),
conn,
@@ -258,13 +268,13 @@ where
None,
*peer_addr,
));
self.poll()
self.poll(cx)
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(err) => {
Poll::Ready(Err(err)) => {
trace!("H2 handshake error: {}", err);
Err(err.into())
Poll::Ready(Err(err.into()))
}
Poll::Pending => Poll::Pending,
},
}
}

View File

@@ -76,6 +76,11 @@ pub enum DispositionParam {
/// the form.
Name(String),
/// A plain file name.
///
/// It is [not supposed](https://tools.ietf.org/html/rfc6266#appendix-D) to contain any
/// non-ASCII characters when used in a *Content-Disposition* HTTP response header, where
/// [`FilenameExt`](DispositionParam::FilenameExt) with charset UTF-8 may be used instead
/// in case there are Unicode characters in file names.
Filename(String),
/// An extended file name. It must not exist for `ContentType::Formdata` according to
/// [RFC7578 Section 4.2](https://tools.ietf.org/html/rfc7578#section-4.2).
@@ -220,7 +225,16 @@ impl DispositionParam {
/// ext-token = <the characters in token, followed by "*">
/// ```
///
/// **Note**: filename* [must not](https://tools.ietf.org/html/rfc7578#section-4.2) be used within
/// # Note
///
/// filename is [not supposed](https://tools.ietf.org/html/rfc6266#appendix-D) to contain any
/// non-ASCII characters when used in a *Content-Disposition* HTTP response header, where
/// filename* with charset UTF-8 may be used instead in case there are Unicode characters in file
/// names.
/// filename is [acceptable](https://tools.ietf.org/html/rfc7578#section-4.2) to be UTF-8 encoded
/// directly in a *Content-Disposition* header for *multipart/form-data*, though.
///
/// filename* [must not](https://tools.ietf.org/html/rfc7578#section-4.2) be used within
/// *multipart/form-data*.
///
/// # Example
@@ -251,6 +265,22 @@ impl DispositionParam {
/// };
/// assert_eq!(cd2.get_name(), Some("file")); // field name
/// assert_eq!(cd2.get_filename(), Some("bill.odt"));
///
/// // HTTP response header with Unicode characters in file names
/// let cd3 = ContentDisposition {
/// disposition: DispositionType::Attachment,
/// parameters: vec![
/// DispositionParam::FilenameExt(ExtendedValue {
/// charset: Charset::Ext(String::from("UTF-8")),
/// language_tag: None,
/// value: String::from("\u{1f600}.svg").into_bytes(),
/// }),
/// // fallback for better compatibility
/// DispositionParam::Filename(String::from("Grinning-Face-Emoji.svg"))
/// ],
/// };
/// assert_eq!(cd3.get_filename_ext().map(|ev| ev.value.as_ref()),
/// Some("\u{1f600}.svg".as_bytes()));
/// ```
///
/// # WARN
@@ -333,15 +363,17 @@ impl ContentDisposition {
// token: won't contains semicolon according to RFC 2616 Section 2.2
let (token, new_left) = split_once_and_trim(left, ';');
left = new_left;
token.to_owned()
};
if value.is_empty() {
if token.is_empty() {
// quoted-string can be empty, but token cannot be empty
return Err(crate::error::ParseError::Header);
}
token.to_owned()
};
let param = if param_name.eq_ignore_ascii_case("name") {
DispositionParam::Name(value)
} else if param_name.eq_ignore_ascii_case("filename") {
// See also comments in test_from_raw_uncessary_percent_decode.
DispositionParam::Filename(value)
} else {
DispositionParam::Unknown(param_name.to_owned(), value)
@@ -466,11 +498,40 @@ impl fmt::Display for DispositionType {
impl fmt::Display for DispositionParam {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// All ASCII control charaters (0-30, 127) excepting horizontal tab, double quote, and
// All ASCII control characters (0-30, 127) including horizontal tab, double quote, and
// backslash should be escaped in quoted-string (i.e. "foobar").
// Ref: RFC6266 S4.1 -> RFC2616 S2.2; RFC 7578 S4.2 -> RFC2183 S2 -> ... .
// Ref: RFC6266 S4.1 -> RFC2616 S3.6
// filename-parm = "filename" "=" value
// value = token | quoted-string
// quoted-string = ( <"> *(qdtext | quoted-pair ) <"> )
// qdtext = <any TEXT except <">>
// quoted-pair = "\" CHAR
// TEXT = <any OCTET except CTLs,
// but including LWS>
// LWS = [CRLF] 1*( SP | HT )
// OCTET = <any 8-bit sequence of data>
// CHAR = <any US-ASCII character (octets 0 - 127)>
// CTL = <any US-ASCII control character
// (octets 0 - 31) and DEL (127)>
//
// Ref: RFC7578 S4.2 -> RFC2183 S2 -> RFC2045 S5.1
// parameter := attribute "=" value
// attribute := token
// ; Matching of attributes
// ; is ALWAYS case-insensitive.
// value := token / quoted-string
// token := 1*<any (US-ASCII) CHAR except SPACE, CTLs,
// or tspecials>
// tspecials := "(" / ")" / "<" / ">" / "@" /
// "," / ";" / ":" / "\" / <">
// "/" / "[" / "]" / "?" / "="
// ; Must be in quoted-string,
// ; to use within parameter values
//
//
// See also comments in test_from_raw_uncessary_percent_decode.
lazy_static! {
static ref RE: Regex = Regex::new("[\x01-\x08\x10\x1F\x7F\"\\\\]").unwrap();
static ref RE: Regex = Regex::new("[\x00-\x08\x10-\x1F\x7F\"\\\\]").unwrap();
}
match self {
DispositionParam::Name(ref value) => write!(f, "name={}", value),
@@ -774,8 +835,18 @@ mod tests {
#[test]
fn test_from_raw_uncessary_percent_decode() {
// In fact, RFC7578 (multipart/form-data) Section 2 and 4.2 suggests that filename with
// non-ASCII characters MAY be percent-encoded.
// On the contrary, RFC6266 or other RFCs related to Content-Disposition response header
// do not mention such percent-encoding.
// So, it appears to be undecidable whether to percent-decode or not without
// knowing the usage scenario (multipart/form-data v.s. HTTP response header) and
// inevitable to unnecessarily percent-decode filename with %XX in the former scenario.
// Fortunately, it seems that almost all mainstream browsers just send UTF-8 encoded file
// names in quoted-string format (tested on Edge, IE11, Chrome and Firefox) without
// percent-encoding. So we do not bother to attempt to percent-decode.
let a = HeaderValue::from_static(
"form-data; name=photo; filename=\"%74%65%73%74%2e%70%6e%67\"", // Should not be decoded!
"form-data; name=photo; filename=\"%74%65%73%74%2e%70%6e%67\"",
);
let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap();
let b = ContentDisposition {
@@ -811,6 +882,9 @@ mod tests {
let a = HeaderValue::from_static("inline; filename= ");
assert!(ContentDisposition::from_raw(&a).is_err());
let a = HeaderValue::from_static("inline; filename=\"\"");
assert!(ContentDisposition::from_raw(&a).expect("parse cd").get_filename().expect("filename").is_empty());
}
#[test]

View File

@@ -29,7 +29,6 @@ impl Response {
STATIC_RESP!(AlreadyReported, StatusCode::ALREADY_REPORTED);
STATIC_RESP!(MultipleChoices, StatusCode::MULTIPLE_CHOICES);
STATIC_RESP!(MovedPermanenty, StatusCode::MOVED_PERMANENTLY);
STATIC_RESP!(MovedPermanently, StatusCode::MOVED_PERMANENTLY);
STATIC_RESP!(Found, StatusCode::FOUND);
STATIC_RESP!(SeeOther, StatusCode::SEE_OTHER);

View File

@@ -4,7 +4,8 @@
clippy::too_many_arguments,
clippy::new_without_default,
clippy::borrow_interior_mutable_const,
clippy::write_with_newline
clippy::write_with_newline,
unused_imports
)]
#[macro_use]

View File

@@ -388,6 +388,12 @@ impl BoxedResponseHead {
pub fn new(status: StatusCode) -> Self {
RESPONSE_POOL.with(|p| p.get_message(status))
}
pub(crate) fn take(&mut self) -> Self {
BoxedResponseHead {
head: self.head.take(),
}
}
}
impl std::ops::Deref for BoxedResponseHead {
@@ -406,7 +412,9 @@ impl std::ops::DerefMut for BoxedResponseHead {
impl Drop for BoxedResponseHead {
fn drop(&mut self) {
RESPONSE_POOL.with(|p| p.release(self.head.take().unwrap()))
if let Some(head) = self.head.take() {
RESPONSE_POOL.with(move |p| p.release(head))
}
}
}

View File

@@ -1,11 +1,15 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures::{Async, Poll, Stream};
use futures::Stream;
use h2::RecvStream;
use crate::error::PayloadError;
/// Type represent boxed payload
pub type PayloadStream = Box<dyn Stream<Item = Bytes, Error = PayloadError>>;
pub type PayloadStream = Pin<Box<dyn Stream<Item = Result<Bytes, PayloadError>>>>;
/// Type represent streaming payload
pub enum Payload<S = PayloadStream> {
@@ -48,18 +52,17 @@ impl<S> Payload<S> {
impl<S> Stream for Payload<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{
type Item = Bytes;
type Error = PayloadError;
type Item = Result<Bytes, PayloadError>;
#[inline]
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
match self {
Payload::None => Ok(Async::Ready(None)),
Payload::H1(ref mut pl) => pl.poll(),
Payload::H2(ref mut pl) => pl.poll(),
Payload::Stream(ref mut pl) => pl.poll(),
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match self.get_mut() {
Payload::None => Poll::Ready(None),
Payload::H1(ref mut pl) => pl.readany(cx),
Payload::H2(ref mut pl) => Pin::new(pl).poll_next(cx),
Payload::Stream(ref mut pl) => Pin::new(pl).poll_next(cx),
}
}
}

View File

@@ -80,6 +80,11 @@ impl<P> Request<P> {
)
}
/// Get request's payload
pub fn payload(&mut self) -> &mut Payload<P> {
&mut self.payload
}
/// Get request's payload
pub fn take_payload(&mut self) -> Payload<P> {
std::mem::replace(&mut self.payload, Payload::None)
@@ -199,7 +204,6 @@ mod tests {
assert_eq!(req.uri().query(), Some("q=1"));
let s = format!("{:?}", req);
println!("T: {:?}", s);
assert!(s.contains("Request HTTP/1.1 GET:/index.html"));
}
}

View File

@@ -1,11 +1,14 @@
//! Http response
use std::cell::{Ref, RefMut};
use std::future::Future;
use std::io::Write;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, str};
use bytes::{BufMut, Bytes, BytesMut};
use futures::future::{ok, FutureResult, IntoFuture};
use futures::Stream;
use futures::future::{ok, Ready};
use futures::stream::Stream;
use serde::Serialize;
use serde_json;
@@ -194,7 +197,7 @@ impl<B> Response<B> {
self.head.extensions.borrow_mut()
}
/// Get body os this response
/// Get body of this response
#[inline]
pub fn body(&self) -> &ResponseBody<B> {
&self.body
@@ -280,13 +283,15 @@ impl<B: MessageBody> fmt::Debug for Response<B> {
}
}
impl IntoFuture for Response {
type Item = Response;
type Error = Error;
type Future = FutureResult<Response, Error>;
impl Future for Response {
type Output = Result<Response, Error>;
fn into_future(self) -> Self::Future {
ok(self)
fn poll(mut self: Pin<&mut Self>, _: &mut Context) -> Poll<Self::Output> {
Poll::Ready(Ok(Response {
head: self.head.take(),
body: self.body.take_body(),
error: self.error.take(),
}))
}
}
@@ -635,7 +640,7 @@ impl ResponseBuilder {
/// `ResponseBuilder` can not be used after this call.
pub fn streaming<S, E>(&mut self, stream: S) -> Response
where
S: Stream<Item = Bytes, Error = E> + 'static,
S: Stream<Item = Result<Bytes, E>> + 'static,
E: Into<Error> + 'static,
{
self.body(Body::from_message(BodyStream::new(stream)))
@@ -757,13 +762,11 @@ impl<'a> From<&'a ResponseHead> for ResponseBuilder {
}
}
impl IntoFuture for ResponseBuilder {
type Item = Response;
type Error = Error;
type Future = FutureResult<Response, Error>;
impl Future for ResponseBuilder {
type Output = Result<Response, Error>;
fn into_future(mut self) -> Self::Future {
ok(self.finish())
fn poll(mut self: Pin<&mut Self>, _: &mut Context) -> Poll<Self::Output> {
Poll::Ready(Ok(self.finish()))
}
}

View File

@@ -1,14 +1,17 @@
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, io, net, rc};
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_server_config::{
Io as ServerIo, IoStream, Protocol, ServerConfig as SrvConfig,
};
use actix_service::{IntoNewService, NewService, Service};
use actix_service::{IntoServiceFactory, Service, ServiceFactory};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures::{try_ready, Async, Future, IntoFuture, Poll};
use futures::{ready, Future};
use h2::server::{self, Handshake};
use pin_project::{pin_project, project};
use crate::body::MessageBody;
use crate::builder::HttpServiceBuilder;
@@ -20,7 +23,7 @@ use crate::request::Request;
use crate::response::Response;
use crate::{h1, h2::Dispatcher};
/// `NewService` HTTP1.1/HTTP2 transport implementation
/// `ServiceFactory` HTTP1.1/HTTP2 transport implementation
pub struct HttpService<T, P, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler<T>> {
srv: S,
cfg: ServiceConfig,
@@ -32,10 +35,10 @@ pub struct HttpService<T, P, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler
impl<T, S, B> HttpService<T, (), S, B>
where
S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>,
S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
{
@@ -47,20 +50,20 @@ where
impl<T, P, S, B> HttpService<T, P, S, B>
where
S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>,
S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
{
/// Create new `HttpService` instance.
pub fn new<F: IntoNewService<S>>(service: F) -> Self {
pub fn new<F: IntoServiceFactory<S>>(service: F) -> Self {
let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0);
HttpService {
cfg,
srv: service.into_new_service(),
srv: service.into_factory(),
expect: h1::ExpectHandler,
upgrade: None,
on_connect: None,
@@ -69,13 +72,13 @@ where
}
/// Create new `HttpService` instance with config.
pub(crate) fn with_config<F: IntoNewService<S>>(
pub(crate) fn with_config<F: IntoServiceFactory<S>>(
cfg: ServiceConfig,
service: F,
) -> Self {
HttpService {
cfg,
srv: service.into_new_service(),
srv: service.into_factory(),
expect: h1::ExpectHandler,
upgrade: None,
on_connect: None,
@@ -86,10 +89,11 @@ where
impl<T, P, S, B, X, U> HttpService<T, P, S, B, X, U>
where
S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>,
S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody,
{
/// Provide service for `EXPECT: 100-Continue` support.
@@ -99,9 +103,10 @@ where
/// request will be forwarded to main service.
pub fn expect<X1>(self, expect: X1) -> HttpService<T, P, S, B, X1, U>
where
X1: NewService<Config = SrvConfig, Request = Request, Response = Request>,
X1: ServiceFactory<Config = SrvConfig, Request = Request, Response = Request>,
X1::Error: Into<Error>,
X1::InitError: fmt::Debug,
<X1::Service as Service>::Future: 'static,
{
HttpService {
expect,
@@ -119,13 +124,14 @@ where
/// and this service get called with original request and framed object.
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> HttpService<T, P, S, B, X, U1>
where
U1: NewService<
U1: ServiceFactory<
Config = SrvConfig,
Request = (Request, Framed<T, h1::Codec>),
Response = (),
>,
U1::Error: fmt::Display,
U1::InitError: fmt::Debug,
<U1::Service as Service>::Future: 'static,
{
HttpService {
upgrade,
@@ -147,25 +153,27 @@ where
}
}
impl<T, P, S, B, X, U> NewService for HttpService<T, P, S, B, X, U>
impl<T, P, S, B, X, U> ServiceFactory for HttpService<T, P, S, B, X, U>
where
T: IoStream,
S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>,
S: ServiceFactory<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
X: NewService<Config = SrvConfig, Request = Request, Response = Request>,
X: ServiceFactory<Config = SrvConfig, Request = Request, Response = Request>,
X::Error: Into<Error>,
X::InitError: fmt::Debug,
U: NewService<
<X::Service as Service>::Future: 'static,
U: ServiceFactory<
Config = SrvConfig,
Request = (Request, Framed<T, h1::Codec>),
Response = (),
>,
U::Error: fmt::Display,
U::InitError: fmt::Debug,
<U::Service as Service>::Future: 'static,
{
type Config = SrvConfig;
type Request = ServerIo<T, P>;
@@ -177,7 +185,7 @@ where
fn new_service(&self, cfg: &SrvConfig) -> Self::Future {
HttpServiceResponse {
fut: self.srv.new_service(cfg).into_future(),
fut: self.srv.new_service(cfg),
fut_ex: Some(self.expect.new_service(cfg)),
fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)),
expect: None,
@@ -190,9 +198,20 @@ where
}
#[doc(hidden)]
pub struct HttpServiceResponse<T, P, S: NewService, B, X: NewService, U: NewService> {
#[pin_project]
pub struct HttpServiceResponse<
T,
P,
S: ServiceFactory,
B,
X: ServiceFactory,
U: ServiceFactory,
> {
#[pin]
fut: S::Future,
#[pin]
fut_ex: Option<X::Future>,
#[pin]
fut_upg: Option<U::Future>,
expect: Option<X::Service>,
upgrade: Option<U::Service>,
@@ -204,50 +223,59 @@ pub struct HttpServiceResponse<T, P, S: NewService, B, X: NewService, U: NewServ
impl<T, P, S, B, X, U> Future for HttpServiceResponse<T, P, S, B, X, U>
where
T: IoStream,
S: NewService<Request = Request>,
S::Error: Into<Error>,
S: ServiceFactory<Request = Request>,
S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
X: NewService<Request = Request, Response = Request>,
X: ServiceFactory<Request = Request, Response = Request>,
X::Error: Into<Error>,
X::InitError: fmt::Debug,
U: NewService<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
<X::Service as Service>::Future: 'static,
U: ServiceFactory<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display,
U::InitError: fmt::Debug,
<U::Service as Service>::Future: 'static,
{
type Item = HttpServiceHandler<T, P, S::Service, B, X::Service, U::Service>;
type Error = ();
type Output =
Result<HttpServiceHandler<T, P, S::Service, B, X::Service, U::Service>, ()>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(ref mut fut) = self.fut_ex {
let expect = try_ready!(fut
.poll()
.map_err(|e| log::error!("Init http service error: {:?}", e)));
self.expect = Some(expect);
self.fut_ex.take();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut this = self.as_mut().project();
if let Some(fut) = this.fut_ex.as_pin_mut() {
let expect = ready!(fut
.poll(cx)
.map_err(|e| log::error!("Init http service error: {:?}", e)))?;
this = self.as_mut().project();
*this.expect = Some(expect);
this.fut_ex.set(None);
}
if let Some(ref mut fut) = self.fut_upg {
let upgrade = try_ready!(fut
.poll()
.map_err(|e| log::error!("Init http service error: {:?}", e)));
self.upgrade = Some(upgrade);
self.fut_ex.take();
if let Some(fut) = this.fut_upg.as_pin_mut() {
let upgrade = ready!(fut
.poll(cx)
.map_err(|e| log::error!("Init http service error: {:?}", e)))?;
this = self.as_mut().project();
*this.upgrade = Some(upgrade);
this.fut_ex.set(None);
}
let service = try_ready!(self
let result = ready!(this
.fut
.poll()
.poll(cx)
.map_err(|e| log::error!("Init http service error: {:?}", e)));
Ok(Async::Ready(HttpServiceHandler::new(
self.cfg.take().unwrap(),
Poll::Ready(result.map(|service| {
let this = self.as_mut().project();
HttpServiceHandler::new(
this.cfg.take().unwrap(),
service,
self.expect.take().unwrap(),
self.upgrade.take(),
self.on_connect.clone(),
)))
this.expect.take().unwrap(),
this.upgrade.take(),
this.on_connect.clone(),
)
}))
}
}
@@ -264,9 +292,9 @@ pub struct HttpServiceHandler<T, P, S, B, X, U> {
impl<T, P, S, B, X, U> HttpServiceHandler<T, P, S, B, X, U>
where
S: Service<Request = Request>,
S::Error: Into<Error>,
S::Error: Into<Error> + 'static,
S::Future: 'static,
S::Response: Into<Response<B>>,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>,
@@ -295,9 +323,9 @@ impl<T, P, S, B, X, U> Service for HttpServiceHandler<T, P, S, B, X, U>
where
T: IoStream,
S: Service<Request = Request>,
S::Error: Into<Error>,
S::Error: Into<Error> + 'static,
S::Future: 'static,
S::Response: Into<Response<B>>,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>,
@@ -309,10 +337,10 @@ where
type Error = DispatchError;
type Future = HttpServiceHandlerResponse<T, S, B, X, U>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let ready = self
.expect
.poll_ready()
.poll_ready(cx)
.map_err(|e| {
let e = e.into();
log::error!("Http service readiness error: {:?}", e);
@@ -322,7 +350,7 @@ where
let ready = self
.srv
.poll_ready()
.poll_ready(cx)
.map_err(|e| {
let e = e.into();
log::error!("Http service readiness error: {:?}", e);
@@ -332,9 +360,9 @@ where
&& ready;
if ready {
Ok(Async::Ready(()))
Poll::Ready(Ok(()))
} else {
Ok(Async::NotReady)
Poll::Pending
}
}
@@ -389,6 +417,7 @@ where
}
}
#[pin_project]
enum State<T, S, B, X, U>
where
S: Service<Request = Request>,
@@ -401,8 +430,8 @@ where
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display,
{
H1(h1::Dispatcher<T, S, B, X, U>),
H2(Dispatcher<Io<T>, S, B>),
H1(#[pin] h1::Dispatcher<T, S, B, X, U>),
H2(#[pin] Dispatcher<Io<T>, S, B>),
Unknown(
Option<(
T,
@@ -425,19 +454,21 @@ where
),
}
#[pin_project]
pub struct HttpServiceHandlerResponse<T, S, B, X, U>
where
T: IoStream,
S: Service<Request = Request>,
S::Error: Into<Error>,
S::Error: Into<Error> + 'static,
S::Future: 'static,
S::Response: Into<Response<B>>,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display,
{
#[pin]
state: State<T, S, B, X, U>,
}
@@ -447,30 +478,51 @@ impl<T, S, B, X, U> Future for HttpServiceHandlerResponse<T, S, B, X, U>
where
T: IoStream,
S: Service<Request = Request>,
S::Error: Into<Error>,
S::Error: Into<Error> + 'static,
S::Future: 'static,
S::Response: Into<Response<B>>,
S::Response: Into<Response<B>> + 'static,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display,
{
type Item = ();
type Error = DispatchError;
type Output = Result<(), DispatchError>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.state {
State::H1(ref mut disp) => disp.poll(),
State::H2(ref mut disp) => disp.poll(),
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.project().state.poll(cx)
}
}
impl<T, S, B, X, U> State<T, S, B, X, U>
where
T: IoStream,
S: Service<Request = Request>,
S::Error: Into<Error> + 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display,
{
#[project]
fn poll(
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Result<(), DispatchError>> {
#[project]
match self.as_mut().project() {
State::H1(disp) => disp.poll(cx),
State::H2(disp) => disp.poll(cx),
State::Unknown(ref mut data) => {
if let Some(ref mut item) = data {
loop {
// Safety - we only write to the returned slice.
let b = unsafe { item.1.bytes_mut() };
let n = try_ready!(item.0.poll_read(b));
let n = ready!(Pin::new(&mut item.0).poll_read(cx, b))?;
if n == 0 {
return Ok(Async::Ready(()));
return Poll::Ready(Ok(()));
}
// Safety - we know that 'n' bytes have
// been initialized via the contract of
@@ -491,15 +543,15 @@ where
inner: io,
unread: Some(buf),
};
self.state = State::Handshake(Some((
self.set(State::Handshake(Some((
server::handshake(io),
cfg,
srv,
peer_addr,
on_connect,
)));
))));
} else {
self.state = State::H1(h1::Dispatcher::with_timeout(
self.set(State::H1(h1::Dispatcher::with_timeout(
io,
h1::Codec::new(cfg.clone()),
cfg,
@@ -509,36 +561,38 @@ where
expect,
upgrade,
on_connect,
))
)))
}
self.poll()
self.poll(cx)
}
State::Handshake(ref mut data) => {
let conn = if let Some(ref mut item) = data {
match item.0.poll() {
Ok(Async::Ready(conn)) => conn,
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => {
match Pin::new(&mut item.0).poll(cx) {
Poll::Ready(Ok(conn)) => conn,
Poll::Ready(Err(err)) => {
trace!("H2 handshake error: {}", err);
return Err(err.into());
return Poll::Ready(Err(err.into()));
}
Poll::Pending => return Poll::Pending,
}
} else {
panic!()
};
let (_, cfg, srv, peer_addr, on_connect) = data.take().unwrap();
self.state = State::H2(Dispatcher::new(
self.set(State::H2(Dispatcher::new(
srv, conn, on_connect, cfg, None, peer_addr,
));
self.poll()
)));
self.poll(cx)
}
}
}
}
/// Wrapper for `AsyncRead + AsyncWrite` types
#[pin_project::pin_project]
struct Io<T> {
unread: Option<BytesMut>,
#[pin]
inner: T,
}
@@ -568,21 +622,65 @@ impl<T: io::Write> io::Write for Io<T> {
}
impl<T: AsyncRead> AsyncRead for Io<T> {
// unsafe fn initializer(&self) -> io::Initializer {
// self.get_mut().inner.initializer()
// }
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}
}
impl<T: AsyncWrite> AsyncWrite for Io<T> {
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.inner.shutdown()
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.project();
if let Some(mut bytes) = this.unread.take() {
let size = std::cmp::min(buf.len(), bytes.len());
buf[..size].copy_from_slice(&bytes[..size]);
if bytes.len() > size {
bytes.split_to(size);
*this.unread = Some(bytes);
}
fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
self.inner.write_buf(buf)
Poll::Ready(Ok(size))
} else {
this.inner.poll_read(cx, buf)
}
}
impl<T: IoStream> IoStream for Io<T> {
// fn poll_read_vectored(
// self: Pin<&mut Self>,
// cx: &mut Context<'_>,
// bufs: &mut [io::IoSliceMut<'_>],
// ) -> Poll<io::Result<usize>> {
// self.get_mut().inner.poll_read_vectored(cx, bufs)
// }
}
impl<T: AsyncWrite> tokio_io::AsyncWrite for Io<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.project().inner.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
self.project().inner.poll_shutdown(cx)
}
}
impl<T: IoStream> actix_server_config::IoStream for Io<T> {
#[inline]
fn peer_addr(&self) -> Option<net::SocketAddr> {
self.inner.peer_addr()

View File

@@ -1,12 +1,13 @@
//! Test Various helpers for Actix applications to use during testing.
use std::fmt::Write as FmtWrite;
use std::io;
use std::io::{self, Read, Write};
use std::pin::Pin;
use std::str::FromStr;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite};
use actix_server_config::IoStream;
use bytes::{Buf, Bytes, BytesMut};
use futures::{Async, Poll};
use http::header::{self, HeaderName, HeaderValue};
use http::{HttpTryFrom, Method, Uri, Version};
use percent_encoding::percent_encode;
@@ -244,14 +245,31 @@ impl io::Write for TestBuffer {
}
}
impl AsyncRead for TestBuffer {}
impl AsyncRead for TestBuffer {
fn poll_read(
self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.get_mut().read(buf))
}
}
impl AsyncWrite for TestBuffer {
fn shutdown(&mut self) -> Poll<(), io::Error> {
Ok(Async::Ready(()))
fn poll_write(
self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.get_mut().write(buf))
}
fn write_buf<B: Buf>(&mut self, _: &mut B) -> Poll<usize, io::Error> {
Ok(Async::NotReady)
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}

View File

@@ -1,7 +1,10 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_service::{IntoService, Service};
use actix_utils::framed::{FramedTransport, FramedTransportError};
use futures::{Future, Poll};
use super::{Codec, Frame, Message};
@@ -40,10 +43,9 @@ where
S::Future: 'static,
S::Error: 'static,
{
type Item = ();
type Error = FramedTransportError<S::Error, Codec>;
type Output = Result<(), FramedTransportError<S::Error, Codec>>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
self.inner.poll()
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx)
}
}

View File

@@ -1,9 +1,9 @@
use actix_service::NewService;
use actix_service::ServiceFactory;
use bytes::Bytes;
use futures::future::{self, ok};
use actix_http::{http, HttpService, Request, Response};
use actix_http_test::TestServer;
use actix_http_test::{block_on, TestServer};
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
@@ -29,43 +29,50 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
#[test]
fn test_h1_v2() {
env_logger::init();
let mut srv = TestServer::new(move || {
HttpService::build().finish(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
block_on(async {
let srv = TestServer::start(move || {
HttpService::build()
.finish(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
});
let response = srv.block_on(srv.get("/").send()).unwrap();
let response = srv.get("/").send().await.unwrap();
assert!(response.status().is_success());
let request = srv.get("/").header("x-test", "111").send();
let response = srv.block_on(request).unwrap();
let mut response = request.await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
let bytes = response.body().await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
let response = srv.block_on(srv.post("/").send()).unwrap();
let mut response = srv.post("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
let bytes = response.body().await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_connection_close() {
let mut srv = TestServer::new(move || {
block_on(async {
let srv = TestServer::start(move || {
HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map(|_| ())
});
let response = srv.block_on(srv.get("/").force_close().send()).unwrap();
let response = srv.get("/").force_close().send().await.unwrap();
assert!(response.status().is_success());
})
}
#[test]
fn test_with_query_parameter() {
let mut srv = TestServer::new(move || {
block_on(async {
let srv = TestServer::start(move || {
HttpService::build()
.finish(|req: Request| {
if req.uri().query().unwrap().contains("qp=") {
@@ -77,7 +84,8 @@ fn test_with_query_parameter() {
.map(|_| ())
});
let request = srv.request(http::Method::GET, srv.url("/?qp=5")).send();
let response = srv.block_on(request).unwrap();
let request = srv.request(http::Method::GET, srv.url("/?qp=5"));
let response = request.send().await.unwrap();
assert!(response.status().is_success());
})
}

View File

@@ -0,0 +1,545 @@
#![cfg(feature = "openssl")]
use std::io;
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http_test::{block_on, TestServer};
use actix_server::ssl::OpensslAcceptor;
use actix_server_config::ServerConfig;
use actix_service::{factory_fn_cfg, pipeline_factory, service_fn2, ServiceFactory};
use bytes::{Bytes, BytesMut};
use futures::future::{err, ok, ready};
use futures::stream::{once, Stream, StreamExt};
use open_ssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod};
use actix_http::error::{ErrorBadRequest, PayloadError};
use actix_http::http::header::{self, HeaderName, HeaderValue};
use actix_http::http::{Method, StatusCode, Version};
use actix_http::httpmessage::HttpMessage;
use actix_http::{body, Error, HttpService, Request, Response};
async fn load_body<S>(stream: S) -> Result<BytesMut, PayloadError>
where
S: Stream<Item = Result<Bytes, PayloadError>>,
{
let body = stream
.map(|res| match res {
Ok(chunk) => chunk,
Err(_) => panic!(),
})
.fold(BytesMut::new(), move |mut body, chunk| {
body.extend_from_slice(&chunk);
ready(body)
})
.await;
Ok(body)
}
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> io::Result<OpensslAcceptor<T, ()>> {
// load ssl keys
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
builder
.set_private_key_file("../tests/key.pem", SslFiletype::PEM)
.unwrap();
builder
.set_certificate_chain_file("../tests/cert.pem")
.unwrap();
builder.set_alpn_select_callback(|_, protos| {
const H2: &[u8] = b"\x02h2";
if protos.windows(3).any(|window| window == H2) {
Ok(b"h2")
} else {
Err(AlpnError::NOACK)
}
});
builder.set_alpn_protos(b"\x02h2")?;
Ok(OpensslAcceptor::new(builder.build()))
}
#[test]
fn test_h2() -> io::Result<()> {
block_on(async {
let openssl = ssl_acceptor()?;
let srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| ok::<_, Error>(Response::Ok().finish()))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
})
}
#[test]
fn test_h2_1() -> io::Result<()> {
block_on(async {
let openssl = ssl_acceptor()?;
let srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.finish(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_2);
ok::<_, Error>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
})
}
#[test]
fn test_h2_body() -> io::Result<()> {
block_on(async {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let openssl = ssl_acceptor()?;
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|mut req: Request<_>| {
async move {
let body = load_body(req.take_payload()).await?;
Ok::<_, Error>(Response::Ok().body(body))
}
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send_body(data.clone()).await.unwrap();
assert!(response.status().is_success());
let body = srv.load_body(response).await.unwrap();
assert_eq!(&body, data.as_bytes());
Ok(())
})
}
#[test]
fn test_h2_content_length() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
ok::<_, ()>(Response::new(statuses[indx]))
})
.map_err(|_| ()),
)
});
let header = HeaderName::from_static("content-length");
let value = HeaderValue::from_static("0");
{
for i in 0..4 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);
let req = srv
.request(Method::HEAD, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);
}
for i in 4..6 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
}
}
})
}
#[test]
fn test_h2_headers() {
block_on(async {
let data = STR.repeat(10);
let data2 = data.clone();
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
let data = data.clone();
pipeline_factory(openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)))
.and_then(
HttpService::build().h2(move |_| {
let mut builder = Response::Ok();
for idx in 0..90 {
builder.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
ok::<_, ()>(builder.body(data.clone()))
}).map_err(|_| ()))
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from(data2));
})
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test]
fn test_h2_body2() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_head_empty() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
})
}
#[test]
fn test_h2_head_binary() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| {
ok::<_, ()>(
Response::Ok().content_length(STR.len() as u64).body(STR),
)
})
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
})
}
#[test]
fn test_h2_head_binary2() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
})
}
#[test]
fn test_h2_body_length() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| {
let body = once(ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.body(body::SizedStream::new(STR.len() as u64, body)),
)
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_body_chunked_explicit() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| {
let body =
once(ok::<_, Error>(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
assert!(!response.headers().contains_key(header::TRANSFER_ENCODING));
// read response
let bytes = srv.load_body(response).await.unwrap();
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_response_http_error_handling() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(factory_fn_cfg(|_: &ServerConfig| {
ok::<_, ()>(service_fn2(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>(
Response::Ok()
.header(header::CONTENT_TYPE, broken_header)
.body(STR),
)
}))
}))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
})
}
#[test]
fn test_h2_service_error() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(|_| err::<Response, Error>(ErrorBadRequest("error")))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
})
}
#[test]
fn test_h2_on_connect() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let srv = TestServer::start(move || {
pipeline_factory(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.on_connect(|_| 10usize)
.h2(|req: Request| {
assert!(req.extensions().contains::<usize>());
ok::<_, ()>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
})
}

View File

@@ -0,0 +1,474 @@
#![cfg(feature = "rustls")]
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::error::PayloadError;
use actix_http::http::header::{self, HeaderName, HeaderValue};
use actix_http::http::{Method, StatusCode, Version};
use actix_http::{body, error, Error, HttpService, Request, Response};
use actix_http_test::{block_on, TestServer};
use actix_server::ssl::RustlsAcceptor;
use actix_server_config::ServerConfig;
use actix_service::{factory_fn_cfg, pipeline_factory, service_fn2, ServiceFactory};
use bytes::{Bytes, BytesMut};
use futures::future::{self, err, ok};
use futures::stream::{once, Stream, StreamExt};
use rust_tls::{
internal::pemfile::{certs, pkcs8_private_keys},
NoClientAuth, ServerConfig as RustlsServerConfig,
};
use std::fs::File;
use std::io::{self, BufReader};
async fn load_body<S>(mut stream: S) -> Result<BytesMut, PayloadError>
where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{
let mut body = BytesMut::new();
while let Some(item) = stream.next().await {
body.extend_from_slice(&item?)
}
Ok(body)
}
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> io::Result<RustlsAcceptor<T, ()>> {
// load ssl keys
let mut config = RustlsServerConfig::new(NoClientAuth::new());
let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap());
let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap());
let cert_chain = certs(cert_file).unwrap();
let mut keys = pkcs8_private_keys(key_file).unwrap();
config.set_single_cert(cert_chain, keys.remove(0)).unwrap();
let protos = vec![b"h2".to_vec()];
config.set_protocols(&protos);
Ok(RustlsAcceptor::new(config))
}
#[test]
fn test_h2() -> io::Result<()> {
block_on(async {
let rustls = ssl_acceptor()?;
let srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| future::ok::<_, Error>(Response::Ok().finish()))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
})
}
#[test]
fn test_h2_1() -> io::Result<()> {
block_on(async {
let rustls = ssl_acceptor()?;
let srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.finish(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_2);
future::ok::<_, Error>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
})
}
#[test]
fn test_h2_body1() -> io::Result<()> {
block_on(async {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let rustls = ssl_acceptor()?;
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|mut req: Request<_>| {
async move {
let body = load_body(req.take_payload()).await?;
Ok::<_, Error>(Response::Ok().body(body))
}
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send_body(data.clone()).await.unwrap();
assert!(response.status().is_success());
let body = srv.load_body(response).await.unwrap();
assert_eq!(&body, data.as_bytes());
Ok(())
})
}
#[test]
fn test_h2_content_length() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
future::ok::<_, ()>(Response::new(statuses[indx]))
})
.map_err(|_| ()),
)
});
let header = HeaderName::from_static("content-length");
let value = HeaderValue::from_static("0");
{
for i in 0..4 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);
let req = srv
.request(Method::HEAD, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);
}
for i in 4..6 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
}
}
})
}
#[test]
fn test_h2_headers() {
block_on(async {
let data = STR.repeat(10);
let data2 = data.clone();
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
let data = data.clone();
pipeline_factory(rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build().h2(move |_| {
let mut config = Response::Ok();
for idx in 0..90 {
config.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
future::ok::<_, ()>(config.body(data.clone()))
}).map_err(|_| ()))
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from(data2));
})
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test]
fn test_h2_body2() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_head_empty() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
})
}
#[test]
fn test_h2_head_binary() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| {
ok::<_, ()>(
Response::Ok()
.content_length(STR.len() as u64)
.body(STR),
)
})
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
})
}
#[test]
fn test_h2_head_binary2() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
})
}
#[test]
fn test_h2_body_length() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| {
let body = once(ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok().body(body::SizedStream::new(
STR.len() as u64,
body,
)),
)
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_body_chunked_explicit() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| {
let body =
once(ok::<_, Error>(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
assert!(!response.headers().contains_key(header::TRANSFER_ENCODING));
// read response
let bytes = srv.load_body(response).await.unwrap();
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h2_response_http_error_handling() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(factory_fn_cfg(|_: &ServerConfig| {
ok::<_, ()>(service_fn2(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>(
Response::Ok()
.header(
http::header::CONTENT_TYPE,
broken_header,
)
.body(STR),
)
}))
}))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
})
}
#[test]
fn test_h2_service_error() {
block_on(async {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::start(move || {
pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
HttpService::build()
.h2(|_| err::<Response, Error>(error::ErrorBadRequest("error")))
.map_err(|_| ()),
)
});
let response = srv.sget("/").send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
})
}

View File

@@ -1,462 +0,0 @@
#![cfg(feature = "rust-tls")]
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::error::PayloadError;
use actix_http::http::header::{self, HeaderName, HeaderValue};
use actix_http::http::{Method, StatusCode, Version};
use actix_http::{body, error, Error, HttpService, Request, Response};
use actix_http_test::TestServer;
use actix_server::ssl::RustlsAcceptor;
use actix_server_config::ServerConfig;
use actix_service::{new_service_cfg, NewService};
use bytes::{Bytes, BytesMut};
use futures::future::{self, ok, Future};
use futures::stream::{once, Stream};
use rustls::{
internal::pemfile::{certs, pkcs8_private_keys},
NoClientAuth, ServerConfig as RustlsServerConfig,
};
use std::fs::File;
use std::io::{BufReader, Result};
fn load_body<S>(stream: S) -> impl Future<Item = BytesMut, Error = PayloadError>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
stream.fold(BytesMut::new(), move |mut body, chunk| {
body.extend_from_slice(&chunk);
Ok::<_, PayloadError>(body)
})
}
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<RustlsAcceptor<T, ()>> {
// load ssl keys
let mut config = RustlsServerConfig::new(NoClientAuth::new());
let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap());
let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap());
let cert_chain = certs(cert_file).unwrap();
let mut keys = pkcs8_private_keys(key_file).unwrap();
config.set_single_cert(cert_chain, keys.remove(0)).unwrap();
let protos = vec![b"h2".to_vec()];
config.set_protocols(&protos);
Ok(RustlsAcceptor::new(config))
}
#[test]
fn test_h2() -> Result<()> {
let rustls = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| future::ok::<_, Error>(Response::Ok().finish()))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
Ok(())
}
#[test]
fn test_h2_1() -> Result<()> {
let rustls = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.finish(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_2);
future::ok::<_, Error>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
Ok(())
}
#[test]
fn test_h2_body() -> Result<()> {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let rustls = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|mut req: Request<_>| {
load_body(req.take_payload())
.and_then(|body| Ok(Response::Ok().body(body)))
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send_body(data.clone())).unwrap();
assert!(response.status().is_success());
let body = srv.load_body(response).unwrap();
assert_eq!(&body, data.as_bytes());
Ok(())
}
#[test]
fn test_h2_content_length() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
future::ok::<_, ()>(Response::new(statuses[indx]))
})
.map_err(|_| ()),
)
});
let header = HeaderName::from_static("content-length");
let value = HeaderValue::from_static("0");
{
for i in 0..4 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None);
let req = srv
.request(Method::HEAD, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None);
}
for i in 4..6 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
}
}
}
#[test]
fn test_h2_headers() {
let data = STR.repeat(10);
let data2 = data.clone();
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
let data = data.clone();
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build().h2(move |_| {
let mut config = Response::Ok();
for idx in 0..90 {
config.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
future::ok::<_, ()>(config.body(data.clone()))
}).map_err(|_| ()))
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from(data2));
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test]
fn test_h2_body2() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_head_empty() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty());
}
#[test]
fn test_h2_head_binary() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
ok::<_, ()>(
Response::Ok().content_length(STR.len() as u64).body(STR),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty());
}
#[test]
fn test_h2_head_binary2() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
}
#[test]
fn test_h2_body_length() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.body(body::SizedStream::new(STR.len() as u64, body)),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_body_chunked_explicit() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
let body =
once::<_, Error>(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
assert!(!response.headers().contains_key(header::TRANSFER_ENCODING));
// read response
let bytes = srv.load_body(response).unwrap();
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_response_http_error_handling() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(new_service_cfg(|_: &ServerConfig| {
Ok::<_, ()>(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>(
Response::Ok()
.header(http::header::CONTENT_TYPE, broken_header)
.body(STR),
)
})
}))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
}
#[test]
fn test_h2_service_error() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| Err::<Response, Error>(error::ErrorBadRequest("error")))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
}

View File

@@ -2,14 +2,14 @@ use std::io::{Read, Write};
use std::time::Duration;
use std::{net, thread};
use actix_http_test::TestServer;
use actix_http_test::{block_on, TestServer};
use actix_server_config::ServerConfig;
use actix_service::{new_service_cfg, service_fn, NewService};
use actix_service::{factory_fn_cfg, pipeline, service_fn, ServiceFactory};
use bytes::Bytes;
use futures::future::{self, ok, Future};
use futures::stream::{once, Stream};
use futures::future::{self, err, ok, ready, FutureExt};
use futures::stream::{once, StreamExt};
use regex::Regex;
use tokio_timer::sleep;
use tokio_timer::delay_for;
use actix_http::httpmessage::HttpMessage;
use actix_http::{
@@ -18,7 +18,8 @@ use actix_http::{
#[test]
fn test_h1() {
let mut srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::build()
.keep_alive(KeepAlive::Disabled)
.client_timeout(1000)
@@ -29,13 +30,15 @@ fn test_h1() {
})
});
let response = srv.block_on(srv.get("/").send()).unwrap();
let response = srv.get("/").send().await.unwrap();
assert!(response.status().is_success());
})
}
#[test]
fn test_h1_2() {
let mut srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::build()
.keep_alive(KeepAlive::Disabled)
.client_timeout(1000)
@@ -48,19 +51,21 @@ fn test_h1_2() {
.map(|_| ())
});
let response = srv.block_on(srv.get("/").send()).unwrap();
let response = srv.get("/").send().await.unwrap();
assert!(response.status().is_success());
})
}
#[test]
fn test_expect_continue() {
let srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::build()
.expect(service_fn(|req: Request| {
if req.head().uri.query() == Some("yes=") {
Ok(req)
ok(req)
} else {
Err(error::ErrorPreconditionFailed("error"))
err(error::ErrorPreconditionFailed("error"))
}
}))
.finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
@@ -73,22 +78,25 @@ fn test_expect_continue() {
assert!(data.starts_with("HTTP/1.1 412 Precondition Failed\r\ncontent-length"));
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
let _ = stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n");
let _ =
stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n");
let mut data = String::new();
let _ = stream.read_to_string(&mut data);
assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n"));
})
}
#[test]
fn test_expect_continue_h1() {
let srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::build()
.expect(service_fn(|req: Request| {
sleep(Duration::from_millis(20)).then(move |_| {
delay_for(Duration::from_millis(20)).then(move |_| {
if req.head().uri.query() == Some("yes=") {
Ok(req)
ok(req)
} else {
Err(error::ErrorPreconditionFailed("error"))
err(error::ErrorPreconditionFailed("error"))
}
})
}))
@@ -102,25 +110,33 @@ fn test_expect_continue_h1() {
assert!(data.starts_with("HTTP/1.1 412 Precondition Failed\r\ncontent-length"));
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
let _ = stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n");
let _ =
stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n");
let mut data = String::new();
let _ = stream.read_to_string(&mut data);
assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n"));
})
}
#[test]
fn test_chunked_payload() {
block_on(async {
let chunk_sizes = vec![32768, 32, 32768];
let total_size: usize = chunk_sizes.iter().sum();
let srv = TestServer::new(|| {
HttpService::build().h1(|mut request: Request| {
let srv = TestServer::start(|| {
HttpService::build().h1(service_fn(|mut request: Request| {
request
.take_payload()
.map_err(|e| panic!(format!("Error reading payload: {}", e)))
.fold(0usize, |acc, chunk| future::ok::<_, ()>(acc + chunk.len()))
.map(|req_size| Response::Ok().body(format!("size={}", req_size)))
.map(|res| match res {
Ok(pl) => pl,
Err(e) => panic!(format!("Error reading payload: {}", e)),
})
.fold(0usize, |acc, chunk| ready(acc + chunk.len()))
.map(|req_size| {
Ok::<_, Error>(Response::Ok().body(format!("size={}", req_size)))
})
}))
});
let returned_size = {
@@ -148,17 +164,21 @@ fn test_chunked_payload() {
let re = Regex::new(r"size=(\d+)").unwrap();
let size: usize = match re.captures(&data) {
Some(caps) => caps.get(1).unwrap().as_str().parse().unwrap(),
None => panic!(format!("Failed to find size in HTTP Response: {}", data)),
None => {
panic!(format!("Failed to find size in HTTP Response: {}", data))
}
};
size
};
assert_eq!(returned_size, total_size);
})
}
#[test]
fn test_slow_request() {
let srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::build()
.client_timeout(100)
.finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
@@ -169,11 +189,13 @@ fn test_slow_request() {
let mut data = String::new();
let _ = stream.read_to_string(&mut data);
assert!(data.starts_with("HTTP/1.1 408 Request Timeout"));
})
}
#[test]
fn test_http1_malformed_request() {
let srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
});
@@ -182,11 +204,13 @@ fn test_http1_malformed_request() {
let mut data = String::new();
let _ = stream.read_to_string(&mut data);
assert!(data.starts_with("HTTP/1.1 400 Bad Request"));
})
}
#[test]
fn test_http1_keepalive() {
let srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
});
@@ -200,11 +224,13 @@ fn test_http1_keepalive() {
let mut data = vec![0; 1024];
let _ = stream.read(&mut data);
assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n");
})
}
#[test]
fn test_http1_keepalive_timeout() {
let srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::build()
.keep_alive(1)
.h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
@@ -220,17 +246,19 @@ fn test_http1_keepalive_timeout() {
let mut data = vec![0; 1024];
let res = stream.read(&mut data).unwrap();
assert_eq!(res, 0);
})
}
#[test]
fn test_http1_keepalive_close() {
let srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
});
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
let _ =
stream.write_all(b"GET /test/tests/test HTTP/1.1\r\nconnection: close\r\n\r\n");
let _ = stream
.write_all(b"GET /test/tests/test HTTP/1.1\r\nconnection: close\r\n\r\n");
let mut data = vec![0; 1024];
let _ = stream.read(&mut data);
assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n");
@@ -238,11 +266,13 @@ fn test_http1_keepalive_close() {
let mut data = vec![0; 1024];
let res = stream.read(&mut data).unwrap();
assert_eq!(res, 0);
})
}
#[test]
fn test_http10_keepalive_default_close() {
let srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
});
@@ -255,17 +285,20 @@ fn test_http10_keepalive_default_close() {
let mut data = vec![0; 1024];
let res = stream.read(&mut data).unwrap();
assert_eq!(res, 0);
})
}
#[test]
fn test_http10_keepalive() {
let srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
});
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
let _ = stream
.write_all(b"GET /test/tests/test HTTP/1.0\r\nconnection: keep-alive\r\n\r\n");
let _ = stream.write_all(
b"GET /test/tests/test HTTP/1.0\r\nconnection: keep-alive\r\n\r\n",
);
let mut data = vec![0; 1024];
let _ = stream.read(&mut data);
assert_eq!(&data[..17], b"HTTP/1.0 200 OK\r\n");
@@ -279,11 +312,13 @@ fn test_http10_keepalive() {
let mut data = vec![0; 1024];
let res = stream.read(&mut data).unwrap();
assert_eq!(res, 0);
})
}
#[test]
fn test_http1_keepalive_disabled() {
let srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::build()
.keep_alive(KeepAlive::Disabled)
.h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
@@ -298,16 +333,18 @@ fn test_http1_keepalive_disabled() {
let mut data = vec![0; 1024];
let res = stream.read(&mut data).unwrap();
assert_eq!(res, 0);
})
}
#[test]
fn test_content_length() {
block_on(async {
use actix_http::http::{
header::{HeaderName, HeaderValue},
StatusCode,
};
let mut srv = TestServer::new(|| {
let srv = TestServer::start(|| {
HttpService::build().h1(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
@@ -327,35 +364,31 @@ fn test_content_length() {
{
for i in 0..4 {
let req = srv
.request(http::Method::GET, srv.url(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
let req = srv.request(http::Method::GET, srv.url(&format!("/{}", i)));
let response = req.send().await.unwrap();
assert_eq!(response.headers().get(&header), None);
let req = srv
.request(http::Method::HEAD, srv.url(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
let req = srv.request(http::Method::HEAD, srv.url(&format!("/{}", i)));
let response = req.send().await.unwrap();
assert_eq!(response.headers().get(&header), None);
}
for i in 4..6 {
let req = srv
.request(http::Method::GET, srv.url(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
let req = srv.request(http::Method::GET, srv.url(&format!("/{}", i)));
let response = req.send().await.unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
}
}
})
}
#[test]
fn test_h1_headers() {
block_on(async {
let data = STR.repeat(10);
let data2 = data.clone();
let mut srv = TestServer::new(move || {
let mut srv = TestServer::start(move || {
let data = data.clone();
HttpService::build().h1(move |_| {
let mut builder = Response::Ok();
@@ -381,12 +414,13 @@ fn test_h1_headers() {
})
});
let response = srv.block_on(srv.get("/").send()).unwrap();
let response = srv.get("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from(data2));
})
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
@@ -413,25 +447,28 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
#[test]
fn test_h1_body() {
let mut srv = TestServer::new(|| {
HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
block_on(async {
let mut srv = TestServer::start(|| {
HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR)))
});
let response = srv.block_on(srv.get("/").send()).unwrap();
let response = srv.get("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h1_head_empty() {
let mut srv = TestServer::new(|| {
block_on(async {
let mut srv = TestServer::start(|| {
HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR)))
});
let response = srv.block_on(srv.head("/").send()).unwrap();
let response = srv.head("/").send().await.unwrap();
assert!(response.status().is_success());
{
@@ -443,19 +480,21 @@ fn test_h1_head_empty() {
}
// read response
let bytes = srv.load_body(response).unwrap();
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
})
}
#[test]
fn test_h1_head_binary() {
let mut srv = TestServer::new(|| {
block_on(async {
let mut srv = TestServer::start(|| {
HttpService::build().h1(|_| {
ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).body(STR))
})
});
let response = srv.block_on(srv.head("/").send()).unwrap();
let response = srv.head("/").send().await.unwrap();
assert!(response.status().is_success());
{
@@ -467,17 +506,19 @@ fn test_h1_head_binary() {
}
// read response
let bytes = srv.load_body(response).unwrap();
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
})
}
#[test]
fn test_h1_head_binary2() {
let mut srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR)))
});
let response = srv.block_on(srv.head("/").send()).unwrap();
let response = srv.head("/").send().await.unwrap();
assert!(response.status().is_success());
{
@@ -487,32 +528,36 @@ fn test_h1_head_binary2() {
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
})
}
#[test]
fn test_h1_body_length() {
let mut srv = TestServer::new(|| {
block_on(async {
let mut srv = TestServer::start(|| {
HttpService::build().h1(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
let body = once(ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)),
)
})
});
let response = srv.block_on(srv.get("/").send()).unwrap();
let response = srv.get("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h1_body_chunked_explicit() {
let mut srv = TestServer::new(|| {
block_on(async {
let mut srv = TestServer::start(|| {
HttpService::build().h1(|_| {
let body = once::<_, Error>(Ok(Bytes::from_static(STR.as_ref())));
let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
@@ -521,7 +566,7 @@ fn test_h1_body_chunked_explicit() {
})
});
let response = srv.block_on(srv.get("/").send()).unwrap();
let response = srv.get("/").send().await.unwrap();
assert!(response.status().is_success());
assert_eq!(
response
@@ -534,22 +579,24 @@ fn test_h1_body_chunked_explicit() {
);
// read response
let bytes = srv.load_body(response).unwrap();
let bytes = srv.load_body(response).await.unwrap();
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h1_body_chunked_implicit() {
let mut srv = TestServer::new(|| {
block_on(async {
let mut srv = TestServer::start(|| {
HttpService::build().h1(|_| {
let body = once::<_, Error>(Ok(Bytes::from_static(STR.as_ref())));
let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(Response::Ok().streaming(body))
})
});
let response = srv.block_on(srv.get("/").send()).unwrap();
let response = srv.get("/").send().await.unwrap();
assert!(response.status().is_success());
assert_eq!(
response
@@ -562,51 +609,57 @@ fn test_h1_body_chunked_implicit() {
);
// read response
let bytes = srv.load_body(response).unwrap();
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_h1_response_http_error_handling() {
let mut srv = TestServer::new(|| {
HttpService::build().h1(new_service_cfg(|_: &ServerConfig| {
Ok::<_, ()>(|_| {
block_on(async {
let mut srv = TestServer::start(|| {
HttpService::build().h1(factory_fn_cfg(|_: &ServerConfig| {
ok::<_, ()>(pipeline(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>(
Response::Ok()
.header(http::header::CONTENT_TYPE, broken_header)
.body(STR),
)
})
}))
}))
});
let response = srv.block_on(srv.get("/").send()).unwrap();
let response = srv.get("/").send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
// read response
let bytes = srv.load_body(response).unwrap();
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
})
}
#[test]
fn test_h1_service_error() {
let mut srv = TestServer::new(|| {
block_on(async {
let mut srv = TestServer::start(|| {
HttpService::build()
.h1(|_| Err::<Response, Error>(error::ErrorBadRequest("error")))
.h1(|_| future::err::<Response, Error>(error::ErrorBadRequest("error")))
});
let response = srv.block_on(srv.get("/").send()).unwrap();
let response = srv.get("/").send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).unwrap();
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
})
}
#[test]
fn test_h1_on_connect() {
let mut srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::build()
.on_connect(|_| 10usize)
.h1(|req: Request| {
@@ -615,6 +668,7 @@ fn test_h1_on_connect() {
})
});
let response = srv.block_on(srv.get("/").send()).unwrap();
let response = srv.get("/").send().await.unwrap();
assert!(response.status().is_success());
})
}

View File

@@ -1,480 +0,0 @@
#![cfg(feature = "ssl")]
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http_test::TestServer;
use actix_server::ssl::OpensslAcceptor;
use actix_server_config::ServerConfig;
use actix_service::{new_service_cfg, NewService};
use bytes::{Bytes, BytesMut};
use futures::future::{ok, Future};
use futures::stream::{once, Stream};
use openssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod};
use std::io::Result;
use actix_http::error::{ErrorBadRequest, PayloadError};
use actix_http::http::header::{self, HeaderName, HeaderValue};
use actix_http::http::{Method, StatusCode, Version};
use actix_http::httpmessage::HttpMessage;
use actix_http::{body, Error, HttpService, Request, Response};
fn load_body<S>(stream: S) -> impl Future<Item = BytesMut, Error = PayloadError>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
stream.fold(BytesMut::new(), move |mut body, chunk| {
body.extend_from_slice(&chunk);
Ok::<_, PayloadError>(body)
})
}
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> {
// load ssl keys
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
builder
.set_private_key_file("../tests/key.pem", SslFiletype::PEM)
.unwrap();
builder
.set_certificate_chain_file("../tests/cert.pem")
.unwrap();
builder.set_alpn_select_callback(|_, protos| {
const H2: &[u8] = b"\x02h2";
if protos.windows(3).any(|window| window == H2) {
Ok(b"h2")
} else {
Err(AlpnError::NOACK)
}
});
builder.set_alpn_protos(b"\x02h2")?;
Ok(OpensslAcceptor::new(builder.build()))
}
#[test]
fn test_h2() -> Result<()> {
let openssl = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| ok::<_, Error>(Response::Ok().finish()))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
Ok(())
}
#[test]
fn test_h2_1() -> Result<()> {
let openssl = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.finish(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_2);
ok::<_, Error>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
Ok(())
}
#[test]
fn test_h2_body() -> Result<()> {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let openssl = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|mut req: Request<_>| {
load_body(req.take_payload())
.and_then(|body| Ok(Response::Ok().body(body)))
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send_body(data.clone())).unwrap();
assert!(response.status().is_success());
let body = srv.load_body(response).unwrap();
assert_eq!(&body, data.as_bytes());
Ok(())
}
#[test]
fn test_h2_content_length() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
ok::<_, ()>(Response::new(statuses[indx]))
})
.map_err(|_| ()),
)
});
let header = HeaderName::from_static("content-length");
let value = HeaderValue::from_static("0");
{
for i in 0..4 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None);
let req = srv
.request(Method::HEAD, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None);
}
for i in 4..6 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
}
}
}
#[test]
fn test_h2_headers() {
let data = STR.repeat(10);
let data2 = data.clone();
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
let data = data.clone();
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build().h2(move |_| {
let mut builder = Response::Ok();
for idx in 0..90 {
builder.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
ok::<_, ()>(builder.body(data.clone()))
}).map_err(|_| ()))
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from(data2));
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test]
fn test_h2_body2() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_head_empty() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty());
}
#[test]
fn test_h2_head_binary() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
ok::<_, ()>(
Response::Ok().content_length(STR.len() as u64).body(STR),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty());
}
#[test]
fn test_h2_head_binary2() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
}
#[test]
fn test_h2_body_length() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.body(body::SizedStream::new(STR.len() as u64, body)),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_body_chunked_explicit() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
let body =
once::<_, Error>(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
assert!(!response.headers().contains_key(header::TRANSFER_ENCODING));
// read response
let bytes = srv.load_body(response).unwrap();
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_response_http_error_handling() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(new_service_cfg(|_: &ServerConfig| {
Ok::<_, ()>(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>(
Response::Ok()
.header(header::CONTENT_TYPE, broken_header)
.body(STR),
)
})
}))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
}
#[test]
fn test_h2_service_error() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| Err::<Response, Error>(ErrorBadRequest("error")))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
}
#[test]
fn test_h2_on_connect() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.on_connect(|_| 10usize)
.h2(|req: Request| {
assert!(req.extensions().contains::<usize>());
ok::<_, ()>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
}

View File

@@ -1,26 +1,27 @@
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::{body, h1, ws, Error, HttpService, Request, Response};
use actix_http_test::TestServer;
use actix_http_test::{block_on, TestServer};
use actix_utils::framed::FramedTransport;
use bytes::{Bytes, BytesMut};
use futures::future::{self, ok};
use futures::{Future, Sink, Stream};
use futures::future;
use futures::{SinkExt, StreamExt};
fn ws_service<T: AsyncRead + AsyncWrite>(
(req, framed): (Request, Framed<T, h1::Codec>),
) -> impl Future<Item = (), Error = Error> {
async fn ws_service<T: AsyncRead + AsyncWrite + Unpin>(
(req, mut framed): (Request, Framed<T, h1::Codec>),
) -> Result<(), Error> {
let res = ws::handshake(req.head()).unwrap().message_body(());
framed
.send((res, body::BodySize::None).into())
.map_err(|_| panic!())
.and_then(|framed| {
.await
.unwrap();
FramedTransport::new(framed.into_framed(ws::Codec::new()), service)
.await
.map_err(|_| panic!())
})
}
fn service(msg: ws::Frame) -> impl Future<Item = ws::Message, Error = Error> {
async fn service(msg: ws::Frame) -> Result<ws::Message, Error> {
let msg = match msg {
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
ws::Frame::Text(text) => {
@@ -30,47 +31,56 @@ fn service(msg: ws::Frame) -> impl Future<Item = ws::Message, Error = Error> {
ws::Frame::Close(reason) => ws::Message::Close(reason),
_ => panic!(),
};
ok(msg)
Ok(msg)
}
#[test]
fn test_simple() {
let mut srv = TestServer::new(|| {
block_on(async {
let mut srv = TestServer::start(|| {
HttpService::build()
.upgrade(ws_service)
.upgrade(actix_service::service_fn(ws_service))
.finish(|_| future::ok::<_, ()>(Response::NotFound()))
});
// client service
let framed = srv.ws().unwrap();
let framed = srv
.block_on(framed.send(ws::Message::Text("text".to_string())))
let mut framed = srv.ws().await.unwrap();
framed
.send(ws::Message::Text("text".to_string()))
.await
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text")))));
let framed = srv
.block_on(framed.send(ws::Message::Binary("text".into())))
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
let (item, mut framed) = framed.into_future().await;
assert_eq!(
item,
Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into())))
item.unwrap().unwrap(),
ws::Frame::Text(Some(BytesMut::from("text")))
);
let framed = srv
.block_on(framed.send(ws::Message::Ping("text".into())))
framed
.send(ws::Message::Binary("text".into()))
.await
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into())));
let framed = srv
.block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))))
.unwrap();
let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
let (item, mut framed) = framed.into_future().await;
assert_eq!(
item,
Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into())))
item.unwrap().unwrap(),
ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))
);
framed.send(ws::Message::Ping("text".into())).await.unwrap();
let (item, mut framed) = framed.into_future().await;
assert_eq!(
item.unwrap().unwrap(),
ws::Frame::Pong("text".to_string().into())
);
framed
.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
.await
.unwrap();
let (item, _framed) = framed.into_future().await;
assert_eq!(
item.unwrap().unwrap(),
ws::Frame::Close(Some(ws::CloseCode::Normal.into()))
);
})
}

View File

@@ -1,6 +1,6 @@
[package]
name = "actix-identity"
version = "0.1.0"
version = "0.2.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Identity service for actix web framework."
readme = "README.md"
@@ -17,14 +17,14 @@ name = "actix_identity"
path = "src/lib.rs"
[dependencies]
actix-web = { version = "1.0.0", default-features = false, features = ["secure-cookies"] }
actix-service = "0.4.0"
futures = "0.1.25"
actix-web = { version = "2.0.0-alpha.1", default-features = false, features = ["secure-cookies"] }
actix-service = "1.0.0-alpha.1"
futures = "0.3.1"
serde = "1.0"
serde_json = "1.0"
time = "0.1.42"
[dev-dependencies]
actix-rt = "0.2.2"
actix-http = "0.2.3"
actix-rt = "1.0.0-alpha.1"
actix-http = "0.3.0-alpha.1"
bytes = "0.4"

View File

@@ -16,7 +16,7 @@
//! use actix_web::*;
//! use actix_identity::{Identity, CookieIdentityPolicy, IdentityService};
//!
//! fn index(id: Identity) -> String {
//! async fn index(id: Identity) -> String {
//! // access request identity
//! if let Some(id) = id.identity() {
//! format!("Welcome! {}", id)
@@ -25,12 +25,12 @@
//! }
//! }
//!
//! fn login(id: Identity) -> HttpResponse {
//! async fn login(id: Identity) -> HttpResponse {
//! id.remember("User1".to_owned()); // <- remember identity
//! HttpResponse::Ok().finish()
//! }
//!
//! fn logout(id: Identity) -> HttpResponse {
//! async fn logout(id: Identity) -> HttpResponse {
//! id.forget(); // <- remove identity
//! HttpResponse::Ok().finish()
//! }
@@ -47,12 +47,13 @@
//! }
//! ```
use std::cell::RefCell;
use std::future::Future;
use std::rc::Rc;
use std::task::{Context, Poll};
use std::time::SystemTime;
use actix_service::{Service, Transform};
use futures::future::{ok, Either, FutureResult};
use futures::{Future, IntoFuture, Poll};
use futures::future::{ok, FutureExt, LocalBoxFuture, Ready};
use serde::{Deserialize, Serialize};
use time::Duration;
@@ -165,21 +166,21 @@ where
impl FromRequest for Identity {
type Config = ();
type Error = Error;
type Future = Result<Identity, Error>;
type Future = Ready<Result<Identity, Error>>;
#[inline]
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
Ok(Identity(req.clone()))
ok(Identity(req.clone()))
}
}
/// Identity policy definition.
pub trait IdentityPolicy: Sized + 'static {
/// The return type of the middleware
type Future: IntoFuture<Item = Option<String>, Error = Error>;
type Future: Future<Output = Result<Option<String>, Error>>;
/// The return type of the middleware
type ResponseFuture: IntoFuture<Item = (), Error = Error>;
type ResponseFuture: Future<Output = Result<(), Error>>;
/// Parse the session from request and load data from a service identity.
fn from_request(&self, request: &mut ServiceRequest) -> Self::Future;
@@ -234,7 +235,7 @@ where
type Error = Error;
type InitError = ();
type Transform = IdentityServiceMiddleware<S, T>;
type Future = FutureResult<Self::Transform, Self::InitError>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(IdentityServiceMiddleware {
@@ -261,46 +262,39 @@ where
type Request = ServiceRequest;
type Response = ServiceResponse<B>;
type Error = Error;
type Future = Box<dyn Future<Item = Self::Response, Error = Self::Error>>;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.borrow_mut().poll_ready()
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.service.borrow_mut().poll_ready(cx)
}
fn call(&mut self, mut req: ServiceRequest) -> Self::Future {
let srv = self.service.clone();
let backend = self.backend.clone();
let fut = self.backend.from_request(&mut req);
Box::new(
self.backend.from_request(&mut req).into_future().then(
move |res| match res {
async move {
match fut.await {
Ok(id) => {
req.extensions_mut()
.insert(IdentityItem { id, changed: false });
Either::A(srv.borrow_mut().call(req).and_then(move |mut res| {
let id =
res.request().extensions_mut().remove::<IdentityItem>();
let mut res = srv.borrow_mut().call(req).await?;
let id = res.request().extensions_mut().remove::<IdentityItem>();
if let Some(id) = id {
Either::A(
backend
.to_response(id.id, id.changed, &mut res)
.into_future()
.then(move |t| match t {
match backend.to_response(id.id, id.changed, &mut res).await {
Ok(_) => Ok(res),
Err(e) => Ok(res.error_response(e)),
}),
)
}
} else {
Either::B(ok(res))
Ok(res)
}
}))
}
Err(err) => Either::B(ok(req.error_response(err))),
},
),
)
Err(err) => Ok(req.error_response(err)),
}
}
.boxed_local()
}
}
@@ -547,11 +541,11 @@ impl CookieIdentityPolicy {
}
impl IdentityPolicy for CookieIdentityPolicy {
type Future = Result<Option<String>, Error>;
type ResponseFuture = Result<(), Error>;
type Future = Ready<Result<Option<String>, Error>>;
type ResponseFuture = Ready<Result<(), Error>>;
fn from_request(&self, req: &mut ServiceRequest) -> Self::Future {
Ok(self.0.load(req).map(
ok(self.0.load(req).map(
|CookieValue {
identity,
login_timestamp,
@@ -603,7 +597,7 @@ impl IdentityPolicy for CookieIdentityPolicy {
} else {
Ok(())
};
Ok(())
ok(())
}
}
@@ -613,7 +607,7 @@ mod tests {
use super::*;
use actix_web::http::StatusCode;
use actix_web::test::{self, TestRequest};
use actix_web::test::{self, block_on, TestRequest};
use actix_web::{web, App, Error, HttpResponse};
const COOKIE_KEY_MASTER: [u8; 32] = [0; 32];
@@ -622,6 +616,7 @@ mod tests {
#[test]
fn test_identity() {
block_on(async {
let mut srv = test::init_service(
App::new()
.wrap(IdentityService::new(
@@ -650,13 +645,20 @@ mod tests {
HttpResponse::BadRequest()
}
})),
);
let resp =
test::call_service(&mut srv, TestRequest::with_uri("/index").to_request());
)
.await;
let resp = test::call_service(
&mut srv,
TestRequest::with_uri("/index").to_request(),
)
.await;
assert_eq!(resp.status(), StatusCode::OK);
let resp =
test::call_service(&mut srv, TestRequest::with_uri("/login").to_request());
let resp = test::call_service(
&mut srv,
TestRequest::with_uri("/login").to_request(),
)
.await;
assert_eq!(resp.status(), StatusCode::OK);
let c = resp.response().cookies().next().unwrap().to_owned();
@@ -665,7 +667,8 @@ mod tests {
TestRequest::with_uri("/index")
.cookie(c.clone())
.to_request(),
);
)
.await;
assert_eq!(resp.status(), StatusCode::CREATED);
let resp = test::call_service(
@@ -673,13 +676,16 @@ mod tests {
TestRequest::with_uri("/logout")
.cookie(c.clone())
.to_request(),
);
)
.await;
assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().contains_key(header::SET_COOKIE))
})
}
#[test]
fn test_identity_max_age_time() {
block_on(async {
let duration = Duration::days(1);
let mut srv = test::init_service(
App::new()
@@ -695,17 +701,23 @@ mod tests {
id.remember("test".to_string());
HttpResponse::Ok()
})),
);
let resp =
test::call_service(&mut srv, TestRequest::with_uri("/login").to_request());
)
.await;
let resp = test::call_service(
&mut srv,
TestRequest::with_uri("/login").to_request(),
)
.await;
assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().contains_key(header::SET_COOKIE));
let c = resp.response().cookies().next().unwrap().to_owned();
assert_eq!(duration, c.max_age().unwrap());
})
}
#[test]
fn test_identity_max_age() {
block_on(async {
let seconds = 60;
let mut srv = test::init_service(
App::new()
@@ -721,16 +733,21 @@ mod tests {
id.remember("test".to_string());
HttpResponse::Ok()
})),
);
let resp =
test::call_service(&mut srv, TestRequest::with_uri("/login").to_request());
)
.await;
let resp = test::call_service(
&mut srv,
TestRequest::with_uri("/login").to_request(),
)
.await;
assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().contains_key(header::SET_COOKIE));
let c = resp.response().cookies().next().unwrap().to_owned();
assert_eq!(Duration::seconds(seconds as i64), c.max_age().unwrap());
})
}
fn create_identity_server<
async fn create_identity_server<
F: Fn(CookieIdentityPolicy) -> CookieIdentityPolicy + Sync + Send + Clone + 'static,
>(
f: F,
@@ -747,13 +764,16 @@ mod tests {
.secure(false)
.name(COOKIE_NAME))))
.service(web::resource("/").to(|id: Identity| {
async move {
let identity = id.identity();
if identity.is_none() {
id.remember(COOKIE_LOGIN.to_string())
}
web::Json(identity)
}
})),
)
.await
}
fn legacy_login_cookie(identity: &'static str) -> Cookie<'static> {
@@ -786,15 +806,8 @@ mod tests {
jar.get(COOKIE_NAME).unwrap().clone()
}
fn assert_logged_in(response: &mut ServiceResponse, identity: Option<&str>) {
use bytes::BytesMut;
use futures::Stream;
let bytes =
test::block_on(response.take_body().fold(BytesMut::new(), |mut b, c| {
b.extend(c);
Ok::<_, Error>(b)
}))
.unwrap();
async fn assert_logged_in(response: ServiceResponse, identity: Option<&str>) {
let bytes = test::read_body(response).await;
let resp: Option<String> = serde_json::from_slice(&bytes[..]).unwrap();
assert_eq!(resp.as_ref().map(|s| s.borrow()), identity);
}
@@ -874,106 +887,130 @@ mod tests {
#[test]
fn test_identity_legacy_cookie_is_set() {
let mut srv = create_identity_server(|c| c);
block_on(async {
let mut srv = create_identity_server(|c| c).await;
let mut resp =
test::call_service(&mut srv, TestRequest::with_uri("/").to_request());
assert_logged_in(&mut resp, None);
test::call_service(&mut srv, TestRequest::with_uri("/").to_request())
.await;
assert_legacy_login_cookie(&mut resp, COOKIE_LOGIN);
assert_logged_in(resp, None).await;
})
}
#[test]
fn test_identity_legacy_cookie_works() {
let mut srv = create_identity_server(|c| c);
block_on(async {
let mut srv = create_identity_server(|c| c).await;
let cookie = legacy_login_cookie(COOKIE_LOGIN);
let mut resp = test::call_service(
&mut srv,
TestRequest::with_uri("/")
.cookie(cookie.clone())
.to_request(),
);
assert_logged_in(&mut resp, Some(COOKIE_LOGIN));
)
.await;
assert_no_login_cookie(&mut resp);
assert_logged_in(resp, Some(COOKIE_LOGIN)).await;
})
}
#[test]
fn test_identity_legacy_cookie_rejected_if_visit_timestamp_needed() {
let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90)));
block_on(async {
let mut srv =
create_identity_server(|c| c.visit_deadline(Duration::days(90))).await;
let cookie = legacy_login_cookie(COOKIE_LOGIN);
let mut resp = test::call_service(
&mut srv,
TestRequest::with_uri("/")
.cookie(cookie.clone())
.to_request(),
);
assert_logged_in(&mut resp, None);
)
.await;
assert_login_cookie(
&mut resp,
COOKIE_LOGIN,
LoginTimestampCheck::NoTimestamp,
VisitTimeStampCheck::NewTimestamp,
);
assert_logged_in(resp, None).await;
})
}
#[test]
fn test_identity_legacy_cookie_rejected_if_login_timestamp_needed() {
let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90)));
block_on(async {
let mut srv =
create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
let cookie = legacy_login_cookie(COOKIE_LOGIN);
let mut resp = test::call_service(
&mut srv,
TestRequest::with_uri("/")
.cookie(cookie.clone())
.to_request(),
);
assert_logged_in(&mut resp, None);
)
.await;
assert_login_cookie(
&mut resp,
COOKIE_LOGIN,
LoginTimestampCheck::NewTimestamp,
VisitTimeStampCheck::NoTimestamp,
);
assert_logged_in(resp, None).await;
})
}
#[test]
fn test_identity_cookie_rejected_if_login_timestamp_needed() {
let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90)));
block_on(async {
let mut srv =
create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
let cookie = login_cookie(COOKIE_LOGIN, None, Some(SystemTime::now()));
let mut resp = test::call_service(
&mut srv,
TestRequest::with_uri("/")
.cookie(cookie.clone())
.to_request(),
);
assert_logged_in(&mut resp, None);
)
.await;
assert_login_cookie(
&mut resp,
COOKIE_LOGIN,
LoginTimestampCheck::NewTimestamp,
VisitTimeStampCheck::NoTimestamp,
);
assert_logged_in(resp, None).await;
})
}
#[test]
fn test_identity_cookie_rejected_if_visit_timestamp_needed() {
let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90)));
block_on(async {
let mut srv =
create_identity_server(|c| c.visit_deadline(Duration::days(90))).await;
let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None);
let mut resp = test::call_service(
&mut srv,
TestRequest::with_uri("/")
.cookie(cookie.clone())
.to_request(),
);
assert_logged_in(&mut resp, None);
)
.await;
assert_login_cookie(
&mut resp,
COOKIE_LOGIN,
LoginTimestampCheck::NoTimestamp,
VisitTimeStampCheck::NewTimestamp,
);
assert_logged_in(resp, None).await;
})
}
#[test]
fn test_identity_cookie_rejected_if_login_timestamp_too_old() {
let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90)));
block_on(async {
let mut srv =
create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
let cookie = login_cookie(
COOKIE_LOGIN,
Some(SystemTime::now() - Duration::days(180).to_std().unwrap()),
@@ -984,19 +1021,23 @@ mod tests {
TestRequest::with_uri("/")
.cookie(cookie.clone())
.to_request(),
);
assert_logged_in(&mut resp, None);
)
.await;
assert_login_cookie(
&mut resp,
COOKIE_LOGIN,
LoginTimestampCheck::NewTimestamp,
VisitTimeStampCheck::NoTimestamp,
);
assert_logged_in(resp, None).await;
})
}
#[test]
fn test_identity_cookie_rejected_if_visit_timestamp_too_old() {
let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90)));
block_on(async {
let mut srv =
create_identity_server(|c| c.visit_deadline(Duration::days(90))).await;
let cookie = login_cookie(
COOKIE_LOGIN,
None,
@@ -1007,36 +1048,44 @@ mod tests {
TestRequest::with_uri("/")
.cookie(cookie.clone())
.to_request(),
);
assert_logged_in(&mut resp, None);
)
.await;
assert_login_cookie(
&mut resp,
COOKIE_LOGIN,
LoginTimestampCheck::NoTimestamp,
VisitTimeStampCheck::NewTimestamp,
);
assert_logged_in(resp, None).await;
})
}
#[test]
fn test_identity_cookie_not_updated_on_login_deadline() {
let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90)));
block_on(async {
let mut srv =
create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None);
let mut resp = test::call_service(
&mut srv,
TestRequest::with_uri("/")
.cookie(cookie.clone())
.to_request(),
);
assert_logged_in(&mut resp, Some(COOKIE_LOGIN));
)
.await;
assert_no_login_cookie(&mut resp);
assert_logged_in(resp, Some(COOKIE_LOGIN)).await;
})
}
#[test]
fn test_identity_cookie_updated_on_visit_deadline() {
block_on(async {
let mut srv = create_identity_server(|c| {
c.visit_deadline(Duration::days(90))
.login_deadline(Duration::days(90))
});
})
.await;
let timestamp = SystemTime::now() - Duration::days(1).to_std().unwrap();
let cookie = login_cookie(COOKIE_LOGIN, Some(timestamp), Some(timestamp));
let mut resp = test::call_service(
@@ -1044,13 +1093,15 @@ mod tests {
TestRequest::with_uri("/")
.cookie(cookie.clone())
.to_request(),
);
assert_logged_in(&mut resp, Some(COOKIE_LOGIN));
)
.await;
assert_login_cookie(
&mut resp,
COOKIE_LOGIN,
LoginTimestampCheck::OldTimestamp(timestamp),
VisitTimeStampCheck::NewTimestamp,
);
assert_logged_in(resp, Some(COOKIE_LOGIN)).await;
})
}
}

View File

@@ -1,6 +1,6 @@
[package]
name = "actix-multipart"
version = "0.1.4"
version = "0.2.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Multipart support for actix web framework."
readme = "README.md"
@@ -18,17 +18,18 @@ name = "actix_multipart"
path = "src/lib.rs"
[dependencies]
actix-web = { version = "1.0.0", default-features = false }
actix-service = "0.4.1"
actix-web = { version = "2.0.0-alpha.1", default-features = false }
actix-service = "1.0.0-alpha.1"
actix-utils = "0.5.0-alpha.1"
bytes = "0.4"
derive_more = "0.15.0"
httparse = "1.3"
futures = "0.1.25"
futures = "0.3.1"
log = "0.4"
mime = "0.3"
time = "0.1"
twoway = "0.2"
[dev-dependencies]
actix-rt = "0.2.2"
actix-http = "0.2.4"
actix-rt = "1.0.0-alpha.1"
actix-http = "0.3.0-alpha.1"

View File

@@ -1,5 +1,6 @@
//! Multipart payload support
use actix_web::{dev::Payload, Error, FromRequest, HttpRequest};
use futures::future::{ok, Ready};
use crate::server::Multipart;
@@ -10,33 +11,31 @@ use crate::server::Multipart;
/// ## Server example
///
/// ```rust
/// # use futures::{Future, Stream};
/// # use futures::future::{ok, result, Either};
/// use futures::{Stream, StreamExt};
/// use actix_web::{web, HttpResponse, Error};
/// use actix_multipart as mp;
///
/// fn index(payload: mp::Multipart) -> impl Future<Item = HttpResponse, Error = Error> {
/// payload.from_err() // <- get multipart stream for current request
/// .and_then(|field| { // <- iterate over multipart items
/// async fn index(mut payload: mp::Multipart) -> Result<HttpResponse, Error> {
/// // iterate over multipart stream
/// while let Some(item) = payload.next().await {
/// let mut field = item?;
///
/// // Field in turn is stream of *Bytes* object
/// field.from_err()
/// .fold((), |_, chunk| {
/// println!("-- CHUNK: \n{:?}", std::str::from_utf8(&chunk));
/// Ok::<_, Error>(())
/// })
/// })
/// .fold((), |_, _| Ok::<_, Error>(()))
/// .map(|_| HttpResponse::Ok().into())
/// while let Some(chunk) = field.next().await {
/// println!("-- CHUNK: \n{:?}", std::str::from_utf8(&chunk?));
/// }
/// }
/// Ok(HttpResponse::Ok().into())
/// }
/// # fn main() {}
/// ```
impl FromRequest for Multipart {
type Error = Error;
type Future = Result<Multipart, Error>;
type Future = Ready<Result<Multipart, Error>>;
type Config = ();
#[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
Ok(Multipart::new(req.headers(), payload.take()))
ok(Multipart::new(req.headers(), payload.take()))
}
}

View File

@@ -1,15 +1,17 @@
//! Multipart payload support
use std::cell::{Cell, RefCell, RefMut};
use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use std::{cmp, fmt};
use bytes::{Bytes, BytesMut};
use futures::task::{current as current_task, Task};
use futures::{Async, Poll, Stream};
use futures::stream::{LocalBoxStream, Stream, StreamExt};
use httparse;
use mime;
use actix_utils::task::LocalWaker;
use actix_web::error::{ParseError, PayloadError};
use actix_web::http::header::{
self, ContentDisposition, HeaderMap, HeaderName, HeaderValue,
@@ -60,7 +62,7 @@ impl Multipart {
/// Create multipart instance for boundary.
pub fn new<S>(headers: &HeaderMap, stream: S) -> Multipart
where
S: Stream<Item = Bytes, Error = PayloadError> + 'static,
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin + 'static,
{
match Self::boundary(headers) {
Ok(boundary) => Multipart {
@@ -104,22 +106,25 @@ impl Multipart {
}
impl Stream for Multipart {
type Item = Field;
type Error = MultipartError;
type Item = Result<Field, MultipartError>;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Option<Self::Item>> {
if let Some(err) = self.error.take() {
Err(err)
Poll::Ready(Some(Err(err)))
} else if self.safety.current() {
let mut inner = self.inner.as_mut().unwrap().borrow_mut();
if let Some(mut payload) = inner.payload.get_mut(&self.safety) {
payload.poll_stream()?;
let this = self.get_mut();
let mut inner = this.inner.as_mut().unwrap().borrow_mut();
if let Some(mut payload) = inner.payload.get_mut(&this.safety) {
payload.poll_stream(cx)?;
}
inner.poll(&self.safety)
inner.poll(&this.safety, cx)
} else if !self.safety.is_clean() {
Err(MultipartError::NotConsumed)
Poll::Ready(Some(Err(MultipartError::NotConsumed)))
} else {
Ok(Async::NotReady)
Poll::Pending
}
}
}
@@ -238,9 +243,13 @@ impl InnerMultipart {
Ok(Some(eof))
}
fn poll(&mut self, safety: &Safety) -> Poll<Option<Field>, MultipartError> {
fn poll(
&mut self,
safety: &Safety,
cx: &mut Context,
) -> Poll<Option<Result<Field, MultipartError>>> {
if self.state == InnerState::Eof {
Ok(Async::Ready(None))
Poll::Ready(None)
} else {
// release field
loop {
@@ -249,10 +258,13 @@ impl InnerMultipart {
if safety.current() {
let stop = match self.item {
InnerMultipartItem::Field(ref mut field) => {
match field.borrow_mut().poll(safety)? {
Async::NotReady => return Ok(Async::NotReady),
Async::Ready(Some(_)) => continue,
Async::Ready(None) => true,
match field.borrow_mut().poll(safety) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(_))) => continue,
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e)))
}
Poll::Ready(None) => true,
}
}
InnerMultipartItem::None => false,
@@ -277,12 +289,12 @@ impl InnerMultipart {
Some(eof) => {
if eof {
self.state = InnerState::Eof;
return Ok(Async::Ready(None));
return Poll::Ready(None);
} else {
self.state = InnerState::Headers;
}
}
None => return Ok(Async::NotReady),
None => return Poll::Pending,
}
}
// read boundary
@@ -291,11 +303,11 @@ impl InnerMultipart {
&mut *payload,
&self.boundary,
)? {
None => return Ok(Async::NotReady),
None => return Poll::Pending,
Some(eof) => {
if eof {
self.state = InnerState::Eof;
return Ok(Async::Ready(None));
return Poll::Ready(None);
} else {
self.state = InnerState::Headers;
}
@@ -311,14 +323,14 @@ impl InnerMultipart {
self.state = InnerState::Boundary;
headers
} else {
return Ok(Async::NotReady);
return Poll::Pending;
}
} else {
unreachable!()
}
} else {
log::debug!("NotReady: field is in flight");
return Ok(Async::NotReady);
return Poll::Pending;
};
// content type
@@ -335,7 +347,7 @@ impl InnerMultipart {
// nested multipart stream
if mt.type_() == mime::MULTIPART {
Err(MultipartError::Nested)
Poll::Ready(Some(Err(MultipartError::Nested)))
} else {
let field = Rc::new(RefCell::new(InnerField::new(
self.payload.clone(),
@@ -344,12 +356,7 @@ impl InnerMultipart {
)?));
self.item = InnerMultipartItem::Field(Rc::clone(&field));
Ok(Async::Ready(Some(Field::new(
safety.clone(),
headers,
mt,
field,
))))
Poll::Ready(Some(Ok(Field::new(safety.clone(cx), headers, mt, field))))
}
}
}
@@ -409,23 +416,21 @@ impl Field {
}
impl Stream for Field {
type Item = Bytes;
type Error = MultipartError;
type Item = Result<Bytes, MultipartError>;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
if self.safety.current() {
let mut inner = self.inner.borrow_mut();
if let Some(mut payload) =
inner.payload.as_ref().unwrap().get_mut(&self.safety)
{
payload.poll_stream()?;
payload.poll_stream(cx)?;
}
inner.poll(&self.safety)
} else if !self.safety.is_clean() {
Err(MultipartError::NotConsumed)
Poll::Ready(Some(Err(MultipartError::NotConsumed)))
} else {
Ok(Async::NotReady)
Poll::Pending
}
}
}
@@ -482,9 +487,9 @@ impl InnerField {
fn read_len(
payload: &mut PayloadBuffer,
size: &mut u64,
) -> Poll<Option<Bytes>, MultipartError> {
) -> Poll<Option<Result<Bytes, MultipartError>>> {
if *size == 0 {
Ok(Async::Ready(None))
Poll::Ready(None)
} else {
match payload.read_max(*size)? {
Some(mut chunk) => {
@@ -494,13 +499,13 @@ impl InnerField {
if !chunk.is_empty() {
payload.unprocessed(chunk);
}
Ok(Async::Ready(Some(ch)))
Poll::Ready(Some(Ok(ch)))
}
None => {
if payload.eof && (*size != 0) {
Err(MultipartError::Incomplete)
Poll::Ready(Some(Err(MultipartError::Incomplete)))
} else {
Ok(Async::NotReady)
Poll::Pending
}
}
}
@@ -512,15 +517,15 @@ impl InnerField {
fn read_stream(
payload: &mut PayloadBuffer,
boundary: &str,
) -> Poll<Option<Bytes>, MultipartError> {
) -> Poll<Option<Result<Bytes, MultipartError>>> {
let mut pos = 0;
let len = payload.buf.len();
if len == 0 {
return if payload.eof {
Err(MultipartError::Incomplete)
Poll::Ready(Some(Err(MultipartError::Incomplete)))
} else {
Ok(Async::NotReady)
Poll::Pending
};
}
@@ -537,10 +542,10 @@ impl InnerField {
if let Some(b_len) = b_len {
let b_size = boundary.len() + b_len;
if len < b_size {
return Ok(Async::NotReady);
return Poll::Pending;
} else if &payload.buf[b_len..b_size] == boundary.as_bytes() {
// found boundary
return Ok(Async::Ready(None));
return Poll::Ready(None);
}
}
}
@@ -552,9 +557,9 @@ impl InnerField {
// check if we have enough data for boundary detection
if cur + 4 > len {
if cur > 0 {
Ok(Async::Ready(Some(payload.buf.split_to(cur).freeze())))
Poll::Ready(Some(Ok(payload.buf.split_to(cur).freeze())))
} else {
Ok(Async::NotReady)
Poll::Pending
}
} else {
// check boundary
@@ -565,7 +570,7 @@ impl InnerField {
{
if cur != 0 {
// return buffer
Ok(Async::Ready(Some(payload.buf.split_to(cur).freeze())))
Poll::Ready(Some(Ok(payload.buf.split_to(cur).freeze())))
} else {
pos = cur + 1;
continue;
@@ -577,49 +582,51 @@ impl InnerField {
}
}
} else {
Ok(Async::Ready(Some(payload.buf.take().freeze())))
Poll::Ready(Some(Ok(payload.buf.take().freeze())))
};
}
}
fn poll(&mut self, s: &Safety) -> Poll<Option<Bytes>, MultipartError> {
fn poll(&mut self, s: &Safety) -> Poll<Option<Result<Bytes, MultipartError>>> {
if self.payload.is_none() {
return Ok(Async::Ready(None));
return Poll::Ready(None);
}
let result = if let Some(mut payload) = self.payload.as_ref().unwrap().get_mut(s)
{
if !self.eof {
let res = if let Some(ref mut len) = self.length {
InnerField::read_len(&mut *payload, len)?
InnerField::read_len(&mut *payload, len)
} else {
InnerField::read_stream(&mut *payload, &self.boundary)?
InnerField::read_stream(&mut *payload, &self.boundary)
};
match res {
Async::NotReady => return Ok(Async::NotReady),
Async::Ready(Some(bytes)) => return Ok(Async::Ready(Some(bytes))),
Async::Ready(None) => self.eof = true,
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(bytes))) => return Poll::Ready(Some(Ok(bytes))),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(None) => self.eof = true,
}
}
match payload.readline()? {
None => Async::Ready(None),
Some(line) => {
match payload.readline() {
Ok(None) => Poll::Ready(None),
Ok(Some(line)) => {
if line.as_ref() != b"\r\n" {
log::warn!("multipart field did not read all the data or it is malformed");
}
Async::Ready(None)
Poll::Ready(None)
}
Err(e) => Poll::Ready(Some(Err(e))),
}
} else {
Async::NotReady
Poll::Pending
};
if Async::Ready(None) == result {
if let Poll::Ready(None) = result {
self.payload.take();
}
Ok(result)
result
}
}
@@ -659,7 +666,7 @@ impl Clone for PayloadRef {
/// most task.
#[derive(Debug)]
struct Safety {
task: Option<Task>,
task: LocalWaker,
level: usize,
payload: Rc<PhantomData<bool>>,
clean: Rc<Cell<bool>>,
@@ -669,7 +676,7 @@ impl Safety {
fn new() -> Safety {
let payload = Rc::new(PhantomData);
Safety {
task: None,
task: LocalWaker::new(),
level: Rc::strong_count(&payload),
clean: Rc::new(Cell::new(true)),
payload,
@@ -683,17 +690,17 @@ impl Safety {
fn is_clean(&self) -> bool {
self.clean.get()
}
}
impl Clone for Safety {
fn clone(&self) -> Safety {
fn clone(&self, cx: &mut Context) -> Safety {
let payload = Rc::clone(&self.payload);
Safety {
task: Some(current_task()),
let s = Safety {
task: LocalWaker::new(),
level: Rc::strong_count(&payload),
clean: self.clean.clone(),
payload,
}
};
s.task.register(cx.waker());
s
}
}
@@ -704,7 +711,7 @@ impl Drop for Safety {
self.clean.set(true);
}
if let Some(task) = self.task.take() {
task.notify()
task.wake()
}
}
}
@@ -713,31 +720,32 @@ impl Drop for Safety {
struct PayloadBuffer {
eof: bool,
buf: BytesMut,
stream: Box<dyn Stream<Item = Bytes, Error = PayloadError>>,
stream: LocalBoxStream<'static, Result<Bytes, PayloadError>>,
}
impl PayloadBuffer {
/// Create new `PayloadBuffer` instance
fn new<S>(stream: S) -> Self
where
S: Stream<Item = Bytes, Error = PayloadError> + 'static,
S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
{
PayloadBuffer {
eof: false,
buf: BytesMut::new(),
stream: Box::new(stream),
stream: stream.boxed_local(),
}
}
fn poll_stream(&mut self) -> Result<(), PayloadError> {
fn poll_stream(&mut self, cx: &mut Context) -> Result<(), PayloadError> {
loop {
match self.stream.poll()? {
Async::Ready(Some(data)) => self.buf.extend_from_slice(&data),
Async::Ready(None) => {
match Pin::new(&mut self.stream).poll_next(cx) {
Poll::Ready(Some(Ok(data))) => self.buf.extend_from_slice(&data),
Poll::Ready(Some(Err(e))) => return Err(e),
Poll::Ready(None) => {
self.eof = true;
return Ok(());
}
Async::NotReady => return Ok(()),
Poll::Pending => return Ok(()),
}
}
}
@@ -800,13 +808,14 @@ impl PayloadBuffer {
#[cfg(test)]
mod tests {
use actix_http::h1::Payload;
use bytes::Bytes;
use futures::unsync::mpsc;
use super::*;
use actix_http::h1::Payload;
use actix_utils::mpsc;
use actix_web::http::header::{DispositionParam, DispositionType};
use actix_web::test::run_on;
use actix_web::test::block_on;
use bytes::Bytes;
use futures::future::lazy;
#[test]
fn test_boundary() {
@@ -852,12 +861,12 @@ mod tests {
}
fn create_stream() -> (
mpsc::UnboundedSender<Result<Bytes, PayloadError>>,
impl Stream<Item = Bytes, Error = PayloadError>,
mpsc::Sender<Result<Bytes, PayloadError>>,
impl Stream<Item = Result<Bytes, PayloadError>>,
) {
let (tx, rx) = mpsc::unbounded();
let (tx, rx) = mpsc::channel();
(tx, rx.map_err(|_| panic!()).and_then(|res| res))
(tx, rx.map(|res| res.map_err(|_| panic!())))
}
fn create_simple_request_with_header() -> (Bytes, HeaderMap) {
@@ -884,28 +893,28 @@ mod tests {
#[test]
fn test_multipart_no_end_crlf() {
run_on(|| {
block_on(async {
let (sender, payload) = create_stream();
let (bytes, headers) = create_simple_request_with_header();
let bytes_stripped = bytes.slice_to(bytes.len()); // strip crlf
sender.unbounded_send(Ok(bytes_stripped)).unwrap();
sender.send(Ok(bytes_stripped)).unwrap();
drop(sender); // eof
let mut multipart = Multipart::new(&headers, payload);
match multipart.poll().unwrap() {
Async::Ready(Some(_)) => (),
match multipart.next().await.unwrap() {
Ok(_) => (),
_ => unreachable!(),
}
match multipart.poll().unwrap() {
Async::Ready(Some(_)) => (),
match multipart.next().await.unwrap() {
Ok(_) => (),
_ => unreachable!(),
}
match multipart.poll().unwrap() {
Async::Ready(None) => (),
match multipart.next().await {
None => (),
_ => unreachable!(),
}
})
@@ -913,15 +922,15 @@ mod tests {
#[test]
fn test_multipart() {
run_on(|| {
block_on(async {
let (sender, payload) = create_stream();
let (bytes, headers) = create_simple_request_with_header();
sender.unbounded_send(Ok(bytes)).unwrap();
sender.send(Ok(bytes)).unwrap();
let mut multipart = Multipart::new(&headers, payload);
match multipart.poll().unwrap() {
Async::Ready(Some(mut field)) => {
match multipart.next().await {
Some(Ok(mut field)) => {
let cd = field.content_disposition().unwrap();
assert_eq!(cd.disposition, DispositionType::FormData);
assert_eq!(cd.parameters[0], DispositionParam::Name("file".into()));
@@ -929,37 +938,37 @@ mod tests {
assert_eq!(field.content_type().type_(), mime::TEXT);
assert_eq!(field.content_type().subtype(), mime::PLAIN);
match field.poll().unwrap() {
Async::Ready(Some(chunk)) => assert_eq!(chunk, "test"),
match field.next().await.unwrap() {
Ok(chunk) => assert_eq!(chunk, "test"),
_ => unreachable!(),
}
match field.poll().unwrap() {
Async::Ready(None) => (),
match field.next().await {
None => (),
_ => unreachable!(),
}
}
_ => unreachable!(),
}
match multipart.poll().unwrap() {
Async::Ready(Some(mut field)) => {
match multipart.next().await.unwrap() {
Ok(mut field) => {
assert_eq!(field.content_type().type_(), mime::TEXT);
assert_eq!(field.content_type().subtype(), mime::PLAIN);
match field.poll() {
Ok(Async::Ready(Some(chunk))) => assert_eq!(chunk, "data"),
match field.next().await {
Some(Ok(chunk)) => assert_eq!(chunk, "data"),
_ => unreachable!(),
}
match field.poll() {
Ok(Async::Ready(None)) => (),
match field.next().await {
None => (),
_ => unreachable!(),
}
}
_ => unreachable!(),
}
match multipart.poll().unwrap() {
Async::Ready(None) => (),
match multipart.next().await {
None => (),
_ => unreachable!(),
}
});
@@ -967,15 +976,15 @@ mod tests {
#[test]
fn test_stream() {
run_on(|| {
block_on(async {
let (sender, payload) = create_stream();
let (bytes, headers) = create_simple_request_with_header();
sender.unbounded_send(Ok(bytes)).unwrap();
sender.send(Ok(bytes)).unwrap();
let mut multipart = Multipart::new(&headers, payload);
match multipart.poll().unwrap() {
Async::Ready(Some(mut field)) => {
match multipart.next().await.unwrap() {
Ok(mut field) => {
let cd = field.content_disposition().unwrap();
assert_eq!(cd.disposition, DispositionType::FormData);
assert_eq!(cd.parameters[0], DispositionParam::Name("file".into()));
@@ -983,37 +992,37 @@ mod tests {
assert_eq!(field.content_type().type_(), mime::TEXT);
assert_eq!(field.content_type().subtype(), mime::PLAIN);
match field.poll().unwrap() {
Async::Ready(Some(chunk)) => assert_eq!(chunk, "test"),
match field.next().await.unwrap() {
Ok(chunk) => assert_eq!(chunk, "test"),
_ => unreachable!(),
}
match field.poll().unwrap() {
Async::Ready(None) => (),
match field.next().await {
None => (),
_ => unreachable!(),
}
}
_ => unreachable!(),
}
match multipart.poll().unwrap() {
Async::Ready(Some(mut field)) => {
match multipart.next().await {
Some(Ok(mut field)) => {
assert_eq!(field.content_type().type_(), mime::TEXT);
assert_eq!(field.content_type().subtype(), mime::PLAIN);
match field.poll() {
Ok(Async::Ready(Some(chunk))) => assert_eq!(chunk, "data"),
match field.next().await {
Some(Ok(chunk)) => assert_eq!(chunk, "data"),
_ => unreachable!(),
}
match field.poll() {
Ok(Async::Ready(None)) => (),
match field.next().await {
None => (),
_ => unreachable!(),
}
}
_ => unreachable!(),
}
match multipart.poll().unwrap() {
Async::Ready(None) => (),
match multipart.next().await {
None => (),
_ => unreachable!(),
}
});
@@ -1021,26 +1030,26 @@ mod tests {
#[test]
fn test_basic() {
run_on(|| {
block_on(async {
let (_, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload);
assert_eq!(payload.buf.len(), 0);
payload.poll_stream().unwrap();
lazy(|cx| payload.poll_stream(cx)).await.unwrap();
assert_eq!(None, payload.read_max(1).unwrap());
})
}
#[test]
fn test_eof() {
run_on(|| {
block_on(async {
let (mut sender, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload);
assert_eq!(None, payload.read_max(4).unwrap());
sender.feed_data(Bytes::from("data"));
sender.feed_eof();
payload.poll_stream().unwrap();
lazy(|cx| payload.poll_stream(cx)).await.unwrap();
assert_eq!(Some(Bytes::from("data")), payload.read_max(4).unwrap());
assert_eq!(payload.buf.len(), 0);
@@ -1051,24 +1060,24 @@ mod tests {
#[test]
fn test_err() {
run_on(|| {
block_on(async {
let (mut sender, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload);
assert_eq!(None, payload.read_max(1).unwrap());
sender.set_error(PayloadError::Incomplete(None));
payload.poll_stream().err().unwrap();
lazy(|cx| payload.poll_stream(cx)).await.err().unwrap();
})
}
#[test]
fn test_readmax() {
run_on(|| {
block_on(async {
let (mut sender, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload);
sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2"));
payload.poll_stream().unwrap();
lazy(|cx| payload.poll_stream(cx)).await.unwrap();
assert_eq!(payload.buf.len(), 10);
assert_eq!(Some(Bytes::from("line1")), payload.read_max(5).unwrap());
@@ -1081,7 +1090,7 @@ mod tests {
#[test]
fn test_readexactly() {
run_on(|| {
block_on(async {
let (mut sender, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload);
@@ -1089,7 +1098,7 @@ mod tests {
sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2"));
payload.poll_stream().unwrap();
lazy(|cx| payload.poll_stream(cx)).await.unwrap();
assert_eq!(Some(Bytes::from_static(b"li")), payload.read_exact(2));
assert_eq!(payload.buf.len(), 8);
@@ -1101,7 +1110,7 @@ mod tests {
#[test]
fn test_readuntil() {
run_on(|| {
block_on(async {
let (mut sender, payload) = Payload::create(false);
let mut payload = PayloadBuffer::new(payload);
@@ -1109,7 +1118,7 @@ mod tests {
sender.feed_data(Bytes::from("line1"));
sender.feed_data(Bytes::from("line2"));
payload.poll_stream().unwrap();
lazy(|cx| payload.poll_stream(cx)).await.unwrap();
assert_eq!(
Some(Bytes::from("line")),

View File

@@ -1,6 +1,6 @@
[package]
name = "actix-session"
version = "0.2.0"
version = "0.3.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Session for actix web framework."
readme = "README.md"
@@ -24,15 +24,15 @@ default = ["cookie-session"]
cookie-session = ["actix-web/secure-cookies"]
[dependencies]
actix-web = "1.0.0"
actix-service = "0.4.1"
actix-web = "2.0.0-alpha.1"
actix-service = "1.0.0-alpha.1"
bytes = "0.4"
derive_more = "0.15.0"
futures = "0.1.25"
hashbrown = "0.5.0"
futures = "0.3.1"
hashbrown = "0.6.3"
serde = "1.0"
serde_json = "1.0"
time = "0.1.42"
[dev-dependencies]
actix-rt = "0.2.2"
actix-rt = "1.0.0-alpha.1"

View File

@@ -17,6 +17,7 @@
use std::collections::HashMap;
use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::{Service, Transform};
use actix_web::cookie::{Cookie, CookieJar, Key, SameSite};
@@ -24,8 +25,7 @@ use actix_web::dev::{ServiceRequest, ServiceResponse};
use actix_web::http::{header::SET_COOKIE, HeaderValue};
use actix_web::{Error, HttpMessage, ResponseError};
use derive_more::{Display, From};
use futures::future::{ok, Future, FutureResult};
use futures::Poll;
use futures::future::{ok, FutureExt, LocalBoxFuture, Ready};
use serde_json::error::Error as JsonError;
use crate::{Session, SessionStatus};
@@ -284,7 +284,7 @@ where
type Error = S::Error;
type InitError = ();
type Transform = CookieSessionMiddleware<S>;
type Future = FutureResult<Self::Transform, Self::InitError>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(CookieSessionMiddleware {
@@ -309,10 +309,10 @@ where
type Request = ServiceRequest;
type Response = ServiceResponse<B>;
type Error = S::Error;
type Future = Box<dyn Future<Item = Self::Response, Error = Self::Error>>;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.poll_ready()
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
/// On first request, a new session cookie is returned in response, regardless
@@ -325,7 +325,10 @@ where
let (is_new, state) = self.inner.load(&req);
Session::set_session(state.into_iter(), &mut req);
Box::new(self.service.call(req).map(move |mut res| {
let fut = self.service.call(req);
async move {
fut.await.map(|mut res| {
match Session::get_changes(&mut res) {
(SessionStatus::Changed, Some(state))
| (SessionStatus::Renewed, Some(state)) => {
@@ -336,7 +339,9 @@ where
{
if is_new {
let state: HashMap<String, String> = HashMap::new();
res.checked_expr(|res| inner.set_cookie(res, state.into_iter()))
res.checked_expr(|res| {
inner.set_cookie(res, state.into_iter())
})
} else {
res
}
@@ -347,7 +352,9 @@ where
}
_ => res,
}
}))
})
}
.boxed_local()
}
}
@@ -359,66 +366,82 @@ mod tests {
#[test]
fn cookie_session() {
test::block_on(async {
let mut app = test::init_service(
App::new()
.wrap(CookieSession::signed(&[0; 32]).secure(false))
.service(web::resource("/").to(|ses: Session| {
async move {
let _ = ses.set("counter", 100);
"test"
}
})),
);
)
.await;
let request = test::TestRequest::get().to_request();
let response = test::block_on(app.call(request)).unwrap();
let response = app.call(request).await.unwrap();
assert!(response
.response()
.cookies()
.find(|c| c.name() == "actix-session")
.is_some());
})
}
#[test]
fn private_cookie() {
test::block_on(async {
let mut app = test::init_service(
App::new()
.wrap(CookieSession::private(&[0; 32]).secure(false))
.service(web::resource("/").to(|ses: Session| {
async move {
let _ = ses.set("counter", 100);
"test"
}
})),
);
)
.await;
let request = test::TestRequest::get().to_request();
let response = test::block_on(app.call(request)).unwrap();
let response = app.call(request).await.unwrap();
assert!(response
.response()
.cookies()
.find(|c| c.name() == "actix-session")
.is_some());
})
}
#[test]
fn cookie_session_extractor() {
test::block_on(async {
let mut app = test::init_service(
App::new()
.wrap(CookieSession::signed(&[0; 32]).secure(false))
.service(web::resource("/").to(|ses: Session| {
async move {
let _ = ses.set("counter", 100);
"test"
}
})),
);
)
.await;
let request = test::TestRequest::get().to_request();
let response = test::block_on(app.call(request)).unwrap();
let response = app.call(request).await.unwrap();
assert!(response
.response()
.cookies()
.find(|c| c.name() == "actix-session")
.is_some());
})
}
#[test]
fn basics() {
test::block_on(async {
let mut app = test::init_service(
App::new()
.wrap(
@@ -431,17 +454,22 @@ mod tests {
.max_age(100),
)
.service(web::resource("/").to(|ses: Session| {
async move {
let _ = ses.set("counter", 100);
"test"
}
}))
.service(web::resource("/test/").to(|ses: Session| {
async move {
let val: usize = ses.get("counter").unwrap().unwrap();
format!("counter: {}", val)
}
})),
);
)
.await;
let request = test::TestRequest::get().to_request();
let response = test::block_on(app.call(request)).unwrap();
let response = app.call(request).await.unwrap();
let cookie = response
.response()
.cookies()
@@ -453,7 +481,8 @@ mod tests {
let request = test::TestRequest::with_uri("/test/")
.cookie(cookie)
.to_request();
let body = test::read_response(&mut app, request);
let body = test::read_response(&mut app, request).await;
assert_eq!(body, Bytes::from_static(b"counter: 100"));
})
}
}

View File

@@ -47,6 +47,7 @@ use std::rc::Rc;
use actix_web::dev::{Extensions, Payload, ServiceRequest, ServiceResponse};
use actix_web::{Error, FromRequest, HttpMessage, HttpRequest};
use futures::future::{ok, Ready};
use hashbrown::HashMap;
use serde::de::DeserializeOwned;
use serde::Serialize;
@@ -230,12 +231,12 @@ impl Session {
/// ```
impl FromRequest for Session {
type Error = Error;
type Future = Result<Session, Error>;
type Future = Ready<Result<Session, Error>>;
type Config = ();
#[inline]
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
Ok(Session::get_session(&mut *req.extensions_mut()))
ok(Session::get_session(&mut *req.extensions_mut()))
}
}

View File

@@ -1,5 +1,9 @@
# Changes
## [1.0.3] - 2019-11-14
* Update actix-web and actix-http dependencies
## [1.0.2] - 2019-07-20
* Add `ws::start_with_addr()`, returning the address of the created actor, along

View File

@@ -1,6 +1,6 @@
[package]
name = "actix-web-actors"
version = "1.0.2"
version = "1.0.3"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix actors support for actix web framework."
readme = "README.md"
@@ -19,8 +19,8 @@ path = "src/lib.rs"
[dependencies]
actix = "0.8.3"
actix-web = "1.0.3"
actix-http = "0.2.5"
actix-web = "1.0.9"
actix-http = "0.2.11"
actix-codec = "0.1.2"
bytes = "0.4"
futures = "0.1.25"

View File

@@ -1,5 +1,11 @@
# Changes
## [0.1.3] - 2019-10-14
* Bump up `syn` & `quote` to 1.0
* Provide better error message
## [0.1.2] - 2019-06-04
* Add macros for head, options, trace, connect and patch http methods

View File

@@ -1,6 +1,6 @@
[package]
name = "actix-web-codegen"
version = "0.1.2"
version = "0.2.0-alpha.1"
description = "Actix web proc macros"
readme = "README.md"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
@@ -12,11 +12,12 @@ workspace = ".."
proc-macro = true
[dependencies]
quote = "0.6.12"
syn = { version = "0.15.34", features = ["full", "parsing", "extra-traits"] }
quote = "1"
syn = { version = "1", features = ["full", "parsing"] }
proc-macro2 = "1"
[dev-dependencies]
actix-web = { version = "1.0.0" }
actix-http = { version = "0.2.4", features=["ssl"] }
actix-http-test = { version = "0.2.0", features=["ssl"] }
futures = { version = "0.1" }
actix-web = { version = "2.0.0-alph.a" }
actix-http = { version = "0.3.0-alpha.1", features=["openssl"] }
actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] }
futures = { version = "0.3.1" }

View File

@@ -35,8 +35,8 @@
//! use futures::{future, Future};
//!
//! #[get("/test")]
//! fn async_test() -> impl Future<Item=HttpResponse, Error=actix_web::Error> {
//! future::ok(HttpResponse::Ok().finish())
//! async fn async_test() -> Result<HttpResponse, actix_web::Error> {
//! Ok(HttpResponse::Ok().finish())
//! }
//! ```
@@ -58,7 +58,10 @@ use syn::parse_macro_input;
#[proc_macro_attribute]
pub fn get(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as syn::AttributeArgs);
let gen = route::Args::new(&args, input, route::GuardType::Get);
let gen = match route::Route::new(args, input, route::GuardType::Get) {
Ok(gen) => gen,
Err(err) => return err.to_compile_error().into(),
};
gen.generate()
}
@@ -70,7 +73,10 @@ pub fn get(args: TokenStream, input: TokenStream) -> TokenStream {
#[proc_macro_attribute]
pub fn post(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as syn::AttributeArgs);
let gen = route::Args::new(&args, input, route::GuardType::Post);
let gen = match route::Route::new(args, input, route::GuardType::Post) {
Ok(gen) => gen,
Err(err) => return err.to_compile_error().into(),
};
gen.generate()
}
@@ -82,7 +88,10 @@ pub fn post(args: TokenStream, input: TokenStream) -> TokenStream {
#[proc_macro_attribute]
pub fn put(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as syn::AttributeArgs);
let gen = route::Args::new(&args, input, route::GuardType::Put);
let gen = match route::Route::new(args, input, route::GuardType::Put) {
Ok(gen) => gen,
Err(err) => return err.to_compile_error().into(),
};
gen.generate()
}
@@ -94,7 +103,10 @@ pub fn put(args: TokenStream, input: TokenStream) -> TokenStream {
#[proc_macro_attribute]
pub fn delete(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as syn::AttributeArgs);
let gen = route::Args::new(&args, input, route::GuardType::Delete);
let gen = match route::Route::new(args, input, route::GuardType::Delete) {
Ok(gen) => gen,
Err(err) => return err.to_compile_error().into(),
};
gen.generate()
}
@@ -106,7 +118,10 @@ pub fn delete(args: TokenStream, input: TokenStream) -> TokenStream {
#[proc_macro_attribute]
pub fn head(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as syn::AttributeArgs);
let gen = route::Args::new(&args, input, route::GuardType::Head);
let gen = match route::Route::new(args, input, route::GuardType::Head) {
Ok(gen) => gen,
Err(err) => return err.to_compile_error().into(),
};
gen.generate()
}
@@ -118,7 +133,10 @@ pub fn head(args: TokenStream, input: TokenStream) -> TokenStream {
#[proc_macro_attribute]
pub fn connect(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as syn::AttributeArgs);
let gen = route::Args::new(&args, input, route::GuardType::Connect);
let gen = match route::Route::new(args, input, route::GuardType::Connect) {
Ok(gen) => gen,
Err(err) => return err.to_compile_error().into(),
};
gen.generate()
}
@@ -130,7 +148,10 @@ pub fn connect(args: TokenStream, input: TokenStream) -> TokenStream {
#[proc_macro_attribute]
pub fn options(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as syn::AttributeArgs);
let gen = route::Args::new(&args, input, route::GuardType::Options);
let gen = match route::Route::new(args, input, route::GuardType::Options) {
Ok(gen) => gen,
Err(err) => return err.to_compile_error().into(),
};
gen.generate()
}
@@ -142,7 +163,10 @@ pub fn options(args: TokenStream, input: TokenStream) -> TokenStream {
#[proc_macro_attribute]
pub fn trace(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as syn::AttributeArgs);
let gen = route::Args::new(&args, input, route::GuardType::Trace);
let gen = match route::Route::new(args, input, route::GuardType::Trace) {
Ok(gen) => gen,
Err(err) => return err.to_compile_error().into(),
};
gen.generate()
}
@@ -154,6 +178,9 @@ pub fn trace(args: TokenStream, input: TokenStream) -> TokenStream {
#[proc_macro_attribute]
pub fn patch(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as syn::AttributeArgs);
let gen = route::Args::new(&args, input, route::GuardType::Patch);
let gen = match route::Route::new(args, input, route::GuardType::Patch) {
Ok(gen) => gen,
Err(err) => return err.to_compile_error().into(),
};
gen.generate()
}

View File

@@ -1,21 +1,23 @@
extern crate proc_macro;
use std::fmt;
use proc_macro::TokenStream;
use quote::quote;
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{quote, ToTokens, TokenStreamExt};
use syn::{AttributeArgs, Ident, NestedMeta};
enum ResourceType {
Async,
Sync,
}
impl fmt::Display for ResourceType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
ResourceType::Async => write!(f, "to_async"),
ResourceType::Sync => write!(f, "to"),
}
impl ToTokens for ResourceType {
fn to_tokens(&self, stream: &mut TokenStream2) {
let ident = match self {
ResourceType::Async => "to",
ResourceType::Sync => "to",
};
let ident = Ident::new(ident, Span::call_site());
stream.append(ident);
}
}
@@ -32,61 +34,87 @@ pub enum GuardType {
Patch,
}
impl fmt::Display for GuardType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
GuardType::Get => write!(f, "Get"),
GuardType::Post => write!(f, "Post"),
GuardType::Put => write!(f, "Put"),
GuardType::Delete => write!(f, "Delete"),
GuardType::Head => write!(f, "Head"),
GuardType::Connect => write!(f, "Connect"),
GuardType::Options => write!(f, "Options"),
GuardType::Trace => write!(f, "Trace"),
GuardType::Patch => write!(f, "Patch"),
impl GuardType {
fn as_str(&self) -> &'static str {
match self {
GuardType::Get => "Get",
GuardType::Post => "Post",
GuardType::Put => "Put",
GuardType::Delete => "Delete",
GuardType::Head => "Head",
GuardType::Connect => "Connect",
GuardType::Options => "Options",
GuardType::Trace => "Trace",
GuardType::Patch => "Patch",
}
}
}
pub struct Args {
impl ToTokens for GuardType {
fn to_tokens(&self, stream: &mut TokenStream2) {
let ident = self.as_str();
let ident = Ident::new(ident, Span::call_site());
stream.append(ident);
}
}
struct Args {
path: syn::LitStr,
guards: Vec<Ident>,
}
impl Args {
fn new(args: AttributeArgs) -> syn::Result<Self> {
let mut path = None;
let mut guards = Vec::new();
for arg in args {
match arg {
NestedMeta::Lit(syn::Lit::Str(lit)) => match path {
None => {
path = Some(lit);
}
_ => {
return Err(syn::Error::new_spanned(
lit,
"Multiple paths specified! Should be only one!",
));
}
},
NestedMeta::Meta(syn::Meta::NameValue(nv)) => {
if nv.path.is_ident("guard") {
if let syn::Lit::Str(lit) = nv.lit {
guards.push(Ident::new(&lit.value(), Span::call_site()));
} else {
return Err(syn::Error::new_spanned(
nv.lit,
"Attribute guard expects literal string!",
));
}
} else {
return Err(syn::Error::new_spanned(
nv.path,
"Unknown attribute key is specified. Allowed: guard",
));
}
}
arg => {
return Err(syn::Error::new_spanned(arg, "Unknown attribute"));
}
}
}
Ok(Args {
path: path.unwrap(),
guards,
})
}
}
pub struct Route {
name: syn::Ident,
path: String,
args: Args,
ast: syn::ItemFn,
resource_type: ResourceType,
pub guard: GuardType,
pub extra_guards: Vec<String>,
}
impl fmt::Display for Args {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let ast = &self.ast;
let guards = format!(".guard(actix_web::guard::{}())", self.guard);
let guards = self.extra_guards.iter().fold(guards, |acc, val| {
format!("{}.guard(actix_web::guard::fn_guard({}))", acc, val)
});
write!(
f,
"
#[allow(non_camel_case_types)]
pub struct {name};
impl actix_web::dev::HttpServiceFactory for {name} {{
fn register(self, config: &mut actix_web::dev::AppService) {{
{ast}
let resource = actix_web::Resource::new(\"{path}\"){guards}.{to}({name});
actix_web::dev::HttpServiceFactory::register(resource, config)
}}
}}",
name = self.name,
ast = quote!(#ast),
path = self.path,
guards = guards,
to = self.resource_type
)
}
guard: GuardType,
}
fn guess_resource_type(typ: &syn::Type) -> ResourceType {
@@ -111,75 +139,73 @@ fn guess_resource_type(typ: &syn::Type) -> ResourceType {
guess
}
impl Args {
pub fn new(args: &[syn::NestedMeta], input: TokenStream, guard: GuardType) -> Self {
impl Route {
pub fn new(
args: AttributeArgs,
input: TokenStream,
guard: GuardType,
) -> syn::Result<Self> {
if args.is_empty() {
panic!(
"invalid server definition, expected: #[{}(\"some path\")]",
guard
);
}
let ast: syn::ItemFn = syn::parse(input).expect("Parse input as function");
let name = ast.ident.clone();
let mut extra_guards = Vec::new();
let mut path = None;
for arg in args {
match arg {
syn::NestedMeta::Literal(syn::Lit::Str(ref fname)) => {
if path.is_some() {
panic!("Multiple paths specified! Should be only one!")
}
let fname = quote!(#fname).to_string();
path = Some(fname.as_str()[1..fname.len() - 1].to_owned())
}
syn::NestedMeta::Meta(syn::Meta::NameValue(ident)) => {
match ident.ident.to_string().to_lowercase().as_str() {
"guard" => match ident.lit {
syn::Lit::Str(ref text) => extra_guards.push(text.value()),
_ => panic!("Attribute guard expects literal string!"),
},
attr => panic!(
"Unknown attribute key is specified: {}. Allowed: guard",
attr
return Err(syn::Error::new(
Span::call_site(),
format!(
r#"invalid server definition, expected #[{}("<some path>")]"#,
guard.as_str().to_ascii_lowercase()
),
));
}
}
attr => panic!("Unknown attribute{:?}", attr),
}
}
let ast: syn::ItemFn = syn::parse(input)?;
let name = ast.sig.ident.clone();
let resource_type = if ast.asyncness.is_some() {
let args = Args::new(args)?;
let resource_type = if ast.sig.asyncness.is_some() {
ResourceType::Async
} else {
match ast.decl.output {
syn::ReturnType::Default => panic!(
"Function {} has no return type. Cannot be used as handler",
name
),
match ast.sig.output {
syn::ReturnType::Default => {
return Err(syn::Error::new_spanned(
ast,
"Function has no return type. Cannot be used as handler",
));
}
syn::ReturnType::Type(_, ref typ) => guess_resource_type(typ.as_ref()),
}
};
let path = path.unwrap();
Self {
Ok(Self {
name,
path,
args,
ast,
resource_type,
guard,
extra_guards,
}
})
}
pub fn generate(&self) -> TokenStream {
let text = self.to_string();
let name = &self.name;
let guard = &self.guard;
let ast = &self.ast;
let path = &self.args.path;
let extra_guards = &self.args.guards;
let resource_type = &self.resource_type;
let stream = quote! {
#[allow(non_camel_case_types)]
pub struct #name;
match text.parse() {
Ok(res) => res,
Err(error) => panic!("Error: {:?}\nGenerated code: {}", error, text),
impl actix_web::dev::HttpServiceFactory for #name {
fn register(self, config: &mut actix_web::dev::AppService) {
#ast
let resource = actix_web::Resource::new(#path)
.guard(actix_web::guard::#guard())
#(.guard(actix_web::guard::fn_guard(#extra_guards)))*
.#resource_type(#name);
actix_web::dev::HttpServiceFactory::register(resource, config)
}
}
};
stream.into()
}
}

View File

@@ -1,77 +1,78 @@
use actix_http::HttpService;
use actix_http_test::TestServer;
use actix_http_test::{block_on, TestServer};
use actix_web::{http, web::Path, App, HttpResponse, Responder};
use actix_web_codegen::{connect, delete, get, head, options, patch, post, put, trace};
use futures::{future, Future};
#[get("/test")]
fn test() -> impl Responder {
async fn test() -> impl Responder {
HttpResponse::Ok()
}
#[put("/test")]
fn put_test() -> impl Responder {
async fn put_test() -> impl Responder {
HttpResponse::Created()
}
#[patch("/test")]
fn patch_test() -> impl Responder {
async fn patch_test() -> impl Responder {
HttpResponse::Ok()
}
#[post("/test")]
fn post_test() -> impl Responder {
async fn post_test() -> impl Responder {
HttpResponse::NoContent()
}
#[head("/test")]
fn head_test() -> impl Responder {
async fn head_test() -> impl Responder {
HttpResponse::Ok()
}
#[connect("/test")]
fn connect_test() -> impl Responder {
async fn connect_test() -> impl Responder {
HttpResponse::Ok()
}
#[options("/test")]
fn options_test() -> impl Responder {
async fn options_test() -> impl Responder {
HttpResponse::Ok()
}
#[trace("/test")]
fn trace_test() -> impl Responder {
async fn trace_test() -> impl Responder {
HttpResponse::Ok()
}
#[get("/test")]
fn auto_async() -> impl Future<Item = HttpResponse, Error = actix_web::Error> {
fn auto_async() -> impl Future<Output = Result<HttpResponse, actix_web::Error>> {
future::ok(HttpResponse::Ok().finish())
}
#[get("/test")]
fn auto_sync() -> impl Future<Item = HttpResponse, Error = actix_web::Error> {
fn auto_sync() -> impl Future<Output = Result<HttpResponse, actix_web::Error>> {
future::ok(HttpResponse::Ok().finish())
}
#[put("/test/{param}")]
fn put_param_test(_: Path<String>) -> impl Responder {
async fn put_param_test(_: Path<String>) -> impl Responder {
HttpResponse::Created()
}
#[delete("/test/{param}")]
fn delete_param_test(_: Path<String>) -> impl Responder {
async fn delete_param_test(_: Path<String>) -> impl Responder {
HttpResponse::NoContent()
}
#[get("/test/{param}")]
fn get_param_test(_: Path<String>) -> impl Responder {
async fn get_param_test(_: Path<String>) -> impl Responder {
HttpResponse::Ok()
}
#[test]
fn test_params() {
let mut srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::new(
App::new()
.service(get_param_test)
@@ -81,21 +82,23 @@ fn test_params() {
});
let request = srv.request(http::Method::GET, srv.url("/test/it"));
let response = srv.block_on(request.send()).unwrap();
let response = request.send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::OK);
let request = srv.request(http::Method::PUT, srv.url("/test/it"));
let response = srv.block_on(request.send()).unwrap();
let response = request.send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::CREATED);
let request = srv.request(http::Method::DELETE, srv.url("/test/it"));
let response = srv.block_on(request.send()).unwrap();
let response = request.send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::NO_CONTENT);
})
}
#[test]
fn test_body() {
let mut srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::new(
App::new()
.service(post_test)
@@ -109,49 +112,52 @@ fn test_body() {
)
});
let request = srv.request(http::Method::GET, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap();
let response = request.send().await.unwrap();
assert!(response.status().is_success());
let request = srv.request(http::Method::HEAD, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap();
let response = request.send().await.unwrap();
assert!(response.status().is_success());
let request = srv.request(http::Method::CONNECT, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap();
let response = request.send().await.unwrap();
assert!(response.status().is_success());
let request = srv.request(http::Method::OPTIONS, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap();
let response = request.send().await.unwrap();
assert!(response.status().is_success());
let request = srv.request(http::Method::TRACE, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap();
let response = request.send().await.unwrap();
assert!(response.status().is_success());
let request = srv.request(http::Method::PATCH, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap();
let response = request.send().await.unwrap();
assert!(response.status().is_success());
let request = srv.request(http::Method::PUT, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap();
let response = request.send().await.unwrap();
assert!(response.status().is_success());
assert_eq!(response.status(), http::StatusCode::CREATED);
let request = srv.request(http::Method::POST, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap();
let response = request.send().await.unwrap();
assert!(response.status().is_success());
assert_eq!(response.status(), http::StatusCode::NO_CONTENT);
let request = srv.request(http::Method::GET, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap();
let response = request.send().await.unwrap();
assert!(response.status().is_success());
})
}
#[test]
fn test_auto_async() {
let mut srv = TestServer::new(|| HttpService::new(App::new().service(auto_async)));
block_on(async {
let srv = TestServer::start(|| HttpService::new(App::new().service(auto_async)));
let request = srv.request(http::Method::GET, srv.url("/test"));
let response = srv.block_on(request.send()).unwrap();
let response = request.send().await.unwrap();
assert!(response.status().is_success());
})
}

View File

@@ -1,5 +1,9 @@
# Changes
## [0.2.8] - 2019-11-06
* Add support for setting query from Serialize type for client request.
## [0.2.7] - 2019-09-25

View File

@@ -1,6 +1,6 @@
[package]
name = "awc"
version = "0.2.7"
version = "0.3.0-alpha.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix http client."
readme = "README.md"
@@ -21,16 +21,16 @@ name = "awc"
path = "src/lib.rs"
[package.metadata.docs.rs]
features = ["ssl", "brotli", "flate2-zlib"]
features = ["openssl", "brotli", "flate2-zlib"]
[features]
default = ["brotli", "flate2-zlib"]
# openssl
ssl = ["openssl", "actix-http/ssl"]
openssl = ["open-ssl", "actix-http/openssl"]
# rustls
rust-tls = ["rustls", "actix-http/rust-tls"]
# rustls = ["rust-tls", "actix-http/rustls"]
# brotli encoding, requires c compiler
brotli = ["actix-http/brotli"]
@@ -42,13 +42,14 @@ flate2-zlib = ["actix-http/flate2-zlib"]
flate2-rust = ["actix-http/flate2-rust"]
[dependencies]
actix-codec = "0.1.2"
actix-service = "0.4.1"
actix-http = "0.2.10"
actix-codec = "0.2.0-alpha.1"
actix-service = "1.0.0-alpha.1"
actix-http = "0.3.0-alpha.1"
base64 = "0.10.1"
bytes = "0.4"
derive_more = "0.15.0"
futures = "0.1.25"
futures = "0.3.1"
log =" 0.4"
mime = "0.3"
percent-encoding = "2.1"
@@ -56,21 +57,21 @@ rand = "0.7"
serde = "1.0"
serde_json = "1.0"
serde_urlencoded = "0.6.1"
tokio-timer = "0.2.8"
openssl = { version="0.10", optional = true }
rustls = { version = "0.15.2", optional = true }
tokio-timer = "0.3.0-alpha.6"
open-ssl = { version="0.10", package="openssl", optional = true }
rust-tls = { version = "0.16.0", package="rustls", optional = true, features = ["dangerous_configuration"] }
[dev-dependencies]
actix-rt = "0.2.2"
actix-web = { version = "1.0.0", features=["ssl"] }
actix-http = { version = "0.2.10", features=["ssl"] }
actix-http-test = { version = "0.2.0", features=["ssl"] }
actix-utils = "0.4.1"
actix-server = { version = "0.6.0", features=["ssl", "rust-tls"] }
actix-rt = "1.0.0-alpha.1"
actix-connect = { version = "1.0.0-alpha.1", features=["openssl"] }
actix-web = { version = "2.0.0-alpha.1", features=["openssl"] }
actix-http = { version = "0.3.0-alpha.1", features=["openssl"] }
actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] }
actix-utils = "0.5.0-alpha.1"
actix-server = { version = "0.8.0-alpha.1", features=["openssl"] }
brotli2 = { version="0.3.2" }
flate2 = { version="1.0.2" }
env_logger = "0.6"
rand = "0.7"
tokio-tcp = "0.1"
webpki = "0.19"
rustls = { version = "0.15.2", features = ["dangerous_configuration"] }
webpki = { version = "0.21" }

View File

@@ -1,4 +1,6 @@
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use std::{fmt, io, net};
use actix_codec::{AsyncRead, AsyncWrite, Framed};
@@ -10,7 +12,7 @@ use actix_http::h1::ClientCodec;
use actix_http::http::HeaderMap;
use actix_http::{RequestHead, RequestHeadType, ResponseHead};
use actix_service::Service;
use futures::{Future, Poll};
use futures::future::{FutureExt, LocalBoxFuture};
use crate::response::ClientResponse;
@@ -22,7 +24,7 @@ pub(crate) trait Connect {
head: RequestHead,
body: Body,
addr: Option<net::SocketAddr>,
) -> Box<dyn Future<Item = ClientResponse, Error = SendRequestError>>;
) -> LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>>;
fn send_request_extra(
&mut self,
@@ -30,18 +32,16 @@ pub(crate) trait Connect {
extra_headers: Option<HeaderMap>,
body: Body,
addr: Option<net::SocketAddr>,
) -> Box<dyn Future<Item = ClientResponse, Error = SendRequestError>>;
) -> LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>>;
/// Send request, returns Response and Framed
fn open_tunnel(
&mut self,
head: RequestHead,
addr: Option<net::SocketAddr>,
) -> Box<
dyn Future<
Item = (ResponseHead, Framed<BoxedSocket, ClientCodec>),
Error = SendRequestError,
>,
) -> LocalBoxFuture<
'static,
Result<(ResponseHead, Framed<BoxedSocket, ClientCodec>), SendRequestError>,
>;
/// Send request and extra headers, returns Response and Framed
@@ -50,11 +50,9 @@ pub(crate) trait Connect {
head: Rc<RequestHead>,
extra_headers: Option<HeaderMap>,
addr: Option<net::SocketAddr>,
) -> Box<
dyn Future<
Item = (ResponseHead, Framed<BoxedSocket, ClientCodec>),
Error = SendRequestError,
>,
) -> LocalBoxFuture<
'static,
Result<(ResponseHead, Framed<BoxedSocket, ClientCodec>), SendRequestError>,
>;
}
@@ -72,21 +70,23 @@ where
head: RequestHead,
body: Body,
addr: Option<net::SocketAddr>,
) -> Box<dyn Future<Item = ClientResponse, Error = SendRequestError>> {
Box::new(
self.0
) -> LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>> {
// connect to the host
.call(ClientConnect {
let fut = self.0.call(ClientConnect {
uri: head.uri.clone(),
addr,
})
.from_err()
});
async move {
let connection = fut.await?;
// send request
.and_then(move |connection| {
connection.send_request(RequestHeadType::from(head), body)
})
.map(|(head, payload)| ClientResponse::new(head, payload)),
)
connection
.send_request(RequestHeadType::from(head), body)
.await
.map(|(head, payload)| ClientResponse::new(head, payload))
}
.boxed_local()
}
fn send_request_extra(
@@ -95,51 +95,51 @@ where
extra_headers: Option<HeaderMap>,
body: Body,
addr: Option<net::SocketAddr>,
) -> Box<dyn Future<Item = ClientResponse, Error = SendRequestError>> {
Box::new(
self.0
) -> LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>> {
// connect to the host
.call(ClientConnect {
let fut = self.0.call(ClientConnect {
uri: head.uri.clone(),
addr,
})
.from_err()
});
async move {
let connection = fut.await?;
// send request
.and_then(move |connection| {
connection
let (head, payload) = connection
.send_request(RequestHeadType::Rc(head, extra_headers), body)
})
.map(|(head, payload)| ClientResponse::new(head, payload)),
)
.await?;
Ok(ClientResponse::new(head, payload))
}
.boxed_local()
}
fn open_tunnel(
&mut self,
head: RequestHead,
addr: Option<net::SocketAddr>,
) -> Box<
dyn Future<
Item = (ResponseHead, Framed<BoxedSocket, ClientCodec>),
Error = SendRequestError,
>,
) -> LocalBoxFuture<
'static,
Result<(ResponseHead, Framed<BoxedSocket, ClientCodec>), SendRequestError>,
> {
Box::new(
self.0
// connect to the host
.call(ClientConnect {
let fut = self.0.call(ClientConnect {
uri: head.uri.clone(),
addr,
})
.from_err()
});
async move {
let connection = fut.await?;
// send request
.and_then(move |connection| {
connection.open_tunnel(RequestHeadType::from(head))
})
.map(|(head, framed)| {
let (head, framed) =
connection.open_tunnel(RequestHeadType::from(head)).await?;
let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io))));
(head, framed)
}),
)
Ok((head, framed))
}
.boxed_local()
}
fn open_tunnel_extra(
@@ -147,48 +147,47 @@ where
head: Rc<RequestHead>,
extra_headers: Option<HeaderMap>,
addr: Option<net::SocketAddr>,
) -> Box<
dyn Future<
Item = (ResponseHead, Framed<BoxedSocket, ClientCodec>),
Error = SendRequestError,
>,
) -> LocalBoxFuture<
'static,
Result<(ResponseHead, Framed<BoxedSocket, ClientCodec>), SendRequestError>,
> {
Box::new(
self.0
// connect to the host
.call(ClientConnect {
let fut = self.0.call(ClientConnect {
uri: head.uri.clone(),
addr,
})
.from_err()
});
async move {
let connection = fut.await?;
// send request
.and_then(move |connection| {
connection.open_tunnel(RequestHeadType::Rc(head, extra_headers))
})
.map(|(head, framed)| {
let (head, framed) = connection
.open_tunnel(RequestHeadType::Rc(head, extra_headers))
.await?;
let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io))));
(head, framed)
}),
)
Ok((head, framed))
}
.boxed_local()
}
}
trait AsyncSocket {
fn as_read(&self) -> &dyn AsyncRead;
fn as_read_mut(&mut self) -> &mut dyn AsyncRead;
fn as_write(&mut self) -> &mut dyn AsyncWrite;
fn as_read(&self) -> &(dyn AsyncRead + Unpin);
fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin);
fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin);
}
struct Socket<T: AsyncRead + AsyncWrite>(T);
struct Socket<T: AsyncRead + AsyncWrite + Unpin>(T);
impl<T: AsyncRead + AsyncWrite> AsyncSocket for Socket<T> {
fn as_read(&self) -> &dyn AsyncRead {
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncSocket for Socket<T> {
fn as_read(&self) -> &(dyn AsyncRead + Unpin) {
&self.0
}
fn as_read_mut(&mut self) -> &mut dyn AsyncRead {
fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin) {
&mut self.0
}
fn as_write(&mut self) -> &mut dyn AsyncWrite {
fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin) {
&mut self.0
}
}
@@ -201,30 +200,37 @@ impl fmt::Debug for BoxedSocket {
}
}
impl io::Read for BoxedSocket {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.as_read_mut().read(buf)
}
}
impl AsyncRead for BoxedSocket {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.0.as_read().prepare_uninitialized_buffer(buf)
}
}
impl io::Write for BoxedSocket {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.as_write().write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.as_write().flush()
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(self.get_mut().0.as_read_mut()).poll_read(cx, buf)
}
}
impl AsyncWrite for BoxedSocket {
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.0.as_write().shutdown()
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(self.get_mut().0.as_write()).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(self.get_mut().0.as_write()).poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Pin::new(self.get_mut().0.as_write()).poll_shutdown(cx)
}
}

View File

@@ -82,7 +82,7 @@ impl FrozenClientRequest {
/// Send a streaming body.
pub fn send_stream<S, E>(&self, stream: S) -> SendClientRequest
where
S: Stream<Item = Bytes, Error = E> + 'static,
S: Stream<Item = Result<Bytes, E>> + Unpin + 'static,
E: Into<Error> + 'static,
{
RequestSender::Rc(self.head.clone(), None).send_stream(
@@ -203,7 +203,7 @@ impl FrozenSendBuilder {
/// Complete request construction and send a streaming body.
pub fn send_stream<S, E>(self, stream: S) -> SendClientRequest
where
S: Stream<Item = Bytes, Error = E> + 'static,
S: Stream<Item = Result<Bytes, E>> + Unpin + 'static,
E: Into<Error> + 'static,
{
if let Some(e) = self.err {

View File

@@ -7,18 +7,18 @@
//! use awc::Client;
//!
//! fn main() {
//! System::new("test").block_on(lazy(|| {
//! System::new("test").block_on(async {
//! let mut client = Client::default();
//!
//! client.get("http://www.rust-lang.org") // <- Create request builder
//! .header("User-Agent", "Actix-web")
//! .send() // <- Send http request
//! .map_err(|_| ())
//! .await
//! .and_then(|response| { // <- server http response
//! println!("Response: {:?}", response);
//! Ok(())
//! })
//! }));
//! });
//! }
//! ```
use std::cell::RefCell;
@@ -52,23 +52,22 @@ use self::connect::{Connect, ConnectorWrapper};
/// An HTTP Client
///
/// ```rust
/// # use futures::future::{Future, lazy};
/// use actix_rt::System;
/// use awc::Client;
///
/// fn main() {
/// System::new("test").block_on(lazy(|| {
/// System::new("test").block_on(async {
/// let mut client = Client::default();
///
/// client.get("http://www.rust-lang.org") // <- Create request builder
/// .header("User-Agent", "Actix-web")
/// .send() // <- Send http request
/// .map_err(|_| ())
/// .await
/// .and_then(|response| { // <- server http response
/// println!("Response: {:?}", response);
/// Ok(())
/// })
/// }));
/// });
/// }
/// ```
#[derive(Clone)]

View File

@@ -37,21 +37,21 @@ const HTTPS_ENCODING: &str = "gzip, deflate";
/// builder-like pattern.
///
/// ```rust
/// use futures::future::{Future, lazy};
/// use actix_rt::System;
///
/// fn main() {
/// System::new("test").block_on(lazy(|| {
/// awc::Client::new()
/// System::new("test").block_on(async {
/// let response = awc::Client::new()
/// .get("http://www.rust-lang.org") // <- Create request builder
/// .header("User-Agent", "Actix-web")
/// .send() // <- Send http request
/// .map_err(|_| ())
/// .and_then(|response| { // <- server http response
/// .await;
///
/// response.and_then(|response| { // <- server http response
/// println!("Response: {:?}", response);
/// Ok(())
/// })
/// }));
/// });
/// }
/// ```
pub struct ClientRequest {
@@ -158,7 +158,7 @@ impl ClientRequest {
///
/// ```rust
/// fn main() {
/// # actix_rt::System::new("test").block_on(futures::future::lazy(|| {
/// # actix_rt::System::new("test").block_on(futures::future::lazy(|_| {
/// let req = awc::Client::new()
/// .get("http://www.rust-lang.org")
/// .set(awc::http::header::Date::now())
@@ -186,13 +186,13 @@ impl ClientRequest {
/// use awc::{http, Client};
///
/// fn main() {
/// # actix_rt::System::new("test").block_on(futures::future::lazy(|| {
/// # actix_rt::System::new("test").block_on(async {
/// let req = Client::new()
/// .get("http://www.rust-lang.org")
/// .header("X-TEST", "value")
/// .header(http::header::CONTENT_TYPE, "application/json");
/// # Ok::<_, ()>(())
/// # }));
/// # });
/// }
/// ```
pub fn header<K, V>(mut self, key: K, value: V) -> Self
@@ -309,9 +309,8 @@ impl ClientRequest {
///
/// ```rust
/// # use actix_rt::System;
/// # use futures::future::{lazy, Future};
/// fn main() {
/// System::new("test").block_on(lazy(|| {
/// System::new("test").block_on(async {
/// awc::Client::new().get("https://www.rust-lang.org")
/// .cookie(
/// awc::http::Cookie::build("name", "value")
@@ -322,12 +321,12 @@ impl ClientRequest {
/// .finish(),
/// )
/// .send()
/// .map_err(|_| ())
/// .await
/// .and_then(|response| {
/// println!("Response: {:?}", response);
/// Ok(())
/// })
/// }));
/// });
/// }
/// ```
pub fn cookie(mut self, cookie: Cookie<'_>) -> Self {
@@ -382,6 +381,27 @@ impl ClientRequest {
}
}
/// Sets the query part of the request
pub fn query<T: Serialize>(
mut self,
query: &T,
) -> Result<Self, serde_urlencoded::ser::Error> {
let mut parts = self.head.uri.clone().into_parts();
if let Some(path_and_query) = parts.path_and_query {
let query = serde_urlencoded::to_string(query)?;
let path = path_and_query.path();
parts.path_and_query = format!("{}?{}", path, query).parse().ok();
match Uri::from_parts(parts) {
Ok(uri) => self.head.uri = uri,
Err(e) => self.err = Some(e.into()),
}
}
Ok(self)
}
/// Freeze request builder and construct `FrozenClientRequest`,
/// which could be used for sending same request multiple times.
pub fn freeze(self) -> Result<FrozenClientRequest, FreezeRequestError> {
@@ -457,7 +477,7 @@ impl ClientRequest {
/// Set an streaming body and generate `ClientRequest`.
pub fn send_stream<S, E>(self, stream: S) -> SendClientRequest
where
S: Stream<Item = Bytes, Error = E> + 'static,
S: Stream<Item = Result<Bytes, E>> + Unpin + 'static,
E: Into<Error> + 'static,
{
let slf = match self.prep_for_sending() {
@@ -690,4 +710,13 @@ mod tests {
"Bearer someS3cr3tAutht0k3n"
);
}
#[test]
fn client_query() {
let req = Client::new()
.get("/")
.query(&[("key1", "val1"), ("key2", "val2")])
.unwrap();
assert_eq!(req.get_uri().query().unwrap(), "key1=val1&key2=val2");
}
}

View File

@@ -1,9 +1,11 @@
use std::cell::{Ref, RefMut};
use std::fmt;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::{Bytes, BytesMut};
use futures::{Async, Future, Poll, Stream};
use futures::{ready, Future, Stream};
use actix_http::cookie::Cookie;
use actix_http::error::{CookieParseError, PayloadError};
@@ -104,7 +106,7 @@ impl<S> ClientResponse<S> {
impl<S> ClientResponse<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
S: Stream<Item = Result<Bytes, PayloadError>>,
{
/// Loads http response's body.
pub fn body(&mut self) -> MessageBody<S> {
@@ -125,13 +127,12 @@ where
impl<S> Stream for ClientResponse<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{
type Item = Bytes;
type Error = PayloadError;
type Item = Result<Bytes, PayloadError>;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
self.payload.poll()
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.get_mut().payload).poll_next(cx)
}
}
@@ -155,7 +156,7 @@ pub struct MessageBody<S> {
impl<S> MessageBody<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
S: Stream<Item = Result<Bytes, PayloadError>>,
{
/// Create `MessageBody` for request.
pub fn new(res: &mut ClientResponse<S>) -> MessageBody<S> {
@@ -198,23 +199,24 @@ where
impl<S> Future for MessageBody<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{
type Item = Bytes;
type Error = PayloadError;
type Output = Result<Bytes, PayloadError>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(err) = self.err.take() {
return Err(err);
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
if let Some(err) = this.err.take() {
return Poll::Ready(Err(err));
}
if let Some(len) = self.length.take() {
if len > self.fut.as_ref().unwrap().limit {
return Err(PayloadError::Overflow);
if let Some(len) = this.length.take() {
if len > this.fut.as_ref().unwrap().limit {
return Poll::Ready(Err(PayloadError::Overflow));
}
}
self.fut.as_mut().unwrap().poll()
Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx)
}
}
@@ -233,7 +235,7 @@ pub struct JsonBody<S, U> {
impl<S, U> JsonBody<S, U>
where
S: Stream<Item = Bytes, Error = PayloadError>,
S: Stream<Item = Result<Bytes, PayloadError>>,
U: DeserializeOwned,
{
/// Create `JsonBody` for request.
@@ -279,27 +281,35 @@ where
}
}
impl<T, U> Future for JsonBody<T, U>
impl<T, U> Unpin for JsonBody<T, U>
where
T: Stream<Item = Bytes, Error = PayloadError>,
T: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
U: DeserializeOwned,
{
type Item = U;
type Error = JsonPayloadError;
}
fn poll(&mut self) -> Poll<U, JsonPayloadError> {
impl<T, U> Future for JsonBody<T, U>
where
T: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
U: DeserializeOwned,
{
type Output = Result<U, JsonPayloadError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let Some(err) = self.err.take() {
return Err(err);
return Poll::Ready(Err(err));
}
if let Some(len) = self.length.take() {
if len > self.fut.as_ref().unwrap().limit {
return Err(JsonPayloadError::Payload(PayloadError::Overflow));
return Poll::Ready(Err(JsonPayloadError::Payload(
PayloadError::Overflow,
)));
}
}
let body = futures::try_ready!(self.fut.as_mut().unwrap().poll());
Ok(Async::Ready(serde_json::from_slice::<U>(&body)?))
let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?;
Poll::Ready(serde_json::from_slice::<U>(&body).map_err(JsonPayloadError::from))
}
}
@@ -321,24 +331,25 @@ impl<S> ReadBody<S> {
impl<S> Future for ReadBody<S>
where
S: Stream<Item = Bytes, Error = PayloadError>,
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{
type Item = Bytes;
type Error = PayloadError;
type Output = Result<Bytes, PayloadError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
return match self.stream.poll()? {
Async::Ready(Some(chunk)) => {
if (self.buf.len() + chunk.len()) > self.limit {
Err(PayloadError::Overflow)
return match Pin::new(&mut this.stream).poll_next(cx)? {
Poll::Ready(Some(chunk)) => {
if (this.buf.len() + chunk.len()) > this.limit {
Poll::Ready(Err(PayloadError::Overflow))
} else {
self.buf.extend_from_slice(&chunk);
this.buf.extend_from_slice(&chunk);
continue;
}
}
Async::Ready(None) => Ok(Async::Ready(self.buf.take().freeze())),
Async::NotReady => Ok(Async::NotReady),
Poll::Ready(None) => Poll::Ready(Ok(this.buf.take().freeze())),
Poll::Pending => Poll::Pending,
};
}
}
@@ -348,22 +359,23 @@ where
mod tests {
use super::*;
use actix_http_test::block_on;
use futures::Async;
use serde::{Deserialize, Serialize};
use crate::{http::header, test::TestResponse};
#[test]
fn test_body() {
let mut req = TestResponse::with_header(header::CONTENT_LENGTH, "xxxx").finish();
match req.body().poll().err().unwrap() {
block_on(async {
let mut req =
TestResponse::with_header(header::CONTENT_LENGTH, "xxxx").finish();
match req.body().await.err().unwrap() {
PayloadError::UnknownLength => (),
_ => unreachable!("error"),
}
let mut req =
TestResponse::with_header(header::CONTENT_LENGTH, "1000000").finish();
match req.body().poll().err().unwrap() {
match req.body().await.err().unwrap() {
PayloadError::Overflow => (),
_ => unreachable!("error"),
}
@@ -371,18 +383,16 @@ mod tests {
let mut req = TestResponse::default()
.set_payload(Bytes::from_static(b"test"))
.finish();
match req.body().poll().ok().unwrap() {
Async::Ready(bytes) => assert_eq!(bytes, Bytes::from_static(b"test")),
_ => unreachable!("error"),
}
assert_eq!(req.body().await.ok().unwrap(), Bytes::from_static(b"test"));
let mut req = TestResponse::default()
.set_payload(Bytes::from_static(b"11111111111111"))
.finish();
match req.body().limit(5).poll().err().unwrap() {
match req.body().limit(5).await.err().unwrap() {
PayloadError::Overflow => (),
_ => unreachable!("error"),
}
})
}
#[derive(Serialize, Deserialize, PartialEq, Debug)]
@@ -406,8 +416,9 @@ mod tests {
#[test]
fn test_json_body() {
block_on(async {
let mut req = TestResponse::default().finish();
let json = block_on(JsonBody::<_, MyObject>::new(&mut req));
let json = JsonBody::<_, MyObject>::new(&mut req).await;
assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
let mut req = TestResponse::default()
@@ -416,7 +427,7 @@ mod tests {
header::HeaderValue::from_static("application/text"),
)
.finish();
let json = block_on(JsonBody::<_, MyObject>::new(&mut req));
let json = JsonBody::<_, MyObject>::new(&mut req).await;
assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
let mut req = TestResponse::default()
@@ -430,7 +441,7 @@ mod tests {
)
.finish();
let json = block_on(JsonBody::<_, MyObject>::new(&mut req).limit(100));
let json = JsonBody::<_, MyObject>::new(&mut req).limit(100).await;
assert!(json_eq(
json.err().unwrap(),
JsonPayloadError::Payload(PayloadError::Overflow)
@@ -448,12 +459,13 @@ mod tests {
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.finish();
let json = block_on(JsonBody::<_, MyObject>::new(&mut req));
let json = JsonBody::<_, MyObject>::new(&mut req).await;
assert_eq!(
json.ok().unwrap(),
MyObject {
name: "test".to_owned()
}
);
})
}
}

View File

@@ -1,13 +1,15 @@
use std::net;
use std::pin::Pin;
use std::rc::Rc;
use std::time::{Duration, Instant};
use std::task::{Context, Poll};
use std::time::Duration;
use bytes::Bytes;
use derive_more::From;
use futures::{try_ready, Async, Future, Poll, Stream};
use futures::{future::LocalBoxFuture, ready, Future, Stream};
use serde::Serialize;
use serde_json;
use tokio_timer::Delay;
use tokio_timer::{delay_for, Delay};
use actix_http::body::{Body, BodyStream};
use actix_http::encoding::Decoder;
@@ -47,7 +49,7 @@ impl Into<SendRequestError> for PrepForSendingError {
#[must_use = "futures do nothing unless polled"]
pub enum SendClientRequest {
Fut(
Box<dyn Future<Item = ClientResponse, Error = SendRequestError>>,
LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>>,
Option<Delay>,
bool,
),
@@ -56,41 +58,51 @@ pub enum SendClientRequest {
impl SendClientRequest {
pub(crate) fn new(
send: Box<dyn Future<Item = ClientResponse, Error = SendRequestError>>,
send: LocalBoxFuture<'static, Result<ClientResponse, SendRequestError>>,
response_decompress: bool,
timeout: Option<Duration>,
) -> SendClientRequest {
let delay = timeout.map(|t| Delay::new(Instant::now() + t));
let delay = timeout.map(|t| delay_for(t));
SendClientRequest::Fut(send, delay, response_decompress)
}
}
impl Future for SendClientRequest {
type Item = ClientResponse<Decoder<Payload<PayloadStream>>>;
type Error = SendRequestError;
type Output =
Result<ClientResponse<Decoder<Payload<PayloadStream>>>, SendRequestError>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self {
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
match this {
SendClientRequest::Fut(send, delay, response_decompress) => {
if delay.is_some() {
match delay.poll() {
Ok(Async::NotReady) => (),
_ => return Err(SendRequestError::Timeout),
match Pin::new(delay.as_mut().unwrap()).poll(cx) {
Poll::Pending => (),
_ => return Poll::Ready(Err(SendRequestError::Timeout)),
}
}
let res = try_ready!(send.poll()).map_body(|head, payload| {
let res = ready!(Pin::new(send).poll(cx)).map(|res| {
res.map_body(|head, payload| {
if *response_decompress {
Payload::Stream(Decoder::from_headers(payload, &head.headers))
Payload::Stream(Decoder::from_headers(
payload,
&head.headers,
))
} else {
Payload::Stream(Decoder::new(payload, ContentEncoding::Identity))
Payload::Stream(Decoder::new(
payload,
ContentEncoding::Identity,
))
}
})
});
Ok(Async::Ready(res))
Poll::Ready(res)
}
SendClientRequest::Err(ref mut e) => match e.take() {
Some(e) => Err(e),
Some(e) => Poll::Ready(Err(e)),
None => panic!("Attempting to call completed future"),
},
}
@@ -223,7 +235,7 @@ impl RequestSender {
stream: S,
) -> SendClientRequest
where
S: Stream<Item = Bytes, Error = E> + 'static,
S: Stream<Item = Result<Bytes, E>> + Unpin + 'static,
E: Into<Error> + 'static,
{
self.send_body(

View File

@@ -7,7 +7,6 @@ use std::{fmt, str};
use actix_codec::Framed;
use actix_http::cookie::{Cookie, CookieJar};
use actix_http::{ws, Payload, RequestHead};
use futures::future::{err, Either, Future};
use percent_encoding::percent_encode;
use tokio_timer::Timeout;
@@ -210,27 +209,26 @@ impl WebsocketsRequest {
}
/// Complete request construction and connect to a websockets server.
pub fn connect(
pub async fn connect(
mut self,
) -> impl Future<Item = (ClientResponse, Framed<BoxedSocket, Codec>), Error = WsClientError>
{
) -> Result<(ClientResponse, Framed<BoxedSocket, Codec>), WsClientError> {
if let Some(e) = self.err.take() {
return Either::A(err(e.into()));
return Err(e.into());
}
// validate uri
let uri = &self.head.uri;
if uri.host().is_none() {
return Either::A(err(InvalidUrl::MissingHost.into()));
return Err(InvalidUrl::MissingHost.into());
} else if uri.scheme_part().is_none() {
return Either::A(err(InvalidUrl::MissingScheme.into()));
return Err(InvalidUrl::MissingScheme.into());
} else if let Some(scheme) = uri.scheme_part() {
match scheme.as_str() {
"http" | "ws" | "https" | "wss" => (),
_ => return Either::A(err(InvalidUrl::UnknownScheme.into())),
_ => return Err(InvalidUrl::UnknownScheme.into()),
}
} else {
return Either::A(err(InvalidUrl::UnknownScheme.into()));
return Err(InvalidUrl::UnknownScheme.into());
}
if !self.head.headers.contains_key(header::HOST) {
@@ -294,13 +292,23 @@ impl WebsocketsRequest {
.config
.connector
.borrow_mut()
.open_tunnel(head, self.addr)
.from_err()
.and_then(move |(head, framed)| {
.open_tunnel(head, self.addr);
// set request timeout
let (head, framed) = if let Some(timeout) = self.config.timeout {
Timeout::new(fut, timeout)
.await
.map_err(|_| SendRequestError::Timeout.into())
.and_then(|res| res)?
} else {
fut.await?
};
// verify response
if head.status != StatusCode::SWITCHING_PROTOCOLS {
return Err(WsClientError::InvalidResponseStatus(head.status));
}
// Check for "UPGRADE" to websocket header
let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) {
if let Ok(s) = hdr.to_str() {
@@ -315,20 +323,17 @@ impl WebsocketsRequest {
log::trace!("Invalid upgrade header");
return Err(WsClientError::InvalidUpgradeHeader);
}
// Check for "CONNECTION" header
if let Some(conn) = head.headers.get(&header::CONNECTION) {
if let Ok(s) = conn.to_str() {
if !s.to_ascii_lowercase().contains("upgrade") {
log::trace!("Invalid connection header: {}", s);
return Err(WsClientError::InvalidConnectionHeader(
conn.clone(),
));
return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
}
} else {
log::trace!("Invalid connection header: {:?}", conn);
return Err(WsClientError::InvalidConnectionHeader(
conn.clone(),
));
return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
}
} else {
log::trace!("Missing connection header");
@@ -364,20 +369,6 @@ impl WebsocketsRequest {
}
}),
))
});
// set request timeout
if let Some(timeout) = self.config.timeout {
Either::B(Either::A(Timeout::new(fut, timeout).map_err(|e| {
if let Some(e) = e.into_inner() {
e
} else {
SendRequestError::Timeout.into()
}
})))
} else {
Either::B(Either::B(fut))
}
}
}
@@ -398,6 +389,8 @@ impl fmt::Debug for WebsocketsRequest {
#[cfg(test)]
mod tests {
use actix_web::test::block_on;
use super::*;
use crate::Client;
@@ -472,6 +465,7 @@ mod tests {
#[test]
fn basics() {
block_on(async {
let req = Client::new()
.ws("http://localhost/")
.origin("test-origin")
@@ -493,14 +487,11 @@ mod tests {
header::HeaderValue::from_static("json")
);
let _ = actix_http_test::block_fn(move || req.connect());
let _ = req.connect().await;
assert!(Client::new().ws("/").connect().poll().is_err());
assert!(Client::new().ws("http:///test").connect().poll().is_err());
assert!(Client::new()
.ws("hmm://test.com/")
.connect()
.poll()
.is_err());
assert!(Client::new().ws("/").connect().await.is_err());
assert!(Client::new().ws("http:///test").connect().await.is_err());
assert!(Client::new().ws("hmm://test.com/").connect().await.is_err());
})
}
}

View File

@@ -9,12 +9,12 @@ use bytes::Bytes;
use flate2::read::GzDecoder;
use flate2::write::GzEncoder;
use flate2::Compression;
use futures::Future;
use futures::future::ok;
use rand::Rng;
use actix_http::HttpService;
use actix_http_test::TestServer;
use actix_service::{service_fn, NewService};
use actix_http_test::{block_on, TestServer};
use actix_service::pipeline_factory;
use actix_web::http::Cookie;
use actix_web::middleware::{BodyEncoding, Compress};
use actix_web::{http::header, web, App, Error, HttpMessage, HttpRequest, HttpResponse};
@@ -44,52 +44,59 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
#[test]
fn test_simple() {
let mut srv =
TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::new(App::new().service(
web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR))),
))
});
let request = srv.get("/").header("x-test", "111").send();
let mut response = srv.block_on(request).unwrap();
let mut response = request.await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.block_on(response.body()).unwrap();
let bytes = response.body().await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
let mut response = srv.block_on(srv.post("/").send()).unwrap();
let mut response = srv.post("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.block_on(response.body()).unwrap();
let bytes = response.body().await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
// camel case
let response = srv.block_on(srv.post("/").camel_case().send()).unwrap();
let response = srv.post("/").camel_case().send().await.unwrap();
assert!(response.status().is_success());
})
}
#[test]
fn test_json() {
let mut srv = TestServer::new(|| {
HttpService::new(App::new().service(
web::resource("/").route(web::to(|_: web::Json<String>| HttpResponse::Ok())),
))
block_on(async {
let srv = TestServer::start(|| {
HttpService::new(
App::new().service(
web::resource("/")
.route(web::to(|_: web::Json<String>| HttpResponse::Ok())),
),
)
});
let request = srv
.get("/")
.header("x-test", "111")
.send_json(&"TEST".to_string());
let response = srv.block_on(request).unwrap();
let response = request.await.unwrap();
assert!(response.status().is_success());
})
}
#[test]
fn test_form() {
let mut srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::new(App::new().service(web::resource("/").route(web::to(
|_: web::Form<HashMap<String, String>>| HttpResponse::Ok(),
))))
@@ -99,40 +106,55 @@ fn test_form() {
let _ = data.insert("key".to_string(), "TEST".to_string());
let request = srv.get("/").header("x-test", "111").send_form(&data);
let response = srv.block_on(request).unwrap();
let response = request.await.unwrap();
assert!(response.status().is_success());
})
}
#[test]
fn test_timeout() {
let mut srv = TestServer::new(|| {
HttpService::new(App::new().service(web::resource("/").route(web::to_async(
block_on(async {
let srv = TestServer::start(|| {
HttpService::new(App::new().service(web::resource("/").route(web::to(
|| {
tokio_timer::sleep(Duration::from_millis(200))
.then(|_| Ok::<_, Error>(HttpResponse::Ok().body(STR)))
async {
tokio_timer::delay_for(Duration::from_millis(200)).await;
Ok::<_, Error>(HttpResponse::Ok().body(STR))
}
},
))))
});
let client = srv.execute(|| {
awc::Client::build()
let connector = awc::Connector::new()
.connector(actix_connect::new_connector(
actix_connect::start_default_resolver(),
))
.timeout(Duration::from_secs(15))
.finish();
let client = awc::Client::build()
.connector(connector)
.timeout(Duration::from_millis(50))
.finish()
});
.finish();
let request = client.get(srv.url("/")).send();
match srv.block_on(request) {
match request.await {
Err(SendRequestError::Timeout) => (),
_ => panic!(),
}
})
}
#[test]
fn test_timeout_override() {
let mut srv = TestServer::new(|| {
HttpService::new(App::new().service(web::resource("/").route(web::to_async(
block_on(async {
let srv = TestServer::start(|| {
HttpService::new(App::new().service(web::resource("/").route(web::to(
|| {
tokio_timer::sleep(Duration::from_millis(200))
.then(|_| Ok::<_, Error>(HttpResponse::Ok().body(STR)))
async {
tokio_timer::delay_for(Duration::from_millis(200)).await;
Ok::<_, Error>(HttpResponse::Ok().body(STR))
}
},
))))
});
@@ -144,86 +166,92 @@ fn test_timeout_override() {
.get(srv.url("/"))
.timeout(Duration::from_millis(50))
.send();
match srv.block_on(request) {
match request.await {
Err(SendRequestError::Timeout) => (),
_ => panic!(),
}
})
}
#[test]
fn test_connection_reuse() {
block_on(async {
let num = Arc::new(AtomicUsize::new(0));
let num2 = num.clone();
let mut srv = TestServer::new(move || {
let srv = TestServer::start(move || {
let num2 = num2.clone();
service_fn(move |io| {
pipeline_factory(move |io| {
num2.fetch_add(1, Ordering::Relaxed);
Ok(io)
ok(io)
})
.and_then(HttpService::new(
App::new().service(web::resource("/").route(web::to(|| HttpResponse::Ok()))),
))
.and_then(HttpService::new(App::new().service(
web::resource("/").route(web::to(|| HttpResponse::Ok())),
)))
});
let client = awc::Client::default();
// req 1
let request = client.get(srv.url("/")).send();
let response = srv.block_on(request).unwrap();
let response = request.await.unwrap();
assert!(response.status().is_success());
// req 2
let req = client.post(srv.url("/"));
let response = srv.block_on_fn(move || req.send()).unwrap();
let response = req.send().await.unwrap();
assert!(response.status().is_success());
// one connection
assert_eq!(num.load(Ordering::Relaxed), 1);
})
}
#[test]
fn test_connection_force_close() {
block_on(async {
let num = Arc::new(AtomicUsize::new(0));
let num2 = num.clone();
let mut srv = TestServer::new(move || {
let srv = TestServer::start(move || {
let num2 = num2.clone();
service_fn(move |io| {
pipeline_factory(move |io| {
num2.fetch_add(1, Ordering::Relaxed);
Ok(io)
ok(io)
})
.and_then(HttpService::new(
App::new().service(web::resource("/").route(web::to(|| HttpResponse::Ok()))),
))
.and_then(HttpService::new(App::new().service(
web::resource("/").route(web::to(|| HttpResponse::Ok())),
)))
});
let client = awc::Client::default();
// req 1
let request = client.get(srv.url("/")).force_close().send();
let response = srv.block_on(request).unwrap();
let response = request.await.unwrap();
assert!(response.status().is_success());
// req 2
let req = client.post(srv.url("/")).force_close();
let response = srv.block_on_fn(move || req.send()).unwrap();
let response = req.send().await.unwrap();
assert!(response.status().is_success());
// two connection
assert_eq!(num.load(Ordering::Relaxed), 2);
})
}
#[test]
fn test_connection_server_close() {
block_on(async {
let num = Arc::new(AtomicUsize::new(0));
let num2 = num.clone();
let mut srv = TestServer::new(move || {
let srv = TestServer::start(move || {
let num2 = num2.clone();
service_fn(move |io| {
pipeline_factory(move |io| {
num2.fetch_add(1, Ordering::Relaxed);
Ok(io)
ok(io)
})
.and_then(HttpService::new(
App::new().service(
@@ -237,28 +265,30 @@ fn test_connection_server_close() {
// req 1
let request = client.get(srv.url("/")).send();
let response = srv.block_on(request).unwrap();
let response = request.await.unwrap();
assert!(response.status().is_success());
// req 2
let req = client.post(srv.url("/"));
let response = srv.block_on_fn(move || req.send()).unwrap();
let response = req.send().await.unwrap();
assert!(response.status().is_success());
// two connection
assert_eq!(num.load(Ordering::Relaxed), 2);
})
}
#[test]
fn test_connection_wait_queue() {
block_on(async {
let num = Arc::new(AtomicUsize::new(0));
let num2 = num.clone();
let mut srv = TestServer::new(move || {
let srv = TestServer::start(move || {
let num2 = num2.clone();
service_fn(move |io| {
pipeline_factory(move |io| {
num2.fetch_add(1, Ordering::Relaxed);
Ok(io)
ok(io)
})
.and_then(HttpService::new(App::new().service(
web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR))),
@@ -271,39 +301,37 @@ fn test_connection_wait_queue() {
// req 1
let request = client.get(srv.url("/")).send();
let mut response = srv.block_on(request).unwrap();
let mut response = request.await.unwrap();
assert!(response.status().is_success());
// req 2
let req2 = client.post(srv.url("/"));
let req2_fut = srv.execute(move || {
let mut fut = req2.send();
assert!(fut.poll().unwrap().is_not_ready());
fut
});
let req2_fut = req2.send();
// read response 1
let bytes = srv.block_on(response.body()).unwrap();
let bytes = response.body().await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
// req 2
let response = srv.block_on(req2_fut).unwrap();
let response = req2_fut.await.unwrap();
assert!(response.status().is_success());
// two connection
assert_eq!(num.load(Ordering::Relaxed), 1);
})
}
#[test]
fn test_connection_wait_queue_force_close() {
block_on(async {
let num = Arc::new(AtomicUsize::new(0));
let num2 = num.clone();
let mut srv = TestServer::new(move || {
let srv = TestServer::start(move || {
let num2 = num2.clone();
service_fn(move |io| {
pipeline_factory(move |io| {
num2.fetch_add(1, Ordering::Relaxed);
Ok(io)
ok(io)
})
.and_then(HttpService::new(
App::new().service(
@@ -319,32 +347,30 @@ fn test_connection_wait_queue_force_close() {
// req 1
let request = client.get(srv.url("/")).send();
let mut response = srv.block_on(request).unwrap();
let mut response = request.await.unwrap();
assert!(response.status().is_success());
// req 2
let req2 = client.post(srv.url("/"));
let req2_fut = srv.execute(move || {
let mut fut = req2.send();
assert!(fut.poll().unwrap().is_not_ready());
fut
});
let req2_fut = req2.send();
// read response 1
let bytes = srv.block_on(response.body()).unwrap();
let bytes = response.body().await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
// req 2
let response = srv.block_on(req2_fut).unwrap();
let response = req2_fut.await.unwrap();
assert!(response.status().is_success());
// two connection
assert_eq!(num.load(Ordering::Relaxed), 2);
})
}
#[test]
fn test_with_query_parameter() {
let mut srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::new(App::new().service(web::resource("/").to(
|req: HttpRequest| {
if req.query_string().contains("qp") {
@@ -356,15 +382,19 @@ fn test_with_query_parameter() {
)))
});
let res = srv
.block_on(awc::Client::new().get(srv.url("/?qp=5")).send())
let res = awc::Client::new()
.get(srv.url("/?qp=5"))
.send()
.await
.unwrap();
assert!(res.status().is_success());
})
}
#[test]
fn test_no_decompress() {
let mut srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::new(App::new().wrap(Compress::default()).service(
web::resource("/").route(web::to(|| {
let mut res = HttpResponse::Ok().body(STR);
@@ -374,13 +404,16 @@ fn test_no_decompress() {
))
});
let mut res = srv
.block_on(awc::Client::new().get(srv.url("/")).no_decompress().send())
let mut res = awc::Client::new()
.get(srv.url("/"))
.no_decompress()
.send()
.await
.unwrap();
assert!(res.status().is_success());
// read response
let bytes = srv.block_on(res.body()).unwrap();
let bytes = res.body().await.unwrap();
let mut e = GzDecoder::new(&bytes[..]);
let mut dec = Vec::new();
@@ -388,22 +421,28 @@ fn test_no_decompress() {
assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref()));
// POST
let mut res = srv
.block_on(awc::Client::new().post(srv.url("/")).no_decompress().send())
let mut res = awc::Client::new()
.post(srv.url("/"))
.no_decompress()
.send()
.await
.unwrap();
assert!(res.status().is_success());
let bytes = srv.block_on(res.body()).unwrap();
let bytes = res.body().await.unwrap();
let mut e = GzDecoder::new(&bytes[..]);
let mut dec = Vec::new();
e.read_to_end(&mut dec).unwrap();
assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_client_gzip_encoding() {
let mut srv = TestServer::new(|| {
HttpService::new(App::new().service(web::resource("/").route(web::to(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::new(App::new().service(web::resource("/").route(web::to(
|| {
let mut e = GzEncoder::new(Vec::new(), Compression::default());
e.write_all(STR.as_ref()).unwrap();
let data = e.finish().unwrap();
@@ -411,22 +450,26 @@ fn test_client_gzip_encoding() {
HttpResponse::Ok()
.header("content-encoding", "gzip")
.body(data)
}))))
},
))))
});
// client request
let mut response = srv.block_on(srv.post("/").send()).unwrap();
let mut response = srv.post("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.block_on(response.body()).unwrap();
let bytes = response.body().await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
#[test]
fn test_client_gzip_encoding_large() {
let mut srv = TestServer::new(|| {
HttpService::new(App::new().service(web::resource("/").route(web::to(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::new(App::new().service(web::resource("/").route(web::to(
|| {
let mut e = GzEncoder::new(Vec::new(), Compression::default());
e.write_all(STR.repeat(10).as_ref()).unwrap();
let data = e.finish().unwrap();
@@ -434,26 +477,29 @@ fn test_client_gzip_encoding_large() {
HttpResponse::Ok()
.header("content-encoding", "gzip")
.body(data)
}))))
},
))))
});
// client request
let mut response = srv.block_on(srv.post("/").send()).unwrap();
let mut response = srv.post("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.block_on(response.body()).unwrap();
let bytes = response.body().await.unwrap();
assert_eq!(bytes, Bytes::from(STR.repeat(10)));
})
}
#[test]
fn test_client_gzip_encoding_large_random() {
block_on(async {
let data = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(100_000)
.collect::<String>();
let mut srv = TestServer::new(|| {
let srv = TestServer::start(|| {
HttpService::new(App::new().service(web::resource("/").route(web::to(
|data: Bytes| {
let mut e = GzEncoder::new(Vec::new(), Compression::default());
@@ -467,17 +513,19 @@ fn test_client_gzip_encoding_large_random() {
});
// client request
let mut response = srv.block_on(srv.post("/").send_body(data.clone())).unwrap();
let mut response = srv.post("/").send_body(data.clone()).await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.block_on(response.body()).unwrap();
let bytes = response.body().await.unwrap();
assert_eq!(bytes, Bytes::from(data));
})
}
#[test]
fn test_client_brotli_encoding() {
let mut srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::new(App::new().service(web::resource("/").route(web::to(
|data: Bytes| {
let mut e = BrotliEncoder::new(Vec::new(), 5);
@@ -491,12 +539,13 @@ fn test_client_brotli_encoding() {
});
// client request
let mut response = srv.block_on(srv.post("/").send_body(STR)).unwrap();
let mut response = srv.post("/").send_body(STR).await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.block_on(response.body()).unwrap();
let bytes = response.body().await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
})
}
// #[test]
@@ -506,7 +555,7 @@ fn test_client_brotli_encoding() {
// .take(70_000)
// .collect::<String>();
// let mut srv = test::TestServer::new(|app| {
// let srv = test::TestServer::start(|app| {
// app.handler(|req: &HttpRequest| {
// req.body()
// .and_then(move |bytes: Bytes| {
@@ -524,11 +573,11 @@ fn test_client_brotli_encoding() {
// .content_encoding(http::ContentEncoding::Br)
// .body(data.clone())
// .unwrap();
// let response = srv.execute(request.send()).unwrap();
// let response = request.send().await.unwrap();
// assert!(response.status().is_success());
// // read response
// let bytes = srv.execute(response.body()).unwrap();
// let bytes = response.body().await.unwrap();
// assert_eq!(bytes.len(), data.len());
// assert_eq!(bytes, Bytes::from(data));
// }
@@ -536,7 +585,7 @@ fn test_client_brotli_encoding() {
// #[cfg(feature = "brotli")]
// #[test]
// fn test_client_deflate_encoding() {
// let mut srv = test::TestServer::new(|app| {
// let srv = test::TestServer::start(|app| {
// app.handler(|req: &HttpRequest| {
// req.body()
// .and_then(|bytes: Bytes| {
@@ -569,7 +618,7 @@ fn test_client_brotli_encoding() {
// .take(70_000)
// .collect::<String>();
// let mut srv = test::TestServer::new(|app| {
// let srv = test::TestServer::start(|app| {
// app.handler(|req: &HttpRequest| {
// req.body()
// .and_then(|bytes: Bytes| {
@@ -597,7 +646,7 @@ fn test_client_brotli_encoding() {
// #[test]
// fn test_client_streaming_explicit() {
// let mut srv = test::TestServer::new(|app| {
// let srv = test::TestServer::start(|app| {
// app.handler(|req: &HttpRequest| {
// req.body()
// .map_err(Error::from)
@@ -624,7 +673,7 @@ fn test_client_brotli_encoding() {
// #[test]
// fn test_body_streaming_implicit() {
// let mut srv = test::TestServer::new(|app| {
// let srv = test::TestServer::start(|app| {
// app.handler(|_| {
// let body = once(Ok(Bytes::from_static(STR.as_ref())));
// HttpResponse::Ok()
@@ -644,11 +693,9 @@ fn test_client_brotli_encoding() {
#[test]
fn test_client_cookie_handling() {
fn err() -> Error {
use std::io::{Error as IoError, ErrorKind};
// stub some generic error
Error::from(IoError::from(ErrorKind::NotFound))
}
block_on(async {
let cookie1 = Cookie::build("cookie1", "value1").finish();
let cookie2 = Cookie::build("cookie2", "value2")
.domain("www.example.org")
@@ -660,49 +707,64 @@ fn test_client_cookie_handling() {
let cookie1b = cookie1.clone();
let cookie2b = cookie2.clone();
let mut srv = TestServer::new(move || {
let srv = TestServer::start(move || {
let cookie1 = cookie1b.clone();
let cookie2 = cookie2b.clone();
HttpService::new(App::new().route(
"/",
web::to(move |req: HttpRequest| {
let cookie1 = cookie1.clone();
let cookie2 = cookie2.clone();
async move {
// Check cookies were sent correctly
req.cookie("cookie1")
.ok_or_else(err)
let res: Result<(), Error> = req
.cookie("cookie1")
.ok_or(())
.and_then(|c1| {
if c1.value() == "value1" {
Ok(())
} else {
Err(err())
Err(())
}
})
.and_then(|()| req.cookie("cookie2").ok_or_else(err))
.and_then(|()| req.cookie("cookie2").ok_or(()))
.and_then(|c2| {
if c2.value() == "value2" {
Ok(())
} else {
Err(err())
Err(())
}
})
.map_err(|_| {
Error::from(IoError::from(ErrorKind::NotFound))
});
if let Err(e) = res {
Err(e)
} else {
// Send some cookies back
.map(|_| {
Ok::<_, Error>(
HttpResponse::Ok()
.cookie(cookie1.clone())
.cookie(cookie2.clone())
.finish()
})
.cookie(cookie1)
.cookie(cookie2)
.finish(),
)
}
}
}),
))
});
let request = srv.get("/").cookie(cookie1.clone()).cookie(cookie2.clone());
let response = srv.block_on(request.send()).unwrap();
let response = request.send().await.unwrap();
assert!(response.status().is_success());
let c1 = response.cookie("cookie1").expect("Missing cookie1");
assert_eq!(c1, cookie1);
let c2 = response.cookie("cookie2").expect("Missing cookie2");
assert_eq!(c2, cookie2);
})
}
// #[test]
@@ -727,17 +789,18 @@ fn test_client_cookie_handling() {
// let req = client::ClientRequest::get(format!("http://{}/", addr).as_str())
// .finish()
// .unwrap();
// let response = sys.block_on(req.send()).unwrap();
// let response = req.send().await.unwrap();
// assert!(response.status().is_success());
// // read response
// let bytes = sys.block_on(response.body()).unwrap();
// let bytes = response.body().await.unwrap();
// assert_eq!(bytes, Bytes::from_static(b"welcome!"));
// }
#[test]
fn client_basic_auth() {
let mut srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::new(App::new().route(
"/",
web::to(|req: HttpRequest| {
@@ -759,13 +822,15 @@ fn client_basic_auth() {
// set authorization header to Basic <base64 encoded username:password>
let request = srv.get("/").basic_auth("username", Some("password"));
let response = srv.block_on(request.send()).unwrap();
let response = request.send().await.unwrap();
assert!(response.status().is_success());
})
}
#[test]
fn client_bearer_auth() {
let mut srv = TestServer::new(|| {
block_on(async {
let srv = TestServer::start(|| {
HttpService::new(App::new().route(
"/",
web::to(|req: HttpRequest| {
@@ -787,6 +852,7 @@ fn client_bearer_auth() {
// set authorization header to Bearer <token>
let request = srv.get("/").bearer_auth("someS3cr3tAutht0k3n");
let response = srv.block_on(request.send()).unwrap();
let response = request.send().await.unwrap();
assert!(response.status().is_success());
})
}

View File

@@ -1,69 +1,81 @@
#![cfg(feature = "rust-tls")]
use rustls::{
internal::pemfile::{certs, pkcs8_private_keys},
ClientConfig, NoClientAuth,
};
#![cfg(feature = "rustls")]
use rust_tls::ClientConfig;
use std::fs::File;
use std::io::{BufReader, Result};
use std::io::Result;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::HttpService;
use actix_http_test::TestServer;
use actix_server::ssl::RustlsAcceptor;
use actix_service::{service_fn, NewService};
use actix_http_test::{block_on, TestServer};
use actix_server::ssl::OpensslAcceptor;
use actix_service::{pipeline_factory, ServiceFactory};
use actix_web::http::Version;
use actix_web::{web, App, HttpResponse};
use futures::future::ok;
use open_ssl::ssl::{SslAcceptor, SslFiletype, SslMethod, SslVerifyMode};
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<RustlsAcceptor<T, ()>> {
use rustls::ServerConfig;
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> {
// load ssl keys
let mut config = ServerConfig::new(NoClientAuth::new());
let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap());
let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap());
let cert_chain = certs(cert_file).unwrap();
let mut keys = pkcs8_private_keys(key_file).unwrap();
config.set_single_cert(cert_chain, keys.remove(0)).unwrap();
let protos = vec![b"h2".to_vec()];
config.set_protocols(&protos);
Ok(RustlsAcceptor::new(config))
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
builder.set_verify_callback(SslVerifyMode::NONE, |_, _| true);
builder
.set_private_key_file("../tests/key.pem", SslFiletype::PEM)
.unwrap();
builder
.set_certificate_chain_file("../tests/cert.pem")
.unwrap();
builder.set_alpn_select_callback(|_, protos| {
const H2: &[u8] = b"\x02h2";
if protos.windows(3).any(|window| window == H2) {
Ok(b"h2")
} else {
Err(open_ssl::ssl::AlpnError::NOACK)
}
});
builder.set_alpn_protos(b"\x02h2")?;
Ok(actix_server::ssl::OpensslAcceptor::new(builder.build()))
}
mod danger {
pub struct NoCertificateVerification {}
impl rustls::ServerCertVerifier for NoCertificateVerification {
impl rust_tls::ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
_roots: &rustls::RootCertStore,
_presented_certs: &[rustls::Certificate],
_roots: &rust_tls::RootCertStore,
_presented_certs: &[rust_tls::Certificate],
_dns_name: webpki::DNSNameRef<'_>,
_ocsp: &[u8],
) -> Result<rustls::ServerCertVerified, rustls::TLSError> {
Ok(rustls::ServerCertVerified::assertion())
) -> Result<rust_tls::ServerCertVerified, rust_tls::TLSError> {
Ok(rust_tls::ServerCertVerified::assertion())
}
}
}
#[test]
fn test_connection_reuse_h2() {
let rustls = ssl_acceptor().unwrap();
// #[test]
fn _test_connection_reuse_h2() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let num = Arc::new(AtomicUsize::new(0));
let num2 = num.clone();
let mut srv = TestServer::new(move || {
let srv = TestServer::start(move || {
let num2 = num2.clone();
service_fn(move |io| {
pipeline_factory(move |io| {
num2.fetch_add(1, Ordering::Relaxed);
Ok(io)
ok(io)
})
.and_then(rustls.clone().map_err(|e| println!("Rustls error: {}", e)))
.and_then(
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e)),
)
.and_then(
HttpService::build()
.h2(App::new()
.service(web::resource("/").route(web::to(|| HttpResponse::Ok()))))
.h2(App::new().service(
web::resource("/").route(web::to(|| HttpResponse::Ok())),
))
.map_err(|_| ()),
)
});
@@ -82,15 +94,16 @@ fn test_connection_reuse_h2() {
// req 1
let request = client.get(srv.surl("/")).send();
let response = srv.block_on(request).unwrap();
let response = request.await.unwrap();
assert!(response.status().is_success());
// req 2
let req = client.post(srv.surl("/"));
let response = srv.block_on_fn(move || req.send()).unwrap();
let response = req.send().await.unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
// one connection
assert_eq!(num.load(Ordering::Relaxed), 1);
})
}

View File

@@ -1,5 +1,5 @@
#![cfg(feature = "ssl")]
use openssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode};
#![cfg(feature = "openssl")]
use open_ssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode};
use std::io::Result;
use std::sync::atomic::{AtomicUsize, Ordering};
@@ -7,11 +7,12 @@ use std::sync::Arc;
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::HttpService;
use actix_http_test::TestServer;
use actix_http_test::{block_on, TestServer};
use actix_server::ssl::OpensslAcceptor;
use actix_service::{service_fn, NewService};
use actix_service::{pipeline_factory, ServiceFactory};
use actix_web::http::Version;
use actix_web::{web, App, HttpResponse};
use futures::future::ok;
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> {
// load ssl keys
@@ -27,7 +28,7 @@ fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> {
if protos.windows(3).any(|window| window == H2) {
Ok(b"h2")
} else {
Err(openssl::ssl::AlpnError::NOACK)
Err(open_ssl::ssl::AlpnError::NOACK)
}
});
builder.set_alpn_protos(b"\x02h2")?;
@@ -36,15 +37,16 @@ fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> {
#[test]
fn test_connection_reuse_h2() {
block_on(async {
let openssl = ssl_acceptor().unwrap();
let num = Arc::new(AtomicUsize::new(0));
let num2 = num.clone();
let mut srv = TestServer::new(move || {
let srv = TestServer::start(move || {
let num2 = num2.clone();
service_fn(move |io| {
pipeline_factory(move |io| {
num2.fetch_add(1, Ordering::Relaxed);
Ok(io)
ok(io)
})
.and_then(
openssl
@@ -53,8 +55,9 @@ fn test_connection_reuse_h2() {
)
.and_then(
HttpService::build()
.h2(App::new()
.service(web::resource("/").route(web::to(|| HttpResponse::Ok()))))
.h2(App::new().service(
web::resource("/").route(web::to(|| HttpResponse::Ok())),
))
.map_err(|_| ()),
)
});
@@ -72,15 +75,16 @@ fn test_connection_reuse_h2() {
// req 1
let request = client.get(srv.surl("/")).send();
let response = srv.block_on(request).unwrap();
let response = request.await.unwrap();
assert!(response.status().is_success());
// req 2
let req = client.post(srv.surl("/"));
let response = srv.block_on_fn(move || req.send()).unwrap();
let response = req.send().await.unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
// one connection
assert_eq!(num.load(Ordering::Relaxed), 1);
})
}

View File

@@ -2,81 +2,82 @@ use std::io;
use actix_codec::Framed;
use actix_http::{body::BodySize, h1, ws, Error, HttpService, Request, Response};
use actix_http_test::TestServer;
use actix_http_test::{block_on, TestServer};
use bytes::{Bytes, BytesMut};
use futures::future::ok;
use futures::{Future, Sink, Stream};
use futures::{SinkExt, StreamExt};
fn ws_service(req: ws::Frame) -> impl Future<Item = ws::Message, Error = io::Error> {
async fn ws_service(req: ws::Frame) -> Result<ws::Message, io::Error> {
match req {
ws::Frame::Ping(msg) => ok(ws::Message::Pong(msg)),
ws::Frame::Ping(msg) => Ok(ws::Message::Pong(msg)),
ws::Frame::Text(text) => {
let text = if let Some(pl) = text {
String::from_utf8(Vec::from(pl.as_ref())).unwrap()
} else {
String::new()
};
ok(ws::Message::Text(text))
Ok(ws::Message::Text(text))
}
ws::Frame::Binary(bin) => ok(ws::Message::Binary(
ws::Frame::Binary(bin) => Ok(ws::Message::Binary(
bin.map(|e| e.freeze())
.unwrap_or_else(|| Bytes::from(""))
.into(),
)),
ws::Frame::Close(reason) => ok(ws::Message::Close(reason)),
_ => ok(ws::Message::Close(None)),
ws::Frame::Close(reason) => Ok(ws::Message::Close(reason)),
_ => Ok(ws::Message::Close(None)),
}
}
#[test]
fn test_simple() {
let mut srv = TestServer::new(|| {
block_on(async {
let mut srv = TestServer::start(|| {
HttpService::build()
.upgrade(|(req, framed): (Request, Framed<_, _>)| {
.upgrade(|(req, mut framed): (Request, Framed<_, _>)| {
async move {
let res = ws::handshake_response(req.head()).finish();
// send handshake response
framed
.send(h1::Message::Item((res.drop_body(), BodySize::None)))
.map_err(|e: io::Error| e.into())
.and_then(|framed| {
.await?;
// start websocket service
let framed = framed.into_framed(ws::Codec::new());
ws::Transport::with(framed, ws_service)
})
ws::Transport::with(framed, ws_service).await
}
})
.finish(|_| ok::<_, Error>(Response::NotFound()))
});
// client service
let framed = srv.ws().unwrap();
let framed = srv
.block_on(framed.send(ws::Message::Text("text".to_string())))
let mut framed = srv.ws().await.unwrap();
framed
.send(ws::Message::Text("text".to_string()))
.await
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text")))));
let item = framed.next().await.unwrap().unwrap();
assert_eq!(item, ws::Frame::Text(Some(BytesMut::from("text"))));
let framed = srv
.block_on(framed.send(ws::Message::Binary("text".into())))
framed
.send(ws::Message::Binary("text".into()))
.await
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
let item = framed.next().await.unwrap().unwrap();
assert_eq!(
item,
Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into())))
ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))
);
let framed = srv
.block_on(framed.send(ws::Message::Ping("text".into())))
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into())));
framed.send(ws::Message::Ping("text".into())).await.unwrap();
let item = framed.next().await.unwrap().unwrap();
assert_eq!(item, ws::Frame::Pong("text".to_string().into()));
let framed = srv
.block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))))
framed
.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
.await
.unwrap();
let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(
item,
Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into())))
);
let item = framed.next().await.unwrap().unwrap();
assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into())));
})
}

View File

@@ -1,22 +1,18 @@
use futures::IntoFuture;
use actix_web::{
get, middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer,
};
use actix_web::{get, middleware, web, App, HttpRequest, HttpResponse, HttpServer};
#[get("/resource1/{name}/index.html")]
fn index(req: HttpRequest, name: web::Path<String>) -> String {
async fn index(req: HttpRequest, name: web::Path<String>) -> String {
println!("REQ: {:?}", req);
format!("Hello: {}!\r\n", name)
}
fn index_async(req: HttpRequest) -> impl IntoFuture<Item = &'static str, Error = Error> {
async fn index_async(req: HttpRequest) -> &'static str {
println!("REQ: {:?}", req);
Ok("Hello world!\r\n")
"Hello world!\r\n"
}
#[get("/")]
fn no_params() -> &'static str {
async fn no_params() -> &'static str {
"Hello world!\r\n"
}
@@ -39,9 +35,9 @@ fn main() -> std::io::Result<()> {
.default_service(
web::route().to(|| HttpResponse::MethodNotAllowed()),
)
.route(web::get().to_async(index_async)),
.route(web::get().to(index_async)),
)
.service(web::resource("/test1.html").to(|| "Test\r\n"))
.service(web::resource("/test1.html").to(|| async { "Test\r\n" }))
})
.bind("127.0.0.1:8080")?
.workers(1)

View File

@@ -1,26 +1,27 @@
use actix_http::Error;
use actix_rt::System;
use futures::{future::lazy, Future};
fn main() -> Result<(), Error> {
std::env::set_var("RUST_LOG", "actix_http=trace");
env_logger::init();
System::new("test").block_on(lazy(|| {
awc::Client::new()
.get("https://www.rust-lang.org/") // <- Create request builder
System::new("test").block_on(async {
let client = awc::Client::new();
// Create request builder, configure request and send
let mut response = client
.get("https://www.rust-lang.org/")
.header("User-Agent", "Actix-web")
.send() // <- Send http request
.from_err()
.and_then(|mut response| {
// <- server http response
.send()
.await?;
// server http response
println!("Response: {:?}", response);
// read response body
response
.body()
.from_err()
.map(|body| println!("Downloaded: {:?} bytes", body.len()))
let body = response.body().await?;
println!("Downloaded: {:?} bytes", body.len());
Ok(())
})
}))
}

View File

@@ -1,26 +1,24 @@
use futures::IntoFuture;
use actix_web::{
get, middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer,
};
#[get("/resource1/{name}/index.html")]
fn index(req: HttpRequest, name: web::Path<String>) -> String {
async fn index(req: HttpRequest, name: web::Path<String>) -> String {
println!("REQ: {:?}", req);
format!("Hello: {}!\r\n", name)
}
fn index_async(req: HttpRequest) -> impl IntoFuture<Item = &'static str, Error = Error> {
async fn index_async(req: HttpRequest) -> Result<&'static str, Error> {
println!("REQ: {:?}", req);
Ok("Hello world!\r\n")
}
#[get("/")]
fn no_params() -> &'static str {
async fn no_params() -> &'static str {
"Hello world!\r\n"
}
#[cfg(feature = "uds")]
#[cfg(unix)]
fn main() -> std::io::Result<()> {
std::env::set_var("RUST_LOG", "actix_server=info,actix_web=info");
env_logger::init();
@@ -40,14 +38,14 @@ fn main() -> std::io::Result<()> {
.default_service(
web::route().to(|| HttpResponse::MethodNotAllowed()),
)
.route(web::get().to_async(index_async)),
.route(web::get().to(index_async)),
)
.service(web::resource("/test1.html").to(|| "Test\r\n"))
.service(web::resource("/test1.html").to(|| async { "Test\r\n" }))
})
.bind_uds("/Users/fafhrd91/uds-test")?
.workers(1)
.run()
}
#[cfg(not(feature = "uds"))]
#[cfg(not(unix))]
fn main() {}

View File

@@ -1,14 +1,17 @@
use std::cell::RefCell;
use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use actix_http::body::{Body, MessageBody};
use actix_service::boxed::{self, BoxedNewService};
use actix_service::{
apply_transform, IntoNewService, IntoTransform, NewService, Transform,
apply, apply_fn_factory, IntoServiceFactory, ServiceFactory, Transform,
};
use futures::{Future, IntoFuture};
use futures::future::{FutureExt, LocalBoxFuture};
use crate::app_service::{AppEntry, AppInit, AppRoutingFactory};
use crate::config::{AppConfig, AppConfigInner, ServiceConfig};
@@ -18,19 +21,19 @@ use crate::error::Error;
use crate::resource::Resource;
use crate::route::Route;
use crate::service::{
HttpServiceFactory, ServiceFactory, ServiceFactoryWrapper, ServiceRequest,
AppServiceFactory, HttpServiceFactory, ServiceFactoryWrapper, ServiceRequest,
ServiceResponse,
};
type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>;
type FnDataFactory =
Box<dyn Fn() -> Box<dyn Future<Item = Box<dyn DataFactory>, Error = ()>>>;
Box<dyn Fn() -> LocalBoxFuture<'static, Result<Box<dyn DataFactory>, ()>>>;
/// Application builder - structure that follows the builder pattern
/// for building application instances.
pub struct App<T, B> {
endpoint: T,
services: Vec<Box<dyn ServiceFactory>>,
services: Vec<Box<dyn AppServiceFactory>>,
default: Option<Rc<HttpNewService>>,
factory_ref: Rc<RefCell<Option<AppRoutingFactory>>>,
data: Vec<Box<dyn DataFactory>>,
@@ -61,7 +64,7 @@ impl App<AppEntry, Body> {
impl<T, B> App<T, B>
where
B: MessageBody,
T: NewService<
T: ServiceFactory<
Config = (),
Request = ServiceRequest,
Response = ServiceResponse<B>,
@@ -87,7 +90,7 @@ where
/// counter: Cell<usize>,
/// }
///
/// fn index(data: web::Data<MyData>) {
/// async fn index(data: web::Data<MyData>) {
/// data.counter.set(data.counter.get() + 1);
/// }
///
@@ -107,24 +110,30 @@ where
/// Set application data factory. This function is
/// similar to `.data()` but it accepts data factory. Data object get
/// constructed asynchronously during application initialization.
pub fn data_factory<F, Out>(mut self, data: F) -> Self
pub fn data_factory<F, Out, D, E>(mut self, data: F) -> Self
where
F: Fn() -> Out + 'static,
Out: IntoFuture + 'static,
Out::Error: std::fmt::Debug,
Out: Future<Output = Result<D, E>> + 'static,
D: 'static,
E: std::fmt::Debug,
{
self.data_factories.push(Box::new(move || {
Box::new(
data()
.into_future()
.map_err(|e| {
{
let fut = data();
async move {
match fut.await {
Err(e) => {
log::error!("Can not construct data instance: {:?}", e);
})
.map(|data| {
Err(())
}
Ok(data) => {
let data: Box<dyn DataFactory> = Box::new(Data::new(data));
data
}),
)
Ok(data)
}
}
}
}
.boxed_local()
}));
self
}
@@ -183,7 +192,7 @@ where
/// ```rust
/// use actix_web::{web, App, HttpResponse};
///
/// fn index(data: web::Path<(String, String)>) -> &'static str {
/// async fn index(data: web::Path<(String, String)>) -> &'static str {
/// "Welcome!"
/// }
///
@@ -238,7 +247,7 @@ where
/// ```rust
/// use actix_web::{web, App, HttpResponse};
///
/// fn index() -> &'static str {
/// async fn index() -> &'static str {
/// "Welcome!"
/// }
///
@@ -267,8 +276,8 @@ where
/// ```
pub fn default_service<F, U>(mut self, f: F) -> Self
where
F: IntoNewService<U>,
U: NewService<
F: IntoServiceFactory<U>,
U: ServiceFactory<
Config = (),
Request = ServiceRequest,
Response = ServiceResponse,
@@ -277,11 +286,9 @@ where
U::InitError: fmt::Debug,
{
// create and configure default resource
self.default = Some(Rc::new(boxed::new_service(
f.into_new_service().map_init_err(|e| {
log::error!("Can not construct default service: {:?}", e)
}),
)));
self.default = Some(Rc::new(boxed::factory(f.into_factory().map_init_err(
|e| log::error!("Can not construct default service: {:?}", e),
))));
self
}
@@ -295,7 +302,7 @@ where
/// ```rust
/// use actix_web::{web, App, HttpRequest, HttpResponse, Result};
///
/// fn index(req: HttpRequest) -> Result<HttpResponse> {
/// async fn index(req: HttpRequest) -> Result<HttpResponse> {
/// let url = req.url_for("youtube", &["asdlkjqme"])?;
/// assert_eq!(url.as_str(), "https://youtube.com/watch/asdlkjqme");
/// Ok(HttpResponse::Ok().into())
@@ -336,11 +343,10 @@ where
///
/// ```rust
/// use actix_service::Service;
/// # use futures::Future;
/// use actix_web::{middleware, web, App};
/// use actix_web::http::{header::CONTENT_TYPE, HeaderValue};
///
/// fn index() -> &'static str {
/// async fn index() -> &'static str {
/// "Welcome!"
/// }
///
@@ -350,11 +356,11 @@ where
/// .route("/index.html", web::get().to(index));
/// }
/// ```
pub fn wrap<M, B1, F>(
pub fn wrap<M, B1>(
self,
mw: F,
mw: M,
) -> App<
impl NewService<
impl ServiceFactory<
Config = (),
Request = ServiceRequest,
Response = ServiceResponse<B1>,
@@ -372,11 +378,9 @@ where
InitError = (),
>,
B1: MessageBody,
F: IntoTransform<M, T::Service>,
{
let endpoint = apply_transform(mw, self.endpoint);
App {
endpoint,
endpoint: apply(mw, self.endpoint),
data: self.data,
data_factories: self.data_factories,
services: self.services,
@@ -397,23 +401,25 @@ where
///
/// ```rust
/// use actix_service::Service;
/// # use futures::Future;
/// use actix_web::{web, App};
/// use actix_web::http::{header::CONTENT_TYPE, HeaderValue};
///
/// fn index() -> &'static str {
/// async fn index() -> &'static str {
/// "Welcome!"
/// }
///
/// fn main() {
/// let app = App::new()
/// .wrap_fn(|req, srv|
/// srv.call(req).map(|mut res| {
/// .wrap_fn(|req, srv| {
/// let fut = srv.call(req);
/// async {
/// let mut res = fut.await?;
/// res.headers_mut().insert(
/// CONTENT_TYPE, HeaderValue::from_static("text/plain"),
/// );
/// res
/// }))
/// Ok(res)
/// }
/// })
/// .route("/index.html", web::get().to(index));
/// }
/// ```
@@ -421,7 +427,7 @@ where
self,
mw: F,
) -> App<
impl NewService<
impl ServiceFactory<
Config = (),
Request = ServiceRequest,
Response = ServiceResponse<B1>,
@@ -433,16 +439,26 @@ where
where
B1: MessageBody,
F: FnMut(ServiceRequest, &mut T::Service) -> R + Clone,
R: IntoFuture<Item = ServiceResponse<B1>, Error = Error>,
R: Future<Output = Result<ServiceResponse<B1>, Error>>,
{
self.wrap(mw)
App {
endpoint: apply_fn_factory(self.endpoint, mw),
data: self.data,
data_factories: self.data_factories,
services: self.services,
default: self.default,
factory_ref: self.factory_ref,
config: self.config,
external: self.external,
_t: PhantomData,
}
}
}
impl<T, B> IntoNewService<AppInit<T, B>> for App<T, B>
impl<T, B> IntoServiceFactory<AppInit<T, B>> for App<T, B>
where
B: MessageBody,
T: NewService<
T: ServiceFactory<
Config = (),
Request = ServiceRequest,
Response = ServiceResponse<B>,
@@ -450,7 +466,7 @@ where
InitError = (),
>,
{
fn into_new_service(self) -> AppInit<T, B> {
fn into_factory(self) -> AppInit<T, B> {
AppInit {
data: Rc::new(self.data),
data_factories: Rc::new(self.data_factories),
@@ -468,27 +484,28 @@ where
mod tests {
use actix_service::Service;
use bytes::Bytes;
use futures::{Future, IntoFuture};
use futures::future::{ok, Future};
use super::*;
use crate::http::{header, HeaderValue, Method, StatusCode};
use crate::middleware::DefaultHeaders;
use crate::service::{ServiceRequest, ServiceResponse};
use crate::test::{
block_fn, block_on, call_service, init_service, read_body, TestRequest,
};
use crate::test::{block_on, call_service, init_service, read_body, TestRequest};
use crate::{web, Error, HttpRequest, HttpResponse};
#[test]
fn test_default_resource() {
block_on(async {
let mut srv = init_service(
App::new().service(web::resource("/test").to(|| HttpResponse::Ok())),
);
)
.await;
let req = TestRequest::with_uri("/test").to_request();
let resp = block_fn(|| srv.call(req)).unwrap();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let req = TestRequest::with_uri("/blah").to_request();
let resp = block_on(srv.call(req)).unwrap();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
let mut srv = init_service(
@@ -497,53 +514,59 @@ mod tests {
.service(
web::resource("/test2")
.default_service(|r: ServiceRequest| {
r.into_response(HttpResponse::Created())
ok(r.into_response(HttpResponse::Created()))
})
.route(web::get().to(|| HttpResponse::Ok())),
)
.default_service(|r: ServiceRequest| {
r.into_response(HttpResponse::MethodNotAllowed())
ok(r.into_response(HttpResponse::MethodNotAllowed()))
}),
);
)
.await;
let req = TestRequest::with_uri("/blah").to_request();
let resp = block_on(srv.call(req)).unwrap();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
let req = TestRequest::with_uri("/test2").to_request();
let resp = block_on(srv.call(req)).unwrap();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let req = TestRequest::with_uri("/test2")
.method(Method::POST)
.to_request();
let resp = block_on(srv.call(req)).unwrap();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::CREATED);
})
}
#[test]
fn test_data_factory() {
block_on(async {
let mut srv =
init_service(App::new().data_factory(|| Ok::<_, ()>(10usize)).service(
init_service(App::new().data_factory(|| ok::<_, ()>(10usize)).service(
web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()),
));
))
.await;
let req = TestRequest::default().to_request();
let resp = block_on(srv.call(req)).unwrap();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let mut srv =
init_service(App::new().data_factory(|| Ok::<_, ()>(10u32)).service(
init_service(App::new().data_factory(|| ok::<_, ()>(10u32)).service(
web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()),
));
))
.await;
let req = TestRequest::default().to_request();
let resp = block_on(srv.call(req)).unwrap();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
})
}
fn md<S, B>(
req: ServiceRequest,
srv: &mut S,
) -> impl IntoFuture<Item = ServiceResponse<B>, Error = Error>
) -> impl Future<Output = Result<ServiceResponse<B>, Error>>
where
S: Service<
Request = ServiceRequest,
@@ -551,95 +574,122 @@ mod tests {
Error = Error,
>,
{
srv.call(req).map(|mut res| {
let fut = srv.call(req);
async move {
let mut res = fut.await?;
res.headers_mut()
.insert(header::CONTENT_TYPE, HeaderValue::from_static("0001"));
res
})
Ok(res)
}
}
#[test]
fn test_wrap() {
let mut srv = init_service(
block_on(async {
let mut srv =
init_service(
App::new()
.wrap(md)
.wrap(DefaultHeaders::new().header(
header::CONTENT_TYPE,
HeaderValue::from_static("0001"),
))
.route("/test", web::get().to(|| HttpResponse::Ok())),
);
)
.await;
let req = TestRequest::with_uri("/test").to_request();
let resp = call_service(&mut srv, req);
let resp = call_service(&mut srv, req).await;
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
HeaderValue::from_static("0001")
);
})
}
#[test]
fn test_router_wrap() {
let mut srv = init_service(
block_on(async {
let mut srv =
init_service(
App::new()
.route("/test", web::get().to(|| HttpResponse::Ok()))
.wrap(md),
);
.wrap(DefaultHeaders::new().header(
header::CONTENT_TYPE,
HeaderValue::from_static("0001"),
)),
)
.await;
let req = TestRequest::with_uri("/test").to_request();
let resp = call_service(&mut srv, req);
let resp = call_service(&mut srv, req).await;
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
HeaderValue::from_static("0001")
);
})
}
#[test]
fn test_wrap_fn() {
block_on(async {
let mut srv = init_service(
App::new()
.wrap_fn(|req, srv| {
srv.call(req).map(|mut res| {
let fut = srv.call(req);
async move {
let mut res = fut.await?;
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("0001"),
);
res
})
Ok(res)
}
})
.service(web::resource("/test").to(|| HttpResponse::Ok())),
);
)
.await;
let req = TestRequest::with_uri("/test").to_request();
let resp = call_service(&mut srv, req);
let resp = call_service(&mut srv, req).await;
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
HeaderValue::from_static("0001")
);
})
}
#[test]
fn test_router_wrap_fn() {
block_on(async {
let mut srv = init_service(
App::new()
.route("/test", web::get().to(|| HttpResponse::Ok()))
.wrap_fn(|req, srv| {
srv.call(req).map(|mut res| {
let fut = srv.call(req);
async {
let mut res = fut.await?;
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("0001"),
);
res
})
Ok(res)
}
}),
);
)
.await;
let req = TestRequest::with_uri("/test").to_request();
let resp = call_service(&mut srv, req);
let resp = call_service(&mut srv, req).await;
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
HeaderValue::from_static("0001")
);
})
}
#[test]
fn test_external_resource() {
block_on(async {
let mut srv = init_service(
App::new()
.external_resource("youtube", "https://youtube.com/watch/{video_id}")
@@ -652,11 +702,13 @@ mod tests {
))
}),
),
);
)
.await;
let req = TestRequest::with_uri("/test").to_request();
let resp = call_service(&mut srv, req);
let resp = call_service(&mut srv, req).await;
assert_eq!(resp.status(), StatusCode::OK);
let body = read_body(resp);
let body = read_body(resp).await;
assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345"));
})
}
}

View File

@@ -1,14 +1,16 @@
use std::cell::RefCell;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use actix_http::{Extensions, Request, Response};
use actix_router::{Path, ResourceDef, ResourceInfo, Router, Url};
use actix_server_config::ServerConfig;
use actix_service::boxed::{self, BoxedNewService, BoxedService};
use actix_service::{service_fn, NewService, Service};
use futures::future::{ok, Either, FutureResult};
use futures::{Async, Future, Poll};
use actix_service::{service_fn, Service, ServiceFactory};
use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready};
use crate::config::{AppConfig, AppService};
use crate::data::DataFactory;
@@ -16,23 +18,20 @@ use crate::error::Error;
use crate::guard::Guard;
use crate::request::{HttpRequest, HttpRequestPool};
use crate::rmap::ResourceMap;
use crate::service::{ServiceFactory, ServiceRequest, ServiceResponse};
use crate::service::{AppServiceFactory, ServiceRequest, ServiceResponse};
type Guards = Vec<Box<dyn Guard>>;
type HttpService = BoxedService<ServiceRequest, ServiceResponse, Error>;
type HttpNewService = BoxedNewService<(), ServiceRequest, ServiceResponse, Error, ()>;
type BoxedResponse = Either<
FutureResult<ServiceResponse, Error>,
Box<dyn Future<Item = ServiceResponse, Error = Error>>,
>;
type BoxedResponse = LocalBoxFuture<'static, Result<ServiceResponse, Error>>;
type FnDataFactory =
Box<dyn Fn() -> Box<dyn Future<Item = Box<dyn DataFactory>, Error = ()>>>;
Box<dyn Fn() -> LocalBoxFuture<'static, Result<Box<dyn DataFactory>, ()>>>;
/// Service factory to convert `Request` to a `ServiceRequest<S>`.
/// It also executes data factories.
pub struct AppInit<T, B>
where
T: NewService<
T: ServiceFactory<
Config = (),
Request = ServiceRequest,
Response = ServiceResponse<B>,
@@ -44,15 +43,15 @@ where
pub(crate) data: Rc<Vec<Box<dyn DataFactory>>>,
pub(crate) data_factories: Rc<Vec<FnDataFactory>>,
pub(crate) config: RefCell<AppConfig>,
pub(crate) services: Rc<RefCell<Vec<Box<dyn ServiceFactory>>>>,
pub(crate) services: Rc<RefCell<Vec<Box<dyn AppServiceFactory>>>>,
pub(crate) default: Option<Rc<HttpNewService>>,
pub(crate) factory_ref: Rc<RefCell<Option<AppRoutingFactory>>>,
pub(crate) external: RefCell<Vec<ResourceDef>>,
}
impl<T, B> NewService for AppInit<T, B>
impl<T, B> ServiceFactory for AppInit<T, B>
where
T: NewService<
T: ServiceFactory<
Config = (),
Request = ServiceRequest,
Response = ServiceResponse<B>,
@@ -71,8 +70,8 @@ where
fn new_service(&self, cfg: &ServerConfig) -> Self::Future {
// update resource default service
let default = self.default.clone().unwrap_or_else(|| {
Rc::new(boxed::new_service(service_fn(|req: ServiceRequest| {
Ok(req.into_response(Response::NotFound().finish()))
Rc::new(boxed::factory(service_fn(|req: ServiceRequest| {
ok(req.into_response(Response::NotFound().finish()))
})))
});
@@ -135,23 +134,25 @@ where
}
}
#[pin_project::pin_project]
pub struct AppInitResult<T, B>
where
T: NewService,
T: ServiceFactory,
{
endpoint: Option<T::Service>,
#[pin]
endpoint_fut: T::Future,
rmap: Rc<ResourceMap>,
config: AppConfig,
data: Rc<Vec<Box<dyn DataFactory>>>,
data_factories: Vec<Box<dyn DataFactory>>,
data_factories_fut: Vec<Box<dyn Future<Item = Box<dyn DataFactory>, Error = ()>>>,
data_factories_fut: Vec<LocalBoxFuture<'static, Result<Box<dyn DataFactory>, ()>>>,
_t: PhantomData<B>,
}
impl<T, B> Future for AppInitResult<T, B>
where
T: NewService<
T: ServiceFactory<
Config = (),
Request = ServiceRequest,
Response = ServiceResponse<B>,
@@ -159,48 +160,49 @@ where
InitError = (),
>,
{
type Item = AppInitService<T::Service, B>;
type Error = ();
type Output = Result<AppInitService<T::Service, B>, ()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
// async data factories
let mut idx = 0;
while idx < self.data_factories_fut.len() {
match self.data_factories_fut[idx].poll()? {
Async::Ready(f) => {
self.data_factories.push(f);
let _ = self.data_factories_fut.remove(idx);
while idx < this.data_factories_fut.len() {
match Pin::new(&mut this.data_factories_fut[idx]).poll(cx)? {
Poll::Ready(f) => {
this.data_factories.push(f);
let _ = this.data_factories_fut.remove(idx);
}
Async::NotReady => idx += 1,
Poll::Pending => idx += 1,
}
}
if self.endpoint.is_none() {
if let Async::Ready(srv) = self.endpoint_fut.poll()? {
self.endpoint = Some(srv);
if this.endpoint.is_none() {
if let Poll::Ready(srv) = this.endpoint_fut.poll(cx)? {
*this.endpoint = Some(srv);
}
}
if self.endpoint.is_some() && self.data_factories_fut.is_empty() {
if this.endpoint.is_some() && this.data_factories_fut.is_empty() {
// create app data container
let mut data = Extensions::new();
for f in self.data.iter() {
for f in this.data.iter() {
f.create(&mut data);
}
for f in &self.data_factories {
for f in this.data_factories.iter() {
f.create(&mut data);
}
Ok(Async::Ready(AppInitService {
service: self.endpoint.take().unwrap(),
rmap: self.rmap.clone(),
config: self.config.clone(),
Poll::Ready(Ok(AppInitService {
service: this.endpoint.take().unwrap(),
rmap: this.rmap.clone(),
config: this.config.clone(),
data: Rc::new(data),
pool: HttpRequestPool::create(),
}))
} else {
Ok(Async::NotReady)
Poll::Pending
}
}
}
@@ -226,8 +228,8 @@ where
type Error = T::Error;
type Future = T::Future;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.poll_ready()
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
@@ -270,7 +272,7 @@ pub struct AppRoutingFactory {
default: Rc<HttpNewService>,
}
impl NewService for AppRoutingFactory {
impl ServiceFactory for AppRoutingFactory {
type Config = ();
type Request = ServiceRequest;
type Response = ServiceResponse;
@@ -288,7 +290,7 @@ impl NewService for AppRoutingFactory {
CreateAppRoutingItem::Future(
Some(path.clone()),
guards.borrow_mut().take(),
service.new_service(&()),
service.new_service(&()).boxed_local(),
)
})
.collect(),
@@ -298,14 +300,14 @@ impl NewService for AppRoutingFactory {
}
}
type HttpServiceFut = Box<dyn Future<Item = HttpService, Error = ()>>;
type HttpServiceFut = LocalBoxFuture<'static, Result<HttpService, ()>>;
/// Create app service
#[doc(hidden)]
pub struct AppRoutingFactoryResponse {
fut: Vec<CreateAppRoutingItem>,
default: Option<HttpService>,
default_fut: Option<Box<dyn Future<Item = HttpService, Error = ()>>>,
default_fut: Option<LocalBoxFuture<'static, Result<HttpService, ()>>>,
}
enum CreateAppRoutingItem {
@@ -314,16 +316,15 @@ enum CreateAppRoutingItem {
}
impl Future for AppRoutingFactoryResponse {
type Item = AppRouting;
type Error = ();
type Output = Result<AppRouting, ()>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut done = true;
if let Some(ref mut fut) = self.default_fut {
match fut.poll()? {
Async::Ready(default) => self.default = Some(default),
Async::NotReady => done = false,
match Pin::new(fut).poll(cx)? {
Poll::Ready(default) => self.default = Some(default),
Poll::Pending => done = false,
}
}
@@ -334,11 +335,12 @@ impl Future for AppRoutingFactoryResponse {
ref mut path,
ref mut guards,
ref mut fut,
) => match fut.poll()? {
Async::Ready(service) => {
) => match Pin::new(fut).poll(cx) {
Poll::Ready(Ok(service)) => {
Some((path.take().unwrap(), guards.take(), service))
}
Async::NotReady => {
Poll::Ready(Err(_)) => return Poll::Ready(Err(())),
Poll::Pending => {
done = false;
None
}
@@ -364,13 +366,13 @@ impl Future for AppRoutingFactoryResponse {
}
router
});
Ok(Async::Ready(AppRouting {
Poll::Ready(Ok(AppRouting {
ready: None,
router: router.finish(),
default: self.default.take(),
}))
} else {
Ok(Async::NotReady)
Poll::Pending
}
}
}
@@ -387,11 +389,11 @@ impl Service for AppRouting {
type Error = Error;
type Future = BoxedResponse;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
if self.ready.is_none() {
Ok(Async::Ready(()))
Poll::Ready(Ok(()))
} else {
Ok(Async::NotReady)
Poll::Pending
}
}
@@ -413,7 +415,7 @@ impl Service for AppRouting {
default.call(req)
} else {
let req = req.into_parts().0;
Either::A(ok(ServiceResponse::new(req, Response::NotFound().finish())))
ok(ServiceResponse::new(req, Response::NotFound().finish())).boxed_local()
}
}
}
@@ -429,7 +431,7 @@ impl AppEntry {
}
}
impl NewService for AppEntry {
impl ServiceFactory for AppEntry {
type Config = ();
type Request = ServiceRequest;
type Response = ServiceResponse;
@@ -464,15 +466,16 @@ mod tests {
#[test]
fn drop_data() {
let data = Arc::new(AtomicBool::new(false));
{
test::block_on(async {
let mut app = test::init_service(
App::new()
.data(DropData(data.clone()))
.service(web::resource("/test").to(|| HttpResponse::Ok())),
);
)
.await;
let req = test::TestRequest::with_uri("/test").to_request();
let _ = test::block_on(app.call(req)).unwrap();
}
let _ = app.call(req).await.unwrap();
});
assert!(data.load(Ordering::Relaxed));
}
}

View File

@@ -3,7 +3,7 @@ use std::rc::Rc;
use actix_http::Extensions;
use actix_router::ResourceDef;
use actix_service::{boxed, IntoNewService, NewService};
use actix_service::{boxed, IntoServiceFactory, ServiceFactory};
use crate::data::{Data, DataFactory};
use crate::error::Error;
@@ -12,7 +12,7 @@ use crate::resource::Resource;
use crate::rmap::ResourceMap;
use crate::route::Route;
use crate::service::{
HttpServiceFactory, ServiceFactory, ServiceFactoryWrapper, ServiceRequest,
AppServiceFactory, HttpServiceFactory, ServiceFactoryWrapper, ServiceRequest,
ServiceResponse,
};
@@ -102,11 +102,11 @@ impl AppService {
&mut self,
rdef: ResourceDef,
guards: Option<Vec<Box<dyn Guard>>>,
service: F,
factory: F,
nested: Option<Rc<ResourceMap>>,
) where
F: IntoNewService<S>,
S: NewService<
F: IntoServiceFactory<S>,
S: ServiceFactory<
Config = (),
Request = ServiceRequest,
Response = ServiceResponse,
@@ -116,7 +116,7 @@ impl AppService {
{
self.services.push((
rdef,
boxed::new_service(service.into_new_service()),
boxed::factory(factory.into_factory()),
guards,
nested,
));
@@ -174,7 +174,7 @@ impl Default for AppConfigInner {
/// to set of external methods. This could help with
/// modularization of big application configuration.
pub struct ServiceConfig {
pub(crate) services: Vec<Box<dyn ServiceFactory>>,
pub(crate) services: Vec<Box<dyn AppServiceFactory>>,
pub(crate) data: Vec<Box<dyn DataFactory>>,
pub(crate) external: Vec<ResourceDef>,
}
@@ -251,17 +251,19 @@ mod tests {
#[test]
fn test_data() {
block_on(async {
let cfg = |cfg: &mut ServiceConfig| {
cfg.data(10usize);
};
let mut srv =
init_service(App::new().configure(cfg).service(
let mut srv = init_service(App::new().configure(cfg).service(
web::resource("/").to(|_: web::Data<usize>| HttpResponse::Ok()),
));
))
.await;
let req = TestRequest::default().to_request();
let resp = block_on(srv.call(req)).unwrap();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
})
}
// #[test]
@@ -298,6 +300,7 @@ mod tests {
#[test]
fn test_external_resource() {
block_on(async {
let mut srv = init_service(
App::new()
.configure(|cfg| {
@@ -315,33 +318,39 @@ mod tests {
))
}),
),
);
)
.await;
let req = TestRequest::with_uri("/test").to_request();
let resp = call_service(&mut srv, req);
let resp = call_service(&mut srv, req).await;
assert_eq!(resp.status(), StatusCode::OK);
let body = read_body(resp);
let body = read_body(resp).await;
assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345"));
})
}
#[test]
fn test_service() {
block_on(async {
let mut srv = init_service(App::new().configure(|cfg| {
cfg.service(
web::resource("/test").route(web::get().to(|| HttpResponse::Created())),
web::resource("/test")
.route(web::get().to(|| HttpResponse::Created())),
)
.route("/index.html", web::get().to(|| HttpResponse::Ok()));
}));
}))
.await;
let req = TestRequest::with_uri("/test")
.method(Method::GET)
.to_request();
let resp = call_service(&mut srv, req);
let resp = call_service(&mut srv, req).await;
assert_eq!(resp.status(), StatusCode::CREATED);
let req = TestRequest::with_uri("/index.html")
.method(Method::GET)
.to_request();
let resp = call_service(&mut srv, req);
let resp = call_service(&mut srv, req).await;
assert_eq!(resp.status(), StatusCode::OK);
})
}
}

Some files were not shown because too many files have changed in this diff Show More