1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-07-20 00:06:12 +02:00

Compare commits

..

4 Commits

Author SHA1 Message Date
Harrison
7880d5e2bf Remove extra '#' so docs render correctly. (#1338) 2020-02-07 23:01:23 +09:00
0x1793d1
c23020d266 Fix extra line feed (#1206) 2019-12-09 22:40:37 +06:00
Nikolay Kim
1d45639bed prepare actix-multipart release 2019-12-07 20:02:46 +06:00
Alexander Larsson
6a672c9097 actix-multipart: Fix multipart boundary reading (#1189)
* actix-multipart: Fix multipart boundary reading

If we're not ready to read the first line after the multipart field
(which should be a "\r\n" line) then return NotReady instead of Ready(None)
so that we will get called again to read that line.

Without this I was getting MultipartError::Boundary from read_boundary()
because it got the "\r\n" line instead of the boundary.

* actix-multipart: Test handling of NotReady

Use a stream that reports NoReady and does partial reads in the test_stream
test. This works now, but failed before the previous commit.
2019-12-07 19:58:38 +06:00
179 changed files with 9650 additions and 10777 deletions

View File

@@ -10,9 +10,9 @@ matrix:
include: include:
- rust: stable - rust: stable
- rust: beta - rust: beta
- rust: nightly-2019-11-20 - rust: nightly-2019-08-10
allow_failures: allow_failures:
- rust: nightly-2019-11-20 - rust: nightly-2019-08-10
env: env:
global: global:
@@ -25,7 +25,7 @@ before_install:
- sudo apt-get install -y openssl libssl-dev libelf-dev libdw-dev cmake gcc binutils-dev libiberty-dev - sudo apt-get install -y openssl libssl-dev libelf-dev libdw-dev cmake gcc binutils-dev libiberty-dev
before_cache: | before_cache: |
if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-11-20" ]]; then if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-08-10" ]]; then
RUSTFLAGS="--cfg procmacro2_semver_exempt" cargo install --version 0.6.11 cargo-tarpaulin RUSTFLAGS="--cfg procmacro2_semver_exempt" cargo install --version 0.6.11 cargo-tarpaulin
fi fi
@@ -36,12 +36,9 @@ before_script:
script: script:
- cargo update - cargo update
- cargo check --all --no-default-features - cargo check --all --no-default-features
- | - cargo test --all-features --all -- --nocapture
if [[ "$TRAVIS_RUST_VERSION" == "stable" || "$TRAVIS_RUST_VERSION" == "beta" ]]; then - cd actix-http; cargo test --no-default-features --features="rust-tls" -- --nocapture; cd ..
cargo test --all-features --all -- --nocapture - cd awc; cargo test --no-default-features --features="rust-tls" -- --nocapture; cd ..
cd actix-http; cargo test --no-default-features --features="rustls" -- --nocapture; cd ..
cd awc; cargo test --no-default-features --features="rustls" -- --nocapture; cd ..
fi
# Upload docs # Upload docs
after_success: after_success:
@@ -54,7 +51,7 @@ after_success:
echo "Uploaded documentation" echo "Uploaded documentation"
fi fi
- | - |
if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-11-20" ]]; then if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-08-10" ]]; then
taskset -c 0 cargo tarpaulin --out Xml --all --all-features taskset -c 0 cargo tarpaulin --out Xml --all --all-features
bash <(curl -s https://codecov.io/bash) bash <(curl -s https://codecov.io/bash)
echo "Uploaded code coverage" echo "Uploaded code coverage"

View File

@@ -1,58 +1,5 @@
# Changes # Changes
## [2.0.0-rc] - 2019-12-20
### Changed
* Move `BodyEncoding` to `dev` module #1220
* Allow to set `peer_addr` for TestRequest #1074
* Make web::Data deref to Arc<T> #1214
* Rename `App::register_data()` to `App::app_data()`
* `HttpRequest::app_data<T>()` returns `Option<&T>` instead of `Option<&Data<T>>`
### Fixed
* Fix `AppConfig::secure()` is always false. #1202
## [2.0.0-alpha.6] - 2019-12-15
### Fixed
* Fixed compilation with default features off
## [2.0.0-alpha.5] - 2019-12-13
### Added
* Add test server, `test::start()` and `test::start_with()`
## [2.0.0-alpha.4] - 2019-12-08
### Deleted
* Delete HttpServer::run(), it is not useful witht async/await
## [2.0.0-alpha.3] - 2019-12-07
### Changed
* Migrate to tokio 0.2
## [2.0.0-alpha.1] - 2019-11-22
### Changed
* Migrated to `std::future`
* Remove implementation of `Responder` for `()`. (#1167)
## [1.0.9] - 2019-11-14 ## [1.0.9] - 2019-11-14
### Added ### Added

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-web" name = "actix-web"
version = "2.0.0-rc" version = "1.0.9"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix web is a simple, pragmatic and extremely fast web framework for Rust." description = "Actix web is a simple, pragmatic and extremely fast web framework for Rust."
readme = "README.md" readme = "README.md"
@@ -12,10 +12,11 @@ categories = ["network-programming", "asynchronous",
"web-programming::http-server", "web-programming::http-server",
"web-programming::websocket"] "web-programming::websocket"]
license = "MIT/Apache-2.0" license = "MIT/Apache-2.0"
exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"]
edition = "2018" edition = "2018"
[package.metadata.docs.rs] [package.metadata.docs.rs]
features = ["openssl", "rustls", "compress", "secure-cookies"] features = ["ssl", "brotli", "flate2-zlib", "secure-cookies", "client", "rust-tls", "uds"]
[badges] [badges]
travis-ci = { repository = "actix/actix-web", branch = "master" } travis-ci = { repository = "actix/actix-web", branch = "master" }
@@ -42,63 +43,78 @@ members = [
] ]
[features] [features]
default = ["compress", "failure"] default = ["brotli", "flate2-zlib", "client", "fail"]
# content-encoding support # http client
compress = ["actix-http/compress", "awc/compress"] client = ["awc"]
# brotli encoding, requires c compiler
brotli = ["actix-http/brotli"]
# miniz-sys backend for flate2 crate
flate2-zlib = ["actix-http/flate2-zlib"]
# rust backend for flate2 crate
flate2-rust = ["actix-http/flate2-rust"]
# sessions feature, session require "ring" crate and c compiler # sessions feature, session require "ring" crate and c compiler
secure-cookies = ["actix-http/secure-cookies"] secure-cookies = ["actix-http/secure-cookies"]
failure = ["actix-http/failure"] fail = ["actix-http/fail"]
# openssl # openssl
openssl = ["actix-tls/openssl", "awc/openssl", "open-ssl"] ssl = ["openssl", "actix-server/ssl", "awc/ssl"]
# rustls # rustls
rustls = ["actix-tls/rustls", "awc/rustls", "rust-tls"] rust-tls = ["rustls", "actix-server/rust-tls", "awc/rust-tls"]
# unix domain sockets support
uds = ["actix-server/uds"]
[dependencies] [dependencies]
actix-codec = "0.2.0" actix-codec = "0.1.2"
actix-service = "1.0.0" actix-service = "0.4.1"
actix-utils = "1.0.3" actix-utils = "0.4.4"
actix-router = "0.2.0" actix-router = "0.1.5"
actix-rt = "1.0.0" actix-rt = "0.2.4"
actix-server = "1.0.0" actix-web-codegen = "0.1.2"
actix-testing = "1.0.0" actix-http = "0.2.11"
actix-macros = "0.1.0" actix-server = "0.6.1"
actix-threadpool = "0.3.0" actix-server-config = "0.1.2"
actix-tls = "1.0.0" actix-testing = "0.1.0"
actix-threadpool = "0.1.1"
awc = { version = "0.2.7", optional = true }
actix-web-codegen = "0.2.0" bytes = "0.4"
actix-http = "1.0.0" derive_more = "0.15.0"
awc = { version = "1.0.1", default-features = false }
bytes = "0.5.3"
derive_more = "0.99.2"
encoding_rs = "0.8" encoding_rs = "0.8"
futures = "0.3.1" futures = "0.1.25"
fxhash = "0.2.1" hashbrown = "0.6.3"
log = "0.4" log = "0.4"
mime = "0.3" mime = "0.3"
net2 = "0.2.33" net2 = "0.2.33"
pin-project = "0.4.6" parking_lot = "0.9"
regex = "1.3" regex = "1.0"
serde = { version = "1.0", features=["derive"] } serde = { version = "1.0", features=["derive"] }
serde_json = "1.0" serde_json = "1.0"
serde_urlencoded = "0.6.1" serde_urlencoded = "0.6.1"
time = "0.1.42" time = "0.1.42"
url = "2.1" url = "2.1"
open-ssl = { version="0.10", package = "openssl", optional = true }
rust-tls = { version = "0.16.0", package = "rustls", optional = true } # ssl support
openssl = { version="0.10", optional = true }
rustls = { version = "0.15", optional = true }
[dev-dependencies] [dev-dependencies]
actix = "0.9.0-alpha.2" actix = "0.8.3"
actix-connect = "0.2.2"
actix-http-test = "0.2.4"
rand = "0.7" rand = "0.7"
env_logger = "0.6" env_logger = "0.6"
serde_derive = "1.0" serde_derive = "1.0"
tokio-timer = "0.2.8"
brotli2 = "0.3.2" brotli2 = "0.3.2"
flate2 = "1.0.13" flate2 = "1.0.2"
[profile.release] [profile.release]
lto = true lto = true
@@ -110,8 +126,7 @@ actix-web = { path = "." }
actix-http = { path = "actix-http" } actix-http = { path = "actix-http" }
actix-http-test = { path = "test-server" } actix-http-test = { path = "test-server" }
actix-web-codegen = { path = "actix-web-codegen" } actix-web-codegen = { path = "actix-web-codegen" }
actix-cors = { path = "actix-cors" } actix-web-actors = { path = "actix-web-actors" }
actix-identity = { path = "actix-identity" }
actix-session = { path = "actix-session" } actix-session = { path = "actix-session" }
actix-files = { path = "actix-files" } actix-files = { path = "actix-files" }
actix-multipart = { path = "actix-multipart" } actix-multipart = { path = "actix-multipart" }

View File

@@ -1,17 +1,3 @@
## 2.0.0
* `App::register_data()` renamed to `App::app_data()` and accepts any type `T: 'static`.
Stored data is available via `HttpRequest::app_data()` method at runtime.
* Extractor configuration must be registered with `App::app_data()` instead of `App::data()`
* Sync handlers has been removed. `.to_async()` method has been renamed to `.to()`
replace `fn` with `async fn` to convert sync handler to async
* `TestServer::new()` renamed to `TestServer::start()`
## 1.0.1 ## 1.0.1
* Cors middleware has been moved to `actix-cors` crate * Cors middleware has been moved to `actix-cors` crate

View File

@@ -1,28 +1,4 @@
<div align="center"> # Actix web [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-web)](https://crates.io/crates/actix-web) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
<p><h1>Actix web</h1> </p>
<p><strong>Actix web is a small, pragmatic, and extremely fast rust web framework</strong> </p>
<p>
[![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web)
[![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web)
[![crates.io](https://meritbadge.herokuapp.com/actix-web)](https://crates.io/crates/actix-web)
[![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![Documentation](https://docs.rs/actix-web/badge.svg)](https://docs.rs/actix-web)
[![Download](https://img.shields.io/crates/d/actix-web.svg)](https://crates.io/crates/actix-web)
[![Version](https://img.shields.io/badge/rustc-1.39+-lightgray.svg)](https://blog.rust-lang.org/2019/11/07/Rust-1.39.0.html)
![License](https://img.shields.io/crates/l/actix-web.svg)
</p>
<h3>
<a href="https://actix.rs">Website</a>
<span> | </span>
<a href="https://gitter.im/actix/actix">Chat</a>
<span> | </span>
<a href="https://github.com/actix/examples">Examples</a>
</h3>
</div>
<br>
Actix web is a simple, pragmatic and extremely fast web framework for Rust. Actix web is a simple, pragmatic and extremely fast web framework for Rust.
@@ -39,22 +15,30 @@ Actix web is a simple, pragmatic and extremely fast web framework for Rust.
* Includes an asynchronous [HTTP client](https://actix.rs/actix-web/actix_web/client/index.html) * Includes an asynchronous [HTTP client](https://actix.rs/actix-web/actix_web/client/index.html)
* Supports [Actix actor framework](https://github.com/actix/actix) * Supports [Actix actor framework](https://github.com/actix/actix)
## Documentation & community resources
* [User Guide](https://actix.rs/docs/)
* [API Documentation (1.0)](https://docs.rs/actix-web/)
* [API Documentation (0.7)](https://docs.rs/actix-web/0.7.19/actix_web/)
* [Chat on gitter](https://gitter.im/actix/actix)
* Cargo package: [actix-web](https://crates.io/crates/actix-web)
* Minimum supported Rust version: 1.36 or later
## Example ## Example
```rust ```rust
use actix_web::{get, web, App, HttpServer, Responder}; use actix_web::{web, App, HttpServer, Responder};
#[get("/{id}/{name}/index.html")] fn index(info: web::Path<(u32, String)>) -> impl Responder {
async fn index(info: web::Path<(u32, String)>) -> impl Responder {
format!("Hello {}! id:{}", info.1, info.0) format!("Hello {}! id:{}", info.1, info.0)
} }
#[actix_rt::main] fn main() -> std::io::Result<()> {
async fn main() -> std::io::Result<()> { HttpServer::new(
HttpServer::new(|| App::new().service(index)) || App::new().service(
web::resource("/{id}/{name}/index.html").to(index)))
.bind("127.0.0.1:8080")? .bind("127.0.0.1:8080")?
.start() .run()
.await
} }
``` ```

View File

@@ -1,14 +1,8 @@
# Changes # Changes
## [0.2.0] - 2019-12-20 ## [0.1.1] - unreleased
* Release * Bump `derive_more` crate version to 0.15.0
## [0.2.0-alpha.3] - 2019-12-07
* Migrate to actix-web 2.0.0
* Bump `derive_more` crate version to 0.99.0
## [0.1.0] - 2019-06-15 ## [0.1.0] - 2019-06-15

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-cors" name = "actix-cors"
version = "0.2.0" version = "0.1.0"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Cross-origin resource sharing (CORS) for Actix applications." description = "Cross-origin resource sharing (CORS) for Actix applications."
readme = "README.md" readme = "README.md"
@@ -10,17 +10,14 @@ repository = "https://github.com/actix/actix-web.git"
documentation = "https://docs.rs/actix-cors/" documentation = "https://docs.rs/actix-cors/"
license = "MIT/Apache-2.0" license = "MIT/Apache-2.0"
edition = "2018" edition = "2018"
workspace = ".." #workspace = ".."
[lib] [lib]
name = "actix_cors" name = "actix_cors"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = "2.0.0-rc" actix-web = "1.0.0"
actix-service = "1.0.0" actix-service = "0.4.0"
derive_more = "0.99.2" derive_more = "0.15.0"
futures = "0.3.1" futures = "0.1.25"
[dev-dependencies]
actix-rt = "1.0.0"

View File

@@ -11,7 +11,7 @@
//! use actix_cors::Cors; //! use actix_cors::Cors;
//! use actix_web::{http, web, App, HttpRequest, HttpResponse, HttpServer}; //! use actix_web::{http, web, App, HttpRequest, HttpResponse, HttpServer};
//! //!
//! async fn index(req: HttpRequest) -> &'static str { //! fn index(req: HttpRequest) -> &'static str {
//! "Hello world" //! "Hello world"
//! } //! }
//! //!
@@ -23,8 +23,7 @@
//! .allowed_methods(vec!["GET", "POST"]) //! .allowed_methods(vec!["GET", "POST"])
//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) //! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT])
//! .allowed_header(http::header::CONTENT_TYPE) //! .allowed_header(http::header::CONTENT_TYPE)
//! .max_age(3600) //! .max_age(3600))
//! .finish())
//! .service( //! .service(
//! web::resource("/index.html") //! web::resource("/index.html")
//! .route(web::get().to(index)) //! .route(web::get().to(index))
@@ -40,19 +39,18 @@
//! //!
//! Cors middleware automatically handle *OPTIONS* preflight request. //! Cors middleware automatically handle *OPTIONS* preflight request.
use std::collections::HashSet; use std::collections::HashSet;
use std::convert::TryFrom;
use std::iter::FromIterator; use std::iter::FromIterator;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::{Service, Transform}; use actix_service::{IntoTransform, Service, Transform};
use actix_web::dev::{RequestHead, ServiceRequest, ServiceResponse}; use actix_web::dev::{RequestHead, ServiceRequest, ServiceResponse};
use actix_web::error::{Error, ResponseError, Result}; use actix_web::error::{Error, ResponseError, Result};
use actix_web::http::header::{self, HeaderName, HeaderValue}; use actix_web::http::header::{self, HeaderName, HeaderValue};
use actix_web::http::{self, Error as HttpError, Method, StatusCode, Uri}; use actix_web::http::{self, HttpTryFrom, Method, StatusCode, Uri};
use actix_web::HttpResponse; use actix_web::HttpResponse;
use derive_more::Display; use derive_more::Display;
use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready}; use futures::future::{ok, Either, Future, FutureResult};
use futures::Poll;
/// A set of errors that can occur during processing CORS /// A set of errors that can occur during processing CORS
#[derive(Debug, Display)] #[derive(Debug, Display)]
@@ -94,10 +92,6 @@ pub enum CorsError {
} }
impl ResponseError for CorsError { impl ResponseError for CorsError {
fn status_code(&self) -> StatusCode {
StatusCode::BAD_REQUEST
}
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
HttpResponse::with_body(StatusCode::BAD_REQUEST, format!("{}", self).into()) HttpResponse::with_body(StatusCode::BAD_REQUEST, format!("{}", self).into())
} }
@@ -275,8 +269,7 @@ impl Cors {
pub fn allowed_methods<U, M>(mut self, methods: U) -> Cors pub fn allowed_methods<U, M>(mut self, methods: U) -> Cors
where where
U: IntoIterator<Item = M>, U: IntoIterator<Item = M>,
Method: TryFrom<M>, Method: HttpTryFrom<M>,
<Method as TryFrom<M>>::Error: Into<HttpError>,
{ {
self.methods = true; self.methods = true;
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
@@ -298,8 +291,7 @@ impl Cors {
/// Set an allowed header /// Set an allowed header
pub fn allowed_header<H>(mut self, header: H) -> Cors pub fn allowed_header<H>(mut self, header: H) -> Cors
where where
HeaderName: TryFrom<H>, HeaderName: HttpTryFrom<H>,
<HeaderName as TryFrom<H>>::Error: Into<HttpError>,
{ {
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
match HeaderName::try_from(header) { match HeaderName::try_from(header) {
@@ -331,8 +323,7 @@ impl Cors {
pub fn allowed_headers<U, H>(mut self, headers: U) -> Cors pub fn allowed_headers<U, H>(mut self, headers: U) -> Cors
where where
U: IntoIterator<Item = H>, U: IntoIterator<Item = H>,
HeaderName: TryFrom<H>, HeaderName: HttpTryFrom<H>,
<HeaderName as TryFrom<H>>::Error: Into<HttpError>,
{ {
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
for h in headers { for h in headers {
@@ -366,8 +357,7 @@ impl Cors {
pub fn expose_headers<U, H>(mut self, headers: U) -> Cors pub fn expose_headers<U, H>(mut self, headers: U) -> Cors
where where
U: IntoIterator<Item = H>, U: IntoIterator<Item = H>,
HeaderName: TryFrom<H>, HeaderName: HttpTryFrom<H>,
<HeaderName as TryFrom<H>>::Error: Into<HttpError>,
{ {
for h in headers { for h in headers {
match HeaderName::try_from(h) { match HeaderName::try_from(h) {
@@ -466,9 +456,25 @@ impl Cors {
} }
self self
} }
}
/// Construct cors middleware fn cors<'a>(
pub fn finish(self) -> CorsFactory { parts: &'a mut Option<Inner>,
err: &Option<http::Error>,
) -> Option<&'a mut Inner> {
if err.is_some() {
return None;
}
parts.as_mut()
}
impl<S, B> IntoTransform<CorsFactory, S> for Cors
where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
fn into_transform(self) -> CorsFactory {
let mut slf = if !self.methods { let mut slf = if !self.methods {
self.allowed_methods(vec![ self.allowed_methods(vec![
Method::GET, Method::GET,
@@ -515,16 +521,6 @@ impl Cors {
} }
} }
fn cors<'a>(
parts: &'a mut Option<Inner>,
err: &Option<http::Error>,
) -> Option<&'a mut Inner> {
if err.is_some() {
return None;
}
parts.as_mut()
}
/// `Middleware` for Cross-origin resource sharing support /// `Middleware` for Cross-origin resource sharing support
/// ///
/// The Cors struct contains the settings for CORS requests to be validated and /// The Cors struct contains the settings for CORS requests to be validated and
@@ -544,7 +540,7 @@ where
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Transform = CorsMiddleware<S>; type Transform = CorsMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>; type Future = FutureResult<Self::Transform, Self::InitError>;
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
ok(CorsMiddleware { ok(CorsMiddleware {
@@ -686,12 +682,12 @@ where
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = Error; type Error = Error;
type Future = Either< type Future = Either<
Ready<Result<Self::Response, Error>>, FutureResult<Self::Response, Error>,
LocalBoxFuture<'static, Result<Self::Response, Error>>, Either<S::Future, Box<dyn Future<Item = Self::Response, Error = Error>>>,
>; >;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.poll_ready(cx) self.service.poll_ready()
} }
fn call(&mut self, req: ServiceRequest) -> Self::Future { fn call(&mut self, req: ServiceRequest) -> Self::Future {
@@ -702,7 +698,7 @@ where
.and_then(|_| self.inner.validate_allowed_method(req.head())) .and_then(|_| self.inner.validate_allowed_method(req.head()))
.and_then(|_| self.inner.validate_allowed_headers(req.head())) .and_then(|_| self.inner.validate_allowed_headers(req.head()))
{ {
return Either::Left(ok(req.error_response(e))); return Either::A(ok(req.error_response(e)));
} }
// allowed headers // allowed headers
@@ -755,32 +751,22 @@ where
.finish() .finish()
.into_body(); .into_body();
Either::Left(ok(req.into_response(res))) Either::A(ok(req.into_response(res)))
} else { } else if req.headers().contains_key(&header::ORIGIN) {
if req.headers().contains_key(&header::ORIGIN) {
// Only check requests with a origin header. // Only check requests with a origin header.
if let Err(e) = self.inner.validate_origin(req.head()) { if let Err(e) = self.inner.validate_origin(req.head()) {
return Either::Left(ok(req.error_response(e))); return Either::A(ok(req.error_response(e)));
}
} }
let inner = self.inner.clone(); let inner = self.inner.clone();
let has_origin = req.headers().contains_key(&header::ORIGIN);
let fut = self.service.call(req);
Either::Right( Either::B(Either::B(Box::new(self.service.call(req).and_then(
async move { move |mut res| {
let res = fut.await;
if has_origin {
let mut res = res?;
if let Some(origin) = if let Some(origin) =
inner.access_control_allow_origin(res.request().head()) inner.access_control_allow_origin(res.request().head())
{ {
res.headers_mut().insert( res.headers_mut()
header::ACCESS_CONTROL_ALLOW_ORIGIN, .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
origin.clone(),
);
}; };
if let Some(ref expose) = inner.expose_hdrs { if let Some(ref expose) = inner.expose_hdrs {
@@ -796,9 +782,8 @@ where
); );
} }
if inner.vary_header { if inner.vary_header {
let value = if let Some(hdr) = let value =
res.headers_mut().get(&header::VARY) if let Some(hdr) = res.headers_mut().get(&header::VARY) {
{
let mut val: Vec<u8> = let mut val: Vec<u8> =
Vec::with_capacity(hdr.as_bytes().len() + 8); Vec::with_capacity(hdr.as_bytes().len() + 8);
val.extend(hdr.as_bytes()); val.extend(hdr.as_bytes());
@@ -810,68 +795,80 @@ where
res.headers_mut().insert(header::VARY, value); res.headers_mut().insert(header::VARY, value);
} }
Ok(res) Ok(res)
},
))))
} else { } else {
res Either::B(Either::A(self.service.call(req)))
}
}
.boxed_local(),
)
} }
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use actix_service::{fn_service, Transform}; use actix_service::{IntoService, Transform};
use actix_web::test::{self, TestRequest}; use actix_web::test::{self, block_on, TestRequest};
use super::*; use super::*;
#[actix_rt::test] impl Cors {
fn finish<F, S, B>(self, srv: F) -> CorsMiddleware<S>
where
F: IntoService<S>,
S: Service<
Request = ServiceRequest,
Response = ServiceResponse<B>,
Error = Error,
> + 'static,
S::Future: 'static,
B: 'static,
{
block_on(
IntoTransform::<CorsFactory, S>::into_transform(self)
.new_transform(srv.into_service()),
)
.unwrap()
}
}
#[test]
#[should_panic(expected = "Credentials are allowed, but the Origin is set to")] #[should_panic(expected = "Credentials are allowed, but the Origin is set to")]
async fn cors_validates_illegal_allow_credentials() { fn cors_validates_illegal_allow_credentials() {
let _cors = Cors::new().supports_credentials().send_wildcard().finish(); let _cors = Cors::new()
.supports_credentials()
.send_wildcard()
.finish(test::ok_service());
} }
#[actix_rt::test] #[test]
async fn validate_origin_allows_all_origins() { fn validate_origin_allows_all_origins() {
let mut cors = Cors::new() let mut cors = Cors::new().finish(test::ok_service());
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
} }
#[actix_rt::test] #[test]
async fn default() { fn default() {
let mut cors = Cors::default() let mut cors =
.new_transform(test::ok_service()) block_on(Cors::default().new_transform(test::ok_service())).unwrap();
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
} }
#[actix_rt::test] #[test]
async fn test_preflight() { fn test_preflight() {
let mut cors = Cors::new() let mut cors = Cors::new()
.send_wildcard() .send_wildcard()
.max_age(3600) .max_age(3600)
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
.allowed_header(header::CONTENT_TYPE) .allowed_header(header::CONTENT_TYPE)
.finish() .finish(test::ok_service());
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
@@ -880,7 +877,7 @@ mod tests {
assert!(cors.inner.validate_allowed_method(req.head()).is_err()); assert!(cors.inner.validate_allowed_method(req.head()).is_err());
assert!(cors.inner.validate_allowed_headers(req.head()).is_err()); assert!(cors.inner.validate_allowed_headers(req.head()).is_err());
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
@@ -900,7 +897,7 @@ mod tests {
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!( assert_eq!(
&b"*"[..], &b"*"[..],
resp.headers() resp.headers()
@@ -946,13 +943,13 @@ mod tests {
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
} }
// #[actix_rt::test] // #[test]
// #[should_panic(expected = "MissingOrigin")] // #[should_panic(expected = "MissingOrigin")]
// async fn test_validate_missing_origin() { // fn test_validate_missing_origin() {
// let cors = Cors::build() // let cors = Cors::build()
// .allowed_origin("https://www.example.com") // .allowed_origin("https://www.example.com")
// .finish(); // .finish();
@@ -960,15 +957,12 @@ mod tests {
// cors.start(&req).unwrap(); // cors.start(&req).unwrap();
// } // }
#[actix_rt::test] #[test]
#[should_panic(expected = "OriginNotAllowed")] #[should_panic(expected = "OriginNotAllowed")]
async fn test_validate_not_allowed_origin() { fn test_validate_not_allowed_origin() {
let cors = Cors::new() let cors = Cors::new()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.finish() .finish(test::ok_service());
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.unknown.com") let req = TestRequest::with_header("Origin", "https://www.unknown.com")
.method(Method::GET) .method(Method::GET)
@@ -978,34 +972,26 @@ mod tests {
cors.inner.validate_allowed_headers(req.head()).unwrap(); cors.inner.validate_allowed_headers(req.head()).unwrap();
} }
#[actix_rt::test] #[test]
async fn test_validate_origin() { fn test_validate_origin() {
let mut cors = Cors::new() let mut cors = Cors::new()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.finish() .finish(test::ok_service());
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::GET) .method(Method::GET)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
} }
#[actix_rt::test] #[test]
async fn test_no_origin_response() { fn test_no_origin_response() {
let mut cors = Cors::new() let mut cors = Cors::new().disable_preflight().finish(test::ok_service());
.disable_preflight()
.finish()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::default().method(Method::GET).to_srv_request(); let req = TestRequest::default().method(Method::GET).to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert!(resp assert!(resp
.headers() .headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
@@ -1014,7 +1000,7 @@ mod tests {
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!( assert_eq!(
&b"https://www.example.com"[..], &b"https://www.example.com"[..],
resp.headers() resp.headers()
@@ -1024,8 +1010,8 @@ mod tests {
); );
} }
#[actix_rt::test] #[test]
async fn test_response() { fn test_response() {
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
let mut cors = Cors::new() let mut cors = Cors::new()
.send_wildcard() .send_wildcard()
@@ -1035,16 +1021,13 @@ mod tests {
.allowed_headers(exposed_headers.clone()) .allowed_headers(exposed_headers.clone())
.expose_headers(exposed_headers.clone()) .expose_headers(exposed_headers.clone())
.allowed_header(header::CONTENT_TYPE) .allowed_header(header::CONTENT_TYPE)
.finish() .finish(test::ok_service());
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!( assert_eq!(
&b"*"[..], &b"*"[..],
resp.headers() resp.headers()
@@ -1082,18 +1065,15 @@ mod tests {
.allowed_headers(exposed_headers.clone()) .allowed_headers(exposed_headers.clone())
.expose_headers(exposed_headers.clone()) .expose_headers(exposed_headers.clone())
.allowed_header(header::CONTENT_TYPE) .allowed_header(header::CONTENT_TYPE)
.finish() .finish(|req: ServiceRequest| {
.new_transform(fn_service(|req: ServiceRequest| { req.into_response(
ok(req.into_response(
HttpResponse::Ok().header(header::VARY, "Accept").finish(), HttpResponse::Ok().header(header::VARY, "Accept").finish(),
)) )
})) });
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!( assert_eq!(
&b"Accept, Origin"[..], &b"Accept, Origin"[..],
resp.headers().get(header::VARY).unwrap().as_bytes() resp.headers().get(header::VARY).unwrap().as_bytes()
@@ -1103,16 +1083,13 @@ mod tests {
.disable_vary_header() .disable_vary_header()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.allowed_origin("https://www.google.com") .allowed_origin("https://www.google.com")
.finish() .finish(test::ok_service());
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
let origins_str = resp let origins_str = resp
.headers() .headers()
@@ -1124,22 +1101,19 @@ mod tests {
assert_eq!("https://www.example.com", origins_str); assert_eq!("https://www.example.com", origins_str);
} }
#[actix_rt::test] #[test]
async fn test_multiple_origins() { fn test_multiple_origins() {
let mut cors = Cors::new() let mut cors = Cors::new()
.allowed_origin("https://example.com") .allowed_origin("https://example.com")
.allowed_origin("https://example.org") .allowed_origin("https://example.org")
.allowed_methods(vec![Method::GET]) .allowed_methods(vec![Method::GET])
.finish() .finish(test::ok_service());
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://example.com") let req = TestRequest::with_header("Origin", "https://example.com")
.method(Method::GET) .method(Method::GET)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!( assert_eq!(
&b"https://example.com"[..], &b"https://example.com"[..],
resp.headers() resp.headers()
@@ -1152,7 +1126,7 @@ mod tests {
.method(Method::GET) .method(Method::GET)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!( assert_eq!(
&b"https://example.org"[..], &b"https://example.org"[..],
resp.headers() resp.headers()
@@ -1162,23 +1136,20 @@ mod tests {
); );
} }
#[actix_rt::test] #[test]
async fn test_multiple_origins_preflight() { fn test_multiple_origins_preflight() {
let mut cors = Cors::new() let mut cors = Cors::new()
.allowed_origin("https://example.com") .allowed_origin("https://example.com")
.allowed_origin("https://example.org") .allowed_origin("https://example.org")
.allowed_methods(vec![Method::GET]) .allowed_methods(vec![Method::GET])
.finish() .finish(test::ok_service());
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://example.com") let req = TestRequest::with_header("Origin", "https://example.com")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!( assert_eq!(
&b"https://example.com"[..], &b"https://example.com"[..],
resp.headers() resp.headers()
@@ -1192,7 +1163,7 @@ mod tests {
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&mut cors, req);
assert_eq!( assert_eq!(
&b"https://example.org"[..], &b"https://example.org"[..],
resp.headers() resp.headers()

View File

@@ -1,13 +1,5 @@
# Changes # Changes
## [0.2.0] - 2019-12-20
* Fix BodyEncoding trait import #1220
## [0.2.0-alpha.1] - 2019-12-07
* Migrate to `std::future`
## [0.1.7] - 2019-11-06 ## [0.1.7] - 2019-11-06
* Add an additional `filename*` param in the `Content-Disposition` header of `actix_files::NamedFile` to be more compatible. (#1151) * Add an additional `filename*` param in the `Content-Disposition` header of `actix_files::NamedFile` to be more compatible. (#1151)

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-files" name = "actix-files"
version = "0.2.0" version = "0.1.7"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Static files support for actix web." description = "Static files support for actix web."
readme = "README.md" readme = "README.md"
@@ -18,13 +18,13 @@ name = "actix_files"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "2.0.0-rc", default-features = false } actix-web = { version = "1.0.8", default-features = false }
actix-http = "1.0.1" actix-http = "0.2.11"
actix-service = "1.0.0" actix-service = "0.4.1"
bitflags = "1" bitflags = "1"
bytes = "0.5.3" bytes = "0.4"
futures = "0.3.1" futures = "0.1.25"
derive_more = "0.99.2" derive_more = "0.15.0"
log = "0.4" log = "0.4"
mime = "0.3" mime = "0.3"
mime_guess = "2.0.1" mime_guess = "2.0.1"
@@ -32,5 +32,4 @@ percent-encoding = "2.1"
v_htmlescape = "0.4" v_htmlescape = "0.4"
[dev-dependencies] [dev-dependencies]
actix-rt = "1.0.0" actix-web = { version = "1.0.8", features=["ssl"] }
actix-web = { version = "2.0.0-rc", features=["openssl"] }

View File

@@ -35,7 +35,7 @@ pub enum UriSegmentError {
/// Return `BadRequest` for `UriSegmentError` /// Return `BadRequest` for `UriSegmentError`
impl ResponseError for UriSegmentError { impl ResponseError for UriSegmentError {
fn status_code(&self) -> StatusCode { fn error_response(&self) -> HttpResponse {
StatusCode::BAD_REQUEST HttpResponse::new(StatusCode::BAD_REQUEST)
} }
} }

File diff suppressed because it is too large Load Diff

View File

@@ -12,13 +12,12 @@ use mime;
use mime_guess::from_path; use mime_guess::from_path;
use actix_http::body::SizedStream; use actix_http::body::SizedStream;
use actix_web::dev::BodyEncoding;
use actix_web::http::header::{ use actix_web::http::header::{
self, Charset, ContentDisposition, DispositionParam, DispositionType, ExtendedValue, self, Charset, ContentDisposition, DispositionParam, DispositionType, ExtendedValue,
}; };
use actix_web::http::{ContentEncoding, StatusCode}; use actix_web::http::{ContentEncoding, StatusCode};
use actix_web::middleware::BodyEncoding;
use actix_web::{Error, HttpMessage, HttpRequest, HttpResponse, Responder}; use actix_web::{Error, HttpMessage, HttpRequest, HttpResponse, Responder};
use futures::future::{ready, Ready};
use crate::range::HttpRange; use crate::range::HttpRange;
use crate::ChunkedReadFile; use crate::ChunkedReadFile;
@@ -256,8 +255,62 @@ impl NamedFile {
pub(crate) fn last_modified(&self) -> Option<header::HttpDate> { pub(crate) fn last_modified(&self) -> Option<header::HttpDate> {
self.modified.map(|mtime| mtime.into()) self.modified.map(|mtime| mtime.into())
} }
}
pub fn into_response(self, req: &HttpRequest) -> Result<HttpResponse, Error> { impl Deref for NamedFile {
type Target = File;
fn deref(&self) -> &File {
&self.file
}
}
impl DerefMut for NamedFile {
fn deref_mut(&mut self) -> &mut File {
&mut self.file
}
}
/// Returns true if `req` has no `If-Match` header or one which matches `etag`.
fn any_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
match req.get_header::<header::IfMatch>() {
None | Some(header::IfMatch::Any) => true,
Some(header::IfMatch::Items(ref items)) => {
if let Some(some_etag) = etag {
for item in items {
if item.strong_eq(some_etag) {
return true;
}
}
}
false
}
}
}
/// Returns true if `req` doesn't have an `If-None-Match` header matching `req`.
fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
match req.get_header::<header::IfNoneMatch>() {
Some(header::IfNoneMatch::Any) => false,
Some(header::IfNoneMatch::Items(ref items)) => {
if let Some(some_etag) = etag {
for item in items {
if item.weak_eq(some_etag) {
return false;
}
}
}
true
}
None => true,
}
}
impl Responder for NamedFile {
type Error = Error;
type Future = Result<HttpResponse, Error>;
fn respond_to(self, req: &HttpRequest) -> Self::Future {
if self.status_code != StatusCode::OK { if self.status_code != StatusCode::OK {
let mut resp = HttpResponse::build(self.status_code); let mut resp = HttpResponse::build(self.status_code);
resp.set(header::ContentType(self.content_type.clone())) resp.set(header::ContentType(self.content_type.clone()))
@@ -389,67 +442,8 @@ impl NamedFile {
counter: 0, counter: 0,
}; };
if offset != 0 || length != self.md.len() { if offset != 0 || length != self.md.len() {
Ok(resp.status(StatusCode::PARTIAL_CONTENT).streaming(reader)) return Ok(resp.status(StatusCode::PARTIAL_CONTENT).streaming(reader));
} else { };
Ok(resp.body(SizedStream::new(length, reader))) Ok(resp.body(SizedStream::new(length, reader)))
} }
} }
}
impl Deref for NamedFile {
type Target = File;
fn deref(&self) -> &File {
&self.file
}
}
impl DerefMut for NamedFile {
fn deref_mut(&mut self) -> &mut File {
&mut self.file
}
}
/// Returns true if `req` has no `If-Match` header or one which matches `etag`.
fn any_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
match req.get_header::<header::IfMatch>() {
None | Some(header::IfMatch::Any) => true,
Some(header::IfMatch::Items(ref items)) => {
if let Some(some_etag) = etag {
for item in items {
if item.strong_eq(some_etag) {
return true;
}
}
}
false
}
}
}
/// Returns true if `req` doesn't have an `If-None-Match` header matching `req`.
fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool {
match req.get_header::<header::IfNoneMatch>() {
Some(header::IfNoneMatch::Any) => false,
Some(header::IfNoneMatch::Items(ref items)) => {
if let Some(some_etag) = etag {
for item in items {
if item.weak_eq(some_etag) {
return false;
}
}
}
true
}
None => true,
}
}
impl Responder for NamedFile {
type Error = Error;
type Future = Ready<Result<HttpResponse, Error>>;
fn respond_to(self, req: &HttpRequest) -> Self::Future {
ready(self.into_response(req))
}
}

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-framed" name = "actix-framed"
version = "0.3.0" version = "0.2.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix framed app server" description = "Actix framed app server"
readme = "README.md" readme = "README.md"
@@ -20,19 +20,19 @@ name = "actix_framed"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-codec = "0.2.0" actix-codec = "0.1.2"
actix-service = "1.0.0" actix-service = "0.4.1"
actix-router = "0.2.0" actix-router = "0.1.2"
actix-rt = "1.0.0" actix-rt = "0.2.2"
actix-http = "1.0.0" actix-http = "0.2.7"
actix-server-config = "0.1.2"
bytes = "0.5.3" bytes = "0.4"
futures = "0.3.1" futures = "0.1.25"
pin-project = "0.4.6"
log = "0.4" log = "0.4"
[dev-dependencies] [dev-dependencies]
actix-server = "1.0.0" actix-server = { version = "0.6.0", features=["ssl"] }
actix-connect = { version = "1.0.0", features=["openssl"] } actix-connect = { version = "0.2.0", features=["ssl"] }
actix-http-test = { version = "1.0.0", features=["openssl"] } actix-http-test = { version = "0.2.4", features=["ssl"] }
actix-utils = "1.0.3" actix-utils = "0.4.4"

View File

@@ -1,23 +1,21 @@
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::h1::{Codec, SendResponse}; use actix_http::h1::{Codec, SendResponse};
use actix_http::{Error, Request, Response}; use actix_http::{Error, Request, Response};
use actix_router::{Path, Router, Url}; use actix_router::{Path, Router, Url};
use actix_service::{IntoServiceFactory, Service, ServiceFactory}; use actix_server_config::ServerConfig;
use futures::future::{ok, FutureExt, LocalBoxFuture}; use actix_service::{IntoNewService, NewService, Service};
use futures::{Async, Future, Poll};
use crate::helpers::{BoxedHttpNewService, BoxedHttpService, HttpNewService}; use crate::helpers::{BoxedHttpNewService, BoxedHttpService, HttpNewService};
use crate::request::FramedRequest; use crate::request::FramedRequest;
use crate::state::State; use crate::state::State;
type BoxedResponse = LocalBoxFuture<'static, Result<(), Error>>; type BoxedResponse = Box<dyn Future<Item = (), Error = Error>>;
pub trait HttpServiceFactory { pub trait HttpServiceFactory {
type Factory: ServiceFactory; type Factory: NewService;
fn path(&self) -> &str; fn path(&self) -> &str;
@@ -50,19 +48,19 @@ impl<T: 'static, S: 'static> FramedApp<T, S> {
pub fn service<U>(mut self, factory: U) -> Self pub fn service<U>(mut self, factory: U) -> Self
where where
U: HttpServiceFactory, U: HttpServiceFactory,
U::Factory: ServiceFactory< U::Factory: NewService<
Config = (), Config = (),
Request = FramedRequest<T, S>, Request = FramedRequest<T, S>,
Response = (), Response = (),
Error = Error, Error = Error,
InitError = (), InitError = (),
> + 'static, > + 'static,
<U::Factory as ServiceFactory>::Future: 'static, <U::Factory as NewService>::Future: 'static,
<U::Factory as ServiceFactory>::Service: Service< <U::Factory as NewService>::Service: Service<
Request = FramedRequest<T, S>, Request = FramedRequest<T, S>,
Response = (), Response = (),
Error = Error, Error = Error,
Future = LocalBoxFuture<'static, Result<(), Error>>, Future = Box<dyn Future<Item = (), Error = Error>>,
>, >,
{ {
let path = factory.path().to_string(); let path = factory.path().to_string();
@@ -72,12 +70,12 @@ impl<T: 'static, S: 'static> FramedApp<T, S> {
} }
} }
impl<T, S> IntoServiceFactory<FramedAppFactory<T, S>> for FramedApp<T, S> impl<T, S> IntoNewService<FramedAppFactory<T, S>> for FramedApp<T, S>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
S: 'static, S: 'static,
{ {
fn into_factory(self) -> FramedAppFactory<T, S> { fn into_new_service(self) -> FramedAppFactory<T, S> {
FramedAppFactory { FramedAppFactory {
state: self.state, state: self.state,
services: Rc::new(self.services), services: Rc::new(self.services),
@@ -91,12 +89,12 @@ pub struct FramedAppFactory<T, S> {
services: Rc<Vec<(String, BoxedHttpNewService<FramedRequest<T, S>>)>>, services: Rc<Vec<(String, BoxedHttpNewService<FramedRequest<T, S>>)>>,
} }
impl<T, S> ServiceFactory for FramedAppFactory<T, S> impl<T, S> NewService for FramedAppFactory<T, S>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
S: 'static, S: 'static,
{ {
type Config = (); type Config = ServerConfig;
type Request = (Request, Framed<T, Codec>); type Request = (Request, Framed<T, Codec>);
type Response = (); type Response = ();
type Error = Error; type Error = Error;
@@ -104,7 +102,7 @@ where
type Service = FramedAppService<T, S>; type Service = FramedAppService<T, S>;
type Future = CreateService<T, S>; type Future = CreateService<T, S>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: &ServerConfig) -> Self::Future {
CreateService { CreateService {
fut: self fut: self
.services .services
@@ -112,7 +110,7 @@ where
.map(|(path, service)| { .map(|(path, service)| {
CreateServiceItem::Future( CreateServiceItem::Future(
Some(path.clone()), Some(path.clone()),
service.new_service(()), service.new_service(&()),
) )
}) })
.collect(), .collect(),
@@ -130,30 +128,28 @@ pub struct CreateService<T, S> {
enum CreateServiceItem<T, S> { enum CreateServiceItem<T, S> {
Future( Future(
Option<String>, Option<String>,
LocalBoxFuture<'static, Result<BoxedHttpService<FramedRequest<T, S>>, ()>>, Box<dyn Future<Item = BoxedHttpService<FramedRequest<T, S>>, Error = ()>>,
), ),
Service(String, BoxedHttpService<FramedRequest<T, S>>), Service(String, BoxedHttpService<FramedRequest<T, S>>),
} }
impl<S: 'static, T: 'static> Future for CreateService<T, S> impl<S: 'static, T: 'static> Future for CreateService<T, S>
where where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite,
{ {
type Output = Result<FramedAppService<T, S>, ()>; type Item = FramedAppService<T, S>;
type Error = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut done = true; let mut done = true;
// poll http services // poll http services
for item in &mut self.fut { for item in &mut self.fut {
let res = match item { let res = match item {
CreateServiceItem::Future(ref mut path, ref mut fut) => { CreateServiceItem::Future(ref mut path, ref mut fut) => {
match Pin::new(fut).poll(cx) { match fut.poll()? {
Poll::Ready(Ok(service)) => { Async::Ready(service) => Some((path.take().unwrap(), service)),
Some((path.take().unwrap(), service)) Async::NotReady => {
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => {
done = false; done = false;
None None
} }
@@ -180,12 +176,12 @@ where
} }
router router
}); });
Poll::Ready(Ok(FramedAppService { Ok(Async::Ready(FramedAppService {
router: router.finish(), router: router.finish(),
state: self.state.clone(), state: self.state.clone(),
})) }))
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} }
} }
@@ -197,15 +193,15 @@ pub struct FramedAppService<T, S> {
impl<S: 'static, T: 'static> Service for FramedAppService<T, S> impl<S: 'static, T: 'static> Service for FramedAppService<T, S>
where where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite,
{ {
type Request = (Request, Framed<T, Codec>); type Request = (Request, Framed<T, Codec>);
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type Future = BoxedResponse; type Future = BoxedResponse;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} }
fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future { fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future {
@@ -214,8 +210,8 @@ where
if let Some((srv, _info)) = self.router.recognize_mut(&mut path) { if let Some((srv, _info)) = self.router.recognize_mut(&mut path) {
return srv.call(FramedRequest::new(req, framed, path, self.state.clone())); return srv.call(FramedRequest::new(req, framed, path, self.state.clone()));
} }
SendResponse::new(framed, Response::NotFound().finish()) Box::new(
.then(|_| ok(())) SendResponse::new(framed, Response::NotFound().finish()).then(|_| Ok(())),
.boxed_local() )
} }
} }

View File

@@ -1,38 +1,36 @@
use std::task::{Context, Poll};
use actix_http::Error; use actix_http::Error;
use actix_service::{Service, ServiceFactory}; use actix_service::{NewService, Service};
use futures::future::{FutureExt, LocalBoxFuture}; use futures::{Future, Poll};
pub(crate) type BoxedHttpService<Req> = Box< pub(crate) type BoxedHttpService<Req> = Box<
dyn Service< dyn Service<
Request = Req, Request = Req,
Response = (), Response = (),
Error = Error, Error = Error,
Future = LocalBoxFuture<'static, Result<(), Error>>, Future = Box<dyn Future<Item = (), Error = Error>>,
>, >,
>; >;
pub(crate) type BoxedHttpNewService<Req> = Box< pub(crate) type BoxedHttpNewService<Req> = Box<
dyn ServiceFactory< dyn NewService<
Config = (), Config = (),
Request = Req, Request = Req,
Response = (), Response = (),
Error = Error, Error = Error,
InitError = (), InitError = (),
Service = BoxedHttpService<Req>, Service = BoxedHttpService<Req>,
Future = LocalBoxFuture<'static, Result<BoxedHttpService<Req>, ()>>, Future = Box<dyn Future<Item = BoxedHttpService<Req>, Error = ()>>,
>, >,
>; >;
pub(crate) struct HttpNewService<T: ServiceFactory>(T); pub(crate) struct HttpNewService<T: NewService>(T);
impl<T> HttpNewService<T> impl<T> HttpNewService<T>
where where
T: ServiceFactory<Response = (), Error = Error>, T: NewService<Response = (), Error = Error>,
T::Response: 'static, T::Response: 'static,
T::Future: 'static, T::Future: 'static,
T::Service: Service<Future = LocalBoxFuture<'static, Result<(), Error>>> + 'static, T::Service: Service<Future = Box<dyn Future<Item = (), Error = Error>>> + 'static,
<T::Service as Service>::Future: 'static, <T::Service as Service>::Future: 'static,
{ {
pub fn new(service: T) -> Self { pub fn new(service: T) -> Self {
@@ -40,12 +38,12 @@ where
} }
} }
impl<T> ServiceFactory for HttpNewService<T> impl<T> NewService for HttpNewService<T>
where where
T: ServiceFactory<Config = (), Response = (), Error = Error>, T: NewService<Config = (), Response = (), Error = Error>,
T::Request: 'static, T::Request: 'static,
T::Future: 'static, T::Future: 'static,
T::Service: Service<Future = LocalBoxFuture<'static, Result<(), Error>>> + 'static, T::Service: Service<Future = Box<dyn Future<Item = (), Error = Error>>> + 'static,
<T::Service as Service>::Future: 'static, <T::Service as Service>::Future: 'static,
{ {
type Config = (); type Config = ();
@@ -54,19 +52,13 @@ where
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Service = BoxedHttpService<T::Request>; type Service = BoxedHttpService<T::Request>;
type Future = LocalBoxFuture<'static, Result<Self::Service, ()>>; type Future = Box<dyn Future<Item = Self::Service, Error = ()>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: &()) -> Self::Future {
let fut = self.0.new_service(()); Box::new(self.0.new_service(&()).map_err(|_| ()).and_then(|service| {
let service: BoxedHttpService<_> = Box::new(HttpServiceWrapper { service });
async move { Ok(service)
fut.await.map_err(|_| ()).map(|service| { }))
let service: BoxedHttpService<_> =
Box::new(HttpServiceWrapper { service });
service
})
}
.boxed_local()
} }
} }
@@ -78,7 +70,7 @@ impl<T> Service for HttpServiceWrapper<T>
where where
T: Service< T: Service<
Response = (), Response = (),
Future = LocalBoxFuture<'static, Result<(), Error>>, Future = Box<dyn Future<Item = (), Error = Error>>,
Error = Error, Error = Error,
>, >,
T::Request: 'static, T::Request: 'static,
@@ -86,10 +78,10 @@ where
type Request = T::Request; type Request = T::Request;
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<(), Error>>; type Future = Box<dyn Future<Item = (), Error = Error>>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.poll_ready(cx) self.service.poll_ready()
} }
fn call(&mut self, req: Self::Request) -> Self::Future { fn call(&mut self, req: Self::Request) -> Self::Future {

View File

@@ -123,9 +123,7 @@ impl<Io, S> FramedRequest<Io, S> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::convert::TryFrom; use actix_http::http::{HeaderName, HeaderValue, HttpTryFrom};
use actix_http::http::{HeaderName, HeaderValue};
use actix_http::test::{TestBuffer, TestRequest}; use actix_http::test::{TestBuffer, TestRequest};
use super::*; use super::*;

View File

@@ -1,12 +1,11 @@
use std::fmt; use std::fmt;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::{http::Method, Error}; use actix_http::{http::Method, Error};
use actix_service::{Service, ServiceFactory}; use actix_service::{NewService, Service};
use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use futures::future::{ok, FutureResult};
use futures::{Async, Future, IntoFuture, Poll};
use log::error; use log::error;
use crate::app::HttpServiceFactory; use crate::app::HttpServiceFactory;
@@ -16,11 +15,11 @@ use crate::request::FramedRequest;
/// ///
/// Route uses builder-like pattern for configuration. /// Route uses builder-like pattern for configuration.
/// If handler is not explicitly set, default *404 Not Found* handler is used. /// If handler is not explicitly set, default *404 Not Found* handler is used.
pub struct FramedRoute<Io, S, F = (), R = (), E = ()> { pub struct FramedRoute<Io, S, F = (), R = ()> {
handler: F, handler: F,
pattern: String, pattern: String,
methods: Vec<Method>, methods: Vec<Method>,
state: PhantomData<(Io, S, R, E)>, state: PhantomData<(Io, S, R)>,
} }
impl<Io, S> FramedRoute<Io, S> { impl<Io, S> FramedRoute<Io, S> {
@@ -54,12 +53,12 @@ impl<Io, S> FramedRoute<Io, S> {
self self
} }
pub fn to<F, R, E>(self, handler: F) -> FramedRoute<Io, S, F, R, E> pub fn to<F, R>(self, handler: F) -> FramedRoute<Io, S, F, R>
where where
F: FnMut(FramedRequest<Io, S>) -> R, F: FnMut(FramedRequest<Io, S>) -> R,
R: Future<Output = Result<(), E>> + 'static, R: IntoFuture<Item = ()>,
R::Future: 'static,
E: fmt::Debug, R::Error: fmt::Debug,
{ {
FramedRoute { FramedRoute {
handler, handler,
@@ -70,14 +69,15 @@ impl<Io, S> FramedRoute<Io, S> {
} }
} }
impl<Io, S, F, R, E> HttpServiceFactory for FramedRoute<Io, S, F, R, E> impl<Io, S, F, R> HttpServiceFactory for FramedRoute<Io, S, F, R>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + 'static,
F: FnMut(FramedRequest<Io, S>) -> R + Clone, F: FnMut(FramedRequest<Io, S>) -> R + Clone,
R: Future<Output = Result<(), E>> + 'static, R: IntoFuture<Item = ()>,
E: fmt::Display, R::Future: 'static,
R::Error: fmt::Display,
{ {
type Factory = FramedRouteFactory<Io, S, F, R, E>; type Factory = FramedRouteFactory<Io, S, F, R>;
fn path(&self) -> &str { fn path(&self) -> &str {
&self.pattern &self.pattern
@@ -92,28 +92,29 @@ where
} }
} }
pub struct FramedRouteFactory<Io, S, F, R, E> { pub struct FramedRouteFactory<Io, S, F, R> {
handler: F, handler: F,
methods: Vec<Method>, methods: Vec<Method>,
_t: PhantomData<(Io, S, R, E)>, _t: PhantomData<(Io, S, R)>,
} }
impl<Io, S, F, R, E> ServiceFactory for FramedRouteFactory<Io, S, F, R, E> impl<Io, S, F, R> NewService for FramedRouteFactory<Io, S, F, R>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + 'static,
F: FnMut(FramedRequest<Io, S>) -> R + Clone, F: FnMut(FramedRequest<Io, S>) -> R + Clone,
R: Future<Output = Result<(), E>> + 'static, R: IntoFuture<Item = ()>,
E: fmt::Display, R::Future: 'static,
R::Error: fmt::Display,
{ {
type Config = (); type Config = ();
type Request = FramedRequest<Io, S>; type Request = FramedRequest<Io, S>;
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Service = FramedRouteService<Io, S, F, R, E>; type Service = FramedRouteService<Io, S, F, R>;
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: &()) -> Self::Future {
ok(FramedRouteService { ok(FramedRouteService {
handler: self.handler.clone(), handler: self.handler.clone(),
methods: self.methods.clone(), methods: self.methods.clone(),
@@ -122,38 +123,35 @@ where
} }
} }
pub struct FramedRouteService<Io, S, F, R, E> { pub struct FramedRouteService<Io, S, F, R> {
handler: F, handler: F,
methods: Vec<Method>, methods: Vec<Method>,
_t: PhantomData<(Io, S, R, E)>, _t: PhantomData<(Io, S, R)>,
} }
impl<Io, S, F, R, E> Service for FramedRouteService<Io, S, F, R, E> impl<Io, S, F, R> Service for FramedRouteService<Io, S, F, R>
where where
Io: AsyncRead + AsyncWrite + 'static, Io: AsyncRead + AsyncWrite + 'static,
F: FnMut(FramedRequest<Io, S>) -> R + Clone, F: FnMut(FramedRequest<Io, S>) -> R + Clone,
R: Future<Output = Result<(), E>> + 'static, R: IntoFuture<Item = ()>,
E: fmt::Display, R::Future: 'static,
R::Error: fmt::Display,
{ {
type Request = FramedRequest<Io, S>; type Request = FramedRequest<Io, S>;
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<(), Error>>; type Future = Box<dyn Future<Item = (), Error = Error>>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} }
fn call(&mut self, req: FramedRequest<Io, S>) -> Self::Future { fn call(&mut self, req: FramedRequest<Io, S>) -> Self::Future {
let fut = (self.handler)(req); Box::new((self.handler)(req).into_future().then(|res| {
async move {
let res = fut.await;
if let Err(e) = res { if let Err(e) = res {
error!("Error in request handler: {}", e); error!("Error in request handler: {}", e);
} }
Ok(()) Ok(())
} }))
.boxed_local()
} }
} }

View File

@@ -1,6 +1,4 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::body::BodySize; use actix_http::body::BodySize;
@@ -8,9 +6,9 @@ use actix_http::error::ResponseError;
use actix_http::h1::{Codec, Message}; use actix_http::h1::{Codec, Message};
use actix_http::ws::{verify_handshake, HandshakeError}; use actix_http::ws::{verify_handshake, HandshakeError};
use actix_http::{Request, Response}; use actix_http::{Request, Response};
use actix_service::{Service, ServiceFactory}; use actix_service::{NewService, Service};
use futures::future::{err, ok, Either, Ready}; use futures::future::{ok, Either, FutureResult};
use futures::Future; use futures::{Async, Future, IntoFuture, Poll, Sink};
/// Service that verifies incoming request if it is valid websocket /// Service that verifies incoming request if it is valid websocket
/// upgrade request. In case of error returns `HandshakeError` /// upgrade request. In case of error returns `HandshakeError`
@@ -24,16 +22,16 @@ impl<T, C> Default for VerifyWebSockets<T, C> {
} }
} }
impl<T, C> ServiceFactory for VerifyWebSockets<T, C> { impl<T, C> NewService for VerifyWebSockets<T, C> {
type Config = C; type Config = C;
type Request = (Request, Framed<T, Codec>); type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>); type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>); type Error = (HandshakeError, Framed<T, Codec>);
type InitError = (); type InitError = ();
type Service = VerifyWebSockets<T, C>; type Service = VerifyWebSockets<T, C>;
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self, _: C) -> Self::Future { fn new_service(&self, _: &C) -> Self::Future {
ok(VerifyWebSockets { _t: PhantomData }) ok(VerifyWebSockets { _t: PhantomData })
} }
} }
@@ -42,16 +40,16 @@ impl<T, C> Service for VerifyWebSockets<T, C> {
type Request = (Request, Framed<T, Codec>); type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>); type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>); type Error = (HandshakeError, Framed<T, Codec>);
type Future = Ready<Result<Self::Response, Self::Error>>; type Future = FutureResult<Self::Response, Self::Error>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} }
fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future { fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future {
match verify_handshake(req.head()) { match verify_handshake(req.head()) {
Err(e) => err((e, framed)), Err(e) => Err((e, framed)).into_future(),
Ok(_) => ok((req, framed)), Ok(_) => Ok((req, framed)).into_future(),
} }
} }
} }
@@ -69,9 +67,9 @@ where
} }
} }
impl<T, R, E, C> ServiceFactory for SendError<T, R, E, C> impl<T, R, E, C> NewService for SendError<T, R, E, C>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
R: 'static, R: 'static,
E: ResponseError + 'static, E: ResponseError + 'static,
{ {
@@ -81,34 +79,34 @@ where
type Error = (E, Framed<T, Codec>); type Error = (E, Framed<T, Codec>);
type InitError = (); type InitError = ();
type Service = SendError<T, R, E, C>; type Service = SendError<T, R, E, C>;
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self, _: C) -> Self::Future { fn new_service(&self, _: &C) -> Self::Future {
ok(SendError(PhantomData)) ok(SendError(PhantomData))
} }
} }
impl<T, R, E, C> Service for SendError<T, R, E, C> impl<T, R, E, C> Service for SendError<T, R, E, C>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
R: 'static, R: 'static,
E: ResponseError + 'static, E: ResponseError + 'static,
{ {
type Request = Result<R, (E, Framed<T, Codec>)>; type Request = Result<R, (E, Framed<T, Codec>)>;
type Response = R; type Response = R;
type Error = (E, Framed<T, Codec>); type Error = (E, Framed<T, Codec>);
type Future = Either<Ready<Result<R, (E, Framed<T, Codec>)>>, SendErrorFut<T, R, E>>; type Future = Either<FutureResult<R, (E, Framed<T, Codec>)>, SendErrorFut<T, R, E>>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} }
fn call(&mut self, req: Result<R, (E, Framed<T, Codec>)>) -> Self::Future { fn call(&mut self, req: Result<R, (E, Framed<T, Codec>)>) -> Self::Future {
match req { match req {
Ok(r) => Either::Left(ok(r)), Ok(r) => Either::A(ok(r)),
Err((e, framed)) => { Err((e, framed)) => {
let res = e.error_response().drop_body(); let res = e.error_response().drop_body();
Either::Right(SendErrorFut { Either::B(SendErrorFut {
framed: Some(framed), framed: Some(framed),
res: Some((res, BodySize::Empty).into()), res: Some((res, BodySize::Empty).into()),
err: Some(e), err: Some(e),
@@ -119,7 +117,6 @@ where
} }
} }
#[pin_project::pin_project]
pub struct SendErrorFut<T, R, E> { pub struct SendErrorFut<T, R, E> {
res: Option<Message<(Response<()>, BodySize)>>, res: Option<Message<(Response<()>, BodySize)>>,
framed: Option<Framed<T, Codec>>, framed: Option<Framed<T, Codec>>,
@@ -130,27 +127,23 @@ pub struct SendErrorFut<T, R, E> {
impl<T, R, E> Future for SendErrorFut<T, R, E> impl<T, R, E> Future for SendErrorFut<T, R, E>
where where
E: ResponseError, E: ResponseError,
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite,
{ {
type Output = Result<R, (E, Framed<T, Codec>)>; type Item = R;
type Error = (E, Framed<T, Codec>);
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(res) = self.res.take() { if let Some(res) = self.res.take() {
if self.framed.as_mut().unwrap().write(res).is_err() { if self.framed.as_mut().unwrap().force_send(res).is_err() {
return Poll::Ready(Err(( return Err((self.err.take().unwrap(), self.framed.take().unwrap()));
self.err.take().unwrap(),
self.framed.take().unwrap(),
)));
} }
} }
match self.framed.as_mut().unwrap().flush(cx) { match self.framed.as_mut().unwrap().poll_complete() {
Poll::Ready(Ok(_)) => { Ok(Async::Ready(_)) => {
Poll::Ready(Err((self.err.take().unwrap(), self.framed.take().unwrap()))) Err((self.err.take().unwrap(), self.framed.take().unwrap()))
} }
Poll::Ready(Err(_)) => { Ok(Async::NotReady) => Ok(Async::NotReady),
Poll::Ready(Err((self.err.take().unwrap(), self.framed.take().unwrap()))) Err(_) => Err((self.err.take().unwrap(), self.framed.take().unwrap())),
}
Poll::Pending => Poll::Pending,
} }
} }
} }

View File

@@ -1,13 +1,12 @@
//! Various helpers for Actix applications to use during testing. //! Various helpers for Actix applications to use during testing.
use std::convert::TryFrom;
use std::future::Future;
use actix_codec::Framed; use actix_codec::Framed;
use actix_http::h1::Codec; use actix_http::h1::Codec;
use actix_http::http::header::{Header, HeaderName, IntoHeaderValue}; use actix_http::http::header::{Header, HeaderName, IntoHeaderValue};
use actix_http::http::{Error as HttpError, Method, Uri, Version}; use actix_http::http::{HttpTryFrom, Method, Uri, Version};
use actix_http::test::{TestBuffer, TestRequest as HttpTestRequest}; use actix_http::test::{TestBuffer, TestRequest as HttpTestRequest};
use actix_router::{Path, Url}; use actix_router::{Path, Url};
use actix_rt::Runtime;
use futures::IntoFuture;
use crate::{FramedRequest, State}; use crate::{FramedRequest, State};
@@ -42,8 +41,7 @@ impl TestRequest<()> {
/// Create TestRequest and set header /// Create TestRequest and set header
pub fn with_header<K, V>(key: K, value: V) -> Self pub fn with_header<K, V>(key: K, value: V) -> Self
where where
HeaderName: TryFrom<K>, HeaderName: HttpTryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
V: IntoHeaderValue, V: IntoHeaderValue,
{ {
Self::default().header(key, value) Self::default().header(key, value)
@@ -98,8 +96,7 @@ impl<S> TestRequest<S> {
/// Set a header /// Set a header
pub fn header<K, V>(mut self, key: K, value: V) -> Self pub fn header<K, V>(mut self, key: K, value: V) -> Self
where where
HeaderName: TryFrom<K>, HeaderName: HttpTryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
V: IntoHeaderValue, V: IntoHeaderValue,
{ {
self.req.header(key, value); self.req.header(key, value);
@@ -121,12 +118,13 @@ impl<S> TestRequest<S> {
} }
/// This method generates `FramedRequest` instance and executes async handler /// This method generates `FramedRequest` instance and executes async handler
pub async fn run<F, R, I, E>(self, f: F) -> Result<I, E> pub fn run<F, R, I, E>(self, f: F) -> Result<I, E>
where where
F: FnOnce(FramedRequest<TestBuffer, S>) -> R, F: FnOnce(FramedRequest<TestBuffer, S>) -> R,
R: Future<Output = Result<I, E>>, R: IntoFuture<Item = I, Error = E>,
{ {
f(self.finish()).await let mut rt = Runtime::new().unwrap();
rt.block_on(f(self.finish()).into_future())
} }
} }

View File

@@ -1,159 +1,141 @@
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::{body, http::StatusCode, ws, Error, HttpService, Response}; use actix_http::{body, http::StatusCode, ws, Error, HttpService, Response};
use actix_http_test::test_server; use actix_http_test::TestServer;
use actix_service::{pipeline_factory, IntoServiceFactory, ServiceFactory}; use actix_service::{IntoNewService, NewService};
use actix_utils::framed::Dispatcher; use actix_utils::framed::FramedTransport;
use bytes::Bytes; use bytes::{Bytes, BytesMut};
use futures::{future, SinkExt, StreamExt}; use futures::future::{self, ok};
use futures::{Future, Sink, Stream};
use actix_framed::{FramedApp, FramedRequest, FramedRoute, SendError, VerifyWebSockets}; use actix_framed::{FramedApp, FramedRequest, FramedRoute, SendError, VerifyWebSockets};
async fn ws_service<T: AsyncRead + AsyncWrite>( fn ws_service<T: AsyncRead + AsyncWrite>(
req: FramedRequest<T>, req: FramedRequest<T>,
) -> Result<(), Error> { ) -> impl Future<Item = (), Error = Error> {
let (req, mut framed, _) = req.into_parts(); let (req, framed, _) = req.into_parts();
let res = ws::handshake(req.head()).unwrap().message_body(()); let res = ws::handshake(req.head()).unwrap().message_body(());
framed framed
.send((res, body::BodySize::None).into()) .send((res, body::BodySize::None).into())
.await .map_err(|_| panic!())
.unwrap(); .and_then(|framed| {
Dispatcher::new(framed.into_framed(ws::Codec::new()), service) FramedTransport::new(framed.into_framed(ws::Codec::new()), service)
.await .map_err(|_| panic!())
.unwrap(); })
Ok(())
} }
async fn service(msg: ws::Frame) -> Result<ws::Message, Error> { fn service(msg: ws::Frame) -> impl Future<Item = ws::Message, Error = Error> {
let msg = match msg { let msg = match msg {
ws::Frame::Ping(msg) => ws::Message::Pong(msg), ws::Frame::Ping(msg) => ws::Message::Pong(msg),
ws::Frame::Text(text) => { ws::Frame::Text(text) => {
ws::Message::Text(String::from_utf8_lossy(&text).to_string()) ws::Message::Text(String::from_utf8_lossy(&text.unwrap()).to_string())
} }
ws::Frame::Binary(bin) => ws::Message::Binary(bin), ws::Frame::Binary(bin) => ws::Message::Binary(bin.unwrap().freeze()),
ws::Frame::Close(reason) => ws::Message::Close(reason), ws::Frame::Close(reason) => ws::Message::Close(reason),
_ => panic!(), _ => panic!(),
}; };
Ok(msg) ok(msg)
} }
#[actix_rt::test] #[test]
async fn test_simple() { fn test_simple() {
let mut srv = test_server(|| { let mut srv = TestServer::new(|| {
HttpService::build() HttpService::build()
.upgrade( .upgrade(
FramedApp::new().service(FramedRoute::get("/index.html").to(ws_service)), FramedApp::new().service(FramedRoute::get("/index.html").to(ws_service)),
) )
.finish(|_| future::ok::<_, Error>(Response::NotFound())) .finish(|_| future::ok::<_, Error>(Response::NotFound()))
.tcp()
}); });
assert!(srv.ws_at("/test").await.is_err()); assert!(srv.ws_at("/test").is_err());
// client service // client service
let mut framed = srv.ws_at("/index.html").await.unwrap(); let framed = srv.ws_at("/index.html").unwrap();
framed let framed = srv
.send(ws::Message::Text("text".to_string())) .block_on(framed.send(ws::Message::Text("text".to_string())))
.await
.unwrap(); .unwrap();
let (item, mut framed) = framed.into_future().await; let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text")))));
let framed = srv
.block_on(framed.send(ws::Message::Binary("text".into())))
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!( assert_eq!(
item.unwrap().unwrap(), item,
ws::Frame::Text(Bytes::from_static(b"text")) Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into())))
); );
framed let framed = srv
.send(ws::Message::Binary("text".into())) .block_on(framed.send(ws::Message::Ping("text".into())))
.await
.unwrap(); .unwrap();
let (item, mut framed) = framed.into_future().await; let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!( assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into())));
item.unwrap().unwrap(),
ws::Frame::Binary(Bytes::from_static(b"text"))
);
framed.send(ws::Message::Ping("text".into())).await.unwrap(); let framed = srv
let (item, mut framed) = framed.into_future().await; .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))))
assert_eq!(
item.unwrap().unwrap(),
ws::Frame::Pong("text".to_string().into())
);
framed
.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
.await
.unwrap(); .unwrap();
let (item, _) = framed.into_future().await; let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!( assert_eq!(
item.unwrap().unwrap(), item,
ws::Frame::Close(Some(ws::CloseCode::Normal.into())) Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into())))
); );
} }
#[actix_rt::test] #[test]
async fn test_service() { fn test_service() {
let mut srv = test_server(|| { let mut srv = TestServer::new(|| {
pipeline_factory(actix_http::h1::OneRequest::new().map_err(|_| ())).and_then( actix_http::h1::OneRequest::new().map_err(|_| ()).and_then(
pipeline_factory( VerifyWebSockets::default()
pipeline_factory(VerifyWebSockets::default())
.then(SendError::default()) .then(SendError::default())
.map_err(|_| ()), .map_err(|_| ())
)
.and_then( .and_then(
FramedApp::new() FramedApp::new()
.service(FramedRoute::get("/index.html").to(ws_service)) .service(FramedRoute::get("/index.html").to(ws_service))
.into_factory() .into_new_service()
.map_err(|_| ()), .map_err(|_| ()),
), ),
) )
}); });
// non ws request // non ws request
let res = srv.get("/index.html").send().await.unwrap(); let res = srv.block_on(srv.get("/index.html").send()).unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST); assert_eq!(res.status(), StatusCode::BAD_REQUEST);
// not found // not found
assert!(srv.ws_at("/test").await.is_err()); assert!(srv.ws_at("/test").is_err());
// client service // client service
let mut framed = srv.ws_at("/index.html").await.unwrap(); let framed = srv.ws_at("/index.html").unwrap();
framed let framed = srv
.send(ws::Message::Text("text".to_string())) .block_on(framed.send(ws::Message::Text("text".to_string())))
.await
.unwrap(); .unwrap();
let (item, mut framed) = framed.into_future().await; let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text")))));
let framed = srv
.block_on(framed.send(ws::Message::Binary("text".into())))
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!( assert_eq!(
item.unwrap().unwrap(), item,
ws::Frame::Text(Bytes::from_static(b"text")) Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into())))
); );
framed let framed = srv
.send(ws::Message::Binary("text".into())) .block_on(framed.send(ws::Message::Ping("text".into())))
.await
.unwrap(); .unwrap();
let (item, mut framed) = framed.into_future().await; let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!( assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into())));
item.unwrap().unwrap(),
ws::Frame::Binary(Bytes::from_static(b"text"))
);
framed.send(ws::Message::Ping("text".into())).await.unwrap(); let framed = srv
let (item, mut framed) = framed.into_future().await; .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))))
assert_eq!(
item.unwrap().unwrap(),
ws::Frame::Pong("text".to_string().into())
);
framed
.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
.await
.unwrap(); .unwrap();
let (item, _) = framed.into_future().await; let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!( assert_eq!(
item.unwrap().unwrap(), item,
ws::Frame::Close(Some(ws::CloseCode::Normal.into())) Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into())))
); );
} }

View File

@@ -1,54 +1,5 @@
# Changes # Changes
## [1.0.1] - 2019-12-20
### Fixed
* Poll upgrade service's readiness from HTTP service handlers
* Replace brotli with brotli2 #1224
## [1.0.0] - 2019-12-13
### Added
* Add websockets continuation frame support
### Changed
* Replace `flate2-xxx` features with `compress`
## [1.0.0-alpha.5] - 2019-12-09
### Fixed
* Check `Upgrade` service readiness before calling it
* Fix buffer remaining capacity calcualtion
### Changed
* Websockets: Ping and Pong should have binary data #1049
## [1.0.0-alpha.4] - 2019-12-08
### Added
* Add impl ResponseBuilder for Error
### Changed
* Use rust based brotli compression library
## [1.0.0-alpha.3] - 2019-12-07
### Changed
* Migrate to tokio 0.2
* Migrate to `std::future`
## [0.2.11] - 2019-11-06 ## [0.2.11] - 2019-11-06
### Added ### Added

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-http" name = "actix-http"
version = "1.0.1" version = "0.2.11"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Actix http primitives" description = "Actix http primitives"
readme = "README.md" readme = "README.md"
@@ -13,9 +13,10 @@ categories = ["network-programming", "asynchronous",
"web-programming::websocket"] "web-programming::websocket"]
license = "MIT/Apache-2.0" license = "MIT/Apache-2.0"
edition = "2018" edition = "2018"
workspace = ".."
[package.metadata.docs.rs] [package.metadata.docs.rs]
features = ["openssl", "rustls", "failure", "compress", "secure-cookies"] features = ["ssl", "fail", "brotli", "flate2-zlib", "secure-cookies"]
[lib] [lib]
name = "actix_http" name = "actix_http"
@@ -25,77 +26,85 @@ path = "src/lib.rs"
default = [] default = []
# openssl # openssl
openssl = ["actix-tls/openssl", "actix-connect/openssl"] ssl = ["openssl", "actix-connect/ssl"]
# rustls support # rustls support
rustls = ["actix-tls/rustls", "actix-connect/rustls"] rust-tls = ["rustls", "webpki-roots", "actix-connect/rust-tls"]
# enable compressison support # brotli encoding, requires c compiler
compress = ["flate2", "brotli2"] brotli = ["brotli2"]
# miniz-sys backend for flate2 crate
flate2-zlib = ["flate2/miniz-sys"]
# rust backend for flate2 crate
flate2-rust = ["flate2/rust_backend"]
# failure integration. actix does not use failure anymore # failure integration. actix does not use failure anymore
failure = ["fail-ure"] fail = ["failure"]
# support for secure cookies # support for secure cookies
secure-cookies = ["ring"] secure-cookies = ["ring"]
[dependencies] [dependencies]
actix-service = "1.0.0" actix-service = "0.4.1"
actix-codec = "0.2.0" actix-codec = "0.1.2"
actix-connect = "1.0.1" actix-connect = "0.2.4"
actix-utils = "1.0.3" actix-utils = "0.4.4"
actix-rt = "1.0.0" actix-server-config = "0.1.2"
actix-threadpool = "0.3.1" actix-threadpool = "0.1.1"
actix-tls = { version = "1.0.0", optional = true }
base64 = "0.11" base64 = "0.10"
bitflags = "1.2" bitflags = "1.0"
bytes = "0.5.3" bytes = "0.4"
copyless = "0.1.4" copyless = "0.1.4"
chrono = "0.4.6" derive_more = "0.15.0"
derive_more = "0.99.2" either = "1.5.2"
either = "1.5.3"
encoding_rs = "0.8" encoding_rs = "0.8"
futures-core = "0.3.1" futures = "0.1.25"
futures-util = "0.3.1" hashbrown = "0.6.3"
futures-channel = "0.3.1" h2 = "0.1.16"
fxhash = "0.2.1" http = "0.1.17"
h2 = "0.2.1"
http = "0.2.0"
httparse = "1.3" httparse = "1.3"
indexmap = "1.3" indexmap = "1.2"
lazy_static = "1.4" lazy_static = "1.0"
language-tags = "0.2" language-tags = "0.2"
log = "0.4" log = "0.4"
mime = "0.3" mime = "0.3"
percent-encoding = "2.1" percent-encoding = "2.1"
pin-project = "0.4.6"
rand = "0.7" rand = "0.7"
regex = "1.3" regex = "1.0"
serde = "1.0" serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
sha1 = "0.6" sha1 = "0.6"
slab = "0.4" slab = "0.4"
serde_urlencoded = "0.6.1" serde_urlencoded = "0.6.1"
time = "0.1.42" time = "0.1.42"
tokio-tcp = "0.1.3"
tokio-timer = "0.2.8"
tokio-current-thread = "0.1"
trust-dns-resolver = { version="0.11.1", default-features = false }
# for secure cookie # for secure cookie
ring = { version = "0.16.9", optional = true } ring = { version = "0.14.6", optional = true }
# compression # compression
brotli2 = { version="0.3.2", optional = true } brotli2 = { version="0.3.2", optional = true }
flate2 = { version = "1.0.13", optional = true } flate2 = { version="1.0.7", optional = true, default-features = false }
# optional deps # optional deps
fail-ure = { version = "0.1.5", package="failure", optional = true } failure = { version = "0.1.5", optional = true }
openssl = { version="0.10", optional = true }
rustls = { version = "0.15.2", optional = true }
webpki-roots = { version = "0.16", optional = true }
chrono = "0.4.6"
[dev-dependencies] [dev-dependencies]
actix-server = "1.0.0" actix-rt = "0.2.2"
actix-connect = { version = "1.0.0", features=["openssl"] } actix-server = { version = "0.6.0", features=["ssl", "rust-tls"] }
actix-http-test = { version = "1.0.0", features=["openssl"] } actix-connect = { version = "0.2.0", features=["ssl"] }
actix-tls = { version = "1.0.0", features=["openssl"] } actix-http-test = { version = "0.2.4", features=["ssl"] }
futures = "0.3.1"
env_logger = "0.6" env_logger = "0.6"
serde_derive = "1.0" serde_derive = "1.0"
open-ssl = { version="0.10", package = "openssl" } openssl = { version="0.10" }
rust-tls = { version="0.16", package = "rustls" } tokio-tcp = "0.1"

View File

@@ -1,9 +1,9 @@
use std::{env, io}; use std::{env, io};
use actix_http::{Error, HttpService, Request, Response}; use actix_http::{error::PayloadError, HttpService, Request, Response};
use actix_server::Server; use actix_server::Server;
use bytes::BytesMut; use bytes::BytesMut;
use futures::StreamExt; use futures::{Future, Stream};
use http::header::HeaderValue; use http::header::HeaderValue;
use log::info; use log::info;
@@ -17,24 +17,21 @@ fn main() -> io::Result<()> {
.client_timeout(1000) .client_timeout(1000)
.client_disconnect(1000) .client_disconnect(1000)
.finish(|mut req: Request| { .finish(|mut req: Request| {
async move { req.take_payload()
let mut body = BytesMut::new(); .fold(BytesMut::new(), move |mut body, chunk| {
while let Some(item) = req.payload().next().await { body.extend_from_slice(&chunk);
body.extend_from_slice(&item?); Ok::<_, PayloadError>(body)
} })
.and_then(|bytes| {
info!("request body: {:?}", body); info!("request body: {:?}", bytes);
Ok::<_, Error>( let mut res = Response::Ok();
Response::Ok() res.header(
.header(
"x-head", "x-head",
HeaderValue::from_static("dummy value!"), HeaderValue::from_static("dummy value!"),
) );
.body(body), Ok(res.body(bytes))
) })
}
}) })
.tcp()
})? })?
.run() .run()
} }

View File

@@ -1,22 +1,25 @@
use std::{env, io}; use std::{env, io};
use actix_http::http::HeaderValue; use actix_http::http::HeaderValue;
use actix_http::{Error, HttpService, Request, Response}; use actix_http::{error::PayloadError, Error, HttpService, Request, Response};
use actix_server::Server; use actix_server::Server;
use bytes::BytesMut; use bytes::BytesMut;
use futures::StreamExt; use futures::{Future, Stream};
use log::info; use log::info;
async fn handle_request(mut req: Request) -> Result<Response, Error> { fn handle_request(mut req: Request) -> impl Future<Item = Response, Error = Error> {
let mut body = BytesMut::new(); req.take_payload()
while let Some(item) = req.payload().next().await { .fold(BytesMut::new(), move |mut body, chunk| {
body.extend_from_slice(&item?) body.extend_from_slice(&chunk);
} Ok::<_, PayloadError>(body)
})
info!("request body: {:?}", body); .from_err()
Ok(Response::Ok() .and_then(|bytes| {
.header("x-head", HeaderValue::from_static("dummy value!")) info!("request body: {:?}", bytes);
.body(body)) let mut res = Response::Ok();
res.header("x-head", HeaderValue::from_static("dummy value!"));
Ok(res.body(bytes))
})
} }
fn main() -> io::Result<()> { fn main() -> io::Result<()> {
@@ -25,7 +28,7 @@ fn main() -> io::Result<()> {
Server::build() Server::build()
.bind("echo", "127.0.0.1:8080", || { .bind("echo", "127.0.0.1:8080", || {
HttpService::build().finish(handle_request).tcp() HttpService::build().finish(|_req: Request| handle_request(_req))
})? })?
.run() .run()
} }

View File

@@ -21,7 +21,6 @@ fn main() -> io::Result<()> {
res.header("x-head", HeaderValue::from_static("dummy value!")); res.header("x-head", HeaderValue::from_static("dummy value!"));
future::ok::<_, ()>(res.body("Hello world!")) future::ok::<_, ()>(res.body("Hello world!"))
}) })
.tcp()
})? })?
.run() .run()
} }

View File

@@ -1,11 +1,8 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, mem}; use std::{fmt, mem};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures_core::Stream; use futures::{Async, Poll, Stream};
use pin_project::{pin_project, project};
use crate::error::Error; use crate::error::Error;
@@ -35,7 +32,7 @@ impl BodySize {
pub trait MessageBody { pub trait MessageBody {
fn size(&self) -> BodySize; fn size(&self) -> BodySize;
fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>>; fn poll_next(&mut self) -> Poll<Option<Bytes>, Error>;
} }
impl MessageBody for () { impl MessageBody for () {
@@ -43,8 +40,8 @@ impl MessageBody for () {
BodySize::Empty BodySize::Empty
} }
fn poll_next(&mut self, _: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
Poll::Ready(None) Ok(Async::Ready(None))
} }
} }
@@ -53,12 +50,11 @@ impl<T: MessageBody> MessageBody for Box<T> {
self.as_ref().size() self.as_ref().size()
} }
fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
self.as_mut().poll_next(cx) self.as_mut().poll_next()
} }
} }
#[pin_project]
pub enum ResponseBody<B> { pub enum ResponseBody<B> {
Body(B), Body(B),
Other(Body), Other(Body),
@@ -97,27 +93,20 @@ impl<B: MessageBody> MessageBody for ResponseBody<B> {
} }
} }
fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
match self { match self {
ResponseBody::Body(ref mut body) => body.poll_next(cx), ResponseBody::Body(ref mut body) => body.poll_next(),
ResponseBody::Other(ref mut body) => body.poll_next(cx), ResponseBody::Other(ref mut body) => body.poll_next(),
} }
} }
} }
impl<B: MessageBody> Stream for ResponseBody<B> { impl<B: MessageBody> Stream for ResponseBody<B> {
type Item = Result<Bytes, Error>; type Item = Bytes;
type Error = Error;
#[project] fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
fn poll_next( self.poll_next()
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
#[project]
match self.project() {
ResponseBody::Body(ref mut body) => body.poll_next(cx),
ResponseBody::Other(ref mut body) => body.poll_next(cx),
}
} }
} }
@@ -136,7 +125,7 @@ pub enum Body {
impl Body { impl Body {
/// Create body from slice (copy) /// Create body from slice (copy)
pub fn from_slice(s: &[u8]) -> Body { pub fn from_slice(s: &[u8]) -> Body {
Body::Bytes(Bytes::copy_from_slice(s)) Body::Bytes(Bytes::from(s))
} }
/// Create body from generic message body. /// Create body from generic message body.
@@ -155,19 +144,19 @@ impl MessageBody for Body {
} }
} }
fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
match self { match self {
Body::None => Poll::Ready(None), Body::None => Ok(Async::Ready(None)),
Body::Empty => Poll::Ready(None), Body::Empty => Ok(Async::Ready(None)),
Body::Bytes(ref mut bin) => { Body::Bytes(ref mut bin) => {
let len = bin.len(); let len = bin.len();
if len == 0 { if len == 0 {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
Poll::Ready(Some(Ok(mem::replace(bin, Bytes::new())))) Ok(Async::Ready(Some(mem::replace(bin, Bytes::new()))))
} }
} }
Body::Message(ref mut body) => body.poll_next(cx), Body::Message(ref mut body) => body.poll_next(),
} }
} }
} }
@@ -193,7 +182,7 @@ impl PartialEq for Body {
} }
impl fmt::Debug for Body { impl fmt::Debug for Body {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
Body::None => write!(f, "Body::None"), Body::None => write!(f, "Body::None"),
Body::Empty => write!(f, "Body::Empty"), Body::Empty => write!(f, "Body::Empty"),
@@ -229,7 +218,7 @@ impl From<String> for Body {
impl<'a> From<&'a String> for Body { impl<'a> From<&'a String> for Body {
fn from(s: &'a String) -> Body { fn from(s: &'a String) -> Body {
Body::Bytes(Bytes::copy_from_slice(AsRef::<[u8]>::as_ref(&s))) Body::Bytes(Bytes::from(AsRef::<[u8]>::as_ref(&s)))
} }
} }
@@ -253,7 +242,7 @@ impl From<serde_json::Value> for Body {
impl<S> From<SizedStream<S>> for Body impl<S> From<SizedStream<S>> for Body
where where
S: Stream<Item = Result<Bytes, Error>> + 'static, S: Stream<Item = Bytes, Error = Error> + 'static,
{ {
fn from(s: SizedStream<S>) -> Body { fn from(s: SizedStream<S>) -> Body {
Body::from_message(s) Body::from_message(s)
@@ -262,7 +251,7 @@ where
impl<S, E> From<BodyStream<S, E>> for Body impl<S, E> From<BodyStream<S, E>> for Body
where where
S: Stream<Item = Result<Bytes, E>> + 'static, S: Stream<Item = Bytes, Error = E> + 'static,
E: Into<Error> + 'static, E: Into<Error> + 'static,
{ {
fn from(s: BodyStream<S, E>) -> Body { fn from(s: BodyStream<S, E>) -> Body {
@@ -275,11 +264,11 @@ impl MessageBody for Bytes {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self, _: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
if self.is_empty() { if self.is_empty() {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
Poll::Ready(Some(Ok(mem::replace(self, Bytes::new())))) Ok(Async::Ready(Some(mem::replace(self, Bytes::new()))))
} }
} }
} }
@@ -289,11 +278,13 @@ impl MessageBody for BytesMut {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self, _: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
if self.is_empty() { if self.is_empty() {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
Poll::Ready(Some(Ok(mem::replace(self, BytesMut::new()).freeze()))) Ok(Async::Ready(Some(
mem::replace(self, BytesMut::new()).freeze(),
)))
} }
} }
} }
@@ -303,11 +294,11 @@ impl MessageBody for &'static str {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self, _: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
if self.is_empty() { if self.is_empty() {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
Poll::Ready(Some(Ok(Bytes::from_static( Ok(Async::Ready(Some(Bytes::from_static(
mem::replace(self, "").as_ref(), mem::replace(self, "").as_ref(),
)))) ))))
} }
@@ -319,11 +310,13 @@ impl MessageBody for &'static [u8] {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self, _: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
if self.is_empty() { if self.is_empty() {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
Poll::Ready(Some(Ok(Bytes::from_static(mem::replace(self, b""))))) Ok(Async::Ready(Some(Bytes::from_static(mem::replace(
self, b"",
)))))
} }
} }
} }
@@ -333,11 +326,14 @@ impl MessageBody for Vec<u8> {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self, _: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
if self.is_empty() { if self.is_empty() {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
Poll::Ready(Some(Ok(Bytes::from(mem::replace(self, Vec::new()))))) Ok(Async::Ready(Some(Bytes::from(mem::replace(
self,
Vec::new(),
)))))
} }
} }
} }
@@ -347,11 +343,11 @@ impl MessageBody for String {
BodySize::Sized(self.len()) BodySize::Sized(self.len())
} }
fn poll_next(&mut self, _: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
if self.is_empty() { if self.is_empty() {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
Poll::Ready(Some(Ok(Bytes::from( Ok(Async::Ready(Some(Bytes::from(
mem::replace(self, String::new()).into_bytes(), mem::replace(self, String::new()).into_bytes(),
)))) ))))
} }
@@ -360,16 +356,14 @@ impl MessageBody for String {
/// Type represent streaming body. /// Type represent streaming body.
/// Response does not contain `content-length` header and appropriate transfer encoding is used. /// Response does not contain `content-length` header and appropriate transfer encoding is used.
#[pin_project]
pub struct BodyStream<S, E> { pub struct BodyStream<S, E> {
#[pin]
stream: S, stream: S,
_t: PhantomData<E>, _t: PhantomData<E>,
} }
impl<S, E> BodyStream<S, E> impl<S, E> BodyStream<S, E>
where where
S: Stream<Item = Result<Bytes, E>>, S: Stream<Item = Bytes, Error = E>,
E: Into<Error>, E: Into<Error>,
{ {
pub fn new(stream: S) -> Self { pub fn new(stream: S) -> Self {
@@ -382,34 +376,28 @@ where
impl<S, E> MessageBody for BodyStream<S, E> impl<S, E> MessageBody for BodyStream<S, E>
where where
S: Stream<Item = Result<Bytes, E>>, S: Stream<Item = Bytes, Error = E>,
E: Into<Error>, E: Into<Error>,
{ {
fn size(&self) -> BodySize { fn size(&self) -> BodySize {
BodySize::Stream BodySize::Stream
} }
fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
unsafe { Pin::new_unchecked(self) } self.stream.poll().map_err(std::convert::Into::into)
.project()
.stream
.poll_next(cx)
.map(|res| res.map(|res| res.map_err(std::convert::Into::into)))
} }
} }
/// Type represent streaming body. This body implementation should be used /// Type represent streaming body. This body implementation should be used
/// if total size of stream is known. Data get sent as is without using transfer encoding. /// if total size of stream is known. Data get sent as is without using transfer encoding.
#[pin_project]
pub struct SizedStream<S> { pub struct SizedStream<S> {
size: u64, size: u64,
#[pin]
stream: S, stream: S,
} }
impl<S> SizedStream<S> impl<S> SizedStream<S>
where where
S: Stream<Item = Result<Bytes, Error>>, S: Stream<Item = Bytes, Error = Error>,
{ {
pub fn new(size: u64, stream: S) -> Self { pub fn new(size: u64, stream: S) -> Self {
SizedStream { size, stream } SizedStream { size, stream }
@@ -418,24 +406,20 @@ where
impl<S> MessageBody for SizedStream<S> impl<S> MessageBody for SizedStream<S>
where where
S: Stream<Item = Result<Bytes, Error>>, S: Stream<Item = Bytes, Error = Error>,
{ {
fn size(&self) -> BodySize { fn size(&self) -> BodySize {
BodySize::Sized64(self.size) BodySize::Sized64(self.size)
} }
fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
unsafe { Pin::new_unchecked(self) } self.stream.poll()
.project()
.stream
.poll_next(cx)
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use futures_util::future::poll_fn;
impl Body { impl Body {
pub(crate) fn get_ref(&self) -> &[u8] { pub(crate) fn get_ref(&self) -> &[u8] {
@@ -455,21 +439,21 @@ mod tests {
} }
} }
#[actix_rt::test] #[test]
async fn test_static_str() { fn test_static_str() {
assert_eq!(Body::from("").size(), BodySize::Sized(0)); assert_eq!(Body::from("").size(), BodySize::Sized(0));
assert_eq!(Body::from("test").size(), BodySize::Sized(4)); assert_eq!(Body::from("test").size(), BodySize::Sized(4));
assert_eq!(Body::from("test").get_ref(), b"test"); assert_eq!(Body::from("test").get_ref(), b"test");
assert_eq!("test".size(), BodySize::Sized(4)); assert_eq!("test".size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
poll_fn(|cx| "test".poll_next(cx)).await.unwrap().ok(), "test".poll_next().unwrap(),
Some(Bytes::from("test")) Async::Ready(Some(Bytes::from("test")))
); );
} }
#[actix_rt::test] #[test]
async fn test_static_bytes() { fn test_static_bytes() {
assert_eq!(Body::from(b"test".as_ref()).size(), BodySize::Sized(4)); assert_eq!(Body::from(b"test".as_ref()).size(), BodySize::Sized(4));
assert_eq!(Body::from(b"test".as_ref()).get_ref(), b"test"); assert_eq!(Body::from(b"test".as_ref()).get_ref(), b"test");
assert_eq!( assert_eq!(
@@ -480,57 +464,51 @@ mod tests {
assert_eq!((&b"test"[..]).size(), BodySize::Sized(4)); assert_eq!((&b"test"[..]).size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
poll_fn(|cx| (&b"test"[..]).poll_next(cx)) (&b"test"[..]).poll_next().unwrap(),
.await Async::Ready(Some(Bytes::from("test")))
.unwrap()
.ok(),
Some(Bytes::from("test"))
); );
} }
#[actix_rt::test] #[test]
async fn test_vec() { fn test_vec() {
assert_eq!(Body::from(Vec::from("test")).size(), BodySize::Sized(4)); assert_eq!(Body::from(Vec::from("test")).size(), BodySize::Sized(4));
assert_eq!(Body::from(Vec::from("test")).get_ref(), b"test"); assert_eq!(Body::from(Vec::from("test")).get_ref(), b"test");
assert_eq!(Vec::from("test").size(), BodySize::Sized(4)); assert_eq!(Vec::from("test").size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
poll_fn(|cx| Vec::from("test").poll_next(cx)) Vec::from("test").poll_next().unwrap(),
.await Async::Ready(Some(Bytes::from("test")))
.unwrap()
.ok(),
Some(Bytes::from("test"))
); );
} }
#[actix_rt::test] #[test]
async fn test_bytes() { fn test_bytes() {
let mut b = Bytes::from("test"); let mut b = Bytes::from("test");
assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4));
assert_eq!(Body::from(b.clone()).get_ref(), b"test"); assert_eq!(Body::from(b.clone()).get_ref(), b"test");
assert_eq!(b.size(), BodySize::Sized(4)); assert_eq!(b.size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
poll_fn(|cx| b.poll_next(cx)).await.unwrap().ok(), b.poll_next().unwrap(),
Some(Bytes::from("test")) Async::Ready(Some(Bytes::from("test")))
); );
} }
#[actix_rt::test] #[test]
async fn test_bytes_mut() { fn test_bytes_mut() {
let mut b = BytesMut::from("test"); let mut b = BytesMut::from("test");
assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4));
assert_eq!(Body::from(b.clone()).get_ref(), b"test"); assert_eq!(Body::from(b.clone()).get_ref(), b"test");
assert_eq!(b.size(), BodySize::Sized(4)); assert_eq!(b.size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
poll_fn(|cx| b.poll_next(cx)).await.unwrap().ok(), b.poll_next().unwrap(),
Some(Bytes::from("test")) Async::Ready(Some(Bytes::from("test")))
); );
} }
#[actix_rt::test] #[test]
async fn test_string() { fn test_string() {
let mut b = "test".to_owned(); let mut b = "test".to_owned();
assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4));
assert_eq!(Body::from(b.clone()).get_ref(), b"test"); assert_eq!(Body::from(b.clone()).get_ref(), b"test");
@@ -539,26 +517,26 @@ mod tests {
assert_eq!(b.size(), BodySize::Sized(4)); assert_eq!(b.size(), BodySize::Sized(4));
assert_eq!( assert_eq!(
poll_fn(|cx| b.poll_next(cx)).await.unwrap().ok(), b.poll_next().unwrap(),
Some(Bytes::from("test")) Async::Ready(Some(Bytes::from("test")))
); );
} }
#[actix_rt::test] #[test]
async fn test_unit() { fn test_unit() {
assert_eq!(().size(), BodySize::Empty); assert_eq!(().size(), BodySize::Empty);
assert!(poll_fn(|cx| ().poll_next(cx)).await.is_none()); assert_eq!(().poll_next().unwrap(), Async::Ready(None));
} }
#[actix_rt::test] #[test]
async fn test_box() { fn test_box() {
let mut val = Box::new(()); let mut val = Box::new(());
assert_eq!(val.size(), BodySize::Empty); assert_eq!(val.size(), BodySize::Empty);
assert!(poll_fn(|cx| val.poll_next(cx)).await.is_none()); assert_eq!(val.poll_next().unwrap(), Async::Ready(None));
} }
#[actix_rt::test] #[test]
async fn test_body_eq() { fn test_body_eq() {
assert!(Body::None == Body::None); assert!(Body::None == Body::None);
assert!(Body::None != Body::Empty); assert!(Body::None != Body::Empty);
assert!(Body::Empty == Body::Empty); assert!(Body::Empty == Body::Empty);
@@ -570,15 +548,15 @@ mod tests {
assert!(Body::Bytes(Bytes::from_static(b"1")) != Body::None); assert!(Body::Bytes(Bytes::from_static(b"1")) != Body::None);
} }
#[actix_rt::test] #[test]
async fn test_body_debug() { fn test_body_debug() {
assert!(format!("{:?}", Body::None).contains("Body::None")); assert!(format!("{:?}", Body::None).contains("Body::None"));
assert!(format!("{:?}", Body::Empty).contains("Body::Empty")); assert!(format!("{:?}", Body::Empty).contains("Body::Empty"));
assert!(format!("{:?}", Body::Bytes(Bytes::from_static(b"1"))).contains("1")); assert!(format!("{:?}", Body::Bytes(Bytes::from_static(b"1"))).contains("1"));
} }
#[actix_rt::test] #[test]
async fn test_serde_json() { fn test_serde_json() {
use serde_json::json; use serde_json::json;
assert_eq!( assert_eq!(
Body::from(serde_json::Value::String("test".into())).size(), Body::from(serde_json::Value::String("test".into())).size(),

View File

@@ -1,9 +1,10 @@
use std::fmt;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::rc::Rc; use std::rc::Rc;
use std::{fmt, net};
use actix_codec::Framed; use actix_codec::Framed;
use actix_service::{IntoServiceFactory, Service, ServiceFactory}; use actix_server_config::ServerConfig as SrvConfig;
use actix_service::{IntoNewService, NewService, Service};
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::config::{KeepAlive, ServiceConfig}; use crate::config::{KeepAlive, ServiceConfig};
@@ -23,8 +24,6 @@ pub struct HttpServiceBuilder<T, S, X = ExpectHandler, U = UpgradeHandler<T>> {
keep_alive: KeepAlive, keep_alive: KeepAlive,
client_timeout: u64, client_timeout: u64,
client_disconnect: u64, client_disconnect: u64,
secure: bool,
local_addr: Option<net::SocketAddr>,
expect: X, expect: X,
upgrade: Option<U>, upgrade: Option<U>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
@@ -33,10 +32,9 @@ pub struct HttpServiceBuilder<T, S, X = ExpectHandler, U = UpgradeHandler<T>> {
impl<T, S> HttpServiceBuilder<T, S, ExpectHandler, UpgradeHandler<T>> impl<T, S> HttpServiceBuilder<T, S, ExpectHandler, UpgradeHandler<T>>
where where
S: ServiceFactory<Config = (), Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
<S::Service as Service>::Future: 'static,
{ {
/// Create instance of `ServiceConfigBuilder` /// Create instance of `ServiceConfigBuilder`
pub fn new() -> Self { pub fn new() -> Self {
@@ -44,8 +42,6 @@ where
keep_alive: KeepAlive::Timeout(5), keep_alive: KeepAlive::Timeout(5),
client_timeout: 5000, client_timeout: 5000,
client_disconnect: 0, client_disconnect: 0,
secure: false,
local_addr: None,
expect: ExpectHandler, expect: ExpectHandler,
upgrade: None, upgrade: None,
on_connect: None, on_connect: None,
@@ -56,18 +52,19 @@ where
impl<T, S, X, U> HttpServiceBuilder<T, S, X, U> impl<T, S, X, U> HttpServiceBuilder<T, S, X, U>
where where
S: ServiceFactory<Config = (), Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
<S::Service as Service>::Future: 'static, X: NewService<Config = SrvConfig, Request = Request, Response = Request>,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
<X::Service as Service>::Future: 'static, U: NewService<
U: ServiceFactory<Config = (), Request = (Request, Framed<T, Codec>), Response = ()>, Config = SrvConfig,
Request = (Request, Framed<T, Codec>),
Response = (),
>,
U::Error: fmt::Display, U::Error: fmt::Display,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
<U::Service as Service>::Future: 'static,
{ {
/// Set server keep-alive setting. /// Set server keep-alive setting.
/// ///
@@ -77,18 +74,6 @@ where
self self
} }
/// Set connection secure state
pub fn secure(mut self) -> Self {
self.secure = true;
self
}
/// Set the local address that this service is bound to.
pub fn local_addr(mut self, addr: net::SocketAddr) -> Self {
self.local_addr = Some(addr);
self
}
/// Set server client timeout in milliseconds for first request. /// Set server client timeout in milliseconds for first request.
/// ///
/// Defines a timeout for reading client request header. If a client does not transmit /// Defines a timeout for reading client request header. If a client does not transmit
@@ -123,19 +108,16 @@ where
/// request will be forwarded to main service. /// request will be forwarded to main service.
pub fn expect<F, X1>(self, expect: F) -> HttpServiceBuilder<T, S, X1, U> pub fn expect<F, X1>(self, expect: F) -> HttpServiceBuilder<T, S, X1, U>
where where
F: IntoServiceFactory<X1>, F: IntoNewService<X1>,
X1: ServiceFactory<Config = (), Request = Request, Response = Request>, X1: NewService<Config = SrvConfig, Request = Request, Response = Request>,
X1::Error: Into<Error>, X1::Error: Into<Error>,
X1::InitError: fmt::Debug, X1::InitError: fmt::Debug,
<X1::Service as Service>::Future: 'static,
{ {
HttpServiceBuilder { HttpServiceBuilder {
keep_alive: self.keep_alive, keep_alive: self.keep_alive,
client_timeout: self.client_timeout, client_timeout: self.client_timeout,
client_disconnect: self.client_disconnect, client_disconnect: self.client_disconnect,
secure: self.secure, expect: expect.into_new_service(),
local_addr: self.local_addr,
expect: expect.into_factory(),
upgrade: self.upgrade, upgrade: self.upgrade,
on_connect: self.on_connect, on_connect: self.on_connect,
_t: PhantomData, _t: PhantomData,
@@ -148,24 +130,21 @@ where
/// and this service get called with original request and framed object. /// and this service get called with original request and framed object.
pub fn upgrade<F, U1>(self, upgrade: F) -> HttpServiceBuilder<T, S, X, U1> pub fn upgrade<F, U1>(self, upgrade: F) -> HttpServiceBuilder<T, S, X, U1>
where where
F: IntoServiceFactory<U1>, F: IntoNewService<U1>,
U1: ServiceFactory< U1: NewService<
Config = (), Config = SrvConfig,
Request = (Request, Framed<T, Codec>), Request = (Request, Framed<T, Codec>),
Response = (), Response = (),
>, >,
U1::Error: fmt::Display, U1::Error: fmt::Display,
U1::InitError: fmt::Debug, U1::InitError: fmt::Debug,
<U1::Service as Service>::Future: 'static,
{ {
HttpServiceBuilder { HttpServiceBuilder {
keep_alive: self.keep_alive, keep_alive: self.keep_alive,
client_timeout: self.client_timeout, client_timeout: self.client_timeout,
client_disconnect: self.client_disconnect, client_disconnect: self.client_disconnect,
secure: self.secure,
local_addr: self.local_addr,
expect: self.expect, expect: self.expect,
upgrade: Some(upgrade.into_factory()), upgrade: Some(upgrade.into_new_service()),
on_connect: self.on_connect, on_connect: self.on_connect,
_t: PhantomData, _t: PhantomData,
} }
@@ -185,10 +164,10 @@ where
} }
/// Finish service configuration and create *http service* for HTTP/1 protocol. /// Finish service configuration and create *http service* for HTTP/1 protocol.
pub fn h1<F, B>(self, service: F) -> H1Service<T, S, B, X, U> pub fn h1<F, P, B>(self, service: F) -> H1Service<T, P, S, B, X, U>
where where
B: MessageBody, B: MessageBody + 'static,
F: IntoServiceFactory<S>, F: IntoNewService<S>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
@@ -197,53 +176,48 @@ where
self.keep_alive, self.keep_alive,
self.client_timeout, self.client_timeout,
self.client_disconnect, self.client_disconnect,
self.secure,
self.local_addr,
); );
H1Service::with_config(cfg, service.into_factory()) H1Service::with_config(cfg, service.into_new_service())
.expect(self.expect) .expect(self.expect)
.upgrade(self.upgrade) .upgrade(self.upgrade)
.on_connect(self.on_connect) .on_connect(self.on_connect)
} }
/// Finish service configuration and create *http service* for HTTP/2 protocol. /// Finish service configuration and create *http service* for HTTP/2 protocol.
pub fn h2<F, B>(self, service: F) -> H2Service<T, S, B> pub fn h2<F, P, B>(self, service: F) -> H2Service<T, P, S, B>
where where
B: MessageBody + 'static, B: MessageBody + 'static,
F: IntoServiceFactory<S>, F: IntoNewService<S>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
{ {
let cfg = ServiceConfig::new( let cfg = ServiceConfig::new(
self.keep_alive, self.keep_alive,
self.client_timeout, self.client_timeout,
self.client_disconnect, self.client_disconnect,
self.secure,
self.local_addr,
); );
H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect) H2Service::with_config(cfg, service.into_new_service())
.on_connect(self.on_connect)
} }
/// Finish service configuration and create `HttpService` instance. /// Finish service configuration and create `HttpService` instance.
pub fn finish<F, B>(self, service: F) -> HttpService<T, S, B, X, U> pub fn finish<F, P, B>(self, service: F) -> HttpService<T, P, S, B, X, U>
where where
B: MessageBody + 'static, B: MessageBody + 'static,
F: IntoServiceFactory<S>, F: IntoNewService<S>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
{ {
let cfg = ServiceConfig::new( let cfg = ServiceConfig::new(
self.keep_alive, self.keep_alive,
self.client_timeout, self.client_timeout,
self.client_disconnect, self.client_disconnect,
self.secure,
self.local_addr,
); );
HttpService::with_config(cfg, service.into_factory()) HttpService::with_config(cfg, service.into_new_service())
.expect(self.expect) .expect(self.expect)
.upgrade(self.upgrade) .upgrade(self.upgrade)
.on_connect(self.on_connect) .on_connect(self.on_connect)

View File

@@ -1,12 +1,10 @@
use std::pin::Pin; use std::{fmt, io, time};
use std::task::{Context, Poll};
use std::{fmt, io, mem, time};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use bytes::{Buf, Bytes}; use bytes::{Buf, Bytes};
use futures_util::future::{err, Either, Future, FutureExt, LocalBoxFuture, Ready}; use futures::future::{err, Either, Future, FutureResult};
use futures::Poll;
use h2::client::SendRequest; use h2::client::SendRequest;
use pin_project::{pin_project, project};
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::h1::ClientCodec; use crate::h1::ClientCodec;
@@ -23,8 +21,8 @@ pub(crate) enum ConnectionType<Io> {
} }
pub trait Connection { pub trait Connection {
type Io: AsyncRead + AsyncWrite + Unpin; type Io: AsyncRead + AsyncWrite;
type Future: Future<Output = Result<(ResponseHead, Payload), SendRequestError>>; type Future: Future<Item = (ResponseHead, Payload), Error = SendRequestError>;
fn protocol(&self) -> Protocol; fn protocol(&self) -> Protocol;
@@ -36,7 +34,8 @@ pub trait Connection {
) -> Self::Future; ) -> Self::Future;
type TunnelFuture: Future< type TunnelFuture: Future<
Output = Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>, Item = (ResponseHead, Framed<Self::Io, ClientCodec>),
Error = SendRequestError,
>; >;
/// Send request, returns Response and Framed /// Send request, returns Response and Framed
@@ -63,7 +62,7 @@ impl<T> fmt::Debug for IoConnection<T>
where where
T: fmt::Debug, T: fmt::Debug,
{ {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.io { match self.io {
Some(ConnectionType::H1(ref io)) => write!(f, "H1Connection({:?})", io), Some(ConnectionType::H1(ref io)) => write!(f, "H1Connection({:?})", io),
Some(ConnectionType::H2(_)) => write!(f, "H2Connection"), Some(ConnectionType::H2(_)) => write!(f, "H2Connection"),
@@ -72,7 +71,7 @@ where
} }
} }
impl<T: AsyncRead + AsyncWrite + Unpin> IoConnection<T> { impl<T: AsyncRead + AsyncWrite> IoConnection<T> {
pub(crate) fn new( pub(crate) fn new(
io: ConnectionType<T>, io: ConnectionType<T>,
created: time::Instant, created: time::Instant,
@@ -92,11 +91,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> IoConnection<T> {
impl<T> Connection for IoConnection<T> impl<T> Connection for IoConnection<T>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
{ {
type Io = T; type Io = T;
type Future = type Future =
LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; Box<dyn Future<Item = (ResponseHead, Payload), Error = SendRequestError>>;
fn protocol(&self) -> Protocol { fn protocol(&self) -> Protocol {
match self.io { match self.io {
@@ -112,30 +111,38 @@ where
body: B, body: B,
) -> Self::Future { ) -> Self::Future {
match self.io.take().unwrap() { match self.io.take().unwrap() {
ConnectionType::H1(io) => { ConnectionType::H1(io) => Box::new(h1proto::send_request(
h1proto::send_request(io, head.into(), body, self.created, self.pool) io,
.boxed_local() head.into(),
} body,
ConnectionType::H2(io) => { self.created,
h2proto::send_request(io, head.into(), body, self.created, self.pool) self.pool,
.boxed_local() )),
} ConnectionType::H2(io) => Box::new(h2proto::send_request(
io,
head.into(),
body,
self.created,
self.pool,
)),
} }
} }
type TunnelFuture = Either< type TunnelFuture = Either<
LocalBoxFuture< Box<
'static, dyn Future<
Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>, Item = (ResponseHead, Framed<Self::Io, ClientCodec>),
Error = SendRequestError,
>, >,
Ready<Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>>, >,
FutureResult<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
>; >;
/// Send request, returns Response and Framed /// Send request, returns Response and Framed
fn open_tunnel<H: Into<RequestHeadType>>(mut self, head: H) -> Self::TunnelFuture { fn open_tunnel<H: Into<RequestHeadType>>(mut self, head: H) -> Self::TunnelFuture {
match self.io.take().unwrap() { match self.io.take().unwrap() {
ConnectionType::H1(io) => { ConnectionType::H1(io) => {
Either::Left(h1proto::open_tunnel(io, head.into()).boxed_local()) Either::A(Box::new(h1proto::open_tunnel(io, head.into())))
} }
ConnectionType::H2(io) => { ConnectionType::H2(io) => {
if let Some(mut pool) = self.pool.take() { if let Some(mut pool) = self.pool.take() {
@@ -145,7 +152,7 @@ where
None, None,
)); ));
} }
Either::Right(err(SendRequestError::TunnelNotSupported)) Either::B(err(SendRequestError::TunnelNotSupported))
} }
} }
} }
@@ -159,12 +166,12 @@ pub(crate) enum EitherConnection<A, B> {
impl<A, B> Connection for EitherConnection<A, B> impl<A, B> Connection for EitherConnection<A, B>
where where
A: AsyncRead + AsyncWrite + Unpin + 'static, A: AsyncRead + AsyncWrite + 'static,
B: AsyncRead + AsyncWrite + Unpin + 'static, B: AsyncRead + AsyncWrite + 'static,
{ {
type Io = EitherIo<A, B>; type Io = EitherIo<A, B>;
type Future = type Future =
LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; Box<dyn Future<Item = (ResponseHead, Payload), Error = SendRequestError>>;
fn protocol(&self) -> Protocol { fn protocol(&self) -> Protocol {
match self { match self {
@@ -184,30 +191,44 @@ where
} }
} }
type TunnelFuture = LocalBoxFuture< type TunnelFuture = Box<
'static, dyn Future<
Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>, Item = (ResponseHead, Framed<Self::Io, ClientCodec>),
Error = SendRequestError,
>,
>; >;
/// Send request, returns Response and Framed /// Send request, returns Response and Framed
fn open_tunnel<H: Into<RequestHeadType>>(self, head: H) -> Self::TunnelFuture { fn open_tunnel<H: Into<RequestHeadType>>(self, head: H) -> Self::TunnelFuture {
match self { match self {
EitherConnection::A(con) => con EitherConnection::A(con) => Box::new(
.open_tunnel(head) con.open_tunnel(head)
.map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::A)))) .map(|(head, framed)| (head, framed.map_io(EitherIo::A))),
.boxed_local(), ),
EitherConnection::B(con) => con EitherConnection::B(con) => Box::new(
.open_tunnel(head) con.open_tunnel(head)
.map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::B)))) .map(|(head, framed)| (head, framed.map_io(EitherIo::B))),
.boxed_local(), ),
} }
} }
} }
#[pin_project]
pub enum EitherIo<A, B> { pub enum EitherIo<A, B> {
A(#[pin] A), A(A),
B(#[pin] B), B(B),
}
impl<A, B> io::Read for EitherIo<A, B>
where
A: io::Read,
B: io::Read,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
EitherIo::A(ref mut val) => val.read(buf),
EitherIo::B(ref mut val) => val.read(buf),
}
}
} }
impl<A, B> AsyncRead for EitherIo<A, B> impl<A, B> AsyncRead for EitherIo<A, B>
@@ -215,23 +236,7 @@ where
A: AsyncRead, A: AsyncRead,
B: AsyncRead, B: AsyncRead,
{ {
#[project] unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
#[project]
match self.project() {
EitherIo::A(val) => val.poll_read(cx, buf),
EitherIo::B(val) => val.poll_read(cx, buf),
}
}
unsafe fn prepare_uninitialized_buffer(
&self,
buf: &mut [mem::MaybeUninit<u8>],
) -> bool {
match self { match self {
EitherIo::A(ref val) => val.prepare_uninitialized_buffer(buf), EitherIo::A(ref val) => val.prepare_uninitialized_buffer(buf),
EitherIo::B(ref val) => val.prepare_uninitialized_buffer(buf), EitherIo::B(ref val) => val.prepare_uninitialized_buffer(buf),
@@ -239,58 +244,45 @@ where
} }
} }
impl<A, B> io::Write for EitherIo<A, B>
where
A: io::Write,
B: io::Write,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
EitherIo::A(ref mut val) => val.write(buf),
EitherIo::B(ref mut val) => val.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
EitherIo::A(ref mut val) => val.flush(),
EitherIo::B(ref mut val) => val.flush(),
}
}
}
impl<A, B> AsyncWrite for EitherIo<A, B> impl<A, B> AsyncWrite for EitherIo<A, B>
where where
A: AsyncWrite, A: AsyncWrite,
B: AsyncWrite, B: AsyncWrite,
{ {
#[project] fn shutdown(&mut self) -> Poll<(), io::Error> {
fn poll_write( match self {
self: Pin<&mut Self>, EitherIo::A(ref mut val) => val.shutdown(),
cx: &mut Context<'_>, EitherIo::B(ref mut val) => val.shutdown(),
buf: &[u8],
) -> Poll<io::Result<usize>> {
#[project]
match self.project() {
EitherIo::A(val) => val.poll_write(cx, buf),
EitherIo::B(val) => val.poll_write(cx, buf),
} }
} }
#[project] fn write_buf<U: Buf>(&mut self, buf: &mut U) -> Poll<usize, io::Error>
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
#[project]
match self.project() {
EitherIo::A(val) => val.poll_flush(cx),
EitherIo::B(val) => val.poll_flush(cx),
}
}
#[project]
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
#[project]
match self.project() {
EitherIo::A(val) => val.poll_shutdown(cx),
EitherIo::B(val) => val.poll_shutdown(cx),
}
}
#[project]
fn poll_write_buf<U: Buf>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut U,
) -> Poll<Result<usize, io::Error>>
where where
Self: Sized, Self: Sized,
{ {
#[project] match self {
match self.project() { EitherIo::A(ref mut val) => val.write_buf(buf),
EitherIo::A(val) => val.poll_write_buf(cx, buf), EitherIo::B(ref mut val) => val.write_buf(buf),
EitherIo::B(val) => val.poll_write_buf(cx, buf),
} }
} }
} }

View File

@@ -6,32 +6,32 @@ use actix_codec::{AsyncRead, AsyncWrite};
use actix_connect::{ use actix_connect::{
default_connector, Connect as TcpConnect, Connection as TcpConnection, default_connector, Connect as TcpConnect, Connection as TcpConnection,
}; };
use actix_rt::net::TcpStream; use actix_service::{apply_fn, Service, ServiceExt};
use actix_service::{apply_fn, Service};
use actix_utils::timeout::{TimeoutError, TimeoutService}; use actix_utils::timeout::{TimeoutError, TimeoutService};
use http::Uri; use http::Uri;
use tokio_tcp::TcpStream;
use super::connection::Connection; use super::connection::Connection;
use super::error::ConnectError; use super::error::ConnectError;
use super::pool::{ConnectionPool, Protocol}; use super::pool::{ConnectionPool, Protocol};
use super::Connect; use super::Connect;
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
use actix_connect::ssl::openssl::SslConnector as OpensslConnector; use openssl::ssl::SslConnector as OpensslConnector;
#[cfg(feature = "rustls")] #[cfg(feature = "rust-tls")]
use actix_connect::ssl::rustls::ClientConfig; use rustls::ClientConfig;
#[cfg(feature = "rustls")] #[cfg(feature = "rust-tls")]
use std::sync::Arc; use std::sync::Arc;
#[cfg(any(feature = "openssl", feature = "rustls"))] #[cfg(any(feature = "ssl", feature = "rust-tls"))]
enum SslConnector { enum SslConnector {
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
Openssl(OpensslConnector), Openssl(OpensslConnector),
#[cfg(feature = "rustls")] #[cfg(feature = "rust-tls")]
Rustls(Arc<ClientConfig>), Rustls(Arc<ClientConfig>),
} }
#[cfg(not(any(feature = "openssl", feature = "rustls")))] #[cfg(not(any(feature = "ssl", feature = "rust-tls")))]
type SslConnector = (); type SslConnector = ();
/// Manages http client network connectivity /// Manages http client network connectivity
@@ -58,11 +58,11 @@ pub struct Connector<T, U> {
_t: PhantomData<U>, _t: PhantomData<U>,
} }
trait Io: AsyncRead + AsyncWrite + Unpin {} trait Io: AsyncRead + AsyncWrite {}
impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {} impl<T: AsyncRead + AsyncWrite> Io for T {}
impl Connector<(), ()> { impl Connector<(), ()> {
#[allow(clippy::new_ret_no_self, clippy::let_unit_value)] #[allow(clippy::new_ret_no_self)]
pub fn new() -> Connector< pub fn new() -> Connector<
impl Service< impl Service<
Request = TcpConnect<Uri>, Request = TcpConnect<Uri>,
@@ -72,9 +72,9 @@ impl Connector<(), ()> {
TcpStream, TcpStream,
> { > {
let ssl = { let ssl = {
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
{ {
use actix_connect::ssl::openssl::SslMethod; use openssl::ssl::SslMethod;
let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap(); let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap();
let _ = ssl let _ = ssl
@@ -82,17 +82,17 @@ impl Connector<(), ()> {
.map_err(|e| error!("Can not set alpn protocol: {:?}", e)); .map_err(|e| error!("Can not set alpn protocol: {:?}", e));
SslConnector::Openssl(ssl.build()) SslConnector::Openssl(ssl.build())
} }
#[cfg(all(not(feature = "openssl"), feature = "rustls"))] #[cfg(all(not(feature = "ssl"), feature = "rust-tls"))]
{ {
let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let mut config = ClientConfig::new(); let mut config = ClientConfig::new();
config.set_protocols(&protos); config.set_protocols(&protos);
config config
.root_store .root_store
.add_server_trust_anchors(&actix_tls::rustls::TLS_SERVER_ROOTS); .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
SslConnector::Rustls(Arc::new(config)) SslConnector::Rustls(Arc::new(config))
} }
#[cfg(not(any(feature = "openssl", feature = "rustls")))] #[cfg(not(any(feature = "ssl", feature = "rust-tls")))]
{} {}
}; };
@@ -113,7 +113,7 @@ impl<T, U> Connector<T, U> {
/// Use custom connector. /// Use custom connector.
pub fn connector<T1, U1>(self, connector: T1) -> Connector<T1, U1> pub fn connector<T1, U1>(self, connector: T1) -> Connector<T1, U1>
where where
U1: AsyncRead + AsyncWrite + Unpin + fmt::Debug, U1: AsyncRead + AsyncWrite + fmt::Debug,
T1: Service< T1: Service<
Request = TcpConnect<Uri>, Request = TcpConnect<Uri>,
Response = TcpConnection<Uri, U1>, Response = TcpConnection<Uri, U1>,
@@ -135,7 +135,7 @@ impl<T, U> Connector<T, U> {
impl<T, U> Connector<T, U> impl<T, U> Connector<T, U>
where where
U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, U: AsyncRead + AsyncWrite + fmt::Debug + 'static,
T: Service< T: Service<
Request = TcpConnect<Uri>, Request = TcpConnect<Uri>,
Response = TcpConnection<Uri, U>, Response = TcpConnection<Uri, U>,
@@ -150,14 +150,14 @@ where
self self
} }
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
/// Use custom `SslConnector` instance. /// Use custom `SslConnector` instance.
pub fn ssl(mut self, connector: OpensslConnector) -> Self { pub fn ssl(mut self, connector: OpensslConnector) -> Self {
self.ssl = SslConnector::Openssl(connector); self.ssl = SslConnector::Openssl(connector);
self self
} }
#[cfg(feature = "rustls")] #[cfg(feature = "rust-tls")]
pub fn rustls(mut self, connector: Arc<ClientConfig>) -> Self { pub fn rustls(mut self, connector: Arc<ClientConfig>) -> Self {
self.ssl = SslConnector::Rustls(connector); self.ssl = SslConnector::Rustls(connector);
self self
@@ -213,7 +213,7 @@ where
self, self,
) -> impl Service<Request = Connect, Response = impl Connection, Error = ConnectError> ) -> impl Service<Request = Connect, Response = impl Connection, Error = ConnectError>
+ Clone { + Clone {
#[cfg(not(any(feature = "openssl", feature = "rustls")))] #[cfg(not(any(feature = "ssl", feature = "rust-tls")))]
{ {
let connector = TimeoutService::new( let connector = TimeoutService::new(
self.timeout, self.timeout,
@@ -238,30 +238,32 @@ where
), ),
} }
} }
#[cfg(any(feature = "openssl", feature = "rustls"))] #[cfg(any(feature = "ssl", feature = "rust-tls"))]
{ {
const H2: &[u8] = b"h2"; const H2: &[u8] = b"h2";
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
use actix_connect::ssl::openssl::OpensslConnector; use actix_connect::ssl::OpensslConnector;
#[cfg(feature = "rustls")] #[cfg(feature = "rust-tls")]
use actix_connect::ssl::rustls::{RustlsConnector, Session}; use actix_connect::ssl::RustlsConnector;
use actix_service::{boxed::service, pipeline}; use actix_service::boxed::service;
#[cfg(feature = "rust-tls")]
use rustls::Session;
let ssl_service = TimeoutService::new( let ssl_service = TimeoutService::new(
self.timeout, self.timeout,
pipeline(
apply_fn(self.connector.clone(), |msg: Connect, srv| { apply_fn(self.connector.clone(), |msg: Connect, srv| {
srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
}) })
.map_err(ConnectError::from), .map_err(ConnectError::from)
)
.and_then(match self.ssl { .and_then(match self.ssl {
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
SslConnector::Openssl(ssl) => service( SslConnector::Openssl(ssl) => service(
OpensslConnector::service(ssl) OpensslConnector::service(ssl)
.map_err(ConnectError::from)
.map(|stream| { .map(|stream| {
let sock = stream.into_parts().0; let sock = stream.into_parts().0;
let h2 = sock let h2 = sock
.get_ref()
.ssl() .ssl()
.selected_alpn_protocol() .selected_alpn_protocol()
.map(|protos| protos.windows(2).any(|w| w == H2)) .map(|protos| protos.windows(2).any(|w| w == H2))
@@ -271,10 +273,9 @@ where
} else { } else {
(Box::new(sock) as Box<dyn Io>, Protocol::Http1) (Box::new(sock) as Box<dyn Io>, Protocol::Http1)
} }
}) }),
.map_err(ConnectError::from),
), ),
#[cfg(feature = "rustls")] #[cfg(feature = "rust-tls")]
SslConnector::Rustls(ssl) => service( SslConnector::Rustls(ssl) => service(
RustlsConnector::service(ssl) RustlsConnector::service(ssl)
.map_err(ConnectError::from) .map_err(ConnectError::from)
@@ -302,7 +303,7 @@ where
let tcp_service = TimeoutService::new( let tcp_service = TimeoutService::new(
self.timeout, self.timeout,
apply_fn(self.connector, |msg: Connect, srv| { apply_fn(self.connector.clone(), |msg: Connect, srv| {
srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
}) })
.map_err(ConnectError::from) .map_err(ConnectError::from)
@@ -333,19 +334,19 @@ where
} }
} }
#[cfg(not(any(feature = "openssl", feature = "rustls")))] #[cfg(not(any(feature = "ssl", feature = "rust-tls")))]
mod connect_impl { mod connect_impl {
use std::task::{Context, Poll}; use futures::future::{err, Either, FutureResult};
use futures::Poll;
use futures_util::future::{err, Either, Ready};
use super::*; use super::*;
use crate::client::connection::IoConnection; use crate::client::connection::IoConnection;
pub(crate) struct InnerConnector<T, Io> pub(crate) struct InnerConnector<T, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
pub(crate) tcp_pool: ConnectionPool<T, Io>, pub(crate) tcp_pool: ConnectionPool<T, Io>,
@@ -353,8 +354,9 @@ mod connect_impl {
impl<T, Io> Clone for InnerConnector<T, Io> impl<T, Io> Clone for InnerConnector<T, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
@@ -366,8 +368,9 @@ mod connect_impl {
impl<T, Io> Service for InnerConnector<T, Io> impl<T, Io> Service for InnerConnector<T, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
type Request = Connect; type Request = Connect;
@@ -375,41 +378,38 @@ mod connect_impl {
type Error = ConnectError; type Error = ConnectError;
type Future = Either< type Future = Either<
<ConnectionPool<T, Io> as Service>::Future, <ConnectionPool<T, Io> as Service>::Future,
Ready<Result<IoConnection<Io>, ConnectError>>, FutureResult<IoConnection<Io>, ConnectError>,
>; >;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.tcp_pool.poll_ready(cx) self.tcp_pool.poll_ready()
} }
fn call(&mut self, req: Connect) -> Self::Future { fn call(&mut self, req: Connect) -> Self::Future {
match req.uri.scheme_str() { match req.uri.scheme_str() {
Some("https") | Some("wss") => { Some("https") | Some("wss") => {
Either::Right(err(ConnectError::SslIsNotSupported)) Either::B(err(ConnectError::SslIsNotSupported))
} }
_ => Either::Left(self.tcp_pool.call(req)), _ => Either::A(self.tcp_pool.call(req)),
} }
} }
} }
} }
#[cfg(any(feature = "openssl", feature = "rustls"))] #[cfg(any(feature = "ssl", feature = "rust-tls"))]
mod connect_impl { mod connect_impl {
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_core::ready; use futures::future::{Either, FutureResult};
use futures_util::future::Either; use futures::{Async, Future, Poll};
use super::*; use super::*;
use crate::client::connection::EitherConnection; use crate::client::connection::EitherConnection;
pub(crate) struct InnerConnector<T1, T2, Io1, Io2> pub(crate) struct InnerConnector<T1, T2, Io1, Io2>
where where
Io1: AsyncRead + AsyncWrite + Unpin + 'static, Io1: AsyncRead + AsyncWrite + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static, Io2: AsyncRead + AsyncWrite + 'static,
T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>, T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>,
T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>, T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>,
{ {
@@ -419,11 +419,13 @@ mod connect_impl {
impl<T1, T2, Io1, Io2> Clone for InnerConnector<T1, T2, Io1, Io2> impl<T1, T2, Io1, Io2> Clone for InnerConnector<T1, T2, Io1, Io2>
where where
Io1: AsyncRead + AsyncWrite + Unpin + 'static, Io1: AsyncRead + AsyncWrite + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static, Io2: AsyncRead + AsyncWrite + 'static,
T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError> T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError> T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
@@ -436,47 +438,53 @@ mod connect_impl {
impl<T1, T2, Io1, Io2> Service for InnerConnector<T1, T2, Io1, Io2> impl<T1, T2, Io1, Io2> Service for InnerConnector<T1, T2, Io1, Io2>
where where
Io1: AsyncRead + AsyncWrite + Unpin + 'static, Io1: AsyncRead + AsyncWrite + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static, Io2: AsyncRead + AsyncWrite + 'static,
T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError> T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError> T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
type Request = Connect; type Request = Connect;
type Response = EitherConnection<Io1, Io2>; type Response = EitherConnection<Io1, Io2>;
type Error = ConnectError; type Error = ConnectError;
type Future = Either< type Future = Either<
FutureResult<Self::Response, Self::Error>,
Either<
InnerConnectorResponseA<T1, Io1, Io2>, InnerConnectorResponseA<T1, Io1, Io2>,
InnerConnectorResponseB<T2, Io1, Io2>, InnerConnectorResponseB<T2, Io1, Io2>,
>,
>; >;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.tcp_pool.poll_ready(cx) self.tcp_pool.poll_ready()
} }
fn call(&mut self, req: Connect) -> Self::Future { fn call(&mut self, req: Connect) -> Self::Future {
match req.uri.scheme_str() { match req.uri.scheme_str() {
Some("https") | Some("wss") => Either::Right(InnerConnectorResponseB { Some("https") | Some("wss") => {
Either::B(Either::B(InnerConnectorResponseB {
fut: self.ssl_pool.call(req), fut: self.ssl_pool.call(req),
_t: PhantomData, _t: PhantomData,
}), }))
_ => Either::Left(InnerConnectorResponseA { }
_ => Either::B(Either::A(InnerConnectorResponseA {
fut: self.tcp_pool.call(req), fut: self.tcp_pool.call(req),
_t: PhantomData, _t: PhantomData,
}), })),
} }
} }
} }
#[pin_project::pin_project]
pub(crate) struct InnerConnectorResponseA<T, Io1, Io2> pub(crate) struct InnerConnectorResponseA<T, Io1, Io2>
where where
Io1: AsyncRead + AsyncWrite + Unpin + 'static, Io1: AsyncRead + AsyncWrite + 'static,
T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
#[pin]
fut: <ConnectionPool<T, Io1> as Service>::Future, fut: <ConnectionPool<T, Io1> as Service>::Future,
_t: PhantomData<Io2>, _t: PhantomData<Io2>,
} }
@@ -484,28 +492,29 @@ mod connect_impl {
impl<T, Io1, Io2> Future for InnerConnectorResponseA<T, Io1, Io2> impl<T, Io1, Io2> Future for InnerConnectorResponseA<T, Io1, Io2>
where where
T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
Io1: AsyncRead + AsyncWrite + Unpin + 'static, Io1: AsyncRead + AsyncWrite + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static, Io2: AsyncRead + AsyncWrite + 'static,
{ {
type Output = Result<EitherConnection<Io1, Io2>, ConnectError>; type Item = EitherConnection<Io1, Io2>;
type Error = ConnectError;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
Poll::Ready( match self.fut.poll()? {
ready!(Pin::new(&mut self.get_mut().fut).poll(cx)) Async::NotReady => Ok(Async::NotReady),
.map(EitherConnection::A), Async::Ready(res) => Ok(Async::Ready(EitherConnection::A(res))),
) }
} }
} }
#[pin_project::pin_project]
pub(crate) struct InnerConnectorResponseB<T, Io1, Io2> pub(crate) struct InnerConnectorResponseB<T, Io1, Io2>
where where
Io2: AsyncRead + AsyncWrite + Unpin + 'static, Io2: AsyncRead + AsyncWrite + 'static,
T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
#[pin]
fut: <ConnectionPool<T, Io2> as Service>::Future, fut: <ConnectionPool<T, Io2> as Service>::Future,
_t: PhantomData<Io1>, _t: PhantomData<Io1>,
} }
@@ -513,17 +522,19 @@ mod connect_impl {
impl<T, Io1, Io2> Future for InnerConnectorResponseB<T, Io1, Io2> impl<T, Io1, Io2> Future for InnerConnectorResponseB<T, Io1, Io2>
where where
T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
Io1: AsyncRead + AsyncWrite + Unpin + 'static, Io1: AsyncRead + AsyncWrite + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static, Io2: AsyncRead + AsyncWrite + 'static,
{ {
type Output = Result<EitherConnection<Io1, Io2>, ConnectError>; type Item = EitherConnection<Io1, Io2>;
type Error = ConnectError;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
Poll::Ready( match self.fut.poll()? {
ready!(Pin::new(&mut self.get_mut().fut).poll(cx)) Async::NotReady => Ok(Async::NotReady),
.map(EitherConnection::B), Async::Ready(res) => Ok(Async::Ready(EitherConnection::B(res))),
) }
} }
} }
} }

View File

@@ -1,13 +1,14 @@
use std::io; use std::io;
use actix_connect::resolver::ResolveError;
use derive_more::{Display, From}; use derive_more::{Display, From};
use trust_dns_resolver::error::ResolveError;
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
use actix_connect::ssl::openssl::{HandshakeError, SslError}; use openssl::ssl::{Error as SslError, HandshakeError};
use crate::error::{Error, ParseError, ResponseError}; use crate::error::{Error, ParseError, ResponseError};
use crate::http::{Error as HttpError, StatusCode}; use crate::http::Error as HttpError;
use crate::response::Response;
/// A set of errors that can occur while connecting to an HTTP host /// A set of errors that can occur while connecting to an HTTP host
#[derive(Debug, Display, From)] #[derive(Debug, Display, From)]
@@ -17,15 +18,10 @@ pub enum ConnectError {
SslIsNotSupported, SslIsNotSupported,
/// SSL error /// SSL error
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
#[display(fmt = "{}", _0)] #[display(fmt = "{}", _0)]
SslError(SslError), SslError(SslError),
/// SSL Handshake error
#[cfg(feature = "openssl")]
#[display(fmt = "{}", _0)]
SslHandshakeError(String),
/// Failed to resolve the hostname /// Failed to resolve the hostname
#[display(fmt = "Failed resolving hostname: {}", _0)] #[display(fmt = "Failed resolving hostname: {}", _0)]
Resolver(ResolveError), Resolver(ResolveError),
@@ -67,10 +63,14 @@ impl From<actix_connect::ConnectError> for ConnectError {
} }
} }
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
impl<T: std::fmt::Debug> From<HandshakeError<T>> for ConnectError { impl<T> From<HandshakeError<T>> for ConnectError {
fn from(err: HandshakeError<T>) -> ConnectError { fn from(err: HandshakeError<T>) -> ConnectError {
ConnectError::SslHandshakeError(format!("{:?}", err)) match err {
HandshakeError::SetupFailure(stack) => SslError::from(stack).into(),
HandshakeError::Failure(stream) => stream.into_error().into(),
HandshakeError::WouldBlock(stream) => stream.into_error().into(),
}
} }
} }
@@ -117,14 +117,15 @@ pub enum SendRequestError {
/// Convert `SendRequestError` to a server `Response` /// Convert `SendRequestError` to a server `Response`
impl ResponseError for SendRequestError { impl ResponseError for SendRequestError {
fn status_code(&self) -> StatusCode { fn error_response(&self) -> Response {
match *self { match *self {
SendRequestError::Connect(ConnectError::Timeout) => { SendRequestError::Connect(ConnectError::Timeout) => {
StatusCode::GATEWAY_TIMEOUT Response::GatewayTimeout()
} }
SendRequestError::Connect(_) => StatusCode::BAD_REQUEST, SendRequestError::Connect(_) => Response::BadGateway(),
_ => StatusCode::INTERNAL_SERVER_ERROR, _ => Response::InternalServerError(),
} }
.into()
} }
} }

View File

@@ -1,14 +1,10 @@
use std::io::Write; use std::io::Write;
use std::pin::Pin; use std::{io, time};
use std::task::{Context, Poll};
use std::{io, mem, time};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use bytes::buf::BufMutExt; use bytes::{BufMut, Bytes, BytesMut};
use bytes::{Bytes, BytesMut}; use futures::future::{ok, Either};
use futures_core::Stream; use futures::{Async, Future, Poll, Sink, Stream};
use futures_util::future::poll_fn;
use futures_util::{SinkExt, StreamExt};
use crate::error::PayloadError; use crate::error::PayloadError;
use crate::h1; use crate::h1;
@@ -22,15 +18,15 @@ use super::error::{ConnectError, SendRequestError};
use super::pool::Acquired; use super::pool::Acquired;
use crate::body::{BodySize, MessageBody}; use crate::body::{BodySize, MessageBody};
pub(crate) async fn send_request<T, B>( pub(crate) fn send_request<T, B>(
io: T, io: T,
mut head: RequestHeadType, mut head: RequestHeadType,
body: B, body: B,
created: time::Instant, created: time::Instant,
pool: Option<Acquired<T>>, pool: Option<Acquired<T>>,
) -> Result<(ResponseHead, Payload), SendRequestError> ) -> impl Future<Item = (ResponseHead, Payload), Error = SendRequestError>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
B: MessageBody, B: MessageBody,
{ {
// set request host header // set request host header
@@ -45,7 +41,7 @@ where
Some(port) => write!(wrt, "{}:{}", host, port), Some(port) => write!(wrt, "{}:{}", host, port),
}; };
match wrt.get_mut().split().freeze().try_into() { match wrt.get_mut().take().freeze().try_into() {
Ok(value) => match head { Ok(value) => match head {
RequestHeadType::Owned(ref mut head) => { RequestHeadType::Owned(ref mut head) => {
head.headers.insert(HOST, value) head.headers.insert(HOST, value)
@@ -66,99 +62,68 @@ where
io: Some(io), io: Some(io),
}; };
let len = body.size();
// create Framed and send request // create Framed and send request
let mut framed = Framed::new(io, h1::ClientCodec::default()); Framed::new(io, h1::ClientCodec::default())
framed.send((head, body.size()).into()).await?; .send((head, len).into())
.from_err()
// send request body // send request body
match body.size() { .and_then(move |framed| match body.size() {
BodySize::None | BodySize::Empty | BodySize::Sized(0) => (), BodySize::None | BodySize::Empty | BodySize::Sized(0) => {
_ => send_body(body, &mut framed).await?, Either::A(ok(framed))
}; }
_ => Either::B(SendBody::new(body, framed)),
})
// read response and init read body // read response and init read body
let res = framed.into_future().await; .and_then(|framed| {
let (head, framed) = if let (Some(result), framed) = res { framed
let item = result.map_err(SendRequestError::from)?; .into_future()
(item, framed) .map_err(|(e, _)| SendRequestError::from(e))
} else { .and_then(|(item, framed)| {
return Err(SendRequestError::from(ConnectError::Disconnected)); if let Some(res) = item {
};
match framed.get_codec().message_type() { match framed.get_codec().message_type() {
h1::MessageType::None => { h1::MessageType::None => {
let force_close = !framed.get_codec().keepalive(); let force_close = !framed.get_codec().keepalive();
release_connection(framed, force_close); release_connection(framed, force_close);
Ok((head, Payload::None)) Ok((res, Payload::None))
} }
_ => { _ => {
let pl: PayloadStream = PlStream::new(framed).boxed_local(); let pl: PayloadStream = Box::new(PlStream::new(framed));
Ok((head, pl.into())) Ok((res, pl.into()))
} }
} }
} else {
Err(ConnectError::Disconnected.into())
}
})
})
} }
pub(crate) async fn open_tunnel<T>( pub(crate) fn open_tunnel<T>(
io: T, io: T,
head: RequestHeadType, head: RequestHeadType,
) -> Result<(ResponseHead, Framed<T, h1::ClientCodec>), SendRequestError> ) -> impl Future<Item = (ResponseHead, Framed<T, h1::ClientCodec>), Error = SendRequestError>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
{ {
// create Framed and send request // create Framed and send request
let mut framed = Framed::new(io, h1::ClientCodec::default()); Framed::new(io, h1::ClientCodec::default())
framed.send((head, BodySize::None).into()).await?; .send((head, BodySize::None).into())
.from_err()
// read response // read response
if let (Some(result), framed) = framed.into_future().await { .and_then(|framed| {
let head = result.map_err(SendRequestError::from)?; framed
.into_future()
.map_err(|(e, _)| SendRequestError::from(e))
.and_then(|(head, framed)| {
if let Some(head) = head {
Ok((head, framed)) Ok((head, framed))
} else { } else {
Err(SendRequestError::from(ConnectError::Disconnected)) Err(SendRequestError::from(ConnectError::Disconnected))
} }
}
/// send request body to the peer
pub(crate) async fn send_body<I, B>(
mut body: B,
framed: &mut Framed<I, h1::ClientCodec>,
) -> Result<(), SendRequestError>
where
I: ConnectionLifetime,
B: MessageBody,
{
let mut eof = false;
while !eof {
while !eof && !framed.is_write_buf_full() {
match poll_fn(|cx| body.poll_next(cx)).await {
Some(result) => {
framed.write(h1::Message::Chunk(Some(result?)))?;
}
None => {
eof = true;
framed.write(h1::Message::Chunk(None))?;
}
}
}
if !framed.is_write_buf_empty() {
poll_fn(|cx| match framed.flush(cx) {
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => {
if !framed.is_write_buf_full() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}) })
.await?; })
}
}
SinkExt::flush(framed).await?;
Ok(())
} }
#[doc(hidden)] #[doc(hidden)]
@@ -169,10 +134,7 @@ pub struct H1Connection<T> {
pool: Option<Acquired<T>>, pool: Option<Acquired<T>>,
} }
impl<T> ConnectionLifetime for H1Connection<T> impl<T: AsyncRead + AsyncWrite + 'static> ConnectionLifetime for H1Connection<T> {
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
/// Close connection /// Close connection
fn close(&mut self) { fn close(&mut self) {
if let Some(mut pool) = self.pool.take() { if let Some(mut pool) = self.pool.take() {
@@ -200,44 +162,98 @@ where
} }
} }
impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncRead for H1Connection<T> { impl<T: AsyncRead + AsyncWrite + 'static> io::Read for H1Connection<T> {
unsafe fn prepare_uninitialized_buffer( fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
&self, self.io.as_mut().unwrap().read(buf)
buf: &mut [mem::MaybeUninit<u8>],
) -> bool {
self.io.as_ref().unwrap().prepare_uninitialized_buffer(buf)
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io.as_mut().unwrap()).poll_read(cx, buf)
} }
} }
impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncWrite for H1Connection<T> { impl<T: AsyncRead + AsyncWrite + 'static> AsyncRead for H1Connection<T> {}
fn poll_write(
mut self: Pin<&mut Self>, impl<T: AsyncRead + AsyncWrite + 'static> io::Write for H1Connection<T> {
cx: &mut Context<'_>, fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
buf: &[u8], self.io.as_mut().unwrap().write(buf)
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io.as_mut().unwrap()).poll_write(cx, buf)
} }
fn poll_flush( fn flush(&mut self) -> io::Result<()> {
mut self: Pin<&mut Self>, self.io.as_mut().unwrap().flush()
cx: &mut Context<'_>, }
) -> Poll<io::Result<()>> {
Pin::new(self.io.as_mut().unwrap()).poll_flush(cx)
} }
fn poll_shutdown( impl<T: AsyncRead + AsyncWrite + 'static> AsyncWrite for H1Connection<T> {
mut self: Pin<&mut Self>, fn shutdown(&mut self) -> Poll<(), io::Error> {
cx: &mut Context<'_>, self.io.as_mut().unwrap().shutdown()
) -> Poll<Result<(), io::Error>> { }
Pin::new(self.io.as_mut().unwrap()).poll_shutdown(cx) }
/// Future responsible for sending request body to the peer
pub(crate) struct SendBody<I, B> {
body: Option<B>,
framed: Option<Framed<I, h1::ClientCodec>>,
flushed: bool,
}
impl<I, B> SendBody<I, B>
where
I: AsyncRead + AsyncWrite + 'static,
B: MessageBody,
{
pub(crate) fn new(body: B, framed: Framed<I, h1::ClientCodec>) -> Self {
SendBody {
body: Some(body),
framed: Some(framed),
flushed: true,
}
}
}
impl<I, B> Future for SendBody<I, B>
where
I: ConnectionLifetime,
B: MessageBody,
{
type Item = Framed<I, h1::ClientCodec>;
type Error = SendRequestError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut body_ready = true;
loop {
while body_ready
&& self.body.is_some()
&& !self.framed.as_ref().unwrap().is_write_buf_full()
{
match self.body.as_mut().unwrap().poll_next()? {
Async::Ready(item) => {
// check if body is done
if item.is_none() {
let _ = self.body.take();
}
self.flushed = false;
self.framed
.as_mut()
.unwrap()
.force_send(h1::Message::Chunk(item))?;
break;
}
Async::NotReady => body_ready = false,
}
}
if !self.flushed {
match self.framed.as_mut().unwrap().poll_complete()? {
Async::Ready(_) => {
self.flushed = true;
continue;
}
Async::NotReady => return Ok(Async::NotReady),
}
}
if self.body.is_none() {
return Ok(Async::Ready(self.framed.take().unwrap()));
}
return Ok(Async::NotReady);
}
} }
} }
@@ -254,27 +270,23 @@ impl<Io: ConnectionLifetime> PlStream<Io> {
} }
impl<Io: ConnectionLifetime> Stream for PlStream<Io> { impl<Io: ConnectionLifetime> Stream for PlStream<Io> {
type Item = Result<Bytes, PayloadError>; type Item = Bytes;
type Error = PayloadError;
fn poll_next( fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
self: Pin<&mut Self>, match self.framed.as_mut().unwrap().poll()? {
cx: &mut Context<'_>, Async::NotReady => Ok(Async::NotReady),
) -> Poll<Option<Self::Item>> { Async::Ready(Some(chunk)) => {
let this = self.get_mut();
match this.framed.as_mut().unwrap().next_item(cx)? {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(chunk)) => {
if let Some(chunk) = chunk { if let Some(chunk) = chunk {
Poll::Ready(Some(Ok(chunk))) Ok(Async::Ready(Some(chunk)))
} else { } else {
let framed = this.framed.take().unwrap(); let framed = self.framed.take().unwrap();
let force_close = !framed.get_codec().keepalive(); let force_close = !framed.get_codec().keepalive();
release_connection(framed, force_close); release_connection(framed, force_close);
Poll::Ready(None) Ok(Async::Ready(None))
} }
} }
Poll::Ready(None) => Poll::Ready(None), Async::Ready(None) => Ok(Async::Ready(None)),
} }
} }
} }

View File

@@ -1,12 +1,12 @@
use std::convert::TryFrom;
use std::time; use std::time;
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use bytes::Bytes; use bytes::Bytes;
use futures_util::future::poll_fn; use futures::future::{err, Either};
use futures::{Async, Future, Poll};
use h2::{client::SendRequest, SendStream}; use h2::{client::SendRequest, SendStream};
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING}; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING};
use http::{request::Request, Method, Version}; use http::{request::Request, HttpTryFrom, Method, Version};
use crate::body::{BodySize, MessageBody}; use crate::body::{BodySize, MessageBody};
use crate::header::HeaderMap; use crate::header::HeaderMap;
@@ -17,15 +17,15 @@ use super::connection::{ConnectionType, IoConnection};
use super::error::SendRequestError; use super::error::SendRequestError;
use super::pool::Acquired; use super::pool::Acquired;
pub(crate) async fn send_request<T, B>( pub(crate) fn send_request<T, B>(
mut io: SendRequest<Bytes>, io: SendRequest<Bytes>,
head: RequestHeadType, head: RequestHeadType,
body: B, body: B,
created: time::Instant, created: time::Instant,
pool: Option<Acquired<T>>, pool: Option<Acquired<T>>,
) -> Result<(ResponseHead, Payload), SendRequestError> ) -> impl Future<Item = (ResponseHead, Payload), Error = SendRequestError>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
B: MessageBody, B: MessageBody,
{ {
trace!("Sending client request: {:?} {:?}", head, body.size()); trace!("Sending client request: {:?} {:?}", head, body.size());
@@ -36,6 +36,9 @@ where
_ => false, _ => false,
}; };
io.ready()
.map_err(SendRequestError::from)
.and_then(move |mut io| {
let mut req = Request::new(()); let mut req = Request::new(());
*req.uri_mut() = head.as_ref().uri.clone(); *req.uri_mut() = head.as_ref().uri.clone();
*req.method_mut() = head.as_ref().method.clone(); *req.method_mut() = head.as_ref().method.clone();
@@ -66,7 +69,9 @@ where
// Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate. // Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate.
let (head, extra_headers) = match head { let (head, extra_headers) = match head {
RequestHeadType::Owned(head) => (RequestHeadType::Owned(head), HeaderMap::new()), RequestHeadType::Owned(head) => {
(RequestHeadType::Owned(head), HeaderMap::new())
}
RequestHeadType::Rc(head, extra_headers) => ( RequestHeadType::Rc(head, extra_headers) => (
RequestHeadType::Rc(head, None), RequestHeadType::Rc(head, None),
extra_headers.unwrap_or_else(HeaderMap::new), extra_headers.unwrap_or_else(HeaderMap::new),
@@ -92,27 +97,30 @@ where
req.headers_mut().append(key, value.clone()); req.headers_mut().append(key, value.clone());
} }
let res = poll_fn(|cx| io.poll_ready(cx)).await; match io.send_request(req, eof) {
if let Err(e) = res { Ok((res, send)) => {
release(io, pool, created, e.is_io());
return Err(SendRequestError::from(e));
}
let resp = match io.send_request(req, eof) {
Ok((fut, send)) => {
release(io, pool, created, false); release(io, pool, created, false);
if !eof { if !eof {
send_body(body, send).await?; Either::A(Either::B(
SendBody {
body,
send,
buf: None,
}
.and_then(move |_| res.map_err(SendRequestError::from)),
))
} else {
Either::B(res.map_err(SendRequestError::from))
} }
fut.await.map_err(SendRequestError::from)?
} }
Err(e) => { Err(e) => {
release(io, pool, created, e.is_io()); release(io, pool, created, e.is_io());
return Err(e.into()); Either::A(Either::A(err(e.into())))
} }
}; }
})
.and_then(move |resp| {
let (parts, body) = resp.into_parts(); let (parts, body) = resp.into_parts();
let payload = if head_req { Payload::None } else { body.into() }; let payload = if head_req { Payload::None } else { body.into() };
@@ -120,56 +128,66 @@ where
head.version = parts.version; head.version = parts.version;
head.headers = parts.headers.into(); head.headers = parts.headers.into();
Ok((head, payload)) Ok((head, payload))
})
.from_err()
} }
async fn send_body<B: MessageBody>( struct SendBody<B: MessageBody> {
mut body: B, body: B,
mut send: SendStream<Bytes>, send: SendStream<Bytes>,
) -> Result<(), SendRequestError> { buf: Option<Bytes>,
let mut buf = None; }
impl<B: MessageBody> Future for SendBody<B> {
type Item = ();
type Error = SendRequestError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop { loop {
if buf.is_none() { if self.buf.is_none() {
match poll_fn(|cx| body.poll_next(cx)).await { match self.body.poll_next() {
Some(Ok(b)) => { Ok(Async::Ready(Some(buf))) => {
send.reserve_capacity(b.len()); self.send.reserve_capacity(buf.len());
buf = Some(b); self.buf = Some(buf);
} }
Some(Err(e)) => return Err(e.into()), Ok(Async::Ready(None)) => {
None => { if let Err(e) = self.send.send_data(Bytes::new(), true) {
if let Err(e) = send.send_data(Bytes::new(), true) {
return Err(e.into()); return Err(e.into());
} }
send.reserve_capacity(0); self.send.reserve_capacity(0);
return Ok(()); return Ok(Async::Ready(()));
} }
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(e) => return Err(e.into()),
} }
} }
match poll_fn(|cx| send.poll_capacity(cx)).await { match self.send.poll_capacity() {
None => return Ok(()), Ok(Async::NotReady) => return Ok(Async::NotReady),
Some(Ok(cap)) => { Ok(Async::Ready(None)) => return Ok(Async::Ready(())),
let b = buf.as_mut().unwrap(); Ok(Async::Ready(Some(cap))) => {
let len = b.len(); let mut buf = self.buf.take().unwrap();
let bytes = b.split_to(std::cmp::min(cap, len)); let len = buf.len();
let bytes = buf.split_to(std::cmp::min(cap, len));
if let Err(e) = send.send_data(bytes, false) { if let Err(e) = self.send.send_data(bytes, false) {
return Err(e.into()); return Err(e.into());
} else { } else {
if !b.is_empty() { if !buf.is_empty() {
send.reserve_capacity(b.len()); self.send.reserve_capacity(buf.len());
} else { self.buf = Some(buf);
buf = None;
} }
continue; continue;
} }
} }
Some(Err(e)) => return Err(e.into()), Err(e) => return Err(e.into()),
}
} }
} }
} }
// release SendRequest object // release SendRequest object
fn release<T: AsyncRead + AsyncWrite + Unpin + 'static>( fn release<T: AsyncRead + AsyncWrite + 'static>(
io: SendRequest<Bytes>, io: SendRequest<Bytes>,
pool: Option<Acquired<T>>, pool: Option<Acquired<T>>,
created: time::Instant, created: time::Instant,

View File

@@ -1,22 +1,22 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::future::Future; use std::io;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_rt::time::{delay_for, Delay};
use actix_service::Service; use actix_service::Service;
use actix_utils::{oneshot, task::LocalWaker};
use bytes::Bytes; use bytes::Bytes;
use futures_util::future::{poll_fn, FutureExt, LocalBoxFuture}; use futures::future::{err, ok, Either, FutureResult};
use fxhash::FxHashMap; use futures::task::AtomicTask;
use h2::client::{handshake, Connection, SendRequest}; use futures::unsync::oneshot;
use futures::{Async, Future, Poll};
use h2::client::{handshake, Handshake};
use hashbrown::HashMap;
use http::uri::Authority; use http::uri::Authority;
use indexmap::IndexSet; use indexmap::IndexSet;
use slab::Slab; use slab::Slab;
use tokio_timer::{sleep, Delay};
use super::connection::{ConnectionType, IoConnection}; use super::connection::{ConnectionType, IoConnection};
use super::error::ConnectError; use super::error::ConnectError;
@@ -41,12 +41,16 @@ impl From<Authority> for Key {
} }
/// Connections pool /// Connections pool
pub(crate) struct ConnectionPool<T, Io: 'static>(Rc<RefCell<T>>, Rc<RefCell<Inner<Io>>>); pub(crate) struct ConnectionPool<T, Io: AsyncRead + AsyncWrite + 'static>(
T,
Rc<RefCell<Inner<Io>>>,
);
impl<T, Io> ConnectionPool<T, Io> impl<T, Io> ConnectionPool<T, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
pub(crate) fn new( pub(crate) fn new(
@@ -57,7 +61,7 @@ where
limit: usize, limit: usize,
) -> Self { ) -> Self {
ConnectionPool( ConnectionPool(
Rc::new(RefCell::new(connector)), connector,
Rc::new(RefCell::new(Inner { Rc::new(RefCell::new(Inner {
conn_lifetime, conn_lifetime,
conn_keep_alive, conn_keep_alive,
@@ -66,8 +70,8 @@ where
acquired: 0, acquired: 0,
waiters: Slab::new(), waiters: Slab::new(),
waiters_queue: IndexSet::new(), waiters_queue: IndexSet::new(),
available: FxHashMap::default(), available: HashMap::new(),
waker: LocalWaker::new(), task: None,
})), })),
) )
} }
@@ -75,7 +79,8 @@ where
impl<T, Io> Clone for ConnectionPool<T, Io> impl<T, Io> Clone for ConnectionPool<T, Io>
where where
Io: 'static, T: Clone,
Io: AsyncRead + AsyncWrite + 'static,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
ConnectionPool(self.0.clone(), self.1.clone()) ConnectionPool(self.0.clone(), self.1.clone())
@@ -84,116 +89,86 @@ where
impl<T, Io> Service for ConnectionPool<T, Io> impl<T, Io> Service for ConnectionPool<T, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError> T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static, + 'static,
{ {
type Request = Connect; type Request = Connect;
type Response = IoConnection<Io>; type Response = IoConnection<Io>;
type Error = ConnectError; type Error = ConnectError;
type Future = LocalBoxFuture<'static, Result<IoConnection<Io>, ConnectError>>; type Future = Either<
FutureResult<Self::Response, Self::Error>,
Either<WaitForConnection<Io>, OpenConnection<T::Future, Io>>,
>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.0.poll_ready(cx) self.0.poll_ready()
} }
fn call(&mut self, req: Connect) -> Self::Future { fn call(&mut self, req: Connect) -> Self::Future {
// start support future let key = if let Some(authority) = req.uri.authority_part() {
actix_rt::spawn(ConnectorPoolSupport {
connector: self.0.clone(),
inner: self.1.clone(),
});
let mut connector = self.0.clone();
let inner = self.1.clone();
let fut = async move {
let key = if let Some(authority) = req.uri.authority() {
authority.clone().into() authority.clone().into()
} else { } else {
return Err(ConnectError::Unresolverd); return Either::A(err(ConnectError::Unresolverd));
}; };
// acquire connection // acquire connection
match poll_fn(|cx| Poll::Ready(inner.borrow_mut().acquire(&key, cx))).await { match self.1.as_ref().borrow_mut().acquire(&key) {
Acquire::Acquired(io, created) => { Acquire::Acquired(io, created) => {
// use existing connection // use existing connection
return Ok(IoConnection::new( return Either::A(ok(IoConnection::new(
io, io,
created, created,
Some(Acquired(key, Some(inner))), Some(Acquired(key, Some(self.1.clone()))),
)); )));
} }
Acquire::Available => { Acquire::Available => {
// open tcp connection // open new connection
let (io, proto) = connector.call(req).await?; return Either::B(Either::B(OpenConnection::new(
key,
let guard = OpenGuard::new(key, inner); self.1.clone(),
self.0.call(req),
if proto == Protocol::Http1 { )));
Ok(IoConnection::new(
ConnectionType::H1(io),
Instant::now(),
Some(guard.consume()),
))
} else {
let (snd, connection) = handshake(io).await?;
actix_rt::spawn(connection.map(|_| ()));
Ok(IoConnection::new(
ConnectionType::H2(snd),
Instant::now(),
Some(guard.consume()),
))
} }
_ => (),
} }
_ => {
// connection is not available, wait // connection is not available, wait
let (rx, token) = inner.borrow_mut().wait_for(req); let (rx, token, support) = self.1.as_ref().borrow_mut().wait_for(req);
let guard = WaiterGuard::new(key, token, inner); // start support future
let res = match rx.await { if !support {
Err(_) => Err(ConnectError::Disconnected), self.1.as_ref().borrow_mut().task = Some(AtomicTask::new());
Ok(res) => res, tokio_current_thread::spawn(ConnectorPoolSupport {
}; connector: self.0.clone(),
guard.consume(); inner: self.1.clone(),
res })
} }
}
};
fut.boxed_local() Either::B(Either::A(WaitForConnection {
rx,
key,
token,
inner: Some(self.1.clone()),
}))
} }
} }
struct WaiterGuard<Io> #[doc(hidden)]
pub struct WaitForConnection<Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
{ {
key: Key, key: Key,
token: usize, token: usize,
rx: oneshot::Receiver<Result<IoConnection<Io>, ConnectError>>,
inner: Option<Rc<RefCell<Inner<Io>>>>, inner: Option<Rc<RefCell<Inner<Io>>>>,
} }
impl<Io> WaiterGuard<Io> impl<Io> Drop for WaitForConnection<Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
{
fn new(key: Key, token: usize, inner: Rc<RefCell<Inner<Io>>>) -> Self {
Self {
key,
token,
inner: Some(inner),
}
}
fn consume(mut self) {
let _ = self.inner.take();
}
}
impl<Io> Drop for WaiterGuard<Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(i) = self.inner.take() { if let Some(i) = self.inner.take() {
@@ -204,43 +179,113 @@ where
} }
} }
struct OpenGuard<Io> impl<Io> Future for WaitForConnection<Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite,
{ {
type Item = IoConnection<Io>;
type Error = ConnectError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.rx.poll() {
Ok(Async::Ready(item)) => match item {
Err(err) => Err(err),
Ok(conn) => {
let _ = self.inner.take();
Ok(Async::Ready(conn))
}
},
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(_) => {
let _ = self.inner.take();
Err(ConnectError::Disconnected)
}
}
}
}
#[doc(hidden)]
pub struct OpenConnection<F, Io>
where
Io: AsyncRead + AsyncWrite + 'static,
{
fut: F,
key: Key, key: Key,
h2: Option<Handshake<Io, Bytes>>,
inner: Option<Rc<RefCell<Inner<Io>>>>, inner: Option<Rc<RefCell<Inner<Io>>>>,
} }
impl<Io> OpenGuard<Io> impl<F, Io> OpenConnection<F, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, F: Future<Item = (Io, Protocol), Error = ConnectError>,
Io: AsyncRead + AsyncWrite + 'static,
{ {
fn new(key: Key, inner: Rc<RefCell<Inner<Io>>>) -> Self { fn new(key: Key, inner: Rc<RefCell<Inner<Io>>>, fut: F) -> Self {
Self { OpenConnection {
key, key,
fut,
inner: Some(inner), inner: Some(inner),
h2: None,
}
} }
} }
fn consume(mut self) -> Acquired<Io> { impl<F, Io> Drop for OpenConnection<F, Io>
Acquired(self.key.clone(), self.inner.take())
}
}
impl<Io> Drop for OpenGuard<Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
{ {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(i) = self.inner.take() { if let Some(inner) = self.inner.take() {
let mut inner = i.as_ref().borrow_mut(); let mut inner = inner.as_ref().borrow_mut();
inner.release(); inner.release();
inner.check_availibility(); inner.check_availibility();
} }
} }
} }
impl<F, Io> Future for OpenConnection<F, Io>
where
F: Future<Item = (Io, Protocol), Error = ConnectError>,
Io: AsyncRead + AsyncWrite,
{
type Item = IoConnection<Io>;
type Error = ConnectError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(ref mut h2) = self.h2 {
return match h2.poll() {
Ok(Async::Ready((snd, connection))) => {
tokio_current_thread::spawn(connection.map_err(|_| ()));
Ok(Async::Ready(IoConnection::new(
ConnectionType::H2(snd),
Instant::now(),
Some(Acquired(self.key.clone(), self.inner.take())),
)))
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(e) => Err(e.into()),
};
}
match self.fut.poll() {
Err(err) => Err(err),
Ok(Async::Ready((io, proto))) => {
if proto == Protocol::Http1 {
Ok(Async::Ready(IoConnection::new(
ConnectionType::H1(io),
Instant::now(),
Some(Acquired(self.key.clone(), self.inner.take())),
)))
} else {
self.h2 = Some(handshake(io));
self.poll()
}
}
Ok(Async::NotReady) => Ok(Async::NotReady),
}
}
}
enum Acquire<T> { enum Acquire<T> {
Acquired(ConnectionType<T>, Instant), Acquired(ConnectionType<T>, Instant),
Available, Available,
@@ -259,7 +304,7 @@ pub(crate) struct Inner<Io> {
disconnect_timeout: Option<Duration>, disconnect_timeout: Option<Duration>,
limit: usize, limit: usize,
acquired: usize, acquired: usize,
available: FxHashMap<Key, VecDeque<AvailableConnection<Io>>>, available: HashMap<Key, VecDeque<AvailableConnection<Io>>>,
waiters: Slab< waiters: Slab<
Option<( Option<(
Connect, Connect,
@@ -267,7 +312,7 @@ pub(crate) struct Inner<Io> {
)>, )>,
>, >,
waiters_queue: IndexSet<(Key, usize)>, waiters_queue: IndexSet<(Key, usize)>,
waker: LocalWaker, task: Option<AtomicTask>,
} }
impl<Io> Inner<Io> { impl<Io> Inner<Io> {
@@ -287,7 +332,7 @@ impl<Io> Inner<Io> {
impl<Io> Inner<Io> impl<Io> Inner<Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
{ {
/// connection is not available, wait /// connection is not available, wait
fn wait_for( fn wait_for(
@@ -296,19 +341,20 @@ where
) -> ( ) -> (
oneshot::Receiver<Result<IoConnection<Io>, ConnectError>>, oneshot::Receiver<Result<IoConnection<Io>, ConnectError>>,
usize, usize,
bool,
) { ) {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
let key: Key = connect.uri.authority().unwrap().clone().into(); let key: Key = connect.uri.authority_part().unwrap().clone().into();
let entry = self.waiters.vacant_entry(); let entry = self.waiters.vacant_entry();
let token = entry.key(); let token = entry.key();
entry.insert(Some((connect, tx))); entry.insert(Some((connect, tx)));
assert!(self.waiters_queue.insert((key, token))); assert!(self.waiters_queue.insert((key, token)));
(rx, token) (rx, token, self.task.is_some())
} }
fn acquire(&mut self, key: &Key, cx: &mut Context<'_>) -> Acquire<Io> { fn acquire(&mut self, key: &Key) -> Acquire<Io> {
// check limits // check limits
if self.limit > 0 && self.acquired >= self.limit { if self.limit > 0 && self.acquired >= self.limit {
return Acquire::NotAvailable; return Acquire::NotAvailable;
@@ -327,26 +373,28 @@ where
{ {
if let Some(timeout) = self.disconnect_timeout { if let Some(timeout) = self.disconnect_timeout {
if let ConnectionType::H1(io) = conn.io { if let ConnectionType::H1(io) = conn.io {
actix_rt::spawn(CloseConnection::new(io, timeout)) tokio_current_thread::spawn(CloseConnection::new(
io, timeout,
))
} }
} }
} else { } else {
let mut io = conn.io; let mut io = conn.io;
let mut buf = [0; 2]; let mut buf = [0; 2];
if let ConnectionType::H1(ref mut s) = io { if let ConnectionType::H1(ref mut s) = io {
match Pin::new(s).poll_read(cx, &mut buf) { match s.read(&mut buf) {
Poll::Pending => (), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (),
Poll::Ready(Ok(n)) if n > 0 => { Ok(n) if n > 0 => {
if let Some(timeout) = self.disconnect_timeout { if let Some(timeout) = self.disconnect_timeout {
if let ConnectionType::H1(io) = io { if let ConnectionType::H1(io) = io {
actix_rt::spawn(CloseConnection::new( tokio_current_thread::spawn(
io, timeout, CloseConnection::new(io, timeout),
)) )
} }
} }
continue; continue;
} }
_ => continue, Ok(_) | Err(_) => continue,
} }
} }
return Acquire::Acquired(io, conn.created); return Acquire::Acquired(io, conn.created);
@@ -373,7 +421,7 @@ where
self.acquired -= 1; self.acquired -= 1;
if let Some(timeout) = self.disconnect_timeout { if let Some(timeout) = self.disconnect_timeout {
if let ConnectionType::H1(io) = io { if let ConnectionType::H1(io) = io {
actix_rt::spawn(CloseConnection::new(io, timeout)) tokio_current_thread::spawn(CloseConnection::new(io, timeout))
} }
} }
self.check_availibility(); self.check_availibility();
@@ -381,7 +429,9 @@ where
fn check_availibility(&self) { fn check_availibility(&self) {
if !self.waiters_queue.is_empty() && self.acquired < self.limit { if !self.waiters_queue.is_empty() && self.acquired < self.limit {
self.waker.wake(); if let Some(t) = self.task.as_ref() {
t.notify()
}
} }
} }
} }
@@ -393,30 +443,29 @@ struct CloseConnection<T> {
impl<T> CloseConnection<T> impl<T> CloseConnection<T>
where where
T: AsyncWrite + Unpin, T: AsyncWrite,
{ {
fn new(io: T, timeout: Duration) -> Self { fn new(io: T, timeout: Duration) -> Self {
CloseConnection { CloseConnection {
io, io,
timeout: delay_for(timeout), timeout: sleep(timeout),
} }
} }
} }
impl<T> Future for CloseConnection<T> impl<T> Future for CloseConnection<T>
where where
T: AsyncWrite + Unpin, T: AsyncWrite,
{ {
type Output = (); type Item = ();
type Error = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { fn poll(&mut self) -> Poll<(), ()> {
let this = self.get_mut(); match self.timeout.poll() {
Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())),
match Pin::new(&mut this.timeout).poll(cx) { Ok(Async::NotReady) => match self.io.shutdown() {
Poll::Ready(_) => Poll::Ready(()), Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())),
Poll::Pending => match Pin::new(&mut this.io).poll_shutdown(cx) { Ok(Async::NotReady) => Ok(Async::NotReady),
Poll::Ready(_) => Poll::Ready(()),
Poll::Pending => Poll::Pending,
}, },
} }
} }
@@ -424,7 +473,7 @@ where
struct ConnectorPoolSupport<T, Io> struct ConnectorPoolSupport<T, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
{ {
connector: T, connector: T,
inner: Rc<RefCell<Inner<Io>>>, inner: Rc<RefCell<Inner<Io>>>,
@@ -432,17 +481,16 @@ where
impl<T, Io> Future for ConnectorPoolSupport<T, Io> impl<T, Io> Future for ConnectorPoolSupport<T, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>, T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>,
T::Future: 'static, T::Future: 'static,
{ {
type Output = (); type Item = ();
type Error = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let this = unsafe { self.get_unchecked_mut() }; let mut inner = self.inner.as_ref().borrow_mut();
inner.task.as_ref().unwrap().register();
let mut inner = this.inner.as_ref().borrow_mut();
inner.waker.register(cx.waker());
// check waiters // check waiters
loop { loop {
@@ -457,14 +505,14 @@ where
continue; continue;
} }
match inner.acquire(&key, cx) { match inner.acquire(&key) {
Acquire::NotAvailable => break, Acquire::NotAvailable => break,
Acquire::Acquired(io, created) => { Acquire::Acquired(io, created) => {
let tx = inner.waiters.get_mut(token).unwrap().take().unwrap().1; let tx = inner.waiters.get_mut(token).unwrap().take().unwrap().1;
if let Err(conn) = tx.send(Ok(IoConnection::new( if let Err(conn) = tx.send(Ok(IoConnection::new(
io, io,
created, created,
Some(Acquired(key.clone(), Some(this.inner.clone()))), Some(Acquired(key.clone(), Some(self.inner.clone()))),
))) { ))) {
let (io, created) = conn.unwrap().into_inner(); let (io, created) = conn.unwrap().into_inner();
inner.release_conn(&key, io, created); inner.release_conn(&key, io, created);
@@ -476,38 +524,33 @@ where
OpenWaitingConnection::spawn( OpenWaitingConnection::spawn(
key.clone(), key.clone(),
tx, tx,
this.inner.clone(), self.inner.clone(),
this.connector.call(connect), self.connector.call(connect),
); );
} }
} }
let _ = inner.waiters_queue.swap_remove_index(0); let _ = inner.waiters_queue.swap_remove_index(0);
} }
Poll::Pending Ok(Async::NotReady)
} }
} }
struct OpenWaitingConnection<F, Io> struct OpenWaitingConnection<F, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
{ {
fut: F, fut: F,
key: Key, key: Key,
h2: Option< h2: Option<Handshake<Io, Bytes>>,
LocalBoxFuture<
'static,
Result<(SendRequest<Bytes>, Connection<Io, Bytes>), h2::Error>,
>,
>,
rx: Option<oneshot::Sender<Result<IoConnection<Io>, ConnectError>>>, rx: Option<oneshot::Sender<Result<IoConnection<Io>, ConnectError>>>,
inner: Option<Rc<RefCell<Inner<Io>>>>, inner: Option<Rc<RefCell<Inner<Io>>>>,
} }
impl<F, Io> OpenWaitingConnection<F, Io> impl<F, Io> OpenWaitingConnection<F, Io>
where where
F: Future<Output = Result<(Io, Protocol), ConnectError>> + 'static, F: Future<Item = (Io, Protocol), Error = ConnectError> + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
{ {
fn spawn( fn spawn(
key: Key, key: Key,
@@ -515,7 +558,7 @@ where
inner: Rc<RefCell<Inner<Io>>>, inner: Rc<RefCell<Inner<Io>>>,
fut: F, fut: F,
) { ) {
actix_rt::spawn(OpenWaitingConnection { tokio_current_thread::spawn(OpenWaitingConnection {
key, key,
fut, fut,
h2: None, h2: None,
@@ -527,7 +570,7 @@ where
impl<F, Io> Drop for OpenWaitingConnection<F, Io> impl<F, Io> Drop for OpenWaitingConnection<F, Io>
where where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + 'static,
{ {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(inner) = self.inner.take() { if let Some(inner) = self.inner.take() {
@@ -540,60 +583,59 @@ where
impl<F, Io> Future for OpenWaitingConnection<F, Io> impl<F, Io> Future for OpenWaitingConnection<F, Io>
where where
F: Future<Output = Result<(Io, Protocol), ConnectError>>, F: Future<Item = (Io, Protocol), Error = ConnectError>,
Io: AsyncRead + AsyncWrite + Unpin, Io: AsyncRead + AsyncWrite,
{ {
type Output = (); type Item = ();
type Error = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let this = unsafe { self.get_unchecked_mut() }; if let Some(ref mut h2) = self.h2 {
return match h2.poll() {
if let Some(ref mut h2) = this.h2 { Ok(Async::Ready((snd, connection))) => {
return match Pin::new(h2).poll(cx) { tokio_current_thread::spawn(connection.map_err(|_| ()));
Poll::Ready(Ok((snd, connection))) => { let rx = self.rx.take().unwrap();
actix_rt::spawn(connection.map(|_| ()));
let rx = this.rx.take().unwrap();
let _ = rx.send(Ok(IoConnection::new( let _ = rx.send(Ok(IoConnection::new(
ConnectionType::H2(snd), ConnectionType::H2(snd),
Instant::now(), Instant::now(),
Some(Acquired(this.key.clone(), this.inner.take())), Some(Acquired(self.key.clone(), self.inner.take())),
))); )));
Poll::Ready(()) Ok(Async::Ready(()))
} }
Poll::Pending => Poll::Pending, Ok(Async::NotReady) => Ok(Async::NotReady),
Poll::Ready(Err(err)) => { Err(err) => {
let _ = this.inner.take(); let _ = self.inner.take();
if let Some(rx) = this.rx.take() { if let Some(rx) = self.rx.take() {
let _ = rx.send(Err(ConnectError::H2(err))); let _ = rx.send(Err(ConnectError::H2(err)));
} }
Poll::Ready(()) Err(())
} }
}; };
} }
match unsafe { Pin::new_unchecked(&mut this.fut) }.poll(cx) { match self.fut.poll() {
Poll::Ready(Err(err)) => { Err(err) => {
let _ = this.inner.take(); let _ = self.inner.take();
if let Some(rx) = this.rx.take() { if let Some(rx) = self.rx.take() {
let _ = rx.send(Err(err)); let _ = rx.send(Err(err));
} }
Poll::Ready(()) Err(())
} }
Poll::Ready(Ok((io, proto))) => { Ok(Async::Ready((io, proto))) => {
if proto == Protocol::Http1 { if proto == Protocol::Http1 {
let rx = this.rx.take().unwrap(); let rx = self.rx.take().unwrap();
let _ = rx.send(Ok(IoConnection::new( let _ = rx.send(Ok(IoConnection::new(
ConnectionType::H1(io), ConnectionType::H1(io),
Instant::now(), Instant::now(),
Some(Acquired(this.key.clone(), this.inner.take())), Some(Acquired(self.key.clone(), self.inner.take())),
))); )));
Poll::Ready(()) Ok(Async::Ready(()))
} else { } else {
this.h2 = Some(handshake(io).boxed_local()); self.h2 = Some(handshake(io));
unsafe { Pin::new_unchecked(this) }.poll(cx) self.poll()
} }
} }
Poll::Pending => Poll::Pending, Ok(Async::NotReady) => Ok(Async::NotReady),
} }
} }
} }
@@ -602,7 +644,7 @@ pub(crate) struct Acquired<T>(Key, Option<Rc<RefCell<Inner<T>>>>);
impl<T> Acquired<T> impl<T> Acquired<T>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + 'static,
{ {
pub(crate) fn close(&mut self, conn: IoConnection<T>) { pub(crate) fn close(&mut self, conn: IoConnection<T>) {
if let Some(inner) = self.1.take() { if let Some(inner) = self.1.take() {

View File

@@ -1,8 +1,8 @@
use std::cell::UnsafeCell; use std::cell::UnsafeCell;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::Service; use actix_service::Service;
use futures::Poll;
#[doc(hidden)] #[doc(hidden)]
/// Service that allows to turn non-clone service to a service with `Clone` impl /// Service that allows to turn non-clone service to a service with `Clone` impl
@@ -32,8 +32,8 @@ where
type Error = T::Error; type Error = T::Error;
type Future = T::Future; type Future = T::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
unsafe { &mut *self.0.as_ref().get() }.poll_ready(cx) unsafe { &mut *self.0.as_ref().get() }.poll_ready()
} }
fn call(&mut self, req: T::Request) -> Self::Future { fn call(&mut self, req: T::Request) -> Self::Future {

View File

@@ -1,13 +1,13 @@
use std::cell::UnsafeCell; use std::cell::UnsafeCell;
use std::fmt;
use std::fmt::Write; use std::fmt::Write;
use std::rc::Rc; use std::rc::Rc;
use std::time::Duration; use std::time::{Duration, Instant};
use std::{fmt, net};
use actix_rt::time::{delay_for, delay_until, Delay, Instant};
use bytes::BytesMut; use bytes::BytesMut;
use futures_util::{future, FutureExt}; use futures::{future, Future};
use time; use time;
use tokio_timer::{sleep, Delay};
// "Sun, 06 Nov 1994 08:49:37 GMT".len() // "Sun, 06 Nov 1994 08:49:37 GMT".len()
const DATE_VALUE_LENGTH: usize = 29; const DATE_VALUE_LENGTH: usize = 29;
@@ -47,8 +47,6 @@ struct Inner {
client_timeout: u64, client_timeout: u64,
client_disconnect: u64, client_disconnect: u64,
ka_enabled: bool, ka_enabled: bool,
secure: bool,
local_addr: Option<std::net::SocketAddr>,
timer: DateService, timer: DateService,
} }
@@ -60,7 +58,7 @@ impl Clone for ServiceConfig {
impl Default for ServiceConfig { impl Default for ServiceConfig {
fn default() -> Self { fn default() -> Self {
Self::new(KeepAlive::Timeout(5), 0, 0, false, None) Self::new(KeepAlive::Timeout(5), 0, 0)
} }
} }
@@ -70,8 +68,6 @@ impl ServiceConfig {
keep_alive: KeepAlive, keep_alive: KeepAlive,
client_timeout: u64, client_timeout: u64,
client_disconnect: u64, client_disconnect: u64,
secure: bool,
local_addr: Option<net::SocketAddr>,
) -> ServiceConfig { ) -> ServiceConfig {
let (keep_alive, ka_enabled) = match keep_alive { let (keep_alive, ka_enabled) = match keep_alive {
KeepAlive::Timeout(val) => (val as u64, true), KeepAlive::Timeout(val) => (val as u64, true),
@@ -89,24 +85,10 @@ impl ServiceConfig {
ka_enabled, ka_enabled,
client_timeout, client_timeout,
client_disconnect, client_disconnect,
secure,
local_addr,
timer: DateService::new(), timer: DateService::new(),
})) }))
} }
#[inline]
/// Returns true if connection is secure(https)
pub fn secure(&self) -> bool {
self.0.secure
}
#[inline]
/// Returns the local address that this server is bound to.
pub fn local_addr(&self) -> Option<net::SocketAddr> {
self.0.local_addr
}
#[inline] #[inline]
/// Keep alive duration if configured. /// Keep alive duration if configured.
pub fn keep_alive(&self) -> Option<Duration> { pub fn keep_alive(&self) -> Option<Duration> {
@@ -122,10 +104,10 @@ impl ServiceConfig {
#[inline] #[inline]
/// Client timeout for first request. /// Client timeout for first request.
pub fn client_timer(&self) -> Option<Delay> { pub fn client_timer(&self) -> Option<Delay> {
let delay_time = self.0.client_timeout; let delay = self.0.client_timeout;
if delay_time != 0 { if delay != 0 {
Some(delay_until( Some(Delay::new(
self.0.timer.now() + Duration::from_millis(delay_time), self.0.timer.now() + Duration::from_millis(delay),
)) ))
} else { } else {
None None
@@ -156,7 +138,7 @@ impl ServiceConfig {
/// Return keep-alive timer delay is configured. /// Return keep-alive timer delay is configured.
pub fn keep_alive_timer(&self) -> Option<Delay> { pub fn keep_alive_timer(&self) -> Option<Delay> {
if let Some(ka) = self.0.keep_alive { if let Some(ka) = self.0.keep_alive {
Some(delay_until(self.0.timer.now() + ka)) Some(Delay::new(self.0.timer.now() + ka))
} else { } else {
None None
} }
@@ -260,10 +242,12 @@ impl DateService {
// periodic date update // periodic date update
let s = self.clone(); let s = self.clone();
actix_rt::spawn(delay_for(Duration::from_millis(500)).then(move |_| { tokio_current_thread::spawn(sleep(Duration::from_millis(500)).then(
move |_| {
s.0.reset(); s.0.reset();
future::ready(()) future::ok(())
})); },
));
} }
} }
@@ -281,19 +265,26 @@ impl DateService {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use actix_rt::System;
use futures::future;
#[test] #[test]
fn test_date_len() { fn test_date_len() {
assert_eq!(DATE_VALUE_LENGTH, "Sun, 06 Nov 1994 08:49:37 GMT".len()); assert_eq!(DATE_VALUE_LENGTH, "Sun, 06 Nov 1994 08:49:37 GMT".len());
} }
#[actix_rt::test] #[test]
async fn test_date() { fn test_date() {
let settings = ServiceConfig::new(KeepAlive::Os, 0, 0, false, None); let mut rt = System::new("test");
let _ = rt.block_on(future::lazy(|| {
let settings = ServiceConfig::new(KeepAlive::Os, 0, 0);
let mut buf1 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); let mut buf1 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10);
settings.set_date(&mut buf1); settings.set_date(&mut buf1);
let mut buf2 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); let mut buf2 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10);
settings.set_date(&mut buf2); settings.set_date(&mut buf2);
assert_eq!(buf1, buf2); assert_eq!(buf1, buf2);
future::ok::<_, ()>(())
}));
} }
} }

View File

@@ -18,6 +18,7 @@ use super::{Cookie, SameSite};
/// ```rust /// ```rust
/// use actix_http::cookie::Cookie; /// use actix_http::cookie::Cookie;
/// ///
/// # fn main() {
/// let cookie: Cookie = Cookie::build("name", "value") /// let cookie: Cookie = Cookie::build("name", "value")
/// .domain("www.rust-lang.org") /// .domain("www.rust-lang.org")
/// .path("/") /// .path("/")
@@ -25,6 +26,7 @@ use super::{Cookie, SameSite};
/// .http_only(true) /// .http_only(true)
/// .max_age(84600) /// .max_age(84600)
/// .finish(); /// .finish();
/// # }
/// ``` /// ```
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct CookieBuilder { pub struct CookieBuilder {
@@ -63,11 +65,13 @@ impl CookieBuilder {
/// ```rust /// ```rust
/// use actix_http::cookie::Cookie; /// use actix_http::cookie::Cookie;
/// ///
/// # fn main() {
/// let c = Cookie::build("foo", "bar") /// let c = Cookie::build("foo", "bar")
/// .expires(time::now()) /// .expires(time::now())
/// .finish(); /// .finish();
/// ///
/// assert!(c.expires().is_some()); /// assert!(c.expires().is_some());
/// # }
/// ``` /// ```
#[inline] #[inline]
pub fn expires(mut self, when: Tm) -> CookieBuilder { pub fn expires(mut self, when: Tm) -> CookieBuilder {
@@ -82,11 +86,13 @@ impl CookieBuilder {
/// ```rust /// ```rust
/// use actix_http::cookie::Cookie; /// use actix_http::cookie::Cookie;
/// ///
/// # fn main() {
/// let c = Cookie::build("foo", "bar") /// let c = Cookie::build("foo", "bar")
/// .max_age(1800) /// .max_age(1800)
/// .finish(); /// .finish();
/// ///
/// assert_eq!(c.max_age(), Some(time::Duration::seconds(30 * 60))); /// assert_eq!(c.max_age(), Some(time::Duration::seconds(30 * 60)));
/// # }
/// ``` /// ```
#[inline] #[inline]
pub fn max_age(self, seconds: i64) -> CookieBuilder { pub fn max_age(self, seconds: i64) -> CookieBuilder {
@@ -100,11 +106,13 @@ impl CookieBuilder {
/// ```rust /// ```rust
/// use actix_http::cookie::Cookie; /// use actix_http::cookie::Cookie;
/// ///
/// # fn main() {
/// let c = Cookie::build("foo", "bar") /// let c = Cookie::build("foo", "bar")
/// .max_age_time(time::Duration::minutes(30)) /// .max_age_time(time::Duration::minutes(30))
/// .finish(); /// .finish();
/// ///
/// assert_eq!(c.max_age(), Some(time::Duration::seconds(30 * 60))); /// assert_eq!(c.max_age(), Some(time::Duration::seconds(30 * 60)));
/// # }
/// ``` /// ```
#[inline] #[inline]
pub fn max_age_time(mut self, value: Duration) -> CookieBuilder { pub fn max_age_time(mut self, value: Duration) -> CookieBuilder {
@@ -214,12 +222,14 @@ impl CookieBuilder {
/// use actix_http::cookie::Cookie; /// use actix_http::cookie::Cookie;
/// use chrono::Duration; /// use chrono::Duration;
/// ///
/// # fn main() {
/// let c = Cookie::build("foo", "bar") /// let c = Cookie::build("foo", "bar")
/// .permanent() /// .permanent()
/// .finish(); /// .finish();
/// ///
/// assert_eq!(c.max_age(), Some(Duration::days(365 * 20))); /// assert_eq!(c.max_age(), Some(Duration::days(365 * 20)));
/// # assert!(c.expires().is_some()); /// # assert!(c.expires().is_some());
/// # }
/// ``` /// ```
#[inline] #[inline]
pub fn permanent(mut self) -> CookieBuilder { pub fn permanent(mut self) -> CookieBuilder {

View File

@@ -88,7 +88,7 @@ impl SameSite {
} }
impl fmt::Display for SameSite { impl fmt::Display for SameSite {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
SameSite::Strict => write!(f, "Strict"), SameSite::Strict => write!(f, "Strict"),
SameSite::Lax => write!(f, "Lax"), SameSite::Lax => write!(f, "Lax"),

View File

@@ -190,6 +190,7 @@ impl CookieJar {
/// use actix_http::cookie::{CookieJar, Cookie}; /// use actix_http::cookie::{CookieJar, Cookie};
/// use chrono::Duration; /// use chrono::Duration;
/// ///
/// # fn main() {
/// let mut jar = CookieJar::new(); /// let mut jar = CookieJar::new();
/// ///
/// // Assume this cookie originally had a path of "/" and domain of "a.b". /// // Assume this cookie originally had a path of "/" and domain of "a.b".
@@ -203,6 +204,7 @@ impl CookieJar {
/// assert_eq!(delta.len(), 1); /// assert_eq!(delta.len(), 1);
/// assert_eq!(delta[0].name(), "name"); /// assert_eq!(delta[0].name(), "name");
/// assert_eq!(delta[0].max_age(), Some(Duration::seconds(0))); /// assert_eq!(delta[0].max_age(), Some(Duration::seconds(0)));
/// # }
/// ``` /// ```
/// ///
/// Removing a new cookie does not result in a _removal_ cookie: /// Removing a new cookie does not result in a _removal_ cookie:
@@ -241,6 +243,7 @@ impl CookieJar {
/// use actix_http::cookie::{CookieJar, Cookie}; /// use actix_http::cookie::{CookieJar, Cookie};
/// use chrono::Duration; /// use chrono::Duration;
/// ///
/// # fn main() {
/// let mut jar = CookieJar::new(); /// let mut jar = CookieJar::new();
/// ///
/// // Add an original cookie and a new cookie. /// // Add an original cookie and a new cookie.
@@ -258,6 +261,7 @@ impl CookieJar {
/// jar.force_remove(Cookie::new("key", "value")); /// jar.force_remove(Cookie::new("key", "value"));
/// assert_eq!(jar.delta().count(), 0); /// assert_eq!(jar.delta().count(), 0);
/// assert_eq!(jar.iter().count(), 0); /// assert_eq!(jar.iter().count(), 0);
/// # }
/// ``` /// ```
pub fn force_remove<'a>(&mut self, cookie: Cookie<'a>) { pub fn force_remove<'a>(&mut self, cookie: Cookie<'a>) {
self.original_cookies.remove(cookie.name()); self.original_cookies.remove(cookie.name());
@@ -303,7 +307,7 @@ impl CookieJar {
/// // Delta contains two new cookies ("new", "yac") and a removal ("name"). /// // Delta contains two new cookies ("new", "yac") and a removal ("name").
/// assert_eq!(jar.delta().count(), 3); /// assert_eq!(jar.delta().count(), 3);
/// ``` /// ```
pub fn delta(&self) -> Delta<'_> { pub fn delta(&self) -> Delta {
Delta { Delta {
iter: self.delta_cookies.iter(), iter: self.delta_cookies.iter(),
} }
@@ -339,7 +343,7 @@ impl CookieJar {
/// } /// }
/// } /// }
/// ``` /// ```
pub fn iter(&self) -> Iter<'_> { pub fn iter(&self) -> Iter {
Iter { Iter {
delta_cookies: self delta_cookies: self
.delta_cookies .delta_cookies
@@ -382,7 +386,7 @@ impl CookieJar {
/// assert!(jar.get("private").is_some()); /// assert!(jar.get("private").is_some());
/// ``` /// ```
#[cfg(feature = "secure-cookies")] #[cfg(feature = "secure-cookies")]
pub fn private(&mut self, key: &Key) -> PrivateJar<'_> { pub fn private(&mut self, key: &Key) -> PrivateJar {
PrivateJar::new(self, key) PrivateJar::new(self, key)
} }
@@ -420,7 +424,7 @@ impl CookieJar {
/// assert!(jar.get("signed").is_some()); /// assert!(jar.get("signed").is_some());
/// ``` /// ```
#[cfg(feature = "secure-cookies")] #[cfg(feature = "secure-cookies")]
pub fn signed(&mut self, key: &Key) -> SignedJar<'_> { pub fn signed(&mut self, key: &Key) -> SignedJar {
SignedJar::new(self, key) SignedJar::new(self, key)
} }
} }

View File

@@ -110,7 +110,7 @@ impl CookieStr {
/// # Panics /// # Panics
/// ///
/// Panics if `self` is an indexed string and `string` is None. /// Panics if `self` is an indexed string and `string` is None.
fn to_str<'s>(&'s self, string: Option<&'s Cow<'_, str>>) -> &'s str { fn to_str<'s>(&'s self, string: Option<&'s Cow<str>>) -> &'s str {
match *self { match *self {
CookieStr::Indexed(i, j) => { CookieStr::Indexed(i, j) => {
let s = string.expect( let s = string.expect(
@@ -647,11 +647,13 @@ impl<'c> Cookie<'c> {
/// use actix_http::cookie::Cookie; /// use actix_http::cookie::Cookie;
/// use chrono::Duration; /// use chrono::Duration;
/// ///
/// # fn main() {
/// let mut c = Cookie::new("name", "value"); /// let mut c = Cookie::new("name", "value");
/// assert_eq!(c.max_age(), None); /// assert_eq!(c.max_age(), None);
/// ///
/// c.set_max_age(Duration::hours(10)); /// c.set_max_age(Duration::hours(10));
/// assert_eq!(c.max_age(), Some(Duration::hours(10))); /// assert_eq!(c.max_age(), Some(Duration::hours(10)));
/// # }
/// ``` /// ```
#[inline] #[inline]
pub fn set_max_age(&mut self, value: Duration) { pub fn set_max_age(&mut self, value: Duration) {
@@ -699,6 +701,7 @@ impl<'c> Cookie<'c> {
/// ```rust /// ```rust
/// use actix_http::cookie::Cookie; /// use actix_http::cookie::Cookie;
/// ///
/// # fn main() {
/// let mut c = Cookie::new("name", "value"); /// let mut c = Cookie::new("name", "value");
/// assert_eq!(c.expires(), None); /// assert_eq!(c.expires(), None);
/// ///
@@ -707,6 +710,7 @@ impl<'c> Cookie<'c> {
/// ///
/// c.set_expires(now); /// c.set_expires(now);
/// assert!(c.expires().is_some()) /// assert!(c.expires().is_some())
/// # }
/// ``` /// ```
#[inline] #[inline]
pub fn set_expires(&mut self, time: Tm) { pub fn set_expires(&mut self, time: Tm) {
@@ -722,6 +726,7 @@ impl<'c> Cookie<'c> {
/// use actix_http::cookie::Cookie; /// use actix_http::cookie::Cookie;
/// use chrono::Duration; /// use chrono::Duration;
/// ///
/// # fn main() {
/// let mut c = Cookie::new("foo", "bar"); /// let mut c = Cookie::new("foo", "bar");
/// assert!(c.expires().is_none()); /// assert!(c.expires().is_none());
/// assert!(c.max_age().is_none()); /// assert!(c.max_age().is_none());
@@ -729,6 +734,7 @@ impl<'c> Cookie<'c> {
/// c.make_permanent(); /// c.make_permanent();
/// assert!(c.expires().is_some()); /// assert!(c.expires().is_some());
/// assert_eq!(c.max_age(), Some(Duration::days(365 * 20))); /// assert_eq!(c.max_age(), Some(Duration::days(365 * 20)));
/// # }
/// ``` /// ```
pub fn make_permanent(&mut self) { pub fn make_permanent(&mut self) {
let twenty_years = Duration::days(365 * 20); let twenty_years = Duration::days(365 * 20);
@@ -736,7 +742,7 @@ impl<'c> Cookie<'c> {
self.set_expires(time::now() + twenty_years); self.set_expires(time::now() + twenty_years);
} }
fn fmt_parameters(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt_parameters(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(true) = self.http_only() { if let Some(true) = self.http_only() {
write!(f, "; HttpOnly")?; write!(f, "; HttpOnly")?;
} }
@@ -918,10 +924,10 @@ impl<'c> Cookie<'c> {
/// let mut c = Cookie::new("my name", "this; value?"); /// let mut c = Cookie::new("my name", "this; value?");
/// assert_eq!(&c.encoded().to_string(), "my%20name=this%3B%20value%3F"); /// assert_eq!(&c.encoded().to_string(), "my%20name=this%3B%20value%3F");
/// ``` /// ```
pub struct EncodedCookie<'a, 'c>(&'a Cookie<'c>); pub struct EncodedCookie<'a, 'c: 'a>(&'a Cookie<'c>);
impl<'a, 'c: 'a> fmt::Display for EncodedCookie<'a, 'c> { impl<'a, 'c: 'a> fmt::Display for EncodedCookie<'a, 'c> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// Percent-encode the name and value. // Percent-encode the name and value.
let name = percent_encode(self.0.name().as_bytes(), USERINFO); let name = percent_encode(self.0.name().as_bytes(), USERINFO);
let value = percent_encode(self.0.value().as_bytes(), USERINFO); let value = percent_encode(self.0.value().as_bytes(), USERINFO);
@@ -946,7 +952,7 @@ impl<'c> fmt::Display for Cookie<'c> {
/// ///
/// assert_eq!(&cookie.to_string(), "foo=bar; Path=/"); /// assert_eq!(&cookie.to_string(), "foo=bar; Path=/");
/// ``` /// ```
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}={}", self.name(), self.value())?; write!(f, "{}={}", self.name(), self.value())?;
self.fmt_parameters(f) self.fmt_parameters(f)
} }

View File

@@ -40,7 +40,7 @@ impl ParseError {
} }
impl fmt::Display for ParseError { impl fmt::Display for ParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.as_str()) write!(f, "{}", self.as_str())
} }
} }
@@ -51,7 +51,11 @@ impl From<Utf8Error> for ParseError {
} }
} }
impl Error for ParseError {} impl Error for ParseError {
fn description(&self) -> &str {
self.as_str()
}
}
fn indexes_of(needle: &str, haystack: &str) -> Option<(usize, usize)> { fn indexes_of(needle: &str, haystack: &str) -> Option<(usize, usize)> {
let haystack_start = haystack.as_ptr() as usize; let haystack_start = haystack.as_ptr() as usize;

View File

@@ -1,11 +1,13 @@
use ring::hkdf::{Algorithm, KeyType, Prk, HKDF_SHA256}; use ring::digest::{Algorithm, SHA256};
use ring::hkdf::expand;
use ring::hmac::SigningKey;
use ring::rand::{SecureRandom, SystemRandom}; use ring::rand::{SecureRandom, SystemRandom};
use super::private::KEY_LEN as PRIVATE_KEY_LEN; use super::private::KEY_LEN as PRIVATE_KEY_LEN;
use super::signed::KEY_LEN as SIGNED_KEY_LEN; use super::signed::KEY_LEN as SIGNED_KEY_LEN;
static HKDF_DIGEST: Algorithm = HKDF_SHA256; static HKDF_DIGEST: &Algorithm = &SHA256;
const KEYS_INFO: &[&[u8]] = &[b"COOKIE;SIGNED:HMAC-SHA256;PRIVATE:AEAD-AES-256-GCM"]; const KEYS_INFO: &str = "COOKIE;SIGNED:HMAC-SHA256;PRIVATE:AEAD-AES-256-GCM";
/// A cryptographic master key for use with `Signed` and/or `Private` jars. /// A cryptographic master key for use with `Signed` and/or `Private` jars.
/// ///
@@ -23,13 +25,6 @@ pub struct Key {
encryption_key: [u8; PRIVATE_KEY_LEN], encryption_key: [u8; PRIVATE_KEY_LEN],
} }
impl KeyType for &Key {
#[inline]
fn len(&self) -> usize {
SIGNED_KEY_LEN + PRIVATE_KEY_LEN
}
}
impl Key { impl Key {
/// Derives new signing/encryption keys from a master key. /// Derives new signing/encryption keys from a master key.
/// ///
@@ -61,26 +56,21 @@ impl Key {
); );
} }
// An empty `Key` structure; will be filled in with HKDF derived keys. // Expand the user's key into two.
let mut output_key = Key { let prk = SigningKey::new(HKDF_DIGEST, key);
signing_key: [0; SIGNED_KEY_LEN],
encryption_key: [0; PRIVATE_KEY_LEN],
};
// Expand the master key into two HKDF generated keys.
let mut both_keys = [0; SIGNED_KEY_LEN + PRIVATE_KEY_LEN]; let mut both_keys = [0; SIGNED_KEY_LEN + PRIVATE_KEY_LEN];
let prk = Prk::new_less_safe(HKDF_DIGEST, key); expand(&prk, KEYS_INFO.as_bytes(), &mut both_keys);
let okm = prk.expand(KEYS_INFO, &output_key).expect("okm expand");
okm.fill(&mut both_keys).expect("fill keys");
// Copy the key parts into their respective fields. // Copy the keys into their respective arrays.
output_key let mut signing_key = [0; SIGNED_KEY_LEN];
.signing_key let mut encryption_key = [0; PRIVATE_KEY_LEN];
.copy_from_slice(&both_keys[..SIGNED_KEY_LEN]); signing_key.copy_from_slice(&both_keys[..SIGNED_KEY_LEN]);
output_key encryption_key.copy_from_slice(&both_keys[SIGNED_KEY_LEN..]);
.encryption_key
.copy_from_slice(&both_keys[SIGNED_KEY_LEN..]); Key {
output_key signing_key,
encryption_key,
}
} }
/// Generates signing/encryption keys from a secure, random source. Keys are /// Generates signing/encryption keys from a secure, random source. Keys are

View File

@@ -1,8 +1,8 @@
use std::str; use std::str;
use log::warn; use log::warn;
use ring::aead::{Aad, Algorithm, Nonce, AES_256_GCM}; use ring::aead::{open_in_place, seal_in_place, Aad, Algorithm, Nonce, AES_256_GCM};
use ring::aead::{LessSafeKey, UnboundKey}; use ring::aead::{OpeningKey, SealingKey};
use ring::rand::{SecureRandom, SystemRandom}; use ring::rand::{SecureRandom, SystemRandom};
use super::Key; use super::Key;
@@ -53,14 +53,11 @@ impl<'a> PrivateJar<'a> {
} }
let ad = Aad::from(name.as_bytes()); let ad = Aad::from(name.as_bytes());
let key = LessSafeKey::new( let key = OpeningKey::new(ALGO, &self.key).expect("opening key");
UnboundKey::new(&ALGO, &self.key).expect("matching key length"), let (nonce, sealed) = data.split_at_mut(NONCE_LEN);
);
let (nonce, mut sealed) = data.split_at_mut(NONCE_LEN);
let nonce = let nonce =
Nonce::try_assume_unique_for_key(nonce).expect("invalid length of `nonce`"); Nonce::try_assume_unique_for_key(nonce).expect("invalid length of `nonce`");
let unsealed = key let unsealed = open_in_place(&key, nonce, ad, 0, sealed)
.open_in_place(nonce, ad, &mut sealed)
.map_err(|_| "invalid key/nonce/value: bad seal")?; .map_err(|_| "invalid key/nonce/value: bad seal")?;
if let Ok(unsealed_utf8) = str::from_utf8(unsealed) { if let Ok(unsealed_utf8) = str::from_utf8(unsealed) {
@@ -159,7 +156,7 @@ Please change it as soon as possible."
/// Encrypts the cookie's value with /// Encrypts the cookie's value with
/// authenticated encryption assuring confidentiality, integrity, and authenticity. /// authenticated encryption assuring confidentiality, integrity, and authenticity.
fn encrypt_cookie(&self, cookie: &mut Cookie<'_>) { fn encrypt_cookie(&self, cookie: &mut Cookie) {
let name = cookie.name().as_bytes(); let name = cookie.name().as_bytes();
let value = cookie.value().as_bytes(); let value = cookie.value().as_bytes();
let data = encrypt_name_value(name, value, &self.key); let data = encrypt_name_value(name, value, &self.key);
@@ -199,33 +196,30 @@ Please change it as soon as possible."
fn encrypt_name_value(name: &[u8], value: &[u8], key: &[u8]) -> Vec<u8> { fn encrypt_name_value(name: &[u8], value: &[u8], key: &[u8]) -> Vec<u8> {
// Create the `SealingKey` structure. // Create the `SealingKey` structure.
let unbound = UnboundKey::new(&ALGO, key).expect("matching key length"); let key = SealingKey::new(ALGO, key).expect("sealing key creation");
let key = LessSafeKey::new(unbound);
// Create a vec to hold the [nonce | cookie value | overhead]. // Create a vec to hold the [nonce | cookie value | overhead].
let mut data = vec![0; NONCE_LEN + value.len() + ALGO.tag_len()]; let overhead = ALGO.tag_len();
let mut data = vec![0; NONCE_LEN + value.len() + overhead];
// Randomly generate the nonce, then copy the cookie value as input. // Randomly generate the nonce, then copy the cookie value as input.
let (nonce, in_out) = data.split_at_mut(NONCE_LEN); let (nonce, in_out) = data.split_at_mut(NONCE_LEN);
let (in_out, tag) = in_out.split_at_mut(value.len());
in_out.copy_from_slice(value);
// Randomly generate the nonce into the nonce piece.
SystemRandom::new() SystemRandom::new()
.fill(nonce) .fill(nonce)
.expect("couldn't random fill nonce"); .expect("couldn't random fill nonce");
let nonce = Nonce::try_assume_unique_for_key(nonce).expect("invalid `nonce` length"); in_out[..value.len()].copy_from_slice(value);
let nonce =
Nonce::try_assume_unique_for_key(nonce).expect("invalid length of `nonce`");
// Use cookie's name as associated data to prevent value swapping. // Use cookie's name as associated data to prevent value swapping.
let ad = Aad::from(name); let ad = Aad::from(name);
let ad_tag = key
.seal_in_place_separate_tag(nonce, ad, in_out)
.expect("in-place seal");
// Copy the tag into the tag piece. // Perform the actual sealing operation and get the output length.
tag.copy_from_slice(ad_tag.as_ref()); let output_len =
seal_in_place(&key, nonce, ad, in_out, overhead).expect("in-place seal");
// Remove the overhead and return the sealed content. // Remove the overhead and return the sealed content.
data.truncate(NONCE_LEN + output_len);
data data
} }

View File

@@ -1,11 +1,12 @@
use ring::hmac::{self, sign, verify}; use ring::digest::{Algorithm, SHA256};
use ring::hmac::{sign, verify_with_own_key as verify, SigningKey};
use super::Key; use super::Key;
use crate::cookie::{Cookie, CookieJar}; use crate::cookie::{Cookie, CookieJar};
// Keep these in sync, and keep the key len synced with the `signed` docs as // Keep these in sync, and keep the key len synced with the `signed` docs as
// well as the `KEYS_INFO` const in secure::Key. // well as the `KEYS_INFO` const in secure::Key.
static HMAC_DIGEST: hmac::Algorithm = hmac::HMAC_SHA256; static HMAC_DIGEST: &Algorithm = &SHA256;
const BASE64_DIGEST_LEN: usize = 44; const BASE64_DIGEST_LEN: usize = 44;
pub const KEY_LEN: usize = 32; pub const KEY_LEN: usize = 32;
@@ -20,7 +21,7 @@ pub const KEY_LEN: usize = 32;
/// This type is only available when the `secure` feature is enabled. /// This type is only available when the `secure` feature is enabled.
pub struct SignedJar<'a> { pub struct SignedJar<'a> {
parent: &'a mut CookieJar, parent: &'a mut CookieJar,
key: hmac::Key, key: SigningKey,
} }
impl<'a> SignedJar<'a> { impl<'a> SignedJar<'a> {
@@ -31,7 +32,7 @@ impl<'a> SignedJar<'a> {
pub fn new(parent: &'a mut CookieJar, key: &Key) -> SignedJar<'a> { pub fn new(parent: &'a mut CookieJar, key: &Key) -> SignedJar<'a> {
SignedJar { SignedJar {
parent, parent,
key: hmac::Key::new(HMAC_DIGEST, key.signing()), key: SigningKey::new(HMAC_DIGEST, key.signing()),
} }
} }
@@ -129,7 +130,7 @@ impl<'a> SignedJar<'a> {
} }
/// Signs the cookie's value assuring integrity and authenticity. /// Signs the cookie's value assuring integrity and authenticity.
fn sign_cookie(&self, cookie: &mut Cookie<'_>) { fn sign_cookie(&self, cookie: &mut Cookie) {
let digest = sign(&self.key, cookie.value().as_bytes()); let digest = sign(&self.key, cookie.value().as_bytes());
let mut new_value = base64::encode(digest.as_ref()); let mut new_value = base64::encode(digest.as_ref());
new_value.push_str(cookie.value()); new_value.push_str(cookie.value());

View File

@@ -1,13 +1,12 @@
use std::future::Future;
use std::io::{self, Write}; use std::io::{self, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_threadpool::{run, CpuFuture}; use actix_threadpool::{run, CpuFuture};
#[cfg(feature = "brotli")]
use brotli2::write::BrotliDecoder; use brotli2::write::BrotliDecoder;
use bytes::Bytes; use bytes::Bytes;
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
use flate2::write::{GzDecoder, ZlibDecoder}; use flate2::write::{GzDecoder, ZlibDecoder};
use futures_core::{ready, Stream}; use futures::{try_ready, Async, Future, Poll, Stream};
use super::Writer; use super::Writer;
use crate::error::PayloadError; use crate::error::PayloadError;
@@ -24,18 +23,21 @@ pub struct Decoder<S> {
impl<S> Decoder<S> impl<S> Decoder<S>
where where
S: Stream<Item = Result<Bytes, PayloadError>>, S: Stream<Item = Bytes, Error = PayloadError>,
{ {
/// Construct a decoder. /// Construct a decoder.
#[inline] #[inline]
pub fn new(stream: S, encoding: ContentEncoding) -> Decoder<S> { pub fn new(stream: S, encoding: ContentEncoding) -> Decoder<S> {
let decoder = match encoding { let decoder = match encoding {
#[cfg(feature = "brotli")]
ContentEncoding::Br => Some(ContentDecoder::Br(Box::new( ContentEncoding::Br => Some(ContentDecoder::Br(Box::new(
BrotliDecoder::new(Writer::new()), BrotliDecoder::new(Writer::new()),
))), ))),
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
ContentEncoding::Deflate => Some(ContentDecoder::Deflate(Box::new( ContentEncoding::Deflate => Some(ContentDecoder::Deflate(Box::new(
ZlibDecoder::new(Writer::new()), ZlibDecoder::new(Writer::new()),
))), ))),
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
ContentEncoding::Gzip => Some(ContentDecoder::Gzip(Box::new( ContentEncoding::Gzip => Some(ContentDecoder::Gzip(Box::new(
GzDecoder::new(Writer::new()), GzDecoder::new(Writer::new()),
))), ))),
@@ -69,40 +71,34 @@ where
impl<S> Stream for Decoder<S> impl<S> Stream for Decoder<S>
where where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin, S: Stream<Item = Bytes, Error = PayloadError>,
{ {
type Item = Result<Bytes, PayloadError>; type Item = Bytes;
type Error = PayloadError;
fn poll_next( fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
loop { loop {
if let Some(ref mut fut) = self.fut { if let Some(ref mut fut) = self.fut {
let (chunk, decoder) = match ready!(Pin::new(fut).poll(cx)) { let (chunk, decoder) = try_ready!(fut.poll());
Ok(item) => item,
Err(e) => return Poll::Ready(Some(Err(e.into()))),
};
self.decoder = Some(decoder); self.decoder = Some(decoder);
self.fut.take(); self.fut.take();
if let Some(chunk) = chunk { if let Some(chunk) = chunk {
return Poll::Ready(Some(Ok(chunk))); return Ok(Async::Ready(Some(chunk)));
} }
} }
if self.eof { if self.eof {
return Poll::Ready(None); return Ok(Async::Ready(None));
} }
match Pin::new(&mut self.stream).poll_next(cx) { match self.stream.poll()? {
Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))), Async::Ready(Some(chunk)) => {
Poll::Ready(Some(Ok(chunk))) => {
if let Some(mut decoder) = self.decoder.take() { if let Some(mut decoder) = self.decoder.take() {
if chunk.len() < INPLACE { if chunk.len() < INPLACE {
let chunk = decoder.feed_data(chunk)?; let chunk = decoder.feed_data(chunk)?;
self.decoder = Some(decoder); self.decoder = Some(decoder);
if let Some(chunk) = chunk { if let Some(chunk) = chunk {
return Poll::Ready(Some(Ok(chunk))); return Ok(Async::Ready(Some(chunk)));
} }
} else { } else {
self.fut = Some(run(move || { self.fut = Some(run(move || {
@@ -112,40 +108,41 @@ where
} }
continue; continue;
} else { } else {
return Poll::Ready(Some(Ok(chunk))); return Ok(Async::Ready(Some(chunk)));
} }
} }
Poll::Ready(None) => { Async::Ready(None) => {
self.eof = true; self.eof = true;
return if let Some(mut decoder) = self.decoder.take() { return if let Some(mut decoder) = self.decoder.take() {
match decoder.feed_eof() { Ok(Async::Ready(decoder.feed_eof()?))
Ok(Some(res)) => Poll::Ready(Some(Ok(res))),
Ok(None) => Poll::Ready(None),
Err(err) => Poll::Ready(Some(Err(err.into()))),
}
} else { } else {
Poll::Ready(None) Ok(Async::Ready(None))
}; };
} }
Poll::Pending => break, Async::NotReady => break,
} }
} }
Poll::Pending Ok(Async::NotReady)
} }
} }
enum ContentDecoder { enum ContentDecoder {
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
Deflate(Box<ZlibDecoder<Writer>>), Deflate(Box<ZlibDecoder<Writer>>),
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
Gzip(Box<GzDecoder<Writer>>), Gzip(Box<GzDecoder<Writer>>),
#[cfg(feature = "brotli")]
Br(Box<BrotliDecoder<Writer>>), Br(Box<BrotliDecoder<Writer>>),
} }
impl ContentDecoder { impl ContentDecoder {
#[allow(unreachable_patterns)]
fn feed_eof(&mut self) -> io::Result<Option<Bytes>> { fn feed_eof(&mut self) -> io::Result<Option<Bytes>> {
match self { match self {
ContentDecoder::Br(ref mut decoder) => match decoder.flush() { #[cfg(feature = "brotli")]
Ok(()) => { ContentDecoder::Br(ref mut decoder) => match decoder.finish() {
let b = decoder.get_mut().take(); Ok(mut writer) => {
let b = writer.take();
if !b.is_empty() { if !b.is_empty() {
Ok(Some(b)) Ok(Some(b))
} else { } else {
@@ -154,6 +151,7 @@ impl ContentDecoder {
} }
Err(e) => Err(e), Err(e) => Err(e),
}, },
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
ContentDecoder::Gzip(ref mut decoder) => match decoder.try_finish() { ContentDecoder::Gzip(ref mut decoder) => match decoder.try_finish() {
Ok(_) => { Ok(_) => {
let b = decoder.get_mut().take(); let b = decoder.get_mut().take();
@@ -165,6 +163,7 @@ impl ContentDecoder {
} }
Err(e) => Err(e), Err(e) => Err(e),
}, },
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
ContentDecoder::Deflate(ref mut decoder) => match decoder.try_finish() { ContentDecoder::Deflate(ref mut decoder) => match decoder.try_finish() {
Ok(_) => { Ok(_) => {
let b = decoder.get_mut().take(); let b = decoder.get_mut().take();
@@ -176,11 +175,14 @@ impl ContentDecoder {
} }
Err(e) => Err(e), Err(e) => Err(e),
}, },
_ => Ok(None),
} }
} }
#[allow(unreachable_patterns)]
fn feed_data(&mut self, data: Bytes) -> io::Result<Option<Bytes>> { fn feed_data(&mut self, data: Bytes) -> io::Result<Option<Bytes>> {
match self { match self {
#[cfg(feature = "brotli")]
ContentDecoder::Br(ref mut decoder) => match decoder.write_all(&data) { ContentDecoder::Br(ref mut decoder) => match decoder.write_all(&data) {
Ok(_) => { Ok(_) => {
decoder.flush()?; decoder.flush()?;
@@ -193,6 +195,7 @@ impl ContentDecoder {
} }
Err(e) => Err(e), Err(e) => Err(e),
}, },
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
ContentDecoder::Gzip(ref mut decoder) => match decoder.write_all(&data) { ContentDecoder::Gzip(ref mut decoder) => match decoder.write_all(&data) {
Ok(_) => { Ok(_) => {
decoder.flush()?; decoder.flush()?;
@@ -205,6 +208,7 @@ impl ContentDecoder {
} }
Err(e) => Err(e), Err(e) => Err(e),
}, },
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
ContentDecoder::Deflate(ref mut decoder) => match decoder.write_all(&data) { ContentDecoder::Deflate(ref mut decoder) => match decoder.write_all(&data) {
Ok(_) => { Ok(_) => {
decoder.flush()?; decoder.flush()?;
@@ -217,6 +221,7 @@ impl ContentDecoder {
} }
Err(e) => Err(e), Err(e) => Err(e),
}, },
_ => Ok(Some(data)),
} }
} }
} }

View File

@@ -1,23 +1,22 @@
//! Stream encoder //! Stream encoder
use std::future::Future;
use std::io::{self, Write}; use std::io::{self, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_threadpool::{run, CpuFuture}; use actix_threadpool::{run, CpuFuture};
#[cfg(feature = "brotli")]
use brotli2::write::BrotliEncoder; use brotli2::write::BrotliEncoder;
use bytes::Bytes; use bytes::Bytes;
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
use flate2::write::{GzEncoder, ZlibEncoder}; use flate2::write::{GzEncoder, ZlibEncoder};
use futures_core::ready; use futures::{Async, Future, Poll};
use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::body::{Body, BodySize, MessageBody, ResponseBody};
use crate::http::header::{ContentEncoding, CONTENT_ENCODING}; use crate::http::header::{ContentEncoding, CONTENT_ENCODING};
use crate::http::{HeaderValue, StatusCode}; use crate::http::{HeaderValue, HttpTryFrom, StatusCode};
use crate::{Error, ResponseHead}; use crate::{Error, ResponseHead};
use super::Writer; use super::Writer;
const INPLACE: usize = 1024; const INPLACE: usize = 2049;
pub struct Encoder<B> { pub struct Encoder<B> {
eof: bool, eof: bool,
@@ -95,45 +94,43 @@ impl<B: MessageBody> MessageBody for Encoder<B> {
} }
} }
fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(&mut self) -> Poll<Option<Bytes>, Error> {
loop { loop {
if self.eof { if self.eof {
return Poll::Ready(None); return Ok(Async::Ready(None));
} }
if let Some(ref mut fut) = self.fut { if let Some(ref mut fut) = self.fut {
let mut encoder = match ready!(Pin::new(fut).poll(cx)) { let mut encoder = futures::try_ready!(fut.poll());
Ok(item) => item,
Err(e) => return Poll::Ready(Some(Err(e.into()))),
};
let chunk = encoder.take(); let chunk = encoder.take();
self.encoder = Some(encoder); self.encoder = Some(encoder);
self.fut.take(); self.fut.take();
if !chunk.is_empty() { if !chunk.is_empty() {
return Poll::Ready(Some(Ok(chunk))); return Ok(Async::Ready(Some(chunk)));
} }
} }
let result = match self.body { let result = match self.body {
EncoderBody::Bytes(ref mut b) => { EncoderBody::Bytes(ref mut b) => {
if b.is_empty() { if b.is_empty() {
Poll::Ready(None) Async::Ready(None)
} else { } else {
Poll::Ready(Some(Ok(std::mem::replace(b, Bytes::new())))) Async::Ready(Some(std::mem::replace(b, Bytes::new())))
} }
} }
EncoderBody::Stream(ref mut b) => b.poll_next(cx), EncoderBody::Stream(ref mut b) => b.poll_next()?,
EncoderBody::BoxedStream(ref mut b) => b.poll_next(cx), EncoderBody::BoxedStream(ref mut b) => b.poll_next()?,
}; };
match result { match result {
Poll::Ready(Some(Ok(chunk))) => { Async::NotReady => return Ok(Async::NotReady),
Async::Ready(Some(chunk)) => {
if let Some(mut encoder) = self.encoder.take() { if let Some(mut encoder) = self.encoder.take() {
if chunk.len() < INPLACE { if chunk.len() < INPLACE {
encoder.write(&chunk)?; encoder.write(&chunk)?;
let chunk = encoder.take(); let chunk = encoder.take();
self.encoder = Some(encoder); self.encoder = Some(encoder);
if !chunk.is_empty() { if !chunk.is_empty() {
return Poll::Ready(Some(Ok(chunk))); return Ok(Async::Ready(Some(chunk)));
} }
} else { } else {
self.fut = Some(run(move || { self.fut = Some(run(move || {
@@ -142,23 +139,22 @@ impl<B: MessageBody> MessageBody for Encoder<B> {
})); }));
} }
} else { } else {
return Poll::Ready(Some(Ok(chunk))); return Ok(Async::Ready(Some(chunk)));
} }
} }
Poll::Ready(None) => { Async::Ready(None) => {
if let Some(encoder) = self.encoder.take() { if let Some(encoder) = self.encoder.take() {
let chunk = encoder.finish()?; let chunk = encoder.finish()?;
if chunk.is_empty() { if chunk.is_empty() {
return Poll::Ready(None); return Ok(Async::Ready(None));
} else { } else {
self.eof = true; self.eof = true;
return Poll::Ready(Some(Ok(chunk))); return Ok(Async::Ready(Some(chunk)));
} }
} else { } else {
return Poll::Ready(None); return Ok(Async::Ready(None));
} }
} }
val => return val,
} }
} }
} }
@@ -167,27 +163,33 @@ impl<B: MessageBody> MessageBody for Encoder<B> {
fn update_head(encoding: ContentEncoding, head: &mut ResponseHead) { fn update_head(encoding: ContentEncoding, head: &mut ResponseHead) {
head.headers_mut().insert( head.headers_mut().insert(
CONTENT_ENCODING, CONTENT_ENCODING,
HeaderValue::from_static(encoding.as_str()), HeaderValue::try_from(Bytes::from_static(encoding.as_str().as_bytes())).unwrap(),
); );
} }
enum ContentEncoder { enum ContentEncoder {
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
Deflate(ZlibEncoder<Writer>), Deflate(ZlibEncoder<Writer>),
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
Gzip(GzEncoder<Writer>), Gzip(GzEncoder<Writer>),
#[cfg(feature = "brotli")]
Br(BrotliEncoder<Writer>), Br(BrotliEncoder<Writer>),
} }
impl ContentEncoder { impl ContentEncoder {
fn encoder(encoding: ContentEncoding) -> Option<Self> { fn encoder(encoding: ContentEncoding) -> Option<Self> {
match encoding { match encoding {
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
ContentEncoding::Deflate => Some(ContentEncoder::Deflate(ZlibEncoder::new( ContentEncoding::Deflate => Some(ContentEncoder::Deflate(ZlibEncoder::new(
Writer::new(), Writer::new(),
flate2::Compression::fast(), flate2::Compression::fast(),
))), ))),
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
ContentEncoding::Gzip => Some(ContentEncoder::Gzip(GzEncoder::new( ContentEncoding::Gzip => Some(ContentEncoder::Gzip(GzEncoder::new(
Writer::new(), Writer::new(),
flate2::Compression::fast(), flate2::Compression::fast(),
))), ))),
#[cfg(feature = "brotli")]
ContentEncoding::Br => { ContentEncoding::Br => {
Some(ContentEncoder::Br(BrotliEncoder::new(Writer::new(), 3))) Some(ContentEncoder::Br(BrotliEncoder::new(Writer::new(), 3)))
} }
@@ -198,22 +200,28 @@ impl ContentEncoder {
#[inline] #[inline]
pub(crate) fn take(&mut self) -> Bytes { pub(crate) fn take(&mut self) -> Bytes {
match *self { match *self {
#[cfg(feature = "brotli")]
ContentEncoder::Br(ref mut encoder) => encoder.get_mut().take(), ContentEncoder::Br(ref mut encoder) => encoder.get_mut().take(),
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
ContentEncoder::Deflate(ref mut encoder) => encoder.get_mut().take(), ContentEncoder::Deflate(ref mut encoder) => encoder.get_mut().take(),
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
ContentEncoder::Gzip(ref mut encoder) => encoder.get_mut().take(), ContentEncoder::Gzip(ref mut encoder) => encoder.get_mut().take(),
} }
} }
fn finish(self) -> Result<Bytes, io::Error> { fn finish(self) -> Result<Bytes, io::Error> {
match self { match self {
#[cfg(feature = "brotli")]
ContentEncoder::Br(encoder) => match encoder.finish() { ContentEncoder::Br(encoder) => match encoder.finish() {
Ok(writer) => Ok(writer.buf.freeze()), Ok(writer) => Ok(writer.buf.freeze()),
Err(err) => Err(err), Err(err) => Err(err),
}, },
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
ContentEncoder::Gzip(encoder) => match encoder.finish() { ContentEncoder::Gzip(encoder) => match encoder.finish() {
Ok(writer) => Ok(writer.buf.freeze()), Ok(writer) => Ok(writer.buf.freeze()),
Err(err) => Err(err), Err(err) => Err(err),
}, },
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
ContentEncoder::Deflate(encoder) => match encoder.finish() { ContentEncoder::Deflate(encoder) => match encoder.finish() {
Ok(writer) => Ok(writer.buf.freeze()), Ok(writer) => Ok(writer.buf.freeze()),
Err(err) => Err(err), Err(err) => Err(err),
@@ -223,6 +231,7 @@ impl ContentEncoder {
fn write(&mut self, data: &[u8]) -> Result<(), io::Error> { fn write(&mut self, data: &[u8]) -> Result<(), io::Error> {
match *self { match *self {
#[cfg(feature = "brotli")]
ContentEncoder::Br(ref mut encoder) => match encoder.write_all(data) { ContentEncoder::Br(ref mut encoder) => match encoder.write_all(data) {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => {
@@ -230,6 +239,7 @@ impl ContentEncoder {
Err(err) Err(err)
} }
}, },
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
ContentEncoder::Gzip(ref mut encoder) => match encoder.write_all(data) { ContentEncoder::Gzip(ref mut encoder) => match encoder.write_all(data) {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => {
@@ -237,6 +247,7 @@ impl ContentEncoder {
Err(err) Err(err)
} }
}, },
#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))]
ContentEncoder::Deflate(ref mut encoder) => match encoder.write_all(data) { ContentEncoder::Deflate(ref mut encoder) => match encoder.write_all(data) {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => {

View File

@@ -19,9 +19,8 @@ impl Writer {
buf: BytesMut::with_capacity(8192), buf: BytesMut::with_capacity(8192),
} }
} }
fn take(&mut self) -> Bytes { fn take(&mut self) -> Bytes {
self.buf.split().freeze() self.buf.take().freeze()
} }
} }
@@ -30,7 +29,6 @@ impl io::Write for Writer {
self.buf.extend_from_slice(buf); self.buf.extend_from_slice(buf);
Ok(buf.len()) Ok(buf.len())
} }
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
Ok(()) Ok(())
} }

View File

@@ -6,25 +6,24 @@ use std::str::Utf8Error;
use std::string::FromUtf8Error; use std::string::FromUtf8Error;
use std::{fmt, io, result}; use std::{fmt, io, result};
use actix_codec::{Decoder, Encoder};
pub use actix_threadpool::BlockingError; pub use actix_threadpool::BlockingError;
use actix_utils::framed::DispatcherError as FramedDispatcherError;
use actix_utils::timeout::TimeoutError; use actix_utils::timeout::TimeoutError;
use bytes::BytesMut; use bytes::BytesMut;
use derive_more::{Display, From}; use derive_more::{Display, From};
pub use futures_channel::oneshot::Canceled; use futures::Canceled;
use http::uri::InvalidUri; use http::uri::InvalidUri;
use http::{header, Error as HttpError, StatusCode}; use http::{header, Error as HttpError, StatusCode};
use httparse; use httparse;
use serde::de::value::Error as DeError; use serde::de::value::Error as DeError;
use serde_json::error::Error as JsonError; use serde_json::error::Error as JsonError;
use serde_urlencoded::ser::Error as FormError; use serde_urlencoded::ser::Error as FormError;
use tokio_timer::Error as TimerError;
// re-export for convinience // re-export for convinience
use crate::body::Body; use crate::body::Body;
pub use crate::cookie::ParseError as CookieParseError; pub use crate::cookie::ParseError as CookieParseError;
use crate::helpers::Writer; use crate::helpers::Writer;
use crate::response::{Response, ResponseBuilder}; use crate::response::Response;
/// A specialized [`Result`](https://doc.rust-lang.org/std/result/enum.Result.html) /// A specialized [`Result`](https://doc.rust-lang.org/std/result/enum.Result.html)
/// for actix web operations /// for actix web operations
@@ -62,18 +61,16 @@ impl Error {
/// Error that can be converted to `Response` /// Error that can be converted to `Response`
pub trait ResponseError: fmt::Debug + fmt::Display { pub trait ResponseError: fmt::Debug + fmt::Display {
/// Response's status code
///
/// Internal server error is generated by default.
fn status_code(&self) -> StatusCode {
StatusCode::INTERNAL_SERVER_ERROR
}
/// Create response for error /// Create response for error
/// ///
/// Internal server error is generated by default. /// Internal server error is generated by default.
fn error_response(&self) -> Response { fn error_response(&self) -> Response {
let mut resp = Response::new(self.status_code()); Response::new(StatusCode::INTERNAL_SERVER_ERROR)
}
/// Constructs an error response
fn render_response(&self) -> Response {
let mut resp = self.error_response();
let mut buf = BytesMut::new(); let mut buf = BytesMut::new();
let _ = write!(Writer(&mut buf), "{}", self); let _ = write!(Writer(&mut buf), "{}", self);
resp.headers_mut().insert( resp.headers_mut().insert(
@@ -104,33 +101,37 @@ impl dyn ResponseError + 'static {
} }
impl fmt::Display for Error { impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.cause, f) fmt::Display::fmt(&self.cause, f)
} }
} }
impl fmt::Debug for Error { impl fmt::Debug for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", &self.cause) write!(f, "{:?}", &self.cause)
} }
} }
impl std::error::Error for Error {
fn cause(&self) -> Option<&dyn std::error::Error> {
None
}
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
None
}
}
impl From<()> for Error { impl From<()> for Error {
fn from(_: ()) -> Self { fn from(_: ()) -> Self {
Error::from(UnitError) Error::from(UnitError)
} }
} }
impl std::error::Error for Error {
fn description(&self) -> &str {
"actix-http::Error"
}
fn cause(&self) -> Option<&dyn std::error::Error> {
None
}
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
None
}
}
impl From<std::convert::Infallible> for Error { impl From<std::convert::Infallible> for Error {
fn from(_: std::convert::Infallible) -> Self { fn from(_: std::convert::Infallible) -> Self {
// `std::convert::Infallible` indicates an error // `std::convert::Infallible` indicates an error
@@ -155,26 +156,12 @@ impl<T: ResponseError + 'static> From<T> for Error {
} }
} }
/// Convert Response to a Error
impl From<Response> for Error {
fn from(res: Response) -> Error {
InternalError::from_response("", res).into()
}
}
/// Convert ResponseBuilder to a Error
impl From<ResponseBuilder> for Error {
fn from(mut res: ResponseBuilder) -> Error {
InternalError::from_response("", res.finish()).into()
}
}
/// Return `GATEWAY_TIMEOUT` for `TimeoutError` /// Return `GATEWAY_TIMEOUT` for `TimeoutError`
impl<E: ResponseError> ResponseError for TimeoutError<E> { impl<E: ResponseError> ResponseError for TimeoutError<E> {
fn status_code(&self) -> StatusCode { fn error_response(&self) -> Response {
match self { match self {
TimeoutError::Service(e) => e.status_code(), TimeoutError::Service(e) => e.error_response(),
TimeoutError::Timeout => StatusCode::GATEWAY_TIMEOUT, TimeoutError::Timeout => Response::new(StatusCode::GATEWAY_TIMEOUT),
} }
} }
} }
@@ -192,31 +179,31 @@ impl ResponseError for JsonError {}
/// `InternalServerError` for `FormError` /// `InternalServerError` for `FormError`
impl ResponseError for FormError {} impl ResponseError for FormError {}
#[cfg(feature = "openssl")] /// `InternalServerError` for `TimerError`
/// `InternalServerError` for `openssl::ssl::Error` impl ResponseError for TimerError {}
impl ResponseError for actix_connect::ssl::openssl::SslError {}
#[cfg(feature = "openssl")] #[cfg(feature = "ssl")]
/// `InternalServerError` for `openssl::ssl::Error`
impl ResponseError for openssl::ssl::Error {}
#[cfg(feature = "ssl")]
/// `InternalServerError` for `openssl::ssl::HandshakeError` /// `InternalServerError` for `openssl::ssl::HandshakeError`
impl<T: std::fmt::Debug> ResponseError for actix_tls::openssl::HandshakeError<T> {} impl ResponseError for openssl::ssl::HandshakeError<tokio_tcp::TcpStream> {}
/// Return `BAD_REQUEST` for `de::value::Error` /// Return `BAD_REQUEST` for `de::value::Error`
impl ResponseError for DeError { impl ResponseError for DeError {
fn status_code(&self) -> StatusCode { fn error_response(&self) -> Response {
StatusCode::BAD_REQUEST Response::new(StatusCode::BAD_REQUEST)
} }
} }
/// `InternalServerError` for `Canceled`
impl ResponseError for Canceled {}
/// `InternalServerError` for `BlockingError` /// `InternalServerError` for `BlockingError`
impl<E: fmt::Debug> ResponseError for BlockingError<E> {} impl<E: fmt::Debug> ResponseError for BlockingError<E> {}
/// Return `BAD_REQUEST` for `Utf8Error` /// Return `BAD_REQUEST` for `Utf8Error`
impl ResponseError for Utf8Error { impl ResponseError for Utf8Error {
fn status_code(&self) -> StatusCode { fn error_response(&self) -> Response {
StatusCode::BAD_REQUEST Response::new(StatusCode::BAD_REQUEST)
} }
} }
@@ -226,22 +213,32 @@ impl ResponseError for HttpError {}
/// Return `InternalServerError` for `io::Error` /// Return `InternalServerError` for `io::Error`
impl ResponseError for io::Error { impl ResponseError for io::Error {
fn status_code(&self) -> StatusCode { fn error_response(&self) -> Response {
match self.kind() { match self.kind() {
io::ErrorKind::NotFound => StatusCode::NOT_FOUND, io::ErrorKind::NotFound => Response::new(StatusCode::NOT_FOUND),
io::ErrorKind::PermissionDenied => StatusCode::FORBIDDEN, io::ErrorKind::PermissionDenied => Response::new(StatusCode::FORBIDDEN),
_ => StatusCode::INTERNAL_SERVER_ERROR, _ => Response::new(StatusCode::INTERNAL_SERVER_ERROR),
} }
} }
} }
/// `BadRequest` for `InvalidHeaderValue` /// `BadRequest` for `InvalidHeaderValue`
impl ResponseError for header::InvalidHeaderValue { impl ResponseError for header::InvalidHeaderValue {
fn status_code(&self) -> StatusCode { fn error_response(&self) -> Response {
StatusCode::BAD_REQUEST Response::new(StatusCode::BAD_REQUEST)
} }
} }
/// `BadRequest` for `InvalidHeaderValue`
impl ResponseError for header::InvalidHeaderValueBytes {
fn error_response(&self) -> Response {
Response::new(StatusCode::BAD_REQUEST)
}
}
/// `InternalServerError` for `futures::Canceled`
impl ResponseError for Canceled {}
/// A set of errors that can occur during parsing HTTP streams /// A set of errors that can occur during parsing HTTP streams
#[derive(Debug, Display)] #[derive(Debug, Display)]
pub enum ParseError { pub enum ParseError {
@@ -281,8 +278,8 @@ pub enum ParseError {
/// Return `BadRequest` for `ParseError` /// Return `BadRequest` for `ParseError`
impl ResponseError for ParseError { impl ResponseError for ParseError {
fn status_code(&self) -> StatusCode { fn error_response(&self) -> Response {
StatusCode::BAD_REQUEST Response::new(StatusCode::BAD_REQUEST)
} }
} }
@@ -374,7 +371,7 @@ impl From<BlockingError<io::Error>> for PayloadError {
BlockingError::Error(e) => PayloadError::Io(e), BlockingError::Error(e) => PayloadError::Io(e),
BlockingError::Canceled => PayloadError::Io(io::Error::new( BlockingError::Canceled => PayloadError::Io(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
"Operation is canceled", "Thread pool is gone",
)), )),
} }
} }
@@ -385,18 +382,18 @@ impl From<BlockingError<io::Error>> for PayloadError {
/// - `Overflow` returns `PayloadTooLarge` /// - `Overflow` returns `PayloadTooLarge`
/// - Other errors returns `BadRequest` /// - Other errors returns `BadRequest`
impl ResponseError for PayloadError { impl ResponseError for PayloadError {
fn status_code(&self) -> StatusCode { fn error_response(&self) -> Response {
match *self { match *self {
PayloadError::Overflow => StatusCode::PAYLOAD_TOO_LARGE, PayloadError::Overflow => Response::new(StatusCode::PAYLOAD_TOO_LARGE),
_ => StatusCode::BAD_REQUEST, _ => Response::new(StatusCode::BAD_REQUEST),
} }
} }
} }
/// Return `BadRequest` for `cookie::ParseError` /// Return `BadRequest` for `cookie::ParseError`
impl ResponseError for crate::cookie::ParseError { impl ResponseError for crate::cookie::ParseError {
fn status_code(&self) -> StatusCode { fn error_response(&self) -> Response {
StatusCode::BAD_REQUEST Response::new(StatusCode::BAD_REQUEST)
} }
} }
@@ -460,19 +457,11 @@ pub enum ContentTypeError {
/// Return `BadRequest` for `ContentTypeError` /// Return `BadRequest` for `ContentTypeError`
impl ResponseError for ContentTypeError { impl ResponseError for ContentTypeError {
fn status_code(&self) -> StatusCode { fn error_response(&self) -> Response {
StatusCode::BAD_REQUEST Response::new(StatusCode::BAD_REQUEST)
} }
} }
impl<E, U: Encoder + Decoder> ResponseError for FramedDispatcherError<E, U>
where
E: fmt::Debug + fmt::Display,
<U as Encoder>::Error: fmt::Debug,
<U as Decoder>::Error: fmt::Debug,
{
}
/// Helper type that can wrap any error and generate custom response. /// Helper type that can wrap any error and generate custom response.
/// ///
/// In following example any `io::Error` will be converted into "BAD REQUEST" /// In following example any `io::Error` will be converted into "BAD REQUEST"
@@ -480,12 +469,14 @@ where
/// default. /// default.
/// ///
/// ```rust /// ```rust
/// # extern crate actix_http;
/// # use std::io; /// # use std::io;
/// # use actix_http::*; /// # use actix_http::*;
/// ///
/// fn index(req: Request) -> Result<&'static str> { /// fn index(req: Request) -> Result<&'static str> {
/// Err(error::ErrorBadRequest(io::Error::new(io::ErrorKind::Other, "error"))) /// Err(error::ErrorBadRequest(io::Error::new(io::ErrorKind::Other, "error")))
/// } /// }
/// # fn main() {}
/// ``` /// ```
pub struct InternalError<T> { pub struct InternalError<T> {
cause: T, cause: T,
@@ -519,7 +510,7 @@ impl<T> fmt::Debug for InternalError<T>
where where
T: fmt::Debug + 'static, T: fmt::Debug + 'static,
{ {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.cause, f) fmt::Debug::fmt(&self.cause, f)
} }
} }
@@ -528,7 +519,7 @@ impl<T> fmt::Display for InternalError<T>
where where
T: fmt::Display + 'static, T: fmt::Display + 'static,
{ {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.cause, f) fmt::Display::fmt(&self.cause, f)
} }
} }
@@ -537,19 +528,6 @@ impl<T> ResponseError for InternalError<T>
where where
T: fmt::Debug + fmt::Display + 'static, T: fmt::Debug + fmt::Display + 'static,
{ {
fn status_code(&self) -> StatusCode {
match self.status {
InternalErrorType::Status(st) => st,
InternalErrorType::Response(ref resp) => {
if let Some(resp) = resp.borrow().as_ref() {
resp.head().status
} else {
StatusCode::INTERNAL_SERVER_ERROR
}
}
}
}
fn error_response(&self) -> Response { fn error_response(&self) -> Response {
match self.status { match self.status {
InternalErrorType::Status(st) => { InternalErrorType::Status(st) => {
@@ -571,6 +549,18 @@ where
} }
} }
} }
/// Constructs an error response
fn render_response(&self) -> Response {
self.error_response()
}
}
/// Convert Response to a Error
impl From<Response> for Error {
fn from(res: Response) -> Error {
InternalError::from_response("", res).into()
}
} }
/// Helper function that creates wrapper of any error and generate *BAD /// Helper function that creates wrapper of any error and generate *BAD
@@ -963,15 +953,24 @@ where
InternalError::new(err, StatusCode::NETWORK_AUTHENTICATION_REQUIRED).into() InternalError::new(err, StatusCode::NETWORK_AUTHENTICATION_REQUIRED).into()
} }
#[cfg(feature = "failure")] #[cfg(feature = "fail")]
mod failure_integration {
use super::*;
/// Compatibility for `failure::Error` /// Compatibility for `failure::Error`
impl ResponseError for fail_ure::Error {} impl ResponseError for failure::Error {
fn error_response(&self) -> Response {
Response::new(StatusCode::INTERNAL_SERVER_ERROR)
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use http::{Error as HttpError, StatusCode}; use http::{Error as HttpError, StatusCode};
use httparse; use httparse;
use std::error::Error as StdError;
use std::io; use std::io;
#[test] #[test]
@@ -1000,7 +999,7 @@ mod tests {
#[test] #[test]
fn test_error_cause() { fn test_error_cause() {
let orig = io::Error::new(io::ErrorKind::Other, "other"); let orig = io::Error::new(io::ErrorKind::Other, "other");
let desc = orig.to_string(); let desc = orig.description().to_owned();
let e = Error::from(orig); let e = Error::from(orig);
assert_eq!(format!("{}", e.as_response_error()), desc); assert_eq!(format!("{}", e.as_response_error()), desc);
} }
@@ -1008,7 +1007,7 @@ mod tests {
#[test] #[test]
fn test_error_display() { fn test_error_display() {
let orig = io::Error::new(io::ErrorKind::Other, "other"); let orig = io::Error::new(io::ErrorKind::Other, "other");
let desc = orig.to_string(); let desc = orig.description().to_owned();
let e = Error::from(orig); let e = Error::from(orig);
assert_eq!(format!("{}", e), desc); assert_eq!(format!("{}", e), desc);
} }
@@ -1050,7 +1049,7 @@ mod tests {
match ParseError::from($from) { match ParseError::from($from) {
e @ $error => { e @ $error => {
let desc = format!("{}", e); let desc = format!("{}", e);
assert_eq!(desc, format!("IO error: {}", $from)); assert_eq!(desc, format!("IO error: {}", $from.description()));
} }
_ => unreachable!("{:?}", $from), _ => unreachable!("{:?}", $from),
} }

View File

@@ -1,12 +1,12 @@
use std::any::{Any, TypeId}; use std::any::{Any, TypeId};
use std::fmt; use std::fmt;
use fxhash::FxHashMap; use hashbrown::HashMap;
#[derive(Default)] #[derive(Default)]
/// A type map of request extensions. /// A type map of request extensions.
pub struct Extensions { pub struct Extensions {
map: FxHashMap<TypeId, Box<dyn Any>>, map: HashMap<TypeId, Box<dyn Any>>,
} }
impl Extensions { impl Extensions {
@@ -14,7 +14,7 @@ impl Extensions {
#[inline] #[inline]
pub fn new() -> Extensions { pub fn new() -> Extensions {
Extensions { Extensions {
map: FxHashMap::default(), map: HashMap::default(),
} }
} }
@@ -65,7 +65,7 @@ impl Extensions {
} }
impl fmt::Debug for Extensions { impl fmt::Debug for Extensions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Extensions").finish() f.debug_struct("Extensions").finish()
} }
} }

View File

@@ -1,8 +1,13 @@
use std::io; #![allow(unused_imports, unused_variables, dead_code)]
use std::io::{self, Write};
use std::rc::Rc;
use actix_codec::{Decoder, Encoder}; use actix_codec::{Decoder, Encoder};
use bitflags::bitflags; use bitflags::bitflags;
use bytes::{Bytes, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};
use http::header::{
HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, UPGRADE,
};
use http::{Method, Version}; use http::{Method, Version};
use super::decoder::{PayloadDecoder, PayloadItem, PayloadType}; use super::decoder::{PayloadDecoder, PayloadItem, PayloadType};
@@ -11,7 +16,11 @@ use super::{Message, MessageType};
use crate::body::BodySize; use crate::body::BodySize;
use crate::config::ServiceConfig; use crate::config::ServiceConfig;
use crate::error::{ParseError, PayloadError}; use crate::error::{ParseError, PayloadError};
use crate::message::{ConnectionType, RequestHeadType, ResponseHead}; use crate::header::HeaderMap;
use crate::helpers;
use crate::message::{
ConnectionType, Head, MessagePool, RequestHead, RequestHeadType, ResponseHead,
};
bitflags! { bitflags! {
struct Flags: u8 { struct Flags: u8 {
@@ -21,6 +30,8 @@ bitflags! {
} }
} }
const AVERAGE_HEADER_SIZE: usize = 30;
/// HTTP/1 Codec /// HTTP/1 Codec
pub struct ClientCodec { pub struct ClientCodec {
inner: ClientCodecInner, inner: ClientCodecInner,
@@ -40,6 +51,7 @@ struct ClientCodecInner {
// encoder part // encoder part
flags: Flags, flags: Flags,
headers_size: u32,
encoder: encoder::MessageEncoder<RequestHeadType>, encoder: encoder::MessageEncoder<RequestHeadType>,
} }
@@ -68,6 +80,7 @@ impl ClientCodec {
ctype: ConnectionType::Close, ctype: ConnectionType::Close,
flags, flags,
headers_size: 0,
encoder: encoder::MessageEncoder::default(), encoder: encoder::MessageEncoder::default(),
}, },
} }

View File

@@ -1,9 +1,12 @@
use std::{fmt, io}; #![allow(unused_imports, unused_variables, dead_code)]
use std::io::Write;
use std::{fmt, io, net};
use actix_codec::{Decoder, Encoder}; use actix_codec::{Decoder, Encoder};
use bitflags::bitflags; use bitflags::bitflags;
use bytes::BytesMut; use bytes::{BufMut, Bytes, BytesMut};
use http::{Method, Version}; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
use http::{Method, StatusCode, Version};
use super::decoder::{PayloadDecoder, PayloadItem, PayloadType}; use super::decoder::{PayloadDecoder, PayloadItem, PayloadType};
use super::{decoder, encoder}; use super::{decoder, encoder};
@@ -11,7 +14,8 @@ use super::{Message, MessageType};
use crate::body::BodySize; use crate::body::BodySize;
use crate::config::ServiceConfig; use crate::config::ServiceConfig;
use crate::error::ParseError; use crate::error::ParseError;
use crate::message::ConnectionType; use crate::helpers;
use crate::message::{ConnectionType, Head, ResponseHead};
use crate::request::Request; use crate::request::Request;
use crate::response::Response; use crate::response::Response;
@@ -23,6 +27,8 @@ bitflags! {
} }
} }
const AVERAGE_HEADER_SIZE: usize = 30;
/// HTTP/1 Codec /// HTTP/1 Codec
pub struct Codec { pub struct Codec {
config: ServiceConfig, config: ServiceConfig,
@@ -43,7 +49,7 @@ impl Default for Codec {
} }
impl fmt::Debug for Codec { impl fmt::Debug for Codec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "h1::Codec({:?})", self.flags) write!(f, "h1::Codec({:?})", self.flags)
} }
} }
@@ -170,6 +176,7 @@ impl Encoder for Codec {
}; };
// encode message // encode message
let len = dst.len();
self.encoder.encode( self.encoder.encode(
dst, dst,
&mut res, &mut res,
@@ -195,11 +202,17 @@ impl Encoder for Codec {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use bytes::BytesMut; use std::{cmp, io};
use http::Method;
use actix_codec::{AsyncRead, AsyncWrite};
use bytes::{Buf, Bytes, BytesMut};
use http::{Method, Version};
use super::*; use super::*;
use crate::error::ParseError;
use crate::h1::Message;
use crate::httpmessage::HttpMessage; use crate::httpmessage::HttpMessage;
use crate::request::Request;
#[test] #[test]
fn test_http_request_chunked_payload_and_next_message() { fn test_http_request_chunked_payload_and_next_message() {

View File

@@ -1,13 +1,12 @@
use std::convert::TryFrom;
use std::io; use std::io;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::task::Poll;
use actix_codec::Decoder; use actix_codec::Decoder;
use bytes::{Buf, Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{Async, Poll};
use http::header::{HeaderName, HeaderValue}; use http::header::{HeaderName, HeaderValue};
use http::{header, Method, StatusCode, Uri, Version}; use http::{header, HttpTryFrom, Method, StatusCode, Uri, Version};
use httparse; use httparse;
use log::{debug, error, trace}; use log::{debug, error, trace};
@@ -80,8 +79,8 @@ pub(crate) trait MessageType: Sized {
// Unsafe: httparse check header value for valid utf-8 // Unsafe: httparse check header value for valid utf-8
let value = unsafe { let value = unsafe {
HeaderValue::from_maybe_shared_unchecked( HeaderValue::from_shared_unchecked(
slice.slice(idx.value.0..idx.value.1), slice.slice(idx.value.0, idx.value.1),
) )
}; };
match name { match name {
@@ -185,7 +184,6 @@ impl MessageType for Request {
&mut self.head_mut().headers &mut self.head_mut().headers
} }
#[allow(clippy::uninit_assumed_init)]
fn decode(src: &mut BytesMut) -> Result<Option<(Self, PayloadType)>, ParseError> { fn decode(src: &mut BytesMut) -> Result<Option<(Self, PayloadType)>, ParseError> {
// Unsafe: we read only this data only after httparse parses headers into. // Unsafe: we read only this data only after httparse parses headers into.
// performance bump for pipeline benchmarks. // performance bump for pipeline benchmarks.
@@ -193,7 +191,7 @@ impl MessageType for Request {
unsafe { MaybeUninit::uninit().assume_init() }; unsafe { MaybeUninit::uninit().assume_init() };
let (len, method, uri, ver, h_len) = { let (len, method, uri, ver, h_len) = {
let mut parsed: [httparse::Header<'_>; MAX_HEADERS] = let mut parsed: [httparse::Header; MAX_HEADERS] =
unsafe { MaybeUninit::uninit().assume_init() }; unsafe { MaybeUninit::uninit().assume_init() };
let mut req = httparse::Request::new(&mut parsed); let mut req = httparse::Request::new(&mut parsed);
@@ -261,7 +259,6 @@ impl MessageType for ResponseHead {
&mut self.headers &mut self.headers
} }
#[allow(clippy::uninit_assumed_init)]
fn decode(src: &mut BytesMut) -> Result<Option<(Self, PayloadType)>, ParseError> { fn decode(src: &mut BytesMut) -> Result<Option<(Self, PayloadType)>, ParseError> {
// Unsafe: we read only this data only after httparse parses headers into. // Unsafe: we read only this data only after httparse parses headers into.
// performance bump for pipeline benchmarks. // performance bump for pipeline benchmarks.
@@ -269,7 +266,7 @@ impl MessageType for ResponseHead {
unsafe { MaybeUninit::uninit().assume_init() }; unsafe { MaybeUninit::uninit().assume_init() };
let (len, ver, status, h_len) = { let (len, ver, status, h_len) = {
let mut parsed: [httparse::Header<'_>; MAX_HEADERS] = let mut parsed: [httparse::Header; MAX_HEADERS] =
unsafe { MaybeUninit::uninit().assume_init() }; unsafe { MaybeUninit::uninit().assume_init() };
let mut res = httparse::Response::new(&mut parsed); let mut res = httparse::Response::new(&mut parsed);
@@ -328,7 +325,7 @@ pub(crate) struct HeaderIndex {
impl HeaderIndex { impl HeaderIndex {
pub(crate) fn record( pub(crate) fn record(
bytes: &[u8], bytes: &[u8],
headers: &[httparse::Header<'_>], headers: &[httparse::Header],
indices: &mut [HeaderIndex], indices: &mut [HeaderIndex],
) { ) {
let bytes_ptr = bytes.as_ptr() as usize; let bytes_ptr = bytes.as_ptr() as usize;
@@ -431,7 +428,7 @@ impl Decoder for PayloadDecoder {
let len = src.len() as u64; let len = src.len() as u64;
let buf; let buf;
if *remaining > len { if *remaining > len {
buf = src.split().freeze(); buf = src.take().freeze();
*remaining -= len; *remaining -= len;
} else { } else {
buf = src.split_to(*remaining as usize).freeze(); buf = src.split_to(*remaining as usize).freeze();
@@ -445,10 +442,9 @@ impl Decoder for PayloadDecoder {
loop { loop {
let mut buf = None; let mut buf = None;
// advances the chunked state // advances the chunked state
*state = match state.step(src, size, &mut buf) { *state = match state.step(src, size, &mut buf)? {
Poll::Pending => return Ok(None), Async::NotReady => return Ok(None),
Poll::Ready(Ok(state)) => state, Async::Ready(state) => state,
Poll::Ready(Err(e)) => return Err(e),
}; };
if *state == ChunkedState::End { if *state == ChunkedState::End {
trace!("End of chunked stream"); trace!("End of chunked stream");
@@ -466,7 +462,7 @@ impl Decoder for PayloadDecoder {
if src.is_empty() { if src.is_empty() {
Ok(None) Ok(None)
} else { } else {
Ok(Some(PayloadItem::Chunk(src.split().freeze()))) Ok(Some(PayloadItem::Chunk(src.take().freeze())))
} }
} }
} }
@@ -477,10 +473,10 @@ macro_rules! byte (
($rdr:ident) => ({ ($rdr:ident) => ({
if $rdr.len() > 0 { if $rdr.len() > 0 {
let b = $rdr[0]; let b = $rdr[0];
$rdr.advance(1); $rdr.split_to(1);
b b
} else { } else {
return Poll::Pending return Ok(Async::NotReady)
} }
}) })
); );
@@ -491,7 +487,7 @@ impl ChunkedState {
body: &mut BytesMut, body: &mut BytesMut,
size: &mut u64, size: &mut u64,
buf: &mut Option<Bytes>, buf: &mut Option<Bytes>,
) -> Poll<Result<ChunkedState, io::Error>> { ) -> Poll<ChunkedState, io::Error> {
use self::ChunkedState::*; use self::ChunkedState::*;
match *self { match *self {
Size => ChunkedState::read_size(body, size), Size => ChunkedState::read_size(body, size),
@@ -503,14 +499,10 @@ impl ChunkedState {
BodyLf => ChunkedState::read_body_lf(body), BodyLf => ChunkedState::read_body_lf(body),
EndCr => ChunkedState::read_end_cr(body), EndCr => ChunkedState::read_end_cr(body),
EndLf => ChunkedState::read_end_lf(body), EndLf => ChunkedState::read_end_lf(body),
End => Poll::Ready(Ok(ChunkedState::End)), End => Ok(Async::Ready(ChunkedState::End)),
} }
} }
fn read_size(rdr: &mut BytesMut, size: &mut u64) -> Poll<ChunkedState, io::Error> {
fn read_size(
rdr: &mut BytesMut,
size: &mut u64,
) -> Poll<Result<ChunkedState, io::Error>> {
let radix = 16; let radix = 16;
match byte!(rdr) { match byte!(rdr) {
b @ b'0'..=b'9' => { b @ b'0'..=b'9' => {
@@ -525,49 +517,48 @@ impl ChunkedState {
*size *= radix; *size *= radix;
*size += u64::from(b + 10 - b'A'); *size += u64::from(b + 10 - b'A');
} }
b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)), b'\t' | b' ' => return Ok(Async::Ready(ChunkedState::SizeLws)),
b';' => return Poll::Ready(Ok(ChunkedState::Extension)), b';' => return Ok(Async::Ready(ChunkedState::Extension)),
b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)), b'\r' => return Ok(Async::Ready(ChunkedState::SizeLf)),
_ => { _ => {
return Poll::Ready(Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk size line: Invalid Size", "Invalid chunk size line: Invalid Size",
))); ));
} }
} }
Poll::Ready(Ok(ChunkedState::Size)) Ok(Async::Ready(ChunkedState::Size))
} }
fn read_size_lws(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
fn read_size_lws(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
trace!("read_size_lws"); trace!("read_size_lws");
match byte!(rdr) { match byte!(rdr) {
// LWS can follow the chunk size, but no more digits can come // LWS can follow the chunk size, but no more digits can come
b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)), b'\t' | b' ' => Ok(Async::Ready(ChunkedState::SizeLws)),
b';' => Poll::Ready(Ok(ChunkedState::Extension)), b';' => Ok(Async::Ready(ChunkedState::Extension)),
b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)),
_ => Poll::Ready(Err(io::Error::new( _ => Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk size linear white space", "Invalid chunk size linear white space",
))), )),
} }
} }
fn read_extension(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> { fn read_extension(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
match byte!(rdr) { match byte!(rdr) {
b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)),
_ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions _ => Ok(Async::Ready(ChunkedState::Extension)), // no supported extensions
} }
} }
fn read_size_lf( fn read_size_lf(
rdr: &mut BytesMut, rdr: &mut BytesMut,
size: &mut u64, size: &mut u64,
) -> Poll<Result<ChunkedState, io::Error>> { ) -> Poll<ChunkedState, io::Error> {
match byte!(rdr) { match byte!(rdr) {
b'\n' if *size > 0 => Poll::Ready(Ok(ChunkedState::Body)), b'\n' if *size > 0 => Ok(Async::Ready(ChunkedState::Body)),
b'\n' if *size == 0 => Poll::Ready(Ok(ChunkedState::EndCr)), b'\n' if *size == 0 => Ok(Async::Ready(ChunkedState::EndCr)),
_ => Poll::Ready(Err(io::Error::new( _ => Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk size LF", "Invalid chunk size LF",
))), )),
} }
} }
@@ -575,16 +566,16 @@ impl ChunkedState {
rdr: &mut BytesMut, rdr: &mut BytesMut,
rem: &mut u64, rem: &mut u64,
buf: &mut Option<Bytes>, buf: &mut Option<Bytes>,
) -> Poll<Result<ChunkedState, io::Error>> { ) -> Poll<ChunkedState, io::Error> {
trace!("Chunked read, remaining={:?}", rem); trace!("Chunked read, remaining={:?}", rem);
let len = rdr.len() as u64; let len = rdr.len() as u64;
if len == 0 { if len == 0 {
Poll::Ready(Ok(ChunkedState::Body)) Ok(Async::Ready(ChunkedState::Body))
} else { } else {
let slice; let slice;
if *rem > len { if *rem > len {
slice = rdr.split().freeze(); slice = rdr.take().freeze();
*rem -= len; *rem -= len;
} else { } else {
slice = rdr.split_to(*rem as usize).freeze(); slice = rdr.split_to(*rem as usize).freeze();
@@ -592,47 +583,47 @@ impl ChunkedState {
} }
*buf = Some(slice); *buf = Some(slice);
if *rem > 0 { if *rem > 0 {
Poll::Ready(Ok(ChunkedState::Body)) Ok(Async::Ready(ChunkedState::Body))
} else { } else {
Poll::Ready(Ok(ChunkedState::BodyCr)) Ok(Async::Ready(ChunkedState::BodyCr))
} }
} }
} }
fn read_body_cr(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> { fn read_body_cr(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
match byte!(rdr) { match byte!(rdr) {
b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)), b'\r' => Ok(Async::Ready(ChunkedState::BodyLf)),
_ => Poll::Ready(Err(io::Error::new( _ => Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk body CR", "Invalid chunk body CR",
))), )),
} }
} }
fn read_body_lf(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> { fn read_body_lf(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
match byte!(rdr) { match byte!(rdr) {
b'\n' => Poll::Ready(Ok(ChunkedState::Size)), b'\n' => Ok(Async::Ready(ChunkedState::Size)),
_ => Poll::Ready(Err(io::Error::new( _ => Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk body LF", "Invalid chunk body LF",
))), )),
} }
} }
fn read_end_cr(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> { fn read_end_cr(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
match byte!(rdr) { match byte!(rdr) {
b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)), b'\r' => Ok(Async::Ready(ChunkedState::EndLf)),
_ => Poll::Ready(Err(io::Error::new( _ => Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk end CR", "Invalid chunk end CR",
))), )),
} }
} }
fn read_end_lf(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> { fn read_end_lf(rdr: &mut BytesMut) -> Poll<ChunkedState, io::Error> {
match byte!(rdr) { match byte!(rdr) {
b'\n' => Poll::Ready(Ok(ChunkedState::End)), b'\n' => Ok(Async::Ready(ChunkedState::End)),
_ => Poll::Ready(Err(io::Error::new( _ => Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
"Invalid chunk end LF", "Invalid chunk end LF",
))), )),
} }
} }
} }

View File

@@ -1,15 +1,15 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::future::Future; use std::time::Instant;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, io, net}; use std::{fmt, io, net};
use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts}; use actix_codec::{Decoder, Encoder, Framed, FramedParts};
use actix_rt::time::{delay_until, Delay, Instant}; use actix_server_config::IoStream;
use actix_service::Service; use actix_service::Service;
use bitflags::bitflags; use bitflags::bitflags;
use bytes::{Buf, BytesMut}; use bytes::{BufMut, BytesMut};
use futures::{Async, Future, Poll};
use log::{error, trace}; use log::{error, trace};
use tokio_timer::Delay;
use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::body::{Body, BodySize, MessageBody, ResponseBody};
use crate::cloneable::CloneableService; use crate::cloneable::CloneableService;
@@ -166,7 +166,7 @@ impl PartialEq for PollResponse {
impl<T, S, B, X, U> Dispatcher<T, S, B, X, U> impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
@@ -184,7 +184,6 @@ where
expect: CloneableService<X>, expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>, upgrade: Option<CloneableService<U>>,
on_connect: Option<Box<dyn DataFactory>>, on_connect: Option<Box<dyn DataFactory>>,
peer_addr: Option<net::SocketAddr>,
) -> Self { ) -> Self {
Dispatcher::with_timeout( Dispatcher::with_timeout(
stream, stream,
@@ -196,7 +195,6 @@ where
expect, expect,
upgrade, upgrade,
on_connect, on_connect,
peer_addr,
) )
} }
@@ -211,7 +209,6 @@ where
expect: CloneableService<X>, expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>, upgrade: Option<CloneableService<U>>,
on_connect: Option<Box<dyn DataFactory>>, on_connect: Option<Box<dyn DataFactory>>,
peer_addr: Option<net::SocketAddr>,
) -> Self { ) -> Self {
let keepalive = config.keep_alive_enabled(); let keepalive = config.keep_alive_enabled();
let flags = if keepalive { let flags = if keepalive {
@@ -235,6 +232,7 @@ where
payload: None, payload: None,
state: State::None, state: State::None,
error: None, error: None,
peer_addr: io.peer_addr(),
messages: VecDeque::new(), messages: VecDeque::new(),
io, io,
codec, codec,
@@ -244,7 +242,6 @@ where
upgrade, upgrade,
on_connect, on_connect,
flags, flags,
peer_addr,
ka_expire, ka_expire,
ka_timer, ka_timer,
}), }),
@@ -254,7 +251,7 @@ where
impl<T, S, B, X, U> InnerDispatcher<T, S, B, X, U> impl<T, S, B, X, U> InnerDispatcher<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
@@ -264,14 +261,14 @@ where
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
fn can_read(&self, cx: &mut Context<'_>) -> bool { fn can_read(&self) -> bool {
if self if self
.flags .flags
.intersects(Flags::READ_DISCONNECT | Flags::UPGRADE) .intersects(Flags::READ_DISCONNECT | Flags::UPGRADE)
{ {
false false
} else if let Some(ref info) = self.payload { } else if let Some(ref info) = self.payload {
info.need_read(cx) == PayloadStatus::Read info.need_read() == PayloadStatus::Read
} else { } else {
true true
} }
@@ -290,7 +287,7 @@ where
/// ///
/// true - got whouldblock /// true - got whouldblock
/// false - didnt get whouldblock /// false - didnt get whouldblock
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<bool, DispatchError> { fn poll_flush(&mut self) -> Result<bool, DispatchError> {
if self.write_buf.is_empty() { if self.write_buf.is_empty() {
return Ok(false); return Ok(false);
} }
@@ -298,31 +295,31 @@ where
let len = self.write_buf.len(); let len = self.write_buf.len();
let mut written = 0; let mut written = 0;
while written < len { while written < len {
match unsafe { Pin::new_unchecked(&mut self.io) } match self.io.write(&self.write_buf[written..]) {
.poll_write(cx, &self.write_buf[written..]) Ok(0) => {
{
Poll::Ready(Ok(0)) => {
return Err(DispatchError::Io(io::Error::new( return Err(DispatchError::Io(io::Error::new(
io::ErrorKind::WriteZero, io::ErrorKind::WriteZero,
"", "",
))); )));
} }
Poll::Ready(Ok(n)) => { Ok(n) => {
written += n; written += n;
} }
Poll::Pending => { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if written > 0 { if written > 0 {
self.write_buf.advance(written); let _ = self.write_buf.split_to(written);
} }
return Ok(true); return Ok(true);
} }
Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)), Err(err) => return Err(DispatchError::Io(err)),
} }
} }
if written > 0 {
if written == self.write_buf.len() { if written == self.write_buf.len() {
unsafe { self.write_buf.set_len(0) } unsafe { self.write_buf.set_len(0) }
} else { } else {
self.write_buf.advance(written); let _ = self.write_buf.split_to(written);
}
} }
Ok(false) Ok(false)
} }
@@ -353,15 +350,12 @@ where
.extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
} }
fn poll_response( fn poll_response(&mut self) -> Result<PollResponse, DispatchError> {
&mut self,
cx: &mut Context<'_>,
) -> Result<PollResponse, DispatchError> {
loop { loop {
let state = match self.state { let state = match self.state {
State::None => match self.messages.pop_front() { State::None => match self.messages.pop_front() {
Some(DispatcherMessage::Item(req)) => { Some(DispatcherMessage::Item(req)) => {
Some(self.handle_request(req, cx)?) Some(self.handle_request(req)?)
} }
Some(DispatcherMessage::Error(res)) => { Some(DispatcherMessage::Error(res)) => {
Some(self.send_response(res, ResponseBody::Other(Body::Empty))?) Some(self.send_response(res, ResponseBody::Other(Body::Empty))?)
@@ -371,58 +365,54 @@ where
} }
None => None, None => None,
}, },
State::ExpectCall(ref mut fut) => { State::ExpectCall(ref mut fut) => match fut.poll() {
match unsafe { Pin::new_unchecked(fut) }.poll(cx) { Ok(Async::Ready(req)) => {
Poll::Ready(Ok(req)) => {
self.send_continue(); self.send_continue();
self.state = State::ServiceCall(self.service.call(req)); self.state = State::ServiceCall(self.service.call(req));
continue; continue;
} }
Poll::Ready(Err(e)) => { Ok(Async::NotReady) => None,
Err(e) => {
let res: Response = e.into().into(); let res: Response = e.into().into();
let (res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
Some(self.send_response(res, body.into_body())?) Some(self.send_response(res, body.into_body())?)
} }
Poll::Pending => None, },
} State::ServiceCall(ref mut fut) => match fut.poll() {
} Ok(Async::Ready(res)) => {
State::ServiceCall(ref mut fut) => {
match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
Poll::Ready(Ok(res)) => {
let (res, body) = res.into().replace_body(()); let (res, body) = res.into().replace_body(());
self.state = self.send_response(res, body)?; self.state = self.send_response(res, body)?;
continue; continue;
} }
Poll::Ready(Err(e)) => { Ok(Async::NotReady) => None,
Err(e) => {
let res: Response = e.into().into(); let res: Response = e.into().into();
let (res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
Some(self.send_response(res, body.into_body())?) Some(self.send_response(res, body.into_body())?)
} }
Poll::Pending => None, },
}
}
State::SendPayload(ref mut stream) => { State::SendPayload(ref mut stream) => {
loop { loop {
if self.write_buf.len() < HW_BUFFER_SIZE { if self.write_buf.len() < HW_BUFFER_SIZE {
match stream.poll_next(cx) { match stream
Poll::Ready(Some(Ok(item))) => { .poll_next()
.map_err(|_| DispatchError::Unknown)?
{
Async::Ready(Some(item)) => {
self.codec.encode( self.codec.encode(
Message::Chunk(Some(item)), Message::Chunk(Some(item)),
&mut self.write_buf, &mut self.write_buf,
)?; )?;
continue; continue;
} }
Poll::Ready(None) => { Async::Ready(None) => {
self.codec.encode( self.codec.encode(
Message::Chunk(None), Message::Chunk(None),
&mut self.write_buf, &mut self.write_buf,
)?; )?;
self.state = State::None; self.state = State::None;
} }
Poll::Ready(Some(Err(_))) => { Async::NotReady => return Ok(PollResponse::DoNothing),
return Err(DispatchError::Unknown)
}
Poll::Pending => return Ok(PollResponse::DoNothing),
} }
} else { } else {
return Ok(PollResponse::DrainWriteBuf); return Ok(PollResponse::DrainWriteBuf);
@@ -443,7 +433,7 @@ where
// if read-backpressure is enabled and we consumed some data. // if read-backpressure is enabled and we consumed some data.
// we may read more data and retry // we may read more data and retry
if self.state.is_call() { if self.state.is_call() {
if self.poll_request(cx)? { if self.poll_request()? {
continue; continue;
} }
} else if !self.messages.is_empty() { } else if !self.messages.is_empty() {
@@ -456,21 +446,17 @@ where
Ok(PollResponse::DoNothing) Ok(PollResponse::DoNothing)
} }
fn handle_request( fn handle_request(&mut self, req: Request) -> Result<State<S, B, X>, DispatchError> {
&mut self,
req: Request,
cx: &mut Context<'_>,
) -> Result<State<S, B, X>, DispatchError> {
// Handle `EXPECT: 100-Continue` header // Handle `EXPECT: 100-Continue` header
let req = if req.head().expect() { let req = if req.head().expect() {
let mut task = self.expect.call(req); let mut task = self.expect.call(req);
match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) { match task.poll() {
Poll::Ready(Ok(req)) => { Ok(Async::Ready(req)) => {
self.send_continue(); self.send_continue();
req req
} }
Poll::Pending => return Ok(State::ExpectCall(task)), Ok(Async::NotReady) => return Ok(State::ExpectCall(task)),
Poll::Ready(Err(e)) => { Err(e) => {
let e = e.into(); let e = e.into();
let res: Response = e.into(); let res: Response = e.into();
let (res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
@@ -483,13 +469,13 @@ where
// Call service // Call service
let mut task = self.service.call(req); let mut task = self.service.call(req);
match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) { match task.poll() {
Poll::Ready(Ok(res)) => { Ok(Async::Ready(res)) => {
let (res, body) = res.into().replace_body(()); let (res, body) = res.into().replace_body(());
self.send_response(res, body) self.send_response(res, body)
} }
Poll::Pending => Ok(State::ServiceCall(task)), Ok(Async::NotReady) => Ok(State::ServiceCall(task)),
Poll::Ready(Err(e)) => { Err(e) => {
let res: Response = e.into().into(); let res: Response = e.into().into();
let (res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
self.send_response(res, body.into_body()) self.send_response(res, body.into_body())
@@ -498,12 +484,9 @@ where
} }
/// Process one incoming requests /// Process one incoming requests
pub(self) fn poll_request( pub(self) fn poll_request(&mut self) -> Result<bool, DispatchError> {
&mut self,
cx: &mut Context<'_>,
) -> Result<bool, DispatchError> {
// limit a mount of non processed requests // limit a mount of non processed requests
if self.messages.len() >= MAX_PIPELINED_MESSAGES || !self.can_read(cx) { if self.messages.len() >= MAX_PIPELINED_MESSAGES || !self.can_read() {
return Ok(false); return Ok(false);
} }
@@ -538,7 +521,7 @@ where
// handle request early // handle request early
if self.state.is_empty() { if self.state.is_empty() {
self.state = self.handle_request(req, cx)?; self.state = self.handle_request(req)?;
} else { } else {
self.messages.push_back(DispatcherMessage::Item(req)); self.messages.push_back(DispatcherMessage::Item(req));
} }
@@ -604,12 +587,12 @@ where
} }
/// keep-alive timer /// keep-alive timer
fn poll_keepalive(&mut self, cx: &mut Context<'_>) -> Result<(), DispatchError> { fn poll_keepalive(&mut self) -> Result<(), DispatchError> {
if self.ka_timer.is_none() { if self.ka_timer.is_none() {
// shutdown timeout // shutdown timeout
if self.flags.contains(Flags::SHUTDOWN) { if self.flags.contains(Flags::SHUTDOWN) {
if let Some(interval) = self.codec.config().client_disconnect_timer() { if let Some(interval) = self.codec.config().client_disconnect_timer() {
self.ka_timer = Some(delay_until(interval)); self.ka_timer = Some(Delay::new(interval));
} else { } else {
self.flags.insert(Flags::READ_DISCONNECT); self.flags.insert(Flags::READ_DISCONNECT);
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = self.payload.take() {
@@ -622,8 +605,11 @@ where
} }
} }
match Pin::new(&mut self.ka_timer.as_mut().unwrap()).poll(cx) { match self.ka_timer.as_mut().unwrap().poll().map_err(|e| {
Poll::Ready(()) => { error!("Timer error {:?}", e);
DispatchError::Unknown
})? {
Async::Ready(_) => {
// if we get timeout during shutdown, drop connection // if we get timeout during shutdown, drop connection
if self.flags.contains(Flags::SHUTDOWN) { if self.flags.contains(Flags::SHUTDOWN) {
return Err(DispatchError::DisconnectTimeout); return Err(DispatchError::DisconnectTimeout);
@@ -638,9 +624,9 @@ where
if let Some(deadline) = if let Some(deadline) =
self.codec.config().client_disconnect_timer() self.codec.config().client_disconnect_timer()
{ {
if let Some(mut timer) = self.ka_timer.as_mut() { if let Some(timer) = self.ka_timer.as_mut() {
timer.reset(deadline); timer.reset(deadline);
let _ = Pin::new(&mut timer).poll(cx); let _ = timer.poll();
} }
} else { } else {
// no shutdown timeout, drop socket // no shutdown timeout, drop socket
@@ -664,40 +650,26 @@ where
} else if let Some(deadline) = } else if let Some(deadline) =
self.codec.config().keep_alive_expire() self.codec.config().keep_alive_expire()
{ {
if let Some(mut timer) = self.ka_timer.as_mut() { if let Some(timer) = self.ka_timer.as_mut() {
timer.reset(deadline); timer.reset(deadline);
let _ = Pin::new(&mut timer).poll(cx); let _ = timer.poll();
} }
} }
} else if let Some(mut timer) = self.ka_timer.as_mut() { } else if let Some(timer) = self.ka_timer.as_mut() {
timer.reset(self.ka_expire); timer.reset(self.ka_expire);
let _ = Pin::new(&mut timer).poll(cx); let _ = timer.poll();
} }
} }
Poll::Pending => (), Async::NotReady => (),
} }
Ok(()) Ok(())
} }
} }
impl<T, S, B, X, U> Unpin for Dispatcher<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request = Request>,
S::Error: Into<Error>,
S::Response: Into<Response<B>>,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
{
}
impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U> impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
@@ -707,28 +679,27 @@ where
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
type Output = Result<(), DispatchError>; type Item = ();
type Error = DispatchError;
#[inline] #[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.as_mut().inner { match self.inner {
DispatcherState::Normal(ref mut inner) => { DispatcherState::Normal(ref mut inner) => {
inner.poll_keepalive(cx)?; inner.poll_keepalive()?;
if inner.flags.contains(Flags::SHUTDOWN) { if inner.flags.contains(Flags::SHUTDOWN) {
if inner.flags.contains(Flags::WRITE_DISCONNECT) { if inner.flags.contains(Flags::WRITE_DISCONNECT) {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} else { } else {
// flush buffer // flush buffer
inner.poll_flush(cx)?; inner.poll_flush()?;
if !inner.write_buf.is_empty() { if !inner.write_buf.is_empty() {
Poll::Pending Ok(Async::NotReady)
} else { } else {
match Pin::new(&mut inner.io).poll_shutdown(cx) { match inner.io.shutdown()? {
Poll::Ready(res) => { Async::Ready(_) => Ok(Async::Ready(())),
Poll::Ready(res.map_err(DispatchError::from)) Async::NotReady => Ok(Async::NotReady),
}
Poll::Pending => Poll::Pending,
} }
} }
} }
@@ -736,12 +707,12 @@ where
// read socket into a buf // read socket into a buf
let should_disconnect = let should_disconnect =
if !inner.flags.contains(Flags::READ_DISCONNECT) { if !inner.flags.contains(Flags::READ_DISCONNECT) {
read_available(cx, &mut inner.io, &mut inner.read_buf)? read_available(&mut inner.io, &mut inner.read_buf)?
} else { } else {
None None
}; };
inner.poll_request(cx)?; inner.poll_request()?;
if let Some(true) = should_disconnect { if let Some(true) = should_disconnect {
inner.flags.insert(Flags::READ_DISCONNECT); inner.flags.insert(Flags::READ_DISCONNECT);
if let Some(mut payload) = inner.payload.take() { if let Some(mut payload) = inner.payload.take() {
@@ -750,12 +721,10 @@ where
}; };
loop { loop {
let remaining = if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE {
inner.write_buf.capacity() - inner.write_buf.len(); inner.write_buf.reserve(HW_BUFFER_SIZE);
if remaining < LW_BUFFER_SIZE {
inner.write_buf.reserve(HW_BUFFER_SIZE - remaining);
} }
let result = inner.poll_response(cx)?; let result = inner.poll_response()?;
let drain = result == PollResponse::DrainWriteBuf; let drain = result == PollResponse::DrainWriteBuf;
// switch to upgrade handler // switch to upgrade handler
@@ -773,7 +742,7 @@ where
self.inner = DispatcherState::Upgrade( self.inner = DispatcherState::Upgrade(
inner.upgrade.unwrap().call((req, framed)), inner.upgrade.unwrap().call((req, framed)),
); );
return self.poll(cx); return self.poll();
} else { } else {
panic!() panic!()
} }
@@ -782,14 +751,14 @@ where
// we didnt get WouldBlock from write operation, // we didnt get WouldBlock from write operation,
// so data get written to kernel completely (OSX) // so data get written to kernel completely (OSX)
// and we have to write again otherwise response can get stuck // and we have to write again otherwise response can get stuck
if inner.poll_flush(cx)? || !drain { if inner.poll_flush()? || !drain {
break; break;
} }
} }
// client is gone // client is gone
if inner.flags.contains(Flags::WRITE_DISCONNECT) { if inner.flags.contains(Flags::WRITE_DISCONNECT) {
return Poll::Ready(Ok(())); return Ok(Async::Ready(()));
} }
let is_empty = inner.state.is_empty(); let is_empty = inner.state.is_empty();
@@ -802,64 +771,58 @@ where
// keep-alive and stream errors // keep-alive and stream errors
if is_empty && inner.write_buf.is_empty() { if is_empty && inner.write_buf.is_empty() {
if let Some(err) = inner.error.take() { if let Some(err) = inner.error.take() {
Poll::Ready(Err(err)) Err(err)
} }
// disconnect if keep-alive is not enabled // disconnect if keep-alive is not enabled
else if inner.flags.contains(Flags::STARTED) else if inner.flags.contains(Flags::STARTED)
&& !inner.flags.intersects(Flags::KEEPALIVE) && !inner.flags.intersects(Flags::KEEPALIVE)
{ {
inner.flags.insert(Flags::SHUTDOWN); inner.flags.insert(Flags::SHUTDOWN);
self.poll(cx) self.poll()
} }
// disconnect if shutdown // disconnect if shutdown
else if inner.flags.contains(Flags::SHUTDOWN) { else if inner.flags.contains(Flags::SHUTDOWN) {
self.poll(cx) self.poll()
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} }
} }
DispatcherState::Upgrade(ref mut fut) => { DispatcherState::Upgrade(ref mut fut) => fut.poll().map_err(|e| {
unsafe { Pin::new_unchecked(fut) }.poll(cx).map_err(|e| {
error!("Upgrade handler error: {}", e); error!("Upgrade handler error: {}", e);
DispatchError::Upgrade DispatchError::Upgrade
}) }),
}
DispatcherState::None => panic!(), DispatcherState::None => panic!(),
} }
} }
} }
fn read_available<T>( fn read_available<T>(io: &mut T, buf: &mut BytesMut) -> Result<Option<bool>, io::Error>
cx: &mut Context<'_>,
io: &mut T,
buf: &mut BytesMut,
) -> Result<Option<bool>, io::Error>
where where
T: AsyncRead + Unpin, T: io::Read,
{ {
let mut read_some = false; let mut read_some = false;
loop { loop {
let remaining = buf.capacity() - buf.len(); if buf.remaining_mut() < LW_BUFFER_SIZE {
if remaining < LW_BUFFER_SIZE { buf.reserve(HW_BUFFER_SIZE);
buf.reserve(HW_BUFFER_SIZE - remaining);
} }
match read(cx, io, buf) { let read = unsafe { io.read(buf.bytes_mut()) };
Poll::Pending => { match read {
return if read_some { Ok(Some(false)) } else { Ok(None) }; Ok(n) => {
}
Poll::Ready(Ok(n)) => {
if n == 0 { if n == 0 {
return Ok(Some(true)); return Ok(Some(true));
} else { } else {
read_some = true; read_some = true;
unsafe {
buf.advance_mut(n);
} }
} }
Poll::Ready(Err(e)) => { }
Err(e) => {
return if e.kind() == io::ErrorKind::WouldBlock { return if e.kind() == io::ErrorKind::WouldBlock {
if read_some { if read_some {
Ok(Some(false)) Ok(Some(false))
@@ -870,36 +833,26 @@ where
Ok(Some(true)) Ok(Some(true))
} else { } else {
Err(e) Err(e)
};
} }
} }
} }
} }
}
fn read<T>(
cx: &mut Context<'_>,
io: &mut T,
buf: &mut BytesMut,
) -> Poll<Result<usize, io::Error>>
where
T: AsyncRead + Unpin,
{
Pin::new(io).poll_read_buf(cx, buf)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use actix_service::IntoService; use actix_service::IntoService;
use futures_util::future::{lazy, ok}; use futures::future::{lazy, ok};
use super::*; use super::*;
use crate::error::Error; use crate::error::Error;
use crate::h1::{ExpectHandler, UpgradeHandler}; use crate::h1::{ExpectHandler, UpgradeHandler};
use crate::test::TestBuffer; use crate::test::TestBuffer;
#[actix_rt::test] #[test]
async fn test_req_parse_err() { fn test_req_parse_err() {
lazy(|cx| { let mut sys = actix_rt::System::new("test");
let _ = sys.block_on(lazy(|| {
let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n"); let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n");
let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler<TestBuffer>>::new( let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler<TestBuffer>>::new(
@@ -911,18 +864,14 @@ mod tests {
CloneableService::new(ExpectHandler), CloneableService::new(ExpectHandler),
None, None,
None, None,
None,
); );
match Pin::new(&mut h1).poll(cx) { assert!(h1.poll().is_err());
Poll::Pending => panic!(),
Poll::Ready(res) => assert!(res.is_err()),
}
if let DispatcherState::Normal(ref inner) = h1.inner { if let DispatcherState::Normal(ref inner) = h1.inner {
assert!(inner.flags.contains(Flags::READ_DISCONNECT)); assert!(inner.flags.contains(Flags::READ_DISCONNECT));
assert_eq!(&inner.io.write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n"); assert_eq!(&inner.io.write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n");
} }
}) ok::<_, ()>(())
.await; }));
} }
} }

View File

@@ -1,18 +1,23 @@
#![allow(unused_imports, unused_variables, dead_code)]
use std::fmt::Write as FmtWrite;
use std::io::Write; use std::io::Write;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::ptr::copy_nonoverlapping; use std::rc::Rc;
use std::slice::from_raw_parts_mut; use std::str::FromStr;
use std::{cmp, io}; use std::{cmp, fmt, io, mem};
use bytes::{buf::BufMutExt, BufMut, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};
use crate::body::BodySize; use crate::body::BodySize;
use crate::config::ServiceConfig; use crate::config::ServiceConfig;
use crate::header::map; use crate::header::{map, ContentEncoding};
use crate::helpers; use crate::helpers;
use crate::http::header::{CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; use crate::http::header::{
use crate::http::{HeaderMap, StatusCode, Version}; HeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING,
use crate::message::{ConnectionType, RequestHeadType}; };
use crate::http::{HeaderMap, Method, StatusCode, Version};
use crate::message::{ConnectionType, Head, RequestHead, RequestHeadType, ResponseHead};
use crate::request::Request;
use crate::response::Response; use crate::response::Response;
const AVERAGE_HEADER_SIZE: usize = 30; const AVERAGE_HEADER_SIZE: usize = 30;
@@ -101,7 +106,6 @@ pub(crate) trait MessageType: Sized {
} else { } else {
dst.put_slice(b"\r\ncontent-length: "); dst.put_slice(b"\r\ncontent-length: ");
} }
#[allow(clippy::write_with_newline)]
write!(dst.writer(), "{}\r\n", len)?; write!(dst.writer(), "{}\r\n", len)?;
} }
BodySize::None => dst.put_slice(b"\r\n"), BodySize::None => dst.put_slice(b"\r\n"),
@@ -140,8 +144,8 @@ pub(crate) trait MessageType: Sized {
// write headers // write headers
let mut pos = 0; let mut pos = 0;
let mut has_date = false; let mut has_date = false;
let mut remaining = dst.capacity() - dst.len(); let mut remaining = dst.remaining_mut();
let mut buf = dst.bytes_mut().as_mut_ptr() as *mut u8; let mut buf = unsafe { &mut *(dst.bytes_mut() as *mut [u8]) };
for (key, value) in headers { for (key, value) in headers {
match *key { match *key {
CONNECTION => continue, CONNECTION => continue,
@@ -155,67 +159,61 @@ pub(crate) trait MessageType: Sized {
match value { match value {
map::Value::One(ref val) => { map::Value::One(ref val) => {
let v = val.as_ref(); let v = val.as_ref();
let v_len = v.len(); let len = k.len() + v.len() + 4;
let k_len = k.len();
let len = k_len + v_len + 4;
if len > remaining { if len > remaining {
unsafe { unsafe {
dst.advance_mut(pos); dst.advance_mut(pos);
} }
pos = 0; pos = 0;
dst.reserve(len * 2); dst.reserve(len * 2);
remaining = dst.capacity() - dst.len(); remaining = dst.remaining_mut();
buf = dst.bytes_mut().as_mut_ptr() as *mut u8; unsafe {
buf = &mut *(dst.bytes_mut() as *mut _);
}
} }
// use upper Camel-Case // use upper Camel-Case
unsafe {
if camel_case { if camel_case {
write_camel_case(k, from_raw_parts_mut(buf, k_len)) write_camel_case(k, &mut buf[pos..pos + k.len()]);
} else { } else {
write_data(k, buf, k_len) buf[pos..pos + k.len()].copy_from_slice(k);
} }
buf = buf.add(k_len); pos += k.len();
write_data(b": ", buf, 2); buf[pos..pos + 2].copy_from_slice(b": ");
buf = buf.add(2); pos += 2;
write_data(v, buf, v_len); buf[pos..pos + v.len()].copy_from_slice(v);
buf = buf.add(v_len); pos += v.len();
write_data(b"\r\n", buf, 2); buf[pos..pos + 2].copy_from_slice(b"\r\n");
buf = buf.add(2); pos += 2;
pos += len;
remaining -= len; remaining -= len;
} }
}
map::Value::Multi(ref vec) => { map::Value::Multi(ref vec) => {
for val in vec { for val in vec {
let v = val.as_ref(); let v = val.as_ref();
let v_len = v.len(); let len = k.len() + v.len() + 4;
let k_len = k.len();
let len = k_len + v_len + 4;
if len > remaining { if len > remaining {
unsafe { unsafe {
dst.advance_mut(pos); dst.advance_mut(pos);
} }
pos = 0; pos = 0;
dst.reserve(len * 2); dst.reserve(len * 2);
remaining = dst.capacity() - dst.len(); remaining = dst.remaining_mut();
buf = dst.bytes_mut().as_mut_ptr() as *mut u8; unsafe {
buf = &mut *(dst.bytes_mut() as *mut _);
}
} }
// use upper Camel-Case // use upper Camel-Case
unsafe {
if camel_case { if camel_case {
write_camel_case(k, from_raw_parts_mut(buf, k_len)); write_camel_case(k, &mut buf[pos..pos + k.len()]);
} else { } else {
write_data(k, buf, k_len); buf[pos..pos + k.len()].copy_from_slice(k);
} }
buf = buf.add(k_len); pos += k.len();
write_data(b": ", buf, 2); buf[pos..pos + 2].copy_from_slice(b": ");
buf = buf.add(2); pos += 2;
write_data(v, buf, v_len); buf[pos..pos + v.len()].copy_from_slice(v);
buf = buf.add(v_len); pos += v.len();
write_data(b"\r\n", buf, 2); buf[pos..pos + 2].copy_from_slice(b"\r\n");
buf = buf.add(2); pos += 2;
};
pos += len;
remaining -= len; remaining -= len;
} }
} }
@@ -300,12 +298,6 @@ impl MessageType for RequestHeadType {
Version::HTTP_10 => "HTTP/1.0", Version::HTTP_10 => "HTTP/1.0",
Version::HTTP_11 => "HTTP/1.1", Version::HTTP_11 => "HTTP/1.1",
Version::HTTP_2 => "HTTP/2.0", Version::HTTP_2 => "HTTP/2.0",
Version::HTTP_3 => "HTTP/3.0",
_ =>
return Err(io::Error::new(
io::ErrorKind::Other,
"unsupported version"
)),
} }
) )
.map_err(|e| io::Error::new(io::ErrorKind::Other, e)) .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
@@ -487,10 +479,6 @@ impl<'a> io::Write for Writer<'a> {
} }
} }
unsafe fn write_data(value: &[u8], buf: *mut u8, len: usize) {
copy_nonoverlapping(value.as_ptr(), buf, len);
}
fn write_camel_case(value: &[u8], buffer: &mut [u8]) { fn write_camel_case(value: &[u8], buffer: &mut [u8]) {
let mut index = 0; let mut index = 0;
let key = value; let key = value;
@@ -521,14 +509,12 @@ fn write_camel_case(value: &[u8], buffer: &mut [u8]) {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::rc::Rc;
use bytes::Bytes; use bytes::Bytes;
use http::header::AUTHORIZATION; //use std::rc::Rc;
use super::*; use super::*;
use crate::http::header::{HeaderValue, CONTENT_TYPE}; use crate::http::header::{HeaderValue, CONTENT_TYPE};
use crate::RequestHead; use http::header::AUTHORIZATION;
#[test] #[test]
fn test_chunked_te() { fn test_chunked_te() {
@@ -539,7 +525,7 @@ mod tests {
assert!(enc.encode(b"", &mut bytes).ok().unwrap()); assert!(enc.encode(b"", &mut bytes).ok().unwrap());
} }
assert_eq!( assert_eq!(
bytes.split().freeze(), bytes.take().freeze(),
Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n") Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n")
); );
} }
@@ -562,8 +548,7 @@ mod tests {
ConnectionType::Close, ConnectionType::Close,
&ServiceConfig::default(), &ServiceConfig::default(),
); );
let data = let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap();
String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
assert!(data.contains("Content-Length: 0\r\n")); assert!(data.contains("Content-Length: 0\r\n"));
assert!(data.contains("Connection: close\r\n")); assert!(data.contains("Connection: close\r\n"));
assert!(data.contains("Content-Type: plain/text\r\n")); assert!(data.contains("Content-Type: plain/text\r\n"));
@@ -576,8 +561,7 @@ mod tests {
ConnectionType::KeepAlive, ConnectionType::KeepAlive,
&ServiceConfig::default(), &ServiceConfig::default(),
); );
let data = let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap();
String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
assert!(data.contains("Transfer-Encoding: chunked\r\n")); assert!(data.contains("Transfer-Encoding: chunked\r\n"));
assert!(data.contains("Content-Type: plain/text\r\n")); assert!(data.contains("Content-Type: plain/text\r\n"));
assert!(data.contains("Date: date\r\n")); assert!(data.contains("Date: date\r\n"));
@@ -589,8 +573,7 @@ mod tests {
ConnectionType::KeepAlive, ConnectionType::KeepAlive,
&ServiceConfig::default(), &ServiceConfig::default(),
); );
let data = let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap();
String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
assert!(data.contains("Content-Length: 100\r\n")); assert!(data.contains("Content-Length: 100\r\n"));
assert!(data.contains("Content-Type: plain/text\r\n")); assert!(data.contains("Content-Type: plain/text\r\n"));
assert!(data.contains("Date: date\r\n")); assert!(data.contains("Date: date\r\n"));
@@ -611,8 +594,7 @@ mod tests {
ConnectionType::KeepAlive, ConnectionType::KeepAlive,
&ServiceConfig::default(), &ServiceConfig::default(),
); );
let data = let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap();
String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
assert!(data.contains("transfer-encoding: chunked\r\n")); assert!(data.contains("transfer-encoding: chunked\r\n"));
assert!(data.contains("content-type: xml\r\n")); assert!(data.contains("content-type: xml\r\n"));
assert!(data.contains("content-type: plain/text\r\n")); assert!(data.contains("content-type: plain/text\r\n"));
@@ -645,8 +627,7 @@ mod tests {
ConnectionType::Close, ConnectionType::Close,
&ServiceConfig::default(), &ServiceConfig::default(),
); );
let data = let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap();
String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
assert!(data.contains("content-length: 0\r\n")); assert!(data.contains("content-length: 0\r\n"));
assert!(data.contains("connection: close\r\n")); assert!(data.contains("connection: close\r\n"));
assert!(data.contains("authorization: another authorization\r\n")); assert!(data.contains("authorization: another authorization\r\n"));

View File

@@ -1,23 +1,23 @@
use std::task::{Context, Poll}; use actix_server_config::ServerConfig;
use actix_service::{NewService, Service};
use actix_service::{Service, ServiceFactory}; use futures::future::{ok, FutureResult};
use futures_util::future::{ok, Ready}; use futures::{Async, Poll};
use crate::error::Error; use crate::error::Error;
use crate::request::Request; use crate::request::Request;
pub struct ExpectHandler; pub struct ExpectHandler;
impl ServiceFactory for ExpectHandler { impl NewService for ExpectHandler {
type Config = (); type Config = ServerConfig;
type Request = Request; type Request = Request;
type Response = Request; type Response = Request;
type Error = Error; type Error = Error;
type Service = ExpectHandler; type Service = ExpectHandler;
type InitError = Error; type InitError = Error;
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: &ServerConfig) -> Self::Future {
ok(ExpectHandler) ok(ExpectHandler)
} }
} }
@@ -26,10 +26,10 @@ impl Service for ExpectHandler {
type Request = Request; type Request = Request;
type Response = Request; type Response = Request;
type Error = Error; type Error = Error;
type Future = Ready<Result<Self::Response, Self::Error>>; type Future = FutureResult<Self::Response, Self::Error>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} }
fn call(&mut self, req: Request) -> Self::Future { fn call(&mut self, req: Request) -> Self::Future {

View File

@@ -1,13 +1,12 @@
//! Payload stream //! Payload stream
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::pin::Pin;
use std::rc::{Rc, Weak}; use std::rc::{Rc, Weak};
use std::task::{Context, Poll};
use actix_utils::task::LocalWaker;
use bytes::Bytes; use bytes::Bytes;
use futures_core::Stream; use futures::task::current as current_task;
use futures::task::Task;
use futures::{Async, Poll, Stream};
use crate::error::PayloadError; use crate::error::PayloadError;
@@ -78,24 +77,15 @@ impl Payload {
pub fn unread_data(&mut self, data: Bytes) { pub fn unread_data(&mut self, data: Bytes) {
self.inner.borrow_mut().unread_data(data); self.inner.borrow_mut().unread_data(data);
} }
#[inline]
pub fn readany(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, PayloadError>>> {
self.inner.borrow_mut().readany(cx)
}
} }
impl Stream for Payload { impl Stream for Payload {
type Item = Result<Bytes, PayloadError>; type Item = Bytes;
type Error = PayloadError;
fn poll_next( #[inline]
self: Pin<&mut Self>, fn poll(&mut self) -> Poll<Option<Bytes>, PayloadError> {
cx: &mut Context<'_>, self.inner.borrow_mut().readany()
) -> Poll<Option<Result<Bytes, PayloadError>>> {
self.inner.borrow_mut().readany(cx)
} }
} }
@@ -127,14 +117,19 @@ impl PayloadSender {
} }
#[inline] #[inline]
pub fn need_read(&self, cx: &mut Context<'_>) -> PayloadStatus { pub fn need_read(&self) -> PayloadStatus {
// we check need_read only if Payload (other side) is alive, // we check need_read only if Payload (other side) is alive,
// otherwise always return true (consume payload) // otherwise always return true (consume payload)
if let Some(shared) = self.inner.upgrade() { if let Some(shared) = self.inner.upgrade() {
if shared.borrow().need_read { if shared.borrow().need_read {
PayloadStatus::Read PayloadStatus::Read
} else { } else {
shared.borrow_mut().io_task.register(cx.waker()); #[cfg(not(test))]
{
if shared.borrow_mut().io_task.is_none() {
shared.borrow_mut().io_task = Some(current_task());
}
}
PayloadStatus::Pause PayloadStatus::Pause
} }
} else { } else {
@@ -150,8 +145,8 @@ struct Inner {
err: Option<PayloadError>, err: Option<PayloadError>,
need_read: bool, need_read: bool,
items: VecDeque<Bytes>, items: VecDeque<Bytes>,
task: LocalWaker, task: Option<Task>,
io_task: LocalWaker, io_task: Option<Task>,
} }
impl Inner { impl Inner {
@@ -162,8 +157,8 @@ impl Inner {
err: None, err: None,
items: VecDeque::new(), items: VecDeque::new(),
need_read: true, need_read: true,
task: LocalWaker::new(), task: None,
io_task: LocalWaker::new(), io_task: None,
} }
} }
@@ -183,7 +178,7 @@ impl Inner {
self.items.push_back(data); self.items.push_back(data);
self.need_read = self.len < MAX_BUFFER_SIZE; self.need_read = self.len < MAX_BUFFER_SIZE;
if let Some(task) = self.task.take() { if let Some(task) = self.task.take() {
task.wake() task.notify()
} }
} }
@@ -192,28 +187,34 @@ impl Inner {
self.len self.len
} }
fn readany( fn readany(&mut self) -> Poll<Option<Bytes>, PayloadError> {
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, PayloadError>>> {
if let Some(data) = self.items.pop_front() { if let Some(data) = self.items.pop_front() {
self.len -= data.len(); self.len -= data.len();
self.need_read = self.len < MAX_BUFFER_SIZE; self.need_read = self.len < MAX_BUFFER_SIZE;
if self.need_read && !self.eof { if self.need_read && self.task.is_none() && !self.eof {
self.task.register(cx.waker()); self.task = Some(current_task());
} }
self.io_task.wake(); if let Some(task) = self.io_task.take() {
Poll::Ready(Some(Ok(data))) task.notify()
}
Ok(Async::Ready(Some(data)))
} else if let Some(err) = self.err.take() { } else if let Some(err) = self.err.take() {
Poll::Ready(Some(Err(err))) Err(err)
} else if self.eof { } else if self.eof {
Poll::Ready(None) Ok(Async::Ready(None))
} else { } else {
self.need_read = true; self.need_read = true;
self.task.register(cx.waker()); #[cfg(not(test))]
self.io_task.wake(); {
Poll::Pending if self.task.is_none() {
self.task = Some(current_task());
}
if let Some(task) = self.io_task.take() {
task.notify()
}
}
Ok(Async::NotReady)
} }
} }
@@ -226,10 +227,14 @@ impl Inner {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use futures_util::future::poll_fn; use actix_rt::Runtime;
use futures::future::{lazy, result};
#[actix_rt::test] #[test]
async fn test_unread_data() { fn test_unread_data() {
Runtime::new()
.unwrap()
.block_on(lazy(|| {
let (_, mut payload) = Payload::create(false); let (_, mut payload) = Payload::create(false);
payload.unread_data(Bytes::from("data")); payload.unread_data(Bytes::from("data"));
@@ -237,8 +242,13 @@ mod tests {
assert_eq!(payload.len(), 4); assert_eq!(payload.len(), 4);
assert_eq!( assert_eq!(
Bytes::from("data"), Async::Ready(Some(Bytes::from("data"))),
poll_fn(|cx| payload.readany(cx)).await.unwrap().unwrap() payload.poll().ok().unwrap()
); );
let res: Result<(), ()> = Ok(());
result(res)
}))
.unwrap();
} }
} }

View File

@@ -1,19 +1,16 @@
use std::future::Future; use std::fmt;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use std::{fmt, net};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::Framed;
use actix_rt::net::TcpStream; use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig};
use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory}; use actix_service::{IntoNewService, NewService, Service};
use futures_core::ready; use futures::future::{ok, FutureResult};
use futures_util::future::{ok, Ready}; use futures::{try_ready, Async, Future, IntoFuture, Poll, Stream};
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::cloneable::CloneableService; use crate::cloneable::CloneableService;
use crate::config::ServiceConfig; use crate::config::{KeepAlive, ServiceConfig};
use crate::error::{DispatchError, Error, ParseError}; use crate::error::{DispatchError, Error, ParseError};
use crate::helpers::DataFactory; use crate::helpers::DataFactory;
use crate::request::Request; use crate::request::Request;
@@ -23,32 +20,43 @@ use super::codec::Codec;
use super::dispatcher::Dispatcher; use super::dispatcher::Dispatcher;
use super::{ExpectHandler, Message, UpgradeHandler}; use super::{ExpectHandler, Message, UpgradeHandler};
/// `ServiceFactory` implementation for HTTP1 transport /// `NewService` implementation for HTTP1 transport
pub struct H1Service<T, S, B, X = ExpectHandler, U = UpgradeHandler<T>> { pub struct H1Service<T, P, S, B, X = ExpectHandler, U = UpgradeHandler<T>> {
srv: S, srv: S,
cfg: ServiceConfig, cfg: ServiceConfig,
expect: X, expect: X,
upgrade: Option<U>, upgrade: Option<U>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
_t: PhantomData<(T, B)>, _t: PhantomData<(T, P, B)>,
} }
impl<T, S, B> H1Service<T, S, B> impl<T, P, S, B> H1Service<T, P, S, B>
where where
S: ServiceFactory<Config = (), Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
B: MessageBody, B: MessageBody,
{ {
/// Create new `HttpService` instance with config. /// Create new `HttpService` instance with default config.
pub(crate) fn with_config<F: IntoServiceFactory<S>>( pub fn new<F: IntoNewService<S>>(service: F) -> Self {
cfg: ServiceConfig, let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0);
service: F,
) -> Self {
H1Service { H1Service {
cfg, cfg,
srv: service.into_factory(), srv: service.into_new_service(),
expect: ExpectHandler,
upgrade: None,
on_connect: None,
_t: PhantomData,
}
}
/// Create new `HttpService` instance with config.
pub fn with_config<F: IntoNewService<S>>(cfg: ServiceConfig, service: F) -> Self {
H1Service {
cfg,
srv: service.into_new_service(),
expect: ExpectHandler, expect: ExpectHandler,
upgrade: None, upgrade: None,
on_connect: None, on_connect: None,
@@ -57,153 +65,17 @@ where
} }
} }
impl<S, B, X, U> H1Service<TcpStream, S, B, X, U> impl<T, P, S, B, X, U> H1Service<T, P, S, B, X, U>
where where
S: ServiceFactory<Config = (), Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
B: MessageBody,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: Into<Error>,
X::InitError: fmt::Debug,
U: ServiceFactory<
Config = (),
Request = (Request, Framed<TcpStream, Codec>),
Response = (),
>,
U::Error: fmt::Display + Into<Error>,
U::InitError: fmt::Debug,
{
/// Create simple tcp stream service
pub fn tcp(
self,
) -> impl ServiceFactory<
Config = (),
Request = TcpStream,
Response = (),
Error = DispatchError,
InitError = (),
> {
pipeline_factory(|io: TcpStream| {
let peer_addr = io.peer_addr().ok();
ok((io, peer_addr))
})
.and_then(self)
}
}
#[cfg(feature = "openssl")]
mod openssl {
use super::*;
use actix_tls::openssl::{Acceptor, SslAcceptor, SslStream};
use actix_tls::{openssl::HandshakeError, SslError};
impl<S, B, X, U> H1Service<SslStream<TcpStream>, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: Into<Error>,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
B: MessageBody,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: Into<Error>,
X::InitError: fmt::Debug,
U: ServiceFactory<
Config = (),
Request = (Request, Framed<SslStream<TcpStream>, Codec>),
Response = (),
>,
U::Error: fmt::Display + Into<Error>,
U::InitError: fmt::Debug,
{
/// Create openssl based service
pub fn openssl(
self,
acceptor: SslAcceptor,
) -> impl ServiceFactory<
Config = (),
Request = TcpStream,
Response = (),
Error = SslError<HandshakeError<TcpStream>, DispatchError>,
InitError = (),
> {
pipeline_factory(
Acceptor::new(acceptor)
.map_err(SslError::Ssl)
.map_init_err(|_| panic!()),
)
.and_then(|io: SslStream<TcpStream>| {
let peer_addr = io.get_ref().peer_addr().ok();
ok((io, peer_addr))
})
.and_then(self.map_err(SslError::Service))
}
}
}
#[cfg(feature = "rustls")]
mod rustls {
use super::*;
use actix_tls::rustls::{Acceptor, ServerConfig, TlsStream};
use actix_tls::SslError;
use std::{fmt, io};
impl<S, B, X, U> H1Service<TlsStream<TcpStream>, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: Into<Error>,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
B: MessageBody,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: Into<Error>,
X::InitError: fmt::Debug,
U: ServiceFactory<
Config = (),
Request = (Request, Framed<TlsStream<TcpStream>, Codec>),
Response = (),
>,
U::Error: fmt::Display + Into<Error>,
U::InitError: fmt::Debug,
{
/// Create rustls based service
pub fn rustls(
self,
config: ServerConfig,
) -> impl ServiceFactory<
Config = (),
Request = TcpStream,
Response = (),
Error = SslError<io::Error, DispatchError>,
InitError = (),
> {
pipeline_factory(
Acceptor::new(config)
.map_err(SslError::Ssl)
.map_init_err(|_| panic!()),
)
.and_then(|io: TlsStream<TcpStream>| {
let peer_addr = io.get_ref().0.peer_addr().ok();
ok((io, peer_addr))
})
.and_then(self.map_err(SslError::Service))
}
}
}
impl<T, S, B, X, U> H1Service<T, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
B: MessageBody, B: MessageBody,
{ {
pub fn expect<X1>(self, expect: X1) -> H1Service<T, S, B, X1, U> pub fn expect<X1>(self, expect: X1) -> H1Service<T, P, S, B, X1, U>
where where
X1: ServiceFactory<Request = Request, Response = Request>, X1: NewService<Request = Request, Response = Request>,
X1::Error: Into<Error>, X1::Error: Into<Error>,
X1::InitError: fmt::Debug, X1::InitError: fmt::Debug,
{ {
@@ -217,9 +89,9 @@ where
} }
} }
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> H1Service<T, S, B, X, U1> pub fn upgrade<U1>(self, upgrade: Option<U1>) -> H1Service<T, P, S, B, X, U1>
where where
U1: ServiceFactory<Request = (Request, Framed<T, Codec>), Response = ()>, U1: NewService<Request = (Request, Framed<T, Codec>), Response = ()>,
U1::Error: fmt::Display, U1::Error: fmt::Display,
U1::InitError: fmt::Debug, U1::InitError: fmt::Debug,
{ {
@@ -243,34 +115,38 @@ where
} }
} }
impl<T, S, B, X, U> ServiceFactory for H1Service<T, S, B, X, U> impl<T, P, S, B, X, U> NewService for H1Service<T, P, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
S: ServiceFactory<Config = (), Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
B: MessageBody, B: MessageBody,
X: ServiceFactory<Config = (), Request = Request, Response = Request>, X: NewService<Config = SrvConfig, Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: ServiceFactory<Config = (), Request = (Request, Framed<T, Codec>), Response = ()>, U: NewService<
U::Error: fmt::Display + Into<Error>, Config = SrvConfig,
Request = (Request, Framed<T, Codec>),
Response = (),
>,
U::Error: fmt::Display,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
{ {
type Config = (); type Config = SrvConfig;
type Request = (T, Option<net::SocketAddr>); type Request = Io<T, P>;
type Response = (); type Response = ();
type Error = DispatchError; type Error = DispatchError;
type InitError = (); type InitError = ();
type Service = H1ServiceHandler<T, S::Service, B, X::Service, U::Service>; type Service = H1ServiceHandler<T, P, S::Service, B, X::Service, U::Service>;
type Future = H1ServiceResponse<T, S, B, X, U>; type Future = H1ServiceResponse<T, P, S, B, X, U>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, cfg: &SrvConfig) -> Self::Future {
H1ServiceResponse { H1ServiceResponse {
fut: self.srv.new_service(()), fut: self.srv.new_service(cfg).into_future(),
fut_ex: Some(self.expect.new_service(())), fut_ex: Some(self.expect.new_service(cfg)),
fut_upg: self.upgrade.as_ref().map(|f| f.new_service(())), fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)),
expect: None, expect: None,
upgrade: None, upgrade: None,
on_connect: self.on_connect.clone(), on_connect: self.on_connect.clone(),
@@ -281,99 +157,88 @@ where
} }
#[doc(hidden)] #[doc(hidden)]
#[pin_project::pin_project] pub struct H1ServiceResponse<T, P, S, B, X, U>
pub struct H1ServiceResponse<T, S, B, X, U>
where where
S: ServiceFactory<Request = Request>, S: NewService<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
X: ServiceFactory<Request = Request, Response = Request>, X: NewService<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: ServiceFactory<Request = (Request, Framed<T, Codec>), Response = ()>, U: NewService<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
{ {
#[pin]
fut: S::Future, fut: S::Future,
#[pin]
fut_ex: Option<X::Future>, fut_ex: Option<X::Future>,
#[pin]
fut_upg: Option<U::Future>, fut_upg: Option<U::Future>,
expect: Option<X::Service>, expect: Option<X::Service>,
upgrade: Option<U::Service>, upgrade: Option<U::Service>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
cfg: Option<ServiceConfig>, cfg: Option<ServiceConfig>,
_t: PhantomData<(T, B)>, _t: PhantomData<(T, P, B)>,
} }
impl<T, S, B, X, U> Future for H1ServiceResponse<T, S, B, X, U> impl<T, P, S, B, X, U> Future for H1ServiceResponse<T, P, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
S: ServiceFactory<Request = Request>, S: NewService<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
B: MessageBody, B: MessageBody,
X: ServiceFactory<Request = Request, Response = Request>, X: NewService<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: ServiceFactory<Request = (Request, Framed<T, Codec>), Response = ()>, U: NewService<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
{ {
type Output = Result<H1ServiceHandler<T, S::Service, B, X::Service, U::Service>, ()>; type Item = H1ServiceHandler<T, P, S::Service, B, X::Service, U::Service>;
type Error = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut this = self.as_mut().project(); if let Some(ref mut fut) = self.fut_ex {
let expect = try_ready!(fut
if let Some(fut) = this.fut_ex.as_pin_mut() { .poll()
let expect = ready!(fut
.poll(cx)
.map_err(|e| log::error!("Init http service error: {:?}", e)))?;
this = self.as_mut().project();
*this.expect = Some(expect);
this.fut_ex.set(None);
}
if let Some(fut) = this.fut_upg.as_pin_mut() {
let upgrade = ready!(fut
.poll(cx)
.map_err(|e| log::error!("Init http service error: {:?}", e)))?;
this = self.as_mut().project();
*this.upgrade = Some(upgrade);
this.fut_ex.set(None);
}
let result = ready!(this
.fut
.poll(cx)
.map_err(|e| log::error!("Init http service error: {:?}", e))); .map_err(|e| log::error!("Init http service error: {:?}", e)));
self.expect = Some(expect);
self.fut_ex.take();
}
Poll::Ready(result.map(|service| { if let Some(ref mut fut) = self.fut_upg {
let this = self.as_mut().project(); let upgrade = try_ready!(fut
H1ServiceHandler::new( .poll()
this.cfg.take().unwrap(), .map_err(|e| log::error!("Init http service error: {:?}", e)));
self.upgrade = Some(upgrade);
self.fut_ex.take();
}
let service = try_ready!(self
.fut
.poll()
.map_err(|e| log::error!("Init http service error: {:?}", e)));
Ok(Async::Ready(H1ServiceHandler::new(
self.cfg.take().unwrap(),
service, service,
this.expect.take().unwrap(), self.expect.take().unwrap(),
this.upgrade.take(), self.upgrade.take(),
this.on_connect.clone(), self.on_connect.clone(),
) )))
}))
} }
} }
/// `Service` implementation for HTTP1 transport /// `Service` implementation for HTTP1 transport
pub struct H1ServiceHandler<T, S, B, X, U> { pub struct H1ServiceHandler<T, P, S, B, X, U> {
srv: CloneableService<S>, srv: CloneableService<S>,
expect: CloneableService<X>, expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>, upgrade: Option<CloneableService<U>>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
cfg: ServiceConfig, cfg: ServiceConfig,
_t: PhantomData<(T, B)>, _t: PhantomData<(T, P, B)>,
} }
impl<T, S, B, X, U> H1ServiceHandler<T, S, B, X, U> impl<T, P, S, B, X, U> H1ServiceHandler<T, P, S, B, X, U>
where where
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
@@ -390,7 +255,7 @@ where
expect: X, expect: X,
upgrade: Option<U>, upgrade: Option<U>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
) -> H1ServiceHandler<T, S, B, X, U> { ) -> H1ServiceHandler<T, P, S, B, X, U> {
H1ServiceHandler { H1ServiceHandler {
srv: CloneableService::new(srv), srv: CloneableService::new(srv),
expect: CloneableService::new(expect), expect: CloneableService::new(expect),
@@ -402,9 +267,9 @@ where
} }
} }
impl<T, S, B, X, U> Service for H1ServiceHandler<T, S, B, X, U> impl<T, P, S, B, X, U> Service for H1ServiceHandler<T, P, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
@@ -412,17 +277,17 @@ where
X: Service<Request = Request, Response = Request>, X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display + Into<Error>, U::Error: fmt::Display,
{ {
type Request = (T, Option<net::SocketAddr>); type Request = Io<T, P>;
type Response = (); type Response = ();
type Error = DispatchError; type Error = DispatchError;
type Future = Dispatcher<T, S, B, X, U>; type Future = Dispatcher<T, S, B, X, U>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
let ready = self let ready = self
.expect .expect
.poll_ready(cx) .poll_ready()
.map_err(|e| { .map_err(|e| {
let e = e.into(); let e = e.into();
log::error!("Http service readiness error: {:?}", e); log::error!("Http service readiness error: {:?}", e);
@@ -432,7 +297,7 @@ where
let ready = self let ready = self
.srv .srv
.poll_ready(cx) .poll_ready()
.map_err(|e| { .map_err(|e| {
let e = e.into(); let e = e.into();
log::error!("Http service readiness error: {:?}", e); log::error!("Http service readiness error: {:?}", e);
@@ -441,27 +306,16 @@ where
.is_ready() .is_ready()
&& ready; && ready;
let ready = if let Some(ref mut upg) = self.upgrade {
upg.poll_ready(cx)
.map_err(|e| {
let e = e.into();
log::error!("Http service readiness error: {:?}", e);
DispatchError::Service(e)
})?
.is_ready()
&& ready
} else {
ready
};
if ready { if ready {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} }
fn call(&mut self, (io, addr): Self::Request) -> Self::Future { fn call(&mut self, req: Self::Request) -> Self::Future {
let io = req.into_parts().0;
let on_connect = if let Some(ref on_connect) = self.on_connect { let on_connect = if let Some(ref on_connect) = self.on_connect {
Some(on_connect(&io)) Some(on_connect(&io))
} else { } else {
@@ -475,21 +329,20 @@ where
self.expect.clone(), self.expect.clone(),
self.upgrade.clone(), self.upgrade.clone(),
on_connect, on_connect,
addr,
) )
} }
} }
/// `ServiceFactory` implementation for `OneRequestService` service /// `NewService` implementation for `OneRequestService` service
#[derive(Default)] #[derive(Default)]
pub struct OneRequest<T> { pub struct OneRequest<T, P> {
config: ServiceConfig, config: ServiceConfig,
_t: PhantomData<T>, _t: PhantomData<(T, P)>,
} }
impl<T> OneRequest<T> impl<T, P> OneRequest<T, P>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
{ {
/// Create new `H1SimpleService` instance. /// Create new `H1SimpleService` instance.
pub fn new() -> Self { pub fn new() -> Self {
@@ -500,49 +353,52 @@ where
} }
} }
impl<T> ServiceFactory for OneRequest<T> impl<T, P> NewService for OneRequest<T, P>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
{ {
type Config = (); type Config = SrvConfig;
type Request = T; type Request = Io<T, P>;
type Response = (Request, Framed<T, Codec>); type Response = (Request, Framed<T, Codec>);
type Error = ParseError; type Error = ParseError;
type InitError = (); type InitError = ();
type Service = OneRequestService<T>; type Service = OneRequestService<T, P>;
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: &SrvConfig) -> Self::Future {
ok(OneRequestService { ok(OneRequestService {
_t: PhantomData,
config: self.config.clone(), config: self.config.clone(),
_t: PhantomData,
}) })
} }
} }
/// `Service` implementation for HTTP1 transport. Reads one request and returns /// `Service` implementation for HTTP1 transport. Reads one request and returns
/// request and framed object. /// request and framed object.
pub struct OneRequestService<T> { pub struct OneRequestService<T, P> {
_t: PhantomData<T>,
config: ServiceConfig, config: ServiceConfig,
_t: PhantomData<(T, P)>,
} }
impl<T> Service for OneRequestService<T> impl<T, P> Service for OneRequestService<T, P>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
{ {
type Request = T; type Request = Io<T, P>;
type Response = (Request, Framed<T, Codec>); type Response = (Request, Framed<T, Codec>);
type Error = ParseError; type Error = ParseError;
type Future = OneRequestServiceResponse<T>; type Future = OneRequestServiceResponse<T>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} }
fn call(&mut self, req: Self::Request) -> Self::Future { fn call(&mut self, req: Self::Request) -> Self::Future {
OneRequestServiceResponse { OneRequestServiceResponse {
framed: Some(Framed::new(req, Codec::new(self.config.clone()))), framed: Some(Framed::new(
req.into_parts().0,
Codec::new(self.config.clone()),
)),
} }
} }
} }
@@ -550,28 +406,28 @@ where
#[doc(hidden)] #[doc(hidden)]
pub struct OneRequestServiceResponse<T> pub struct OneRequestServiceResponse<T>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
{ {
framed: Option<Framed<T, Codec>>, framed: Option<Framed<T, Codec>>,
} }
impl<T> Future for OneRequestServiceResponse<T> impl<T> Future for OneRequestServiceResponse<T>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
{ {
type Output = Result<(Request, Framed<T, Codec>), ParseError>; type Item = (Request, Framed<T, Codec>);
type Error = ParseError;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.framed.as_mut().unwrap().next_item(cx) { match self.framed.as_mut().unwrap().poll()? {
Poll::Ready(Some(Ok(req))) => match req { Async::Ready(Some(req)) => match req {
Message::Item(req) => { Message::Item(req) => {
Poll::Ready(Ok((req, self.framed.take().unwrap()))) Ok(Async::Ready((req, self.framed.take().unwrap())))
} }
Message::Chunk(_) => unreachable!("Something is wrong"), Message::Chunk(_) => unreachable!("Something is wrong"),
}, },
Poll::Ready(Some(Err(err))) => Poll::Ready(Err(err)), Async::Ready(None) => Err(ParseError::Incomplete),
Poll::Ready(None) => Poll::Ready(Err(ParseError::Incomplete)), Async::NotReady => Ok(Async::NotReady),
Poll::Pending => Poll::Pending,
} }
} }
} }

View File

@@ -1,9 +1,10 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use std::task::{Context, Poll};
use actix_codec::Framed; use actix_codec::Framed;
use actix_service::{Service, ServiceFactory}; use actix_server_config::ServerConfig;
use futures_util::future::Ready; use actix_service::{NewService, Service};
use futures::future::FutureResult;
use futures::{Async, Poll};
use crate::error::Error; use crate::error::Error;
use crate::h1::Codec; use crate::h1::Codec;
@@ -11,16 +12,16 @@ use crate::request::Request;
pub struct UpgradeHandler<T>(PhantomData<T>); pub struct UpgradeHandler<T>(PhantomData<T>);
impl<T> ServiceFactory for UpgradeHandler<T> { impl<T> NewService for UpgradeHandler<T> {
type Config = (); type Config = ServerConfig;
type Request = (Request, Framed<T, Codec>); type Request = (Request, Framed<T, Codec>);
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type Service = UpgradeHandler<T>; type Service = UpgradeHandler<T>;
type InitError = Error; type InitError = Error;
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: &ServerConfig) -> Self::Future {
unimplemented!() unimplemented!()
} }
} }
@@ -29,10 +30,10 @@ impl<T> Service for UpgradeHandler<T> {
type Request = (Request, Framed<T, Codec>); type Request = (Request, Framed<T, Codec>);
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type Future = Ready<Result<Self::Response, Self::Error>>; type Future = FutureResult<Self::Response, Self::Error>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} }
fn call(&mut self, _: Self::Request) -> Self::Future { fn call(&mut self, _: Self::Request) -> Self::Future {

View File

@@ -1,8 +1,5 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use futures::{Async, Future, Poll, Sink};
use crate::body::{BodySize, MessageBody, ResponseBody}; use crate::body::{BodySize, MessageBody, ResponseBody};
use crate::error::Error; use crate::error::Error;
@@ -10,7 +7,6 @@ use crate::h1::{Codec, Message};
use crate::response::Response; use crate::response::Response;
/// Send http/1 response /// Send http/1 response
#[pin_project::pin_project]
pub struct SendResponse<T, B> { pub struct SendResponse<T, B> {
res: Option<Message<(Response<()>, BodySize)>>, res: Option<Message<(Response<()>, BodySize)>>,
body: Option<ResponseBody<B>>, body: Option<ResponseBody<B>>,
@@ -37,61 +33,60 @@ where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
B: MessageBody, B: MessageBody,
{ {
type Output = Result<Framed<T, Codec>, Error>; type Item = Framed<T, Codec>;
type Error = Error;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop { loop {
let mut body_ready = this.body.is_some(); let mut body_ready = self.body.is_some();
let framed = this.framed.as_mut().unwrap(); let framed = self.framed.as_mut().unwrap();
// send body // send body
if this.res.is_none() && this.body.is_some() { if self.res.is_none() && self.body.is_some() {
while body_ready && this.body.is_some() && !framed.is_write_buf_full() { while body_ready && self.body.is_some() && !framed.is_write_buf_full() {
match this.body.as_mut().unwrap().poll_next(cx)? { match self.body.as_mut().unwrap().poll_next()? {
Poll::Ready(item) => { Async::Ready(item) => {
// body is done // body is done
if item.is_none() { if item.is_none() {
let _ = this.body.take(); let _ = self.body.take();
} }
framed.write(Message::Chunk(item))?; framed.force_send(Message::Chunk(item))?;
} }
Poll::Pending => body_ready = false, Async::NotReady => body_ready = false,
} }
} }
} }
// flush write buffer // flush write buffer
if !framed.is_write_buf_empty() { if !framed.is_write_buf_empty() {
match framed.flush(cx)? { match framed.poll_complete()? {
Poll::Ready(_) => { Async::Ready(_) => {
if body_ready { if body_ready {
continue; continue;
} else { } else {
return Poll::Pending; return Ok(Async::NotReady);
} }
} }
Poll::Pending => return Poll::Pending, Async::NotReady => return Ok(Async::NotReady),
} }
} }
// send response // send response
if let Some(res) = this.res.take() { if let Some(res) = self.res.take() {
framed.write(res)?; framed.force_send(res)?;
continue; continue;
} }
if this.body.is_some() { if self.body.is_some() {
if body_ready { if body_ready {
continue; continue;
} else { } else {
return Poll::Pending; return Ok(Async::NotReady);
} }
} else { } else {
break; break;
} }
} }
Poll::Ready(Ok(this.framed.take().unwrap())) Ok(Async::Ready(self.framed.take().unwrap()))
} }
} }

View File

@@ -1,23 +1,27 @@
use std::convert::TryFrom; use std::collections::VecDeque;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::net; use std::time::Instant;
use std::pin::Pin; use std::{fmt, mem, net};
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_rt::time::{Delay, Instant}; use actix_server_config::IoStream;
use actix_service::Service; use actix_service::Service;
use bitflags::bitflags;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{try_ready, Async, Future, Poll, Sink, Stream};
use h2::server::{Connection, SendResponse}; use h2::server::{Connection, SendResponse};
use h2::SendStream; use h2::{RecvStream, SendStream};
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; use http::header::{
use log::{error, trace}; HeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING,
};
use http::HttpTryFrom;
use log::{debug, error, trace};
use tokio_timer::Delay;
use crate::body::{BodySize, MessageBody, ResponseBody}; use crate::body::{Body, BodySize, MessageBody, ResponseBody};
use crate::cloneable::CloneableService; use crate::cloneable::CloneableService;
use crate::config::ServiceConfig; use crate::config::ServiceConfig;
use crate::error::{DispatchError, Error}; use crate::error::{DispatchError, Error, ParseError, PayloadError, ResponseError};
use crate::helpers::DataFactory; use crate::helpers::DataFactory;
use crate::httpmessage::HttpMessage; use crate::httpmessage::HttpMessage;
use crate::message::ResponseHead; use crate::message::ResponseHead;
@@ -28,11 +32,7 @@ use crate::response::Response;
const CHUNK_SIZE: usize = 16_384; const CHUNK_SIZE: usize = 16_384;
/// Dispatcher for HTTP/2 protocol /// Dispatcher for HTTP/2 protocol
#[pin_project::pin_project] pub struct Dispatcher<T: IoStream, S: Service<Request = Request>, B: MessageBody> {
pub struct Dispatcher<T, S: Service<Request = Request>, B: MessageBody>
where
T: AsyncRead + AsyncWrite + Unpin,
{
service: CloneableService<S>, service: CloneableService<S>,
connection: Connection<T, Bytes>, connection: Connection<T, Bytes>,
on_connect: Option<Box<dyn DataFactory>>, on_connect: Option<Box<dyn DataFactory>>,
@@ -45,12 +45,12 @@ where
impl<T, S, B> Dispatcher<T, S, B> impl<T, S, B> Dispatcher<T, S, B>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error>, S::Error: Into<Error>,
// S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
B: MessageBody, B: MessageBody + 'static,
{ {
pub(crate) fn new( pub(crate) fn new(
service: CloneableService<S>, service: CloneableService<S>,
@@ -91,77 +91,63 @@ where
impl<T, S, B> Future for Dispatcher<T, S, B> impl<T, S, B> Future for Dispatcher<T, S, B>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
type Output = Result<(), DispatchError>; type Item = ();
type Error = DispatchError;
#[inline] #[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let this = self.get_mut();
loop { loop {
match Pin::new(&mut this.connection).poll_accept(cx) { match self.connection.poll()? {
Poll::Ready(None) => return Poll::Ready(Ok(())), Async::Ready(None) => return Ok(Async::Ready(())),
Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())), Async::Ready(Some((req, res))) => {
Poll::Ready(Some(Ok((req, res)))) => {
// update keep-alive expire // update keep-alive expire
if this.ka_timer.is_some() { if self.ka_timer.is_some() {
if let Some(expire) = this.config.keep_alive_expire() { if let Some(expire) = self.config.keep_alive_expire() {
this.ka_expire = expire; self.ka_expire = expire;
} }
} }
let (parts, body) = req.into_parts(); let (parts, body) = req.into_parts();
let mut req = Request::with_payload(Payload::< let mut req = Request::with_payload(body.into());
crate::payload::PayloadStream,
>::H2(
crate::h2::Payload::new(body)
));
let head = &mut req.head_mut(); let head = &mut req.head_mut();
head.uri = parts.uri; head.uri = parts.uri;
head.method = parts.method; head.method = parts.method;
head.version = parts.version; head.version = parts.version;
head.headers = parts.headers.into(); head.headers = parts.headers.into();
head.peer_addr = this.peer_addr; head.peer_addr = self.peer_addr;
// set on_connect data // set on_connect data
if let Some(ref on_connect) = this.on_connect { if let Some(ref on_connect) = self.on_connect {
on_connect.set(&mut req.extensions_mut()); on_connect.set(&mut req.extensions_mut());
} }
actix_rt::spawn(ServiceResponse::< tokio_current_thread::spawn(ServiceResponse::<S::Future, B> {
S::Future,
S::Response,
S::Error,
B,
> {
state: ServiceResponseState::ServiceCall( state: ServiceResponseState::ServiceCall(
this.service.call(req), self.service.call(req),
Some(res), Some(res),
), ),
config: this.config.clone(), config: self.config.clone(),
buffer: None, buffer: None,
_t: PhantomData, })
});
} }
Poll::Pending => return Poll::Pending, Async::NotReady => return Ok(Async::NotReady),
} }
} }
} }
} }
#[pin_project::pin_project] struct ServiceResponse<F, B> {
struct ServiceResponse<F, I, E, B> {
state: ServiceResponseState<F, B>, state: ServiceResponseState<F, B>,
config: ServiceConfig, config: ServiceConfig,
buffer: Option<Bytes>, buffer: Option<Bytes>,
_t: PhantomData<(I, E)>,
} }
enum ServiceResponseState<F, B> { enum ServiceResponseState<F, B> {
@@ -169,12 +155,12 @@ enum ServiceResponseState<F, B> {
SendPayload(SendStream<Bytes>, ResponseBody<B>), SendPayload(SendStream<Bytes>, ResponseBody<B>),
} }
impl<F, I, E, B> ServiceResponse<F, I, E, B> impl<F, B> ServiceResponse<F, B>
where where
F: Future<Output = Result<I, E>>, F: Future,
E: Into<Error>, F::Error: Into<Error>,
I: Into<Response<B>>, F::Item: Into<Response<B>>,
B: MessageBody, B: MessageBody + 'static,
{ {
fn prepare_response( fn prepare_response(
&self, &self,
@@ -229,130 +215,117 @@ where
if !has_date { if !has_date {
let mut bytes = BytesMut::with_capacity(29); let mut bytes = BytesMut::with_capacity(29);
self.config.set_date_header(&mut bytes); self.config.set_date_header(&mut bytes);
res.headers_mut().insert(DATE, unsafe { res.headers_mut()
HeaderValue::from_maybe_shared_unchecked(bytes.freeze()) .insert(DATE, HeaderValue::try_from(bytes.freeze()).unwrap());
});
} }
res res
} }
} }
impl<F, I, E, B> Future for ServiceResponse<F, I, E, B> impl<F, B> Future for ServiceResponse<F, B>
where where
F: Future<Output = Result<I, E>>, F: Future,
E: Into<Error>, F::Error: Into<Error>,
I: Into<Response<B>>, F::Item: Into<Response<B>>,
B: MessageBody, B: MessageBody + 'static,
{ {
type Output = (); type Item = ();
type Error = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut this = self.as_mut().project(); match self.state {
match this.state {
ServiceResponseState::ServiceCall(ref mut call, ref mut send) => { ServiceResponseState::ServiceCall(ref mut call, ref mut send) => {
match unsafe { Pin::new_unchecked(call) }.poll(cx) { match call.poll() {
Poll::Ready(Ok(res)) => { Ok(Async::Ready(res)) => {
let (res, body) = res.into().replace_body(()); let (res, body) = res.into().replace_body(());
let mut send = send.take().unwrap(); let mut send = send.take().unwrap();
let mut size = body.size(); let mut size = body.size();
let h2_res = let h2_res = self.prepare_response(res.head(), &mut size);
self.as_mut().prepare_response(res.head(), &mut size);
this = self.as_mut().project();
let stream = match send.send_response(h2_res, size.is_eof()) { let stream =
Err(e) => { send.send_response(h2_res, size.is_eof()).map_err(|e| {
trace!("Error sending h2 response: {:?}", e); trace!("Error sending h2 response: {:?}", e);
return Poll::Ready(()); })?;
}
Ok(stream) => stream,
};
if size.is_eof() { if size.is_eof() {
Poll::Ready(()) Ok(Async::Ready(()))
} else { } else {
*this.state = self.state = ServiceResponseState::SendPayload(stream, body);
ServiceResponseState::SendPayload(stream, body); self.poll()
self.poll(cx)
} }
} }
Poll::Pending => Poll::Pending, Ok(Async::NotReady) => Ok(Async::NotReady),
Poll::Ready(Err(e)) => { Err(e) => {
let res: Response = e.into().into(); let res: Response = e.into().into();
let (res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
let mut send = send.take().unwrap(); let mut send = send.take().unwrap();
let mut size = body.size(); let mut size = body.size();
let h2_res = let h2_res = self.prepare_response(res.head(), &mut size);
self.as_mut().prepare_response(res.head(), &mut size);
this = self.as_mut().project();
let stream = match send.send_response(h2_res, size.is_eof()) { let stream =
Err(e) => { send.send_response(h2_res, size.is_eof()).map_err(|e| {
trace!("Error sending h2 response: {:?}", e); trace!("Error sending h2 response: {:?}", e);
return Poll::Ready(()); })?;
}
Ok(stream) => stream,
};
if size.is_eof() { if size.is_eof() {
Poll::Ready(()) Ok(Async::Ready(()))
} else { } else {
*this.state = ServiceResponseState::SendPayload( self.state = ServiceResponseState::SendPayload(
stream, stream,
body.into_body(), body.into_body(),
); );
self.poll(cx) self.poll()
} }
} }
} }
} }
ServiceResponseState::SendPayload(ref mut stream, ref mut body) => loop { ServiceResponseState::SendPayload(ref mut stream, ref mut body) => loop {
loop { loop {
if let Some(ref mut buffer) = this.buffer { if let Some(ref mut buffer) = self.buffer {
match stream.poll_capacity(cx) { match stream.poll_capacity().map_err(|e| warn!("{:?}", e))? {
Poll::Pending => return Poll::Pending, Async::NotReady => return Ok(Async::NotReady),
Poll::Ready(None) => return Poll::Ready(()), Async::Ready(None) => return Ok(Async::Ready(())),
Poll::Ready(Some(Ok(cap))) => { Async::Ready(Some(cap)) => {
let len = buffer.len(); let len = buffer.len();
let bytes = buffer.split_to(std::cmp::min(cap, len)); let bytes = buffer.split_to(std::cmp::min(cap, len));
if let Err(e) = stream.send_data(bytes, false) { if let Err(e) = stream.send_data(bytes, false) {
warn!("{:?}", e); warn!("{:?}", e);
return Poll::Ready(()); return Err(());
} else if !buffer.is_empty() { } else if !buffer.is_empty() {
let cap = std::cmp::min(buffer.len(), CHUNK_SIZE); let cap = std::cmp::min(buffer.len(), CHUNK_SIZE);
stream.reserve_capacity(cap); stream.reserve_capacity(cap);
} else { } else {
this.buffer.take(); self.buffer.take();
} }
} }
Poll::Ready(Some(Err(e))) => {
warn!("{:?}", e);
return Poll::Ready(());
}
} }
} else { } else {
match body.poll_next(cx) { match body.poll_next() {
Poll::Pending => return Poll::Pending, Ok(Async::NotReady) => {
Poll::Ready(None) => { return Ok(Async::NotReady);
}
Ok(Async::Ready(None)) => {
if let Err(e) = stream.send_data(Bytes::new(), true) { if let Err(e) = stream.send_data(Bytes::new(), true) {
warn!("{:?}", e); warn!("{:?}", e);
return Err(());
} else {
return Ok(Async::Ready(()));
} }
return Poll::Ready(());
} }
Poll::Ready(Some(Ok(chunk))) => { Ok(Async::Ready(Some(chunk))) => {
stream.reserve_capacity(std::cmp::min( stream.reserve_capacity(std::cmp::min(
chunk.len(), chunk.len(),
CHUNK_SIZE, CHUNK_SIZE,
)); ));
*this.buffer = Some(chunk); self.buffer = Some(chunk);
} }
Poll::Ready(Some(Err(e))) => { Err(e) => {
error!("Response payload stream error: {:?}", e); error!("Response payload stream error: {:?}", e);
return Poll::Ready(()); return Err(());
} }
} }
} }

View File

@@ -1,9 +1,9 @@
//! HTTP/2 implementation #![allow(dead_code, unused_imports)]
use std::pin::Pin;
use std::task::{Context, Poll}; use std::fmt;
use bytes::Bytes; use bytes::Bytes;
use futures_core::Stream; use futures::{Async, Poll, Stream};
use h2::RecvStream; use h2::RecvStream;
mod dispatcher; mod dispatcher;
@@ -25,26 +25,22 @@ impl Payload {
} }
impl Stream for Payload { impl Stream for Payload {
type Item = Result<Bytes, PayloadError>; type Item = Bytes;
type Error = PayloadError;
fn poll_next( fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
self: Pin<&mut Self>, match self.pl.poll() {
cx: &mut Context<'_>, Ok(Async::Ready(Some(chunk))) => {
) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match Pin::new(&mut this.pl).poll_data(cx) {
Poll::Ready(Some(Ok(chunk))) => {
let len = chunk.len(); let len = chunk.len();
if let Err(err) = this.pl.flow_control().release_capacity(len) { if let Err(err) = self.pl.release_capacity().release_capacity(len) {
Poll::Ready(Some(Err(err.into()))) Err(err.into())
} else { } else {
Poll::Ready(Some(Ok(chunk))) Ok(Async::Ready(Some(chunk)))
} }
} }
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))), Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
Poll::Pending => Poll::Pending, Ok(Async::NotReady) => Ok(Async::NotReady),
Poll::Ready(None) => Poll::Ready(None), Err(err) => Err(err.into()),
} }
} }
} }

View File

@@ -1,56 +1,62 @@
use std::future::Future; use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin; use std::{io, net, rc};
use std::task::{Context, Poll};
use std::{net, rc};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_rt::net::TcpStream; use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig};
use actix_service::{ use actix_service::{IntoNewService, NewService, Service};
fn_factory, fn_service, pipeline_factory, IntoServiceFactory, Service,
ServiceFactory,
};
use bytes::Bytes; use bytes::Bytes;
use futures_core::ready; use futures::future::{ok, FutureResult};
use futures_util::future::ok; use futures::{try_ready, Async, Future, IntoFuture, Poll, Stream};
use h2::server::{self, Handshake}; use h2::server::{self, Connection, Handshake};
use h2::RecvStream;
use log::error; use log::error;
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::cloneable::CloneableService; use crate::cloneable::CloneableService;
use crate::config::ServiceConfig; use crate::config::{KeepAlive, ServiceConfig};
use crate::error::{DispatchError, Error}; use crate::error::{DispatchError, Error, ParseError, ResponseError};
use crate::helpers::DataFactory; use crate::helpers::DataFactory;
use crate::payload::Payload;
use crate::request::Request; use crate::request::Request;
use crate::response::Response; use crate::response::Response;
use super::dispatcher::Dispatcher; use super::dispatcher::Dispatcher;
/// `ServiceFactory` implementation for HTTP2 transport /// `NewService` implementation for HTTP2 transport
pub struct H2Service<T, S, B> { pub struct H2Service<T, P, S, B> {
srv: S, srv: S,
cfg: ServiceConfig, cfg: ServiceConfig,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
_t: PhantomData<(T, B)>, _t: PhantomData<(T, P, B)>,
} }
impl<T, S, B> H2Service<T, S, B> impl<T, P, S, B> H2Service<T, P, S, B>
where where
S: ServiceFactory<Config = (), Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
/// Create new `HttpService` instance with config. /// Create new `HttpService` instance.
pub(crate) fn with_config<F: IntoServiceFactory<S>>( pub fn new<F: IntoNewService<S>>(service: F) -> Self {
cfg: ServiceConfig, let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0);
service: F,
) -> Self {
H2Service { H2Service {
cfg, cfg,
on_connect: None, on_connect: None,
srv: service.into_factory(), srv: service.into_new_service(),
_t: PhantomData,
}
}
/// Create new `HttpService` instance with config.
pub fn with_config<F: IntoNewService<S>>(cfg: ServiceConfig, service: F) -> Self {
H2Service {
cfg,
on_connect: None,
srv: service.into_new_service(),
_t: PhantomData, _t: PhantomData,
} }
} }
@@ -65,144 +71,26 @@ where
} }
} }
impl<S, B> H2Service<TcpStream, S, B> impl<T, P, S, B> NewService for H2Service<T, P, S, B>
where where
S: ServiceFactory<Config = (), Request = Request>, T: IoStream,
S::Error: Into<Error> + 'static, S: NewService<Config = SrvConfig, Request = Request>,
S::Response: Into<Response<B>> + 'static, S::Error: Into<Error>,
S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
/// Create simple tcp based service type Config = SrvConfig;
pub fn tcp( type Request = Io<T, P>;
self,
) -> impl ServiceFactory<
Config = (),
Request = TcpStream,
Response = (),
Error = DispatchError,
InitError = S::InitError,
> {
pipeline_factory(fn_factory(|| {
async {
Ok::<_, S::InitError>(fn_service(|io: TcpStream| {
let peer_addr = io.peer_addr().ok();
ok::<_, DispatchError>((io, peer_addr))
}))
}
}))
.and_then(self)
}
}
#[cfg(feature = "openssl")]
mod openssl {
use actix_service::{fn_factory, fn_service};
use actix_tls::openssl::{Acceptor, SslAcceptor, SslStream};
use actix_tls::{openssl::HandshakeError, SslError};
use super::*;
impl<S, B> H2Service<SslStream<TcpStream>, S, B>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: Into<Error> + 'static,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
{
/// Create ssl based service
pub fn openssl(
self,
acceptor: SslAcceptor,
) -> impl ServiceFactory<
Config = (),
Request = TcpStream,
Response = (),
Error = SslError<HandshakeError<TcpStream>, DispatchError>,
InitError = S::InitError,
> {
pipeline_factory(
Acceptor::new(acceptor)
.map_err(SslError::Ssl)
.map_init_err(|_| panic!()),
)
.and_then(fn_factory(|| {
ok::<_, S::InitError>(fn_service(|io: SslStream<TcpStream>| {
let peer_addr = io.get_ref().peer_addr().ok();
ok((io, peer_addr))
}))
}))
.and_then(self.map_err(SslError::Service))
}
}
}
#[cfg(feature = "rustls")]
mod rustls {
use super::*;
use actix_tls::rustls::{Acceptor, ServerConfig, TlsStream};
use actix_tls::SslError;
use std::io;
impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: Into<Error> + 'static,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
{
/// Create openssl based service
pub fn rustls(
self,
mut config: ServerConfig,
) -> impl ServiceFactory<
Config = (),
Request = TcpStream,
Response = (),
Error = SslError<io::Error, DispatchError>,
InitError = S::InitError,
> {
let protos = vec!["h2".to_string().into()];
config.set_protocols(&protos);
pipeline_factory(
Acceptor::new(config)
.map_err(SslError::Ssl)
.map_init_err(|_| panic!()),
)
.and_then(fn_factory(|| {
ok::<_, S::InitError>(fn_service(|io: TlsStream<TcpStream>| {
let peer_addr = io.get_ref().0.peer_addr().ok();
ok((io, peer_addr))
}))
}))
.and_then(self.map_err(SslError::Service))
}
}
}
impl<T, S, B> ServiceFactory for H2Service<T, S, B>
where
T: AsyncRead + AsyncWrite + Unpin,
S: ServiceFactory<Config = (), Request = Request>,
S::Error: Into<Error> + 'static,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
{
type Config = ();
type Request = (T, Option<net::SocketAddr>);
type Response = (); type Response = ();
type Error = DispatchError; type Error = DispatchError;
type InitError = S::InitError; type InitError = S::InitError;
type Service = H2ServiceHandler<T, S::Service, B>; type Service = H2ServiceHandler<T, P, S::Service, B>;
type Future = H2ServiceResponse<T, S, B>; type Future = H2ServiceResponse<T, P, S, B>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, cfg: &SrvConfig) -> Self::Future {
H2ServiceResponse { H2ServiceResponse {
fut: self.srv.new_service(()), fut: self.srv.new_service(cfg).into_future(),
cfg: Some(self.cfg.clone()), cfg: Some(self.cfg.clone()),
on_connect: self.on_connect.clone(), on_connect: self.on_connect.clone(),
_t: PhantomData, _t: PhantomData,
@@ -211,61 +99,56 @@ where
} }
#[doc(hidden)] #[doc(hidden)]
#[pin_project::pin_project] pub struct H2ServiceResponse<T, P, S: NewService, B> {
pub struct H2ServiceResponse<T, S: ServiceFactory, B> { fut: <S::Future as IntoFuture>::Future,
#[pin]
fut: S::Future,
cfg: Option<ServiceConfig>, cfg: Option<ServiceConfig>,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
_t: PhantomData<(T, B)>, _t: PhantomData<(T, P, B)>,
} }
impl<T, S, B> Future for H2ServiceResponse<T, S, B> impl<T, P, S, B> Future for H2ServiceResponse<T, P, S, B>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
S: ServiceFactory<Config = (), Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
type Output = Result<H2ServiceHandler<T, S::Service, B>, S::InitError>; type Item = H2ServiceHandler<T, P, S::Service, B>;
type Error = S::InitError;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let this = self.as_mut().project(); let service = try_ready!(self.fut.poll());
Ok(Async::Ready(H2ServiceHandler::new(
Poll::Ready(ready!(this.fut.poll(cx)).map(|service| { self.cfg.take().unwrap(),
let this = self.as_mut().project(); self.on_connect.clone(),
H2ServiceHandler::new(
this.cfg.take().unwrap(),
this.on_connect.clone(),
service, service,
) )))
}))
} }
} }
/// `Service` implementation for http/2 transport /// `Service` implementation for http/2 transport
pub struct H2ServiceHandler<T, S, B> { pub struct H2ServiceHandler<T, P, S, B> {
srv: CloneableService<S>, srv: CloneableService<S>,
cfg: ServiceConfig, cfg: ServiceConfig,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
_t: PhantomData<(T, B)>, _t: PhantomData<(T, P, B)>,
} }
impl<T, S, B> H2ServiceHandler<T, S, B> impl<T, P, S, B> H2ServiceHandler<T, P, S, B>
where where
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
fn new( fn new(
cfg: ServiceConfig, cfg: ServiceConfig,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
srv: S, srv: S,
) -> H2ServiceHandler<T, S, B> { ) -> H2ServiceHandler<T, P, S, B> {
H2ServiceHandler { H2ServiceHandler {
cfg, cfg,
on_connect, on_connect,
@@ -275,29 +158,31 @@ where
} }
} }
impl<T, S, B> Service for H2ServiceHandler<T, S, B> impl<T, P, S, B> Service for H2ServiceHandler<T, P, S, B>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
type Request = (T, Option<net::SocketAddr>); type Request = Io<T, P>;
type Response = (); type Response = ();
type Error = DispatchError; type Error = DispatchError;
type Future = H2ServiceHandlerResponse<T, S, B>; type Future = H2ServiceHandlerResponse<T, S, B>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.srv.poll_ready(cx).map_err(|e| { self.srv.poll_ready().map_err(|e| {
let e = e.into(); let e = e.into();
error!("Service readiness error: {:?}", e); error!("Service readiness error: {:?}", e);
DispatchError::Service(e) DispatchError::Service(e)
}) })
} }
fn call(&mut self, (io, addr): Self::Request) -> Self::Future { fn call(&mut self, req: Self::Request) -> Self::Future {
let io = req.into_parts().0;
let peer_addr = io.peer_addr();
let on_connect = if let Some(ref on_connect) = self.on_connect { let on_connect = if let Some(ref on_connect) = self.on_connect {
Some(on_connect(&io)) Some(on_connect(&io))
} else { } else {
@@ -308,7 +193,7 @@ where
state: State::Handshake( state: State::Handshake(
Some(self.srv.clone()), Some(self.srv.clone()),
Some(self.cfg.clone()), Some(self.cfg.clone()),
addr, peer_addr,
on_connect, on_connect,
server::handshake(io), server::handshake(io),
), ),
@@ -316,9 +201,8 @@ where
} }
} }
enum State<T, S: Service<Request = Request>, B: MessageBody> enum State<T: IoStream, S: Service<Request = Request>, B: MessageBody>
where where
T: AsyncRead + AsyncWrite + Unpin,
S::Future: 'static, S::Future: 'static,
{ {
Incoming(Dispatcher<T, S, B>), Incoming(Dispatcher<T, S, B>),
@@ -333,11 +217,11 @@ where
pub struct H2ServiceHandlerResponse<T, S, B> pub struct H2ServiceHandlerResponse<T, S, B>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
state: State<T, S, B>, state: State<T, S, B>,
@@ -345,26 +229,27 @@ where
impl<T, S, B> Future for H2ServiceHandlerResponse<T, S, B> impl<T, S, B> Future for H2ServiceHandlerResponse<T, S, B>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody, B: MessageBody,
{ {
type Output = Result<(), DispatchError>; type Item = ();
type Error = DispatchError;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.state { match self.state {
State::Incoming(ref mut disp) => Pin::new(disp).poll(cx), State::Incoming(ref mut disp) => disp.poll(),
State::Handshake( State::Handshake(
ref mut srv, ref mut srv,
ref mut config, ref mut config,
ref peer_addr, ref peer_addr,
ref mut on_connect, ref mut on_connect,
ref mut handshake, ref mut handshake,
) => match Pin::new(handshake).poll(cx) { ) => match handshake.poll() {
Poll::Ready(Ok(conn)) => { Ok(Async::Ready(conn)) => {
self.state = State::Incoming(Dispatcher::new( self.state = State::Incoming(Dispatcher::new(
srv.take().unwrap(), srv.take().unwrap(),
conn, conn,
@@ -373,13 +258,13 @@ where
None, None,
*peer_addr, *peer_addr,
)); ));
self.poll(cx) self.poll()
} }
Poll::Ready(Err(err)) => { Ok(Async::NotReady) => Ok(Async::NotReady),
Err(err) => {
trace!("H2 handshake error: {}", err); trace!("H2 handshake error: {}", err);
Poll::Ready(Err(err.into())) Err(err.into())
} }
Poll::Pending => Poll::Pending,
}, },
} }
} }

View File

@@ -74,18 +74,18 @@ impl Header for CacheControl {
} }
impl fmt::Display for CacheControl { impl fmt::Display for CacheControl {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt_comma_delimited(f, &self[..]) fmt_comma_delimited(f, &self[..])
} }
} }
impl IntoHeaderValue for CacheControl { impl IntoHeaderValue for CacheControl {
type Error = header::InvalidHeaderValue; type Error = header::InvalidHeaderValueBytes;
fn try_into(self) -> Result<header::HeaderValue, Self::Error> { fn try_into(self) -> Result<header::HeaderValue, Self::Error> {
let mut writer = Writer::new(); let mut writer = Writer::new();
let _ = write!(&mut writer, "{}", self); let _ = write!(&mut writer, "{}", self);
header::HeaderValue::from_maybe_shared(writer.take()) header::HeaderValue::from_shared(writer.take())
} }
} }
@@ -126,7 +126,7 @@ pub enum CacheDirective {
} }
impl fmt::Display for CacheDirective { impl fmt::Display for CacheDirective {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use self::CacheDirective::*; use self::CacheDirective::*;
fmt::Display::fmt( fmt::Display::fmt(
match *self { match *self {

View File

@@ -462,12 +462,12 @@ impl ContentDisposition {
} }
impl IntoHeaderValue for ContentDisposition { impl IntoHeaderValue for ContentDisposition {
type Error = header::InvalidHeaderValue; type Error = header::InvalidHeaderValueBytes;
fn try_into(self) -> Result<header::HeaderValue, Self::Error> { fn try_into(self) -> Result<header::HeaderValue, Self::Error> {
let mut writer = Writer::new(); let mut writer = Writer::new();
let _ = write!(&mut writer, "{}", self); let _ = write!(&mut writer, "{}", self);
header::HeaderValue::from_maybe_shared(writer.take()) header::HeaderValue::from_shared(writer.take())
} }
} }
@@ -486,7 +486,7 @@ impl Header for ContentDisposition {
} }
impl fmt::Display for DispositionType { impl fmt::Display for DispositionType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
DispositionType::Inline => write!(f, "inline"), DispositionType::Inline => write!(f, "inline"),
DispositionType::Attachment => write!(f, "attachment"), DispositionType::Attachment => write!(f, "attachment"),
@@ -497,7 +497,7 @@ impl fmt::Display for DispositionType {
} }
impl fmt::Display for DispositionParam { impl fmt::Display for DispositionParam {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// All ASCII control characters (0-30, 127) including horizontal tab, double quote, and // All ASCII control characters (0-30, 127) including horizontal tab, double quote, and
// backslash should be escaped in quoted-string (i.e. "foobar"). // backslash should be escaped in quoted-string (i.e. "foobar").
// Ref: RFC6266 S4.1 -> RFC2616 S3.6 // Ref: RFC6266 S4.1 -> RFC2616 S3.6
@@ -555,7 +555,7 @@ impl fmt::Display for DispositionParam {
} }
impl fmt::Display for ContentDisposition { impl fmt::Display for ContentDisposition {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.disposition)?; write!(f, "{}", self.disposition)?;
self.parameters self.parameters
.iter() .iter()
@@ -768,7 +768,8 @@ mod tests {
Mainstream browsers like Firefox (gecko) and Chrome use UTF-8 directly as above. Mainstream browsers like Firefox (gecko) and Chrome use UTF-8 directly as above.
(And now, only UTF-8 is handled by this implementation.) (And now, only UTF-8 is handled by this implementation.)
*/ */
let a = HeaderValue::from_str("form-data; name=upload; filename=\"文件.webp\"") let a =
HeaderValue::from_str("form-data; name=upload; filename=\"文件.webp\"")
.unwrap(); .unwrap();
let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap();
let b = ContentDisposition { let b = ContentDisposition {
@@ -883,11 +884,7 @@ mod tests {
assert!(ContentDisposition::from_raw(&a).is_err()); assert!(ContentDisposition::from_raw(&a).is_err());
let a = HeaderValue::from_static("inline; filename=\"\""); let a = HeaderValue::from_static("inline; filename=\"\"");
assert!(ContentDisposition::from_raw(&a) assert!(ContentDisposition::from_raw(&a).expect("parse cd").get_filename().expect("filename").is_empty());
.expect("parse cd")
.get_filename()
.expect("filename")
.is_empty());
} }
#[test] #[test]

View File

@@ -3,7 +3,7 @@ use std::str::FromStr;
use crate::error::ParseError; use crate::error::ParseError;
use crate::header::{ use crate::header::{
HeaderValue, IntoHeaderValue, InvalidHeaderValue, Writer, CONTENT_RANGE, HeaderValue, IntoHeaderValue, InvalidHeaderValueBytes, Writer, CONTENT_RANGE,
}; };
header! { header! {
@@ -166,7 +166,7 @@ impl FromStr for ContentRangeSpec {
} }
impl Display for ContentRangeSpec { impl Display for ContentRangeSpec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
ContentRangeSpec::Bytes { ContentRangeSpec::Bytes {
range, range,
@@ -198,11 +198,11 @@ impl Display for ContentRangeSpec {
} }
impl IntoHeaderValue for ContentRangeSpec { impl IntoHeaderValue for ContentRangeSpec {
type Error = InvalidHeaderValue; type Error = InvalidHeaderValueBytes;
fn try_into(self) -> Result<HeaderValue, Self::Error> { fn try_into(self) -> Result<HeaderValue, Self::Error> {
let mut writer = Writer::new(); let mut writer = Writer::new();
let _ = write!(&mut writer, "{}", self); let _ = write!(&mut writer, "{}", self);
HeaderValue::from_maybe_shared(writer.take()) HeaderValue::from_shared(writer.take())
} }
} }

View File

@@ -3,7 +3,7 @@ use std::fmt::{self, Display, Write};
use crate::error::ParseError; use crate::error::ParseError;
use crate::header::{ use crate::header::{
self, from_one_raw_str, EntityTag, Header, HeaderName, HeaderValue, HttpDate, self, from_one_raw_str, EntityTag, Header, HeaderName, HeaderValue, HttpDate,
IntoHeaderValue, InvalidHeaderValue, Writer, IntoHeaderValue, InvalidHeaderValueBytes, Writer,
}; };
use crate::httpmessage::HttpMessage; use crate::httpmessage::HttpMessage;
@@ -87,7 +87,7 @@ impl Header for IfRange {
} }
impl Display for IfRange { impl Display for IfRange {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
IfRange::EntityTag(ref x) => Display::fmt(x, f), IfRange::EntityTag(ref x) => Display::fmt(x, f),
IfRange::Date(ref x) => Display::fmt(x, f), IfRange::Date(ref x) => Display::fmt(x, f),
@@ -96,12 +96,12 @@ impl Display for IfRange {
} }
impl IntoHeaderValue for IfRange { impl IntoHeaderValue for IfRange {
type Error = InvalidHeaderValue; type Error = InvalidHeaderValueBytes;
fn try_into(self) -> Result<HeaderValue, Self::Error> { fn try_into(self) -> Result<HeaderValue, Self::Error> {
let mut writer = Writer::new(); let mut writer = Writer::new();
let _ = write!(&mut writer, "{}", self); let _ = write!(&mut writer, "{}", self);
HeaderValue::from_maybe_shared(writer.take()) HeaderValue::from_shared(writer.take())
} }
} }

View File

@@ -159,18 +159,18 @@ macro_rules! header {
} }
impl std::fmt::Display for $id { impl std::fmt::Display for $id {
#[inline] #[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> ::std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> ::std::fmt::Result {
$crate::http::header::fmt_comma_delimited(f, &self.0[..]) $crate::http::header::fmt_comma_delimited(f, &self.0[..])
} }
} }
impl $crate::http::header::IntoHeaderValue for $id { impl $crate::http::header::IntoHeaderValue for $id {
type Error = $crate::http::header::InvalidHeaderValue; type Error = $crate::http::header::InvalidHeaderValueBytes;
fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> {
use std::fmt::Write; use std::fmt::Write;
let mut writer = $crate::http::header::Writer::new(); let mut writer = $crate::http::header::Writer::new();
let _ = write!(&mut writer, "{}", self); let _ = write!(&mut writer, "{}", self);
$crate::http::header::HeaderValue::from_maybe_shared(writer.take()) $crate::http::header::HeaderValue::from_shared(writer.take())
} }
} }
}; };
@@ -195,18 +195,18 @@ macro_rules! header {
} }
impl std::fmt::Display for $id { impl std::fmt::Display for $id {
#[inline] #[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
$crate::http::header::fmt_comma_delimited(f, &self.0[..]) $crate::http::header::fmt_comma_delimited(f, &self.0[..])
} }
} }
impl $crate::http::header::IntoHeaderValue for $id { impl $crate::http::header::IntoHeaderValue for $id {
type Error = $crate::http::header::InvalidHeaderValue; type Error = $crate::http::header::InvalidHeaderValueBytes;
fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> {
use std::fmt::Write; use std::fmt::Write;
let mut writer = $crate::http::header::Writer::new(); let mut writer = $crate::http::header::Writer::new();
let _ = write!(&mut writer, "{}", self); let _ = write!(&mut writer, "{}", self);
$crate::http::header::HeaderValue::from_maybe_shared(writer.take()) $crate::http::header::HeaderValue::from_shared(writer.take())
} }
} }
}; };
@@ -231,12 +231,12 @@ macro_rules! header {
} }
impl std::fmt::Display for $id { impl std::fmt::Display for $id {
#[inline] #[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
std::fmt::Display::fmt(&self.0, f) std::fmt::Display::fmt(&self.0, f)
} }
} }
impl $crate::http::header::IntoHeaderValue for $id { impl $crate::http::header::IntoHeaderValue for $id {
type Error = $crate::http::header::InvalidHeaderValue; type Error = $crate::http::header::InvalidHeaderValueBytes;
fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> {
self.0.try_into() self.0.try_into()
@@ -276,7 +276,7 @@ macro_rules! header {
} }
impl std::fmt::Display for $id { impl std::fmt::Display for $id {
#[inline] #[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match *self { match *self {
$id::Any => f.write_str("*"), $id::Any => f.write_str("*"),
$id::Items(ref fields) => $crate::http::header::fmt_comma_delimited( $id::Items(ref fields) => $crate::http::header::fmt_comma_delimited(
@@ -285,13 +285,13 @@ macro_rules! header {
} }
} }
impl $crate::http::header::IntoHeaderValue for $id { impl $crate::http::header::IntoHeaderValue for $id {
type Error = $crate::http::header::InvalidHeaderValue; type Error = $crate::http::header::InvalidHeaderValueBytes;
fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> {
use std::fmt::Write; use std::fmt::Write;
let mut writer = $crate::http::header::Writer::new(); let mut writer = $crate::http::header::Writer::new();
let _ = write!(&mut writer, "{}", self); let _ = write!(&mut writer, "{}", self);
$crate::http::header::HeaderValue::from_maybe_shared(writer.take()) $crate::http::header::HeaderValue::from_shared(writer.take())
} }
} }
}; };

View File

@@ -1,9 +1,8 @@
use std::collections::hash_map::{self, Entry};
use std::convert::TryFrom;
use either::Either; use either::Either;
use fxhash::FxHashMap; use hashbrown::hash_map::{self, Entry};
use hashbrown::HashMap;
use http::header::{HeaderName, HeaderValue}; use http::header::{HeaderName, HeaderValue};
use http::HttpTryFrom;
/// A set of HTTP headers /// A set of HTTP headers
/// ///
@@ -12,7 +11,7 @@ use http::header::{HeaderName, HeaderValue};
/// [`HeaderName`]: struct.HeaderName.html /// [`HeaderName`]: struct.HeaderName.html
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct HeaderMap { pub struct HeaderMap {
pub(crate) inner: FxHashMap<HeaderName, Value>, pub(crate) inner: HashMap<HeaderName, Value>,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -57,7 +56,7 @@ impl HeaderMap {
/// allocate. /// allocate.
pub fn new() -> Self { pub fn new() -> Self {
HeaderMap { HeaderMap {
inner: FxHashMap::default(), inner: HashMap::new(),
} }
} }
@@ -71,7 +70,7 @@ impl HeaderMap {
/// More capacity than requested may be allocated. /// More capacity than requested may be allocated.
pub fn with_capacity(capacity: usize) -> HeaderMap { pub fn with_capacity(capacity: usize) -> HeaderMap {
HeaderMap { HeaderMap {
inner: FxHashMap::with_capacity_and_hasher(capacity, Default::default()), inner: HashMap::with_capacity(capacity),
} }
} }
@@ -143,7 +142,7 @@ impl HeaderMap {
/// Returns `None` if there are no values associated with the key. /// Returns `None` if there are no values associated with the key.
/// ///
/// [`GetAll`]: struct.GetAll.html /// [`GetAll`]: struct.GetAll.html
pub fn get_all<N: AsName>(&self, name: N) -> GetAll<'_> { pub fn get_all<N: AsName>(&self, name: N) -> GetAll {
GetAll { GetAll {
idx: 0, idx: 0,
item: self.get2(name), item: self.get2(name),
@@ -187,7 +186,7 @@ impl HeaderMap {
/// The iteration order is arbitrary, but consistent across platforms for /// The iteration order is arbitrary, but consistent across platforms for
/// the same crate version. Each key will be yielded once per associated /// the same crate version. Each key will be yielded once per associated
/// value. So, if a key has 3 associated values, it will be yielded 3 times. /// value. So, if a key has 3 associated values, it will be yielded 3 times.
pub fn iter(&self) -> Iter<'_> { pub fn iter(&self) -> Iter {
Iter::new(self.inner.iter()) Iter::new(self.inner.iter())
} }
@@ -196,7 +195,7 @@ impl HeaderMap {
/// The iteration order is arbitrary, but consistent across platforms for /// The iteration order is arbitrary, but consistent across platforms for
/// the same crate version. Each key will be yielded only once even if it /// the same crate version. Each key will be yielded only once even if it
/// has multiple associated values. /// has multiple associated values.
pub fn keys(&self) -> Keys<'_> { pub fn keys(&self) -> Keys {
Keys(self.inner.keys()) Keys(self.inner.keys())
} }

View File

@@ -1,7 +1,6 @@
//! Various http headers //! Various http headers
// This is mostly copy of [hyper](https://github.com/hyperium/hyper/tree/master/src/header) // This is mostly copy of [hyper](https://github.com/hyperium/hyper/tree/master/src/header)
use std::convert::TryFrom;
use std::{fmt, str::FromStr}; use std::{fmt, str::FromStr};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
@@ -74,58 +73,58 @@ impl<'a> IntoHeaderValue for &'a [u8] {
} }
impl IntoHeaderValue for Bytes { impl IntoHeaderValue for Bytes {
type Error = InvalidHeaderValue; type Error = InvalidHeaderValueBytes;
#[inline] #[inline]
fn try_into(self) -> Result<HeaderValue, Self::Error> { fn try_into(self) -> Result<HeaderValue, Self::Error> {
HeaderValue::from_maybe_shared(self) HeaderValue::from_shared(self)
} }
} }
impl IntoHeaderValue for Vec<u8> { impl IntoHeaderValue for Vec<u8> {
type Error = InvalidHeaderValue; type Error = InvalidHeaderValueBytes;
#[inline] #[inline]
fn try_into(self) -> Result<HeaderValue, Self::Error> { fn try_into(self) -> Result<HeaderValue, Self::Error> {
HeaderValue::try_from(self) HeaderValue::from_shared(Bytes::from(self))
} }
} }
impl IntoHeaderValue for String { impl IntoHeaderValue for String {
type Error = InvalidHeaderValue; type Error = InvalidHeaderValueBytes;
#[inline] #[inline]
fn try_into(self) -> Result<HeaderValue, Self::Error> { fn try_into(self) -> Result<HeaderValue, Self::Error> {
HeaderValue::try_from(self) HeaderValue::from_shared(Bytes::from(self))
} }
} }
impl IntoHeaderValue for usize { impl IntoHeaderValue for usize {
type Error = InvalidHeaderValue; type Error = InvalidHeaderValueBytes;
#[inline] #[inline]
fn try_into(self) -> Result<HeaderValue, Self::Error> { fn try_into(self) -> Result<HeaderValue, Self::Error> {
let s = format!("{}", self); let s = format!("{}", self);
HeaderValue::try_from(s) HeaderValue::from_shared(Bytes::from(s))
} }
} }
impl IntoHeaderValue for u64 { impl IntoHeaderValue for u64 {
type Error = InvalidHeaderValue; type Error = InvalidHeaderValueBytes;
#[inline] #[inline]
fn try_into(self) -> Result<HeaderValue, Self::Error> { fn try_into(self) -> Result<HeaderValue, Self::Error> {
let s = format!("{}", self); let s = format!("{}", self);
HeaderValue::try_from(s) HeaderValue::from_shared(Bytes::from(s))
} }
} }
impl IntoHeaderValue for Mime { impl IntoHeaderValue for Mime {
type Error = InvalidHeaderValue; type Error = InvalidHeaderValueBytes;
#[inline] #[inline]
fn try_into(self) -> Result<HeaderValue, Self::Error> { fn try_into(self) -> Result<HeaderValue, Self::Error> {
HeaderValue::try_from(format!("{}", self)) HeaderValue::from_shared(Bytes::from(format!("{}", self)))
} }
} }
@@ -205,7 +204,7 @@ impl Writer {
} }
} }
fn take(&mut self) -> Bytes { fn take(&mut self) -> Bytes {
self.buf.split().freeze() self.buf.take().freeze()
} }
} }
@@ -217,7 +216,7 @@ impl fmt::Write for Writer {
} }
#[inline] #[inline]
fn write_fmt(&mut self, args: fmt::Arguments<'_>) -> fmt::Result { fn write_fmt(&mut self, args: fmt::Arguments) -> fmt::Result {
fmt::write(self, args) fmt::write(self, args)
} }
} }
@@ -259,7 +258,7 @@ pub fn from_one_raw_str<T: FromStr>(val: Option<&HeaderValue>) -> Result<T, Pars
#[inline] #[inline]
#[doc(hidden)] #[doc(hidden)]
/// Format an array into a comma-delimited string. /// Format an array into a comma-delimited string.
pub fn fmt_comma_delimited<T>(f: &mut fmt::Formatter<'_>, parts: &[T]) -> fmt::Result pub fn fmt_comma_delimited<T>(f: &mut fmt::Formatter, parts: &[T]) -> fmt::Result
where where
T: fmt::Display, T: fmt::Display,
{ {
@@ -361,7 +360,7 @@ pub fn parse_extended_value(
} }
impl fmt::Display for ExtendedValue { impl fmt::Display for ExtendedValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let encoded_value = let encoded_value =
percent_encoding::percent_encode(&self.value[..], HTTP_VALUE); percent_encoding::percent_encode(&self.value[..], HTTP_VALUE);
if let Some(ref lang) = self.language_tag { if let Some(ref lang) = self.language_tag {
@@ -376,7 +375,7 @@ impl fmt::Display for ExtendedValue {
/// [https://tools.ietf.org/html/rfc5987#section-3.2][url] /// [https://tools.ietf.org/html/rfc5987#section-3.2][url]
/// ///
/// [url]: https://tools.ietf.org/html/rfc5987#section-3.2 /// [url]: https://tools.ietf.org/html/rfc5987#section-3.2
pub fn http_percent_encode(f: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result { pub fn http_percent_encode(f: &mut fmt::Formatter, bytes: &[u8]) -> fmt::Result {
let encoded = percent_encoding::percent_encode(bytes, HTTP_VALUE); let encoded = percent_encoding::percent_encode(bytes, HTTP_VALUE);
fmt::Display::fmt(&encoded, f) fmt::Display::fmt(&encoded, f)
} }

View File

@@ -98,7 +98,7 @@ impl Charset {
} }
impl Display for Charset { impl Display for Charset {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(self.label()) f.write_str(self.label())
} }
} }

View File

@@ -27,7 +27,7 @@ pub enum Encoding {
} }
impl fmt::Display for Encoding { impl fmt::Display for Encoding {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(match *self { f.write_str(match *self {
Chunked => "chunked", Chunked => "chunked",
Brotli => "br", Brotli => "br",

View File

@@ -1,7 +1,7 @@
use std::fmt::{self, Display, Write}; use std::fmt::{self, Display, Write};
use std::str::FromStr; use std::str::FromStr;
use crate::header::{HeaderValue, IntoHeaderValue, InvalidHeaderValue, Writer}; use crate::header::{HeaderValue, IntoHeaderValue, InvalidHeaderValueBytes, Writer};
/// check that each char in the slice is either: /// check that each char in the slice is either:
/// 1. `%x21`, or /// 1. `%x21`, or
@@ -113,7 +113,7 @@ impl EntityTag {
} }
impl Display for EntityTag { impl Display for EntityTag {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if self.weak { if self.weak {
write!(f, "W/\"{}\"", self.tag) write!(f, "W/\"{}\"", self.tag)
} else { } else {
@@ -157,12 +157,12 @@ impl FromStr for EntityTag {
} }
impl IntoHeaderValue for EntityTag { impl IntoHeaderValue for EntityTag {
type Error = InvalidHeaderValue; type Error = InvalidHeaderValueBytes;
fn try_into(self) -> Result<HeaderValue, Self::Error> { fn try_into(self) -> Result<HeaderValue, Self::Error> {
let mut wrt = Writer::new(); let mut wrt = Writer::new();
write!(wrt, "{}", self).unwrap(); write!(wrt, "{}", self).unwrap();
HeaderValue::from_maybe_shared(wrt.take()) HeaderValue::from_shared(wrt.take())
} }
} }

View File

@@ -3,8 +3,8 @@ use std::io::Write;
use std::str::FromStr; use std::str::FromStr;
use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use bytes::{buf::BufMutExt, BytesMut}; use bytes::{BufMut, BytesMut};
use http::header::{HeaderValue, InvalidHeaderValue}; use http::header::{HeaderValue, InvalidHeaderValueBytes};
use crate::error::ParseError; use crate::error::ParseError;
use crate::header::IntoHeaderValue; use crate::header::IntoHeaderValue;
@@ -28,7 +28,7 @@ impl FromStr for HttpDate {
} }
impl Display for HttpDate { impl Display for HttpDate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.0.to_utc().rfc822(), f) fmt::Display::fmt(&self.0.to_utc().rfc822(), f)
} }
} }
@@ -58,12 +58,12 @@ impl From<SystemTime> for HttpDate {
} }
impl IntoHeaderValue for HttpDate { impl IntoHeaderValue for HttpDate {
type Error = InvalidHeaderValue; type Error = InvalidHeaderValueBytes;
fn try_into(self) -> Result<HeaderValue, Self::Error> { fn try_into(self) -> Result<HeaderValue, Self::Error> {
let mut wrt = BytesMut::with_capacity(29).writer(); let mut wrt = BytesMut::with_capacity(29).writer();
write!(wrt, "{}", self.0.rfc822()).unwrap(); write!(wrt, "{}", self.0.rfc822()).unwrap();
HeaderValue::from_maybe_shared(wrt.get_mut().split().freeze()) HeaderValue::from_shared(wrt.get_mut().take().freeze())
} }
} }

View File

@@ -53,7 +53,7 @@ impl<T: PartialEq> cmp::PartialOrd for QualityItem<T> {
} }
impl<T: fmt::Display> fmt::Display for QualityItem<T> { impl<T: fmt::Display> fmt::Display for QualityItem<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.item, f)?; fmt::Display::fmt(&self.item, f)?;
match self.quality.0 { match self.quality.0 {
1000 => Ok(()), 1000 => Ok(()),

View File

@@ -60,7 +60,7 @@ pub(crate) fn write_status_line(version: Version, mut n: u16, bytes: &mut BytesM
bytes.put_slice(&buf); bytes.put_slice(&buf);
if four { if four {
bytes.put_u8(b' '); bytes.put(b' ');
} }
} }
@@ -203,33 +203,33 @@ mod tests {
let mut bytes = BytesMut::new(); let mut bytes = BytesMut::new();
bytes.reserve(50); bytes.reserve(50);
write_content_length(0, &mut bytes); write_content_length(0, &mut bytes);
assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 0\r\n"[..]); assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 0\r\n"[..]);
bytes.reserve(50); bytes.reserve(50);
write_content_length(9, &mut bytes); write_content_length(9, &mut bytes);
assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 9\r\n"[..]); assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 9\r\n"[..]);
bytes.reserve(50); bytes.reserve(50);
write_content_length(10, &mut bytes); write_content_length(10, &mut bytes);
assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 10\r\n"[..]); assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 10\r\n"[..]);
bytes.reserve(50); bytes.reserve(50);
write_content_length(99, &mut bytes); write_content_length(99, &mut bytes);
assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 99\r\n"[..]); assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 99\r\n"[..]);
bytes.reserve(50); bytes.reserve(50);
write_content_length(100, &mut bytes); write_content_length(100, &mut bytes);
assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 100\r\n"[..]); assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 100\r\n"[..]);
bytes.reserve(50); bytes.reserve(50);
write_content_length(101, &mut bytes); write_content_length(101, &mut bytes);
assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 101\r\n"[..]); assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 101\r\n"[..]);
bytes.reserve(50); bytes.reserve(50);
write_content_length(998, &mut bytes); write_content_length(998, &mut bytes);
assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 998\r\n"[..]); assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 998\r\n"[..]);
bytes.reserve(50); bytes.reserve(50);
write_content_length(1000, &mut bytes); write_content_length(1000, &mut bytes);
assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 1000\r\n"[..]); assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 1000\r\n"[..]);
bytes.reserve(50); bytes.reserve(50);
write_content_length(1001, &mut bytes); write_content_length(1001, &mut bytes);
assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 1001\r\n"[..]); assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 1001\r\n"[..]);
bytes.reserve(50); bytes.reserve(50);
write_content_length(5909, &mut bytes); write_content_length(5909, &mut bytes);
assert_eq!(bytes.split().freeze(), b"\r\ncontent-length: 5909\r\n"[..]); assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 5909\r\n"[..]);
} }
} }

View File

@@ -25,10 +25,10 @@ pub trait HttpMessage: Sized {
fn take_payload(&mut self) -> Payload<Self::Stream>; fn take_payload(&mut self) -> Payload<Self::Stream>;
/// Request's extensions container /// Request's extensions container
fn extensions(&self) -> Ref<'_, Extensions>; fn extensions(&self) -> Ref<Extensions>;
/// Mutable reference to a the request's extensions container /// Mutable reference to a the request's extensions container
fn extensions_mut(&self) -> RefMut<'_, Extensions>; fn extensions_mut(&self) -> RefMut<Extensions>;
#[doc(hidden)] #[doc(hidden)]
/// Get a header /// Get a header
@@ -105,7 +105,7 @@ pub trait HttpMessage: Sized {
/// Load request cookies. /// Load request cookies.
#[inline] #[inline]
fn cookies(&self) -> Result<Ref<'_, Vec<Cookie<'static>>>, CookieParseError> { fn cookies(&self) -> Result<Ref<Vec<Cookie<'static>>>, CookieParseError> {
if self.extensions().get::<Cookies>().is_none() { if self.extensions().get::<Cookies>().is_none() {
let mut cookies = Vec::new(); let mut cookies = Vec::new();
for hdr in self.headers().get_all(header::COOKIE) { for hdr in self.headers().get_all(header::COOKIE) {
@@ -153,12 +153,12 @@ where
} }
/// Request's extensions container /// Request's extensions container
fn extensions(&self) -> Ref<'_, Extensions> { fn extensions(&self) -> Ref<Extensions> {
(**self).extensions() (**self).extensions()
} }
/// Mutable reference to a the request's extensions container /// Mutable reference to a the request's extensions container
fn extensions_mut(&self) -> RefMut<'_, Extensions> { fn extensions_mut(&self) -> RefMut<Extensions> {
(**self).extensions_mut() (**self).extensions_mut()
} }
} }

View File

@@ -1,10 +1,10 @@
//! Basic http primitives for actix-net framework. //! Basic http primitives for actix-net framework.
#![deny(rust_2018_idioms, warnings)]
#![allow( #![allow(
clippy::type_complexity, clippy::type_complexity,
clippy::too_many_arguments, clippy::too_many_arguments,
clippy::new_without_default, clippy::new_without_default,
clippy::borrow_interior_mutable_const clippy::borrow_interior_mutable_const,
clippy::write_with_newline
)] )]
#[macro_use] #[macro_use]
@@ -15,7 +15,6 @@ mod builder;
pub mod client; pub mod client;
mod cloneable; mod cloneable;
mod config; mod config;
#[cfg(feature = "compress")]
pub mod encoding; pub mod encoding;
mod extensions; mod extensions;
mod header; mod header;
@@ -52,7 +51,7 @@ pub mod http {
// re-exports // re-exports
pub use http::header::{HeaderName, HeaderValue}; pub use http::header::{HeaderName, HeaderValue};
pub use http::uri::PathAndQuery; pub use http::uri::PathAndQuery;
pub use http::{uri, Error, Uri}; pub use http::{uri, Error, HttpTryFrom, Uri};
pub use http::{Method, StatusCode, Version}; pub use http::{Method, StatusCode, Version};
pub use crate::cookie::{Cookie, CookieBuilder}; pub use crate::cookie::{Cookie, CookieBuilder};
@@ -65,10 +64,3 @@ pub mod http {
pub use crate::header::ContentEncoding; pub use crate::header::ContentEncoding;
pub use crate::message::ConnectionType; pub use crate::message::ConnectionType;
} }
/// Http protocol
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum Protocol {
Http1,
Http2,
}

View File

@@ -78,13 +78,13 @@ impl Head for RequestHead {
impl RequestHead { impl RequestHead {
/// Message extensions /// Message extensions
#[inline] #[inline]
pub fn extensions(&self) -> Ref<'_, Extensions> { pub fn extensions(&self) -> Ref<Extensions> {
self.extensions.borrow() self.extensions.borrow()
} }
/// Mutable reference to a the message's extensions /// Mutable reference to a the message's extensions
#[inline] #[inline]
pub fn extensions_mut(&self) -> RefMut<'_, Extensions> { pub fn extensions_mut(&self) -> RefMut<Extensions> {
self.extensions.borrow_mut() self.extensions.borrow_mut()
} }
@@ -237,13 +237,13 @@ impl ResponseHead {
/// Message extensions /// Message extensions
#[inline] #[inline]
pub fn extensions(&self) -> Ref<'_, Extensions> { pub fn extensions(&self) -> Ref<Extensions> {
self.extensions.borrow() self.extensions.borrow()
} }
/// Mutable reference to a the message's extensions /// Mutable reference to a the message's extensions
#[inline] #[inline]
pub fn extensions_mut(&self) -> RefMut<'_, Extensions> { pub fn extensions_mut(&self) -> RefMut<Extensions> {
self.extensions.borrow_mut() self.extensions.borrow_mut()
} }
@@ -388,12 +388,6 @@ impl BoxedResponseHead {
pub fn new(status: StatusCode) -> Self { pub fn new(status: StatusCode) -> Self {
RESPONSE_POOL.with(|p| p.get_message(status)) RESPONSE_POOL.with(|p| p.get_message(status))
} }
pub(crate) fn take(&mut self) -> Self {
BoxedResponseHead {
head: self.head.take(),
}
}
} }
impl std::ops::Deref for BoxedResponseHead { impl std::ops::Deref for BoxedResponseHead {
@@ -412,9 +406,7 @@ impl std::ops::DerefMut for BoxedResponseHead {
impl Drop for BoxedResponseHead { impl Drop for BoxedResponseHead {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(head) = self.head.take() { RESPONSE_POOL.with(|p| p.release(self.head.take().unwrap()))
RESPONSE_POOL.with(move |p| p.release(head))
}
} }
} }

View File

@@ -1,14 +1,11 @@
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes; use bytes::Bytes;
use futures_core::Stream; use futures::{Async, Poll, Stream};
use h2::RecvStream; use h2::RecvStream;
use crate::error::PayloadError; use crate::error::PayloadError;
/// Type represent boxed payload /// Type represent boxed payload
pub type PayloadStream = Pin<Box<dyn Stream<Item = Result<Bytes, PayloadError>>>>; pub type PayloadStream = Box<dyn Stream<Item = Bytes, Error = PayloadError>>;
/// Type represent streaming payload /// Type represent streaming payload
pub enum Payload<S = PayloadStream> { pub enum Payload<S = PayloadStream> {
@@ -51,20 +48,18 @@ impl<S> Payload<S> {
impl<S> Stream for Payload<S> impl<S> Stream for Payload<S>
where where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin, S: Stream<Item = Bytes, Error = PayloadError>,
{ {
type Item = Result<Bytes, PayloadError>; type Item = Bytes;
type Error = PayloadError;
#[inline] #[inline]
fn poll_next( fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
self: Pin<&mut Self>, match self {
cx: &mut Context<'_>, Payload::None => Ok(Async::Ready(None)),
) -> Poll<Option<Self::Item>> { Payload::H1(ref mut pl) => pl.poll(),
match self.get_mut() { Payload::H2(ref mut pl) => pl.poll(),
Payload::None => Poll::Ready(None), Payload::Stream(ref mut pl) => pl.poll(),
Payload::H1(ref mut pl) => pl.readany(cx),
Payload::H2(ref mut pl) => Pin::new(pl).poll_next(cx),
Payload::Stream(ref mut pl) => Pin::new(pl).poll_next(cx),
} }
} }
} }

View File

@@ -25,13 +25,13 @@ impl<P> HttpMessage for Request<P> {
/// Request extensions /// Request extensions
#[inline] #[inline]
fn extensions(&self) -> Ref<'_, Extensions> { fn extensions(&self) -> Ref<Extensions> {
self.head.extensions() self.head.extensions()
} }
/// Mutable reference to a the request's extensions /// Mutable reference to a the request's extensions
#[inline] #[inline]
fn extensions_mut(&self) -> RefMut<'_, Extensions> { fn extensions_mut(&self) -> RefMut<Extensions> {
self.head.extensions_mut() self.head.extensions_mut()
} }
@@ -80,11 +80,6 @@ impl<P> Request<P> {
) )
} }
/// Get request's payload
pub fn payload(&mut self) -> &mut Payload<P> {
&mut self.payload
}
/// Get request's payload /// Get request's payload
pub fn take_payload(&mut self) -> Payload<P> { pub fn take_payload(&mut self) -> Payload<P> {
std::mem::replace(&mut self.payload, Payload::None) std::mem::replace(&mut self.payload, Payload::None)
@@ -165,7 +160,7 @@ impl<P> Request<P> {
} }
impl<P> fmt::Debug for Request<P> { impl<P> fmt::Debug for Request<P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!( writeln!(
f, f,
"\nRequest {:?} {}:{}", "\nRequest {:?} {}:{}",
@@ -187,7 +182,7 @@ impl<P> fmt::Debug for Request<P> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::convert::TryFrom; use http::HttpTryFrom;
#[test] #[test]
fn test_basics() { fn test_basics() {
@@ -204,6 +199,7 @@ mod tests {
assert_eq!(req.uri().query(), Some("q=1")); assert_eq!(req.uri().query(), Some("q=1"));
let s = format!("{:?}", req); let s = format!("{:?}", req);
println!("T: {:?}", s);
assert!(s.contains("Request HTTP/1.1 GET:/index.html")); assert!(s.contains("Request HTTP/1.1 GET:/index.html"));
} }
} }

View File

@@ -1,13 +1,11 @@
//! Http response //! Http response
use std::cell::{Ref, RefMut}; use std::cell::{Ref, RefMut};
use std::convert::TryFrom; use std::io::Write;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, str}; use std::{fmt, str};
use bytes::{Bytes, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};
use futures_core::Stream; use futures::future::{ok, FutureResult, IntoFuture};
use futures::Stream;
use serde::Serialize; use serde::Serialize;
use serde_json; use serde_json;
@@ -17,7 +15,7 @@ use crate::error::Error;
use crate::extensions::Extensions; use crate::extensions::Extensions;
use crate::header::{Header, IntoHeaderValue}; use crate::header::{Header, IntoHeaderValue};
use crate::http::header::{self, HeaderName, HeaderValue}; use crate::http::header::{self, HeaderName, HeaderValue};
use crate::http::{Error as HttpError, HeaderMap, StatusCode}; use crate::http::{Error as HttpError, HeaderMap, HttpTryFrom, StatusCode};
use crate::message::{BoxedResponseHead, ConnectionType, ResponseHead}; use crate::message::{BoxedResponseHead, ConnectionType, ResponseHead};
/// An HTTP Response /// An HTTP Response
@@ -53,7 +51,7 @@ impl Response<Body> {
/// Constructs an error response /// Constructs an error response
#[inline] #[inline]
pub fn from_error(error: Error) -> Response { pub fn from_error(error: Error) -> Response {
let mut resp = error.as_response_error().error_response(); let mut resp = error.as_response_error().render_response();
if resp.head.status == StatusCode::INTERNAL_SERVER_ERROR { if resp.head.status == StatusCode::INTERNAL_SERVER_ERROR {
error!("Internal Server Error: {:?}", error); error!("Internal Server Error: {:?}", error);
} }
@@ -130,7 +128,7 @@ impl<B> Response<B> {
/// Get an iterator for the cookies set by this response /// Get an iterator for the cookies set by this response
#[inline] #[inline]
pub fn cookies(&self) -> CookieIter<'_> { pub fn cookies(&self) -> CookieIter {
CookieIter { CookieIter {
iter: self.head.headers.get_all(header::SET_COOKIE), iter: self.head.headers.get_all(header::SET_COOKIE),
} }
@@ -138,7 +136,7 @@ impl<B> Response<B> {
/// Add a cookie to this response /// Add a cookie to this response
#[inline] #[inline]
pub fn add_cookie(&mut self, cookie: &Cookie<'_>) -> Result<(), HttpError> { pub fn add_cookie(&mut self, cookie: &Cookie) -> Result<(), HttpError> {
let h = &mut self.head.headers; let h = &mut self.head.headers;
HeaderValue::from_str(&cookie.to_string()) HeaderValue::from_str(&cookie.to_string())
.map(|c| { .map(|c| {
@@ -186,13 +184,13 @@ impl<B> Response<B> {
/// Responses extensions /// Responses extensions
#[inline] #[inline]
pub fn extensions(&self) -> Ref<'_, Extensions> { pub fn extensions(&self) -> Ref<Extensions> {
self.head.extensions.borrow() self.head.extensions.borrow()
} }
/// Mutable reference to a the response's extensions /// Mutable reference to a the response's extensions
#[inline] #[inline]
pub fn extensions_mut(&mut self) -> RefMut<'_, Extensions> { pub fn extensions_mut(&mut self) -> RefMut<Extensions> {
self.head.extensions.borrow_mut() self.head.extensions.borrow_mut()
} }
@@ -265,7 +263,7 @@ impl<B> Response<B> {
} }
impl<B: MessageBody> fmt::Debug for Response<B> { impl<B: MessageBody> fmt::Debug for Response<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let res = writeln!( let res = writeln!(
f, f,
"\nResponse {:?} {}{}", "\nResponse {:?} {}{}",
@@ -282,15 +280,13 @@ impl<B: MessageBody> fmt::Debug for Response<B> {
} }
} }
impl Future for Response { impl IntoFuture for Response {
type Output = Result<Response, Error>; type Item = Response;
type Error = Error;
type Future = FutureResult<Response, Error>;
fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> { fn into_future(self) -> Self::Future {
Poll::Ready(Ok(Response { ok(self)
head: self.head.take(),
body: self.body.take_body(),
error: self.error.take(),
}))
} }
} }
@@ -354,6 +350,7 @@ impl ResponseBuilder {
/// )) /// ))
/// .finish()) /// .finish())
/// } /// }
/// fn main() {}
/// ``` /// ```
#[doc(hidden)] #[doc(hidden)]
pub fn set<H: Header>(&mut self, hdr: H) -> &mut Self { pub fn set<H: Header>(&mut self, hdr: H) -> &mut Self {
@@ -379,11 +376,11 @@ impl ResponseBuilder {
/// .header(http::header::CONTENT_TYPE, "application/json") /// .header(http::header::CONTENT_TYPE, "application/json")
/// .finish() /// .finish()
/// } /// }
/// fn main() {}
/// ``` /// ```
pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self
where where
HeaderName: TryFrom<K>, HeaderName: HttpTryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
V: IntoHeaderValue, V: IntoHeaderValue,
{ {
if let Some(parts) = parts(&mut self.head, &self.err) { if let Some(parts) = parts(&mut self.head, &self.err) {
@@ -411,11 +408,11 @@ impl ResponseBuilder {
/// .set_header(http::header::CONTENT_TYPE, "application/json") /// .set_header(http::header::CONTENT_TYPE, "application/json")
/// .finish() /// .finish()
/// } /// }
/// fn main() {}
/// ``` /// ```
pub fn set_header<K, V>(&mut self, key: K, value: V) -> &mut Self pub fn set_header<K, V>(&mut self, key: K, value: V) -> &mut Self
where where
HeaderName: TryFrom<K>, HeaderName: HttpTryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
V: IntoHeaderValue, V: IntoHeaderValue,
{ {
if let Some(parts) = parts(&mut self.head, &self.err) { if let Some(parts) = parts(&mut self.head, &self.err) {
@@ -484,8 +481,7 @@ impl ResponseBuilder {
#[inline] #[inline]
pub fn content_type<V>(&mut self, value: V) -> &mut Self pub fn content_type<V>(&mut self, value: V) -> &mut Self
where where
HeaderValue: TryFrom<V>, HeaderValue: HttpTryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
{ {
if let Some(parts) = parts(&mut self.head, &self.err) { if let Some(parts) = parts(&mut self.head, &self.err) {
match HeaderValue::try_from(value) { match HeaderValue::try_from(value) {
@@ -501,7 +497,9 @@ impl ResponseBuilder {
/// Set content length /// Set content length
#[inline] #[inline]
pub fn content_length(&mut self, len: u64) -> &mut Self { pub fn content_length(&mut self, len: u64) -> &mut Self {
self.header(header::CONTENT_LENGTH, len) let mut wrt = BytesMut::new().writer();
let _ = write!(wrt, "{}", len);
self.header(header::CONTENT_LENGTH, wrt.get_mut().take().freeze())
} }
/// Set a cookie /// Set a cookie
@@ -585,14 +583,14 @@ impl ResponseBuilder {
/// Responses extensions /// Responses extensions
#[inline] #[inline]
pub fn extensions(&self) -> Ref<'_, Extensions> { pub fn extensions(&self) -> Ref<Extensions> {
let head = self.head.as_ref().expect("cannot reuse response builder"); let head = self.head.as_ref().expect("cannot reuse response builder");
head.extensions.borrow() head.extensions.borrow()
} }
/// Mutable reference to a the response's extensions /// Mutable reference to a the response's extensions
#[inline] #[inline]
pub fn extensions_mut(&mut self) -> RefMut<'_, Extensions> { pub fn extensions_mut(&mut self) -> RefMut<Extensions> {
let head = self.head.as_ref().expect("cannot reuse response builder"); let head = self.head.as_ref().expect("cannot reuse response builder");
head.extensions.borrow_mut() head.extensions.borrow_mut()
} }
@@ -637,7 +635,7 @@ impl ResponseBuilder {
/// `ResponseBuilder` can not be used after this call. /// `ResponseBuilder` can not be used after this call.
pub fn streaming<S, E>(&mut self, stream: S) -> Response pub fn streaming<S, E>(&mut self, stream: S) -> Response
where where
S: Stream<Item = Result<Bytes, E>> + 'static, S: Stream<Item = Bytes, Error = E> + 'static,
E: Into<Error> + 'static, E: Into<Error> + 'static,
{ {
self.body(Body::from_message(BodyStream::new(stream))) self.body(Body::from_message(BodyStream::new(stream)))
@@ -759,16 +757,18 @@ impl<'a> From<&'a ResponseHead> for ResponseBuilder {
} }
} }
impl Future for ResponseBuilder { impl IntoFuture for ResponseBuilder {
type Output = Result<Response, Error>; type Item = Response;
type Error = Error;
type Future = FutureResult<Response, Error>;
fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> { fn into_future(mut self) -> Self::Future {
Poll::Ready(Ok(self.finish())) ok(self.finish())
} }
} }
impl fmt::Debug for ResponseBuilder { impl fmt::Debug for ResponseBuilder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let head = self.head.as_ref().unwrap(); let head = self.head.as_ref().unwrap();
let res = writeln!( let res = writeln!(

View File

@@ -1,16 +1,14 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin; use std::{fmt, io, net, rc};
use std::task::{Context, Poll};
use std::{fmt, net, rc};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_rt::net::TcpStream; use actix_server_config::{
use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory}; Io as ServerIo, IoStream, Protocol, ServerConfig as SrvConfig,
use bytes::Bytes; };
use futures_core::{ready, Future}; use actix_service::{IntoNewService, NewService, Service};
use futures_util::future::ok; use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures::{try_ready, Async, Future, IntoFuture, Poll};
use h2::server::{self, Handshake}; use h2::server::{self, Handshake};
use pin_project::{pin_project, project};
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::builder::HttpServiceBuilder; use crate::builder::HttpServiceBuilder;
@@ -20,24 +18,24 @@ use crate::error::{DispatchError, Error};
use crate::helpers::DataFactory; use crate::helpers::DataFactory;
use crate::request::Request; use crate::request::Request;
use crate::response::Response; use crate::response::Response;
use crate::{h1, h2::Dispatcher, Protocol}; use crate::{h1, h2::Dispatcher};
/// `ServiceFactory` HTTP1.1/HTTP2 transport implementation /// `NewService` HTTP1.1/HTTP2 transport implementation
pub struct HttpService<T, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler<T>> { pub struct HttpService<T, P, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler<T>> {
srv: S, srv: S,
cfg: ServiceConfig, cfg: ServiceConfig,
expect: X, expect: X,
upgrade: Option<U>, upgrade: Option<U>,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
_t: PhantomData<(T, B)>, _t: PhantomData<(T, P, B)>,
} }
impl<T, S, B> HttpService<T, S, B> impl<T, S, B> HttpService<T, (), S, B>
where where
S: ServiceFactory<Config = (), Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
@@ -47,22 +45,22 @@ where
} }
} }
impl<T, S, B> HttpService<T, S, B> impl<T, P, S, B> HttpService<T, P, S, B>
where where
S: ServiceFactory<Config = (), Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
{ {
/// Create new `HttpService` instance. /// Create new `HttpService` instance.
pub fn new<F: IntoServiceFactory<S>>(service: F) -> Self { pub fn new<F: IntoNewService<S>>(service: F) -> Self {
let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0, false, None); let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0);
HttpService { HttpService {
cfg, cfg,
srv: service.into_factory(), srv: service.into_new_service(),
expect: h1::ExpectHandler, expect: h1::ExpectHandler,
upgrade: None, upgrade: None,
on_connect: None, on_connect: None,
@@ -71,13 +69,13 @@ where
} }
/// Create new `HttpService` instance with config. /// Create new `HttpService` instance with config.
pub(crate) fn with_config<F: IntoServiceFactory<S>>( pub(crate) fn with_config<F: IntoNewService<S>>(
cfg: ServiceConfig, cfg: ServiceConfig,
service: F, service: F,
) -> Self { ) -> Self {
HttpService { HttpService {
cfg, cfg,
srv: service.into_factory(), srv: service.into_new_service(),
expect: h1::ExpectHandler, expect: h1::ExpectHandler,
upgrade: None, upgrade: None,
on_connect: None, on_connect: None,
@@ -86,13 +84,12 @@ where
} }
} }
impl<T, S, B, X, U> HttpService<T, S, B, X, U> impl<T, P, S, B, X, U> HttpService<T, P, S, B, X, U>
where where
S: ServiceFactory<Config = (), Request = Request>, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static,
B: MessageBody, B: MessageBody,
{ {
/// Provide service for `EXPECT: 100-Continue` support. /// Provide service for `EXPECT: 100-Continue` support.
@@ -100,12 +97,11 @@ where
/// Service get called with request that contains `EXPECT` header. /// Service get called with request that contains `EXPECT` header.
/// Service must return request in case of success, in that case /// Service must return request in case of success, in that case
/// request will be forwarded to main service. /// request will be forwarded to main service.
pub fn expect<X1>(self, expect: X1) -> HttpService<T, S, B, X1, U> pub fn expect<X1>(self, expect: X1) -> HttpService<T, P, S, B, X1, U>
where where
X1: ServiceFactory<Config = (), Request = Request, Response = Request>, X1: NewService<Config = SrvConfig, Request = Request, Response = Request>,
X1::Error: Into<Error>, X1::Error: Into<Error>,
X1::InitError: fmt::Debug, X1::InitError: fmt::Debug,
<X1::Service as Service>::Future: 'static,
{ {
HttpService { HttpService {
expect, expect,
@@ -121,16 +117,15 @@ where
/// ///
/// If service is provided then normal requests handling get halted /// If service is provided then normal requests handling get halted
/// and this service get called with original request and framed object. /// and this service get called with original request and framed object.
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> HttpService<T, S, B, X, U1> pub fn upgrade<U1>(self, upgrade: Option<U1>) -> HttpService<T, P, S, B, X, U1>
where where
U1: ServiceFactory< U1: NewService<
Config = (), Config = SrvConfig,
Request = (Request, Framed<T, h1::Codec>), Request = (Request, Framed<T, h1::Codec>),
Response = (), Response = (),
>, >,
U1::Error: fmt::Display, U1::Error: fmt::Display,
U1::InitError: fmt::Debug, U1::InitError: fmt::Debug,
<U1::Service as Service>::Future: 'static,
{ {
HttpService { HttpService {
upgrade, upgrade,
@@ -152,312 +147,126 @@ where
} }
} }
impl<S, B, X, U> HttpService<TcpStream, S, B, X, U> impl<T, P, S, B, X, U> NewService for HttpService<T, P, S, B, X, U>
where where
S: ServiceFactory<Config = (), Request = Request>, T: IoStream,
S::Error: Into<Error> + 'static, S: NewService<Config = SrvConfig, Request = Request>,
S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
X: ServiceFactory<Config = (), Request = Request, Response = Request>, X: NewService<Config = SrvConfig, Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
<X::Service as Service>::Future: 'static, U: NewService<
U: ServiceFactory< Config = SrvConfig,
Config = (),
Request = (Request, Framed<TcpStream, h1::Codec>),
Response = (),
>,
U::Error: fmt::Display + Into<Error>,
U::InitError: fmt::Debug,
<U::Service as Service>::Future: 'static,
{
/// Create simple tcp stream service
pub fn tcp(
self,
) -> impl ServiceFactory<
Config = (),
Request = TcpStream,
Response = (),
Error = DispatchError,
InitError = (),
> {
pipeline_factory(|io: TcpStream| {
let peer_addr = io.peer_addr().ok();
ok((io, Protocol::Http1, peer_addr))
})
.and_then(self)
}
}
#[cfg(feature = "openssl")]
mod openssl {
use super::*;
use actix_tls::openssl::{Acceptor, SslAcceptor, SslStream};
use actix_tls::{openssl::HandshakeError, SslError};
impl<S, B, X, U> HttpService<SslStream<TcpStream>, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: Into<Error>,
X::InitError: fmt::Debug,
<X::Service as Service>::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, Framed<SslStream<TcpStream>, h1::Codec>),
Response = (),
>,
U::Error: fmt::Display + Into<Error>,
U::InitError: fmt::Debug,
<U::Service as Service>::Future: 'static,
{
/// Create openssl based service
pub fn openssl(
self,
acceptor: SslAcceptor,
) -> impl ServiceFactory<
Config = (),
Request = TcpStream,
Response = (),
Error = SslError<HandshakeError<TcpStream>, DispatchError>,
InitError = (),
> {
pipeline_factory(
Acceptor::new(acceptor)
.map_err(SslError::Ssl)
.map_init_err(|_| panic!()),
)
.and_then(|io: SslStream<TcpStream>| {
let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() {
if protos.windows(2).any(|window| window == b"h2") {
Protocol::Http2
} else {
Protocol::Http1
}
} else {
Protocol::Http1
};
let peer_addr = io.get_ref().peer_addr().ok();
ok((io, proto, peer_addr))
})
.and_then(self.map_err(SslError::Service))
}
}
}
#[cfg(feature = "rustls")]
mod rustls {
use super::*;
use actix_tls::rustls::{Acceptor, ServerConfig, Session, TlsStream};
use actix_tls::SslError;
use std::io;
impl<S, B, X, U> HttpService<TlsStream<TcpStream>, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: Into<Error>,
X::InitError: fmt::Debug,
<X::Service as Service>::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, Framed<TlsStream<TcpStream>, h1::Codec>),
Response = (),
>,
U::Error: fmt::Display + Into<Error>,
U::InitError: fmt::Debug,
<U::Service as Service>::Future: 'static,
{
/// Create openssl based service
pub fn rustls(
self,
mut config: ServerConfig,
) -> impl ServiceFactory<
Config = (),
Request = TcpStream,
Response = (),
Error = SslError<io::Error, DispatchError>,
InitError = (),
> {
let protos = vec!["h2".to_string().into(), "http/1.1".to_string().into()];
config.set_protocols(&protos);
pipeline_factory(
Acceptor::new(config)
.map_err(SslError::Ssl)
.map_init_err(|_| panic!()),
)
.and_then(|io: TlsStream<TcpStream>| {
let proto = if let Some(protos) = io.get_ref().1.get_alpn_protocol() {
if protos.windows(2).any(|window| window == b"h2") {
Protocol::Http2
} else {
Protocol::Http1
}
} else {
Protocol::Http1
};
let peer_addr = io.get_ref().0.peer_addr().ok();
ok((io, proto, peer_addr))
})
.and_then(self.map_err(SslError::Service))
}
}
}
impl<T, S, B, X, U> ServiceFactory for HttpService<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin,
S: ServiceFactory<Config = (), Request = Request>,
S::Error: Into<Error> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: Into<Error>,
X::InitError: fmt::Debug,
<X::Service as Service>::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, Framed<T, h1::Codec>), Request = (Request, Framed<T, h1::Codec>),
Response = (), Response = (),
>, >,
U::Error: fmt::Display + Into<Error>, U::Error: fmt::Display,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
<U::Service as Service>::Future: 'static,
{ {
type Config = (); type Config = SrvConfig;
type Request = (T, Protocol, Option<net::SocketAddr>); type Request = ServerIo<T, P>;
type Response = (); type Response = ();
type Error = DispatchError; type Error = DispatchError;
type InitError = (); type InitError = ();
type Service = HttpServiceHandler<T, S::Service, B, X::Service, U::Service>; type Service = HttpServiceHandler<T, P, S::Service, B, X::Service, U::Service>;
type Future = HttpServiceResponse<T, S, B, X, U>; type Future = HttpServiceResponse<T, P, S, B, X, U>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, cfg: &SrvConfig) -> Self::Future {
HttpServiceResponse { HttpServiceResponse {
fut: self.srv.new_service(()), fut: self.srv.new_service(cfg).into_future(),
fut_ex: Some(self.expect.new_service(())), fut_ex: Some(self.expect.new_service(cfg)),
fut_upg: self.upgrade.as_ref().map(|f| f.new_service(())), fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)),
expect: None, expect: None,
upgrade: None, upgrade: None,
on_connect: self.on_connect.clone(), on_connect: self.on_connect.clone(),
cfg: self.cfg.clone(), cfg: Some(self.cfg.clone()),
_t: PhantomData, _t: PhantomData,
} }
} }
} }
#[doc(hidden)] #[doc(hidden)]
#[pin_project] pub struct HttpServiceResponse<T, P, S: NewService, B, X: NewService, U: NewService> {
pub struct HttpServiceResponse<
T,
S: ServiceFactory,
B,
X: ServiceFactory,
U: ServiceFactory,
> {
#[pin]
fut: S::Future, fut: S::Future,
#[pin]
fut_ex: Option<X::Future>, fut_ex: Option<X::Future>,
#[pin]
fut_upg: Option<U::Future>, fut_upg: Option<U::Future>,
expect: Option<X::Service>, expect: Option<X::Service>,
upgrade: Option<U::Service>, upgrade: Option<U::Service>,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
cfg: ServiceConfig, cfg: Option<ServiceConfig>,
_t: PhantomData<(T, B)>, _t: PhantomData<(T, P, B)>,
} }
impl<T, S, B, X, U> Future for HttpServiceResponse<T, S, B, X, U> impl<T, P, S, B, X, U> Future for HttpServiceResponse<T, P, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
S: ServiceFactory<Request = Request>, S: NewService<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::InitError: fmt::Debug, S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
<S::Service as Service>::Future: 'static, <S::Service as Service>::Future: 'static,
B: MessageBody + 'static, B: MessageBody + 'static,
X: ServiceFactory<Request = Request, Response = Request>, X: NewService<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
<X::Service as Service>::Future: 'static, U: NewService<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U: ServiceFactory<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
<U::Service as Service>::Future: 'static,
{ {
type Output = type Item = HttpServiceHandler<T, P, S::Service, B, X::Service, U::Service>;
Result<HttpServiceHandler<T, S::Service, B, X::Service, U::Service>, ()>; type Error = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut this = self.as_mut().project(); if let Some(ref mut fut) = self.fut_ex {
let expect = try_ready!(fut
if let Some(fut) = this.fut_ex.as_pin_mut() { .poll()
let expect = ready!(fut
.poll(cx)
.map_err(|e| log::error!("Init http service error: {:?}", e)))?;
this = self.as_mut().project();
*this.expect = Some(expect);
this.fut_ex.set(None);
}
if let Some(fut) = this.fut_upg.as_pin_mut() {
let upgrade = ready!(fut
.poll(cx)
.map_err(|e| log::error!("Init http service error: {:?}", e)))?;
this = self.as_mut().project();
*this.upgrade = Some(upgrade);
this.fut_ex.set(None);
}
let result = ready!(this
.fut
.poll(cx)
.map_err(|e| log::error!("Init http service error: {:?}", e))); .map_err(|e| log::error!("Init http service error: {:?}", e)));
Poll::Ready(result.map(|service| { self.expect = Some(expect);
let this = self.as_mut().project(); self.fut_ex.take();
HttpServiceHandler::new( }
this.cfg.clone(),
if let Some(ref mut fut) = self.fut_upg {
let upgrade = try_ready!(fut
.poll()
.map_err(|e| log::error!("Init http service error: {:?}", e)));
self.upgrade = Some(upgrade);
self.fut_ex.take();
}
let service = try_ready!(self
.fut
.poll()
.map_err(|e| log::error!("Init http service error: {:?}", e)));
Ok(Async::Ready(HttpServiceHandler::new(
self.cfg.take().unwrap(),
service, service,
this.expect.take().unwrap(), self.expect.take().unwrap(),
this.upgrade.take(), self.upgrade.take(),
this.on_connect.clone(), self.on_connect.clone(),
) )))
}))
} }
} }
/// `Service` implementation for http transport /// `Service` implementation for http transport
pub struct HttpServiceHandler<T, S, B, X, U> { pub struct HttpServiceHandler<T, P, S, B, X, U> {
srv: CloneableService<S>, srv: CloneableService<S>,
expect: CloneableService<X>, expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>, upgrade: Option<CloneableService<U>>,
cfg: ServiceConfig, cfg: ServiceConfig,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
_t: PhantomData<(T, B, X)>, _t: PhantomData<(T, P, B, X)>,
} }
impl<T, S, B, X, U> HttpServiceHandler<T, S, B, X, U> impl<T, P, S, B, X, U> HttpServiceHandler<T, P, S, B, X, U>
where where
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>, X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
@@ -470,7 +279,7 @@ where
expect: X, expect: X,
upgrade: Option<U>, upgrade: Option<U>,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
) -> HttpServiceHandler<T, S, B, X, U> { ) -> HttpServiceHandler<T, P, S, B, X, U> {
HttpServiceHandler { HttpServiceHandler {
cfg, cfg,
on_connect, on_connect,
@@ -482,28 +291,28 @@ where
} }
} }
impl<T, S, B, X, U> Service for HttpServiceHandler<T, S, B, X, U> impl<T, P, S, B, X, U> Service for HttpServiceHandler<T, P, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>, X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display + Into<Error>, U::Error: fmt::Display,
{ {
type Request = (T, Protocol, Option<net::SocketAddr>); type Request = ServerIo<T, P>;
type Response = (); type Response = ();
type Error = DispatchError; type Error = DispatchError;
type Future = HttpServiceHandlerResponse<T, S, B, X, U>; type Future = HttpServiceHandlerResponse<T, S, B, X, U>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
let ready = self let ready = self
.expect .expect
.poll_ready(cx) .poll_ready()
.map_err(|e| { .map_err(|e| {
let e = e.into(); let e = e.into();
log::error!("Http service readiness error: {:?}", e); log::error!("Http service readiness error: {:?}", e);
@@ -513,7 +322,7 @@ where
let ready = self let ready = self
.srv .srv
.poll_ready(cx) .poll_ready()
.map_err(|e| { .map_err(|e| {
let e = e.into(); let e = e.into();
log::error!("Http service readiness error: {:?}", e); log::error!("Http service readiness error: {:?}", e);
@@ -522,27 +331,16 @@ where
.is_ready() .is_ready()
&& ready; && ready;
let ready = if let Some(ref mut upg) = self.upgrade {
upg.poll_ready(cx)
.map_err(|e| {
let e = e.into();
log::error!("Http service readiness error: {:?}", e);
DispatchError::Service(e)
})?
.is_ready()
&& ready
} else {
ready
};
if ready { if ready {
Poll::Ready(Ok(())) Ok(Async::Ready(()))
} else { } else {
Poll::Pending Ok(Async::NotReady)
} }
} }
fn call(&mut self, (io, proto, peer_addr): Self::Request) -> Self::Future { fn call(&mut self, req: Self::Request) -> Self::Future {
let (io, _, proto) = req.into_parts();
let on_connect = if let Some(ref on_connect) = self.on_connect { let on_connect = if let Some(ref on_connect) = self.on_connect {
Some(on_connect(&io)) Some(on_connect(&io))
} else { } else {
@@ -550,16 +348,23 @@ where
}; };
match proto { match proto {
Protocol::Http2 => HttpServiceHandlerResponse { Protocol::Http2 => {
state: State::H2Handshake(Some(( let peer_addr = io.peer_addr();
let io = Io {
inner: io,
unread: None,
};
HttpServiceHandlerResponse {
state: State::Handshake(Some((
server::handshake(io), server::handshake(io),
self.cfg.clone(), self.cfg.clone(),
self.srv.clone(), self.srv.clone(),
on_connect,
peer_addr, peer_addr,
on_connect,
))), ))),
}, }
Protocol::Http1 => HttpServiceHandlerResponse { }
Protocol::Http10 | Protocol::Http11 => HttpServiceHandlerResponse {
state: State::H1(h1::Dispatcher::new( state: State::H1(h1::Dispatcher::new(
io, io,
self.cfg.clone(), self.cfg.clone(),
@@ -567,117 +372,234 @@ where
self.expect.clone(), self.expect.clone(),
self.upgrade.clone(), self.upgrade.clone(),
on_connect, on_connect,
peer_addr,
)), )),
}, },
_ => HttpServiceHandlerResponse {
state: State::Unknown(Some((
io,
BytesMut::with_capacity(14),
self.cfg.clone(),
self.srv.clone(),
self.expect.clone(),
self.upgrade.clone(),
on_connect,
))),
},
} }
} }
} }
#[pin_project]
enum State<T, S, B, X, U> enum State<T, S, B, X, U>
where where
S: Service<Request = Request>, S: Service<Request = Request>,
S::Future: 'static, S::Future: 'static,
S::Error: Into<Error>, S::Error: Into<Error>,
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
B: MessageBody, B: MessageBody,
X: Service<Request = Request, Response = Request>, X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
H1(#[pin] h1::Dispatcher<T, S, B, X, U>), H1(h1::Dispatcher<T, S, B, X, U>),
H2(#[pin] Dispatcher<T, S, B>), H2(Dispatcher<Io<T>, S, B>),
H2Handshake( Unknown(
Option<( Option<(
Handshake<T, Bytes>, T,
BytesMut,
ServiceConfig, ServiceConfig,
CloneableService<S>, CloneableService<S>,
CloneableService<X>,
Option<CloneableService<U>>,
Option<Box<dyn DataFactory>>, Option<Box<dyn DataFactory>>,
)>,
),
Handshake(
Option<(
Handshake<Io<T>, Bytes>,
ServiceConfig,
CloneableService<S>,
Option<net::SocketAddr>, Option<net::SocketAddr>,
Option<Box<dyn DataFactory>>,
)>, )>,
), ),
} }
#[pin_project]
pub struct HttpServiceHandlerResponse<T, S, B, X, U> pub struct HttpServiceHandlerResponse<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody + 'static, B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>, X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
#[pin]
state: State<T, S, B, X, U>, state: State<T, S, B, X, U>,
} }
const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0";
impl<T, S, B, X, U> Future for HttpServiceHandlerResponse<T, S, B, X, U> impl<T, S, B, X, U> Future for HttpServiceHandlerResponse<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: IoStream,
S: Service<Request = Request>, S: Service<Request = Request>,
S::Error: Into<Error> + 'static, S::Error: Into<Error>,
S::Future: 'static, S::Future: 'static,
S::Response: Into<Response<B>> + 'static, S::Response: Into<Response<B>>,
B: MessageBody, B: MessageBody,
X: Service<Request = Request, Response = Request>, X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>, X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
type Output = Result<(), DispatchError>; type Item = ();
type Error = DispatchError;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
self.project().state.poll(cx) match self.state {
State::H1(ref mut disp) => disp.poll(),
State::H2(ref mut disp) => disp.poll(),
State::Unknown(ref mut data) => {
if let Some(ref mut item) = data {
loop {
// Safety - we only write to the returned slice.
let b = unsafe { item.1.bytes_mut() };
let n = try_ready!(item.0.poll_read(b));
if n == 0 {
return Ok(Async::Ready(()));
}
// Safety - we know that 'n' bytes have
// been initialized via the contract of
// 'poll_read'
unsafe { item.1.advance_mut(n) };
if item.1.len() >= HTTP2_PREFACE.len() {
break;
} }
} }
} else {
impl<T, S, B, X, U> State<T, S, B, X, U> panic!()
where }
T: AsyncRead + AsyncWrite + Unpin, let (io, buf, cfg, srv, expect, upgrade, on_connect) =
S: Service<Request = Request>, data.take().unwrap();
S::Error: Into<Error> + 'static, if buf[..14] == HTTP2_PREFACE[..] {
S::Response: Into<Response<B>> + 'static, let peer_addr = io.peer_addr();
B: MessageBody + 'static, let io = Io {
X: Service<Request = Request, Response = Request>, inner: io,
X::Error: Into<Error>, unread: Some(buf),
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>, };
U::Error: fmt::Display, self.state = State::Handshake(Some((
{ server::handshake(io),
#[project] cfg,
fn poll( srv,
mut self: Pin<&mut Self>, peer_addr,
cx: &mut Context<'_>, on_connect,
) -> Poll<Result<(), DispatchError>> { )));
#[project] } else {
match self.as_mut().project() { self.state = State::H1(h1::Dispatcher::with_timeout(
State::H1(disp) => disp.poll(cx), io,
State::H2(disp) => disp.poll(cx), h1::Codec::new(cfg.clone()),
State::H2Handshake(ref mut data) => { cfg,
buf,
None,
srv,
expect,
upgrade,
on_connect,
))
}
self.poll()
}
State::Handshake(ref mut data) => {
let conn = if let Some(ref mut item) = data { let conn = if let Some(ref mut item) = data {
match Pin::new(&mut item.0).poll(cx) { match item.0.poll() {
Poll::Ready(Ok(conn)) => conn, Ok(Async::Ready(conn)) => conn,
Poll::Ready(Err(err)) => { Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => {
trace!("H2 handshake error: {}", err); trace!("H2 handshake error: {}", err);
return Poll::Ready(Err(err.into())); return Err(err.into());
} }
Poll::Pending => return Poll::Pending,
} }
} else { } else {
panic!() panic!()
}; };
let (_, cfg, srv, on_connect, peer_addr) = data.take().unwrap(); let (_, cfg, srv, peer_addr, on_connect) = data.take().unwrap();
self.set(State::H2(Dispatcher::new( self.state = State::H2(Dispatcher::new(
srv, conn, on_connect, cfg, None, peer_addr, srv, conn, on_connect, cfg, None, peer_addr,
))); ));
self.poll(cx) self.poll()
} }
} }
} }
} }
/// Wrapper for `AsyncRead + AsyncWrite` types
struct Io<T> {
unread: Option<BytesMut>,
inner: T,
}
impl<T: io::Read> io::Read for Io<T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if let Some(mut bytes) = self.unread.take() {
let size = std::cmp::min(buf.len(), bytes.len());
buf[..size].copy_from_slice(&bytes[..size]);
if bytes.len() > size {
bytes.split_to(size);
self.unread = Some(bytes);
}
Ok(size)
} else {
self.inner.read(buf)
}
}
}
impl<T: io::Write> io::Write for Io<T> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl<T: AsyncRead> AsyncRead for Io<T> {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}
}
impl<T: AsyncWrite> AsyncWrite for Io<T> {
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.inner.shutdown()
}
fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
self.inner.write_buf(buf)
}
}
impl<T: IoStream> IoStream for Io<T> {
#[inline]
fn peer_addr(&self) -> Option<net::SocketAddr> {
self.inner.peer_addr()
}
#[inline]
fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> {
self.inner.set_nodelay(nodelay)
}
#[inline]
fn set_linger(&mut self, dur: Option<std::time::Duration>) -> io::Result<()> {
self.inner.set_linger(dur)
}
#[inline]
fn set_keepalive(&mut self, dur: Option<std::time::Duration>) -> io::Result<()> {
self.inner.set_keepalive(dur)
}
}

View File

@@ -1,15 +1,14 @@
//! Test Various helpers for Actix applications to use during testing. //! Test Various helpers for Actix applications to use during testing.
use std::convert::TryFrom;
use std::fmt::Write as FmtWrite; use std::fmt::Write as FmtWrite;
use std::io::{self, Read, Write}; use std::io;
use std::pin::Pin;
use std::str::FromStr; use std::str::FromStr;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use bytes::{Bytes, BytesMut}; use actix_server_config::IoStream;
use bytes::{Buf, Bytes, BytesMut};
use futures::{Async, Poll};
use http::header::{self, HeaderName, HeaderValue}; use http::header::{self, HeaderName, HeaderValue};
use http::{Error as HttpError, Method, Uri, Version}; use http::{HttpTryFrom, Method, Uri, Version};
use percent_encoding::percent_encode; use percent_encoding::percent_encode;
use crate::cookie::{Cookie, CookieJar, USERINFO}; use crate::cookie::{Cookie, CookieJar, USERINFO};
@@ -21,6 +20,8 @@ use crate::Request;
/// Test `Request` builder /// Test `Request` builder
/// ///
/// ```rust,ignore /// ```rust,ignore
/// # extern crate http;
/// # extern crate actix_web;
/// # use http::{header, StatusCode}; /// # use http::{header, StatusCode};
/// # use actix_web::*; /// # use actix_web::*;
/// use actix_web::test::TestRequest; /// use actix_web::test::TestRequest;
@@ -33,6 +34,7 @@ use crate::Request;
/// } /// }
/// } /// }
/// ///
/// fn main() {
/// let resp = TestRequest::with_header("content-type", "text/plain") /// let resp = TestRequest::with_header("content-type", "text/plain")
/// .run(&index) /// .run(&index)
/// .unwrap(); /// .unwrap();
@@ -40,6 +42,7 @@ use crate::Request;
/// ///
/// let resp = TestRequest::default().run(&index).unwrap(); /// let resp = TestRequest::default().run(&index).unwrap();
/// assert_eq!(resp.status(), StatusCode::BAD_REQUEST); /// assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
/// }
/// ``` /// ```
pub struct TestRequest(Option<Inner>); pub struct TestRequest(Option<Inner>);
@@ -79,8 +82,7 @@ impl TestRequest {
/// Create TestRequest and set header /// Create TestRequest and set header
pub fn with_header<K, V>(key: K, value: V) -> TestRequest pub fn with_header<K, V>(key: K, value: V) -> TestRequest
where where
HeaderName: TryFrom<K>, HeaderName: HttpTryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
V: IntoHeaderValue, V: IntoHeaderValue,
{ {
TestRequest::default().header(key, value).take() TestRequest::default().header(key, value).take()
@@ -116,8 +118,7 @@ impl TestRequest {
/// Set a header /// Set a header
pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self
where where
HeaderName: TryFrom<K>, HeaderName: HttpTryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
V: IntoHeaderValue, V: IntoHeaderValue,
{ {
if let Ok(key) = HeaderName::try_from(key) { if let Ok(key) = HeaderName::try_from(key) {
@@ -243,30 +244,27 @@ impl io::Write for TestBuffer {
} }
} }
impl AsyncRead for TestBuffer { impl AsyncRead for TestBuffer {}
fn poll_read(
self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.get_mut().read(buf))
}
}
impl AsyncWrite for TestBuffer { impl AsyncWrite for TestBuffer {
fn poll_write( fn shutdown(&mut self) -> Poll<(), io::Error> {
self: Pin<&mut Self>, Ok(Async::Ready(()))
_: &mut Context<'_>, }
buf: &[u8], fn write_buf<B: Buf>(&mut self, _: &mut B) -> Poll<usize, io::Error> {
) -> Poll<io::Result<usize>> { Ok(Async::NotReady)
Poll::Ready(self.get_mut().write(buf)) }
} }
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { impl IoStream for TestBuffer {
Poll::Ready(Ok(())) fn set_nodelay(&mut self, _nodelay: bool) -> io::Result<()> {
Ok(())
} }
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { fn set_linger(&mut self, _dur: Option<std::time::Duration>) -> io::Result<()> {
Poll::Ready(Ok(())) Ok(())
}
fn set_keepalive(&mut self, _dur: Option<std::time::Duration>) -> io::Result<()> {
Ok(())
} }
} }

View File

@@ -12,12 +12,10 @@ pub enum Message {
Text(String), Text(String),
/// Binary message /// Binary message
Binary(Bytes), Binary(Bytes),
/// Continuation
Continuation(Item),
/// Ping message /// Ping message
Ping(Bytes), Ping(String),
/// Pong message /// Pong message
Pong(Bytes), Pong(String),
/// Close message with optional reason /// Close message with optional reason
Close(Option<CloseReason>), Close(Option<CloseReason>),
/// No-op. Useful for actix-net services /// No-op. Useful for actix-net services
@@ -28,41 +26,22 @@ pub enum Message {
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum Frame { pub enum Frame {
/// Text frame, codec does not verify utf8 encoding /// Text frame, codec does not verify utf8 encoding
Text(Bytes), Text(Option<BytesMut>),
/// Binary frame /// Binary frame
Binary(Bytes), Binary(Option<BytesMut>),
/// Continuation
Continuation(Item),
/// Ping message /// Ping message
Ping(Bytes), Ping(String),
/// Pong message /// Pong message
Pong(Bytes), Pong(String),
/// Close message with optional reason /// Close message with optional reason
Close(Option<CloseReason>), Close(Option<CloseReason>),
} }
/// `WebSocket` continuation item
#[derive(Debug, PartialEq)]
pub enum Item {
FirstText(Bytes),
FirstBinary(Bytes),
Continue(Bytes),
Last(Bytes),
}
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
/// WebSockets protocol codec /// WebSockets protocol codec
pub struct Codec { pub struct Codec {
flags: Flags,
max_size: usize, max_size: usize,
} server: bool,
bitflags::bitflags! {
struct Flags: u8 {
const SERVER = 0b0000_0001;
const CONTINUATION = 0b0000_0010;
const W_CONTINUATION = 0b0000_0100;
}
} }
impl Codec { impl Codec {
@@ -70,7 +49,7 @@ impl Codec {
pub fn new() -> Codec { pub fn new() -> Codec {
Codec { Codec {
max_size: 65_536, max_size: 65_536,
flags: Flags::SERVER, server: true,
} }
} }
@@ -86,7 +65,7 @@ impl Codec {
/// ///
/// By default decoder works in server mode. /// By default decoder works in server mode.
pub fn client_mode(mut self) -> Self { pub fn client_mode(mut self) -> Self {
self.flags.remove(Flags::SERVER); self.server = false;
self self
} }
} }
@@ -97,94 +76,19 @@ impl Encoder for Codec {
fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
match item { match item {
Message::Text(txt) => Parser::write_message( Message::Text(txt) => {
dst, Parser::write_message(dst, txt, OpCode::Text, true, !self.server)
txt,
OpCode::Text,
true,
!self.flags.contains(Flags::SERVER),
),
Message::Binary(bin) => Parser::write_message(
dst,
bin,
OpCode::Binary,
true,
!self.flags.contains(Flags::SERVER),
),
Message::Ping(txt) => Parser::write_message(
dst,
txt,
OpCode::Ping,
true,
!self.flags.contains(Flags::SERVER),
),
Message::Pong(txt) => Parser::write_message(
dst,
txt,
OpCode::Pong,
true,
!self.flags.contains(Flags::SERVER),
),
Message::Close(reason) => {
Parser::write_close(dst, reason, !self.flags.contains(Flags::SERVER))
} }
Message::Continuation(cont) => match cont { Message::Binary(bin) => {
Item::FirstText(data) => { Parser::write_message(dst, bin, OpCode::Binary, true, !self.server)
if self.flags.contains(Flags::W_CONTINUATION) {
return Err(ProtocolError::ContinuationStarted);
} else {
self.flags.insert(Flags::W_CONTINUATION);
Parser::write_message(
dst,
&data[..],
OpCode::Binary,
false,
!self.flags.contains(Flags::SERVER),
)
} }
Message::Ping(txt) => {
Parser::write_message(dst, txt, OpCode::Ping, true, !self.server)
} }
Item::FirstBinary(data) => { Message::Pong(txt) => {
if self.flags.contains(Flags::W_CONTINUATION) { Parser::write_message(dst, txt, OpCode::Pong, true, !self.server)
return Err(ProtocolError::ContinuationStarted);
} else {
self.flags.insert(Flags::W_CONTINUATION);
Parser::write_message(
dst,
&data[..],
OpCode::Text,
false,
!self.flags.contains(Flags::SERVER),
)
} }
} Message::Close(reason) => Parser::write_close(dst, reason, !self.server),
Item::Continue(data) => {
if self.flags.contains(Flags::W_CONTINUATION) {
Parser::write_message(
dst,
&data[..],
OpCode::Continue,
false,
!self.flags.contains(Flags::SERVER),
)
} else {
return Err(ProtocolError::ContinuationNotStarted);
}
}
Item::Last(data) => {
if self.flags.contains(Flags::W_CONTINUATION) {
self.flags.remove(Flags::W_CONTINUATION);
Parser::write_message(
dst,
&data[..],
OpCode::Continue,
true,
!self.flags.contains(Flags::SERVER),
)
} else {
return Err(ProtocolError::ContinuationNotStarted);
}
}
},
Message::Nop => (), Message::Nop => (),
} }
Ok(()) Ok(())
@@ -196,64 +100,15 @@ impl Decoder for Codec {
type Error = ProtocolError; type Error = ProtocolError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size) { match Parser::parse(src, self.server, self.max_size) {
Ok(Some((finished, opcode, payload))) => { Ok(Some((finished, opcode, payload))) => {
// continuation is not supported // continuation is not supported
if !finished { if !finished {
return match opcode { return Err(ProtocolError::NoContinuation);
OpCode::Continue => {
if self.flags.contains(Flags::CONTINUATION) {
Ok(Some(Frame::Continuation(Item::Continue(
payload
.map(|pl| pl.freeze())
.unwrap_or_else(Bytes::new),
))))
} else {
Err(ProtocolError::ContinuationNotStarted)
}
}
OpCode::Binary => {
if !self.flags.contains(Flags::CONTINUATION) {
self.flags.insert(Flags::CONTINUATION);
Ok(Some(Frame::Continuation(Item::FirstBinary(
payload
.map(|pl| pl.freeze())
.unwrap_or_else(Bytes::new),
))))
} else {
Err(ProtocolError::ContinuationStarted)
}
}
OpCode::Text => {
if !self.flags.contains(Flags::CONTINUATION) {
self.flags.insert(Flags::CONTINUATION);
Ok(Some(Frame::Continuation(Item::FirstText(
payload
.map(|pl| pl.freeze())
.unwrap_or_else(Bytes::new),
))))
} else {
Err(ProtocolError::ContinuationStarted)
}
}
_ => {
error!("Unfinished fragment {:?}", opcode);
Err(ProtocolError::ContinuationFragment(opcode))
}
};
} }
match opcode { match opcode {
OpCode::Continue => { OpCode::Continue => Err(ProtocolError::NoContinuation),
if self.flags.contains(Flags::CONTINUATION) {
self.flags.remove(Flags::CONTINUATION);
Ok(Some(Frame::Continuation(Item::Last(
payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
))))
} else {
Err(ProtocolError::ContinuationNotStarted)
}
}
OpCode::Bad => Err(ProtocolError::BadOpCode), OpCode::Bad => Err(ProtocolError::BadOpCode),
OpCode::Close => { OpCode::Close => {
if let Some(ref pl) = payload { if let Some(ref pl) = payload {
@@ -263,18 +118,29 @@ impl Decoder for Codec {
Ok(Some(Frame::Close(None))) Ok(Some(Frame::Close(None)))
} }
} }
OpCode::Ping => Ok(Some(Frame::Ping( OpCode::Ping => {
payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), if let Some(ref pl) = payload {
))), Ok(Some(Frame::Ping(String::from_utf8_lossy(pl).into())))
OpCode::Pong => Ok(Some(Frame::Pong( } else {
payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), Ok(Some(Frame::Ping(String::new())))
))), }
OpCode::Binary => Ok(Some(Frame::Binary( }
payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), OpCode::Pong => {
))), if let Some(ref pl) = payload {
OpCode::Text => Ok(Some(Frame::Text( Ok(Some(Frame::Pong(String::from_utf8_lossy(pl).into())))
payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), } else {
))), Ok(Some(Frame::Pong(String::new())))
}
}
OpCode::Binary => Ok(Some(Frame::Binary(payload))),
OpCode::Text => {
Ok(Some(Frame::Text(payload)))
//let tmp = Vec::from(payload.as_ref());
//match String::from_utf8(tmp) {
// Ok(s) => Ok(Some(Message::Text(s))),
// Err(_) => Err(ProtocolError::BadEncoding),
//}
}
} }
} }
Ok(None) => Ok(None), Ok(None) => Ok(None),

View File

@@ -1,6 +1,6 @@
use std::convert::TryFrom; use std::convert::TryFrom;
use bytes::{Buf, BufMut, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};
use log::debug; use log::debug;
use rand; use rand;
@@ -108,7 +108,7 @@ impl Parser {
} }
// remove prefix // remove prefix
src.advance(idx); src.split_to(idx);
// no need for body // no need for body
if length == 0 { if length == 0 {
@@ -154,14 +154,14 @@ impl Parser {
} }
/// Generate binary representation /// Generate binary representation
pub fn write_message<B: AsRef<[u8]>>( pub fn write_message<B: Into<Bytes>>(
dst: &mut BytesMut, dst: &mut BytesMut,
pl: B, pl: B,
op: OpCode, op: OpCode,
fin: bool, fin: bool,
mask: bool, mask: bool,
) { ) {
let payload = pl.as_ref(); let payload = pl.into();
let one: u8 = if fin { let one: u8 = if fin {
0x80 | Into::<u8>::into(op) 0x80 | Into::<u8>::into(op)
} else { } else {
@@ -180,11 +180,11 @@ impl Parser {
} else if payload_len <= 65_535 { } else if payload_len <= 65_535 {
dst.reserve(p_len + 4 + if mask { 4 } else { 0 }); dst.reserve(p_len + 4 + if mask { 4 } else { 0 });
dst.put_slice(&[one, two | 126]); dst.put_slice(&[one, two | 126]);
dst.put_u16(payload_len as u16); dst.put_u16_be(payload_len as u16);
} else { } else {
dst.reserve(p_len + 10 + if mask { 4 } else { 0 }); dst.reserve(p_len + 10 + if mask { 4 } else { 0 });
dst.put_slice(&[one, two | 127]); dst.put_slice(&[one, two | 127]);
dst.put_u64(payload_len as u64); dst.put_u64_be(payload_len as u64);
}; };
if mask { if mask {

View File

@@ -51,7 +51,7 @@ pub(crate) fn apply_mask(buf: &mut [u8], mask_u32: u32) {
// inefficient, it could be done better. The compiler does not understand that // inefficient, it could be done better. The compiler does not understand that
// a `ShortSlice` must be smaller than a u64. // a `ShortSlice` must be smaller than a u64.
#[allow(clippy::needless_pass_by_value)] #[allow(clippy::needless_pass_by_value)]
fn xor_short(buf: ShortSlice<'_>, mask: u64) { fn xor_short(buf: ShortSlice, mask: u64) {
// Unsafe: we know that a `ShortSlice` fits in a u64 // Unsafe: we know that a `ShortSlice` fits in a u64
unsafe { unsafe {
let (ptr, len) = (buf.0.as_mut_ptr(), buf.0.len()); let (ptr, len) = (buf.0.as_mut_ptr(), buf.0.len());
@@ -77,7 +77,7 @@ unsafe fn cast_slice(buf: &mut [u8]) -> &mut [u64] {
#[inline] #[inline]
// Splits a slice into three parts: an unaligned short head and tail, plus an aligned // Splits a slice into three parts: an unaligned short head and tail, plus an aligned
// u64 mid section. // u64 mid section.
fn align_buf(buf: &mut [u8]) -> (ShortSlice<'_>, &mut [u64], ShortSlice<'_>) { fn align_buf(buf: &mut [u8]) -> (ShortSlice, &mut [u64], ShortSlice) {
let start_ptr = buf.as_ptr() as usize; let start_ptr = buf.as_ptr() as usize;
let end_ptr = start_ptr + buf.len(); let end_ptr = start_ptr + buf.len();

View File

@@ -13,15 +13,15 @@ use crate::message::RequestHead;
use crate::response::{Response, ResponseBuilder}; use crate::response::{Response, ResponseBuilder};
mod codec; mod codec;
mod dispatcher;
mod frame; mod frame;
mod mask; mod mask;
mod proto; mod proto;
mod transport;
pub use self::codec::{Codec, Frame, Item, Message}; pub use self::codec::{Codec, Frame, Message};
pub use self::dispatcher::Dispatcher;
pub use self::frame::Parser; pub use self::frame::Parser;
pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode}; pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode};
pub use self::transport::Transport;
/// Websocket protocol errors /// Websocket protocol errors
#[derive(Debug, Display, From)] #[derive(Debug, Display, From)]
@@ -44,15 +44,12 @@ pub enum ProtocolError {
/// A payload reached size limit. /// A payload reached size limit.
#[display(fmt = "A payload reached size limit.")] #[display(fmt = "A payload reached size limit.")]
Overflow, Overflow,
/// Continuation is not started /// Continuation is not supported
#[display(fmt = "Continuation is not started.")] #[display(fmt = "Continuation is not supported.")]
ContinuationNotStarted, NoContinuation,
/// Received new continuation but it is already started /// Bad utf-8 encoding
#[display(fmt = "Received new continuation but it is already started")] #[display(fmt = "Bad utf-8 encoding.")]
ContinuationStarted, BadEncoding,
/// Unknown continuation fragment
#[display(fmt = "Unknown continuation fragment.")]
ContinuationFragment(OpCode),
/// Io error /// Io error
#[display(fmt = "io error: {}", _0)] #[display(fmt = "io error: {}", _0)]
Io(io::Error), Io(io::Error),

View File

@@ -24,7 +24,7 @@ pub enum OpCode {
} }
impl fmt::Display for OpCode { impl fmt::Display for OpCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
Continue => write!(f, "CONTINUE"), Continue => write!(f, "CONTINUE"),
Text => write!(f, "TEXT"), Text => write!(f, "TEXT"),
@@ -95,7 +95,7 @@ pub enum CloseCode {
Abnormal, Abnormal,
/// Indicates that an endpoint is terminating the connection /// Indicates that an endpoint is terminating the connection
/// because it has received data within a message that was not /// because it has received data within a message that was not
/// consistent with the type of the message (e.g., non-UTF-8 \[RFC3629\] /// consistent with the type of the message (e.g., non-UTF-8 [RFC3629]
/// data within a text message). /// data within a text message).
Invalid, Invalid,
/// Indicates that an endpoint is terminating the connection /// Indicates that an endpoint is terminating the connection

View File

@@ -1,22 +1,19 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_service::{IntoService, Service}; use actix_service::{IntoService, Service};
use actix_utils::framed; use actix_utils::framed::{FramedTransport, FramedTransportError};
use futures::{Future, Poll};
use super::{Codec, Frame, Message}; use super::{Codec, Frame, Message};
pub struct Dispatcher<S, T> pub struct Transport<S, T>
where where
S: Service<Request = Frame, Response = Message> + 'static, S: Service<Request = Frame, Response = Message> + 'static,
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
{ {
inner: framed::Dispatcher<S, T, Codec>, inner: FramedTransport<S, T, Codec>,
} }
impl<S, T> Dispatcher<S, T> impl<S, T> Transport<S, T>
where where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
S: Service<Request = Frame, Response = Message>, S: Service<Request = Frame, Response = Message>,
@@ -24,28 +21,29 @@ where
S::Error: 'static, S::Error: 'static,
{ {
pub fn new<F: IntoService<S>>(io: T, service: F) -> Self { pub fn new<F: IntoService<S>>(io: T, service: F) -> Self {
Dispatcher { Transport {
inner: framed::Dispatcher::new(Framed::new(io, Codec::new()), service), inner: FramedTransport::new(Framed::new(io, Codec::new()), service),
} }
} }
pub fn with<F: IntoService<S>>(framed: Framed<T, Codec>, service: F) -> Self { pub fn with<F: IntoService<S>>(framed: Framed<T, Codec>, service: F) -> Self {
Dispatcher { Transport {
inner: framed::Dispatcher::new(framed, service), inner: FramedTransport::new(framed, service),
} }
} }
} }
impl<S, T> Future for Dispatcher<S, T> impl<S, T> Future for Transport<S, T>
where where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
S: Service<Request = Frame, Response = Message>, S: Service<Request = Frame, Response = Message>,
S::Future: 'static, S::Future: 'static,
S::Error: 'static, S::Error: 'static,
{ {
type Output = Result<(), framed::DispatcherError<S::Error, Codec>>; type Item = ();
type Error = FramedTransportError<S::Error, Codec>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
Pin::new(&mut self.inner).poll(cx) self.inner.poll()
} }
} }

View File

@@ -1,9 +1,9 @@
use actix_service::ServiceFactory; use actix_service::NewService;
use bytes::Bytes; use bytes::Bytes;
use futures::future::{self, ok}; use futures::future::{self, ok};
use actix_http::{http, HttpService, Request, Response}; use actix_http::{http, HttpService, Request, Response};
use actix_http_test::test_server; use actix_http_test::TestServer;
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
@@ -27,49 +27,45 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World"; Hello World Hello World Hello World Hello World Hello World";
#[actix_rt::test] #[test]
async fn test_h1_v2() { fn test_h1_v2() {
let srv = test_server(move || { env_logger::init();
HttpService::build() let mut srv = TestServer::new(move || {
.finish(|_| future::ok::<_, ()>(Response::Ok().body(STR))) HttpService::build().finish(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
.tcp()
}); });
let response = srv.block_on(srv.get("/").send()).unwrap();
let response = srv.get("/").send().await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
let request = srv.get("/").header("x-test", "111").send(); let request = srv.get("/").header("x-test", "111").send();
let mut response = request.await.unwrap(); let response = srv.block_on(request).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
// read response // read response
let bytes = response.body().await.unwrap(); let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref())); assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
let mut response = srv.post("/").send().await.unwrap(); let response = srv.block_on(srv.post("/").send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
// read response // read response
let bytes = response.body().await.unwrap(); let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref())); assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
} }
#[actix_rt::test] #[test]
async fn test_connection_close() { fn test_connection_close() {
let srv = test_server(move || { let mut srv = TestServer::new(move || {
HttpService::build() HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR))) .finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.tcp()
.map(|_| ()) .map(|_| ())
}); });
let response = srv.block_on(srv.get("/").force_close().send()).unwrap();
let response = srv.get("/").force_close().send().await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
} }
#[actix_rt::test] #[test]
async fn test_with_query_parameter() { fn test_with_query_parameter() {
let srv = test_server(move || { let mut srv = TestServer::new(move || {
HttpService::build() HttpService::build()
.finish(|req: Request| { .finish(|req: Request| {
if req.uri().query().unwrap().contains("qp=") { if req.uri().query().unwrap().contains("qp=") {
@@ -78,11 +74,10 @@ async fn test_with_query_parameter() {
ok::<_, ()>(Response::BadRequest().finish()) ok::<_, ()>(Response::BadRequest().finish())
} }
}) })
.tcp()
.map(|_| ()) .map(|_| ())
}); });
let request = srv.request(http::Method::GET, srv.url("/?qp=5")); let request = srv.request(http::Method::GET, srv.url("/?qp=5")).send();
let response = request.send().await.unwrap(); let response = srv.block_on(request).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
} }

View File

@@ -1,416 +0,0 @@
#![cfg(feature = "openssl")]
use std::io;
use actix_http_test::test_server;
use actix_service::{fn_service, ServiceFactory};
use bytes::{Bytes, BytesMut};
use futures::future::{err, ok, ready};
use futures::stream::{once, Stream, StreamExt};
use open_ssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod};
use actix_http::error::{ErrorBadRequest, PayloadError};
use actix_http::http::header::{self, HeaderName, HeaderValue};
use actix_http::http::{Method, StatusCode, Version};
use actix_http::httpmessage::HttpMessage;
use actix_http::{body, Error, HttpService, Request, Response};
async fn load_body<S>(stream: S) -> Result<BytesMut, PayloadError>
where
S: Stream<Item = Result<Bytes, PayloadError>>,
{
let body = stream
.map(|res| match res {
Ok(chunk) => chunk,
Err(_) => panic!(),
})
.fold(BytesMut::new(), move |mut body, chunk| {
body.extend_from_slice(&chunk);
ready(body)
})
.await;
Ok(body)
}
fn ssl_acceptor() -> SslAcceptor {
// load ssl keys
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
builder
.set_private_key_file("../tests/key.pem", SslFiletype::PEM)
.unwrap();
builder
.set_certificate_chain_file("../tests/cert.pem")
.unwrap();
builder.set_alpn_select_callback(|_, protos| {
const H2: &[u8] = b"\x02h2";
const H11: &[u8] = b"\x08http/1.1";
if protos.windows(3).any(|window| window == H2) {
Ok(b"h2")
} else if protos.windows(9).any(|window| window == H11) {
Ok(b"http/1.1")
} else {
Err(AlpnError::NOACK)
}
});
builder
.set_alpn_protos(b"\x08http/1.1\x02h2")
.expect("Can not contrust SslAcceptor");
builder.build()
}
#[actix_rt::test]
async fn test_h2() -> io::Result<()> {
let srv = test_server(move || {
HttpService::build()
.h2(|_| ok::<_, Error>(Response::Ok().finish()))
.openssl(ssl_acceptor())
.map_err(|_| ())
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
}
#[actix_rt::test]
async fn test_h2_1() -> io::Result<()> {
let srv = test_server(move || {
HttpService::build()
.finish(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_2);
ok::<_, Error>(Response::Ok().finish())
})
.openssl(ssl_acceptor())
.map_err(|_| ())
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
}
#[actix_rt::test]
async fn test_h2_body() -> io::Result<()> {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let mut srv = test_server(move || {
HttpService::build()
.h2(|mut req: Request<_>| {
async move {
let body = load_body(req.take_payload()).await?;
Ok::<_, Error>(Response::Ok().body(body))
}
})
.openssl(ssl_acceptor())
.map_err(|_| ())
});
let response = srv.sget("/").send_body(data.clone()).await.unwrap();
assert!(response.status().is_success());
let body = srv.load_body(response).await.unwrap();
assert_eq!(&body, data.as_bytes());
Ok(())
}
#[actix_rt::test]
async fn test_h2_content_length() {
let srv = test_server(move || {
HttpService::build()
.h2(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
ok::<_, ()>(Response::new(statuses[indx]))
})
.openssl(ssl_acceptor())
.map_err(|_| ())
});
let header = HeaderName::from_static("content-length");
let value = HeaderValue::from_static("0");
{
for i in 0..4 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);
let req = srv
.request(Method::HEAD, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);
}
for i in 4..6 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
}
}
}
#[actix_rt::test]
async fn test_h2_headers() {
let data = STR.repeat(10);
let data2 = data.clone();
let mut srv = test_server(move || {
let data = data.clone();
HttpService::build().h2(move |_| {
let mut builder = Response::Ok();
for idx in 0..90 {
builder.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
ok::<_, ()>(builder.body(data.clone()))
})
.openssl(ssl_acceptor())
.map_err(|_| ())
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from(data2));
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[actix_rt::test]
async fn test_h2_body2() {
let mut srv = test_server(move || {
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.openssl(ssl_acceptor())
.map_err(|_| ())
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[actix_rt::test]
async fn test_h2_head_empty() {
let mut srv = test_server(move || {
HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.openssl(ssl_acceptor())
.map_err(|_| ())
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
}
#[actix_rt::test]
async fn test_h2_head_binary() {
let mut srv = test_server(move || {
HttpService::build()
.h2(|_| {
ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).body(STR))
})
.openssl(ssl_acceptor())
.map_err(|_| ())
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
}
#[actix_rt::test]
async fn test_h2_head_binary2() {
let srv = test_server(move || {
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.openssl(ssl_acceptor())
.map_err(|_| ())
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
}
#[actix_rt::test]
async fn test_h2_body_length() {
let mut srv = test_server(move || {
HttpService::build()
.h2(|_| {
let body = once(ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)),
)
})
.openssl(ssl_acceptor())
.map_err(|_| ())
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[actix_rt::test]
async fn test_h2_body_chunked_explicit() {
let mut srv = test_server(move || {
HttpService::build()
.h2(|_| {
let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
.openssl(ssl_acceptor())
.map_err(|_| ())
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
assert!(!response.headers().contains_key(header::TRANSFER_ENCODING));
// read response
let bytes = srv.load_body(response).await.unwrap();
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[actix_rt::test]
async fn test_h2_response_http_error_handling() {
let mut srv = test_server(move || {
HttpService::build()
.h2(fn_service(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>(
Response::Ok()
.header(header::CONTENT_TYPE, broken_header)
.body(STR),
)
}))
.openssl(ssl_acceptor())
.map_err(|_| ())
});
let response = srv.sget("/").send().await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
}
#[actix_rt::test]
async fn test_h2_service_error() {
let mut srv = test_server(move || {
HttpService::build()
.h2(|_| err::<Response, Error>(ErrorBadRequest("error")))
.openssl(ssl_acceptor())
.map_err(|_| ())
});
let response = srv.sget("/").send().await.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
}
#[actix_rt::test]
async fn test_h2_on_connect() {
let srv = test_server(move || {
HttpService::build()
.on_connect(|_| 10usize)
.h2(|req: Request| {
assert!(req.extensions().contains::<usize>());
ok::<_, ()>(Response::Ok().finish())
})
.openssl(ssl_acceptor())
.map_err(|_| ())
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
}

View File

@@ -1,421 +0,0 @@
#![cfg(feature = "rustls")]
use actix_http::error::PayloadError;
use actix_http::http::header::{self, HeaderName, HeaderValue};
use actix_http::http::{Method, StatusCode, Version};
use actix_http::{body, error, Error, HttpService, Request, Response};
use actix_http_test::test_server;
use actix_service::{fn_factory_with_config, fn_service};
use bytes::{Bytes, BytesMut};
use futures::future::{self, err, ok};
use futures::stream::{once, Stream, StreamExt};
use rust_tls::{
internal::pemfile::{certs, pkcs8_private_keys},
NoClientAuth, ServerConfig as RustlsServerConfig,
};
use std::fs::File;
use std::io::{self, BufReader};
async fn load_body<S>(mut stream: S) -> Result<BytesMut, PayloadError>
where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{
let mut body = BytesMut::new();
while let Some(item) = stream.next().await {
body.extend_from_slice(&item?)
}
Ok(body)
}
fn ssl_acceptor() -> RustlsServerConfig {
// load ssl keys
let mut config = RustlsServerConfig::new(NoClientAuth::new());
let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap());
let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap());
let cert_chain = certs(cert_file).unwrap();
let mut keys = pkcs8_private_keys(key_file).unwrap();
config.set_single_cert(cert_chain, keys.remove(0)).unwrap();
config
}
#[actix_rt::test]
async fn test_h1() -> io::Result<()> {
let srv = test_server(move || {
HttpService::build()
.h1(|_| future::ok::<_, Error>(Response::Ok().finish()))
.rustls(ssl_acceptor())
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
}
#[actix_rt::test]
async fn test_h2() -> io::Result<()> {
let srv = test_server(move || {
HttpService::build()
.h2(|_| future::ok::<_, Error>(Response::Ok().finish()))
.rustls(ssl_acceptor())
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
}
#[actix_rt::test]
async fn test_h1_1() -> io::Result<()> {
let srv = test_server(move || {
HttpService::build()
.h1(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_11);
future::ok::<_, Error>(Response::Ok().finish())
})
.rustls(ssl_acceptor())
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
}
#[actix_rt::test]
async fn test_h2_1() -> io::Result<()> {
let srv = test_server(move || {
HttpService::build()
.finish(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_2);
future::ok::<_, Error>(Response::Ok().finish())
})
.rustls(ssl_acceptor())
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
Ok(())
}
#[actix_rt::test]
async fn test_h2_body1() -> io::Result<()> {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let mut srv = test_server(move || {
HttpService::build()
.h2(|mut req: Request<_>| {
async move {
let body = load_body(req.take_payload()).await?;
Ok::<_, Error>(Response::Ok().body(body))
}
})
.rustls(ssl_acceptor())
});
let response = srv.sget("/").send_body(data.clone()).await.unwrap();
assert!(response.status().is_success());
let body = srv.load_body(response).await.unwrap();
assert_eq!(&body, data.as_bytes());
Ok(())
}
#[actix_rt::test]
async fn test_h2_content_length() {
let srv = test_server(move || {
HttpService::build()
.h2(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
future::ok::<_, ()>(Response::new(statuses[indx]))
})
.rustls(ssl_acceptor())
});
let header = HeaderName::from_static("content-length");
let value = HeaderValue::from_static("0");
{
for i in 0..4 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);
let req = srv
.request(Method::HEAD, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);
}
for i in 4..6 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
}
}
}
#[actix_rt::test]
async fn test_h2_headers() {
let data = STR.repeat(10);
let data2 = data.clone();
let mut srv = test_server(move || {
let data = data.clone();
HttpService::build().h2(move |_| {
let mut config = Response::Ok();
for idx in 0..90 {
config.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
future::ok::<_, ()>(config.body(data.clone()))
})
.rustls(ssl_acceptor())
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from(data2));
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[actix_rt::test]
async fn test_h2_body2() {
let mut srv = test_server(move || {
HttpService::build()
.h2(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
.rustls(ssl_acceptor())
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[actix_rt::test]
async fn test_h2_head_empty() {
let mut srv = test_server(move || {
HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.rustls(ssl_acceptor())
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
}
#[actix_rt::test]
async fn test_h2_head_binary() {
let mut srv = test_server(move || {
HttpService::build()
.h2(|_| {
ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).body(STR))
})
.rustls(ssl_acceptor())
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
}
#[actix_rt::test]
async fn test_h2_head_binary2() {
let srv = test_server(move || {
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.rustls(ssl_acceptor())
});
let response = srv.shead("/").send().await.unwrap();
assert!(response.status().is_success());
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
}
#[actix_rt::test]
async fn test_h2_body_length() {
let mut srv = test_server(move || {
HttpService::build()
.h2(|_| {
let body = once(ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)),
)
})
.rustls(ssl_acceptor())
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[actix_rt::test]
async fn test_h2_body_chunked_explicit() {
let mut srv = test_server(move || {
HttpService::build()
.h2(|_| {
let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
.rustls(ssl_acceptor())
});
let response = srv.sget("/").send().await.unwrap();
assert!(response.status().is_success());
assert!(!response.headers().contains_key(header::TRANSFER_ENCODING));
// read response
let bytes = srv.load_body(response).await.unwrap();
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[actix_rt::test]
async fn test_h2_response_http_error_handling() {
let mut srv = test_server(move || {
HttpService::build()
.h2(fn_factory_with_config(|_: ()| {
ok::<_, ()>(fn_service(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>(
Response::Ok()
.header(http::header::CONTENT_TYPE, broken_header)
.body(STR),
)
}))
}))
.rustls(ssl_acceptor())
});
let response = srv.sget("/").send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
}
#[actix_rt::test]
async fn test_h2_service_error() {
let mut srv = test_server(move || {
HttpService::build()
.h2(|_| err::<Response, Error>(error::ErrorBadRequest("error")))
.rustls(ssl_acceptor())
});
let response = srv.sget("/").send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
}
#[actix_rt::test]
async fn test_h1_service_error() {
let mut srv = test_server(move || {
HttpService::build()
.h1(|_| err::<Response, Error>(error::ErrorBadRequest("error")))
.rustls(ssl_acceptor())
});
let response = srv.sget("/").send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
}

View File

@@ -0,0 +1,462 @@
#![cfg(feature = "rust-tls")]
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http::error::PayloadError;
use actix_http::http::header::{self, HeaderName, HeaderValue};
use actix_http::http::{Method, StatusCode, Version};
use actix_http::{body, error, Error, HttpService, Request, Response};
use actix_http_test::TestServer;
use actix_server::ssl::RustlsAcceptor;
use actix_server_config::ServerConfig;
use actix_service::{new_service_cfg, NewService};
use bytes::{Bytes, BytesMut};
use futures::future::{self, ok, Future};
use futures::stream::{once, Stream};
use rustls::{
internal::pemfile::{certs, pkcs8_private_keys},
NoClientAuth, ServerConfig as RustlsServerConfig,
};
use std::fs::File;
use std::io::{BufReader, Result};
fn load_body<S>(stream: S) -> impl Future<Item = BytesMut, Error = PayloadError>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
stream.fold(BytesMut::new(), move |mut body, chunk| {
body.extend_from_slice(&chunk);
Ok::<_, PayloadError>(body)
})
}
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<RustlsAcceptor<T, ()>> {
// load ssl keys
let mut config = RustlsServerConfig::new(NoClientAuth::new());
let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap());
let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap());
let cert_chain = certs(cert_file).unwrap();
let mut keys = pkcs8_private_keys(key_file).unwrap();
config.set_single_cert(cert_chain, keys.remove(0)).unwrap();
let protos = vec![b"h2".to_vec()];
config.set_protocols(&protos);
Ok(RustlsAcceptor::new(config))
}
#[test]
fn test_h2() -> Result<()> {
let rustls = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| future::ok::<_, Error>(Response::Ok().finish()))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
Ok(())
}
#[test]
fn test_h2_1() -> Result<()> {
let rustls = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.finish(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_2);
future::ok::<_, Error>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
Ok(())
}
#[test]
fn test_h2_body() -> Result<()> {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let rustls = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|mut req: Request<_>| {
load_body(req.take_payload())
.and_then(|body| Ok(Response::Ok().body(body)))
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send_body(data.clone())).unwrap();
assert!(response.status().is_success());
let body = srv.load_body(response).unwrap();
assert_eq!(&body, data.as_bytes());
Ok(())
}
#[test]
fn test_h2_content_length() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
future::ok::<_, ()>(Response::new(statuses[indx]))
})
.map_err(|_| ()),
)
});
let header = HeaderName::from_static("content-length");
let value = HeaderValue::from_static("0");
{
for i in 0..4 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None);
let req = srv
.request(Method::HEAD, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None);
}
for i in 4..6 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
}
}
}
#[test]
fn test_h2_headers() {
let data = STR.repeat(10);
let data2 = data.clone();
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
let data = data.clone();
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build().h2(move |_| {
let mut config = Response::Ok();
for idx in 0..90 {
config.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
future::ok::<_, ()>(config.body(data.clone()))
}).map_err(|_| ()))
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from(data2));
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test]
fn test_h2_body2() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_head_empty() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty());
}
#[test]
fn test_h2_head_binary() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
ok::<_, ()>(
Response::Ok().content_length(STR.len() as u64).body(STR),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty());
}
#[test]
fn test_h2_head_binary2() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
{
let len = response
.headers()
.get(http::header::CONTENT_LENGTH)
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
}
#[test]
fn test_h2_body_length() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.body(body::SizedStream::new(STR.len() as u64, body)),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_body_chunked_explicit() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
let body =
once::<_, Error>(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
assert!(!response.headers().contains_key(header::TRANSFER_ENCODING));
// read response
let bytes = srv.load_body(response).unwrap();
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_response_http_error_handling() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(new_service_cfg(|_: &ServerConfig| {
Ok::<_, ()>(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>(
Response::Ok()
.header(http::header::CONTENT_TYPE, broken_header)
.body(STR),
)
})
}))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
}
#[test]
fn test_h2_service_error() {
let rustls = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
rustls
.clone()
.map_err(|e| println!("Rustls error: {}", e))
.and_then(
HttpService::build()
.h2(|_| Err::<Response, Error>(error::ErrorBadRequest("error")))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
}

View File

@@ -2,22 +2,23 @@ use std::io::{Read, Write};
use std::time::Duration; use std::time::Duration;
use std::{net, thread}; use std::{net, thread};
use actix_http_test::test_server; use actix_http_test::TestServer;
use actix_rt::time::delay_for; use actix_server_config::ServerConfig;
use actix_service::fn_service; use actix_service::{new_service_cfg, service_fn, NewService};
use bytes::Bytes; use bytes::Bytes;
use futures::future::{self, err, ok, ready, FutureExt}; use futures::future::{self, ok, Future};
use futures::stream::{once, StreamExt}; use futures::stream::{once, Stream};
use regex::Regex; use regex::Regex;
use tokio_timer::sleep;
use actix_http::httpmessage::HttpMessage; use actix_http::httpmessage::HttpMessage;
use actix_http::{ use actix_http::{
body, error, http, http::header, Error, HttpService, KeepAlive, Request, Response, body, error, http, http::header, Error, HttpService, KeepAlive, Request, Response,
}; };
#[actix_rt::test] #[test]
async fn test_h1() { fn test_h1() {
let srv = test_server(|| { let mut srv = TestServer::new(|| {
HttpService::build() HttpService::build()
.keep_alive(KeepAlive::Disabled) .keep_alive(KeepAlive::Disabled)
.client_timeout(1000) .client_timeout(1000)
@@ -26,16 +27,15 @@ async fn test_h1() {
assert!(req.peer_addr().is_some()); assert!(req.peer_addr().is_some());
future::ok::<_, ()>(Response::Ok().finish()) future::ok::<_, ()>(Response::Ok().finish())
}) })
.tcp()
}); });
let response = srv.get("/").send().await.unwrap(); let response = srv.block_on(srv.get("/").send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
} }
#[actix_rt::test] #[test]
async fn test_h1_2() { fn test_h1_2() {
let srv = test_server(|| { let mut srv = TestServer::new(|| {
HttpService::build() HttpService::build()
.keep_alive(KeepAlive::Disabled) .keep_alive(KeepAlive::Disabled)
.client_timeout(1000) .client_timeout(1000)
@@ -45,26 +45,25 @@ async fn test_h1_2() {
assert_eq!(req.version(), http::Version::HTTP_11); assert_eq!(req.version(), http::Version::HTTP_11);
future::ok::<_, ()>(Response::Ok().finish()) future::ok::<_, ()>(Response::Ok().finish())
}) })
.tcp() .map(|_| ())
}); });
let response = srv.get("/").send().await.unwrap(); let response = srv.block_on(srv.get("/").send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
} }
#[actix_rt::test] #[test]
async fn test_expect_continue() { fn test_expect_continue() {
let srv = test_server(|| { let srv = TestServer::new(|| {
HttpService::build() HttpService::build()
.expect(fn_service(|req: Request| { .expect(service_fn(|req: Request| {
if req.head().uri.query() == Some("yes=") { if req.head().uri.query() == Some("yes=") {
ok(req) Ok(req)
} else { } else {
err(error::ErrorPreconditionFailed("error")) Err(error::ErrorPreconditionFailed("error"))
} }
})) }))
.finish(|_| future::ok::<_, ()>(Response::Ok().finish())) .finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
.tcp()
}); });
let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@@ -80,21 +79,20 @@ async fn test_expect_continue() {
assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n"));
} }
#[actix_rt::test] #[test]
async fn test_expect_continue_h1() { fn test_expect_continue_h1() {
let srv = test_server(|| { let srv = TestServer::new(|| {
HttpService::build() HttpService::build()
.expect(fn_service(|req: Request| { .expect(service_fn(|req: Request| {
delay_for(Duration::from_millis(20)).then(move |_| { sleep(Duration::from_millis(20)).then(move |_| {
if req.head().uri.query() == Some("yes=") { if req.head().uri.query() == Some("yes=") {
ok(req) Ok(req)
} else { } else {
err(error::ErrorPreconditionFailed("error")) Err(error::ErrorPreconditionFailed("error"))
} }
}) })
})) }))
.h1(fn_service(|_| future::ok::<_, ()>(Response::Ok().finish()))) .h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
.tcp()
}); });
let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@@ -110,26 +108,19 @@ async fn test_expect_continue_h1() {
assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n"));
} }
#[actix_rt::test] #[test]
async fn test_chunked_payload() { fn test_chunked_payload() {
let chunk_sizes = vec![32768, 32, 32768]; let chunk_sizes = vec![32768, 32, 32768];
let total_size: usize = chunk_sizes.iter().sum(); let total_size: usize = chunk_sizes.iter().sum();
let srv = test_server(|| { let srv = TestServer::new(|| {
HttpService::build() HttpService::build().h1(|mut request: Request| {
.h1(fn_service(|mut request: Request| {
request request
.take_payload() .take_payload()
.map(|res| match res { .map_err(|e| panic!(format!("Error reading payload: {}", e)))
Ok(pl) => pl, .fold(0usize, |acc, chunk| future::ok::<_, ()>(acc + chunk.len()))
Err(e) => panic!(format!("Error reading payload: {}", e)), .map(|req_size| Response::Ok().body(format!("size={}", req_size)))
}) })
.fold(0usize, |acc, chunk| ready(acc + chunk.len()))
.map(|req_size| {
Ok::<_, Error>(Response::Ok().body(format!("size={}", req_size)))
})
}))
.tcp()
}); });
let returned_size = { let returned_size = {
@@ -165,13 +156,12 @@ async fn test_chunked_payload() {
assert_eq!(returned_size, total_size); assert_eq!(returned_size, total_size);
} }
#[actix_rt::test] #[test]
async fn test_slow_request() { fn test_slow_request() {
let srv = test_server(|| { let srv = TestServer::new(|| {
HttpService::build() HttpService::build()
.client_timeout(100) .client_timeout(100)
.finish(|_| future::ok::<_, ()>(Response::Ok().finish())) .finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
.tcp()
}); });
let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@@ -181,12 +171,10 @@ async fn test_slow_request() {
assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); assert!(data.starts_with("HTTP/1.1 408 Request Timeout"));
} }
#[actix_rt::test] #[test]
async fn test_http1_malformed_request() { fn test_http1_malformed_request() {
let srv = test_server(|| { let srv = TestServer::new(|| {
HttpService::build() HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
.h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
.tcp()
}); });
let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@@ -196,12 +184,10 @@ async fn test_http1_malformed_request() {
assert!(data.starts_with("HTTP/1.1 400 Bad Request")); assert!(data.starts_with("HTTP/1.1 400 Bad Request"));
} }
#[actix_rt::test] #[test]
async fn test_http1_keepalive() { fn test_http1_keepalive() {
let srv = test_server(|| { let srv = TestServer::new(|| {
HttpService::build() HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
.h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
.tcp()
}); });
let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@@ -216,13 +202,12 @@ async fn test_http1_keepalive() {
assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n");
} }
#[actix_rt::test] #[test]
async fn test_http1_keepalive_timeout() { fn test_http1_keepalive_timeout() {
let srv = test_server(|| { let srv = TestServer::new(|| {
HttpService::build() HttpService::build()
.keep_alive(1) .keep_alive(1)
.h1(|_| future::ok::<_, ()>(Response::Ok().finish())) .h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
.tcp()
}); });
let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@@ -237,12 +222,10 @@ async fn test_http1_keepalive_timeout() {
assert_eq!(res, 0); assert_eq!(res, 0);
} }
#[actix_rt::test] #[test]
async fn test_http1_keepalive_close() { fn test_http1_keepalive_close() {
let srv = test_server(|| { let srv = TestServer::new(|| {
HttpService::build() HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
.h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
.tcp()
}); });
let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@@ -257,12 +240,10 @@ async fn test_http1_keepalive_close() {
assert_eq!(res, 0); assert_eq!(res, 0);
} }
#[actix_rt::test] #[test]
async fn test_http10_keepalive_default_close() { fn test_http10_keepalive_default_close() {
let srv = test_server(|| { let srv = TestServer::new(|| {
HttpService::build() HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
.h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
.tcp()
}); });
let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@@ -276,12 +257,10 @@ async fn test_http10_keepalive_default_close() {
assert_eq!(res, 0); assert_eq!(res, 0);
} }
#[actix_rt::test] #[test]
async fn test_http10_keepalive() { fn test_http10_keepalive() {
let srv = test_server(|| { let srv = TestServer::new(|| {
HttpService::build() HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
.h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
.tcp()
}); });
let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@@ -302,13 +281,12 @@ async fn test_http10_keepalive() {
assert_eq!(res, 0); assert_eq!(res, 0);
} }
#[actix_rt::test] #[test]
async fn test_http1_keepalive_disabled() { fn test_http1_keepalive_disabled() {
let srv = test_server(|| { let srv = TestServer::new(|| {
HttpService::build() HttpService::build()
.keep_alive(KeepAlive::Disabled) .keep_alive(KeepAlive::Disabled)
.h1(|_| future::ok::<_, ()>(Response::Ok().finish())) .h1(|_| future::ok::<_, ()>(Response::Ok().finish()))
.tcp()
}); });
let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@@ -322,16 +300,15 @@ async fn test_http1_keepalive_disabled() {
assert_eq!(res, 0); assert_eq!(res, 0);
} }
#[actix_rt::test] #[test]
async fn test_content_length() { fn test_content_length() {
use actix_http::http::{ use actix_http::http::{
header::{HeaderName, HeaderValue}, header::{HeaderName, HeaderValue},
StatusCode, StatusCode,
}; };
let srv = test_server(|| { let mut srv = TestServer::new(|| {
HttpService::build() HttpService::build().h1(|req: Request| {
.h1(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap(); let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [ let statuses = [
StatusCode::NO_CONTENT, StatusCode::NO_CONTENT,
@@ -343,7 +320,6 @@ async fn test_content_length() {
]; ];
future::ok::<_, ()>(Response::new(statuses[indx])) future::ok::<_, ()>(Response::new(statuses[indx]))
}) })
.tcp()
}); });
let header = HeaderName::from_static("content-length"); let header = HeaderName::from_static("content-length");
@@ -351,29 +327,35 @@ async fn test_content_length() {
{ {
for i in 0..4 { for i in 0..4 {
let req = srv.request(http::Method::GET, srv.url(&format!("/{}", i))); let req = srv
let response = req.send().await.unwrap(); .request(http::Method::GET, srv.url(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None); assert_eq!(response.headers().get(&header), None);
let req = srv.request(http::Method::HEAD, srv.url(&format!("/{}", i))); let req = srv
let response = req.send().await.unwrap(); .request(http::Method::HEAD, srv.url(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None); assert_eq!(response.headers().get(&header), None);
} }
for i in 4..6 { for i in 4..6 {
let req = srv.request(http::Method::GET, srv.url(&format!("/{}", i))); let req = srv
let response = req.send().await.unwrap(); .request(http::Method::GET, srv.url(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), Some(&value)); assert_eq!(response.headers().get(&header), Some(&value));
} }
} }
} }
#[actix_rt::test] #[test]
async fn test_h1_headers() { fn test_h1_headers() {
let data = STR.repeat(10); let data = STR.repeat(10);
let data2 = data.clone(); let data2 = data.clone();
let mut srv = test_server(move || { let mut srv = TestServer::new(move || {
let data = data.clone(); let data = data.clone();
HttpService::build().h1(move |_| { HttpService::build().h1(move |_| {
let mut builder = Response::Ok(); let mut builder = Response::Ok();
@@ -396,14 +378,14 @@ async fn test_h1_headers() {
); );
} }
future::ok::<_, ()>(builder.body(data.clone())) future::ok::<_, ()>(builder.body(data.clone()))
}).tcp() })
}); });
let response = srv.get("/").send().await.unwrap(); let response = srv.block_on(srv.get("/").send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
// read response // read response
let bytes = srv.load_body(response).await.unwrap(); let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from(data2)); assert_eq!(bytes, Bytes::from(data2));
} }
@@ -429,31 +411,27 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World"; Hello World Hello World Hello World Hello World Hello World";
#[actix_rt::test] #[test]
async fn test_h1_body() { fn test_h1_body() {
let mut srv = test_server(|| { let mut srv = TestServer::new(|| {
HttpService::build() HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
.h1(|_| ok::<_, ()>(Response::Ok().body(STR)))
.tcp()
}); });
let response = srv.get("/").send().await.unwrap(); let response = srv.block_on(srv.get("/").send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
// read response // read response
let bytes = srv.load_body(response).await.unwrap(); let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref())); assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
} }
#[actix_rt::test] #[test]
async fn test_h1_head_empty() { fn test_h1_head_empty() {
let mut srv = test_server(|| { let mut srv = TestServer::new(|| {
HttpService::build() HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR)))
.h1(|_| ok::<_, ()>(Response::Ok().body(STR)))
.tcp()
}); });
let response = srv.head("/").send().await.unwrap(); let response = srv.block_on(srv.head("/").send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
{ {
@@ -465,21 +443,19 @@ async fn test_h1_head_empty() {
} }
// read response // read response
let bytes = srv.load_body(response).await.unwrap(); let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty()); assert!(bytes.is_empty());
} }
#[actix_rt::test] #[test]
async fn test_h1_head_binary() { fn test_h1_head_binary() {
let mut srv = test_server(|| { let mut srv = TestServer::new(|| {
HttpService::build() HttpService::build().h1(|_| {
.h1(|_| {
ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).body(STR)) ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).body(STR))
}) })
.tcp()
}); });
let response = srv.head("/").send().await.unwrap(); let response = srv.block_on(srv.head("/").send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
{ {
@@ -491,19 +467,17 @@ async fn test_h1_head_binary() {
} }
// read response // read response
let bytes = srv.load_body(response).await.unwrap(); let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty()); assert!(bytes.is_empty());
} }
#[actix_rt::test] #[test]
async fn test_h1_head_binary2() { fn test_h1_head_binary2() {
let srv = test_server(|| { let mut srv = TestServer::new(|| {
HttpService::build() HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR)))
.h1(|_| ok::<_, ()>(Response::Ok().body(STR)))
.tcp()
}); });
let response = srv.head("/").send().await.unwrap(); let response = srv.block_on(srv.head("/").send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
{ {
@@ -515,43 +489,39 @@ async fn test_h1_head_binary2() {
} }
} }
#[actix_rt::test] #[test]
async fn test_h1_body_length() { fn test_h1_body_length() {
let mut srv = test_server(|| { let mut srv = TestServer::new(|| {
HttpService::build() HttpService::build().h1(|_| {
.h1(|_| { let body = once(Ok(Bytes::from_static(STR.as_ref())));
let body = once(ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>( ok::<_, ()>(
Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)), Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)),
) )
}) })
.tcp()
}); });
let response = srv.get("/").send().await.unwrap(); let response = srv.block_on(srv.get("/").send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
// read response // read response
let bytes = srv.load_body(response).await.unwrap(); let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref())); assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
} }
#[actix_rt::test] #[test]
async fn test_h1_body_chunked_explicit() { fn test_h1_body_chunked_explicit() {
let mut srv = test_server(|| { let mut srv = TestServer::new(|| {
HttpService::build() HttpService::build().h1(|_| {
.h1(|_| { let body = once::<_, Error>(Ok(Bytes::from_static(STR.as_ref())));
let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref())));
ok::<_, ()>( ok::<_, ()>(
Response::Ok() Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked") .header(header::TRANSFER_ENCODING, "chunked")
.streaming(body), .streaming(body),
) )
}) })
.tcp()
}); });
let response = srv.get("/").send().await.unwrap(); let response = srv.block_on(srv.get("/").send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
assert_eq!( assert_eq!(
response response
@@ -564,24 +534,22 @@ async fn test_h1_body_chunked_explicit() {
); );
// read response // read response
let bytes = srv.load_body(response).await.unwrap(); let bytes = srv.load_body(response).unwrap();
// decode // decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref())); assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
} }
#[actix_rt::test] #[test]
async fn test_h1_body_chunked_implicit() { fn test_h1_body_chunked_implicit() {
let mut srv = test_server(|| { let mut srv = TestServer::new(|| {
HttpService::build() HttpService::build().h1(|_| {
.h1(|_| { let body = once::<_, Error>(Ok(Bytes::from_static(STR.as_ref())));
let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(Response::Ok().streaming(body)) ok::<_, ()>(Response::Ok().streaming(body))
}) })
.tcp()
}); });
let response = srv.get("/").send().await.unwrap(); let response = srv.block_on(srv.get("/").send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
assert_eq!( assert_eq!(
response response
@@ -594,61 +562,59 @@ async fn test_h1_body_chunked_implicit() {
); );
// read response // read response
let bytes = srv.load_body(response).await.unwrap(); let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref())); assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
} }
#[actix_rt::test] #[test]
async fn test_h1_response_http_error_handling() { fn test_h1_response_http_error_handling() {
let mut srv = test_server(|| { let mut srv = TestServer::new(|| {
HttpService::build() HttpService::build().h1(new_service_cfg(|_: &ServerConfig| {
.h1(fn_service(|_| { Ok::<_, ()>(|_| {
let broken_header = Bytes::from_static(b"\0\0\0"); let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>( ok::<_, ()>(
Response::Ok() Response::Ok()
.header(http::header::CONTENT_TYPE, broken_header) .header(http::header::CONTENT_TYPE, broken_header)
.body(STR), .body(STR),
) )
})
})) }))
.tcp()
}); });
let response = srv.get("/").send().await.unwrap(); let response = srv.block_on(srv.get("/").send()).unwrap();
assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
// read response // read response
let bytes = srv.load_body(response).await.unwrap(); let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
} }
#[actix_rt::test] #[test]
async fn test_h1_service_error() { fn test_h1_service_error() {
let mut srv = test_server(|| { let mut srv = TestServer::new(|| {
HttpService::build() HttpService::build()
.h1(|_| future::err::<Response, Error>(error::ErrorBadRequest("error"))) .h1(|_| Err::<Response, Error>(error::ErrorBadRequest("error")))
.tcp()
}); });
let response = srv.get("/").send().await.unwrap(); let response = srv.block_on(srv.get("/").send()).unwrap();
assert_eq!(response.status(), http::StatusCode::BAD_REQUEST); assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
// read response // read response
let bytes = srv.load_body(response).await.unwrap(); let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"error")); assert_eq!(bytes, Bytes::from_static(b"error"));
} }
#[actix_rt::test] #[test]
async fn test_h1_on_connect() { fn test_h1_on_connect() {
let srv = test_server(|| { let mut srv = TestServer::new(|| {
HttpService::build() HttpService::build()
.on_connect(|_| 10usize) .on_connect(|_| 10usize)
.h1(|req: Request| { .h1(|req: Request| {
assert!(req.extensions().contains::<usize>()); assert!(req.extensions().contains::<usize>());
future::ok::<_, ()>(Response::Ok().finish()) future::ok::<_, ()>(Response::Ok().finish())
}) })
.tcp()
}); });
let response = srv.get("/").send().await.unwrap(); let response = srv.block_on(srv.get("/").send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
} }

View File

@@ -0,0 +1,480 @@
#![cfg(feature = "ssl")]
use actix_codec::{AsyncRead, AsyncWrite};
use actix_http_test::TestServer;
use actix_server::ssl::OpensslAcceptor;
use actix_server_config::ServerConfig;
use actix_service::{new_service_cfg, NewService};
use bytes::{Bytes, BytesMut};
use futures::future::{ok, Future};
use futures::stream::{once, Stream};
use openssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod};
use std::io::Result;
use actix_http::error::{ErrorBadRequest, PayloadError};
use actix_http::http::header::{self, HeaderName, HeaderValue};
use actix_http::http::{Method, StatusCode, Version};
use actix_http::httpmessage::HttpMessage;
use actix_http::{body, Error, HttpService, Request, Response};
fn load_body<S>(stream: S) -> impl Future<Item = BytesMut, Error = PayloadError>
where
S: Stream<Item = Bytes, Error = PayloadError>,
{
stream.fold(BytesMut::new(), move |mut body, chunk| {
body.extend_from_slice(&chunk);
Ok::<_, PayloadError>(body)
})
}
fn ssl_acceptor<T: AsyncRead + AsyncWrite>() -> Result<OpensslAcceptor<T, ()>> {
// load ssl keys
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
builder
.set_private_key_file("../tests/key.pem", SslFiletype::PEM)
.unwrap();
builder
.set_certificate_chain_file("../tests/cert.pem")
.unwrap();
builder.set_alpn_select_callback(|_, protos| {
const H2: &[u8] = b"\x02h2";
if protos.windows(3).any(|window| window == H2) {
Ok(b"h2")
} else {
Err(AlpnError::NOACK)
}
});
builder.set_alpn_protos(b"\x02h2")?;
Ok(OpensslAcceptor::new(builder.build()))
}
#[test]
fn test_h2() -> Result<()> {
let openssl = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| ok::<_, Error>(Response::Ok().finish()))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
Ok(())
}
#[test]
fn test_h2_1() -> Result<()> {
let openssl = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.finish(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_2);
ok::<_, Error>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
Ok(())
}
#[test]
fn test_h2_body() -> Result<()> {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let openssl = ssl_acceptor()?;
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|mut req: Request<_>| {
load_body(req.take_payload())
.and_then(|body| Ok(Response::Ok().body(body)))
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send_body(data.clone())).unwrap();
assert!(response.status().is_success());
let body = srv.load_body(response).unwrap();
assert_eq!(&body, data.as_bytes());
Ok(())
}
#[test]
fn test_h2_content_length() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
ok::<_, ()>(Response::new(statuses[indx]))
})
.map_err(|_| ()),
)
});
let header = HeaderName::from_static("content-length");
let value = HeaderValue::from_static("0");
{
for i in 0..4 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None);
let req = srv
.request(Method::HEAD, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), None);
}
for i in 4..6 {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = srv.block_on(req).unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
}
}
}
#[test]
fn test_h2_headers() {
let data = STR.repeat(10);
let data2 = data.clone();
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
let data = data.clone();
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build().h2(move |_| {
let mut builder = Response::Ok();
for idx in 0..90 {
builder.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
ok::<_, ()>(builder.body(data.clone()))
}).map_err(|_| ()))
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from(data2));
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World";
#[test]
fn test_h2_body2() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_head_empty() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.finish(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
assert_eq!(response.version(), Version::HTTP_2);
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty());
}
#[test]
fn test_h2_head_binary() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
ok::<_, ()>(
Response::Ok().content_length(STR.len() as u64).body(STR),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
// read response
let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty());
}
#[test]
fn test_h2_head_binary2() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| ok::<_, ()>(Response::Ok().body(STR)))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.shead("/").send()).unwrap();
assert!(response.status().is_success());
{
let len = response.headers().get(header::CONTENT_LENGTH).unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
}
#[test]
fn test_h2_body_length() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.body(body::SizedStream::new(STR.len() as u64, body)),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_body_chunked_explicit() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| {
let body =
once::<_, Error>(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
assert!(!response.headers().contains_key(header::TRANSFER_ENCODING));
// read response
let bytes = srv.load_body(response).unwrap();
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
}
#[test]
fn test_h2_response_http_error_handling() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(new_service_cfg(|_: &ServerConfig| {
Ok::<_, ()>(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, ()>(
Response::Ok()
.header(header::CONTENT_TYPE, broken_header)
.body(STR),
)
})
}))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"failed to parse header value"));
}
#[test]
fn test_h2_service_error() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.h2(|_| Err::<Response, Error>(ErrorBadRequest("error")))
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// read response
let bytes = srv.load_body(response).unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
}
#[test]
fn test_h2_on_connect() {
let openssl = ssl_acceptor().unwrap();
let mut srv = TestServer::new(move || {
openssl
.clone()
.map_err(|e| println!("Openssl error: {}", e))
.and_then(
HttpService::build()
.on_connect(|_| 10usize)
.h2(|req: Request| {
assert!(req.extensions().contains::<usize>());
ok::<_, ()>(Response::Ok().finish())
})
.map_err(|_| ()),
)
});
let response = srv.block_on(srv.sget("/").send()).unwrap();
assert!(response.status().is_success());
}

View File

@@ -1,194 +1,76 @@
use std::cell::Cell;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::{body, h1, ws, Error, HttpService, Request, Response}; use actix_http::{body, h1, ws, Error, HttpService, Request, Response};
use actix_http_test::test_server; use actix_http_test::TestServer;
use actix_service::{fn_factory, Service}; use actix_utils::framed::FramedTransport;
use actix_utils::framed::Dispatcher; use bytes::{Bytes, BytesMut};
use bytes::Bytes; use futures::future::{self, ok};
use futures::future; use futures::{Future, Sink, Stream};
use futures::task::{Context, Poll};
use futures::{Future, SinkExt, StreamExt};
struct WsService<T>(Arc<Mutex<(PhantomData<T>, Cell<bool>)>>); fn ws_service<T: AsyncRead + AsyncWrite>(
(req, framed): (Request, Framed<T, h1::Codec>),
impl<T> WsService<T> { ) -> impl Future<Item = (), Error = Error> {
fn new() -> Self {
WsService(Arc::new(Mutex::new((PhantomData, Cell::new(false)))))
}
fn set_polled(&mut self) {
*self.0.lock().unwrap().1.get_mut() = true;
}
fn was_polled(&self) -> bool {
self.0.lock().unwrap().1.get()
}
}
impl<T> Clone for WsService<T> {
fn clone(&self) -> Self {
WsService(self.0.clone())
}
}
impl<T> Service for WsService<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
type Request = (Request, Framed<T, h1::Codec>);
type Response = ();
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<(), Error>>>>;
fn poll_ready(&mut self, _ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.set_polled();
Poll::Ready(Ok(()))
}
fn call(&mut self, (req, mut framed): Self::Request) -> Self::Future {
let fut = async move {
let res = ws::handshake(req.head()).unwrap().message_body(()); let res = ws::handshake(req.head()).unwrap().message_body(());
framed framed
.send((res, body::BodySize::None).into()) .send((res, body::BodySize::None).into())
.await
.unwrap();
Dispatcher::new(framed.into_framed(ws::Codec::new()), service)
.await
.map_err(|_| panic!()) .map_err(|_| panic!())
}; .and_then(|framed| {
FramedTransport::new(framed.into_framed(ws::Codec::new()), service)
Box::pin(fut) .map_err(|_| panic!())
} })
} }
async fn service(msg: ws::Frame) -> Result<ws::Message, Error> { fn service(msg: ws::Frame) -> impl Future<Item = ws::Message, Error = Error> {
let msg = match msg { let msg = match msg {
ws::Frame::Ping(msg) => ws::Message::Pong(msg), ws::Frame::Ping(msg) => ws::Message::Pong(msg),
ws::Frame::Text(text) => { ws::Frame::Text(text) => {
ws::Message::Text(String::from_utf8_lossy(&text).to_string()) ws::Message::Text(String::from_utf8_lossy(&text.unwrap()).to_string())
} }
ws::Frame::Binary(bin) => ws::Message::Binary(bin), ws::Frame::Binary(bin) => ws::Message::Binary(bin.unwrap().freeze()),
ws::Frame::Continuation(item) => ws::Message::Continuation(item),
ws::Frame::Close(reason) => ws::Message::Close(reason), ws::Frame::Close(reason) => ws::Message::Close(reason),
_ => panic!(), _ => panic!(),
}; };
Ok(msg) ok(msg)
} }
#[actix_rt::test] #[test]
async fn test_simple() { fn test_simple() {
let ws_service = WsService::new(); let mut srv = TestServer::new(|| {
let mut srv = test_server({
let ws_service = ws_service.clone();
move || {
let ws_service = ws_service.clone();
HttpService::build() HttpService::build()
.upgrade(fn_factory(move || future::ok::<_, ()>(ws_service.clone()))) .upgrade(ws_service)
.finish(|_| future::ok::<_, ()>(Response::NotFound())) .finish(|_| future::ok::<_, ()>(Response::NotFound()))
.tcp()
}
}); });
// client service // client service
let mut framed = srv.ws().await.unwrap(); let framed = srv.ws().unwrap();
framed let framed = srv
.send(ws::Message::Text("text".to_string())) .block_on(framed.send(ws::Message::Text("text".to_string())))
.await
.unwrap(); .unwrap();
let (item, mut framed) = framed.into_future().await; let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!( assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text")))));
item.unwrap().unwrap(),
ws::Frame::Text(Bytes::from_static(b"text"))
);
framed let framed = srv
.send(ws::Message::Binary("text".into())) .block_on(framed.send(ws::Message::Binary("text".into())))
.await
.unwrap(); .unwrap();
let (item, mut framed) = framed.into_future().await; let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!( assert_eq!(
item.unwrap().unwrap(), item,
ws::Frame::Binary(Bytes::from_static(&b"text"[..])) Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into())))
); );
framed.send(ws::Message::Ping("text".into())).await.unwrap(); let framed = srv
let (item, mut framed) = framed.into_future().await; .block_on(framed.send(ws::Message::Ping("text".into())))
assert_eq!(
item.unwrap().unwrap(),
ws::Frame::Pong("text".to_string().into())
);
framed
.send(ws::Message::Continuation(ws::Item::FirstText(
"text".into(),
)))
.await
.unwrap(); .unwrap();
let (item, mut framed) = framed.into_future().await; let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!( assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into())));
item.unwrap().unwrap(),
ws::Frame::Continuation(ws::Item::FirstText(Bytes::from_static(b"text")))
);
assert!(framed let framed = srv
.send(ws::Message::Continuation(ws::Item::FirstText( .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))))
"text".into()
)))
.await
.is_err());
assert!(framed
.send(ws::Message::Continuation(ws::Item::FirstBinary(
"text".into()
)))
.await
.is_err());
framed
.send(ws::Message::Continuation(ws::Item::Continue("text".into())))
.await
.unwrap();
let (item, mut framed) = framed.into_future().await;
assert_eq!(
item.unwrap().unwrap(),
ws::Frame::Continuation(ws::Item::Continue(Bytes::from_static(b"text")))
);
framed
.send(ws::Message::Continuation(ws::Item::Last("text".into())))
.await
.unwrap();
let (item, mut framed) = framed.into_future().await;
assert_eq!(
item.unwrap().unwrap(),
ws::Frame::Continuation(ws::Item::Last(Bytes::from_static(b"text")))
);
assert!(framed
.send(ws::Message::Continuation(ws::Item::Continue("text".into())))
.await
.is_err());
assert!(framed
.send(ws::Message::Continuation(ws::Item::Last("text".into())))
.await
.is_err());
framed
.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
.await
.unwrap(); .unwrap();
let (item, _framed) = framed.into_future().await; let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!( assert_eq!(
item.unwrap().unwrap(), item,
ws::Frame::Close(Some(ws::CloseCode::Normal.into())) Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into())))
); );
assert!(ws_service.was_polled());
} }

View File

@@ -1,9 +1,5 @@
# Changes # Changes
## [0.2.0] - 2019-12-20
* Use actix-web 2.0
## [0.1.0] - 2019-06-xx ## [0.1.0] - 2019-06-xx
* Move identity middleware to separate crate * Move identity middleware to separate crate

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "actix-identity" name = "actix-identity"
version = "0.2.0" version = "0.1.0"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Identity service for actix web framework." description = "Identity service for actix web framework."
readme = "README.md" readme = "README.md"
@@ -17,14 +17,14 @@ name = "actix_identity"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "2.0.0-rc", default-features = false, features = ["secure-cookies"] } actix-web = { version = "1.0.0", default-features = false, features = ["secure-cookies"] }
actix-service = "1.0.0" actix-service = "0.4.0"
futures = "0.3.1" futures = "0.1.25"
serde = "1.0" serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
time = "0.1.42" time = "0.1.42"
[dev-dependencies] [dev-dependencies]
actix-rt = "1.0.0" actix-rt = "0.2.2"
actix-http = "1.0.1" actix-http = "0.2.3"
bytes = "0.5.3" bytes = "0.4"

View File

@@ -16,7 +16,7 @@
//! use actix_web::*; //! use actix_web::*;
//! use actix_identity::{Identity, CookieIdentityPolicy, IdentityService}; //! use actix_identity::{Identity, CookieIdentityPolicy, IdentityService};
//! //!
//! async fn index(id: Identity) -> String { //! fn index(id: Identity) -> String {
//! // access request identity //! // access request identity
//! if let Some(id) = id.identity() { //! if let Some(id) = id.identity() {
//! format!("Welcome! {}", id) //! format!("Welcome! {}", id)
@@ -25,12 +25,12 @@
//! } //! }
//! } //! }
//! //!
//! async fn login(id: Identity) -> HttpResponse { //! fn login(id: Identity) -> HttpResponse {
//! id.remember("User1".to_owned()); // <- remember identity //! id.remember("User1".to_owned()); // <- remember identity
//! HttpResponse::Ok().finish() //! HttpResponse::Ok().finish()
//! } //! }
//! //!
//! async fn logout(id: Identity) -> HttpResponse { //! fn logout(id: Identity) -> HttpResponse {
//! id.forget(); // <- remove identity //! id.forget(); // <- remove identity
//! HttpResponse::Ok().finish() //! HttpResponse::Ok().finish()
//! } //! }
@@ -47,13 +47,12 @@
//! } //! }
//! ``` //! ```
use std::cell::RefCell; use std::cell::RefCell;
use std::future::Future;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll};
use std::time::SystemTime; use std::time::SystemTime;
use actix_service::{Service, Transform}; use actix_service::{Service, Transform};
use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use futures::future::{ok, Either, FutureResult};
use futures::{Future, IntoFuture, Poll};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use time::Duration; use time::Duration;
@@ -166,21 +165,21 @@ where
impl FromRequest for Identity { impl FromRequest for Identity {
type Config = (); type Config = ();
type Error = Error; type Error = Error;
type Future = Ready<Result<Identity, Error>>; type Future = Result<Identity, Error>;
#[inline] #[inline]
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
ok(Identity(req.clone())) Ok(Identity(req.clone()))
} }
} }
/// Identity policy definition. /// Identity policy definition.
pub trait IdentityPolicy: Sized + 'static { pub trait IdentityPolicy: Sized + 'static {
/// The return type of the middleware /// The return type of the middleware
type Future: Future<Output = Result<Option<String>, Error>>; type Future: IntoFuture<Item = Option<String>, Error = Error>;
/// The return type of the middleware /// The return type of the middleware
type ResponseFuture: Future<Output = Result<(), Error>>; type ResponseFuture: IntoFuture<Item = (), Error = Error>;
/// Parse the session from request and load data from a service identity. /// Parse the session from request and load data from a service identity.
fn from_request(&self, request: &mut ServiceRequest) -> Self::Future; fn from_request(&self, request: &mut ServiceRequest) -> Self::Future;
@@ -235,7 +234,7 @@ where
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Transform = IdentityServiceMiddleware<S, T>; type Transform = IdentityServiceMiddleware<S, T>;
type Future = Ready<Result<Self::Transform, Self::InitError>>; type Future = FutureResult<Self::Transform, Self::InitError>;
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
ok(IdentityServiceMiddleware { ok(IdentityServiceMiddleware {
@@ -262,39 +261,46 @@ where
type Request = ServiceRequest; type Request = ServiceRequest;
type Response = ServiceResponse<B>; type Response = ServiceResponse<B>;
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = Box<dyn Future<Item = Self::Response, Error = Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.borrow_mut().poll_ready(cx) self.service.borrow_mut().poll_ready()
} }
fn call(&mut self, mut req: ServiceRequest) -> Self::Future { fn call(&mut self, mut req: ServiceRequest) -> Self::Future {
let srv = self.service.clone(); let srv = self.service.clone();
let backend = self.backend.clone(); let backend = self.backend.clone();
let fut = self.backend.from_request(&mut req);
async move { Box::new(
match fut.await { self.backend.from_request(&mut req).into_future().then(
move |res| match res {
Ok(id) => { Ok(id) => {
req.extensions_mut() req.extensions_mut()
.insert(IdentityItem { id, changed: false }); .insert(IdentityItem { id, changed: false });
let mut res = srv.borrow_mut().call(req).await?; Either::A(srv.borrow_mut().call(req).and_then(move |mut res| {
let id = res.request().extensions_mut().remove::<IdentityItem>(); let id =
res.request().extensions_mut().remove::<IdentityItem>();
if let Some(id) = id { if let Some(id) = id {
match backend.to_response(id.id, id.changed, &mut res).await { Either::A(
backend
.to_response(id.id, id.changed, &mut res)
.into_future()
.then(move |t| match t {
Ok(_) => Ok(res), Ok(_) => Ok(res),
Err(e) => Ok(res.error_response(e)), Err(e) => Ok(res.error_response(e)),
} }),
)
} else { } else {
Ok(res) Either::B(ok(res))
} }
}))
} }
Err(err) => Ok(req.error_response(err)), Err(err) => Either::B(ok(req.error_response(err))),
} },
} ),
.boxed_local() )
} }
} }
@@ -541,11 +547,11 @@ impl CookieIdentityPolicy {
} }
impl IdentityPolicy for CookieIdentityPolicy { impl IdentityPolicy for CookieIdentityPolicy {
type Future = Ready<Result<Option<String>, Error>>; type Future = Result<Option<String>, Error>;
type ResponseFuture = Ready<Result<(), Error>>; type ResponseFuture = Result<(), Error>;
fn from_request(&self, req: &mut ServiceRequest) -> Self::Future { fn from_request(&self, req: &mut ServiceRequest) -> Self::Future {
ok(self.0.load(req).map( Ok(self.0.load(req).map(
|CookieValue { |CookieValue {
identity, identity,
login_timestamp, login_timestamp,
@@ -597,7 +603,7 @@ impl IdentityPolicy for CookieIdentityPolicy {
} else { } else {
Ok(()) Ok(())
}; };
ok(()) Ok(())
} }
} }
@@ -614,8 +620,8 @@ mod tests {
const COOKIE_NAME: &'static str = "actix_auth"; const COOKIE_NAME: &'static str = "actix_auth";
const COOKIE_LOGIN: &'static str = "test"; const COOKIE_LOGIN: &'static str = "test";
#[actix_rt::test] #[test]
async fn test_identity() { fn test_identity() {
let mut srv = test::init_service( let mut srv = test::init_service(
App::new() App::new()
.wrap(IdentityService::new( .wrap(IdentityService::new(
@@ -644,16 +650,13 @@ mod tests {
HttpResponse::BadRequest() HttpResponse::BadRequest()
} }
})), })),
) );
.await;
let resp = let resp =
test::call_service(&mut srv, TestRequest::with_uri("/index").to_request()) test::call_service(&mut srv, TestRequest::with_uri("/index").to_request());
.await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
let resp = let resp =
test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()) test::call_service(&mut srv, TestRequest::with_uri("/login").to_request());
.await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
let c = resp.response().cookies().next().unwrap().to_owned(); let c = resp.response().cookies().next().unwrap().to_owned();
@@ -662,8 +665,7 @@ mod tests {
TestRequest::with_uri("/index") TestRequest::with_uri("/index")
.cookie(c.clone()) .cookie(c.clone())
.to_request(), .to_request(),
) );
.await;
assert_eq!(resp.status(), StatusCode::CREATED); assert_eq!(resp.status(), StatusCode::CREATED);
let resp = test::call_service( let resp = test::call_service(
@@ -671,14 +673,13 @@ mod tests {
TestRequest::with_uri("/logout") TestRequest::with_uri("/logout")
.cookie(c.clone()) .cookie(c.clone())
.to_request(), .to_request(),
) );
.await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().contains_key(header::SET_COOKIE)) assert!(resp.headers().contains_key(header::SET_COOKIE))
} }
#[actix_rt::test] #[test]
async fn test_identity_max_age_time() { fn test_identity_max_age_time() {
let duration = Duration::days(1); let duration = Duration::days(1);
let mut srv = test::init_service( let mut srv = test::init_service(
App::new() App::new()
@@ -694,19 +695,17 @@ mod tests {
id.remember("test".to_string()); id.remember("test".to_string());
HttpResponse::Ok() HttpResponse::Ok()
})), })),
) );
.await;
let resp = let resp =
test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()) test::call_service(&mut srv, TestRequest::with_uri("/login").to_request());
.await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().contains_key(header::SET_COOKIE)); assert!(resp.headers().contains_key(header::SET_COOKIE));
let c = resp.response().cookies().next().unwrap().to_owned(); let c = resp.response().cookies().next().unwrap().to_owned();
assert_eq!(duration, c.max_age().unwrap()); assert_eq!(duration, c.max_age().unwrap());
} }
#[actix_rt::test] #[test]
async fn test_identity_max_age() { fn test_identity_max_age() {
let seconds = 60; let seconds = 60;
let mut srv = test::init_service( let mut srv = test::init_service(
App::new() App::new()
@@ -722,18 +721,16 @@ mod tests {
id.remember("test".to_string()); id.remember("test".to_string());
HttpResponse::Ok() HttpResponse::Ok()
})), })),
) );
.await;
let resp = let resp =
test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()) test::call_service(&mut srv, TestRequest::with_uri("/login").to_request());
.await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().contains_key(header::SET_COOKIE)); assert!(resp.headers().contains_key(header::SET_COOKIE));
let c = resp.response().cookies().next().unwrap().to_owned(); let c = resp.response().cookies().next().unwrap().to_owned();
assert_eq!(Duration::seconds(seconds as i64), c.max_age().unwrap()); assert_eq!(Duration::seconds(seconds as i64), c.max_age().unwrap());
} }
async fn create_identity_server< fn create_identity_server<
F: Fn(CookieIdentityPolicy) -> CookieIdentityPolicy + Sync + Send + Clone + 'static, F: Fn(CookieIdentityPolicy) -> CookieIdentityPolicy + Sync + Send + Clone + 'static,
>( >(
f: F, f: F,
@@ -750,16 +747,13 @@ mod tests {
.secure(false) .secure(false)
.name(COOKIE_NAME)))) .name(COOKIE_NAME))))
.service(web::resource("/").to(|id: Identity| { .service(web::resource("/").to(|id: Identity| {
async move {
let identity = id.identity(); let identity = id.identity();
if identity.is_none() { if identity.is_none() {
id.remember(COOKIE_LOGIN.to_string()) id.remember(COOKIE_LOGIN.to_string())
} }
web::Json(identity) web::Json(identity)
}
})), })),
) )
.await
} }
fn legacy_login_cookie(identity: &'static str) -> Cookie<'static> { fn legacy_login_cookie(identity: &'static str) -> Cookie<'static> {
@@ -792,8 +786,15 @@ mod tests {
jar.get(COOKIE_NAME).unwrap().clone() jar.get(COOKIE_NAME).unwrap().clone()
} }
async fn assert_logged_in(response: ServiceResponse, identity: Option<&str>) { fn assert_logged_in(response: &mut ServiceResponse, identity: Option<&str>) {
let bytes = test::read_body(response).await; use bytes::BytesMut;
use futures::Stream;
let bytes =
test::block_on(response.take_body().fold(BytesMut::new(), |mut b, c| {
b.extend(c);
Ok::<_, Error>(b)
}))
.unwrap();
let resp: Option<String> = serde_json::from_slice(&bytes[..]).unwrap(); let resp: Option<String> = serde_json::from_slice(&bytes[..]).unwrap();
assert_eq!(resp.as_ref().map(|s| s.borrow()), identity); assert_eq!(resp.as_ref().map(|s| s.borrow()), identity);
} }
@@ -871,118 +872,108 @@ mod tests {
assert!(cookies.get(COOKIE_NAME).is_none()); assert!(cookies.get(COOKIE_NAME).is_none());
} }
#[actix_rt::test] #[test]
async fn test_identity_legacy_cookie_is_set() { fn test_identity_legacy_cookie_is_set() {
let mut srv = create_identity_server(|c| c).await; let mut srv = create_identity_server(|c| c);
let mut resp = let mut resp =
test::call_service(&mut srv, TestRequest::with_uri("/").to_request()).await; test::call_service(&mut srv, TestRequest::with_uri("/").to_request());
assert_logged_in(&mut resp, None);
assert_legacy_login_cookie(&mut resp, COOKIE_LOGIN); assert_legacy_login_cookie(&mut resp, COOKIE_LOGIN);
assert_logged_in(resp, None).await;
} }
#[actix_rt::test] #[test]
async fn test_identity_legacy_cookie_works() { fn test_identity_legacy_cookie_works() {
let mut srv = create_identity_server(|c| c).await; let mut srv = create_identity_server(|c| c);
let cookie = legacy_login_cookie(COOKIE_LOGIN); let cookie = legacy_login_cookie(COOKIE_LOGIN);
let mut resp = test::call_service( let mut resp = test::call_service(
&mut srv, &mut srv,
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
) );
.await; assert_logged_in(&mut resp, Some(COOKIE_LOGIN));
assert_no_login_cookie(&mut resp); assert_no_login_cookie(&mut resp);
assert_logged_in(resp, Some(COOKIE_LOGIN)).await;
} }
#[actix_rt::test] #[test]
async fn test_identity_legacy_cookie_rejected_if_visit_timestamp_needed() { fn test_identity_legacy_cookie_rejected_if_visit_timestamp_needed() {
let mut srv = let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90)));
create_identity_server(|c| c.visit_deadline(Duration::days(90))).await;
let cookie = legacy_login_cookie(COOKIE_LOGIN); let cookie = legacy_login_cookie(COOKIE_LOGIN);
let mut resp = test::call_service( let mut resp = test::call_service(
&mut srv, &mut srv,
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
) );
.await; assert_logged_in(&mut resp, None);
assert_login_cookie( assert_login_cookie(
&mut resp, &mut resp,
COOKIE_LOGIN, COOKIE_LOGIN,
LoginTimestampCheck::NoTimestamp, LoginTimestampCheck::NoTimestamp,
VisitTimeStampCheck::NewTimestamp, VisitTimeStampCheck::NewTimestamp,
); );
assert_logged_in(resp, None).await;
} }
#[actix_rt::test] #[test]
async fn test_identity_legacy_cookie_rejected_if_login_timestamp_needed() { fn test_identity_legacy_cookie_rejected_if_login_timestamp_needed() {
let mut srv = let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90)));
create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
let cookie = legacy_login_cookie(COOKIE_LOGIN); let cookie = legacy_login_cookie(COOKIE_LOGIN);
let mut resp = test::call_service( let mut resp = test::call_service(
&mut srv, &mut srv,
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
) );
.await; assert_logged_in(&mut resp, None);
assert_login_cookie( assert_login_cookie(
&mut resp, &mut resp,
COOKIE_LOGIN, COOKIE_LOGIN,
LoginTimestampCheck::NewTimestamp, LoginTimestampCheck::NewTimestamp,
VisitTimeStampCheck::NoTimestamp, VisitTimeStampCheck::NoTimestamp,
); );
assert_logged_in(resp, None).await;
} }
#[actix_rt::test] #[test]
async fn test_identity_cookie_rejected_if_login_timestamp_needed() { fn test_identity_cookie_rejected_if_login_timestamp_needed() {
let mut srv = let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90)));
create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
let cookie = login_cookie(COOKIE_LOGIN, None, Some(SystemTime::now())); let cookie = login_cookie(COOKIE_LOGIN, None, Some(SystemTime::now()));
let mut resp = test::call_service( let mut resp = test::call_service(
&mut srv, &mut srv,
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
) );
.await; assert_logged_in(&mut resp, None);
assert_login_cookie( assert_login_cookie(
&mut resp, &mut resp,
COOKIE_LOGIN, COOKIE_LOGIN,
LoginTimestampCheck::NewTimestamp, LoginTimestampCheck::NewTimestamp,
VisitTimeStampCheck::NoTimestamp, VisitTimeStampCheck::NoTimestamp,
); );
assert_logged_in(resp, None).await;
} }
#[actix_rt::test] #[test]
async fn test_identity_cookie_rejected_if_visit_timestamp_needed() { fn test_identity_cookie_rejected_if_visit_timestamp_needed() {
let mut srv = let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90)));
create_identity_server(|c| c.visit_deadline(Duration::days(90))).await;
let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None);
let mut resp = test::call_service( let mut resp = test::call_service(
&mut srv, &mut srv,
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
) );
.await; assert_logged_in(&mut resp, None);
assert_login_cookie( assert_login_cookie(
&mut resp, &mut resp,
COOKIE_LOGIN, COOKIE_LOGIN,
LoginTimestampCheck::NoTimestamp, LoginTimestampCheck::NoTimestamp,
VisitTimeStampCheck::NewTimestamp, VisitTimeStampCheck::NewTimestamp,
); );
assert_logged_in(resp, None).await;
} }
#[actix_rt::test] #[test]
async fn test_identity_cookie_rejected_if_login_timestamp_too_old() { fn test_identity_cookie_rejected_if_login_timestamp_too_old() {
let mut srv = let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90)));
create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
let cookie = login_cookie( let cookie = login_cookie(
COOKIE_LOGIN, COOKIE_LOGIN,
Some(SystemTime::now() - Duration::days(180).to_std().unwrap()), Some(SystemTime::now() - Duration::days(180).to_std().unwrap()),
@@ -993,21 +984,19 @@ mod tests {
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
) );
.await; assert_logged_in(&mut resp, None);
assert_login_cookie( assert_login_cookie(
&mut resp, &mut resp,
COOKIE_LOGIN, COOKIE_LOGIN,
LoginTimestampCheck::NewTimestamp, LoginTimestampCheck::NewTimestamp,
VisitTimeStampCheck::NoTimestamp, VisitTimeStampCheck::NoTimestamp,
); );
assert_logged_in(resp, None).await;
} }
#[actix_rt::test] #[test]
async fn test_identity_cookie_rejected_if_visit_timestamp_too_old() { fn test_identity_cookie_rejected_if_visit_timestamp_too_old() {
let mut srv = let mut srv = create_identity_server(|c| c.visit_deadline(Duration::days(90)));
create_identity_server(|c| c.visit_deadline(Duration::days(90))).await;
let cookie = login_cookie( let cookie = login_cookie(
COOKIE_LOGIN, COOKIE_LOGIN,
None, None,
@@ -1018,40 +1007,36 @@ mod tests {
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
) );
.await; assert_logged_in(&mut resp, None);
assert_login_cookie( assert_login_cookie(
&mut resp, &mut resp,
COOKIE_LOGIN, COOKIE_LOGIN,
LoginTimestampCheck::NoTimestamp, LoginTimestampCheck::NoTimestamp,
VisitTimeStampCheck::NewTimestamp, VisitTimeStampCheck::NewTimestamp,
); );
assert_logged_in(resp, None).await;
} }
#[actix_rt::test] #[test]
async fn test_identity_cookie_not_updated_on_login_deadline() { fn test_identity_cookie_not_updated_on_login_deadline() {
let mut srv = let mut srv = create_identity_server(|c| c.login_deadline(Duration::days(90)));
create_identity_server(|c| c.login_deadline(Duration::days(90))).await;
let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None);
let mut resp = test::call_service( let mut resp = test::call_service(
&mut srv, &mut srv,
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
) );
.await; assert_logged_in(&mut resp, Some(COOKIE_LOGIN));
assert_no_login_cookie(&mut resp); assert_no_login_cookie(&mut resp);
assert_logged_in(resp, Some(COOKIE_LOGIN)).await;
} }
#[actix_rt::test] #[test]
async fn test_identity_cookie_updated_on_visit_deadline() { fn test_identity_cookie_updated_on_visit_deadline() {
let mut srv = create_identity_server(|c| { let mut srv = create_identity_server(|c| {
c.visit_deadline(Duration::days(90)) c.visit_deadline(Duration::days(90))
.login_deadline(Duration::days(90)) .login_deadline(Duration::days(90))
}) });
.await;
let timestamp = SystemTime::now() - Duration::days(1).to_std().unwrap(); let timestamp = SystemTime::now() - Duration::days(1).to_std().unwrap();
let cookie = login_cookie(COOKIE_LOGIN, Some(timestamp), Some(timestamp)); let cookie = login_cookie(COOKIE_LOGIN, Some(timestamp), Some(timestamp));
let mut resp = test::call_service( let mut resp = test::call_service(
@@ -1059,14 +1044,13 @@ mod tests {
TestRequest::with_uri("/") TestRequest::with_uri("/")
.cookie(cookie.clone()) .cookie(cookie.clone())
.to_request(), .to_request(),
) );
.await; assert_logged_in(&mut resp, Some(COOKIE_LOGIN));
assert_login_cookie( assert_login_cookie(
&mut resp, &mut resp,
COOKIE_LOGIN, COOKIE_LOGIN,
LoginTimestampCheck::OldTimestamp(timestamp), LoginTimestampCheck::OldTimestamp(timestamp),
VisitTimeStampCheck::NewTimestamp, VisitTimeStampCheck::NewTimestamp,
); );
assert_logged_in(resp, Some(COOKIE_LOGIN)).await;
} }
} }

View File

@@ -1,16 +1,8 @@
# Changes # Changes
## [0.2.0] - 2019-12-20 ## [0.1.5] - 2019-12-07
* Release * Multipart handling now handles NotReady during read of boundary #1189
## [0.2.0-alpha.4] - 2019-12-xx
* Multipart handling now handles Pending during read of boundary #1205
## [0.2.0-alpha.2] - 2019-12-03
* Migrate to `std::future`
## [0.1.4] - 2019-09-12 ## [0.1.4] - 2019-09-12

Some files were not shown because too many files have changed in this diff Show More