diff --git a/.appveyor.yml b/.appveyor.yml index f9e79ce7c..2f0a4a7dd 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -1,43 +1,21 @@ environment: global: - PROJECT_NAME: actix + PROJECT_NAME: actix-web matrix: # Stable channel - - TARGET: i686-pc-windows-gnu - CHANNEL: 1.21.0 - - TARGET: i686-pc-windows-msvc - CHANNEL: 1.21.0 - - TARGET: x86_64-pc-windows-gnu - CHANNEL: 1.21.0 - - TARGET: x86_64-pc-windows-msvc - CHANNEL: 1.21.0 - # Stable channel - - TARGET: i686-pc-windows-gnu - CHANNEL: stable - TARGET: i686-pc-windows-msvc CHANNEL: stable - TARGET: x86_64-pc-windows-gnu CHANNEL: stable - TARGET: x86_64-pc-windows-msvc CHANNEL: stable - # Beta channel - - TARGET: i686-pc-windows-gnu - CHANNEL: beta - - TARGET: i686-pc-windows-msvc - CHANNEL: beta - - TARGET: x86_64-pc-windows-gnu - CHANNEL: beta - - TARGET: x86_64-pc-windows-msvc - CHANNEL: beta # Nightly channel - - TARGET: i686-pc-windows-gnu - CHANNEL: nightly-2017-12-21 - TARGET: i686-pc-windows-msvc - CHANNEL: nightly-2017-12-21 + CHANNEL: nightly - TARGET: x86_64-pc-windows-gnu - CHANNEL: nightly-2017-12-21 + CHANNEL: nightly - TARGET: x86_64-pc-windows-msvc - CHANNEL: nightly-2017-12-21 + CHANNEL: nightly # Install Rust and Cargo # (Based on from https://github.com/rust-lang/libc/blob/master/appveyor.yml) @@ -59,4 +37,5 @@ build: false # Equivalent to Travis' `script` phase test_script: - - cargo test --no-default-features + - cargo clean + - cargo test --no-default-features --features="flate2-rust" diff --git a/.travis.yml b/.travis.yml index 7aa8ebaa9..f10f82a48 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,25 +1,18 @@ language: rust -sudo: false +sudo: required dist: trusty cache: - cargo: true + # cargo: true apt: true matrix: include: - - rust: 1.21.0 - rust: stable - rust: beta - - rust: nightly + - rust: nightly-2019-11-20 allow_failures: - - rust: nightly - -#rust: -# - 1.21.0 -# - stable -# - beta -# - nightly-2018-01-03 + - rust: nightly-2019-11-20 env: global: @@ -29,67 +22,40 @@ env: before_install: - sudo add-apt-repository -y ppa:0k53d-karl-f830m/openssl - sudo apt-get update -qq - - sudo apt-get install -qq libssl-dev libelf-dev libdw-dev cmake gcc binutils-dev libiberty-dev + - sudo apt-get install -y openssl libssl-dev libelf-dev libdw-dev cmake gcc binutils-dev libiberty-dev + +before_cache: | + if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-11-20" ]]; then + RUSTFLAGS="--cfg procmacro2_semver_exempt" cargo install --version 0.6.11 cargo-tarpaulin + fi # Add clippy before_script: - - | - if [[ "$TRAVIS_RUST_VERSION" == "nightly" ]]; then - ( ( cargo install clippy && export CLIPPY=true ) || export CLIPPY=false ); - fi - export PATH=$PATH:~/.cargo/bin script: + - cargo update + - cargo check --all --no-default-features - | - if [[ "$TRAVIS_RUST_VERSION" == "stable" ]]; then - cargo clean - USE_SKEPTIC=1 cargo test --features=alpn - else - cargo clean - cargo test -- --nocapture - # --features=alpn - fi - - - | - if [[ "$TRAVIS_RUST_VERSION" == "stable" ]]; then - cd examples/basics && cargo check && cd ../.. - cd examples/hello-world && cargo check && cd ../.. - cd examples/http-proxy && cargo check && cd ../.. - cd examples/multipart && cargo check && cd ../.. - cd examples/json && cargo check && cd ../.. - cd examples/juniper && cargo check && cd ../.. - cd examples/protobuf && cargo check && cd ../.. - cd examples/state && cargo check && cd ../.. - cd examples/template_tera && cargo check && cd ../.. - cd examples/diesel && cargo check && cd ../.. - cd examples/r2d2 && cargo check && cd ../.. - cd examples/tls && cargo check && cd ../.. - cd examples/websocket-chat && cargo check && cd ../.. - cd examples/websocket && cargo check && cd ../.. - cd examples/unix-socket && cargo check && cd ../.. - fi - - | - if [[ "$TRAVIS_RUST_VERSION" == "nightly" && $CLIPPY ]]; then - cargo clippy + if [[ "$TRAVIS_RUST_VERSION" == "stable" || "$TRAVIS_RUST_VERSION" == "beta" ]]; then + cargo test --all-features --all -- --nocapture + cd actix-http; cargo test --no-default-features --features="rustls" -- --nocapture; cd .. + cd awc; cargo test --no-default-features --features="rustls" -- --nocapture; cd .. fi # Upload docs after_success: - | - if [[ "$TRAVIS_OS_NAME" == "linux" && "$TRAVIS_PULL_REQUEST" = "false" && "$TRAVIS_BRANCH" == "master" && "$TRAVIS_RUST_VERSION" == "beta" ]]; then - cargo doc --features "alpn, tls, session" --no-deps && + if [[ "$TRAVIS_OS_NAME" == "linux" && "$TRAVIS_PULL_REQUEST" = "false" && "$TRAVIS_BRANCH" == "master" && "$TRAVIS_RUST_VERSION" == "stable" ]]; then + cargo doc --no-deps --all-features && echo "" > target/doc/index.html && - curl -sL https://github.com/rust-lang-nursery/mdBook/releases/download/v0.1.5/mdbook-v0.1.5-x86_64-unknown-linux-gnu.tar.gz | tar xvz -C $HOME/.cargo/bin && - cd guide && mdbook build -d ../target/doc/guide && cd .. && git clone https://github.com/davisp/ghp-import.git && ./ghp-import/ghp_import.py -n -p -f -m "Documentation upload" -r https://"$GH_TOKEN"@github.com/"$TRAVIS_REPO_SLUG.git" target/doc && echo "Uploaded documentation" fi - - | - if [[ "$TRAVIS_OS_NAME" == "linux" && "$TRAVIS_RUST_VERSION" == "1.21.0" ]]; then - bash <(curl https://raw.githubusercontent.com/xd009642/tarpaulin/master/travis-install.sh) - USE_SKEPTIC=1 cargo tarpaulin --out Xml - bash <(curl -s https://codecov.io/bash) - echo "Uploaded code coverage" + if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-11-20" ]]; then + taskset -c 0 cargo tarpaulin --out Xml --all --all-features + bash <(curl -s https://codecov.io/bash) + echo "Uploaded code coverage" fi diff --git a/CHANGES.md b/CHANGES.md index 03f5b5e94..a7569862d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,257 +1,357 @@ # Changes -## 0.5.0 +## [2.0.0-alpha.1] - 2019-11-22 -* Type-safe path/query/form parameter handling, using serde #70 +### Changed -* HttpResponse builder's methods `.body()`, `.finish()`, `.json()` - return `HttpResponse` instead of `Result` +* Migrated to `std::future` -* Use more ergonomic `actix_web::Error` instead of `http::Error` for `ClientRequestBuilder::body()` +* Remove implementation of `Responder` for `()`. (#1167) -* Added `HttpRequest::resource()`, returns current matched resource -* Added `ErrorHandlers` middleware +## [1.0.9] - 2019-11-14 -* Router cannot parse Non-ASCII characters in URL #137 +### Added -* Fix long client urls #129 +* Add `Payload::into_inner` method and make stored `def::Payload` public. (#1110) -* Fix panic on invalid URL characters #130 +### Changed -* Fix client connection pooling +* Support `Host` guards when the `Host` header is unset (e.g. HTTP/2 requests) (#1129) -* Fix logger request duration calculation #152 +## [1.0.8] - 2019-09-25 -## 0.4.10 (2018-03-20) +### Added -* Use `Error` instead of `InternalError` for `error::ErrorXXXX` methods +* Add `Scope::register_data` and `Resource::register_data` methods, parallel to + `App::register_data`. -* Allow to set client request timeout +* Add `middleware::Condition` that conditionally enables another middleware -* Allow to set client websocket handshake timeout +* Allow to re-construct `ServiceRequest` from `HttpRequest` and `Payload` -* Refactor `TestServer` configuration +* Add `HttpServer::listen_uds` for ability to listen on UDS FD rather than path, + which is useful for example with systemd. -* Fix server websockets big payloads support +### Changed -* Fix http/2 date header generation +* Make UrlEncodedError::Overflow more informativve +* Use actix-testing for testing utils -## 0.4.9 (2018-03-16) -* Allow to disable http/2 support +## [1.0.7] - 2019-08-29 -* Wake payload reading task when data is available +### Fixed -* Fix server keep-alive handling +* Request Extensions leak #1062 -* Send Query Parameters in client requests #120 -* Move brotli encoding to a feature +## [1.0.6] - 2019-08-28 -* Add option of default handler for `StaticFiles` handler #57 +### Added -* Add basic client connection pooling +* Re-implement Host predicate (#989) +* Form immplements Responder, returning a `application/x-www-form-urlencoded` response -## 0.4.8 (2018-03-12) +* Add `into_inner` to `Data` -* Allow to set read buffer capacity for server request +* Add `test::TestRequest::set_form()` convenience method to automatically serialize data and set + the header in test requests. -* Handle WouldBlock error for socket accept call +### Changed +* `Query` payload made `pub`. Allows user to pattern-match the payload. -## 0.4.7 (2018-03-11) +* Enable `rust-tls` feature for client #1045 -* Fix panic on unknown content encoding +* Update serde_urlencoded to 0.6.1 -* Fix connection get closed too early +* Update url to 2.1 -* Fix streaming response handling for http/2 -* Better sleep on error support +## [1.0.5] - 2019-07-18 +### Added -## 0.4.6 (2018-03-10) +* Unix domain sockets (HttpServer::bind_uds) #92 -* Fix client cookie handling +* Actix now logs errors resulting in "internal server error" responses always, with the `error` + logging level -* Fix json content type detection +### Fixed -* Fix CORS middleware #117 +* Restored logging of errors through the `Logger` middleware -* Optimize websockets stream support +## [1.0.4] - 2019-07-17 -## 0.4.5 (2018-03-07) +### Added -* Fix compression #103 and #104 +* Add `Responder` impl for `(T, StatusCode) where T: Responder` -* Fix client cookie handling #111 +* Allow to access app's resource map via + `ServiceRequest::resource_map()` and `HttpRequest::resource_map()` methods. -* Non-blocking processing of a `NamedFile` +### Changed -* Enable compression support for `NamedFile` +* Upgrade `rand` dependency version to 0.7 -* Better support for `NamedFile` type -* Add `ResponseError` impl for `SendRequestError`. This improves ergonomics of the client. +## [1.0.3] - 2019-06-28 -* Add native-tls support for client +### Added -* Allow client connection timeout to be set #108 +* Support asynchronous data factories #850 -* Allow to use std::net::TcpListener for HttpServer +### Changed -* Handle panics in worker threads +* Use `encoding_rs` crate instead of unmaintained `encoding` crate -## 0.4.4 (2018-03-04) +## [1.0.2] - 2019-06-17 -* Allow to use Arc> as response/request body +### Changed -* Fix handling of requests with an encoded body with a length > 8192 #93 +* Move cors middleware to `actix-cors` crate. -## 0.4.3 (2018-03-03) +* Move identity middleware to `actix-identity` crate. -* Fix request body read bug -* Fix segmentation fault #79 +## [1.0.1] - 2019-06-17 -* Set reuse address before bind #90 +### Added +* Add support for PathConfig #903 -## 0.4.2 (2018-03-02) +* Add `middleware::identity::RequestIdentity` trait to `get_identity` from `HttpMessage`. -* Better naming for websockets implementation +### Changed -* Add `Pattern::with_prefix()`, make it more usable outside of actix +* Move cors middleware to `actix-cors` crate. -* Add csrf middleware for filter for cross-site request forgery #89 +* Move identity middleware to `actix-identity` crate. -* Fix disconnect on idle connections +* Disable default feature `secure-cookies`. +* Allow to test an app that uses async actors #897 -## 0.4.1 (2018-03-01) +* Re-apply patch from #637 #894 -* Rename `Route::p()` to `Route::filter()` +### Fixed -* Better naming for http codes +* HttpRequest::url_for is broken with nested scopes #915 -* Fix payload parse in situation when socket data is not ready. -* Fix Session mutable borrow lifetime #87 +## [1.0.0] - 2019-06-05 +### Added -## 0.4.0 (2018-02-28) +* Add `Scope::configure()` method. -* Actix 0.5 compatibility +* Add `ServiceRequest::set_payload()` method. -* Fix request json/urlencoded loaders +* Add `test::TestRequest::set_json()` convenience method to automatically + serialize data and set header in test requests. -* Simplify HttpServer type definition +* Add macros for head, options, trace, connect and patch http methods -* Added HttpRequest::encoding() method +### Changed -* Added HttpRequest::mime_type() method +* Drop an unnecessary `Option<_>` indirection around `ServerBuilder` from `HttpServer`. #863 -* Added HttpRequest::uri_mut(), allows to modify request uri +### Fixed -* Added StaticFiles::index_file() +* Fix Logger request time format, and use rfc3339. #867 -* Added http client +* Clear http requests pool on app service drop #860 -* Added websocket client -* Added TestServer::ws(), test websockets client +## [1.0.0-rc] - 2019-05-18 -* Added TestServer http client support +### Add -* Allow to override content encoding on application level +* Add `Query::from_query()` to extract parameters from a query string. #846 +* `QueryConfig`, similar to `JsonConfig` for customizing error handling of query extractors. +### Changed -## 0.3.3 (2018-01-25) +* `JsonConfig` is now `Send + Sync`, this implies that `error_handler` must be `Send + Sync` too. -* Stop processing any events after context stop +### Fixed -* Re-enable write back-pressure for h1 connections +* Codegen with parameters in the path only resolves the first registered endpoint #841 -* Refactor HttpServer::start_ssl() method -* Upgrade openssl to 0.10 +## [1.0.0-beta.4] - 2019-05-12 +### Add -## 0.3.2 (2018-01-21) +* Allow to set/override app data on scope level -* Fix HEAD requests handling +### Changed -* Log request processing errors +* `App::configure` take an `FnOnce` instead of `Fn` +* Upgrade actix-net crates -* Always enable content encoding if encoding explicitly selected -* Allow multiple Applications on a single server with different state #49 +## [1.0.0-beta.3] - 2019-05-04 -* CORS middleware: allowed_headers is defaulting to None #50 +### Added +* Add helper function for executing futures `test::block_fn()` -## 0.3.1 (2018-01-13) +### Changed -* Fix directory entry path #47 +* Extractor configuration could be registered with `App::data()` + or with `Resource::data()` #775 -* Do not enable chunked encoding for HTTP/1.0 +* Route data is unified with app data, `Route::data()` moved to resource + level to `Resource::data()` -* Allow explicitly disable chunked encoding +* CORS handling without headers #702 +* Allow to construct `Data` instances to avoid double `Arc` for `Send + Sync` types. -## 0.3.0 (2018-01-12) +### Fixed -* HTTP/2 Support +* Fix `NormalizePath` middleware impl #806 -* Refactor streaming responses +### Deleted -* Refactor error handling +* `App::data_factory()` is deleted. -* Asynchronous middlewares -* Refactor logger middleware +## [1.0.0-beta.2] - 2019-04-24 -* Content compression/decompression (br, gzip, deflate) +### Added -* Server multi-threading +* Add raw services support via `web::service()` -* Gracefull shutdown support +* Add helper functions for reading response body `test::read_body()` +* Add support for `remainder match` (i.e "/path/{tail}*") -## 0.2.1 (2017-11-03) +* Extend `Responder` trait, allow to override status code and headers. -* Allow to start tls server with `HttpServer::serve_tls` +* Store visit and login timestamp in the identity cookie #502 -* Export `Frame` enum +### Changed -* Add conversion impl from `HttpResponse` and `BinaryBody` to a `Frame` +* `.to_async()` handler can return `Responder` type #792 +### Fixed -## 0.2.0 (2017-10-30) +* Fix async web::Data factory handling -* Do not use `http::Uri` as it can not parse some valid paths -* Refactor response `Body` +## [1.0.0-beta.1] - 2019-04-20 -* Refactor `RouteRecognizer` usability +### Added -* Refactor `HttpContext::write` +* Add helper functions for reading test response body, + `test::read_response()` and test::read_response_json()` -* Refactor `Payload` stream +* Add `.peer_addr()` #744 -* Re-use `BinaryBody` for `Frame::Payload` +* Add `NormalizePath` middleware -* Stop http actor on `write_eof` +### Changed -* Fix disconnection handling. +* Rename `RouterConfig` to `ServiceConfig` +* Rename `test::call_success` to `test::call_service` -## 0.1.0 (2017-10-23) +* Removed `ServiceRequest::from_parts()` as it is unsafe to create from parts. -* First release +* `CookieIdentityPolicy::max_age()` accepts value in seconds + +### Fixed + +* Fixed `TestRequest::app_data()` + + +## [1.0.0-alpha.6] - 2019-04-14 + +### Changed + +* Allow to use any service as default service. + +* Remove generic type for request payload, always use default. + +* Removed `Decompress` middleware. Bytes, String, Json, Form extractors + automatically decompress payload. + +* Make extractor config type explicit. Add `FromRequest::Config` associated type. + + +## [1.0.0-alpha.5] - 2019-04-12 + +### Added + +* Added async io `TestBuffer` for testing. + +### Deleted + +* Removed native-tls support + + +## [1.0.0-alpha.4] - 2019-04-08 + +### Added + +* `App::configure()` allow to offload app configuration to different methods + +* Added `URLPath` option for logger + +* Added `ServiceRequest::app_data()`, returns `Data` + +* Added `ServiceFromRequest::app_data()`, returns `Data` + +### Changed + +* `FromRequest` trait refactoring + +* Move multipart support to actix-multipart crate + +### Fixed + +* Fix body propagation in Response::from_error. #760 + + +## [1.0.0-alpha.3] - 2019-04-02 + +### Changed + +* Renamed `TestRequest::to_service()` to `TestRequest::to_srv_request()` + +* Renamed `TestRequest::to_response()` to `TestRequest::to_srv_response()` + +* Removed `Deref` impls + +### Removed + +* Removed unused `actix_web::web::md()` + + +## [1.0.0-alpha.2] - 2019-03-29 + +### Added + +* rustls support + +### Changed + +* use forked cookie + +* multipart::Field renamed to MultipartField + +## [1.0.0-alpha.1] - 2019-03-28 + +### Changed + +* Complete architecture re-design. + +* Return 405 response if no matching route found within resource #538 diff --git a/Cargo.toml b/Cargo.toml index e5a17e9ea..689f7b147 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,130 +1,142 @@ [package] name = "actix-web" -version = "0.5.0-dev" +version = "2.0.0-alpha.1" authors = ["Nikolay Kim "] -description = "Actix web is a simple, pragmatic, extremely fast, web framework for Rust." +description = "Actix web is a simple, pragmatic and extremely fast web framework for Rust." readme = "README.md" -keywords = ["http", "web", "framework", "async", "futures"] -homepage = "https://github.com/actix/actix-web" +keywords = ["actix", "http", "web", "framework", "async"] +homepage = "https://actix.rs" repository = "https://github.com/actix/actix-web.git" documentation = "https://docs.rs/actix-web/" categories = ["network-programming", "asynchronous", "web-programming::http-server", - "web-programming::http-client", "web-programming::websocket"] license = "MIT/Apache-2.0" -exclude = [".gitignore", ".travis.yml", ".cargo/config", - "appveyor.yml", "/examples/**"] -build = "build.rs" +exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"] +edition = "2018" + +[package.metadata.docs.rs] +features = ["openssl", "brotli", "flate2-zlib", "secure-cookies", "client"] [badges] travis-ci = { repository = "actix/actix-web", branch = "master" } -appveyor = { repository = "fafhrd91/actix-web-hdy9d" } codecov = { repository = "actix/actix-web", branch = "master", service = "github" } [lib] name = "actix_web" path = "src/lib.rs" -[features] -default = ["session", "brotli"] +[workspace] +members = [ + ".", + "awc", + "actix-http", + "actix-cors", + "actix-files", + "actix-framed", + "actix-session", + "actix-identity", + "actix-multipart", + "actix-web-actors", + "actix-web-codegen", + "test-server", +] -# tls -tls = ["native-tls", "tokio-tls"] +[features] +default = ["brotli", "flate2-zlib", "client", "fail"] + +# http client +client = ["awc"] + +# brotli encoding, requires c compiler +brotli = ["actix-http/brotli"] + +# miniz-sys backend for flate2 crate +flate2-zlib = ["actix-http/flate2-zlib"] + +# rust backend for flate2 crate +flate2-rust = ["actix-http/flate2-rust"] + +# sessions feature, session require "ring" crate and c compiler +secure-cookies = ["actix-http/secure-cookies"] + +fail = ["actix-http/fail"] # openssl -alpn = ["openssl", "openssl/v102", "openssl/v110", "tokio-openssl"] +openssl = ["open-ssl", "actix-server/openssl", "awc/openssl"] -# sessions -session = ["cookie/secure"] - -# brotli encoding -brotli = ["brotli2"] +# rustls +# rustls = ["rust-tls", "actix-server/rustls", "awc/rustls"] [dependencies] -actix = "^0.5.5" +actix-codec = "0.2.0-alpha.1" +actix-service = "1.0.0-alpha.1" +actix-utils = "0.5.0-alpha.1" +actix-router = "0.1.5" +actix-rt = "1.0.0-alpha.1" +actix-web-codegen = "0.2.0-alpha.1" +actix-http = "0.3.0-alpha.1" +actix-server = "0.8.0-alpha.1" +actix-server-config = "0.3.0-alpha.1" +actix-testing = "0.3.0-alpha.1" +actix-threadpool = "0.2.0-alpha.1" +awc = { version = "0.3.0-alpha.1", optional = true } -base64 = "0.9" -bitflags = "1.0" -failure = "0.1.1" -flate2 = "1.0" -h2 = "0.1" -http = "^0.1.5" -httparse = "1.2" -http-range = "0.1" -libc = "0.2" +bytes = "0.4" +derive_more = "0.99.2" +encoding_rs = "0.8" +futures = "0.3.1" +hashbrown = "0.6.3" log = "0.4" mime = "0.3" -mime_guess = "2.0.0-alpha" -num_cpus = "1.0" -percent-encoding = "1.0" -rand = "0.4" -regex = "0.2" -serde = "1.0" +net2 = "0.2.33" +parking_lot = "0.9" +pin-project = "0.4.5" +regex = "1.0" +serde = { version = "1.0", features=["derive"] } serde_json = "1.0" -serde_urlencoded = "0.5" -sha1 = "0.6" -smallvec = "0.6" -time = "0.1" -encoding = "0.2" -language-tags = "0.2" -lazy_static = "1.0" -url = { version="1.7", features=["query_encoding"] } -cookie = { version="0.10", features=["percent-encode"] } -brotli2 = { version="^0.3.2", optional = true } +serde_urlencoded = "0.6.1" +time = "0.1.42" +url = "2.1" -# io -mio = "^0.6.13" -net2 = "0.2" -bytes = "0.4" -byteorder = "1" -futures = "0.1" -futures-cpupool = "0.1" -tokio-io = "0.1" -tokio-core = "0.1" -trust-dns-resolver = "0.8" - -# native-tls -native-tls = { version="0.1", optional = true } -tokio-tls = { version="0.1", optional = true } - -# openssl -openssl = { version="0.10", optional = true } -tokio-openssl = { version="0.2", optional = true } +# ssl support +open-ssl = { version="0.10", package="openssl", optional = true } +# rust-tls = { version = "0.16", package="rustls", optional = true } [dev-dependencies] -env_logger = "0.5" -skeptic = "0.13" +# actix = "0.8.3" +actix-connect = "0.3.0-alpha.1" +actix-http-test = "0.3.0-alpha.1" +rand = "0.7" +env_logger = "0.6" serde_derive = "1.0" - -[build-dependencies] -skeptic = "0.13" -version_check = "0.1" +brotli2 = "0.3.2" +flate2 = "1.0.2" [profile.release] lto = true opt-level = 3 codegen-units = 1 -[workspace] -members = [ - "./", - "examples/basics", - "examples/juniper", - "examples/diesel", - "examples/r2d2", - "examples/json", - "examples/protobuf", - "examples/hello-world", - "examples/http-proxy", - "examples/multipart", - "examples/state", - "examples/redis-session", - "examples/template_tera", - "examples/tls", - "examples/websocket", - "examples/websocket-chat", - "examples/web-cors/backend", - "examples/unix-socket", - "tools/wsload/", -] +[patch.crates-io] +actix-web = { path = "." } +actix-http = { path = "actix-http" } +actix-http-test = { path = "test-server" } +actix-web-codegen = { path = "actix-web-codegen" } +# actix-web-actors = { path = "actix-web-actors" } +actix-cors = { path = "actix-cors" } +actix-identity = { path = "actix-identity" } +actix-session = { path = "actix-session" } +actix-files = { path = "actix-files" } +actix-multipart = { path = "actix-multipart" } +awc = { path = "awc" } + +actix-codec = { git = "https://github.com/actix/actix-net.git" } +actix-connect = { git = "https://github.com/actix/actix-net.git" } +actix-rt = { git = "https://github.com/actix/actix-net.git" } +actix-macros = { git = "https://github.com/actix/actix-net.git" } +actix-server = { git = "https://github.com/actix/actix-net.git" } +actix-server-config = { git = "https://github.com/actix/actix-net.git" } +actix-service = { git = "https://github.com/actix/actix-net.git" } +actix-testing = { git = "https://github.com/actix/actix-net.git" } +actix-utils = { git = "https://github.com/actix/actix-net.git" } diff --git a/MIGRATION.md b/MIGRATION.md new file mode 100644 index 000000000..dd3a1b043 --- /dev/null +++ b/MIGRATION.md @@ -0,0 +1,556 @@ +## 2.0.0 + +* Sync handlers has been removed. `.to_async()` method has been renamed to `.to()` + + replace `fn` with `async fn` to convert sync handler to async + +* `TestServer::new()` renamed to `TestServer::start()` + + +## 1.0.1 + +* Cors middleware has been moved to `actix-cors` crate + + instead of + + ```rust + use actix_web::middleware::cors::Cors; + ``` + + use + + ```rust + use actix_cors::Cors; + ``` + +* Identity middleware has been moved to `actix-identity` crate + + instead of + + ```rust + use actix_web::middleware::identity::{Identity, CookieIdentityPolicy, IdentityService}; + ``` + + use + + ```rust + use actix_identity::{Identity, CookieIdentityPolicy, IdentityService}; + ``` + + +## 1.0.0 + +* Extractor configuration. In version 1.0 this is handled with the new `Data` mechanism for both setting and retrieving the configuration + + instead of + + ```rust + + #[derive(Default)] + struct ExtractorConfig { + config: String, + } + + impl FromRequest for YourExtractor { + type Config = ExtractorConfig; + type Result = Result; + + fn from_request(req: &HttpRequest, cfg: &Self::Config) -> Self::Result { + println!("use the config: {:?}", cfg.config); + ... + } + } + + App::new().resource("/route_with_config", |r| { + r.post().with_config(handler_fn, |cfg| { + cfg.0.config = "test".to_string(); + }) + }) + + ``` + + use the HttpRequest to get the configuration like any other `Data` with `req.app_data::()` and set it with the `data()` method on the `resource` + + ```rust + #[derive(Default)] + struct ExtractorConfig { + config: String, + } + + impl FromRequest for YourExtractor { + type Error = Error; + type Future = Result; + type Config = ExtractorConfig; + + fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { + let cfg = req.app_data::(); + println!("config data?: {:?}", cfg.unwrap().role); + ... + } + } + + App::new().service( + resource("/route_with_config") + .data(ExtractorConfig { + config: "test".to_string(), + }) + .route(post().to(handler_fn)), + ) + ``` + +* Resource registration. 1.0 version uses generalized resource + registration via `.service()` method. + + instead of + + ```rust + App.new().resource("/welcome", |r| r.f(welcome)) + ``` + + use App's or Scope's `.service()` method. `.service()` method accepts + object that implements `HttpServiceFactory` trait. By default + actix-web provides `Resource` and `Scope` services. + + ```rust + App.new().service( + web::resource("/welcome") + .route(web::get().to(welcome)) + .route(web::post().to(post_handler)) + ``` + +* Scope registration. + + instead of + + ```rust + let app = App::new().scope("/{project_id}", |scope| { + scope + .resource("/path1", |r| r.f(|_| HttpResponse::Ok())) + .resource("/path2", |r| r.f(|_| HttpResponse::Ok())) + .resource("/path3", |r| r.f(|_| HttpResponse::MethodNotAllowed())) + }); + ``` + + use `.service()` for registration and `web::scope()` as scope object factory. + + ```rust + let app = App::new().service( + web::scope("/{project_id}") + .service(web::resource("/path1").to(|| HttpResponse::Ok())) + .service(web::resource("/path2").to(|| HttpResponse::Ok())) + .service(web::resource("/path3").to(|| HttpResponse::MethodNotAllowed())) + ); + ``` + +* `.with()`, `.with_async()` registration methods have been renamed to `.to()` and `.to_async()`. + + instead of + + ```rust + App.new().resource("/welcome", |r| r.with(welcome)) + ``` + + use `.to()` or `.to_async()` methods + + ```rust + App.new().service(web::resource("/welcome").to(welcome)) + ``` + +* Passing arguments to handler with extractors, multiple arguments are allowed + + instead of + + ```rust + fn welcome((body, req): (Bytes, HttpRequest)) -> ... { + ... + } + ``` + + use multiple arguments + + ```rust + fn welcome(body: Bytes, req: HttpRequest) -> ... { + ... + } + ``` + +* `.f()`, `.a()` and `.h()` handler registration methods have been removed. + Use `.to()` for handlers and `.to_async()` for async handlers. Handler function + must use extractors. + + instead of + + ```rust + App.new().resource("/welcome", |r| r.f(welcome)) + ``` + + use App's `to()` or `to_async()` methods + + ```rust + App.new().service(web::resource("/welcome").to(welcome)) + ``` + +* `HttpRequest` does not provide access to request's payload stream. + + instead of + + ```rust + fn index(req: &HttpRequest) -> Box> { + req + .payload() + .from_err() + .fold((), |_, chunk| { + ... + }) + .map(|_| HttpResponse::Ok().finish()) + .responder() + } + ``` + + use `Payload` extractor + + ```rust + fn index(stream: web::Payload) -> impl Future { + stream + .from_err() + .fold((), |_, chunk| { + ... + }) + .map(|_| HttpResponse::Ok().finish()) + } + ``` + +* `State` is now `Data`. You register Data during the App initialization process + and then access it from handlers either using a Data extractor or using + HttpRequest's api. + + instead of + + ```rust + App.with_state(T) + ``` + + use App's `data` method + + ```rust + App.new() + .data(T) + ``` + + and either use the Data extractor within your handler + + ```rust + use actix_web::web::Data; + + fn endpoint_handler(Data)){ + ... + } + ``` + + .. or access your Data element from the HttpRequest + + ```rust + fn endpoint_handler(req: HttpRequest) { + let data: Option> = req.app_data::(); + } + ``` + + +* AsyncResponder is removed, use `.to_async()` registration method and `impl Future<>` as result type. + + instead of + + ```rust + use actix_web::AsyncResponder; + + fn endpoint_handler(...) -> impl Future{ + ... + .responder() + } + ``` + + .. simply omit AsyncResponder and the corresponding responder() finish method + + +* Middleware + + instead of + + ```rust + let app = App::new() + .middleware(middleware::Logger::default()) + ``` + + use `.wrap()` method + + ```rust + let app = App::new() + .wrap(middleware::Logger::default()) + .route("/index.html", web::get().to(index)); + ``` + +* `HttpRequest::body()`, `HttpRequest::urlencoded()`, `HttpRequest::json()`, `HttpRequest::multipart()` + method have been removed. Use `Bytes`, `String`, `Form`, `Json`, `Multipart` extractors instead. + + instead of + + ```rust + fn index(req: &HttpRequest) -> Responder { + req.body() + .and_then(|body| { + ... + }) + } + ``` + + use + + ```rust + fn index(body: Bytes) -> Responder { + ... + } + ``` + +* `actix_web::server` module has been removed. To start http server use `actix_web::HttpServer` type + +* StaticFiles and NamedFile has been move to separate create. + + instead of `use actix_web::fs::StaticFile` + + use `use actix_files::Files` + + instead of `use actix_web::fs::Namedfile` + + use `use actix_files::NamedFile` + +* Multipart has been move to separate create. + + instead of `use actix_web::multipart::Multipart` + + use `use actix_multipart::Multipart` + +* Response compression is not enabled by default. + To enable, use `Compress` middleware, `App::new().wrap(Compress::default())`. + +* Session middleware moved to actix-session crate + +* Actors support have been moved to `actix-web-actors` crate + +* Custom Error + + Instead of error_response method alone, ResponseError now provides two methods: error_response and render_response respectively. Where, error_response creates the error response and render_response returns the error response to the caller. + + Simplest migration from 0.7 to 1.0 shall include below method to the custom implementation of ResponseError: + + ```rust + fn render_response(&self) -> HttpResponse { + self.error_response() + } + ``` + +## 0.7.15 + +* The `' '` character is not percent decoded anymore before matching routes. If you need to use it in + your routes, you should use `%20`. + + instead of + + ```rust + fn main() { + let app = App::new().resource("/my index", |r| { + r.method(http::Method::GET) + .with(index); + }); + } + ``` + + use + + ```rust + fn main() { + let app = App::new().resource("/my%20index", |r| { + r.method(http::Method::GET) + .with(index); + }); + } + ``` + +* If you used `AsyncResult::async` you need to replace it with `AsyncResult::future` + + +## 0.7.4 + +* `Route::with_config()`/`Route::with_async_config()` always passes configuration objects as tuple + even for handler with one parameter. + + +## 0.7 + +* `HttpRequest` does not implement `Stream` anymore. If you need to read request payload + use `HttpMessage::payload()` method. + + instead of + + ```rust + fn index(req: HttpRequest) -> impl Responder { + req + .from_err() + .fold(...) + .... + } + ``` + + use `.payload()` + + ```rust + fn index(req: HttpRequest) -> impl Responder { + req + .payload() // <- get request payload stream + .from_err() + .fold(...) + .... + } + ``` + +* [Middleware](https://actix.rs/actix-web/actix_web/middleware/trait.Middleware.html) + trait uses `&HttpRequest` instead of `&mut HttpRequest`. + +* Removed `Route::with2()` and `Route::with3()` use tuple of extractors instead. + + instead of + + ```rust + fn index(query: Query<..>, info: Json impl Responder {} + ``` + + use tuple of extractors and use `.with()` for registration: + + ```rust + fn index((query, json): (Query<..>, Json impl Responder {} + ``` + +* `Handler::handle()` uses `&self` instead of `&mut self` + +* `Handler::handle()` accepts reference to `HttpRequest<_>` instead of value + +* Removed deprecated `HttpServer::threads()`, use + [HttpServer::workers()](https://actix.rs/actix-web/actix_web/server/struct.HttpServer.html#method.workers) instead. + +* Renamed `client::ClientConnectorError::Connector` to + `client::ClientConnectorError::Resolver` + +* `Route::with()` does not return `ExtractorConfig`, to configure + extractor use `Route::with_config()` + + instead of + + ```rust + fn main() { + let app = App::new().resource("/index.html", |r| { + r.method(http::Method::GET) + .with(index) + .limit(4096); // <- limit size of the payload + }); + } + ``` + + use + + ```rust + + fn main() { + let app = App::new().resource("/index.html", |r| { + r.method(http::Method::GET) + .with_config(index, |cfg| { // <- register handler + cfg.limit(4096); // <- limit size of the payload + }) + }); + } + ``` + +* `Route::with_async()` does not return `ExtractorConfig`, to configure + extractor use `Route::with_async_config()` + + +## 0.6 + +* `Path` extractor return `ErrorNotFound` on failure instead of `ErrorBadRequest` + +* `ws::Message::Close` now includes optional close reason. + `ws::CloseCode::Status` and `ws::CloseCode::Empty` have been removed. + +* `HttpServer::threads()` renamed to `HttpServer::workers()`. + +* `HttpServer::start_ssl()` and `HttpServer::start_tls()` deprecated. + Use `HttpServer::bind_ssl()` and `HttpServer::bind_tls()` instead. + +* `HttpRequest::extensions()` returns read only reference to the request's Extension + `HttpRequest::extensions_mut()` returns mutable reference. + +* Instead of + + `use actix_web::middleware::{ + CookieSessionBackend, CookieSessionError, RequestSession, + Session, SessionBackend, SessionImpl, SessionStorage};` + + use `actix_web::middleware::session` + + `use actix_web::middleware::session{CookieSessionBackend, CookieSessionError, + RequestSession, Session, SessionBackend, SessionImpl, SessionStorage};` + +* `FromRequest::from_request()` accepts mutable reference to a request + +* `FromRequest::Result` has to implement `Into>` + +* [`Responder::respond_to()`]( + https://actix.rs/actix-web/actix_web/trait.Responder.html#tymethod.respond_to) + is generic over `S` + +* Use `Query` extractor instead of HttpRequest::query()`. + + ```rust + fn index(q: Query>) -> Result<..> { + ... + } + ``` + + or + + ```rust + let q = Query::>::extract(req); + ``` + +* Websocket operations are implemented as `WsWriter` trait. + you need to use `use actix_web::ws::WsWriter` + + +## 0.5 + +* `HttpResponseBuilder::body()`, `.finish()`, `.json()` + methods return `HttpResponse` instead of `Result` + +* `actix_web::Method`, `actix_web::StatusCode`, `actix_web::Version` + moved to `actix_web::http` module + +* `actix_web::header` moved to `actix_web::http::header` + +* `NormalizePath` moved to `actix_web::http` module + +* `HttpServer` moved to `actix_web::server`, added new `actix_web::server::new()` function, + shortcut for `actix_web::server::HttpServer::new()` + +* `DefaultHeaders` middleware does not use separate builder, all builder methods moved to type itself + +* `StaticFiles::new()`'s show_index parameter removed, use `show_files_listing()` method instead. + +* `CookieSessionBackendBuilder` removed, all methods moved to `CookieSessionBackend` type + +* `actix_web::httpcodes` module is deprecated, `HttpResponse::Ok()`, `HttpResponse::Found()` and other `HttpResponse::XXX()` + functions should be used instead + +* `ClientRequestBuilder::body()` returns `Result<_, actix_web::Error>` + instead of `Result<_, http::Error>` + +* `Application` renamed to a `App` + +* `actix_web::Reply`, `actix_web::Resource` moved to `actix_web::dev` diff --git a/Makefile b/Makefile deleted file mode 100644 index fdc3cbbc0..000000000 --- a/Makefile +++ /dev/null @@ -1,26 +0,0 @@ -.PHONY: default build test doc book clean - -CARGO_FLAGS := --features "$(FEATURES) alpn" - -default: test - -build: - cargo build $(CARGO_FLAGS) - -test: build clippy - cargo test $(CARGO_FLAGS) - -skeptic: - USE_SKEPTIC=1 cargo test $(CARGO_FLAGS) - -# cd examples/word-count && python setup.py install && pytest -v tests - -clippy: - if $$CLIPPY; then cargo clippy $(CARGO_FLAGS); fi - -doc: build - cargo doc --no-deps $(CARGO_FLAGS) - cd guide; mdbook build -d ../target/doc/guide/; cd .. - -book: - cd guide; mdbook build -d ../target/doc/guide/; cd .. diff --git a/README.md b/README.md index 46f589d6f..b7a1bf28f 100644 --- a/README.md +++ b/README.md @@ -1,74 +1,67 @@ -# Actix web [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![Build status](https://ci.appveyor.com/api/projects/status/kkdb4yce7qhm5w85/branch/master?svg=true)](https://ci.appveyor.com/project/fafhrd91/actix-web-hdy9d/branch/master) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-web)](https://crates.io/crates/actix-web) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +# Actix web [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-web)](https://crates.io/crates/actix-web) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -Actix web is a simple, pragmatic, extremely fast, web framework for Rust. +Actix web is a simple, pragmatic and extremely fast web framework for Rust. -* Supported *HTTP/1.x* and [*HTTP/2.0*](https://actix.github.io/actix-web/guide/qs_13.html) protocols +* Supported *HTTP/1.x* and *HTTP/2.0* protocols * Streaming and pipelining * Keep-alive and slow requests handling -* Client/server [WebSockets](https://actix.github.io/actix-web/guide/qs_9.html) support +* Client/server [WebSockets](https://actix.rs/docs/websockets/) support * Transparent content compression/decompression (br, gzip, deflate) -* Configurable [request routing](https://actix.github.io/actix-web/guide/qs_5.html) -* Graceful server shutdown +* Configurable [request routing](https://actix.rs/docs/url-dispatch/) * Multipart streams * Static assets -* SSL support with openssl or native-tls -* Middlewares ([Logger](https://actix.github.io/actix-web/guide/qs_10.html#logging), - [Session](https://actix.github.io/actix-web/guide/qs_10.html#user-sessions), - [Redis sessions](https://github.com/actix/actix-redis), - [DefaultHeaders](https://actix.github.io/actix-web/guide/qs_10.html#default-headers), - [CORS](https://actix.github.io/actix-web/actix_web/middleware/cors/index.html), - [CSRF](https://actix.github.io/actix-web/actix_web/middleware/csrf/index.html)) -* Built on top of [Actix actor framework](https://github.com/actix/actix). +* SSL support with OpenSSL or Rustls +* Middlewares ([Logger, Session, CORS, etc](https://actix.rs/docs/middleware/)) +* Includes an asynchronous [HTTP client](https://actix.rs/actix-web/actix_web/client/index.html) +* Supports [Actix actor framework](https://github.com/actix/actix) -## Documentation +## Documentation & community resources -* [User Guide](http://actix.github.io/actix-web/guide/) -* [API Documentation (Development)](http://actix.github.io/actix-web/actix_web/) -* [API Documentation (Releases)](https://docs.rs/actix-web/) +* [User Guide](https://actix.rs/docs/) +* [API Documentation (1.0)](https://docs.rs/actix-web/) * [Chat on gitter](https://gitter.im/actix/actix) * Cargo package: [actix-web](https://crates.io/crates/actix-web) -* Minimum supported Rust version: 1.21 or later +* Minimum supported Rust version: 1.39 or later ## Example ```rust -extern crate actix_web; -use actix_web::{App, HttpServer, Path}; +use actix_web::{get, App, HttpServer, Responder}; -fn index(info: Path<(String, u32)>) -> String { - format!("Hello {}! id:{}", info.0, info.1) +#[get("/{id}/{name}/index.html")] +async fn index(info: web::Path<(u32, String)>) -> impl Responder { + format!("Hello {}! id:{}", info.1, info.0) } -fn main() { - HttpServer::new( - || App::new() - .resource("/{name}/{id}/index.html", |r| r.with(index))) - .bind("127.0.0.1:8080").unwrap() - .run(); +#[actix_rt::main] +async fn main() -> std::io::Result<()> { + HttpServer::new(|| App::new().service(index)) + .bind("127.0.0.1:8080")? + .start() + .await } ``` ### More examples -* [Basics](https://github.com/actix/actix-web/tree/master/examples/basics/) -* [Stateful](https://github.com/actix/actix-web/tree/master/examples/state/) -* [Protobuf support](https://github.com/actix/actix-web/tree/master/examples/protobuf/) -* [Multipart streams](https://github.com/actix/actix-web/tree/master/examples/multipart/) -* [Simple websocket session](https://github.com/actix/actix-web/tree/master/examples/websocket/) -* [Tera templates](https://github.com/actix/actix-web/tree/master/examples/template_tera/) -* [Diesel integration](https://github.com/actix/actix-web/tree/master/examples/diesel/) -* [SSL / HTTP/2.0](https://github.com/actix/actix-web/tree/master/examples/tls/) -* [Tcp/Websocket chat](https://github.com/actix/actix-web/tree/master/examples/websocket-chat/) -* [Json](https://github.com/actix/actix-web/tree/master/examples/json/) +* [Basics](https://github.com/actix/examples/tree/master/basics/) +* [Stateful](https://github.com/actix/examples/tree/master/state/) +* [Multipart streams](https://github.com/actix/examples/tree/master/multipart/) +* [Simple websocket](https://github.com/actix/examples/tree/master/websocket/) +* [Tera](https://github.com/actix/examples/tree/master/template_tera/) / + [Askama](https://github.com/actix/examples/tree/master/template_askama/) templates +* [Diesel integration](https://github.com/actix/examples/tree/master/diesel/) +* [r2d2](https://github.com/actix/examples/tree/master/r2d2/) +* [SSL / HTTP/2.0](https://github.com/actix/examples/tree/master/tls/) +* [Tcp/Websocket chat](https://github.com/actix/examples/tree/master/websocket-chat/) +* [Json](https://github.com/actix/examples/tree/master/json/) You may consider checking out -[this directory](https://github.com/actix/actix-web/tree/master/examples) for more examples. +[this directory](https://github.com/actix/examples/tree/master/) for more examples. ## Benchmarks -* [TechEmpower Framework Benchmark](https://www.techempower.com/benchmarks/#section=data-r15&hw=ph&test=plaintext) - -* Some basic benchmarks could be found in this [repository](https://github.com/fafhrd91/benchmarks). +* [TechEmpower Framework Benchmark](https://www.techempower.com/benchmarks/#section=data-r18) ## License diff --git a/actix-cors/CHANGES.md b/actix-cors/CHANGES.md new file mode 100644 index 000000000..10e408ede --- /dev/null +++ b/actix-cors/CHANGES.md @@ -0,0 +1,9 @@ +# Changes + +## [0.1.1] - unreleased + +* Bump `derive_more` crate version to 0.15.0 + +## [0.1.0] - 2019-06-15 + +* Move cors middleware to separate crate diff --git a/actix-cors/Cargo.toml b/actix-cors/Cargo.toml new file mode 100644 index 000000000..ddb5f307e --- /dev/null +++ b/actix-cors/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "actix-cors" +version = "0.2.0-alpha.1" +authors = ["Nikolay Kim "] +description = "Cross-origin resource sharing (CORS) for Actix applications." +readme = "README.md" +keywords = ["web", "framework"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" +documentation = "https://docs.rs/actix-cors/" +license = "MIT/Apache-2.0" +edition = "2018" +workspace = ".." + +[lib] +name = "actix_cors" +path = "src/lib.rs" + +[dependencies] +actix-web = "2.0.0-alpha.1" +actix-service = "1.0.0-alpha.1" +derive_more = "0.99.2" +futures = "0.3.1" + +[dev-dependencies] +actix-rt = "1.0.0-alpha.1" diff --git a/actix-cors/LICENSE-APACHE b/actix-cors/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/actix-cors/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/actix-cors/LICENSE-MIT b/actix-cors/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/actix-cors/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/actix-cors/README.md b/actix-cors/README.md new file mode 100644 index 000000000..a77f6c6d3 --- /dev/null +++ b/actix-cors/README.md @@ -0,0 +1,9 @@ +# Cors Middleware for actix web framework [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-cors)](https://crates.io/crates/actix-cors) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +## Documentation & community resources + +* [User Guide](https://actix.rs/docs/) +* [API Documentation](https://docs.rs/actix-cors/) +* [Chat on gitter](https://gitter.im/actix/actix) +* Cargo package: [actix-cors](https://crates.io/crates/actix-cors) +* Minimum supported Rust version: 1.34 or later diff --git a/actix-cors/src/lib.rs b/actix-cors/src/lib.rs new file mode 100644 index 000000000..d3607aa8e --- /dev/null +++ b/actix-cors/src/lib.rs @@ -0,0 +1,1199 @@ +#![allow(clippy::borrow_interior_mutable_const, clippy::type_complexity)] +//! Cross-origin resource sharing (CORS) for Actix applications +//! +//! CORS middleware could be used with application and with resource. +//! Cors middleware could be used as parameter for `App::wrap()`, +//! `Resource::wrap()` or `Scope::wrap()` methods. +//! +//! # Example +//! +//! ```rust +//! use actix_cors::Cors; +//! use actix_web::{http, web, App, HttpRequest, HttpResponse, HttpServer}; +//! +//! async fn index(req: HttpRequest) -> &'static str { +//! "Hello world" +//! } +//! +//! fn main() -> std::io::Result<()> { +//! HttpServer::new(|| App::new() +//! .wrap( +//! Cors::new() // <- Construct CORS middleware builder +//! .allowed_origin("https://www.rust-lang.org/") +//! .allowed_methods(vec!["GET", "POST"]) +//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) +//! .allowed_header(http::header::CONTENT_TYPE) +//! .max_age(3600) +//! .finish()) +//! .service( +//! web::resource("/index.html") +//! .route(web::get().to(index)) +//! .route(web::head().to(|| HttpResponse::MethodNotAllowed())) +//! )) +//! .bind("127.0.0.1:8080")?; +//! +//! Ok(()) +//! } +//! ``` +//! In this example custom *CORS* middleware get registered for "/index.html" +//! endpoint. +//! +//! Cors middleware automatically handle *OPTIONS* preflight request. +use std::collections::HashSet; +use std::iter::FromIterator; +use std::rc::Rc; +use std::task::{Context, Poll}; + +use actix_service::{Service, Transform}; +use actix_web::dev::{RequestHead, ServiceRequest, ServiceResponse}; +use actix_web::error::{Error, ResponseError, Result}; +use actix_web::http::header::{self, HeaderName, HeaderValue}; +use actix_web::http::{self, HttpTryFrom, Method, StatusCode, Uri}; +use actix_web::HttpResponse; +use derive_more::Display; +use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready}; + +/// A set of errors that can occur during processing CORS +#[derive(Debug, Display)] +pub enum CorsError { + /// The HTTP request header `Origin` is required but was not provided + #[display( + fmt = "The HTTP request header `Origin` is required but was not provided" + )] + MissingOrigin, + /// The HTTP request header `Origin` could not be parsed correctly. + #[display(fmt = "The HTTP request header `Origin` could not be parsed correctly.")] + BadOrigin, + /// The request header `Access-Control-Request-Method` is required but is + /// missing + #[display( + fmt = "The request header `Access-Control-Request-Method` is required but is missing" + )] + MissingRequestMethod, + /// The request header `Access-Control-Request-Method` has an invalid value + #[display( + fmt = "The request header `Access-Control-Request-Method` has an invalid value" + )] + BadRequestMethod, + /// The request header `Access-Control-Request-Headers` has an invalid + /// value + #[display( + fmt = "The request header `Access-Control-Request-Headers` has an invalid value" + )] + BadRequestHeaders, + /// Origin is not allowed to make this request + #[display(fmt = "Origin is not allowed to make this request")] + OriginNotAllowed, + /// Requested method is not allowed + #[display(fmt = "Requested method is not allowed")] + MethodNotAllowed, + /// One or more headers requested are not allowed + #[display(fmt = "One or more headers requested are not allowed")] + HeadersNotAllowed, +} + +impl ResponseError for CorsError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } + + fn error_response(&self) -> HttpResponse { + HttpResponse::with_body(StatusCode::BAD_REQUEST, format!("{}", self).into()) + } +} + +/// An enum signifying that some of type T is allowed, or `All` (everything is +/// allowed). +/// +/// `Default` is implemented for this enum and is `All`. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum AllOrSome { + /// Everything is allowed. Usually equivalent to the "*" value. + All, + /// Only some of `T` is allowed + Some(T), +} + +impl Default for AllOrSome { + fn default() -> Self { + AllOrSome::All + } +} + +impl AllOrSome { + /// Returns whether this is an `All` variant + pub fn is_all(&self) -> bool { + match *self { + AllOrSome::All => true, + AllOrSome::Some(_) => false, + } + } + + /// Returns whether this is a `Some` variant + pub fn is_some(&self) -> bool { + !self.is_all() + } + + /// Returns &T + pub fn as_ref(&self) -> Option<&T> { + match *self { + AllOrSome::All => None, + AllOrSome::Some(ref t) => Some(t), + } + } +} + +/// Structure that follows the builder pattern for building `Cors` middleware +/// structs. +/// +/// To construct a cors: +/// +/// 1. Call [`Cors::build`](struct.Cors.html#method.build) to start building. +/// 2. Use any of the builder methods to set fields in the backend. +/// 3. Call [finish](struct.Cors.html#method.finish) to retrieve the +/// constructed backend. +/// +/// # Example +/// +/// ```rust +/// use actix_cors::Cors; +/// use actix_web::http::header; +/// +/// # fn main() { +/// let cors = Cors::new() +/// .allowed_origin("https://www.rust-lang.org/") +/// .allowed_methods(vec!["GET", "POST"]) +/// .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) +/// .allowed_header(header::CONTENT_TYPE) +/// .max_age(3600); +/// # } +/// ``` +#[derive(Default)] +pub struct Cors { + cors: Option, + methods: bool, + error: Option, + expose_hdrs: HashSet, +} + +impl Cors { + /// Build a new CORS middleware instance + pub fn new() -> Cors { + Cors { + cors: Some(Inner { + origins: AllOrSome::All, + origins_str: None, + methods: HashSet::new(), + headers: AllOrSome::All, + expose_hdrs: None, + max_age: None, + preflight: true, + send_wildcard: false, + supports_credentials: false, + vary_header: true, + }), + methods: false, + error: None, + expose_hdrs: HashSet::new(), + } + } + + /// Build a new CORS default middleware + pub fn default() -> CorsFactory { + let inner = Inner { + origins: AllOrSome::default(), + origins_str: None, + methods: HashSet::from_iter( + vec![ + Method::GET, + Method::HEAD, + Method::POST, + Method::OPTIONS, + Method::PUT, + Method::PATCH, + Method::DELETE, + ] + .into_iter(), + ), + headers: AllOrSome::All, + expose_hdrs: None, + max_age: None, + preflight: true, + send_wildcard: false, + supports_credentials: false, + vary_header: true, + }; + CorsFactory { + inner: Rc::new(inner), + } + } + + /// Add an origin that are allowed to make requests. + /// Will be verified against the `Origin` request header. + /// + /// When `All` is set, and `send_wildcard` is set, "*" will be sent in + /// the `Access-Control-Allow-Origin` response header. Otherwise, the + /// client's `Origin` request header will be echoed back in the + /// `Access-Control-Allow-Origin` response header. + /// + /// When `Some` is set, the client's `Origin` request header will be + /// checked in a case-sensitive manner. + /// + /// This is the `list of origins` in the + /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). + /// + /// Defaults to `All`. + /// + /// Builder panics if supplied origin is not valid uri. + pub fn allowed_origin(mut self, origin: &str) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + match Uri::try_from(origin) { + Ok(_) => { + if cors.origins.is_all() { + cors.origins = AllOrSome::Some(HashSet::new()); + } + if let AllOrSome::Some(ref mut origins) = cors.origins { + origins.insert(origin.to_owned()); + } + } + Err(e) => { + self.error = Some(e.into()); + } + } + } + self + } + + /// Set a list of methods which the allowed origins are allowed to access + /// for requests. + /// + /// This is the `list of methods` in the + /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). + /// + /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]` + pub fn allowed_methods(mut self, methods: U) -> Cors + where + U: IntoIterator, + Method: HttpTryFrom, + { + self.methods = true; + if let Some(cors) = cors(&mut self.cors, &self.error) { + for m in methods { + match Method::try_from(m) { + Ok(method) => { + cors.methods.insert(method); + } + Err(e) => { + self.error = Some(e.into()); + break; + } + } + } + } + self + } + + /// Set an allowed header + pub fn allowed_header(mut self, header: H) -> Cors + where + HeaderName: HttpTryFrom, + { + if let Some(cors) = cors(&mut self.cors, &self.error) { + match HeaderName::try_from(header) { + Ok(method) => { + if cors.headers.is_all() { + cors.headers = AllOrSome::Some(HashSet::new()); + } + if let AllOrSome::Some(ref mut headers) = cors.headers { + headers.insert(method); + } + } + Err(e) => self.error = Some(e.into()), + } + } + self + } + + /// Set a list of header field names which can be used when + /// this resource is accessed by allowed origins. + /// + /// If `All` is set, whatever is requested by the client in + /// `Access-Control-Request-Headers` will be echoed back in the + /// `Access-Control-Allow-Headers` header. + /// + /// This is the `list of headers` in the + /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). + /// + /// Defaults to `All`. + pub fn allowed_headers(mut self, headers: U) -> Cors + where + U: IntoIterator, + HeaderName: HttpTryFrom, + { + if let Some(cors) = cors(&mut self.cors, &self.error) { + for h in headers { + match HeaderName::try_from(h) { + Ok(method) => { + if cors.headers.is_all() { + cors.headers = AllOrSome::Some(HashSet::new()); + } + if let AllOrSome::Some(ref mut headers) = cors.headers { + headers.insert(method); + } + } + Err(e) => { + self.error = Some(e.into()); + break; + } + } + } + } + self + } + + /// Set a list of headers which are safe to expose to the API of a CORS API + /// specification. This corresponds to the + /// `Access-Control-Expose-Headers` response header. + /// + /// This is the `list of exposed headers` in the + /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). + /// + /// This defaults to an empty set. + pub fn expose_headers(mut self, headers: U) -> Cors + where + U: IntoIterator, + HeaderName: HttpTryFrom, + { + for h in headers { + match HeaderName::try_from(h) { + Ok(method) => { + self.expose_hdrs.insert(method); + } + Err(e) => { + self.error = Some(e.into()); + break; + } + } + } + self + } + + /// Set a maximum time for which this CORS request maybe cached. + /// This value is set as the `Access-Control-Max-Age` header. + /// + /// This defaults to `None` (unset). + pub fn max_age(mut self, max_age: usize) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.max_age = Some(max_age) + } + self + } + + /// Set a wildcard origins + /// + /// If send wildcard is set and the `allowed_origins` parameter is `All`, a + /// wildcard `Access-Control-Allow-Origin` response header is sent, + /// rather than the request’s `Origin` header. + /// + /// This is the `supports credentials flag` in the + /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). + /// + /// This **CANNOT** be used in conjunction with `allowed_origins` set to + /// `All` and `allow_credentials` set to `true`. Depending on the mode + /// of usage, this will either result in an `Error:: + /// CredentialsWithWildcardOrigin` error during actix launch or runtime. + /// + /// Defaults to `false`. + pub fn send_wildcard(mut self) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.send_wildcard = true + } + self + } + + /// Allows users to make authenticated requests + /// + /// If true, injects the `Access-Control-Allow-Credentials` header in + /// responses. This allows cookies and credentials to be submitted + /// across domains. + /// + /// This option cannot be used in conjunction with an `allowed_origin` set + /// to `All` and `send_wildcards` set to `true`. + /// + /// Defaults to `false`. + /// + /// Builder panics if credentials are allowed, but the Origin is set to "*". + /// This is not allowed by W3C + pub fn supports_credentials(mut self) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.supports_credentials = true + } + self + } + + /// Disable `Vary` header support. + /// + /// When enabled the header `Vary: Origin` will be returned as per the W3 + /// implementation guidelines. + /// + /// Setting this header when the `Access-Control-Allow-Origin` is + /// dynamically generated (e.g. when there is more than one allowed + /// origin, and an Origin than '*' is returned) informs CDNs and other + /// caches that the CORS headers are dynamic, and cannot be cached. + /// + /// By default `vary` header support is enabled. + pub fn disable_vary_header(mut self) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.vary_header = false + } + self + } + + /// Disable *preflight* request support. + /// + /// When enabled cors middleware automatically handles *OPTIONS* request. + /// This is useful application level middleware. + /// + /// By default *preflight* support is enabled. + pub fn disable_preflight(mut self) -> Cors { + if let Some(cors) = cors(&mut self.cors, &self.error) { + cors.preflight = false + } + self + } + + /// Construct cors middleware + pub fn finish(self) -> CorsFactory { + let mut slf = if !self.methods { + self.allowed_methods(vec![ + Method::GET, + Method::HEAD, + Method::POST, + Method::OPTIONS, + Method::PUT, + Method::PATCH, + Method::DELETE, + ]) + } else { + self + }; + + if let Some(e) = slf.error.take() { + panic!("{}", e); + } + + let mut cors = slf.cors.take().expect("cannot reuse CorsBuilder"); + + if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() { + panic!("Credentials are allowed, but the Origin is set to \"*\""); + } + + if let AllOrSome::Some(ref origins) = cors.origins { + let s = origins + .iter() + .fold(String::new(), |s, v| format!("{}, {}", s, v)); + cors.origins_str = Some(HeaderValue::try_from(&s[2..]).unwrap()); + } + + if !slf.expose_hdrs.is_empty() { + cors.expose_hdrs = Some( + slf.expose_hdrs + .iter() + .fold(String::new(), |s, v| format!("{}, {}", s, v.as_str()))[2..] + .to_owned(), + ); + } + + CorsFactory { + inner: Rc::new(cors), + } + } +} + +fn cors<'a>( + parts: &'a mut Option, + err: &Option, +) -> Option<&'a mut Inner> { + if err.is_some() { + return None; + } + parts.as_mut() +} + +/// `Middleware` for Cross-origin resource sharing support +/// +/// The Cors struct contains the settings for CORS requests to be validated and +/// for responses to be generated. +pub struct CorsFactory { + inner: Rc, +} + +impl Transform for CorsFactory +where + S: Service, Error = Error>, + S::Future: 'static, + B: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Transform = CorsMiddleware; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ok(CorsMiddleware { + service, + inner: self.inner.clone(), + }) + } +} + +/// `Middleware` for Cross-origin resource sharing support +/// +/// The Cors struct contains the settings for CORS requests to be validated and +/// for responses to be generated. +#[derive(Clone)] +pub struct CorsMiddleware { + service: S, + inner: Rc, +} + +struct Inner { + methods: HashSet, + origins: AllOrSome>, + origins_str: Option, + headers: AllOrSome>, + expose_hdrs: Option, + max_age: Option, + preflight: bool, + send_wildcard: bool, + supports_credentials: bool, + vary_header: bool, +} + +impl Inner { + fn validate_origin(&self, req: &RequestHead) -> Result<(), CorsError> { + if let Some(hdr) = req.headers().get(&header::ORIGIN) { + if let Ok(origin) = hdr.to_str() { + return match self.origins { + AllOrSome::All => Ok(()), + AllOrSome::Some(ref allowed_origins) => allowed_origins + .get(origin) + .and_then(|_| Some(())) + .ok_or_else(|| CorsError::OriginNotAllowed), + }; + } + Err(CorsError::BadOrigin) + } else { + match self.origins { + AllOrSome::All => Ok(()), + _ => Err(CorsError::MissingOrigin), + } + } + } + + fn access_control_allow_origin(&self, req: &RequestHead) -> Option { + match self.origins { + AllOrSome::All => { + if self.send_wildcard { + Some(HeaderValue::from_static("*")) + } else if let Some(origin) = req.headers().get(&header::ORIGIN) { + Some(origin.clone()) + } else { + None + } + } + AllOrSome::Some(ref origins) => { + if let Some(origin) = + req.headers() + .get(&header::ORIGIN) + .filter(|o| match o.to_str() { + Ok(os) => origins.contains(os), + _ => false, + }) + { + Some(origin.clone()) + } else { + Some(self.origins_str.as_ref().unwrap().clone()) + } + } + } + } + + fn validate_allowed_method(&self, req: &RequestHead) -> Result<(), CorsError> { + if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_METHOD) { + if let Ok(meth) = hdr.to_str() { + if let Ok(method) = Method::try_from(meth) { + return self + .methods + .get(&method) + .and_then(|_| Some(())) + .ok_or_else(|| CorsError::MethodNotAllowed); + } + } + Err(CorsError::BadRequestMethod) + } else { + Err(CorsError::MissingRequestMethod) + } + } + + fn validate_allowed_headers(&self, req: &RequestHead) -> Result<(), CorsError> { + match self.headers { + AllOrSome::All => Ok(()), + AllOrSome::Some(ref allowed_headers) => { + if let Some(hdr) = + req.headers().get(&header::ACCESS_CONTROL_REQUEST_HEADERS) + { + if let Ok(headers) = hdr.to_str() { + let mut hdrs = HashSet::new(); + for hdr in headers.split(',') { + match HeaderName::try_from(hdr.trim()) { + Ok(hdr) => hdrs.insert(hdr), + Err(_) => return Err(CorsError::BadRequestHeaders), + }; + } + // `Access-Control-Request-Headers` must contain 1 or more + // `field-name`. + if !hdrs.is_empty() { + if !hdrs.is_subset(allowed_headers) { + return Err(CorsError::HeadersNotAllowed); + } + return Ok(()); + } + } + Err(CorsError::BadRequestHeaders) + } else { + Ok(()) + } + } + } + } +} + +impl Service for CorsMiddleware +where + S: Service, Error = Error>, + S::Future: 'static, + B: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Future = Either< + Ready>, + LocalBoxFuture<'static, Result>, + >; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: ServiceRequest) -> Self::Future { + if self.inner.preflight && Method::OPTIONS == *req.method() { + if let Err(e) = self + .inner + .validate_origin(req.head()) + .and_then(|_| self.inner.validate_allowed_method(req.head())) + .and_then(|_| self.inner.validate_allowed_headers(req.head())) + { + return Either::Left(ok(req.error_response(e))); + } + + // allowed headers + let headers = if let Some(headers) = self.inner.headers.as_ref() { + Some( + HeaderValue::try_from( + &headers + .iter() + .fold(String::new(), |s, v| s + "," + v.as_str()) + .as_str()[1..], + ) + .unwrap(), + ) + } else if let Some(hdr) = + req.headers().get(&header::ACCESS_CONTROL_REQUEST_HEADERS) + { + Some(hdr.clone()) + } else { + None + }; + + let res = HttpResponse::Ok() + .if_some(self.inner.max_age.as_ref(), |max_age, resp| { + let _ = resp.header( + header::ACCESS_CONTROL_MAX_AGE, + format!("{}", max_age).as_str(), + ); + }) + .if_some(headers, |headers, resp| { + let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); + }) + .if_some( + self.inner.access_control_allow_origin(req.head()), + |origin, resp| { + let _ = resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); + }, + ) + .if_true(self.inner.supports_credentials, |resp| { + resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); + }) + .header( + header::ACCESS_CONTROL_ALLOW_METHODS, + &self + .inner + .methods + .iter() + .fold(String::new(), |s, v| s + "," + v.as_str()) + .as_str()[1..], + ) + .finish() + .into_body(); + + Either::Left(ok(req.into_response(res))) + } else { + if req.headers().contains_key(&header::ORIGIN) { + // Only check requests with a origin header. + if let Err(e) = self.inner.validate_origin(req.head()) { + return Either::Left(ok(req.error_response(e))); + } + } + + let inner = self.inner.clone(); + let has_origin = req.headers().contains_key(&header::ORIGIN); + let fut = self.service.call(req); + + Either::Right( + async move { + let res = fut.await; + + if has_origin { + let mut res = res?; + if let Some(origin) = + inner.access_control_allow_origin(res.request().head()) + { + res.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + origin.clone(), + ); + }; + + if let Some(ref expose) = inner.expose_hdrs { + res.headers_mut().insert( + header::ACCESS_CONTROL_EXPOSE_HEADERS, + HeaderValue::try_from(expose.as_str()).unwrap(), + ); + } + if inner.supports_credentials { + res.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + } + if inner.vary_header { + let value = if let Some(hdr) = + res.headers_mut().get(&header::VARY) + { + let mut val: Vec = + Vec::with_capacity(hdr.as_bytes().len() + 8); + val.extend(hdr.as_bytes()); + val.extend(b", Origin"); + HeaderValue::try_from(&val[..]).unwrap() + } else { + HeaderValue::from_static("Origin") + }; + res.headers_mut().insert(header::VARY, value); + } + Ok(res) + } else { + res + } + } + .boxed_local(), + ) + } + } +} + +#[cfg(test)] +mod tests { + use actix_service::{service_fn2, Transform}; + use actix_web::test::{self, TestRequest}; + + use super::*; + + #[actix_rt::test] + #[should_panic(expected = "Credentials are allowed, but the Origin is set to")] + async fn cors_validates_illegal_allow_credentials() { + let _cors = Cors::new().supports_credentials().send_wildcard().finish(); + } + + #[actix_rt::test] + async fn validate_origin_allows_all_origins() { + let mut cors = Cors::new() + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn default() { + let mut cors = Cors::default() + .new_transform(test::ok_service()) + .await + .unwrap(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_preflight() { + let mut cors = Cors::new() + .send_wildcard() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) + .allowed_header(header::CONTENT_TYPE) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed") + .to_srv_request(); + + assert!(cors.inner.validate_allowed_method(req.head()).is_err()); + assert!(cors.inner.validate_allowed_headers(req.head()).is_err()); + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put") + .method(Method::OPTIONS) + .to_srv_request(); + + assert!(cors.inner.validate_allowed_method(req.head()).is_err()); + assert!(cors.inner.validate_allowed_headers(req.head()).is_ok()); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .header( + header::ACCESS_CONTROL_REQUEST_HEADERS, + "AUTHORIZATION,ACCEPT", + ) + .method(Method::OPTIONS) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"*"[..], + resp.headers() + .get(&header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + assert_eq!( + &b"3600"[..], + resp.headers() + .get(&header::ACCESS_CONTROL_MAX_AGE) + .unwrap() + .as_bytes() + ); + let hdr = resp + .headers() + .get(&header::ACCESS_CONTROL_ALLOW_HEADERS) + .unwrap() + .to_str() + .unwrap(); + assert!(hdr.contains("authorization")); + assert!(hdr.contains("accept")); + assert!(hdr.contains("content-type")); + + let methods = resp + .headers() + .get(header::ACCESS_CONTROL_ALLOW_METHODS) + .unwrap() + .to_str() + .unwrap(); + assert!(methods.contains("POST")); + assert!(methods.contains("GET")); + assert!(methods.contains("OPTIONS")); + + Rc::get_mut(&mut cors.inner).unwrap().preflight = false; + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .header( + header::ACCESS_CONTROL_REQUEST_HEADERS, + "AUTHORIZATION,ACCEPT", + ) + .method(Method::OPTIONS) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + // #[actix_rt::test] + // #[should_panic(expected = "MissingOrigin")] + // async fn test_validate_missing_origin() { + // let cors = Cors::build() + // .allowed_origin("https://www.example.com") + // .finish(); + // let mut req = HttpRequest::default(); + // cors.start(&req).unwrap(); + // } + + #[actix_rt::test] + #[should_panic(expected = "OriginNotAllowed")] + async fn test_validate_not_allowed_origin() { + let cors = Cors::new() + .allowed_origin("https://www.example.com") + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.unknown.com") + .method(Method::GET) + .to_srv_request(); + cors.inner.validate_origin(req.head()).unwrap(); + cors.inner.validate_allowed_method(req.head()).unwrap(); + cors.inner.validate_allowed_headers(req.head()).unwrap(); + } + + #[actix_rt::test] + async fn test_validate_origin() { + let mut cors = Cors::new() + .allowed_origin("https://www.example.com") + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_no_origin_response() { + let mut cors = Cors::new() + .disable_preflight() + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::default().method(Method::GET).to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + assert!(resp + .headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .is_none()); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://www.example.com"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + } + + #[actix_rt::test] + async fn test_response() { + let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; + let mut cors = Cors::new() + .send_wildcard() + .disable_preflight() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(exposed_headers.clone()) + .expose_headers(exposed_headers.clone()) + .allowed_header(header::CONTENT_TYPE) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"*"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + assert_eq!( + &b"Origin"[..], + resp.headers().get(header::VARY).unwrap().as_bytes() + ); + + { + let headers = resp + .headers() + .get(header::ACCESS_CONTROL_EXPOSE_HEADERS) + .unwrap() + .to_str() + .unwrap() + .split(',') + .map(|s| s.trim()) + .collect::>(); + + for h in exposed_headers { + assert!(headers.contains(&h.as_str())); + } + } + + let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; + let mut cors = Cors::new() + .send_wildcard() + .disable_preflight() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(exposed_headers.clone()) + .expose_headers(exposed_headers.clone()) + .allowed_header(header::CONTENT_TYPE) + .finish() + .new_transform(service_fn2(|req: ServiceRequest| { + ok(req.into_response( + HttpResponse::Ok().header(header::VARY, "Accept").finish(), + )) + })) + .await + .unwrap(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"Accept, Origin"[..], + resp.headers().get(header::VARY).unwrap().as_bytes() + ); + + let mut cors = Cors::new() + .disable_vary_header() + .allowed_origin("https://www.example.com") + .allowed_origin("https://www.google.com") + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + + let origins_str = resp + .headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .to_str() + .unwrap(); + + assert_eq!("https://www.example.com", origins_str); + } + + #[actix_rt::test] + async fn test_multiple_origins() { + let mut cors = Cors::new() + .allowed_origin("https://example.com") + .allowed_origin("https://example.org") + .allowed_methods(vec![Method::GET]) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://example.com") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.com"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + + let req = TestRequest::with_header("Origin", "https://example.org") + .method(Method::GET) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.org"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + } + + #[actix_rt::test] + async fn test_multiple_origins_preflight() { + let mut cors = Cors::new() + .allowed_origin("https://example.com") + .allowed_origin("https://example.org") + .allowed_methods(vec![Method::GET]) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") + .method(Method::OPTIONS) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.com"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + + let req = TestRequest::with_header("Origin", "https://example.org") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") + .method(Method::OPTIONS) + .to_srv_request(); + + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.org"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + } +} diff --git a/actix-files/CHANGES.md b/actix-files/CHANGES.md new file mode 100644 index 000000000..5ec56593c --- /dev/null +++ b/actix-files/CHANGES.md @@ -0,0 +1,64 @@ +# Changes + +## [0.1.7] - 2019-11-06 + +* Add an additional `filename*` param in the `Content-Disposition` header of `actix_files::NamedFile` to be more compatible. (#1151) + +## [0.1.6] - 2019-10-14 + +* Add option to redirect to a slash-ended path `Files` #1132 + +## [0.1.5] - 2019-10-08 + +* Bump up `mime_guess` crate version to 2.0.1 + +* Bump up `percent-encoding` crate version to 2.1 + +* Allow user defined request guards for `Files` #1113 + +## [0.1.4] - 2019-07-20 + +* Allow to disable `Content-Disposition` header #686 + +## [0.1.3] - 2019-06-28 + +* Do not set `Content-Length` header, let actix-http set it #930 + +## [0.1.2] - 2019-06-13 + +* Content-Length is 0 for NamedFile HEAD request #914 + +* Fix ring dependency from actix-web default features for #741 + +## [0.1.1] - 2019-06-01 + +* Static files are incorrectly served as both chunked and with length #812 + +## [0.1.0] - 2019-05-25 + +* NamedFile last-modified check always fails due to nano-seconds + in file modified date #820 + +## [0.1.0-beta.4] - 2019-05-12 + +* Update actix-web to beta.4 + +## [0.1.0-beta.1] - 2019-04-20 + +* Update actix-web to beta.1 + +## [0.1.0-alpha.6] - 2019-04-14 + +* Update actix-web to alpha6 + +## [0.1.0-alpha.4] - 2019-04-08 + +* Update actix-web to alpha4 + +## [0.1.0-alpha.2] - 2019-04-02 + +* Add default handler support + +## [0.1.0-alpha.1] - 2019-03-28 + +* Initial impl diff --git a/actix-files/Cargo.toml b/actix-files/Cargo.toml new file mode 100644 index 000000000..19366b902 --- /dev/null +++ b/actix-files/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "actix-files" +version = "0.2.0-alpha.1" +authors = ["Nikolay Kim "] +description = "Static files support for actix web." +readme = "README.md" +keywords = ["actix", "http", "async", "futures"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" +documentation = "https://docs.rs/actix-files/" +categories = ["asynchronous", "web-programming::http-server"] +license = "MIT/Apache-2.0" +edition = "2018" +workspace = ".." + +[lib] +name = "actix_files" +path = "src/lib.rs" + +[dependencies] +actix-web = { version = "2.0.0-alpha.1", default-features = false } +actix-http = "0.3.0-alpha.1" +actix-service = "1.0.0-alpha.1" +bitflags = "1" +bytes = "0.4" +futures = "0.3.1" +derive_more = "0.99.2" +log = "0.4" +mime = "0.3" +mime_guess = "2.0.1" +percent-encoding = "2.1" +v_htmlescape = "0.4" + +[dev-dependencies] +actix-rt = "1.0.0-alpha.1" +actix-web = { version = "2.0.0-alpha.1", features=["openssl"] } diff --git a/actix-files/LICENSE-APACHE b/actix-files/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/actix-files/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/actix-files/LICENSE-MIT b/actix-files/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/actix-files/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/actix-files/README.md b/actix-files/README.md new file mode 100644 index 000000000..9585e67a8 --- /dev/null +++ b/actix-files/README.md @@ -0,0 +1,9 @@ +# Static files support for actix web [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-files)](https://crates.io/crates/actix-files) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +## Documentation & community resources + +* [User Guide](https://actix.rs/docs/) +* [API Documentation](https://docs.rs/actix-files/) +* [Chat on gitter](https://gitter.im/actix/actix) +* Cargo package: [actix-files](https://crates.io/crates/actix-files) +* Minimum supported Rust version: 1.33 or later diff --git a/actix-files/src/error.rs b/actix-files/src/error.rs new file mode 100644 index 000000000..49a46e58d --- /dev/null +++ b/actix-files/src/error.rs @@ -0,0 +1,41 @@ +use actix_web::{http::StatusCode, HttpResponse, ResponseError}; +use derive_more::Display; + +/// Errors which can occur when serving static files. +#[derive(Display, Debug, PartialEq)] +pub enum FilesError { + /// Path is not a directory + #[display(fmt = "Path is not a directory. Unable to serve static files")] + IsNotDirectory, + + /// Cannot render directory + #[display(fmt = "Unable to render directory without index file")] + IsDirectory, +} + +/// Return `NotFound` for `FilesError` +impl ResponseError for FilesError { + fn error_response(&self) -> HttpResponse { + HttpResponse::new(StatusCode::NOT_FOUND) + } +} + +#[derive(Display, Debug, PartialEq)] +pub enum UriSegmentError { + /// The segment started with the wrapped invalid character. + #[display(fmt = "The segment started with the wrapped invalid character")] + BadStart(char), + /// The segment contained the wrapped invalid character. + #[display(fmt = "The segment contained the wrapped invalid character")] + BadChar(char), + /// The segment ended with the wrapped invalid character. + #[display(fmt = "The segment ended with the wrapped invalid character")] + BadEnd(char), +} + +/// Return `BadRequest` for `UriSegmentError` +impl ResponseError for UriSegmentError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} diff --git a/actix-files/src/lib.rs b/actix-files/src/lib.rs new file mode 100644 index 000000000..ed8b6c3b9 --- /dev/null +++ b/actix-files/src/lib.rs @@ -0,0 +1,1421 @@ +#![allow(clippy::borrow_interior_mutable_const, clippy::type_complexity)] + +//! Static files support +use std::cell::RefCell; +use std::fmt::Write; +use std::fs::{DirEntry, File}; +use std::future::Future; +use std::io::{Read, Seek}; +use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; +use std::{cmp, io}; + +use actix_service::boxed::{self, BoxService, BoxServiceFactory}; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; +use actix_web::dev::{ + AppService, HttpServiceFactory, Payload, ResourceDef, ServiceRequest, + ServiceResponse, +}; +use actix_web::error::{Canceled, Error, ErrorInternalServerError}; +use actix_web::guard::Guard; +use actix_web::http::header::{self, DispositionType}; +use actix_web::http::Method; +use actix_web::{web, FromRequest, HttpRequest, HttpResponse}; +use bytes::Bytes; +use futures::future::{ok, ready, Either, FutureExt, LocalBoxFuture, Ready}; +use futures::Stream; +use mime; +use mime_guess::from_ext; +use percent_encoding::{utf8_percent_encode, CONTROLS}; +use v_htmlescape::escape as escape_html_entity; + +mod error; +mod named; +mod range; + +use self::error::{FilesError, UriSegmentError}; +pub use crate::named::NamedFile; +pub use crate::range::HttpRange; + +type HttpService = BoxService; +type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; + +/// Return the MIME type associated with a filename extension (case-insensitive). +/// If `ext` is empty or no associated type for the extension was found, returns +/// the type `application/octet-stream`. +#[inline] +pub fn file_extension_to_mime(ext: &str) -> mime::Mime { + from_ext(ext).first_or_octet_stream() +} + +#[doc(hidden)] +/// A helper created from a `std::fs::File` which reads the file +/// chunk-by-chunk on a `ThreadPool`. +pub struct ChunkedReadFile { + size: u64, + offset: u64, + file: Option, + fut: Option< + LocalBoxFuture<'static, Result, Canceled>>, + >, + counter: u64, +} + +impl Stream for ChunkedReadFile { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + if let Some(ref mut fut) = self.fut { + return match Pin::new(fut).poll(cx) { + Poll::Ready(Err(_)) => Poll::Ready(Some(Err(ErrorInternalServerError( + "Unexpected error", + ) + .into()))), + Poll::Ready(Ok(Ok((file, bytes)))) => { + self.fut.take(); + self.file = Some(file); + self.offset += bytes.len() as u64; + self.counter += bytes.len() as u64; + Poll::Ready(Some(Ok(bytes))) + } + Poll::Ready(Ok(Err(e))) => Poll::Ready(Some(Err(e.into()))), + Poll::Pending => Poll::Pending, + }; + } + + let size = self.size; + let offset = self.offset; + let counter = self.counter; + + if size == counter { + Poll::Ready(None) + } else { + let mut file = self.file.take().expect("Use after completion"); + self.fut = Some( + web::block(move || { + let max_bytes: usize; + max_bytes = cmp::min(size.saturating_sub(counter), 65_536) as usize; + let mut buf = Vec::with_capacity(max_bytes); + file.seek(io::SeekFrom::Start(offset))?; + let nbytes = + file.by_ref().take(max_bytes as u64).read_to_end(&mut buf)?; + if nbytes == 0 { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + Ok((file, Bytes::from(buf))) + }) + .boxed_local(), + ); + self.poll_next(cx) + } + } +} + +type DirectoryRenderer = + dyn Fn(&Directory, &HttpRequest) -> Result; + +/// A directory; responds with the generated directory listing. +#[derive(Debug)] +pub struct Directory { + /// Base directory + pub base: PathBuf, + /// Path of subdirectory to generate listing for + pub path: PathBuf, +} + +impl Directory { + /// Create a new directory + pub fn new(base: PathBuf, path: PathBuf) -> Directory { + Directory { base, path } + } + + /// Is this entry visible from this directory? + pub fn is_visible(&self, entry: &io::Result) -> bool { + if let Ok(ref entry) = *entry { + if let Some(name) = entry.file_name().to_str() { + if name.starts_with('.') { + return false; + } + } + if let Ok(ref md) = entry.metadata() { + let ft = md.file_type(); + return ft.is_dir() || ft.is_file() || ft.is_symlink(); + } + } + false + } +} + +// show file url as relative to static path +macro_rules! encode_file_url { + ($path:ident) => { + utf8_percent_encode(&$path.to_string_lossy(), CONTROLS) + }; +} + +// " -- " & -- & ' -- ' < -- < > -- > / -- / +macro_rules! encode_file_name { + ($entry:ident) => { + escape_html_entity(&$entry.file_name().to_string_lossy()) + }; +} + +fn directory_listing( + dir: &Directory, + req: &HttpRequest, +) -> Result { + let index_of = format!("Index of {}", req.path()); + let mut body = String::new(); + let base = Path::new(req.path()); + + for entry in dir.path.read_dir()? { + if dir.is_visible(&entry) { + let entry = entry.unwrap(); + let p = match entry.path().strip_prefix(&dir.path) { + Ok(p) => base.join(p), + Err(_) => continue, + }; + + // if file is a directory, add '/' to the end of the name + if let Ok(metadata) = entry.metadata() { + if metadata.is_dir() { + let _ = write!( + body, + "
  • {}/
  • ", + encode_file_url!(p), + encode_file_name!(entry), + ); + } else { + let _ = write!( + body, + "
  • {}
  • ", + encode_file_url!(p), + encode_file_name!(entry), + ); + } + } else { + continue; + } + } + } + + let html = format!( + "\ + {}\ +

    {}

    \ +
      \ + {}\ +
    \n", + index_of, index_of, body + ); + Ok(ServiceResponse::new( + req.clone(), + HttpResponse::Ok() + .content_type("text/html; charset=utf-8") + .body(html), + )) +} + +type MimeOverride = dyn Fn(&mime::Name) -> DispositionType; + +/// Static files handling +/// +/// `Files` service must be registered with `App::service()` method. +/// +/// ```rust +/// use actix_web::App; +/// use actix_files as fs; +/// +/// fn main() { +/// let app = App::new() +/// .service(fs::Files::new("/static", ".")); +/// } +/// ``` +pub struct Files { + path: String, + directory: PathBuf, + index: Option, + show_index: bool, + redirect_to_slash: bool, + default: Rc>>>, + renderer: Rc, + mime_override: Option>, + file_flags: named::Flags, + guards: Option>>, +} + +impl Clone for Files { + fn clone(&self) -> Self { + Self { + directory: self.directory.clone(), + index: self.index.clone(), + show_index: self.show_index, + redirect_to_slash: self.redirect_to_slash, + default: self.default.clone(), + renderer: self.renderer.clone(), + file_flags: self.file_flags, + path: self.path.clone(), + mime_override: self.mime_override.clone(), + guards: self.guards.clone(), + } + } +} + +impl Files { + /// Create new `Files` instance for specified base directory. + /// + /// `File` uses `ThreadPool` for blocking filesystem operations. + /// By default pool with 5x threads of available cpus is used. + /// Pool size can be changed by setting ACTIX_CPU_POOL environment variable. + pub fn new>(path: &str, dir: T) -> Files { + let dir = dir.into().canonicalize().unwrap_or_else(|_| PathBuf::new()); + if !dir.is_dir() { + log::error!("Specified path is not a directory: {:?}", dir); + } + + Files { + path: path.to_string(), + directory: dir, + index: None, + show_index: false, + redirect_to_slash: false, + default: Rc::new(RefCell::new(None)), + renderer: Rc::new(directory_listing), + mime_override: None, + file_flags: named::Flags::default(), + guards: None, + } + } + + /// Show files listing for directories. + /// + /// By default show files listing is disabled. + pub fn show_files_listing(mut self) -> Self { + self.show_index = true; + self + } + + /// Redirects to a slash-ended path when browsing a directory. + /// + /// By default never redirect. + pub fn redirect_to_slash_directory(mut self) -> Self { + self.redirect_to_slash = true; + self + } + + /// Set custom directory renderer + pub fn files_listing_renderer(mut self, f: F) -> Self + where + for<'r, 's> F: Fn(&'r Directory, &'s HttpRequest) -> Result + + 'static, + { + self.renderer = Rc::new(f); + self + } + + /// Specifies mime override callback + pub fn mime_override(mut self, f: F) -> Self + where + F: Fn(&mime::Name) -> DispositionType + 'static, + { + self.mime_override = Some(Rc::new(f)); + self + } + + /// Set index file + /// + /// Shows specific index file for directory "/" instead of + /// showing files listing. + pub fn index_file>(mut self, index: T) -> Self { + self.index = Some(index.into()); + self + } + + #[inline] + /// Specifies whether to use ETag or not. + /// + /// Default is true. + pub fn use_etag(mut self, value: bool) -> Self { + self.file_flags.set(named::Flags::ETAG, value); + self + } + + #[inline] + /// Specifies whether to use Last-Modified or not. + /// + /// Default is true. + pub fn use_last_modified(mut self, value: bool) -> Self { + self.file_flags.set(named::Flags::LAST_MD, value); + self + } + + /// Specifies custom guards to use for directory listings and files. + /// + /// Default behaviour allows GET and HEAD. + #[inline] + pub fn use_guards(mut self, guards: G) -> Self { + self.guards = Some(Rc::new(Box::new(guards))); + self + } + + /// Disable `Content-Disposition` header. + /// + /// By default Content-Disposition` header is enabled. + #[inline] + pub fn disable_content_disposition(mut self) -> Self { + self.file_flags.remove(named::Flags::CONTENT_DISPOSITION); + self + } + + /// Sets default handler which is used when no matched file could be found. + pub fn default_handler(mut self, f: F) -> Self + where + F: IntoServiceFactory, + U: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + > + 'static, + { + // create and configure default resource + self.default = Rc::new(RefCell::new(Some(Rc::new(boxed::factory( + f.into_factory().map_init_err(|_| ()), + ))))); + + self + } +} + +impl HttpServiceFactory for Files { + fn register(self, config: &mut AppService) { + if self.default.borrow().is_none() { + *self.default.borrow_mut() = Some(config.default_service()); + } + let rdef = if config.is_root() { + ResourceDef::root_prefix(&self.path) + } else { + ResourceDef::prefix(&self.path) + }; + config.register_service(rdef, None, self, None) + } +} + +impl ServiceFactory for Files { + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Config = (); + type Service = FilesService; + type InitError = (); + type Future = LocalBoxFuture<'static, Result>; + + fn new_service(&self, _: &()) -> Self::Future { + let mut srv = FilesService { + directory: self.directory.clone(), + index: self.index.clone(), + show_index: self.show_index, + redirect_to_slash: self.redirect_to_slash, + default: None, + renderer: self.renderer.clone(), + mime_override: self.mime_override.clone(), + file_flags: self.file_flags, + guards: self.guards.clone(), + }; + + if let Some(ref default) = *self.default.borrow() { + default + .new_service(&()) + .map(move |result| match result { + Ok(default) => { + srv.default = Some(default); + Ok(srv) + } + Err(_) => Err(()), + }) + .boxed_local() + } else { + ok(srv).boxed_local() + } + } +} + +pub struct FilesService { + directory: PathBuf, + index: Option, + show_index: bool, + redirect_to_slash: bool, + default: Option, + renderer: Rc, + mime_override: Option>, + file_flags: named::Flags, + guards: Option>>, +} + +impl FilesService { + fn handle_err( + &mut self, + e: io::Error, + req: ServiceRequest, + ) -> Either< + Ready>, + LocalBoxFuture<'static, Result>, + > { + log::debug!("Files: Failed to handle {}: {}", req.path(), e); + if let Some(ref mut default) = self.default { + Either::Right(default.call(req)) + } else { + Either::Left(ok(req.error_response(e))) + } + } +} + +impl Service for FilesService { + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Future = Either< + Ready>, + LocalBoxFuture<'static, Result>, + >; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: ServiceRequest) -> Self::Future { + let is_method_valid = if let Some(guard) = &self.guards { + // execute user defined guards + (**guard).check(req.head()) + } else { + // default behaviour + match *req.method() { + Method::HEAD | Method::GET => true, + _ => false, + } + }; + + if !is_method_valid { + return Either::Left(ok(req.into_response( + actix_web::HttpResponse::MethodNotAllowed() + .header(header::CONTENT_TYPE, "text/plain") + .body("Request did not meet this resource's requirements."), + ))); + } + + let real_path = match PathBufWrp::get_pathbuf(req.match_info().path()) { + Ok(item) => item, + Err(e) => return Either::Left(ok(req.error_response(e))), + }; + + // full filepath + let path = match self.directory.join(&real_path.0).canonicalize() { + Ok(path) => path, + Err(e) => return self.handle_err(e, req), + }; + + if path.is_dir() { + if let Some(ref redir_index) = self.index { + if self.redirect_to_slash && !req.path().ends_with('/') { + let redirect_to = format!("{}/", req.path()); + return Either::Left(ok(req.into_response( + HttpResponse::Found() + .header(header::LOCATION, redirect_to) + .body("") + .into_body(), + ))); + } + + let path = path.join(redir_index); + + match NamedFile::open(path) { + Ok(mut named_file) => { + if let Some(ref mime_override) = self.mime_override { + let new_disposition = + mime_override(&named_file.content_type.type_()); + named_file.content_disposition.disposition = new_disposition; + } + + named_file.flags = self.file_flags; + let (req, _) = req.into_parts(); + Either::Left(ok(match named_file.into_response(&req) { + Ok(item) => ServiceResponse::new(req, item), + Err(e) => ServiceResponse::from_err(e, req), + })) + } + Err(e) => self.handle_err(e, req), + } + } else if self.show_index { + let dir = Directory::new(self.directory.clone(), path); + let (req, _) = req.into_parts(); + let x = (self.renderer)(&dir, &req); + match x { + Ok(resp) => Either::Left(ok(resp)), + Err(e) => Either::Left(ok(ServiceResponse::from_err(e, req))), + } + } else { + Either::Left(ok(ServiceResponse::from_err( + FilesError::IsDirectory, + req.into_parts().0, + ))) + } + } else { + match NamedFile::open(path) { + Ok(mut named_file) => { + if let Some(ref mime_override) = self.mime_override { + let new_disposition = + mime_override(&named_file.content_type.type_()); + named_file.content_disposition.disposition = new_disposition; + } + + named_file.flags = self.file_flags; + let (req, _) = req.into_parts(); + match named_file.into_response(&req) { + Ok(item) => { + Either::Left(ok(ServiceResponse::new(req.clone(), item))) + } + Err(e) => Either::Left(ok(ServiceResponse::from_err(e, req))), + } + } + Err(e) => self.handle_err(e, req), + } + } + } +} + +#[derive(Debug)] +struct PathBufWrp(PathBuf); + +impl PathBufWrp { + fn get_pathbuf(path: &str) -> Result { + let mut buf = PathBuf::new(); + for segment in path.split('/') { + if segment == ".." { + buf.pop(); + } else if segment.starts_with('.') { + return Err(UriSegmentError::BadStart('.')); + } else if segment.starts_with('*') { + return Err(UriSegmentError::BadStart('*')); + } else if segment.ends_with(':') { + return Err(UriSegmentError::BadEnd(':')); + } else if segment.ends_with('>') { + return Err(UriSegmentError::BadEnd('>')); + } else if segment.ends_with('<') { + return Err(UriSegmentError::BadEnd('<')); + } else if segment.is_empty() { + continue; + } else if cfg!(windows) && segment.contains('\\') { + return Err(UriSegmentError::BadChar('\\')); + } else { + buf.push(segment) + } + } + + Ok(PathBufWrp(buf)) + } +} + +impl FromRequest for PathBufWrp { + type Error = UriSegmentError; + type Future = Ready>; + type Config = (); + + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + ready(PathBufWrp::get_pathbuf(req.match_info().path())) + } +} + +#[cfg(test)] +mod tests { + use std::fs; + use std::iter::FromIterator; + use std::ops::Add; + use std::time::{Duration, SystemTime}; + + use super::*; + use actix_web::guard; + use actix_web::http::header::{ + self, ContentDisposition, DispositionParam, DispositionType, + }; + use actix_web::http::{Method, StatusCode}; + use actix_web::middleware::Compress; + use actix_web::test::{self, TestRequest}; + use actix_web::{App, Responder}; + + #[actix_rt::test] + async fn test_file_extension_to_mime() { + let m = file_extension_to_mime("jpg"); + assert_eq!(m, mime::IMAGE_JPEG); + + let m = file_extension_to_mime("invalid extension!!"); + assert_eq!(m, mime::APPLICATION_OCTET_STREAM); + + let m = file_extension_to_mime(""); + assert_eq!(m, mime::APPLICATION_OCTET_STREAM); + } + + #[actix_rt::test] + async fn test_if_modified_since_without_if_none_match() { + let file = NamedFile::open("Cargo.toml").unwrap(); + let since = + header::HttpDate::from(SystemTime::now().add(Duration::from_secs(60))); + + let req = TestRequest::default() + .header(header::IF_MODIFIED_SINCE, since) + .to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_MODIFIED); + } + + #[actix_rt::test] + async fn test_if_modified_since_with_if_none_match() { + let file = NamedFile::open("Cargo.toml").unwrap(); + let since = + header::HttpDate::from(SystemTime::now().add(Duration::from_secs(60))); + + let req = TestRequest::default() + .header(header::IF_NONE_MATCH, "miss_etag") + .header(header::IF_MODIFIED_SINCE, since) + .to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_ne!(resp.status(), StatusCode::NOT_MODIFIED); + } + + #[actix_rt::test] + async fn test_named_file_text() { + assert!(NamedFile::open("test--").is_err()); + let mut file = NamedFile::open("Cargo.toml").unwrap(); + { + file.file(); + let _f: &File = &file; + } + { + let _f: &mut File = &mut file; + } + + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "text/x-toml" + ); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "inline; filename=\"Cargo.toml\"" + ); + } + + #[actix_rt::test] + async fn test_named_file_content_disposition() { + assert!(NamedFile::open("test--").is_err()); + let mut file = NamedFile::open("Cargo.toml").unwrap(); + { + file.file(); + let _f: &File = &file; + } + { + let _f: &mut File = &mut file; + } + + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "inline; filename=\"Cargo.toml\"" + ); + + let file = NamedFile::open("Cargo.toml") + .unwrap() + .disable_content_disposition(); + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert!(resp.headers().get(header::CONTENT_DISPOSITION).is_none()); + } + + #[actix_rt::test] + async fn test_named_file_non_ascii_file_name() { + let mut file = + NamedFile::from_file(File::open("Cargo.toml").unwrap(), "貨物.toml") + .unwrap(); + { + file.file(); + let _f: &File = &file; + } + { + let _f: &mut File = &mut file; + } + + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "text/x-toml" + ); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "inline; filename=\"貨物.toml\"; filename*=UTF-8''%E8%B2%A8%E7%89%A9.toml" + ); + } + + #[actix_rt::test] + async fn test_named_file_set_content_type() { + let mut file = NamedFile::open("Cargo.toml") + .unwrap() + .set_content_type(mime::TEXT_XML); + { + file.file(); + let _f: &File = &file; + } + { + let _f: &mut File = &mut file; + } + + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "text/xml" + ); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "inline; filename=\"Cargo.toml\"" + ); + } + + #[actix_rt::test] + async fn test_named_file_image() { + let mut file = NamedFile::open("tests/test.png").unwrap(); + { + file.file(); + let _f: &File = &file; + } + { + let _f: &mut File = &mut file; + } + + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "image/png" + ); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "inline; filename=\"test.png\"" + ); + } + + #[actix_rt::test] + async fn test_named_file_image_attachment() { + let cd = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![DispositionParam::Filename(String::from("test.png"))], + }; + let mut file = NamedFile::open("tests/test.png") + .unwrap() + .set_content_disposition(cd); + { + file.file(); + let _f: &File = &file; + } + { + let _f: &mut File = &mut file; + } + + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "image/png" + ); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "attachment; filename=\"test.png\"" + ); + } + + #[actix_rt::test] + async fn test_named_file_binary() { + let mut file = NamedFile::open("tests/test.binary").unwrap(); + { + file.file(); + let _f: &File = &file; + } + { + let _f: &mut File = &mut file; + } + + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "application/octet-stream" + ); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "attachment; filename=\"test.binary\"" + ); + } + + #[actix_rt::test] + async fn test_named_file_status_code_text() { + let mut file = NamedFile::open("Cargo.toml") + .unwrap() + .set_status_code(StatusCode::NOT_FOUND); + { + file.file(); + let _f: &File = &file; + } + { + let _f: &mut File = &mut file; + } + + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "text/x-toml" + ); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "inline; filename=\"Cargo.toml\"" + ); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + } + + #[actix_rt::test] + async fn test_mime_override() { + fn all_attachment(_: &mime::Name) -> DispositionType { + DispositionType::Attachment + } + + let mut srv = test::init_service( + App::new().service( + Files::new("/", ".") + .mime_override(all_attachment) + .index_file("Cargo.toml"), + ), + ) + .await; + + let request = TestRequest::get().uri("/").to_request(); + let response = test::call_service(&mut srv, request).await; + assert_eq!(response.status(), StatusCode::OK); + + let content_disposition = response + .headers() + .get(header::CONTENT_DISPOSITION) + .expect("To have CONTENT_DISPOSITION"); + let content_disposition = content_disposition + .to_str() + .expect("Convert CONTENT_DISPOSITION to str"); + assert_eq!(content_disposition, "attachment; filename=\"Cargo.toml\""); + } + + #[actix_rt::test] + async fn test_named_file_ranges_status_code() { + let mut srv = test::init_service( + App::new().service(Files::new("/test", ".").index_file("Cargo.toml")), + ) + .await; + + // Valid range header + let request = TestRequest::get() + .uri("/t%65st/Cargo.toml") + .header(header::RANGE, "bytes=10-20") + .to_request(); + let response = test::call_service(&mut srv, request).await; + assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT); + + // Invalid range header + let request = TestRequest::get() + .uri("/t%65st/Cargo.toml") + .header(header::RANGE, "bytes=1-0") + .to_request(); + let response = test::call_service(&mut srv, request).await; + + assert_eq!(response.status(), StatusCode::RANGE_NOT_SATISFIABLE); + } + + #[actix_rt::test] + async fn test_named_file_content_range_headers() { + let mut srv = test::init_service( + App::new().service(Files::new("/test", ".").index_file("tests/test.binary")), + ) + .await; + + // Valid range header + let request = TestRequest::get() + .uri("/t%65st/tests/test.binary") + .header(header::RANGE, "bytes=10-20") + .to_request(); + + let response = test::call_service(&mut srv, request).await; + let contentrange = response + .headers() + .get(header::CONTENT_RANGE) + .unwrap() + .to_str() + .unwrap(); + + assert_eq!(contentrange, "bytes 10-20/100"); + + // Invalid range header + let request = TestRequest::get() + .uri("/t%65st/tests/test.binary") + .header(header::RANGE, "bytes=10-5") + .to_request(); + let response = test::call_service(&mut srv, request).await; + + let contentrange = response + .headers() + .get(header::CONTENT_RANGE) + .unwrap() + .to_str() + .unwrap(); + + assert_eq!(contentrange, "bytes */100"); + } + + #[actix_rt::test] + async fn test_named_file_content_length_headers() { + // use actix_web::body::{MessageBody, ResponseBody}; + + let mut srv = test::init_service( + App::new().service(Files::new("test", ".").index_file("tests/test.binary")), + ) + .await; + + // Valid range header + let request = TestRequest::get() + .uri("/t%65st/tests/test.binary") + .header(header::RANGE, "bytes=10-20") + .to_request(); + let _response = test::call_service(&mut srv, request).await; + + // let contentlength = response + // .headers() + // .get(header::CONTENT_LENGTH) + // .unwrap() + // .to_str() + // .unwrap(); + // assert_eq!(contentlength, "11"); + + // Invalid range header + let request = TestRequest::get() + .uri("/t%65st/tests/test.binary") + .header(header::RANGE, "bytes=10-8") + .to_request(); + let response = test::call_service(&mut srv, request).await; + assert_eq!(response.status(), StatusCode::RANGE_NOT_SATISFIABLE); + + // Without range header + let request = TestRequest::get() + .uri("/t%65st/tests/test.binary") + // .no_default_headers() + .to_request(); + let _response = test::call_service(&mut srv, request).await; + + // let contentlength = response + // .headers() + // .get(header::CONTENT_LENGTH) + // .unwrap() + // .to_str() + // .unwrap(); + // assert_eq!(contentlength, "100"); + + // chunked + let request = TestRequest::get() + .uri("/t%65st/tests/test.binary") + .to_request(); + let response = test::call_service(&mut srv, request).await; + + // with enabled compression + // { + // let te = response + // .headers() + // .get(header::TRANSFER_ENCODING) + // .unwrap() + // .to_str() + // .unwrap(); + // assert_eq!(te, "chunked"); + // } + + let bytes = test::read_body(response).await; + let data = Bytes::from(fs::read("tests/test.binary").unwrap()); + assert_eq!(bytes, data); + } + + #[actix_rt::test] + async fn test_head_content_length_headers() { + let mut srv = test::init_service( + App::new().service(Files::new("test", ".").index_file("tests/test.binary")), + ) + .await; + + // Valid range header + let request = TestRequest::default() + .method(Method::HEAD) + .uri("/t%65st/tests/test.binary") + .to_request(); + let _response = test::call_service(&mut srv, request).await; + + // TODO: fix check + // let contentlength = response + // .headers() + // .get(header::CONTENT_LENGTH) + // .unwrap() + // .to_str() + // .unwrap(); + // assert_eq!(contentlength, "100"); + } + + #[actix_rt::test] + async fn test_static_files_with_spaces() { + let mut srv = test::init_service( + App::new().service(Files::new("/", ".").index_file("Cargo.toml")), + ) + .await; + let request = TestRequest::get() + .uri("/tests/test%20space.binary") + .to_request(); + let response = test::call_service(&mut srv, request).await; + assert_eq!(response.status(), StatusCode::OK); + + let bytes = test::read_body(response).await; + let data = Bytes::from(fs::read("tests/test space.binary").unwrap()); + assert_eq!(bytes, data); + } + + #[actix_rt::test] + async fn test_files_not_allowed() { + let mut srv = test::init_service(App::new().service(Files::new("/", "."))).await; + + let req = TestRequest::default() + .uri("/Cargo.toml") + .method(Method::POST) + .to_request(); + + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + + let mut srv = test::init_service(App::new().service(Files::new("/", "."))).await; + let req = TestRequest::default() + .method(Method::PUT) + .uri("/Cargo.toml") + .to_request(); + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + } + + #[actix_rt::test] + async fn test_files_guards() { + let mut srv = test::init_service( + App::new().service(Files::new("/", ".").use_guards(guard::Post())), + ) + .await; + + let req = TestRequest::default() + .uri("/Cargo.toml") + .method(Method::POST) + .to_request(); + + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_named_file_content_encoding() { + let mut srv = test::init_service(App::new().wrap(Compress::default()).service( + web::resource("/").to(|| { + async { + NamedFile::open("Cargo.toml") + .unwrap() + .set_content_encoding(header::ContentEncoding::Identity) + } + }), + )) + .await; + + let request = TestRequest::get() + .uri("/") + .header(header::ACCEPT_ENCODING, "gzip") + .to_request(); + let res = test::call_service(&mut srv, request).await; + assert_eq!(res.status(), StatusCode::OK); + assert!(!res.headers().contains_key(header::CONTENT_ENCODING)); + } + + #[actix_rt::test] + async fn test_named_file_content_encoding_gzip() { + let mut srv = test::init_service(App::new().wrap(Compress::default()).service( + web::resource("/").to(|| { + async { + NamedFile::open("Cargo.toml") + .unwrap() + .set_content_encoding(header::ContentEncoding::Gzip) + } + }), + )) + .await; + + let request = TestRequest::get() + .uri("/") + .header(header::ACCEPT_ENCODING, "gzip") + .to_request(); + let res = test::call_service(&mut srv, request).await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers() + .get(header::CONTENT_ENCODING) + .unwrap() + .to_str() + .unwrap(), + "gzip" + ); + } + + #[actix_rt::test] + async fn test_named_file_allowed_method() { + let req = TestRequest::default().method(Method::GET).to_http_request(); + let file = NamedFile::open("Cargo.toml").unwrap(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_static_files() { + let mut srv = test::init_service( + App::new().service(Files::new("/", ".").show_files_listing()), + ) + .await; + let req = TestRequest::with_uri("/missing").to_request(); + + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + let mut srv = test::init_service(App::new().service(Files::new("/", "."))).await; + + let req = TestRequest::default().to_request(); + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + let mut srv = test::init_service( + App::new().service(Files::new("/", ".").show_files_listing()), + ) + .await; + let req = TestRequest::with_uri("/tests").to_request(); + let resp = test::call_service(&mut srv, req).await; + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "text/html; charset=utf-8" + ); + + let bytes = test::read_body(resp).await; + assert!(format!("{:?}", bytes).contains("/tests/test.png")); + } + + #[actix_rt::test] + async fn test_redirect_to_slash_directory() { + // should not redirect if no index + let mut srv = test::init_service( + App::new().service(Files::new("/", ".").redirect_to_slash_directory()), + ) + .await; + let req = TestRequest::with_uri("/tests").to_request(); + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + // should redirect if index present + let mut srv = test::init_service( + App::new().service( + Files::new("/", ".") + .index_file("test.png") + .redirect_to_slash_directory(), + ), + ) + .await; + let req = TestRequest::with_uri("/tests").to_request(); + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::FOUND); + + // should not redirect if the path is wrong + let req = TestRequest::with_uri("/not_existing").to_request(); + let resp = test::call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + } + + #[actix_rt::test] + async fn test_static_files_bad_directory() { + let _st: Files = Files::new("/", "missing"); + let _st: Files = Files::new("/", "Cargo.toml"); + } + + #[actix_rt::test] + async fn test_default_handler_file_missing() { + let mut st = Files::new("/", ".") + .default_handler(|req: ServiceRequest| { + ok(req.into_response(HttpResponse::Ok().body("default content"))) + }) + .new_service(&()) + .await + .unwrap(); + let req = TestRequest::with_uri("/missing").to_srv_request(); + + let resp = test::call_service(&mut st, req).await; + assert_eq!(resp.status(), StatusCode::OK); + let bytes = test::read_body(resp).await; + assert_eq!(bytes, Bytes::from_static(b"default content")); + } + + // #[actix_rt::test] + // async fn test_serve_index() { + // let st = Files::new(".").index_file("test.binary"); + // let req = TestRequest::default().uri("/tests").finish(); + + // let resp = st.handle(&req).respond_to(&req).unwrap(); + // let resp = resp.as_msg(); + // assert_eq!(resp.status(), StatusCode::OK); + // assert_eq!( + // resp.headers() + // .get(header::CONTENT_TYPE) + // .expect("content type"), + // "application/octet-stream" + // ); + // assert_eq!( + // resp.headers() + // .get(header::CONTENT_DISPOSITION) + // .expect("content disposition"), + // "attachment; filename=\"test.binary\"" + // ); + + // let req = TestRequest::default().uri("/tests/").finish(); + // let resp = st.handle(&req).respond_to(&req).unwrap(); + // let resp = resp.as_msg(); + // assert_eq!(resp.status(), StatusCode::OK); + // assert_eq!( + // resp.headers().get(header::CONTENT_TYPE).unwrap(), + // "application/octet-stream" + // ); + // assert_eq!( + // resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + // "attachment; filename=\"test.binary\"" + // ); + + // // nonexistent index file + // let req = TestRequest::default().uri("/tests/unknown").finish(); + // let resp = st.handle(&req).respond_to(&req).unwrap(); + // let resp = resp.as_msg(); + // assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + // let req = TestRequest::default().uri("/tests/unknown/").finish(); + // let resp = st.handle(&req).respond_to(&req).unwrap(); + // let resp = resp.as_msg(); + // assert_eq!(resp.status(), StatusCode::NOT_FOUND); + // } + + // #[actix_rt::test] + // async fn test_serve_index_nested() { + // let st = Files::new(".").index_file("mod.rs"); + // let req = TestRequest::default().uri("/src/client").finish(); + // let resp = st.handle(&req).respond_to(&req).unwrap(); + // let resp = resp.as_msg(); + // assert_eq!(resp.status(), StatusCode::OK); + // assert_eq!( + // resp.headers().get(header::CONTENT_TYPE).unwrap(), + // "text/x-rust" + // ); + // assert_eq!( + // resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + // "inline; filename=\"mod.rs\"" + // ); + // } + + // #[actix_rt::test] + // fn integration_serve_index() { + // let mut srv = test::TestServer::with_factory(|| { + // App::new().handler( + // "test", + // Files::new(".").index_file("Cargo.toml"), + // ) + // }); + + // let request = srv.get().uri(srv.url("/test")).finish().unwrap(); + // let response = srv.execute(request.send()).unwrap(); + // assert_eq!(response.status(), StatusCode::OK); + // let bytes = srv.execute(response.body()).unwrap(); + // let data = Bytes::from(fs::read("Cargo.toml").unwrap()); + // assert_eq!(bytes, data); + + // let request = srv.get().uri(srv.url("/test/")).finish().unwrap(); + // let response = srv.execute(request.send()).unwrap(); + // assert_eq!(response.status(), StatusCode::OK); + // let bytes = srv.execute(response.body()).unwrap(); + // let data = Bytes::from(fs::read("Cargo.toml").unwrap()); + // assert_eq!(bytes, data); + + // // nonexistent index file + // let request = srv.get().uri(srv.url("/test/unknown")).finish().unwrap(); + // let response = srv.execute(request.send()).unwrap(); + // assert_eq!(response.status(), StatusCode::NOT_FOUND); + + // let request = srv.get().uri(srv.url("/test/unknown/")).finish().unwrap(); + // let response = srv.execute(request.send()).unwrap(); + // assert_eq!(response.status(), StatusCode::NOT_FOUND); + // } + + // #[actix_rt::test] + // fn integration_percent_encoded() { + // let mut srv = test::TestServer::with_factory(|| { + // App::new().handler( + // "test", + // Files::new(".").index_file("Cargo.toml"), + // ) + // }); + + // let request = srv + // .get() + // .uri(srv.url("/test/%43argo.toml")) + // .finish() + // .unwrap(); + // let response = srv.execute(request.send()).unwrap(); + // assert_eq!(response.status(), StatusCode::OK); + // } + + #[actix_rt::test] + async fn test_path_buf() { + assert_eq!( + PathBufWrp::get_pathbuf("/test/.tt").map(|t| t.0), + Err(UriSegmentError::BadStart('.')) + ); + assert_eq!( + PathBufWrp::get_pathbuf("/test/*tt").map(|t| t.0), + Err(UriSegmentError::BadStart('*')) + ); + assert_eq!( + PathBufWrp::get_pathbuf("/test/tt:").map(|t| t.0), + Err(UriSegmentError::BadEnd(':')) + ); + assert_eq!( + PathBufWrp::get_pathbuf("/test/tt<").map(|t| t.0), + Err(UriSegmentError::BadEnd('<')) + ); + assert_eq!( + PathBufWrp::get_pathbuf("/test/tt>").map(|t| t.0), + Err(UriSegmentError::BadEnd('>')) + ); + assert_eq!( + PathBufWrp::get_pathbuf("/seg1/seg2/").unwrap().0, + PathBuf::from_iter(vec!["seg1", "seg2"]) + ); + assert_eq!( + PathBufWrp::get_pathbuf("/seg1/../seg2/").unwrap().0, + PathBuf::from_iter(vec!["seg2"]) + ); + } +} diff --git a/actix-files/src/named.rs b/actix-files/src/named.rs new file mode 100644 index 000000000..0dcbd93b8 --- /dev/null +++ b/actix-files/src/named.rs @@ -0,0 +1,455 @@ +use std::fs::{File, Metadata}; +use std::io; +use std::ops::{Deref, DerefMut}; +use std::path::{Path, PathBuf}; +use std::time::{SystemTime, UNIX_EPOCH}; + +#[cfg(unix)] +use std::os::unix::fs::MetadataExt; + +use bitflags::bitflags; +use mime; +use mime_guess::from_path; + +use actix_http::body::SizedStream; +use actix_web::http::header::{ + self, Charset, ContentDisposition, DispositionParam, DispositionType, ExtendedValue, +}; +use actix_web::http::{ContentEncoding, StatusCode}; +use actix_web::middleware::BodyEncoding; +use actix_web::{Error, HttpMessage, HttpRequest, HttpResponse, Responder}; +use futures::future::{ready, Ready}; + +use crate::range::HttpRange; +use crate::ChunkedReadFile; + +bitflags! { + pub(crate) struct Flags: u8 { + const ETAG = 0b0000_0001; + const LAST_MD = 0b0000_0010; + const CONTENT_DISPOSITION = 0b0000_0100; + } +} + +impl Default for Flags { + fn default() -> Self { + Flags::all() + } +} + +/// A file with an associated name. +#[derive(Debug)] +pub struct NamedFile { + path: PathBuf, + file: File, + modified: Option, + pub(crate) md: Metadata, + pub(crate) flags: Flags, + pub(crate) status_code: StatusCode, + pub(crate) content_type: mime::Mime, + pub(crate) content_disposition: header::ContentDisposition, + pub(crate) encoding: Option, +} + +impl NamedFile { + /// Creates an instance from a previously opened file. + /// + /// The given `path` need not exist and is only used to determine the `ContentType` and + /// `ContentDisposition` headers. + /// + /// # Examples + /// + /// ```rust + /// use actix_files::NamedFile; + /// use std::io::{self, Write}; + /// use std::env; + /// use std::fs::File; + /// + /// fn main() -> io::Result<()> { + /// let mut file = File::create("foo.txt")?; + /// file.write_all(b"Hello, world!")?; + /// let named_file = NamedFile::from_file(file, "bar.txt")?; + /// # std::fs::remove_file("foo.txt"); + /// Ok(()) + /// } + /// ``` + pub fn from_file>(file: File, path: P) -> io::Result { + let path = path.as_ref().to_path_buf(); + + // Get the name of the file and use it to construct default Content-Type + // and Content-Disposition values + let (content_type, content_disposition) = { + let filename = match path.file_name() { + Some(name) => name.to_string_lossy(), + None => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Provided path has no filename", + )); + } + }; + + let ct = from_path(&path).first_or_octet_stream(); + let disposition_type = match ct.type_() { + mime::IMAGE | mime::TEXT | mime::VIDEO => DispositionType::Inline, + _ => DispositionType::Attachment, + }; + let mut parameters = + vec![DispositionParam::Filename(String::from(filename.as_ref()))]; + if !filename.is_ascii() { + parameters.push(DispositionParam::FilenameExt(ExtendedValue { + charset: Charset::Ext(String::from("UTF-8")), + language_tag: None, + value: filename.into_owned().into_bytes(), + })) + } + let cd = ContentDisposition { + disposition: disposition_type, + parameters: parameters, + }; + (ct, cd) + }; + + let md = file.metadata()?; + let modified = md.modified().ok(); + let encoding = None; + Ok(NamedFile { + path, + file, + content_type, + content_disposition, + md, + modified, + encoding, + status_code: StatusCode::OK, + flags: Flags::default(), + }) + } + + /// Attempts to open a file in read-only mode. + /// + /// # Examples + /// + /// ```rust + /// use actix_files::NamedFile; + /// + /// let file = NamedFile::open("foo.txt"); + /// ``` + pub fn open>(path: P) -> io::Result { + Self::from_file(File::open(&path)?, path) + } + + /// Returns reference to the underlying `File` object. + #[inline] + pub fn file(&self) -> &File { + &self.file + } + + /// Retrieve the path of this file. + /// + /// # Examples + /// + /// ```rust + /// # use std::io; + /// use actix_files::NamedFile; + /// + /// # fn path() -> io::Result<()> { + /// let file = NamedFile::open("test.txt")?; + /// assert_eq!(file.path().as_os_str(), "foo.txt"); + /// # Ok(()) + /// # } + /// ``` + #[inline] + pub fn path(&self) -> &Path { + self.path.as_path() + } + + /// Set response **Status Code** + pub fn set_status_code(mut self, status: StatusCode) -> Self { + self.status_code = status; + self + } + + /// Set the MIME Content-Type for serving this file. By default + /// the Content-Type is inferred from the filename extension. + #[inline] + pub fn set_content_type(mut self, mime_type: mime::Mime) -> Self { + self.content_type = mime_type; + self + } + + /// Set the Content-Disposition for serving this file. This allows + /// changing the inline/attachment disposition as well as the filename + /// sent to the peer. By default the disposition is `inline` for text, + /// image, and video content types, and `attachment` otherwise, and + /// the filename is taken from the path provided in the `open` method + /// after converting it to UTF-8 using. + /// [to_string_lossy](https://doc.rust-lang.org/std/ffi/struct.OsStr.html#method.to_string_lossy). + #[inline] + pub fn set_content_disposition(mut self, cd: header::ContentDisposition) -> Self { + self.content_disposition = cd; + self.flags.insert(Flags::CONTENT_DISPOSITION); + self + } + + /// Disable `Content-Disposition` header. + /// + /// By default Content-Disposition` header is enabled. + #[inline] + pub fn disable_content_disposition(mut self) -> Self { + self.flags.remove(Flags::CONTENT_DISPOSITION); + self + } + + /// Set content encoding for serving this file + #[inline] + pub fn set_content_encoding(mut self, enc: ContentEncoding) -> Self { + self.encoding = Some(enc); + self + } + + #[inline] + ///Specifies whether to use ETag or not. + /// + ///Default is true. + pub fn use_etag(mut self, value: bool) -> Self { + self.flags.set(Flags::ETAG, value); + self + } + + #[inline] + ///Specifies whether to use Last-Modified or not. + /// + ///Default is true. + pub fn use_last_modified(mut self, value: bool) -> Self { + self.flags.set(Flags::LAST_MD, value); + self + } + + pub(crate) fn etag(&self) -> Option { + // This etag format is similar to Apache's. + self.modified.as_ref().map(|mtime| { + let ino = { + #[cfg(unix)] + { + self.md.ino() + } + #[cfg(not(unix))] + { + 0 + } + }; + + let dur = mtime + .duration_since(UNIX_EPOCH) + .expect("modification time must be after epoch"); + header::EntityTag::strong(format!( + "{:x}:{:x}:{:x}:{:x}", + ino, + self.md.len(), + dur.as_secs(), + dur.subsec_nanos() + )) + }) + } + + pub(crate) fn last_modified(&self) -> Option { + self.modified.map(|mtime| mtime.into()) + } + + pub fn into_response(self, req: &HttpRequest) -> Result { + if self.status_code != StatusCode::OK { + let mut resp = HttpResponse::build(self.status_code); + resp.set(header::ContentType(self.content_type.clone())) + .if_true(self.flags.contains(Flags::CONTENT_DISPOSITION), |res| { + res.header( + header::CONTENT_DISPOSITION, + self.content_disposition.to_string(), + ); + }); + if let Some(current_encoding) = self.encoding { + resp.encoding(current_encoding); + } + let reader = ChunkedReadFile { + size: self.md.len(), + offset: 0, + file: Some(self.file), + fut: None, + counter: 0, + }; + return Ok(resp.streaming(reader)); + } + + let etag = if self.flags.contains(Flags::ETAG) { + self.etag() + } else { + None + }; + let last_modified = if self.flags.contains(Flags::LAST_MD) { + self.last_modified() + } else { + None + }; + + // check preconditions + let precondition_failed = if !any_match(etag.as_ref(), req) { + true + } else if let (Some(ref m), Some(header::IfUnmodifiedSince(ref since))) = + (last_modified, req.get_header()) + { + let t1: SystemTime = m.clone().into(); + let t2: SystemTime = since.clone().into(); + match (t1.duration_since(UNIX_EPOCH), t2.duration_since(UNIX_EPOCH)) { + (Ok(t1), Ok(t2)) => t1 > t2, + _ => false, + } + } else { + false + }; + + // check last modified + let not_modified = if !none_match(etag.as_ref(), req) { + true + } else if req.headers().contains_key(&header::IF_NONE_MATCH) { + false + } else if let (Some(ref m), Some(header::IfModifiedSince(ref since))) = + (last_modified, req.get_header()) + { + let t1: SystemTime = m.clone().into(); + let t2: SystemTime = since.clone().into(); + match (t1.duration_since(UNIX_EPOCH), t2.duration_since(UNIX_EPOCH)) { + (Ok(t1), Ok(t2)) => t1 <= t2, + _ => false, + } + } else { + false + }; + + let mut resp = HttpResponse::build(self.status_code); + resp.set(header::ContentType(self.content_type.clone())) + .if_true(self.flags.contains(Flags::CONTENT_DISPOSITION), |res| { + res.header( + header::CONTENT_DISPOSITION, + self.content_disposition.to_string(), + ); + }); + // default compressing + if let Some(current_encoding) = self.encoding { + resp.encoding(current_encoding); + } + + resp.if_some(last_modified, |lm, resp| { + resp.set(header::LastModified(lm)); + }) + .if_some(etag, |etag, resp| { + resp.set(header::ETag(etag)); + }); + + resp.header(header::ACCEPT_RANGES, "bytes"); + + let mut length = self.md.len(); + let mut offset = 0; + + // check for range header + if let Some(ranges) = req.headers().get(&header::RANGE) { + if let Ok(rangesheader) = ranges.to_str() { + if let Ok(rangesvec) = HttpRange::parse(rangesheader, length) { + length = rangesvec[0].length; + offset = rangesvec[0].start; + resp.encoding(ContentEncoding::Identity); + resp.header( + header::CONTENT_RANGE, + format!( + "bytes {}-{}/{}", + offset, + offset + length - 1, + self.md.len() + ), + ); + } else { + resp.header(header::CONTENT_RANGE, format!("bytes */{}", length)); + return Ok(resp.status(StatusCode::RANGE_NOT_SATISFIABLE).finish()); + }; + } else { + return Ok(resp.status(StatusCode::BAD_REQUEST).finish()); + }; + }; + + if precondition_failed { + return Ok(resp.status(StatusCode::PRECONDITION_FAILED).finish()); + } else if not_modified { + return Ok(resp.status(StatusCode::NOT_MODIFIED).finish()); + } + + let reader = ChunkedReadFile { + offset, + size: length, + file: Some(self.file), + fut: None, + counter: 0, + }; + if offset != 0 || length != self.md.len() { + Ok(resp.status(StatusCode::PARTIAL_CONTENT).streaming(reader)) + } else { + Ok(resp.body(SizedStream::new(length, reader))) + } + } +} + +impl Deref for NamedFile { + type Target = File; + + fn deref(&self) -> &File { + &self.file + } +} + +impl DerefMut for NamedFile { + fn deref_mut(&mut self) -> &mut File { + &mut self.file + } +} + +/// Returns true if `req` has no `If-Match` header or one which matches `etag`. +fn any_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool { + match req.get_header::() { + None | Some(header::IfMatch::Any) => true, + Some(header::IfMatch::Items(ref items)) => { + if let Some(some_etag) = etag { + for item in items { + if item.strong_eq(some_etag) { + return true; + } + } + } + false + } + } +} + +/// Returns true if `req` doesn't have an `If-None-Match` header matching `req`. +fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool { + match req.get_header::() { + Some(header::IfNoneMatch::Any) => false, + Some(header::IfNoneMatch::Items(ref items)) => { + if let Some(some_etag) = etag { + for item in items { + if item.weak_eq(some_etag) { + return false; + } + } + } + true + } + None => true, + } +} + +impl Responder for NamedFile { + type Error = Error; + type Future = Ready>; + + fn respond_to(self, req: &HttpRequest) -> Self::Future { + ready(self.into_response(req)) + } +} diff --git a/actix-files/src/range.rs b/actix-files/src/range.rs new file mode 100644 index 000000000..47673b0b0 --- /dev/null +++ b/actix-files/src/range.rs @@ -0,0 +1,375 @@ +/// HTTP Range header representation. +#[derive(Debug, Clone, Copy)] +pub struct HttpRange { + pub start: u64, + pub length: u64, +} + +static PREFIX: &str = "bytes="; +const PREFIX_LEN: usize = 6; + +impl HttpRange { + /// Parses Range HTTP header string as per RFC 2616. + /// + /// `header` is HTTP Range header (e.g. `bytes=bytes=0-9`). + /// `size` is full size of response (file). + pub fn parse(header: &str, size: u64) -> Result, ()> { + if header.is_empty() { + return Ok(Vec::new()); + } + if !header.starts_with(PREFIX) { + return Err(()); + } + + let size_sig = size as i64; + let mut no_overlap = false; + + let all_ranges: Vec> = header[PREFIX_LEN..] + .split(',') + .map(|x| x.trim()) + .filter(|x| !x.is_empty()) + .map(|ra| { + let mut start_end_iter = ra.split('-'); + + let start_str = start_end_iter.next().ok_or(())?.trim(); + let end_str = start_end_iter.next().ok_or(())?.trim(); + + if start_str.is_empty() { + // If no start is specified, end specifies the + // range start relative to the end of the file. + let mut length: i64 = end_str.parse().map_err(|_| ())?; + + if length > size_sig { + length = size_sig; + } + + Ok(Some(HttpRange { + start: (size_sig - length) as u64, + length: length as u64, + })) + } else { + let start: i64 = start_str.parse().map_err(|_| ())?; + + if start < 0 { + return Err(()); + } + if start >= size_sig { + no_overlap = true; + return Ok(None); + } + + let length = if end_str.is_empty() { + // If no end is specified, range extends to end of the file. + size_sig - start + } else { + let mut end: i64 = end_str.parse().map_err(|_| ())?; + + if start > end { + return Err(()); + } + + if end >= size_sig { + end = size_sig - 1; + } + + end - start + 1 + }; + + Ok(Some(HttpRange { + start: start as u64, + length: length as u64, + })) + } + }) + .collect::>()?; + + let ranges: Vec = all_ranges.into_iter().filter_map(|x| x).collect(); + + if no_overlap && ranges.is_empty() { + return Err(()); + } + + Ok(ranges) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct T(&'static str, u64, Vec); + + #[test] + fn test_parse() { + let tests = vec![ + T("", 0, vec![]), + T("", 1000, vec![]), + T("foo", 0, vec![]), + T("bytes=", 0, vec![]), + T("bytes=7", 10, vec![]), + T("bytes= 7 ", 10, vec![]), + T("bytes=1-", 0, vec![]), + T("bytes=5-4", 10, vec![]), + T("bytes=0-2,5-4", 10, vec![]), + T("bytes=2-5,4-3", 10, vec![]), + T("bytes=--5,4--3", 10, vec![]), + T("bytes=A-", 10, vec![]), + T("bytes=A- ", 10, vec![]), + T("bytes=A-Z", 10, vec![]), + T("bytes= -Z", 10, vec![]), + T("bytes=5-Z", 10, vec![]), + T("bytes=Ran-dom, garbage", 10, vec![]), + T("bytes=0x01-0x02", 10, vec![]), + T("bytes= ", 10, vec![]), + T("bytes= , , , ", 10, vec![]), + T( + "bytes=0-9", + 10, + vec![HttpRange { + start: 0, + length: 10, + }], + ), + T( + "bytes=0-", + 10, + vec![HttpRange { + start: 0, + length: 10, + }], + ), + T( + "bytes=5-", + 10, + vec![HttpRange { + start: 5, + length: 5, + }], + ), + T( + "bytes=0-20", + 10, + vec![HttpRange { + start: 0, + length: 10, + }], + ), + T( + "bytes=15-,0-5", + 10, + vec![HttpRange { + start: 0, + length: 6, + }], + ), + T( + "bytes=1-2,5-", + 10, + vec![ + HttpRange { + start: 1, + length: 2, + }, + HttpRange { + start: 5, + length: 5, + }, + ], + ), + T( + "bytes=-2 , 7-", + 11, + vec![ + HttpRange { + start: 9, + length: 2, + }, + HttpRange { + start: 7, + length: 4, + }, + ], + ), + T( + "bytes=0-0 ,2-2, 7-", + 11, + vec![ + HttpRange { + start: 0, + length: 1, + }, + HttpRange { + start: 2, + length: 1, + }, + HttpRange { + start: 7, + length: 4, + }, + ], + ), + T( + "bytes=-5", + 10, + vec![HttpRange { + start: 5, + length: 5, + }], + ), + T( + "bytes=-15", + 10, + vec![HttpRange { + start: 0, + length: 10, + }], + ), + T( + "bytes=0-499", + 10000, + vec![HttpRange { + start: 0, + length: 500, + }], + ), + T( + "bytes=500-999", + 10000, + vec![HttpRange { + start: 500, + length: 500, + }], + ), + T( + "bytes=-500", + 10000, + vec![HttpRange { + start: 9500, + length: 500, + }], + ), + T( + "bytes=9500-", + 10000, + vec![HttpRange { + start: 9500, + length: 500, + }], + ), + T( + "bytes=0-0,-1", + 10000, + vec![ + HttpRange { + start: 0, + length: 1, + }, + HttpRange { + start: 9999, + length: 1, + }, + ], + ), + T( + "bytes=500-600,601-999", + 10000, + vec![ + HttpRange { + start: 500, + length: 101, + }, + HttpRange { + start: 601, + length: 399, + }, + ], + ), + T( + "bytes=500-700,601-999", + 10000, + vec![ + HttpRange { + start: 500, + length: 201, + }, + HttpRange { + start: 601, + length: 399, + }, + ], + ), + // Match Apache laxity: + T( + "bytes= 1 -2 , 4- 5, 7 - 8 , ,,", + 11, + vec![ + HttpRange { + start: 1, + length: 2, + }, + HttpRange { + start: 4, + length: 2, + }, + HttpRange { + start: 7, + length: 2, + }, + ], + ), + ]; + + for t in tests { + let header = t.0; + let size = t.1; + let expected = t.2; + + let res = HttpRange::parse(header, size); + + if res.is_err() { + if expected.is_empty() { + continue; + } else { + assert!( + false, + "parse({}, {}) returned error {:?}", + header, + size, + res.unwrap_err() + ); + } + } + + let got = res.unwrap(); + + if got.len() != expected.len() { + assert!( + false, + "len(parseRange({}, {})) = {}, want {}", + header, + size, + got.len(), + expected.len() + ); + continue; + } + + for i in 0..expected.len() { + if got[i].start != expected[i].start { + assert!( + false, + "parseRange({}, {})[{}].start = {}, want {}", + header, size, i, got[i].start, expected[i].start + ) + } + if got[i].length != expected[i].length { + assert!( + false, + "parseRange({}, {})[{}].length = {}, want {}", + header, size, i, got[i].length, expected[i].length + ) + } + } + } + } +} diff --git a/actix-files/tests/test space.binary b/actix-files/tests/test space.binary new file mode 100644 index 000000000..ef8ff0245 --- /dev/null +++ b/actix-files/tests/test space.binary @@ -0,0 +1 @@ +ÂTÇ‘É‚Vù2þvI ª–\ÇRË™–ˆæeÞvDØ:è—½¬RVÖYpíÿ;ÍÏGñùp!2÷CŒ.– û®õpA !ûߦÙx j+Uc÷±©X”c%Û;ï"yì­AI \ No newline at end of file diff --git a/actix-files/tests/test.binary b/actix-files/tests/test.binary new file mode 100644 index 000000000..ef8ff0245 --- /dev/null +++ b/actix-files/tests/test.binary @@ -0,0 +1 @@ +ÂTÇ‘É‚Vù2þvI ª–\ÇRË™–ˆæeÞvDØ:è—½¬RVÖYpíÿ;ÍÏGñùp!2÷CŒ.– û®õpA !ûߦÙx j+Uc÷±©X”c%Û;ï"yì­AI \ No newline at end of file diff --git a/actix-files/tests/test.png b/actix-files/tests/test.png new file mode 100644 index 000000000..6b7cdc0b8 Binary files /dev/null and b/actix-files/tests/test.png differ diff --git a/actix-framed/Cargo.toml b/actix-framed/Cargo.toml new file mode 100644 index 000000000..4783daefd --- /dev/null +++ b/actix-framed/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "actix-framed" +version = "0.3.0-alpha.1" +authors = ["Nikolay Kim "] +description = "Actix framed app server" +readme = "README.md" +keywords = ["http", "web", "framework", "async", "futures"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" +documentation = "https://docs.rs/actix-framed/" +categories = ["network-programming", "asynchronous", + "web-programming::http-server", + "web-programming::websocket"] +license = "MIT/Apache-2.0" +edition = "2018" +workspace =".." + +[lib] +name = "actix_framed" +path = "src/lib.rs" + +[dependencies] +actix-codec = "0.2.0-alpha.1" +actix-service = "1.0.0-alpha.1" +actix-router = "0.1.2" +actix-rt = "1.0.0-alpha.1" +actix-http = "0.3.0-alpha.1" +actix-server-config = "0.3.0-alpha.1" + +bytes = "0.4" +futures = "0.3.1" +pin-project = "0.4.6" +log = "0.4" + +[dev-dependencies] +actix-server = { version = "0.8.0-alpha.1", features=["openssl"] } +actix-connect = { version = "0.3.0-alpha.1", features=["openssl"] } +actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] } +actix-utils = "0.5.0-alpha.1" diff --git a/actix-framed/LICENSE-APACHE b/actix-framed/LICENSE-APACHE new file mode 100644 index 000000000..6cdf2d16c --- /dev/null +++ b/actix-framed/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2017-NOW Nikolay Kim + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/actix-framed/LICENSE-MIT b/actix-framed/LICENSE-MIT new file mode 100644 index 000000000..0f80296ae --- /dev/null +++ b/actix-framed/LICENSE-MIT @@ -0,0 +1,25 @@ +Copyright (c) 2017 Nikolay Kim + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/actix-framed/README.md b/actix-framed/README.md new file mode 100644 index 000000000..1714b3640 --- /dev/null +++ b/actix-framed/README.md @@ -0,0 +1,8 @@ +# Framed app for actix web [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-framed)](https://crates.io/crates/actix-framed) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +## Documentation & community resources + +* [API Documentation](https://docs.rs/actix-framed/) +* [Chat on gitter](https://gitter.im/actix/actix) +* Cargo package: [actix-framed](https://crates.io/crates/actix-framed) +* Minimum supported Rust version: 1.33 or later diff --git a/actix-framed/changes.md b/actix-framed/changes.md new file mode 100644 index 000000000..6e67e00d8 --- /dev/null +++ b/actix-framed/changes.md @@ -0,0 +1,20 @@ +# Changes + +## [0.2.1] - 2019-07-20 + +* Remove unneeded actix-utils dependency + + +## [0.2.0] - 2019-05-12 + +* Update dependencies + + +## [0.1.0] - 2019-04-16 + +* Update tests + + +## [0.1.0-alpha.1] - 2019-04-12 + +* Initial release diff --git a/actix-framed/src/app.rs b/actix-framed/src/app.rs new file mode 100644 index 000000000..f3e746e9f --- /dev/null +++ b/actix-framed/src/app.rs @@ -0,0 +1,222 @@ +use std::future::Future; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_http::h1::{Codec, SendResponse}; +use actix_http::{Error, Request, Response}; +use actix_router::{Path, Router, Url}; +use actix_server_config::ServerConfig; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; +use futures::future::{ok, FutureExt, LocalBoxFuture}; + +use crate::helpers::{BoxedHttpNewService, BoxedHttpService, HttpNewService}; +use crate::request::FramedRequest; +use crate::state::State; + +type BoxedResponse = LocalBoxFuture<'static, Result<(), Error>>; + +pub trait HttpServiceFactory { + type Factory: ServiceFactory; + + fn path(&self) -> &str; + + fn create(self) -> Self::Factory; +} + +/// Application builder +pub struct FramedApp { + state: State, + services: Vec<(String, BoxedHttpNewService>)>, +} + +impl FramedApp { + pub fn new() -> Self { + FramedApp { + state: State::new(()), + services: Vec::new(), + } + } +} + +impl FramedApp { + pub fn with(state: S) -> FramedApp { + FramedApp { + services: Vec::new(), + state: State::new(state), + } + } + + pub fn service(mut self, factory: U) -> Self + where + U: HttpServiceFactory, + U::Factory: ServiceFactory< + Config = (), + Request = FramedRequest, + Response = (), + Error = Error, + InitError = (), + > + 'static, + ::Future: 'static, + ::Service: Service< + Request = FramedRequest, + Response = (), + Error = Error, + Future = LocalBoxFuture<'static, Result<(), Error>>, + >, + { + let path = factory.path().to_string(); + self.services + .push((path, Box::new(HttpNewService::new(factory.create())))); + self + } +} + +impl IntoServiceFactory> for FramedApp +where + T: AsyncRead + AsyncWrite + Unpin + 'static, + S: 'static, +{ + fn into_factory(self) -> FramedAppFactory { + FramedAppFactory { + state: self.state, + services: Rc::new(self.services), + } + } +} + +#[derive(Clone)] +pub struct FramedAppFactory { + state: State, + services: Rc>)>>, +} + +impl ServiceFactory for FramedAppFactory +where + T: AsyncRead + AsyncWrite + Unpin + 'static, + S: 'static, +{ + type Config = ServerConfig; + type Request = (Request, Framed); + type Response = (); + type Error = Error; + type InitError = (); + type Service = FramedAppService; + type Future = CreateService; + + fn new_service(&self, _: &ServerConfig) -> Self::Future { + CreateService { + fut: self + .services + .iter() + .map(|(path, service)| { + CreateServiceItem::Future( + Some(path.clone()), + service.new_service(&()), + ) + }) + .collect(), + state: self.state.clone(), + } + } +} + +#[doc(hidden)] +pub struct CreateService { + fut: Vec>, + state: State, +} + +enum CreateServiceItem { + Future( + Option, + LocalBoxFuture<'static, Result>, ()>>, + ), + Service(String, BoxedHttpService>), +} + +impl Future for CreateService +where + T: AsyncRead + AsyncWrite + Unpin, +{ + type Output = Result, ()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut done = true; + + // poll http services + for item in &mut self.fut { + let res = match item { + CreateServiceItem::Future(ref mut path, ref mut fut) => { + match Pin::new(fut).poll(cx) { + Poll::Ready(Ok(service)) => { + Some((path.take().unwrap(), service)) + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => { + done = false; + None + } + } + } + CreateServiceItem::Service(_, _) => continue, + }; + + if let Some((path, service)) = res { + *item = CreateServiceItem::Service(path, service); + } + } + + if done { + let router = self + .fut + .drain(..) + .fold(Router::build(), |mut router, item| { + match item { + CreateServiceItem::Service(path, service) => { + router.path(&path, service); + } + CreateServiceItem::Future(_, _) => unreachable!(), + } + router + }); + Poll::Ready(Ok(FramedAppService { + router: router.finish(), + state: self.state.clone(), + })) + } else { + Poll::Pending + } + } +} + +pub struct FramedAppService { + state: State, + router: Router>>, +} + +impl Service for FramedAppService +where + T: AsyncRead + AsyncWrite + Unpin, +{ + type Request = (Request, Framed); + type Response = (); + type Error = Error; + type Future = BoxedResponse; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, (req, framed): (Request, Framed)) -> Self::Future { + let mut path = Path::new(Url::new(req.uri().clone())); + + if let Some((srv, _info)) = self.router.recognize_mut(&mut path) { + return srv.call(FramedRequest::new(req, framed, path, self.state.clone())); + } + SendResponse::new(framed, Response::NotFound().finish()) + .then(|_| ok(())) + .boxed_local() + } +} diff --git a/actix-framed/src/helpers.rs b/actix-framed/src/helpers.rs new file mode 100644 index 000000000..b654f9cd7 --- /dev/null +++ b/actix-framed/src/helpers.rs @@ -0,0 +1,98 @@ +use std::task::{Context, Poll}; + +use actix_http::Error; +use actix_service::{Service, ServiceFactory}; +use futures::future::{FutureExt, LocalBoxFuture}; + +pub(crate) type BoxedHttpService = Box< + dyn Service< + Request = Req, + Response = (), + Error = Error, + Future = LocalBoxFuture<'static, Result<(), Error>>, + >, +>; + +pub(crate) type BoxedHttpNewService = Box< + dyn ServiceFactory< + Config = (), + Request = Req, + Response = (), + Error = Error, + InitError = (), + Service = BoxedHttpService, + Future = LocalBoxFuture<'static, Result, ()>>, + >, +>; + +pub(crate) struct HttpNewService(T); + +impl HttpNewService +where + T: ServiceFactory, + T::Response: 'static, + T::Future: 'static, + T::Service: Service>> + 'static, + ::Future: 'static, +{ + pub fn new(service: T) -> Self { + HttpNewService(service) + } +} + +impl ServiceFactory for HttpNewService +where + T: ServiceFactory, + T::Request: 'static, + T::Future: 'static, + T::Service: Service>> + 'static, + ::Future: 'static, +{ + type Config = (); + type Request = T::Request; + type Response = (); + type Error = Error; + type InitError = (); + type Service = BoxedHttpService; + type Future = LocalBoxFuture<'static, Result>; + + fn new_service(&self, _: &()) -> Self::Future { + let fut = self.0.new_service(&()); + + async move { + fut.await.map_err(|_| ()).map(|service| { + let service: BoxedHttpService<_> = + Box::new(HttpServiceWrapper { service }); + service + }) + } + .boxed_local() + } +} + +struct HttpServiceWrapper { + service: T, +} + +impl Service for HttpServiceWrapper +where + T: Service< + Response = (), + Future = LocalBoxFuture<'static, Result<(), Error>>, + Error = Error, + >, + T::Request: 'static, +{ + type Request = T::Request; + type Response = (); + type Error = Error; + type Future = LocalBoxFuture<'static, Result<(), Error>>; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + self.service.call(req) + } +} diff --git a/actix-framed/src/lib.rs b/actix-framed/src/lib.rs new file mode 100644 index 000000000..250533f39 --- /dev/null +++ b/actix-framed/src/lib.rs @@ -0,0 +1,17 @@ +#![allow(clippy::type_complexity, clippy::new_without_default, dead_code)] +mod app; +mod helpers; +mod request; +mod route; +mod service; +mod state; +pub mod test; + +// re-export for convinience +pub use actix_http::{http, Error, HttpMessage, Response, ResponseError}; + +pub use self::app::{FramedApp, FramedAppService}; +pub use self::request::FramedRequest; +pub use self::route::FramedRoute; +pub use self::service::{SendError, VerifyWebSockets}; +pub use self::state::State; diff --git a/actix-framed/src/request.rs b/actix-framed/src/request.rs new file mode 100644 index 000000000..bdcdd7026 --- /dev/null +++ b/actix-framed/src/request.rs @@ -0,0 +1,170 @@ +use std::cell::{Ref, RefMut}; + +use actix_codec::Framed; +use actix_http::http::{HeaderMap, Method, Uri, Version}; +use actix_http::{h1::Codec, Extensions, Request, RequestHead}; +use actix_router::{Path, Url}; + +use crate::state::State; + +pub struct FramedRequest { + req: Request, + framed: Framed, + state: State, + pub(crate) path: Path, +} + +impl FramedRequest { + pub fn new( + req: Request, + framed: Framed, + path: Path, + state: State, + ) -> Self { + Self { + req, + framed, + state, + path, + } + } +} + +impl FramedRequest { + /// Split request into a parts + pub fn into_parts(self) -> (Request, Framed, State) { + (self.req, self.framed, self.state) + } + + /// This method returns reference to the request head + #[inline] + pub fn head(&self) -> &RequestHead { + self.req.head() + } + + /// This method returns muttable reference to the request head. + /// panics if multiple references of http request exists. + #[inline] + pub fn head_mut(&mut self) -> &mut RequestHead { + self.req.head_mut() + } + + /// Shared application state + #[inline] + pub fn state(&self) -> &S { + self.state.get_ref() + } + + /// Request's uri. + #[inline] + pub fn uri(&self) -> &Uri { + &self.head().uri + } + + /// Read the Request method. + #[inline] + pub fn method(&self) -> &Method { + &self.head().method + } + + /// Read the Request Version. + #[inline] + pub fn version(&self) -> Version { + self.head().version + } + + #[inline] + /// Returns request's headers. + pub fn headers(&self) -> &HeaderMap { + &self.head().headers + } + + /// The target path of this Request. + #[inline] + pub fn path(&self) -> &str { + self.head().uri.path() + } + + /// The query string in the URL. + /// + /// E.g., id=10 + #[inline] + pub fn query_string(&self) -> &str { + if let Some(query) = self.uri().query().as_ref() { + query + } else { + "" + } + } + + /// Get a reference to the Path parameters. + /// + /// Params is a container for url parameters. + /// A variable segment is specified in the form `{identifier}`, + /// where the identifier can be used later in a request handler to + /// access the matched value for that segment. + #[inline] + pub fn match_info(&self) -> &Path { + &self.path + } + + /// Request extensions + #[inline] + pub fn extensions(&self) -> Ref { + self.head().extensions() + } + + /// Mutable reference to a the request's extensions + #[inline] + pub fn extensions_mut(&self) -> RefMut { + self.head().extensions_mut() + } +} + +#[cfg(test)] +mod tests { + use actix_http::http::{HeaderName, HeaderValue, HttpTryFrom}; + use actix_http::test::{TestBuffer, TestRequest}; + + use super::*; + + #[test] + fn test_reqest() { + let buf = TestBuffer::empty(); + let framed = Framed::new(buf, Codec::default()); + let req = TestRequest::with_uri("/index.html?q=1") + .header("content-type", "test") + .finish(); + let path = Path::new(Url::new(req.uri().clone())); + + let mut freq = FramedRequest::new(req, framed, path, State::new(10u8)); + assert_eq!(*freq.state(), 10); + assert_eq!(freq.version(), Version::HTTP_11); + assert_eq!(freq.method(), Method::GET); + assert_eq!(freq.path(), "/index.html"); + assert_eq!(freq.query_string(), "q=1"); + assert_eq!( + freq.headers() + .get("content-type") + .unwrap() + .to_str() + .unwrap(), + "test" + ); + + freq.head_mut().headers.insert( + HeaderName::try_from("x-hdr").unwrap(), + HeaderValue::from_static("test"), + ); + assert_eq!( + freq.headers().get("x-hdr").unwrap().to_str().unwrap(), + "test" + ); + + freq.extensions_mut().insert(100usize); + assert_eq!(*freq.extensions().get::().unwrap(), 100usize); + + let (_, _, state) = freq.into_parts(); + assert_eq!(*state, 10); + } +} diff --git a/actix-framed/src/route.rs b/actix-framed/src/route.rs new file mode 100644 index 000000000..783039684 --- /dev/null +++ b/actix-framed/src/route.rs @@ -0,0 +1,159 @@ +use std::fmt; +use std::future::Future; +use std::marker::PhantomData; +use std::task::{Context, Poll}; + +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_http::{http::Method, Error}; +use actix_service::{Service, ServiceFactory}; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; +use log::error; + +use crate::app::HttpServiceFactory; +use crate::request::FramedRequest; + +/// Resource route definition +/// +/// Route uses builder-like pattern for configuration. +/// If handler is not explicitly set, default *404 Not Found* handler is used. +pub struct FramedRoute { + handler: F, + pattern: String, + methods: Vec, + state: PhantomData<(Io, S, R, E)>, +} + +impl FramedRoute { + pub fn new(pattern: &str) -> Self { + FramedRoute { + handler: (), + pattern: pattern.to_string(), + methods: Vec::new(), + state: PhantomData, + } + } + + pub fn get(path: &str) -> FramedRoute { + FramedRoute::new(path).method(Method::GET) + } + + pub fn post(path: &str) -> FramedRoute { + FramedRoute::new(path).method(Method::POST) + } + + pub fn put(path: &str) -> FramedRoute { + FramedRoute::new(path).method(Method::PUT) + } + + pub fn delete(path: &str) -> FramedRoute { + FramedRoute::new(path).method(Method::DELETE) + } + + pub fn method(mut self, method: Method) -> Self { + self.methods.push(method); + self + } + + pub fn to(self, handler: F) -> FramedRoute + where + F: FnMut(FramedRequest) -> R, + R: Future> + 'static, + + E: fmt::Debug, + { + FramedRoute { + handler, + pattern: self.pattern, + methods: self.methods, + state: PhantomData, + } + } +} + +impl HttpServiceFactory for FramedRoute +where + Io: AsyncRead + AsyncWrite + 'static, + F: FnMut(FramedRequest) -> R + Clone, + R: Future> + 'static, + E: fmt::Display, +{ + type Factory = FramedRouteFactory; + + fn path(&self) -> &str { + &self.pattern + } + + fn create(self) -> Self::Factory { + FramedRouteFactory { + handler: self.handler, + methods: self.methods, + _t: PhantomData, + } + } +} + +pub struct FramedRouteFactory { + handler: F, + methods: Vec, + _t: PhantomData<(Io, S, R, E)>, +} + +impl ServiceFactory for FramedRouteFactory +where + Io: AsyncRead + AsyncWrite + 'static, + F: FnMut(FramedRequest) -> R + Clone, + R: Future> + 'static, + E: fmt::Display, +{ + type Config = (); + type Request = FramedRequest; + type Response = (); + type Error = Error; + type InitError = (); + type Service = FramedRouteService; + type Future = Ready>; + + fn new_service(&self, _: &()) -> Self::Future { + ok(FramedRouteService { + handler: self.handler.clone(), + methods: self.methods.clone(), + _t: PhantomData, + }) + } +} + +pub struct FramedRouteService { + handler: F, + methods: Vec, + _t: PhantomData<(Io, S, R, E)>, +} + +impl Service for FramedRouteService +where + Io: AsyncRead + AsyncWrite + 'static, + F: FnMut(FramedRequest) -> R + Clone, + R: Future> + 'static, + E: fmt::Display, +{ + type Request = FramedRequest; + type Response = (); + type Error = Error; + type Future = LocalBoxFuture<'static, Result<(), Error>>; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: FramedRequest) -> Self::Future { + let fut = (self.handler)(req); + + async move { + let res = fut.await; + if let Err(e) = res { + error!("Error in request handler: {}", e); + } + Ok(()) + } + .boxed_local() + } +} diff --git a/actix-framed/src/service.rs b/actix-framed/src/service.rs new file mode 100644 index 000000000..ed3a75ff5 --- /dev/null +++ b/actix-framed/src/service.rs @@ -0,0 +1,156 @@ +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_http::body::BodySize; +use actix_http::error::ResponseError; +use actix_http::h1::{Codec, Message}; +use actix_http::ws::{verify_handshake, HandshakeError}; +use actix_http::{Request, Response}; +use actix_service::{Service, ServiceFactory}; +use futures::future::{err, ok, Either, Ready}; +use futures::Future; + +/// Service that verifies incoming request if it is valid websocket +/// upgrade request. In case of error returns `HandshakeError` +pub struct VerifyWebSockets { + _t: PhantomData<(T, C)>, +} + +impl Default for VerifyWebSockets { + fn default() -> Self { + VerifyWebSockets { _t: PhantomData } + } +} + +impl ServiceFactory for VerifyWebSockets { + type Config = C; + type Request = (Request, Framed); + type Response = (Request, Framed); + type Error = (HandshakeError, Framed); + type InitError = (); + type Service = VerifyWebSockets; + type Future = Ready>; + + fn new_service(&self, _: &C) -> Self::Future { + ok(VerifyWebSockets { _t: PhantomData }) + } +} + +impl Service for VerifyWebSockets { + type Request = (Request, Framed); + type Response = (Request, Framed); + type Error = (HandshakeError, Framed); + type Future = Ready>; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, (req, framed): (Request, Framed)) -> Self::Future { + match verify_handshake(req.head()) { + Err(e) => err((e, framed)), + Ok(_) => ok((req, framed)), + } + } +} + +/// Send http/1 error response +pub struct SendError(PhantomData<(T, R, E, C)>); + +impl Default for SendError +where + T: AsyncRead + AsyncWrite, + E: ResponseError, +{ + fn default() -> Self { + SendError(PhantomData) + } +} + +impl ServiceFactory for SendError +where + T: AsyncRead + AsyncWrite + Unpin + 'static, + R: 'static, + E: ResponseError + 'static, +{ + type Config = C; + type Request = Result)>; + type Response = R; + type Error = (E, Framed); + type InitError = (); + type Service = SendError; + type Future = Ready>; + + fn new_service(&self, _: &C) -> Self::Future { + ok(SendError(PhantomData)) + } +} + +impl Service for SendError +where + T: AsyncRead + AsyncWrite + Unpin + 'static, + R: 'static, + E: ResponseError + 'static, +{ + type Request = Result)>; + type Response = R; + type Error = (E, Framed); + type Future = Either)>>, SendErrorFut>; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Result)>) -> Self::Future { + match req { + Ok(r) => Either::Left(ok(r)), + Err((e, framed)) => { + let res = e.error_response().drop_body(); + Either::Right(SendErrorFut { + framed: Some(framed), + res: Some((res, BodySize::Empty).into()), + err: Some(e), + _t: PhantomData, + }) + } + } + } +} + +#[pin_project::pin_project] +pub struct SendErrorFut { + res: Option, BodySize)>>, + framed: Option>, + err: Option, + _t: PhantomData, +} + +impl Future for SendErrorFut +where + E: ResponseError, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Output = Result)>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + if let Some(res) = self.res.take() { + if self.framed.as_mut().unwrap().write(res).is_err() { + return Poll::Ready(Err(( + self.err.take().unwrap(), + self.framed.take().unwrap(), + ))); + } + } + match self.framed.as_mut().unwrap().flush(cx) { + Poll::Ready(Ok(_)) => { + Poll::Ready(Err((self.err.take().unwrap(), self.framed.take().unwrap()))) + } + Poll::Ready(Err(_)) => { + Poll::Ready(Err((self.err.take().unwrap(), self.framed.take().unwrap()))) + } + Poll::Pending => Poll::Pending, + } + } +} diff --git a/actix-framed/src/state.rs b/actix-framed/src/state.rs new file mode 100644 index 000000000..600a639ca --- /dev/null +++ b/actix-framed/src/state.rs @@ -0,0 +1,29 @@ +use std::ops::Deref; +use std::sync::Arc; + +/// Application state +pub struct State(Arc); + +impl State { + pub fn new(state: S) -> State { + State(Arc::new(state)) + } + + pub fn get_ref(&self) -> &S { + self.0.as_ref() + } +} + +impl Deref for State { + type Target = S; + + fn deref(&self) -> &S { + self.0.as_ref() + } +} + +impl Clone for State { + fn clone(&self) -> State { + State(self.0.clone()) + } +} diff --git a/actix-framed/src/test.rs b/actix-framed/src/test.rs new file mode 100644 index 000000000..7969d51ff --- /dev/null +++ b/actix-framed/src/test.rs @@ -0,0 +1,152 @@ +//! Various helpers for Actix applications to use during testing. +use std::future::Future; + +use actix_codec::Framed; +use actix_http::h1::Codec; +use actix_http::http::header::{Header, HeaderName, IntoHeaderValue}; +use actix_http::http::{HttpTryFrom, Method, Uri, Version}; +use actix_http::test::{TestBuffer, TestRequest as HttpTestRequest}; +use actix_router::{Path, Url}; + +use crate::{FramedRequest, State}; + +/// Test `Request` builder. +pub struct TestRequest { + req: HttpTestRequest, + path: Path, + state: State, +} + +impl Default for TestRequest<()> { + fn default() -> TestRequest { + TestRequest { + req: HttpTestRequest::default(), + path: Path::new(Url::new(Uri::default())), + state: State::new(()), + } + } +} + +impl TestRequest<()> { + /// Create TestRequest and set request uri + pub fn with_uri(path: &str) -> Self { + Self::get().uri(path) + } + + /// Create TestRequest and set header + pub fn with_hdr(hdr: H) -> Self { + Self::default().set(hdr) + } + + /// Create TestRequest and set header + pub fn with_header(key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + Self::default().header(key, value) + } + + /// Create TestRequest and set method to `Method::GET` + pub fn get() -> Self { + Self::default().method(Method::GET) + } + + /// Create TestRequest and set method to `Method::POST` + pub fn post() -> Self { + Self::default().method(Method::POST) + } +} + +impl TestRequest { + /// Create TestRequest and set request uri + pub fn with_state(state: S) -> TestRequest { + let req = TestRequest::get(); + TestRequest { + state: State::new(state), + req: req.req, + path: req.path, + } + } + + /// Set HTTP version of this request + pub fn version(mut self, ver: Version) -> Self { + self.req.version(ver); + self + } + + /// Set HTTP method of this request + pub fn method(mut self, meth: Method) -> Self { + self.req.method(meth); + self + } + + /// Set HTTP Uri of this request + pub fn uri(mut self, path: &str) -> Self { + self.req.uri(path); + self + } + + /// Set a header + pub fn set(mut self, hdr: H) -> Self { + self.req.set(hdr); + self + } + + /// Set a header + pub fn header(mut self, key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + self.req.header(key, value); + self + } + + /// Set request path pattern parameter + pub fn param(mut self, name: &'static str, value: &'static str) -> Self { + self.path.add_static(name, value); + self + } + + /// Complete request creation and generate `Request` instance + pub fn finish(mut self) -> FramedRequest { + let req = self.req.finish(); + self.path.get_mut().update(req.uri()); + let framed = Framed::new(TestBuffer::empty(), Codec::default()); + FramedRequest::new(req, framed, self.path, self.state) + } + + /// This method generates `FramedRequest` instance and executes async handler + pub async fn run(self, f: F) -> Result + where + F: FnOnce(FramedRequest) -> R, + R: Future>, + { + f(self.finish()).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test() { + let req = TestRequest::with_uri("/index.html") + .header("x-test", "test") + .param("test", "123") + .finish(); + + assert_eq!(*req.state(), ()); + assert_eq!(req.version(), Version::HTTP_11); + assert_eq!(req.method(), Method::GET); + assert_eq!(req.path(), "/index.html"); + assert_eq!(req.query_string(), ""); + assert_eq!( + req.headers().get("x-test").unwrap().to_str().unwrap(), + "test" + ); + assert_eq!(&req.match_info()["test"], "123"); + } +} diff --git a/actix-framed/tests/test_server.rs b/actix-framed/tests/test_server.rs new file mode 100644 index 000000000..4d1028d31 --- /dev/null +++ b/actix-framed/tests/test_server.rs @@ -0,0 +1,158 @@ +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_http::{body, http::StatusCode, ws, Error, HttpService, Response}; +use actix_http_test::TestServer; +use actix_service::{pipeline_factory, IntoServiceFactory, ServiceFactory}; +use actix_utils::framed::FramedTransport; +use bytes::{Bytes, BytesMut}; +use futures::{future, SinkExt, StreamExt}; + +use actix_framed::{FramedApp, FramedRequest, FramedRoute, SendError, VerifyWebSockets}; + +async fn ws_service( + req: FramedRequest, +) -> Result<(), Error> { + let (req, mut framed, _) = req.into_parts(); + let res = ws::handshake(req.head()).unwrap().message_body(()); + + framed + .send((res, body::BodySize::None).into()) + .await + .unwrap(); + FramedTransport::new(framed.into_framed(ws::Codec::new()), service) + .await + .unwrap(); + + Ok(()) +} + +async fn service(msg: ws::Frame) -> Result { + let msg = match msg { + ws::Frame::Ping(msg) => ws::Message::Pong(msg), + ws::Frame::Text(text) => { + ws::Message::Text(String::from_utf8_lossy(&text.unwrap()).to_string()) + } + ws::Frame::Binary(bin) => ws::Message::Binary(bin.unwrap().freeze()), + ws::Frame::Close(reason) => ws::Message::Close(reason), + _ => panic!(), + }; + Ok(msg) +} + +#[actix_rt::test] +async fn test_simple() { + let mut srv = TestServer::start(|| { + HttpService::build() + .upgrade( + FramedApp::new().service(FramedRoute::get("/index.html").to(ws_service)), + ) + .finish(|_| future::ok::<_, Error>(Response::NotFound())) + }); + + assert!(srv.ws_at("/test").await.is_err()); + + // client service + let mut framed = srv.ws_at("/index.html").await.unwrap(); + framed + .send(ws::Message::Text("text".to_string())) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Text(Some(BytesMut::from("text"))) + ); + + framed + .send(ws::Message::Binary("text".into())) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Binary(Some(Bytes::from_static(b"text").into())) + ); + + framed.send(ws::Message::Ping("text".into())).await.unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Pong("text".to_string().into()) + ); + + framed + .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) + .await + .unwrap(); + + let (item, _) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Close(Some(ws::CloseCode::Normal.into())) + ); +} + +#[actix_rt::test] +async fn test_service() { + let mut srv = TestServer::start(|| { + pipeline_factory(actix_http::h1::OneRequest::new().map_err(|_| ())).and_then( + pipeline_factory( + pipeline_factory(VerifyWebSockets::default()) + .then(SendError::default()) + .map_err(|_| ()), + ) + .and_then( + FramedApp::new() + .service(FramedRoute::get("/index.html").to(ws_service)) + .into_factory() + .map_err(|_| ()), + ), + ) + }); + + // non ws request + let res = srv.get("/index.html").send().await.unwrap(); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + + // not found + assert!(srv.ws_at("/test").await.is_err()); + + // client service + let mut framed = srv.ws_at("/index.html").await.unwrap(); + framed + .send(ws::Message::Text("text".to_string())) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Text(Some(BytesMut::from("text"))) + ); + + framed + .send(ws::Message::Binary("text".into())) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Binary(Some(Bytes::from_static(b"text").into())) + ); + + framed.send(ws::Message::Ping("text".into())).await.unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Pong("text".to_string().into()) + ); + + framed + .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) + .await + .unwrap(); + + let (item, _) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Close(Some(ws::CloseCode::Normal.into())) + ); +} diff --git a/actix-http/.appveyor.yml b/actix-http/.appveyor.yml new file mode 100644 index 000000000..780fdd6b5 --- /dev/null +++ b/actix-http/.appveyor.yml @@ -0,0 +1,41 @@ +environment: + global: + PROJECT_NAME: actix-http + matrix: + # Stable channel + - TARGET: i686-pc-windows-msvc + CHANNEL: stable + - TARGET: x86_64-pc-windows-gnu + CHANNEL: stable + - TARGET: x86_64-pc-windows-msvc + CHANNEL: stable + # Nightly channel + - TARGET: i686-pc-windows-msvc + CHANNEL: nightly + - TARGET: x86_64-pc-windows-gnu + CHANNEL: nightly + - TARGET: x86_64-pc-windows-msvc + CHANNEL: nightly + +# Install Rust and Cargo +# (Based on from https://github.com/rust-lang/libc/blob/master/appveyor.yml) +install: + - ps: >- + If ($Env:TARGET -eq 'x86_64-pc-windows-gnu') { + $Env:PATH += ';C:\msys64\mingw64\bin' + } ElseIf ($Env:TARGET -eq 'i686-pc-windows-gnu') { + $Env:PATH += ';C:\MinGW\bin' + } + - curl -sSf -o rustup-init.exe https://win.rustup.rs + - rustup-init.exe --default-host %TARGET% --default-toolchain %CHANNEL% -y + - set PATH=%PATH%;C:\Users\appveyor\.cargo\bin + - rustc -Vv + - cargo -V + +# 'cargo test' takes care of building for us, so disable Appveyor's build stage. +build: false + +# Equivalent to Travis' `script` phase +test_script: + - cargo clean + - cargo test diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md new file mode 100644 index 000000000..4cb5644c3 --- /dev/null +++ b/actix-http/CHANGES.md @@ -0,0 +1,259 @@ +# Changes + +## [0.2.11] - 2019-11-06 + +### Added + +* Add support for serde_json::Value to be passed as argument to ResponseBuilder.body() + +* Add an additional `filename*` param in the `Content-Disposition` header of `actix_files::NamedFile` to be more compatible. (#1151) + +* Allow to use `std::convert::Infallible` as `actix_http::error::Error` + +### Fixed + +* To be compatible with non-English error responses, `ResponseError` rendered with `text/plain; charset=utf-8` header #1118 + + +## [0.2.10] - 2019-09-11 + +### Added + +* Add support for sending HTTP requests with `Rc` in addition to sending HTTP requests with `RequestHead` + +### Fixed + +* h2 will use error response #1080 + +* on_connect result isn't added to request extensions for http2 requests #1009 + + +## [0.2.9] - 2019-08-13 + +### Changed + +* Dropped the `byteorder`-dependency in favor of `stdlib`-implementation + +* Update percent-encoding to 2.1 + +* Update serde_urlencoded to 0.6.1 + +### Fixed + +* Fixed a panic in the HTTP2 handshake in client HTTP requests (#1031) + + +## [0.2.8] - 2019-08-01 + +### Added + +* Add `rustls` support + +* Add `Clone` impl for `HeaderMap` + +### Fixed + +* awc client panic #1016 + +* Invalid response with compression middleware enabled, but compression-related features disabled #997 + + +## [0.2.7] - 2019-07-18 + +### Added + +* Add support for downcasting response errors #986 + + +## [0.2.6] - 2019-07-17 + +### Changed + +* Replace `ClonableService` with local copy + +* Upgrade `rand` dependency version to 0.7 + + +## [0.2.5] - 2019-06-28 + +### Added + +* Add `on-connect` callback, `HttpServiceBuilder::on_connect()` #946 + +### Changed + +* Use `encoding_rs` crate instead of unmaintained `encoding` crate + +* Add `Copy` and `Clone` impls for `ws::Codec` + + +## [0.2.4] - 2019-06-16 + +### Fixed + +* Do not compress NoContent (204) responses #918 + + +## [0.2.3] - 2019-06-02 + +### Added + +* Debug impl for ResponseBuilder + +* From SizedStream and BodyStream for Body + +### Changed + +* SizedStream uses u64 + + +## [0.2.2] - 2019-05-29 + +### Fixed + +* Parse incoming stream before closing stream on disconnect #868 + + +## [0.2.1] - 2019-05-25 + +### Fixed + +* Handle socket read disconnect + + +## [0.2.0] - 2019-05-12 + +### Changed + +* Update actix-service to 0.4 + +* Expect and upgrade services accept `ServerConfig` config. + +### Deleted + +* `OneRequest` service + + +## [0.1.5] - 2019-05-04 + +### Fixed + +* Clean up response extensions in response pool #817 + + +## [0.1.4] - 2019-04-24 + +### Added + +* Allow to render h1 request headers in `Camel-Case` + +### Fixed + +* Read until eof for http/1.0 responses #771 + + +## [0.1.3] - 2019-04-23 + +### Fixed + +* Fix http client pool management + +* Fix http client wait queue management #794 + + +## [0.1.2] - 2019-04-23 + +### Fixed + +* Fix BorrowMutError panic in client connector #793 + + +## [0.1.1] - 2019-04-19 + +### Changed + +* Cookie::max_age() accepts value in seconds + +* Cookie::max_age_time() accepts value in time::Duration + +* Allow to specify server address for client connector + + +## [0.1.0] - 2019-04-16 + +### Added + +* Expose peer addr via `Request::peer_addr()` and `RequestHead::peer_addr` + +### Changed + +* `actix_http::encoding` always available + +* use trust-dns-resolver 0.11.0 + + +## [0.1.0-alpha.5] - 2019-04-12 + +### Added + +* Allow to use custom service for upgrade requests + +* Added `h1::SendResponse` future. + +### Changed + +* MessageBody::length() renamed to MessageBody::size() for consistency + +* ws handshake verification functions take RequestHead instead of Request + + +## [0.1.0-alpha.4] - 2019-04-08 + +### Added + +* Allow to use custom `Expect` handler + +* Add minimal `std::error::Error` impl for `Error` + +### Changed + +* Export IntoHeaderValue + +* Render error and return as response body + +* Use thread pool for response body comression + +### Deleted + +* Removed PayloadBuffer + + +## [0.1.0-alpha.3] - 2019-04-02 + +### Added + +* Warn when an unsealed private cookie isn't valid UTF-8 + +### Fixed + +* Rust 1.31.0 compatibility + +* Preallocate read buffer for h1 codec + +* Detect socket disconnection during protocol selection + + +## [0.1.0-alpha.2] - 2019-03-29 + +### Added + +* Added ws::Message::Nop, no-op websockets message + +### Changed + +* Do not use thread pool for decomression if chunk size is smaller than 2048. + + +## [0.1.0-alpha.1] - 2019-03-28 + +* Initial impl diff --git a/actix-http/CODE_OF_CONDUCT.md b/actix-http/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..599b28c0d --- /dev/null +++ b/actix-http/CODE_OF_CONDUCT.md @@ -0,0 +1,46 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at fafhrd91@gmail.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version] + +[homepage]: http://contributor-covenant.org +[version]: http://contributor-covenant.org/version/1/4/ diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml new file mode 100644 index 000000000..9a14abefe --- /dev/null +++ b/actix-http/Cargo.toml @@ -0,0 +1,112 @@ +[package] +name = "actix-http" +version = "0.3.0-alpha.1" +authors = ["Nikolay Kim "] +description = "Actix http primitives" +readme = "README.md" +keywords = ["actix", "http", "framework", "async", "futures"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" +documentation = "https://docs.rs/actix-http/" +categories = ["network-programming", "asynchronous", + "web-programming::http-server", + "web-programming::websocket"] +license = "MIT/Apache-2.0" +edition = "2018" +workspace = ".." + +[package.metadata.docs.rs] +features = ["openssl", "fail", "brotli", "flate2-zlib", "secure-cookies"] + +[lib] +name = "actix_http" +path = "src/lib.rs" + +[features] +default = [] + +# openssl +openssl = ["open-ssl", "actix-connect/openssl", "tokio-openssl"] + +# rustls support +# rustls = ["rust-tls", "webpki-roots", "actix-connect/rustls"] + +# brotli encoding, requires c compiler +brotli = ["brotli2"] + +# miniz-sys backend for flate2 crate +flate2-zlib = ["flate2/miniz-sys"] + +# rust backend for flate2 crate +flate2-rust = ["flate2/rust_backend"] + +# failure integration. actix does not use failure anymore +fail = ["failure"] + +# support for secure cookies +secure-cookies = ["ring"] + +[dependencies] +actix-service = "1.0.0-alpha.1" +actix-codec = "0.2.0-alpha.1" +actix-connect = "1.0.0-alpha.1" +actix-utils = "0.5.0-alpha.1" +actix-server-config = "0.3.0-alpha.1" +actix-rt = "1.0.0-alpha.1" +actix-threadpool = "0.2.0-alpha.1" + +base64 = "0.10" +bitflags = "1.0" +bytes = "0.4" +copyless = "0.1.4" +chrono = "0.4.6" +derive_more = "0.99.2" +either = "1.5.2" +encoding_rs = "0.8" +futures = "0.3.1" +hashbrown = "0.6.3" +h2 = "0.2.0-alpha.3" +http = "0.1.17" +httparse = "1.3" +indexmap = "1.2" +lazy_static = "1.0" +language-tags = "0.2" +log = "0.4" +mime = "0.3" +percent-encoding = "2.1" +pin-project = "0.4.5" +rand = "0.7" +regex = "1.0" +serde = "1.0" +serde_json = "1.0" +sha1 = "0.6" +slab = "0.4" +serde_urlencoded = "0.6.1" +time = "0.1.42" + +tokio-net = "=0.2.0-alpha.6" +trust-dns-resolver = { version="0.18.0-alpha.1", default-features = false } + +# for secure cookie +ring = { version = "0.16.9", optional = true } + +# compression +brotli2 = { version="0.3.2", optional = true } +flate2 = { version="1.0.7", optional = true, default-features = false } + +# optional deps +failure = { version = "0.1.5", optional = true } +open-ssl = { version="0.10", package="openssl", optional = true } +tokio-openssl = { version = "0.4.0-alpha.6", optional = true } + +# rust-tls = { version = "0.16.0", package="rustls", optional = true } +# webpki-roots = { version = "0.18", optional = true } + +[dev-dependencies] +#actix-server = { version = "0.8.0-alpha.1", features=["openssl", "rustls"] } +actix-server = { version = "0.8.0-alpha.1", features=["openssl"] } +actix-connect = { version = "1.0.0-alpha.1", features=["openssl"] } +actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] } +env_logger = "0.6" +serde_derive = "1.0" +open-ssl = { version="0.10", package="openssl" } diff --git a/actix-http/LICENSE-APACHE b/actix-http/LICENSE-APACHE new file mode 100644 index 000000000..6cdf2d16c --- /dev/null +++ b/actix-http/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2017-NOW Nikolay Kim + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/actix-http/LICENSE-MIT b/actix-http/LICENSE-MIT new file mode 100644 index 000000000..0f80296ae --- /dev/null +++ b/actix-http/LICENSE-MIT @@ -0,0 +1,25 @@ +Copyright (c) 2017 Nikolay Kim + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/actix-http/README.md b/actix-http/README.md new file mode 100644 index 000000000..d75e822ba --- /dev/null +++ b/actix-http/README.md @@ -0,0 +1,46 @@ +# Actix http [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-http)](https://crates.io/crates/actix-http) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +Actix http + +## Documentation & community resources + +* [User Guide](https://actix.rs/docs/) +* [API Documentation](https://docs.rs/actix-http/) +* [Chat on gitter](https://gitter.im/actix/actix) +* Cargo package: [actix-http](https://crates.io/crates/actix-http) +* Minimum supported Rust version: 1.31 or later + +## Example + +```rust +// see examples/framed_hello.rs for complete list of used crates. +extern crate actix_http; +use actix_http::{h1, Response, ServiceConfig}; + +fn main() { + Server::new().bind("framed_hello", "127.0.0.1:8080", || { + IntoFramed::new(|| h1::Codec::new(ServiceConfig::default())) // <- create h1 codec + .and_then(TakeItem::new().map_err(|_| ())) // <- read one request + .and_then(|(_req, _framed): (_, Framed<_, _>)| { // <- send response and close conn + SendResponse::send(_framed, Response::Ok().body("Hello world!")) + .map_err(|_| ()) + .map(|_| ()) + }) + }).unwrap().run(); +} +``` + +## License + +This project is licensed under either of + +* Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0)) +* MIT license ([LICENSE-MIT](LICENSE-MIT) or [http://opensource.org/licenses/MIT](http://opensource.org/licenses/MIT)) + +at your option. + +## Code of Conduct + +Contribution to the actix-http crate is organized under the terms of the +Contributor Covenant, the maintainer of actix-http, @fafhrd91, promises to +intervene to uphold that code of conduct. diff --git a/actix-http/examples/echo.rs b/actix-http/examples/echo.rs new file mode 100644 index 000000000..ba81020ca --- /dev/null +++ b/actix-http/examples/echo.rs @@ -0,0 +1,39 @@ +use std::{env, io}; + +use actix_http::{Error, HttpService, Request, Response}; +use actix_server::Server; +use bytes::BytesMut; +use futures::StreamExt; +use http::header::HeaderValue; +use log::info; + +fn main() -> io::Result<()> { + env::set_var("RUST_LOG", "echo=info"); + env_logger::init(); + + Server::build() + .bind("echo", "127.0.0.1:8080", || { + HttpService::build() + .client_timeout(1000) + .client_disconnect(1000) + .finish(|mut req: Request| { + async move { + let mut body = BytesMut::new(); + while let Some(item) = req.payload().next().await { + body.extend_from_slice(&item?); + } + + info!("request body: {:?}", body); + Ok::<_, Error>( + Response::Ok() + .header( + "x-head", + HeaderValue::from_static("dummy value!"), + ) + .body(body), + ) + } + }) + })? + .run() +} diff --git a/actix-http/examples/echo2.rs b/actix-http/examples/echo2.rs new file mode 100644 index 000000000..3776c7d58 --- /dev/null +++ b/actix-http/examples/echo2.rs @@ -0,0 +1,31 @@ +use std::{env, io}; + +use actix_http::http::HeaderValue; +use actix_http::{Error, HttpService, Request, Response}; +use actix_server::Server; +use bytes::BytesMut; +use futures::StreamExt; +use log::info; + +async fn handle_request(mut req: Request) -> Result { + let mut body = BytesMut::new(); + while let Some(item) = req.payload().next().await { + body.extend_from_slice(&item?) + } + + info!("request body: {:?}", body); + Ok(Response::Ok() + .header("x-head", HeaderValue::from_static("dummy value!")) + .body(body)) +} + +fn main() -> io::Result<()> { + env::set_var("RUST_LOG", "echo=info"); + env_logger::init(); + + Server::build() + .bind("echo", "127.0.0.1:8080", || { + HttpService::build().finish(handle_request) + })? + .run() +} diff --git a/actix-http/examples/hello-world.rs b/actix-http/examples/hello-world.rs new file mode 100644 index 000000000..6e3820390 --- /dev/null +++ b/actix-http/examples/hello-world.rs @@ -0,0 +1,26 @@ +use std::{env, io}; + +use actix_http::{HttpService, Response}; +use actix_server::Server; +use futures::future; +use http::header::HeaderValue; +use log::info; + +fn main() -> io::Result<()> { + env::set_var("RUST_LOG", "hello_world=info"); + env_logger::init(); + + Server::build() + .bind("hello-world", "127.0.0.1:8080", || { + HttpService::build() + .client_timeout(1000) + .client_disconnect(1000) + .finish(|_req| { + info!("{:?}", _req); + let mut res = Response::Ok(); + res.header("x-head", HeaderValue::from_static("dummy value!")); + future::ok::<_, ()>(res.body("Hello world!")) + }) + })? + .run() +} diff --git a/actix-http/rustfmt.toml b/actix-http/rustfmt.toml new file mode 100644 index 000000000..5fcaaca0f --- /dev/null +++ b/actix-http/rustfmt.toml @@ -0,0 +1,5 @@ +max_width = 89 +reorder_imports = true +#wrap_comments = true +#fn_args_density = "Compressed" +#use_small_heuristics = false diff --git a/actix-http/src/body.rs b/actix-http/src/body.rs new file mode 100644 index 000000000..b69c21eaa --- /dev/null +++ b/actix-http/src/body.rs @@ -0,0 +1,589 @@ +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, mem}; + +use bytes::{Bytes, BytesMut}; +use futures::Stream; +use pin_project::{pin_project, project}; + +use crate::error::Error; + +#[derive(Debug, PartialEq, Copy, Clone)] +/// Body size hint +pub enum BodySize { + None, + Empty, + Sized(usize), + Sized64(u64), + Stream, +} + +impl BodySize { + pub fn is_eof(&self) -> bool { + match self { + BodySize::None + | BodySize::Empty + | BodySize::Sized(0) + | BodySize::Sized64(0) => true, + _ => false, + } + } +} + +/// Type that provides this trait can be streamed to a peer. +pub trait MessageBody { + fn size(&self) -> BodySize; + + fn poll_next(&mut self, cx: &mut Context) -> Poll>>; +} + +impl MessageBody for () { + fn size(&self) -> BodySize { + BodySize::Empty + } + + fn poll_next(&mut self, _: &mut Context) -> Poll>> { + Poll::Ready(None) + } +} + +impl MessageBody for Box { + fn size(&self) -> BodySize { + self.as_ref().size() + } + + fn poll_next(&mut self, cx: &mut Context) -> Poll>> { + self.as_mut().poll_next(cx) + } +} + +#[pin_project] +pub enum ResponseBody { + Body(B), + Other(Body), +} + +impl ResponseBody { + pub fn into_body(self) -> ResponseBody { + match self { + ResponseBody::Body(b) => ResponseBody::Other(b), + ResponseBody::Other(b) => ResponseBody::Other(b), + } + } +} + +impl ResponseBody { + pub fn take_body(&mut self) -> ResponseBody { + std::mem::replace(self, ResponseBody::Other(Body::None)) + } +} + +impl ResponseBody { + pub fn as_ref(&self) -> Option<&B> { + if let ResponseBody::Body(ref b) = self { + Some(b) + } else { + None + } + } +} + +impl MessageBody for ResponseBody { + fn size(&self) -> BodySize { + match self { + ResponseBody::Body(ref body) => body.size(), + ResponseBody::Other(ref body) => body.size(), + } + } + + fn poll_next(&mut self, cx: &mut Context) -> Poll>> { + match self { + ResponseBody::Body(ref mut body) => body.poll_next(cx), + ResponseBody::Other(ref mut body) => body.poll_next(cx), + } + } +} + +impl Stream for ResponseBody { + type Item = Result; + + #[project] + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + #[project] + match self.project() { + ResponseBody::Body(ref mut body) => body.poll_next(cx), + ResponseBody::Other(ref mut body) => body.poll_next(cx), + } + } +} + +/// Represents various types of http message body. +pub enum Body { + /// Empty response. `Content-Length` header is not set. + None, + /// Zero sized response body. `Content-Length` header is set to `0`. + Empty, + /// Specific response body. + Bytes(Bytes), + /// Generic message body. + Message(Box), +} + +impl Body { + /// Create body from slice (copy) + pub fn from_slice(s: &[u8]) -> Body { + Body::Bytes(Bytes::from(s)) + } + + /// Create body from generic message body. + pub fn from_message(body: B) -> Body { + Body::Message(Box::new(body)) + } +} + +impl MessageBody for Body { + fn size(&self) -> BodySize { + match self { + Body::None => BodySize::None, + Body::Empty => BodySize::Empty, + Body::Bytes(ref bin) => BodySize::Sized(bin.len()), + Body::Message(ref body) => body.size(), + } + } + + fn poll_next(&mut self, cx: &mut Context) -> Poll>> { + match self { + Body::None => Poll::Ready(None), + Body::Empty => Poll::Ready(None), + Body::Bytes(ref mut bin) => { + let len = bin.len(); + if len == 0 { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(mem::replace(bin, Bytes::new())))) + } + } + Body::Message(ref mut body) => body.poll_next(cx), + } + } +} + +impl PartialEq for Body { + fn eq(&self, other: &Body) -> bool { + match *self { + Body::None => match *other { + Body::None => true, + _ => false, + }, + Body::Empty => match *other { + Body::Empty => true, + _ => false, + }, + Body::Bytes(ref b) => match *other { + Body::Bytes(ref b2) => b == b2, + _ => false, + }, + Body::Message(_) => false, + } + } +} + +impl fmt::Debug for Body { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Body::None => write!(f, "Body::None"), + Body::Empty => write!(f, "Body::Empty"), + Body::Bytes(ref b) => write!(f, "Body::Bytes({:?})", b), + Body::Message(_) => write!(f, "Body::Message(_)"), + } + } +} + +impl From<&'static str> for Body { + fn from(s: &'static str) -> Body { + Body::Bytes(Bytes::from_static(s.as_ref())) + } +} + +impl From<&'static [u8]> for Body { + fn from(s: &'static [u8]) -> Body { + Body::Bytes(Bytes::from_static(s)) + } +} + +impl From> for Body { + fn from(vec: Vec) -> Body { + Body::Bytes(Bytes::from(vec)) + } +} + +impl From for Body { + fn from(s: String) -> Body { + s.into_bytes().into() + } +} + +impl<'a> From<&'a String> for Body { + fn from(s: &'a String) -> Body { + Body::Bytes(Bytes::from(AsRef::<[u8]>::as_ref(&s))) + } +} + +impl From for Body { + fn from(s: Bytes) -> Body { + Body::Bytes(s) + } +} + +impl From for Body { + fn from(s: BytesMut) -> Body { + Body::Bytes(s.freeze()) + } +} + +impl From for Body { + fn from(v: serde_json::Value) -> Body { + Body::Bytes(v.to_string().into()) + } +} + +impl From> for Body +where + S: Stream> + 'static, +{ + fn from(s: SizedStream) -> Body { + Body::from_message(s) + } +} + +impl From> for Body +where + S: Stream> + 'static, + E: Into + 'static, +{ + fn from(s: BodyStream) -> Body { + Body::from_message(s) + } +} + +impl MessageBody for Bytes { + fn size(&self) -> BodySize { + BodySize::Sized(self.len()) + } + + fn poll_next(&mut self, _: &mut Context) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(mem::replace(self, Bytes::new())))) + } + } +} + +impl MessageBody for BytesMut { + fn size(&self) -> BodySize { + BodySize::Sized(self.len()) + } + + fn poll_next(&mut self, _: &mut Context) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(mem::replace(self, BytesMut::new()).freeze()))) + } + } +} + +impl MessageBody for &'static str { + fn size(&self) -> BodySize { + BodySize::Sized(self.len()) + } + + fn poll_next(&mut self, _: &mut Context) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(Bytes::from_static( + mem::replace(self, "").as_ref(), + )))) + } + } +} + +impl MessageBody for &'static [u8] { + fn size(&self) -> BodySize { + BodySize::Sized(self.len()) + } + + fn poll_next(&mut self, _: &mut Context) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(Bytes::from_static(mem::replace(self, b""))))) + } + } +} + +impl MessageBody for Vec { + fn size(&self) -> BodySize { + BodySize::Sized(self.len()) + } + + fn poll_next(&mut self, _: &mut Context) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(Bytes::from(mem::replace(self, Vec::new()))))) + } + } +} + +impl MessageBody for String { + fn size(&self) -> BodySize { + BodySize::Sized(self.len()) + } + + fn poll_next(&mut self, _: &mut Context) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(Bytes::from( + mem::replace(self, String::new()).into_bytes(), + )))) + } + } +} + +/// Type represent streaming body. +/// Response does not contain `content-length` header and appropriate transfer encoding is used. +#[pin_project] +pub struct BodyStream { + #[pin] + stream: S, + _t: PhantomData, +} + +impl BodyStream +where + S: Stream>, + E: Into, +{ + pub fn new(stream: S) -> Self { + BodyStream { + stream, + _t: PhantomData, + } + } +} + +impl MessageBody for BodyStream +where + S: Stream>, + E: Into, +{ + fn size(&self) -> BodySize { + BodySize::Stream + } + + fn poll_next(&mut self, cx: &mut Context) -> Poll>> { + unsafe { Pin::new_unchecked(self) } + .project() + .stream + .poll_next(cx) + .map(|res| res.map(|res| res.map_err(std::convert::Into::into))) + } +} + +/// Type represent streaming body. This body implementation should be used +/// if total size of stream is known. Data get sent as is without using transfer encoding. +#[pin_project] +pub struct SizedStream { + size: u64, + #[pin] + stream: S, +} + +impl SizedStream +where + S: Stream>, +{ + pub fn new(size: u64, stream: S) -> Self { + SizedStream { size, stream } + } +} + +impl MessageBody for SizedStream +where + S: Stream>, +{ + fn size(&self) -> BodySize { + BodySize::Sized64(self.size) + } + + fn poll_next(&mut self, cx: &mut Context) -> Poll>> { + unsafe { Pin::new_unchecked(self) } + .project() + .stream + .poll_next(cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::future::poll_fn; + + impl Body { + pub(crate) fn get_ref(&self) -> &[u8] { + match *self { + Body::Bytes(ref bin) => &bin, + _ => panic!(), + } + } + } + + impl ResponseBody { + pub(crate) fn get_ref(&self) -> &[u8] { + match *self { + ResponseBody::Body(ref b) => b.get_ref(), + ResponseBody::Other(ref b) => b.get_ref(), + } + } + } + + #[actix_rt::test] + async fn test_static_str() { + assert_eq!(Body::from("").size(), BodySize::Sized(0)); + assert_eq!(Body::from("test").size(), BodySize::Sized(4)); + assert_eq!(Body::from("test").get_ref(), b"test"); + + assert_eq!("test".size(), BodySize::Sized(4)); + assert_eq!( + poll_fn(|cx| "test".poll_next(cx)).await.unwrap().ok(), + Some(Bytes::from("test")) + ); + } + + #[actix_rt::test] + async fn test_static_bytes() { + assert_eq!(Body::from(b"test".as_ref()).size(), BodySize::Sized(4)); + assert_eq!(Body::from(b"test".as_ref()).get_ref(), b"test"); + assert_eq!( + Body::from_slice(b"test".as_ref()).size(), + BodySize::Sized(4) + ); + assert_eq!(Body::from_slice(b"test".as_ref()).get_ref(), b"test"); + + assert_eq!((&b"test"[..]).size(), BodySize::Sized(4)); + assert_eq!( + poll_fn(|cx| (&b"test"[..]).poll_next(cx)) + .await + .unwrap() + .ok(), + Some(Bytes::from("test")) + ); + } + + #[actix_rt::test] + async fn test_vec() { + assert_eq!(Body::from(Vec::from("test")).size(), BodySize::Sized(4)); + assert_eq!(Body::from(Vec::from("test")).get_ref(), b"test"); + + assert_eq!(Vec::from("test").size(), BodySize::Sized(4)); + assert_eq!( + poll_fn(|cx| Vec::from("test").poll_next(cx)) + .await + .unwrap() + .ok(), + Some(Bytes::from("test")) + ); + } + + #[actix_rt::test] + async fn test_bytes() { + let mut b = Bytes::from("test"); + assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); + assert_eq!(Body::from(b.clone()).get_ref(), b"test"); + + assert_eq!(b.size(), BodySize::Sized(4)); + assert_eq!( + poll_fn(|cx| b.poll_next(cx)).await.unwrap().ok(), + Some(Bytes::from("test")) + ); + } + + #[actix_rt::test] + async fn test_bytes_mut() { + let mut b = BytesMut::from("test"); + assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); + assert_eq!(Body::from(b.clone()).get_ref(), b"test"); + + assert_eq!(b.size(), BodySize::Sized(4)); + assert_eq!( + poll_fn(|cx| b.poll_next(cx)).await.unwrap().ok(), + Some(Bytes::from("test")) + ); + } + + #[actix_rt::test] + async fn test_string() { + let mut b = "test".to_owned(); + assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); + assert_eq!(Body::from(b.clone()).get_ref(), b"test"); + assert_eq!(Body::from(&b).size(), BodySize::Sized(4)); + assert_eq!(Body::from(&b).get_ref(), b"test"); + + assert_eq!(b.size(), BodySize::Sized(4)); + assert_eq!( + poll_fn(|cx| b.poll_next(cx)).await.unwrap().ok(), + Some(Bytes::from("test")) + ); + } + + #[actix_rt::test] + async fn test_unit() { + assert_eq!(().size(), BodySize::Empty); + assert!(poll_fn(|cx| ().poll_next(cx)).await.is_none()); + } + + #[actix_rt::test] + async fn test_box() { + let mut val = Box::new(()); + assert_eq!(val.size(), BodySize::Empty); + assert!(poll_fn(|cx| val.poll_next(cx)).await.is_none()); + } + + #[actix_rt::test] + async fn test_body_eq() { + assert!(Body::None == Body::None); + assert!(Body::None != Body::Empty); + assert!(Body::Empty == Body::Empty); + assert!(Body::Empty != Body::None); + assert!( + Body::Bytes(Bytes::from_static(b"1")) + == Body::Bytes(Bytes::from_static(b"1")) + ); + assert!(Body::Bytes(Bytes::from_static(b"1")) != Body::None); + } + + #[actix_rt::test] + async fn test_body_debug() { + assert!(format!("{:?}", Body::None).contains("Body::None")); + assert!(format!("{:?}", Body::Empty).contains("Body::Empty")); + assert!(format!("{:?}", Body::Bytes(Bytes::from_static(b"1"))).contains("1")); + } + + #[actix_rt::test] + async fn test_serde_json() { + use serde_json::json; + assert_eq!( + Body::from(serde_json::Value::String("test".into())).size(), + BodySize::Sized(6) + ); + assert_eq!( + Body::from(json!({"test-key":"test-value"})).size(), + BodySize::Sized(25) + ); + } +} diff --git a/actix-http/src/builder.rs b/actix-http/src/builder.rs new file mode 100644 index 000000000..7e1dae58f --- /dev/null +++ b/actix-http/src/builder.rs @@ -0,0 +1,230 @@ +use std::fmt; +use std::marker::PhantomData; +use std::rc::Rc; + +use actix_codec::Framed; +use actix_server_config::ServerConfig as SrvConfig; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; + +use crate::body::MessageBody; +use crate::config::{KeepAlive, ServiceConfig}; +use crate::error::Error; +use crate::h1::{Codec, ExpectHandler, H1Service, UpgradeHandler}; +use crate::h2::H2Service; +use crate::helpers::{Data, DataFactory}; +use crate::request::Request; +use crate::response::Response; +use crate::service::HttpService; + +/// A http service builder +/// +/// This type can be used to construct an instance of `http service` through a +/// builder-like pattern. +pub struct HttpServiceBuilder> { + keep_alive: KeepAlive, + client_timeout: u64, + client_disconnect: u64, + expect: X, + upgrade: Option, + on_connect: Option Box>>, + _t: PhantomData<(T, S)>, +} + +impl HttpServiceBuilder> +where + S: ServiceFactory, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + ::Future: 'static, +{ + /// Create instance of `ServiceConfigBuilder` + pub fn new() -> Self { + HttpServiceBuilder { + keep_alive: KeepAlive::Timeout(5), + client_timeout: 5000, + client_disconnect: 0, + expect: ExpectHandler, + upgrade: None, + on_connect: None, + _t: PhantomData, + } + } +} + +impl HttpServiceBuilder +where + S: ServiceFactory, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + ::Future: 'static, + X: ServiceFactory, + X::Error: Into, + X::InitError: fmt::Debug, + ::Future: 'static, + U: ServiceFactory< + Config = SrvConfig, + Request = (Request, Framed), + Response = (), + >, + U::Error: fmt::Display, + U::InitError: fmt::Debug, + ::Future: 'static, +{ + /// Set server keep-alive setting. + /// + /// By default keep alive is set to a 5 seconds. + pub fn keep_alive>(mut self, val: W) -> Self { + self.keep_alive = val.into(); + self + } + + /// Set server client timeout in milliseconds for first request. + /// + /// Defines a timeout for reading client request header. If a client does not transmit + /// the entire set headers within this time, the request is terminated with + /// the 408 (Request Time-out) error. + /// + /// To disable timeout set value to 0. + /// + /// By default client timeout is set to 5000 milliseconds. + pub fn client_timeout(mut self, val: u64) -> Self { + self.client_timeout = val; + self + } + + /// Set server connection disconnect timeout in milliseconds. + /// + /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete + /// within this time, the request get dropped. This timeout affects secure connections. + /// + /// To disable timeout set value to 0. + /// + /// By default disconnect timeout is set to 0. + pub fn client_disconnect(mut self, val: u64) -> Self { + self.client_disconnect = val; + self + } + + /// Provide service for `EXPECT: 100-Continue` support. + /// + /// Service get called with request that contains `EXPECT` header. + /// Service must return request in case of success, in that case + /// request will be forwarded to main service. + pub fn expect(self, expect: F) -> HttpServiceBuilder + where + F: IntoServiceFactory, + X1: ServiceFactory, + X1::Error: Into, + X1::InitError: fmt::Debug, + ::Future: 'static, + { + HttpServiceBuilder { + keep_alive: self.keep_alive, + client_timeout: self.client_timeout, + client_disconnect: self.client_disconnect, + expect: expect.into_factory(), + upgrade: self.upgrade, + on_connect: self.on_connect, + _t: PhantomData, + } + } + + /// Provide service for custom `Connection: UPGRADE` support. + /// + /// If service is provided then normal requests handling get halted + /// and this service get called with original request and framed object. + pub fn upgrade(self, upgrade: F) -> HttpServiceBuilder + where + F: IntoServiceFactory, + U1: ServiceFactory< + Config = SrvConfig, + Request = (Request, Framed), + Response = (), + >, + U1::Error: fmt::Display, + U1::InitError: fmt::Debug, + ::Future: 'static, + { + HttpServiceBuilder { + keep_alive: self.keep_alive, + client_timeout: self.client_timeout, + client_disconnect: self.client_disconnect, + expect: self.expect, + upgrade: Some(upgrade.into_factory()), + on_connect: self.on_connect, + _t: PhantomData, + } + } + + /// Set on-connect callback. + /// + /// It get called once per connection and result of the call + /// get stored to the request's extensions. + pub fn on_connect(mut self, f: F) -> Self + where + F: Fn(&T) -> I + 'static, + I: Clone + 'static, + { + self.on_connect = Some(Rc::new(move |io| Box::new(Data(f(io))))); + self + } + + /// Finish service configuration and create *http service* for HTTP/1 protocol. + pub fn h1(self, service: F) -> H1Service + where + B: MessageBody, + F: IntoServiceFactory, + S::Error: Into, + S::InitError: fmt::Debug, + S::Response: Into>, + { + let cfg = ServiceConfig::new( + self.keep_alive, + self.client_timeout, + self.client_disconnect, + ); + H1Service::with_config(cfg, service.into_factory()) + .expect(self.expect) + .upgrade(self.upgrade) + .on_connect(self.on_connect) + } + + /// Finish service configuration and create *http service* for HTTP/2 protocol. + pub fn h2(self, service: F) -> H2Service + where + B: MessageBody + 'static, + F: IntoServiceFactory, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + S::Response: Into> + 'static, + ::Future: 'static, + { + let cfg = ServiceConfig::new( + self.keep_alive, + self.client_timeout, + self.client_disconnect, + ); + H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect) + } + + /// Finish service configuration and create `HttpService` instance. + pub fn finish(self, service: F) -> HttpService + where + B: MessageBody + 'static, + F: IntoServiceFactory, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + S::Response: Into> + 'static, + ::Future: 'static, + { + let cfg = ServiceConfig::new( + self.keep_alive, + self.client_timeout, + self.client_disconnect, + ); + HttpService::with_config(cfg, service.into_factory()) + .expect(self.expect) + .upgrade(self.upgrade) + .on_connect(self.on_connect) + } +} diff --git a/actix-http/src/client/connection.rs b/actix-http/src/client/connection.rs new file mode 100644 index 000000000..75d393b1b --- /dev/null +++ b/actix-http/src/client/connection.rs @@ -0,0 +1,293 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, io, time}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use bytes::{Buf, Bytes}; +use futures::future::{err, Either, Future, FutureExt, LocalBoxFuture, Ready}; +use h2::client::SendRequest; +use pin_project::{pin_project, project}; + +use crate::body::MessageBody; +use crate::h1::ClientCodec; +use crate::message::{RequestHeadType, ResponseHead}; +use crate::payload::Payload; + +use super::error::SendRequestError; +use super::pool::{Acquired, Protocol}; +use super::{h1proto, h2proto}; + +pub(crate) enum ConnectionType { + H1(Io), + H2(SendRequest), +} + +pub trait Connection { + type Io: AsyncRead + AsyncWrite + Unpin; + type Future: Future>; + + fn protocol(&self) -> Protocol; + + /// Send request and body + fn send_request>( + self, + head: H, + body: B, + ) -> Self::Future; + + type TunnelFuture: Future< + Output = Result<(ResponseHead, Framed), SendRequestError>, + >; + + /// Send request, returns Response and Framed + fn open_tunnel>(self, head: H) -> Self::TunnelFuture; +} + +pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static { + /// Close connection + fn close(&mut self); + + /// Release connection to the connection pool + fn release(&mut self); +} + +#[doc(hidden)] +/// HTTP client connection +pub struct IoConnection { + io: Option>, + created: time::Instant, + pool: Option>, +} + +impl fmt::Debug for IoConnection +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.io { + Some(ConnectionType::H1(ref io)) => write!(f, "H1Connection({:?})", io), + Some(ConnectionType::H2(_)) => write!(f, "H2Connection"), + None => write!(f, "Connection(Empty)"), + } + } +} + +impl IoConnection { + pub(crate) fn new( + io: ConnectionType, + created: time::Instant, + pool: Option>, + ) -> Self { + IoConnection { + pool, + created, + io: Some(io), + } + } + + pub(crate) fn into_inner(self) -> (ConnectionType, time::Instant) { + (self.io.unwrap(), self.created) + } +} + +impl Connection for IoConnection +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ + type Io = T; + type Future = + LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; + + fn protocol(&self) -> Protocol { + match self.io { + Some(ConnectionType::H1(_)) => Protocol::Http1, + Some(ConnectionType::H2(_)) => Protocol::Http2, + None => Protocol::Http1, + } + } + + fn send_request>( + mut self, + head: H, + body: B, + ) -> Self::Future { + match self.io.take().unwrap() { + ConnectionType::H1(io) => { + h1proto::send_request(io, head.into(), body, self.created, self.pool) + .boxed_local() + } + ConnectionType::H2(io) => { + h2proto::send_request(io, head.into(), body, self.created, self.pool) + .boxed_local() + } + } + } + + type TunnelFuture = Either< + LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, + >, + Ready), SendRequestError>>, + >; + + /// Send request, returns Response and Framed + fn open_tunnel>(mut self, head: H) -> Self::TunnelFuture { + match self.io.take().unwrap() { + ConnectionType::H1(io) => { + Either::Left(h1proto::open_tunnel(io, head.into()).boxed_local()) + } + ConnectionType::H2(io) => { + if let Some(mut pool) = self.pool.take() { + pool.release(IoConnection::new( + ConnectionType::H2(io), + self.created, + None, + )); + } + Either::Right(err(SendRequestError::TunnelNotSupported)) + } + } + } +} + +#[allow(dead_code)] +pub(crate) enum EitherConnection { + A(IoConnection), + B(IoConnection), +} + +impl Connection for EitherConnection +where + A: AsyncRead + AsyncWrite + Unpin + 'static, + B: AsyncRead + AsyncWrite + Unpin + 'static, +{ + type Io = EitherIo; + type Future = + LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; + + fn protocol(&self) -> Protocol { + match self { + EitherConnection::A(con) => con.protocol(), + EitherConnection::B(con) => con.protocol(), + } + } + + fn send_request>( + self, + head: H, + body: RB, + ) -> Self::Future { + match self { + EitherConnection::A(con) => con.send_request(head, body), + EitherConnection::B(con) => con.send_request(head, body), + } + } + + type TunnelFuture = LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, + >; + + /// Send request, returns Response and Framed + fn open_tunnel>(self, head: H) -> Self::TunnelFuture { + match self { + EitherConnection::A(con) => con + .open_tunnel(head) + .map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::A)))) + .boxed_local(), + EitherConnection::B(con) => con + .open_tunnel(head) + .map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::B)))) + .boxed_local(), + } + } +} + +#[pin_project] +pub enum EitherIo { + A(#[pin] A), + B(#[pin] B), +} + +impl AsyncRead for EitherIo +where + A: AsyncRead, + B: AsyncRead, +{ + #[project] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + #[project] + match self.project() { + EitherIo::A(val) => val.poll_read(cx, buf), + EitherIo::B(val) => val.poll_read(cx, buf), + } + } + + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + match self { + EitherIo::A(ref val) => val.prepare_uninitialized_buffer(buf), + EitherIo::B(ref val) => val.prepare_uninitialized_buffer(buf), + } + } +} + +impl AsyncWrite for EitherIo +where + A: AsyncWrite, + B: AsyncWrite, +{ + #[project] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + #[project] + match self.project() { + EitherIo::A(val) => val.poll_write(cx, buf), + EitherIo::B(val) => val.poll_write(cx, buf), + } + } + + #[project] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + #[project] + match self.project() { + EitherIo::A(val) => val.poll_flush(cx), + EitherIo::B(val) => val.poll_flush(cx), + } + } + + #[project] + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + #[project] + match self.project() { + EitherIo::A(val) => val.poll_shutdown(cx), + EitherIo::B(val) => val.poll_shutdown(cx), + } + } + + #[project] + fn poll_write_buf( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut U, + ) -> Poll> + where + Self: Sized, + { + #[project] + match self.project() { + EitherIo::A(val) => val.poll_write_buf(cx, buf), + EitherIo::B(val) => val.poll_write_buf(cx, buf), + } + } +} diff --git a/actix-http/src/client/connector.rs b/actix-http/src/client/connector.rs new file mode 100644 index 000000000..eaa3d97e4 --- /dev/null +++ b/actix-http/src/client/connector.rs @@ -0,0 +1,531 @@ +use std::fmt; +use std::marker::PhantomData; +use std::time::Duration; + +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_connect::{ + default_connector, Connect as TcpConnect, Connection as TcpConnection, +}; +use actix_service::{apply_fn, Service}; +use actix_utils::timeout::{TimeoutError, TimeoutService}; +use http::Uri; +use tokio_net::tcp::TcpStream; + +use super::connection::Connection; +use super::error::ConnectError; +use super::pool::{ConnectionPool, Protocol}; +use super::Connect; + +#[cfg(feature = "openssl")] +use open_ssl::ssl::SslConnector as OpensslConnector; + +#[cfg(feature = "rustls")] +use rust_tls::ClientConfig; +#[cfg(feature = "rustls")] +use std::sync::Arc; + +#[cfg(any(feature = "openssl", feature = "rustls"))] +enum SslConnector { + #[cfg(feature = "openssl")] + Openssl(OpensslConnector), + #[cfg(feature = "rustls")] + Rustls(Arc), +} +#[cfg(not(any(feature = "openssl", feature = "rustls")))] +type SslConnector = (); + +/// Manages http client network connectivity +/// The `Connector` type uses a builder-like combinator pattern for service +/// construction that finishes by calling the `.finish()` method. +/// +/// ```rust,ignore +/// use std::time::Duration; +/// use actix_http::client::Connector; +/// +/// let connector = Connector::new() +/// .timeout(Duration::from_secs(5)) +/// .finish(); +/// ``` +pub struct Connector { + connector: T, + timeout: Duration, + conn_lifetime: Duration, + conn_keep_alive: Duration, + disconnect_timeout: Duration, + limit: usize, + #[allow(dead_code)] + ssl: SslConnector, + _t: PhantomData, +} + +trait Io: AsyncRead + AsyncWrite + Unpin {} +impl Io for T {} + +impl Connector<(), ()> { + #[allow(clippy::new_ret_no_self)] + pub fn new() -> Connector< + impl Service< + Request = TcpConnect, + Response = TcpConnection, + Error = actix_connect::ConnectError, + > + Clone, + TcpStream, + > { + let ssl = { + #[cfg(feature = "openssl")] + { + use open_ssl::ssl::SslMethod; + + let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap(); + let _ = ssl + .set_alpn_protos(b"\x02h2\x08http/1.1") + .map_err(|e| error!("Can not set alpn protocol: {:?}", e)); + SslConnector::Openssl(ssl.build()) + } + #[cfg(all(not(feature = "openssl"), feature = "rustls"))] + { + let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + let mut config = ClientConfig::new(); + config.set_protocols(&protos); + config + .root_store + .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + SslConnector::Rustls(Arc::new(config)) + } + #[cfg(not(any(feature = "openssl", feature = "rustls")))] + {} + }; + + Connector { + ssl, + connector: default_connector(), + timeout: Duration::from_secs(1), + conn_lifetime: Duration::from_secs(75), + conn_keep_alive: Duration::from_secs(15), + disconnect_timeout: Duration::from_millis(3000), + limit: 100, + _t: PhantomData, + } + } +} + +impl Connector { + /// Use custom connector. + pub fn connector(self, connector: T1) -> Connector + where + U1: AsyncRead + AsyncWrite + Unpin + fmt::Debug, + T1: Service< + Request = TcpConnect, + Response = TcpConnection, + Error = actix_connect::ConnectError, + > + Clone, + { + Connector { + connector, + timeout: self.timeout, + conn_lifetime: self.conn_lifetime, + conn_keep_alive: self.conn_keep_alive, + disconnect_timeout: self.disconnect_timeout, + limit: self.limit, + ssl: self.ssl, + _t: PhantomData, + } + } +} + +impl Connector +where + U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, + T: Service< + Request = TcpConnect, + Response = TcpConnection, + Error = actix_connect::ConnectError, + > + Clone + + 'static, +{ + /// Connection timeout, i.e. max time to connect to remote host including dns name resolution. + /// Set to 1 second by default. + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + #[cfg(feature = "openssl")] + /// Use custom `SslConnector` instance. + pub fn ssl(mut self, connector: OpensslConnector) -> Self { + self.ssl = SslConnector::Openssl(connector); + self + } + + #[cfg(feature = "rustls")] + pub fn rustls(mut self, connector: Arc) -> Self { + self.ssl = SslConnector::Rustls(connector); + self + } + + /// Set total number of simultaneous connections per type of scheme. + /// + /// If limit is 0, the connector has no limit. + /// The default limit size is 100. + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } + + /// Set keep-alive period for opened connection. + /// + /// Keep-alive period is the period between connection usage. If + /// the delay between repeated usages of the same connection + /// exceeds this period, the connection is closed. + /// Default keep-alive period is 15 seconds. + pub fn conn_keep_alive(mut self, dur: Duration) -> Self { + self.conn_keep_alive = dur; + self + } + + /// Set max lifetime period for connection. + /// + /// Connection lifetime is max lifetime of any opened connection + /// until it is closed regardless of keep-alive period. + /// Default lifetime period is 75 seconds. + pub fn conn_lifetime(mut self, dur: Duration) -> Self { + self.conn_lifetime = dur; + self + } + + /// Set server connection disconnect timeout in milliseconds. + /// + /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete + /// within this time, the socket get dropped. This timeout affects only secure connections. + /// + /// To disable timeout set value to 0. + /// + /// By default disconnect timeout is set to 3000 milliseconds. + pub fn disconnect_timeout(mut self, dur: Duration) -> Self { + self.disconnect_timeout = dur; + self + } + + /// Finish configuration process and create connector service. + /// The Connector builder always concludes by calling `finish()` last in + /// its combinator chain. + pub fn finish( + self, + ) -> impl Service + + Clone { + #[cfg(not(any(feature = "openssl", feature = "rustls")))] + { + let connector = TimeoutService::new( + self.timeout, + apply_fn(self.connector, |msg: Connect, srv| { + srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) + }) + .map_err(ConnectError::from) + .map(|stream| (stream.into_parts().0, Protocol::Http1)), + ) + .map_err(|e| match e { + TimeoutError::Service(e) => e, + TimeoutError::Timeout => ConnectError::Timeout, + }); + + connect_impl::InnerConnector { + tcp_pool: ConnectionPool::new( + connector, + self.conn_lifetime, + self.conn_keep_alive, + None, + self.limit, + ), + } + } + #[cfg(any(feature = "openssl", feature = "rustls"))] + { + const H2: &[u8] = b"h2"; + #[cfg(feature = "openssl")] + use actix_connect::ssl::OpensslConnector; + #[cfg(feature = "rustls")] + use actix_connect::ssl::RustlsConnector; + use actix_service::{boxed::service, pipeline}; + #[cfg(feature = "rustls")] + use rust_tls::Session; + + let ssl_service = TimeoutService::new( + self.timeout, + pipeline( + apply_fn(self.connector.clone(), |msg: Connect, srv| { + srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) + }) + .map_err(ConnectError::from), + ) + .and_then(match self.ssl { + #[cfg(feature = "openssl")] + SslConnector::Openssl(ssl) => service( + OpensslConnector::service(ssl) + .map(|stream| { + let sock = stream.into_parts().0; + let h2 = sock + .ssl() + .selected_alpn_protocol() + .map(|protos| protos.windows(2).any(|w| w == H2)) + .unwrap_or(false); + if h2 { + (Box::new(sock) as Box, Protocol::Http2) + } else { + (Box::new(sock) as Box, Protocol::Http1) + } + }) + .map_err(ConnectError::from), + ), + #[cfg(feature = "rustls")] + SslConnector::Rustls(ssl) => service( + RustlsConnector::service(ssl) + .map_err(ConnectError::from) + .map(|stream| { + let sock = stream.into_parts().0; + let h2 = sock + .get_ref() + .1 + .get_alpn_protocol() + .map(|protos| protos.windows(2).any(|w| w == H2)) + .unwrap_or(false); + if h2 { + (Box::new(sock) as Box, Protocol::Http2) + } else { + (Box::new(sock) as Box, Protocol::Http1) + } + }), + ), + }), + ) + .map_err(|e| match e { + TimeoutError::Service(e) => e, + TimeoutError::Timeout => ConnectError::Timeout, + }); + + let tcp_service = TimeoutService::new( + self.timeout, + apply_fn(self.connector, |msg: Connect, srv| { + srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) + }) + .map_err(ConnectError::from) + .map(|stream| (stream.into_parts().0, Protocol::Http1)), + ) + .map_err(|e| match e { + TimeoutError::Service(e) => e, + TimeoutError::Timeout => ConnectError::Timeout, + }); + + connect_impl::InnerConnector { + tcp_pool: ConnectionPool::new( + tcp_service, + self.conn_lifetime, + self.conn_keep_alive, + None, + self.limit, + ), + ssl_pool: ConnectionPool::new( + ssl_service, + self.conn_lifetime, + self.conn_keep_alive, + Some(self.disconnect_timeout), + self.limit, + ), + } + } + } +} + +#[cfg(not(any(feature = "openssl", feature = "rustls")))] +mod connect_impl { + use std::task::{Context, Poll}; + + use futures::future::{err, Either, Ready}; + + use super::*; + use crate::client::connection::IoConnection; + + pub(crate) struct InnerConnector + where + Io: AsyncRead + AsyncWrite + Unpin + 'static, + T: Service + + 'static, + { + pub(crate) tcp_pool: ConnectionPool, + } + + impl Clone for InnerConnector + where + Io: AsyncRead + AsyncWrite + Unpin + 'static, + T: Service + + 'static, + { + fn clone(&self) -> Self { + InnerConnector { + tcp_pool: self.tcp_pool.clone(), + } + } + } + + impl Service for InnerConnector + where + Io: AsyncRead + AsyncWrite + Unpin + 'static, + T: Service + + 'static, + { + type Request = Connect; + type Response = IoConnection; + type Error = ConnectError; + type Future = Either< + as Service>::Future, + Ready, ConnectError>>, + >; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.tcp_pool.poll_ready(cx) + } + + fn call(&mut self, req: Connect) -> Self::Future { + match req.uri.scheme_str() { + Some("https") | Some("wss") => { + Either::Right(err(ConnectError::SslIsNotSupported)) + } + _ => Either::Left(self.tcp_pool.call(req)), + } + } + } +} + +#[cfg(any(feature = "openssl", feature = "rustls"))] +mod connect_impl { + use std::future::Future; + use std::marker::PhantomData; + use std::pin::Pin; + use std::task::{Context, Poll}; + + use futures::future::Either; + use futures::ready; + + use super::*; + use crate::client::connection::EitherConnection; + + pub(crate) struct InnerConnector + where + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, + T1: Service, + T2: Service, + { + pub(crate) tcp_pool: ConnectionPool, + pub(crate) ssl_pool: ConnectionPool, + } + + impl Clone for InnerConnector + where + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, + T1: Service + + 'static, + T2: Service + + 'static, + { + fn clone(&self) -> Self { + InnerConnector { + tcp_pool: self.tcp_pool.clone(), + ssl_pool: self.ssl_pool.clone(), + } + } + } + + impl Service for InnerConnector + where + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, + T1: Service + + 'static, + T2: Service + + 'static, + { + type Request = Connect; + type Response = EitherConnection; + type Error = ConnectError; + type Future = Either< + InnerConnectorResponseA, + InnerConnectorResponseB, + >; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.tcp_pool.poll_ready(cx) + } + + fn call(&mut self, req: Connect) -> Self::Future { + match req.uri.scheme_str() { + Some("https") | Some("wss") => Either::Right(InnerConnectorResponseB { + fut: self.ssl_pool.call(req), + _t: PhantomData, + }), + _ => Either::Left(InnerConnectorResponseA { + fut: self.tcp_pool.call(req), + _t: PhantomData, + }), + } + } + } + + #[pin_project::pin_project] + pub(crate) struct InnerConnectorResponseA + where + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + T: Service + + 'static, + { + #[pin] + fut: as Service>::Future, + _t: PhantomData, + } + + impl Future for InnerConnectorResponseA + where + T: Service + + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, + { + type Output = Result, ConnectError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + Poll::Ready( + ready!(Pin::new(&mut self.get_mut().fut).poll(cx)) + .map(|res| EitherConnection::A(res)), + ) + } + } + + #[pin_project::pin_project] + pub(crate) struct InnerConnectorResponseB + where + Io2: AsyncRead + AsyncWrite + Unpin + 'static, + T: Service + + 'static, + { + #[pin] + fut: as Service>::Future, + _t: PhantomData, + } + + impl Future for InnerConnectorResponseB + where + T: Service + + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, + { + type Output = Result, ConnectError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + Poll::Ready( + ready!(Pin::new(&mut self.get_mut().fut).poll(cx)) + .map(|res| EitherConnection::B(res)), + ) + } + } +} diff --git a/actix-http/src/client/error.rs b/actix-http/src/client/error.rs new file mode 100644 index 000000000..ee568e8be --- /dev/null +++ b/actix-http/src/client/error.rs @@ -0,0 +1,148 @@ +use std::io; + +use derive_more::{Display, From}; +use trust_dns_resolver::error::ResolveError; + +#[cfg(feature = "openssl")] +use open_ssl::ssl::{Error as SslError, HandshakeError}; + +use crate::error::{Error, ParseError, ResponseError}; +use crate::http::{Error as HttpError, StatusCode}; + +/// A set of errors that can occur while connecting to an HTTP host +#[derive(Debug, Display, From)] +pub enum ConnectError { + /// SSL feature is not enabled + #[display(fmt = "SSL is not supported")] + SslIsNotSupported, + + /// SSL error + #[cfg(feature = "openssl")] + #[display(fmt = "{}", _0)] + SslError(SslError), + + /// Failed to resolve the hostname + #[display(fmt = "Failed resolving hostname: {}", _0)] + Resolver(ResolveError), + + /// No dns records + #[display(fmt = "No dns records found for the input")] + NoRecords, + + /// Http2 error + #[display(fmt = "{}", _0)] + H2(h2::Error), + + /// Connecting took too long + #[display(fmt = "Timeout out while establishing connection")] + Timeout, + + /// Connector has been disconnected + #[display(fmt = "Internal error: connector has been disconnected")] + Disconnected, + + /// Unresolved host name + #[display(fmt = "Connector received `Connect` method with unresolved host")] + Unresolverd, + + /// Connection io error + #[display(fmt = "{}", _0)] + Io(io::Error), +} + +impl From for ConnectError { + fn from(err: actix_connect::ConnectError) -> ConnectError { + match err { + actix_connect::ConnectError::Resolver(e) => ConnectError::Resolver(e), + actix_connect::ConnectError::NoRecords => ConnectError::NoRecords, + actix_connect::ConnectError::InvalidInput => panic!(), + actix_connect::ConnectError::Unresolverd => ConnectError::Unresolverd, + actix_connect::ConnectError::Io(e) => ConnectError::Io(e), + } + } +} + +#[cfg(feature = "openssl")] +impl From> for ConnectError { + fn from(err: HandshakeError) -> ConnectError { + match err { + HandshakeError::SetupFailure(stack) => SslError::from(stack).into(), + HandshakeError::Failure(stream) => stream.into_error().into(), + HandshakeError::WouldBlock(stream) => stream.into_error().into(), + } + } +} + +#[derive(Debug, Display, From)] +pub enum InvalidUrl { + #[display(fmt = "Missing url scheme")] + MissingScheme, + #[display(fmt = "Unknown url scheme")] + UnknownScheme, + #[display(fmt = "Missing host name")] + MissingHost, + #[display(fmt = "Url parse error: {}", _0)] + HttpError(http::Error), +} + +/// A set of errors that can occur during request sending and response reading +#[derive(Debug, Display, From)] +pub enum SendRequestError { + /// Invalid URL + #[display(fmt = "Invalid URL: {}", _0)] + Url(InvalidUrl), + /// Failed to connect to host + #[display(fmt = "Failed to connect to host: {}", _0)] + Connect(ConnectError), + /// Error sending request + Send(io::Error), + /// Error parsing response + Response(ParseError), + /// Http error + #[display(fmt = "{}", _0)] + Http(HttpError), + /// Http2 error + #[display(fmt = "{}", _0)] + H2(h2::Error), + /// Response took too long + #[display(fmt = "Timeout out while waiting for response")] + Timeout, + /// Tunnels are not supported for http2 connection + #[display(fmt = "Tunnels are not supported for http2 connection")] + TunnelNotSupported, + /// Error sending request body + Body(Error), +} + +/// Convert `SendRequestError` to a server `Response` +impl ResponseError for SendRequestError { + fn status_code(&self) -> StatusCode { + match *self { + SendRequestError::Connect(ConnectError::Timeout) => { + StatusCode::GATEWAY_TIMEOUT + } + SendRequestError::Connect(_) => StatusCode::BAD_REQUEST, + _ => StatusCode::INTERNAL_SERVER_ERROR, + } + } +} + +/// A set of errors that can occur during freezing a request +#[derive(Debug, Display, From)] +pub enum FreezeRequestError { + /// Invalid URL + #[display(fmt = "Invalid URL: {}", _0)] + Url(InvalidUrl), + /// Http error + #[display(fmt = "{}", _0)] + Http(HttpError), +} + +impl From for SendRequestError { + fn from(e: FreezeRequestError) -> Self { + match e { + FreezeRequestError::Url(e) => e.into(), + FreezeRequestError::Http(e) => e.into(), + } + } +} diff --git a/actix-http/src/client/h1proto.rs b/actix-http/src/client/h1proto.rs new file mode 100644 index 000000000..ddfc7a314 --- /dev/null +++ b/actix-http/src/client/h1proto.rs @@ -0,0 +1,284 @@ +use std::io::Write; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{io, time}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use bytes::{BufMut, Bytes, BytesMut}; +use futures::future::poll_fn; +use futures::{SinkExt, Stream, StreamExt}; + +use crate::error::PayloadError; +use crate::h1; +use crate::header::HeaderMap; +use crate::http::header::{IntoHeaderValue, HOST}; +use crate::message::{RequestHeadType, ResponseHead}; +use crate::payload::{Payload, PayloadStream}; + +use super::connection::{ConnectionLifetime, ConnectionType, IoConnection}; +use super::error::{ConnectError, SendRequestError}; +use super::pool::Acquired; +use crate::body::{BodySize, MessageBody}; + +pub(crate) async fn send_request( + io: T, + mut head: RequestHeadType, + body: B, + created: time::Instant, + pool: Option>, +) -> Result<(ResponseHead, Payload), SendRequestError> +where + T: AsyncRead + AsyncWrite + Unpin + 'static, + B: MessageBody, +{ + // set request host header + if !head.as_ref().headers.contains_key(HOST) + && !head.extra_headers().iter().any(|h| h.contains_key(HOST)) + { + if let Some(host) = head.as_ref().uri.host() { + let mut wrt = BytesMut::with_capacity(host.len() + 5).writer(); + + let _ = match head.as_ref().uri.port_u16() { + None | Some(80) | Some(443) => write!(wrt, "{}", host), + Some(port) => write!(wrt, "{}:{}", host, port), + }; + + match wrt.get_mut().take().freeze().try_into() { + Ok(value) => match head { + RequestHeadType::Owned(ref mut head) => { + head.headers.insert(HOST, value) + } + RequestHeadType::Rc(_, ref mut extra_headers) => { + let headers = extra_headers.get_or_insert(HeaderMap::new()); + headers.insert(HOST, value) + } + }, + Err(e) => log::error!("Can not set HOST header {}", e), + } + } + } + + let io = H1Connection { + created, + pool, + io: Some(io), + }; + + // create Framed and send request + let mut framed = Framed::new(io, h1::ClientCodec::default()); + framed.send((head, body.size()).into()).await?; + + // send request body + match body.size() { + BodySize::None | BodySize::Empty | BodySize::Sized(0) => (), + _ => send_body(body, &mut framed).await?, + }; + + // read response and init read body + let res = framed.into_future().await; + let (head, framed) = if let (Some(result), framed) = res { + let item = result.map_err(SendRequestError::from)?; + (item, framed) + } else { + return Err(SendRequestError::from(ConnectError::Disconnected)); + }; + + match framed.get_codec().message_type() { + h1::MessageType::None => { + let force_close = !framed.get_codec().keepalive(); + release_connection(framed, force_close); + Ok((head, Payload::None)) + } + _ => { + let pl: PayloadStream = PlStream::new(framed).boxed_local(); + Ok((head, pl.into())) + } + } +} + +pub(crate) async fn open_tunnel( + io: T, + head: RequestHeadType, +) -> Result<(ResponseHead, Framed), SendRequestError> +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ + // create Framed and send request + let mut framed = Framed::new(io, h1::ClientCodec::default()); + framed.send((head, BodySize::None).into()).await?; + + // read response + if let (Some(result), framed) = framed.into_future().await { + let head = result.map_err(SendRequestError::from)?; + Ok((head, framed)) + } else { + Err(SendRequestError::from(ConnectError::Disconnected)) + } +} + +/// send request body to the peer +pub(crate) async fn send_body( + mut body: B, + framed: &mut Framed, +) -> Result<(), SendRequestError> +where + I: ConnectionLifetime, + B: MessageBody, +{ + let mut eof = false; + while !eof { + while !eof && !framed.is_write_buf_full() { + match poll_fn(|cx| body.poll_next(cx)).await { + Some(result) => { + framed.write(h1::Message::Chunk(Some(result?)))?; + } + None => { + eof = true; + framed.write(h1::Message::Chunk(None))?; + } + } + } + + if !framed.is_write_buf_empty() { + poll_fn(|cx| match framed.flush(cx) { + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => { + if !framed.is_write_buf_full() { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + }) + .await?; + } + } + + SinkExt::flush(framed).await?; + Ok(()) +} + +#[doc(hidden)] +/// HTTP client connection +pub struct H1Connection { + io: Option, + created: time::Instant, + pool: Option>, +} + +impl ConnectionLifetime for H1Connection +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ + /// Close connection + fn close(&mut self) { + if let Some(mut pool) = self.pool.take() { + if let Some(io) = self.io.take() { + pool.close(IoConnection::new( + ConnectionType::H1(io), + self.created, + None, + )); + } + } + } + + /// Release this connection to the connection pool + fn release(&mut self) { + if let Some(mut pool) = self.pool.take() { + if let Some(io) = self.io.take() { + pool.release(IoConnection::new( + ConnectionType::H1(io), + self.created, + None, + )); + } + } + } +} + +impl AsyncRead for H1Connection { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.io.as_ref().unwrap().prepare_uninitialized_buffer(buf) + } + + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut self.io.as_mut().unwrap()).poll_read(cx, buf) + } +} + +impl AsyncWrite for H1Connection { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.io.as_mut().unwrap()).poll_write(cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(self.io.as_mut().unwrap()).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + Pin::new(self.io.as_mut().unwrap()).poll_shutdown(cx) + } +} + +pub(crate) struct PlStream { + framed: Option>, +} + +impl PlStream { + fn new(framed: Framed) -> Self { + PlStream { + framed: Some(framed.map_codec(|codec| codec.into_payload_codec())), + } + } +} + +impl Stream for PlStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.get_mut(); + + match this.framed.as_mut().unwrap().next_item(cx)? { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(chunk)) => { + if let Some(chunk) = chunk { + Poll::Ready(Some(Ok(chunk))) + } else { + let framed = this.framed.take().unwrap(); + let force_close = !framed.get_codec().keepalive(); + release_connection(framed, force_close); + Poll::Ready(None) + } + } + Poll::Ready(None) => Poll::Ready(None), + } + } +} + +fn release_connection(framed: Framed, force_close: bool) +where + T: ConnectionLifetime, +{ + let mut parts = framed.into_parts(); + if !force_close && parts.read_buf.is_empty() && parts.write_buf.is_empty() { + parts.io.release() + } else { + parts.io.close() + } +} diff --git a/actix-http/src/client/h2proto.rs b/actix-http/src/client/h2proto.rs new file mode 100644 index 000000000..a94562f2d --- /dev/null +++ b/actix-http/src/client/h2proto.rs @@ -0,0 +1,184 @@ +use std::time; + +use actix_codec::{AsyncRead, AsyncWrite}; +use bytes::Bytes; +use futures::future::poll_fn; +use h2::{client::SendRequest, SendStream}; +use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING}; +use http::{request::Request, HttpTryFrom, Method, Version}; + +use crate::body::{BodySize, MessageBody}; +use crate::header::HeaderMap; +use crate::message::{RequestHeadType, ResponseHead}; +use crate::payload::Payload; + +use super::connection::{ConnectionType, IoConnection}; +use super::error::SendRequestError; +use super::pool::Acquired; + +pub(crate) async fn send_request( + mut io: SendRequest, + head: RequestHeadType, + body: B, + created: time::Instant, + pool: Option>, +) -> Result<(ResponseHead, Payload), SendRequestError> +where + T: AsyncRead + AsyncWrite + Unpin + 'static, + B: MessageBody, +{ + trace!("Sending client request: {:?} {:?}", head, body.size()); + let head_req = head.as_ref().method == Method::HEAD; + let length = body.size(); + let eof = match length { + BodySize::None | BodySize::Empty | BodySize::Sized(0) => true, + _ => false, + }; + + let mut req = Request::new(()); + *req.uri_mut() = head.as_ref().uri.clone(); + *req.method_mut() = head.as_ref().method.clone(); + *req.version_mut() = Version::HTTP_2; + + let mut skip_len = true; + // let mut has_date = false; + + // Content length + let _ = match length { + BodySize::None => None, + BodySize::Stream => { + skip_len = false; + None + } + BodySize::Empty => req + .headers_mut() + .insert(CONTENT_LENGTH, HeaderValue::from_static("0")), + BodySize::Sized(len) => req.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::try_from(format!("{}", len)).unwrap(), + ), + BodySize::Sized64(len) => req.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::try_from(format!("{}", len)).unwrap(), + ), + }; + + // Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate. + let (head, extra_headers) = match head { + RequestHeadType::Owned(head) => (RequestHeadType::Owned(head), HeaderMap::new()), + RequestHeadType::Rc(head, extra_headers) => ( + RequestHeadType::Rc(head, None), + extra_headers.unwrap_or_else(HeaderMap::new), + ), + }; + + // merging headers from head and extra headers. + let headers = head + .as_ref() + .headers + .iter() + .filter(|(name, _)| !extra_headers.contains_key(*name)) + .chain(extra_headers.iter()); + + // copy headers + for (key, value) in headers { + match *key { + CONNECTION | TRANSFER_ENCODING => continue, // http2 specific + CONTENT_LENGTH if skip_len => continue, + // DATE => has_date = true, + _ => (), + } + req.headers_mut().append(key, value.clone()); + } + + let res = poll_fn(|cx| io.poll_ready(cx)).await; + if let Err(e) = res { + release(io, pool, created, e.is_io()); + return Err(SendRequestError::from(e)); + } + + let resp = match io.send_request(req, eof) { + Ok((fut, send)) => { + release(io, pool, created, false); + + if !eof { + send_body(body, send).await?; + } + fut.await.map_err(SendRequestError::from)? + } + Err(e) => { + release(io, pool, created, e.is_io()); + return Err(e.into()); + } + }; + + let (parts, body) = resp.into_parts(); + let payload = if head_req { Payload::None } else { body.into() }; + + let mut head = ResponseHead::new(parts.status); + head.version = parts.version; + head.headers = parts.headers.into(); + Ok((head, payload)) +} + +async fn send_body( + mut body: B, + mut send: SendStream, +) -> Result<(), SendRequestError> { + let mut buf = None; + loop { + if buf.is_none() { + match poll_fn(|cx| body.poll_next(cx)).await { + Some(Ok(b)) => { + send.reserve_capacity(b.len()); + buf = Some(b); + } + Some(Err(e)) => return Err(e.into()), + None => { + if let Err(e) = send.send_data(Bytes::new(), true) { + return Err(e.into()); + } + send.reserve_capacity(0); + return Ok(()); + } + } + } + + match poll_fn(|cx| send.poll_capacity(cx)).await { + None => return Ok(()), + Some(Ok(cap)) => { + let b = buf.as_mut().unwrap(); + let len = b.len(); + let bytes = b.split_to(std::cmp::min(cap, len)); + + if let Err(e) = send.send_data(bytes, false) { + return Err(e.into()); + } else { + if !b.is_empty() { + send.reserve_capacity(b.len()); + } else { + buf = None; + } + continue; + } + } + Some(Err(e)) => return Err(e.into()), + } + } +} + +// release SendRequest object +fn release( + io: SendRequest, + pool: Option>, + created: time::Instant, + close: bool, +) { + if let Some(mut pool) = pool { + if close { + pool.close(IoConnection::new(ConnectionType::H2(io), created, None)); + } else { + pool.release(IoConnection::new(ConnectionType::H2(io), created, None)); + } + } +} diff --git a/actix-http/src/client/mod.rs b/actix-http/src/client/mod.rs new file mode 100644 index 000000000..a45aebcd5 --- /dev/null +++ b/actix-http/src/client/mod.rs @@ -0,0 +1,20 @@ +//! Http client api +use http::Uri; + +mod connection; +mod connector; +mod error; +mod h1proto; +mod h2proto; +mod pool; + +pub use self::connection::Connection; +pub use self::connector::Connector; +pub use self::error::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError}; +pub use self::pool::Protocol; + +#[derive(Clone)] +pub struct Connect { + pub uri: Uri, + pub addr: Option, +} diff --git a/actix-http/src/client/pool.rs b/actix-http/src/client/pool.rs new file mode 100644 index 000000000..c61039866 --- /dev/null +++ b/actix-http/src/client/pool.rs @@ -0,0 +1,630 @@ +use std::cell::RefCell; +use std::collections::VecDeque; +use std::future::Future; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; +use std::time::{Duration, Instant}; + +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_rt::time::{delay_for, Delay}; +use actix_service::Service; +use actix_utils::{oneshot, task::LocalWaker}; +use bytes::Bytes; +use futures::future::{poll_fn, FutureExt, LocalBoxFuture}; +use h2::client::{handshake, Connection, SendRequest}; +use hashbrown::HashMap; +use http::uri::Authority; +use indexmap::IndexSet; +use slab::Slab; + +use super::connection::{ConnectionType, IoConnection}; +use super::error::ConnectError; +use super::Connect; + +#[derive(Clone, Copy, PartialEq)] +/// Protocol version +pub enum Protocol { + Http1, + Http2, +} + +#[derive(Hash, Eq, PartialEq, Clone, Debug)] +pub(crate) struct Key { + authority: Authority, +} + +impl From for Key { + fn from(authority: Authority) -> Key { + Key { authority } + } +} + +/// Connections pool +pub(crate) struct ConnectionPool(Rc>, Rc>>); + +impl ConnectionPool +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, + T: Service + + 'static, +{ + pub(crate) fn new( + connector: T, + conn_lifetime: Duration, + conn_keep_alive: Duration, + disconnect_timeout: Option, + limit: usize, + ) -> Self { + ConnectionPool( + Rc::new(RefCell::new(connector)), + Rc::new(RefCell::new(Inner { + conn_lifetime, + conn_keep_alive, + disconnect_timeout, + limit, + acquired: 0, + waiters: Slab::new(), + waiters_queue: IndexSet::new(), + available: HashMap::new(), + waker: LocalWaker::new(), + })), + ) + } +} + +impl Clone for ConnectionPool +where + Io: 'static, +{ + fn clone(&self) -> Self { + ConnectionPool(self.0.clone(), self.1.clone()) + } +} + +impl Service for ConnectionPool +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, + T: Service + + 'static, +{ + type Request = Connect; + type Response = IoConnection; + type Error = ConnectError; + type Future = LocalBoxFuture<'static, Result, ConnectError>>; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.0.poll_ready(cx) + } + + fn call(&mut self, req: Connect) -> Self::Future { + // start support future + actix_rt::spawn(ConnectorPoolSupport { + connector: self.0.clone(), + inner: self.1.clone(), + }); + + let mut connector = self.0.clone(); + let inner = self.1.clone(); + + let fut = async move { + let key = if let Some(authority) = req.uri.authority_part() { + authority.clone().into() + } else { + return Err(ConnectError::Unresolverd); + }; + + // acquire connection + match poll_fn(|cx| Poll::Ready(inner.borrow_mut().acquire(&key, cx))).await { + Acquire::Acquired(io, created) => { + // use existing connection + return Ok(IoConnection::new( + io, + created, + Some(Acquired(key, Some(inner))), + )); + } + Acquire::Available => { + // open tcp connection + let (io, proto) = connector.call(req).await?; + + let guard = OpenGuard::new(key, inner); + + if proto == Protocol::Http1 { + Ok(IoConnection::new( + ConnectionType::H1(io), + Instant::now(), + Some(guard.consume()), + )) + } else { + let (snd, connection) = handshake(io).await?; + actix_rt::spawn(connection.map(|_| ())); + Ok(IoConnection::new( + ConnectionType::H2(snd), + Instant::now(), + Some(guard.consume()), + )) + } + } + _ => { + // connection is not available, wait + let (rx, token) = inner.borrow_mut().wait_for(req); + + let guard = WaiterGuard::new(key, token, inner); + let res = match rx.await { + Err(_) => Err(ConnectError::Disconnected), + Ok(res) => res, + }; + guard.consume(); + res + } + } + }; + + fut.boxed_local() + } +} + +struct WaiterGuard +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + key: Key, + token: usize, + inner: Option>>>, +} + +impl WaiterGuard +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + fn new(key: Key, token: usize, inner: Rc>>) -> Self { + Self { + key, + token, + inner: Some(inner), + } + } + + fn consume(mut self) { + let _ = self.inner.take(); + } +} + +impl Drop for WaiterGuard +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + fn drop(&mut self) { + if let Some(i) = self.inner.take() { + let mut inner = i.as_ref().borrow_mut(); + inner.release_waiter(&self.key, self.token); + inner.check_availibility(); + } + } +} + +struct OpenGuard +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + key: Key, + inner: Option>>>, +} + +impl OpenGuard +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + fn new(key: Key, inner: Rc>>) -> Self { + Self { + key, + inner: Some(inner), + } + } + + fn consume(mut self) -> Acquired { + Acquired(self.key.clone(), self.inner.take()) + } +} + +impl Drop for OpenGuard +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + fn drop(&mut self) { + if let Some(i) = self.inner.take() { + let mut inner = i.as_ref().borrow_mut(); + inner.release(); + inner.check_availibility(); + } + } +} + +enum Acquire { + Acquired(ConnectionType, Instant), + Available, + NotAvailable, +} + +struct AvailableConnection { + io: ConnectionType, + used: Instant, + created: Instant, +} + +pub(crate) struct Inner { + conn_lifetime: Duration, + conn_keep_alive: Duration, + disconnect_timeout: Option, + limit: usize, + acquired: usize, + available: HashMap>>, + waiters: Slab< + Option<( + Connect, + oneshot::Sender, ConnectError>>, + )>, + >, + waiters_queue: IndexSet<(Key, usize)>, + waker: LocalWaker, +} + +impl Inner { + fn reserve(&mut self) { + self.acquired += 1; + } + + fn release(&mut self) { + self.acquired -= 1; + } + + fn release_waiter(&mut self, key: &Key, token: usize) { + self.waiters.remove(token); + let _ = self.waiters_queue.shift_remove(&(key.clone(), token)); + } +} + +impl Inner +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + /// connection is not available, wait + fn wait_for( + &mut self, + connect: Connect, + ) -> ( + oneshot::Receiver, ConnectError>>, + usize, + ) { + let (tx, rx) = oneshot::channel(); + + let key: Key = connect.uri.authority_part().unwrap().clone().into(); + let entry = self.waiters.vacant_entry(); + let token = entry.key(); + entry.insert(Some((connect, tx))); + assert!(self.waiters_queue.insert((key, token))); + + (rx, token) + } + + fn acquire(&mut self, key: &Key, cx: &mut Context) -> Acquire { + // check limits + if self.limit > 0 && self.acquired >= self.limit { + return Acquire::NotAvailable; + } + + self.reserve(); + + // check if open connection is available + // cleanup stale connections at the same time + if let Some(ref mut connections) = self.available.get_mut(key) { + let now = Instant::now(); + while let Some(conn) = connections.pop_back() { + // check if it still usable + if (now - conn.used) > self.conn_keep_alive + || (now - conn.created) > self.conn_lifetime + { + if let Some(timeout) = self.disconnect_timeout { + if let ConnectionType::H1(io) = conn.io { + actix_rt::spawn(CloseConnection::new(io, timeout)) + } + } + } else { + let mut io = conn.io; + let mut buf = [0; 2]; + if let ConnectionType::H1(ref mut s) = io { + match Pin::new(s).poll_read(cx, &mut buf) { + Poll::Pending => (), + Poll::Ready(Ok(n)) if n > 0 => { + if let Some(timeout) = self.disconnect_timeout { + if let ConnectionType::H1(io) = io { + actix_rt::spawn(CloseConnection::new( + io, timeout, + )) + } + } + continue; + } + _ => continue, + } + } + return Acquire::Acquired(io, conn.created); + } + } + } + Acquire::Available + } + + fn release_conn(&mut self, key: &Key, io: ConnectionType, created: Instant) { + self.acquired -= 1; + self.available + .entry(key.clone()) + .or_insert_with(VecDeque::new) + .push_back(AvailableConnection { + io, + created, + used: Instant::now(), + }); + self.check_availibility(); + } + + fn release_close(&mut self, io: ConnectionType) { + self.acquired -= 1; + if let Some(timeout) = self.disconnect_timeout { + if let ConnectionType::H1(io) = io { + actix_rt::spawn(CloseConnection::new(io, timeout)) + } + } + self.check_availibility(); + } + + fn check_availibility(&self) { + if !self.waiters_queue.is_empty() && self.acquired < self.limit { + self.waker.wake(); + } + } +} + +struct CloseConnection { + io: T, + timeout: Delay, +} + +impl CloseConnection +where + T: AsyncWrite + Unpin, +{ + fn new(io: T, timeout: Duration) -> Self { + CloseConnection { + io, + timeout: delay_for(timeout), + } + } +} + +impl Future for CloseConnection +where + T: AsyncWrite + Unpin, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> { + let this = self.get_mut(); + + match Pin::new(&mut this.timeout).poll(cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => match Pin::new(&mut this.io).poll_shutdown(cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + }, + } + } +} + +struct ConnectorPoolSupport +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + connector: T, + inner: Rc>>, +} + +impl Future for ConnectorPoolSupport +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, + T: Service, + T::Future: 'static, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = unsafe { self.get_unchecked_mut() }; + + let mut inner = this.inner.as_ref().borrow_mut(); + inner.waker.register(cx.waker()); + + // check waiters + loop { + let (key, token) = { + if let Some((key, token)) = inner.waiters_queue.get_index(0) { + (key.clone(), *token) + } else { + break; + } + }; + if inner.waiters.get(token).unwrap().is_none() { + continue; + } + + match inner.acquire(&key, cx) { + Acquire::NotAvailable => break, + Acquire::Acquired(io, created) => { + let tx = inner.waiters.get_mut(token).unwrap().take().unwrap().1; + if let Err(conn) = tx.send(Ok(IoConnection::new( + io, + created, + Some(Acquired(key.clone(), Some(this.inner.clone()))), + ))) { + let (io, created) = conn.unwrap().into_inner(); + inner.release_conn(&key, io, created); + } + } + Acquire::Available => { + let (connect, tx) = + inner.waiters.get_mut(token).unwrap().take().unwrap(); + OpenWaitingConnection::spawn( + key.clone(), + tx, + this.inner.clone(), + this.connector.call(connect), + ); + } + } + let _ = inner.waiters_queue.swap_remove_index(0); + } + + Poll::Pending + } +} + +struct OpenWaitingConnection +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + fut: F, + key: Key, + h2: Option< + LocalBoxFuture< + 'static, + Result<(SendRequest, Connection), h2::Error>, + >, + >, + rx: Option, ConnectError>>>, + inner: Option>>>, +} + +impl OpenWaitingConnection +where + F: Future> + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + fn spawn( + key: Key, + rx: oneshot::Sender, ConnectError>>, + inner: Rc>>, + fut: F, + ) { + actix_rt::spawn(OpenWaitingConnection { + key, + fut, + h2: None, + rx: Some(rx), + inner: Some(inner), + }) + } +} + +impl Drop for OpenWaitingConnection +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + fn drop(&mut self) { + if let Some(inner) = self.inner.take() { + let mut inner = inner.as_ref().borrow_mut(); + inner.release(); + inner.check_availibility(); + } + } +} + +impl Future for OpenWaitingConnection +where + F: Future>, + Io: AsyncRead + AsyncWrite + Unpin, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = unsafe { self.get_unchecked_mut() }; + + if let Some(ref mut h2) = this.h2 { + return match Pin::new(h2).poll(cx) { + Poll::Ready(Ok((snd, connection))) => { + actix_rt::spawn(connection.map(|_| ())); + let rx = this.rx.take().unwrap(); + let _ = rx.send(Ok(IoConnection::new( + ConnectionType::H2(snd), + Instant::now(), + Some(Acquired(this.key.clone(), this.inner.take())), + ))); + Poll::Ready(()) + } + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => { + let _ = this.inner.take(); + if let Some(rx) = this.rx.take() { + let _ = rx.send(Err(ConnectError::H2(err))); + } + Poll::Ready(()) + } + }; + } + + match unsafe { Pin::new_unchecked(&mut this.fut) }.poll(cx) { + Poll::Ready(Err(err)) => { + let _ = this.inner.take(); + if let Some(rx) = this.rx.take() { + let _ = rx.send(Err(err)); + } + Poll::Ready(()) + } + Poll::Ready(Ok((io, proto))) => { + if proto == Protocol::Http1 { + let rx = this.rx.take().unwrap(); + let _ = rx.send(Ok(IoConnection::new( + ConnectionType::H1(io), + Instant::now(), + Some(Acquired(this.key.clone(), this.inner.take())), + ))); + Poll::Ready(()) + } else { + this.h2 = Some(handshake(io).boxed_local()); + unsafe { Pin::new_unchecked(this) }.poll(cx) + } + } + Poll::Pending => Poll::Pending, + } + } +} + +pub(crate) struct Acquired(Key, Option>>>); + +impl Acquired +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ + pub(crate) fn close(&mut self, conn: IoConnection) { + if let Some(inner) = self.1.take() { + let (io, _) = conn.into_inner(); + inner.as_ref().borrow_mut().release_close(io); + } + } + pub(crate) fn release(&mut self, conn: IoConnection) { + if let Some(inner) = self.1.take() { + let (io, created) = conn.into_inner(); + inner + .as_ref() + .borrow_mut() + .release_conn(&self.0, io, created); + } + } +} + +impl Drop for Acquired { + fn drop(&mut self) { + if let Some(inner) = self.1.take() { + inner.as_ref().borrow_mut().release(); + } + } +} diff --git a/actix-http/src/cloneable.rs b/actix-http/src/cloneable.rs new file mode 100644 index 000000000..18869c66d --- /dev/null +++ b/actix-http/src/cloneable.rs @@ -0,0 +1,42 @@ +use std::cell::UnsafeCell; +use std::rc::Rc; +use std::task::{Context, Poll}; + +use actix_service::Service; + +#[doc(hidden)] +/// Service that allows to turn non-clone service to a service with `Clone` impl +pub(crate) struct CloneableService(Rc>); + +impl CloneableService { + pub(crate) fn new(service: T) -> Self + where + T: Service, + { + Self(Rc::new(UnsafeCell::new(service))) + } +} + +impl Clone for CloneableService { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Service for CloneableService +where + T: Service, +{ + type Request = T::Request; + type Response = T::Response; + type Error = T::Error; + type Future = T::Future; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + unsafe { &mut *self.0.as_ref().get() }.poll_ready(cx) + } + + fn call(&mut self, req: T::Request) -> Self::Future { + unsafe { &mut *self.0.as_ref().get() }.call(req) + } +} diff --git a/actix-http/src/config.rs b/actix-http/src/config.rs new file mode 100644 index 000000000..bab3cdc6d --- /dev/null +++ b/actix-http/src/config.rs @@ -0,0 +1,281 @@ +use std::cell::UnsafeCell; +use std::fmt; +use std::fmt::Write; +use std::rc::Rc; +use std::time::{Duration, Instant}; + +use actix_rt::time::{delay, delay_for, Delay}; +use bytes::BytesMut; +use futures::{future, FutureExt}; +use time; + +// "Sun, 06 Nov 1994 08:49:37 GMT".len() +const DATE_VALUE_LENGTH: usize = 29; + +#[derive(Debug, PartialEq, Clone, Copy)] +/// Server keep-alive setting +pub enum KeepAlive { + /// Keep alive in seconds + Timeout(usize), + /// Relay on OS to shutdown tcp connection + Os, + /// Disabled + Disabled, +} + +impl From for KeepAlive { + fn from(keepalive: usize) -> Self { + KeepAlive::Timeout(keepalive) + } +} + +impl From> for KeepAlive { + fn from(keepalive: Option) -> Self { + if let Some(keepalive) = keepalive { + KeepAlive::Timeout(keepalive) + } else { + KeepAlive::Disabled + } + } +} + +/// Http service configuration +pub struct ServiceConfig(Rc); + +struct Inner { + keep_alive: Option, + client_timeout: u64, + client_disconnect: u64, + ka_enabled: bool, + timer: DateService, +} + +impl Clone for ServiceConfig { + fn clone(&self) -> Self { + ServiceConfig(self.0.clone()) + } +} + +impl Default for ServiceConfig { + fn default() -> Self { + Self::new(KeepAlive::Timeout(5), 0, 0) + } +} + +impl ServiceConfig { + /// Create instance of `ServiceConfig` + pub fn new( + keep_alive: KeepAlive, + client_timeout: u64, + client_disconnect: u64, + ) -> ServiceConfig { + let (keep_alive, ka_enabled) = match keep_alive { + KeepAlive::Timeout(val) => (val as u64, true), + KeepAlive::Os => (0, true), + KeepAlive::Disabled => (0, false), + }; + let keep_alive = if ka_enabled && keep_alive > 0 { + Some(Duration::from_secs(keep_alive)) + } else { + None + }; + + ServiceConfig(Rc::new(Inner { + keep_alive, + ka_enabled, + client_timeout, + client_disconnect, + timer: DateService::new(), + })) + } + + #[inline] + /// Keep alive duration if configured. + pub fn keep_alive(&self) -> Option { + self.0.keep_alive + } + + #[inline] + /// Return state of connection keep-alive funcitonality + pub fn keep_alive_enabled(&self) -> bool { + self.0.ka_enabled + } + + #[inline] + /// Client timeout for first request. + pub fn client_timer(&self) -> Option { + let delay_time = self.0.client_timeout; + if delay_time != 0 { + Some(delay( + self.0.timer.now() + Duration::from_millis(delay_time), + )) + } else { + None + } + } + + /// Client timeout for first request. + pub fn client_timer_expire(&self) -> Option { + let delay = self.0.client_timeout; + if delay != 0 { + Some(self.0.timer.now() + Duration::from_millis(delay)) + } else { + None + } + } + + /// Client disconnect timer + pub fn client_disconnect_timer(&self) -> Option { + let delay = self.0.client_disconnect; + if delay != 0 { + Some(self.0.timer.now() + Duration::from_millis(delay)) + } else { + None + } + } + + #[inline] + /// Return keep-alive timer delay is configured. + pub fn keep_alive_timer(&self) -> Option { + if let Some(ka) = self.0.keep_alive { + Some(delay(self.0.timer.now() + ka)) + } else { + None + } + } + + /// Keep-alive expire time + pub fn keep_alive_expire(&self) -> Option { + if let Some(ka) = self.0.keep_alive { + Some(self.0.timer.now() + ka) + } else { + None + } + } + + #[inline] + pub(crate) fn now(&self) -> Instant { + self.0.timer.now() + } + + #[doc(hidden)] + pub fn set_date(&self, dst: &mut BytesMut) { + let mut buf: [u8; 39] = [0; 39]; + buf[..6].copy_from_slice(b"date: "); + self.0 + .timer + .set_date(|date| buf[6..35].copy_from_slice(&date.bytes)); + buf[35..].copy_from_slice(b"\r\n\r\n"); + dst.extend_from_slice(&buf); + } + + pub(crate) fn set_date_header(&self, dst: &mut BytesMut) { + self.0 + .timer + .set_date(|date| dst.extend_from_slice(&date.bytes)); + } +} + +#[derive(Copy, Clone)] +struct Date { + bytes: [u8; DATE_VALUE_LENGTH], + pos: usize, +} + +impl Date { + fn new() -> Date { + let mut date = Date { + bytes: [0; DATE_VALUE_LENGTH], + pos: 0, + }; + date.update(); + date + } + fn update(&mut self) { + self.pos = 0; + write!(self, "{}", time::at_utc(time::get_time()).rfc822()).unwrap(); + } +} + +impl fmt::Write for Date { + fn write_str(&mut self, s: &str) -> fmt::Result { + let len = s.len(); + self.bytes[self.pos..self.pos + len].copy_from_slice(s.as_bytes()); + self.pos += len; + Ok(()) + } +} + +#[derive(Clone)] +struct DateService(Rc); + +struct DateServiceInner { + current: UnsafeCell>, +} + +impl DateServiceInner { + fn new() -> Self { + DateServiceInner { + current: UnsafeCell::new(None), + } + } + + fn reset(&self) { + unsafe { (&mut *self.current.get()).take() }; + } + + fn update(&self) { + let now = Instant::now(); + let date = Date::new(); + *(unsafe { &mut *self.current.get() }) = Some((date, now)); + } +} + +impl DateService { + fn new() -> Self { + DateService(Rc::new(DateServiceInner::new())) + } + + fn check_date(&self) { + if unsafe { (&*self.0.current.get()).is_none() } { + self.0.update(); + + // periodic date update + let s = self.clone(); + actix_rt::spawn(delay_for(Duration::from_millis(500)).then(move |_| { + s.0.reset(); + future::ready(()) + })); + } + } + + fn now(&self) -> Instant { + self.check_date(); + unsafe { (&*self.0.current.get()).as_ref().unwrap().1 } + } + + fn set_date(&self, mut f: F) { + self.check_date(); + f(&unsafe { (&*self.0.current.get()).as_ref().unwrap().0 }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_date_len() { + assert_eq!(DATE_VALUE_LENGTH, "Sun, 06 Nov 1994 08:49:37 GMT".len()); + } + + #[actix_rt::test] + async fn test_date() { + let settings = ServiceConfig::new(KeepAlive::Os, 0, 0); + let mut buf1 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); + settings.set_date(&mut buf1); + let mut buf2 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); + settings.set_date(&mut buf2); + assert_eq!(buf1, buf2); + } +} diff --git a/actix-http/src/cookie/builder.rs b/actix-http/src/cookie/builder.rs new file mode 100644 index 000000000..efeddbb62 --- /dev/null +++ b/actix-http/src/cookie/builder.rs @@ -0,0 +1,260 @@ +use std::borrow::Cow; + +use chrono::Duration; +use time::Tm; + +use super::{Cookie, SameSite}; + +/// Structure that follows the builder pattern for building `Cookie` structs. +/// +/// To construct a cookie: +/// +/// 1. Call [`Cookie::build`](struct.Cookie.html#method.build) to start building. +/// 2. Use any of the builder methods to set fields in the cookie. +/// 3. Call [finish](#method.finish) to retrieve the built cookie. +/// +/// # Example +/// +/// ```rust +/// use actix_http::cookie::Cookie; +/// +/// # fn main() { +/// let cookie: Cookie = Cookie::build("name", "value") +/// .domain("www.rust-lang.org") +/// .path("/") +/// .secure(true) +/// .http_only(true) +/// .max_age(84600) +/// .finish(); +/// # } +/// ``` +#[derive(Debug, Clone)] +pub struct CookieBuilder { + /// The cookie being built. + cookie: Cookie<'static>, +} + +impl CookieBuilder { + /// Creates a new `CookieBuilder` instance from the given name and value. + /// + /// This method is typically called indirectly via + /// [Cookie::build](struct.Cookie.html#method.build). + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::build("foo", "bar").finish(); + /// assert_eq!(c.name_value(), ("foo", "bar")); + /// ``` + pub fn new(name: N, value: V) -> CookieBuilder + where + N: Into>, + V: Into>, + { + CookieBuilder { + cookie: Cookie::new(name, value), + } + } + + /// Sets the `expires` field in the cookie being built. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// # fn main() { + /// let c = Cookie::build("foo", "bar") + /// .expires(time::now()) + /// .finish(); + /// + /// assert!(c.expires().is_some()); + /// # } + /// ``` + #[inline] + pub fn expires(mut self, when: Tm) -> CookieBuilder { + self.cookie.set_expires(when); + self + } + + /// Sets the `max_age` field in seconds in the cookie being built. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// # fn main() { + /// let c = Cookie::build("foo", "bar") + /// .max_age(1800) + /// .finish(); + /// + /// assert_eq!(c.max_age(), Some(time::Duration::seconds(30 * 60))); + /// # } + /// ``` + #[inline] + pub fn max_age(self, seconds: i64) -> CookieBuilder { + self.max_age_time(Duration::seconds(seconds)) + } + + /// Sets the `max_age` field in the cookie being built. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// # fn main() { + /// let c = Cookie::build("foo", "bar") + /// .max_age_time(time::Duration::minutes(30)) + /// .finish(); + /// + /// assert_eq!(c.max_age(), Some(time::Duration::seconds(30 * 60))); + /// # } + /// ``` + #[inline] + pub fn max_age_time(mut self, value: Duration) -> CookieBuilder { + self.cookie.set_max_age(value); + self + } + + /// Sets the `domain` field in the cookie being built. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::build("foo", "bar") + /// .domain("www.rust-lang.org") + /// .finish(); + /// + /// assert_eq!(c.domain(), Some("www.rust-lang.org")); + /// ``` + pub fn domain>>(mut self, value: D) -> CookieBuilder { + self.cookie.set_domain(value); + self + } + + /// Sets the `path` field in the cookie being built. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::build("foo", "bar") + /// .path("/") + /// .finish(); + /// + /// assert_eq!(c.path(), Some("/")); + /// ``` + pub fn path>>(mut self, path: P) -> CookieBuilder { + self.cookie.set_path(path); + self + } + + /// Sets the `secure` field in the cookie being built. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::build("foo", "bar") + /// .secure(true) + /// .finish(); + /// + /// assert_eq!(c.secure(), Some(true)); + /// ``` + #[inline] + pub fn secure(mut self, value: bool) -> CookieBuilder { + self.cookie.set_secure(value); + self + } + + /// Sets the `http_only` field in the cookie being built. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::build("foo", "bar") + /// .http_only(true) + /// .finish(); + /// + /// assert_eq!(c.http_only(), Some(true)); + /// ``` + #[inline] + pub fn http_only(mut self, value: bool) -> CookieBuilder { + self.cookie.set_http_only(value); + self + } + + /// Sets the `same_site` field in the cookie being built. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{Cookie, SameSite}; + /// + /// let c = Cookie::build("foo", "bar") + /// .same_site(SameSite::Strict) + /// .finish(); + /// + /// assert_eq!(c.same_site(), Some(SameSite::Strict)); + /// ``` + #[inline] + pub fn same_site(mut self, value: SameSite) -> CookieBuilder { + self.cookie.set_same_site(value); + self + } + + /// Makes the cookie being built 'permanent' by extending its expiration and + /// max age 20 years into the future. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// use chrono::Duration; + /// + /// # fn main() { + /// let c = Cookie::build("foo", "bar") + /// .permanent() + /// .finish(); + /// + /// assert_eq!(c.max_age(), Some(Duration::days(365 * 20))); + /// # assert!(c.expires().is_some()); + /// # } + /// ``` + #[inline] + pub fn permanent(mut self) -> CookieBuilder { + self.cookie.make_permanent(); + self + } + + /// Finishes building and returns the built `Cookie`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::build("foo", "bar") + /// .domain("crates.io") + /// .path("/") + /// .finish(); + /// + /// assert_eq!(c.name_value(), ("foo", "bar")); + /// assert_eq!(c.domain(), Some("crates.io")); + /// assert_eq!(c.path(), Some("/")); + /// ``` + #[inline] + pub fn finish(self) -> Cookie<'static> { + self.cookie + } +} diff --git a/actix-http/src/cookie/delta.rs b/actix-http/src/cookie/delta.rs new file mode 100644 index 000000000..a001a5bb8 --- /dev/null +++ b/actix-http/src/cookie/delta.rs @@ -0,0 +1,71 @@ +use std::borrow::Borrow; +use std::hash::{Hash, Hasher}; +use std::ops::{Deref, DerefMut}; + +use super::Cookie; + +/// A `DeltaCookie` is a helper structure used in a cookie jar. It wraps a +/// `Cookie` so that it can be hashed and compared purely by name. It further +/// records whether the wrapped cookie is a "removal" cookie, that is, a cookie +/// that when sent to the client removes the named cookie on the client's +/// machine. +#[derive(Clone, Debug)] +pub struct DeltaCookie { + pub cookie: Cookie<'static>, + pub removed: bool, +} + +impl DeltaCookie { + /// Create a new `DeltaCookie` that is being added to a jar. + #[inline] + pub fn added(cookie: Cookie<'static>) -> DeltaCookie { + DeltaCookie { + cookie, + removed: false, + } + } + + /// Create a new `DeltaCookie` that is being removed from a jar. The + /// `cookie` should be a "removal" cookie. + #[inline] + pub fn removed(cookie: Cookie<'static>) -> DeltaCookie { + DeltaCookie { + cookie, + removed: true, + } + } +} + +impl Deref for DeltaCookie { + type Target = Cookie<'static>; + + fn deref(&self) -> &Cookie<'static> { + &self.cookie + } +} + +impl DerefMut for DeltaCookie { + fn deref_mut(&mut self) -> &mut Cookie<'static> { + &mut self.cookie + } +} + +impl PartialEq for DeltaCookie { + fn eq(&self, other: &DeltaCookie) -> bool { + self.name() == other.name() + } +} + +impl Eq for DeltaCookie {} + +impl Hash for DeltaCookie { + fn hash(&self, state: &mut H) { + self.name().hash(state); + } +} + +impl Borrow for DeltaCookie { + fn borrow(&self) -> &str { + self.name() + } +} diff --git a/actix-http/src/cookie/draft.rs b/actix-http/src/cookie/draft.rs new file mode 100644 index 000000000..362133946 --- /dev/null +++ b/actix-http/src/cookie/draft.rs @@ -0,0 +1,98 @@ +//! This module contains types that represent cookie properties that are not yet +//! standardized. That is, _draft_ features. + +use std::fmt; + +/// The `SameSite` cookie attribute. +/// +/// A cookie with a `SameSite` attribute is imposed restrictions on when it is +/// sent to the origin server in a cross-site request. If the `SameSite` +/// attribute is "Strict", then the cookie is never sent in cross-site requests. +/// If the `SameSite` attribute is "Lax", the cookie is only sent in cross-site +/// requests with "safe" HTTP methods, i.e, `GET`, `HEAD`, `OPTIONS`, `TRACE`. +/// If the `SameSite` attribute is not present (made explicit via the +/// `SameSite::None` variant), then the cookie will be sent as normal. +/// +/// **Note:** This cookie attribute is an HTTP draft! Its meaning and definition +/// are subject to change. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum SameSite { + /// The "Strict" `SameSite` attribute. + Strict, + /// The "Lax" `SameSite` attribute. + Lax, + /// No `SameSite` attribute. + None, +} + +impl SameSite { + /// Returns `true` if `self` is `SameSite::Strict` and `false` otherwise. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::SameSite; + /// + /// let strict = SameSite::Strict; + /// assert!(strict.is_strict()); + /// assert!(!strict.is_lax()); + /// assert!(!strict.is_none()); + /// ``` + #[inline] + pub fn is_strict(self) -> bool { + match self { + SameSite::Strict => true, + SameSite::Lax | SameSite::None => false, + } + } + + /// Returns `true` if `self` is `SameSite::Lax` and `false` otherwise. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::SameSite; + /// + /// let lax = SameSite::Lax; + /// assert!(lax.is_lax()); + /// assert!(!lax.is_strict()); + /// assert!(!lax.is_none()); + /// ``` + #[inline] + pub fn is_lax(self) -> bool { + match self { + SameSite::Lax => true, + SameSite::Strict | SameSite::None => false, + } + } + + /// Returns `true` if `self` is `SameSite::None` and `false` otherwise. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::SameSite; + /// + /// let none = SameSite::None; + /// assert!(none.is_none()); + /// assert!(!none.is_lax()); + /// assert!(!none.is_strict()); + /// ``` + #[inline] + pub fn is_none(self) -> bool { + match self { + SameSite::None => true, + SameSite::Lax | SameSite::Strict => false, + } + } +} + +impl fmt::Display for SameSite { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + SameSite::Strict => write!(f, "Strict"), + SameSite::Lax => write!(f, "Lax"), + SameSite::None => Ok(()), + } + } +} diff --git a/actix-http/src/cookie/jar.rs b/actix-http/src/cookie/jar.rs new file mode 100644 index 000000000..cc67536c6 --- /dev/null +++ b/actix-http/src/cookie/jar.rs @@ -0,0 +1,655 @@ +use std::collections::HashSet; +use std::mem::replace; + +use chrono::Duration; + +use super::delta::DeltaCookie; +use super::Cookie; + +#[cfg(feature = "secure-cookies")] +use super::secure::{Key, PrivateJar, SignedJar}; + +/// A collection of cookies that tracks its modifications. +/// +/// A `CookieJar` provides storage for any number of cookies. Any changes made +/// to the jar are tracked; the changes can be retrieved via the +/// [delta](#method.delta) method which returns an interator over the changes. +/// +/// # Usage +/// +/// A jar's life begins via [new](#method.new) and calls to +/// [`add_original`](#method.add_original): +/// +/// ```rust +/// use actix_http::cookie::{Cookie, CookieJar}; +/// +/// let mut jar = CookieJar::new(); +/// jar.add_original(Cookie::new("name", "value")); +/// jar.add_original(Cookie::new("second", "another")); +/// ``` +/// +/// Cookies can be added via [add](#method.add) and removed via +/// [remove](#method.remove). Finally, cookies can be looked up via +/// [get](#method.get): +/// +/// ```rust +/// # use actix_http::cookie::{Cookie, CookieJar}; +/// let mut jar = CookieJar::new(); +/// jar.add(Cookie::new("a", "one")); +/// jar.add(Cookie::new("b", "two")); +/// +/// assert_eq!(jar.get("a").map(|c| c.value()), Some("one")); +/// assert_eq!(jar.get("b").map(|c| c.value()), Some("two")); +/// +/// jar.remove(Cookie::named("b")); +/// assert!(jar.get("b").is_none()); +/// ``` +/// +/// # Deltas +/// +/// A jar keeps track of any modifications made to it over time. The +/// modifications are recorded as cookies. The modifications can be retrieved +/// via [delta](#method.delta). Any new `Cookie` added to a jar via `add` +/// results in the same `Cookie` appearing in the `delta`; cookies added via +/// `add_original` do not count towards the delta. Any _original_ cookie that is +/// removed from a jar results in a "removal" cookie appearing in the delta. A +/// "removal" cookie is a cookie that a server sends so that the cookie is +/// removed from the client's machine. +/// +/// Deltas are typically used to create `Set-Cookie` headers corresponding to +/// the changes made to a cookie jar over a period of time. +/// +/// ```rust +/// # use actix_http::cookie::{Cookie, CookieJar}; +/// let mut jar = CookieJar::new(); +/// +/// // original cookies don't affect the delta +/// jar.add_original(Cookie::new("original", "value")); +/// assert_eq!(jar.delta().count(), 0); +/// +/// // new cookies result in an equivalent `Cookie` in the delta +/// jar.add(Cookie::new("a", "one")); +/// jar.add(Cookie::new("b", "two")); +/// assert_eq!(jar.delta().count(), 2); +/// +/// // removing an original cookie adds a "removal" cookie to the delta +/// jar.remove(Cookie::named("original")); +/// assert_eq!(jar.delta().count(), 3); +/// +/// // removing a new cookie that was added removes that `Cookie` from the delta +/// jar.remove(Cookie::named("a")); +/// assert_eq!(jar.delta().count(), 2); +/// ``` +#[derive(Default, Debug, Clone)] +pub struct CookieJar { + original_cookies: HashSet, + delta_cookies: HashSet, +} + +impl CookieJar { + /// Creates an empty cookie jar. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::CookieJar; + /// + /// let jar = CookieJar::new(); + /// assert_eq!(jar.iter().count(), 0); + /// ``` + pub fn new() -> CookieJar { + CookieJar::default() + } + + /// Returns a reference to the `Cookie` inside this jar with the name + /// `name`. If no such cookie exists, returns `None`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{CookieJar, Cookie}; + /// + /// let mut jar = CookieJar::new(); + /// assert!(jar.get("name").is_none()); + /// + /// jar.add(Cookie::new("name", "value")); + /// assert_eq!(jar.get("name").map(|c| c.value()), Some("value")); + /// ``` + pub fn get(&self, name: &str) -> Option<&Cookie<'static>> { + self.delta_cookies + .get(name) + .or_else(|| self.original_cookies.get(name)) + .and_then(|c| if !c.removed { Some(&c.cookie) } else { None }) + } + + /// Adds an "original" `cookie` to this jar. If an original cookie with the + /// same name already exists, it is replaced with `cookie`. Cookies added + /// with `add` take precedence and are not replaced by this method. + /// + /// Adding an original cookie does not affect the [delta](#method.delta) + /// computation. This method is intended to be used to seed the cookie jar + /// with cookies received from a client's HTTP message. + /// + /// For accurate `delta` computations, this method should not be called + /// after calling `remove`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{CookieJar, Cookie}; + /// + /// let mut jar = CookieJar::new(); + /// jar.add_original(Cookie::new("name", "value")); + /// jar.add_original(Cookie::new("second", "two")); + /// + /// assert_eq!(jar.get("name").map(|c| c.value()), Some("value")); + /// assert_eq!(jar.get("second").map(|c| c.value()), Some("two")); + /// assert_eq!(jar.iter().count(), 2); + /// assert_eq!(jar.delta().count(), 0); + /// ``` + pub fn add_original(&mut self, cookie: Cookie<'static>) { + self.original_cookies.replace(DeltaCookie::added(cookie)); + } + + /// Adds `cookie` to this jar. If a cookie with the same name already + /// exists, it is replaced with `cookie`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{CookieJar, Cookie}; + /// + /// let mut jar = CookieJar::new(); + /// jar.add(Cookie::new("name", "value")); + /// jar.add(Cookie::new("second", "two")); + /// + /// assert_eq!(jar.get("name").map(|c| c.value()), Some("value")); + /// assert_eq!(jar.get("second").map(|c| c.value()), Some("two")); + /// assert_eq!(jar.iter().count(), 2); + /// assert_eq!(jar.delta().count(), 2); + /// ``` + pub fn add(&mut self, cookie: Cookie<'static>) { + self.delta_cookies.replace(DeltaCookie::added(cookie)); + } + + /// Removes `cookie` from this jar. If an _original_ cookie with the same + /// name as `cookie` is present in the jar, a _removal_ cookie will be + /// present in the `delta` computation. To properly generate the removal + /// cookie, `cookie` must contain the same `path` and `domain` as the cookie + /// that was initially set. + /// + /// A "removal" cookie is a cookie that has the same name as the original + /// cookie but has an empty value, a max-age of 0, and an expiration date + /// far in the past. + /// + /// # Example + /// + /// Removing an _original_ cookie results in a _removal_ cookie: + /// + /// ```rust + /// use actix_http::cookie::{CookieJar, Cookie}; + /// use chrono::Duration; + /// + /// # fn main() { + /// let mut jar = CookieJar::new(); + /// + /// // Assume this cookie originally had a path of "/" and domain of "a.b". + /// jar.add_original(Cookie::new("name", "value")); + /// + /// // If the path and domain were set, they must be provided to `remove`. + /// jar.remove(Cookie::build("name", "").path("/").domain("a.b").finish()); + /// + /// // The delta will contain the removal cookie. + /// let delta: Vec<_> = jar.delta().collect(); + /// assert_eq!(delta.len(), 1); + /// assert_eq!(delta[0].name(), "name"); + /// assert_eq!(delta[0].max_age(), Some(Duration::seconds(0))); + /// # } + /// ``` + /// + /// Removing a new cookie does not result in a _removal_ cookie: + /// + /// ```rust + /// use actix_http::cookie::{CookieJar, Cookie}; + /// + /// let mut jar = CookieJar::new(); + /// jar.add(Cookie::new("name", "value")); + /// assert_eq!(jar.delta().count(), 1); + /// + /// jar.remove(Cookie::named("name")); + /// assert_eq!(jar.delta().count(), 0); + /// ``` + pub fn remove(&mut self, mut cookie: Cookie<'static>) { + if self.original_cookies.contains(cookie.name()) { + cookie.set_value(""); + cookie.set_max_age(Duration::seconds(0)); + cookie.set_expires(time::now() - Duration::days(365)); + self.delta_cookies.replace(DeltaCookie::removed(cookie)); + } else { + self.delta_cookies.remove(cookie.name()); + } + } + + /// Removes `cookie` from this jar completely. This method differs from + /// `remove` in that no delta cookie is created under any condition. Neither + /// the `delta` nor `iter` methods will return a cookie that is removed + /// using this method. + /// + /// # Example + /// + /// Removing an _original_ cookie; no _removal_ cookie is generated: + /// + /// ```rust + /// use actix_http::cookie::{CookieJar, Cookie}; + /// use chrono::Duration; + /// + /// # fn main() { + /// let mut jar = CookieJar::new(); + /// + /// // Add an original cookie and a new cookie. + /// jar.add_original(Cookie::new("name", "value")); + /// jar.add(Cookie::new("key", "value")); + /// assert_eq!(jar.delta().count(), 1); + /// assert_eq!(jar.iter().count(), 2); + /// + /// // Now force remove the original cookie. + /// jar.force_remove(Cookie::new("name", "value")); + /// assert_eq!(jar.delta().count(), 1); + /// assert_eq!(jar.iter().count(), 1); + /// + /// // Now force remove the new cookie. + /// jar.force_remove(Cookie::new("key", "value")); + /// assert_eq!(jar.delta().count(), 0); + /// assert_eq!(jar.iter().count(), 0); + /// # } + /// ``` + pub fn force_remove<'a>(&mut self, cookie: Cookie<'a>) { + self.original_cookies.remove(cookie.name()); + self.delta_cookies.remove(cookie.name()); + } + + /// Removes all cookies from this cookie jar. + #[deprecated( + since = "0.7.0", + note = "calling this method may not remove \ + all cookies since the path and domain are not specified; use \ + `remove` instead" + )] + pub fn clear(&mut self) { + self.delta_cookies.clear(); + for delta in replace(&mut self.original_cookies, HashSet::new()) { + self.remove(delta.cookie); + } + } + + /// Returns an iterator over cookies that represent the changes to this jar + /// over time. These cookies can be rendered directly as `Set-Cookie` header + /// values to affect the changes made to this jar on the client. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{CookieJar, Cookie}; + /// + /// let mut jar = CookieJar::new(); + /// jar.add_original(Cookie::new("name", "value")); + /// jar.add_original(Cookie::new("second", "two")); + /// + /// // Add new cookies. + /// jar.add(Cookie::new("new", "third")); + /// jar.add(Cookie::new("another", "fourth")); + /// jar.add(Cookie::new("yac", "fifth")); + /// + /// // Remove some cookies. + /// jar.remove(Cookie::named("name")); + /// jar.remove(Cookie::named("another")); + /// + /// // Delta contains two new cookies ("new", "yac") and a removal ("name"). + /// assert_eq!(jar.delta().count(), 3); + /// ``` + pub fn delta(&self) -> Delta { + Delta { + iter: self.delta_cookies.iter(), + } + } + + /// Returns an iterator over all of the cookies present in this jar. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{CookieJar, Cookie}; + /// + /// let mut jar = CookieJar::new(); + /// + /// jar.add_original(Cookie::new("name", "value")); + /// jar.add_original(Cookie::new("second", "two")); + /// + /// jar.add(Cookie::new("new", "third")); + /// jar.add(Cookie::new("another", "fourth")); + /// jar.add(Cookie::new("yac", "fifth")); + /// + /// jar.remove(Cookie::named("name")); + /// jar.remove(Cookie::named("another")); + /// + /// // There are three cookies in the jar: "second", "new", and "yac". + /// # assert_eq!(jar.iter().count(), 3); + /// for cookie in jar.iter() { + /// match cookie.name() { + /// "second" => assert_eq!(cookie.value(), "two"), + /// "new" => assert_eq!(cookie.value(), "third"), + /// "yac" => assert_eq!(cookie.value(), "fifth"), + /// _ => unreachable!("there are only three cookies in the jar") + /// } + /// } + /// ``` + pub fn iter(&self) -> Iter { + Iter { + delta_cookies: self + .delta_cookies + .iter() + .chain(self.original_cookies.difference(&self.delta_cookies)), + } + } + + /// Returns a `PrivateJar` with `self` as its parent jar using the key `key` + /// to sign/encrypt and verify/decrypt cookies added/retrieved from the + /// child jar. + /// + /// Any modifications to the child jar will be reflected on the parent jar, + /// and any retrievals from the child jar will be made from the parent jar. + /// + /// This method is only available when the `secure` feature is enabled. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{Cookie, CookieJar, Key}; + /// + /// // Generate a secure key. + /// let key = Key::generate(); + /// + /// // Add a private (signed + encrypted) cookie. + /// let mut jar = CookieJar::new(); + /// jar.private(&key).add(Cookie::new("private", "text")); + /// + /// // The cookie's contents are encrypted. + /// assert_ne!(jar.get("private").unwrap().value(), "text"); + /// + /// // They can be decrypted and verified through the child jar. + /// assert_eq!(jar.private(&key).get("private").unwrap().value(), "text"); + /// + /// // A tampered with cookie does not validate but still exists. + /// let mut cookie = jar.get("private").unwrap().clone(); + /// jar.add(Cookie::new("private", cookie.value().to_string() + "!")); + /// assert!(jar.private(&key).get("private").is_none()); + /// assert!(jar.get("private").is_some()); + /// ``` + #[cfg(feature = "secure-cookies")] + pub fn private(&mut self, key: &Key) -> PrivateJar { + PrivateJar::new(self, key) + } + + /// Returns a `SignedJar` with `self` as its parent jar using the key `key` + /// to sign/verify cookies added/retrieved from the child jar. + /// + /// Any modifications to the child jar will be reflected on the parent jar, + /// and any retrievals from the child jar will be made from the parent jar. + /// + /// This method is only available when the `secure` feature is enabled. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{Cookie, CookieJar, Key}; + /// + /// // Generate a secure key. + /// let key = Key::generate(); + /// + /// // Add a signed cookie. + /// let mut jar = CookieJar::new(); + /// jar.signed(&key).add(Cookie::new("signed", "text")); + /// + /// // The cookie's contents are signed but still in plaintext. + /// assert_ne!(jar.get("signed").unwrap().value(), "text"); + /// assert!(jar.get("signed").unwrap().value().contains("text")); + /// + /// // They can be verified through the child jar. + /// assert_eq!(jar.signed(&key).get("signed").unwrap().value(), "text"); + /// + /// // A tampered with cookie does not validate but still exists. + /// let mut cookie = jar.get("signed").unwrap().clone(); + /// jar.add(Cookie::new("signed", cookie.value().to_string() + "!")); + /// assert!(jar.signed(&key).get("signed").is_none()); + /// assert!(jar.get("signed").is_some()); + /// ``` + #[cfg(feature = "secure-cookies")] + pub fn signed(&mut self, key: &Key) -> SignedJar { + SignedJar::new(self, key) + } +} + +use std::collections::hash_set::Iter as HashSetIter; + +/// Iterator over the changes to a cookie jar. +pub struct Delta<'a> { + iter: HashSetIter<'a, DeltaCookie>, +} + +impl<'a> Iterator for Delta<'a> { + type Item = &'a Cookie<'static>; + + fn next(&mut self) -> Option<&'a Cookie<'static>> { + self.iter.next().map(|c| &c.cookie) + } +} + +use std::collections::hash_map::RandomState; +use std::collections::hash_set::Difference; +use std::iter::Chain; + +/// Iterator over all of the cookies in a jar. +pub struct Iter<'a> { + delta_cookies: + Chain, Difference<'a, DeltaCookie, RandomState>>, +} + +impl<'a> Iterator for Iter<'a> { + type Item = &'a Cookie<'static>; + + fn next(&mut self) -> Option<&'a Cookie<'static>> { + for cookie in self.delta_cookies.by_ref() { + if !cookie.removed { + return Some(&*cookie); + } + } + + None + } +} + +#[cfg(test)] +mod test { + #[cfg(feature = "secure-cookies")] + use super::Key; + use super::{Cookie, CookieJar}; + + #[test] + #[allow(deprecated)] + fn simple() { + let mut c = CookieJar::new(); + + c.add(Cookie::new("test", "")); + c.add(Cookie::new("test2", "")); + c.remove(Cookie::named("test")); + + assert!(c.get("test").is_none()); + assert!(c.get("test2").is_some()); + + c.add(Cookie::new("test3", "")); + c.clear(); + + assert!(c.get("test").is_none()); + assert!(c.get("test2").is_none()); + assert!(c.get("test3").is_none()); + } + + #[test] + fn jar_is_send() { + fn is_send(_: T) -> bool { + true + } + + assert!(is_send(CookieJar::new())) + } + + #[test] + #[cfg(feature = "secure-cookies")] + fn iter() { + let key = Key::generate(); + let mut c = CookieJar::new(); + + c.add_original(Cookie::new("original", "original")); + + c.add(Cookie::new("test", "test")); + c.add(Cookie::new("test2", "test2")); + c.add(Cookie::new("test3", "test3")); + assert_eq!(c.iter().count(), 4); + + c.signed(&key).add(Cookie::new("signed", "signed")); + c.private(&key).add(Cookie::new("encrypted", "encrypted")); + assert_eq!(c.iter().count(), 6); + + c.remove(Cookie::named("test")); + assert_eq!(c.iter().count(), 5); + + c.remove(Cookie::named("signed")); + c.remove(Cookie::named("test2")); + assert_eq!(c.iter().count(), 3); + + c.add(Cookie::new("test2", "test2")); + assert_eq!(c.iter().count(), 4); + + c.remove(Cookie::named("test2")); + assert_eq!(c.iter().count(), 3); + } + + #[test] + #[cfg(feature = "secure-cookies")] + fn delta() { + use chrono::Duration; + use std::collections::HashMap; + + let mut c = CookieJar::new(); + + c.add_original(Cookie::new("original", "original")); + c.add_original(Cookie::new("original1", "original1")); + + c.add(Cookie::new("test", "test")); + c.add(Cookie::new("test2", "test2")); + c.add(Cookie::new("test3", "test3")); + c.add(Cookie::new("test4", "test4")); + + c.remove(Cookie::named("test")); + c.remove(Cookie::named("original")); + + assert_eq!(c.delta().count(), 4); + + let names: HashMap<_, _> = c.delta().map(|c| (c.name(), c.max_age())).collect(); + + assert!(names.get("test2").unwrap().is_none()); + assert!(names.get("test3").unwrap().is_none()); + assert!(names.get("test4").unwrap().is_none()); + assert_eq!(names.get("original").unwrap(), &Some(Duration::seconds(0))); + } + + #[test] + fn replace_original() { + let mut jar = CookieJar::new(); + jar.add_original(Cookie::new("original_a", "a")); + jar.add_original(Cookie::new("original_b", "b")); + assert_eq!(jar.get("original_a").unwrap().value(), "a"); + + jar.add(Cookie::new("original_a", "av2")); + assert_eq!(jar.get("original_a").unwrap().value(), "av2"); + } + + #[test] + fn empty_delta() { + let mut jar = CookieJar::new(); + jar.add(Cookie::new("name", "val")); + assert_eq!(jar.delta().count(), 1); + + jar.remove(Cookie::named("name")); + assert_eq!(jar.delta().count(), 0); + + jar.add_original(Cookie::new("name", "val")); + assert_eq!(jar.delta().count(), 0); + + jar.remove(Cookie::named("name")); + assert_eq!(jar.delta().count(), 1); + + jar.add(Cookie::new("name", "val")); + assert_eq!(jar.delta().count(), 1); + + jar.remove(Cookie::named("name")); + assert_eq!(jar.delta().count(), 1); + } + + #[test] + fn add_remove_add() { + let mut jar = CookieJar::new(); + jar.add_original(Cookie::new("name", "val")); + assert_eq!(jar.delta().count(), 0); + + jar.remove(Cookie::named("name")); + assert_eq!(jar.delta().filter(|c| c.value().is_empty()).count(), 1); + assert_eq!(jar.delta().count(), 1); + + // The cookie's been deleted. Another original doesn't change that. + jar.add_original(Cookie::new("name", "val")); + assert_eq!(jar.delta().filter(|c| c.value().is_empty()).count(), 1); + assert_eq!(jar.delta().count(), 1); + + jar.remove(Cookie::named("name")); + assert_eq!(jar.delta().filter(|c| c.value().is_empty()).count(), 1); + assert_eq!(jar.delta().count(), 1); + + jar.add(Cookie::new("name", "val")); + assert_eq!(jar.delta().filter(|c| !c.value().is_empty()).count(), 1); + assert_eq!(jar.delta().count(), 1); + + jar.remove(Cookie::named("name")); + assert_eq!(jar.delta().filter(|c| c.value().is_empty()).count(), 1); + assert_eq!(jar.delta().count(), 1); + } + + #[test] + fn replace_remove() { + let mut jar = CookieJar::new(); + jar.add_original(Cookie::new("name", "val")); + assert_eq!(jar.delta().count(), 0); + + jar.add(Cookie::new("name", "val")); + assert_eq!(jar.delta().count(), 1); + assert_eq!(jar.delta().filter(|c| !c.value().is_empty()).count(), 1); + + jar.remove(Cookie::named("name")); + assert_eq!(jar.delta().filter(|c| c.value().is_empty()).count(), 1); + } + + #[test] + fn remove_with_path() { + let mut jar = CookieJar::new(); + jar.add_original(Cookie::build("name", "val").finish()); + assert_eq!(jar.iter().count(), 1); + assert_eq!(jar.delta().count(), 0); + assert_eq!(jar.iter().filter(|c| c.path().is_none()).count(), 1); + + jar.remove(Cookie::build("name", "").path("/").finish()); + assert_eq!(jar.iter().count(), 0); + assert_eq!(jar.delta().count(), 1); + assert_eq!(jar.delta().filter(|c| c.value().is_empty()).count(), 1); + assert_eq!(jar.delta().filter(|c| c.path() == Some("/")).count(), 1); + } +} diff --git a/actix-http/src/cookie/mod.rs b/actix-http/src/cookie/mod.rs new file mode 100644 index 000000000..db8211427 --- /dev/null +++ b/actix-http/src/cookie/mod.rs @@ -0,0 +1,1106 @@ +//! https://github.com/alexcrichton/cookie-rs fork +//! +//! HTTP cookie parsing and cookie jar management. +//! +//! This crates provides the [`Cookie`](struct.Cookie.html) type, which directly +//! maps to an HTTP cookie, and the [`CookieJar`](struct.CookieJar.html) type, +//! which allows for simple management of many cookies as well as encryption and +//! signing of cookies for session management. +//! +//! # Features +//! +//! This crates can be configured at compile-time through the following Cargo +//! features: +//! +//! +//! * **secure** (disabled by default) +//! +//! Enables signed and private (signed + encrypted) cookie jars. +//! +//! When this feature is enabled, the +//! [signed](struct.CookieJar.html#method.signed) and +//! [private](struct.CookieJar.html#method.private) method of `CookieJar` and +//! [`SignedJar`](struct.SignedJar.html) and +//! [`PrivateJar`](struct.PrivateJar.html) structures are available. The jars +//! act as "children jars", allowing for easy retrieval and addition of signed +//! and/or encrypted cookies to a cookie jar. When this feature is disabled, +//! none of the types are available. +//! +//! * **percent-encode** (disabled by default) +//! +//! Enables percent encoding and decoding of names and values in cookies. +//! +//! When this feature is enabled, the +//! [encoded](struct.Cookie.html#method.encoded) and +//! [`parse_encoded`](struct.Cookie.html#method.parse_encoded) methods of +//! `Cookie` become available. The `encoded` method returns a wrapper around a +//! `Cookie` whose `Display` implementation percent-encodes the name and value +//! of the cookie. The `parse_encoded` method percent-decodes the name and +//! value of a `Cookie` during parsing. When this feature is disabled, the +//! `encoded` and `parse_encoded` methods are not available. +//! +//! You can enable features via the `Cargo.toml` file: +//! +//! ```ignore +//! [dependencies.cookie] +//! features = ["secure", "percent-encode"] +//! ``` + +#![doc(html_root_url = "https://docs.rs/cookie/0.11")] +#![deny(missing_docs)] + +mod builder; +mod delta; +mod draft; +mod jar; +mod parse; + +#[cfg(feature = "secure-cookies")] +#[macro_use] +mod secure; +#[cfg(feature = "secure-cookies")] +pub use self::secure::*; + +use std::borrow::Cow; +use std::fmt; +use std::str::FromStr; + +use chrono::Duration; +use percent_encoding::{percent_encode, AsciiSet, CONTROLS}; +use time::Tm; + +pub use self::builder::CookieBuilder; +pub use self::draft::*; +pub use self::jar::{CookieJar, Delta, Iter}; +use self::parse::parse_cookie; +pub use self::parse::ParseError; + +/// https://url.spec.whatwg.org/#fragment-percent-encode-set +const FRAGMENT: &AsciiSet = &CONTROLS.add(b' ').add(b'"').add(b'<').add(b'>').add(b'`'); + +/// https://url.spec.whatwg.org/#path-percent-encode-set +const PATH: &AsciiSet = &FRAGMENT.add(b'#').add(b'?').add(b'{').add(b'}'); + +/// https://url.spec.whatwg.org/#userinfo-percent-encode-set +pub const USERINFO: &AsciiSet = &PATH + .add(b'/') + .add(b':') + .add(b';') + .add(b'=') + .add(b'@') + .add(b'[') + .add(b'\\') + .add(b']') + .add(b'^') + .add(b'|'); + +#[derive(Debug, Clone)] +enum CookieStr { + /// An string derived from indexes (start, end). + Indexed(usize, usize), + /// A string derived from a concrete string. + Concrete(Cow<'static, str>), +} + +impl CookieStr { + /// Retrieves the string `self` corresponds to. If `self` is derived from + /// indexes, the corresponding subslice of `string` is returned. Otherwise, + /// the concrete string is returned. + /// + /// # Panics + /// + /// Panics if `self` is an indexed string and `string` is None. + fn to_str<'s>(&'s self, string: Option<&'s Cow>) -> &'s str { + match *self { + CookieStr::Indexed(i, j) => { + let s = string.expect( + "`Some` base string must exist when \ + converting indexed str to str! (This is a module invariant.)", + ); + &s[i..j] + } + CookieStr::Concrete(ref cstr) => &*cstr, + } + } + + #[allow(clippy::ptr_arg)] + fn to_raw_str<'s, 'c: 's>(&'s self, string: &'s Cow<'c, str>) -> Option<&'c str> { + match *self { + CookieStr::Indexed(i, j) => match *string { + Cow::Borrowed(s) => Some(&s[i..j]), + Cow::Owned(_) => None, + }, + CookieStr::Concrete(_) => None, + } + } +} + +/// Representation of an HTTP cookie. +/// +/// # Constructing a `Cookie` +/// +/// To construct a cookie with only a name/value, use the [new](#method.new) +/// method: +/// +/// ```rust +/// use actix_http::cookie::Cookie; +/// +/// let cookie = Cookie::new("name", "value"); +/// assert_eq!(&cookie.to_string(), "name=value"); +/// ``` +/// +/// To construct more elaborate cookies, use the [build](#method.build) method +/// and [`CookieBuilder`](struct.CookieBuilder.html) methods: +/// +/// ```rust +/// use actix_http::cookie::Cookie; +/// +/// let cookie = Cookie::build("name", "value") +/// .domain("www.rust-lang.org") +/// .path("/") +/// .secure(true) +/// .http_only(true) +/// .finish(); +/// ``` +#[derive(Debug, Clone)] +pub struct Cookie<'c> { + /// Storage for the cookie string. Only used if this structure was derived + /// from a string that was subsequently parsed. + cookie_string: Option>, + /// The cookie's name. + name: CookieStr, + /// The cookie's value. + value: CookieStr, + /// The cookie's expiration, if any. + expires: Option, + /// The cookie's maximum age, if any. + max_age: Option, + /// The cookie's domain, if any. + domain: Option, + /// The cookie's path domain, if any. + path: Option, + /// Whether this cookie was marked Secure. + secure: Option, + /// Whether this cookie was marked HttpOnly. + http_only: Option, + /// The draft `SameSite` attribute. + same_site: Option, +} + +impl Cookie<'static> { + /// Creates a new `Cookie` with the given name and value. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let cookie = Cookie::new("name", "value"); + /// assert_eq!(cookie.name_value(), ("name", "value")); + /// ``` + pub fn new(name: N, value: V) -> Cookie<'static> + where + N: Into>, + V: Into>, + { + Cookie { + cookie_string: None, + name: CookieStr::Concrete(name.into()), + value: CookieStr::Concrete(value.into()), + expires: None, + max_age: None, + domain: None, + path: None, + secure: None, + http_only: None, + same_site: None, + } + } + + /// Creates a new `Cookie` with the given name and an empty value. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let cookie = Cookie::named("name"); + /// assert_eq!(cookie.name(), "name"); + /// assert!(cookie.value().is_empty()); + /// ``` + pub fn named(name: N) -> Cookie<'static> + where + N: Into>, + { + Cookie::new(name, "") + } + + /// Creates a new `CookieBuilder` instance from the given key and value + /// strings. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::build("foo", "bar").finish(); + /// assert_eq!(c.name_value(), ("foo", "bar")); + /// ``` + pub fn build(name: N, value: V) -> CookieBuilder + where + N: Into>, + V: Into>, + { + CookieBuilder::new(name, value) + } +} + +impl<'c> Cookie<'c> { + /// Parses a `Cookie` from the given HTTP cookie header value string. Does + /// not perform any percent-decoding. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::parse("foo=bar%20baz; HttpOnly").unwrap(); + /// assert_eq!(c.name_value(), ("foo", "bar%20baz")); + /// assert_eq!(c.http_only(), Some(true)); + /// ``` + pub fn parse(s: S) -> Result, ParseError> + where + S: Into>, + { + parse_cookie(s, false) + } + + /// Parses a `Cookie` from the given HTTP cookie header value string where + /// the name and value fields are percent-encoded. Percent-decodes the + /// name/value fields. + /// + /// This API requires the `percent-encode` feature to be enabled on this + /// crate. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::parse_encoded("foo=bar%20baz; HttpOnly").unwrap(); + /// assert_eq!(c.name_value(), ("foo", "bar baz")); + /// assert_eq!(c.http_only(), Some(true)); + /// ``` + pub fn parse_encoded(s: S) -> Result, ParseError> + where + S: Into>, + { + parse_cookie(s, true) + } + + /// Wraps `self` in an `EncodedCookie`: a cost-free wrapper around `Cookie` + /// whose `Display` implementation percent-encodes the name and value of the + /// wrapped `Cookie`. + /// + /// This method is only available when the `percent-encode` feature is + /// enabled. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let mut c = Cookie::new("my name", "this; value?"); + /// assert_eq!(&c.encoded().to_string(), "my%20name=this%3B%20value%3F"); + /// ``` + pub fn encoded<'a>(&'a self) -> EncodedCookie<'a, 'c> { + EncodedCookie(self) + } + + /// Converts `self` into a `Cookie` with a static lifetime. This method + /// results in at most one allocation. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::new("a", "b"); + /// let owned_cookie = c.into_owned(); + /// assert_eq!(owned_cookie.name_value(), ("a", "b")); + /// ``` + pub fn into_owned(self) -> Cookie<'static> { + Cookie { + cookie_string: self.cookie_string.map(|s| s.into_owned().into()), + name: self.name, + value: self.value, + expires: self.expires, + max_age: self.max_age, + domain: self.domain, + path: self.path, + secure: self.secure, + http_only: self.http_only, + same_site: self.same_site, + } + } + + /// Returns the name of `self`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::new("name", "value"); + /// assert_eq!(c.name(), "name"); + /// ``` + #[inline] + pub fn name(&self) -> &str { + self.name.to_str(self.cookie_string.as_ref()) + } + + /// Returns the value of `self`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::new("name", "value"); + /// assert_eq!(c.value(), "value"); + /// ``` + #[inline] + pub fn value(&self) -> &str { + self.value.to_str(self.cookie_string.as_ref()) + } + + /// Returns the name and value of `self` as a tuple of `(name, value)`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::new("name", "value"); + /// assert_eq!(c.name_value(), ("name", "value")); + /// ``` + #[inline] + pub fn name_value(&self) -> (&str, &str) { + (self.name(), self.value()) + } + + /// Returns whether this cookie was marked `HttpOnly` or not. Returns + /// `Some(true)` when the cookie was explicitly set (manually or parsed) as + /// `HttpOnly`, `Some(false)` when `http_only` was manually set to `false`, + /// and `None` otherwise. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::parse("name=value; httponly").unwrap(); + /// assert_eq!(c.http_only(), Some(true)); + /// + /// let mut c = Cookie::new("name", "value"); + /// assert_eq!(c.http_only(), None); + /// + /// let mut c = Cookie::new("name", "value"); + /// assert_eq!(c.http_only(), None); + /// + /// // An explicitly set "false" value. + /// c.set_http_only(false); + /// assert_eq!(c.http_only(), Some(false)); + /// + /// // An explicitly set "true" value. + /// c.set_http_only(true); + /// assert_eq!(c.http_only(), Some(true)); + /// ``` + #[inline] + pub fn http_only(&self) -> Option { + self.http_only + } + + /// Returns whether this cookie was marked `Secure` or not. Returns + /// `Some(true)` when the cookie was explicitly set (manually or parsed) as + /// `Secure`, `Some(false)` when `secure` was manually set to `false`, and + /// `None` otherwise. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::parse("name=value; Secure").unwrap(); + /// assert_eq!(c.secure(), Some(true)); + /// + /// let mut c = Cookie::parse("name=value").unwrap(); + /// assert_eq!(c.secure(), None); + /// + /// let mut c = Cookie::new("name", "value"); + /// assert_eq!(c.secure(), None); + /// + /// // An explicitly set "false" value. + /// c.set_secure(false); + /// assert_eq!(c.secure(), Some(false)); + /// + /// // An explicitly set "true" value. + /// c.set_secure(true); + /// assert_eq!(c.secure(), Some(true)); + /// ``` + #[inline] + pub fn secure(&self) -> Option { + self.secure + } + + /// Returns the `SameSite` attribute of this cookie if one was specified. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{Cookie, SameSite}; + /// + /// let c = Cookie::parse("name=value; SameSite=Lax").unwrap(); + /// assert_eq!(c.same_site(), Some(SameSite::Lax)); + /// ``` + #[inline] + pub fn same_site(&self) -> Option { + self.same_site + } + + /// Returns the specified max-age of the cookie if one was specified. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::parse("name=value").unwrap(); + /// assert_eq!(c.max_age(), None); + /// + /// let c = Cookie::parse("name=value; Max-Age=3600").unwrap(); + /// assert_eq!(c.max_age().map(|age| age.num_hours()), Some(1)); + /// ``` + #[inline] + pub fn max_age(&self) -> Option { + self.max_age + } + + /// Returns the `Path` of the cookie if one was specified. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::parse("name=value").unwrap(); + /// assert_eq!(c.path(), None); + /// + /// let c = Cookie::parse("name=value; Path=/").unwrap(); + /// assert_eq!(c.path(), Some("/")); + /// + /// let c = Cookie::parse("name=value; path=/sub").unwrap(); + /// assert_eq!(c.path(), Some("/sub")); + /// ``` + #[inline] + pub fn path(&self) -> Option<&str> { + match self.path { + Some(ref c) => Some(c.to_str(self.cookie_string.as_ref())), + None => None, + } + } + + /// Returns the `Domain` of the cookie if one was specified. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::parse("name=value").unwrap(); + /// assert_eq!(c.domain(), None); + /// + /// let c = Cookie::parse("name=value; Domain=crates.io").unwrap(); + /// assert_eq!(c.domain(), Some("crates.io")); + /// ``` + #[inline] + pub fn domain(&self) -> Option<&str> { + match self.domain { + Some(ref c) => Some(c.to_str(self.cookie_string.as_ref())), + None => None, + } + } + + /// Returns the `Expires` time of the cookie if one was specified. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let c = Cookie::parse("name=value").unwrap(); + /// assert_eq!(c.expires(), None); + /// + /// let expire_time = "Wed, 21 Oct 2017 07:28:00 GMT"; + /// let cookie_str = format!("name=value; Expires={}", expire_time); + /// let c = Cookie::parse(cookie_str).unwrap(); + /// assert_eq!(c.expires().map(|t| t.tm_year), Some(117)); + /// ``` + #[inline] + pub fn expires(&self) -> Option { + self.expires + } + + /// Sets the name of `self` to `name`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let mut c = Cookie::new("name", "value"); + /// assert_eq!(c.name(), "name"); + /// + /// c.set_name("foo"); + /// assert_eq!(c.name(), "foo"); + /// ``` + pub fn set_name>>(&mut self, name: N) { + self.name = CookieStr::Concrete(name.into()) + } + + /// Sets the value of `self` to `value`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let mut c = Cookie::new("name", "value"); + /// assert_eq!(c.value(), "value"); + /// + /// c.set_value("bar"); + /// assert_eq!(c.value(), "bar"); + /// ``` + pub fn set_value>>(&mut self, value: V) { + self.value = CookieStr::Concrete(value.into()) + } + + /// Sets the value of `http_only` in `self` to `value`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let mut c = Cookie::new("name", "value"); + /// assert_eq!(c.http_only(), None); + /// + /// c.set_http_only(true); + /// assert_eq!(c.http_only(), Some(true)); + /// ``` + #[inline] + pub fn set_http_only(&mut self, value: bool) { + self.http_only = Some(value); + } + + /// Sets the value of `secure` in `self` to `value`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let mut c = Cookie::new("name", "value"); + /// assert_eq!(c.secure(), None); + /// + /// c.set_secure(true); + /// assert_eq!(c.secure(), Some(true)); + /// ``` + #[inline] + pub fn set_secure(&mut self, value: bool) { + self.secure = Some(value); + } + + /// Sets the value of `same_site` in `self` to `value`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{Cookie, SameSite}; + /// + /// let mut c = Cookie::new("name", "value"); + /// assert!(c.same_site().is_none()); + /// + /// c.set_same_site(SameSite::Strict); + /// assert_eq!(c.same_site(), Some(SameSite::Strict)); + /// ``` + #[inline] + pub fn set_same_site(&mut self, value: SameSite) { + self.same_site = Some(value); + } + + /// Sets the value of `max_age` in `self` to `value`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// use chrono::Duration; + /// + /// # fn main() { + /// let mut c = Cookie::new("name", "value"); + /// assert_eq!(c.max_age(), None); + /// + /// c.set_max_age(Duration::hours(10)); + /// assert_eq!(c.max_age(), Some(Duration::hours(10))); + /// # } + /// ``` + #[inline] + pub fn set_max_age(&mut self, value: Duration) { + self.max_age = Some(value); + } + + /// Sets the `path` of `self` to `path`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let mut c = Cookie::new("name", "value"); + /// assert_eq!(c.path(), None); + /// + /// c.set_path("/"); + /// assert_eq!(c.path(), Some("/")); + /// ``` + pub fn set_path>>(&mut self, path: P) { + self.path = Some(CookieStr::Concrete(path.into())); + } + + /// Sets the `domain` of `self` to `domain`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let mut c = Cookie::new("name", "value"); + /// assert_eq!(c.domain(), None); + /// + /// c.set_domain("rust-lang.org"); + /// assert_eq!(c.domain(), Some("rust-lang.org")); + /// ``` + pub fn set_domain>>(&mut self, domain: D) { + self.domain = Some(CookieStr::Concrete(domain.into())); + } + + /// Sets the expires field of `self` to `time`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// # fn main() { + /// let mut c = Cookie::new("name", "value"); + /// assert_eq!(c.expires(), None); + /// + /// let mut now = time::now(); + /// now.tm_year += 1; + /// + /// c.set_expires(now); + /// assert!(c.expires().is_some()) + /// # } + /// ``` + #[inline] + pub fn set_expires(&mut self, time: Tm) { + self.expires = Some(time); + } + + /// Makes `self` a "permanent" cookie by extending its expiration and max + /// age 20 years into the future. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// use chrono::Duration; + /// + /// # fn main() { + /// let mut c = Cookie::new("foo", "bar"); + /// assert!(c.expires().is_none()); + /// assert!(c.max_age().is_none()); + /// + /// c.make_permanent(); + /// assert!(c.expires().is_some()); + /// assert_eq!(c.max_age(), Some(Duration::days(365 * 20))); + /// # } + /// ``` + pub fn make_permanent(&mut self) { + let twenty_years = Duration::days(365 * 20); + self.set_max_age(twenty_years); + self.set_expires(time::now() + twenty_years); + } + + fn fmt_parameters(&self, f: &mut fmt::Formatter) -> fmt::Result { + if let Some(true) = self.http_only() { + write!(f, "; HttpOnly")?; + } + + if let Some(true) = self.secure() { + write!(f, "; Secure")?; + } + + if let Some(same_site) = self.same_site() { + if !same_site.is_none() { + write!(f, "; SameSite={}", same_site)?; + } + } + + if let Some(path) = self.path() { + write!(f, "; Path={}", path)?; + } + + if let Some(domain) = self.domain() { + write!(f, "; Domain={}", domain)?; + } + + if let Some(max_age) = self.max_age() { + write!(f, "; Max-Age={}", max_age.num_seconds())?; + } + + if let Some(time) = self.expires() { + write!(f, "; Expires={}", time.rfc822())?; + } + + Ok(()) + } + + /// Returns the name of `self` as a string slice of the raw string `self` + /// was originally parsed from. If `self` was not originally parsed from a + /// raw string, returns `None`. + /// + /// This method differs from [name](#method.name) in that it returns a + /// string with the same lifetime as the originally parsed string. This + /// lifetime may outlive `self`. If a longer lifetime is not required, or + /// you're unsure if you need a longer lifetime, use [name](#method.name). + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let cookie_string = format!("{}={}", "foo", "bar"); + /// + /// // `c` will be dropped at the end of the scope, but `name` will live on + /// let name = { + /// let c = Cookie::parse(cookie_string.as_str()).unwrap(); + /// c.name_raw() + /// }; + /// + /// assert_eq!(name, Some("foo")); + /// ``` + #[inline] + pub fn name_raw(&self) -> Option<&'c str> { + self.cookie_string + .as_ref() + .and_then(|s| self.name.to_raw_str(s)) + } + + /// Returns the value of `self` as a string slice of the raw string `self` + /// was originally parsed from. If `self` was not originally parsed from a + /// raw string, returns `None`. + /// + /// This method differs from [value](#method.value) in that it returns a + /// string with the same lifetime as the originally parsed string. This + /// lifetime may outlive `self`. If a longer lifetime is not required, or + /// you're unsure if you need a longer lifetime, use [value](#method.value). + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let cookie_string = format!("{}={}", "foo", "bar"); + /// + /// // `c` will be dropped at the end of the scope, but `value` will live on + /// let value = { + /// let c = Cookie::parse(cookie_string.as_str()).unwrap(); + /// c.value_raw() + /// }; + /// + /// assert_eq!(value, Some("bar")); + /// ``` + #[inline] + pub fn value_raw(&self) -> Option<&'c str> { + self.cookie_string + .as_ref() + .and_then(|s| self.value.to_raw_str(s)) + } + + /// Returns the `Path` of `self` as a string slice of the raw string `self` + /// was originally parsed from. If `self` was not originally parsed from a + /// raw string, or if `self` doesn't contain a `Path`, or if the `Path` has + /// changed since parsing, returns `None`. + /// + /// This method differs from [path](#method.path) in that it returns a + /// string with the same lifetime as the originally parsed string. This + /// lifetime may outlive `self`. If a longer lifetime is not required, or + /// you're unsure if you need a longer lifetime, use [path](#method.path). + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let cookie_string = format!("{}={}; Path=/", "foo", "bar"); + /// + /// // `c` will be dropped at the end of the scope, but `path` will live on + /// let path = { + /// let c = Cookie::parse(cookie_string.as_str()).unwrap(); + /// c.path_raw() + /// }; + /// + /// assert_eq!(path, Some("/")); + /// ``` + #[inline] + pub fn path_raw(&self) -> Option<&'c str> { + match (self.path.as_ref(), self.cookie_string.as_ref()) { + (Some(path), Some(string)) => path.to_raw_str(string), + _ => None, + } + } + + /// Returns the `Domain` of `self` as a string slice of the raw string + /// `self` was originally parsed from. If `self` was not originally parsed + /// from a raw string, or if `self` doesn't contain a `Domain`, or if the + /// `Domain` has changed since parsing, returns `None`. + /// + /// This method differs from [domain](#method.domain) in that it returns a + /// string with the same lifetime as the originally parsed string. This + /// lifetime may outlive `self` struct. If a longer lifetime is not + /// required, or you're unsure if you need a longer lifetime, use + /// [domain](#method.domain). + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let cookie_string = format!("{}={}; Domain=crates.io", "foo", "bar"); + /// + /// //`c` will be dropped at the end of the scope, but `domain` will live on + /// let domain = { + /// let c = Cookie::parse(cookie_string.as_str()).unwrap(); + /// c.domain_raw() + /// }; + /// + /// assert_eq!(domain, Some("crates.io")); + /// ``` + #[inline] + pub fn domain_raw(&self) -> Option<&'c str> { + match (self.domain.as_ref(), self.cookie_string.as_ref()) { + (Some(domain), Some(string)) => domain.to_raw_str(string), + _ => None, + } + } +} + +/// Wrapper around `Cookie` whose `Display` implementation percent-encodes the +/// cookie's name and value. +/// +/// A value of this type can be obtained via the +/// [encoded](struct.Cookie.html#method.encoded) method on +/// [Cookie](struct.Cookie.html). This type should only be used for its +/// `Display` implementation. +/// +/// This type is only available when the `percent-encode` feature is enabled. +/// +/// # Example +/// +/// ```rust +/// use actix_http::cookie::Cookie; +/// +/// let mut c = Cookie::new("my name", "this; value?"); +/// assert_eq!(&c.encoded().to_string(), "my%20name=this%3B%20value%3F"); +/// ``` +pub struct EncodedCookie<'a, 'c: 'a>(&'a Cookie<'c>); + +impl<'a, 'c: 'a> fmt::Display for EncodedCookie<'a, 'c> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // Percent-encode the name and value. + let name = percent_encode(self.0.name().as_bytes(), USERINFO); + let value = percent_encode(self.0.value().as_bytes(), USERINFO); + + // Write out the name/value pair and the cookie's parameters. + write!(f, "{}={}", name, value)?; + self.0.fmt_parameters(f) + } +} + +impl<'c> fmt::Display for Cookie<'c> { + /// Formats the cookie `self` as a `Set-Cookie` header value. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Cookie; + /// + /// let mut cookie = Cookie::build("foo", "bar") + /// .path("/") + /// .finish(); + /// + /// assert_eq!(&cookie.to_string(), "foo=bar; Path=/"); + /// ``` + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}={}", self.name(), self.value())?; + self.fmt_parameters(f) + } +} + +impl FromStr for Cookie<'static> { + type Err = ParseError; + + fn from_str(s: &str) -> Result, ParseError> { + Cookie::parse(s).map(|c| c.into_owned()) + } +} + +impl<'a, 'b> PartialEq> for Cookie<'a> { + fn eq(&self, other: &Cookie<'b>) -> bool { + let so_far_so_good = self.name() == other.name() + && self.value() == other.value() + && self.http_only() == other.http_only() + && self.secure() == other.secure() + && self.max_age() == other.max_age() + && self.expires() == other.expires(); + + if !so_far_so_good { + return false; + } + + match (self.path(), other.path()) { + (Some(a), Some(b)) if a.eq_ignore_ascii_case(b) => {} + (None, None) => {} + _ => return false, + }; + + match (self.domain(), other.domain()) { + (Some(a), Some(b)) if a.eq_ignore_ascii_case(b) => {} + (None, None) => {} + _ => return false, + }; + + true + } +} + +#[cfg(test)] +mod tests { + use super::{Cookie, SameSite}; + use time::strptime; + + #[test] + fn format() { + let cookie = Cookie::new("foo", "bar"); + assert_eq!(&cookie.to_string(), "foo=bar"); + + let cookie = Cookie::build("foo", "bar").http_only(true).finish(); + assert_eq!(&cookie.to_string(), "foo=bar; HttpOnly"); + + let cookie = Cookie::build("foo", "bar").max_age(10).finish(); + assert_eq!(&cookie.to_string(), "foo=bar; Max-Age=10"); + + let cookie = Cookie::build("foo", "bar").secure(true).finish(); + assert_eq!(&cookie.to_string(), "foo=bar; Secure"); + + let cookie = Cookie::build("foo", "bar").path("/").finish(); + assert_eq!(&cookie.to_string(), "foo=bar; Path=/"); + + let cookie = Cookie::build("foo", "bar") + .domain("www.rust-lang.org") + .finish(); + assert_eq!(&cookie.to_string(), "foo=bar; Domain=www.rust-lang.org"); + + let time_str = "Wed, 21 Oct 2015 07:28:00 GMT"; + let expires = strptime(time_str, "%a, %d %b %Y %H:%M:%S %Z").unwrap(); + let cookie = Cookie::build("foo", "bar").expires(expires).finish(); + assert_eq!( + &cookie.to_string(), + "foo=bar; Expires=Wed, 21 Oct 2015 07:28:00 GMT" + ); + + let cookie = Cookie::build("foo", "bar") + .same_site(SameSite::Strict) + .finish(); + assert_eq!(&cookie.to_string(), "foo=bar; SameSite=Strict"); + + let cookie = Cookie::build("foo", "bar") + .same_site(SameSite::Lax) + .finish(); + assert_eq!(&cookie.to_string(), "foo=bar; SameSite=Lax"); + + let cookie = Cookie::build("foo", "bar") + .same_site(SameSite::None) + .finish(); + assert_eq!(&cookie.to_string(), "foo=bar"); + } + + #[test] + fn cookie_string_long_lifetimes() { + let cookie_string = + "bar=baz; Path=/subdir; HttpOnly; Domain=crates.io".to_owned(); + let (name, value, path, domain) = { + // Create a cookie passing a slice + let c = Cookie::parse(cookie_string.as_str()).unwrap(); + (c.name_raw(), c.value_raw(), c.path_raw(), c.domain_raw()) + }; + + assert_eq!(name, Some("bar")); + assert_eq!(value, Some("baz")); + assert_eq!(path, Some("/subdir")); + assert_eq!(domain, Some("crates.io")); + } + + #[test] + fn owned_cookie_string() { + let cookie_string = + "bar=baz; Path=/subdir; HttpOnly; Domain=crates.io".to_owned(); + let (name, value, path, domain) = { + // Create a cookie passing an owned string + let c = Cookie::parse(cookie_string).unwrap(); + (c.name_raw(), c.value_raw(), c.path_raw(), c.domain_raw()) + }; + + assert_eq!(name, None); + assert_eq!(value, None); + assert_eq!(path, None); + assert_eq!(domain, None); + } + + #[test] + fn owned_cookie_struct() { + let cookie_string = "bar=baz; Path=/subdir; HttpOnly; Domain=crates.io"; + let (name, value, path, domain) = { + // Create an owned cookie + let c = Cookie::parse(cookie_string).unwrap().into_owned(); + + (c.name_raw(), c.value_raw(), c.path_raw(), c.domain_raw()) + }; + + assert_eq!(name, None); + assert_eq!(value, None); + assert_eq!(path, None); + assert_eq!(domain, None); + } + + #[test] + fn format_encoded() { + let cookie = Cookie::build("foo !?=", "bar;; a").finish(); + let cookie_str = cookie.encoded().to_string(); + assert_eq!(&cookie_str, "foo%20!%3F%3D=bar%3B%3B%20a"); + + let cookie = Cookie::parse_encoded(cookie_str).unwrap(); + assert_eq!(cookie.name_value(), ("foo !?=", "bar;; a")); + } +} diff --git a/actix-http/src/cookie/parse.rs b/actix-http/src/cookie/parse.rs new file mode 100644 index 000000000..42a2c1fcf --- /dev/null +++ b/actix-http/src/cookie/parse.rs @@ -0,0 +1,425 @@ +use std::borrow::Cow; +use std::cmp; +use std::convert::From; +use std::error::Error; +use std::fmt; +use std::str::Utf8Error; + +use chrono::Duration; +use percent_encoding::percent_decode; + +use super::{Cookie, CookieStr, SameSite}; + +/// Enum corresponding to a parsing error. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum ParseError { + /// The cookie did not contain a name/value pair. + MissingPair, + /// The cookie's name was empty. + EmptyName, + /// Decoding the cookie's name or value resulted in invalid UTF-8. + Utf8Error(Utf8Error), + /// It is discouraged to exhaustively match on this enum as its variants may + /// grow without a breaking-change bump in version numbers. + #[doc(hidden)] + __Nonexhasutive, +} + +impl ParseError { + /// Returns a description of this error as a string + pub fn as_str(&self) -> &'static str { + match *self { + ParseError::MissingPair => "the cookie is missing a name/value pair", + ParseError::EmptyName => "the cookie's name is empty", + ParseError::Utf8Error(_) => { + "decoding the cookie's name or value resulted in invalid UTF-8" + } + ParseError::__Nonexhasutive => unreachable!("__Nonexhasutive ParseError"), + } + } +} + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +impl From for ParseError { + fn from(error: Utf8Error) -> ParseError { + ParseError::Utf8Error(error) + } +} + +impl Error for ParseError { + fn description(&self) -> &str { + self.as_str() + } +} + +fn indexes_of(needle: &str, haystack: &str) -> Option<(usize, usize)> { + let haystack_start = haystack.as_ptr() as usize; + let needle_start = needle.as_ptr() as usize; + + if needle_start < haystack_start { + return None; + } + + if (needle_start + needle.len()) > (haystack_start + haystack.len()) { + return None; + } + + let start = needle_start - haystack_start; + let end = start + needle.len(); + Some((start, end)) +} + +fn name_val_decoded( + name: &str, + val: &str, +) -> Result<(CookieStr, CookieStr), ParseError> { + let decoded_name = percent_decode(name.as_bytes()).decode_utf8()?; + let decoded_value = percent_decode(val.as_bytes()).decode_utf8()?; + let name = CookieStr::Concrete(Cow::Owned(decoded_name.into_owned())); + let val = CookieStr::Concrete(Cow::Owned(decoded_value.into_owned())); + + Ok((name, val)) +} + +// This function does the real parsing but _does not_ set the `cookie_string` in +// the returned cookie object. This only exists so that the borrow to `s` is +// returned at the end of the call, allowing the `cookie_string` field to be +// set in the outer `parse` function. +fn parse_inner<'c>(s: &str, decode: bool) -> Result, ParseError> { + let mut attributes = s.split(';'); + let key_value = match attributes.next() { + Some(s) => s, + _ => panic!(), + }; + + // Determine the name = val. + let (name, value) = match key_value.find('=') { + Some(i) => (key_value[..i].trim(), key_value[(i + 1)..].trim()), + None => return Err(ParseError::MissingPair), + }; + + if name.is_empty() { + return Err(ParseError::EmptyName); + } + + // Create a cookie with all of the defaults. We'll fill things in while we + // iterate through the parameters below. + let (name, value) = if decode { + name_val_decoded(name, value)? + } else { + let name_indexes = indexes_of(name, s).expect("name sub"); + let value_indexes = indexes_of(value, s).expect("value sub"); + let name = CookieStr::Indexed(name_indexes.0, name_indexes.1); + let value = CookieStr::Indexed(value_indexes.0, value_indexes.1); + + (name, value) + }; + + let mut cookie = Cookie { + name, + value, + cookie_string: None, + expires: None, + max_age: None, + domain: None, + path: None, + secure: None, + http_only: None, + same_site: None, + }; + + for attr in attributes { + let (key, value) = match attr.find('=') { + Some(i) => (attr[..i].trim(), Some(attr[(i + 1)..].trim())), + None => (attr.trim(), None), + }; + + match (&*key.to_ascii_lowercase(), value) { + ("secure", _) => cookie.secure = Some(true), + ("httponly", _) => cookie.http_only = Some(true), + ("max-age", Some(v)) => { + // See RFC 6265 Section 5.2.2, negative values indicate that the + // earliest possible expiration time should be used, so set the + // max age as 0 seconds. + cookie.max_age = match v.parse() { + Ok(val) if val <= 0 => Some(Duration::zero()), + Ok(val) => { + // Don't panic if the max age seconds is greater than what's supported by + // `Duration`. + let val = cmp::min(val, Duration::max_value().num_seconds()); + Some(Duration::seconds(val)) + } + Err(_) => continue, + }; + } + ("domain", Some(mut domain)) if !domain.is_empty() => { + if domain.starts_with('.') { + domain = &domain[1..]; + } + + let (i, j) = indexes_of(domain, s).expect("domain sub"); + cookie.domain = Some(CookieStr::Indexed(i, j)); + } + ("path", Some(v)) => { + let (i, j) = indexes_of(v, s).expect("path sub"); + cookie.path = Some(CookieStr::Indexed(i, j)); + } + ("samesite", Some(v)) => { + if v.eq_ignore_ascii_case("strict") { + cookie.same_site = Some(SameSite::Strict); + } else if v.eq_ignore_ascii_case("lax") { + cookie.same_site = Some(SameSite::Lax); + } else { + // We do nothing here, for now. When/if the `SameSite` + // attribute becomes standard, the spec says that we should + // ignore this cookie, i.e, fail to parse it, when an + // invalid value is passed in. The draft is at + // http://httpwg.org/http-extensions/draft-ietf-httpbis-cookie-same-site.html. + } + } + ("expires", Some(v)) => { + // Try strptime with three date formats according to + // http://tools.ietf.org/html/rfc2616#section-3.3.1. Try + // additional ones as encountered in the real world. + let tm = time::strptime(v, "%a, %d %b %Y %H:%M:%S %Z") + .or_else(|_| time::strptime(v, "%A, %d-%b-%y %H:%M:%S %Z")) + .or_else(|_| time::strptime(v, "%a, %d-%b-%Y %H:%M:%S %Z")) + .or_else(|_| time::strptime(v, "%a %b %d %H:%M:%S %Y")); + + if let Ok(time) = tm { + cookie.expires = Some(time) + } + } + _ => { + // We're going to be permissive here. If we have no idea what + // this is, then it's something nonstandard. We're not going to + // store it (because it's not compliant), but we're also not + // going to emit an error. + } + } + } + + Ok(cookie) +} + +pub fn parse_cookie<'c, S>(cow: S, decode: bool) -> Result, ParseError> +where + S: Into>, +{ + let s = cow.into(); + let mut cookie = parse_inner(&s, decode)?; + cookie.cookie_string = Some(s); + Ok(cookie) +} + +#[cfg(test)] +mod tests { + use super::{Cookie, SameSite}; + use chrono::Duration; + use time::strptime; + + macro_rules! assert_eq_parse { + ($string:expr, $expected:expr) => { + let cookie = match Cookie::parse($string) { + Ok(cookie) => cookie, + Err(e) => panic!("Failed to parse {:?}: {:?}", $string, e), + }; + + assert_eq!(cookie, $expected); + }; + } + + macro_rules! assert_ne_parse { + ($string:expr, $expected:expr) => { + let cookie = match Cookie::parse($string) { + Ok(cookie) => cookie, + Err(e) => panic!("Failed to parse {:?}: {:?}", $string, e), + }; + + assert_ne!(cookie, $expected); + }; + } + + #[test] + fn parse_same_site() { + let expected = Cookie::build("foo", "bar") + .same_site(SameSite::Lax) + .finish(); + + assert_eq_parse!("foo=bar; SameSite=Lax", expected); + assert_eq_parse!("foo=bar; SameSite=lax", expected); + assert_eq_parse!("foo=bar; SameSite=LAX", expected); + assert_eq_parse!("foo=bar; samesite=Lax", expected); + assert_eq_parse!("foo=bar; SAMESITE=Lax", expected); + + let expected = Cookie::build("foo", "bar") + .same_site(SameSite::Strict) + .finish(); + + assert_eq_parse!("foo=bar; SameSite=Strict", expected); + assert_eq_parse!("foo=bar; SameSITE=Strict", expected); + assert_eq_parse!("foo=bar; SameSite=strict", expected); + assert_eq_parse!("foo=bar; SameSite=STrICT", expected); + assert_eq_parse!("foo=bar; SameSite=STRICT", expected); + } + + #[test] + fn parse() { + assert!(Cookie::parse("bar").is_err()); + assert!(Cookie::parse("=bar").is_err()); + assert!(Cookie::parse(" =bar").is_err()); + assert!(Cookie::parse("foo=").is_ok()); + + let expected = Cookie::build("foo", "bar=baz").finish(); + assert_eq_parse!("foo=bar=baz", expected); + + let mut expected = Cookie::build("foo", "bar").finish(); + assert_eq_parse!("foo=bar", expected); + assert_eq_parse!("foo = bar", expected); + assert_eq_parse!(" foo=bar ", expected); + assert_eq_parse!(" foo=bar ;Domain=", expected); + assert_eq_parse!(" foo=bar ;Domain= ", expected); + assert_eq_parse!(" foo=bar ;Ignored", expected); + + let mut unexpected = Cookie::build("foo", "bar").http_only(false).finish(); + assert_ne_parse!(" foo=bar ;HttpOnly", unexpected); + assert_ne_parse!(" foo=bar; httponly", unexpected); + + expected.set_http_only(true); + assert_eq_parse!(" foo=bar ;HttpOnly", expected); + assert_eq_parse!(" foo=bar ;httponly", expected); + assert_eq_parse!(" foo=bar ;HTTPONLY=whatever", expected); + assert_eq_parse!(" foo=bar ; sekure; HTTPONLY", expected); + + expected.set_secure(true); + assert_eq_parse!(" foo=bar ;HttpOnly; Secure", expected); + assert_eq_parse!(" foo=bar ;HttpOnly; Secure=aaaa", expected); + + unexpected.set_http_only(true); + unexpected.set_secure(true); + assert_ne_parse!(" foo=bar ;HttpOnly; skeure", unexpected); + assert_ne_parse!(" foo=bar ;HttpOnly; =secure", unexpected); + assert_ne_parse!(" foo=bar ;HttpOnly;", unexpected); + + unexpected.set_secure(false); + assert_ne_parse!(" foo=bar ;HttpOnly; secure", unexpected); + assert_ne_parse!(" foo=bar ;HttpOnly; secure", unexpected); + assert_ne_parse!(" foo=bar ;HttpOnly; secure", unexpected); + + expected.set_max_age(Duration::zero()); + assert_eq_parse!(" foo=bar ;HttpOnly; Secure; Max-Age=0", expected); + assert_eq_parse!(" foo=bar ;HttpOnly; Secure; Max-Age = 0 ", expected); + assert_eq_parse!(" foo=bar ;HttpOnly; Secure; Max-Age=-1", expected); + assert_eq_parse!(" foo=bar ;HttpOnly; Secure; Max-Age = -1 ", expected); + + expected.set_max_age(Duration::minutes(1)); + assert_eq_parse!(" foo=bar ;HttpOnly; Secure; Max-Age=60", expected); + assert_eq_parse!(" foo=bar ;HttpOnly; Secure; Max-Age = 60 ", expected); + + expected.set_max_age(Duration::seconds(4)); + assert_eq_parse!(" foo=bar ;HttpOnly; Secure; Max-Age=4", expected); + assert_eq_parse!(" foo=bar ;HttpOnly; Secure; Max-Age = 4 ", expected); + + unexpected.set_secure(true); + unexpected.set_max_age(Duration::minutes(1)); + assert_ne_parse!(" foo=bar ;HttpOnly; Secure; Max-Age=122", unexpected); + assert_ne_parse!(" foo=bar ;HttpOnly; Secure; Max-Age = 38 ", unexpected); + assert_ne_parse!(" foo=bar ;HttpOnly; Secure; Max-Age=51", unexpected); + assert_ne_parse!(" foo=bar ;HttpOnly; Secure; Max-Age = -1 ", unexpected); + assert_ne_parse!(" foo=bar ;HttpOnly; Secure; Max-Age = 0", unexpected); + + expected.set_path("/"); + assert_eq_parse!("foo=bar;HttpOnly; Secure; Max-Age=4; Path=/", expected); + assert_eq_parse!("foo=bar;HttpOnly; Secure; Max-Age=4;Path=/", expected); + + expected.set_path("/foo"); + assert_eq_parse!("foo=bar;HttpOnly; Secure; Max-Age=4; Path=/foo", expected); + assert_eq_parse!("foo=bar;HttpOnly; Secure; Max-Age=4;Path=/foo", expected); + assert_eq_parse!("foo=bar;HttpOnly; Secure; Max-Age=4;path=/foo", expected); + assert_eq_parse!("foo=bar;HttpOnly; Secure; Max-Age=4;path = /foo", expected); + + unexpected.set_max_age(Duration::seconds(4)); + unexpected.set_path("/bar"); + assert_ne_parse!("foo=bar;HttpOnly; Secure; Max-Age=4; Path=/foo", unexpected); + assert_ne_parse!("foo=bar;HttpOnly; Secure; Max-Age=4;Path=/baz", unexpected); + + expected.set_domain("www.foo.com"); + assert_eq_parse!( + " foo=bar ;HttpOnly; Secure; Max-Age=4; Path=/foo; \ + Domain=www.foo.com", + expected + ); + + expected.set_domain("foo.com"); + assert_eq_parse!( + " foo=bar ;HttpOnly; Secure; Max-Age=4; Path=/foo; \ + Domain=foo.com", + expected + ); + assert_eq_parse!( + " foo=bar ;HttpOnly; Secure; Max-Age=4; Path=/foo; \ + Domain=FOO.COM", + expected + ); + + unexpected.set_path("/foo"); + unexpected.set_domain("bar.com"); + assert_ne_parse!( + " foo=bar ;HttpOnly; Secure; Max-Age=4; Path=/foo; \ + Domain=foo.com", + unexpected + ); + assert_ne_parse!( + " foo=bar ;HttpOnly; Secure; Max-Age=4; Path=/foo; \ + Domain=FOO.COM", + unexpected + ); + + let time_str = "Wed, 21 Oct 2015 07:28:00 GMT"; + let expires = strptime(time_str, "%a, %d %b %Y %H:%M:%S %Z").unwrap(); + expected.set_expires(expires); + assert_eq_parse!( + " foo=bar ;HttpOnly; Secure; Max-Age=4; Path=/foo; \ + Domain=foo.com; Expires=Wed, 21 Oct 2015 07:28:00 GMT", + expected + ); + + unexpected.set_domain("foo.com"); + let bad_expires = strptime(time_str, "%a, %d %b %Y %H:%S:%M %Z").unwrap(); + expected.set_expires(bad_expires); + assert_ne_parse!( + " foo=bar ;HttpOnly; Secure; Max-Age=4; Path=/foo; \ + Domain=foo.com; Expires=Wed, 21 Oct 2015 07:28:00 GMT", + unexpected + ); + } + + #[test] + fn odd_characters() { + let expected = Cookie::new("foo", "b%2Fr"); + assert_eq_parse!("foo=b%2Fr", expected); + } + + #[test] + fn odd_characters_encoded() { + let expected = Cookie::new("foo", "b/r"); + let cookie = match Cookie::parse_encoded("foo=b%2Fr") { + Ok(cookie) => cookie, + Err(e) => panic!("Failed to parse: {:?}", e), + }; + + assert_eq!(cookie, expected); + } + + #[test] + fn do_not_panic_on_large_max_ages() { + let max_seconds = Duration::max_value().num_seconds(); + let expected = Cookie::build("foo", "bar").max_age(max_seconds).finish(); + assert_eq_parse!(format!(" foo=bar; Max-Age={:?}", max_seconds + 1), expected); + } +} diff --git a/actix-http/src/cookie/secure/key.rs b/actix-http/src/cookie/secure/key.rs new file mode 100644 index 000000000..779c16b75 --- /dev/null +++ b/actix-http/src/cookie/secure/key.rs @@ -0,0 +1,190 @@ +use ring::hkdf::{Algorithm, KeyType, Prk, HKDF_SHA256}; +use ring::rand::{SecureRandom, SystemRandom}; + +use super::private::KEY_LEN as PRIVATE_KEY_LEN; +use super::signed::KEY_LEN as SIGNED_KEY_LEN; + +static HKDF_DIGEST: Algorithm = HKDF_SHA256; +const KEYS_INFO: &[&[u8]] = &[b"COOKIE;SIGNED:HMAC-SHA256;PRIVATE:AEAD-AES-256-GCM"]; + +/// A cryptographic master key for use with `Signed` and/or `Private` jars. +/// +/// This structure encapsulates secure, cryptographic keys for use with both +/// [PrivateJar](struct.PrivateJar.html) and [SignedJar](struct.SignedJar.html). +/// It can be derived from a single master key via +/// [from_master](#method.from_master) or generated from a secure random source +/// via [generate](#method.generate). A single instance of `Key` can be used for +/// both a `PrivateJar` and a `SignedJar`. +/// +/// This type is only available when the `secure` feature is enabled. +#[derive(Clone)] +pub struct Key { + signing_key: [u8; SIGNED_KEY_LEN], + encryption_key: [u8; PRIVATE_KEY_LEN], +} + +impl KeyType for &Key { + #[inline] + fn len(&self) -> usize { + SIGNED_KEY_LEN + PRIVATE_KEY_LEN + } +} + +impl Key { + /// Derives new signing/encryption keys from a master key. + /// + /// The master key must be at least 256-bits (32 bytes). For security, the + /// master key _must_ be cryptographically random. The keys are derived + /// deterministically from the master key. + /// + /// # Panics + /// + /// Panics if `key` is less than 32 bytes in length. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Key; + /// + /// # /* + /// let master_key = { /* a cryptographically random key >= 32 bytes */ }; + /// # */ + /// # let master_key: &Vec = &(0..32).collect(); + /// + /// let key = Key::from_master(master_key); + /// ``` + pub fn from_master(key: &[u8]) -> Key { + if key.len() < 32 { + panic!( + "bad master key length: expected at least 32 bytes, found {}", + key.len() + ); + } + + // An empty `Key` structure; will be filled in with HKDF derived keys. + let mut output_key = Key { + signing_key: [0; SIGNED_KEY_LEN], + encryption_key: [0; PRIVATE_KEY_LEN], + }; + + // Expand the master key into two HKDF generated keys. + let mut both_keys = [0; SIGNED_KEY_LEN + PRIVATE_KEY_LEN]; + let prk = Prk::new_less_safe(HKDF_DIGEST, key); + let okm = prk.expand(KEYS_INFO, &output_key).expect("okm expand"); + okm.fill(&mut both_keys).expect("fill keys"); + + // Copy the key parts into their respective fields. + output_key + .signing_key + .copy_from_slice(&both_keys[..SIGNED_KEY_LEN]); + output_key + .encryption_key + .copy_from_slice(&both_keys[SIGNED_KEY_LEN..]); + output_key + } + + /// Generates signing/encryption keys from a secure, random source. Keys are + /// generated nondeterministically. + /// + /// # Panics + /// + /// Panics if randomness cannot be retrieved from the operating system. See + /// [try_generate](#method.try_generate) for a non-panicking version. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Key; + /// + /// let key = Key::generate(); + /// ``` + pub fn generate() -> Key { + Self::try_generate().expect("failed to generate `Key` from randomness") + } + + /// Attempts to generate signing/encryption keys from a secure, random + /// source. Keys are generated nondeterministically. If randomness cannot be + /// retrieved from the underlying operating system, returns `None`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Key; + /// + /// let key = Key::try_generate(); + /// ``` + pub fn try_generate() -> Option { + let mut sign_key = [0; SIGNED_KEY_LEN]; + let mut enc_key = [0; PRIVATE_KEY_LEN]; + + let rng = SystemRandom::new(); + if rng.fill(&mut sign_key).is_err() || rng.fill(&mut enc_key).is_err() { + return None; + } + + Some(Key { + signing_key: sign_key, + encryption_key: enc_key, + }) + } + + /// Returns the raw bytes of a key suitable for signing cookies. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Key; + /// + /// let key = Key::generate(); + /// let signing_key = key.signing(); + /// ``` + pub fn signing(&self) -> &[u8] { + &self.signing_key[..] + } + + /// Returns the raw bytes of a key suitable for encrypting cookies. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::Key; + /// + /// let key = Key::generate(); + /// let encryption_key = key.encryption(); + /// ``` + pub fn encryption(&self) -> &[u8] { + &self.encryption_key[..] + } +} + +#[cfg(test)] +mod test { + use super::Key; + + #[test] + fn deterministic_from_master() { + let master_key: Vec = (0..32).collect(); + + let key_a = Key::from_master(&master_key); + let key_b = Key::from_master(&master_key); + + assert_eq!(key_a.signing(), key_b.signing()); + assert_eq!(key_a.encryption(), key_b.encryption()); + assert_ne!(key_a.encryption(), key_a.signing()); + + let master_key_2: Vec = (32..64).collect(); + let key_2 = Key::from_master(&master_key_2); + + assert_ne!(key_2.signing(), key_a.signing()); + assert_ne!(key_2.encryption(), key_a.encryption()); + } + + #[test] + fn non_deterministic_generate() { + let key_a = Key::generate(); + let key_b = Key::generate(); + + assert_ne!(key_a.signing(), key_b.signing()); + assert_ne!(key_a.encryption(), key_b.encryption()); + } +} diff --git a/actix-http/src/cookie/secure/macros.rs b/actix-http/src/cookie/secure/macros.rs new file mode 100644 index 000000000..089047c4e --- /dev/null +++ b/actix-http/src/cookie/secure/macros.rs @@ -0,0 +1,40 @@ +#[cfg(test)] +macro_rules! assert_simple_behaviour { + ($clear:expr, $secure:expr) => {{ + assert_eq!($clear.iter().count(), 0); + + $secure.add(Cookie::new("name", "val")); + assert_eq!($clear.iter().count(), 1); + assert_eq!($secure.get("name").unwrap().value(), "val"); + assert_ne!($clear.get("name").unwrap().value(), "val"); + + $secure.add(Cookie::new("another", "two")); + assert_eq!($clear.iter().count(), 2); + + $clear.remove(Cookie::named("another")); + assert_eq!($clear.iter().count(), 1); + + $secure.remove(Cookie::named("name")); + assert_eq!($clear.iter().count(), 0); + }}; +} + +#[cfg(test)] +macro_rules! assert_secure_behaviour { + ($clear:expr, $secure:expr) => {{ + $secure.add(Cookie::new("secure", "secure")); + assert!($clear.get("secure").unwrap().value() != "secure"); + assert!($secure.get("secure").unwrap().value() == "secure"); + + let mut cookie = $clear.get("secure").unwrap().clone(); + let new_val = format!("{}l", cookie.value()); + cookie.set_value(new_val); + $clear.add(cookie); + assert!($secure.get("secure").is_none()); + + let mut cookie = $clear.get("secure").unwrap().clone(); + cookie.set_value("foobar"); + $clear.add(cookie); + assert!($secure.get("secure").is_none()); + }}; +} diff --git a/actix-http/src/cookie/secure/mod.rs b/actix-http/src/cookie/secure/mod.rs new file mode 100644 index 000000000..e0fba9733 --- /dev/null +++ b/actix-http/src/cookie/secure/mod.rs @@ -0,0 +1,10 @@ +//! Fork of https://github.com/alexcrichton/cookie-rs +#[macro_use] +mod macros; +mod key; +mod private; +mod signed; + +pub use self::key::*; +pub use self::private::*; +pub use self::signed::*; diff --git a/actix-http/src/cookie/secure/private.rs b/actix-http/src/cookie/secure/private.rs new file mode 100644 index 000000000..6c16e94e8 --- /dev/null +++ b/actix-http/src/cookie/secure/private.rs @@ -0,0 +1,275 @@ +use std::str; + +use log::warn; +use ring::aead::{Aad, Algorithm, Nonce, AES_256_GCM}; +use ring::aead::{LessSafeKey, UnboundKey}; +use ring::rand::{SecureRandom, SystemRandom}; + +use super::Key; +use crate::cookie::{Cookie, CookieJar}; + +// Keep these in sync, and keep the key len synced with the `private` docs as +// well as the `KEYS_INFO` const in secure::Key. +static ALGO: &'static Algorithm = &AES_256_GCM; +const NONCE_LEN: usize = 12; +pub const KEY_LEN: usize = 32; + +/// A child cookie jar that provides authenticated encryption for its cookies. +/// +/// A _private_ child jar signs and encrypts all the cookies added to it and +/// verifies and decrypts cookies retrieved from it. Any cookies stored in a +/// `PrivateJar` are simultaneously assured confidentiality, integrity, and +/// authenticity. In other words, clients cannot discover nor tamper with the +/// contents of a cookie, nor can they fabricate cookie data. +/// +/// This type is only available when the `secure` feature is enabled. +pub struct PrivateJar<'a> { + parent: &'a mut CookieJar, + key: [u8; KEY_LEN], +} + +impl<'a> PrivateJar<'a> { + /// Creates a new child `PrivateJar` with parent `parent` and key `key`. + /// This method is typically called indirectly via the `signed` method of + /// `CookieJar`. + #[doc(hidden)] + pub fn new(parent: &'a mut CookieJar, key: &Key) -> PrivateJar<'a> { + let mut key_array = [0u8; KEY_LEN]; + key_array.copy_from_slice(key.encryption()); + PrivateJar { + parent, + key: key_array, + } + } + + /// Given a sealed value `str` and a key name `name`, where the nonce is + /// prepended to the original value and then both are Base64 encoded, + /// verifies and decrypts the sealed value and returns it. If there's a + /// problem, returns an `Err` with a string describing the issue. + fn unseal(&self, name: &str, value: &str) -> Result { + let mut data = base64::decode(value).map_err(|_| "bad base64 value")?; + if data.len() <= NONCE_LEN { + return Err("length of decoded data is <= NONCE_LEN"); + } + + let ad = Aad::from(name.as_bytes()); + let key = LessSafeKey::new( + UnboundKey::new(&ALGO, &self.key).expect("matching key length"), + ); + let (nonce, mut sealed) = data.split_at_mut(NONCE_LEN); + let nonce = + Nonce::try_assume_unique_for_key(nonce).expect("invalid length of `nonce`"); + let unsealed = key + .open_in_place(nonce, ad, &mut sealed) + .map_err(|_| "invalid key/nonce/value: bad seal")?; + + if let Ok(unsealed_utf8) = str::from_utf8(unsealed) { + Ok(unsealed_utf8.to_string()) + } else { + warn!( + "Private cookie does not have utf8 content! +It is likely the secret key used to encrypt them has been leaked. +Please change it as soon as possible." + ); + Err("bad unsealed utf8") + } + } + + /// Returns a reference to the `Cookie` inside this jar with the name `name` + /// and authenticates and decrypts the cookie's value, returning a `Cookie` + /// with the decrypted value. If the cookie cannot be found, or the cookie + /// fails to authenticate or decrypt, `None` is returned. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{CookieJar, Cookie, Key}; + /// + /// let key = Key::generate(); + /// let mut jar = CookieJar::new(); + /// let mut private_jar = jar.private(&key); + /// assert!(private_jar.get("name").is_none()); + /// + /// private_jar.add(Cookie::new("name", "value")); + /// assert_eq!(private_jar.get("name").unwrap().value(), "value"); + /// ``` + pub fn get(&self, name: &str) -> Option> { + if let Some(cookie_ref) = self.parent.get(name) { + let mut cookie = cookie_ref.clone(); + if let Ok(value) = self.unseal(name, cookie.value()) { + cookie.set_value(value); + return Some(cookie); + } + } + + None + } + + /// Adds `cookie` to the parent jar. The cookie's value is encrypted with + /// authenticated encryption assuring confidentiality, integrity, and + /// authenticity. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{CookieJar, Cookie, Key}; + /// + /// let key = Key::generate(); + /// let mut jar = CookieJar::new(); + /// jar.private(&key).add(Cookie::new("name", "value")); + /// + /// assert_ne!(jar.get("name").unwrap().value(), "value"); + /// assert_eq!(jar.private(&key).get("name").unwrap().value(), "value"); + /// ``` + pub fn add(&mut self, mut cookie: Cookie<'static>) { + self.encrypt_cookie(&mut cookie); + + // Add the sealed cookie to the parent. + self.parent.add(cookie); + } + + /// Adds an "original" `cookie` to parent jar. The cookie's value is + /// encrypted with authenticated encryption assuring confidentiality, + /// integrity, and authenticity. Adding an original cookie does not affect + /// the [`CookieJar::delta()`](struct.CookieJar.html#method.delta) + /// computation. This method is intended to be used to seed the cookie jar + /// with cookies received from a client's HTTP message. + /// + /// For accurate `delta` computations, this method should not be called + /// after calling `remove`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{CookieJar, Cookie, Key}; + /// + /// let key = Key::generate(); + /// let mut jar = CookieJar::new(); + /// jar.private(&key).add_original(Cookie::new("name", "value")); + /// + /// assert_eq!(jar.iter().count(), 1); + /// assert_eq!(jar.delta().count(), 0); + /// ``` + pub fn add_original(&mut self, mut cookie: Cookie<'static>) { + self.encrypt_cookie(&mut cookie); + + // Add the sealed cookie to the parent. + self.parent.add_original(cookie); + } + + /// Encrypts the cookie's value with + /// authenticated encryption assuring confidentiality, integrity, and authenticity. + fn encrypt_cookie(&self, cookie: &mut Cookie) { + let name = cookie.name().as_bytes(); + let value = cookie.value().as_bytes(); + let data = encrypt_name_value(name, value, &self.key); + + // Base64 encode the nonce and encrypted value. + let sealed_value = base64::encode(&data); + cookie.set_value(sealed_value); + } + + /// Removes `cookie` from the parent jar. + /// + /// For correct removal, the passed in `cookie` must contain the same `path` + /// and `domain` as the cookie that was initially set. + /// + /// See [CookieJar::remove](struct.CookieJar.html#method.remove) for more + /// details. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{CookieJar, Cookie, Key}; + /// + /// let key = Key::generate(); + /// let mut jar = CookieJar::new(); + /// let mut private_jar = jar.private(&key); + /// + /// private_jar.add(Cookie::new("name", "value")); + /// assert!(private_jar.get("name").is_some()); + /// + /// private_jar.remove(Cookie::named("name")); + /// assert!(private_jar.get("name").is_none()); + /// ``` + pub fn remove(&mut self, cookie: Cookie<'static>) { + self.parent.remove(cookie); + } +} + +fn encrypt_name_value(name: &[u8], value: &[u8], key: &[u8]) -> Vec { + // Create the `SealingKey` structure. + let unbound = UnboundKey::new(&ALGO, key).expect("matching key length"); + let key = LessSafeKey::new(unbound); + + // Create a vec to hold the [nonce | cookie value | overhead]. + let mut data = vec![0; NONCE_LEN + value.len() + ALGO.tag_len()]; + + // Randomly generate the nonce, then copy the cookie value as input. + let (nonce, in_out) = data.split_at_mut(NONCE_LEN); + let (in_out, tag) = in_out.split_at_mut(value.len()); + in_out.copy_from_slice(value); + + // Randomly generate the nonce into the nonce piece. + SystemRandom::new() + .fill(nonce) + .expect("couldn't random fill nonce"); + let nonce = Nonce::try_assume_unique_for_key(nonce).expect("invalid `nonce` length"); + + // Use cookie's name as associated data to prevent value swapping. + let ad = Aad::from(name); + let ad_tag = key + .seal_in_place_separate_tag(nonce, ad, in_out) + .expect("in-place seal"); + + // Copy the tag into the tag piece. + tag.copy_from_slice(ad_tag.as_ref()); + + // Remove the overhead and return the sealed content. + data +} + +#[cfg(test)] +mod test { + use super::{encrypt_name_value, Cookie, CookieJar, Key}; + + #[test] + fn simple() { + let key = Key::generate(); + let mut jar = CookieJar::new(); + assert_simple_behaviour!(jar, jar.private(&key)); + } + + #[test] + fn private() { + let key = Key::generate(); + let mut jar = CookieJar::new(); + assert_secure_behaviour!(jar, jar.private(&key)); + } + + #[test] + fn non_utf8() { + let key = Key::generate(); + let mut jar = CookieJar::new(); + + let name = "malicious"; + let mut assert_non_utf8 = |value: &[u8]| { + let sealed = encrypt_name_value(name.as_bytes(), value, &key.encryption()); + let encoded = base64::encode(&sealed); + assert_eq!( + jar.private(&key).unseal(name, &encoded), + Err("bad unsealed utf8") + ); + jar.add(Cookie::new(name, encoded)); + assert_eq!(jar.private(&key).get(name), None); + }; + + assert_non_utf8(&[0x72, 0xfb, 0xdf, 0x74]); // rûst in ISO/IEC 8859-1 + + let mut malicious = + String::from(r#"{"id":"abc123??%X","admin":true}"#).into_bytes(); + malicious[8] |= 0b1100_0000; + malicious[9] |= 0b1100_0000; + assert_non_utf8(&malicious); + } +} diff --git a/actix-http/src/cookie/secure/signed.rs b/actix-http/src/cookie/secure/signed.rs new file mode 100644 index 000000000..3fcd2cd84 --- /dev/null +++ b/actix-http/src/cookie/secure/signed.rs @@ -0,0 +1,184 @@ +use ring::hmac::{self, sign, verify}; + +use super::Key; +use crate::cookie::{Cookie, CookieJar}; + +// Keep these in sync, and keep the key len synced with the `signed` docs as +// well as the `KEYS_INFO` const in secure::Key. +static HMAC_DIGEST: hmac::Algorithm = hmac::HMAC_SHA256; +const BASE64_DIGEST_LEN: usize = 44; +pub const KEY_LEN: usize = 32; + +/// A child cookie jar that authenticates its cookies. +/// +/// A _signed_ child jar signs all the cookies added to it and verifies cookies +/// retrieved from it. Any cookies stored in a `SignedJar` are assured integrity +/// and authenticity. In other words, clients cannot tamper with the contents of +/// a cookie nor can they fabricate cookie values, but the data is visible in +/// plaintext. +/// +/// This type is only available when the `secure` feature is enabled. +pub struct SignedJar<'a> { + parent: &'a mut CookieJar, + key: hmac::Key, +} + +impl<'a> SignedJar<'a> { + /// Creates a new child `SignedJar` with parent `parent` and key `key`. This + /// method is typically called indirectly via the `signed` method of + /// `CookieJar`. + #[doc(hidden)] + pub fn new(parent: &'a mut CookieJar, key: &Key) -> SignedJar<'a> { + SignedJar { + parent, + key: hmac::Key::new(HMAC_DIGEST, key.signing()), + } + } + + /// Given a signed value `str` where the signature is prepended to `value`, + /// verifies the signed value and returns it. If there's a problem, returns + /// an `Err` with a string describing the issue. + fn verify(&self, cookie_value: &str) -> Result { + if cookie_value.len() < BASE64_DIGEST_LEN { + return Err("length of value is <= BASE64_DIGEST_LEN"); + } + + let (digest_str, value) = cookie_value.split_at(BASE64_DIGEST_LEN); + let sig = base64::decode(digest_str).map_err(|_| "bad base64 digest")?; + + verify(&self.key, value.as_bytes(), &sig) + .map(|_| value.to_string()) + .map_err(|_| "value did not verify") + } + + /// Returns a reference to the `Cookie` inside this jar with the name `name` + /// and verifies the authenticity and integrity of the cookie's value, + /// returning a `Cookie` with the authenticated value. If the cookie cannot + /// be found, or the cookie fails to verify, `None` is returned. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{CookieJar, Cookie, Key}; + /// + /// let key = Key::generate(); + /// let mut jar = CookieJar::new(); + /// let mut signed_jar = jar.signed(&key); + /// assert!(signed_jar.get("name").is_none()); + /// + /// signed_jar.add(Cookie::new("name", "value")); + /// assert_eq!(signed_jar.get("name").unwrap().value(), "value"); + /// ``` + pub fn get(&self, name: &str) -> Option> { + if let Some(cookie_ref) = self.parent.get(name) { + let mut cookie = cookie_ref.clone(); + if let Ok(value) = self.verify(cookie.value()) { + cookie.set_value(value); + return Some(cookie); + } + } + + None + } + + /// Adds `cookie` to the parent jar. The cookie's value is signed assuring + /// integrity and authenticity. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{CookieJar, Cookie, Key}; + /// + /// let key = Key::generate(); + /// let mut jar = CookieJar::new(); + /// jar.signed(&key).add(Cookie::new("name", "value")); + /// + /// assert_ne!(jar.get("name").unwrap().value(), "value"); + /// assert!(jar.get("name").unwrap().value().contains("value")); + /// assert_eq!(jar.signed(&key).get("name").unwrap().value(), "value"); + /// ``` + pub fn add(&mut self, mut cookie: Cookie<'static>) { + self.sign_cookie(&mut cookie); + self.parent.add(cookie); + } + + /// Adds an "original" `cookie` to this jar. The cookie's value is signed + /// assuring integrity and authenticity. Adding an original cookie does not + /// affect the [`CookieJar::delta()`](struct.CookieJar.html#method.delta) + /// computation. This method is intended to be used to seed the cookie jar + /// with cookies received from a client's HTTP message. + /// + /// For accurate `delta` computations, this method should not be called + /// after calling `remove`. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{CookieJar, Cookie, Key}; + /// + /// let key = Key::generate(); + /// let mut jar = CookieJar::new(); + /// jar.signed(&key).add_original(Cookie::new("name", "value")); + /// + /// assert_eq!(jar.iter().count(), 1); + /// assert_eq!(jar.delta().count(), 0); + /// ``` + pub fn add_original(&mut self, mut cookie: Cookie<'static>) { + self.sign_cookie(&mut cookie); + self.parent.add_original(cookie); + } + + /// Signs the cookie's value assuring integrity and authenticity. + fn sign_cookie(&self, cookie: &mut Cookie) { + let digest = sign(&self.key, cookie.value().as_bytes()); + let mut new_value = base64::encode(digest.as_ref()); + new_value.push_str(cookie.value()); + cookie.set_value(new_value); + } + + /// Removes `cookie` from the parent jar. + /// + /// For correct removal, the passed in `cookie` must contain the same `path` + /// and `domain` as the cookie that was initially set. + /// + /// See [CookieJar::remove](struct.CookieJar.html#method.remove) for more + /// details. + /// + /// # Example + /// + /// ```rust + /// use actix_http::cookie::{CookieJar, Cookie, Key}; + /// + /// let key = Key::generate(); + /// let mut jar = CookieJar::new(); + /// let mut signed_jar = jar.signed(&key); + /// + /// signed_jar.add(Cookie::new("name", "value")); + /// assert!(signed_jar.get("name").is_some()); + /// + /// signed_jar.remove(Cookie::named("name")); + /// assert!(signed_jar.get("name").is_none()); + /// ``` + pub fn remove(&mut self, cookie: Cookie<'static>) { + self.parent.remove(cookie); + } +} + +#[cfg(test)] +mod test { + use super::{Cookie, CookieJar, Key}; + + #[test] + fn simple() { + let key = Key::generate(); + let mut jar = CookieJar::new(); + assert_simple_behaviour!(jar, jar.signed(&key)); + } + + #[test] + fn private() { + let key = Key::generate(); + let mut jar = CookieJar::new(); + assert_secure_behaviour!(jar, jar.signed(&key)); + } +} diff --git a/actix-http/src/encoding/decoder.rs b/actix-http/src/encoding/decoder.rs new file mode 100644 index 000000000..1e51e8b56 --- /dev/null +++ b/actix-http/src/encoding/decoder.rs @@ -0,0 +1,241 @@ +use std::future::Future; +use std::io::{self, Write}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use actix_threadpool::{run, CpuFuture}; +#[cfg(feature = "brotli")] +use brotli2::write::BrotliDecoder; +use bytes::Bytes; +#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] +use flate2::write::{GzDecoder, ZlibDecoder}; +use futures::{ready, Stream}; + +use super::Writer; +use crate::error::PayloadError; +use crate::http::header::{ContentEncoding, HeaderMap, CONTENT_ENCODING}; + +const INPLACE: usize = 2049; + +pub struct Decoder { + decoder: Option, + stream: S, + eof: bool, + fut: Option, ContentDecoder), io::Error>>>, +} + +impl Decoder +where + S: Stream>, +{ + /// Construct a decoder. + #[inline] + pub fn new(stream: S, encoding: ContentEncoding) -> Decoder { + let decoder = match encoding { + #[cfg(feature = "brotli")] + ContentEncoding::Br => Some(ContentDecoder::Br(Box::new( + BrotliDecoder::new(Writer::new()), + ))), + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + ContentEncoding::Deflate => Some(ContentDecoder::Deflate(Box::new( + ZlibDecoder::new(Writer::new()), + ))), + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + ContentEncoding::Gzip => Some(ContentDecoder::Gzip(Box::new( + GzDecoder::new(Writer::new()), + ))), + _ => None, + }; + Decoder { + decoder, + stream, + fut: None, + eof: false, + } + } + + /// Construct decoder based on headers. + #[inline] + pub fn from_headers(stream: S, headers: &HeaderMap) -> Decoder { + // check content-encoding + let encoding = if let Some(enc) = headers.get(&CONTENT_ENCODING) { + if let Ok(enc) = enc.to_str() { + ContentEncoding::from(enc) + } else { + ContentEncoding::Identity + } + } else { + ContentEncoding::Identity + }; + + Self::new(stream, encoding) + } +} + +impl Stream for Decoder +where + S: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + loop { + if let Some(ref mut fut) = self.fut { + let (chunk, decoder) = match ready!(Pin::new(fut).poll(cx)) { + Ok(Ok(item)) => item, + Ok(Err(e)) => return Poll::Ready(Some(Err(e.into()))), + Err(e) => return Poll::Ready(Some(Err(e.into()))), + }; + self.decoder = Some(decoder); + self.fut.take(); + if let Some(chunk) = chunk { + return Poll::Ready(Some(Ok(chunk))); + } + } + + if self.eof { + return Poll::Ready(None); + } + + match Pin::new(&mut self.stream).poll_next(cx) { + Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))), + Poll::Ready(Some(Ok(chunk))) => { + if let Some(mut decoder) = self.decoder.take() { + if chunk.len() < INPLACE { + let chunk = decoder.feed_data(chunk)?; + self.decoder = Some(decoder); + if let Some(chunk) = chunk { + return Poll::Ready(Some(Ok(chunk))); + } + } else { + self.fut = Some(run(move || { + let chunk = decoder.feed_data(chunk)?; + Ok((chunk, decoder)) + })); + } + continue; + } else { + return Poll::Ready(Some(Ok(chunk))); + } + } + Poll::Ready(None) => { + self.eof = true; + return if let Some(mut decoder) = self.decoder.take() { + match decoder.feed_eof() { + Ok(Some(res)) => Poll::Ready(Some(Ok(res))), + Ok(None) => Poll::Ready(None), + Err(err) => Poll::Ready(Some(Err(err.into()))), + } + } else { + Poll::Ready(None) + }; + } + Poll::Pending => break, + } + } + Poll::Pending + } +} + +enum ContentDecoder { + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + Deflate(Box>), + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + Gzip(Box>), + #[cfg(feature = "brotli")] + Br(Box>), +} + +impl ContentDecoder { + #[allow(unreachable_patterns)] + fn feed_eof(&mut self) -> io::Result> { + match self { + #[cfg(feature = "brotli")] + ContentDecoder::Br(ref mut decoder) => match decoder.finish() { + Ok(mut writer) => { + let b = writer.take(); + if !b.is_empty() { + Ok(Some(b)) + } else { + Ok(None) + } + } + Err(e) => Err(e), + }, + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + ContentDecoder::Gzip(ref mut decoder) => match decoder.try_finish() { + Ok(_) => { + let b = decoder.get_mut().take(); + if !b.is_empty() { + Ok(Some(b)) + } else { + Ok(None) + } + } + Err(e) => Err(e), + }, + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + ContentDecoder::Deflate(ref mut decoder) => match decoder.try_finish() { + Ok(_) => { + let b = decoder.get_mut().take(); + if !b.is_empty() { + Ok(Some(b)) + } else { + Ok(None) + } + } + Err(e) => Err(e), + }, + _ => Ok(None), + } + } + + #[allow(unreachable_patterns)] + fn feed_data(&mut self, data: Bytes) -> io::Result> { + match self { + #[cfg(feature = "brotli")] + ContentDecoder::Br(ref mut decoder) => match decoder.write_all(&data) { + Ok(_) => { + decoder.flush()?; + let b = decoder.get_mut().take(); + if !b.is_empty() { + Ok(Some(b)) + } else { + Ok(None) + } + } + Err(e) => Err(e), + }, + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + ContentDecoder::Gzip(ref mut decoder) => match decoder.write_all(&data) { + Ok(_) => { + decoder.flush()?; + let b = decoder.get_mut().take(); + if !b.is_empty() { + Ok(Some(b)) + } else { + Ok(None) + } + } + Err(e) => Err(e), + }, + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + ContentDecoder::Deflate(ref mut decoder) => match decoder.write_all(&data) { + Ok(_) => { + decoder.flush()?; + let b = decoder.get_mut().take(); + if !b.is_empty() { + Ok(Some(b)) + } else { + Ok(None) + } + } + Err(e) => Err(e), + }, + _ => Ok(Some(data)), + } + } +} diff --git a/actix-http/src/encoding/encoder.rs b/actix-http/src/encoding/encoder.rs new file mode 100644 index 000000000..295d99a2a --- /dev/null +++ b/actix-http/src/encoding/encoder.rs @@ -0,0 +1,266 @@ +//! Stream encoder +use std::future::Future; +use std::io::{self, Write}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use actix_threadpool::{run, CpuFuture}; +#[cfg(feature = "brotli")] +use brotli2::write::BrotliEncoder; +use bytes::Bytes; +#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] +use flate2::write::{GzEncoder, ZlibEncoder}; + +use crate::body::{Body, BodySize, MessageBody, ResponseBody}; +use crate::http::header::{ContentEncoding, CONTENT_ENCODING}; +use crate::http::{HeaderValue, HttpTryFrom, StatusCode}; +use crate::{Error, ResponseHead}; + +use super::Writer; + +const INPLACE: usize = 2049; + +pub struct Encoder { + eof: bool, + body: EncoderBody, + encoder: Option, + fut: Option>>, +} + +impl Encoder { + pub fn response( + encoding: ContentEncoding, + head: &mut ResponseHead, + body: ResponseBody, + ) -> ResponseBody> { + let can_encode = !(head.headers().contains_key(&CONTENT_ENCODING) + || head.status == StatusCode::SWITCHING_PROTOCOLS + || head.status == StatusCode::NO_CONTENT + || encoding == ContentEncoding::Identity + || encoding == ContentEncoding::Auto); + + let body = match body { + ResponseBody::Other(b) => match b { + Body::None => return ResponseBody::Other(Body::None), + Body::Empty => return ResponseBody::Other(Body::Empty), + Body::Bytes(buf) => { + if can_encode { + EncoderBody::Bytes(buf) + } else { + return ResponseBody::Other(Body::Bytes(buf)); + } + } + Body::Message(stream) => EncoderBody::BoxedStream(stream), + }, + ResponseBody::Body(stream) => EncoderBody::Stream(stream), + }; + + if can_encode { + // Modify response body only if encoder is not None + if let Some(enc) = ContentEncoder::encoder(encoding) { + update_head(encoding, head); + head.no_chunking(false); + return ResponseBody::Body(Encoder { + body, + eof: false, + fut: None, + encoder: Some(enc), + }); + } + } + ResponseBody::Body(Encoder { + body, + eof: false, + fut: None, + encoder: None, + }) + } +} + +enum EncoderBody { + Bytes(Bytes), + Stream(B), + BoxedStream(Box), +} + +impl MessageBody for Encoder { + fn size(&self) -> BodySize { + if self.encoder.is_none() { + match self.body { + EncoderBody::Bytes(ref b) => b.size(), + EncoderBody::Stream(ref b) => b.size(), + EncoderBody::BoxedStream(ref b) => b.size(), + } + } else { + BodySize::Stream + } + } + + fn poll_next(&mut self, cx: &mut Context) -> Poll>> { + loop { + if self.eof { + return Poll::Ready(None); + } + + if let Some(ref mut fut) = self.fut { + let mut encoder = match futures::ready!(Pin::new(fut).poll(cx)) { + Ok(Ok(item)) => item, + Ok(Err(e)) => return Poll::Ready(Some(Err(e.into()))), + Err(e) => return Poll::Ready(Some(Err(e.into()))), + }; + let chunk = encoder.take(); + self.encoder = Some(encoder); + self.fut.take(); + if !chunk.is_empty() { + return Poll::Ready(Some(Ok(chunk))); + } + } + + let result = match self.body { + EncoderBody::Bytes(ref mut b) => { + if b.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(std::mem::replace(b, Bytes::new())))) + } + } + EncoderBody::Stream(ref mut b) => b.poll_next(cx), + EncoderBody::BoxedStream(ref mut b) => b.poll_next(cx), + }; + match result { + Poll::Ready(Some(Ok(chunk))) => { + if let Some(mut encoder) = self.encoder.take() { + if chunk.len() < INPLACE { + encoder.write(&chunk)?; + let chunk = encoder.take(); + self.encoder = Some(encoder); + if !chunk.is_empty() { + return Poll::Ready(Some(Ok(chunk))); + } + } else { + self.fut = Some(run(move || { + encoder.write(&chunk)?; + Ok(encoder) + })); + } + } else { + return Poll::Ready(Some(Ok(chunk))); + } + } + Poll::Ready(None) => { + if let Some(encoder) = self.encoder.take() { + let chunk = encoder.finish()?; + if chunk.is_empty() { + return Poll::Ready(None); + } else { + self.eof = true; + return Poll::Ready(Some(Ok(chunk))); + } + } else { + return Poll::Ready(None); + } + } + val => return val, + } + } + } +} + +fn update_head(encoding: ContentEncoding, head: &mut ResponseHead) { + head.headers_mut().insert( + CONTENT_ENCODING, + HeaderValue::try_from(Bytes::from_static(encoding.as_str().as_bytes())).unwrap(), + ); +} + +enum ContentEncoder { + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + Deflate(ZlibEncoder), + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + Gzip(GzEncoder), + #[cfg(feature = "brotli")] + Br(BrotliEncoder), +} + +impl ContentEncoder { + fn encoder(encoding: ContentEncoding) -> Option { + match encoding { + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + ContentEncoding::Deflate => Some(ContentEncoder::Deflate(ZlibEncoder::new( + Writer::new(), + flate2::Compression::fast(), + ))), + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + ContentEncoding::Gzip => Some(ContentEncoder::Gzip(GzEncoder::new( + Writer::new(), + flate2::Compression::fast(), + ))), + #[cfg(feature = "brotli")] + ContentEncoding::Br => { + Some(ContentEncoder::Br(BrotliEncoder::new(Writer::new(), 3))) + } + _ => None, + } + } + + #[inline] + pub(crate) fn take(&mut self) -> Bytes { + match *self { + #[cfg(feature = "brotli")] + ContentEncoder::Br(ref mut encoder) => encoder.get_mut().take(), + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + ContentEncoder::Deflate(ref mut encoder) => encoder.get_mut().take(), + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + ContentEncoder::Gzip(ref mut encoder) => encoder.get_mut().take(), + } + } + + fn finish(self) -> Result { + match self { + #[cfg(feature = "brotli")] + ContentEncoder::Br(encoder) => match encoder.finish() { + Ok(writer) => Ok(writer.buf.freeze()), + Err(err) => Err(err), + }, + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + ContentEncoder::Gzip(encoder) => match encoder.finish() { + Ok(writer) => Ok(writer.buf.freeze()), + Err(err) => Err(err), + }, + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + ContentEncoder::Deflate(encoder) => match encoder.finish() { + Ok(writer) => Ok(writer.buf.freeze()), + Err(err) => Err(err), + }, + } + } + + fn write(&mut self, data: &[u8]) -> Result<(), io::Error> { + match *self { + #[cfg(feature = "brotli")] + ContentEncoder::Br(ref mut encoder) => match encoder.write_all(data) { + Ok(_) => Ok(()), + Err(err) => { + trace!("Error decoding br encoding: {}", err); + Err(err) + } + }, + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + ContentEncoder::Gzip(ref mut encoder) => match encoder.write_all(data) { + Ok(_) => Ok(()), + Err(err) => { + trace!("Error decoding gzip encoding: {}", err); + Err(err) + } + }, + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + ContentEncoder::Deflate(ref mut encoder) => match encoder.write_all(data) { + Ok(_) => Ok(()), + Err(err) => { + trace!("Error decoding deflate encoding: {}", err); + Err(err) + } + }, + } + } +} diff --git a/actix-http/src/encoding/mod.rs b/actix-http/src/encoding/mod.rs new file mode 100644 index 000000000..b55a43a7c --- /dev/null +++ b/actix-http/src/encoding/mod.rs @@ -0,0 +1,35 @@ +//! Content-Encoding support +use std::io; + +use bytes::{Bytes, BytesMut}; + +mod decoder; +mod encoder; + +pub use self::decoder::Decoder; +pub use self::encoder::Encoder; + +pub(self) struct Writer { + buf: BytesMut, +} + +impl Writer { + fn new() -> Writer { + Writer { + buf: BytesMut::with_capacity(8192), + } + } + fn take(&mut self) -> Bytes { + self.buf.take().freeze() + } +} + +impl io::Write for Writer { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.buf.extend_from_slice(buf); + Ok(buf.len()) + } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} diff --git a/actix-http/src/error.rs b/actix-http/src/error.rs new file mode 100644 index 000000000..587849bde --- /dev/null +++ b/actix-http/src/error.rs @@ -0,0 +1,1204 @@ +//! Error and Result module +use std::any::TypeId; +use std::cell::RefCell; +use std::io::Write; +use std::str::Utf8Error; +use std::string::FromUtf8Error; +use std::{fmt, io, result}; + +use actix_utils::timeout::TimeoutError; +use bytes::BytesMut; +use derive_more::{Display, From}; +pub use futures::channel::oneshot::Canceled; +use http::uri::InvalidUri; +use http::{header, Error as HttpError, StatusCode}; +use httparse; +use serde::de::value::Error as DeError; +use serde_json::error::Error as JsonError; +use serde_urlencoded::ser::Error as FormError; + +// re-export for convinience +use crate::body::Body; +pub use crate::cookie::ParseError as CookieParseError; +use crate::helpers::Writer; +use crate::response::Response; + +/// A specialized [`Result`](https://doc.rust-lang.org/std/result/enum.Result.html) +/// for actix web operations +/// +/// This typedef is generally used to avoid writing out +/// `actix_http::error::Error` directly and is otherwise a direct mapping to +/// `Result`. +pub type Result = result::Result; + +/// General purpose actix web error. +/// +/// An actix web error is used to carry errors from `failure` or `std::error` +/// through actix in a convenient way. It can be created through +/// converting errors with `into()`. +/// +/// Whenever it is created from an external object a response error is created +/// for it that can be used to create an http response from it this means that +/// if you have access to an actix `Error` you can always get a +/// `ResponseError` reference from it. +pub struct Error { + cause: Box, +} + +impl Error { + /// Returns the reference to the underlying `ResponseError`. + pub fn as_response_error(&self) -> &dyn ResponseError { + self.cause.as_ref() + } + + /// Similar to `as_response_error` but downcasts. + pub fn as_error(&self) -> Option<&T> { + ResponseError::downcast_ref(self.cause.as_ref()) + } +} + +/// Error that can be converted to `Response` +pub trait ResponseError: fmt::Debug + fmt::Display { + /// Response's status code + /// + /// Internal server error is generated by default. + fn status_code(&self) -> StatusCode { + StatusCode::INTERNAL_SERVER_ERROR + } + + /// Create response for error + /// + /// Internal server error is generated by default. + fn error_response(&self) -> Response { + let mut resp = Response::new(self.status_code()); + let mut buf = BytesMut::new(); + let _ = write!(Writer(&mut buf), "{}", self); + resp.headers_mut().insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("text/plain; charset=utf-8"), + ); + resp.set_body(Body::from(buf)) + } + + #[doc(hidden)] + fn __private_get_type_id__(&self) -> TypeId + where + Self: 'static, + { + TypeId::of::() + } +} + +impl dyn ResponseError + 'static { + /// Downcasts a response error to a specific type. + pub fn downcast_ref(&self) -> Option<&T> { + if self.__private_get_type_id__() == TypeId::of::() { + unsafe { Some(&*(self as *const dyn ResponseError as *const T)) } + } else { + None + } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.cause, f) + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "{:?}", &self.cause) + } +} + +impl From<()> for Error { + fn from(_: ()) -> Self { + Error::from(UnitError) + } +} + +impl std::error::Error for Error { + fn description(&self) -> &str { + "actix-http::Error" + } + + fn cause(&self) -> Option<&dyn std::error::Error> { + None + } + + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + None + } +} + +impl From for Error { + fn from(_: std::convert::Infallible) -> Self { + // `std::convert::Infallible` indicates an error + // that will never happen + unreachable!() + } +} + +/// Convert `Error` to a `Response` instance +impl From for Response { + fn from(err: Error) -> Self { + Response::from_error(err) + } +} + +/// `Error` for any error that implements `ResponseError` +impl From for Error { + fn from(err: T) -> Error { + Error { + cause: Box::new(err), + } + } +} + +/// Return `GATEWAY_TIMEOUT` for `TimeoutError` +impl ResponseError for TimeoutError { + fn status_code(&self) -> StatusCode { + match self { + TimeoutError::Service(e) => e.status_code(), + TimeoutError::Timeout => StatusCode::GATEWAY_TIMEOUT, + } + } +} + +#[derive(Debug, Display)] +#[display(fmt = "UnknownError")] +struct UnitError; + +/// `InternalServerError` for `UnitError` +impl ResponseError for UnitError {} + +/// `InternalServerError` for `JsonError` +impl ResponseError for JsonError {} + +/// `InternalServerError` for `FormError` +impl ResponseError for FormError {} + +#[cfg(feature = "openssl")] +/// `InternalServerError` for `openssl::ssl::Error` +impl ResponseError for open_ssl::ssl::Error {} + +#[cfg(feature = "openssl")] +/// `InternalServerError` for `openssl::ssl::HandshakeError` +impl ResponseError for open_ssl::ssl::HandshakeError {} + +/// Return `BAD_REQUEST` for `de::value::Error` +impl ResponseError for DeError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + +/// `InternalServerError` for `Canceled` +impl ResponseError for Canceled {} + +/// Return `BAD_REQUEST` for `Utf8Error` +impl ResponseError for Utf8Error { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + +/// Return `InternalServerError` for `HttpError`, +/// Response generation can return `HttpError`, so it is internal error +impl ResponseError for HttpError {} + +/// Return `InternalServerError` for `io::Error` +impl ResponseError for io::Error { + fn status_code(&self) -> StatusCode { + match self.kind() { + io::ErrorKind::NotFound => StatusCode::NOT_FOUND, + io::ErrorKind::PermissionDenied => StatusCode::FORBIDDEN, + _ => StatusCode::INTERNAL_SERVER_ERROR, + } + } +} + +/// `BadRequest` for `InvalidHeaderValue` +impl ResponseError for header::InvalidHeaderValue { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + +/// `BadRequest` for `InvalidHeaderValue` +impl ResponseError for header::InvalidHeaderValueBytes { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + +/// A set of errors that can occur during parsing HTTP streams +#[derive(Debug, Display)] +pub enum ParseError { + /// An invalid `Method`, such as `GE.T`. + #[display(fmt = "Invalid Method specified")] + Method, + /// An invalid `Uri`, such as `exam ple.domain`. + #[display(fmt = "Uri error: {}", _0)] + Uri(InvalidUri), + /// An invalid `HttpVersion`, such as `HTP/1.1` + #[display(fmt = "Invalid HTTP version specified")] + Version, + /// An invalid `Header`. + #[display(fmt = "Invalid Header provided")] + Header, + /// A message head is too large to be reasonable. + #[display(fmt = "Message head is too large")] + TooLarge, + /// A message reached EOF, but is not complete. + #[display(fmt = "Message is incomplete")] + Incomplete, + /// An invalid `Status`, such as `1337 ELITE`. + #[display(fmt = "Invalid Status provided")] + Status, + /// A timeout occurred waiting for an IO event. + #[allow(dead_code)] + #[display(fmt = "Timeout")] + Timeout, + /// An `io::Error` that occurred while trying to read or write to a network + /// stream. + #[display(fmt = "IO error: {}", _0)] + Io(io::Error), + /// Parsing a field as string failed + #[display(fmt = "UTF8 error: {}", _0)] + Utf8(Utf8Error), +} + +/// Return `BadRequest` for `ParseError` +impl ResponseError for ParseError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + +impl From for ParseError { + fn from(err: io::Error) -> ParseError { + ParseError::Io(err) + } +} + +impl From for ParseError { + fn from(err: InvalidUri) -> ParseError { + ParseError::Uri(err) + } +} + +impl From for ParseError { + fn from(err: Utf8Error) -> ParseError { + ParseError::Utf8(err) + } +} + +impl From for ParseError { + fn from(err: FromUtf8Error) -> ParseError { + ParseError::Utf8(err.utf8_error()) + } +} + +impl From for ParseError { + fn from(err: httparse::Error) -> ParseError { + match err { + httparse::Error::HeaderName + | httparse::Error::HeaderValue + | httparse::Error::NewLine + | httparse::Error::Token => ParseError::Header, + httparse::Error::Status => ParseError::Status, + httparse::Error::TooManyHeaders => ParseError::TooLarge, + httparse::Error::Version => ParseError::Version, + } + } +} + +#[derive(Display, Debug)] +/// A set of errors that can occur during payload parsing +pub enum PayloadError { + /// A payload reached EOF, but is not complete. + #[display( + fmt = "A payload reached EOF, but is not complete. With error: {:?}", + _0 + )] + Incomplete(Option), + /// Content encoding stream corruption + #[display(fmt = "Can not decode content-encoding.")] + EncodingCorrupted, + /// A payload reached size limit. + #[display(fmt = "A payload reached size limit.")] + Overflow, + /// A payload length is unknown. + #[display(fmt = "A payload length is unknown.")] + UnknownLength, + /// Http2 payload error + #[display(fmt = "{}", _0)] + Http2Payload(h2::Error), + /// Io error + #[display(fmt = "{}", _0)] + Io(io::Error), +} + +impl From for PayloadError { + fn from(err: h2::Error) -> Self { + PayloadError::Http2Payload(err) + } +} + +impl From> for PayloadError { + fn from(err: Option) -> Self { + PayloadError::Incomplete(err) + } +} + +impl From for PayloadError { + fn from(err: io::Error) -> Self { + PayloadError::Incomplete(Some(err)) + } +} + +impl From for PayloadError { + fn from(_: Canceled) -> Self { + PayloadError::Io(io::Error::new( + io::ErrorKind::Other, + "Operation is canceled", + )) + } +} + +/// `PayloadError` returns two possible results: +/// +/// - `Overflow` returns `PayloadTooLarge` +/// - Other errors returns `BadRequest` +impl ResponseError for PayloadError { + fn status_code(&self) -> StatusCode { + match *self { + PayloadError::Overflow => StatusCode::PAYLOAD_TOO_LARGE, + _ => StatusCode::BAD_REQUEST, + } + } +} + +/// Return `BadRequest` for `cookie::ParseError` +impl ResponseError for crate::cookie::ParseError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + +#[derive(Debug, Display, From)] +/// A set of errors that can occur during dispatching http requests +pub enum DispatchError { + /// Service error + Service(Error), + + /// Upgrade service error + Upgrade, + + /// An `io::Error` that occurred while trying to read or write to a network + /// stream. + #[display(fmt = "IO error: {}", _0)] + Io(io::Error), + + /// Http request parse error. + #[display(fmt = "Parse error: {}", _0)] + Parse(ParseError), + + /// Http/2 error + #[display(fmt = "{}", _0)] + H2(h2::Error), + + /// The first request did not complete within the specified timeout. + #[display(fmt = "The first request did not complete within the specified timeout")] + SlowRequestTimeout, + + /// Disconnect timeout. Makes sense for ssl streams. + #[display(fmt = "Connection shutdown timeout")] + DisconnectTimeout, + + /// Payload is not consumed + #[display(fmt = "Task is completed but request's payload is not consumed")] + PayloadIsNotConsumed, + + /// Malformed request + #[display(fmt = "Malformed request")] + MalformedRequest, + + /// Internal error + #[display(fmt = "Internal error")] + InternalError, + + /// Unknown error + #[display(fmt = "Unknown error")] + Unknown, +} + +/// A set of error that can occure during parsing content type +#[derive(PartialEq, Debug, Display)] +pub enum ContentTypeError { + /// Can not parse content type + #[display(fmt = "Can not parse content type")] + ParseError, + /// Unknown content encoding + #[display(fmt = "Unknown content encoding")] + UnknownEncoding, +} + +/// Return `BadRequest` for `ContentTypeError` +impl ResponseError for ContentTypeError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + +/// Helper type that can wrap any error and generate custom response. +/// +/// In following example any `io::Error` will be converted into "BAD REQUEST" +/// response as opposite to *INTERNAL SERVER ERROR* which is defined by +/// default. +/// +/// ```rust +/// # extern crate actix_http; +/// # use std::io; +/// # use actix_http::*; +/// +/// fn index(req: Request) -> Result<&'static str> { +/// Err(error::ErrorBadRequest(io::Error::new(io::ErrorKind::Other, "error"))) +/// } +/// # fn main() {} +/// ``` +pub struct InternalError { + cause: T, + status: InternalErrorType, +} + +enum InternalErrorType { + Status(StatusCode), + Response(RefCell>), +} + +impl InternalError { + /// Create `InternalError` instance + pub fn new(cause: T, status: StatusCode) -> Self { + InternalError { + cause, + status: InternalErrorType::Status(status), + } + } + + /// Create `InternalError` with predefined `Response`. + pub fn from_response(cause: T, response: Response) -> Self { + InternalError { + cause, + status: InternalErrorType::Response(RefCell::new(Some(response))), + } + } +} + +impl fmt::Debug for InternalError +where + T: fmt::Debug + 'static, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Debug::fmt(&self.cause, f) + } +} + +impl fmt::Display for InternalError +where + T: fmt::Display + 'static, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.cause, f) + } +} + +impl ResponseError for InternalError +where + T: fmt::Debug + fmt::Display + 'static, +{ + fn status_code(&self) -> StatusCode { + match self.status { + InternalErrorType::Status(st) => st, + InternalErrorType::Response(ref resp) => { + if let Some(resp) = resp.borrow().as_ref() { + resp.head().status + } else { + StatusCode::INTERNAL_SERVER_ERROR + } + } + } + } + + fn error_response(&self) -> Response { + match self.status { + InternalErrorType::Status(st) => { + let mut res = Response::new(st); + let mut buf = BytesMut::new(); + let _ = write!(Writer(&mut buf), "{}", self); + res.headers_mut().insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("text/plain; charset=utf-8"), + ); + res.set_body(Body::from(buf)) + } + InternalErrorType::Response(ref resp) => { + if let Some(resp) = resp.borrow_mut().take() { + resp + } else { + Response::new(StatusCode::INTERNAL_SERVER_ERROR) + } + } + } + } +} + +/// Convert Response to a Error +impl From for Error { + fn from(res: Response) -> Error { + InternalError::from_response("", res).into() + } +} + +/// Helper function that creates wrapper of any error and generate *BAD +/// REQUEST* response. +#[allow(non_snake_case)] +pub fn ErrorBadRequest(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::BAD_REQUEST).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *UNAUTHORIZED* response. +#[allow(non_snake_case)] +pub fn ErrorUnauthorized(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::UNAUTHORIZED).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *PAYMENT_REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorPaymentRequired(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::PAYMENT_REQUIRED).into() +} + +/// Helper function that creates wrapper of any error and generate *FORBIDDEN* +/// response. +#[allow(non_snake_case)] +pub fn ErrorForbidden(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::FORBIDDEN).into() +} + +/// Helper function that creates wrapper of any error and generate *NOT FOUND* +/// response. +#[allow(non_snake_case)] +pub fn ErrorNotFound(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::NOT_FOUND).into() +} + +/// Helper function that creates wrapper of any error and generate *METHOD NOT +/// ALLOWED* response. +#[allow(non_snake_case)] +pub fn ErrorMethodNotAllowed(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::METHOD_NOT_ALLOWED).into() +} + +/// Helper function that creates wrapper of any error and generate *NOT +/// ACCEPTABLE* response. +#[allow(non_snake_case)] +pub fn ErrorNotAcceptable(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::NOT_ACCEPTABLE).into() +} + +/// Helper function that creates wrapper of any error and generate *PROXY +/// AUTHENTICATION REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorProxyAuthenticationRequired(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::PROXY_AUTHENTICATION_REQUIRED).into() +} + +/// Helper function that creates wrapper of any error and generate *REQUEST +/// TIMEOUT* response. +#[allow(non_snake_case)] +pub fn ErrorRequestTimeout(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::REQUEST_TIMEOUT).into() +} + +/// Helper function that creates wrapper of any error and generate *CONFLICT* +/// response. +#[allow(non_snake_case)] +pub fn ErrorConflict(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::CONFLICT).into() +} + +/// Helper function that creates wrapper of any error and generate *GONE* +/// response. +#[allow(non_snake_case)] +pub fn ErrorGone(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::GONE).into() +} + +/// Helper function that creates wrapper of any error and generate *LENGTH +/// REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorLengthRequired(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::LENGTH_REQUIRED).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *PAYLOAD TOO LARGE* response. +#[allow(non_snake_case)] +pub fn ErrorPayloadTooLarge(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::PAYLOAD_TOO_LARGE).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *URI TOO LONG* response. +#[allow(non_snake_case)] +pub fn ErrorUriTooLong(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::URI_TOO_LONG).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *UNSUPPORTED MEDIA TYPE* response. +#[allow(non_snake_case)] +pub fn ErrorUnsupportedMediaType(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::UNSUPPORTED_MEDIA_TYPE).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *RANGE NOT SATISFIABLE* response. +#[allow(non_snake_case)] +pub fn ErrorRangeNotSatisfiable(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::RANGE_NOT_SATISFIABLE).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *IM A TEAPOT* response. +#[allow(non_snake_case)] +pub fn ErrorImATeapot(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::IM_A_TEAPOT).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *MISDIRECTED REQUEST* response. +#[allow(non_snake_case)] +pub fn ErrorMisdirectedRequest(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::MISDIRECTED_REQUEST).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *UNPROCESSABLE ENTITY* response. +#[allow(non_snake_case)] +pub fn ErrorUnprocessableEntity(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::UNPROCESSABLE_ENTITY).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *LOCKED* response. +#[allow(non_snake_case)] +pub fn ErrorLocked(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::LOCKED).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *FAILED DEPENDENCY* response. +#[allow(non_snake_case)] +pub fn ErrorFailedDependency(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::FAILED_DEPENDENCY).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *UPGRADE REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorUpgradeRequired(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::UPGRADE_REQUIRED).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *PRECONDITION FAILED* response. +#[allow(non_snake_case)] +pub fn ErrorPreconditionFailed(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::PRECONDITION_FAILED).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *PRECONDITION REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorPreconditionRequired(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::PRECONDITION_REQUIRED).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *TOO MANY REQUESTS* response. +#[allow(non_snake_case)] +pub fn ErrorTooManyRequests(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::TOO_MANY_REQUESTS).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *REQUEST HEADER FIELDS TOO LARGE* response. +#[allow(non_snake_case)] +pub fn ErrorRequestHeaderFieldsTooLarge(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *UNAVAILABLE FOR LEGAL REASONS* response. +#[allow(non_snake_case)] +pub fn ErrorUnavailableForLegalReasons(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *EXPECTATION FAILED* response. +#[allow(non_snake_case)] +pub fn ErrorExpectationFailed(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::EXPECTATION_FAILED).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *INTERNAL SERVER ERROR* response. +#[allow(non_snake_case)] +pub fn ErrorInternalServerError(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::INTERNAL_SERVER_ERROR).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *NOT IMPLEMENTED* response. +#[allow(non_snake_case)] +pub fn ErrorNotImplemented(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::NOT_IMPLEMENTED).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *BAD GATEWAY* response. +#[allow(non_snake_case)] +pub fn ErrorBadGateway(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::BAD_GATEWAY).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *SERVICE UNAVAILABLE* response. +#[allow(non_snake_case)] +pub fn ErrorServiceUnavailable(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::SERVICE_UNAVAILABLE).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *GATEWAY TIMEOUT* response. +#[allow(non_snake_case)] +pub fn ErrorGatewayTimeout(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::GATEWAY_TIMEOUT).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *HTTP VERSION NOT SUPPORTED* response. +#[allow(non_snake_case)] +pub fn ErrorHttpVersionNotSupported(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::HTTP_VERSION_NOT_SUPPORTED).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *VARIANT ALSO NEGOTIATES* response. +#[allow(non_snake_case)] +pub fn ErrorVariantAlsoNegotiates(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::VARIANT_ALSO_NEGOTIATES).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *INSUFFICIENT STORAGE* response. +#[allow(non_snake_case)] +pub fn ErrorInsufficientStorage(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::INSUFFICIENT_STORAGE).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *LOOP DETECTED* response. +#[allow(non_snake_case)] +pub fn ErrorLoopDetected(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::LOOP_DETECTED).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *NOT EXTENDED* response. +#[allow(non_snake_case)] +pub fn ErrorNotExtended(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::NOT_EXTENDED).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *NETWORK AUTHENTICATION REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorNetworkAuthenticationRequired(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::NETWORK_AUTHENTICATION_REQUIRED).into() +} + +#[cfg(feature = "fail")] +mod failure_integration { + use super::*; + + /// Compatibility for `failure::Error` + impl ResponseError for failure::Error {} +} + +#[cfg(test)] +mod tests { + use super::*; + use http::{Error as HttpError, StatusCode}; + use httparse; + use std::error::Error as StdError; + use std::io; + + #[test] + fn test_into_response() { + let resp: Response = ParseError::Incomplete.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + let err: HttpError = StatusCode::from_u16(10000).err().unwrap().into(); + let resp: Response = err.error_response(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + } + + #[test] + fn test_cookie_parse() { + let resp: Response = CookieParseError::EmptyName.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[test] + fn test_as_response() { + let orig = io::Error::new(io::ErrorKind::Other, "other"); + let e: Error = ParseError::Io(orig).into(); + assert_eq!(format!("{}", e.as_response_error()), "IO error: other"); + } + + #[test] + fn test_error_cause() { + let orig = io::Error::new(io::ErrorKind::Other, "other"); + let desc = orig.description().to_owned(); + let e = Error::from(orig); + assert_eq!(format!("{}", e.as_response_error()), desc); + } + + #[test] + fn test_error_display() { + let orig = io::Error::new(io::ErrorKind::Other, "other"); + let desc = orig.description().to_owned(); + let e = Error::from(orig); + assert_eq!(format!("{}", e), desc); + } + + #[test] + fn test_error_http_response() { + let orig = io::Error::new(io::ErrorKind::Other, "other"); + let e = Error::from(orig); + let resp: Response = e.into(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + } + + #[test] + fn test_payload_error() { + let err: PayloadError = + io::Error::new(io::ErrorKind::Other, "ParseError").into(); + assert!(format!("{}", err).contains("ParseError")); + + let err = PayloadError::Incomplete(None); + assert_eq!( + format!("{}", err), + "A payload reached EOF, but is not complete. With error: None" + ); + } + + macro_rules! from { + ($from:expr => $error:pat) => { + match ParseError::from($from) { + e @ $error => { + assert!(format!("{}", e).len() >= 5); + } + e => unreachable!("{:?}", e), + } + }; + } + + macro_rules! from_and_cause { + ($from:expr => $error:pat) => { + match ParseError::from($from) { + e @ $error => { + let desc = format!("{}", e); + assert_eq!(desc, format!("IO error: {}", $from.description())); + } + _ => unreachable!("{:?}", $from), + } + }; + } + + #[test] + fn test_from() { + from_and_cause!(io::Error::new(io::ErrorKind::Other, "other") => ParseError::Io(..)); + from!(httparse::Error::HeaderName => ParseError::Header); + from!(httparse::Error::HeaderName => ParseError::Header); + from!(httparse::Error::HeaderValue => ParseError::Header); + from!(httparse::Error::NewLine => ParseError::Header); + from!(httparse::Error::Status => ParseError::Status); + from!(httparse::Error::Token => ParseError::Header); + from!(httparse::Error::TooManyHeaders => ParseError::TooLarge); + from!(httparse::Error::Version => ParseError::Version); + } + + #[test] + fn test_internal_error() { + let err = + InternalError::from_response(ParseError::Method, Response::Ok().into()); + let resp: Response = err.error_response(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[test] + fn test_error_casting() { + let err = PayloadError::Overflow; + let resp_err: &dyn ResponseError = &err; + let err = resp_err.downcast_ref::().unwrap(); + assert_eq!(err.to_string(), "A payload reached size limit."); + let not_err = resp_err.downcast_ref::(); + assert!(not_err.is_none()); + } + + #[test] + fn test_error_helpers() { + let r: Response = ErrorBadRequest("err").into(); + assert_eq!(r.status(), StatusCode::BAD_REQUEST); + + let r: Response = ErrorUnauthorized("err").into(); + assert_eq!(r.status(), StatusCode::UNAUTHORIZED); + + let r: Response = ErrorPaymentRequired("err").into(); + assert_eq!(r.status(), StatusCode::PAYMENT_REQUIRED); + + let r: Response = ErrorForbidden("err").into(); + assert_eq!(r.status(), StatusCode::FORBIDDEN); + + let r: Response = ErrorNotFound("err").into(); + assert_eq!(r.status(), StatusCode::NOT_FOUND); + + let r: Response = ErrorMethodNotAllowed("err").into(); + assert_eq!(r.status(), StatusCode::METHOD_NOT_ALLOWED); + + let r: Response = ErrorNotAcceptable("err").into(); + assert_eq!(r.status(), StatusCode::NOT_ACCEPTABLE); + + let r: Response = ErrorProxyAuthenticationRequired("err").into(); + assert_eq!(r.status(), StatusCode::PROXY_AUTHENTICATION_REQUIRED); + + let r: Response = ErrorRequestTimeout("err").into(); + assert_eq!(r.status(), StatusCode::REQUEST_TIMEOUT); + + let r: Response = ErrorConflict("err").into(); + assert_eq!(r.status(), StatusCode::CONFLICT); + + let r: Response = ErrorGone("err").into(); + assert_eq!(r.status(), StatusCode::GONE); + + let r: Response = ErrorLengthRequired("err").into(); + assert_eq!(r.status(), StatusCode::LENGTH_REQUIRED); + + let r: Response = ErrorPreconditionFailed("err").into(); + assert_eq!(r.status(), StatusCode::PRECONDITION_FAILED); + + let r: Response = ErrorPayloadTooLarge("err").into(); + assert_eq!(r.status(), StatusCode::PAYLOAD_TOO_LARGE); + + let r: Response = ErrorUriTooLong("err").into(); + assert_eq!(r.status(), StatusCode::URI_TOO_LONG); + + let r: Response = ErrorUnsupportedMediaType("err").into(); + assert_eq!(r.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE); + + let r: Response = ErrorRangeNotSatisfiable("err").into(); + assert_eq!(r.status(), StatusCode::RANGE_NOT_SATISFIABLE); + + let r: Response = ErrorExpectationFailed("err").into(); + assert_eq!(r.status(), StatusCode::EXPECTATION_FAILED); + + let r: Response = ErrorImATeapot("err").into(); + assert_eq!(r.status(), StatusCode::IM_A_TEAPOT); + + let r: Response = ErrorMisdirectedRequest("err").into(); + assert_eq!(r.status(), StatusCode::MISDIRECTED_REQUEST); + + let r: Response = ErrorUnprocessableEntity("err").into(); + assert_eq!(r.status(), StatusCode::UNPROCESSABLE_ENTITY); + + let r: Response = ErrorLocked("err").into(); + assert_eq!(r.status(), StatusCode::LOCKED); + + let r: Response = ErrorFailedDependency("err").into(); + assert_eq!(r.status(), StatusCode::FAILED_DEPENDENCY); + + let r: Response = ErrorUpgradeRequired("err").into(); + assert_eq!(r.status(), StatusCode::UPGRADE_REQUIRED); + + let r: Response = ErrorPreconditionRequired("err").into(); + assert_eq!(r.status(), StatusCode::PRECONDITION_REQUIRED); + + let r: Response = ErrorTooManyRequests("err").into(); + assert_eq!(r.status(), StatusCode::TOO_MANY_REQUESTS); + + let r: Response = ErrorRequestHeaderFieldsTooLarge("err").into(); + assert_eq!(r.status(), StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE); + + let r: Response = ErrorUnavailableForLegalReasons("err").into(); + assert_eq!(r.status(), StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS); + + let r: Response = ErrorInternalServerError("err").into(); + assert_eq!(r.status(), StatusCode::INTERNAL_SERVER_ERROR); + + let r: Response = ErrorNotImplemented("err").into(); + assert_eq!(r.status(), StatusCode::NOT_IMPLEMENTED); + + let r: Response = ErrorBadGateway("err").into(); + assert_eq!(r.status(), StatusCode::BAD_GATEWAY); + + let r: Response = ErrorServiceUnavailable("err").into(); + assert_eq!(r.status(), StatusCode::SERVICE_UNAVAILABLE); + + let r: Response = ErrorGatewayTimeout("err").into(); + assert_eq!(r.status(), StatusCode::GATEWAY_TIMEOUT); + + let r: Response = ErrorHttpVersionNotSupported("err").into(); + assert_eq!(r.status(), StatusCode::HTTP_VERSION_NOT_SUPPORTED); + + let r: Response = ErrorVariantAlsoNegotiates("err").into(); + assert_eq!(r.status(), StatusCode::VARIANT_ALSO_NEGOTIATES); + + let r: Response = ErrorInsufficientStorage("err").into(); + assert_eq!(r.status(), StatusCode::INSUFFICIENT_STORAGE); + + let r: Response = ErrorLoopDetected("err").into(); + assert_eq!(r.status(), StatusCode::LOOP_DETECTED); + + let r: Response = ErrorNotExtended("err").into(); + assert_eq!(r.status(), StatusCode::NOT_EXTENDED); + + let r: Response = ErrorNetworkAuthenticationRequired("err").into(); + assert_eq!(r.status(), StatusCode::NETWORK_AUTHENTICATION_REQUIRED); + } +} diff --git a/actix-http/src/extensions.rs b/actix-http/src/extensions.rs new file mode 100644 index 000000000..c6266f56e --- /dev/null +++ b/actix-http/src/extensions.rs @@ -0,0 +1,91 @@ +use std::any::{Any, TypeId}; +use std::fmt; + +use hashbrown::HashMap; + +#[derive(Default)] +/// A type map of request extensions. +pub struct Extensions { + map: HashMap>, +} + +impl Extensions { + /// Create an empty `Extensions`. + #[inline] + pub fn new() -> Extensions { + Extensions { + map: HashMap::default(), + } + } + + /// Insert a type into this `Extensions`. + /// + /// If a extension of this type already existed, it will + /// be returned. + pub fn insert(&mut self, val: T) { + self.map.insert(TypeId::of::(), Box::new(val)); + } + + /// Check if container contains entry + pub fn contains(&self) -> bool { + self.map.get(&TypeId::of::()).is_some() + } + + /// Get a reference to a type previously inserted on this `Extensions`. + pub fn get(&self) -> Option<&T> { + self.map + .get(&TypeId::of::()) + .and_then(|boxed| (&**boxed as &(dyn Any + 'static)).downcast_ref()) + } + + /// Get a mutable reference to a type previously inserted on this `Extensions`. + pub fn get_mut(&mut self) -> Option<&mut T> { + self.map + .get_mut(&TypeId::of::()) + .and_then(|boxed| (&mut **boxed as &mut (dyn Any + 'static)).downcast_mut()) + } + + /// Remove a type from this `Extensions`. + /// + /// If a extension of this type existed, it will be returned. + pub fn remove(&mut self) -> Option { + self.map.remove(&TypeId::of::()).and_then(|boxed| { + (boxed as Box) + .downcast() + .ok() + .map(|boxed| *boxed) + }) + } + + /// Clear the `Extensions` of all inserted extensions. + #[inline] + pub fn clear(&mut self) { + self.map.clear(); + } +} + +impl fmt::Debug for Extensions { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Extensions").finish() + } +} + +#[test] +fn test_extensions() { + #[derive(Debug, PartialEq)] + struct MyType(i32); + + let mut extensions = Extensions::new(); + + extensions.insert(5i32); + extensions.insert(MyType(10)); + + assert_eq!(extensions.get(), Some(&5i32)); + assert_eq!(extensions.get_mut(), Some(&mut 5i32)); + + assert_eq!(extensions.remove::(), Some(5i32)); + assert!(extensions.get::().is_none()); + + assert_eq!(extensions.get::(), None); + assert_eq!(extensions.get(), Some(&MyType(10))); +} diff --git a/actix-http/src/h1/client.rs b/actix-http/src/h1/client.rs new file mode 100644 index 000000000..bea629c4f --- /dev/null +++ b/actix-http/src/h1/client.rs @@ -0,0 +1,251 @@ +#![allow(unused_imports, unused_variables, dead_code)] +use std::io::{self, Write}; +use std::rc::Rc; + +use actix_codec::{Decoder, Encoder}; +use bitflags::bitflags; +use bytes::{BufMut, Bytes, BytesMut}; +use http::header::{ + HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, UPGRADE, +}; +use http::{Method, Version}; + +use super::decoder::{PayloadDecoder, PayloadItem, PayloadType}; +use super::{decoder, encoder, reserve_readbuf}; +use super::{Message, MessageType}; +use crate::body::BodySize; +use crate::config::ServiceConfig; +use crate::error::{ParseError, PayloadError}; +use crate::header::HeaderMap; +use crate::helpers; +use crate::message::{ + ConnectionType, Head, MessagePool, RequestHead, RequestHeadType, ResponseHead, +}; + +bitflags! { + struct Flags: u8 { + const HEAD = 0b0000_0001; + const KEEPALIVE_ENABLED = 0b0000_1000; + const STREAM = 0b0001_0000; + } +} + +const AVERAGE_HEADER_SIZE: usize = 30; + +/// HTTP/1 Codec +pub struct ClientCodec { + inner: ClientCodecInner, +} + +/// HTTP/1 Payload Codec +pub struct ClientPayloadCodec { + inner: ClientCodecInner, +} + +struct ClientCodecInner { + config: ServiceConfig, + decoder: decoder::MessageDecoder, + payload: Option, + version: Version, + ctype: ConnectionType, + + // encoder part + flags: Flags, + headers_size: u32, + encoder: encoder::MessageEncoder, +} + +impl Default for ClientCodec { + fn default() -> Self { + ClientCodec::new(ServiceConfig::default()) + } +} + +impl ClientCodec { + /// Create HTTP/1 codec. + /// + /// `keepalive_enabled` how response `connection` header get generated. + pub fn new(config: ServiceConfig) -> Self { + let flags = if config.keep_alive_enabled() { + Flags::KEEPALIVE_ENABLED + } else { + Flags::empty() + }; + ClientCodec { + inner: ClientCodecInner { + config, + decoder: decoder::MessageDecoder::default(), + payload: None, + version: Version::HTTP_11, + ctype: ConnectionType::Close, + + flags, + headers_size: 0, + encoder: encoder::MessageEncoder::default(), + }, + } + } + + /// Check if request is upgrade + pub fn upgrade(&self) -> bool { + self.inner.ctype == ConnectionType::Upgrade + } + + /// Check if last response is keep-alive + pub fn keepalive(&self) -> bool { + self.inner.ctype == ConnectionType::KeepAlive + } + + /// Check last request's message type + pub fn message_type(&self) -> MessageType { + if self.inner.flags.contains(Flags::STREAM) { + MessageType::Stream + } else if self.inner.payload.is_none() { + MessageType::None + } else { + MessageType::Payload + } + } + + /// Convert message codec to a payload codec + pub fn into_payload_codec(self) -> ClientPayloadCodec { + ClientPayloadCodec { inner: self.inner } + } +} + +impl ClientPayloadCodec { + /// Check if last response is keep-alive + pub fn keepalive(&self) -> bool { + self.inner.ctype == ConnectionType::KeepAlive + } + + /// Transform payload codec to a message codec + pub fn into_message_codec(self) -> ClientCodec { + ClientCodec { inner: self.inner } + } +} + +impl Decoder for ClientCodec { + type Item = ResponseHead; + type Error = ParseError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + debug_assert!(!self.inner.payload.is_some(), "Payload decoder is set"); + + if let Some((req, payload)) = self.inner.decoder.decode(src)? { + if let Some(ctype) = req.ctype() { + // do not use peer's keep-alive + self.inner.ctype = if ctype == ConnectionType::KeepAlive { + self.inner.ctype + } else { + ctype + }; + } + + if !self.inner.flags.contains(Flags::HEAD) { + match payload { + PayloadType::None => self.inner.payload = None, + PayloadType::Payload(pl) => self.inner.payload = Some(pl), + PayloadType::Stream(pl) => { + self.inner.payload = Some(pl); + self.inner.flags.insert(Flags::STREAM); + } + } + } else { + self.inner.payload = None; + } + reserve_readbuf(src); + Ok(Some(req)) + } else { + Ok(None) + } + } +} + +impl Decoder for ClientPayloadCodec { + type Item = Option; + type Error = PayloadError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + debug_assert!( + self.inner.payload.is_some(), + "Payload decoder is not specified" + ); + + Ok(match self.inner.payload.as_mut().unwrap().decode(src)? { + Some(PayloadItem::Chunk(chunk)) => { + reserve_readbuf(src); + Some(Some(chunk)) + } + Some(PayloadItem::Eof) => { + self.inner.payload.take(); + Some(None) + } + None => None, + }) + } +} + +impl Encoder for ClientCodec { + type Item = Message<(RequestHeadType, BodySize)>; + type Error = io::Error; + + fn encode( + &mut self, + item: Self::Item, + dst: &mut BytesMut, + ) -> Result<(), Self::Error> { + match item { + Message::Item((mut head, length)) => { + let inner = &mut self.inner; + inner.version = head.as_ref().version; + inner + .flags + .set(Flags::HEAD, head.as_ref().method == Method::HEAD); + + // connection status + inner.ctype = match head.as_ref().connection_type() { + ConnectionType::KeepAlive => { + if inner.flags.contains(Flags::KEEPALIVE_ENABLED) { + ConnectionType::KeepAlive + } else { + ConnectionType::Close + } + } + ConnectionType::Upgrade => ConnectionType::Upgrade, + ConnectionType::Close => ConnectionType::Close, + }; + + inner.encoder.encode( + dst, + &mut head, + false, + false, + inner.version, + length, + inner.ctype, + &inner.config, + )?; + } + Message::Chunk(Some(bytes)) => { + self.inner.encoder.encode_chunk(bytes.as_ref(), dst)?; + } + Message::Chunk(None) => { + self.inner.encoder.encode_eof(dst)?; + } + } + Ok(()) + } +} + +pub struct Writer<'a>(pub &'a mut BytesMut); + +impl<'a> io::Write for Writer<'a> { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.extend_from_slice(buf); + Ok(buf.len()) + } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} diff --git a/actix-http/src/h1/codec.rs b/actix-http/src/h1/codec.rs new file mode 100644 index 000000000..22c7ed232 --- /dev/null +++ b/actix-http/src/h1/codec.rs @@ -0,0 +1,253 @@ +#![allow(unused_imports, unused_variables, dead_code)] +use std::io::Write; +use std::{fmt, io, net}; + +use actix_codec::{Decoder, Encoder}; +use bitflags::bitflags; +use bytes::{BufMut, Bytes, BytesMut}; +use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; +use http::{Method, StatusCode, Version}; + +use super::decoder::{PayloadDecoder, PayloadItem, PayloadType}; +use super::{decoder, encoder}; +use super::{Message, MessageType}; +use crate::body::BodySize; +use crate::config::ServiceConfig; +use crate::error::ParseError; +use crate::helpers; +use crate::message::{ConnectionType, Head, ResponseHead}; +use crate::request::Request; +use crate::response::Response; + +bitflags! { + struct Flags: u8 { + const HEAD = 0b0000_0001; + const KEEPALIVE_ENABLED = 0b0000_0010; + const STREAM = 0b0000_0100; + } +} + +const AVERAGE_HEADER_SIZE: usize = 30; + +/// HTTP/1 Codec +pub struct Codec { + config: ServiceConfig, + decoder: decoder::MessageDecoder, + payload: Option, + version: Version, + ctype: ConnectionType, + + // encoder part + flags: Flags, + encoder: encoder::MessageEncoder>, +} + +impl Default for Codec { + fn default() -> Self { + Codec::new(ServiceConfig::default()) + } +} + +impl fmt::Debug for Codec { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "h1::Codec({:?})", self.flags) + } +} + +impl Codec { + /// Create HTTP/1 codec. + /// + /// `keepalive_enabled` how response `connection` header get generated. + pub fn new(config: ServiceConfig) -> Self { + let flags = if config.keep_alive_enabled() { + Flags::KEEPALIVE_ENABLED + } else { + Flags::empty() + }; + Codec { + config, + flags, + decoder: decoder::MessageDecoder::default(), + payload: None, + version: Version::HTTP_11, + ctype: ConnectionType::Close, + encoder: encoder::MessageEncoder::default(), + } + } + + #[inline] + /// Check if request is upgrade + pub fn upgrade(&self) -> bool { + self.ctype == ConnectionType::Upgrade + } + + #[inline] + /// Check if last response is keep-alive + pub fn keepalive(&self) -> bool { + self.ctype == ConnectionType::KeepAlive + } + + #[inline] + /// Check if keep-alive enabled on server level + pub fn keepalive_enabled(&self) -> bool { + self.flags.contains(Flags::KEEPALIVE_ENABLED) + } + + #[inline] + /// Check last request's message type + pub fn message_type(&self) -> MessageType { + if self.flags.contains(Flags::STREAM) { + MessageType::Stream + } else if self.payload.is_none() { + MessageType::None + } else { + MessageType::Payload + } + } + + #[inline] + pub fn config(&self) -> &ServiceConfig { + &self.config + } +} + +impl Decoder for Codec { + type Item = Message; + type Error = ParseError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + if self.payload.is_some() { + Ok(match self.payload.as_mut().unwrap().decode(src)? { + Some(PayloadItem::Chunk(chunk)) => Some(Message::Chunk(Some(chunk))), + Some(PayloadItem::Eof) => { + self.payload.take(); + Some(Message::Chunk(None)) + } + None => None, + }) + } else if let Some((req, payload)) = self.decoder.decode(src)? { + let head = req.head(); + self.flags.set(Flags::HEAD, head.method == Method::HEAD); + self.version = head.version; + self.ctype = head.connection_type(); + if self.ctype == ConnectionType::KeepAlive + && !self.flags.contains(Flags::KEEPALIVE_ENABLED) + { + self.ctype = ConnectionType::Close + } + match payload { + PayloadType::None => self.payload = None, + PayloadType::Payload(pl) => self.payload = Some(pl), + PayloadType::Stream(pl) => { + self.payload = Some(pl); + self.flags.insert(Flags::STREAM); + } + } + Ok(Some(Message::Item(req))) + } else { + Ok(None) + } + } +} + +impl Encoder for Codec { + type Item = Message<(Response<()>, BodySize)>; + type Error = io::Error; + + fn encode( + &mut self, + item: Self::Item, + dst: &mut BytesMut, + ) -> Result<(), Self::Error> { + match item { + Message::Item((mut res, length)) => { + // set response version + res.head_mut().version = self.version; + + // connection status + self.ctype = if let Some(ct) = res.head().ctype() { + if ct == ConnectionType::KeepAlive { + self.ctype + } else { + ct + } + } else { + self.ctype + }; + + // encode message + let len = dst.len(); + self.encoder.encode( + dst, + &mut res, + self.flags.contains(Flags::HEAD), + self.flags.contains(Flags::STREAM), + self.version, + length, + self.ctype, + &self.config, + )?; + // self.headers_size = (dst.len() - len) as u32; + } + Message::Chunk(Some(bytes)) => { + self.encoder.encode_chunk(bytes.as_ref(), dst)?; + } + Message::Chunk(None) => { + self.encoder.encode_eof(dst)?; + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::{cmp, io}; + + use actix_codec::{AsyncRead, AsyncWrite}; + use bytes::{Buf, Bytes, BytesMut}; + use http::{Method, Version}; + + use super::*; + use crate::error::ParseError; + use crate::h1::Message; + use crate::httpmessage::HttpMessage; + use crate::request::Request; + + #[test] + fn test_http_request_chunked_payload_and_next_message() { + let mut codec = Codec::default(); + + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + let item = codec.decode(&mut buf).unwrap().unwrap(); + let req = item.message(); + + assert_eq!(req.method(), Method::GET); + assert!(req.chunked().unwrap()); + + buf.extend( + b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n\ + POST /test2 HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n" + .iter(), + ); + + let msg = codec.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"data"); + + let msg = codec.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"line"); + + let msg = codec.decode(&mut buf).unwrap().unwrap(); + assert!(msg.eof()); + + // decode next message + let item = codec.decode(&mut buf).unwrap().unwrap(); + let req = item.message(); + assert_eq!(*req.method(), Method::POST); + assert!(req.chunked().unwrap()); + } +} diff --git a/actix-http/src/h1/decoder.rs b/actix-http/src/h1/decoder.rs new file mode 100644 index 000000000..ffa00288f --- /dev/null +++ b/actix-http/src/h1/decoder.rs @@ -0,0 +1,1221 @@ +use std::io; +use std::marker::PhantomData; +use std::mem::MaybeUninit; +use std::task::Poll; + +use actix_codec::Decoder; +use bytes::{Bytes, BytesMut}; +use http::header::{HeaderName, HeaderValue}; +use http::{header, HttpTryFrom, Method, StatusCode, Uri, Version}; +use httparse; +use log::{debug, error, trace}; + +use crate::error::ParseError; +use crate::header::HeaderMap; +use crate::message::{ConnectionType, ResponseHead}; +use crate::request::Request; + +const MAX_BUFFER_SIZE: usize = 131_072; +const MAX_HEADERS: usize = 96; + +/// Incoming messagd decoder +pub(crate) struct MessageDecoder(PhantomData); + +#[derive(Debug)] +/// Incoming request type +pub(crate) enum PayloadType { + None, + Payload(PayloadDecoder), + Stream(PayloadDecoder), +} + +impl Default for MessageDecoder { + fn default() -> Self { + MessageDecoder(PhantomData) + } +} + +impl Decoder for MessageDecoder { + type Item = (T, PayloadType); + type Error = ParseError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + T::decode(src) + } +} + +pub(crate) enum PayloadLength { + Payload(PayloadType), + Upgrade, + None, +} + +pub(crate) trait MessageType: Sized { + fn set_connection_type(&mut self, ctype: Option); + + fn set_expect(&mut self); + + fn headers_mut(&mut self) -> &mut HeaderMap; + + fn decode(src: &mut BytesMut) -> Result, ParseError>; + + fn set_headers( + &mut self, + slice: &Bytes, + raw_headers: &[HeaderIndex], + ) -> Result { + let mut ka = None; + let mut has_upgrade = false; + let mut expect = false; + let mut chunked = false; + let mut content_length = None; + + { + let headers = self.headers_mut(); + + for idx in raw_headers.iter() { + let name = + HeaderName::from_bytes(&slice[idx.name.0..idx.name.1]).unwrap(); + + // Unsafe: httparse check header value for valid utf-8 + let value = unsafe { + HeaderValue::from_shared_unchecked( + slice.slice(idx.value.0, idx.value.1), + ) + }; + match name { + header::CONTENT_LENGTH => { + if let Ok(s) = value.to_str() { + if let Ok(len) = s.parse::() { + if len != 0 { + content_length = Some(len); + } + } else { + debug!("illegal Content-Length: {:?}", s); + return Err(ParseError::Header); + } + } else { + debug!("illegal Content-Length: {:?}", value); + return Err(ParseError::Header); + } + } + // transfer-encoding + header::TRANSFER_ENCODING => { + if let Ok(s) = value.to_str().map(|s| s.trim()) { + chunked = s.eq_ignore_ascii_case("chunked"); + } else { + return Err(ParseError::Header); + } + } + // connection keep-alive state + header::CONNECTION => { + ka = if let Ok(conn) = value.to_str().map(|conn| conn.trim()) { + if conn.eq_ignore_ascii_case("keep-alive") { + Some(ConnectionType::KeepAlive) + } else if conn.eq_ignore_ascii_case("close") { + Some(ConnectionType::Close) + } else if conn.eq_ignore_ascii_case("upgrade") { + Some(ConnectionType::Upgrade) + } else { + None + } + } else { + None + }; + } + header::UPGRADE => { + has_upgrade = true; + // check content-length, some clients (dart) + // sends "content-length: 0" with websocket upgrade + if let Ok(val) = value.to_str().map(|val| val.trim()) { + if val.eq_ignore_ascii_case("websocket") { + content_length = None; + } + } + } + header::EXPECT => { + let bytes = value.as_bytes(); + if bytes.len() >= 4 && &bytes[0..4] == b"100-" { + expect = true; + } + } + _ => (), + } + + headers.append(name, value); + } + } + self.set_connection_type(ka); + if expect { + self.set_expect() + } + + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + if chunked { + // Chunked encoding + Ok(PayloadLength::Payload(PayloadType::Payload( + PayloadDecoder::chunked(), + ))) + } else if let Some(len) = content_length { + // Content-Length + Ok(PayloadLength::Payload(PayloadType::Payload( + PayloadDecoder::length(len), + ))) + } else if has_upgrade { + Ok(PayloadLength::Upgrade) + } else { + Ok(PayloadLength::None) + } + } +} + +impl MessageType for Request { + fn set_connection_type(&mut self, ctype: Option) { + if let Some(ctype) = ctype { + self.head_mut().set_connection_type(ctype); + } + } + + fn set_expect(&mut self) { + self.head_mut().set_expect(); + } + + fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.head_mut().headers + } + + fn decode(src: &mut BytesMut) -> Result, ParseError> { + // Unsafe: we read only this data only after httparse parses headers into. + // performance bump for pipeline benchmarks. + let mut headers: [HeaderIndex; MAX_HEADERS] = + unsafe { MaybeUninit::uninit().assume_init() }; + + let (len, method, uri, ver, h_len) = { + let mut parsed: [httparse::Header; MAX_HEADERS] = + unsafe { MaybeUninit::uninit().assume_init() }; + + let mut req = httparse::Request::new(&mut parsed); + match req.parse(src)? { + httparse::Status::Complete(len) => { + let method = Method::from_bytes(req.method.unwrap().as_bytes()) + .map_err(|_| ParseError::Method)?; + let uri = Uri::try_from(req.path.unwrap())?; + let version = if req.version.unwrap() == 1 { + Version::HTTP_11 + } else { + Version::HTTP_10 + }; + HeaderIndex::record(src, req.headers, &mut headers); + + (len, method, uri, version, req.headers.len()) + } + httparse::Status::Partial => return Ok(None), + } + }; + + let mut msg = Request::new(); + + // convert headers + let length = msg.set_headers(&src.split_to(len).freeze(), &headers[..h_len])?; + + // payload decoder + let decoder = match length { + PayloadLength::Payload(pl) => pl, + PayloadLength::Upgrade => { + // upgrade(websocket) + PayloadType::Stream(PayloadDecoder::eof()) + } + PayloadLength::None => { + if method == Method::CONNECT { + PayloadType::Stream(PayloadDecoder::eof()) + } else if src.len() >= MAX_BUFFER_SIZE { + trace!("MAX_BUFFER_SIZE unprocessed data reached, closing"); + return Err(ParseError::TooLarge); + } else { + PayloadType::None + } + } + }; + + let head = msg.head_mut(); + head.uri = uri; + head.method = method; + head.version = ver; + + Ok(Some((msg, decoder))) + } +} + +impl MessageType for ResponseHead { + fn set_connection_type(&mut self, ctype: Option) { + if let Some(ctype) = ctype { + ResponseHead::set_connection_type(self, ctype); + } + } + + fn set_expect(&mut self) {} + + fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.headers + } + + fn decode(src: &mut BytesMut) -> Result, ParseError> { + // Unsafe: we read only this data only after httparse parses headers into. + // performance bump for pipeline benchmarks. + let mut headers: [HeaderIndex; MAX_HEADERS] = + unsafe { MaybeUninit::uninit().assume_init() }; + + let (len, ver, status, h_len) = { + let mut parsed: [httparse::Header; MAX_HEADERS] = + unsafe { MaybeUninit::uninit().assume_init() }; + + let mut res = httparse::Response::new(&mut parsed); + match res.parse(src)? { + httparse::Status::Complete(len) => { + let version = if res.version.unwrap() == 1 { + Version::HTTP_11 + } else { + Version::HTTP_10 + }; + let status = StatusCode::from_u16(res.code.unwrap()) + .map_err(|_| ParseError::Status)?; + HeaderIndex::record(src, res.headers, &mut headers); + + (len, version, status, res.headers.len()) + } + httparse::Status::Partial => return Ok(None), + } + }; + + let mut msg = ResponseHead::new(status); + msg.version = ver; + + // convert headers + let length = msg.set_headers(&src.split_to(len).freeze(), &headers[..h_len])?; + + // message payload + let decoder = if let PayloadLength::Payload(pl) = length { + pl + } else if status == StatusCode::SWITCHING_PROTOCOLS { + // switching protocol or connect + PayloadType::Stream(PayloadDecoder::eof()) + } else if src.len() >= MAX_BUFFER_SIZE { + error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); + return Err(ParseError::TooLarge); + } else { + // for HTTP/1.0 read to eof and close connection + if msg.version == Version::HTTP_10 { + msg.set_connection_type(ConnectionType::Close); + PayloadType::Payload(PayloadDecoder::eof()) + } else { + PayloadType::None + } + }; + + Ok(Some((msg, decoder))) + } +} + +#[derive(Clone, Copy)] +pub(crate) struct HeaderIndex { + pub(crate) name: (usize, usize), + pub(crate) value: (usize, usize), +} + +impl HeaderIndex { + pub(crate) fn record( + bytes: &[u8], + headers: &[httparse::Header], + indices: &mut [HeaderIndex], + ) { + let bytes_ptr = bytes.as_ptr() as usize; + for (header, indices) in headers.iter().zip(indices.iter_mut()) { + let name_start = header.name.as_ptr() as usize - bytes_ptr; + let name_end = name_start + header.name.len(); + indices.name = (name_start, name_end); + let value_start = header.value.as_ptr() as usize - bytes_ptr; + let value_end = value_start + header.value.len(); + indices.value = (value_start, value_end); + } + } +} + +#[derive(Debug, Clone, PartialEq)] +/// Http payload item +pub enum PayloadItem { + Chunk(Bytes), + Eof, +} + +/// Decoders to handle different Transfer-Encodings. +/// +/// If a message body does not include a Transfer-Encoding, it *should* +/// include a Content-Length header. +#[derive(Debug, Clone, PartialEq)] +pub struct PayloadDecoder { + kind: Kind, +} + +impl PayloadDecoder { + pub fn length(x: u64) -> PayloadDecoder { + PayloadDecoder { + kind: Kind::Length(x), + } + } + + pub fn chunked() -> PayloadDecoder { + PayloadDecoder { + kind: Kind::Chunked(ChunkedState::Size, 0), + } + } + + pub fn eof() -> PayloadDecoder { + PayloadDecoder { kind: Kind::Eof } + } +} + +#[derive(Debug, Clone, PartialEq)] +enum Kind { + /// A Reader used when a Content-Length header is passed with a positive + /// integer. + Length(u64), + /// A Reader used when Transfer-Encoding is `chunked`. + Chunked(ChunkedState, u64), + /// A Reader used for responses that don't indicate a length or chunked. + /// + /// Note: This should only used for `Response`s. It is illegal for a + /// `Request` to be made with both `Content-Length` and + /// `Transfer-Encoding: chunked` missing, as explained from the spec: + /// + /// > If a Transfer-Encoding header field is present in a response and + /// > the chunked transfer coding is not the final encoding, the + /// > message body length is determined by reading the connection until + /// > it is closed by the server. If a Transfer-Encoding header field + /// > is present in a request and the chunked transfer coding is not + /// > the final encoding, the message body length cannot be determined + /// > reliably; the server MUST respond with the 400 (Bad Request) + /// > status code and then close the connection. + Eof, +} + +#[derive(Debug, PartialEq, Clone)] +enum ChunkedState { + Size, + SizeLws, + Extension, + SizeLf, + Body, + BodyCr, + BodyLf, + EndCr, + EndLf, + End, +} + +impl Decoder for PayloadDecoder { + type Item = PayloadItem; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + match self.kind { + Kind::Length(ref mut remaining) => { + if *remaining == 0 { + Ok(Some(PayloadItem::Eof)) + } else { + if src.is_empty() { + return Ok(None); + } + let len = src.len() as u64; + let buf; + if *remaining > len { + buf = src.take().freeze(); + *remaining -= len; + } else { + buf = src.split_to(*remaining as usize).freeze(); + *remaining = 0; + }; + trace!("Length read: {}", buf.len()); + Ok(Some(PayloadItem::Chunk(buf))) + } + } + Kind::Chunked(ref mut state, ref mut size) => { + loop { + let mut buf = None; + // advances the chunked state + *state = match state.step(src, size, &mut buf) { + Poll::Pending => return Ok(None), + Poll::Ready(Ok(state)) => state, + Poll::Ready(Err(e)) => return Err(e), + }; + if *state == ChunkedState::End { + trace!("End of chunked stream"); + return Ok(Some(PayloadItem::Eof)); + } + if let Some(buf) = buf { + return Ok(Some(PayloadItem::Chunk(buf))); + } + if src.is_empty() { + return Ok(None); + } + } + } + Kind::Eof => { + if src.is_empty() { + Ok(None) + } else { + Ok(Some(PayloadItem::Chunk(src.take().freeze()))) + } + } + } + } +} + +macro_rules! byte ( + ($rdr:ident) => ({ + if $rdr.len() > 0 { + let b = $rdr[0]; + $rdr.split_to(1); + b + } else { + return Poll::Pending + } + }) +); + +impl ChunkedState { + fn step( + &self, + body: &mut BytesMut, + size: &mut u64, + buf: &mut Option, + ) -> Poll> { + use self::ChunkedState::*; + match *self { + Size => ChunkedState::read_size(body, size), + SizeLws => ChunkedState::read_size_lws(body), + Extension => ChunkedState::read_extension(body), + SizeLf => ChunkedState::read_size_lf(body, size), + Body => ChunkedState::read_body(body, size, buf), + BodyCr => ChunkedState::read_body_cr(body), + BodyLf => ChunkedState::read_body_lf(body), + EndCr => ChunkedState::read_end_cr(body), + EndLf => ChunkedState::read_end_lf(body), + End => Poll::Ready(Ok(ChunkedState::End)), + } + } + + fn read_size( + rdr: &mut BytesMut, + size: &mut u64, + ) -> Poll> { + let radix = 16; + match byte!(rdr) { + b @ b'0'..=b'9' => { + *size *= radix; + *size += u64::from(b - b'0'); + } + b @ b'a'..=b'f' => { + *size *= radix; + *size += u64::from(b + 10 - b'a'); + } + b @ b'A'..=b'F' => { + *size *= radix; + *size += u64::from(b + 10 - b'A'); + } + b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => return Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size line: Invalid Size", + ))); + } + } + Poll::Ready(Ok(ChunkedState::Size)) + } + + fn read_size_lws(rdr: &mut BytesMut) -> Poll> { + trace!("read_size_lws"); + match byte!(rdr) { + // LWS can follow the chunk size, but no more digits can come + b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size linear white space", + ))), + } + } + fn read_extension(rdr: &mut BytesMut) -> Poll> { + match byte!(rdr) { + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions + } + } + fn read_size_lf( + rdr: &mut BytesMut, + size: &mut u64, + ) -> Poll> { + match byte!(rdr) { + b'\n' if *size > 0 => Poll::Ready(Ok(ChunkedState::Body)), + b'\n' if *size == 0 => Poll::Ready(Ok(ChunkedState::EndCr)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size LF", + ))), + } + } + + fn read_body( + rdr: &mut BytesMut, + rem: &mut u64, + buf: &mut Option, + ) -> Poll> { + trace!("Chunked read, remaining={:?}", rem); + + let len = rdr.len() as u64; + if len == 0 { + Poll::Ready(Ok(ChunkedState::Body)) + } else { + let slice; + if *rem > len { + slice = rdr.take().freeze(); + *rem -= len; + } else { + slice = rdr.split_to(*rem as usize).freeze(); + *rem = 0; + } + *buf = Some(slice); + if *rem > 0 { + Poll::Ready(Ok(ChunkedState::Body)) + } else { + Poll::Ready(Ok(ChunkedState::BodyCr)) + } + } + } + + fn read_body_cr(rdr: &mut BytesMut) -> Poll> { + match byte!(rdr) { + b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk body CR", + ))), + } + } + fn read_body_lf(rdr: &mut BytesMut) -> Poll> { + match byte!(rdr) { + b'\n' => Poll::Ready(Ok(ChunkedState::Size)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk body LF", + ))), + } + } + fn read_end_cr(rdr: &mut BytesMut) -> Poll> { + match byte!(rdr) { + b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk end CR", + ))), + } + } + fn read_end_lf(rdr: &mut BytesMut) -> Poll> { + match byte!(rdr) { + b'\n' => Poll::Ready(Ok(ChunkedState::End)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk end LF", + ))), + } + } +} + +#[cfg(test)] +mod tests { + use bytes::{Bytes, BytesMut}; + use http::{Method, Version}; + + use super::*; + use crate::error::ParseError; + use crate::http::header::{HeaderName, SET_COOKIE}; + use crate::httpmessage::HttpMessage; + + impl PayloadType { + fn unwrap(self) -> PayloadDecoder { + match self { + PayloadType::Payload(pl) => pl, + _ => panic!(), + } + } + + fn is_unhandled(&self) -> bool { + match self { + PayloadType::Stream(_) => true, + _ => false, + } + } + } + + impl PayloadItem { + fn chunk(self) -> Bytes { + match self { + PayloadItem::Chunk(chunk) => chunk, + _ => panic!("error"), + } + } + fn eof(&self) -> bool { + match *self { + PayloadItem::Eof => true, + _ => false, + } + } + } + + macro_rules! parse_ready { + ($e:expr) => {{ + match MessageDecoder::::default().decode($e) { + Ok(Some((msg, _))) => msg, + Ok(_) => unreachable!("Eof during parsing http request"), + Err(err) => unreachable!("Error during parsing http request: {:?}", err), + } + }}; + } + + macro_rules! expect_parse_err { + ($e:expr) => {{ + match MessageDecoder::::default().decode($e) { + Err(err) => match err { + ParseError::Io(_) => unreachable!("Parse error expected"), + _ => (), + }, + _ => unreachable!("Error expected"), + } + }}; + } + + #[test] + fn test_parse() { + let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n\r\n"); + + let mut reader = MessageDecoder::::default(); + match reader.decode(&mut buf) { + Ok(Some((req, _))) => { + assert_eq!(req.version(), Version::HTTP_11); + assert_eq!(*req.method(), Method::GET); + assert_eq!(req.path(), "/test"); + } + Ok(_) | Err(_) => unreachable!("Error during parsing http request"), + } + } + + #[test] + fn test_parse_partial() { + let mut buf = BytesMut::from("PUT /test HTTP/1"); + + let mut reader = MessageDecoder::::default(); + assert!(reader.decode(&mut buf).unwrap().is_none()); + + buf.extend(b".1\r\n\r\n"); + let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); + assert_eq!(req.version(), Version::HTTP_11); + assert_eq!(*req.method(), Method::PUT); + assert_eq!(req.path(), "/test"); + } + + #[test] + fn test_parse_post() { + let mut buf = BytesMut::from("POST /test2 HTTP/1.0\r\n\r\n"); + + let mut reader = MessageDecoder::::default(); + let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); + assert_eq!(req.version(), Version::HTTP_10); + assert_eq!(*req.method(), Method::POST); + assert_eq!(req.path(), "/test2"); + } + + #[test] + fn test_parse_body() { + let mut buf = + BytesMut::from("GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); + + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert_eq!(req.version(), Version::HTTP_11); + assert_eq!(*req.method(), Method::GET); + assert_eq!(req.path(), "/test"); + assert_eq!( + pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), + b"body" + ); + } + + #[test] + fn test_parse_body_crlf() { + let mut buf = + BytesMut::from("\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); + + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert_eq!(req.version(), Version::HTTP_11); + assert_eq!(*req.method(), Method::GET); + assert_eq!(req.path(), "/test"); + assert_eq!( + pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), + b"body" + ); + } + + #[test] + fn test_parse_partial_eof() { + let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n"); + let mut reader = MessageDecoder::::default(); + assert!(reader.decode(&mut buf).unwrap().is_none()); + + buf.extend(b"\r\n"); + let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); + assert_eq!(req.version(), Version::HTTP_11); + assert_eq!(*req.method(), Method::GET); + assert_eq!(req.path(), "/test"); + } + + #[test] + fn test_headers_split_field() { + let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n"); + + let mut reader = MessageDecoder::::default(); + assert! { reader.decode(&mut buf).unwrap().is_none() } + + buf.extend(b"t"); + assert! { reader.decode(&mut buf).unwrap().is_none() } + + buf.extend(b"es"); + assert! { reader.decode(&mut buf).unwrap().is_none() } + + buf.extend(b"t: value\r\n\r\n"); + let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); + assert_eq!(req.version(), Version::HTTP_11); + assert_eq!(*req.method(), Method::GET); + assert_eq!(req.path(), "/test"); + assert_eq!( + req.headers() + .get(HeaderName::try_from("test").unwrap()) + .unwrap() + .as_bytes(), + b"value" + ); + } + + #[test] + fn test_headers_multi_value() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + Set-Cookie: c1=cookie1\r\n\ + Set-Cookie: c2=cookie2\r\n\r\n", + ); + let mut reader = MessageDecoder::::default(); + let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); + + let val: Vec<_> = req + .headers() + .get_all(SET_COOKIE) + .map(|v| v.to_str().unwrap().to_owned()) + .collect(); + assert_eq!(val[1], "c1=cookie1"); + assert_eq!(val[0], "c2=cookie2"); + } + + #[test] + fn test_conn_default_1_0() { + let mut buf = BytesMut::from("GET /test HTTP/1.0\r\n\r\n"); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().connection_type(), ConnectionType::Close); + } + + #[test] + fn test_conn_default_1_1() { + let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n\r\n"); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().connection_type(), ConnectionType::KeepAlive); + } + + #[test] + fn test_conn_close() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + connection: close\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().connection_type(), ConnectionType::Close); + + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + connection: Close\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().connection_type(), ConnectionType::Close); + } + + #[test] + fn test_conn_close_1_0() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.0\r\n\ + connection: close\r\n\r\n", + ); + + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().connection_type(), ConnectionType::Close); + } + + #[test] + fn test_conn_keep_alive_1_0() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.0\r\n\ + connection: keep-alive\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().connection_type(), ConnectionType::KeepAlive); + + let mut buf = BytesMut::from( + "GET /test HTTP/1.0\r\n\ + connection: Keep-Alive\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().connection_type(), ConnectionType::KeepAlive); + } + + #[test] + fn test_conn_keep_alive_1_1() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + connection: keep-alive\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().connection_type(), ConnectionType::KeepAlive); + } + + #[test] + fn test_conn_other_1_0() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.0\r\n\ + connection: other\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().connection_type(), ConnectionType::Close); + } + + #[test] + fn test_conn_other_1_1() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + connection: other\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().connection_type(), ConnectionType::KeepAlive); + } + + #[test] + fn test_conn_upgrade() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + upgrade: websockets\r\n\ + connection: upgrade\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert!(req.upgrade()); + assert_eq!(req.head().connection_type(), ConnectionType::Upgrade); + + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + upgrade: Websockets\r\n\ + connection: Upgrade\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert!(req.upgrade()); + assert_eq!(req.head().connection_type(), ConnectionType::Upgrade); + } + + #[test] + fn test_conn_upgrade_connect_method() { + let mut buf = BytesMut::from( + "CONNECT /test HTTP/1.1\r\n\ + content-type: text/plain\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert!(req.upgrade()); + } + + #[test] + fn test_request_chunked() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + if let Ok(val) = req.chunked() { + assert!(val); + } else { + unreachable!("Error"); + } + + // type in chunked + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chnked\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + if let Ok(val) = req.chunked() { + assert!(!val); + } else { + unreachable!("Error"); + } + } + + #[test] + fn test_headers_content_length_err_1() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + content-length: line\r\n\r\n", + ); + + expect_parse_err!(&mut buf) + } + + #[test] + fn test_headers_content_length_err_2() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + content-length: -1\r\n\r\n", + ); + + expect_parse_err!(&mut buf); + } + + #[test] + fn test_invalid_header() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + test line\r\n\r\n", + ); + + expect_parse_err!(&mut buf); + } + + #[test] + fn test_invalid_name() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + test[]: line\r\n\r\n", + ); + + expect_parse_err!(&mut buf); + } + + #[test] + fn test_http_request_bad_status_line() { + let mut buf = BytesMut::from("getpath \r\n\r\n"); + expect_parse_err!(&mut buf); + } + + #[test] + fn test_http_request_upgrade() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + connection: upgrade\r\n\ + upgrade: websocket\r\n\r\n\ + some raw data", + ); + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + assert_eq!(req.head().connection_type(), ConnectionType::Upgrade); + assert!(req.upgrade()); + assert!(pl.is_unhandled()); + } + + #[test] + fn test_http_request_parser_utf8() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + x-test: теÑÑ‚\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert_eq!( + req.headers().get("x-test").unwrap().as_bytes(), + "теÑÑ‚".as_bytes() + ); + } + + #[test] + fn test_http_request_parser_two_slashes() { + let mut buf = BytesMut::from("GET //path HTTP/1.1\r\n\r\n"); + let req = parse_ready!(&mut buf); + + assert_eq!(req.path(), "//path"); + } + + #[test] + fn test_http_request_parser_bad_method() { + let mut buf = BytesMut::from("!12%()+=~$ /get HTTP/1.1\r\n\r\n"); + + expect_parse_err!(&mut buf); + } + + #[test] + fn test_http_request_parser_bad_version() { + let mut buf = BytesMut::from("GET //get HT/11\r\n\r\n"); + + expect_parse_err!(&mut buf); + } + + #[test] + fn test_http_request_chunked_payload() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert!(req.chunked().unwrap()); + + buf.extend(b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); + assert_eq!( + pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), + b"data" + ); + assert_eq!( + pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), + b"line" + ); + assert!(pl.decode(&mut buf).unwrap().unwrap().eof()); + } + + #[test] + fn test_http_request_chunked_payload_and_next_message() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert!(req.chunked().unwrap()); + + buf.extend( + b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n\ + POST /test2 HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n" + .iter(), + ); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"data"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"line"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert!(msg.eof()); + + let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); + assert!(req.chunked().unwrap()); + assert_eq!(*req.method(), Method::POST); + assert!(req.chunked().unwrap()); + } + + #[test] + fn test_http_request_chunked_payload_chunks() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert!(req.chunked().unwrap()); + + buf.extend(b"4\r\n1111\r\n"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"1111"); + + buf.extend(b"4\r\ndata\r"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"data"); + + buf.extend(b"\n4"); + assert!(pl.decode(&mut buf).unwrap().is_none()); + + buf.extend(b"\r"); + assert!(pl.decode(&mut buf).unwrap().is_none()); + buf.extend(b"\n"); + assert!(pl.decode(&mut buf).unwrap().is_none()); + + buf.extend(b"li"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"li"); + + //trailers + //buf.feed_data("test: test\r\n"); + //not_ready!(reader.parse(&mut buf, &mut readbuf)); + + buf.extend(b"ne\r\n0\r\n"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"ne"); + assert!(pl.decode(&mut buf).unwrap().is_none()); + + buf.extend(b"\r\n"); + assert!(pl.decode(&mut buf).unwrap().unwrap().eof()); + } + + #[test] + fn test_parse_chunked_payload_chunk_extension() { + let mut buf = BytesMut::from( + &"GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n"[..], + ); + + let mut reader = MessageDecoder::::default(); + let (msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert!(msg.chunked().unwrap()); + + buf.extend(b"4;test\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); // test: test\r\n\r\n") + let chunk = pl.decode(&mut buf).unwrap().unwrap().chunk(); + assert_eq!(chunk, Bytes::from_static(b"data")); + let chunk = pl.decode(&mut buf).unwrap().unwrap().chunk(); + assert_eq!(chunk, Bytes::from_static(b"line")); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert!(msg.eof()); + } + + #[test] + fn test_response_http10_read_until_eof() { + let mut buf = BytesMut::from(&"HTTP/1.0 200 Ok\r\n\r\ntest data"[..]); + + let mut reader = MessageDecoder::::default(); + let (_msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + + let chunk = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(chunk, PayloadItem::Chunk(Bytes::from_static(b"test data"))); + } +} diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs new file mode 100644 index 000000000..154b3ed40 --- /dev/null +++ b/actix-http/src/h1/dispatcher.rs @@ -0,0 +1,923 @@ +use std::collections::VecDeque; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use std::{fmt, io, net}; + +use actix_codec::{AsyncRead, Decoder, Encoder, Framed, FramedParts}; +use actix_rt::time::{delay, Delay}; +use actix_server_config::IoStream; +use actix_service::Service; +use bitflags::bitflags; +use bytes::{BufMut, BytesMut}; +use log::{error, trace}; + +use crate::body::{Body, BodySize, MessageBody, ResponseBody}; +use crate::cloneable::CloneableService; +use crate::config::ServiceConfig; +use crate::error::{DispatchError, Error}; +use crate::error::{ParseError, PayloadError}; +use crate::helpers::DataFactory; +use crate::httpmessage::HttpMessage; +use crate::request::Request; +use crate::response::Response; + +use super::codec::Codec; +use super::payload::{Payload, PayloadSender, PayloadStatus}; +use super::{Message, MessageType}; + +const LW_BUFFER_SIZE: usize = 4096; +const HW_BUFFER_SIZE: usize = 32_768; +const MAX_PIPELINED_MESSAGES: usize = 16; + +bitflags! { + pub struct Flags: u8 { + const STARTED = 0b0000_0001; + const KEEPALIVE = 0b0000_0010; + const POLLED = 0b0000_0100; + const SHUTDOWN = 0b0000_1000; + const READ_DISCONNECT = 0b0001_0000; + const WRITE_DISCONNECT = 0b0010_0000; + const UPGRADE = 0b0100_0000; + } +} + +/// Dispatcher for HTTP/1.1 protocol +pub struct Dispatcher +where + S: Service, + S::Error: Into, + B: MessageBody, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + inner: DispatcherState, +} + +enum DispatcherState +where + S: Service, + S::Error: Into, + B: MessageBody, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + Normal(InnerDispatcher), + Upgrade(U::Future), + None, +} + +struct InnerDispatcher +where + S: Service, + S::Error: Into, + B: MessageBody, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + service: CloneableService, + expect: CloneableService, + upgrade: Option>, + on_connect: Option>, + flags: Flags, + peer_addr: Option, + error: Option, + + state: State, + payload: Option, + messages: VecDeque, + + ka_expire: Instant, + ka_timer: Option, + + io: T, + read_buf: BytesMut, + write_buf: BytesMut, + codec: Codec, +} + +enum DispatcherMessage { + Item(Request), + Upgrade(Request), + Error(Response<()>), +} + +enum State +where + S: Service, + X: Service, + B: MessageBody, +{ + None, + ExpectCall(X::Future), + ServiceCall(S::Future), + SendPayload(ResponseBody), +} + +impl State +where + S: Service, + X: Service, + B: MessageBody, +{ + fn is_empty(&self) -> bool { + if let State::None = self { + true + } else { + false + } + } + + fn is_call(&self) -> bool { + if let State::ServiceCall(_) = self { + true + } else { + false + } + } +} + +enum PollResponse { + Upgrade(Request), + DoNothing, + DrainWriteBuf, +} + +impl PartialEq for PollResponse { + fn eq(&self, other: &PollResponse) -> bool { + match self { + PollResponse::DrainWriteBuf => match other { + PollResponse::DrainWriteBuf => true, + _ => false, + }, + PollResponse::DoNothing => match other { + PollResponse::DoNothing => true, + _ => false, + }, + _ => false, + } + } +} + +impl Dispatcher +where + T: IoStream, + S: Service, + S::Error: Into, + S::Response: Into>, + B: MessageBody, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + /// Create http/1 dispatcher. + pub(crate) fn new( + stream: T, + config: ServiceConfig, + service: CloneableService, + expect: CloneableService, + upgrade: Option>, + on_connect: Option>, + ) -> Self { + Dispatcher::with_timeout( + stream, + Codec::new(config.clone()), + config, + BytesMut::with_capacity(HW_BUFFER_SIZE), + None, + service, + expect, + upgrade, + on_connect, + ) + } + + /// Create http/1 dispatcher with slow request timeout. + pub(crate) fn with_timeout( + io: T, + codec: Codec, + config: ServiceConfig, + read_buf: BytesMut, + timeout: Option, + service: CloneableService, + expect: CloneableService, + upgrade: Option>, + on_connect: Option>, + ) -> Self { + let keepalive = config.keep_alive_enabled(); + let flags = if keepalive { + Flags::KEEPALIVE + } else { + Flags::empty() + }; + + // keep-alive timer + let (ka_expire, ka_timer) = if let Some(delay) = timeout { + (delay.deadline(), Some(delay)) + } else if let Some(delay) = config.keep_alive_timer() { + (delay.deadline(), Some(delay)) + } else { + (config.now(), None) + }; + + Dispatcher { + inner: DispatcherState::Normal(InnerDispatcher { + write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE), + payload: None, + state: State::None, + error: None, + peer_addr: io.peer_addr(), + messages: VecDeque::new(), + io, + codec, + read_buf, + service, + expect, + upgrade, + on_connect, + flags, + ka_expire, + ka_timer, + }), + } + } +} + +impl InnerDispatcher +where + T: IoStream, + S: Service, + S::Error: Into, + S::Response: Into>, + B: MessageBody, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + fn can_read(&self, cx: &mut Context) -> bool { + if self + .flags + .intersects(Flags::READ_DISCONNECT | Flags::UPGRADE) + { + false + } else if let Some(ref info) = self.payload { + info.need_read(cx) == PayloadStatus::Read + } else { + true + } + } + + // if checked is set to true, delay disconnect until all tasks have finished. + fn client_disconnected(&mut self) { + self.flags + .insert(Flags::READ_DISCONNECT | Flags::WRITE_DISCONNECT); + if let Some(mut payload) = self.payload.take() { + payload.set_error(PayloadError::Incomplete(None)); + } + } + + /// Flush stream + /// + /// true - got whouldblock + /// false - didnt get whouldblock + fn poll_flush(&mut self, cx: &mut Context) -> Result { + if self.write_buf.is_empty() { + return Ok(false); + } + + let len = self.write_buf.len(); + let mut written = 0; + while written < len { + match unsafe { Pin::new_unchecked(&mut self.io) } + .poll_write(cx, &self.write_buf[written..]) + { + Poll::Ready(Ok(0)) => { + return Err(DispatchError::Io(io::Error::new( + io::ErrorKind::WriteZero, + "", + ))); + } + Poll::Ready(Ok(n)) => { + written += n; + } + Poll::Pending => { + if written > 0 { + let _ = self.write_buf.split_to(written); + } + return Ok(true); + } + Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)), + } + } + if written == self.write_buf.len() { + unsafe { self.write_buf.set_len(0) } + } else { + let _ = self.write_buf.split_to(written); + } + Ok(false) + } + + fn send_response( + &mut self, + message: Response<()>, + body: ResponseBody, + ) -> Result, DispatchError> { + self.codec + .encode(Message::Item((message, body.size())), &mut self.write_buf) + .map_err(|err| { + if let Some(mut payload) = self.payload.take() { + payload.set_error(PayloadError::Incomplete(None)); + } + DispatchError::Io(err) + })?; + + self.flags.set(Flags::KEEPALIVE, self.codec.keepalive()); + match body.size() { + BodySize::None | BodySize::Empty => Ok(State::None), + _ => Ok(State::SendPayload(body)), + } + } + + fn send_continue(&mut self) { + self.write_buf + .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); + } + + fn poll_response( + &mut self, + cx: &mut Context, + ) -> Result { + loop { + let state = match self.state { + State::None => match self.messages.pop_front() { + Some(DispatcherMessage::Item(req)) => { + Some(self.handle_request(req, cx)?) + } + Some(DispatcherMessage::Error(res)) => { + Some(self.send_response(res, ResponseBody::Other(Body::Empty))?) + } + Some(DispatcherMessage::Upgrade(req)) => { + return Ok(PollResponse::Upgrade(req)); + } + None => None, + }, + State::ExpectCall(ref mut fut) => { + match unsafe { Pin::new_unchecked(fut) }.poll(cx) { + Poll::Ready(Ok(req)) => { + self.send_continue(); + self.state = State::ServiceCall(self.service.call(req)); + continue; + } + Poll::Ready(Err(e)) => { + let res: Response = e.into().into(); + let (res, body) = res.replace_body(()); + Some(self.send_response(res, body.into_body())?) + } + Poll::Pending => None, + } + } + State::ServiceCall(ref mut fut) => { + match unsafe { Pin::new_unchecked(fut) }.poll(cx) { + Poll::Ready(Ok(res)) => { + let (res, body) = res.into().replace_body(()); + self.state = self.send_response(res, body)?; + continue; + } + Poll::Ready(Err(e)) => { + let res: Response = e.into().into(); + let (res, body) = res.replace_body(()); + Some(self.send_response(res, body.into_body())?) + } + Poll::Pending => None, + } + } + State::SendPayload(ref mut stream) => { + loop { + if self.write_buf.len() < HW_BUFFER_SIZE { + match stream.poll_next(cx) { + Poll::Ready(Some(Ok(item))) => { + self.codec.encode( + Message::Chunk(Some(item)), + &mut self.write_buf, + )?; + continue; + } + Poll::Ready(None) => { + self.codec.encode( + Message::Chunk(None), + &mut self.write_buf, + )?; + self.state = State::None; + } + Poll::Ready(Some(Err(_))) => { + return Err(DispatchError::Unknown) + } + Poll::Pending => return Ok(PollResponse::DoNothing), + } + } else { + return Ok(PollResponse::DrainWriteBuf); + } + break; + } + continue; + } + }; + + // set new state + if let Some(state) = state { + self.state = state; + if !self.state.is_empty() { + continue; + } + } else { + // if read-backpressure is enabled and we consumed some data. + // we may read more data and retry + if self.state.is_call() { + if self.poll_request(cx)? { + continue; + } + } else if !self.messages.is_empty() { + continue; + } + } + break; + } + + Ok(PollResponse::DoNothing) + } + + fn handle_request( + &mut self, + req: Request, + cx: &mut Context, + ) -> Result, DispatchError> { + // Handle `EXPECT: 100-Continue` header + let req = if req.head().expect() { + let mut task = self.expect.call(req); + match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) { + Poll::Ready(Ok(req)) => { + self.send_continue(); + req + } + Poll::Pending => return Ok(State::ExpectCall(task)), + Poll::Ready(Err(e)) => { + let e = e.into(); + let res: Response = e.into(); + let (res, body) = res.replace_body(()); + return self.send_response(res, body.into_body()); + } + } + } else { + req + }; + + // Call service + let mut task = self.service.call(req); + match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) { + Poll::Ready(Ok(res)) => { + let (res, body) = res.into().replace_body(()); + self.send_response(res, body) + } + Poll::Pending => Ok(State::ServiceCall(task)), + Poll::Ready(Err(e)) => { + let res: Response = e.into().into(); + let (res, body) = res.replace_body(()); + self.send_response(res, body.into_body()) + } + } + } + + /// Process one incoming requests + pub(self) fn poll_request( + &mut self, + cx: &mut Context, + ) -> Result { + // limit a mount of non processed requests + if self.messages.len() >= MAX_PIPELINED_MESSAGES || !self.can_read(cx) { + return Ok(false); + } + + let mut updated = false; + loop { + match self.codec.decode(&mut self.read_buf) { + Ok(Some(msg)) => { + updated = true; + self.flags.insert(Flags::STARTED); + + match msg { + Message::Item(mut req) => { + let pl = self.codec.message_type(); + req.head_mut().peer_addr = self.peer_addr; + + // set on_connect data + if let Some(ref on_connect) = self.on_connect { + on_connect.set(&mut req.extensions_mut()); + } + + if pl == MessageType::Stream && self.upgrade.is_some() { + self.messages.push_back(DispatcherMessage::Upgrade(req)); + break; + } + if pl == MessageType::Payload || pl == MessageType::Stream { + let (ps, pl) = Payload::create(false); + let (req1, _) = + req.replace_payload(crate::Payload::H1(pl)); + req = req1; + self.payload = Some(ps); + } + + // handle request early + if self.state.is_empty() { + self.state = self.handle_request(req, cx)?; + } else { + self.messages.push_back(DispatcherMessage::Item(req)); + } + } + Message::Chunk(Some(chunk)) => { + if let Some(ref mut payload) = self.payload { + payload.feed_data(chunk); + } else { + error!( + "Internal server error: unexpected payload chunk" + ); + self.flags.insert(Flags::READ_DISCONNECT); + self.messages.push_back(DispatcherMessage::Error( + Response::InternalServerError().finish().drop_body(), + )); + self.error = Some(DispatchError::InternalError); + break; + } + } + Message::Chunk(None) => { + if let Some(mut payload) = self.payload.take() { + payload.feed_eof(); + } else { + error!("Internal server error: unexpected eof"); + self.flags.insert(Flags::READ_DISCONNECT); + self.messages.push_back(DispatcherMessage::Error( + Response::InternalServerError().finish().drop_body(), + )); + self.error = Some(DispatchError::InternalError); + break; + } + } + } + } + Ok(None) => break, + Err(ParseError::Io(e)) => { + self.client_disconnected(); + self.error = Some(DispatchError::Io(e)); + break; + } + Err(e) => { + if let Some(mut payload) = self.payload.take() { + payload.set_error(PayloadError::EncodingCorrupted); + } + + // Malformed requests should be responded with 400 + self.messages.push_back(DispatcherMessage::Error( + Response::BadRequest().finish().drop_body(), + )); + self.flags.insert(Flags::READ_DISCONNECT); + self.error = Some(e.into()); + break; + } + } + } + + if updated && self.ka_timer.is_some() { + if let Some(expire) = self.codec.config().keep_alive_expire() { + self.ka_expire = expire; + } + } + Ok(updated) + } + + /// keep-alive timer + fn poll_keepalive(&mut self, cx: &mut Context) -> Result<(), DispatchError> { + if self.ka_timer.is_none() { + // shutdown timeout + if self.flags.contains(Flags::SHUTDOWN) { + if let Some(interval) = self.codec.config().client_disconnect_timer() { + self.ka_timer = Some(delay(interval)); + } else { + self.flags.insert(Flags::READ_DISCONNECT); + if let Some(mut payload) = self.payload.take() { + payload.set_error(PayloadError::Incomplete(None)); + } + return Ok(()); + } + } else { + return Ok(()); + } + } + + match Pin::new(&mut self.ka_timer.as_mut().unwrap()).poll(cx) { + Poll::Ready(()) => { + // if we get timeout during shutdown, drop connection + if self.flags.contains(Flags::SHUTDOWN) { + return Err(DispatchError::DisconnectTimeout); + } else if self.ka_timer.as_mut().unwrap().deadline() >= self.ka_expire { + // check for any outstanding tasks + if self.state.is_empty() && self.write_buf.is_empty() { + if self.flags.contains(Flags::STARTED) { + trace!("Keep-alive timeout, close connection"); + self.flags.insert(Flags::SHUTDOWN); + + // start shutdown timer + if let Some(deadline) = + self.codec.config().client_disconnect_timer() + { + if let Some(mut timer) = self.ka_timer.as_mut() { + timer.reset(deadline); + let _ = Pin::new(&mut timer).poll(cx); + } + } else { + // no shutdown timeout, drop socket + self.flags.insert(Flags::WRITE_DISCONNECT); + return Ok(()); + } + } else { + // timeout on first request (slow request) return 408 + if !self.flags.contains(Flags::STARTED) { + trace!("Slow request timeout"); + let _ = self.send_response( + Response::RequestTimeout().finish().drop_body(), + ResponseBody::Other(Body::Empty), + ); + } else { + trace!("Keep-alive connection timeout"); + } + self.flags.insert(Flags::STARTED | Flags::SHUTDOWN); + self.state = State::None; + } + } else if let Some(deadline) = + self.codec.config().keep_alive_expire() + { + if let Some(mut timer) = self.ka_timer.as_mut() { + timer.reset(deadline); + let _ = Pin::new(&mut timer).poll(cx); + } + } + } else if let Some(mut timer) = self.ka_timer.as_mut() { + timer.reset(self.ka_expire); + let _ = Pin::new(&mut timer).poll(cx); + } + } + Poll::Pending => (), + } + + Ok(()) + } +} + +impl Unpin for Dispatcher +where + T: IoStream, + S: Service, + S::Error: Into, + S::Response: Into>, + B: MessageBody, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ +} + +impl Future for Dispatcher +where + T: IoStream, + S: Service, + S::Error: Into, + S::Response: Into>, + B: MessageBody, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + type Output = Result<(), DispatchError>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match self.as_mut().inner { + DispatcherState::Normal(ref mut inner) => { + inner.poll_keepalive(cx)?; + + if inner.flags.contains(Flags::SHUTDOWN) { + if inner.flags.contains(Flags::WRITE_DISCONNECT) { + Poll::Ready(Ok(())) + } else { + // flush buffer + inner.poll_flush(cx)?; + if !inner.write_buf.is_empty() { + Poll::Pending + } else { + match Pin::new(&mut inner.io).poll_shutdown(cx) { + Poll::Ready(res) => { + Poll::Ready(res.map_err(DispatchError::from)) + } + Poll::Pending => Poll::Pending, + } + } + } + } else { + // read socket into a buf + let should_disconnect = + if !inner.flags.contains(Flags::READ_DISCONNECT) { + read_available(cx, &mut inner.io, &mut inner.read_buf)? + } else { + None + }; + + inner.poll_request(cx)?; + if let Some(true) = should_disconnect { + inner.flags.insert(Flags::READ_DISCONNECT); + if let Some(mut payload) = inner.payload.take() { + payload.feed_eof(); + } + }; + + loop { + if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE { + inner.write_buf.reserve(HW_BUFFER_SIZE); + } + let result = inner.poll_response(cx)?; + let drain = result == PollResponse::DrainWriteBuf; + + // switch to upgrade handler + if let PollResponse::Upgrade(req) = result { + if let DispatcherState::Normal(inner) = + std::mem::replace(&mut self.inner, DispatcherState::None) + { + let mut parts = FramedParts::with_read_buf( + inner.io, + inner.codec, + inner.read_buf, + ); + parts.write_buf = inner.write_buf; + let framed = Framed::from_parts(parts); + self.inner = DispatcherState::Upgrade( + inner.upgrade.unwrap().call((req, framed)), + ); + return self.poll(cx); + } else { + panic!() + } + } + + // we didnt get WouldBlock from write operation, + // so data get written to kernel completely (OSX) + // and we have to write again otherwise response can get stuck + if inner.poll_flush(cx)? || !drain { + break; + } + } + + // client is gone + if inner.flags.contains(Flags::WRITE_DISCONNECT) { + return Poll::Ready(Ok(())); + } + + let is_empty = inner.state.is_empty(); + + // read half is closed and we do not processing any responses + if inner.flags.contains(Flags::READ_DISCONNECT) && is_empty { + inner.flags.insert(Flags::SHUTDOWN); + } + + // keep-alive and stream errors + if is_empty && inner.write_buf.is_empty() { + if let Some(err) = inner.error.take() { + Poll::Ready(Err(err)) + } + // disconnect if keep-alive is not enabled + else if inner.flags.contains(Flags::STARTED) + && !inner.flags.intersects(Flags::KEEPALIVE) + { + inner.flags.insert(Flags::SHUTDOWN); + self.poll(cx) + } + // disconnect if shutdown + else if inner.flags.contains(Flags::SHUTDOWN) { + self.poll(cx) + } else { + Poll::Pending + } + } else { + Poll::Pending + } + } + } + DispatcherState::Upgrade(ref mut fut) => { + unsafe { Pin::new_unchecked(fut) }.poll(cx).map_err(|e| { + error!("Upgrade handler error: {}", e); + DispatchError::Upgrade + }) + } + DispatcherState::None => panic!(), + } + } +} + +fn read_available( + cx: &mut Context, + io: &mut T, + buf: &mut BytesMut, +) -> Result, io::Error> +where + T: AsyncRead + Unpin, +{ + let mut read_some = false; + loop { + if buf.remaining_mut() < LW_BUFFER_SIZE { + buf.reserve(HW_BUFFER_SIZE); + } + + match read(cx, io, buf) { + Poll::Pending => { + return if read_some { Ok(Some(false)) } else { Ok(None) }; + } + Poll::Ready(Ok(n)) => { + if n == 0 { + return Ok(Some(true)); + } else { + read_some = true; + } + } + Poll::Ready(Err(e)) => { + return if e.kind() == io::ErrorKind::WouldBlock { + if read_some { + Ok(Some(false)) + } else { + Ok(None) + } + } else if e.kind() == io::ErrorKind::ConnectionReset && read_some { + Ok(Some(true)) + } else { + Err(e) + } + } + } + } +} + +fn read( + cx: &mut Context, + io: &mut T, + buf: &mut BytesMut, +) -> Poll> +where + T: AsyncRead + Unpin, +{ + Pin::new(io).poll_read_buf(cx, buf) +} + +#[cfg(test)] +mod tests { + use actix_service::IntoService; + use futures::future::{lazy, ok}; + + use super::*; + use crate::error::Error; + use crate::h1::{ExpectHandler, UpgradeHandler}; + use crate::test::TestBuffer; + + #[actix_rt::test] + async fn test_req_parse_err() { + lazy(|cx| { + let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n"); + + let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( + buf, + ServiceConfig::default(), + CloneableService::new( + (|_| ok::<_, Error>(Response::Ok().finish())).into_service(), + ), + CloneableService::new(ExpectHandler), + None, + None, + ); + match Pin::new(&mut h1).poll(cx) { + Poll::Pending => panic!(), + Poll::Ready(res) => assert!(res.is_err()), + } + + if let DispatcherState::Normal(ref inner) = h1.inner { + assert!(inner.flags.contains(Flags::READ_DISCONNECT)); + assert_eq!(&inner.io.write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n"); + } + }) + .await; + } +} diff --git a/actix-http/src/h1/encoder.rs b/actix-http/src/h1/encoder.rs new file mode 100644 index 000000000..6396f3b55 --- /dev/null +++ b/actix-http/src/h1/encoder.rs @@ -0,0 +1,636 @@ +#![allow(unused_imports, unused_variables, dead_code)] +use std::fmt::Write as FmtWrite; +use std::io::Write; +use std::marker::PhantomData; +use std::rc::Rc; +use std::str::FromStr; +use std::{cmp, fmt, io, mem}; + +use bytes::{BufMut, Bytes, BytesMut}; + +use crate::body::BodySize; +use crate::config::ServiceConfig; +use crate::header::{map, ContentEncoding}; +use crate::helpers; +use crate::http::header::{ + HeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, +}; +use crate::http::{HeaderMap, Method, StatusCode, Version}; +use crate::message::{ConnectionType, Head, RequestHead, RequestHeadType, ResponseHead}; +use crate::request::Request; +use crate::response::Response; + +const AVERAGE_HEADER_SIZE: usize = 30; + +#[derive(Debug)] +pub(crate) struct MessageEncoder { + pub length: BodySize, + pub te: TransferEncoding, + _t: PhantomData, +} + +impl Default for MessageEncoder { + fn default() -> Self { + MessageEncoder { + length: BodySize::None, + te: TransferEncoding::empty(), + _t: PhantomData, + } + } +} + +pub(crate) trait MessageType: Sized { + fn status(&self) -> Option; + + fn headers(&self) -> &HeaderMap; + + fn extra_headers(&self) -> Option<&HeaderMap>; + + fn camel_case(&self) -> bool { + false + } + + fn chunked(&self) -> bool; + + fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()>; + + fn encode_headers( + &mut self, + dst: &mut BytesMut, + version: Version, + mut length: BodySize, + ctype: ConnectionType, + config: &ServiceConfig, + ) -> io::Result<()> { + let chunked = self.chunked(); + let mut skip_len = length != BodySize::Stream; + let camel_case = self.camel_case(); + + // Content length + if let Some(status) = self.status() { + match status { + StatusCode::NO_CONTENT + | StatusCode::CONTINUE + | StatusCode::PROCESSING => length = BodySize::None, + StatusCode::SWITCHING_PROTOCOLS => { + skip_len = true; + length = BodySize::Stream; + } + _ => (), + } + } + match length { + BodySize::Stream => { + if chunked { + if camel_case { + dst.put_slice(b"\r\nTransfer-Encoding: chunked\r\n") + } else { + dst.put_slice(b"\r\ntransfer-encoding: chunked\r\n") + } + } else { + skip_len = false; + dst.put_slice(b"\r\n"); + } + } + BodySize::Empty => { + if camel_case { + dst.put_slice(b"\r\nContent-Length: 0\r\n"); + } else { + dst.put_slice(b"\r\ncontent-length: 0\r\n"); + } + } + BodySize::Sized(len) => helpers::write_content_length(len, dst), + BodySize::Sized64(len) => { + if camel_case { + dst.put_slice(b"\r\nContent-Length: "); + } else { + dst.put_slice(b"\r\ncontent-length: "); + } + write!(dst.writer(), "{}\r\n", len)?; + } + BodySize::None => dst.put_slice(b"\r\n"), + } + + // Connection + match ctype { + ConnectionType::Upgrade => dst.put_slice(b"connection: upgrade\r\n"), + ConnectionType::KeepAlive if version < Version::HTTP_11 => { + if camel_case { + dst.put_slice(b"Connection: keep-alive\r\n") + } else { + dst.put_slice(b"connection: keep-alive\r\n") + } + } + ConnectionType::Close if version >= Version::HTTP_11 => { + if camel_case { + dst.put_slice(b"Connection: close\r\n") + } else { + dst.put_slice(b"connection: close\r\n") + } + } + _ => (), + } + + // merging headers from head and extra headers. HeaderMap::new() does not allocate. + let empty_headers = HeaderMap::new(); + let extra_headers = self.extra_headers().unwrap_or(&empty_headers); + let headers = self + .headers() + .inner + .iter() + .filter(|(name, _)| !extra_headers.contains_key(*name)) + .chain(extra_headers.inner.iter()); + + // write headers + let mut pos = 0; + let mut has_date = false; + let mut remaining = dst.remaining_mut(); + let mut buf = unsafe { &mut *(dst.bytes_mut() as *mut [u8]) }; + for (key, value) in headers { + match *key { + CONNECTION => continue, + TRANSFER_ENCODING | CONTENT_LENGTH if skip_len => continue, + DATE => { + has_date = true; + } + _ => (), + } + let k = key.as_str().as_bytes(); + match value { + map::Value::One(ref val) => { + let v = val.as_ref(); + let len = k.len() + v.len() + 4; + if len > remaining { + unsafe { + dst.advance_mut(pos); + } + pos = 0; + dst.reserve(len * 2); + remaining = dst.remaining_mut(); + unsafe { + buf = &mut *(dst.bytes_mut() as *mut _); + } + } + // use upper Camel-Case + if camel_case { + write_camel_case(k, &mut buf[pos..pos + k.len()]); + } else { + buf[pos..pos + k.len()].copy_from_slice(k); + } + pos += k.len(); + buf[pos..pos + 2].copy_from_slice(b": "); + pos += 2; + buf[pos..pos + v.len()].copy_from_slice(v); + pos += v.len(); + buf[pos..pos + 2].copy_from_slice(b"\r\n"); + pos += 2; + remaining -= len; + } + map::Value::Multi(ref vec) => { + for val in vec { + let v = val.as_ref(); + let len = k.len() + v.len() + 4; + if len > remaining { + unsafe { + dst.advance_mut(pos); + } + pos = 0; + dst.reserve(len * 2); + remaining = dst.remaining_mut(); + unsafe { + buf = &mut *(dst.bytes_mut() as *mut _); + } + } + // use upper Camel-Case + if camel_case { + write_camel_case(k, &mut buf[pos..pos + k.len()]); + } else { + buf[pos..pos + k.len()].copy_from_slice(k); + } + pos += k.len(); + buf[pos..pos + 2].copy_from_slice(b": "); + pos += 2; + buf[pos..pos + v.len()].copy_from_slice(v); + pos += v.len(); + buf[pos..pos + 2].copy_from_slice(b"\r\n"); + pos += 2; + remaining -= len; + } + } + } + } + unsafe { + dst.advance_mut(pos); + } + + // optimized date header, set_date writes \r\n + if !has_date { + config.set_date(dst); + } else { + // msg eof + dst.extend_from_slice(b"\r\n"); + } + + Ok(()) + } +} + +impl MessageType for Response<()> { + fn status(&self) -> Option { + Some(self.head().status) + } + + fn chunked(&self) -> bool { + self.head().chunked() + } + + fn headers(&self) -> &HeaderMap { + &self.head().headers + } + + fn extra_headers(&self) -> Option<&HeaderMap> { + None + } + + fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> { + let head = self.head(); + let reason = head.reason().as_bytes(); + dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE + reason.len()); + + // status line + helpers::write_status_line(head.version, head.status.as_u16(), dst); + dst.put_slice(reason); + Ok(()) + } +} + +impl MessageType for RequestHeadType { + fn status(&self) -> Option { + None + } + + fn chunked(&self) -> bool { + self.as_ref().chunked() + } + + fn camel_case(&self) -> bool { + self.as_ref().camel_case_headers() + } + + fn headers(&self) -> &HeaderMap { + self.as_ref().headers() + } + + fn extra_headers(&self) -> Option<&HeaderMap> { + self.extra_headers() + } + + fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> { + let head = self.as_ref(); + dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE); + write!( + Writer(dst), + "{} {} {}", + head.method, + head.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"), + match head.version { + Version::HTTP_09 => "HTTP/0.9", + Version::HTTP_10 => "HTTP/1.0", + Version::HTTP_11 => "HTTP/1.1", + Version::HTTP_2 => "HTTP/2.0", + } + ) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + } +} + +impl MessageEncoder { + /// Encode message + pub fn encode_chunk(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result { + self.te.encode(msg, buf) + } + + /// Encode eof + pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> { + self.te.encode_eof(buf) + } + + pub fn encode( + &mut self, + dst: &mut BytesMut, + message: &mut T, + head: bool, + stream: bool, + version: Version, + length: BodySize, + ctype: ConnectionType, + config: &ServiceConfig, + ) -> io::Result<()> { + // transfer encoding + if !head { + self.te = match length { + BodySize::Empty => TransferEncoding::empty(), + BodySize::Sized(len) => TransferEncoding::length(len as u64), + BodySize::Sized64(len) => TransferEncoding::length(len), + BodySize::Stream => { + if message.chunked() && !stream { + TransferEncoding::chunked() + } else { + TransferEncoding::eof() + } + } + BodySize::None => TransferEncoding::empty(), + }; + } else { + self.te = TransferEncoding::empty(); + } + + message.encode_status(dst)?; + message.encode_headers(dst, version, length, ctype, config) + } +} + +/// Encoders to handle different Transfer-Encodings. +#[derive(Debug)] +pub(crate) struct TransferEncoding { + kind: TransferEncodingKind, +} + +#[derive(Debug, PartialEq, Clone)] +enum TransferEncodingKind { + /// An Encoder for when Transfer-Encoding includes `chunked`. + Chunked(bool), + /// An Encoder for when Content-Length is set. + /// + /// Enforces that the body is not longer than the Content-Length header. + Length(u64), + /// An Encoder for when Content-Length is not known. + /// + /// Application decides when to stop writing. + Eof, +} + +impl TransferEncoding { + #[inline] + pub fn empty() -> TransferEncoding { + TransferEncoding { + kind: TransferEncodingKind::Length(0), + } + } + + #[inline] + pub fn eof() -> TransferEncoding { + TransferEncoding { + kind: TransferEncodingKind::Eof, + } + } + + #[inline] + pub fn chunked() -> TransferEncoding { + TransferEncoding { + kind: TransferEncodingKind::Chunked(false), + } + } + + #[inline] + pub fn length(len: u64) -> TransferEncoding { + TransferEncoding { + kind: TransferEncodingKind::Length(len), + } + } + + /// Encode message. Return `EOF` state of encoder + #[inline] + pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result { + match self.kind { + TransferEncodingKind::Eof => { + let eof = msg.is_empty(); + buf.extend_from_slice(msg); + Ok(eof) + } + TransferEncodingKind::Chunked(ref mut eof) => { + if *eof { + return Ok(true); + } + + if msg.is_empty() { + *eof = true; + buf.extend_from_slice(b"0\r\n\r\n"); + } else { + writeln!(Writer(buf), "{:X}\r", msg.len()) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + + buf.reserve(msg.len() + 2); + buf.extend_from_slice(msg); + buf.extend_from_slice(b"\r\n"); + } + Ok(*eof) + } + TransferEncodingKind::Length(ref mut remaining) => { + if *remaining > 0 { + if msg.is_empty() { + return Ok(*remaining == 0); + } + let len = cmp::min(*remaining, msg.len() as u64); + + buf.extend_from_slice(&msg[..len as usize]); + + *remaining -= len as u64; + Ok(*remaining == 0) + } else { + Ok(true) + } + } + } + } + + /// Encode eof. Return `EOF` state of encoder + #[inline] + pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> { + match self.kind { + TransferEncodingKind::Eof => Ok(()), + TransferEncodingKind::Length(rem) => { + if rem != 0 { + Err(io::Error::new(io::ErrorKind::UnexpectedEof, "")) + } else { + Ok(()) + } + } + TransferEncodingKind::Chunked(ref mut eof) => { + if !*eof { + *eof = true; + buf.extend_from_slice(b"0\r\n\r\n"); + } + Ok(()) + } + } + } +} + +struct Writer<'a>(pub &'a mut BytesMut); + +impl<'a> io::Write for Writer<'a> { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.extend_from_slice(buf); + Ok(buf.len()) + } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +fn write_camel_case(value: &[u8], buffer: &mut [u8]) { + let mut index = 0; + let key = value; + let mut key_iter = key.iter(); + + if let Some(c) = key_iter.next() { + if *c >= b'a' && *c <= b'z' { + buffer[index] = *c ^ b' '; + index += 1; + } + } else { + return; + } + + while let Some(c) = key_iter.next() { + buffer[index] = *c; + index += 1; + if *c == b'-' { + if let Some(c) = key_iter.next() { + if *c >= b'a' && *c <= b'z' { + buffer[index] = *c ^ b' '; + index += 1; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + //use std::rc::Rc; + + use super::*; + use crate::http::header::{HeaderValue, CONTENT_TYPE}; + use http::header::AUTHORIZATION; + + #[test] + fn test_chunked_te() { + let mut bytes = BytesMut::new(); + let mut enc = TransferEncoding::chunked(); + { + assert!(!enc.encode(b"test", &mut bytes).ok().unwrap()); + assert!(enc.encode(b"", &mut bytes).ok().unwrap()); + } + assert_eq!( + bytes.take().freeze(), + Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n") + ); + } + + #[test] + fn test_camel_case() { + let mut bytes = BytesMut::with_capacity(2048); + let mut head = RequestHead::default(); + head.set_camel_case_headers(true); + head.headers.insert(DATE, HeaderValue::from_static("date")); + head.headers + .insert(CONTENT_TYPE, HeaderValue::from_static("plain/text")); + + let mut head = RequestHeadType::Owned(head); + + let _ = head.encode_headers( + &mut bytes, + Version::HTTP_11, + BodySize::Empty, + ConnectionType::Close, + &ServiceConfig::default(), + ); + let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap(); + assert!(data.contains("Content-Length: 0\r\n")); + assert!(data.contains("Connection: close\r\n")); + assert!(data.contains("Content-Type: plain/text\r\n")); + assert!(data.contains("Date: date\r\n")); + + let _ = head.encode_headers( + &mut bytes, + Version::HTTP_11, + BodySize::Stream, + ConnectionType::KeepAlive, + &ServiceConfig::default(), + ); + let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap(); + assert!(data.contains("Transfer-Encoding: chunked\r\n")); + assert!(data.contains("Content-Type: plain/text\r\n")); + assert!(data.contains("Date: date\r\n")); + + let _ = head.encode_headers( + &mut bytes, + Version::HTTP_11, + BodySize::Sized64(100), + ConnectionType::KeepAlive, + &ServiceConfig::default(), + ); + let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap(); + assert!(data.contains("Content-Length: 100\r\n")); + assert!(data.contains("Content-Type: plain/text\r\n")); + assert!(data.contains("Date: date\r\n")); + + let mut head = RequestHead::default(); + head.set_camel_case_headers(false); + head.headers.insert(DATE, HeaderValue::from_static("date")); + head.headers + .insert(CONTENT_TYPE, HeaderValue::from_static("plain/text")); + head.headers + .append(CONTENT_TYPE, HeaderValue::from_static("xml")); + + let mut head = RequestHeadType::Owned(head); + let _ = head.encode_headers( + &mut bytes, + Version::HTTP_11, + BodySize::Stream, + ConnectionType::KeepAlive, + &ServiceConfig::default(), + ); + let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap(); + assert!(data.contains("transfer-encoding: chunked\r\n")); + assert!(data.contains("content-type: xml\r\n")); + assert!(data.contains("content-type: plain/text\r\n")); + assert!(data.contains("date: date\r\n")); + } + + #[test] + fn test_extra_headers() { + let mut bytes = BytesMut::with_capacity(2048); + + let mut head = RequestHead::default(); + head.headers.insert( + AUTHORIZATION, + HeaderValue::from_static("some authorization"), + ); + + let mut extra_headers = HeaderMap::new(); + extra_headers.insert( + AUTHORIZATION, + HeaderValue::from_static("another authorization"), + ); + extra_headers.insert(DATE, HeaderValue::from_static("date")); + + let mut head = RequestHeadType::Rc(Rc::new(head), Some(extra_headers)); + + let _ = head.encode_headers( + &mut bytes, + Version::HTTP_11, + BodySize::Empty, + ConnectionType::Close, + &ServiceConfig::default(), + ); + let data = String::from_utf8(Vec::from(bytes.take().freeze().as_ref())).unwrap(); + assert!(data.contains("content-length: 0\r\n")); + assert!(data.contains("connection: close\r\n")); + assert!(data.contains("authorization: another authorization\r\n")); + assert!(data.contains("date: date\r\n")); + } +} diff --git a/actix-http/src/h1/expect.rs b/actix-http/src/h1/expect.rs new file mode 100644 index 000000000..d6b4a9f1e --- /dev/null +++ b/actix-http/src/h1/expect.rs @@ -0,0 +1,39 @@ +use std::task::{Context, Poll}; + +use actix_server_config::ServerConfig; +use actix_service::{Service, ServiceFactory}; +use futures::future::{ok, Ready}; + +use crate::error::Error; +use crate::request::Request; + +pub struct ExpectHandler; + +impl ServiceFactory for ExpectHandler { + type Config = ServerConfig; + type Request = Request; + type Response = Request; + type Error = Error; + type Service = ExpectHandler; + type InitError = Error; + type Future = Ready>; + + fn new_service(&self, _: &ServerConfig) -> Self::Future { + ok(ExpectHandler) + } +} + +impl Service for ExpectHandler { + type Request = Request; + type Response = Request; + type Error = Error; + type Future = Ready>; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + ok(req) + } +} diff --git a/actix-http/src/h1/mod.rs b/actix-http/src/h1/mod.rs new file mode 100644 index 000000000..0c85f076a --- /dev/null +++ b/actix-http/src/h1/mod.rs @@ -0,0 +1,85 @@ +//! HTTP/1 implementation +use bytes::{Bytes, BytesMut}; + +mod client; +mod codec; +mod decoder; +mod dispatcher; +mod encoder; +mod expect; +mod payload; +mod service; +mod upgrade; +mod utils; + +pub use self::client::{ClientCodec, ClientPayloadCodec}; +pub use self::codec::Codec; +pub use self::dispatcher::Dispatcher; +pub use self::expect::ExpectHandler; +pub use self::payload::Payload; +pub use self::service::{H1Service, H1ServiceHandler, OneRequest}; +pub use self::upgrade::UpgradeHandler; +pub use self::utils::SendResponse; + +#[derive(Debug)] +/// Codec message +pub enum Message { + /// Http message + Item(T), + /// Payload chunk + Chunk(Option), +} + +impl From for Message { + fn from(item: T) -> Self { + Message::Item(item) + } +} + +/// Incoming request type +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MessageType { + None, + Payload, + Stream, +} + +const LW: usize = 2 * 1024; +const HW: usize = 32 * 1024; + +pub(crate) fn reserve_readbuf(src: &mut BytesMut) { + let cap = src.capacity(); + if cap < LW { + src.reserve(HW - cap); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::request::Request; + + impl Message { + pub fn message(self) -> Request { + match self { + Message::Item(req) => req, + _ => panic!("error"), + } + } + + pub fn chunk(self) -> Bytes { + match self { + Message::Chunk(Some(data)) => data, + _ => panic!("error"), + } + } + + pub fn eof(self) -> bool { + match self { + Message::Chunk(None) => true, + Message::Chunk(Some(_)) => false, + _ => panic!("error"), + } + } + } +} diff --git a/actix-http/src/h1/payload.rs b/actix-http/src/h1/payload.rs new file mode 100644 index 000000000..46f2f9728 --- /dev/null +++ b/actix-http/src/h1/payload.rs @@ -0,0 +1,244 @@ +//! Payload stream +use std::cell::RefCell; +use std::collections::VecDeque; +use std::pin::Pin; +use std::rc::{Rc, Weak}; +use std::task::{Context, Poll}; + +use actix_utils::task::LocalWaker; +use bytes::Bytes; +use futures::Stream; + +use crate::error::PayloadError; + +/// max buffer size 32k +pub(crate) const MAX_BUFFER_SIZE: usize = 32_768; + +#[derive(Debug, PartialEq)] +pub enum PayloadStatus { + Read, + Pause, + Dropped, +} + +/// Buffered stream of bytes chunks +/// +/// Payload stores chunks in a vector. First chunk can be received with +/// `.readany()` method. Payload stream is not thread safe. Payload does not +/// notify current task when new data is available. +/// +/// Payload stream can be used as `Response` body stream. +#[derive(Debug)] +pub struct Payload { + inner: Rc>, +} + +impl Payload { + /// Create payload stream. + /// + /// This method construct two objects responsible for bytes stream + /// generation. + /// + /// * `PayloadSender` - *Sender* side of the stream + /// + /// * `Payload` - *Receiver* side of the stream + pub fn create(eof: bool) -> (PayloadSender, Payload) { + let shared = Rc::new(RefCell::new(Inner::new(eof))); + + ( + PayloadSender { + inner: Rc::downgrade(&shared), + }, + Payload { inner: shared }, + ) + } + + /// Create empty payload + #[doc(hidden)] + pub fn empty() -> Payload { + Payload { + inner: Rc::new(RefCell::new(Inner::new(true))), + } + } + + /// Length of the data in this payload + #[cfg(test)] + pub fn len(&self) -> usize { + self.inner.borrow().len() + } + + /// Is payload empty + #[cfg(test)] + pub fn is_empty(&self) -> bool { + self.inner.borrow().len() == 0 + } + + /// Put unused data back to payload + #[inline] + pub fn unread_data(&mut self, data: Bytes) { + self.inner.borrow_mut().unread_data(data); + } + + #[inline] + pub fn readany( + &mut self, + cx: &mut Context, + ) -> Poll>> { + self.inner.borrow_mut().readany(cx) + } +} + +impl Stream for Payload { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll>> { + self.inner.borrow_mut().readany(cx) + } +} + +/// Sender part of the payload stream +pub struct PayloadSender { + inner: Weak>, +} + +impl PayloadSender { + #[inline] + pub fn set_error(&mut self, err: PayloadError) { + if let Some(shared) = self.inner.upgrade() { + shared.borrow_mut().set_error(err) + } + } + + #[inline] + pub fn feed_eof(&mut self) { + if let Some(shared) = self.inner.upgrade() { + shared.borrow_mut().feed_eof() + } + } + + #[inline] + pub fn feed_data(&mut self, data: Bytes) { + if let Some(shared) = self.inner.upgrade() { + shared.borrow_mut().feed_data(data) + } + } + + #[inline] + pub fn need_read(&self, cx: &mut Context) -> PayloadStatus { + // we check need_read only if Payload (other side) is alive, + // otherwise always return true (consume payload) + if let Some(shared) = self.inner.upgrade() { + if shared.borrow().need_read { + PayloadStatus::Read + } else { + shared.borrow_mut().io_task.register(cx.waker()); + PayloadStatus::Pause + } + } else { + PayloadStatus::Dropped + } + } +} + +#[derive(Debug)] +struct Inner { + len: usize, + eof: bool, + err: Option, + need_read: bool, + items: VecDeque, + task: LocalWaker, + io_task: LocalWaker, +} + +impl Inner { + fn new(eof: bool) -> Self { + Inner { + eof, + len: 0, + err: None, + items: VecDeque::new(), + need_read: true, + task: LocalWaker::new(), + io_task: LocalWaker::new(), + } + } + + #[inline] + fn set_error(&mut self, err: PayloadError) { + self.err = Some(err); + } + + #[inline] + fn feed_eof(&mut self) { + self.eof = true; + } + + #[inline] + fn feed_data(&mut self, data: Bytes) { + self.len += data.len(); + self.items.push_back(data); + self.need_read = self.len < MAX_BUFFER_SIZE; + if let Some(task) = self.task.take() { + task.wake() + } + } + + #[cfg(test)] + fn len(&self) -> usize { + self.len + } + + fn readany( + &mut self, + cx: &mut Context, + ) -> Poll>> { + if let Some(data) = self.items.pop_front() { + self.len -= data.len(); + self.need_read = self.len < MAX_BUFFER_SIZE; + + if self.need_read && !self.eof { + self.task.register(cx.waker()); + } + self.io_task.wake(); + Poll::Ready(Some(Ok(data))) + } else if let Some(err) = self.err.take() { + Poll::Ready(Some(Err(err))) + } else if self.eof { + Poll::Ready(None) + } else { + self.need_read = true; + self.task.register(cx.waker()); + self.io_task.wake(); + Poll::Pending + } + } + + fn unread_data(&mut self, data: Bytes) { + self.len += data.len(); + self.items.push_front(data); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::future::poll_fn; + + #[actix_rt::test] + async fn test_unread_data() { + let (_, mut payload) = Payload::create(false); + + payload.unread_data(Bytes::from("data")); + assert!(!payload.is_empty()); + assert_eq!(payload.len(), 4); + + assert_eq!( + Bytes::from("data"), + poll_fn(|cx| payload.readany(cx)).await.unwrap().unwrap() + ); + } +} diff --git a/actix-http/src/h1/service.rs b/actix-http/src/h1/service.rs new file mode 100644 index 000000000..197c92887 --- /dev/null +++ b/actix-http/src/h1/service.rs @@ -0,0 +1,451 @@ +use std::fmt; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; + +use actix_codec::Framed; +use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig}; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; +use futures::future::{ok, Ready}; +use futures::ready; + +use crate::body::MessageBody; +use crate::cloneable::CloneableService; +use crate::config::{KeepAlive, ServiceConfig}; +use crate::error::{DispatchError, Error, ParseError}; +use crate::helpers::DataFactory; +use crate::request::Request; +use crate::response::Response; + +use super::codec::Codec; +use super::dispatcher::Dispatcher; +use super::{ExpectHandler, Message, UpgradeHandler}; + +/// `ServiceFactory` implementation for HTTP1 transport +pub struct H1Service> { + srv: S, + cfg: ServiceConfig, + expect: X, + upgrade: Option, + on_connect: Option Box>>, + _t: PhantomData<(T, P, B)>, +} + +impl H1Service +where + S: ServiceFactory, + S::Error: Into, + S::InitError: fmt::Debug, + S::Response: Into>, + B: MessageBody, +{ + /// Create new `HttpService` instance with default config. + pub fn new>(service: F) -> Self { + let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); + + H1Service { + cfg, + srv: service.into_factory(), + expect: ExpectHandler, + upgrade: None, + on_connect: None, + _t: PhantomData, + } + } + + /// Create new `HttpService` instance with config. + pub fn with_config>( + cfg: ServiceConfig, + service: F, + ) -> Self { + H1Service { + cfg, + srv: service.into_factory(), + expect: ExpectHandler, + upgrade: None, + on_connect: None, + _t: PhantomData, + } + } +} + +impl H1Service +where + S: ServiceFactory, + S::Error: Into, + S::Response: Into>, + S::InitError: fmt::Debug, + B: MessageBody, +{ + pub fn expect(self, expect: X1) -> H1Service + where + X1: ServiceFactory, + X1::Error: Into, + X1::InitError: fmt::Debug, + { + H1Service { + expect, + cfg: self.cfg, + srv: self.srv, + upgrade: self.upgrade, + on_connect: self.on_connect, + _t: PhantomData, + } + } + + pub fn upgrade(self, upgrade: Option) -> H1Service + where + U1: ServiceFactory), Response = ()>, + U1::Error: fmt::Display, + U1::InitError: fmt::Debug, + { + H1Service { + upgrade, + cfg: self.cfg, + srv: self.srv, + expect: self.expect, + on_connect: self.on_connect, + _t: PhantomData, + } + } + + /// Set on connect callback. + pub(crate) fn on_connect( + mut self, + f: Option Box>>, + ) -> Self { + self.on_connect = f; + self + } +} + +impl ServiceFactory for H1Service +where + T: IoStream, + S: ServiceFactory, + S::Error: Into, + S::Response: Into>, + S::InitError: fmt::Debug, + B: MessageBody, + X: ServiceFactory, + X::Error: Into, + X::InitError: fmt::Debug, + U: ServiceFactory< + Config = SrvConfig, + Request = (Request, Framed), + Response = (), + >, + U::Error: fmt::Display, + U::InitError: fmt::Debug, +{ + type Config = SrvConfig; + type Request = Io; + type Response = (); + type Error = DispatchError; + type InitError = (); + type Service = H1ServiceHandler; + type Future = H1ServiceResponse; + + fn new_service(&self, cfg: &SrvConfig) -> Self::Future { + H1ServiceResponse { + fut: self.srv.new_service(cfg), + fut_ex: Some(self.expect.new_service(cfg)), + fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)), + expect: None, + upgrade: None, + on_connect: self.on_connect.clone(), + cfg: Some(self.cfg.clone()), + _t: PhantomData, + } + } +} + +#[doc(hidden)] +#[pin_project::pin_project] +pub struct H1ServiceResponse +where + S: ServiceFactory, + S::Error: Into, + S::InitError: fmt::Debug, + X: ServiceFactory, + X::Error: Into, + X::InitError: fmt::Debug, + U: ServiceFactory), Response = ()>, + U::Error: fmt::Display, + U::InitError: fmt::Debug, +{ + #[pin] + fut: S::Future, + #[pin] + fut_ex: Option, + #[pin] + fut_upg: Option, + expect: Option, + upgrade: Option, + on_connect: Option Box>>, + cfg: Option, + _t: PhantomData<(T, P, B)>, +} + +impl Future for H1ServiceResponse +where + T: IoStream, + S: ServiceFactory, + S::Error: Into, + S::Response: Into>, + S::InitError: fmt::Debug, + B: MessageBody, + X: ServiceFactory, + X::Error: Into, + X::InitError: fmt::Debug, + U: ServiceFactory), Response = ()>, + U::Error: fmt::Display, + U::InitError: fmt::Debug, +{ + type Output = + Result, ()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut this = self.as_mut().project(); + + if let Some(fut) = this.fut_ex.as_pin_mut() { + let expect = ready!(fut + .poll(cx) + .map_err(|e| log::error!("Init http service error: {:?}", e)))?; + this = self.as_mut().project(); + *this.expect = Some(expect); + this.fut_ex.set(None); + } + + if let Some(fut) = this.fut_upg.as_pin_mut() { + let upgrade = ready!(fut + .poll(cx) + .map_err(|e| log::error!("Init http service error: {:?}", e)))?; + this = self.as_mut().project(); + *this.upgrade = Some(upgrade); + this.fut_ex.set(None); + } + + let result = ready!(this + .fut + .poll(cx) + .map_err(|e| log::error!("Init http service error: {:?}", e))); + + Poll::Ready(result.map(|service| { + let this = self.as_mut().project(); + H1ServiceHandler::new( + this.cfg.take().unwrap(), + service, + this.expect.take().unwrap(), + this.upgrade.take(), + this.on_connect.clone(), + ) + })) + } +} + +/// `Service` implementation for HTTP1 transport +pub struct H1ServiceHandler { + srv: CloneableService, + expect: CloneableService, + upgrade: Option>, + on_connect: Option Box>>, + cfg: ServiceConfig, + _t: PhantomData<(T, P, B)>, +} + +impl H1ServiceHandler +where + S: Service, + S::Error: Into, + S::Response: Into>, + B: MessageBody, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + fn new( + cfg: ServiceConfig, + srv: S, + expect: X, + upgrade: Option, + on_connect: Option Box>>, + ) -> H1ServiceHandler { + H1ServiceHandler { + srv: CloneableService::new(srv), + expect: CloneableService::new(expect), + upgrade: upgrade.map(CloneableService::new), + cfg, + on_connect, + _t: PhantomData, + } + } +} + +impl Service for H1ServiceHandler +where + T: IoStream, + S: Service, + S::Error: Into, + S::Response: Into>, + B: MessageBody, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + type Request = Io; + type Response = (); + type Error = DispatchError; + type Future = Dispatcher; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + let ready = self + .expect + .poll_ready(cx) + .map_err(|e| { + let e = e.into(); + log::error!("Http service readiness error: {:?}", e); + DispatchError::Service(e) + })? + .is_ready(); + + let ready = self + .srv + .poll_ready(cx) + .map_err(|e| { + let e = e.into(); + log::error!("Http service readiness error: {:?}", e); + DispatchError::Service(e) + })? + .is_ready() + && ready; + + if ready { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + let io = req.into_parts().0; + + let on_connect = if let Some(ref on_connect) = self.on_connect { + Some(on_connect(&io)) + } else { + None + }; + + Dispatcher::new( + io, + self.cfg.clone(), + self.srv.clone(), + self.expect.clone(), + self.upgrade.clone(), + on_connect, + ) + } +} + +/// `ServiceFactory` implementation for `OneRequestService` service +#[derive(Default)] +pub struct OneRequest { + config: ServiceConfig, + _t: PhantomData<(T, P)>, +} + +impl OneRequest +where + T: IoStream, +{ + /// Create new `H1SimpleService` instance. + pub fn new() -> Self { + OneRequest { + config: ServiceConfig::default(), + _t: PhantomData, + } + } +} + +impl ServiceFactory for OneRequest +where + T: IoStream, +{ + type Config = SrvConfig; + type Request = Io; + type Response = (Request, Framed); + type Error = ParseError; + type InitError = (); + type Service = OneRequestService; + type Future = Ready>; + + fn new_service(&self, _: &SrvConfig) -> Self::Future { + ok(OneRequestService { + config: self.config.clone(), + _t: PhantomData, + }) + } +} + +/// `Service` implementation for HTTP1 transport. Reads one request and returns +/// request and framed object. +pub struct OneRequestService { + config: ServiceConfig, + _t: PhantomData<(T, P)>, +} + +impl Service for OneRequestService +where + T: IoStream, +{ + type Request = Io; + type Response = (Request, Framed); + type Error = ParseError; + type Future = OneRequestServiceResponse; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + OneRequestServiceResponse { + framed: Some(Framed::new( + req.into_parts().0, + Codec::new(self.config.clone()), + )), + } + } +} + +#[doc(hidden)] +pub struct OneRequestServiceResponse +where + T: IoStream, +{ + framed: Option>, +} + +impl Future for OneRequestServiceResponse +where + T: IoStream, +{ + type Output = Result<(Request, Framed), ParseError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match self.framed.as_mut().unwrap().next_item(cx) { + Poll::Ready(Some(Ok(req))) => match req { + Message::Item(req) => { + Poll::Ready(Ok((req, self.framed.take().unwrap()))) + } + Message::Chunk(_) => unreachable!("Something is wrong"), + }, + Poll::Ready(Some(Err(err))) => Poll::Ready(Err(err)), + Poll::Ready(None) => Poll::Ready(Err(ParseError::Incomplete)), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/actix-http/src/h1/upgrade.rs b/actix-http/src/h1/upgrade.rs new file mode 100644 index 000000000..ce46fbe93 --- /dev/null +++ b/actix-http/src/h1/upgrade.rs @@ -0,0 +1,42 @@ +use std::marker::PhantomData; +use std::task::{Context, Poll}; + +use actix_codec::Framed; +use actix_server_config::ServerConfig; +use actix_service::{Service, ServiceFactory}; +use futures::future::Ready; + +use crate::error::Error; +use crate::h1::Codec; +use crate::request::Request; + +pub struct UpgradeHandler(PhantomData); + +impl ServiceFactory for UpgradeHandler { + type Config = ServerConfig; + type Request = (Request, Framed); + type Response = (); + type Error = Error; + type Service = UpgradeHandler; + type InitError = Error; + type Future = Ready>; + + fn new_service(&self, _: &ServerConfig) -> Self::Future { + unimplemented!() + } +} + +impl Service for UpgradeHandler { + type Request = (Request, Framed); + type Response = (); + type Error = Error; + type Future = Ready>; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: Self::Request) -> Self::Future { + unimplemented!() + } +} diff --git a/actix-http/src/h1/utils.rs b/actix-http/src/h1/utils.rs new file mode 100644 index 000000000..7af0b124e --- /dev/null +++ b/actix-http/src/h1/utils.rs @@ -0,0 +1,97 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; + +use crate::body::{BodySize, MessageBody, ResponseBody}; +use crate::error::Error; +use crate::h1::{Codec, Message}; +use crate::response::Response; + +/// Send http/1 response +#[pin_project::pin_project] +pub struct SendResponse { + res: Option, BodySize)>>, + body: Option>, + framed: Option>, +} + +impl SendResponse +where + B: MessageBody, +{ + pub fn new(framed: Framed, response: Response) -> Self { + let (res, body) = response.into_parts(); + + SendResponse { + res: Some((res, body.size()).into()), + body: Some(body), + framed: Some(framed), + } + } +} + +impl Future for SendResponse +where + T: AsyncRead + AsyncWrite, + B: MessageBody, +{ + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); + + loop { + let mut body_ready = this.body.is_some(); + let framed = this.framed.as_mut().unwrap(); + + // send body + if this.res.is_none() && this.body.is_some() { + while body_ready && this.body.is_some() && !framed.is_write_buf_full() { + match this.body.as_mut().unwrap().poll_next(cx)? { + Poll::Ready(item) => { + // body is done + if item.is_none() { + let _ = this.body.take(); + } + framed.write(Message::Chunk(item))?; + } + Poll::Pending => body_ready = false, + } + } + } + + // flush write buffer + if !framed.is_write_buf_empty() { + match framed.flush(cx)? { + Poll::Ready(_) => { + if body_ready { + continue; + } else { + return Poll::Pending; + } + } + Poll::Pending => return Poll::Pending, + } + } + + // send response + if let Some(res) = this.res.take() { + framed.write(res)?; + continue; + } + + if this.body.is_some() { + if body_ready { + continue; + } else { + return Poll::Pending; + } + } else { + break; + } + } + Poll::Ready(Ok(this.framed.take().unwrap())) + } +} diff --git a/actix-http/src/h2/dispatcher.rs b/actix-http/src/h2/dispatcher.rs new file mode 100644 index 000000000..188553806 --- /dev/null +++ b/actix-http/src/h2/dispatcher.rs @@ -0,0 +1,366 @@ +use std::collections::VecDeque; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use std::{fmt, mem, net}; + +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_rt::time::Delay; +use actix_server_config::IoStream; +use actix_service::Service; +use bitflags::bitflags; +use bytes::{Bytes, BytesMut}; +use futures::{ready, Sink, Stream}; +use h2::server::{Connection, SendResponse}; +use h2::{RecvStream, SendStream}; +use http::header::{ + HeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, +}; +use http::HttpTryFrom; +use log::{debug, error, trace}; + +use crate::body::{Body, BodySize, MessageBody, ResponseBody}; +use crate::cloneable::CloneableService; +use crate::config::ServiceConfig; +use crate::error::{DispatchError, Error, ParseError, PayloadError, ResponseError}; +use crate::helpers::DataFactory; +use crate::httpmessage::HttpMessage; +use crate::message::ResponseHead; +use crate::payload::Payload; +use crate::request::Request; +use crate::response::Response; + +const CHUNK_SIZE: usize = 16_384; + +/// Dispatcher for HTTP/2 protocol +#[pin_project::pin_project] +pub struct Dispatcher, B: MessageBody> { + service: CloneableService, + connection: Connection, + on_connect: Option>, + config: ServiceConfig, + peer_addr: Option, + ka_expire: Instant, + ka_timer: Option, + _t: PhantomData, +} + +impl Dispatcher +where + T: IoStream, + S: Service, + S::Error: Into, + // S::Future: 'static, + S::Response: Into>, + B: MessageBody, +{ + pub(crate) fn new( + service: CloneableService, + connection: Connection, + on_connect: Option>, + config: ServiceConfig, + timeout: Option, + peer_addr: Option, + ) -> Self { + // let keepalive = config.keep_alive_enabled(); + // let flags = if keepalive { + // Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED + // } else { + // Flags::empty() + // }; + + // keep-alive timer + let (ka_expire, ka_timer) = if let Some(delay) = timeout { + (delay.deadline(), Some(delay)) + } else if let Some(delay) = config.keep_alive_timer() { + (delay.deadline(), Some(delay)) + } else { + (config.now(), None) + }; + + Dispatcher { + service, + config, + peer_addr, + connection, + on_connect, + ka_expire, + ka_timer, + _t: PhantomData, + } + } +} + +impl Future for Dispatcher +where + T: IoStream, + S: Service, + S::Error: Into + 'static, + S::Future: 'static, + S::Response: Into> + 'static, + B: MessageBody + 'static, +{ + type Output = Result<(), DispatchError>; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); + + loop { + match Pin::new(&mut this.connection).poll_accept(cx) { + Poll::Ready(None) => return Poll::Ready(Ok(())), + Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())), + Poll::Ready(Some(Ok((req, res)))) => { + // update keep-alive expire + if this.ka_timer.is_some() { + if let Some(expire) = this.config.keep_alive_expire() { + this.ka_expire = expire; + } + } + + let (parts, body) = req.into_parts(); + let mut req = Request::with_payload(Payload::< + crate::payload::PayloadStream, + >::H2( + crate::h2::Payload::new(body) + )); + + let head = &mut req.head_mut(); + head.uri = parts.uri; + head.method = parts.method; + head.version = parts.version; + head.headers = parts.headers.into(); + head.peer_addr = this.peer_addr; + + // set on_connect data + if let Some(ref on_connect) = this.on_connect { + on_connect.set(&mut req.extensions_mut()); + } + + actix_rt::spawn(ServiceResponse::< + S::Future, + S::Response, + S::Error, + B, + > { + state: ServiceResponseState::ServiceCall( + this.service.call(req), + Some(res), + ), + config: this.config.clone(), + buffer: None, + _t: PhantomData, + }); + } + Poll::Pending => return Poll::Pending, + } + } + } +} + +#[pin_project::pin_project] +struct ServiceResponse { + state: ServiceResponseState, + config: ServiceConfig, + buffer: Option, + _t: PhantomData<(I, E)>, +} + +enum ServiceResponseState { + ServiceCall(F, Option>), + SendPayload(SendStream, ResponseBody), +} + +impl ServiceResponse +where + F: Future>, + E: Into, + I: Into>, + B: MessageBody, +{ + fn prepare_response( + &self, + head: &ResponseHead, + size: &mut BodySize, + ) -> http::Response<()> { + let mut has_date = false; + let mut skip_len = size != &BodySize::Stream; + + let mut res = http::Response::new(()); + *res.status_mut() = head.status; + *res.version_mut() = http::Version::HTTP_2; + + // Content length + match head.status { + http::StatusCode::NO_CONTENT + | http::StatusCode::CONTINUE + | http::StatusCode::PROCESSING => *size = BodySize::None, + http::StatusCode::SWITCHING_PROTOCOLS => { + skip_len = true; + *size = BodySize::Stream; + } + _ => (), + } + let _ = match size { + BodySize::None | BodySize::Stream => None, + BodySize::Empty => res + .headers_mut() + .insert(CONTENT_LENGTH, HeaderValue::from_static("0")), + BodySize::Sized(len) => res.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::try_from(format!("{}", len)).unwrap(), + ), + BodySize::Sized64(len) => res.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::try_from(format!("{}", len)).unwrap(), + ), + }; + + // copy headers + for (key, value) in head.headers.iter() { + match *key { + CONNECTION | TRANSFER_ENCODING => continue, // http2 specific + CONTENT_LENGTH if skip_len => continue, + DATE => has_date = true, + _ => (), + } + res.headers_mut().append(key, value.clone()); + } + + // set date header + if !has_date { + let mut bytes = BytesMut::with_capacity(29); + self.config.set_date_header(&mut bytes); + res.headers_mut() + .insert(DATE, HeaderValue::try_from(bytes.freeze()).unwrap()); + } + + res + } +} + +impl Future for ServiceResponse +where + F: Future>, + E: Into, + I: Into>, + B: MessageBody, +{ + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut this = self.as_mut().project(); + + match this.state { + ServiceResponseState::ServiceCall(ref mut call, ref mut send) => { + match unsafe { Pin::new_unchecked(call) }.poll(cx) { + Poll::Ready(Ok(res)) => { + let (res, body) = res.into().replace_body(()); + + let mut send = send.take().unwrap(); + let mut size = body.size(); + let h2_res = + self.as_mut().prepare_response(res.head(), &mut size); + this = self.as_mut().project(); + + let stream = match send.send_response(h2_res, size.is_eof()) { + Err(e) => { + trace!("Error sending h2 response: {:?}", e); + return Poll::Ready(()); + } + Ok(stream) => stream, + }; + + if size.is_eof() { + Poll::Ready(()) + } else { + *this.state = + ServiceResponseState::SendPayload(stream, body); + self.poll(cx) + } + } + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => { + let res: Response = e.into().into(); + let (res, body) = res.replace_body(()); + + let mut send = send.take().unwrap(); + let mut size = body.size(); + let h2_res = + self.as_mut().prepare_response(res.head(), &mut size); + this = self.as_mut().project(); + + let stream = match send.send_response(h2_res, size.is_eof()) { + Err(e) => { + trace!("Error sending h2 response: {:?}", e); + return Poll::Ready(()); + } + Ok(stream) => stream, + }; + + if size.is_eof() { + Poll::Ready(()) + } else { + *this.state = ServiceResponseState::SendPayload( + stream, + body.into_body(), + ); + self.poll(cx) + } + } + } + } + ServiceResponseState::SendPayload(ref mut stream, ref mut body) => loop { + loop { + if let Some(ref mut buffer) = this.buffer { + match stream.poll_capacity(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => return Poll::Ready(()), + Poll::Ready(Some(Ok(cap))) => { + let len = buffer.len(); + let bytes = buffer.split_to(std::cmp::min(cap, len)); + + if let Err(e) = stream.send_data(bytes, false) { + warn!("{:?}", e); + return Poll::Ready(()); + } else if !buffer.is_empty() { + let cap = std::cmp::min(buffer.len(), CHUNK_SIZE); + stream.reserve_capacity(cap); + } else { + this.buffer.take(); + } + } + Poll::Ready(Some(Err(e))) => { + warn!("{:?}", e); + return Poll::Ready(()); + } + } + } else { + match body.poll_next(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => { + if let Err(e) = stream.send_data(Bytes::new(), true) { + warn!("{:?}", e); + } + return Poll::Ready(()); + } + Poll::Ready(Some(Ok(chunk))) => { + stream.reserve_capacity(std::cmp::min( + chunk.len(), + CHUNK_SIZE, + )); + *this.buffer = Some(chunk); + } + Poll::Ready(Some(Err(e))) => { + error!("Response payload stream error: {:?}", e); + return Poll::Ready(()); + } + } + } + } + }, + } + } +} diff --git a/actix-http/src/h2/mod.rs b/actix-http/src/h2/mod.rs new file mode 100644 index 000000000..9c902f18c --- /dev/null +++ b/actix-http/src/h2/mod.rs @@ -0,0 +1,49 @@ +#![allow(dead_code, unused_imports)] +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::Bytes; +use futures::Stream; +use h2::RecvStream; + +mod dispatcher; +mod service; + +pub use self::dispatcher::Dispatcher; +pub use self::service::H2Service; +use crate::error::PayloadError; + +/// H2 receive stream +pub struct Payload { + pl: RecvStream, +} + +impl Payload { + pub(crate) fn new(pl: RecvStream) -> Self { + Self { pl } + } +} + +impl Stream for Payload { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.get_mut(); + + match Pin::new(&mut this.pl).poll_data(cx) { + Poll::Ready(Some(Ok(chunk))) => { + let len = chunk.len(); + if let Err(err) = this.pl.release_capacity().release_capacity(len) { + Poll::Ready(Some(Err(err.into()))) + } else { + Poll::Ready(Some(Ok(chunk))) + } + } + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err.into()))), + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), + } + } +} diff --git a/actix-http/src/h2/service.rs b/actix-http/src/h2/service.rs new file mode 100644 index 000000000..860a61f73 --- /dev/null +++ b/actix-http/src/h2/service.rs @@ -0,0 +1,281 @@ +use std::fmt::Debug; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{io, net, rc}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig}; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; +use bytes::Bytes; +use futures::future::{ok, Ready}; +use futures::{ready, Stream}; +use h2::server::{self, Connection, Handshake}; +use h2::RecvStream; +use log::error; + +use crate::body::MessageBody; +use crate::cloneable::CloneableService; +use crate::config::{KeepAlive, ServiceConfig}; +use crate::error::{DispatchError, Error, ParseError, ResponseError}; +use crate::helpers::DataFactory; +use crate::payload::Payload; +use crate::request::Request; +use crate::response::Response; + +use super::dispatcher::Dispatcher; + +/// `ServiceFactory` implementation for HTTP2 transport +pub struct H2Service { + srv: S, + cfg: ServiceConfig, + on_connect: Option Box>>, + _t: PhantomData<(T, P, B)>, +} + +impl H2Service +where + S: ServiceFactory, + S::Error: Into + 'static, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody + 'static, +{ + /// Create new `HttpService` instance. + pub fn new>(service: F) -> Self { + let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); + + H2Service { + cfg, + on_connect: None, + srv: service.into_factory(), + _t: PhantomData, + } + } + + /// Create new `HttpService` instance with config. + pub fn with_config>( + cfg: ServiceConfig, + service: F, + ) -> Self { + H2Service { + cfg, + on_connect: None, + srv: service.into_factory(), + _t: PhantomData, + } + } + + /// Set on connect callback. + pub(crate) fn on_connect( + mut self, + f: Option Box>>, + ) -> Self { + self.on_connect = f; + self + } +} + +impl ServiceFactory for H2Service +where + T: IoStream, + S: ServiceFactory, + S::Error: Into + 'static, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody + 'static, +{ + type Config = SrvConfig; + type Request = Io; + type Response = (); + type Error = DispatchError; + type InitError = S::InitError; + type Service = H2ServiceHandler; + type Future = H2ServiceResponse; + + fn new_service(&self, cfg: &SrvConfig) -> Self::Future { + H2ServiceResponse { + fut: self.srv.new_service(cfg), + cfg: Some(self.cfg.clone()), + on_connect: self.on_connect.clone(), + _t: PhantomData, + } + } +} + +#[doc(hidden)] +#[pin_project::pin_project] +pub struct H2ServiceResponse { + #[pin] + fut: S::Future, + cfg: Option, + on_connect: Option Box>>, + _t: PhantomData<(T, P, B)>, +} + +impl Future for H2ServiceResponse +where + T: IoStream, + S: ServiceFactory, + S::Error: Into + 'static, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody + 'static, +{ + type Output = Result, S::InitError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.as_mut().project(); + + Poll::Ready(ready!(this.fut.poll(cx)).map(|service| { + let this = self.as_mut().project(); + H2ServiceHandler::new( + this.cfg.take().unwrap(), + this.on_connect.clone(), + service, + ) + })) + } +} + +/// `Service` implementation for http/2 transport +pub struct H2ServiceHandler { + srv: CloneableService, + cfg: ServiceConfig, + on_connect: Option Box>>, + _t: PhantomData<(T, P, B)>, +} + +impl H2ServiceHandler +where + S: Service, + S::Error: Into + 'static, + S::Future: 'static, + S::Response: Into> + 'static, + B: MessageBody + 'static, +{ + fn new( + cfg: ServiceConfig, + on_connect: Option Box>>, + srv: S, + ) -> H2ServiceHandler { + H2ServiceHandler { + cfg, + on_connect, + srv: CloneableService::new(srv), + _t: PhantomData, + } + } +} + +impl Service for H2ServiceHandler +where + T: IoStream, + S: Service, + S::Error: Into + 'static, + S::Future: 'static, + S::Response: Into> + 'static, + B: MessageBody + 'static, +{ + type Request = Io; + type Response = (); + type Error = DispatchError; + type Future = H2ServiceHandlerResponse; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.srv.poll_ready(cx).map_err(|e| { + let e = e.into(); + error!("Service readiness error: {:?}", e); + DispatchError::Service(e) + }) + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + let io = req.into_parts().0; + let peer_addr = io.peer_addr(); + let on_connect = if let Some(ref on_connect) = self.on_connect { + Some(on_connect(&io)) + } else { + None + }; + + H2ServiceHandlerResponse { + state: State::Handshake( + Some(self.srv.clone()), + Some(self.cfg.clone()), + peer_addr, + on_connect, + server::handshake(io), + ), + } + } +} + +enum State, B: MessageBody> +where + S::Future: 'static, +{ + Incoming(Dispatcher), + Handshake( + Option>, + Option, + Option, + Option>, + Handshake, + ), +} + +pub struct H2ServiceHandlerResponse +where + T: IoStream, + S: Service, + S::Error: Into + 'static, + S::Future: 'static, + S::Response: Into> + 'static, + B: MessageBody + 'static, +{ + state: State, +} + +impl Future for H2ServiceHandlerResponse +where + T: IoStream, + S: Service, + S::Error: Into + 'static, + S::Future: 'static, + S::Response: Into> + 'static, + B: MessageBody, +{ + type Output = Result<(), DispatchError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match self.state { + State::Incoming(ref mut disp) => Pin::new(disp).poll(cx), + State::Handshake( + ref mut srv, + ref mut config, + ref peer_addr, + ref mut on_connect, + ref mut handshake, + ) => match Pin::new(handshake).poll(cx) { + Poll::Ready(Ok(conn)) => { + self.state = State::Incoming(Dispatcher::new( + srv.take().unwrap(), + conn, + on_connect.take(), + config.take().unwrap(), + None, + *peer_addr, + )); + self.poll(cx) + } + Poll::Ready(Err(err)) => { + trace!("H2 handshake error: {}", err); + Poll::Ready(Err(err.into())) + } + Poll::Pending => Poll::Pending, + }, + } + } +} diff --git a/src/header/common/accept.rs b/actix-http/src/header/common/accept.rs similarity index 79% rename from src/header/common/accept.rs rename to actix-http/src/header/common/accept.rs index be49b151f..d52eba241 100644 --- a/src/header/common/accept.rs +++ b/actix-http/src/header/common/accept.rs @@ -1,6 +1,7 @@ -use mime::{self, Mime}; -use header::{QualityItem, qitem}; -use http::header as http; +use mime::Mime; + +use crate::header::{qitem, QualityItem}; +use crate::http::header; header! { /// `Accept` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.2) @@ -30,13 +31,13 @@ header! { /// /// # Examples /// ```rust - /// # extern crate actix_web; + /// # extern crate actix_http; /// extern crate mime; - /// use actix_web::HttpResponse; - /// use actix_web::http::header::{Accept, qitem}; + /// use actix_http::Response; + /// use actix_http::http::header::{Accept, qitem}; /// /// # fn main() { - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// /// builder.set( /// Accept(vec![ @@ -47,13 +48,13 @@ header! { /// ``` /// /// ```rust - /// # extern crate actix_web; + /// # extern crate actix_http; /// extern crate mime; - /// use actix_web::HttpResponse; - /// use actix_web::http::header::{Accept, qitem}; + /// use actix_http::Response; + /// use actix_http::http::header::{Accept, qitem}; /// /// # fn main() { - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// /// builder.set( /// Accept(vec![ @@ -64,13 +65,13 @@ header! { /// ``` /// /// ```rust - /// # extern crate actix_web; + /// # extern crate actix_http; /// extern crate mime; - /// use actix_web::HttpResponse; - /// use actix_web::http::header::{Accept, QualityItem, q, qitem}; + /// use actix_http::Response; + /// use actix_http::http::header::{Accept, QualityItem, q, qitem}; /// /// # fn main() { - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// /// builder.set( /// Accept(vec![ @@ -89,7 +90,7 @@ header! { /// ); /// # } /// ``` - (Accept, http::ACCEPT) => (QualityItem)+ + (Accept, header::ACCEPT) => (QualityItem)+ test_accept { // Tests from the RFC @@ -104,8 +105,8 @@ header! { test2, vec![b"text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c"], Some(HeaderField(vec![ - QualityItem::new(TEXT_PLAIN, q(500)), - qitem(TEXT_HTML), + QualityItem::new(mime::TEXT_PLAIN, q(500)), + qitem(mime::TEXT_HTML), QualityItem::new( "text/x-dvi".parse().unwrap(), q(800)), @@ -116,20 +117,20 @@ header! { test3, vec![b"text/plain; charset=utf-8"], Some(Accept(vec![ - qitem(TEXT_PLAIN_UTF_8), + qitem(mime::TEXT_PLAIN_UTF_8), ]))); test_header!( test4, vec![b"text/plain; charset=utf-8; q=0.5"], Some(Accept(vec![ - QualityItem::new(TEXT_PLAIN_UTF_8, + QualityItem::new(mime::TEXT_PLAIN_UTF_8, q(500)), ]))); #[test] fn test_fuzzing1() { - use test::TestRequest; - let req = TestRequest::with_header(super::http::ACCEPT, "chunk#;e").finish(); + use crate::test::TestRequest; + let req = TestRequest::with_header(crate::header::ACCEPT, "chunk#;e").finish(); let header = Accept::parse(&req); assert!(header.is_ok()); } diff --git a/src/header/common/accept_charset.rs b/actix-http/src/header/common/accept_charset.rs similarity index 71% rename from src/header/common/accept_charset.rs rename to actix-http/src/header/common/accept_charset.rs index 3282198e4..117e2015d 100644 --- a/src/header/common/accept_charset.rs +++ b/actix-http/src/header/common/accept_charset.rs @@ -1,4 +1,4 @@ -use header::{ACCEPT_CHARSET, Charset, QualityItem}; +use crate::header::{Charset, QualityItem, ACCEPT_CHARSET}; header! { /// `Accept-Charset` header, defined in @@ -22,24 +22,24 @@ header! { /// /// # Examples /// ```rust - /// # extern crate actix_web; - /// use actix_web::HttpResponse; - /// use actix_web::http::header::{AcceptCharset, Charset, qitem}; + /// # extern crate actix_http; + /// use actix_http::Response; + /// use actix_http::http::header::{AcceptCharset, Charset, qitem}; /// /// # fn main() { - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// builder.set( /// AcceptCharset(vec![qitem(Charset::Us_Ascii)]) /// ); /// # } /// ``` /// ```rust - /// # extern crate actix_web; - /// use actix_web::HttpResponse; - /// use actix_web::http::header::{AcceptCharset, Charset, q, QualityItem}; + /// # extern crate actix_http; + /// use actix_http::Response; + /// use actix_http::http::header::{AcceptCharset, Charset, q, QualityItem}; /// /// # fn main() { - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// builder.set( /// AcceptCharset(vec![ /// QualityItem::new(Charset::Us_Ascii, q(900)), @@ -49,12 +49,12 @@ header! { /// # } /// ``` /// ```rust - /// # extern crate actix_web; - /// use actix_web::HttpResponse; - /// use actix_web::http::header::{AcceptCharset, Charset, qitem}; + /// # extern crate actix_http; + /// use actix_http::Response; + /// use actix_http::http::header::{AcceptCharset, Charset, qitem}; /// /// # fn main() { - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// builder.set( /// AcceptCharset(vec![qitem(Charset::Ext("utf-8".to_owned()))]) /// ); diff --git a/src/header/common/accept_encoding.rs b/actix-http/src/header/common/accept_encoding.rs similarity index 100% rename from src/header/common/accept_encoding.rs rename to actix-http/src/header/common/accept_encoding.rs diff --git a/src/header/common/accept_language.rs b/actix-http/src/header/common/accept_language.rs similarity index 81% rename from src/header/common/accept_language.rs rename to actix-http/src/header/common/accept_language.rs index c9059beed..55879b57f 100644 --- a/src/header/common/accept_language.rs +++ b/actix-http/src/header/common/accept_language.rs @@ -1,5 +1,5 @@ +use crate::header::{QualityItem, ACCEPT_LANGUAGE}; use language_tags::LanguageTag; -use header::{ACCEPT_LANGUAGE, QualityItem}; header! { /// `Accept-Language` header, defined in @@ -23,13 +23,13 @@ header! { /// # Examples /// /// ```rust - /// # extern crate actix_web; + /// # extern crate actix_http; /// # extern crate language_tags; - /// use actix_web::HttpResponse; - /// use actix_web::http::header::{AcceptLanguage, LanguageTag, qitem}; + /// use actix_http::Response; + /// use actix_http::http::header::{AcceptLanguage, LanguageTag, qitem}; /// /// # fn main() { - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// let mut langtag: LanguageTag = Default::default(); /// langtag.language = Some("en".to_owned()); /// langtag.region = Some("US".to_owned()); @@ -42,13 +42,13 @@ header! { /// ``` /// /// ```rust - /// # extern crate actix_web; + /// # extern crate actix_http; /// # #[macro_use] extern crate language_tags; - /// use actix_web::HttpResponse; - /// use actix_web::http::header::{AcceptLanguage, QualityItem, q, qitem}; + /// use actix_http::Response; + /// use actix_http::http::header::{AcceptLanguage, QualityItem, q, qitem}; /// # /// # fn main() { - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// builder.set( /// AcceptLanguage(vec![ /// qitem(langtag!(da)), diff --git a/src/header/common/allow.rs b/actix-http/src/header/common/allow.rs similarity index 86% rename from src/header/common/allow.rs rename to actix-http/src/header/common/allow.rs index 5046290de..432cc00d5 100644 --- a/src/header/common/allow.rs +++ b/actix-http/src/header/common/allow.rs @@ -24,13 +24,13 @@ header! { /// /// ```rust /// # extern crate http; - /// # extern crate actix_web; - /// use actix_web::HttpResponse; - /// use actix_web::http::header::Allow; + /// # extern crate actix_http; + /// use actix_http::Response; + /// use actix_http::http::header::Allow; /// use http::Method; /// /// # fn main() { - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// builder.set( /// Allow(vec![Method::GET]) /// ); @@ -39,13 +39,13 @@ header! { /// /// ```rust /// # extern crate http; - /// # extern crate actix_web; - /// use actix_web::HttpResponse; - /// use actix_web::http::header::Allow; + /// # extern crate actix_http; + /// use actix_http::Response; + /// use actix_http::http::header::Allow; /// use http::Method; /// /// # fn main() { - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// builder.set( /// Allow(vec![ /// Method::GET, diff --git a/src/header/common/cache_control.rs b/actix-http/src/header/common/cache_control.rs similarity index 50% rename from src/header/common/cache_control.rs rename to actix-http/src/header/common/cache_control.rs index 09a39b184..55774619b 100644 --- a/src/header/common/cache_control.rs +++ b/actix-http/src/header/common/cache_control.rs @@ -1,8 +1,11 @@ use std::fmt::{self, Write}; use std::str::FromStr; + use http::header; -use header::{Header, IntoHeaderValue, Writer}; -use header::{from_comma_delimited, fmt_comma_delimited}; + +use crate::header::{ + fmt_comma_delimited, from_comma_delimited, Header, IntoHeaderValue, Writer, +}; /// `Cache-Control` header, defined in [RFC7234](https://tools.ietf.org/html/rfc7234#section-5.2) /// @@ -26,29 +29,24 @@ use header::{from_comma_delimited, fmt_comma_delimited}; /// /// # Examples /// ```rust -/// use actix_web::HttpResponse; -/// use actix_web::http::header::{CacheControl, CacheDirective}; +/// use actix_http::Response; +/// use actix_http::http::header::{CacheControl, CacheDirective}; /// -/// let mut builder = HttpResponse::Ok(); -/// builder.set( -/// CacheControl(vec![CacheDirective::MaxAge(86400u32)]) -/// ); +/// let mut builder = Response::Ok(); +/// builder.set(CacheControl(vec![CacheDirective::MaxAge(86400u32)])); /// ``` /// /// ```rust -/// use actix_web::HttpResponse; -/// use actix_web::http::header::{CacheControl, CacheDirective}; +/// use actix_http::Response; +/// use actix_http::http::header::{CacheControl, CacheDirective}; /// -/// let mut builder = HttpResponse::Ok(); -/// builder.set( -/// CacheControl(vec![ -/// CacheDirective::NoCache, -/// CacheDirective::Private, -/// CacheDirective::MaxAge(360u32), -/// CacheDirective::Extension("foo".to_owned(), -/// Some("bar".to_owned())), -/// ]) -/// ); +/// let mut builder = Response::Ok(); +/// builder.set(CacheControl(vec![ +/// CacheDirective::NoCache, +/// CacheDirective::Private, +/// CacheDirective::MaxAge(360u32), +/// CacheDirective::Extension("foo".to_owned(), Some("bar".to_owned())), +/// ])); /// ``` #[derive(PartialEq, Clone, Debug)] pub struct CacheControl(pub Vec); @@ -62,14 +60,15 @@ impl Header for CacheControl { } #[inline] - fn parse(msg: &T) -> Result - where T: ::HttpMessage + fn parse(msg: &T) -> Result + where + T: crate::HttpMessage, { - let directives = from_comma_delimited(msg.headers().get_all(Self::name()))?; + let directives = from_comma_delimited(msg.headers().get_all(&Self::name()))?; if !directives.is_empty() { Ok(CacheControl(directives)) } else { - Err(::error::ParseError::Header) + Err(crate::error::ParseError::Header) } } } @@ -123,32 +122,36 @@ pub enum CacheDirective { SMaxAge(u32), /// Extension directives. Optionally include an argument. - Extension(String, Option) + Extension(String, Option), } impl fmt::Display for CacheDirective { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { use self::CacheDirective::*; - fmt::Display::fmt(match *self { - NoCache => "no-cache", - NoStore => "no-store", - NoTransform => "no-transform", - OnlyIfCached => "only-if-cached", + fmt::Display::fmt( + match *self { + NoCache => "no-cache", + NoStore => "no-store", + NoTransform => "no-transform", + OnlyIfCached => "only-if-cached", - MaxAge(secs) => return write!(f, "max-age={}", secs), - MaxStale(secs) => return write!(f, "max-stale={}", secs), - MinFresh(secs) => return write!(f, "min-fresh={}", secs), + MaxAge(secs) => return write!(f, "max-age={}", secs), + MaxStale(secs) => return write!(f, "max-stale={}", secs), + MinFresh(secs) => return write!(f, "min-fresh={}", secs), - MustRevalidate => "must-revalidate", - Public => "public", - Private => "private", - ProxyRevalidate => "proxy-revalidate", - SMaxAge(secs) => return write!(f, "s-maxage={}", secs), + MustRevalidate => "must-revalidate", + Public => "public", + Private => "private", + ProxyRevalidate => "proxy-revalidate", + SMaxAge(secs) => return write!(f, "s-maxage={}", secs), - Extension(ref name, None) => &name[..], - Extension(ref name, Some(ref arg)) => return write!(f, "{}={}", name, arg), - - }, f) + Extension(ref name, None) => &name[..], + Extension(ref name, Some(ref arg)) => { + return write!(f, "{}={}", name, arg); + } + }, + f, + ) } } @@ -167,16 +170,20 @@ impl FromStr for CacheDirective { "proxy-revalidate" => Ok(ProxyRevalidate), "" => Err(None), _ => match s.find('=') { - Some(idx) if idx+1 < s.len() => match (&s[..idx], (&s[idx+1..]).trim_matches('"')) { - ("max-age" , secs) => secs.parse().map(MaxAge).map_err(Some), - ("max-stale", secs) => secs.parse().map(MaxStale).map_err(Some), - ("min-fresh", secs) => secs.parse().map(MinFresh).map_err(Some), - ("s-maxage", secs) => secs.parse().map(SMaxAge).map_err(Some), - (left, right) => Ok(Extension(left.to_owned(), Some(right.to_owned()))) - }, + Some(idx) if idx + 1 < s.len() => { + match (&s[..idx], (&s[idx + 1..]).trim_matches('"')) { + ("max-age", secs) => secs.parse().map(MaxAge).map_err(Some), + ("max-stale", secs) => secs.parse().map(MaxStale).map_err(Some), + ("min-fresh", secs) => secs.parse().map(MinFresh).map_err(Some), + ("s-maxage", secs) => secs.parse().map(SMaxAge).map_err(Some), + (left, right) => { + Ok(Extension(left.to_owned(), Some(right.to_owned()))) + } + } + } Some(_) => Err(None), - None => Ok(Extension(s.to_owned(), None)) - } + None => Ok(Extension(s.to_owned(), None)), + }, } } } @@ -184,43 +191,61 @@ impl FromStr for CacheDirective { #[cfg(test)] mod tests { use super::*; - use header::Header; - use test::TestRequest; + use crate::header::Header; + use crate::test::TestRequest; #[test] fn test_parse_multiple_headers() { - let req = TestRequest::with_header( - header::CACHE_CONTROL, "no-cache, private").finish(); + let req = TestRequest::with_header(header::CACHE_CONTROL, "no-cache, private") + .finish(); let cache = Header::parse(&req); - assert_eq!(cache.ok(), Some(CacheControl(vec![CacheDirective::NoCache, - CacheDirective::Private]))) + assert_eq!( + cache.ok(), + Some(CacheControl(vec![ + CacheDirective::NoCache, + CacheDirective::Private, + ])) + ) } #[test] fn test_parse_argument() { - let req = TestRequest::with_header( - header::CACHE_CONTROL, "max-age=100, private").finish(); + let req = + TestRequest::with_header(header::CACHE_CONTROL, "max-age=100, private") + .finish(); let cache = Header::parse(&req); - assert_eq!(cache.ok(), Some(CacheControl(vec![CacheDirective::MaxAge(100), - CacheDirective::Private]))) + assert_eq!( + cache.ok(), + Some(CacheControl(vec![ + CacheDirective::MaxAge(100), + CacheDirective::Private, + ])) + ) } #[test] fn test_parse_quote_form() { - let req = TestRequest::with_header( - header::CACHE_CONTROL, "max-age=\"200\"").finish(); + let req = + TestRequest::with_header(header::CACHE_CONTROL, "max-age=\"200\"").finish(); let cache = Header::parse(&req); - assert_eq!(cache.ok(), Some(CacheControl(vec![CacheDirective::MaxAge(200)]))) + assert_eq!( + cache.ok(), + Some(CacheControl(vec![CacheDirective::MaxAge(200)])) + ) } #[test] fn test_parse_extension() { - let req = TestRequest::with_header( - header::CACHE_CONTROL, "foo, bar=baz").finish(); + let req = + TestRequest::with_header(header::CACHE_CONTROL, "foo, bar=baz").finish(); let cache = Header::parse(&req); - assert_eq!(cache.ok(), Some(CacheControl(vec![ - CacheDirective::Extension("foo".to_owned(), None), - CacheDirective::Extension("bar".to_owned(), Some("baz".to_owned()))]))) + assert_eq!( + cache.ok(), + Some(CacheControl(vec![ + CacheDirective::Extension("foo".to_owned(), None), + CacheDirective::Extension("bar".to_owned(), Some("baz".to_owned())), + ])) + ) } #[test] diff --git a/actix-http/src/header/common/content_disposition.rs b/actix-http/src/header/common/content_disposition.rs new file mode 100644 index 000000000..b2b6f34d7 --- /dev/null +++ b/actix-http/src/header/common/content_disposition.rs @@ -0,0 +1,995 @@ +// # References +// +// "The Content-Disposition Header Field" https://www.ietf.org/rfc/rfc2183.txt +// "The Content-Disposition Header Field in the Hypertext Transfer Protocol (HTTP)" https://www.ietf.org/rfc/rfc6266.txt +// "Returning Values from Forms: multipart/form-data" https://www.ietf.org/rfc/rfc7578.txt +// Browser conformance tests at: http://greenbytes.de/tech/tc2231/ +// IANA assignment: http://www.iana.org/assignments/cont-disp/cont-disp.xhtml + +use lazy_static::lazy_static; +use regex::Regex; +use std::fmt::{self, Write}; + +use crate::header::{self, ExtendedValue, Header, IntoHeaderValue, Writer}; + +/// Split at the index of the first `needle` if it exists or at the end. +fn split_once(haystack: &str, needle: char) -> (&str, &str) { + haystack.find(needle).map_or_else( + || (haystack, ""), + |sc| { + let (first, last) = haystack.split_at(sc); + (first, last.split_at(1).1) + }, + ) +} + +/// Split at the index of the first `needle` if it exists or at the end, trim the right of the +/// first part and the left of the last part. +fn split_once_and_trim(haystack: &str, needle: char) -> (&str, &str) { + let (first, last) = split_once(haystack, needle); + (first.trim_end(), last.trim_start()) +} + +/// The implied disposition of the content of the HTTP body. +#[derive(Clone, Debug, PartialEq)] +pub enum DispositionType { + /// Inline implies default processing + Inline, + /// Attachment implies that the recipient should prompt the user to save the response locally, + /// rather than process it normally (as per its media type). + Attachment, + /// Used in *multipart/form-data* as defined in + /// [RFC7578](https://tools.ietf.org/html/rfc7578) to carry the field name and the file name. + FormData, + /// Extension type. Should be handled by recipients the same way as Attachment + Ext(String), +} + +impl<'a> From<&'a str> for DispositionType { + fn from(origin: &'a str) -> DispositionType { + if origin.eq_ignore_ascii_case("inline") { + DispositionType::Inline + } else if origin.eq_ignore_ascii_case("attachment") { + DispositionType::Attachment + } else if origin.eq_ignore_ascii_case("form-data") { + DispositionType::FormData + } else { + DispositionType::Ext(origin.to_owned()) + } + } +} + +/// Parameter in [`ContentDisposition`]. +/// +/// # Examples +/// ``` +/// use actix_http::http::header::DispositionParam; +/// +/// let param = DispositionParam::Filename(String::from("sample.txt")); +/// assert!(param.is_filename()); +/// assert_eq!(param.as_filename().unwrap(), "sample.txt"); +/// ``` +#[derive(Clone, Debug, PartialEq)] +#[allow(clippy::large_enum_variant)] +pub enum DispositionParam { + /// For [`DispositionType::FormData`] (i.e. *multipart/form-data*), the name of an field from + /// the form. + Name(String), + /// A plain file name. + /// + /// It is [not supposed](https://tools.ietf.org/html/rfc6266#appendix-D) to contain any + /// non-ASCII characters when used in a *Content-Disposition* HTTP response header, where + /// [`FilenameExt`](DispositionParam::FilenameExt) with charset UTF-8 may be used instead + /// in case there are Unicode characters in file names. + Filename(String), + /// An extended file name. It must not exist for `ContentType::Formdata` according to + /// [RFC7578 Section 4.2](https://tools.ietf.org/html/rfc7578#section-4.2). + FilenameExt(ExtendedValue), + /// An unrecognized regular parameter as defined in + /// [RFC5987](https://tools.ietf.org/html/rfc5987) as *reg-parameter*, in + /// [RFC6266](https://tools.ietf.org/html/rfc6266) as *token "=" value*. Recipients should + /// ignore unrecognizable parameters. + Unknown(String, String), + /// An unrecognized extended paramater as defined in + /// [RFC5987](https://tools.ietf.org/html/rfc5987) as *ext-parameter*, in + /// [RFC6266](https://tools.ietf.org/html/rfc6266) as *ext-token "=" ext-value*. The single + /// trailling asterisk is not included. Recipients should ignore unrecognizable parameters. + UnknownExt(String, ExtendedValue), +} + +impl DispositionParam { + /// Returns `true` if the paramater is [`Name`](DispositionParam::Name). + #[inline] + pub fn is_name(&self) -> bool { + self.as_name().is_some() + } + + /// Returns `true` if the paramater is [`Filename`](DispositionParam::Filename). + #[inline] + pub fn is_filename(&self) -> bool { + self.as_filename().is_some() + } + + /// Returns `true` if the paramater is [`FilenameExt`](DispositionParam::FilenameExt). + #[inline] + pub fn is_filename_ext(&self) -> bool { + self.as_filename_ext().is_some() + } + + /// Returns `true` if the paramater is [`Unknown`](DispositionParam::Unknown) and the `name` + #[inline] + /// matches. + pub fn is_unknown>(&self, name: T) -> bool { + self.as_unknown(name).is_some() + } + + /// Returns `true` if the paramater is [`UnknownExt`](DispositionParam::UnknownExt) and the + /// `name` matches. + #[inline] + pub fn is_unknown_ext>(&self, name: T) -> bool { + self.as_unknown_ext(name).is_some() + } + + /// Returns the name if applicable. + #[inline] + pub fn as_name(&self) -> Option<&str> { + match self { + DispositionParam::Name(ref name) => Some(name.as_str()), + _ => None, + } + } + + /// Returns the filename if applicable. + #[inline] + pub fn as_filename(&self) -> Option<&str> { + match self { + DispositionParam::Filename(ref filename) => Some(filename.as_str()), + _ => None, + } + } + + /// Returns the filename* if applicable. + #[inline] + pub fn as_filename_ext(&self) -> Option<&ExtendedValue> { + match self { + DispositionParam::FilenameExt(ref value) => Some(value), + _ => None, + } + } + + /// Returns the value of the unrecognized regular parameter if it is + /// [`Unknown`](DispositionParam::Unknown) and the `name` matches. + #[inline] + pub fn as_unknown>(&self, name: T) -> Option<&str> { + match self { + DispositionParam::Unknown(ref ext_name, ref value) + if ext_name.eq_ignore_ascii_case(name.as_ref()) => + { + Some(value.as_str()) + } + _ => None, + } + } + + /// Returns the value of the unrecognized extended parameter if it is + /// [`Unknown`](DispositionParam::Unknown) and the `name` matches. + #[inline] + pub fn as_unknown_ext>(&self, name: T) -> Option<&ExtendedValue> { + match self { + DispositionParam::UnknownExt(ref ext_name, ref value) + if ext_name.eq_ignore_ascii_case(name.as_ref()) => + { + Some(value) + } + _ => None, + } + } +} + +/// A *Content-Disposition* header. It is compatible to be used either as +/// [a response header for the main body](https://mdn.io/Content-Disposition#As_a_response_header_for_the_main_body) +/// as (re)defined in [RFC6266](https://tools.ietf.org/html/rfc6266), or as +/// [a header for a multipart body](https://mdn.io/Content-Disposition#As_a_header_for_a_multipart_body) +/// as (re)defined in [RFC7587](https://tools.ietf.org/html/rfc7578). +/// +/// In a regular HTTP response, the *Content-Disposition* response header is a header indicating if +/// the content is expected to be displayed *inline* in the browser, that is, as a Web page or as +/// part of a Web page, or as an attachment, that is downloaded and saved locally, and also can be +/// used to attach additional metadata, such as the filename to use when saving the response payload +/// locally. +/// +/// In a *multipart/form-data* body, the HTTP *Content-Disposition* general header is a header that +/// can be used on the subpart of a multipart body to give information about the field it applies to. +/// The subpart is delimited by the boundary defined in the *Content-Type* header. Used on the body +/// itself, *Content-Disposition* has no effect. +/// +/// # ABNF + +/// ```text +/// content-disposition = "Content-Disposition" ":" +/// disposition-type *( ";" disposition-parm ) +/// +/// disposition-type = "inline" | "attachment" | disp-ext-type +/// ; case-insensitive +/// +/// disp-ext-type = token +/// +/// disposition-parm = filename-parm | disp-ext-parm +/// +/// filename-parm = "filename" "=" value +/// | "filename*" "=" ext-value +/// +/// disp-ext-parm = token "=" value +/// | ext-token "=" ext-value +/// +/// ext-token = +/// ``` +/// +/// # Note +/// +/// filename is [not supposed](https://tools.ietf.org/html/rfc6266#appendix-D) to contain any +/// non-ASCII characters when used in a *Content-Disposition* HTTP response header, where +/// filename* with charset UTF-8 may be used instead in case there are Unicode characters in file +/// names. +/// filename is [acceptable](https://tools.ietf.org/html/rfc7578#section-4.2) to be UTF-8 encoded +/// directly in a *Content-Disposition* header for *multipart/form-data*, though. +/// +/// filename* [must not](https://tools.ietf.org/html/rfc7578#section-4.2) be used within +/// *multipart/form-data*. +/// +/// # Example +/// +/// ``` +/// use actix_http::http::header::{ +/// Charset, ContentDisposition, DispositionParam, DispositionType, +/// ExtendedValue, +/// }; +/// +/// let cd1 = ContentDisposition { +/// disposition: DispositionType::Attachment, +/// parameters: vec![DispositionParam::FilenameExt(ExtendedValue { +/// charset: Charset::Iso_8859_1, // The character set for the bytes of the filename +/// language_tag: None, // The optional language tag (see `language-tag` crate) +/// value: b"\xa9 Copyright 1989.txt".to_vec(), // the actual bytes of the filename +/// })], +/// }; +/// assert!(cd1.is_attachment()); +/// assert!(cd1.get_filename_ext().is_some()); +/// +/// let cd2 = ContentDisposition { +/// disposition: DispositionType::FormData, +/// parameters: vec![ +/// DispositionParam::Name(String::from("file")), +/// DispositionParam::Filename(String::from("bill.odt")), +/// ], +/// }; +/// assert_eq!(cd2.get_name(), Some("file")); // field name +/// assert_eq!(cd2.get_filename(), Some("bill.odt")); +/// +/// // HTTP response header with Unicode characters in file names +/// let cd3 = ContentDisposition { +/// disposition: DispositionType::Attachment, +/// parameters: vec![ +/// DispositionParam::FilenameExt(ExtendedValue { +/// charset: Charset::Ext(String::from("UTF-8")), +/// language_tag: None, +/// value: String::from("\u{1f600}.svg").into_bytes(), +/// }), +/// // fallback for better compatibility +/// DispositionParam::Filename(String::from("Grinning-Face-Emoji.svg")) +/// ], +/// }; +/// assert_eq!(cd3.get_filename_ext().map(|ev| ev.value.as_ref()), +/// Some("\u{1f600}.svg".as_bytes())); +/// ``` +/// +/// # WARN +/// If "filename" parameter is supplied, do not use the file name blindly, check and possibly +/// change to match local file system conventions if applicable, and do not use directory path +/// information that may be present. See [RFC2183](https://tools.ietf.org/html/rfc2183#section-2.3) +/// . +#[derive(Clone, Debug, PartialEq)] +pub struct ContentDisposition { + /// The disposition type + pub disposition: DispositionType, + /// Disposition parameters + pub parameters: Vec, +} + +impl ContentDisposition { + /// Parse a raw Content-Disposition header value. + pub fn from_raw(hv: &header::HeaderValue) -> Result { + // `header::from_one_raw_str` invokes `hv.to_str` which assumes `hv` contains only visible + // ASCII characters. So `hv.as_bytes` is necessary here. + let hv = String::from_utf8(hv.as_bytes().to_vec()) + .map_err(|_| crate::error::ParseError::Header)?; + let (disp_type, mut left) = split_once_and_trim(hv.as_str().trim(), ';'); + if disp_type.is_empty() { + return Err(crate::error::ParseError::Header); + } + let mut cd = ContentDisposition { + disposition: disp_type.into(), + parameters: Vec::new(), + }; + + while !left.is_empty() { + let (param_name, new_left) = split_once_and_trim(left, '='); + if param_name.is_empty() || param_name == "*" || new_left.is_empty() { + return Err(crate::error::ParseError::Header); + } + left = new_left; + if param_name.ends_with('*') { + // extended parameters + let param_name = ¶m_name[..param_name.len() - 1]; // trim asterisk + let (ext_value, new_left) = split_once_and_trim(left, ';'); + left = new_left; + let ext_value = header::parse_extended_value(ext_value)?; + + let param = if param_name.eq_ignore_ascii_case("filename") { + DispositionParam::FilenameExt(ext_value) + } else { + DispositionParam::UnknownExt(param_name.to_owned(), ext_value) + }; + cd.parameters.push(param); + } else { + // regular parameters + let value = if left.starts_with('\"') { + // quoted-string: defined in RFC6266 -> RFC2616 Section 3.6 + let mut escaping = false; + let mut quoted_string = vec![]; + let mut end = None; + // search for closing quote + for (i, &c) in left.as_bytes().iter().skip(1).enumerate() { + if escaping { + escaping = false; + quoted_string.push(c); + } else if c == 0x5c { + // backslash + escaping = true; + } else if c == 0x22 { + // double quote + end = Some(i + 1); // cuz skipped 1 for the leading quote + break; + } else { + quoted_string.push(c); + } + } + left = &left[end.ok_or(crate::error::ParseError::Header)? + 1..]; + left = split_once(left, ';').1.trim_start(); + // In fact, it should not be Err if the above code is correct. + String::from_utf8(quoted_string) + .map_err(|_| crate::error::ParseError::Header)? + } else { + // token: won't contains semicolon according to RFC 2616 Section 2.2 + let (token, new_left) = split_once_and_trim(left, ';'); + left = new_left; + if token.is_empty() { + // quoted-string can be empty, but token cannot be empty + return Err(crate::error::ParseError::Header); + } + token.to_owned() + }; + + let param = if param_name.eq_ignore_ascii_case("name") { + DispositionParam::Name(value) + } else if param_name.eq_ignore_ascii_case("filename") { + // See also comments in test_from_raw_uncessary_percent_decode. + DispositionParam::Filename(value) + } else { + DispositionParam::Unknown(param_name.to_owned(), value) + }; + cd.parameters.push(param); + } + } + + Ok(cd) + } + + /// Returns `true` if it is [`Inline`](DispositionType::Inline). + pub fn is_inline(&self) -> bool { + match self.disposition { + DispositionType::Inline => true, + _ => false, + } + } + + /// Returns `true` if it is [`Attachment`](DispositionType::Attachment). + pub fn is_attachment(&self) -> bool { + match self.disposition { + DispositionType::Attachment => true, + _ => false, + } + } + + /// Returns `true` if it is [`FormData`](DispositionType::FormData). + pub fn is_form_data(&self) -> bool { + match self.disposition { + DispositionType::FormData => true, + _ => false, + } + } + + /// Returns `true` if it is [`Ext`](DispositionType::Ext) and the `disp_type` matches. + pub fn is_ext>(&self, disp_type: T) -> bool { + match self.disposition { + DispositionType::Ext(ref t) + if t.eq_ignore_ascii_case(disp_type.as_ref()) => + { + true + } + _ => false, + } + } + + /// Return the value of *name* if exists. + pub fn get_name(&self) -> Option<&str> { + self.parameters.iter().filter_map(|p| p.as_name()).nth(0) + } + + /// Return the value of *filename* if exists. + pub fn get_filename(&self) -> Option<&str> { + self.parameters + .iter() + .filter_map(|p| p.as_filename()) + .nth(0) + } + + /// Return the value of *filename\** if exists. + pub fn get_filename_ext(&self) -> Option<&ExtendedValue> { + self.parameters + .iter() + .filter_map(|p| p.as_filename_ext()) + .nth(0) + } + + /// Return the value of the parameter which the `name` matches. + pub fn get_unknown>(&self, name: T) -> Option<&str> { + let name = name.as_ref(); + self.parameters + .iter() + .filter_map(|p| p.as_unknown(name)) + .nth(0) + } + + /// Return the value of the extended parameter which the `name` matches. + pub fn get_unknown_ext>(&self, name: T) -> Option<&ExtendedValue> { + let name = name.as_ref(); + self.parameters + .iter() + .filter_map(|p| p.as_unknown_ext(name)) + .nth(0) + } +} + +impl IntoHeaderValue for ContentDisposition { + type Error = header::InvalidHeaderValueBytes; + + fn try_into(self) -> Result { + let mut writer = Writer::new(); + let _ = write!(&mut writer, "{}", self); + header::HeaderValue::from_shared(writer.take()) + } +} + +impl Header for ContentDisposition { + fn name() -> header::HeaderName { + header::CONTENT_DISPOSITION + } + + fn parse(msg: &T) -> Result { + if let Some(h) = msg.headers().get(&Self::name()) { + Self::from_raw(&h) + } else { + Err(crate::error::ParseError::Header) + } + } +} + +impl fmt::Display for DispositionType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + DispositionType::Inline => write!(f, "inline"), + DispositionType::Attachment => write!(f, "attachment"), + DispositionType::FormData => write!(f, "form-data"), + DispositionType::Ext(ref s) => write!(f, "{}", s), + } + } +} + +impl fmt::Display for DispositionParam { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // All ASCII control characters (0-30, 127) including horizontal tab, double quote, and + // backslash should be escaped in quoted-string (i.e. "foobar"). + // Ref: RFC6266 S4.1 -> RFC2616 S3.6 + // filename-parm = "filename" "=" value + // value = token | quoted-string + // quoted-string = ( <"> *(qdtext | quoted-pair ) <"> ) + // qdtext = > + // quoted-pair = "\" CHAR + // TEXT = + // LWS = [CRLF] 1*( SP | HT ) + // OCTET = + // CHAR = + // CTL = + // + // Ref: RFC7578 S4.2 -> RFC2183 S2 -> RFC2045 S5.1 + // parameter := attribute "=" value + // attribute := token + // ; Matching of attributes + // ; is ALWAYS case-insensitive. + // value := token / quoted-string + // token := 1* + // tspecials := "(" / ")" / "<" / ">" / "@" / + // "," / ";" / ":" / "\" / <"> + // "/" / "[" / "]" / "?" / "=" + // ; Must be in quoted-string, + // ; to use within parameter values + // + // + // See also comments in test_from_raw_uncessary_percent_decode. + lazy_static! { + static ref RE: Regex = Regex::new("[\x00-\x08\x10-\x1F\x7F\"\\\\]").unwrap(); + } + match self { + DispositionParam::Name(ref value) => write!(f, "name={}", value), + DispositionParam::Filename(ref value) => { + write!(f, "filename=\"{}\"", RE.replace_all(value, "\\$0").as_ref()) + } + DispositionParam::Unknown(ref name, ref value) => write!( + f, + "{}=\"{}\"", + name, + &RE.replace_all(value, "\\$0").as_ref() + ), + DispositionParam::FilenameExt(ref ext_value) => { + write!(f, "filename*={}", ext_value) + } + DispositionParam::UnknownExt(ref name, ref ext_value) => { + write!(f, "{}*={}", name, ext_value) + } + } + } +} + +impl fmt::Display for ContentDisposition { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.disposition)?; + self.parameters + .iter() + .map(|param| write!(f, "; {}", param)) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::{ContentDisposition, DispositionParam, DispositionType}; + use crate::header::shared::Charset; + use crate::header::{ExtendedValue, HeaderValue}; + + #[test] + fn test_from_raw_basic() { + assert!(ContentDisposition::from_raw(&HeaderValue::from_static("")).is_err()); + + let a = HeaderValue::from_static( + "form-data; dummy=3; name=upload; filename=\"sample.png\"", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Unknown("dummy".to_owned(), "3".to_owned()), + DispositionParam::Name("upload".to_owned()), + DispositionParam::Filename("sample.png".to_owned()), + ], + }; + assert_eq!(a, b); + + let a = HeaderValue::from_static("attachment; filename=\"image.jpg\""); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![DispositionParam::Filename("image.jpg".to_owned())], + }; + assert_eq!(a, b); + + let a = HeaderValue::from_static("inline; filename=image.jpg"); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Inline, + parameters: vec![DispositionParam::Filename("image.jpg".to_owned())], + }; + assert_eq!(a, b); + + let a = HeaderValue::from_static( + "attachment; creation-date=\"Wed, 12 Feb 1997 16:29:51 -0500\"", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![DispositionParam::Unknown( + String::from("creation-date"), + "Wed, 12 Feb 1997 16:29:51 -0500".to_owned(), + )], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_extended() { + let a = HeaderValue::from_static( + "attachment; filename*=UTF-8''%c2%a3%20and%20%e2%82%ac%20rates", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![DispositionParam::FilenameExt(ExtendedValue { + charset: Charset::Ext(String::from("UTF-8")), + language_tag: None, + value: vec![ + 0xc2, 0xa3, 0x20, b'a', b'n', b'd', 0x20, 0xe2, 0x82, 0xac, 0x20, + b'r', b'a', b't', b'e', b's', + ], + })], + }; + assert_eq!(a, b); + + let a = HeaderValue::from_static( + "attachment; filename*=UTF-8''%c2%a3%20and%20%e2%82%ac%20rates", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![DispositionParam::FilenameExt(ExtendedValue { + charset: Charset::Ext(String::from("UTF-8")), + language_tag: None, + value: vec![ + 0xc2, 0xa3, 0x20, b'a', b'n', b'd', 0x20, 0xe2, 0x82, 0xac, 0x20, + b'r', b'a', b't', b'e', b's', + ], + })], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_extra_whitespace() { + let a = HeaderValue::from_static( + "form-data ; du-mmy= 3 ; name =upload ; filename = \"sample.png\" ; ", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Unknown("du-mmy".to_owned(), "3".to_owned()), + DispositionParam::Name("upload".to_owned()), + DispositionParam::Filename("sample.png".to_owned()), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_unordered() { + let a = HeaderValue::from_static( + "form-data; dummy=3; filename=\"sample.png\" ; name=upload;", + // Actually, a trailling semolocon is not compliant. But it is fine to accept. + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Unknown("dummy".to_owned(), "3".to_owned()), + DispositionParam::Filename("sample.png".to_owned()), + DispositionParam::Name("upload".to_owned()), + ], + }; + assert_eq!(a, b); + + let a = HeaderValue::from_str( + "attachment; filename*=iso-8859-1''foo-%E4.html; filename=\"foo-ä.html\"", + ) + .unwrap(); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![ + DispositionParam::FilenameExt(ExtendedValue { + charset: Charset::Iso_8859_1, + language_tag: None, + value: b"foo-\xe4.html".to_vec(), + }), + DispositionParam::Filename("foo-ä.html".to_owned()), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_only_disp() { + let a = ContentDisposition::from_raw(&HeaderValue::from_static("attachment")) + .unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![], + }; + assert_eq!(a, b); + + let a = + ContentDisposition::from_raw(&HeaderValue::from_static("inline ;")).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Inline, + parameters: vec![], + }; + assert_eq!(a, b); + + let a = ContentDisposition::from_raw(&HeaderValue::from_static( + "unknown-disp-param", + )) + .unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Ext(String::from("unknown-disp-param")), + parameters: vec![], + }; + assert_eq!(a, b); + } + + #[test] + fn from_raw_with_mixed_case() { + let a = HeaderValue::from_str( + "InLInE; fIlenAME*=iso-8859-1''foo-%E4.html; filEName=\"foo-ä.html\"", + ) + .unwrap(); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Inline, + parameters: vec![ + DispositionParam::FilenameExt(ExtendedValue { + charset: Charset::Iso_8859_1, + language_tag: None, + value: b"foo-\xe4.html".to_vec(), + }), + DispositionParam::Filename("foo-ä.html".to_owned()), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn from_raw_with_unicode() { + /* RFC7578 Section 4.2: + Some commonly deployed systems use multipart/form-data with file names directly encoded + including octets outside the US-ASCII range. The encoding used for the file names is + typically UTF-8, although HTML forms will use the charset associated with the form. + + Mainstream browsers like Firefox (gecko) and Chrome use UTF-8 directly as above. + (And now, only UTF-8 is handled by this implementation.) + */ + let a = + HeaderValue::from_str("form-data; name=upload; filename=\"文件.webp\"") + .unwrap(); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Name(String::from("upload")), + DispositionParam::Filename(String::from("文件.webp")), + ], + }; + assert_eq!(a, b); + + let a = HeaderValue::from_str( + "form-data; name=upload; filename=\"余固知謇謇之為患兮,å¿è€Œä¸èƒ½èˆä¹Ÿ.pptx\"", + ) + .unwrap(); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Name(String::from("upload")), + DispositionParam::Filename(String::from( + "余固知謇謇之為患兮,å¿è€Œä¸èƒ½èˆä¹Ÿ.pptx", + )), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_escape() { + let a = HeaderValue::from_static( + "form-data; dummy=3; name=upload; filename=\"s\\amp\\\"le.png\"", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Unknown("dummy".to_owned(), "3".to_owned()), + DispositionParam::Name("upload".to_owned()), + DispositionParam::Filename( + ['s', 'a', 'm', 'p', '\"', 'l', 'e', '.', 'p', 'n', 'g'] + .iter() + .collect(), + ), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_semicolon() { + let a = + HeaderValue::from_static("form-data; filename=\"A semicolon here;.pdf\""); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![DispositionParam::Filename(String::from( + "A semicolon here;.pdf", + ))], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_uncessary_percent_decode() { + // In fact, RFC7578 (multipart/form-data) Section 2 and 4.2 suggests that filename with + // non-ASCII characters MAY be percent-encoded. + // On the contrary, RFC6266 or other RFCs related to Content-Disposition response header + // do not mention such percent-encoding. + // So, it appears to be undecidable whether to percent-decode or not without + // knowing the usage scenario (multipart/form-data v.s. HTTP response header) and + // inevitable to unnecessarily percent-decode filename with %XX in the former scenario. + // Fortunately, it seems that almost all mainstream browsers just send UTF-8 encoded file + // names in quoted-string format (tested on Edge, IE11, Chrome and Firefox) without + // percent-encoding. So we do not bother to attempt to percent-decode. + let a = HeaderValue::from_static( + "form-data; name=photo; filename=\"%74%65%73%74%2e%70%6e%67\"", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Name("photo".to_owned()), + DispositionParam::Filename(String::from("%74%65%73%74%2e%70%6e%67")), + ], + }; + assert_eq!(a, b); + + let a = HeaderValue::from_static( + "form-data; name=photo; filename=\"%74%65%73%74.png\"", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Name("photo".to_owned()), + DispositionParam::Filename(String::from("%74%65%73%74.png")), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_param_value_missing() { + let a = HeaderValue::from_static("form-data; name=upload ; filename="); + assert!(ContentDisposition::from_raw(&a).is_err()); + + let a = HeaderValue::from_static("attachment; dummy=; filename=invoice.pdf"); + assert!(ContentDisposition::from_raw(&a).is_err()); + + let a = HeaderValue::from_static("inline; filename= "); + assert!(ContentDisposition::from_raw(&a).is_err()); + + let a = HeaderValue::from_static("inline; filename=\"\""); + assert!(ContentDisposition::from_raw(&a).expect("parse cd").get_filename().expect("filename").is_empty()); + } + + #[test] + fn test_from_raw_param_name_missing() { + let a = HeaderValue::from_static("inline; =\"test.txt\""); + assert!(ContentDisposition::from_raw(&a).is_err()); + + let a = HeaderValue::from_static("inline; =diary.odt"); + assert!(ContentDisposition::from_raw(&a).is_err()); + + let a = HeaderValue::from_static("inline; ="); + assert!(ContentDisposition::from_raw(&a).is_err()); + } + + #[test] + fn test_display_extended() { + let as_string = + "attachment; filename*=UTF-8'en'%C2%A3%20and%20%E2%82%AC%20rates"; + let a = HeaderValue::from_static(as_string); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let display_rendered = format!("{}", a); + assert_eq!(as_string, display_rendered); + + let a = HeaderValue::from_static("attachment; filename=colourful.csv"); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let display_rendered = format!("{}", a); + assert_eq!( + "attachment; filename=\"colourful.csv\"".to_owned(), + display_rendered + ); + } + + #[test] + fn test_display_quote() { + let as_string = "form-data; name=upload; filename=\"Quote\\\"here.png\""; + as_string + .find(['\\', '\"'].iter().collect::().as_str()) + .unwrap(); // ensure `\"` is there + let a = HeaderValue::from_static(as_string); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let display_rendered = format!("{}", a); + assert_eq!(as_string, display_rendered); + } + + #[test] + fn test_display_space_tab() { + let as_string = "form-data; name=upload; filename=\"Space here.png\""; + let a = HeaderValue::from_static(as_string); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let display_rendered = format!("{}", a); + assert_eq!(as_string, display_rendered); + + let a: ContentDisposition = ContentDisposition { + disposition: DispositionType::Inline, + parameters: vec![DispositionParam::Filename(String::from("Tab\there.png"))], + }; + let display_rendered = format!("{}", a); + assert_eq!("inline; filename=\"Tab\x09here.png\"", display_rendered); + } + + #[test] + fn test_display_control_characters() { + /* let a = "attachment; filename=\"carriage\rreturn.png\""; + let a = HeaderValue::from_static(a); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let display_rendered = format!("{}", a); + assert_eq!( + "attachment; filename=\"carriage\\\rreturn.png\"", + display_rendered + );*/ + // No way to create a HeaderValue containing a carriage return. + + let a: ContentDisposition = ContentDisposition { + disposition: DispositionType::Inline, + parameters: vec![DispositionParam::Filename(String::from("bell\x07.png"))], + }; + let display_rendered = format!("{}", a); + assert_eq!("inline; filename=\"bell\\\x07.png\"", display_rendered); + } + + #[test] + fn test_param_methods() { + let param = DispositionParam::Filename(String::from("sample.txt")); + assert!(param.is_filename()); + assert_eq!(param.as_filename().unwrap(), "sample.txt"); + + let param = DispositionParam::Unknown(String::from("foo"), String::from("bar")); + assert!(param.is_unknown("foo")); + assert_eq!(param.as_unknown("fOo"), Some("bar")); + } + + #[test] + fn test_disposition_methods() { + let cd = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Unknown("dummy".to_owned(), "3".to_owned()), + DispositionParam::Name("upload".to_owned()), + DispositionParam::Filename("sample.png".to_owned()), + ], + }; + assert_eq!(cd.get_name(), Some("upload")); + assert_eq!(cd.get_unknown("dummy"), Some("3")); + assert_eq!(cd.get_filename(), Some("sample.png")); + assert_eq!(cd.get_unknown_ext("dummy"), None); + assert_eq!(cd.get_unknown("duMMy"), Some("3")); + } +} diff --git a/src/header/common/content_language.rs b/actix-http/src/header/common/content_language.rs similarity index 74% rename from src/header/common/content_language.rs rename to actix-http/src/header/common/content_language.rs index a567ab691..838981a39 100644 --- a/src/header/common/content_language.rs +++ b/actix-http/src/header/common/content_language.rs @@ -1,21 +1,21 @@ +use crate::header::{QualityItem, CONTENT_LANGUAGE}; use language_tags::LanguageTag; -use header::{CONTENT_LANGUAGE, QualityItem}; header! { /// `Content-Language` header, defined in /// [RFC7231](https://tools.ietf.org/html/rfc7231#section-3.1.3.2) - /// + /// /// The `Content-Language` header field describes the natural language(s) /// of the intended audience for the representation. Note that this /// might not be equivalent to all the languages used within the /// representation. - /// + /// /// # ABNF /// /// ```text /// Content-Language = 1#language-tag /// ``` - /// + /// /// # Example values /// /// * `da` @@ -24,13 +24,13 @@ header! { /// # Examples /// /// ```rust - /// # extern crate actix_web; + /// # extern crate actix_http; /// # #[macro_use] extern crate language_tags; - /// use actix_web::HttpResponse; - /// # use actix_web::http::header::{ContentLanguage, qitem}; - /// # + /// use actix_http::Response; + /// # use actix_http::http::header::{ContentLanguage, qitem}; + /// # /// # fn main() { - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// builder.set( /// ContentLanguage(vec![ /// qitem(langtag!(en)), @@ -40,14 +40,14 @@ header! { /// ``` /// /// ```rust - /// # extern crate actix_web; + /// # extern crate actix_http; /// # #[macro_use] extern crate language_tags; - /// use actix_web::HttpResponse; - /// # use actix_web::http::header::{ContentLanguage, qitem}; + /// use actix_http::Response; + /// # use actix_http::http::header::{ContentLanguage, qitem}; /// # /// # fn main() { - /// - /// let mut builder = HttpResponse::Ok(); + /// + /// let mut builder = Response::Ok(); /// builder.set( /// ContentLanguage(vec![ /// qitem(langtag!(da)), diff --git a/src/header/common/content_range.rs b/actix-http/src/header/common/content_range.rs similarity index 76% rename from src/header/common/content_range.rs rename to actix-http/src/header/common/content_range.rs index 8916cf541..cc7f27548 100644 --- a/src/header/common/content_range.rs +++ b/actix-http/src/header/common/content_range.rs @@ -1,8 +1,10 @@ use std::fmt::{self, Display, Write}; use std::str::FromStr; -use error::ParseError; -use header::{IntoHeaderValue, Writer, - HeaderValue, InvalidHeaderValueBytes, CONTENT_RANGE}; + +use crate::error::ParseError; +use crate::header::{ + HeaderValue, IntoHeaderValue, InvalidHeaderValueBytes, Writer, CONTENT_RANGE, +}; header! { /// `Content-Range` header, defined in @@ -69,7 +71,6 @@ header! { } } - /// Content-Range, described in [RFC7233](https://tools.ietf.org/html/rfc7233#section-4.2) /// /// # ABNF @@ -99,7 +100,7 @@ pub enum ContentRangeSpec { range: Option<(u64, u64)>, /// Total length of the instance, can be omitted if unknown - instance_length: Option + instance_length: Option, }, /// Custom range, with unit not registered at IANA @@ -108,15 +109,15 @@ pub enum ContentRangeSpec { unit: String, /// other-range-resp - resp: String - } + resp: String, + }, } fn split_in_two(s: &str, separator: char) -> Option<(&str, &str)> { let mut iter = s.splitn(2, separator); match (iter.next(), iter.next()) { (Some(a), Some(b)) => Some((a, b)), - _ => None + _ => None, } } @@ -126,40 +127,39 @@ impl FromStr for ContentRangeSpec { fn from_str(s: &str) -> Result { let res = match split_in_two(s, ' ') { Some(("bytes", resp)) => { - let (range, instance_length) = split_in_two( - resp, '/').ok_or(ParseError::Header)?; + let (range, instance_length) = + split_in_two(resp, '/').ok_or(ParseError::Header)?; let instance_length = if instance_length == "*" { None } else { - Some(instance_length.parse() - .map_err(|_| ParseError::Header)?) + Some(instance_length.parse().map_err(|_| ParseError::Header)?) }; let range = if range == "*" { None } else { - let (first_byte, last_byte) = split_in_two( - range, '-').ok_or(ParseError::Header)?; - let first_byte = first_byte.parse() - .map_err(|_| ParseError::Header)?; - let last_byte = last_byte.parse() - .map_err(|_| ParseError::Header)?; + let (first_byte, last_byte) = + split_in_two(range, '-').ok_or(ParseError::Header)?; + let first_byte = + first_byte.parse().map_err(|_| ParseError::Header)?; + let last_byte = last_byte.parse().map_err(|_| ParseError::Header)?; if last_byte < first_byte { return Err(ParseError::Header); } Some((first_byte, last_byte)) }; - ContentRangeSpec::Bytes {range, instance_length} - } - Some((unit, resp)) => { - ContentRangeSpec::Unregistered { - unit: unit.to_owned(), - resp: resp.to_owned() + ContentRangeSpec::Bytes { + range, + instance_length, } } - _ => return Err(ParseError::Header) + Some((unit, resp)) => ContentRangeSpec::Unregistered { + unit: unit.to_owned(), + resp: resp.to_owned(), + }, + _ => return Err(ParseError::Header), }; Ok(res) } @@ -168,17 +168,20 @@ impl FromStr for ContentRangeSpec { impl Display for ContentRangeSpec { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - ContentRangeSpec::Bytes { range, instance_length } => { - try!(f.write_str("bytes ")); + ContentRangeSpec::Bytes { + range, + instance_length, + } => { + f.write_str("bytes ")?; match range { Some((first_byte, last_byte)) => { - try!(write!(f, "{}-{}", first_byte, last_byte)); - }, + write!(f, "{}-{}", first_byte, last_byte)?; + } None => { - try!(f.write_str("*")); + f.write_str("*")?; } }; - try!(f.write_str("/")); + f.write_str("/")?; if let Some(v) = instance_length { write!(f, "{}", v) } else { @@ -186,8 +189,8 @@ impl Display for ContentRangeSpec { } } ContentRangeSpec::Unregistered { ref unit, ref resp } => { - try!(f.write_str(unit)); - try!(f.write_str(" ")); + f.write_str(unit)?; + f.write_str(" ")?; f.write_str(resp) } } diff --git a/src/header/common/content_type.rs b/actix-http/src/header/common/content_type.rs similarity index 81% rename from src/header/common/content_type.rs rename to actix-http/src/header/common/content_type.rs index 939054a05..a0baa5637 100644 --- a/src/header/common/content_type.rs +++ b/actix-http/src/header/common/content_type.rs @@ -1,6 +1,5 @@ -use mime::{self, Mime}; -use header::CONTENT_TYPE; - +use crate::header::CONTENT_TYPE; +use mime::Mime; header! { /// `Content-Type` header, defined in @@ -32,11 +31,11 @@ header! { /// # Examples /// /// ```rust - /// use actix_web::HttpResponse; - /// use actix_web::http::header::ContentType; + /// use actix_http::Response; + /// use actix_http::http::header::ContentType; /// /// # fn main() { - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// builder.set( /// ContentType::json() /// ); @@ -45,13 +44,13 @@ header! { /// /// ```rust /// # extern crate mime; - /// # extern crate actix_web; + /// # extern crate actix_http; /// use mime::TEXT_HTML; - /// use actix_web::HttpResponse; - /// use actix_web::http::header::ContentType; + /// use actix_http::Response; + /// use actix_http::http::header::ContentType; /// /// # fn main() { - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// builder.set( /// ContentType(TEXT_HTML) /// ); @@ -63,18 +62,20 @@ header! { test_header!( test1, vec![b"text/html"], - Some(HeaderField(TEXT_HTML))); + Some(HeaderField(mime::TEXT_HTML))); } } impl ContentType { - /// A constructor to easily create a `Content-Type: application/json` header. + /// A constructor to easily create a `Content-Type: application/json` + /// header. #[inline] pub fn json() -> ContentType { ContentType(mime::APPLICATION_JSON) } - /// A constructor to easily create a `Content-Type: text/plain; charset=utf-8` header. + /// A constructor to easily create a `Content-Type: text/plain; + /// charset=utf-8` header. #[inline] pub fn plaintext() -> ContentType { ContentType(mime::TEXT_PLAIN_UTF_8) @@ -92,7 +93,8 @@ impl ContentType { ContentType(mime::TEXT_XML) } - /// A constructor to easily create a `Content-Type: application/www-form-url-encoded` header. + /// A constructor to easily create a `Content-Type: + /// application/www-form-url-encoded` header. #[inline] pub fn form_url_encoded() -> ContentType { ContentType(mime::APPLICATION_WWW_FORM_URLENCODED) @@ -109,7 +111,8 @@ impl ContentType { ContentType(mime::IMAGE_PNG) } - /// A constructor to easily create a `Content-Type: application/octet-stream` header. + /// A constructor to easily create a `Content-Type: + /// application/octet-stream` header. #[inline] pub fn octet_stream() -> ContentType { ContentType(mime::APPLICATION_OCTET_STREAM) diff --git a/src/header/common/date.rs b/actix-http/src/header/common/date.rs similarity index 78% rename from src/header/common/date.rs rename to actix-http/src/header/common/date.rs index 59d37d73b..784100e8d 100644 --- a/src/header/common/date.rs +++ b/actix-http/src/header/common/date.rs @@ -1,6 +1,5 @@ +use crate::header::{HttpDate, DATE}; use std::time::SystemTime; -use header::{DATE, HttpDate}; - header! { /// `Date` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-7.1.1.2) @@ -21,11 +20,11 @@ header! { /// # Example /// /// ```rust - /// use actix_web::HttpResponse; - /// use actix_web::http::header::Date; + /// use actix_http::Response; + /// use actix_http::http::header::Date; /// use std::time::SystemTime; /// - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// builder.set(Date(SystemTime::now().into())); /// ``` (Date, DATE) => [HttpDate] @@ -36,6 +35,7 @@ header! { } impl Date { + /// Create a date instance set to the current system time pub fn now() -> Date { Date(SystemTime::now().into()) } diff --git a/src/header/common/etag.rs b/actix-http/src/header/common/etag.rs similarity index 90% rename from src/header/common/etag.rs rename to actix-http/src/header/common/etag.rs index a52bd0a8a..325b91cbf 100644 --- a/src/header/common/etag.rs +++ b/actix-http/src/header/common/etag.rs @@ -1,4 +1,4 @@ -use header::{ETAG, EntityTag}; +use crate::header::{EntityTag, ETAG}; header! { /// `ETag` header, defined in [RFC7232](http://tools.ietf.org/html/rfc7232#section-2.3) @@ -28,18 +28,18 @@ header! { /// # Examples /// /// ```rust - /// use actix_web::HttpResponse; - /// use actix_web::http::header::{ETag, EntityTag}; + /// use actix_http::Response; + /// use actix_http::http::header::{ETag, EntityTag}; /// - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// builder.set(ETag(EntityTag::new(false, "xyzzy".to_owned()))); /// ``` /// /// ```rust - /// use actix_web::HttpResponse; - /// use actix_web::http::header::{ETag, EntityTag}; + /// use actix_http::Response; + /// use actix_http::http::header::{ETag, EntityTag}; /// - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// builder.set(ETag(EntityTag::new(true, "xyzzy".to_owned()))); /// ``` (ETag, ETAG) => [EntityTag] diff --git a/src/header/common/expires.rs b/actix-http/src/header/common/expires.rs similarity index 85% rename from src/header/common/expires.rs rename to actix-http/src/header/common/expires.rs index aab751b0a..3b9a7873d 100644 --- a/src/header/common/expires.rs +++ b/actix-http/src/header/common/expires.rs @@ -1,4 +1,4 @@ -use header::{EXPIRES, HttpDate}; +use crate::header::{HttpDate, EXPIRES}; header! { /// `Expires` header, defined in [RFC7234](http://tools.ietf.org/html/rfc7234#section-5.3) @@ -22,11 +22,11 @@ header! { /// # Example /// /// ```rust - /// use actix_web::HttpResponse; - /// use actix_web::http::header::Expires; + /// use actix_http::Response; + /// use actix_http::http::header::Expires; /// use std::time::{SystemTime, Duration}; /// - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// let expiration = SystemTime::now() + Duration::from_secs(60 * 60 * 24); /// builder.set(Expires(expiration.into())); /// ``` diff --git a/src/header/common/if_match.rs b/actix-http/src/header/common/if_match.rs similarity index 87% rename from src/header/common/if_match.rs rename to actix-http/src/header/common/if_match.rs index a7ad7f704..7e0e9a7e0 100644 --- a/src/header/common/if_match.rs +++ b/actix-http/src/header/common/if_match.rs @@ -1,4 +1,4 @@ -use header::{IF_MATCH, EntityTag}; +use crate::header::{EntityTag, IF_MATCH}; header! { /// `If-Match` header, defined in @@ -30,18 +30,18 @@ header! { /// # Examples /// /// ```rust - /// use actix_web::HttpResponse; - /// use actix_web::http::header::IfMatch; + /// use actix_http::Response; + /// use actix_http::http::header::IfMatch; /// - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// builder.set(IfMatch::Any); /// ``` /// /// ```rust - /// use actix_web::HttpResponse; - /// use actix_web::http::header::{IfMatch, EntityTag}; + /// use actix_http::Response; + /// use actix_http::http::header::{IfMatch, EntityTag}; /// - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// builder.set( /// IfMatch::Items(vec![ /// EntityTag::new(false, "xyzzy".to_owned()), diff --git a/src/header/common/if_modified_since.rs b/actix-http/src/header/common/if_modified_since.rs similarity index 85% rename from src/header/common/if_modified_since.rs rename to actix-http/src/header/common/if_modified_since.rs index 48d3c9382..39aca595d 100644 --- a/src/header/common/if_modified_since.rs +++ b/actix-http/src/header/common/if_modified_since.rs @@ -1,4 +1,4 @@ -use header::{IF_MODIFIED_SINCE, HttpDate}; +use crate::header::{HttpDate, IF_MODIFIED_SINCE}; header! { /// `If-Modified-Since` header, defined in @@ -22,11 +22,11 @@ header! { /// # Example /// /// ```rust - /// use actix_web::HttpResponse; - /// use actix_web::http::header::IfModifiedSince; + /// use actix_http::Response; + /// use actix_http::http::header::IfModifiedSince; /// use std::time::{SystemTime, Duration}; /// - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// let modified = SystemTime::now() - Duration::from_secs(60 * 60 * 24); /// builder.set(IfModifiedSince(modified.into())); /// ``` diff --git a/src/header/common/if_none_match.rs b/actix-http/src/header/common/if_none_match.rs similarity index 82% rename from src/header/common/if_none_match.rs rename to actix-http/src/header/common/if_none_match.rs index 8381988aa..7f6ccb137 100644 --- a/src/header/common/if_none_match.rs +++ b/actix-http/src/header/common/if_none_match.rs @@ -1,4 +1,4 @@ -use header::{IF_NONE_MATCH, EntityTag}; +use crate::header::{EntityTag, IF_NONE_MATCH}; header! { /// `If-None-Match` header, defined in @@ -32,18 +32,18 @@ header! { /// # Examples /// /// ```rust - /// use actix_web::HttpResponse; - /// use actix_web::http::header::IfNoneMatch; + /// use actix_http::Response; + /// use actix_http::http::header::IfNoneMatch; /// - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// builder.set(IfNoneMatch::Any); /// ``` /// /// ```rust - /// use actix_web::HttpResponse; - /// use actix_web::http::header::{IfNoneMatch, EntityTag}; + /// use actix_http::Response; + /// use actix_http::http::header::{IfNoneMatch, EntityTag}; /// - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// builder.set( /// IfNoneMatch::Items(vec![ /// EntityTag::new(false, "xyzzy".to_owned()), @@ -66,8 +66,8 @@ header! { #[cfg(test)] mod tests { use super::IfNoneMatch; - use test::TestRequest; - use header::{IF_NONE_MATCH, Header, EntityTag}; + use crate::header::{EntityTag, Header, IF_NONE_MATCH}; + use crate::test::TestRequest; #[test] fn test_if_none_match() { @@ -77,8 +77,9 @@ mod tests { if_none_match = Header::parse(&req); assert_eq!(if_none_match.ok(), Some(IfNoneMatch::Any)); - let req = TestRequest::with_header( - IF_NONE_MATCH, &b"\"foobar\", W/\"weak-etag\""[..]).finish(); + let req = + TestRequest::with_header(IF_NONE_MATCH, &b"\"foobar\", W/\"weak-etag\""[..]) + .finish(); if_none_match = Header::parse(&req); let mut entities: Vec = Vec::new(); diff --git a/src/header/common/if_range.rs b/actix-http/src/header/common/if_range.rs similarity index 75% rename from src/header/common/if_range.rs rename to actix-http/src/header/common/if_range.rs index 7848f12d0..e910ebd96 100644 --- a/src/header/common/if_range.rs +++ b/actix-http/src/header/common/if_range.rs @@ -1,10 +1,11 @@ use std::fmt::{self, Display, Write}; -use error::ParseError; -use httpmessage::HttpMessage; -use http::header; -use header::from_one_raw_str; -use header::{IntoHeaderValue, Header, HeaderName, HeaderValue, - EntityTag, HttpDate, Writer, InvalidHeaderValueBytes}; + +use crate::error::ParseError; +use crate::header::{ + self, from_one_raw_str, EntityTag, Header, HeaderName, HeaderValue, HttpDate, + IntoHeaderValue, InvalidHeaderValueBytes, Writer, +}; +use crate::httpmessage::HttpMessage; /// `If-Range` header, defined in [RFC7233](http://tools.ietf.org/html/rfc7233#section-3.2) /// @@ -35,19 +36,22 @@ use header::{IntoHeaderValue, Header, HeaderName, HeaderValue, /// # Examples /// /// ```rust -/// use actix_web::HttpResponse; -/// use actix_web::http::header::{IfRange, EntityTag}; +/// use actix_http::Response; +/// use actix_http::http::header::{EntityTag, IfRange}; /// -/// let mut builder = HttpResponse::Ok(); -/// builder.set(IfRange::EntityTag(EntityTag::new(false, "xyzzy".to_owned()))); +/// let mut builder = Response::Ok(); +/// builder.set(IfRange::EntityTag(EntityTag::new( +/// false, +/// "xyzzy".to_owned(), +/// ))); /// ``` /// /// ```rust -/// use actix_web::HttpResponse; -/// use actix_web::http::header::IfRange; -/// use std::time::{SystemTime, Duration}; +/// use actix_http::Response; +/// use actix_http::http::header::IfRange; +/// use std::time::{Duration, SystemTime}; /// -/// let mut builder = HttpResponse::Ok(); +/// let mut builder = Response::Ok(); /// let fetched = SystemTime::now() - Duration::from_secs(60 * 60 * 24); /// builder.set(IfRange::Date(fetched.into())); /// ``` @@ -64,15 +68,17 @@ impl Header for IfRange { header::IF_RANGE } #[inline] - fn parse(msg: &T) -> Result where T: HttpMessage + fn parse(msg: &T) -> Result + where + T: HttpMessage, { let etag: Result = - from_one_raw_str(msg.headers().get(header::IF_RANGE)); + from_one_raw_str(msg.headers().get(&header::IF_RANGE)); if let Ok(etag) = etag { return Ok(IfRange::EntityTag(etag)); } let date: Result = - from_one_raw_str(msg.headers().get(header::IF_RANGE)); + from_one_raw_str(msg.headers().get(&header::IF_RANGE)); if let Ok(date) = date { return Ok(IfRange::Date(date)); } @@ -99,12 +105,11 @@ impl IntoHeaderValue for IfRange { } } - #[cfg(test)] mod test_if_range { - use std::str; - use header::*; use super::IfRange as HeaderField; + use crate::header::*; + use std::str; test_header!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); test_header!(test2, vec![b"\"xyzzy\""]); test_header!(test3, vec![b"this-is-invalid"], None::); diff --git a/src/header/common/if_unmodified_since.rs b/actix-http/src/header/common/if_unmodified_since.rs similarity index 86% rename from src/header/common/if_unmodified_since.rs rename to actix-http/src/header/common/if_unmodified_since.rs index 4750de0e6..d6c099e64 100644 --- a/src/header/common/if_unmodified_since.rs +++ b/actix-http/src/header/common/if_unmodified_since.rs @@ -1,4 +1,4 @@ -use header::{IF_UNMODIFIED_SINCE, HttpDate}; +use crate::header::{HttpDate, IF_UNMODIFIED_SINCE}; header! { /// `If-Unmodified-Since` header, defined in @@ -23,11 +23,11 @@ header! { /// # Example /// /// ```rust - /// use actix_web::HttpResponse; - /// use actix_web::http::header::IfUnmodifiedSince; + /// use actix_http::Response; + /// use actix_http::http::header::IfUnmodifiedSince; /// use std::time::{SystemTime, Duration}; /// - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// let modified = SystemTime::now() - Duration::from_secs(60 * 60 * 24); /// builder.set(IfUnmodifiedSince(modified.into())); /// ``` diff --git a/src/header/common/last_modified.rs b/actix-http/src/header/common/last_modified.rs similarity index 85% rename from src/header/common/last_modified.rs rename to actix-http/src/header/common/last_modified.rs index a882b0d1a..cc888ccb0 100644 --- a/src/header/common/last_modified.rs +++ b/actix-http/src/header/common/last_modified.rs @@ -1,4 +1,4 @@ -use header::{LAST_MODIFIED, HttpDate}; +use crate::header::{HttpDate, LAST_MODIFIED}; header! { /// `Last-Modified` header, defined in @@ -22,11 +22,11 @@ header! { /// # Example /// /// ```rust - /// use actix_web::HttpResponse; - /// use actix_web::http::header::LastModified; + /// use actix_http::Response; + /// use actix_http::http::header::LastModified; /// use std::time::{SystemTime, Duration}; /// - /// let mut builder = HttpResponse::Ok(); + /// let mut builder = Response::Ok(); /// let modified = SystemTime::now() - Duration::from_secs(60 * 60 * 24); /// builder.set(LastModified(modified.into())); /// ``` diff --git a/src/header/common/mod.rs b/actix-http/src/header/common/mod.rs similarity index 73% rename from src/header/common/mod.rs rename to actix-http/src/header/common/mod.rs index 5f548f012..30dfcaa6d 100644 --- a/src/header/common/mod.rs +++ b/actix-http/src/header/common/mod.rs @@ -5,6 +5,7 @@ //! Several header fields use MIME values for their contents. Keeping with the //! strongly-typed theme, the [mime](https://docs.rs/mime) crate //! is used, such as `ContentType(pub Mime)`. +#![cfg_attr(rustfmt, rustfmt_skip)] pub use self::accept_charset::AcceptCharset; //pub use self::accept_encoding::AcceptEncoding; @@ -12,7 +13,7 @@ pub use self::accept_language::AcceptLanguage; pub use self::accept::Accept; pub use self::allow::Allow; pub use self::cache_control::{CacheControl, CacheDirective}; -//pub use self::content_disposition::{ContentDisposition, DispositionType, DispositionParam}; +pub use self::content_disposition::{ContentDisposition, DispositionType, DispositionParam}; pub use self::content_language::ContentLanguage; pub use self::content_range::{ContentRange, ContentRangeSpec}; pub use self::content_type::ContentType; @@ -58,8 +59,8 @@ macro_rules! __hyper__tm { mod $tm{ use std::str; use http::Method; + use mime::*; use $crate::header::*; - use $crate::mime::*; use super::$id as HeaderField; $($tf)* } @@ -73,14 +74,14 @@ macro_rules! test_header { ($id:ident, $raw:expr) => { #[test] fn $id() { - #[allow(unused, deprecated)] - use std::ascii::AsciiExt; - use test; + use $crate::test; + use super::*; + let raw = $raw; let a: Vec> = raw.iter().map(|x| x.to_vec()).collect(); let mut req = test::TestRequest::default(); for item in a { - req = req.header(HeaderField::name(), item); + req = req.header(HeaderField::name(), item).take(); } let req = req.finish(); let value = HeaderField::parse(&req); @@ -103,10 +104,11 @@ macro_rules! test_header { #[test] fn $id() { use $crate::test; + let a: Vec> = $raw.iter().map(|x| x.to_vec()).collect(); let mut req = test::TestRequest::default(); for item in a { - req = req.header(HeaderField::name(), item); + req.header(HeaderField::name(), item); } let req = req.finish(); let val = HeaderField::parse(&req); @@ -142,33 +144,33 @@ macro_rules! header { #[derive(Clone, Debug, PartialEq)] pub struct $id(pub Vec<$item>); __hyper__deref!($id => Vec<$item>); - impl $crate::header::Header for $id { + impl $crate::http::header::Header for $id { #[inline] - fn name() -> $crate::header::HeaderName { + fn name() -> $crate::http::header::HeaderName { $name } #[inline] fn parse(msg: &T) -> Result where T: $crate::HttpMessage { - $crate::header::from_comma_delimited( + $crate::http::header::from_comma_delimited( msg.headers().get_all(Self::name())).map($id) } } - impl ::std::fmt::Display for $id { + impl std::fmt::Display for $id { #[inline] - fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - $crate::header::fmt_comma_delimited(f, &self.0[..]) + fn fmt(&self, f: &mut std::fmt::Formatter) -> ::std::fmt::Result { + $crate::http::header::fmt_comma_delimited(f, &self.0[..]) } } - impl $crate::header::IntoHeaderValue for $id { - type Error = $crate::header::InvalidHeaderValueBytes; + impl $crate::http::header::IntoHeaderValue for $id { + type Error = $crate::http::header::InvalidHeaderValueBytes; - fn try_into(self) -> Result<$crate::header::HeaderValue, Self::Error> { + fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { use std::fmt::Write; - let mut writer = $crate::header::Writer::new(); + let mut writer = $crate::http::header::Writer::new(); let _ = write!(&mut writer, "{}", self); - $crate::header::HeaderValue::from_shared(writer.take()) + $crate::http::header::HeaderValue::from_shared(writer.take()) } } }; @@ -178,33 +180,33 @@ macro_rules! header { #[derive(Clone, Debug, PartialEq)] pub struct $id(pub Vec<$item>); __hyper__deref!($id => Vec<$item>); - impl $crate::header::Header for $id { + impl $crate::http::header::Header for $id { #[inline] - fn name() -> $crate::header::HeaderName { + fn name() -> $crate::http::header::HeaderName { $name } #[inline] fn parse(msg: &T) -> Result where T: $crate::HttpMessage { - $crate::header::from_comma_delimited( + $crate::http::header::from_comma_delimited( msg.headers().get_all(Self::name())).map($id) } } - impl ::std::fmt::Display for $id { + impl std::fmt::Display for $id { #[inline] - fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - $crate::header::fmt_comma_delimited(f, &self.0[..]) + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + $crate::http::header::fmt_comma_delimited(f, &self.0[..]) } } - impl $crate::header::IntoHeaderValue for $id { - type Error = $crate::header::InvalidHeaderValueBytes; + impl $crate::http::header::IntoHeaderValue for $id { + type Error = $crate::http::header::InvalidHeaderValueBytes; - fn try_into(self) -> Result<$crate::header::HeaderValue, Self::Error> { + fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { use std::fmt::Write; - let mut writer = $crate::header::Writer::new(); + let mut writer = $crate::http::header::Writer::new(); let _ = write!(&mut writer, "{}", self); - $crate::header::HeaderValue::from_shared(writer.take()) + $crate::http::header::HeaderValue::from_shared(writer.take()) } } }; @@ -214,29 +216,29 @@ macro_rules! header { #[derive(Clone, Debug, PartialEq)] pub struct $id(pub $value); __hyper__deref!($id => $value); - impl $crate::header::Header for $id { + impl $crate::http::header::Header for $id { #[inline] - fn name() -> $crate::header::HeaderName { + fn name() -> $crate::http::header::HeaderName { $name } #[inline] fn parse(msg: &T) -> Result where T: $crate::HttpMessage { - $crate::header::from_one_raw_str( + $crate::http::header::from_one_raw_str( msg.headers().get(Self::name())).map($id) } } - impl ::std::fmt::Display for $id { + impl std::fmt::Display for $id { #[inline] - fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - ::std::fmt::Display::fmt(&self.0, f) + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + std::fmt::Display::fmt(&self.0, f) } } - impl $crate::header::IntoHeaderValue for $id { - type Error = $crate::header::InvalidHeaderValueBytes; + impl $crate::http::header::IntoHeaderValue for $id { + type Error = $crate::http::header::InvalidHeaderValueBytes; - fn try_into(self) -> Result<$crate::header::HeaderValue, Self::Error> { + fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { self.0.try_into() } } @@ -251,9 +253,9 @@ macro_rules! header { /// Only the listed items are a match Items(Vec<$item>), } - impl $crate::header::Header for $id { + impl $crate::http::header::Header for $id { #[inline] - fn name() -> $crate::header::HeaderName { + fn name() -> $crate::http::header::HeaderName { $name } #[inline] @@ -267,29 +269,29 @@ macro_rules! header { Ok($id::Any) } else { Ok($id::Items( - $crate::header::from_comma_delimited( + $crate::http::header::from_comma_delimited( msg.headers().get_all(Self::name()))?)) } } } - impl ::std::fmt::Display for $id { + impl std::fmt::Display for $id { #[inline] - fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match *self { $id::Any => f.write_str("*"), - $id::Items(ref fields) => $crate::header::fmt_comma_delimited( + $id::Items(ref fields) => $crate::http::header::fmt_comma_delimited( f, &fields[..]) } } } - impl $crate::header::IntoHeaderValue for $id { - type Error = $crate::header::InvalidHeaderValueBytes; + impl $crate::http::header::IntoHeaderValue for $id { + type Error = $crate::http::header::InvalidHeaderValueBytes; - fn try_into(self) -> Result<$crate::header::HeaderValue, Self::Error> { + fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { use std::fmt::Write; - let mut writer = $crate::header::Writer::new(); + let mut writer = $crate::http::header::Writer::new(); let _ = write!(&mut writer, "{}", self); - $crate::header::HeaderValue::from_shared(writer.take()) + $crate::http::header::HeaderValue::from_shared(writer.take()) } } }; @@ -335,7 +337,7 @@ mod accept_language; mod accept; mod allow; mod cache_control; -//mod content_disposition; +mod content_disposition; mod content_language; mod content_range; mod content_type; @@ -348,4 +350,3 @@ mod if_none_match; mod if_range; mod if_unmodified_since; mod last_modified; -//mod range; diff --git a/src/header/common/range.rs b/actix-http/src/header/common/range.rs similarity index 74% rename from src/header/common/range.rs rename to actix-http/src/header/common/range.rs index d0fca0f3e..71718fc7a 100644 --- a/src/header/common/range.rs +++ b/actix-http/src/header/common/range.rs @@ -1,8 +1,8 @@ use std::fmt::{self, Display}; use std::str::FromStr; +use header::parsing::from_one_raw_str; use header::{Header, Raw}; -use header::parsing::{from_one_raw_str}; /// `Range` header, defined in [RFC7233](https://tools.ietf.org/html/rfc7233#section-3.1) /// @@ -65,7 +65,7 @@ pub enum Range { Bytes(Vec), /// Custom range, with unit not registered at IANA /// (`other-range-unit`: String , `other-range-set`: String) - Unregistered(String, String) + Unregistered(String, String), } /// Each `Range::Bytes` header can contain one or more `ByteRangeSpecs`. @@ -77,25 +77,25 @@ pub enum ByteRangeSpec { /// Get all bytes starting from x ("x-") AllFrom(u64), /// Get last x bytes ("-x") - Last(u64) + Last(u64), } impl ByteRangeSpec { /// Given the full length of the entity, attempt to normalize the byte range /// into an satisfiable end-inclusive (from, to) range. /// - /// The resulting range is guaranteed to be a satisfiable range within the bounds - /// of `0 <= from <= to < full_length`. + /// The resulting range is guaranteed to be a satisfiable range within the + /// bounds of `0 <= from <= to < full_length`. /// /// If the byte range is deemed unsatisfiable, `None` is returned. /// An unsatisfiable range is generally cause for a server to either reject /// the client request with a `416 Range Not Satisfiable` status code, or to - /// simply ignore the range header and serve the full entity using a `200 OK` - /// status code. + /// simply ignore the range header and serve the full entity using a `200 + /// OK` status code. /// /// This function closely follows [RFC 7233][1] section 2.1. - /// As such, it considers ranges to be satisfiable if they meet the following - /// conditions: + /// As such, it considers ranges to be satisfiable if they meet the + /// following conditions: /// /// > If a valid byte-range-set includes at least one byte-range-spec with /// a first-byte-pos that is less than the current length of the @@ -125,14 +125,14 @@ impl ByteRangeSpec { } else { None } - }, + } &ByteRangeSpec::AllFrom(from) => { if from < full_length { Some((from, full_length - 1)) } else { None } - }, + } &ByteRangeSpec::Last(last) => { if last > 0 { // From the RFC: If the selected representation is shorter @@ -160,11 +160,15 @@ impl Range { /// Get byte range header with multiple subranges /// ("bytes=from1-to1,from2-to2,fromX-toX") pub fn bytes_multi(ranges: Vec<(u64, u64)>) -> Range { - Range::Bytes(ranges.iter().map(|r| ByteRangeSpec::FromTo(r.0, r.1)).collect()) + Range::Bytes( + ranges + .iter() + .map(|r| ByteRangeSpec::FromTo(r.0, r.1)) + .collect(), + ) } } - impl fmt::Display for ByteRangeSpec { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { @@ -175,7 +179,6 @@ impl fmt::Display for ByteRangeSpec { } } - impl fmt::Display for Range { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { @@ -189,10 +192,10 @@ impl fmt::Display for Range { try!(Display::fmt(range, f)); } Ok(()) - }, + } Range::Unregistered(ref unit, ref range_str) => { write!(f, "{}={}", unit, range_str) - }, + } } } } @@ -211,11 +214,10 @@ impl FromStr for Range { } Ok(Range::Bytes(ranges)) } - (Some(unit), Some(range_str)) if unit != "" && range_str != "" => { - Ok(Range::Unregistered(unit.to_owned(), range_str.to_owned())) - - }, - _ => Err(::Error::Header) + (Some(unit), Some(range_str)) if unit != "" && range_str != "" => Ok( + Range::Unregistered(unit.to_owned(), range_str.to_owned()), + ), + _ => Err(::Error::Header), } } } @@ -227,19 +229,20 @@ impl FromStr for ByteRangeSpec { let mut parts = s.splitn(2, '-'); match (parts.next(), parts.next()) { - (Some(""), Some(end)) => { - end.parse().or(Err(::Error::Header)).map(ByteRangeSpec::Last) - }, - (Some(start), Some("")) => { - start.parse().or(Err(::Error::Header)).map(ByteRangeSpec::AllFrom) - }, - (Some(start), Some(end)) => { - match (start.parse(), end.parse()) { - (Ok(start), Ok(end)) if start <= end => Ok(ByteRangeSpec::FromTo(start, end)), - _ => Err(::Error::Header) + (Some(""), Some(end)) => end.parse() + .or(Err(::Error::Header)) + .map(ByteRangeSpec::Last), + (Some(start), Some("")) => start + .parse() + .or(Err(::Error::Header)) + .map(ByteRangeSpec::AllFrom), + (Some(start), Some(end)) => match (start.parse(), end.parse()) { + (Ok(start), Ok(end)) if start <= end => { + Ok(ByteRangeSpec::FromTo(start, end)) } + _ => Err(::Error::Header), }, - _ => Err(::Error::Header) + _ => Err(::Error::Header), } } } @@ -248,14 +251,13 @@ fn from_comma_delimited(s: &str) -> Vec { s.split(',') .filter_map(|x| match x.trim() { "" => None, - y => Some(y) + y => Some(y), }) .filter_map(|x| x.parse().ok()) .collect() } impl Header for Range { - fn header_name() -> &'static str { static NAME: &'static str = "Range"; NAME @@ -268,51 +270,52 @@ impl Header for Range { fn fmt_header(&self, f: &mut ::header::Formatter) -> fmt::Result { f.fmt_line(self) } - } #[test] fn test_parse_bytes_range_valid() { let r: Range = Header::parse_header(&"bytes=1-100".into()).unwrap(); let r2: Range = Header::parse_header(&"bytes=1-100,-".into()).unwrap(); - let r3 = Range::bytes(1, 100); + let r3 = Range::bytes(1, 100); assert_eq!(r, r2); assert_eq!(r2, r3); let r: Range = Header::parse_header(&"bytes=1-100,200-".into()).unwrap(); - let r2: Range = Header::parse_header(&"bytes= 1-100 , 101-xxx, 200- ".into()).unwrap(); - let r3 = Range::Bytes( - vec![ByteRangeSpec::FromTo(1, 100), ByteRangeSpec::AllFrom(200)] - ); + let r2: Range = + Header::parse_header(&"bytes= 1-100 , 101-xxx, 200- ".into()).unwrap(); + let r3 = Range::Bytes(vec![ + ByteRangeSpec::FromTo(1, 100), + ByteRangeSpec::AllFrom(200), + ]); assert_eq!(r, r2); assert_eq!(r2, r3); let r: Range = Header::parse_header(&"bytes=1-100,-100".into()).unwrap(); let r2: Range = Header::parse_header(&"bytes=1-100, ,,-100".into()).unwrap(); - let r3 = Range::Bytes( - vec![ByteRangeSpec::FromTo(1, 100), ByteRangeSpec::Last(100)] - ); + let r3 = Range::Bytes(vec![ + ByteRangeSpec::FromTo(1, 100), + ByteRangeSpec::Last(100), + ]); assert_eq!(r, r2); assert_eq!(r2, r3); let r: Range = Header::parse_header(&"custom=1-100,-100".into()).unwrap(); - let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned()); + let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned()); assert_eq!(r, r2); - } #[test] fn test_parse_unregistered_range_valid() { let r: Range = Header::parse_header(&"custom=1-100,-100".into()).unwrap(); - let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned()); + let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned()); assert_eq!(r, r2); let r: Range = Header::parse_header(&"custom=abcd".into()).unwrap(); - let r2 = Range::Unregistered("custom".to_owned(), "abcd".to_owned()); + let r2 = Range::Unregistered("custom".to_owned(), "abcd".to_owned()); assert_eq!(r, r2); let r: Range = Header::parse_header(&"custom=xxx-yyy".into()).unwrap(); - let r2 = Range::Unregistered("custom".to_owned(), "xxx-yyy".to_owned()); + let r2 = Range::Unregistered("custom".to_owned(), "xxx-yyy".to_owned()); assert_eq!(r, r2); } @@ -346,10 +349,10 @@ fn test_fmt() { let mut headers = Headers::new(); - headers.set( - Range::Bytes( - vec![ByteRangeSpec::FromTo(0, 1000), ByteRangeSpec::AllFrom(2000)] - )); + headers.set(Range::Bytes(vec![ + ByteRangeSpec::FromTo(0, 1000), + ByteRangeSpec::AllFrom(2000), + ])); assert_eq!(&headers.to_string(), "Range: bytes=0-1000,2000-\r\n"); headers.clear(); @@ -358,30 +361,74 @@ fn test_fmt() { assert_eq!(&headers.to_string(), "Range: bytes=\r\n"); headers.clear(); - headers.set(Range::Unregistered("custom".to_owned(), "1-xxx".to_owned())); + headers.set(Range::Unregistered( + "custom".to_owned(), + "1-xxx".to_owned(), + )); assert_eq!(&headers.to_string(), "Range: custom=1-xxx\r\n"); } #[test] fn test_byte_range_spec_to_satisfiable_range() { - assert_eq!(Some((0, 0)), ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(3)); - assert_eq!(Some((1, 2)), ByteRangeSpec::FromTo(1, 2).to_satisfiable_range(3)); - assert_eq!(Some((1, 2)), ByteRangeSpec::FromTo(1, 5).to_satisfiable_range(3)); - assert_eq!(None, ByteRangeSpec::FromTo(3, 3).to_satisfiable_range(3)); - assert_eq!(None, ByteRangeSpec::FromTo(2, 1).to_satisfiable_range(3)); - assert_eq!(None, ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(0)); + assert_eq!( + Some((0, 0)), + ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(3) + ); + assert_eq!( + Some((1, 2)), + ByteRangeSpec::FromTo(1, 2).to_satisfiable_range(3) + ); + assert_eq!( + Some((1, 2)), + ByteRangeSpec::FromTo(1, 5).to_satisfiable_range(3) + ); + assert_eq!( + None, + ByteRangeSpec::FromTo(3, 3).to_satisfiable_range(3) + ); + assert_eq!( + None, + ByteRangeSpec::FromTo(2, 1).to_satisfiable_range(3) + ); + assert_eq!( + None, + ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(0) + ); - assert_eq!(Some((0, 2)), ByteRangeSpec::AllFrom(0).to_satisfiable_range(3)); - assert_eq!(Some((2, 2)), ByteRangeSpec::AllFrom(2).to_satisfiable_range(3)); - assert_eq!(None, ByteRangeSpec::AllFrom(3).to_satisfiable_range(3)); - assert_eq!(None, ByteRangeSpec::AllFrom(5).to_satisfiable_range(3)); - assert_eq!(None, ByteRangeSpec::AllFrom(0).to_satisfiable_range(0)); + assert_eq!( + Some((0, 2)), + ByteRangeSpec::AllFrom(0).to_satisfiable_range(3) + ); + assert_eq!( + Some((2, 2)), + ByteRangeSpec::AllFrom(2).to_satisfiable_range(3) + ); + assert_eq!( + None, + ByteRangeSpec::AllFrom(3).to_satisfiable_range(3) + ); + assert_eq!( + None, + ByteRangeSpec::AllFrom(5).to_satisfiable_range(3) + ); + assert_eq!( + None, + ByteRangeSpec::AllFrom(0).to_satisfiable_range(0) + ); - assert_eq!(Some((1, 2)), ByteRangeSpec::Last(2).to_satisfiable_range(3)); - assert_eq!(Some((2, 2)), ByteRangeSpec::Last(1).to_satisfiable_range(3)); - assert_eq!(Some((0, 2)), ByteRangeSpec::Last(5).to_satisfiable_range(3)); + assert_eq!( + Some((1, 2)), + ByteRangeSpec::Last(2).to_satisfiable_range(3) + ); + assert_eq!( + Some((2, 2)), + ByteRangeSpec::Last(1).to_satisfiable_range(3) + ); + assert_eq!( + Some((0, 2)), + ByteRangeSpec::Last(5).to_satisfiable_range(3) + ); assert_eq!(None, ByteRangeSpec::Last(0).to_satisfiable_range(3)); assert_eq!(None, ByteRangeSpec::Last(2).to_satisfiable_range(0)); } - diff --git a/actix-http/src/header/map.rs b/actix-http/src/header/map.rs new file mode 100644 index 000000000..f2f1ba51c --- /dev/null +++ b/actix-http/src/header/map.rs @@ -0,0 +1,384 @@ +use either::Either; +use hashbrown::hash_map::{self, Entry}; +use hashbrown::HashMap; +use http::header::{HeaderName, HeaderValue}; +use http::HttpTryFrom; + +/// A set of HTTP headers +/// +/// `HeaderMap` is an multimap of [`HeaderName`] to values. +/// +/// [`HeaderName`]: struct.HeaderName.html +#[derive(Debug, Clone)] +pub struct HeaderMap { + pub(crate) inner: HashMap, +} + +#[derive(Debug, Clone)] +pub(crate) enum Value { + One(HeaderValue), + Multi(Vec), +} + +impl Value { + fn get(&self) -> &HeaderValue { + match self { + Value::One(ref val) => val, + Value::Multi(ref val) => &val[0], + } + } + + fn get_mut(&mut self) -> &mut HeaderValue { + match self { + Value::One(ref mut val) => val, + Value::Multi(ref mut val) => &mut val[0], + } + } + + fn append(&mut self, val: HeaderValue) { + match self { + Value::One(_) => { + let data = std::mem::replace(self, Value::Multi(vec![val])); + match data { + Value::One(val) => self.append(val), + Value::Multi(_) => unreachable!(), + } + } + Value::Multi(ref mut vec) => vec.push(val), + } + } +} + +impl HeaderMap { + /// Create an empty `HeaderMap`. + /// + /// The map will be created without any capacity. This function will not + /// allocate. + pub fn new() -> Self { + HeaderMap { + inner: HashMap::new(), + } + } + + /// Create an empty `HeaderMap` with the specified capacity. + /// + /// The returned map will allocate internal storage in order to hold about + /// `capacity` elements without reallocating. However, this is a "best + /// effort" as there are usage patterns that could cause additional + /// allocations before `capacity` headers are stored in the map. + /// + /// More capacity than requested may be allocated. + pub fn with_capacity(capacity: usize) -> HeaderMap { + HeaderMap { + inner: HashMap::with_capacity(capacity), + } + } + + /// Returns the number of keys stored in the map. + /// + /// This number could be be less than or equal to actual headers stored in + /// the map. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Returns true if the map contains no elements. + pub fn is_empty(&self) -> bool { + self.inner.len() == 0 + } + + /// Clears the map, removing all key-value pairs. Keeps the allocated memory + /// for reuse. + pub fn clear(&mut self) { + self.inner.clear(); + } + + /// Returns the number of headers the map can hold without reallocating. + /// + /// This number is an approximation as certain usage patterns could cause + /// additional allocations before the returned capacity is filled. + pub fn capacity(&self) -> usize { + self.inner.capacity() + } + + /// Reserves capacity for at least `additional` more headers to be inserted + /// into the `HeaderMap`. + /// + /// The header map may reserve more space to avoid frequent reallocations. + /// Like with `with_capacity`, this will be a "best effort" to avoid + /// allocations until `additional` more headers are inserted. Certain usage + /// patterns could cause additional allocations before the number is + /// reached. + pub fn reserve(&mut self, additional: usize) { + self.inner.reserve(additional) + } + + /// Returns a reference to the value associated with the key. + /// + /// If there are multiple values associated with the key, then the first one + /// is returned. Use `get_all` to get all values associated with a given + /// key. Returns `None` if there are no values associated with the key. + pub fn get(&self, name: N) -> Option<&HeaderValue> { + self.get2(name).map(|v| v.get()) + } + + fn get2(&self, name: N) -> Option<&Value> { + match name.as_name() { + Either::Left(name) => self.inner.get(name), + Either::Right(s) => { + if let Ok(name) = HeaderName::try_from(s) { + self.inner.get(&name) + } else { + None + } + } + } + } + + /// Returns a view of all values associated with a key. + /// + /// The returned view does not incur any allocations and allows iterating + /// the values associated with the key. See [`GetAll`] for more details. + /// Returns `None` if there are no values associated with the key. + /// + /// [`GetAll`]: struct.GetAll.html + pub fn get_all(&self, name: N) -> GetAll { + GetAll { + idx: 0, + item: self.get2(name), + } + } + + /// Returns a mutable reference to the value associated with the key. + /// + /// If there are multiple values associated with the key, then the first one + /// is returned. Use `entry` to get all values associated with a given + /// key. Returns `None` if there are no values associated with the key. + pub fn get_mut(&mut self, name: N) -> Option<&mut HeaderValue> { + match name.as_name() { + Either::Left(name) => self.inner.get_mut(name).map(|v| v.get_mut()), + Either::Right(s) => { + if let Ok(name) = HeaderName::try_from(s) { + self.inner.get_mut(&name).map(|v| v.get_mut()) + } else { + None + } + } + } + } + + /// Returns true if the map contains a value for the specified key. + pub fn contains_key(&self, key: N) -> bool { + match key.as_name() { + Either::Left(name) => self.inner.contains_key(name), + Either::Right(s) => { + if let Ok(name) = HeaderName::try_from(s) { + self.inner.contains_key(&name) + } else { + false + } + } + } + } + + /// An iterator visiting all key-value pairs. + /// + /// The iteration order is arbitrary, but consistent across platforms for + /// the same crate version. Each key will be yielded once per associated + /// value. So, if a key has 3 associated values, it will be yielded 3 times. + pub fn iter(&self) -> Iter { + Iter::new(self.inner.iter()) + } + + /// An iterator visiting all keys. + /// + /// The iteration order is arbitrary, but consistent across platforms for + /// the same crate version. Each key will be yielded only once even if it + /// has multiple associated values. + pub fn keys(&self) -> Keys { + Keys(self.inner.keys()) + } + + /// Inserts a key-value pair into the map. + /// + /// If the map did not previously have this key present, then `None` is + /// returned. + /// + /// If the map did have this key present, the new value is associated with + /// the key and all previous values are removed. **Note** that only a single + /// one of the previous values is returned. If there are multiple values + /// that have been previously associated with the key, then the first one is + /// returned. See `insert_mult` on `OccupiedEntry` for an API that returns + /// all values. + /// + /// The key is not updated, though; this matters for types that can be `==` + /// without being identical. + pub fn insert(&mut self, key: HeaderName, val: HeaderValue) { + let _ = self.inner.insert(key, Value::One(val)); + } + + /// Inserts a key-value pair into the map. + /// + /// If the map did not previously have this key present, then `false` is + /// returned. + /// + /// If the map did have this key present, the new value is pushed to the end + /// of the list of values currently associated with the key. The key is not + /// updated, though; this matters for types that can be `==` without being + /// identical. + pub fn append(&mut self, key: HeaderName, value: HeaderValue) { + match self.inner.entry(key) { + Entry::Occupied(mut entry) => entry.get_mut().append(value), + Entry::Vacant(entry) => { + entry.insert(Value::One(value)); + } + } + } + + /// Removes all headers for a particular header name from the map. + pub fn remove(&mut self, key: N) { + match key.as_name() { + Either::Left(name) => { + let _ = self.inner.remove(name); + } + Either::Right(s) => { + if let Ok(name) = HeaderName::try_from(s) { + let _ = self.inner.remove(&name); + } + } + } + } +} + +#[doc(hidden)] +pub trait AsName { + fn as_name(&self) -> Either<&HeaderName, &str>; +} + +impl AsName for HeaderName { + fn as_name(&self) -> Either<&HeaderName, &str> { + Either::Left(self) + } +} + +impl<'a> AsName for &'a HeaderName { + fn as_name(&self) -> Either<&HeaderName, &str> { + Either::Left(self) + } +} + +impl<'a> AsName for &'a str { + fn as_name(&self) -> Either<&HeaderName, &str> { + Either::Right(self) + } +} + +impl AsName for String { + fn as_name(&self) -> Either<&HeaderName, &str> { + Either::Right(self.as_str()) + } +} + +impl<'a> AsName for &'a String { + fn as_name(&self) -> Either<&HeaderName, &str> { + Either::Right(self.as_str()) + } +} + +pub struct GetAll<'a> { + idx: usize, + item: Option<&'a Value>, +} + +impl<'a> Iterator for GetAll<'a> { + type Item = &'a HeaderValue; + + #[inline] + fn next(&mut self) -> Option<&'a HeaderValue> { + if let Some(ref val) = self.item { + match val { + Value::One(ref val) => { + self.item.take(); + Some(val) + } + Value::Multi(ref vec) => { + if self.idx < vec.len() { + let item = Some(&vec[self.idx]); + self.idx += 1; + item + } else { + self.item.take(); + None + } + } + } + } else { + None + } + } +} + +pub struct Keys<'a>(hash_map::Keys<'a, HeaderName, Value>); + +impl<'a> Iterator for Keys<'a> { + type Item = &'a HeaderName; + + #[inline] + fn next(&mut self) -> Option<&'a HeaderName> { + self.0.next() + } +} + +impl<'a> IntoIterator for &'a HeaderMap { + type Item = (&'a HeaderName, &'a HeaderValue); + type IntoIter = Iter<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +pub struct Iter<'a> { + idx: usize, + current: Option<(&'a HeaderName, &'a Vec)>, + iter: hash_map::Iter<'a, HeaderName, Value>, +} + +impl<'a> Iter<'a> { + fn new(iter: hash_map::Iter<'a, HeaderName, Value>) -> Self { + Self { + iter, + idx: 0, + current: None, + } + } +} + +impl<'a> Iterator for Iter<'a> { + type Item = (&'a HeaderName, &'a HeaderValue); + + #[inline] + fn next(&mut self) -> Option<(&'a HeaderName, &'a HeaderValue)> { + if let Some(ref mut item) = self.current { + if self.idx < item.1.len() { + let item = (item.0, &item.1[self.idx]); + self.idx += 1; + return Some(item); + } else { + self.idx = 0; + self.current.take(); + } + } + if let Some(item) = self.iter.next() { + match item.1 { + Value::One(ref value) => Some((item.0, value)), + Value::Multi(ref vec) => { + self.current = Some((item.0, vec)); + self.next() + } + } + } else { + None + } + } +} diff --git a/actix-http/src/header/mod.rs b/actix-http/src/header/mod.rs new file mode 100644 index 000000000..59cbb11c4 --- /dev/null +++ b/actix-http/src/header/mod.rs @@ -0,0 +1,504 @@ +//! Various http headers +// This is mostly copy of [hyper](https://github.com/hyperium/hyper/tree/master/src/header) + +use std::{fmt, str::FromStr}; + +use bytes::{Bytes, BytesMut}; +use http::Error as HttpError; +use mime::Mime; +use percent_encoding::{AsciiSet, CONTROLS}; + +pub use http::header::*; + +use crate::error::ParseError; +use crate::httpmessage::HttpMessage; + +mod common; +pub(crate) mod map; +mod shared; +pub use self::common::*; +#[doc(hidden)] +pub use self::shared::*; + +#[doc(hidden)] +pub use self::map::GetAll; +pub use self::map::HeaderMap; + +/// A trait for any object that will represent a header field and value. +pub trait Header +where + Self: IntoHeaderValue, +{ + /// Returns the name of the header field + fn name() -> HeaderName; + + /// Parse a header + fn parse(msg: &T) -> Result; +} + +/// A trait for any object that can be Converted to a `HeaderValue` +pub trait IntoHeaderValue: Sized { + /// The type returned in the event of a conversion error. + type Error: Into; + + /// Try to convert value to a Header value. + fn try_into(self) -> Result; +} + +impl IntoHeaderValue for HeaderValue { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into(self) -> Result { + Ok(self) + } +} + +impl<'a> IntoHeaderValue for &'a str { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into(self) -> Result { + self.parse() + } +} + +impl<'a> IntoHeaderValue for &'a [u8] { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into(self) -> Result { + HeaderValue::from_bytes(self) + } +} + +impl IntoHeaderValue for Bytes { + type Error = InvalidHeaderValueBytes; + + #[inline] + fn try_into(self) -> Result { + HeaderValue::from_shared(self) + } +} + +impl IntoHeaderValue for Vec { + type Error = InvalidHeaderValueBytes; + + #[inline] + fn try_into(self) -> Result { + HeaderValue::from_shared(Bytes::from(self)) + } +} + +impl IntoHeaderValue for String { + type Error = InvalidHeaderValueBytes; + + #[inline] + fn try_into(self) -> Result { + HeaderValue::from_shared(Bytes::from(self)) + } +} + +impl IntoHeaderValue for usize { + type Error = InvalidHeaderValueBytes; + + #[inline] + fn try_into(self) -> Result { + let s = format!("{}", self); + HeaderValue::from_shared(Bytes::from(s)) + } +} + +impl IntoHeaderValue for u64 { + type Error = InvalidHeaderValueBytes; + + #[inline] + fn try_into(self) -> Result { + let s = format!("{}", self); + HeaderValue::from_shared(Bytes::from(s)) + } +} + +impl IntoHeaderValue for Mime { + type Error = InvalidHeaderValueBytes; + + #[inline] + fn try_into(self) -> Result { + HeaderValue::from_shared(Bytes::from(format!("{}", self))) + } +} + +/// Represents supported types of content encodings +#[derive(Copy, Clone, PartialEq, Debug)] +pub enum ContentEncoding { + /// Automatically select encoding based on encoding negotiation + Auto, + /// A format using the Brotli algorithm + Br, + /// A format using the zlib structure with deflate algorithm + Deflate, + /// Gzip algorithm + Gzip, + /// Indicates the identity function (i.e. no compression, nor modification) + Identity, +} + +impl ContentEncoding { + #[inline] + /// Is the content compressed? + pub fn is_compression(self) -> bool { + match self { + ContentEncoding::Identity | ContentEncoding::Auto => false, + _ => true, + } + } + + #[inline] + /// Convert content encoding to string + pub fn as_str(self) -> &'static str { + match self { + ContentEncoding::Br => "br", + ContentEncoding::Gzip => "gzip", + ContentEncoding::Deflate => "deflate", + ContentEncoding::Identity | ContentEncoding::Auto => "identity", + } + } + + #[inline] + /// default quality value + pub fn quality(self) -> f64 { + match self { + ContentEncoding::Br => 1.1, + ContentEncoding::Gzip => 1.0, + ContentEncoding::Deflate => 0.9, + ContentEncoding::Identity | ContentEncoding::Auto => 0.1, + } + } +} + +impl<'a> From<&'a str> for ContentEncoding { + fn from(s: &'a str) -> ContentEncoding { + let s = s.trim(); + + if s.eq_ignore_ascii_case("br") { + ContentEncoding::Br + } else if s.eq_ignore_ascii_case("gzip") { + ContentEncoding::Gzip + } else if s.eq_ignore_ascii_case("deflate") { + ContentEncoding::Deflate + } else { + ContentEncoding::Identity + } + } +} + +#[doc(hidden)] +pub(crate) struct Writer { + buf: BytesMut, +} + +impl Writer { + fn new() -> Writer { + Writer { + buf: BytesMut::new(), + } + } + fn take(&mut self) -> Bytes { + self.buf.take().freeze() + } +} + +impl fmt::Write for Writer { + #[inline] + fn write_str(&mut self, s: &str) -> fmt::Result { + self.buf.extend_from_slice(s.as_bytes()); + Ok(()) + } + + #[inline] + fn write_fmt(&mut self, args: fmt::Arguments) -> fmt::Result { + fmt::write(self, args) + } +} + +#[inline] +#[doc(hidden)] +/// Reads a comma-delimited raw header into a Vec. +pub fn from_comma_delimited<'a, I: Iterator + 'a, T: FromStr>( + all: I, +) -> Result, ParseError> { + let mut result = Vec::new(); + for h in all { + let s = h.to_str().map_err(|_| ParseError::Header)?; + result.extend( + s.split(',') + .filter_map(|x| match x.trim() { + "" => None, + y => Some(y), + }) + .filter_map(|x| x.trim().parse().ok()), + ) + } + Ok(result) +} + +#[inline] +#[doc(hidden)] +/// Reads a single string when parsing a header. +pub fn from_one_raw_str(val: Option<&HeaderValue>) -> Result { + if let Some(line) = val { + let line = line.to_str().map_err(|_| ParseError::Header)?; + if !line.is_empty() { + return T::from_str(line).or(Err(ParseError::Header)); + } + } + Err(ParseError::Header) +} + +#[inline] +#[doc(hidden)] +/// Format an array into a comma-delimited string. +pub fn fmt_comma_delimited(f: &mut fmt::Formatter, parts: &[T]) -> fmt::Result +where + T: fmt::Display, +{ + let mut iter = parts.iter(); + if let Some(part) = iter.next() { + fmt::Display::fmt(part, f)?; + } + for part in iter { + f.write_str(", ")?; + fmt::Display::fmt(part, f)?; + } + Ok(()) +} + +// From hyper v0.11.27 src/header/parsing.rs + +/// The value part of an extended parameter consisting of three parts: +/// the REQUIRED character set name (`charset`), the OPTIONAL language information (`language_tag`), +/// and a character sequence representing the actual value (`value`), separated by single quote +/// characters. It is defined in [RFC 5987](https://tools.ietf.org/html/rfc5987#section-3.2). +#[derive(Clone, Debug, PartialEq)] +pub struct ExtendedValue { + /// The character set that is used to encode the `value` to a string. + pub charset: Charset, + /// The human language details of the `value`, if available. + pub language_tag: Option, + /// The parameter value, as expressed in octets. + pub value: Vec, +} + +/// Parses extended header parameter values (`ext-value`), as defined in +/// [RFC 5987](https://tools.ietf.org/html/rfc5987#section-3.2). +/// +/// Extended values are denoted by parameter names that end with `*`. +/// +/// ## ABNF +/// +/// ```text +/// ext-value = charset "'" [ language ] "'" value-chars +/// ; like RFC 2231's +/// ; (see [RFC2231], Section 7) +/// +/// charset = "UTF-8" / "ISO-8859-1" / mime-charset +/// +/// mime-charset = 1*mime-charsetc +/// mime-charsetc = ALPHA / DIGIT +/// / "!" / "#" / "$" / "%" / "&" +/// / "+" / "-" / "^" / "_" / "`" +/// / "{" / "}" / "~" +/// ; as in Section 2.3 of [RFC2978] +/// ; except that the single quote is not included +/// ; SHOULD be registered in the IANA charset registry +/// +/// language = +/// +/// value-chars = *( pct-encoded / attr-char ) +/// +/// pct-encoded = "%" HEXDIG HEXDIG +/// ; see [RFC3986], Section 2.1 +/// +/// attr-char = ALPHA / DIGIT +/// / "!" / "#" / "$" / "&" / "+" / "-" / "." +/// / "^" / "_" / "`" / "|" / "~" +/// ; token except ( "*" / "'" / "%" ) +/// ``` +pub fn parse_extended_value( + val: &str, +) -> Result { + // Break into three pieces separated by the single-quote character + let mut parts = val.splitn(3, '\''); + + // Interpret the first piece as a Charset + let charset: Charset = match parts.next() { + None => return Err(crate::error::ParseError::Header), + Some(n) => FromStr::from_str(n).map_err(|_| crate::error::ParseError::Header)?, + }; + + // Interpret the second piece as a language tag + let language_tag: Option = match parts.next() { + None => return Err(crate::error::ParseError::Header), + Some("") => None, + Some(s) => match s.parse() { + Ok(lt) => Some(lt), + Err(_) => return Err(crate::error::ParseError::Header), + }, + }; + + // Interpret the third piece as a sequence of value characters + let value: Vec = match parts.next() { + None => return Err(crate::error::ParseError::Header), + Some(v) => percent_encoding::percent_decode(v.as_bytes()).collect(), + }; + + Ok(ExtendedValue { + value, + charset, + language_tag, + }) +} + +impl fmt::Display for ExtendedValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let encoded_value = + percent_encoding::percent_encode(&self.value[..], HTTP_VALUE); + if let Some(ref lang) = self.language_tag { + write!(f, "{}'{}'{}", self.charset, lang, encoded_value) + } else { + write!(f, "{}''{}", self.charset, encoded_value) + } + } +} + +/// Percent encode a sequence of bytes with a character set defined in +/// [https://tools.ietf.org/html/rfc5987#section-3.2][url] +/// +/// [url]: https://tools.ietf.org/html/rfc5987#section-3.2 +pub fn http_percent_encode(f: &mut fmt::Formatter, bytes: &[u8]) -> fmt::Result { + let encoded = percent_encoding::percent_encode(bytes, HTTP_VALUE); + fmt::Display::fmt(&encoded, f) +} + +/// Convert http::HeaderMap to a HeaderMap +impl From for HeaderMap { + fn from(map: http::HeaderMap) -> HeaderMap { + let mut new_map = HeaderMap::with_capacity(map.capacity()); + for (h, v) in map.iter() { + new_map.append(h.clone(), v.clone()); + } + new_map + } +} + +// This encode set is used for HTTP header values and is defined at +// https://tools.ietf.org/html/rfc5987#section-3.2 +pub(crate) const HTTP_VALUE: &AsciiSet = &CONTROLS + .add(b' ') + .add(b'"') + .add(b'%') + .add(b'\'') + .add(b'(') + .add(b')') + .add(b'*') + .add(b',') + .add(b'/') + .add(b':') + .add(b';') + .add(b'<') + .add(b'-') + .add(b'>') + .add(b'?') + .add(b'[') + .add(b'\\') + .add(b']') + .add(b'{') + .add(b'}'); + +#[cfg(test)] +mod tests { + use super::shared::Charset; + use super::{parse_extended_value, ExtendedValue}; + use language_tags::LanguageTag; + + #[test] + fn test_parse_extended_value_with_encoding_and_language_tag() { + let expected_language_tag = "en".parse::().unwrap(); + // RFC 5987, Section 3.2.2 + // Extended notation, using the Unicode character U+00A3 (POUND SIGN) + let result = parse_extended_value("iso-8859-1'en'%A3%20rates"); + assert!(result.is_ok()); + let extended_value = result.unwrap(); + assert_eq!(Charset::Iso_8859_1, extended_value.charset); + assert!(extended_value.language_tag.is_some()); + assert_eq!(expected_language_tag, extended_value.language_tag.unwrap()); + assert_eq!( + vec![163, b' ', b'r', b'a', b't', b'e', b's'], + extended_value.value + ); + } + + #[test] + fn test_parse_extended_value_with_encoding() { + // RFC 5987, Section 3.2.2 + // Extended notation, using the Unicode characters U+00A3 (POUND SIGN) + // and U+20AC (EURO SIGN) + let result = parse_extended_value("UTF-8''%c2%a3%20and%20%e2%82%ac%20rates"); + assert!(result.is_ok()); + let extended_value = result.unwrap(); + assert_eq!(Charset::Ext("UTF-8".to_string()), extended_value.charset); + assert!(extended_value.language_tag.is_none()); + assert_eq!( + vec![ + 194, 163, b' ', b'a', b'n', b'd', b' ', 226, 130, 172, b' ', b'r', b'a', + b't', b'e', b's', + ], + extended_value.value + ); + } + + #[test] + fn test_parse_extended_value_missing_language_tag_and_encoding() { + // From: https://greenbytes.de/tech/tc2231/#attwithfn2231quot2 + let result = parse_extended_value("foo%20bar.html"); + assert!(result.is_err()); + } + + #[test] + fn test_parse_extended_value_partially_formatted() { + let result = parse_extended_value("UTF-8'missing third part"); + assert!(result.is_err()); + } + + #[test] + fn test_parse_extended_value_partially_formatted_blank() { + let result = parse_extended_value("blank second part'"); + assert!(result.is_err()); + } + + #[test] + fn test_fmt_extended_value_with_encoding_and_language_tag() { + let extended_value = ExtendedValue { + charset: Charset::Iso_8859_1, + language_tag: Some("en".parse().expect("Could not parse language tag")), + value: vec![163, b' ', b'r', b'a', b't', b'e', b's'], + }; + assert_eq!("ISO-8859-1'en'%A3%20rates", format!("{}", extended_value)); + } + + #[test] + fn test_fmt_extended_value_with_encoding() { + let extended_value = ExtendedValue { + charset: Charset::Ext("UTF-8".to_string()), + language_tag: None, + value: vec![ + 194, 163, b' ', b'a', b'n', b'd', b' ', 226, 130, 172, b' ', b'r', b'a', + b't', b'e', b's', + ], + }; + assert_eq!( + "UTF-8''%C2%A3%20and%20%E2%82%AC%20rates", + format!("{}", extended_value) + ); + } +} diff --git a/src/header/shared/charset.rs b/actix-http/src/header/shared/charset.rs similarity index 82% rename from src/header/shared/charset.rs rename to actix-http/src/header/shared/charset.rs index 765b34afa..ec3fe3854 100644 --- a/src/header/shared/charset.rs +++ b/actix-http/src/header/shared/charset.rs @@ -1,20 +1,18 @@ -#![allow(unused, deprecated)] use std::fmt::{self, Display}; use std::str::FromStr; -use std::ascii::AsciiExt; use self::Charset::*; /// A Mime charset. /// -/// The string representation is normalised to upper case. +/// The string representation is normalized to upper case. /// /// See [http://www.iana.org/assignments/character-sets/character-sets.xhtml][url]. /// /// [url]: http://www.iana.org/assignments/character-sets/character-sets.xhtml -#[derive(Clone,Debug,PartialEq)] +#[derive(Clone, Debug, PartialEq)] #[allow(non_camel_case_types)] -pub enum Charset{ +pub enum Charset { /// US ASCII Us_Ascii, /// ISO-8859-1 @@ -64,11 +62,11 @@ pub enum Charset{ /// KOI8-R Koi8_R, /// An arbitrary charset specified as a string - Ext(String) + Ext(String), } impl Charset { - fn name(&self) -> &str { + fn label(&self) -> &str { match *self { Us_Ascii => "US-ASCII", Iso_8859_1 => "ISO-8859-1", @@ -92,22 +90,23 @@ impl Charset { Iso_8859_8_E => "ISO-8859-8-E", Iso_8859_8_I => "ISO-8859-8-I", Gb2312 => "GB2312", - Big5 => "5", + Big5 => "big5", Koi8_R => "KOI8-R", - Ext(ref s) => s + Ext(ref s) => s, } } } impl Display for Charset { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str(self.name()) + f.write_str(self.label()) } } impl FromStr for Charset { - type Err = ::Error; - fn from_str(s: &str) -> ::Result { + type Err = crate::Error; + + fn from_str(s: &str) -> crate::Result { Ok(match s.to_ascii_uppercase().as_ref() { "US-ASCII" => Us_Ascii, "ISO-8859-1" => Iso_8859_1, @@ -131,20 +130,20 @@ impl FromStr for Charset { "ISO-8859-8-E" => Iso_8859_8_E, "ISO-8859-8-I" => Iso_8859_8_I, "GB2312" => Gb2312, - "5" => Big5, + "big5" => Big5, "KOI8-R" => Koi8_R, - s => Ext(s.to_owned()) + s => Ext(s.to_owned()), }) } } #[test] fn test_parse() { - assert_eq!(Us_Ascii,"us-ascii".parse().unwrap()); - assert_eq!(Us_Ascii,"US-Ascii".parse().unwrap()); - assert_eq!(Us_Ascii,"US-ASCII".parse().unwrap()); - assert_eq!(Shift_Jis,"Shift-JIS".parse().unwrap()); - assert_eq!(Ext("ABCD".to_owned()),"abcd".parse().unwrap()); + assert_eq!(Us_Ascii, "us-ascii".parse().unwrap()); + assert_eq!(Us_Ascii, "US-Ascii".parse().unwrap()); + assert_eq!(Us_Ascii, "US-ASCII".parse().unwrap()); + assert_eq!(Shift_Jis, "Shift-JIS".parse().unwrap()); + assert_eq!(Ext("ABCD".to_owned()), "abcd".parse().unwrap()); } #[test] diff --git a/src/header/shared/encoding.rs b/actix-http/src/header/shared/encoding.rs similarity index 77% rename from src/header/shared/encoding.rs rename to actix-http/src/header/shared/encoding.rs index 6381ac7eb..af7404828 100644 --- a/src/header/shared/encoding.rs +++ b/actix-http/src/header/shared/encoding.rs @@ -1,7 +1,8 @@ -use std::fmt; -use std::str; +use std::{fmt, str}; -pub use self::Encoding::{Chunked, Brotli, Gzip, Deflate, Compress, Identity, EncodingExt, Trailers}; +pub use self::Encoding::{ + Brotli, Chunked, Compress, Deflate, EncodingExt, Gzip, Identity, Trailers, +}; /// A value to represent an encoding used in `Transfer-Encoding` /// or `Accept-Encoding` header. @@ -22,7 +23,7 @@ pub enum Encoding { /// The `trailers` encoding. Trailers, /// Some other encoding that is less common, can be any String. - EncodingExt(String) + EncodingExt(String), } impl fmt::Display for Encoding { @@ -35,14 +36,14 @@ impl fmt::Display for Encoding { Compress => "compress", Identity => "identity", Trailers => "trailers", - EncodingExt(ref s) => s.as_ref() + EncodingExt(ref s) => s.as_ref(), }) } } impl str::FromStr for Encoding { - type Err = ::error::ParseError; - fn from_str(s: &str) -> Result { + type Err = crate::error::ParseError; + fn from_str(s: &str) -> Result { match s { "chunked" => Ok(Chunked), "br" => Ok(Brotli), @@ -51,7 +52,7 @@ impl str::FromStr for Encoding { "compress" => Ok(Compress), "identity" => Ok(Identity), "trailers" => Ok(Trailers), - _ => Ok(EncodingExt(s.to_owned())) + _ => Ok(EncodingExt(s.to_owned())), } } } diff --git a/src/header/shared/entity.rs b/actix-http/src/header/shared/entity.rs similarity index 72% rename from src/header/shared/entity.rs rename to actix-http/src/header/shared/entity.rs index 08a66b4f1..da02dc193 100644 --- a/src/header/shared/entity.rs +++ b/actix-http/src/header/shared/entity.rs @@ -1,21 +1,24 @@ -use std::str::FromStr; use std::fmt::{self, Display, Write}; -use header::{HeaderValue, Writer, IntoHeaderValue, InvalidHeaderValueBytes}; +use std::str::FromStr; + +use crate::header::{HeaderValue, IntoHeaderValue, InvalidHeaderValueBytes, Writer}; /// check that each char in the slice is either: /// 1. `%x21`, or /// 2. in the range `%x23` to `%x7E`, or /// 3. above `%x80` fn check_slice_validity(slice: &str) -> bool { - slice.bytes().all(|c| - c == b'\x21' || (c >= b'\x23' && c <= b'\x7e') | (c >= b'\x80')) + slice + .bytes() + .all(|c| c == b'\x21' || (c >= b'\x23' && c <= b'\x7e') | (c >= b'\x80')) } /// An entity tag, defined in [RFC7232](https://tools.ietf.org/html/rfc7232#section-2.3) /// /// An entity tag consists of a string enclosed by two literal double quotes. /// Preceding the first double quote is an optional weakness indicator, -/// which always looks like `W/`. Examples for valid tags are `"xyzzy"` and `W/"xyzzy"`. +/// which always looks like `W/`. Examples for valid tags are `"xyzzy"` and +/// `W/"xyzzy"`. /// /// # ABNF /// @@ -28,9 +31,9 @@ fn check_slice_validity(slice: &str) -> bool { /// ``` /// /// # Comparison -/// To check if two entity tags are equivalent in an application always use the `strong_eq` or -/// `weak_eq` methods based on the context of the Tag. Only use `==` to check if two tags are -/// identical. +/// To check if two entity tags are equivalent in an application always use the +/// `strong_eq` or `weak_eq` methods based on the context of the Tag. Only use +/// `==` to check if two tags are identical. /// /// The example below shows the results for a set of entity-tag pairs and /// both the weak and strong comparison function results: @@ -46,7 +49,7 @@ pub struct EntityTag { /// Weakness indicator for the tag pub weak: bool, /// The opaque string in between the DQUOTEs - tag: String + tag: String, } impl EntityTag { @@ -85,8 +88,8 @@ impl EntityTag { self.tag = tag } - /// For strong comparison two entity-tags are equivalent if both are not weak and their - /// opaque-tags match character-by-character. + /// For strong comparison two entity-tags are equivalent if both are not + /// weak and their opaque-tags match character-by-character. pub fn strong_eq(&self, other: &EntityTag) -> bool { !self.weak && !other.weak && self.tag == other.tag } @@ -120,26 +123,36 @@ impl Display for EntityTag { } impl FromStr for EntityTag { - type Err = ::error::ParseError; + type Err = crate::error::ParseError; - fn from_str(s: &str) -> Result { + fn from_str(s: &str) -> Result { let length: usize = s.len(); let slice = &s[..]; // Early exits if it doesn't terminate in a DQUOTE. if !slice.ends_with('"') || slice.len() < 2 { - return Err(::error::ParseError::Header); + return Err(crate::error::ParseError::Header); } // The etag is weak if its first char is not a DQUOTE. - if slice.len() >= 2 && slice.starts_with('"') - && check_slice_validity(&slice[1..length-1]) { + if slice.len() >= 2 + && slice.starts_with('"') + && check_slice_validity(&slice[1..length - 1]) + { // No need to check if the last char is a DQUOTE, // we already did that above. - return Ok(EntityTag { weak: false, tag: slice[1..length-1].to_owned() }); - } else if slice.len() >= 4 && slice.starts_with("W/\"") - && check_slice_validity(&slice[3..length-1]) { - return Ok(EntityTag { weak: true, tag: slice[3..length-1].to_owned() }); + return Ok(EntityTag { + weak: false, + tag: slice[1..length - 1].to_owned(), + }); + } else if slice.len() >= 4 + && slice.starts_with("W/\"") + && check_slice_validity(&slice[3..length - 1]) + { + return Ok(EntityTag { + weak: true, + tag: slice[3..length - 1].to_owned(), + }); } - Err(::error::ParseError::Header) + Err(crate::error::ParseError::Header) } } @@ -149,7 +162,7 @@ impl IntoHeaderValue for EntityTag { fn try_into(self) -> Result { let mut wrt = Writer::new(); write!(wrt, "{}", self).unwrap(); - unsafe{Ok(HeaderValue::from_shared_unchecked(wrt.take()))} + HeaderValue::from_shared(wrt.take()) } } @@ -160,22 +173,35 @@ mod tests { #[test] fn test_etag_parse_success() { // Expected success - assert_eq!("\"foobar\"".parse::().unwrap(), - EntityTag::strong("foobar".to_owned())); - assert_eq!("\"\"".parse::().unwrap(), - EntityTag::strong("".to_owned())); - assert_eq!("W/\"weaktag\"".parse::().unwrap(), - EntityTag::weak("weaktag".to_owned())); - assert_eq!("W/\"\x65\x62\"".parse::().unwrap(), - EntityTag::weak("\x65\x62".to_owned())); - assert_eq!("W/\"\"".parse::().unwrap(), EntityTag::weak("".to_owned())); + assert_eq!( + "\"foobar\"".parse::().unwrap(), + EntityTag::strong("foobar".to_owned()) + ); + assert_eq!( + "\"\"".parse::().unwrap(), + EntityTag::strong("".to_owned()) + ); + assert_eq!( + "W/\"weaktag\"".parse::().unwrap(), + EntityTag::weak("weaktag".to_owned()) + ); + assert_eq!( + "W/\"\x65\x62\"".parse::().unwrap(), + EntityTag::weak("\x65\x62".to_owned()) + ); + assert_eq!( + "W/\"\"".parse::().unwrap(), + EntityTag::weak("".to_owned()) + ); } #[test] fn test_etag_parse_failures() { // Expected failures assert!("no-dquotes".parse::().is_err()); - assert!("w/\"the-first-w-is-case-sensitive\"".parse::().is_err()); + assert!("w/\"the-first-w-is-case-sensitive\"" + .parse::() + .is_err()); assert!("".parse::().is_err()); assert!("\"unmatched-dquotes1".parse::().is_err()); assert!("unmatched-dquotes2\"".parse::().is_err()); @@ -184,10 +210,19 @@ mod tests { #[test] fn test_etag_fmt() { - assert_eq!(format!("{}", EntityTag::strong("foobar".to_owned())), "\"foobar\""); + assert_eq!( + format!("{}", EntityTag::strong("foobar".to_owned())), + "\"foobar\"" + ); assert_eq!(format!("{}", EntityTag::strong("".to_owned())), "\"\""); - assert_eq!(format!("{}", EntityTag::weak("weak-etag".to_owned())), "W/\"weak-etag\""); - assert_eq!(format!("{}", EntityTag::weak("\u{0065}".to_owned())), "W/\"\x65\""); + assert_eq!( + format!("{}", EntityTag::weak("weak-etag".to_owned())), + "W/\"weak-etag\"" + ); + assert_eq!( + format!("{}", EntityTag::weak("\u{0065}".to_owned())), + "W/\"\x65\"" + ); assert_eq!(format!("{}", EntityTag::weak("".to_owned())), "W/\"\""); } diff --git a/src/header/shared/httpdate.rs b/actix-http/src/header/shared/httpdate.rs similarity index 63% rename from src/header/shared/httpdate.rs rename to actix-http/src/header/shared/httpdate.rs index b2fcf5270..350f77bbe 100644 --- a/src/header/shared/httpdate.rs +++ b/actix-http/src/header/shared/httpdate.rs @@ -3,13 +3,11 @@ use std::io::Write; use std::str::FromStr; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use time; -use bytes::{BytesMut, BufMut}; +use bytes::{BufMut, BytesMut}; use http::header::{HeaderValue, InvalidHeaderValueBytes}; -use error::ParseError; -use header::IntoHeaderValue; - +use crate::error::ParseError; +use crate::header::IntoHeaderValue; /// A timestamp with HTTP formatting and parsing #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] @@ -19,11 +17,10 @@ impl FromStr for HttpDate { type Err = ParseError; fn from_str(s: &str) -> Result { - match time::strptime(s, "%a, %d %b %Y %T %Z").or_else(|_| { - time::strptime(s, "%A, %d-%b-%y %T %Z") - }).or_else(|_| { - time::strptime(s, "%c") - }) { + match time::strptime(s, "%a, %d %b %Y %T %Z") + .or_else(|_| time::strptime(s, "%A, %d-%b-%y %T %Z")) + .or_else(|_| time::strptime(s, "%c")) + { Ok(t) => Ok(HttpDate(t)), Err(_) => Err(ParseError::Header), } @@ -47,11 +44,14 @@ impl From for HttpDate { let tmspec = match sys.duration_since(UNIX_EPOCH) { Ok(dur) => { time::Timespec::new(dur.as_secs() as i64, dur.subsec_nanos() as i32) - }, + } Err(err) => { let neg = err.duration(); - time::Timespec::new(-(neg.as_secs() as i64), -(neg.subsec_nanos() as i32)) - }, + time::Timespec::new( + -(neg.as_secs() as i64), + -(neg.subsec_nanos() as i32), + ) + } }; HttpDate(time::at_utc(tmspec)) } @@ -63,7 +63,7 @@ impl IntoHeaderValue for HttpDate { fn try_into(self) -> Result { let mut wrt = BytesMut::with_capacity(29).writer(); write!(wrt, "{}", self.0.rfc822()).unwrap(); - unsafe{Ok(HeaderValue::from_shared_unchecked(wrt.get_mut().take().freeze()))} + HeaderValue::from_shared(wrt.get_mut().take().freeze()) } } @@ -80,18 +80,39 @@ impl From for SystemTime { #[cfg(test)] mod tests { - use time::Tm; use super::HttpDate; + use time::Tm; const NOV_07: HttpDate = HttpDate(Tm { - tm_nsec: 0, tm_sec: 37, tm_min: 48, tm_hour: 8, tm_mday: 7, tm_mon: 10, tm_year: 94, - tm_wday: 0, tm_isdst: 0, tm_yday: 0, tm_utcoff: 0}); + tm_nsec: 0, + tm_sec: 37, + tm_min: 48, + tm_hour: 8, + tm_mday: 7, + tm_mon: 10, + tm_year: 94, + tm_wday: 0, + tm_isdst: 0, + tm_yday: 0, + tm_utcoff: 0, + }); #[test] fn test_date() { - assert_eq!("Sun, 07 Nov 1994 08:48:37 GMT".parse::().unwrap(), NOV_07); - assert_eq!("Sunday, 07-Nov-94 08:48:37 GMT".parse::().unwrap(), NOV_07); - assert_eq!("Sun Nov 7 08:48:37 1994".parse::().unwrap(), NOV_07); + assert_eq!( + "Sun, 07 Nov 1994 08:48:37 GMT".parse::().unwrap(), + NOV_07 + ); + assert_eq!( + "Sunday, 07-Nov-94 08:48:37 GMT" + .parse::() + .unwrap(), + NOV_07 + ); + assert_eq!( + "Sun Nov 7 08:48:37 1994".parse::().unwrap(), + NOV_07 + ); assert!("this-is-no-date".parse::().is_err()); } } diff --git a/src/header/shared/mod.rs b/actix-http/src/header/shared/mod.rs similarity index 81% rename from src/header/shared/mod.rs rename to actix-http/src/header/shared/mod.rs index 04ff7f41a..f2bc91634 100644 --- a/src/header/shared/mod.rs +++ b/actix-http/src/header/shared/mod.rs @@ -4,11 +4,11 @@ pub use self::charset::Charset; pub use self::encoding::Encoding; pub use self::entity::EntityTag; pub use self::httpdate::HttpDate; +pub use self::quality_item::{q, qitem, Quality, QualityItem}; pub use language_tags::LanguageTag; -pub use self::quality_item::{Quality, QualityItem, qitem, q}; mod charset; -mod entity; mod encoding; +mod entity; mod httpdate; mod quality_item; diff --git a/src/header/shared/quality_item.rs b/actix-http/src/header/shared/quality_item.rs similarity index 77% rename from src/header/shared/quality_item.rs rename to actix-http/src/header/shared/quality_item.rs index aa56866ac..fc3930c5e 100644 --- a/src/header/shared/quality_item.rs +++ b/actix-http/src/header/shared/quality_item.rs @@ -1,9 +1,4 @@ -#![allow(unused, deprecated)] -use std::ascii::AsciiExt; -use std::cmp; -use std::default::Default; -use std::fmt; -use std::str; +use std::{cmp, fmt, str}; use self::internal::IntoQuality; @@ -13,11 +8,13 @@ use self::internal::IntoQuality; /// /// # Implementation notes /// -/// The quality value is defined as a number between 0 and 1 with three decimal places. This means -/// there are 1001 possible values. Since floating point numbers are not exact and the smallest -/// floating point data type (`f32`) consumes four bytes, hyper uses an `u16` value to store the -/// quality internally. For performance reasons you may set quality directly to a value between -/// 0 and 1000 e.g. `Quality(532)` matches the quality `q=0.532`. +/// The quality value is defined as a number between 0 and 1 with three decimal +/// places. This means there are 1001 possible values. Since floating point +/// numbers are not exact and the smallest floating point data type (`f32`) +/// consumes four bytes, hyper uses an `u16` value to store the +/// quality internally. For performance reasons you may set quality directly to +/// a value between 0 and 1000 e.g. `Quality(532)` matches the quality +/// `q=0.532`. /// /// [RFC7231 Section 5.3.1](https://tools.ietf.org/html/rfc7231#section-5.3.1) /// gives more information on quality values in HTTP header fields. @@ -57,21 +54,21 @@ impl cmp::PartialOrd for QualityItem { impl fmt::Display for QualityItem { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - try!(fmt::Display::fmt(&self.item, f)); + fmt::Display::fmt(&self.item, f)?; match self.quality.0 { 1000 => Ok(()), 0 => f.write_str("; q=0"), - x => write!(f, "; q=0.{}", format!("{:03}", x).trim_right_matches('0')) + x => write!(f, "; q=0.{}", format!("{:03}", x).trim_end_matches('0')), } } } impl str::FromStr for QualityItem { - type Err = ::error::ParseError; + type Err = crate::error::ParseError; - fn from_str(s: &str) -> Result, ::error::ParseError> { + fn from_str(s: &str) -> Result, crate::error::ParseError> { if !s.is_ascii() { - return Err(::error::ParseError::Header); + return Err(crate::error::ParseError::Header); } // Set defaults used if parsing fails. let mut raw_item = s; @@ -80,13 +77,13 @@ impl str::FromStr for QualityItem { let parts: Vec<&str> = s.rsplitn(2, ';').map(|x| x.trim()).collect(); if parts.len() == 2 { if parts[0].len() < 2 { - return Err(::error::ParseError::Header); + return Err(crate::error::ParseError::Header); } let start = &parts[0][0..2]; if start == "q=" || start == "Q=" { let q_part = &parts[0][2..parts[0].len()]; if q_part.len() > 5 { - return Err(::error::ParseError::Header); + return Err(crate::error::ParseError::Header); } match q_part.parse::() { Ok(q_value) => { @@ -94,17 +91,17 @@ impl str::FromStr for QualityItem { quality = q_value; raw_item = parts[1]; } else { - return Err(::error::ParseError::Header); + return Err(crate::error::ParseError::Header); } - }, - Err(_) => return Err(::error::ParseError::Header), + } + Err(_) => return Err(crate::error::ParseError::Header), } } } match raw_item.parse::() { // we already checked above that the quality is within range Ok(item) => Ok(QualityItem::new(item, from_f32(quality))), - Err(_) => Err(::error::ParseError::Header), + Err(_) => Err(crate::error::ParseError::Header), } } } @@ -114,7 +111,10 @@ fn from_f32(f: f32) -> Quality { // this function is only used internally. A check that `f` is within range // should be done before calling this method. Just in case, this // debug_assert should catch if we were forgetful - debug_assert!(f >= 0f32 && f <= 1f32, "q value must be between 0.0 and 1.0"); + debug_assert!( + f >= 0f32 && f <= 1f32, + "q value must be between 0.0 and 1.0" + ); Quality((f * 1000f32) as u16) } @@ -125,7 +125,7 @@ pub fn qitem(item: T) -> QualityItem { } /// Convenience function to create a `Quality` from a float or integer. -/// +/// /// Implemented for `u16` and `f32`. Panics if value is out of range. pub fn q(val: T) -> Quality { val.into_quality() @@ -147,7 +147,10 @@ mod internal { impl IntoQuality for f32 { fn into_quality(self) -> Quality { - assert!(self >= 0f32 && self <= 1f32, "float must be between 0.0 and 1.0"); + assert!( + self >= 0f32 && self <= 1f32, + "float must be between 0.0 and 1.0" + ); super::from_f32(self) } } @@ -159,7 +162,6 @@ mod internal { } } - pub trait Sealed {} impl Sealed for u16 {} impl Sealed for f32 {} @@ -167,8 +169,8 @@ mod internal { #[cfg(test)] mod tests { - use super::*; use super::super::encoding::*; + use super::*; #[test] fn test_quality_item_fmt_q_1() { @@ -183,7 +185,7 @@ mod tests { #[test] fn test_quality_item_fmt_q_05() { // Custom value - let x = QualityItem{ + let x = QualityItem { item: EncodingExt("identity".to_owned()), quality: Quality(500), }; @@ -193,7 +195,7 @@ mod tests { #[test] fn test_quality_item_fmt_q_0() { // Custom value - let x = QualityItem{ + let x = QualityItem { item: EncodingExt("identity".to_owned()), quality: Quality(0), }; @@ -203,22 +205,46 @@ mod tests { #[test] fn test_quality_item_from_str1() { let x: Result, _> = "chunked".parse(); - assert_eq!(x.unwrap(), QualityItem{ item: Chunked, quality: Quality(1000), }); + assert_eq!( + x.unwrap(), + QualityItem { + item: Chunked, + quality: Quality(1000), + } + ); } #[test] fn test_quality_item_from_str2() { let x: Result, _> = "chunked; q=1".parse(); - assert_eq!(x.unwrap(), QualityItem{ item: Chunked, quality: Quality(1000), }); + assert_eq!( + x.unwrap(), + QualityItem { + item: Chunked, + quality: Quality(1000), + } + ); } #[test] fn test_quality_item_from_str3() { let x: Result, _> = "gzip; q=0.5".parse(); - assert_eq!(x.unwrap(), QualityItem{ item: Gzip, quality: Quality(500), }); + assert_eq!( + x.unwrap(), + QualityItem { + item: Gzip, + quality: Quality(500), + } + ); } #[test] fn test_quality_item_from_str4() { let x: Result, _> = "gzip; q=0.273".parse(); - assert_eq!(x.unwrap(), QualityItem{ item: Gzip, quality: Quality(273), }); + assert_eq!( + x.unwrap(), + QualityItem { + item: Gzip, + quality: Quality(273), + } + ); } #[test] fn test_quality_item_from_str5() { @@ -245,14 +271,14 @@ mod tests { #[test] #[should_panic] // FIXME - 32-bit msvc unwinding broken - #[cfg_attr(all(target_arch="x86", target_env="msvc"), ignore)] + #[cfg_attr(all(target_arch = "x86", target_env = "msvc"), ignore)] fn test_quality_invalid() { q(-1.0); } #[test] #[should_panic] // FIXME - 32-bit msvc unwinding broken - #[cfg_attr(all(target_arch="x86", target_env="msvc"), ignore)] + #[cfg_attr(all(target_arch = "x86", target_env = "msvc"), ignore)] fn test_quality_invalid2() { q(2.0); } diff --git a/actix-http/src/helpers.rs b/actix-http/src/helpers.rs new file mode 100644 index 000000000..84403d8fd --- /dev/null +++ b/actix-http/src/helpers.rs @@ -0,0 +1,235 @@ +use std::{io, mem, ptr, slice}; + +use bytes::{BufMut, BytesMut}; +use http::Version; + +use crate::extensions::Extensions; + +const DEC_DIGITS_LUT: &[u8] = b"0001020304050607080910111213141516171819\ + 2021222324252627282930313233343536373839\ + 4041424344454647484950515253545556575859\ + 6061626364656667686970717273747576777879\ + 8081828384858687888990919293949596979899"; + +pub(crate) const STATUS_LINE_BUF_SIZE: usize = 13; + +pub(crate) fn write_status_line(version: Version, mut n: u16, bytes: &mut BytesMut) { + let mut buf: [u8; STATUS_LINE_BUF_SIZE] = [ + b'H', b'T', b'T', b'P', b'/', b'1', b'.', b'1', b' ', b' ', b' ', b' ', b' ', + ]; + match version { + Version::HTTP_2 => buf[5] = b'2', + Version::HTTP_10 => buf[7] = b'0', + Version::HTTP_09 => { + buf[5] = b'0'; + buf[7] = b'9'; + } + _ => (), + } + + let mut curr: isize = 12; + let buf_ptr = buf.as_mut_ptr(); + let lut_ptr = DEC_DIGITS_LUT.as_ptr(); + let four = n > 999; + + // decode 2 more chars, if > 2 chars + let d1 = (n % 100) << 1; + n /= 100; + curr -= 2; + unsafe { + ptr::copy_nonoverlapping(lut_ptr.offset(d1 as isize), buf_ptr.offset(curr), 2); + } + + // decode last 1 or 2 chars + if n < 10 { + curr -= 1; + unsafe { + *buf_ptr.offset(curr) = (n as u8) + b'0'; + } + } else { + let d1 = n << 1; + curr -= 2; + unsafe { + ptr::copy_nonoverlapping( + lut_ptr.offset(d1 as isize), + buf_ptr.offset(curr), + 2, + ); + } + } + + bytes.put_slice(&buf); + if four { + bytes.put(b' '); + } +} + +/// NOTE: bytes object has to contain enough space +pub fn write_content_length(mut n: usize, bytes: &mut BytesMut) { + if n < 10 { + let mut buf: [u8; 21] = [ + b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', b'e', + b'n', b'g', b't', b'h', b':', b' ', b'0', b'\r', b'\n', + ]; + buf[18] = (n as u8) + b'0'; + bytes.put_slice(&buf); + } else if n < 100 { + let mut buf: [u8; 22] = [ + b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', b'e', + b'n', b'g', b't', b'h', b':', b' ', b'0', b'0', b'\r', b'\n', + ]; + let d1 = n << 1; + unsafe { + ptr::copy_nonoverlapping( + DEC_DIGITS_LUT.as_ptr().add(d1), + buf.as_mut_ptr().offset(18), + 2, + ); + } + bytes.put_slice(&buf); + } else if n < 1000 { + let mut buf: [u8; 23] = [ + b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', b'e', + b'n', b'g', b't', b'h', b':', b' ', b'0', b'0', b'0', b'\r', b'\n', + ]; + // decode 2 more chars, if > 2 chars + let d1 = (n % 100) << 1; + n /= 100; + unsafe { + ptr::copy_nonoverlapping( + DEC_DIGITS_LUT.as_ptr().add(d1), + buf.as_mut_ptr().offset(19), + 2, + ) + }; + + // decode last 1 + buf[18] = (n as u8) + b'0'; + + bytes.put_slice(&buf); + } else { + bytes.put_slice(b"\r\ncontent-length: "); + convert_usize(n, bytes); + } +} + +pub(crate) fn convert_usize(mut n: usize, bytes: &mut BytesMut) { + let mut curr: isize = 39; + let mut buf: [u8; 41] = unsafe { mem::MaybeUninit::uninit().assume_init() }; + buf[39] = b'\r'; + buf[40] = b'\n'; + let buf_ptr = buf.as_mut_ptr(); + let lut_ptr = DEC_DIGITS_LUT.as_ptr(); + + // eagerly decode 4 characters at a time + while n >= 10_000 { + let rem = (n % 10_000) as isize; + n /= 10_000; + + let d1 = (rem / 100) << 1; + let d2 = (rem % 100) << 1; + curr -= 4; + unsafe { + ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2); + ptr::copy_nonoverlapping(lut_ptr.offset(d2), buf_ptr.offset(curr + 2), 2); + } + } + + // if we reach here numbers are <= 9999, so at most 4 chars long + let mut n = n as isize; // possibly reduce 64bit math + + // decode 2 more chars, if > 2 chars + if n >= 100 { + let d1 = (n % 100) << 1; + n /= 100; + curr -= 2; + unsafe { + ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2); + } + } + + // decode last 1 or 2 chars + if n < 10 { + curr -= 1; + unsafe { + *buf_ptr.offset(curr) = (n as u8) + b'0'; + } + } else { + let d1 = n << 1; + curr -= 2; + unsafe { + ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2); + } + } + + unsafe { + bytes.extend_from_slice(slice::from_raw_parts( + buf_ptr.offset(curr), + 41 - curr as usize, + )); + } +} + +pub(crate) struct Writer<'a>(pub &'a mut BytesMut); + +impl<'a> io::Write for Writer<'a> { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.extend_from_slice(buf); + Ok(buf.len()) + } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +pub(crate) trait DataFactory { + fn set(&self, ext: &mut Extensions); +} + +pub(crate) struct Data(pub(crate) T); + +impl DataFactory for Data { + fn set(&self, ext: &mut Extensions) { + ext.insert(self.0.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_write_content_length() { + let mut bytes = BytesMut::new(); + bytes.reserve(50); + write_content_length(0, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 0\r\n"[..]); + bytes.reserve(50); + write_content_length(9, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 9\r\n"[..]); + bytes.reserve(50); + write_content_length(10, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 10\r\n"[..]); + bytes.reserve(50); + write_content_length(99, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 99\r\n"[..]); + bytes.reserve(50); + write_content_length(100, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 100\r\n"[..]); + bytes.reserve(50); + write_content_length(101, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 101\r\n"[..]); + bytes.reserve(50); + write_content_length(998, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 998\r\n"[..]); + bytes.reserve(50); + write_content_length(1000, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 1000\r\n"[..]); + bytes.reserve(50); + write_content_length(1001, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 1001\r\n"[..]); + bytes.reserve(50); + write_content_length(5909, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 5909\r\n"[..]); + } +} diff --git a/actix-http/src/httpcodes.rs b/actix-http/src/httpcodes.rs new file mode 100644 index 000000000..0c7f23fc8 --- /dev/null +++ b/actix-http/src/httpcodes.rs @@ -0,0 +1,87 @@ +//! Basic http responses +#![allow(non_upper_case_globals)] +use http::StatusCode; + +use crate::response::{Response, ResponseBuilder}; + +macro_rules! STATIC_RESP { + ($name:ident, $status:expr) => { + #[allow(non_snake_case, missing_docs)] + pub fn $name() -> ResponseBuilder { + ResponseBuilder::new($status) + } + }; +} + +impl Response { + STATIC_RESP!(Ok, StatusCode::OK); + STATIC_RESP!(Created, StatusCode::CREATED); + STATIC_RESP!(Accepted, StatusCode::ACCEPTED); + STATIC_RESP!( + NonAuthoritativeInformation, + StatusCode::NON_AUTHORITATIVE_INFORMATION + ); + + STATIC_RESP!(NoContent, StatusCode::NO_CONTENT); + STATIC_RESP!(ResetContent, StatusCode::RESET_CONTENT); + STATIC_RESP!(PartialContent, StatusCode::PARTIAL_CONTENT); + STATIC_RESP!(MultiStatus, StatusCode::MULTI_STATUS); + STATIC_RESP!(AlreadyReported, StatusCode::ALREADY_REPORTED); + + STATIC_RESP!(MultipleChoices, StatusCode::MULTIPLE_CHOICES); + STATIC_RESP!(MovedPermanently, StatusCode::MOVED_PERMANENTLY); + STATIC_RESP!(Found, StatusCode::FOUND); + STATIC_RESP!(SeeOther, StatusCode::SEE_OTHER); + STATIC_RESP!(NotModified, StatusCode::NOT_MODIFIED); + STATIC_RESP!(UseProxy, StatusCode::USE_PROXY); + STATIC_RESP!(TemporaryRedirect, StatusCode::TEMPORARY_REDIRECT); + STATIC_RESP!(PermanentRedirect, StatusCode::PERMANENT_REDIRECT); + + STATIC_RESP!(BadRequest, StatusCode::BAD_REQUEST); + STATIC_RESP!(NotFound, StatusCode::NOT_FOUND); + STATIC_RESP!(Unauthorized, StatusCode::UNAUTHORIZED); + STATIC_RESP!(PaymentRequired, StatusCode::PAYMENT_REQUIRED); + STATIC_RESP!(Forbidden, StatusCode::FORBIDDEN); + STATIC_RESP!(MethodNotAllowed, StatusCode::METHOD_NOT_ALLOWED); + STATIC_RESP!(NotAcceptable, StatusCode::NOT_ACCEPTABLE); + STATIC_RESP!( + ProxyAuthenticationRequired, + StatusCode::PROXY_AUTHENTICATION_REQUIRED + ); + STATIC_RESP!(RequestTimeout, StatusCode::REQUEST_TIMEOUT); + STATIC_RESP!(Conflict, StatusCode::CONFLICT); + STATIC_RESP!(Gone, StatusCode::GONE); + STATIC_RESP!(LengthRequired, StatusCode::LENGTH_REQUIRED); + STATIC_RESP!(PreconditionFailed, StatusCode::PRECONDITION_FAILED); + STATIC_RESP!(PreconditionRequired, StatusCode::PRECONDITION_REQUIRED); + STATIC_RESP!(PayloadTooLarge, StatusCode::PAYLOAD_TOO_LARGE); + STATIC_RESP!(UriTooLong, StatusCode::URI_TOO_LONG); + STATIC_RESP!(UnsupportedMediaType, StatusCode::UNSUPPORTED_MEDIA_TYPE); + STATIC_RESP!(RangeNotSatisfiable, StatusCode::RANGE_NOT_SATISFIABLE); + STATIC_RESP!(ExpectationFailed, StatusCode::EXPECTATION_FAILED); + STATIC_RESP!(UnprocessableEntity, StatusCode::UNPROCESSABLE_ENTITY); + STATIC_RESP!(TooManyRequests, StatusCode::TOO_MANY_REQUESTS); + + STATIC_RESP!(InternalServerError, StatusCode::INTERNAL_SERVER_ERROR); + STATIC_RESP!(NotImplemented, StatusCode::NOT_IMPLEMENTED); + STATIC_RESP!(BadGateway, StatusCode::BAD_GATEWAY); + STATIC_RESP!(ServiceUnavailable, StatusCode::SERVICE_UNAVAILABLE); + STATIC_RESP!(GatewayTimeout, StatusCode::GATEWAY_TIMEOUT); + STATIC_RESP!(VersionNotSupported, StatusCode::HTTP_VERSION_NOT_SUPPORTED); + STATIC_RESP!(VariantAlsoNegotiates, StatusCode::VARIANT_ALSO_NEGOTIATES); + STATIC_RESP!(InsufficientStorage, StatusCode::INSUFFICIENT_STORAGE); + STATIC_RESP!(LoopDetected, StatusCode::LOOP_DETECTED); +} + +#[cfg(test)] +mod tests { + use crate::body::Body; + use crate::response::Response; + use http::StatusCode; + + #[test] + fn test_build() { + let resp = Response::Ok().body(Body::Empty); + assert_eq!(resp.status(), StatusCode::OK); + } +} diff --git a/actix-http/src/httpmessage.rs b/actix-http/src/httpmessage.rs new file mode 100644 index 000000000..05d668c10 --- /dev/null +++ b/actix-http/src/httpmessage.rs @@ -0,0 +1,261 @@ +use std::cell::{Ref, RefMut}; +use std::str; + +use encoding_rs::{Encoding, UTF_8}; +use http::header; +use mime::Mime; + +use crate::cookie::Cookie; +use crate::error::{ContentTypeError, CookieParseError, ParseError}; +use crate::extensions::Extensions; +use crate::header::{Header, HeaderMap}; +use crate::payload::Payload; + +struct Cookies(Vec>); + +/// Trait that implements general purpose operations on http messages +pub trait HttpMessage: Sized { + /// Type of message payload stream + type Stream; + + /// Read the message headers. + fn headers(&self) -> &HeaderMap; + + /// Message payload stream + fn take_payload(&mut self) -> Payload; + + /// Request's extensions container + fn extensions(&self) -> Ref; + + /// Mutable reference to a the request's extensions container + fn extensions_mut(&self) -> RefMut; + + #[doc(hidden)] + /// Get a header + fn get_header(&self) -> Option + where + Self: Sized, + { + if self.headers().contains_key(H::name()) { + H::parse(self).ok() + } else { + None + } + } + + /// Read the request content type. If request does not contain + /// *Content-Type* header, empty str get returned. + fn content_type(&self) -> &str { + if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) { + if let Ok(content_type) = content_type.to_str() { + return content_type.split(';').next().unwrap().trim(); + } + } + "" + } + + /// Get content type encoding + /// + /// UTF-8 is used by default, If request charset is not set. + fn encoding(&self) -> Result<&'static Encoding, ContentTypeError> { + if let Some(mime_type) = self.mime_type()? { + if let Some(charset) = mime_type.get_param("charset") { + if let Some(enc) = + Encoding::for_label_no_replacement(charset.as_str().as_bytes()) + { + Ok(enc) + } else { + Err(ContentTypeError::UnknownEncoding) + } + } else { + Ok(UTF_8) + } + } else { + Ok(UTF_8) + } + } + + /// Convert the request content type to a known mime type. + fn mime_type(&self) -> Result, ContentTypeError> { + if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) { + if let Ok(content_type) = content_type.to_str() { + return match content_type.parse() { + Ok(mt) => Ok(Some(mt)), + Err(_) => Err(ContentTypeError::ParseError), + }; + } else { + return Err(ContentTypeError::ParseError); + } + } + Ok(None) + } + + /// Check if request has chunked transfer encoding + fn chunked(&self) -> Result { + if let Some(encodings) = self.headers().get(header::TRANSFER_ENCODING) { + if let Ok(s) = encodings.to_str() { + Ok(s.to_lowercase().contains("chunked")) + } else { + Err(ParseError::Header) + } + } else { + Ok(false) + } + } + + /// Load request cookies. + #[inline] + fn cookies(&self) -> Result>>, CookieParseError> { + if self.extensions().get::().is_none() { + let mut cookies = Vec::new(); + for hdr in self.headers().get_all(header::COOKIE) { + let s = + str::from_utf8(hdr.as_bytes()).map_err(CookieParseError::from)?; + for cookie_str in s.split(';').map(|s| s.trim()) { + if !cookie_str.is_empty() { + cookies.push(Cookie::parse_encoded(cookie_str)?.into_owned()); + } + } + } + self.extensions_mut().insert(Cookies(cookies)); + } + Ok(Ref::map(self.extensions(), |ext| { + &ext.get::().unwrap().0 + })) + } + + /// Return request cookie. + fn cookie(&self, name: &str) -> Option> { + if let Ok(cookies) = self.cookies() { + for cookie in cookies.iter() { + if cookie.name() == name { + return Some(cookie.to_owned()); + } + } + } + None + } +} + +impl<'a, T> HttpMessage for &'a mut T +where + T: HttpMessage, +{ + type Stream = T::Stream; + + fn headers(&self) -> &HeaderMap { + (**self).headers() + } + + /// Message payload stream + fn take_payload(&mut self) -> Payload { + (**self).take_payload() + } + + /// Request's extensions container + fn extensions(&self) -> Ref { + (**self).extensions() + } + + /// Mutable reference to a the request's extensions container + fn extensions_mut(&self) -> RefMut { + (**self).extensions_mut() + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use encoding_rs::ISO_8859_2; + use mime; + + use super::*; + use crate::test::TestRequest; + + #[test] + fn test_content_type() { + let req = TestRequest::with_header("content-type", "text/plain").finish(); + assert_eq!(req.content_type(), "text/plain"); + let req = + TestRequest::with_header("content-type", "application/json; charset=utf=8") + .finish(); + assert_eq!(req.content_type(), "application/json"); + let req = TestRequest::default().finish(); + assert_eq!(req.content_type(), ""); + } + + #[test] + fn test_mime_type() { + let req = TestRequest::with_header("content-type", "application/json").finish(); + assert_eq!(req.mime_type().unwrap(), Some(mime::APPLICATION_JSON)); + let req = TestRequest::default().finish(); + assert_eq!(req.mime_type().unwrap(), None); + let req = + TestRequest::with_header("content-type", "application/json; charset=utf-8") + .finish(); + let mt = req.mime_type().unwrap().unwrap(); + assert_eq!(mt.get_param(mime::CHARSET), Some(mime::UTF_8)); + assert_eq!(mt.type_(), mime::APPLICATION); + assert_eq!(mt.subtype(), mime::JSON); + } + + #[test] + fn test_mime_type_error() { + let req = TestRequest::with_header( + "content-type", + "applicationadfadsfasdflknadsfklnadsfjson", + ) + .finish(); + assert_eq!(Err(ContentTypeError::ParseError), req.mime_type()); + } + + #[test] + fn test_encoding() { + let req = TestRequest::default().finish(); + assert_eq!(UTF_8.name(), req.encoding().unwrap().name()); + + let req = TestRequest::with_header("content-type", "application/json").finish(); + assert_eq!(UTF_8.name(), req.encoding().unwrap().name()); + + let req = TestRequest::with_header( + "content-type", + "application/json; charset=ISO-8859-2", + ) + .finish(); + assert_eq!(ISO_8859_2, req.encoding().unwrap()); + } + + #[test] + fn test_encoding_error() { + let req = TestRequest::with_header("content-type", "applicatjson").finish(); + assert_eq!(Some(ContentTypeError::ParseError), req.encoding().err()); + + let req = TestRequest::with_header( + "content-type", + "application/json; charset=kkkttktk", + ) + .finish(); + assert_eq!( + Some(ContentTypeError::UnknownEncoding), + req.encoding().err() + ); + } + + #[test] + fn test_chunked() { + let req = TestRequest::default().finish(); + assert!(!req.chunked().unwrap()); + + let req = + TestRequest::with_header(header::TRANSFER_ENCODING, "chunked").finish(); + assert!(req.chunked().unwrap()); + + let req = TestRequest::default() + .header( + header::TRANSFER_ENCODING, + Bytes::from_static(b"some va\xadscc\xacas0xsdasdlue"), + ) + .finish(); + assert!(req.chunked().is_err()); + } +} diff --git a/actix-http/src/lib.rs b/actix-http/src/lib.rs new file mode 100644 index 000000000..b57fdddce --- /dev/null +++ b/actix-http/src/lib.rs @@ -0,0 +1,66 @@ +//! Basic http primitives for actix-net framework. +#![allow( + clippy::type_complexity, + clippy::too_many_arguments, + clippy::new_without_default, + clippy::borrow_interior_mutable_const, + clippy::write_with_newline +)] + +#[macro_use] +extern crate log; + +pub mod body; +mod builder; +pub mod client; +mod cloneable; +mod config; +pub mod encoding; +mod extensions; +mod header; +mod helpers; +mod httpcodes; +pub mod httpmessage; +mod message; +mod payload; +mod request; +mod response; +mod service; + +pub mod cookie; +pub mod error; +pub mod h1; +pub mod h2; +pub mod test; +pub mod ws; + +pub use self::builder::HttpServiceBuilder; +pub use self::config::{KeepAlive, ServiceConfig}; +pub use self::error::{Error, ResponseError, Result}; +pub use self::extensions::Extensions; +pub use self::httpmessage::HttpMessage; +pub use self::message::{Message, RequestHead, RequestHeadType, ResponseHead}; +pub use self::payload::{Payload, PayloadStream}; +pub use self::request::Request; +pub use self::response::{Response, ResponseBuilder}; +pub use self::service::HttpService; + +pub mod http { + //! Various HTTP related types + + // re-exports + pub use http::header::{HeaderName, HeaderValue}; + pub use http::uri::PathAndQuery; + pub use http::{uri, Error, HttpTryFrom, Uri}; + pub use http::{Method, StatusCode, Version}; + + pub use crate::cookie::{Cookie, CookieBuilder}; + pub use crate::header::HeaderMap; + + /// Various http headers + pub mod header { + pub use crate::header::*; + } + pub use crate::header::ContentEncoding; + pub use crate::message::ConnectionType; +} diff --git a/actix-http/src/message.rs b/actix-http/src/message.rs new file mode 100644 index 000000000..5994ed39e --- /dev/null +++ b/actix-http/src/message.rs @@ -0,0 +1,495 @@ +use std::cell::{Ref, RefCell, RefMut}; +use std::net; +use std::rc::Rc; + +use bitflags::bitflags; +use copyless::BoxHelper; + +use crate::extensions::Extensions; +use crate::header::HeaderMap; +use crate::http::{header, Method, StatusCode, Uri, Version}; + +/// Represents various types of connection +#[derive(Copy, Clone, PartialEq, Debug)] +pub enum ConnectionType { + /// Close connection after response + Close, + /// Keep connection alive after response + KeepAlive, + /// Connection is upgraded to different type + Upgrade, +} + +bitflags! { + pub(crate) struct Flags: u8 { + const CLOSE = 0b0000_0001; + const KEEP_ALIVE = 0b0000_0010; + const UPGRADE = 0b0000_0100; + const EXPECT = 0b0000_1000; + const NO_CHUNKING = 0b0001_0000; + const CAMEL_CASE = 0b0010_0000; + } +} + +#[doc(hidden)] +pub trait Head: Default + 'static { + fn clear(&mut self); + + fn pool() -> &'static MessagePool; +} + +#[derive(Debug)] +pub struct RequestHead { + pub uri: Uri, + pub method: Method, + pub version: Version, + pub headers: HeaderMap, + pub extensions: RefCell, + pub peer_addr: Option, + flags: Flags, +} + +impl Default for RequestHead { + fn default() -> RequestHead { + RequestHead { + uri: Uri::default(), + method: Method::default(), + version: Version::HTTP_11, + headers: HeaderMap::with_capacity(16), + flags: Flags::empty(), + peer_addr: None, + extensions: RefCell::new(Extensions::new()), + } + } +} + +impl Head for RequestHead { + fn clear(&mut self) { + self.flags = Flags::empty(); + self.headers.clear(); + self.extensions.borrow_mut().clear(); + } + + fn pool() -> &'static MessagePool { + REQUEST_POOL.with(|p| *p) + } +} + +impl RequestHead { + /// Message extensions + #[inline] + pub fn extensions(&self) -> Ref { + self.extensions.borrow() + } + + /// Mutable reference to a the message's extensions + #[inline] + pub fn extensions_mut(&self) -> RefMut { + self.extensions.borrow_mut() + } + + /// Read the message headers. + pub fn headers(&self) -> &HeaderMap { + &self.headers + } + + /// Mutable reference to the message headers. + pub fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.headers + } + + /// Is to uppercase headers with Camel-Case. + /// Befault is `false` + #[inline] + pub fn camel_case_headers(&self) -> bool { + self.flags.contains(Flags::CAMEL_CASE) + } + + /// Set `true` to send headers which are uppercased with Camel-Case. + #[inline] + pub fn set_camel_case_headers(&mut self, val: bool) { + if val { + self.flags.insert(Flags::CAMEL_CASE); + } else { + self.flags.remove(Flags::CAMEL_CASE); + } + } + + #[inline] + /// Set connection type of the message + pub fn set_connection_type(&mut self, ctype: ConnectionType) { + match ctype { + ConnectionType::Close => self.flags.insert(Flags::CLOSE), + ConnectionType::KeepAlive => self.flags.insert(Flags::KEEP_ALIVE), + ConnectionType::Upgrade => self.flags.insert(Flags::UPGRADE), + } + } + + #[inline] + /// Connection type + pub fn connection_type(&self) -> ConnectionType { + if self.flags.contains(Flags::CLOSE) { + ConnectionType::Close + } else if self.flags.contains(Flags::KEEP_ALIVE) { + ConnectionType::KeepAlive + } else if self.flags.contains(Flags::UPGRADE) { + ConnectionType::Upgrade + } else if self.version < Version::HTTP_11 { + ConnectionType::Close + } else { + ConnectionType::KeepAlive + } + } + + /// Connection upgrade status + pub fn upgrade(&self) -> bool { + if let Some(hdr) = self.headers().get(header::CONNECTION) { + if let Ok(s) = hdr.to_str() { + s.to_ascii_lowercase().contains("upgrade") + } else { + false + } + } else { + false + } + } + + #[inline] + /// Get response body chunking state + pub fn chunked(&self) -> bool { + !self.flags.contains(Flags::NO_CHUNKING) + } + + #[inline] + pub fn no_chunking(&mut self, val: bool) { + if val { + self.flags.insert(Flags::NO_CHUNKING); + } else { + self.flags.remove(Flags::NO_CHUNKING); + } + } + + #[inline] + /// Request contains `EXPECT` header + pub fn expect(&self) -> bool { + self.flags.contains(Flags::EXPECT) + } + + #[inline] + pub(crate) fn set_expect(&mut self) { + self.flags.insert(Flags::EXPECT); + } +} + +#[derive(Debug)] +pub enum RequestHeadType { + Owned(RequestHead), + Rc(Rc, Option), +} + +impl RequestHeadType { + pub fn extra_headers(&self) -> Option<&HeaderMap> { + match self { + RequestHeadType::Owned(_) => None, + RequestHeadType::Rc(_, headers) => headers.as_ref(), + } + } +} + +impl AsRef for RequestHeadType { + fn as_ref(&self) -> &RequestHead { + match self { + RequestHeadType::Owned(head) => &head, + RequestHeadType::Rc(head, _) => head.as_ref(), + } + } +} + +impl From for RequestHeadType { + fn from(head: RequestHead) -> Self { + RequestHeadType::Owned(head) + } +} + +#[derive(Debug)] +pub struct ResponseHead { + pub version: Version, + pub status: StatusCode, + pub headers: HeaderMap, + pub reason: Option<&'static str>, + pub(crate) extensions: RefCell, + flags: Flags, +} + +impl ResponseHead { + /// Create new instance of `ResponseHead` type + #[inline] + pub fn new(status: StatusCode) -> ResponseHead { + ResponseHead { + status, + version: Version::default(), + headers: HeaderMap::with_capacity(12), + reason: None, + flags: Flags::empty(), + extensions: RefCell::new(Extensions::new()), + } + } + + /// Message extensions + #[inline] + pub fn extensions(&self) -> Ref { + self.extensions.borrow() + } + + /// Mutable reference to a the message's extensions + #[inline] + pub fn extensions_mut(&self) -> RefMut { + self.extensions.borrow_mut() + } + + #[inline] + /// Read the message headers. + pub fn headers(&self) -> &HeaderMap { + &self.headers + } + + #[inline] + /// Mutable reference to the message headers. + pub fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.headers + } + + #[inline] + /// Set connection type of the message + pub fn set_connection_type(&mut self, ctype: ConnectionType) { + match ctype { + ConnectionType::Close => self.flags.insert(Flags::CLOSE), + ConnectionType::KeepAlive => self.flags.insert(Flags::KEEP_ALIVE), + ConnectionType::Upgrade => self.flags.insert(Flags::UPGRADE), + } + } + + #[inline] + pub fn connection_type(&self) -> ConnectionType { + if self.flags.contains(Flags::CLOSE) { + ConnectionType::Close + } else if self.flags.contains(Flags::KEEP_ALIVE) { + ConnectionType::KeepAlive + } else if self.flags.contains(Flags::UPGRADE) { + ConnectionType::Upgrade + } else if self.version < Version::HTTP_11 { + ConnectionType::Close + } else { + ConnectionType::KeepAlive + } + } + + #[inline] + /// Check if keep-alive is enabled + pub fn keep_alive(&self) -> bool { + self.connection_type() == ConnectionType::KeepAlive + } + + #[inline] + /// Check upgrade status of this message + pub fn upgrade(&self) -> bool { + self.connection_type() == ConnectionType::Upgrade + } + + /// Get custom reason for the response + #[inline] + pub fn reason(&self) -> &str { + if let Some(reason) = self.reason { + reason + } else { + self.status + .canonical_reason() + .unwrap_or("") + } + } + + #[inline] + pub(crate) fn ctype(&self) -> Option { + if self.flags.contains(Flags::CLOSE) { + Some(ConnectionType::Close) + } else if self.flags.contains(Flags::KEEP_ALIVE) { + Some(ConnectionType::KeepAlive) + } else if self.flags.contains(Flags::UPGRADE) { + Some(ConnectionType::Upgrade) + } else { + None + } + } + + #[inline] + /// Get response body chunking state + pub fn chunked(&self) -> bool { + !self.flags.contains(Flags::NO_CHUNKING) + } + + #[inline] + /// Set no chunking for payload + pub fn no_chunking(&mut self, val: bool) { + if val { + self.flags.insert(Flags::NO_CHUNKING); + } else { + self.flags.remove(Flags::NO_CHUNKING); + } + } +} + +pub struct Message { + head: Rc, +} + +impl Message { + /// Get new message from the pool of objects + pub fn new() -> Self { + T::pool().get_message() + } +} + +impl Clone for Message { + fn clone(&self) -> Self { + Message { + head: self.head.clone(), + } + } +} + +impl std::ops::Deref for Message { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.head.as_ref() + } +} + +impl std::ops::DerefMut for Message { + fn deref_mut(&mut self) -> &mut Self::Target { + Rc::get_mut(&mut self.head).expect("Multiple copies exist") + } +} + +impl Drop for Message { + fn drop(&mut self) { + if Rc::strong_count(&self.head) == 1 { + T::pool().release(self.head.clone()); + } + } +} + +pub(crate) struct BoxedResponseHead { + head: Option>, +} + +impl BoxedResponseHead { + /// Get new message from the pool of objects + pub fn new(status: StatusCode) -> Self { + RESPONSE_POOL.with(|p| p.get_message(status)) + } + + pub(crate) fn take(&mut self) -> Self { + BoxedResponseHead { + head: self.head.take(), + } + } +} + +impl std::ops::Deref for BoxedResponseHead { + type Target = ResponseHead; + + fn deref(&self) -> &Self::Target { + self.head.as_ref().unwrap() + } +} + +impl std::ops::DerefMut for BoxedResponseHead { + fn deref_mut(&mut self) -> &mut Self::Target { + self.head.as_mut().unwrap() + } +} + +impl Drop for BoxedResponseHead { + fn drop(&mut self) { + if let Some(head) = self.head.take() { + RESPONSE_POOL.with(move |p| p.release(head)) + } + } +} + +#[doc(hidden)] +/// Request's objects pool +pub struct MessagePool(RefCell>>); + +#[doc(hidden)] +#[allow(clippy::vec_box)] +/// Request's objects pool +pub struct BoxedResponsePool(RefCell>>); + +thread_local!(static REQUEST_POOL: &'static MessagePool = MessagePool::::create()); +thread_local!(static RESPONSE_POOL: &'static BoxedResponsePool = BoxedResponsePool::create()); + +impl MessagePool { + fn create() -> &'static MessagePool { + let pool = MessagePool(RefCell::new(Vec::with_capacity(128))); + Box::leak(Box::new(pool)) + } + + /// Get message from the pool + #[inline] + fn get_message(&'static self) -> Message { + if let Some(mut msg) = self.0.borrow_mut().pop() { + if let Some(r) = Rc::get_mut(&mut msg) { + r.clear(); + } + Message { head: msg } + } else { + Message { + head: Rc::new(T::default()), + } + } + } + + #[inline] + /// Release request instance + fn release(&self, msg: Rc) { + let v = &mut self.0.borrow_mut(); + if v.len() < 128 { + v.push(msg); + } + } +} + +impl BoxedResponsePool { + fn create() -> &'static BoxedResponsePool { + let pool = BoxedResponsePool(RefCell::new(Vec::with_capacity(128))); + Box::leak(Box::new(pool)) + } + + /// Get message from the pool + #[inline] + fn get_message(&'static self, status: StatusCode) -> BoxedResponseHead { + if let Some(mut head) = self.0.borrow_mut().pop() { + head.reason = None; + head.status = status; + head.headers.clear(); + head.flags = Flags::empty(); + BoxedResponseHead { head: Some(head) } + } else { + BoxedResponseHead { + head: Some(Box::alloc().init(ResponseHead::new(status))), + } + } + } + + #[inline] + /// Release request instance + fn release(&self, msg: Box) { + let v = &mut self.0.borrow_mut(); + if v.len() < 128 { + msg.extensions.borrow_mut().clear(); + v.push(msg); + } + } +} diff --git a/actix-http/src/payload.rs b/actix-http/src/payload.rs new file mode 100644 index 000000000..b3ec04d11 --- /dev/null +++ b/actix-http/src/payload.rs @@ -0,0 +1,67 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::Bytes; +use futures::Stream; +use h2::RecvStream; + +use crate::error::PayloadError; + +/// Type represent boxed payload +pub type PayloadStream = Pin>>>; + +/// Type represent streaming payload +pub enum Payload { + None, + H1(crate::h1::Payload), + H2(crate::h2::Payload), + Stream(S), +} + +impl From for Payload { + fn from(v: crate::h1::Payload) -> Self { + Payload::H1(v) + } +} + +impl From for Payload { + fn from(v: crate::h2::Payload) -> Self { + Payload::H2(v) + } +} + +impl From for Payload { + fn from(v: RecvStream) -> Self { + Payload::H2(crate::h2::Payload::new(v)) + } +} + +impl From for Payload { + fn from(pl: PayloadStream) -> Self { + Payload::Stream(pl) + } +} + +impl Payload { + /// Takes current payload and replaces it with `None` value + pub fn take(&mut self) -> Payload { + std::mem::replace(self, Payload::None) + } +} + +impl Stream for Payload +where + S: Stream> + Unpin, +{ + type Item = Result; + + #[inline] + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match self.get_mut() { + Payload::None => Poll::Ready(None), + Payload::H1(ref mut pl) => pl.readany(cx), + Payload::H2(ref mut pl) => Pin::new(pl).poll_next(cx), + Payload::Stream(ref mut pl) => Pin::new(pl).poll_next(cx), + } + } +} diff --git a/actix-http/src/request.rs b/actix-http/src/request.rs new file mode 100644 index 000000000..77ece01c5 --- /dev/null +++ b/actix-http/src/request.rs @@ -0,0 +1,209 @@ +use std::cell::{Ref, RefMut}; +use std::{fmt, net}; + +use http::{header, Method, Uri, Version}; + +use crate::extensions::Extensions; +use crate::header::HeaderMap; +use crate::httpmessage::HttpMessage; +use crate::message::{Message, RequestHead}; +use crate::payload::{Payload, PayloadStream}; + +/// Request +pub struct Request

    { + pub(crate) payload: Payload

    , + pub(crate) head: Message, +} + +impl

    HttpMessage for Request

    { + type Stream = P; + + #[inline] + fn headers(&self) -> &HeaderMap { + &self.head().headers + } + + /// Request extensions + #[inline] + fn extensions(&self) -> Ref { + self.head.extensions() + } + + /// Mutable reference to a the request's extensions + #[inline] + fn extensions_mut(&self) -> RefMut { + self.head.extensions_mut() + } + + fn take_payload(&mut self) -> Payload

    { + std::mem::replace(&mut self.payload, Payload::None) + } +} + +impl From> for Request { + fn from(head: Message) -> Self { + Request { + head, + payload: Payload::None, + } + } +} + +impl Request { + /// Create new Request instance + pub fn new() -> Request { + Request { + head: Message::new(), + payload: Payload::None, + } + } +} + +impl

    Request

    { + /// Create new Request instance + pub fn with_payload(payload: Payload

    ) -> Request

    { + Request { + payload, + head: Message::new(), + } + } + + /// Create new Request instance + pub fn replace_payload(self, payload: Payload) -> (Request, Payload

    ) { + let pl = self.payload; + ( + Request { + payload, + head: self.head, + }, + pl, + ) + } + + /// Get request's payload + pub fn payload(&mut self) -> &mut Payload

    { + &mut self.payload + } + + /// Get request's payload + pub fn take_payload(&mut self) -> Payload

    { + std::mem::replace(&mut self.payload, Payload::None) + } + + /// Split request into request head and payload + pub fn into_parts(self) -> (Message, Payload

    ) { + (self.head, self.payload) + } + + #[inline] + /// Http message part of the request + pub fn head(&self) -> &RequestHead { + &*self.head + } + + #[inline] + #[doc(hidden)] + /// Mutable reference to a http message part of the request + pub fn head_mut(&mut self) -> &mut RequestHead { + &mut *self.head + } + + /// Mutable reference to the message's headers. + pub fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.head_mut().headers + } + + /// Request's uri. + #[inline] + pub fn uri(&self) -> &Uri { + &self.head().uri + } + + /// Mutable reference to the request's uri. + #[inline] + pub fn uri_mut(&mut self) -> &mut Uri { + &mut self.head_mut().uri + } + + /// Read the Request method. + #[inline] + pub fn method(&self) -> &Method { + &self.head().method + } + + /// Read the Request Version. + #[inline] + pub fn version(&self) -> Version { + self.head().version + } + + /// The target path of this Request. + #[inline] + pub fn path(&self) -> &str { + self.head().uri.path() + } + + /// Check if request requires connection upgrade + #[inline] + pub fn upgrade(&self) -> bool { + if let Some(conn) = self.head().headers.get(header::CONNECTION) { + if let Ok(s) = conn.to_str() { + return s.to_lowercase().contains("upgrade"); + } + } + self.head().method == Method::CONNECT + } + + /// Peer socket address + /// + /// Peer address is actual socket address, if proxy is used in front of + /// actix http server, then peer address would be address of this proxy. + #[inline] + pub fn peer_addr(&self) -> Option { + self.head().peer_addr + } +} + +impl

    fmt::Debug for Request

    Welcome

    - session counter = {} - -"#, counter); - - // response - Ok(HttpResponse::build(StatusCode::OK) - .content_type("text/html; charset=utf-8") - .body(&html)) - -} - -/// 404 handler -fn p404(req: HttpRequest) -> Result { - - // html - let html = r#"actix - basics - - back to home -

    404

    - -"#; - - // response - Ok(HttpResponse::build(StatusCode::NOT_FOUND) - .content_type("text/html; charset=utf-8") - .body(html)) -} - - -/// async handler -fn index_async(req: HttpRequest) -> FutureResult -{ - println!("{:?}", req); - - result(Ok(HttpResponse::Ok() - .content_type("text/html") - .body(format!("Hello {}!", req.match_info().get("name").unwrap())))) -} - -/// handler with path parameters like `/user/{name}/` -fn with_param(req: HttpRequest) -> HttpResponse -{ - println!("{:?}", req); - - HttpResponse::Ok() - .content_type("test/plain") - .body(format!("Hello {}!", req.match_info().get("name").unwrap())) -} - -fn main() { - env::set_var("RUST_LOG", "actix_web=debug"); - env::set_var("RUST_BACKTRACE", "1"); - env_logger::init(); - let sys = actix::System::new("basic-example"); - - let addr = HttpServer::new( - || App::new() - // enable logger - .middleware(middleware::Logger::default()) - // cookie session middleware - .middleware(middleware::SessionStorage::new( - middleware::CookieSessionBackend::build(&[0; 32]) - .secure(false) - .finish() - )) - // register favicon - .resource("/favicon.ico", |r| r.f(favicon)) - // register simple route, handle all methods - .resource("/index.html", |r| r.f(index)) - // with path parameters - .resource("/user/{name}/", |r| r.method(Method::GET).f(with_param)) - // async handler - .resource("/async/{name}", |r| r.method(Method::GET).a(index_async)) - .resource("/test", |r| r.f(|req| { - match *req.method() { - Method::GET => HttpResponse::Ok(), - Method::POST => HttpResponse::MethodNotAllowed(), - _ => HttpResponse::NotFound(), - } - })) - .resource("/error.html", |r| r.f(|req| { - error::InternalError::new( - io::Error::new(io::ErrorKind::Other, "test"), StatusCode::OK) - })) - // static files - .handler("/static/", fs::StaticFiles::new("../static/", true)) - // redirect - .resource("/", |r| r.method(Method::GET).f(|req| { - println!("{:?}", req); - - HttpResponse::Found() - .header("LOCATION", "/index.html") - .finish() - })) - // default - .default_resource(|r| { - r.method(Method::GET).f(p404); - r.route().filter(pred::Not(pred::Get())).f( - |req| HttpResponse::MethodNotAllowed()); - })) - - .bind("127.0.0.1:8080").expect("Can not bind to 127.0.0.1:8080") - .shutdown_timeout(0) // <- Set shutdown timeout to 0 seconds (default 60s) - .start(); - - println!("Starting http server: 127.0.0.1:8080"); - let _ = sys.run(); -} diff --git a/examples/client.rs b/examples/client.rs new file mode 100644 index 000000000..874e08e1b --- /dev/null +++ b/examples/client.rs @@ -0,0 +1,25 @@ +use actix_http::Error; + +#[actix_rt::main] +async fn main() -> Result<(), Error> { + std::env::set_var("RUST_LOG", "actix_http=trace"); + env_logger::init(); + + let client = awc::Client::new(); + + // Create request builder, configure request and send + let mut response = client + .get("https://www.rust-lang.org/") + .header("User-Agent", "Actix-web") + .send() + .await?; + + // server http response + println!("Response: {:?}", response); + + // read response body + let body = response.body().await?; + println!("Downloaded: {:?} bytes", body.len()); + + Ok(()) +} diff --git a/examples/diesel/.env b/examples/diesel/.env deleted file mode 100644 index 1fbc5af72..000000000 --- a/examples/diesel/.env +++ /dev/null @@ -1 +0,0 @@ -DATABASE_URL=file:test.db diff --git a/examples/diesel/Cargo.toml b/examples/diesel/Cargo.toml deleted file mode 100644 index 2551b9628..000000000 --- a/examples/diesel/Cargo.toml +++ /dev/null @@ -1,20 +0,0 @@ -[package] -name = "diesel-example" -version = "0.1.0" -authors = ["Nikolay Kim "] -workspace = "../.." - -[dependencies] -env_logger = "0.5" -actix = "0.5" -actix-web = { path = "../../" } - -futures = "0.1" -uuid = { version = "0.5", features = ["serde", "v4"] } -serde = "1.0" -serde_json = "1.0" -serde_derive = "1.0" - -diesel = { version = "^1.1.0", features = ["sqlite", "r2d2"] } -r2d2 = "0.8" -dotenv = "0.10" diff --git a/examples/diesel/README.md b/examples/diesel/README.md deleted file mode 100644 index 922ba1e3b..000000000 --- a/examples/diesel/README.md +++ /dev/null @@ -1,43 +0,0 @@ -# diesel - -Diesel's `Getting Started` guide using SQLite for Actix web - -## Usage - -### init database sqlite - -```bash -cargo install diesel_cli --no-default-features --features sqlite -cd actix-web/examples/diesel -echo "DATABASE_URL=file:test.db" > .env -diesel migration run -``` - -### server - -```bash -# if ubuntu : sudo apt-get install libsqlite3-dev -# if fedora : sudo dnf install libsqlite3x-devel -cd actix-web/examples/diesel -cargo run (or ``cargo watch -x run``) -# Started http server: 127.0.0.1:8080 -``` - -### web client - -[http://127.0.0.1:8080/NAME](http://127.0.0.1:8080/NAME) - -### sqlite client - -```bash -# if ubuntu : sudo apt-get install sqlite3 -# if fedora : sudo dnf install sqlite3x -sqlite3 test.db -sqlite> .tables -sqlite> select * from users; -``` - - -## Postgresql - -You will also find another complete example of diesel+postgresql on [https://github.com/TechEmpower/FrameworkBenchmarks/tree/master/frameworks/Rust/actix](https://github.com/TechEmpower/FrameworkBenchmarks/tree/master/frameworks/Rust/actix) \ No newline at end of file diff --git a/examples/diesel/migrations/20170124012402_create_users/down.sql b/examples/diesel/migrations/20170124012402_create_users/down.sql deleted file mode 100644 index 9951735c4..000000000 --- a/examples/diesel/migrations/20170124012402_create_users/down.sql +++ /dev/null @@ -1 +0,0 @@ -DROP TABLE users diff --git a/examples/diesel/migrations/20170124012402_create_users/up.sql b/examples/diesel/migrations/20170124012402_create_users/up.sql deleted file mode 100644 index d88d44fb7..000000000 --- a/examples/diesel/migrations/20170124012402_create_users/up.sql +++ /dev/null @@ -1,4 +0,0 @@ -CREATE TABLE users ( - id VARCHAR NOT NULL PRIMARY KEY, - name VARCHAR NOT NULL -) diff --git a/examples/diesel/src/db.rs b/examples/diesel/src/db.rs deleted file mode 100644 index 13b376823..000000000 --- a/examples/diesel/src/db.rs +++ /dev/null @@ -1,55 +0,0 @@ -//! Db executor actor -use uuid; -use diesel; -use actix_web::*; -use actix::prelude::*; -use diesel::prelude::*; -use diesel::r2d2::{ Pool, ConnectionManager }; - -use models; -use schema; - -/// This is db executor actor. We are going to run 3 of them in parallel. -pub struct DbExecutor(pub Pool>); - -/// This is only message that this actor can handle, but it is easy to extend number of -/// messages. -pub struct CreateUser { - pub name: String, -} - -impl Message for CreateUser { - type Result = Result; -} - -impl Actor for DbExecutor { - type Context = SyncContext; -} - -impl Handler for DbExecutor { - type Result = Result; - - fn handle(&mut self, msg: CreateUser, _: &mut Self::Context) -> Self::Result { - use self::schema::users::dsl::*; - - let uuid = format!("{}", uuid::Uuid::new_v4()); - let new_user = models::NewUser { - id: &uuid, - name: &msg.name, - }; - - let conn: &SqliteConnection = &self.0.get().unwrap(); - - diesel::insert_into(users) - .values(&new_user) - .execute(conn) - .expect("Error inserting person"); - - let mut items = users - .filter(id.eq(&uuid)) - .load::(conn) - .expect("Error loading person"); - - Ok(items.pop().unwrap()) - } -} diff --git a/examples/diesel/src/main.rs b/examples/diesel/src/main.rs deleted file mode 100644 index 2fd7087ce..000000000 --- a/examples/diesel/src/main.rs +++ /dev/null @@ -1,78 +0,0 @@ -//! Actix web diesel example -//! -//! Diesel does not support tokio, so we have to run it in separate threads. -//! Actix supports sync actors by default, so we going to create sync actor that use diesel. -//! Technically sync actors are worker style actors, multiple of them -//! can run in parallel and process messages from same queue. -extern crate serde; -extern crate serde_json; -#[macro_use] -extern crate serde_derive; -#[macro_use] -extern crate diesel; -extern crate r2d2; -extern crate uuid; -extern crate futures; -extern crate actix; -extern crate actix_web; -extern crate env_logger; - -use actix::prelude::*; -use actix_web::{http, server, middleware, - App, Path, State, HttpResponse, AsyncResponder, FutureResponse}; - -use diesel::prelude::*; -use diesel::r2d2::{ Pool, ConnectionManager }; -use futures::future::Future; - -mod db; -mod models; -mod schema; - -use db::{CreateUser, DbExecutor}; - - -/// State with DbExecutor address -struct AppState { - db: Addr, -} - -/// Async request handler -fn index(name: Path, state: State) -> FutureResponse { - // send async `CreateUser` message to a `DbExecutor` - state.db.send(CreateUser{name: name.into_inner()}) - .from_err() - .and_then(|res| { - match res { - Ok(user) => Ok(HttpResponse::Ok().json(user)), - Err(_) => Ok(HttpResponse::InternalServerError().into()) - } - }) - .responder() -} - -fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=info"); - env_logger::init(); - let sys = actix::System::new("diesel-example"); - - // Start 3 db executor actors - let manager = ConnectionManager::::new("test.db"); - let pool = r2d2::Pool::builder().build(manager).expect("Failed to create pool."); - - let addr = SyncArbiter::start(3, move || { - DbExecutor(pool.clone()) - }); - - // Start http server - server::new(move || { - App::with_state(AppState{db: addr.clone()}) - // enable logger - .middleware(middleware::Logger::default()) - .resource("/{name}", |r| r.method(http::Method::GET).with2(index))}) - .bind("127.0.0.1:8080").unwrap() - .start(); - - println!("Started http server: 127.0.0.1:8080"); - let _ = sys.run(); -} diff --git a/examples/diesel/src/models.rs b/examples/diesel/src/models.rs deleted file mode 100644 index 315d59f13..000000000 --- a/examples/diesel/src/models.rs +++ /dev/null @@ -1,14 +0,0 @@ -use super::schema::users; - -#[derive(Serialize, Queryable)] -pub struct User { - pub id: String, - pub name: String, -} - -#[derive(Insertable)] -#[table_name = "users"] -pub struct NewUser<'a> { - pub id: &'a str, - pub name: &'a str, -} diff --git a/examples/diesel/src/schema.rs b/examples/diesel/src/schema.rs deleted file mode 100644 index 51aa40b89..000000000 --- a/examples/diesel/src/schema.rs +++ /dev/null @@ -1,6 +0,0 @@ -table! { - users (id) { - id -> Text, - name -> Text, - } -} diff --git a/examples/diesel/test.db b/examples/diesel/test.db deleted file mode 100644 index 65e590a6e..000000000 Binary files a/examples/diesel/test.db and /dev/null differ diff --git a/examples/hello-world/Cargo.toml b/examples/hello-world/Cargo.toml deleted file mode 100644 index 156a1ada6..000000000 --- a/examples/hello-world/Cargo.toml +++ /dev/null @@ -1,10 +0,0 @@ -[package] -name = "hello-world" -version = "0.1.0" -authors = ["Nikolay Kim "] -workspace = "../.." - -[dependencies] -env_logger = "0.5" -actix = "0.5" -actix-web = { path = "../../" } diff --git a/examples/hello-world/src/main.rs b/examples/hello-world/src/main.rs deleted file mode 100644 index 2af478947..000000000 --- a/examples/hello-world/src/main.rs +++ /dev/null @@ -1,28 +0,0 @@ -extern crate actix; -extern crate actix_web; -extern crate env_logger; - -use actix_web::{App, HttpRequest, server, middleware}; - - -fn index(_req: HttpRequest) -> &'static str { - "Hello world!" -} - -fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=info"); - env_logger::init(); - let sys = actix::System::new("hello-world"); - - server::new( - || App::new() - // enable logger - .middleware(middleware::Logger::default()) - .resource("/index.html", |r| r.f(|_| "Hello world!")) - .resource("/", |r| r.f(index))) - .bind("127.0.0.1:8080").unwrap() - .start(); - - println!("Started http server: 127.0.0.1:8080"); - let _ = sys.run(); -} diff --git a/examples/http-proxy/Cargo.toml b/examples/http-proxy/Cargo.toml deleted file mode 100644 index 7b9597bff..000000000 --- a/examples/http-proxy/Cargo.toml +++ /dev/null @@ -1,11 +0,0 @@ -[package] -name = "http-proxy" -version = "0.1.0" -authors = ["Nikolay Kim "] -workspace = "../.." - -[dependencies] -env_logger = "0.5" -futures = "0.1" -actix = "0.5" -actix-web = { path = "../../", features=["alpn"] } diff --git a/examples/http-proxy/src/main.rs b/examples/http-proxy/src/main.rs deleted file mode 100644 index a69fff88d..000000000 --- a/examples/http-proxy/src/main.rs +++ /dev/null @@ -1,57 +0,0 @@ -extern crate actix; -extern crate actix_web; -extern crate futures; -extern crate env_logger; - -use futures::{Future, Stream}; -use actix_web::{client, server, middleware, - App, AsyncResponder, Body, - HttpRequest, HttpResponse, HttpMessage, Error}; - -/// Stream client request response and then send body to a server response -fn index(_req: HttpRequest) -> Box> { - client::ClientRequest::get("https://www.rust-lang.org/en-US/") - .finish().unwrap() - .send() - .map_err(Error::from) // <- convert SendRequestError to an Error - .and_then( - |resp| resp.body() // <- this is MessageBody type, resolves to complete body - .from_err() // <- convert PayloadError to a Error - .and_then(|body| { // <- we got complete body, now send as server response - Ok(HttpResponse::Ok().body(body)) - })) - .responder() -} - -/// streaming client request to a streaming server response -fn streaming(_req: HttpRequest) -> Box> { - // send client request - client::ClientRequest::get("https://www.rust-lang.org/en-US/") - .finish().unwrap() - .send() // <- connect to host and send request - .map_err(Error::from) // <- convert SendRequestError to an Error - .and_then(|resp| { // <- we received client response - Ok(HttpResponse::Ok() - // read one chunk from client response and send this chunk to a server response - // .from_err() converts PayloadError to a Error - .body(Body::Streaming(Box::new(resp.from_err())))) - }) - .responder() -} - -fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=info"); - env_logger::init(); - let sys = actix::System::new("http-proxy"); - - let _addr = server::new( - || App::new() - .middleware(middleware::Logger::default()) - .resource("/streaming", |r| r.f(streaming)) - .resource("/", |r| r.f(index))) - .bind("127.0.0.1:8080").unwrap() - .start(); - - println!("Started http server: 127.0.0.1:8080"); - let _ = sys.run(); -} diff --git a/examples/json/Cargo.toml b/examples/json/Cargo.toml deleted file mode 100644 index bf117c704..000000000 --- a/examples/json/Cargo.toml +++ /dev/null @@ -1,18 +0,0 @@ -[package] -name = "json-example" -version = "0.1.0" -authors = ["Nikolay Kim "] -workspace = "../.." - -[dependencies] -bytes = "0.4" -futures = "0.1" -env_logger = "*" - -serde = "1.0" -serde_json = "1.0" -serde_derive = "1.0" -json = "*" - -actix = "0.5" -actix-web = { path="../../" } diff --git a/examples/json/README.md b/examples/json/README.md deleted file mode 100644 index 167c3909f..000000000 --- a/examples/json/README.md +++ /dev/null @@ -1,48 +0,0 @@ -# json - -Json's `Getting Started` guide using json (serde-json or json-rust) for Actix web - -## Usage - -### server - -```bash -cd actix-web/examples/json -cargo run -# Started http server: 127.0.0.1:8080 -``` - -### web client - -With [Postman](https://www.getpostman.com/) or [Rested](moz-extension://60daeb1c-5b1b-4afd-9842-0579ed34dfcb/dist/index.html) - -- POST / (embed serde-json): - - - method : ``POST`` - - url : ``http://127.0.0.1:8080/`` - - header : ``Content-Type`` = ``application/json`` - - body (raw) : ``{"name": "Test user", "number": 100}`` - -- POST /manual (manual serde-json): - - - method : ``POST`` - - url : ``http://127.0.0.1:8080/manual`` - - header : ``Content-Type`` = ``application/json`` - - body (raw) : ``{"name": "Test user", "number": 100}`` - -- POST /mjsonrust (manual json-rust): - - - method : ``POST`` - - url : ``http://127.0.0.1:8080/mjsonrust`` - - header : ``Content-Type`` = ``application/json`` - - body (raw) : ``{"name": "Test user", "number": 100}`` (you can also test ``{notjson}``) - -### python client - -- ``pip install aiohttp`` -- ``python client.py`` - -if ubuntu : - -- ``pip3 install aiohttp`` -- ``python3 client.py`` diff --git a/examples/json/client.py b/examples/json/client.py deleted file mode 100644 index e89ffe096..000000000 --- a/examples/json/client.py +++ /dev/null @@ -1,18 +0,0 @@ -# This script could be used for actix-web multipart example test -# just start server and run client.py - -import json -import asyncio -import aiohttp - -async def req(): - resp = await aiohttp.ClientSession().request( - "post", 'http://localhost:8080/', - data=json.dumps({"name": "Test user", "number": 100}), - headers={"content-type": "application/json"}) - print(str(resp)) - print(await resp.text()) - assert 200 == resp.status - - -asyncio.get_event_loop().run_until_complete(req()) diff --git a/examples/json/src/main.rs b/examples/json/src/main.rs deleted file mode 100644 index 34730366e..000000000 --- a/examples/json/src/main.rs +++ /dev/null @@ -1,107 +0,0 @@ -extern crate actix; -extern crate actix_web; -extern crate bytes; -extern crate futures; -extern crate env_logger; -extern crate serde_json; -#[macro_use] extern crate serde_derive; -#[macro_use] extern crate json; - -use actix_web::{ - middleware, http, error, server, - App, AsyncResponder, HttpRequest, HttpResponse, HttpMessage, Error, Json}; - -use bytes::BytesMut; -use futures::{Future, Stream}; -use json::JsonValue; - -#[derive(Debug, Serialize, Deserialize)] -struct MyObj { - name: String, - number: i32, -} - -/// This handler uses `HttpRequest::json()` for loading serde json object. -fn index(req: HttpRequest) -> Box> { - req.json() - .from_err() // convert all errors into `Error` - .and_then(|val: MyObj| { - println!("model: {:?}", val); - Ok(HttpResponse::Ok().json(val)) // <- send response - }) - .responder() -} - -/// This handler uses `With` helper for loading serde json object. -fn extract_item(item: Json) -> HttpResponse { - println!("model: {:?}", &item); - HttpResponse::Ok().json(item.0) // <- send response -} - -const MAX_SIZE: usize = 262_144; // max payload size is 256k - -/// This handler manually load request payload and parse serde json -fn index_manual(req: HttpRequest) -> Box> { - // HttpRequest is stream of Bytes objects - req - // `Future::from_err` acts like `?` in that it coerces the error type from - // the future into the final error type - .from_err() - - // `fold` will asynchronously read each chunk of the request body and - // call supplied closure, then it resolves to result of closure - .fold(BytesMut::new(), move |mut body, chunk| { - // limit max size of in-memory payload - if (body.len() + chunk.len()) > MAX_SIZE { - Err(error::ErrorBadRequest("overflow")) - } else { - body.extend_from_slice(&chunk); - Ok(body) - } - }) - // `Future::and_then` can be used to merge an asynchronous workflow with a - // synchronous workflow - .and_then(|body| { - // body is loaded, now we can deserialize serde-json - let obj = serde_json::from_slice::(&body)?; - Ok(HttpResponse::Ok().json(obj)) // <- send response - }) - .responder() -} - -/// This handler manually load request payload and parse json-rust -fn index_mjsonrust(req: HttpRequest) -> Box> { - req.concat2() - .from_err() - .and_then(|body| { - // body is loaded, now we can deserialize json-rust - let result = json::parse(std::str::from_utf8(&body).unwrap()); // return Result - let injson: JsonValue = match result { Ok(v) => v, Err(e) => object!{"err" => e.to_string() } }; - Ok(HttpResponse::Ok() - .content_type("application/json") - .body(injson.dump())) - }) - .responder() -} - -fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=info"); - let _ = env_logger::init(); - let sys = actix::System::new("json-example"); - - let _ = server::new(|| { - App::new() - // enable logger - .middleware(middleware::Logger::default()) - .resource("/extractor/{name}/{number}/", - |r| r.method(http::Method::GET).with(extract_item)) - .resource("/manual", |r| r.method(http::Method::POST).f(index_manual)) - .resource("/mjsonrust", |r| r.method(http::Method::POST).f(index_mjsonrust)) - .resource("/", |r| r.method(http::Method::POST).f(index))}) - .bind("127.0.0.1:8080").unwrap() - .shutdown_timeout(1) - .start(); - - println!("Started http server: 127.0.0.1:8080"); - let _ = sys.run(); -} diff --git a/examples/juniper/Cargo.toml b/examples/juniper/Cargo.toml deleted file mode 100644 index 9e52b0a83..000000000 --- a/examples/juniper/Cargo.toml +++ /dev/null @@ -1,17 +0,0 @@ -[package] -name = "juniper-example" -version = "0.1.0" -authors = ["pyros2097 "] -workspace = "../.." - -[dependencies] -env_logger = "0.5" -actix = "0.5" -actix-web = { path = "../../" } - -futures = "0.1" -serde = "1.0" -serde_json = "1.0" -serde_derive = "1.0" - -juniper = "0.9.2" diff --git a/examples/juniper/README.md b/examples/juniper/README.md deleted file mode 100644 index 2ac0eac4e..000000000 --- a/examples/juniper/README.md +++ /dev/null @@ -1,15 +0,0 @@ -# juniper - -Juniper integration for Actix web - -### server - -```bash -cd actix-web/examples/juniper -cargo run (or ``cargo watch -x run``) -# Started http server: 127.0.0.1:8080 -``` - -### web client - -[http://127.0.0.1:8080/graphiql](http://127.0.0.1:8080/graphiql) diff --git a/examples/juniper/src/main.rs b/examples/juniper/src/main.rs deleted file mode 100644 index 97319afea..000000000 --- a/examples/juniper/src/main.rs +++ /dev/null @@ -1,112 +0,0 @@ -//! Actix web juniper example -//! -//! A simple example integrating juniper in actix-web -extern crate serde; -extern crate serde_json; -#[macro_use] -extern crate serde_derive; -#[macro_use] -extern crate juniper; -extern crate futures; -extern crate actix; -extern crate actix_web; -extern crate env_logger; - -use actix::prelude::*; -use actix_web::{ - middleware, http::{self, header::CONTENT_TYPE}, server, - App, AsyncResponder, HttpRequest, HttpResponse, HttpMessage, Error}; -use juniper::http::graphiql::graphiql_source; -use juniper::http::GraphQLRequest; - -use futures::future::Future; - -mod schema; - -use schema::Schema; -use schema::create_schema; - -struct State { - executor: Addr, -} - -#[derive(Serialize, Deserialize)] -pub struct GraphQLData(GraphQLRequest); - -impl Message for GraphQLData { - type Result = Result; -} - -pub struct GraphQLExecutor { - schema: std::sync::Arc -} - -impl GraphQLExecutor { - fn new(schema: std::sync::Arc) -> GraphQLExecutor { - GraphQLExecutor { - schema: schema, - } - } -} - -impl Actor for GraphQLExecutor { - type Context = SyncContext; -} - -impl Handler for GraphQLExecutor { - type Result = Result; - - fn handle(&mut self, msg: GraphQLData, _: &mut Self::Context) -> Self::Result { - let res = msg.0.execute(&self.schema, &()); - let res_text = serde_json::to_string(&res)?; - Ok(res_text) - } -} - -fn graphiql(_req: HttpRequest) -> Result { - let html = graphiql_source("http://127.0.0.1:8080/graphql"); - Ok(HttpResponse::Ok() - .content_type("text/html; charset=utf-8") - .body(html)) -} - -fn graphql(req: HttpRequest) -> Box> { - let executor = req.state().executor.clone(); - req.json() - .from_err() - .and_then(move |val: GraphQLData| { - executor.send(val) - .from_err() - .and_then(|res| { - match res { - Ok(user) => Ok(HttpResponse::Ok().header(CONTENT_TYPE, "application/json").body(user)), - Err(_) => Ok(HttpResponse::InternalServerError().into()) - } - }) - }) - .responder() -} - -fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=info"); - let _ = env_logger::init(); - let sys = actix::System::new("juniper-example"); - - let schema = std::sync::Arc::new(create_schema()); - let addr = SyncArbiter::start(3, move || { - GraphQLExecutor::new(schema.clone()) - }); - - // Start http server - let _ = server::new(move || { - App::with_state(State{executor: addr.clone()}) - // enable logger - .middleware(middleware::Logger::default()) - .resource("/graphql", |r| r.method(http::Method::POST).h(graphql)) - .resource("/graphiql", |r| r.method(http::Method::GET).h(graphiql))}) - .bind("127.0.0.1:8080").unwrap() - .start(); - - println!("Started http server: 127.0.0.1:8080"); - let _ = sys.run(); -} diff --git a/examples/juniper/src/schema.rs b/examples/juniper/src/schema.rs deleted file mode 100644 index 2b4cf3042..000000000 --- a/examples/juniper/src/schema.rs +++ /dev/null @@ -1,58 +0,0 @@ -use juniper::FieldResult; -use juniper::RootNode; - -#[derive(GraphQLEnum)] -enum Episode { - NewHope, - Empire, - Jedi, -} - -#[derive(GraphQLObject)] -#[graphql(description = "A humanoid creature in the Star Wars universe")] -struct Human { - id: String, - name: String, - appears_in: Vec, - home_planet: String, -} - -#[derive(GraphQLInputObject)] -#[graphql(description = "A humanoid creature in the Star Wars universe")] -struct NewHuman { - name: String, - appears_in: Vec, - home_planet: String, -} - -pub struct QueryRoot; - -graphql_object!(QueryRoot: () |&self| { - field human(&executor, id: String) -> FieldResult { - Ok(Human{ - id: "1234".to_owned(), - name: "Luke".to_owned(), - appears_in: vec![Episode::NewHope], - home_planet: "Mars".to_owned(), - }) - } -}); - -pub struct MutationRoot; - -graphql_object!(MutationRoot: () |&self| { - field createHuman(&executor, new_human: NewHuman) -> FieldResult { - Ok(Human{ - id: "1234".to_owned(), - name: new_human.name, - appears_in: new_human.appears_in, - home_planet: new_human.home_planet, - }) - } -}); - -pub type Schema = RootNode<'static, QueryRoot, MutationRoot>; - -pub fn create_schema() -> Schema { - Schema::new(QueryRoot {}, MutationRoot {}) -} diff --git a/examples/multipart/Cargo.toml b/examples/multipart/Cargo.toml deleted file mode 100644 index b5235d7e7..000000000 --- a/examples/multipart/Cargo.toml +++ /dev/null @@ -1,15 +0,0 @@ -[package] -name = "multipart-example" -version = "0.1.0" -authors = ["Nikolay Kim "] -workspace = "../.." - -[[bin]] -name = "multipart" -path = "src/main.rs" - -[dependencies] -env_logger = "*" -futures = "0.1" -actix = "0.5" -actix-web = { path="../../" } diff --git a/examples/multipart/README.md b/examples/multipart/README.md deleted file mode 100644 index 348d28687..000000000 --- a/examples/multipart/README.md +++ /dev/null @@ -1,24 +0,0 @@ -# multipart - -Multipart's `Getting Started` guide for Actix web - -## Usage - -### server - -```bash -cd actix-web/examples/multipart -cargo run (or ``cargo watch -x run``) -# Started http server: 127.0.0.1:8080 -``` - -### client - -- ``pip install aiohttp`` -- ``python client.py`` -- you must see in server console multipart fields - -if ubuntu : - -- ``pip3 install aiohttp`` -- ``python3 client.py`` diff --git a/examples/multipart/client.py b/examples/multipart/client.py deleted file mode 100644 index afc07f17d..000000000 --- a/examples/multipart/client.py +++ /dev/null @@ -1,34 +0,0 @@ -# This script could be used for actix-web multipart example test -# just start server and run client.py - -import asyncio -import aiohttp - -async def req1(): - with aiohttp.MultipartWriter() as writer: - writer.append('test') - writer.append_json({'passed': True}) - - resp = await aiohttp.ClientSession().request( - "post", 'http://localhost:8080/multipart', - data=writer, headers=writer.headers) - print(resp) - assert 200 == resp.status - - -async def req2(): - with aiohttp.MultipartWriter() as writer: - writer.append('test') - writer.append_json({'passed': True}) - writer.append(open('src/main.rs')) - - resp = await aiohttp.ClientSession().request( - "post", 'http://localhost:8080/multipart', - data=writer, headers=writer.headers) - print(resp) - assert 200 == resp.status - - -loop = asyncio.get_event_loop() -loop.run_until_complete(req1()) -loop.run_until_complete(req2()) diff --git a/examples/multipart/src/main.rs b/examples/multipart/src/main.rs deleted file mode 100644 index cac76d30c..000000000 --- a/examples/multipart/src/main.rs +++ /dev/null @@ -1,61 +0,0 @@ -#![allow(unused_variables)] -extern crate actix; -extern crate actix_web; -extern crate env_logger; -extern crate futures; - -use actix::*; -use actix_web::{ - http, middleware, multipart, server, - App, AsyncResponder, HttpRequest, HttpResponse, HttpMessage, Error}; - -use futures::{Future, Stream}; -use futures::future::{result, Either}; - - -fn index(req: HttpRequest) -> Box> -{ - println!("{:?}", req); - - req.multipart() // <- get multipart stream for current request - .from_err() // <- convert multipart errors - .and_then(|item| { // <- iterate over multipart items - match item { - // Handle multipart Field - multipart::MultipartItem::Field(field) => { - println!("==== FIELD ==== {:?}", field); - - // Field in turn is stream of *Bytes* object - Either::A( - field.map_err(Error::from) - .map(|chunk| { - println!("-- CHUNK: \n{}", - std::str::from_utf8(&chunk).unwrap());}) - .finish()) - }, - multipart::MultipartItem::Nested(mp) => { - // Or item could be nested Multipart stream - Either::B(result(Ok(()))) - } - } - }) - .finish() // <- Stream::finish() combinator from actix - .map(|_| HttpResponse::Ok().into()) - .responder() -} - -fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=info"); - let _ = env_logger::init(); - let sys = actix::System::new("multipart-example"); - - let _ = server::new( - || App::new() - .middleware(middleware::Logger::default()) // <- logger - .resource("/multipart", |r| r.method(http::Method::POST).a(index))) - .bind("127.0.0.1:8080").unwrap() - .start(); - - println!("Starting http server: 127.0.0.1:8080"); - let _ = sys.run(); -} diff --git a/examples/protobuf/Cargo.toml b/examples/protobuf/Cargo.toml deleted file mode 100644 index 3bb56869f..000000000 --- a/examples/protobuf/Cargo.toml +++ /dev/null @@ -1,16 +0,0 @@ -[package] -name = "protobuf-example" -version = "0.1.0" -authors = ["kingxsp "] - -[dependencies] -bytes = "0.4" -futures = "0.1" -failure = "0.1" -env_logger = "*" - -prost = "0.2.0" -prost-derive = "0.2.0" - -actix = "0.5" -actix-web = { path="../../" } diff --git a/examples/protobuf/client.py b/examples/protobuf/client.py deleted file mode 100644 index ab91365d8..000000000 --- a/examples/protobuf/client.py +++ /dev/null @@ -1,66 +0,0 @@ -# just start server and run client.py - -# wget https://github.com/google/protobuf/releases/download/v3.5.1/protobuf-python-3.5.1.zip -# unzip protobuf-python-3.5.1.zip.1 -# cd protobuf-3.5.1/python/ -# python3.6 setup.py install - -# pip3.6 install --upgrade pip -# pip3.6 install aiohttp - -#!/usr/bin/env python -import test_pb2 -import traceback -import sys - -import asyncio -import aiohttp - -def op(): - try: - obj = test_pb2.MyObj() - obj.number = 9 - obj.name = 'USB' - - #Serialize - sendDataStr = obj.SerializeToString() - #print serialized string value - print('serialized string:', sendDataStr) - #------------------------# - # message transmission # - #------------------------# - receiveDataStr = sendDataStr - receiveData = test_pb2.MyObj() - - #Deserialize - receiveData.ParseFromString(receiveDataStr) - print('pares serialize string, return: devId = ', receiveData.number, ', name = ', receiveData.name) - except(Exception, e): - print(Exception, ':', e) - print(traceback.print_exc()) - errInfo = sys.exc_info() - print(errInfo[0], ':', errInfo[1]) - - -async def fetch(session): - obj = test_pb2.MyObj() - obj.number = 9 - obj.name = 'USB' - async with session.post('http://localhost:8080/', data=obj.SerializeToString(), - headers={"content-type": "application/protobuf"}) as resp: - print(resp.status) - data = await resp.read() - receiveObj = test_pb2.MyObj() - receiveObj.ParseFromString(data) - print(receiveObj) - -async def go(loop): - obj = test_pb2.MyObj() - obj.number = 9 - obj.name = 'USB' - async with aiohttp.ClientSession(loop=loop) as session: - await fetch(session) - -loop = asyncio.get_event_loop() -loop.run_until_complete(go(loop)) -loop.close() \ No newline at end of file diff --git a/examples/protobuf/src/main.rs b/examples/protobuf/src/main.rs deleted file mode 100644 index ae61e0e46..000000000 --- a/examples/protobuf/src/main.rs +++ /dev/null @@ -1,57 +0,0 @@ -extern crate actix; -extern crate actix_web; -extern crate bytes; -extern crate futures; -#[macro_use] -extern crate failure; -extern crate env_logger; -extern crate prost; -#[macro_use] -extern crate prost_derive; - -use futures::Future; -use actix_web::{ - http, middleware, server, - App, AsyncResponder, HttpRequest, HttpResponse, Error}; - -mod protobuf; -use protobuf::ProtoBufResponseBuilder; - - -#[derive(Clone, Debug, PartialEq, Message)] -pub struct MyObj { - #[prost(int32, tag="1")] - pub number: i32, - #[prost(string, tag="2")] - pub name: String, -} - - -/// This handler uses `ProtoBufMessage` for loading protobuf object. -fn index(req: HttpRequest) -> Box> { - protobuf::ProtoBufMessage::new(req) - .from_err() // convert all errors into `Error` - .and_then(|val: MyObj| { - println!("model: {:?}", val); - Ok(HttpResponse::Ok().protobuf(val)?) // <- send response - }) - .responder() -} - - -fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=info"); - env_logger::init(); - let sys = actix::System::new("protobuf-example"); - - let _ = server::new(|| { - App::new() - .middleware(middleware::Logger::default()) - .resource("/", |r| r.method(http::Method::POST).f(index))}) - .bind("127.0.0.1:8080").unwrap() - .shutdown_timeout(1) - .start(); - - println!("Started http server: 127.0.0.1:8080"); - let _ = sys.run(); -} diff --git a/examples/protobuf/src/protobuf.rs b/examples/protobuf/src/protobuf.rs deleted file mode 100644 index 2b117fe76..000000000 --- a/examples/protobuf/src/protobuf.rs +++ /dev/null @@ -1,168 +0,0 @@ -use bytes::{Bytes, BytesMut}; -use futures::{Poll, Future, Stream}; - -use bytes::IntoBuf; -use prost::Message; -use prost::DecodeError as ProtoBufDecodeError; -use prost::EncodeError as ProtoBufEncodeError; - -use actix_web::http::header::{CONTENT_TYPE, CONTENT_LENGTH}; -use actix_web::{Responder, HttpMessage, HttpRequest, HttpResponse}; -use actix_web::dev::HttpResponseBuilder; -use actix_web::error::{Error, PayloadError, ResponseError}; - - -#[derive(Fail, Debug)] -pub enum ProtoBufPayloadError { - /// Payload size is bigger than 256k - #[fail(display="Payload size is bigger than 256k")] - Overflow, - /// Content type error - #[fail(display="Content type error")] - ContentType, - /// Serialize error - #[fail(display="ProtoBud serialize error: {}", _0)] - Serialize(#[cause] ProtoBufEncodeError), - /// Deserialize error - #[fail(display="ProtoBud deserialize error: {}", _0)] - Deserialize(#[cause] ProtoBufDecodeError), - /// Payload error - #[fail(display="Error that occur during reading payload: {}", _0)] - Payload(#[cause] PayloadError), -} - -impl ResponseError for ProtoBufPayloadError { - - fn error_response(&self) -> HttpResponse { - match *self { - ProtoBufPayloadError::Overflow => HttpResponse::PayloadTooLarge().into(), - _ => HttpResponse::BadRequest().into(), - } - } -} - -impl From for ProtoBufPayloadError { - fn from(err: PayloadError) -> ProtoBufPayloadError { - ProtoBufPayloadError::Payload(err) - } -} - -impl From for ProtoBufPayloadError { - fn from(err: ProtoBufDecodeError) -> ProtoBufPayloadError { - ProtoBufPayloadError::Deserialize(err) - } -} - -#[derive(Debug)] -pub struct ProtoBuf(pub T); - -impl Responder for ProtoBuf { - type Item = HttpResponse; - type Error = Error; - - fn respond_to(self, _: HttpRequest) -> Result { - let mut buf = Vec::new(); - self.0.encode(&mut buf) - .map_err(|e| Error::from(ProtoBufPayloadError::Serialize(e))) - .and_then(|()| { - Ok(HttpResponse::Ok() - .content_type("application/protobuf") - .body(buf) - .into()) - }) - } -} - -pub struct ProtoBufMessage{ - limit: usize, - ct: &'static str, - req: Option, - fut: Option>>, -} - -impl ProtoBufMessage { - - /// Create `ProtoBufMessage` for request. - pub fn new(req: T) -> Self { - ProtoBufMessage{ - limit: 262_144, - req: Some(req), - fut: None, - ct: "application/protobuf", - } - } - - /// Change max size of payload. By default max size is 256Kb - pub fn limit(mut self, limit: usize) -> Self { - self.limit = limit; - self - } - - /// Set allowed content type. - /// - /// By default *application/protobuf* content type is used. Set content type - /// to empty string if you want to disable content type check. - pub fn content_type(mut self, ct: &'static str) -> Self { - self.ct = ct; - self - } -} - -impl Future for ProtoBufMessage - where T: HttpMessage + Stream + 'static -{ - type Item = U; - type Error = ProtoBufPayloadError; - - fn poll(&mut self) -> Poll { - if let Some(req) = self.req.take() { - if let Some(len) = req.headers().get(CONTENT_LENGTH) { - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - if len > self.limit { - return Err(ProtoBufPayloadError::Overflow); - } - } else { - return Err(ProtoBufPayloadError::Overflow); - } - } - } - // check content-type - if !self.ct.is_empty() && req.content_type() != self.ct { - return Err(ProtoBufPayloadError::ContentType) - } - - let limit = self.limit; - let fut = req.from_err() - .fold(BytesMut::new(), move |mut body, chunk| { - if (body.len() + chunk.len()) > limit { - Err(ProtoBufPayloadError::Overflow) - } else { - body.extend_from_slice(&chunk); - Ok(body) - } - }) - .and_then(|body| Ok(::decode(&mut body.into_buf())?)); - self.fut = Some(Box::new(fut)); - } - - self.fut.as_mut().expect("ProtoBufBody could not be used second time").poll() - } -} - - -pub trait ProtoBufResponseBuilder { - - fn protobuf(&mut self, value: T) -> Result; -} - -impl ProtoBufResponseBuilder for HttpResponseBuilder { - - fn protobuf(&mut self, value: T) -> Result { - self.header(CONTENT_TYPE, "application/protobuf"); - - let mut body = Vec::new(); - value.encode(&mut body).map_err(|e| ProtoBufPayloadError::Serialize(e))?; - Ok(self.body(body)) - } -} diff --git a/examples/protobuf/test.proto b/examples/protobuf/test.proto deleted file mode 100644 index 8ec278ca4..000000000 --- a/examples/protobuf/test.proto +++ /dev/null @@ -1,6 +0,0 @@ -syntax = "proto3"; - -message MyObj { - int32 number = 1; - string name = 2; -} \ No newline at end of file diff --git a/examples/protobuf/test_pb2.py b/examples/protobuf/test_pb2.py deleted file mode 100644 index 05e71f3a6..000000000 --- a/examples/protobuf/test_pb2.py +++ /dev/null @@ -1,76 +0,0 @@ -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: test.proto - -import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -from google.protobuf import descriptor_pb2 -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='test.proto', - package='', - syntax='proto3', - serialized_pb=_b('\n\ntest.proto\"%\n\x05MyObj\x12\x0e\n\x06number\x18\x01 \x01(\x05\x12\x0c\n\x04name\x18\x02 \x01(\tb\x06proto3') -) -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - - - - -_MYOBJ = _descriptor.Descriptor( - name='MyObj', - full_name='MyObj', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='number', full_name='MyObj.number', index=0, - number=1, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='name', full_name='MyObj.name', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=14, - serialized_end=51, -) - -DESCRIPTOR.message_types_by_name['MyObj'] = _MYOBJ - -MyObj = _reflection.GeneratedProtocolMessageType('MyObj', (_message.Message,), dict( - DESCRIPTOR = _MYOBJ, - __module__ = 'test_pb2' - # @@protoc_insertion_point(class_scope:MyObj) - )) -_sym_db.RegisterMessage(MyObj) - - -# @@protoc_insertion_point(module_scope) diff --git a/examples/r2d2/Cargo.toml b/examples/r2d2/Cargo.toml deleted file mode 100644 index ab9590a43..000000000 --- a/examples/r2d2/Cargo.toml +++ /dev/null @@ -1,20 +0,0 @@ -[package] -name = "r2d2-example" -version = "0.1.0" -authors = ["Nikolay Kim "] -workspace = "../.." - -[dependencies] -env_logger = "0.5" -actix = "0.5" -actix-web = { path = "../../" } - -futures = "0.1" -uuid = { version = "0.5", features = ["serde", "v4"] } -serde = "1.0" -serde_json = "1.0" -serde_derive = "1.0" - -r2d2 = "*" -r2d2_sqlite = "*" -rusqlite = "*" diff --git a/examples/r2d2/src/db.rs b/examples/r2d2/src/db.rs deleted file mode 100644 index 6e2ddc09f..000000000 --- a/examples/r2d2/src/db.rs +++ /dev/null @@ -1,41 +0,0 @@ -//! Db executor actor -use std::io; -use uuid; -use actix_web::*; -use actix::prelude::*; -use r2d2::Pool; -use r2d2_sqlite::SqliteConnectionManager; - - -/// This is db executor actor. We are going to run 3 of them in parallel. -pub struct DbExecutor(pub Pool); - -/// This is only message that this actor can handle, but it is easy to extend number of -/// messages. -pub struct CreateUser { - pub name: String, -} - -impl Message for CreateUser { - type Result = Result; -} - -impl Actor for DbExecutor { - type Context = SyncContext; -} - -impl Handler for DbExecutor { - type Result = Result; - - fn handle(&mut self, msg: CreateUser, _: &mut Self::Context) -> Self::Result { - let conn = self.0.get().unwrap(); - - let uuid = format!("{}", uuid::Uuid::new_v4()); - conn.execute("INSERT INTO users (id, name) VALUES ($1, $2)", - &[&uuid, &msg.name]).unwrap(); - - Ok(conn.query_row("SELECT name FROM users WHERE id=$1", &[&uuid], |row| { - row.get(0) - }).map_err(|_| io::Error::new(io::ErrorKind::Other, "db error"))?) - } -} diff --git a/examples/r2d2/src/main.rs b/examples/r2d2/src/main.rs deleted file mode 100644 index a3cf637c7..000000000 --- a/examples/r2d2/src/main.rs +++ /dev/null @@ -1,65 +0,0 @@ -//! Actix web r2d2 example -extern crate serde; -extern crate serde_json; -extern crate uuid; -extern crate futures; -extern crate actix; -extern crate actix_web; -extern crate env_logger; -extern crate r2d2; -extern crate r2d2_sqlite; -extern crate rusqlite; - -use actix::prelude::*; -use actix_web::{ - middleware, http, server, App, AsyncResponder, HttpRequest, HttpResponse, Error}; -use futures::future::Future; -use r2d2_sqlite::SqliteConnectionManager; - -mod db; -use db::{CreateUser, DbExecutor}; - - -/// State with DbExecutor address -struct State { - db: Addr, -} - -/// Async request handler -fn index(req: HttpRequest) -> Box> { - let name = &req.match_info()["name"]; - - req.state().db.send(CreateUser{name: name.to_owned()}) - .from_err() - .and_then(|res| { - match res { - Ok(user) => Ok(HttpResponse::Ok().json(user)), - Err(_) => Ok(HttpResponse::InternalServerError().into()) - } - }) - .responder() -} - -fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=debug"); - env_logger::init(); - let sys = actix::System::new("r2d2-example"); - - // r2d2 pool - let manager = SqliteConnectionManager::file("test.db"); - let pool = r2d2::Pool::new(manager).unwrap(); - - // Start db executor actors - let addr = SyncArbiter::start(3, move || DbExecutor(pool.clone())); - - // Start http server - let _ = server::new(move || { - App::with_state(State{db: addr.clone()}) - // enable logger - .middleware(middleware::Logger::default()) - .resource("/{name}", |r| r.method(http::Method::GET).a(index))}) - .bind("127.0.0.1:8080").unwrap() - .start(); - - let _ = sys.run(); -} diff --git a/examples/r2d2/test.db b/examples/r2d2/test.db deleted file mode 100644 index 3ea0c83d7..000000000 Binary files a/examples/r2d2/test.db and /dev/null differ diff --git a/examples/redis-session/Cargo.toml b/examples/redis-session/Cargo.toml deleted file mode 100644 index cfa102d11..000000000 --- a/examples/redis-session/Cargo.toml +++ /dev/null @@ -1,11 +0,0 @@ -[package] -name = "redis-session" -version = "0.1.0" -authors = ["Nikolay Kim "] -workspace = "../.." - -[dependencies] -env_logger = "0.5" -actix = "0.5" -actix-web = "0.4" -actix-redis = { version = "0.2", features = ["web"] } diff --git a/examples/redis-session/src/main.rs b/examples/redis-session/src/main.rs deleted file mode 100644 index 36df16559..000000000 --- a/examples/redis-session/src/main.rs +++ /dev/null @@ -1,48 +0,0 @@ -#![allow(unused_variables)] - -extern crate actix; -extern crate actix_web; -extern crate actix_redis; -extern crate env_logger; - -use actix_web::*; -use actix_web::middleware::RequestSession; -use actix_redis::RedisSessionBackend; - - -/// simple handler -fn index(mut req: HttpRequest) -> Result { - println!("{:?}", req); - - // session - if let Some(count) = req.session().get::("counter")? { - println!("SESSION value: {}", count); - req.session().set("counter", count+1)?; - } else { - req.session().set("counter", 1)?; - } - - Ok("Welcome!".into()) -} - -fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=info,actix_redis=info"); - env_logger::init(); - let sys = actix::System::new("basic-example"); - - HttpServer::new( - || Application::new() - // enable logger - .middleware(middleware::Logger::default()) - // cookie session middleware - .middleware(middleware::SessionStorage::new( - RedisSessionBackend::new("127.0.0.1:6379", &[0; 32]) - )) - // register simple route, handle all methods - .resource("/", |r| r.f(index))) - .bind("0.0.0.0:8080").unwrap() - .threads(1) - .start(); - - let _ = sys.run(); -} diff --git a/examples/state/Cargo.toml b/examples/state/Cargo.toml deleted file mode 100644 index bd3ba2439..000000000 --- a/examples/state/Cargo.toml +++ /dev/null @@ -1,11 +0,0 @@ -[package] -name = "state" -version = "0.1.0" -authors = ["Nikolay Kim "] -workspace = "../.." - -[dependencies] -futures = "*" -env_logger = "0.5" -actix = "0.5" -actix-web = { path = "../../" } diff --git a/examples/state/README.md b/examples/state/README.md deleted file mode 100644 index 127ed2a0f..000000000 --- a/examples/state/README.md +++ /dev/null @@ -1,15 +0,0 @@ -# state - -## Usage - -### server - -```bash -cd actix-web/examples/state -cargo run -# Started http server: 127.0.0.1:8080 -``` - -### web client - -- [http://localhost:8080/](http://localhost:8080/) diff --git a/examples/state/src/main.rs b/examples/state/src/main.rs deleted file mode 100644 index e3b0890bd..000000000 --- a/examples/state/src/main.rs +++ /dev/null @@ -1,77 +0,0 @@ -#![cfg_attr(feature="cargo-clippy", allow(needless_pass_by_value))] -//! There are two level of statefulness in actix-web. Application has state -//! that is shared across all handlers within same Application. -//! And individual handler can have state. - -extern crate actix; -extern crate actix_web; -extern crate env_logger; - -use std::cell::Cell; - -use actix::prelude::*; -use actix_web::{ - http, server, ws, middleware, App, HttpRequest, HttpResponse}; - -/// Application state -struct AppState { - counter: Cell, -} - -/// simple handle -fn index(req: HttpRequest) -> HttpResponse { - println!("{:?}", req); - req.state().counter.set(req.state().counter.get() + 1); - - HttpResponse::Ok().body(format!("Num of requests: {}", req.state().counter.get())) -} - -/// `MyWebSocket` counts how many messages it receives from peer, -/// websocket-client.py could be used for tests -struct MyWebSocket { - counter: usize, -} - -impl Actor for MyWebSocket { - type Context = ws::WebsocketContext; -} - -impl StreamHandler for MyWebSocket { - - fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { - self.counter += 1; - println!("WS({}): {:?}", self.counter, msg); - match msg { - ws::Message::Ping(msg) => ctx.pong(&msg), - ws::Message::Text(text) => ctx.text(text), - ws::Message::Binary(bin) => ctx.binary(bin), - ws::Message::Close(_) => { - ctx.stop(); - } - _ => (), - } - } -} - -fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=info"); - let _ = env_logger::init(); - let sys = actix::System::new("ws-example"); - - let _ = server::new( - || App::with_state(AppState{counter: Cell::new(0)}) - // enable logger - .middleware(middleware::Logger::default()) - // websocket route - .resource( - "/ws/", |r| - r.method(http::Method::GET).f( - |req| ws::start(req, MyWebSocket{counter: 0}))) - // register simple handler, handle all methods - .resource("/", |r| r.f(index))) - .bind("127.0.0.1:8080").unwrap() - .start(); - - println!("Started http server: 127.0.0.1:8080"); - let _ = sys.run(); -} diff --git a/examples/static/actixLogo.png b/examples/static/actixLogo.png deleted file mode 100644 index 142e4e8d5..000000000 Binary files a/examples/static/actixLogo.png and /dev/null differ diff --git a/examples/static/favicon.ico b/examples/static/favicon.ico deleted file mode 100644 index 03018db5b..000000000 Binary files a/examples/static/favicon.ico and /dev/null differ diff --git a/examples/static/index.html b/examples/static/index.html deleted file mode 100644 index e59e13f12..000000000 --- a/examples/static/index.html +++ /dev/null @@ -1,90 +0,0 @@ - - - - - - - - -

    Chat!

    -
    -  | Status: - disconnected -
    -
    -
    -
    - - -
    - - diff --git a/examples/template_tera/Cargo.toml b/examples/template_tera/Cargo.toml deleted file mode 100644 index 8591fa50e..000000000 --- a/examples/template_tera/Cargo.toml +++ /dev/null @@ -1,11 +0,0 @@ -[package] -name = "template-tera" -version = "0.1.0" -authors = ["Nikolay Kim "] -workspace = "../.." - -[dependencies] -env_logger = "0.5" -actix = "0.5" -actix-web = { path = "../../" } -tera = "*" diff --git a/examples/template_tera/README.md b/examples/template_tera/README.md deleted file mode 100644 index 35829599f..000000000 --- a/examples/template_tera/README.md +++ /dev/null @@ -1,17 +0,0 @@ -# template_tera - -Minimal example of using the template [tera](https://github.com/Keats/tera) that displays a form. - -## Usage - -### server - -```bash -cd actix-web/examples/template_tera -cargo run (or ``cargo watch -x run``) -# Started http server: 127.0.0.1:8080 -``` - -### web client - -- [http://localhost:8080](http://localhost:8080) diff --git a/examples/template_tera/src/main.rs b/examples/template_tera/src/main.rs deleted file mode 100644 index fb512d2c4..000000000 --- a/examples/template_tera/src/main.rs +++ /dev/null @@ -1,48 +0,0 @@ -extern crate actix; -extern crate actix_web; -extern crate env_logger; -#[macro_use] -extern crate tera; - -use actix_web::{ - http, error, middleware, server, App, HttpRequest, HttpResponse, Error}; - - -struct State { - template: tera::Tera, // <- store tera template in application state -} - -fn index(req: HttpRequest) -> Result { - let s = if let Some(name) = req.query().get("name") { // <- submitted form - let mut ctx = tera::Context::new(); - ctx.add("name", &name.to_owned()); - ctx.add("text", &"Welcome!".to_owned()); - req.state().template.render("user.html", &ctx) - .map_err(|_| error::ErrorInternalServerError("Template error"))? - } else { - req.state().template.render("index.html", &tera::Context::new()) - .map_err(|_| error::ErrorInternalServerError("Template error"))? - }; - Ok(HttpResponse::Ok() - .content_type("text/html") - .body(s)) -} - -fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=info"); - let _ = env_logger::init(); - let sys = actix::System::new("tera-example"); - - let _ = server::new(|| { - let tera = compile_templates!(concat!(env!("CARGO_MANIFEST_DIR"), "/templates/**/*")); - - App::with_state(State{template: tera}) - // enable logger - .middleware(middleware::Logger::default()) - .resource("/", |r| r.method(http::Method::GET).f(index))}) - .bind("127.0.0.1:8080").unwrap() - .start(); - - println!("Started http server: 127.0.0.1:8080"); - let _ = sys.run(); -} diff --git a/examples/template_tera/templates/index.html b/examples/template_tera/templates/index.html deleted file mode 100644 index d8a47bc09..000000000 --- a/examples/template_tera/templates/index.html +++ /dev/null @@ -1,17 +0,0 @@ - - - - - Actix web - - -

    Welcome!

    -

    -

    What is your name?

    -
    -
    -

    -
    -

    - - diff --git a/examples/template_tera/templates/user.html b/examples/template_tera/templates/user.html deleted file mode 100644 index cb5328915..000000000 --- a/examples/template_tera/templates/user.html +++ /dev/null @@ -1,13 +0,0 @@ - - - - - Actix web - - -

    Hi, {{ name }}!

    -

    - {{ text }} -

    - - diff --git a/examples/tls/Cargo.toml b/examples/tls/Cargo.toml deleted file mode 100644 index a4706d419..000000000 --- a/examples/tls/Cargo.toml +++ /dev/null @@ -1,15 +0,0 @@ -[package] -name = "tls-example" -version = "0.1.0" -authors = ["Nikolay Kim "] -workspace = "../.." - -[[bin]] -name = "server" -path = "src/main.rs" - -[dependencies] -env_logger = "0.5" -actix = "0.5" -actix-web = { path = "../../", features=["alpn"] } -openssl = { version="0.10" } diff --git a/examples/tls/README.md b/examples/tls/README.md deleted file mode 100644 index 1bc9ba3b7..000000000 --- a/examples/tls/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# tls example - -## Usage - -### server - -```bash -cd actix-web/examples/tls -cargo run (or ``cargo watch -x run``) -# Started http server: 127.0.0.1:8443 -``` - -### web client - -- curl: ``curl -v https://127.0.0.1:8443/index.html --compress -k`` -- browser: [https://127.0.0.1:8443/index.html](https://127.0.0.1:8080/index.html) diff --git a/examples/tls/cert.pem b/examples/tls/cert.pem deleted file mode 100644 index 159aacea2..000000000 --- a/examples/tls/cert.pem +++ /dev/null @@ -1,31 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIFPjCCAyYCCQDvLYiYD+jqeTANBgkqhkiG9w0BAQsFADBhMQswCQYDVQQGEwJV -UzELMAkGA1UECAwCQ0ExCzAJBgNVBAcMAlNGMRAwDgYDVQQKDAdDb21wYW55MQww -CgYDVQQLDANPcmcxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xODAxMjUx -NzQ2MDFaFw0xOTAxMjUxNzQ2MDFaMGExCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJD -QTELMAkGA1UEBwwCU0YxEDAOBgNVBAoMB0NvbXBhbnkxDDAKBgNVBAsMA09yZzEY -MBYGA1UEAwwPd3d3LmV4YW1wbGUuY29tMIICIjANBgkqhkiG9w0BAQEFAAOCAg8A -MIICCgKCAgEA2WzIA2IpVR9Tb9EFhITlxuhE5rY2a3S6qzYNzQVgSFggxXEPn8k1 -sQEcer5BfAP986Sck3H0FvB4Bt/I8PwOtUCmhwcc8KtB5TcGPR4fjXnrpC+MIK5U -NLkwuyBDKziYzTdBj8kUFX1WxmvEHEgqToPOZfBgsS71cJAR/zOWraDLSRM54jXy -voLZN4Ti9rQagQrvTQ44Vz5ycDQy7UxtbUGh1CVv69vNVr7/SOOh/Nw5FNOZWLWr -odGyoec5wh9iqRZgRqiTUc6Lt7V2RWc2X2gjwST2UfI+U46Ip3oaQ7ZD4eAkoqND -xdniBZAykVG3c/99ux4BAESTF8fsNch6UticBxYMuTu+ouvP0psfI9wwwNliJDmA -CRUTB9AgRynbL1AzhqQoDfsb98IZfjfNOpwnwuLwpMAPhbgd5KNdZaIJ4Hb6/stI -yFElOExxd3TAxF2Gshd/lq1JcNHAZ1DSXV5MvOWT/NWgXwbIzUgQ8eIi+HuDYX2U -UuaB6R8tbd52H7rbUv6HrfinuSlKWqjSYLkiKHkwUpoMw8y9UycRSzs1E9nPwPTO -vRXb0mNCQeBCV9FvStNVXdCUTT8LGPv87xSD2pmt7LijlE6mHLG8McfcWkzA69un -CEHIFAFDimTuN7EBljc119xWFTcHMyoZAfFF+oTqwSbBGImruCxnaJECAwEAATAN -BgkqhkiG9w0BAQsFAAOCAgEApavsgsn7SpPHfhDSN5iZs1ILZQRewJg0Bty0xPfk -3tynSW6bNH3nSaKbpsdmxxomthNSQgD2heOq1By9YzeOoNR+7Pk3s4FkASnf3ToI -JNTUasBFFfaCG96s4Yvs8KiWS/k84yaWuU8c3Wb1jXs5Rv1qE1Uvuwat1DSGXSoD -JNluuIkCsC4kWkyq5pWCGQrabWPRTWsHwC3PTcwSRBaFgYLJaR72SloHB1ot02zL -d2age9dmFRFLLCBzP+D7RojBvL37qS/HR+rQ4SoQwiVc/JzaeqSe7ZbvEH9sZYEu -ALowJzgbwro7oZflwTWunSeSGDSltkqKjvWvZI61pwfHKDahUTmZ5h2y67FuGEaC -CIOUI8dSVSPKITxaq3JL4ze2e9/0Lt7hj19YK2uUmtMAW5Tirz4Yx5lyGH9U8Wur -y/X8VPxTc4A9TMlJgkyz0hqvhbPOT/zSWB10zXh0glKAsSBryAOEDxV1UygmSir7 -YV8Qaq+oyKUTMc1MFq5vZ07M51EPaietn85t8V2Y+k/8XYltRp32NxsypxAJuyxh -g/ko6RVTrWa1sMvz/F9LFqAdKiK5eM96lh9IU4xiLg4ob8aS/GRAA8oIFkZFhLrt -tOwjIUPmEPyHWFi8dLpNuQKYalLYhuwZftG/9xV+wqhKGZO9iPrpHSYBRTap8w2y -1QU= ------END CERTIFICATE----- diff --git a/examples/tls/key.pem b/examples/tls/key.pem deleted file mode 100644 index aac387c64..000000000 --- a/examples/tls/key.pem +++ /dev/null @@ -1,51 +0,0 @@ ------BEGIN RSA PRIVATE KEY----- -MIIJKAIBAAKCAgEA2WzIA2IpVR9Tb9EFhITlxuhE5rY2a3S6qzYNzQVgSFggxXEP -n8k1sQEcer5BfAP986Sck3H0FvB4Bt/I8PwOtUCmhwcc8KtB5TcGPR4fjXnrpC+M -IK5UNLkwuyBDKziYzTdBj8kUFX1WxmvEHEgqToPOZfBgsS71cJAR/zOWraDLSRM5 -4jXyvoLZN4Ti9rQagQrvTQ44Vz5ycDQy7UxtbUGh1CVv69vNVr7/SOOh/Nw5FNOZ -WLWrodGyoec5wh9iqRZgRqiTUc6Lt7V2RWc2X2gjwST2UfI+U46Ip3oaQ7ZD4eAk -oqNDxdniBZAykVG3c/99ux4BAESTF8fsNch6UticBxYMuTu+ouvP0psfI9wwwNli -JDmACRUTB9AgRynbL1AzhqQoDfsb98IZfjfNOpwnwuLwpMAPhbgd5KNdZaIJ4Hb6 -/stIyFElOExxd3TAxF2Gshd/lq1JcNHAZ1DSXV5MvOWT/NWgXwbIzUgQ8eIi+HuD -YX2UUuaB6R8tbd52H7rbUv6HrfinuSlKWqjSYLkiKHkwUpoMw8y9UycRSzs1E9nP -wPTOvRXb0mNCQeBCV9FvStNVXdCUTT8LGPv87xSD2pmt7LijlE6mHLG8McfcWkzA -69unCEHIFAFDimTuN7EBljc119xWFTcHMyoZAfFF+oTqwSbBGImruCxnaJECAwEA -AQKCAgAME3aoeXNCPxMrSri7u4Xnnk71YXl0Tm9vwvjRQlMusXZggP8VKN/KjP0/ -9AE/GhmoxqPLrLCZ9ZE1EIjgmZ9Xgde9+C8rTtfCG2RFUL7/5J2p6NonlocmxoJm -YkxYwjP6ce86RTjQWL3RF3s09u0inz9/efJk5O7M6bOWMQ9VZXDlBiRY5BYvbqUR -6FeSzD4MnMbdyMRoVBeXE88gTvZk8xhB6DJnLzYgc0tKiRoeKT0iYv5JZw25VyRM -ycLzfTrFmXCPfB1ylb483d9Ly4fBlM8nkx37PzEnAuukIawDxsPOb9yZC+hfvNJI -7NFiMN+3maEqG2iC00w4Lep4skHY7eHUEUMl+Wjr+koAy2YGLWAwHZQTm7iXn9Ab -L6adL53zyCKelRuEQOzbeosJAqS+5fpMK0ekXyoFIuskj7bWuIoCX7K/kg6q5IW+ -vC2FrlsrbQ79GztWLVmHFO1I4J9M5r666YS0qdh8c+2yyRl4FmSiHfGxb3eOKpxQ -b6uI97iZlkxPF9LYUCSc7wq0V2gGz+6LnGvTHlHrOfVXqw/5pLAKhXqxvnroDTwz -0Ay/xFF6ei/NSxBY5t8ztGCBm45wCU3l8pW0X6dXqwUipw5b4MRy1VFRu6rqlmbL -OPSCuLxqyqsigiEYsBgS/icvXz9DWmCQMPd2XM9YhsHvUq+R4QKCAQEA98EuMMXI -6UKIt1kK2t/3OeJRyDd4iv/fCMUAnuPjLBvFE4cXD/SbqCxcQYqb+pue3PYkiTIC -71rN8OQAc5yKhzmmnCE5N26br/0pG4pwEjIr6mt8kZHmemOCNEzvhhT83nfKmV0g -9lNtuGEQMiwmZrpUOF51JOMC39bzcVjYX2Cmvb7cFbIq3lR0zwM+aZpQ4P8LHCIu -bgHmwbdlkLyIULJcQmHIbo6nPFB3ZZE4mqmjwY+rA6Fh9rgBa8OFCfTtrgeYXrNb -IgZQ5U8GoYRPNC2ot0vpTinraboa/cgm6oG4M7FW1POCJTl+/ktHEnKuO5oroSga -/BSg7hCNFVaOhwKCAQEA4Kkys0HtwEbV5mY/NnvUD5KwfXX7BxoXc9lZ6seVoLEc -KjgPYxqYRVrC7dB2YDwwp3qcRTi/uBAgFNm3iYlDzI4xS5SeaudUWjglj7BSgXE2 -iOEa7EwcvVPluLaTgiWjlzUKeUCNNHWSeQOt+paBOT+IgwRVemGVpAgkqQzNh/nP -tl3p9aNtgzEm1qVlPclY/XUCtf3bcOR+z1f1b4jBdn0leu5OhnxkC+Htik+2fTXD -jt6JGrMkanN25YzsjnD3Sn+v6SO26H99wnYx5oMSdmb8SlWRrKtfJHnihphjG/YY -l1cyorV6M/asSgXNQfGJm4OuJi0I4/FL2wLUHnU+JwKCAQEAzh4WipcRthYXXcoj -gMKRkMOb3GFh1OpYqJgVExtudNTJmZxq8GhFU51MR27Eo7LycMwKy2UjEfTOnplh -Us2qZiPtW7k8O8S2m6yXlYUQBeNdq9IuuYDTaYD94vsazscJNSAeGodjE+uGvb1q -1wLqE87yoE7dUInYa1cOA3+xy2/CaNuviBFJHtzOrSb6tqqenQEyQf6h9/12+DTW -t5pSIiixHrzxHiFqOoCLRKGToQB+71rSINwTf0nITNpGBWmSj5VcC3VV3TG5/XxI -fPlxV2yhD5WFDPVNGBGvwPDSh4jSMZdZMSNBZCy4XWFNSKjGEWoK4DFYed3DoSt9 -5IG1YwKCAQA63ntHl64KJUWlkwNbboU583FF3uWBjee5VqoGKHhf3CkKMxhtGqnt -+oN7t5VdUEhbinhqdx1dyPPvIsHCS3K1pkjqii4cyzNCVNYa2dQ00Qq+QWZBpwwc -3GAkz8rFXsGIPMDa1vxpU6mnBjzPniKMcsZ9tmQDppCEpBGfLpio2eAA5IkK8eEf -cIDB3CM0Vo94EvI76CJZabaE9IJ+0HIJb2+jz9BJ00yQBIqvJIYoNy9gP5Xjpi+T -qV/tdMkD5jwWjHD3AYHLWKUGkNwwkAYFeqT/gX6jpWBP+ZRPOp011X3KInJFSpKU -DT5GQ1Dux7EMTCwVGtXqjO8Ym5wjwwsfAoIBAEcxlhIW1G6BiNfnWbNPWBdh3v/K -5Ln98Rcrz8UIbWyl7qNPjYb13C1KmifVG1Rym9vWMO3KuG5atK3Mz2yLVRtmWAVc -fxzR57zz9MZFDun66xo+Z1wN3fVxQB4CYpOEI4Lb9ioX4v85hm3D6RpFukNtRQEc -Gfr4scTjJX4jFWDp0h6ffMb8mY+quvZoJ0TJqV9L9Yj6Ksdvqez/bdSraev97bHQ -4gbQxaTZ6WjaD4HjpPQefMdWp97Metg0ZQSS8b8EzmNFgyJ3XcjirzwliKTAQtn6 -I2sd0NCIooelrKRD8EJoDUwxoOctY7R97wpZ7/wEHU45cBCbRV3H4JILS5c= ------END RSA PRIVATE KEY----- diff --git a/examples/tls/src/main.rs b/examples/tls/src/main.rs deleted file mode 100644 index 809af1716..000000000 --- a/examples/tls/src/main.rs +++ /dev/null @@ -1,49 +0,0 @@ -#![allow(unused_variables)] -extern crate actix; -extern crate actix_web; -extern crate env_logger; -extern crate openssl; - -use openssl::ssl::{SslMethod, SslAcceptor, SslFiletype}; -use actix_web::{ - http, middleware, server, App, HttpRequest, HttpResponse, Error}; - - -/// simple handle -fn index(req: HttpRequest) -> Result { - println!("{:?}", req); - Ok(HttpResponse::Ok() - .content_type("text/plain") - .body("Welcome!")) -} - -fn main() { - if ::std::env::var("RUST_LOG").is_err() { - ::std::env::set_var("RUST_LOG", "actix_web=info"); - } - env_logger::init(); - let sys = actix::System::new("ws-example"); - - // load ssl keys - let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - builder.set_private_key_file("key.pem", SslFiletype::PEM).unwrap(); - builder.set_certificate_chain_file("cert.pem").unwrap(); - - let _ = server::new( - || App::new() - // enable logger - .middleware(middleware::Logger::default()) - // register simple handler, handle all methods - .resource("/index.html", |r| r.f(index)) - // with path parameters - .resource("/", |r| r.method(http::Method::GET).f(|req| { - HttpResponse::Found() - .header("LOCATION", "/index.html") - .finish() - }))) - .bind("127.0.0.1:8443").unwrap() - .start_ssl(builder).unwrap(); - - println!("Started http server: 127.0.0.1:8443"); - let _ = sys.run(); -} diff --git a/examples/uds.rs b/examples/uds.rs new file mode 100644 index 000000000..8db4cf230 --- /dev/null +++ b/examples/uds.rs @@ -0,0 +1,53 @@ +use actix_web::{ + get, middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer, +}; + +#[get("/resource1/{name}/index.html")] +async fn index(req: HttpRequest, name: web::Path) -> String { + println!("REQ: {:?}", req); + format!("Hello: {}!\r\n", name) +} + +async fn index_async(req: HttpRequest) -> Result<&'static str, Error> { + println!("REQ: {:?}", req); + Ok("Hello world!\r\n") +} + +#[get("/")] +async fn no_params() -> &'static str { + "Hello world!\r\n" +} + +#[cfg(unix)] +#[actix_rt::main] +async fn main() -> std::io::Result<()> { + std::env::set_var("RUST_LOG", "actix_server=info,actix_web=info"); + env_logger::init(); + + HttpServer::new(|| { + App::new() + .wrap(middleware::DefaultHeaders::new().header("X-Version", "0.2")) + .wrap(middleware::Compress::default()) + .wrap(middleware::Logger::default()) + .service(index) + .service(no_params) + .service( + web::resource("/resource2/index.html") + .wrap( + middleware::DefaultHeaders::new().header("X-Version-R2", "0.3"), + ) + .default_service( + web::route().to(|| HttpResponse::MethodNotAllowed()), + ) + .route(web::get().to(index_async)), + ) + .service(web::resource("/test1.html").to(|| async { "Test\r\n" })) + }) + .bind_uds("/Users/fafhrd91/uds-test")? + .workers(1) + .start() + .await +} + +#[cfg(not(unix))] +fn main() {} diff --git a/examples/unix-socket/Cargo.toml b/examples/unix-socket/Cargo.toml deleted file mode 100644 index a7c31f212..000000000 --- a/examples/unix-socket/Cargo.toml +++ /dev/null @@ -1,10 +0,0 @@ -[package] -name = "unix-socket" -version = "0.1.0" -authors = ["Messense Lv "] - -[dependencies] -env_logger = "0.5" -actix = "0.5" -actix-web = { path = "../../" } -tokio-uds = "0.1" diff --git a/examples/unix-socket/README.md b/examples/unix-socket/README.md deleted file mode 100644 index 03b0066a2..000000000 --- a/examples/unix-socket/README.md +++ /dev/null @@ -1,14 +0,0 @@ -## Unix domain socket example - -```bash -$ curl --unix-socket /tmp/actix-uds.socket http://localhost/ -Hello world! -``` - -Although this will only one thread for handling incoming connections -according to the -[documentation](https://actix.github.io/actix-web/actix_web/struct.HttpServer.html#method.start_incoming). - -And it does not delete the socket file (`/tmp/actix-uds.socket`) when stopping -the server so it will fail to start next time you run it unless you delete -the socket file manually. diff --git a/examples/unix-socket/src/main.rs b/examples/unix-socket/src/main.rs deleted file mode 100644 index aeb749d10..000000000 --- a/examples/unix-socket/src/main.rs +++ /dev/null @@ -1,32 +0,0 @@ -extern crate actix; -extern crate actix_web; -extern crate env_logger; -extern crate tokio_uds; - -use actix::*; -use actix_web::*; -use tokio_uds::UnixListener; - - -fn index(_req: HttpRequest) -> &'static str { - "Hello world!" -} - -fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=info"); - env_logger::init(); - let sys = actix::System::new("unix-socket"); - - let listener = UnixListener::bind( - "/tmp/actix-uds.socket", Arbiter::handle()).expect("bind failed"); - HttpServer::new( - || App::new() - // enable logger - .middleware(middleware::Logger::default()) - .resource("/index.html", |r| r.f(|_| "Hello world!")) - .resource("/", |r| r.f(index))) - .start_incoming(listener.incoming(), false); - - println!("Started http server: /tmp/actix-uds.socket"); - let _ = sys.run(); -} diff --git a/examples/web-cors/README.md b/examples/web-cors/README.md deleted file mode 100644 index 6dd3d77ff..000000000 --- a/examples/web-cors/README.md +++ /dev/null @@ -1,15 +0,0 @@ -# Actix Web CORS example - -## start -1 - backend server -```bash -$ cd web-cors/backend -$ cargo run -``` -2 - frontend server -```bash -$ cd web-cors/frontend -$ npm install -$ npm run dev -``` -then open browser 'http://localhost:1234/' diff --git a/examples/web-cors/backend/.gitignore b/examples/web-cors/backend/.gitignore deleted file mode 100644 index 250b626d5..000000000 --- a/examples/web-cors/backend/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ - -/target/ -**/*.rs.bk -Cargo.lock \ No newline at end of file diff --git a/examples/web-cors/backend/Cargo.toml b/examples/web-cors/backend/Cargo.toml deleted file mode 100644 index cffc895fa..000000000 --- a/examples/web-cors/backend/Cargo.toml +++ /dev/null @@ -1,17 +0,0 @@ -[package] -name = "actix-web-cors" -version = "0.1.0" -authors = ["krircc "] -workspace = "../../../" - -[dependencies] -serde = "1.0" -serde_derive = "1.0" -serde_json = "1.0" -http = "0.1" - -actix = "0.5" -actix-web = { path = "../../../" } -dotenv = "0.10" -env_logger = "0.5" -futures = "0.1" diff --git a/examples/web-cors/backend/src/main.rs b/examples/web-cors/backend/src/main.rs deleted file mode 100644 index 599be2c94..000000000 --- a/examples/web-cors/backend/src/main.rs +++ /dev/null @@ -1,43 +0,0 @@ -#[macro_use] extern crate serde_derive; -extern crate serde; -extern crate serde_json; -extern crate futures; -extern crate actix; -extern crate actix_web; -extern crate env_logger; - -use std::env; -use actix_web::{http, middleware, server, App}; - -mod user; -use user::info; - - -fn main() { - env::set_var("RUST_LOG", "actix_web=info"); - env_logger::init(); - - let sys = actix::System::new("Actix-web-CORS"); - - server::new( - || App::new() - .middleware(middleware::Logger::default()) - .resource("/user/info", |r| { - middleware::cors::Cors::build() - .allowed_origin("http://localhost:1234") - .allowed_methods(vec!["GET", "POST"]) - .allowed_headers( - vec![http::header::AUTHORIZATION, - http::header::ACCEPT, - http::header::CONTENT_TYPE]) - .max_age(3600) - .finish().expect("Can not create CORS middleware") - .register(r); - r.method(http::Method::POST).a(info); - })) - .bind("127.0.0.1:8000").unwrap() - .shutdown_timeout(200) - .start(); - - let _ = sys.run(); -} diff --git a/examples/web-cors/backend/src/user.rs b/examples/web-cors/backend/src/user.rs deleted file mode 100644 index 364430fce..000000000 --- a/examples/web-cors/backend/src/user.rs +++ /dev/null @@ -1,19 +0,0 @@ -use actix_web::{AsyncResponder, Error, HttpMessage, HttpResponse, HttpRequest}; -use futures::Future; - - -#[derive(Deserialize,Serialize, Debug)] -struct Info { - username: String, - email: String, - password: String, - confirm_password: String, -} - -pub fn info(req: HttpRequest) -> Box> { - req.json() - .from_err() - .and_then(|res: Info| { - Ok(HttpResponse::Ok().json(res)) - }).responder() -} diff --git a/examples/web-cors/frontend/.babelrc b/examples/web-cors/frontend/.babelrc deleted file mode 100644 index 002b4aa0d..000000000 --- a/examples/web-cors/frontend/.babelrc +++ /dev/null @@ -1,3 +0,0 @@ -{ - "presets": ["env"] -} diff --git a/examples/web-cors/frontend/.gitignore b/examples/web-cors/frontend/.gitignore deleted file mode 100644 index 8875af865..000000000 --- a/examples/web-cors/frontend/.gitignore +++ /dev/null @@ -1,14 +0,0 @@ -.DS_Store -node_modules/ -/dist/ -.cache -npm-debug.log* -yarn-debug.log* -yarn-error.log* - -# Editor directories and files -.idea -*.suo -*.ntvs* -*.njsproj -*.sln diff --git a/examples/web-cors/frontend/index.html b/examples/web-cors/frontend/index.html deleted file mode 100644 index d71de81cc..000000000 --- a/examples/web-cors/frontend/index.html +++ /dev/null @@ -1,13 +0,0 @@ - - - - - - webapp - - -
    - - - - \ No newline at end of file diff --git a/examples/web-cors/frontend/package.json b/examples/web-cors/frontend/package.json deleted file mode 100644 index 7ce2f641d..000000000 --- a/examples/web-cors/frontend/package.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "name": "actix-web-cors", - "version": "0.1.0", - "description": "webapp", - "main": "main.js", - "scripts": { - "dev": "rm -rf dist/ && NODE_ENV=development parcel index.html", - "build": "NODE_ENV=production parcel build index.html", - "test": "echo \"Error: no test specified\" && exit 1" - }, - "license": "ISC", - "dependencies": { - "vue": "^2.5.13", - "vue-router": "^3.0.1", - "axios": "^0.17.1" - }, - "devDependencies": { - "babel-preset-env": "^1.6.1", - "parcel-bundler": "^1.4.1", - "parcel-plugin-vue": "^1.5.0" - } -} diff --git a/examples/web-cors/frontend/src/app.vue b/examples/web-cors/frontend/src/app.vue deleted file mode 100644 index 0c054c206..000000000 --- a/examples/web-cors/frontend/src/app.vue +++ /dev/null @@ -1,145 +0,0 @@ - - - - - \ No newline at end of file diff --git a/examples/web-cors/frontend/src/main.js b/examples/web-cors/frontend/src/main.js deleted file mode 100644 index df1e4b7cb..000000000 --- a/examples/web-cors/frontend/src/main.js +++ /dev/null @@ -1,11 +0,0 @@ -import Vue from 'vue' -import App from './app' - -new Vue({ - el: '#app', - render: h => h(App) -}) - -if (module.hot) { - module.hot.accept(); -} \ No newline at end of file diff --git a/examples/websocket-chat/Cargo.toml b/examples/websocket-chat/Cargo.toml deleted file mode 100644 index 389ccd346..000000000 --- a/examples/websocket-chat/Cargo.toml +++ /dev/null @@ -1,29 +0,0 @@ -[package] -name = "websocket-example" -version = "0.1.0" -authors = ["Nikolay Kim "] -workspace = "../.." - -[[bin]] -name = "server" -path = "src/main.rs" - -[[bin]] -name = "client" -path = "src/client.rs" - -[dependencies] -rand = "*" -bytes = "0.4" -byteorder = "1.1" -futures = "0.1" -tokio-io = "0.1" -tokio-core = "0.1" -env_logger = "*" - -serde = "1.0" -serde_json = "1.0" -serde_derive = "1.0" - -actix = "0.5" -actix-web = { path="../../" } diff --git a/examples/websocket-chat/README.md b/examples/websocket-chat/README.md deleted file mode 100644 index a01dd68b7..000000000 --- a/examples/websocket-chat/README.md +++ /dev/null @@ -1,32 +0,0 @@ -# Websocket chat example - -This is extension of the -[actix chat example](https://github.com/actix/actix/tree/master/examples/chat) - -Added features: - -* Browser WebSocket client -* Chat server runs in separate thread -* Tcp listener runs in separate thread - -## Server - -Chat server listens for incoming tcp connections. Server can access several types of message: - -* `\list` - list all available rooms -* `\join name` - join room, if room does not exist, create new one -* `\name name` - set session name -* `some message` - just string, send message to all peers in same room -* client has to send heartbeat `Ping` messages, if server does not receive a heartbeat message for 10 seconds connection gets dropped - -To start server use command: `cargo run --bin server` - -## Client - -Client connects to server. Reads input from stdin and sends to server. - -To run client use command: `cargo run --bin client` - -## WebSocket Browser Client - -Open url: [http://localhost:8080/](http://localhost:8080/) diff --git a/examples/websocket-chat/client.py b/examples/websocket-chat/client.py deleted file mode 100755 index 8a1bd9aee..000000000 --- a/examples/websocket-chat/client.py +++ /dev/null @@ -1,72 +0,0 @@ -#!/usr/bin/env python3 -"""websocket cmd client for wssrv.py example.""" -import argparse -import asyncio -import signal -import sys - -import aiohttp - - -def start_client(loop, url): - name = input('Please enter your name: ') - - # send request - ws = yield from aiohttp.ClientSession().ws_connect(url, autoclose=False, autoping=False) - - # input reader - def stdin_callback(): - line = sys.stdin.buffer.readline().decode('utf-8') - if not line: - loop.stop() - else: - ws.send_str(name + ': ' + line) - loop.add_reader(sys.stdin.fileno(), stdin_callback) - - @asyncio.coroutine - def dispatch(): - while True: - msg = yield from ws.receive() - - if msg.type == aiohttp.WSMsgType.TEXT: - print('Text: ', msg.data.strip()) - elif msg.type == aiohttp.WSMsgType.BINARY: - print('Binary: ', msg.data) - elif msg.type == aiohttp.WSMsgType.PING: - ws.pong() - elif msg.type == aiohttp.WSMsgType.PONG: - print('Pong received') - else: - if msg.type == aiohttp.WSMsgType.CLOSE: - yield from ws.close() - elif msg.type == aiohttp.WSMsgType.ERROR: - print('Error during receive %s' % ws.exception()) - elif msg.type == aiohttp.WSMsgType.CLOSED: - pass - - break - - yield from dispatch() - - -ARGS = argparse.ArgumentParser( - description="websocket console client for wssrv.py example.") -ARGS.add_argument( - '--host', action="store", dest='host', - default='127.0.0.1', help='Host name') -ARGS.add_argument( - '--port', action="store", dest='port', - default=8080, type=int, help='Port number') - -if __name__ == '__main__': - args = ARGS.parse_args() - if ':' in args.host: - args.host, port = args.host.split(':', 1) - args.port = int(port) - - url = 'http://{}:{}/ws/'.format(args.host, args.port) - - loop = asyncio.get_event_loop() - loop.add_signal_handler(signal.SIGINT, loop.stop) - asyncio.Task(start_client(loop, url)) - loop.run_forever() diff --git a/examples/websocket-chat/src/client.rs b/examples/websocket-chat/src/client.rs deleted file mode 100644 index e2e6a7c84..000000000 --- a/examples/websocket-chat/src/client.rs +++ /dev/null @@ -1,153 +0,0 @@ -#[macro_use] extern crate actix; -extern crate bytes; -extern crate byteorder; -extern crate futures; -extern crate tokio_io; -extern crate tokio_core; -extern crate serde; -extern crate serde_json; -#[macro_use] extern crate serde_derive; - -use std::{io, net, process, thread}; -use std::str::FromStr; -use std::time::Duration; -use futures::Future; -use tokio_io::AsyncRead; -use tokio_io::io::WriteHalf; -use tokio_io::codec::FramedRead; -use tokio_core::net::TcpStream; -use actix::prelude::*; - -mod codec; - - -fn main() { - let sys = actix::System::new("chat-client"); - - // Connect to server - let addr = net::SocketAddr::from_str("127.0.0.1:12345").unwrap(); - Arbiter::handle().spawn( - TcpStream::connect(&addr, Arbiter::handle()) - .and_then(|stream| { - let addr: Addr = ChatClient::create(|ctx| { - let (r, w) = stream.split(); - ChatClient::add_stream(FramedRead::new(r, codec::ClientChatCodec), ctx); - ChatClient{ - framed: actix::io::FramedWrite::new( - w, codec::ClientChatCodec, ctx)}}); - - // start console loop - thread::spawn(move|| { - loop { - let mut cmd = String::new(); - if io::stdin().read_line(&mut cmd).is_err() { - println!("error"); - return - } - - addr.do_send(ClientCommand(cmd)); - } - }); - - futures::future::ok(()) - }) - .map_err(|e| { - println!("Can not connect to server: {}", e); - process::exit(1) - }) - ); - - println!("Running chat client"); - sys.run(); -} - - -struct ChatClient { - framed: actix::io::FramedWrite, codec::ClientChatCodec>, -} - -#[derive(Message)] -struct ClientCommand(String); - -impl Actor for ChatClient { - type Context = Context; - - fn started(&mut self, ctx: &mut Context) { - // start heartbeats otherwise server will disconnect after 10 seconds - self.hb(ctx) - } - - fn stopped(&mut self, _: &mut Context) { - println!("Disconnected"); - - // Stop application on disconnect - Arbiter::system().do_send(actix::msgs::SystemExit(0)); - } -} - -impl ChatClient { - fn hb(&self, ctx: &mut Context) { - ctx.run_later(Duration::new(1, 0), |act, ctx| { - act.framed.write(codec::ChatRequest::Ping); - act.hb(ctx); - }); - } -} - -impl actix::io::WriteHandler for ChatClient {} - -/// Handle stdin commands -impl Handler for ChatClient { - type Result = (); - - fn handle(&mut self, msg: ClientCommand, _: &mut Context) { - let m = msg.0.trim(); - if m.is_empty() { - return - } - - // we check for /sss type of messages - if m.starts_with('/') { - let v: Vec<&str> = m.splitn(2, ' ').collect(); - match v[0] { - "/list" => { - self.framed.write(codec::ChatRequest::List); - }, - "/join" => { - if v.len() == 2 { - self.framed.write(codec::ChatRequest::Join(v[1].to_owned())); - } else { - println!("!!! room name is required"); - } - }, - _ => println!("!!! unknown command"), - } - } else { - self.framed.write(codec::ChatRequest::Message(m.to_owned())); - } - } -} - -/// Server communication - -impl StreamHandler for ChatClient { - - fn handle(&mut self, msg: codec::ChatResponse, _: &mut Context) { - match msg { - codec::ChatResponse::Message(ref msg) => { - println!("message: {}", msg); - } - codec::ChatResponse::Joined(ref msg) => { - println!("!!! joined: {}", msg); - } - codec::ChatResponse::Rooms(rooms) => { - println!("\n!!! Available rooms:"); - for room in rooms { - println!("{}", room); - } - println!(""); - } - _ => (), - } - } -} diff --git a/examples/websocket-chat/src/codec.rs b/examples/websocket-chat/src/codec.rs deleted file mode 100644 index 03638241b..000000000 --- a/examples/websocket-chat/src/codec.rs +++ /dev/null @@ -1,123 +0,0 @@ -#![allow(dead_code)] -use std::io; -use serde_json as json; -use byteorder::{BigEndian , ByteOrder}; -use bytes::{BytesMut, BufMut}; -use tokio_io::codec::{Encoder, Decoder}; - -/// Client request -#[derive(Serialize, Deserialize, Debug, Message)] -#[serde(tag="cmd", content="data")] -pub enum ChatRequest { - /// List rooms - List, - /// Join rooms - Join(String), - /// Send message - Message(String), - /// Ping - Ping -} - -/// Server response -#[derive(Serialize, Deserialize, Debug, Message)] -#[serde(tag="cmd", content="data")] -pub enum ChatResponse { - Ping, - - /// List of rooms - Rooms(Vec), - - /// Joined - Joined(String), - - /// Message - Message(String), -} - -/// Codec for Client -> Server transport -pub struct ChatCodec; - -impl Decoder for ChatCodec -{ - type Item = ChatRequest; - type Error = io::Error; - - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - let size = { - if src.len() < 2 { - return Ok(None) - } - BigEndian::read_u16(src.as_ref()) as usize - }; - - if src.len() >= size + 2 { - src.split_to(2); - let buf = src.split_to(size); - Ok(Some(json::from_slice::(&buf)?)) - } else { - Ok(None) - } - } -} - -impl Encoder for ChatCodec -{ - type Item = ChatResponse; - type Error = io::Error; - - fn encode(&mut self, msg: ChatResponse, dst: &mut BytesMut) -> Result<(), Self::Error> { - let msg = json::to_string(&msg).unwrap(); - let msg_ref: &[u8] = msg.as_ref(); - - dst.reserve(msg_ref.len() + 2); - dst.put_u16::(msg_ref.len() as u16); - dst.put(msg_ref); - - Ok(()) - } -} - - -/// Codec for Server -> Client transport -pub struct ClientChatCodec; - -impl Decoder for ClientChatCodec -{ - type Item = ChatResponse; - type Error = io::Error; - - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - let size = { - if src.len() < 2 { - return Ok(None) - } - BigEndian::read_u16(src.as_ref()) as usize - }; - - if src.len() >= size + 2 { - src.split_to(2); - let buf = src.split_to(size); - Ok(Some(json::from_slice::(&buf)?)) - } else { - Ok(None) - } - } -} - -impl Encoder for ClientChatCodec -{ - type Item = ChatRequest; - type Error = io::Error; - - fn encode(&mut self, msg: ChatRequest, dst: &mut BytesMut) -> Result<(), Self::Error> { - let msg = json::to_string(&msg).unwrap(); - let msg_ref: &[u8] = msg.as_ref(); - - dst.reserve(msg_ref.len() + 2); - dst.put_u16::(msg_ref.len() as u16); - dst.put(msg_ref); - - Ok(()) - } -} diff --git a/examples/websocket-chat/src/main.rs b/examples/websocket-chat/src/main.rs deleted file mode 100644 index 1de3900c4..000000000 --- a/examples/websocket-chat/src/main.rs +++ /dev/null @@ -1,208 +0,0 @@ -#![allow(unused_variables)] -extern crate rand; -extern crate bytes; -extern crate byteorder; -extern crate futures; -extern crate tokio_io; -extern crate tokio_core; -extern crate env_logger; -extern crate serde; -extern crate serde_json; -#[macro_use] extern crate serde_derive; - -#[macro_use] -extern crate actix; -extern crate actix_web; - -use std::time::Instant; - -use actix::*; -use actix_web::{http, fs, ws, App, HttpRequest, HttpResponse, HttpServer, Error}; - -mod codec; -mod server; -mod session; - -/// This is our websocket route state, this state is shared with all route instances -/// via `HttpContext::state()` -struct WsChatSessionState { - addr: Addr, -} - -/// Entry point for our route -fn chat_route(req: HttpRequest) -> Result { - ws::start( - req, - WsChatSession { - id: 0, - hb: Instant::now(), - room: "Main".to_owned(), - name: None}) -} - -struct WsChatSession { - /// unique session id - id: usize, - /// Client must send ping at least once per 10 seconds, otherwise we drop connection. - hb: Instant, - /// joined room - room: String, - /// peer name - name: Option, -} - -impl Actor for WsChatSession { - type Context = ws::WebsocketContext; - - /// Method is called on actor start. - /// We register ws session with ChatServer - fn started(&mut self, ctx: &mut Self::Context) { - // register self in chat server. `AsyncContext::wait` register - // future within context, but context waits until this future resolves - // before processing any other events. - // HttpContext::state() is instance of WsChatSessionState, state is shared across all - // routes within application - let addr: Addr = ctx.address(); - ctx.state().addr.send(server::Connect{addr: addr.recipient()}) - .into_actor(self) - .then(|res, act, ctx| { - match res { - Ok(res) => act.id = res, - // something is wrong with chat server - _ => ctx.stop(), - } - fut::ok(()) - }).wait(ctx); - } - - fn stopping(&mut self, ctx: &mut Self::Context) -> Running { - // notify chat server - ctx.state().addr.do_send(server::Disconnect{id: self.id}); - Running::Stop - } -} - -/// Handle messages from chat server, we simply send it to peer websocket -impl Handler for WsChatSession { - type Result = (); - - fn handle(&mut self, msg: session::Message, ctx: &mut Self::Context) { - ctx.text(msg.0); - } -} - -/// WebSocket message handler -impl StreamHandler for WsChatSession { - - fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { - println!("WEBSOCKET MESSAGE: {:?}", msg); - match msg { - ws::Message::Ping(msg) => ctx.pong(&msg), - ws::Message::Pong(msg) => self.hb = Instant::now(), - ws::Message::Text(text) => { - let m = text.trim(); - // we check for /sss type of messages - if m.starts_with('/') { - let v: Vec<&str> = m.splitn(2, ' ').collect(); - match v[0] { - "/list" => { - // Send ListRooms message to chat server and wait for response - println!("List rooms"); - ctx.state().addr.send(server::ListRooms) - .into_actor(self) - .then(|res, _, ctx| { - match res { - Ok(rooms) => { - for room in rooms { - ctx.text(room); - } - }, - _ => println!("Something is wrong"), - } - fut::ok(()) - }).wait(ctx) - // .wait(ctx) pauses all events in context, - // so actor wont receive any new messages until it get list - // of rooms back - }, - "/join" => { - if v.len() == 2 { - self.room = v[1].to_owned(); - ctx.state().addr.do_send( - server::Join{id: self.id, name: self.room.clone()}); - - ctx.text("joined"); - } else { - ctx.text("!!! room name is required"); - } - }, - "/name" => { - if v.len() == 2 { - self.name = Some(v[1].to_owned()); - } else { - ctx.text("!!! name is required"); - } - }, - _ => ctx.text(format!("!!! unknown command: {:?}", m)), - } - } else { - let msg = if let Some(ref name) = self.name { - format!("{}: {}", name, m) - } else { - m.to_owned() - }; - // send message to chat server - ctx.state().addr.do_send( - server::Message{id: self.id, - msg: msg, - room: self.room.clone()}) - } - }, - ws::Message::Binary(bin) => - println!("Unexpected binary"), - ws::Message::Close(_) => { - ctx.stop(); - } - } - } -} - -fn main() { - let _ = env_logger::init(); - let sys = actix::System::new("websocket-example"); - - // Start chat server actor in separate thread - let server: Addr = Arbiter::start(|_| server::ChatServer::default()); - - // Start tcp server in separate thread - let srv = server.clone(); - Arbiter::new("tcp-server").do_send::( - msgs::Execute::new(move || { - session::TcpServer::new("127.0.0.1:12345", srv); - Ok(()) - })); - - // Create Http server with websocket support - let addr = HttpServer::new( - move || { - // Websocket sessions state - let state = WsChatSessionState { addr: server.clone() }; - - App::with_state(state) - // redirect to websocket.html - .resource("/", |r| r.method(http::Method::GET).f(|_| { - HttpResponse::Found() - .header("LOCATION", "/static/websocket.html") - .finish() - })) - // websocket - .resource("/ws/", |r| r.route().f(chat_route)) - // static resources - .handler("/static/", fs::StaticFiles::new("static/", true)) - }) - .bind("127.0.0.1:8080").unwrap() - .start(); - - println!("Started http server: 127.0.0.1:8080"); - let _ = sys.run(); -} diff --git a/examples/websocket-chat/src/server.rs b/examples/websocket-chat/src/server.rs deleted file mode 100644 index 8b735b852..000000000 --- a/examples/websocket-chat/src/server.rs +++ /dev/null @@ -1,197 +0,0 @@ -//! `ChatServer` is an actor. It maintains list of connection client session. -//! And manages available rooms. Peers send messages to other peers in same -//! room through `ChatServer`. - -use std::cell::RefCell; -use std::collections::{HashMap, HashSet}; -use rand::{self, Rng, ThreadRng}; -use actix::prelude::*; - -use session; - -/// Message for chat server communications - -/// New chat session is created -#[derive(Message)] -#[rtype(usize)] -pub struct Connect { - pub addr: Recipient, -} - -/// Session is disconnected -#[derive(Message)] -pub struct Disconnect { - pub id: usize, -} - -/// Send message to specific room -#[derive(Message)] -pub struct Message { - /// Id of the client session - pub id: usize, - /// Peer message - pub msg: String, - /// Room name - pub room: String, -} - -/// List of available rooms -pub struct ListRooms; - -impl actix::Message for ListRooms { - type Result = Vec; -} - -/// Join room, if room does not exists create new one. -#[derive(Message)] -pub struct Join { - /// Client id - pub id: usize, - /// Room name - pub name: String, -} - -/// `ChatServer` manages chat rooms and responsible for coordinating chat session. -/// implementation is super primitive -pub struct ChatServer { - sessions: HashMap>, - rooms: HashMap>, - rng: RefCell, -} - -impl Default for ChatServer { - fn default() -> ChatServer { - // default room - let mut rooms = HashMap::new(); - rooms.insert("Main".to_owned(), HashSet::new()); - - ChatServer { - sessions: HashMap::new(), - rooms: rooms, - rng: RefCell::new(rand::thread_rng()), - } - } -} - -impl ChatServer { - /// Send message to all users in the room - fn send_message(&self, room: &str, message: &str, skip_id: usize) { - if let Some(sessions) = self.rooms.get(room) { - for id in sessions { - if *id != skip_id { - if let Some(addr) = self.sessions.get(id) { - let _ = addr.do_send(session::Message(message.to_owned())); - } - } - } - } - } -} - -/// Make actor from `ChatServer` -impl Actor for ChatServer { - /// We are going to use simple Context, we just need ability to communicate - /// with other actors. - type Context = Context; -} - -/// Handler for Connect message. -/// -/// Register new session and assign unique id to this session -impl Handler for ChatServer { - type Result = usize; - - fn handle(&mut self, msg: Connect, _: &mut Context) -> Self::Result { - println!("Someone joined"); - - // notify all users in same room - self.send_message(&"Main".to_owned(), "Someone joined", 0); - - // register session with random id - let id = self.rng.borrow_mut().gen::(); - self.sessions.insert(id, msg.addr); - - // auto join session to Main room - self.rooms.get_mut(&"Main".to_owned()).unwrap().insert(id); - - // send id back - id - } -} - -/// Handler for Disconnect message. -impl Handler for ChatServer { - type Result = (); - - fn handle(&mut self, msg: Disconnect, _: &mut Context) { - println!("Someone disconnected"); - - let mut rooms: Vec = Vec::new(); - - // remove address - if self.sessions.remove(&msg.id).is_some() { - // remove session from all rooms - for (name, sessions) in &mut self.rooms { - if sessions.remove(&msg.id) { - rooms.push(name.to_owned()); - } - } - } - // send message to other users - for room in rooms { - self.send_message(&room, "Someone disconnected", 0); - } - } -} - -/// Handler for Message message. -impl Handler for ChatServer { - type Result = (); - - fn handle(&mut self, msg: Message, _: &mut Context) { - self.send_message(&msg.room, msg.msg.as_str(), msg.id); - } -} - -/// Handler for `ListRooms` message. -impl Handler for ChatServer { - type Result = MessageResult; - - fn handle(&mut self, _: ListRooms, _: &mut Context) -> Self::Result { - let mut rooms = Vec::new(); - - for key in self.rooms.keys() { - rooms.push(key.to_owned()) - } - - MessageResult(rooms) - } -} - -/// Join room, send disconnect message to old room -/// send join message to new room -impl Handler for ChatServer { - type Result = (); - - fn handle(&mut self, msg: Join, _: &mut Context) { - let Join {id, name} = msg; - let mut rooms = Vec::new(); - - // remove session from all rooms - for (n, sessions) in &mut self.rooms { - if sessions.remove(&id) { - rooms.push(n.to_owned()); - } - } - // send message to other users - for room in rooms { - self.send_message(&room, "Someone disconnected", 0); - } - - if self.rooms.get_mut(&name).is_none() { - self.rooms.insert(name.clone(), HashSet::new()); - } - self.send_message(&name, "Someone connected", id); - self.rooms.get_mut(&name).unwrap().insert(id); - } -} diff --git a/examples/websocket-chat/src/session.rs b/examples/websocket-chat/src/session.rs deleted file mode 100644 index 7f28c6a4f..000000000 --- a/examples/websocket-chat/src/session.rs +++ /dev/null @@ -1,207 +0,0 @@ -//! `ClientSession` is an actor, it manages peer tcp connection and -//! proxies commands from peer to `ChatServer`. -use std::{io, net}; -use std::str::FromStr; -use std::time::{Instant, Duration}; -use futures::Stream; -use tokio_io::AsyncRead; -use tokio_io::io::WriteHalf; -use tokio_io::codec::FramedRead; -use tokio_core::net::{TcpStream, TcpListener}; - -use actix::prelude::*; - -use server::{self, ChatServer}; -use codec::{ChatRequest, ChatResponse, ChatCodec}; - - -/// Chat server sends this messages to session -#[derive(Message)] -pub struct Message(pub String); - -/// `ChatSession` actor is responsible for tcp peer communications. -pub struct ChatSession { - /// unique session id - id: usize, - /// this is address of chat server - addr: Addr, - /// Client must send ping at least once per 10 seconds, otherwise we drop connection. - hb: Instant, - /// joined room - room: String, - /// Framed wrapper - framed: actix::io::FramedWrite, ChatCodec>, -} - -impl Actor for ChatSession { - /// For tcp communication we are going to use `FramedContext`. - /// It is convenient wrapper around `Framed` object from `tokio_io` - type Context = Context; - - fn started(&mut self, ctx: &mut Self::Context) { - // we'll start heartbeat process on session start. - self.hb(ctx); - - // register self in chat server. `AsyncContext::wait` register - // future within context, but context waits until this future resolves - // before processing any other events. - let addr: Addr = ctx.address(); - self.addr.send(server::Connect{addr: addr.recipient()}) - .into_actor(self) - .then(|res, act, ctx| { - match res { - Ok(res) => act.id = res, - // something is wrong with chat server - _ => ctx.stop(), - } - actix::fut::ok(()) - }).wait(ctx); - } - - fn stopping(&mut self, ctx: &mut Self::Context) -> Running { - // notify chat server - self.addr.do_send(server::Disconnect{id: self.id}); - Running::Stop - } -} - -impl actix::io::WriteHandler for ChatSession {} - -/// To use `Framed` we have to define Io type and Codec -impl StreamHandler for ChatSession { - - /// This is main event loop for client requests - fn handle(&mut self, msg: ChatRequest, ctx: &mut Context) { - match msg { - ChatRequest::List => { - // Send ListRooms message to chat server and wait for response - println!("List rooms"); - self.addr.send(server::ListRooms) - .into_actor(self) - .then(|res, act, ctx| { - match res { - Ok(rooms) => { - act.framed.write(ChatResponse::Rooms(rooms)); - }, - _ => println!("Something is wrong"), - } - actix::fut::ok(()) - }).wait(ctx) - // .wait(ctx) pauses all events in context, - // so actor wont receive any new messages until it get list of rooms back - }, - ChatRequest::Join(name) => { - println!("Join to room: {}", name); - self.room = name.clone(); - self.addr.do_send(server::Join{id: self.id, name: name.clone()}); - self.framed.write(ChatResponse::Joined(name)); - }, - ChatRequest::Message(message) => { - // send message to chat server - println!("Peer message: {}", message); - self.addr.do_send( - server::Message{id: self.id, - msg: message, room: - self.room.clone()}) - } - // we update heartbeat time on ping from peer - ChatRequest::Ping => - self.hb = Instant::now(), - } - } -} - -/// Handler for Message, chat server sends this message, we just send string to peer -impl Handler for ChatSession { - type Result = (); - - fn handle(&mut self, msg: Message, ctx: &mut Context) { - // send message to peer - self.framed.write(ChatResponse::Message(msg.0)); - } -} - -/// Helper methods -impl ChatSession { - - pub fn new(addr: Addr, - framed: actix::io::FramedWrite, ChatCodec>) -> ChatSession { - ChatSession {id: 0, addr: addr, hb: Instant::now(), - room: "Main".to_owned(), framed: framed} - } - - /// helper method that sends ping to client every second. - /// - /// also this method check heartbeats from client - fn hb(&self, ctx: &mut Context) { - ctx.run_later(Duration::new(1, 0), |act, ctx| { - // check client heartbeats - if Instant::now().duration_since(act.hb) > Duration::new(10, 0) { - // heartbeat timed out - println!("Client heartbeat failed, disconnecting!"); - - // notify chat server - act.addr.do_send(server::Disconnect{id: act.id}); - - // stop actor - ctx.stop(); - } - - act.framed.write(ChatResponse::Ping); - // if we can not send message to sink, sink is closed (disconnected) - act.hb(ctx); - }); - } -} - - -/// Define tcp server that will accept incoming tcp connection and create -/// chat actors. -pub struct TcpServer { - chat: Addr, -} - -impl TcpServer { - pub fn new(s: &str, chat: Addr) { - // Create server listener - let addr = net::SocketAddr::from_str("127.0.0.1:12345").unwrap(); - let listener = TcpListener::bind(&addr, Arbiter::handle()).unwrap(); - - // Our chat server `Server` is an actor, first we need to start it - // and then add stream on incoming tcp connections to it. - // TcpListener::incoming() returns stream of the (TcpStream, net::SocketAddr) items - // So to be able to handle this events `Server` actor has to implement - // stream handler `StreamHandler<(TcpStream, net::SocketAddr), io::Error>` - let _: () = TcpServer::create(|ctx| { - ctx.add_message_stream(listener.incoming() - .map_err(|_| ()) - .map(|(t, a)| TcpConnect(t, a))); - TcpServer{chat: chat} - }); - } -} - -/// Make actor from `Server` -impl Actor for TcpServer { - /// Every actor has to provide execution `Context` in which it can run. - type Context = Context; -} - -#[derive(Message)] -struct TcpConnect(TcpStream, net::SocketAddr); - -/// Handle stream of TcpStream's -impl Handler for TcpServer { - type Result = (); - - fn handle(&mut self, msg: TcpConnect, _: &mut Context) { - // For each incoming connection we create `ChatSession` actor - // with out chat server address. - let server = self.chat.clone(); - let _: () = ChatSession::create(|ctx| { - let (r, w) = msg.0.split(); - ChatSession::add_stream(FramedRead::new(r, ChatCodec), ctx); - ChatSession::new(server, actix::io::FramedWrite::new(w, ChatCodec, ctx)) - }); - } -} diff --git a/examples/websocket-chat/static/websocket.html b/examples/websocket-chat/static/websocket.html deleted file mode 100644 index e59e13f12..000000000 --- a/examples/websocket-chat/static/websocket.html +++ /dev/null @@ -1,90 +0,0 @@ - - - - - - - - -

    Chat!

    -
    -  | Status: - disconnected -
    -
    -
    -
    - - -
    - - diff --git a/examples/websocket/Cargo.toml b/examples/websocket/Cargo.toml deleted file mode 100644 index 7b754f0d1..000000000 --- a/examples/websocket/Cargo.toml +++ /dev/null @@ -1,20 +0,0 @@ -[package] -name = "websocket" -version = "0.1.0" -authors = ["Nikolay Kim "] -workspace = "../.." - -[[bin]] -name = "server" -path = "src/main.rs" - -[[bin]] -name = "client" -path = "src/client.rs" - -[dependencies] -env_logger = "*" -futures = "0.1" -tokio-core = "0.1" -actix = "0.5" -actix-web = { path="../../" } diff --git a/examples/websocket/README.md b/examples/websocket/README.md deleted file mode 100644 index 8ffcba822..000000000 --- a/examples/websocket/README.md +++ /dev/null @@ -1,27 +0,0 @@ -# websocket - -Simple echo websocket server. - -## Usage - -### server - -```bash -cd actix-web/examples/websocket -cargo run -# Started http server: 127.0.0.1:8080 -``` - -### web client - -- [http://localhost:8080/ws/index.html](http://localhost:8080/ws/index.html) - -### python client - -- ``pip install aiohttp`` -- ``python websocket-client.py`` - -if ubuntu : - -- ``pip3 install aiohttp`` -- ``python3 websocket-client.py`` diff --git a/examples/websocket/src/client.rs b/examples/websocket/src/client.rs deleted file mode 100644 index 34ff24372..000000000 --- a/examples/websocket/src/client.rs +++ /dev/null @@ -1,113 +0,0 @@ -//! Simple websocket client. - -#![allow(unused_variables)] -extern crate actix; -extern crate actix_web; -extern crate env_logger; -extern crate futures; -extern crate tokio_core; - -use std::{io, thread}; -use std::time::Duration; - -use actix::*; -use futures::Future; -use actix_web::ws::{Message, ProtocolError, Client, ClientWriter}; - - -fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=info"); - let _ = env_logger::init(); - let sys = actix::System::new("ws-example"); - - Arbiter::handle().spawn( - Client::new("http://127.0.0.1:8080/ws/") - .connect() - .map_err(|e| { - println!("Error: {}", e); - () - }) - .map(|(reader, writer)| { - let addr: Addr = ChatClient::create(|ctx| { - ChatClient::add_stream(reader, ctx); - ChatClient(writer) - }); - - // start console loop - thread::spawn(move|| { - loop { - let mut cmd = String::new(); - if io::stdin().read_line(&mut cmd).is_err() { - println!("error"); - return - } - addr.do_send(ClientCommand(cmd)); - } - }); - - () - }) - ); - - let _ = sys.run(); -} - - -struct ChatClient(ClientWriter); - -#[derive(Message)] -struct ClientCommand(String); - -impl Actor for ChatClient { - type Context = Context; - - fn started(&mut self, ctx: &mut Context) { - // start heartbeats otherwise server will disconnect after 10 seconds - self.hb(ctx) - } - - fn stopped(&mut self, _: &mut Context) { - println!("Disconnected"); - - // Stop application on disconnect - Arbiter::system().do_send(actix::msgs::SystemExit(0)); - } -} - -impl ChatClient { - fn hb(&self, ctx: &mut Context) { - ctx.run_later(Duration::new(1, 0), |act, ctx| { - act.0.ping(""); - act.hb(ctx); - }); - } -} - -/// Handle stdin commands -impl Handler for ChatClient { - type Result = (); - - fn handle(&mut self, msg: ClientCommand, ctx: &mut Context) { - self.0.text(msg.0) - } -} - -/// Handle server websocket messages -impl StreamHandler for ChatClient { - - fn handle(&mut self, msg: Message, ctx: &mut Context) { - match msg { - Message::Text(txt) => println!("Server: {:?}", txt), - _ => () - } - } - - fn started(&mut self, ctx: &mut Context) { - println!("Connected"); - } - - fn finished(&mut self, ctx: &mut Context) { - println!("Server disconnected"); - ctx.stop() - } -} diff --git a/examples/websocket/src/main.rs b/examples/websocket/src/main.rs deleted file mode 100644 index bcf2ee7ba..000000000 --- a/examples/websocket/src/main.rs +++ /dev/null @@ -1,66 +0,0 @@ -//! Simple echo websocket server. -//! Open `http://localhost:8080/ws/index.html` in browser -//! or [python console client](https://github.com/actix/actix-web/blob/master/examples/websocket-client.py) -//! could be used for testing. - -#![allow(unused_variables)] -extern crate actix; -extern crate actix_web; -extern crate env_logger; - -use actix::prelude::*; -use actix_web::{ - http, middleware, server, fs, ws, App, HttpRequest, HttpResponse, Error}; - -/// do websocket handshake and start `MyWebSocket` actor -fn ws_index(r: HttpRequest) -> Result { - ws::start(r, MyWebSocket) -} - -/// websocket connection is long running connection, it easier -/// to handle with an actor -struct MyWebSocket; - -impl Actor for MyWebSocket { - type Context = ws::WebsocketContext; -} - -/// Handler for `ws::Message` -impl StreamHandler for MyWebSocket { - - fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { - // process websocket messages - println!("WS: {:?}", msg); - match msg { - ws::Message::Ping(msg) => ctx.pong(&msg), - ws::Message::Text(text) => ctx.text(text), - ws::Message::Binary(bin) => ctx.binary(bin), - ws::Message::Close(_) => { - ctx.stop(); - } - _ => (), - } - } -} - -fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=info"); - let _ = env_logger::init(); - let sys = actix::System::new("ws-example"); - - server::new( - || App::new() - // enable logger - .middleware(middleware::Logger::default()) - // websocket route - .resource("/ws/", |r| r.method(http::Method::GET).f(ws_index)) - // static files - .handler("/", fs::StaticFiles::new("../static/", true) - .index_file("index.html"))) - // start http server on 127.0.0.1:8080 - .bind("127.0.0.1:8080").unwrap() - .start(); - - println!("Started http server: 127.0.0.1:8080"); - let _ = sys.run(); -} diff --git a/examples/websocket/websocket-client.py b/examples/websocket/websocket-client.py deleted file mode 100755 index 8a1bd9aee..000000000 --- a/examples/websocket/websocket-client.py +++ /dev/null @@ -1,72 +0,0 @@ -#!/usr/bin/env python3 -"""websocket cmd client for wssrv.py example.""" -import argparse -import asyncio -import signal -import sys - -import aiohttp - - -def start_client(loop, url): - name = input('Please enter your name: ') - - # send request - ws = yield from aiohttp.ClientSession().ws_connect(url, autoclose=False, autoping=False) - - # input reader - def stdin_callback(): - line = sys.stdin.buffer.readline().decode('utf-8') - if not line: - loop.stop() - else: - ws.send_str(name + ': ' + line) - loop.add_reader(sys.stdin.fileno(), stdin_callback) - - @asyncio.coroutine - def dispatch(): - while True: - msg = yield from ws.receive() - - if msg.type == aiohttp.WSMsgType.TEXT: - print('Text: ', msg.data.strip()) - elif msg.type == aiohttp.WSMsgType.BINARY: - print('Binary: ', msg.data) - elif msg.type == aiohttp.WSMsgType.PING: - ws.pong() - elif msg.type == aiohttp.WSMsgType.PONG: - print('Pong received') - else: - if msg.type == aiohttp.WSMsgType.CLOSE: - yield from ws.close() - elif msg.type == aiohttp.WSMsgType.ERROR: - print('Error during receive %s' % ws.exception()) - elif msg.type == aiohttp.WSMsgType.CLOSED: - pass - - break - - yield from dispatch() - - -ARGS = argparse.ArgumentParser( - description="websocket console client for wssrv.py example.") -ARGS.add_argument( - '--host', action="store", dest='host', - default='127.0.0.1', help='Host name') -ARGS.add_argument( - '--port', action="store", dest='port', - default=8080, type=int, help='Port number') - -if __name__ == '__main__': - args = ARGS.parse_args() - if ':' in args.host: - args.host, port = args.host.split(':', 1) - args.port = int(port) - - url = 'http://{}:{}/ws/'.format(args.host, args.port) - - loop = asyncio.get_event_loop() - loop.add_signal_handler(signal.SIGINT, loop.stop) - asyncio.Task(start_client(loop, url)) - loop.run_forever() diff --git a/guide/book.toml b/guide/book.toml deleted file mode 100644 index 5549978d7..000000000 --- a/guide/book.toml +++ /dev/null @@ -1,7 +0,0 @@ -[book] -title = "Actix web" -description = "Actix web framework guide" -author = "Actix Project and Contributors" - -[output.html] -google-analytics = "UA-110322332-1" diff --git a/guide/src/SUMMARY.md b/guide/src/SUMMARY.md deleted file mode 100644 index d76840f9c..000000000 --- a/guide/src/SUMMARY.md +++ /dev/null @@ -1,16 +0,0 @@ -# Summary - -[Quickstart](./qs_1.md) -- [Getting Started](./qs_2.md) -- [Application](./qs_3.md) -- [Server](./qs_3_5.md) -- [Handler](./qs_4.md) -- [Errors](./qs_4_5.md) -- [URL Dispatch](./qs_5.md) -- [Request & Response](./qs_7.md) -- [Testing](./qs_8.md) -- [Middlewares](./qs_10.md) -- [Static file handling](./qs_12.md) -- [WebSockets](./qs_9.md) -- [HTTP/2](./qs_13.md) -- [Database integration](./qs_14.md) diff --git a/guide/src/qs_1.md b/guide/src/qs_1.md deleted file mode 100644 index e73f65627..000000000 --- a/guide/src/qs_1.md +++ /dev/null @@ -1,34 +0,0 @@ -# Quick start - -Before you can start writing a actix web applications, you’ll need a version of Rust installed. -We recommend you use rustup to install or configure such a version. - -## Install Rust - -Before we begin, we need to install Rust using the [rustup](https://www.rustup.rs/) installer: - -```bash -curl https://sh.rustup.rs -sSf | sh -``` - -If you already have rustup installed, run this command to ensure you have the latest version of Rust: - -```bash -rustup update -``` - -Actix web framework requires rust version 1.21 and up. - -## Running Examples - -The fastest way to start experimenting with actix web is to clone the actix web repository -and run the included examples in the examples/ directory. The following set of -commands runs the `basics` example: - -```bash -git clone https://github.com/actix/actix-web -cd actix-web/examples/basics -cargo run -``` - -Check [examples/](https://github.com/actix/actix-web/tree/master/examples) directory for more examples. diff --git a/guide/src/qs_10.md b/guide/src/qs_10.md deleted file mode 100644 index aaff39ae1..000000000 --- a/guide/src/qs_10.md +++ /dev/null @@ -1,245 +0,0 @@ -# Middleware - -Actix' middleware system allows to add additional behavior to request/response processing. -Middleware can hook into incoming request process and modify request or halt request -processing and return response early. Also it can hook into response processing. - -Typically middlewares are involved in the following actions: - -* Pre-process the Request -* Post-process a Response -* Modify application state -* Access external services (redis, logging, sessions) - -Middlewares are registered for each application and are executed in same order as -registration order. In general, a *middleware* is a type that implements the -[*Middleware trait*](../actix_web/middlewares/trait.Middleware.html). Each method -in this trait has a default implementation. Each method can return a result immediately -or a *future* object. - -Here is an example of a simple middleware that adds request and response headers: - -```rust -# extern crate http; -# extern crate actix_web; -use http::{header, HttpTryFrom}; -use actix_web::{App, HttpRequest, HttpResponse, Result}; -use actix_web::middleware::{Middleware, Started, Response}; - -struct Headers; // <- Our middleware - -/// Middleware implementation, middlewares are generic over application state, -/// so you can access state with `HttpRequest::state()` method. -impl Middleware for Headers { - - /// Method is called when request is ready. It may return - /// future, which should resolve before next middleware get called. - fn start(&self, req: &mut HttpRequest) -> Result { - req.headers_mut().insert( - header::CONTENT_TYPE, header::HeaderValue::from_static("text/plain")); - Ok(Started::Done) - } - - /// Method is called when handler returns response, - /// but before sending http message to peer. - fn response(&self, req: &mut HttpRequest, mut resp: HttpResponse) -> Result { - resp.headers_mut().insert( - header::HeaderName::try_from("X-VERSION").unwrap(), - header::HeaderValue::from_static("0.2")); - Ok(Response::Done(resp)) - } -} - -fn main() { - App::new() - .middleware(Headers) // <- Register middleware, this method can be called multiple times - .resource("/", |r| r.f(|_| HttpResponse::Ok())); -} -``` - -Actix provides several useful middlewares, like *logging*, *user sessions*, etc. - - -## Logging - -Logging is implemented as a middleware. -It is common to register a logging middleware as the first middleware for the application. -Logging middleware has to be registered for each application. *Logger* middleware -uses the standard log crate to log information. You should enable logger for *actix_web* -package to see access log ([env_logger](https://docs.rs/env_logger/*/env_logger/) or similar). - -### Usage - -Create `Logger` middleware with the specified `format`. -Default `Logger` can be created with `default` method, it uses the default format: - -```ignore - %a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T -``` -```rust -# extern crate actix_web; -extern crate env_logger; -use actix_web::App; -use actix_web::middleware::Logger; - -fn main() { - std::env::set_var("RUST_LOG", "actix_web=info"); - env_logger::init(); - - App::new() - .middleware(Logger::default()) - .middleware(Logger::new("%a %{User-Agent}i")) - .finish(); -} -``` - -Here is an example of the default logging format: - -``` -INFO:actix_web::middleware::logger: 127.0.0.1:59934 [02/Dec/2017:00:21:43 -0800] "GET / HTTP/1.1" 302 0 "-" "curl/7.54.0" 0.000397 -INFO:actix_web::middleware::logger: 127.0.0.1:59947 [02/Dec/2017:00:22:40 -0800] "GET /index.html HTTP/1.1" 200 0 "-" "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.13; rv:57.0) Gecko/20100101 Firefox/57.0" 0.000646 -``` - -### Format - - `%%` The percent sign - - `%a` Remote IP-address (IP-address of proxy if using reverse proxy) - - `%t` Time when the request was started to process - - `%P` The process ID of the child that serviced the request - - `%r` First line of request - - `%s` Response status code - - `%b` Size of response in bytes, including HTTP headers - - `%T` Time taken to serve the request, in seconds with floating fraction in .06f format - - `%D` Time taken to serve the request, in milliseconds - - `%{FOO}i` request.headers['FOO'] - - `%{FOO}o` response.headers['FOO'] - - `%{FOO}e` os.environ['FOO'] - - -## Default headers - -To set default response headers the `DefaultHeaders` middleware can be used. The -*DefaultHeaders* middleware does not set the header if response headers already contain -the specified header. - -```rust -# extern crate actix_web; -use actix_web::{http, middleware, App, HttpResponse}; - -fn main() { - let app = App::new() - .middleware( - middleware::DefaultHeaders::new() - .header("X-Version", "0.2")) - .resource("/test", |r| { - r.method(http::Method::GET).f(|req| HttpResponse::Ok()); - r.method(http::Method::HEAD).f(|req| HttpResponse::MethodNotAllowed()); - }) - .finish(); -} -``` - -## User sessions - -Actix provides a general solution for session management. The -[*Session storage*](../actix_web/middleware/struct.SessionStorage.html) middleware can be -used with different backend types to store session data in different backends. -By default only cookie session backend is implemented. Other backend implementations -could be added later. - -[*Cookie session backend*](../actix_web/middleware/struct.CookieSessionBackend.html) -uses signed cookies as session storage. *Cookie session backend* creates sessions which -are limited to storing fewer than 4000 bytes of data (as the payload must fit into a -single cookie). Internal server error is generated if session contains more than 4000 bytes. - -You need to pass a random value to the constructor of *CookieSessionBackend*. -This is private key for cookie session. When this value is changed, all session data is lost. -Note that whatever you write into your session is visible by the user (but not modifiable). - -In general case, you create -[*Session storage*](../actix_web/middleware/struct.SessionStorage.html) middleware -and initializes it with specific backend implementation, like *CookieSessionBackend*. -To access session data -[*HttpRequest::session()*](../actix_web/middleware/trait.RequestSession.html#tymethod.session) - has to be used. This method returns a -[*Session*](../actix_web/middleware/struct.Session.html) object, which allows to get or set -session data. - -```rust -# extern crate actix; -# extern crate actix_web; -use actix_web::*; -use actix_web::middleware::{RequestSession, SessionStorage, CookieSessionBackend}; - -fn index(mut req: HttpRequest) -> Result<&'static str> { - // access session data - if let Some(count) = req.session().get::("counter")? { - println!("SESSION value: {}", count); - req.session().set("counter", count+1)?; - } else { - req.session().set("counter", 1)?; - } - - Ok("Welcome!") -} - -fn main() { -# let sys = actix::System::new("basic-example"); - HttpServer::new( - || App::new() - .middleware(SessionStorage::new( // <- create session middleware - CookieSessionBackend::build(&[0; 32]) // <- create cookie session backend - .secure(false) - .finish() - ))) - .bind("127.0.0.1:59880").unwrap() - .start(); -# actix::Arbiter::system().do_send(actix::msgs::SystemExit(0)); -# let _ = sys.run(); -} -``` - -## Error handlers - -`ErrorHandlers` middleware allows to provide custom handlers for responses. - -You can use `ErrorHandlers::handler()` method to register a custom error handler -for specific status code. You can modify existing response or create completly new -one. Error handler can return response immediately or return future that resolves -to a response. - -```rust -# extern crate actix_web; -use actix_web::{ - App, HttpRequest, HttpResponse, Result, - http, middleware::Response, middleware::ErrorHandlers}; - -fn render_500(_: &mut HttpRequest, resp: HttpResponse) -> Result { - let mut builder = resp.into_builder(); - builder.header(http::header::CONTENT_TYPE, "application/json"); - Ok(Response::Done(builder.into())) -} - -fn main() { - let app = App::new() - .middleware( - ErrorHandlers::new() - .handler(http::StatusCode::INTERNAL_SERVER_ERROR, render_500)) - .resource("/test", |r| { - r.method(http::Method::GET).f(|_| HttpResponse::Ok()); - r.method(http::Method::HEAD).f(|_| HttpResponse::MethodNotAllowed()); - }) - .finish(); -} -``` diff --git a/guide/src/qs_12.md b/guide/src/qs_12.md deleted file mode 100644 index 1da5f1ef9..000000000 --- a/guide/src/qs_12.md +++ /dev/null @@ -1,49 +0,0 @@ -# Static file handling - -## Individual file - -It is possible to serve static files with custom path pattern and `NamedFile`. To -match path tail we can use `[.*]` regex. - -```rust -# extern crate actix_web; -use std::path::PathBuf; -use actix_web::{App, HttpRequest, Result, http::Method, fs::NamedFile}; - -fn index(req: HttpRequest) -> Result { - let path: PathBuf = req.match_info().query("tail")?; - Ok(NamedFile::open(path)?) -} - -fn main() { - App::new() - .resource(r"/a/{tail:.*}", |r| r.method(Method::GET).f(index)) - .finish(); -} -``` - -## Directory - -To serve files from specific directory and sub-directories `StaticFiles` could be used. -`StaticFiles` must be registered with `App::handler()` method otherwise -it won't be able to serve sub-paths. - -```rust -# extern crate actix_web; -use actix_web::*; - -fn main() { - App::new() - .handler("/static", fs::StaticFiles::new(".", true)) - .finish(); -} -``` - -First parameter is a base directory. Second parameter is *show_index*, if it is set to *true* -directory listing would be returned for directories, if it is set to *false* -then *404 Not Found* would be returned instead of directory listing. - -Instead of showing files listing for directory, it is possible to redirect to specific -index file. Use -[*StaticFiles::index_file()*](../actix_web/s/struct.StaticFiles.html#method.index_file) -method to configure this redirect. diff --git a/guide/src/qs_13.md b/guide/src/qs_13.md deleted file mode 100644 index 753a9c16f..000000000 --- a/guide/src/qs_13.md +++ /dev/null @@ -1,44 +0,0 @@ -# HTTP/2.0 - -Actix web automatically upgrades connection to *HTTP/2.0* if possible. - -## Negotiation - -*HTTP/2.0* protocol over tls without prior knowledge requires -[tls alpn](https://tools.ietf.org/html/rfc7301). At the moment only -`rust-openssl` has support. Turn on the `alpn` feature to enable `alpn` negotiation. -With enabled `alpn` feature `HttpServer` provides the -[serve_tls](../actix_web/struct.HttpServer.html#method.serve_tls) method. - -```toml -[dependencies] -actix-web = { version = "0.3.3", features=["alpn"] } -openssl = { version="0.10", features = ["v110"] } -``` - -```rust,ignore -use std::fs::File; -use actix_web::*; -use openssl::ssl::{SslMethod, SslAcceptor, SslFiletype}; - -fn main() { - // load ssl keys - let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - builder.set_private_key_file("key.pem", SslFiletype::PEM).unwrap(); - builder.set_certificate_chain_file("cert.pem").unwrap(); - - HttpServer::new( - || App::new() - .resource("/index.html", |r| r.f(index))) - .bind("127.0.0.1:8080").unwrap(); - .serve_ssl(builder).unwrap(); -} -``` - -Upgrade to *HTTP/2.0* schema described in -[rfc section 3.2](https://http2.github.io/http2-spec/#rfc.section.3.2) is not supported. -Starting *HTTP/2* with prior knowledge is supported for both clear text connection -and tls connection. [rfc section 3.4](https://http2.github.io/http2-spec/#rfc.section.3.4) - -Please check [example](https://github.com/actix/actix-web/tree/master/examples/tls) -for a concrete example. diff --git a/guide/src/qs_14.md b/guide/src/qs_14.md deleted file mode 100644 index a805e7a58..000000000 --- a/guide/src/qs_14.md +++ /dev/null @@ -1,127 +0,0 @@ -# Database integration - -## Diesel - -At the moment of 1.0 release Diesel does not support asynchronous operations. -But it possible to use the `actix` synchronous actor system as a db interface api. -Technically sync actors are worker style actors, multiple of them -can be run in parallel and process messages from same queue (sync actors work in mpsc mode). - -Let's create a simple db api that can insert a new user row into an SQLite table. -We have to define sync actor and connection that this actor will use. The same approach -can be used for other databases. - -```rust,ignore -use actix::prelude::*; - -struct DbExecutor(SqliteConnection); - -impl Actor for DbExecutor { - type Context = SyncContext; -} -``` - -This is the definition of our actor. Now we need to define the *create user* message and response. - -```rust,ignore -struct CreateUser { - name: String, -} - -impl Message for CreateUser { - type Result = Result; -} -``` - -We can send a `CreateUser` message to the `DbExecutor` actor, and as a result we get a -`User` model instance. Now we need to define the actual handler implementation for this message. - -```rust,ignore -impl Handler for DbExecutor { - type Result = Result; - - fn handle(&mut self, msg: CreateUser, _: &mut Self::Context) -> Self::Result - { - use self::schema::users::dsl::*; - - // Create insertion model - let uuid = format!("{}", uuid::Uuid::new_v4()); - let new_user = models::NewUser { - id: &uuid, - name: &msg.name, - }; - - // normal diesel operations - diesel::insert_into(users) - .values(&new_user) - .execute(&self.0) - .expect("Error inserting person"); - - let mut items = users - .filter(id.eq(&uuid)) - .load::(&self.0) - .expect("Error loading person"); - - Ok(items.pop().unwrap()) - } -} -``` - -That's it. Now we can use the *DbExecutor* actor from any http handler or middleware. -All we need is to start *DbExecutor* actors and store the address in a state where http handler -can access it. - -```rust,ignore -/// This is state where we will store *DbExecutor* address. -struct State { - db: Addr, -} - -fn main() { - let sys = actix::System::new("diesel-example"); - - // Start 3 parallel db executors - let addr = SyncArbiter::start(3, || { - DbExecutor(SqliteConnection::establish("test.db").unwrap()) - }); - - // Start http server - HttpServer::new(move || { - App::with_state(State{db: addr.clone()}) - .resource("/{name}", |r| r.method(Method::GET).a(index))}) - .bind("127.0.0.1:8080").unwrap() - .start().unwrap(); - - println!("Started http server: 127.0.0.1:8080"); - let _ = sys.run(); -} -``` - -And finally we can use the address in a request handler. We get a message response -asynchronously, so the handler needs to return a future object, also `Route::a()` needs to be -used for async handler registration. - - -```rust,ignore -/// Async handler -fn index(req: HttpRequest) -> Box> { - let name = &req.match_info()["name"]; - - // Send message to `DbExecutor` actor - req.state().db.send(CreateUser{name: name.to_owned()}) - .from_err() - .and_then(|res| { - match res { - Ok(user) => Ok(HttpResponse::Ok().json(user)), - Err(_) => Ok(HttpResponse::InternalServerError().into()) - } - }) - .responder() -} -``` - -Full example is available in the -[examples directory](https://github.com/actix/actix-web/tree/master/examples/diesel/). - -More information on sync actors can be found in the -[actix documentation](https://docs.rs/actix/0.5.0/actix/sync/index.html). diff --git a/guide/src/qs_2.md b/guide/src/qs_2.md deleted file mode 100644 index e405775d4..000000000 --- a/guide/src/qs_2.md +++ /dev/null @@ -1,98 +0,0 @@ -# Getting Started - -Let’s create and run our first actix web application. We’ll create a new Cargo project -that depends on actix web and then run the application. - -In the previous section we already installed the required rust version. Now let's create new cargo projects. - -## Hello, world! - -Let’s write our first actix web application! Start by creating a new binary-based -Cargo project and changing into the new directory: - -```bash -cargo new hello-world --bin -cd hello-world -``` - -Now, add actix and actix web as dependencies of your project by ensuring your Cargo.toml -contains the following: - -```toml -[dependencies] -actix = "0.5" -actix-web = "0.4" -``` - -In order to implement a web server, first we need to create a request handler. - -A request handler is a function that accepts an `HttpRequest` instance as its only parameter -and returns a type that can be converted into `HttpResponse`: - -```rust -# extern crate actix_web; -# use actix_web::*; - fn index(req: HttpRequest) -> &'static str { - "Hello world!" - } -# fn main() {} -``` - -Next, create an `Application` instance and register the -request handler with the application's `resource` on a particular *HTTP method* and *path*:: - -```rust -# extern crate actix_web; -# use actix_web::*; -# fn index(req: HttpRequest) -> &'static str { -# "Hello world!" -# } -# fn main() { - App::new() - .resource("/", |r| r.f(index)); -# } -``` - -After that, the application instance can be used with `HttpServer` to listen for incoming -connections. The server accepts a function that should return an `HttpHandler` instance: - -```rust,ignore - HttpServer::new( - || App::new() - .resource("/", |r| r.f(index))) - .bind("127.0.0.1:8088")? - .run(); -``` - -That's it. Now, compile and run the program with `cargo run`. -Head over to ``http://localhost:8088/`` to see the results. - -Here is full source of main.rs file: - -```rust -# use std::thread; -extern crate actix_web; -use actix_web::{App, HttpRequest, HttpResponse, HttpServer}; - -fn index(req: HttpRequest) -> &'static str { - "Hello world!" -} - -fn main() { -# // In the doctest suite we can't run blocking code - deliberately leak a thread -# // If copying this example in show-all mode make sure you skip the thread spawn -# // call. -# thread::spawn(|| { - HttpServer::new( - || App::new() - .resource("/", |r| r.f(index))) - .bind("127.0.0.1:8088").expect("Can not bind to 127.0.0.1:8088") - .run(); -# }); -} -``` - -Note on the `actix` crate. Actix web framework is built on top of actix actor library. -`actix::System` initializes actor system, `HttpServer` is an actor and must run within a -properly configured actix system. For more information please check -[actix documentation](https://actix.github.io/actix/actix/) diff --git a/guide/src/qs_3.md b/guide/src/qs_3.md deleted file mode 100644 index bcfdee8ad..000000000 --- a/guide/src/qs_3.md +++ /dev/null @@ -1,109 +0,0 @@ -# Application - -Actix web provides some primitives to build web servers and applications with Rust. -It provides routing, middlewares, pre-processing of requests, and post-processing of responses, -websocket protocol handling, multipart streams, etc. - -All actix web servers are built around the `App` instance. -It is used for registering routes for resources, and middlewares. -It also stores application specific state that is shared across all handlers -within same application. - -Application acts as a namespace for all routes, i.e all routes for a specific application -have the same url path prefix. The application prefix always contains a leading "/" slash. -If supplied prefix does not contain leading slash, it gets inserted. -The prefix should consist of value path segments. i.e for an application with prefix `/app` -any request with the paths `/app`, `/app/` or `/app/test` would match, -but path `/application` would not match. - -```rust,ignore -# extern crate actix_web; -# extern crate tokio_core; -# use actix_web::{*, http::Method}; -# fn index(req: HttpRequest) -> &'static str { -# "Hello world!" -# } -# fn main() { - let app = App::new() - .prefix("/app") - .resource("/index.html", |r| r.method(Method::GET).f(index)) - .finish() -# } -``` - -In this example application with `/app` prefix and `index.html` resource -gets created. This resource is available as on `/app/index.html` url. -For more information check -[*URL Matching*](./qs_5.html#using-a-application-prefix-to-compose-applications) section. - -Multiple applications can be served with one server: - -```rust -# extern crate actix_web; -# extern crate tokio_core; -# use tokio_core::net::TcpStream; -# use std::net::SocketAddr; -use actix_web::{App, HttpResponse, HttpServer}; - -fn main() { - HttpServer::new(|| vec![ - App::new() - .prefix("/app1") - .resource("/", |r| r.f(|r| HttpResponse::Ok())), - App::new() - .prefix("/app2") - .resource("/", |r| r.f(|r| HttpResponse::Ok())), - App::new() - .resource("/", |r| r.f(|r| HttpResponse::Ok())), - ]); -} -``` - -All `/app1` requests route to the first application, `/app2` to the second and then all other to the third. -Applications get matched based on registration order, if an application with more general -prefix is registered before a less generic one, that would effectively block the less generic -application from getting matched. For example, if *application* with prefix "/" gets registered -as first application, it would match all incoming requests. - -## State - -Application state is shared with all routes and resources within the same application. -State can be accessed with the `HttpRequest::state()` method as a read-only, -but an interior mutability pattern with `RefCell` can be used to achieve state mutability. -State can be accessed with `HttpContext::state()` when using an http actor. -State is also available for route matching predicates and middlewares. - -Let's write a simple application that uses shared state. We are going to store request count -in the state: - -```rust -# extern crate actix; -# extern crate actix_web; -# -use std::cell::Cell; -use actix_web::{App, HttpRequest, http}; - -// This struct represents state -struct AppState { - counter: Cell, -} - -fn index(req: HttpRequest) -> String { - let count = req.state().counter.get() + 1; // <- get count - req.state().counter.set(count); // <- store new count in state - - format!("Request number: {}", count) // <- response with count -} - -fn main() { - App::with_state(AppState{counter: Cell::new(0)}) - .resource("/", |r| r.method(http::Method::GET).f(index)) - .finish(); -} -``` - -Note on application state, http server accepts an application factory rather than an application -instance. Http server constructs an application instance for each thread, so application state -must be constructed multiple times. If you want to share state between different threads, a -shared object should be used, like `Arc`. Application state does not need to be `Send` and `Sync` -but the application factory must be `Send` + `Sync`. diff --git a/guide/src/qs_3_5.md b/guide/src/qs_3_5.md deleted file mode 100644 index 274524024..000000000 --- a/guide/src/qs_3_5.md +++ /dev/null @@ -1,204 +0,0 @@ -# Server - -The [*HttpServer*](../actix_web/struct.HttpServer.html) type is responsible for -serving http requests. *HttpServer* accepts application factory as a parameter, -Application factory must have `Send` + `Sync` boundaries. More about that in the -*multi-threading* section. To bind to a specific socket address, `bind()` must be used. -This method can be called multiple times. To start the http server, one of the *start* -methods can be used. `start()` method starts a simple server, `start_tls()` or `start_ssl()` -starts ssl server. *HttpServer* is an actix actor, it has to be initialized -within a properly configured actix system: - -```rust -# extern crate actix; -# extern crate actix_web; -use actix::*; -use actix_web::{server, App, HttpResponse}; - -fn main() { - let sys = actix::System::new("guide"); - - server::new( - || App::new() - .resource("/", |r| r.f(|_| HttpResponse::Ok()))) - .bind("127.0.0.1:59080").unwrap() - .start(); - -# actix::Arbiter::system().do_send(actix::msgs::SystemExit(0)); - let _ = sys.run(); -} -``` - -It is possible to start a server in a separate thread with the *spawn()* method. In that -case the server spawns a new thread and creates a new actix system in it. To stop -this server, send a `StopServer` message. - -Http server is implemented as an actix actor. It is possible to communicate with the server -via a messaging system. All start methods like `start()`, `start_ssl()`, etc. return the -address of the started http server. Actix http server accepts several messages: - -* `PauseServer` - Pause accepting incoming connections -* `ResumeServer` - Resume accepting incoming connections -* `StopServer` - Stop incoming connection processing, stop all workers and exit - -```rust -# extern crate futures; -# extern crate actix; -# extern crate actix_web; -# use futures::Future; -use std::thread; -use std::sync::mpsc; -use actix::*; -use actix_web::{server, App, HttpResponse, HttpServer}; - -fn main() { - let (tx, rx) = mpsc::channel(); - - thread::spawn(move || { - let sys = actix::System::new("http-server"); - let addr = server::new( - || App::new() - .resource("/", |r| r.f(|_| HttpResponse::Ok()))) - .bind("127.0.0.1:0").expect("Can not bind to 127.0.0.1:0") - .shutdown_timeout(60) // <- Set shutdown timeout to 60 seconds - .start(); - let _ = tx.send(addr); - let _ = sys.run(); - }); - - let addr = rx.recv().unwrap(); - let _ = addr.send( - server::StopServer{graceful:true}).wait(); // <- Send `StopServer` message to server. -} -``` - -## Multi-threading - -Http server automatically starts an number of http workers, by default -this number is equal to number of logical CPUs in the system. This number -can be overridden with the `HttpServer::threads()` method. - -```rust -# extern crate actix_web; -# extern crate tokio_core; -use actix_web::{App, HttpServer, HttpResponse}; - -fn main() { - HttpServer::new( - || App::new() - .resource("/", |r| r.f(|_| HttpResponse::Ok()))) - .threads(4); // <- Start 4 workers -} -``` - -The server creates a separate application instance for each created worker. Application state -is not shared between threads, to share state `Arc` could be used. Application state -does not need to be `Send` and `Sync` but application factory must be `Send` + `Sync`. - -## SSL - -There are two features for ssl server: `tls` and `alpn`. The `tls` feature is for `native-tls` -integration and `alpn` is for `openssl`. - -```toml -[dependencies] -actix-web = { git = "https://github.com/actix/actix-web", features=["alpn"] } -``` - -```rust,ignore -use std::fs::File; -use actix_web::*; - -fn main() { - // load ssl keys - let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - builder.set_private_key_file("key.pem", SslFiletype::PEM).unwrap(); - builder.set_certificate_chain_file("cert.pem").unwrap(); - - HttpServer::new( - || App::new() - .resource("/index.html", |r| r.f(index))) - .bind("127.0.0.1:8080").unwrap() - .serve_ssl(builder).unwrap(); -} -``` - -Note on *HTTP/2.0* protocol over tls without prior knowledge, it requires -[tls alpn](https://tools.ietf.org/html/rfc7301). At the moment only -`openssl` has `alpn ` support. - -Please check [example](https://github.com/actix/actix-web/tree/master/examples/tls) -for a full example. - -## Keep-Alive - -Actix can wait for requests on a keep-alive connection. *Keep alive* -connection behavior is defined by server settings. - - * `75` or `Some(75)` or `KeepAlive::Timeout(75)` - enable 75 sec *keep alive* timer according - request and response settings. - * `None` or `KeepAlive::Disabled` - disable *keep alive*. - * `KeepAlive::Tcp(75)` - Use `SO_KEEPALIVE` socket option. - -```rust -# extern crate actix_web; -# extern crate tokio_core; -use actix_web::{server, App, HttpResponse}; - -fn main() { - server::new(|| - App::new() - .resource("/", |r| r.f(|_| HttpResponse::Ok()))) - .keep_alive(75); // <- Set keep-alive to 75 seconds - - server::new(|| - App::new() - .resource("/", |r| r.f(|_| HttpResponse::Ok()))) - .keep_alive(server::KeepAlive::Tcp(75)); // <- Use `SO_KEEPALIVE` socket option. - - server::new(|| - App::new() - .resource("/", |r| r.f(|_| HttpResponse::Ok()))) - .keep_alive(None); // <- Disable keep-alive -} -``` - -If first option is selected then *keep alive* state is -calculated based on the response's *connection-type*. By default -`HttpResponse::connection_type` is not defined in that case *keep alive* -defined by request's http version. Keep alive is off for *HTTP/1.0* -and is on for *HTTP/1.1* and *HTTP/2.0*. - -*Connection type* can be change with `HttpResponseBuilder::connection_type()` method. - -```rust -# extern crate actix_web; -use actix_web::{HttpRequest, HttpResponse, http}; - -fn index(req: HttpRequest) -> HttpResponse { - HttpResponse::Ok() - .connection_type(http::ConnectionType::Close) // <- Close connection - .force_close() // <- Alternative method - .finish() -} -# fn main() {} -``` - -## Graceful shutdown - -Actix http server supports graceful shutdown. After receiving a stop signal, workers -have a specific amount of time to finish serving requests. Workers still alive after the -timeout are force-dropped. By default the shutdown timeout is set to 30 seconds. -You can change this parameter with the `HttpServer::shutdown_timeout()` method. - -You can send a stop message to the server with the server address and specify if you want -graceful shutdown or not. The `start()` methods return address of the server. - -Http server handles several OS signals. *CTRL-C* is available on all OSs, -other signals are available on unix systems. - -* *SIGINT* - Force shutdown workers -* *SIGTERM* - Graceful shutdown workers -* *SIGQUIT* - Force shutdown workers - -It is possible to disable signal handling with `HttpServer::disable_signals()` method. diff --git a/guide/src/qs_4.md b/guide/src/qs_4.md deleted file mode 100644 index 5c31a78f5..000000000 --- a/guide/src/qs_4.md +++ /dev/null @@ -1,310 +0,0 @@ -# Handler - -A request handler can be any object that implements -[*Handler trait*](../actix_web/dev/trait.Handler.html). -Request handling happens in two stages. First the handler object is called. -Handler can return any object that implements -[*Responder trait*](../actix_web/trait.Responder.html#foreign-impls). -Then `respond_to()` is called on the returned object. And finally -result of the `respond_to()` call is converted to a `Reply` object. - -By default actix provides `Responder` implementations for some standard types, -like `&'static str`, `String`, etc. -For a complete list of implementations, check -[*Responder documentation*](../actix_web/trait.Responder.html#foreign-impls). - -Examples of valid handlers: - -```rust,ignore -fn index(req: HttpRequest) -> &'static str { - "Hello world!" -} -``` - -```rust,ignore -fn index(req: HttpRequest) -> String { - "Hello world!".to_owned() -} -``` - -```rust,ignore -fn index(req: HttpRequest) -> Bytes { - Bytes::from_static("Hello world!") -} -``` - -```rust,ignore -fn index(req: HttpRequest) -> Box> { - ... -} -``` - -Some notes on shared application state and handler state. If you noticed -*Handler* trait is generic over *S*, which defines application state type. So -application state is accessible from handler with the `HttpRequest::state()` method. -But state is accessible as a read-only reference - if you need mutable access to state -you have to implement it yourself. On other hand, handler can mutably access its own state -as the `handle` method takes a mutable reference to *self*. Beware, actix creates multiple copies -of application state and handlers, unique for each thread, so if you run your -application in several threads, actix will create the same amount as number of threads -of application state objects and handler objects. - -Here is an example of a handler that stores the number of processed requests: - -```rust -# extern crate actix_web; -use actix_web::{App, HttpRequest, HttpResponse, dev::Handler}; - -struct MyHandler(usize); - -impl Handler for MyHandler { - type Result = HttpResponse; - - /// Handle request - fn handle(&mut self, req: HttpRequest) -> Self::Result { - self.0 += 1; - HttpResponse::Ok().into() - } -} -# fn main() {} -``` - -This handler will work, but `self.0` will be different depending on the number of threads and -number of requests processed per thread. A proper implementation would use `Arc` and `AtomicUsize` - -```rust -# extern crate actix; -# extern crate actix_web; -use actix_web::{server, App, HttpRequest, HttpResponse, dev::Handler}; -use std::sync::Arc; -use std::sync::atomic::{AtomicUsize, Ordering}; - -struct MyHandler(Arc); - -impl Handler for MyHandler { - type Result = HttpResponse; - - /// Handle request - fn handle(&mut self, req: HttpRequest) -> Self::Result { - self.0.fetch_add(1, Ordering::Relaxed); - HttpResponse::Ok().into() - } -} - -fn main() { - let sys = actix::System::new("example"); - - let inc = Arc::new(AtomicUsize::new(0)); - - server::new( - move || { - let cloned = inc.clone(); - App::new() - .resource("/", move |r| r.h(MyHandler(cloned))) - }) - .bind("127.0.0.1:8088").unwrap() - .start(); - - println!("Started http server: 127.0.0.1:8088"); -# actix::Arbiter::system().do_send(actix::msgs::SystemExit(0)); - let _ = sys.run(); -} -``` - -Be careful with synchronization primitives like *Mutex* or *RwLock*. Actix web framework -handles requests asynchronously; by blocking thread execution all concurrent -request handling processes would block. If you need to share or update some state -from multiple threads consider using the [actix](https://actix.github.io/actix/actix/) actor system. - -## Response with custom type - -To return a custom type directly from a handler function, the type needs to implement the `Responder` trait. -Let's create a response for a custom type that serializes to an `application/json` response: - -```rust -# extern crate actix; -# extern crate actix_web; -extern crate serde; -extern crate serde_json; -#[macro_use] extern crate serde_derive; -use actix_web::{App, HttpServer, HttpRequest, HttpResponse, Error, Responder, http}; - -#[derive(Serialize)] -struct MyObj { - name: &'static str, -} - -/// Responder -impl Responder for MyObj { - type Item = HttpResponse; - type Error = Error; - - fn respond_to(self, req: HttpRequest) -> Result { - let body = serde_json::to_string(&self)?; - - // Create response and set content type - Ok(HttpResponse::Ok() - .content_type("application/json") - .body(body)) - } -} - -/// Because `MyObj` implements `Responder`, it is possible to return it directly -fn index(req: HttpRequest) -> MyObj { - MyObj{name: "user"} -} - -fn main() { - let sys = actix::System::new("example"); - - HttpServer::new( - || App::new() - .resource("/", |r| r.method(http::Method::GET).f(index))) - .bind("127.0.0.1:8088").unwrap() - .start(); - - println!("Started http server: 127.0.0.1:8088"); -# actix::Arbiter::system().do_send(actix::msgs::SystemExit(0)); - let _ = sys.run(); -} -``` - -## Async handlers - -There are two different types of async handlers. - -Response objects can be generated asynchronously or more precisely, any type -that implements the [*Responder*](../actix_web/trait.Responder.html) trait. In this case the handler must return a `Future` object that resolves to the *Responder* type, i.e: - -```rust -# extern crate actix_web; -# extern crate futures; -# extern crate bytes; -# use actix_web::*; -# use bytes::Bytes; -# use futures::stream::once; -# use futures::future::{Future, result}; -fn index(req: HttpRequest) -> Box> { - - result(Ok(HttpResponse::Ok() - .content_type("text/html") - .body(format!("Hello!")))) - .responder() -} - -fn index2(req: HttpRequest) -> Box> { - result(Ok("Welcome!")) - .responder() -} - -fn main() { - App::new() - .resource("/async", |r| r.route().a(index)) - .resource("/", |r| r.route().a(index2)) - .finish(); -} -``` - -Or the response body can be generated asynchronously. In this case body -must implement stream trait `Stream`, i.e: - -```rust -# extern crate actix_web; -# extern crate futures; -# extern crate bytes; -# use actix_web::*; -# use bytes::Bytes; -# use futures::stream::once; -fn index(req: HttpRequest) -> HttpResponse { - let body = once(Ok(Bytes::from_static(b"test"))); - - HttpResponse::Ok() - .content_type("application/json") - .body(Body::Streaming(Box::new(body))) -} - -fn main() { - App::new() - .resource("/async", |r| r.f(index)) - .finish(); -} -``` - -Both methods can be combined. (i.e Async response with streaming body) - -It is possible to return a `Result` where the `Result::Item` type can be `Future`. -In this example the `index` handler can return an error immediately or return a -future that resolves to a `HttpResponse`. - -```rust -# extern crate actix_web; -# extern crate futures; -# extern crate bytes; -# use actix_web::*; -# use bytes::Bytes; -# use futures::stream::once; -# use futures::future::{Future, result}; -fn index(req: HttpRequest) -> Result>, Error> { - if is_error() { - Err(error::ErrorBadRequest("bad request")) - } else { - Ok(Box::new( - result(Ok(HttpResponse::Ok() - .content_type("text/html") - .body(format!("Hello!")))))) - } -} -# -# fn is_error() -> bool { true } -# fn main() { -# App::new() -# .resource("/async", |r| r.route().f(index)) -# .finish(); -# } -``` - -## Different return types (Either) - -Sometimes you need to return different types of responses. For example -you can do error check and return error and return async response otherwise. -Or any result that requires two different types. -For this case the [*Either*](../actix_web/enum.Either.html) type can be used. -*Either* allows combining two different responder types into a single type. - -```rust -# extern crate actix_web; -# extern crate futures; -# use actix_web::*; -# use futures::future::Future; -use futures::future::result; -use actix_web::{Either, Error, HttpResponse}; - -type RegisterResult = Either>>; - -fn index(req: HttpRequest) -> RegisterResult { - if is_a_variant() { // <- choose variant A - Either::A( - HttpResponse::BadRequest().body("Bad data")) - } else { - Either::B( // <- variant B - result(Ok(HttpResponse::Ok() - .content_type("text/html") - .body(format!("Hello!")))).responder()) - } -} -# fn is_a_variant() -> bool { true } -# fn main() { -# App::new() -# .resource("/register", |r| r.f(index)) -# .finish(); -# } -``` - -## Tokio core handle - -Any actix web handler runs within a properly configured -[actix system](https://actix.github.io/actix/actix/struct.System.html) -and [arbiter](https://actix.github.io/actix/actix/struct.Arbiter.html). -You can always get access to the tokio handle via the -[Arbiter::handle()](https://actix.github.io/actix/actix/struct.Arbiter.html#method.handle) -method. diff --git a/guide/src/qs_4_5.md b/guide/src/qs_4_5.md deleted file mode 100644 index cf8c6ef36..000000000 --- a/guide/src/qs_4_5.md +++ /dev/null @@ -1,151 +0,0 @@ -# Errors - -Actix uses [`Error` type](../actix_web/error/struct.Error.html) -and [`ResponseError` trait](../actix_web/error/trait.ResponseError.html) -for handling handler's errors. -Any error that implements the `ResponseError` trait can be returned as an error value. -*Handler* can return an *Result* object; actix by default provides -`Responder` implementation for compatible result types. Here is the implementation -definition: - -```rust,ignore -impl> Responder for Result -``` - -And any error that implements `ResponseError` can be converted into an `Error` object. -For example, if the *handler* function returns `io::Error`, it would be converted -into an `HttpInternalServerError` response. Implementation for `io::Error` is provided -by default. - -```rust -# extern crate actix_web; -# use actix_web::*; -use std::io; - -fn index(req: HttpRequest) -> io::Result { - Ok(fs::NamedFile::open("static/index.html")?) -} -# -# fn main() { -# App::new() -# .resource(r"/a/index.html", |r| r.f(index)) -# .finish(); -# } -``` - -## Custom error response - -To add support for custom errors, all we need to do is just implement the `ResponseError` trait -for the custom error type. The `ResponseError` trait has a default implementation -for the `error_response()` method: it generates a *500* response. - -```rust -# extern crate actix_web; -#[macro_use] extern crate failure; -use actix_web::*; - -#[derive(Fail, Debug)] -#[fail(display="my error")] -struct MyError { - name: &'static str -} - -/// Use default implementation for `error_response()` method -impl error::ResponseError for MyError {} - -fn index(req: HttpRequest) -> Result<&'static str, MyError> { - Err(MyError{name: "test"}) -} -# -# fn main() { -# App::new() -# .resource(r"/a/index.html", |r| r.f(index)) -# .finish(); -# } -``` - -In this example the *index* handler always returns a *500* response. But it is easy -to return different responses for different types of errors. - -```rust -# extern crate actix_web; -#[macro_use] extern crate failure; -use actix_web::{App, HttpRequest, HttpResponse, http, error}; - -#[derive(Fail, Debug)] -enum MyError { - #[fail(display="internal error")] - InternalError, - #[fail(display="bad request")] - BadClientData, - #[fail(display="timeout")] - Timeout, -} - -impl error::ResponseError for MyError { - fn error_response(&self) -> HttpResponse { - match *self { - MyError::InternalError => HttpResponse::new( - http::StatusCode::INTERNAL_SERVER_ERROR), - MyError::BadClientData => HttpResponse::new( - http::StatusCode::BAD_REQUEST), - MyError::Timeout => HttpResponse::new( - http::StatusCode::GATEWAY_TIMEOUT), - } - } -} - -fn index(req: HttpRequest) -> Result<&'static str, MyError> { - Err(MyError::BadClientData) -} -# -# fn main() { -# App::new() -# .resource(r"/a/index.html", |r| r.f(index)) -# .finish(); -# } -``` - -## Error helpers - -Actix provides a set of error helper types. It is possible to use them for generating -specific error responses. We can use helper types for the first example with custom error. - -```rust -# extern crate actix_web; -#[macro_use] extern crate failure; -use actix_web::*; - -#[derive(Debug)] -struct MyError { - name: &'static str -} - -fn index(req: HttpRequest) -> Result<&'static str> { - let result: Result<&'static str, MyError> = Err(MyError{name: "test"}); - - Ok(result.map_err(|e| error::ErrorBadRequest(e))?) -} -# fn main() { -# App::new() -# .resource(r"/a/index.html", |r| r.f(index)) -# .finish(); -# } -``` - -In this example, a *BAD REQUEST* response is generated for the `MyError` error. - -## Error logging - -Actix logs all errors with the log level `WARN`. If log level set to `DEBUG` -and `RUST_BACKTRACE` is enabled, the backtrace gets logged. The Error type uses -the cause's error backtrace if available. If the underlying failure does not provide -a backtrace, a new backtrace is constructed pointing to that conversion point -(rather than the origin of the error). This construction only happens if there -is no underlying backtrace; if it does have a backtrace, no new backtrace is constructed. - -You can enable backtrace and debug logging with following command: - -``` ->> RUST_BACKTRACE=1 RUST_LOG=actix_web=debug cargo run -``` diff --git a/guide/src/qs_5.md b/guide/src/qs_5.md deleted file mode 100644 index f97840a06..000000000 --- a/guide/src/qs_5.md +++ /dev/null @@ -1,622 +0,0 @@ -# URL Dispatch - -URL dispatch provides a simple way for mapping URLs to `Handler` code using a simple pattern -matching language. If one of the patterns matches the path information associated with a request, -a particular handler object is invoked. A handler is a specific object that implements the -`Handler` trait, defined in your application, that receives the request and returns -a response object. More information is available in the [handler section](../qs_4.html). - -## Resource configuration - -Resource configuration is the act of adding a new resources to an application. -A resource has a name, which acts as an identifier to be used for URL generation. -The name also allows developers to add routes to existing resources. -A resource also has a pattern, meant to match against the *PATH* portion of a *URL*, -it does not match against the *QUERY* portion (the portion following the scheme and -port, e.g., */foo/bar* in the *URL* *http://localhost:8080/foo/bar?q=value*). - -The [App::resource](../actix_web/struct.App.html#method.resource) methods -add a single resource to application routing table. This method accepts a *path pattern* -and a resource configuration function. - -```rust -# extern crate actix_web; -# use actix_web::{App, HttpRequest, HttpResponse, http::Method}; -# -# fn index(req: HttpRequest) -> HttpResponse { -# unimplemented!() -# } -# -fn main() { - App::new() - .resource("/prefix", |r| r.f(index)) - .resource("/user/{name}", - |r| r.method(Method::GET).f(|req| HttpResponse::Ok())) - .finish(); -} -``` - -The *Configuration function* has the following type: - -```rust,ignore - FnOnce(&mut Resource<_>) -> () -``` - -The *Configuration function* can set a name and register specific routes. -If a resource does not contain any route or does not have any matching routes it -returns *NOT FOUND* http response. - -## Configuring a Route - -Resource contains a set of routes. Each route in turn has a set of predicates and a handler. -New routes can be created with `Resource::route()` method which returns a reference -to new *Route* instance. By default the *route* does not contain any predicates, so matches -all requests and the default handler is `HttpNotFound`. - -The application routes incoming requests based on route criteria which are defined during -resource registration and route registration. Resource matches all routes it contains in -the order the routes were registered via `Resource::route()`. A *Route* can contain -any number of *predicates* but only one handler. - -```rust -# extern crate actix_web; -# use actix_web::*; - -fn main() { - App::new() - .resource("/path", |resource| - resource.route() - .filter(pred::Get()) - .filter(pred::Header("content-type", "text/plain")) - .f(|req| HttpResponse::Ok()) - ) - .finish(); -} -``` - -In this example `HttpResponse::Ok()` is returned for *GET* requests, -if request contains `Content-Type` header and value of this header is *text/plain* -and path equals to `/path`. Resource calls handle of the first matching route. -If a resource can not match any route a "NOT FOUND" response is returned. - -[*Resource::route()*](../actix_web/struct.Resource.html#method.route) returns a -[*Route*](../actix_web/struct.Route.html) object. Route can be configured with a -builder-like pattern. Following configuration methods are available: - -* [*Route::filter()*](../actix_web/struct.Route.html#method.filter) registers a new predicate. - Any number of predicates can be registered for each route. - -* [*Route::f()*](../actix_web/struct.Route.html#method.f) registers handler function - for this route. Only one handler can be registered. Usually handler registration - is the last config operation. Handler function can be a function or closure and has the type - `Fn(HttpRequest) -> R + 'static` - -* [*Route::h()*](../actix_web/struct.Route.html#method.h) registers a handler object - that implements the `Handler` trait. This is similar to `f()` method - only one handler can - be registered. Handler registration is the last config operation. - -* [*Route::a()*](../actix_web/struct.Route.html#method.a) registers an async handler - function for this route. Only one handler can be registered. Handler registration - is the last config operation. Handler function can be a function or closure and has the type - `Fn(HttpRequest) -> Future + 'static` - -## Route matching - -The main purpose of route configuration is to match (or not match) the request's `path` -against a URL path pattern. `path` represents the path portion of the URL that was requested. - -The way that *actix* does this is very simple. When a request enters the system, -for each resource configuration declaration present in the system, actix checks -the request's path against the pattern declared. This checking happens in the order that -the routes were declared via `App::resource()` method. If resource can not be found, -the *default resource* is used as the matched resource. - -When a route configuration is declared, it may contain route predicate arguments. All route -predicates associated with a route declaration must be `true` for the route configuration to -be used for a given request during a check. If any predicate in the set of route predicate -arguments provided to a route configuration returns `false` during a check, that route is -skipped and route matching continues through the ordered set of routes. - -If any route matches, the route matching process stops and the handler associated with -the route is invoked. - -If no route matches after all route patterns are exhausted, a *NOT FOUND* response get returned. - -## Resource pattern syntax - -The syntax of the pattern matching language used by actix in the pattern -argument is straightforward. - -The pattern used in route configuration may start with a slash character. If the pattern -does not start with a slash character, an implicit slash will be prepended -to it at matching time. For example, the following patterns are equivalent: - -``` -{foo}/bar/baz -``` - -and: - -``` -/{foo}/bar/baz -``` - -A *variable part* (replacement marker) is specified in the form *{identifier}*, -where this means "accept any characters up to the next slash character and use this -as the name in the `HttpRequest.match_info()` object". - -A replacement marker in a pattern matches the regular expression `[^{}/]+`. - -A match_info is the `Params` object representing the dynamic parts extracted from a -*URL* based on the routing pattern. It is available as *request.match_info*. For example, the -following pattern defines one literal segment (foo) and two replacement markers (baz, and bar): - -``` -foo/{baz}/{bar} -``` - -The above pattern will match these URLs, generating the following match information: - -``` -foo/1/2 -> Params {'baz':'1', 'bar':'2'} -foo/abc/def -> Params {'baz':'abc', 'bar':'def'} -``` - -It will not match the following patterns however: - -``` -foo/1/2/ -> No match (trailing slash) -bar/abc/def -> First segment literal mismatch -``` - -The match for a segment replacement marker in a segment will be done only up to -the first non-alphanumeric character in the segment in the pattern. So, for instance, -if this route pattern was used: - -``` -foo/{name}.html -``` - -The literal path */foo/biz.html* will match the above route pattern, and the match result -will be `Params{'name': 'biz'}`. However, the literal path */foo/biz* will not match, -because it does not contain a literal *.html* at the end of the segment represented -by *{name}.html* (it only contains biz, not biz.html). - -To capture both segments, two replacement markers can be used: - -``` -foo/{name}.{ext} -``` - -The literal path */foo/biz.html* will match the above route pattern, and the match -result will be *Params{'name': 'biz', 'ext': 'html'}*. This occurs because there is a -literal part of *.* (period) between the two replacement markers *{name}* and *{ext}*. - -Replacement markers can optionally specify a regular expression which will be used to decide -whether a path segment should match the marker. To specify that a replacement marker should -match only a specific set of characters as defined by a regular expression, you must use a -slightly extended form of replacement marker syntax. Within braces, the replacement marker -name must be followed by a colon, then directly thereafter, the regular expression. The default -regular expression associated with a replacement marker *[^/]+* matches one or more characters -which are not a slash. For example, under the hood, the replacement marker *{foo}* can more -verbosely be spelled as *{foo:[^/]+}*. You can change this to be an arbitrary regular expression -to match an arbitrary sequence of characters, such as *{foo:\d+}* to match only digits. - -Segments must contain at least one character in order to match a segment replacement marker. -For example, for the URL */abc/*: - -* */abc/{foo}* will not match. -* */{foo}/* will match. - -Note that path will be URL-unquoted and decoded into valid unicode string before -matching pattern and values representing matched path segments will be URL-unquoted too. -So for instance, the following pattern: - -``` -foo/{bar} -``` - -When matching the following URL: - -``` -http://example.com/foo/La%20Pe%C3%B1a -``` - -The matchdict will look like so (the value is URL-decoded): - -``` -Params{'bar': 'La Pe\xf1a'} -``` - -Literal strings in the path segment should represent the decoded value of the -path provided to actix. You don't want to use a URL-encoded value in the pattern. -For example, rather than this: - -``` -/Foo%20Bar/{baz} -``` - -You'll want to use something like this: - -``` -/Foo Bar/{baz} -``` - -It is possible to get "tail match". For this purpose custom regex has to be used. - -``` -foo/{bar}/{tail:.*} -``` - -The above pattern will match these URLs, generating the following match information: - -``` -foo/1/2/ -> Params{'bar':'1', 'tail': '2/'} -foo/abc/def/a/b/c -> Params{'bar':u'abc', 'tail': 'def/a/b/c'} -``` - -## Match information - -All values representing matched path segments are available in -[`HttpRequest::match_info`](../actix_web/struct.HttpRequest.html#method.match_info). -Specific values can be retrieved with -[`Params::get()`](../actix_web/dev/struct.Params.html#method.get). - -Any matched parameter can be deserialized into a specific type if the type -implements the `FromParam` trait. For example most standard integer types -the trait, i.e.: - -```rust -# extern crate actix_web; -use actix_web::*; - -fn index(req: HttpRequest) -> Result { - let v1: u8 = req.match_info().query("v1")?; - let v2: u8 = req.match_info().query("v2")?; - Ok(format!("Values {} {}", v1, v2)) -} - -fn main() { - App::new() - .resource(r"/a/{v1}/{v2}/", |r| r.f(index)) - .finish(); -} -``` - -For this example for path '/a/1/2/', values v1 and v2 will resolve to "1" and "2". - -It is possible to create a `PathBuf` from a tail path parameter. The returned `PathBuf` is -percent-decoded. If a segment is equal to "..", the previous segment (if -any) is skipped. - -For security purposes, if a segment meets any of the following conditions, -an `Err` is returned indicating the condition met: - - * Decoded segment starts with any of: `.` (except `..`), `*` - * Decoded segment ends with any of: `:`, `>`, `<` - * Decoded segment contains any of: `/` - * On Windows, decoded segment contains any of: '\' - * Percent-encoding results in invalid UTF8. - -As a result of these conditions, a `PathBuf` parsed from request path parameter is -safe to interpolate within, or use as a suffix of, a path without additional checks. - -```rust -# extern crate actix_web; -use std::path::PathBuf; -use actix_web::{App, HttpRequest, Result, http::Method}; - -fn index(req: HttpRequest) -> Result { - let path: PathBuf = req.match_info().query("tail")?; - Ok(format!("Path {:?}", path)) -} - -fn main() { - App::new() - .resource(r"/a/{tail:.*}", |r| r.method(Method::GET).f(index)) - .finish(); -} -``` - -List of `FromParam` implementations can be found in -[api docs](../actix_web/dev/trait.FromParam.html#foreign-impls) - -## Path information extractor - -Actix provides functionality for type safe request path information extraction. -It uses *serde* package as a deserialization library. -[Path](../actix_web/struct.Path.html) extracts information, the destination type -has to implement *serde's *`Deserialize` trait. - -```rust -# extern crate actix_web; -#[macro_use] extern crate serde_derive; -use actix_web::{App, Path, Result, http::Method}; - -#[derive(Deserialize)] -struct Info { - username: String, -} - -// extract path info using serde -fn index(info: Path) -> Result { - Ok(format!("Welcome {}!", info.username)) -} - -fn main() { - let app = App::new() - .resource("/{username}/index.html", // <- define path parameters - |r| r.method(Method::GET).with(index)); -} -``` - -It also possible to extract path information to a tuple, in this case you don't need -to define extra type, just use tuple for as a `Path` generic type. - -Here is previous example re-written using tuple instead of specific type. - -```rust -# extern crate actix_web; -use actix_web::{App, Path, Result, http::Method}; - -// extract path info using serde -fn index(info: Path<(String, u32)>) -> Result { - Ok(format!("Welcome {}! id: {}", info.0, info.1)) -} - -fn main() { - let app = App::new() - .resource("/{username}/{id}/index.html", // <- define path parameters - |r| r.method(Method::GET).with(index)); -} -``` - -[Query](../actix_web/struct.Query.html) provides similar functionality for -request query parameters. - - -## Generating resource URLs - -Use the [HttpRequest.url_for()](../actix_web/struct.HttpRequest.html#method.url_for) -method to generate URLs based on resource patterns. For example, if you've configured a -resource with the name "foo" and the pattern "{a}/{b}/{c}", you might do this: - -```rust -# extern crate actix_web; -# use actix_web::{App, HttpRequest, HttpResponse, http::Method}; -# -fn index(req: HttpRequest) -> HttpResponse { - let url = req.url_for("foo", &["1", "2", "3"]); // <- generate url for "foo" resource - HttpResponse::Ok().into() -} - -fn main() { - let app = App::new() - .resource("/test/{a}/{b}/{c}", |r| { - r.name("foo"); // <- set resource name, then it could be used in `url_for` - r.method(Method::GET).f(|_| HttpResponse::Ok()); - }) - .finish(); -} -``` - -This would return something like the string *http://example.com/test/1/2/3* (at least if -the current protocol and hostname implied http://example.com). -`url_for()` method returns [*Url object*](https://docs.rs/url/1.6.0/url/struct.Url.html) so you -can modify this url (add query parameters, anchor, etc). -`url_for()` could be called only for *named* resources otherwise error get returned. - -## External resources - -Resources that are valid URLs, can be registered as external resources. They are useful -for URL generation purposes only and are never considered for matching at request time. - -```rust -# extern crate actix_web; -use actix_web::{App, HttpRequest, HttpResponse, Error}; - -fn index(mut req: HttpRequest) -> Result { - let url = req.url_for("youtube", &["oHg5SJYRHA0"])?; - assert_eq!(url.as_str(), "https://youtube.com/watch/oHg5SJYRHA0"); - Ok(HttpResponse::Ok().into()) -} - -fn main() { - let app = App::new() - .resource("/index.html", |r| r.f(index)) - .external_resource("youtube", "https://youtube.com/watch/{video_id}") - .finish(); -} -``` - -## Path normalization and redirecting to slash-appended routes - -By normalizing it means: - - - Add a trailing slash to the path. - - Double slashes are replaced by one. - -The handler returns as soon as it finds a path that resolves -correctly. The order if all enable is 1) merge, 3) both merge and append -and 3) append. If the path resolves with -at least one of those conditions, it will redirect to the new path. - -If *append* is *true* append slash when needed. If a resource is -defined with trailing slash and the request doesn't have one, it will -be appended automatically. - -If *merge* is *true*, merge multiple consecutive slashes in the path into one. - -This handler designed to be use as a handler for application's *default resource*. - -```rust -# extern crate actix_web; -# #[macro_use] extern crate serde_derive; -# use actix_web::*; -use actix_web::http::NormalizePath; -# -# fn index(req: HttpRequest) -> HttpResponse { -# HttpResponse::Ok().into() -# } -fn main() { - let app = App::new() - .resource("/resource/", |r| r.f(index)) - .default_resource(|r| r.h(NormalizePath::default())) - .finish(); -} -``` - -In this example `/resource`, `//resource///` will be redirected to `/resource/`. - -In this example path normalization handler is registered for all methods, -but you should not rely on this mechanism to redirect *POST* requests. The redirect of the -slash-appending *Not Found* will turn a *POST* request into a GET, losing any -*POST* data in the original request. - -It is possible to register path normalization only for *GET* requests only: - -```rust -# extern crate actix_web; -# #[macro_use] extern crate serde_derive; -use actix_web::{App, HttpRequest, http::Method, http::NormalizePath}; -# -# fn index(req: HttpRequest) -> &'static str { -# "test" -# } -fn main() { - let app = App::new() - .resource("/resource/", |r| r.f(index)) - .default_resource(|r| r.method(Method::GET).h(NormalizePath::default())) - .finish(); -} -``` - -## Using an Application Prefix to Compose Applications - -The `App::prefix()`" method allows to set a specific application prefix. -This prefix represents a resource prefix that will be prepended to all resource patterns added -by the resource configuration. This can be used to help mount a set of routes at a different -location than the included callable's author intended while still maintaining the same -resource names. - -For example: - -```rust -# extern crate actix_web; -# use actix_web::*; -# -fn show_users(req: HttpRequest) -> HttpResponse { - unimplemented!() -} - -fn main() { - App::new() - .prefix("/users") - .resource("/show", |r| r.f(show_users)) - .finish(); -} -``` - -In the above example, the *show_users* route will have an effective route pattern of -*/users/show* instead of */show* because the application's prefix argument will be prepended -to the pattern. The route will then only match if the URL path is */users/show*, -and when the `HttpRequest.url_for()` function is called with the route name show_users, -it will generate a URL with that same path. - -## Custom route predicates - -You can think of a predicate as a simple function that accepts a *request* object reference -and returns *true* or *false*. Formally, a predicate is any object that implements the -[`Predicate`](../actix_web/pred/trait.Predicate.html) trait. Actix provides -several predicates, you can check [functions section](../actix_web/pred/index.html#functions) -of api docs. - -Here is a simple predicate that check that a request contains a specific *header*: - -```rust -# extern crate actix_web; -# use actix_web::*; -use actix_web::{http, pred::Predicate, App, HttpRequest}; - -struct ContentTypeHeader; - -impl Predicate for ContentTypeHeader { - - fn check(&self, req: &mut HttpRequest) -> bool { - req.headers().contains_key(http::header::CONTENT_TYPE) - } -} - -fn main() { - App::new() - .resource("/index.html", |r| - r.route() - .filter(ContentTypeHeader) - .f(|_| HttpResponse::Ok())); -} -``` - -In this example *index* handler will be called only if request contains *CONTENT-TYPE* header. - -Predicates have access to the application's state via `HttpRequest::state()`. -Also predicates can store extra information in -[request extensions](../actix_web/struct.HttpRequest.html#method.extensions). - -### Modifying predicate values - -You can invert the meaning of any predicate value by wrapping it in a `Not` predicate. -For example if you want to return "METHOD NOT ALLOWED" response for all methods -except "GET": - -```rust -# extern crate actix_web; -# extern crate http; -# use actix_web::*; -use actix_web::{pred, App, HttpResponse}; - -fn main() { - App::new() - .resource("/index.html", |r| - r.route() - .filter(pred::Not(pred::Get())) - .f(|req| HttpResponse::MethodNotAllowed())) - .finish(); -} -``` - -The `Any` predicate accepts a list of predicates and matches if any of the supplied -predicates match. i.e: - -```rust,ignore - pred::Any(pred::Get()).or(pred::Post()) -``` - -The `All` predicate accepts a list of predicates and matches if all of the supplied -predicates match. i.e: - -```rust,ignore - pred::All(pred::Get()).and(pred::Header("content-type", "plain/text")) -``` - -## Changing the default Not Found response - -If the path pattern can not be found in the routing table or a resource can not find matching -route, the default resource is used. The default response is *NOT FOUND*. -It is possible to override the *NOT FOUND* response with `App::default_resource()`. -This method accepts a *configuration function* same as normal resource configuration -with `App::resource()` method. - -```rust -# extern crate actix_web; -use actix_web::{App, HttpResponse, http::Method, pred}; - -fn main() { - App::new() - .default_resource(|r| { - r.method(Method::GET).f(|req| HttpResponse::NotFound()); - r.route().filter(pred::Not(pred::Get())) - .f(|req| HttpResponse::MethodNotAllowed()); - }) -# .finish(); -} -``` diff --git a/guide/src/qs_7.md b/guide/src/qs_7.md deleted file mode 100644 index fab21a34b..000000000 --- a/guide/src/qs_7.md +++ /dev/null @@ -1,317 +0,0 @@ -# Request & Response - -## Response - -A builder-like pattern is used to construct an instance of `HttpResponse`. -`HttpResponse` provides several methods that return a `HttpResponseBuilder` instance, -which implements various convenience methods that helps building responses. -Check [documentation](../actix_web/dev/struct.HttpResponseBuilder.html) -for type descriptions. The methods `.body`, `.finish`, `.json` finalize response creation and -return a constructed *HttpResponse* instance. If this methods is called for the same -builder instance multiple times, the builder will panic. - -```rust -# extern crate actix_web; -use actix_web::{HttpRequest, HttpResponse, http::ContentEncoding}; - -fn index(req: HttpRequest) -> HttpResponse { - HttpResponse::Ok() - .content_encoding(ContentEncoding::Br) - .content_type("plain/text") - .header("X-Hdr", "sample") - .body("data") -} -# fn main() {} -``` - -## Content encoding - -Actix automatically *compresses*/*decompresses* payloads. Following codecs are supported: - - * Brotli - * Gzip - * Deflate - * Identity - - If request headers contain a `Content-Encoding` header, the request payload is decompressed - according to the header value. Multiple codecs are not supported, i.e: `Content-Encoding: br, gzip`. - -Response payload is compressed based on the *content_encoding* parameter. -By default `ContentEncoding::Auto` is used. If `ContentEncoding::Auto` is selected -then compression depends on the request's `Accept-Encoding` header. -`ContentEncoding::Identity` can be used to disable compression. -If another content encoding is selected the compression is enforced for this codec. For example, -to enable `brotli` use `ContentEncoding::Br`: - -```rust -# extern crate actix_web; -use actix_web::{HttpRequest, HttpResponse, http::ContentEncoding}; - -fn index(req: HttpRequest) -> HttpResponse { - HttpResponse::Ok() - .content_encoding(ContentEncoding::Br) - .body("data") -} -# fn main() {} -``` - - -## JSON Request - -There are several options for json body deserialization. - -The first option is to use *Json* extractor. You define handler function -that accepts `Json` as a parameter and use `.with()` method for registering -this handler. It is also possible to accept arbitrary valid json object by -using `serde_json::Value` as a type `T` - -```rust -# extern crate actix_web; -#[macro_use] extern crate serde_derive; -use actix_web::{App, Json, Result, http}; - -#[derive(Deserialize)] -struct Info { - username: String, -} - -/// extract `Info` using serde -fn index(info: Json) -> Result { - Ok(format!("Welcome {}!", info.username)) -} - -fn main() { - let app = App::new().resource( - "/index.html", - |r| r.method(http::Method::POST).with(index)); // <- use `with` extractor -} -``` - -The second option is to use *HttpResponse::json()*. This method returns a -[*JsonBody*](../actix_web/dev/struct.JsonBody.html) object which resolves into -the deserialized value. - -```rust -# extern crate actix; -# extern crate actix_web; -# extern crate futures; -# extern crate serde_json; -# #[macro_use] extern crate serde_derive; -# use actix_web::*; -# use futures::Future; -#[derive(Debug, Serialize, Deserialize)] -struct MyObj { - name: String, - number: i32, -} - -fn index(mut req: HttpRequest) -> Box> { - req.json().from_err() - .and_then(|val: MyObj| { - println!("model: {:?}", val); - Ok(HttpResponse::Ok().json(val)) // <- send response - }) - .responder() -} -# fn main() {} -``` - -Or you can manually load the payload into memory and then deserialize it. -Here is a simple example. We will deserialize a *MyObj* struct. We need to load the request -body first and then deserialize the json into an object. - -```rust -# extern crate actix_web; -# extern crate futures; -# use actix_web::*; -# #[macro_use] extern crate serde_derive; -extern crate serde_json; -use futures::{Future, Stream}; - -#[derive(Serialize, Deserialize)] -struct MyObj {name: String, number: i32} - -fn index(req: HttpRequest) -> Box> { - // `concat2` will asynchronously read each chunk of the request body and - // return a single, concatenated, chunk - req.concat2() - // `Future::from_err` acts like `?` in that it coerces the error type from - // the future into the final error type - .from_err() - // `Future::and_then` can be used to merge an asynchronous workflow with a - // synchronous workflow - .and_then(|body| { // <- body is loaded, now we can deserialize json - let obj = serde_json::from_slice::(&body)?; - Ok(HttpResponse::Ok().json(obj)) // <- send response - }) - .responder() -} -# fn main() {} -``` - -A complete example for both options is available in -[examples directory](https://github.com/actix/actix-web/tree/master/examples/json/). - - -## JSON Response - -The `Json` type allows to respond with well-formed JSON data: simply return a value of -type Json where T is the type of a structure to serialize into *JSON*. The -type `T` must implement the `Serialize` trait from *serde*. - -```rust -# extern crate actix_web; -#[macro_use] extern crate serde_derive; -use actix_web::{App, HttpRequest, Json, Result, http::Method}; - -#[derive(Serialize)] -struct MyObj { - name: String, -} - -fn index(req: HttpRequest) -> Result> { - Ok(Json(MyObj{name: req.match_info().query("name")?})) -} - -fn main() { - App::new() - .resource(r"/a/{name}", |r| r.method(Method::GET).f(index)) - .finish(); -} -``` - -## Chunked transfer encoding - -Actix automatically decodes *chunked* encoding. `HttpRequest::payload()` already contains -the decoded byte stream. If the request payload is compressed with one of the supported -compression codecs (br, gzip, deflate) the byte stream is decompressed. - -Chunked encoding on response can be enabled with `HttpResponseBuilder::chunked()`. -But this takes effect only for `Body::Streaming(BodyStream)` or `Body::StreamingContext` bodies. -Also if response payload compression is enabled and streaming body is used, chunked encoding -is enabled automatically. - -Enabling chunked encoding for *HTTP/2.0* responses is forbidden. - -```rust -# extern crate bytes; -# extern crate actix_web; -# extern crate futures; -# use futures::Stream; -use actix_web::*; -use bytes::Bytes; -use futures::stream::once; - -fn index(req: HttpRequest) -> HttpResponse { - HttpResponse::Ok() - .chunked() - .body(Body::Streaming(Box::new(once(Ok(Bytes::from_static(b"data")))))) -} -# fn main() {} -``` - -## Multipart body - -Actix provides multipart stream support. -[*Multipart*](../actix_web/multipart/struct.Multipart.html) is implemented as -a stream of multipart items, each item can be a -[*Field*](../actix_web/multipart/struct.Field.html) or a nested *Multipart* stream. -`HttpResponse::multipart()` returns the *Multipart* stream for the current request. - -In simple form multipart stream handling can be implemented similar to this example - -```rust,ignore -# extern crate actix_web; -use actix_web::*; - -fn index(req: HttpRequest) -> Box> { - req.multipart() // <- get multipart stream for current request - .and_then(|item| { // <- iterate over multipart items - match item { - // Handle multipart Field - multipart::MultipartItem::Field(field) => { - println!("==== FIELD ==== {:?} {:?}", field.headers(), field.content_type()); - - Either::A( - // Field in turn is a stream of *Bytes* objects - field.map(|chunk| { - println!("-- CHUNK: \n{}", - std::str::from_utf8(&chunk).unwrap());}) - .fold((), |_, _| result(Ok(())))) - }, - multipart::MultipartItem::Nested(mp) => { - // Or item could be nested Multipart stream - Either::B(result(Ok(()))) - } - } - }) -} -``` - -A full example is available in the -[examples directory](https://github.com/actix/actix-web/tree/master/examples/multipart/). - -## Urlencoded body - -Actix provides support for *application/x-www-form-urlencoded* encoded bodies. -`HttpResponse::urlencoded()` returns a -[*UrlEncoded*](../actix_web/dev/struct.UrlEncoded.html) future, which resolves -to the deserialized instance, the type of the instance must implement the -`Deserialize` trait from *serde*. The *UrlEncoded* future can resolve into -a error in several cases: - -* content type is not `application/x-www-form-urlencoded` -* transfer encoding is `chunked`. -* content-length is greater than 256k -* payload terminates with error. - -```rust -# extern crate actix_web; -# extern crate futures; -#[macro_use] extern crate serde_derive; -use actix_web::*; -use futures::future::{Future, ok}; - -#[derive(Deserialize)] -struct FormData { - username: String, -} - -fn index(mut req: HttpRequest) -> Box> { - req.urlencoded::() // <- get UrlEncoded future - .from_err() - .and_then(|data| { // <- deserialized instance - println!("USERNAME: {:?}", data.username); - ok(HttpResponse::Ok().into()) - }) - .responder() -} -# fn main() {} -``` - -## Streaming request - -*HttpRequest* is a stream of `Bytes` objects. It can be used to read the request -body payload. - -In this example handle reads the request payload chunk by chunk and prints every chunk. - -```rust -# extern crate actix_web; -# extern crate futures; -# use futures::future::result; -use actix_web::*; -use futures::{Future, Stream}; - - -fn index(mut req: HttpRequest) -> Box> { - req.from_err() - .fold((), |_, chunk| { - println!("Chunk: {:?}", chunk); - result::<_, error::PayloadError>(Ok(())) - }) - .map(|_| HttpResponse::Ok().finish()) - .responder() -} -# fn main() {} -``` diff --git a/guide/src/qs_8.md b/guide/src/qs_8.md deleted file mode 100644 index 380f9e0e7..000000000 --- a/guide/src/qs_8.md +++ /dev/null @@ -1,176 +0,0 @@ -# Testing - -Every application should be well tested. Actix provides tools to perform unit and -integration tests. - -## Unit tests - -For unit testing actix provides a request builder type and simple handler runner. -[*TestRequest*](../actix_web/test/struct.TestRequest.html) implements a builder-like pattern. -You can generate a `HttpRequest` instance with `finish()` or you can -run your handler with `run()` or `run_async()`. - -```rust -# extern crate actix_web; -use actix_web::{http, test, HttpRequest, HttpResponse, HttpMessage}; - -fn index(req: HttpRequest) -> HttpResponse { - if let Some(hdr) = req.headers().get(http::header::CONTENT_TYPE) { - if let Ok(s) = hdr.to_str() { - return HttpResponse::Ok().into() - } - } - HttpResponse::BadRequest().into() -} - -fn main() { - let resp = test::TestRequest::with_header("content-type", "text/plain") - .run(index) - .unwrap(); - assert_eq!(resp.status(), http::StatusCode::OK); - - let resp = test::TestRequest::default() - .run(index) - .unwrap(); - assert_eq!(resp.status(), http::StatusCode::BAD_REQUEST); -} -``` - - -## Integration tests - -There are several methods how you can test your application. Actix provides -[*TestServer*](../actix_web/test/struct.TestServer.html) -server that can be used to run the whole application of just specific handlers -in real http server. *TestServer::get()*, *TestServer::post()* or *TestServer::client()* -methods can be used to send requests to the test server. - -In simple form *TestServer* can be configured to use handler. *TestServer::new* method -accepts configuration function, only argument for this function is *test application* -instance. You can check the [api documentation](../actix_web/test/struct.TestApp.html) -for more information. - -```rust -# extern crate actix_web; -use actix_web::{HttpRequest, HttpResponse, HttpMessage}; -use actix_web::test::TestServer; - -fn index(req: HttpRequest) -> HttpResponse { - HttpResponse::Ok().into() -} - -fn main() { - let mut srv = TestServer::new(|app| app.handler(index)); // <- Start new test server - - let request = srv.get().finish().unwrap(); // <- create client request - let response = srv.execute(request.send()).unwrap(); // <- send request to the server - assert!(response.status().is_success()); // <- check response - - let bytes = srv.execute(response.body()).unwrap(); // <- read response body -} -``` - -The other option is to use an application factory. In this case you need to pass the factory -function same way as you would for real http server configuration. - -```rust -# extern crate actix_web; -use actix_web::{http, test, App, HttpRequest, HttpResponse}; - -fn index(req: HttpRequest) -> HttpResponse { - HttpResponse::Ok().into() -} - -/// This function get called by http server. -fn create_app() -> App { - App::new() - .resource("/test", |r| r.h(index)) -} - -fn main() { - let mut srv = test::TestServer::with_factory(create_app); // <- Start new test server - - let request = srv.client( - http::Method::GET, "/test").finish().unwrap(); // <- create client request - let response = srv.execute(request.send()).unwrap(); // <- send request to the server - - assert!(response.status().is_success()); // <- check response -} -``` - -If you need more complex application configuration, for example you may need to -initialize application state or start `SyncActor`'s for diesel interation, you -can use `TestServer::build_with_state()` method. This method accepts closure -that has to construct application state. This closure runs when actix system is -configured already, so you can initialize any additional actors. - -```rust,ignore -#[test] -fn test() { - let srv = TestServer::build_with_state(|| { // <- construct builder with config closure - // we can start diesel actors - let addr = SyncArbiter::start(3, || { - DbExecutor(SqliteConnection::establish("test.db").unwrap()) - }); - // then we can construct custom state, or it could be `()` - MyState{addr: addr} - }) - .start(|app| { // <- register server handlers and start test server - app.resource( - "/{username}/index.html", |r| r.with( - |p: Path| format!("Welcome {}!", p.username))); - }); - - // now we can run our test code -); -``` - -## WebSocket server tests - -It is possible to register a *handler* with `TestApp::handler()` that -initiates a web socket connection. *TestServer* provides `ws()` which connects to -the websocket server and returns ws reader and writer objects. *TestServer* also -provides an `execute()` method which runs future objects to completion and returns -result of the future computation. - -Here is a simple example that shows how to test server websocket handler. - -```rust -# extern crate actix; -# extern crate actix_web; -# extern crate futures; -# extern crate http; -# extern crate bytes; - -use actix_web::*; -use futures::Stream; -# use actix::prelude::*; - -struct Ws; // <- WebSocket actor - -impl Actor for Ws { - type Context = ws::WebsocketContext; -} - -impl StreamHandler for Ws { - - fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { - match msg { - ws::Message::Text(text) => ctx.text(text), - _ => (), - } - } -} - -fn main() { - let mut srv = test::TestServer::new( // <- start our server with ws handler - |app| app.handler(|req| ws::start(req, Ws))); - - let (reader, mut writer) = srv.ws().unwrap(); // <- connect to ws server - - writer.text("text"); // <- send message to server - - let (item, reader) = srv.execute(reader.into_future()).unwrap(); // <- wait for one message - assert_eq!(item, Some(ws::Message::Text("text".to_owned()))); -} -``` diff --git a/guide/src/qs_9.md b/guide/src/qs_9.md deleted file mode 100644 index 158ba2513..000000000 --- a/guide/src/qs_9.md +++ /dev/null @@ -1,48 +0,0 @@ -# WebSockets - -Actix supports WebSockets out-of-the-box. It is possible to convert a request's `Payload` -to a stream of [*ws::Message*](../actix_web/ws/enum.Message.html) with -a [*ws::WsStream*](../actix_web/ws/struct.WsStream.html) and then use stream -combinators to handle actual messages. But it is simpler to handle websocket communications -with an http actor. - -This is example of a simple websocket echo server: - -```rust -# extern crate actix; -# extern crate actix_web; -use actix::*; -use actix_web::*; - -/// Define http actor -struct Ws; - -impl Actor for Ws { - type Context = ws::WebsocketContext; -} - -/// Handler for ws::Message message -impl StreamHandler for Ws { - - fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { - match msg { - ws::Message::Ping(msg) => ctx.pong(&msg), - ws::Message::Text(text) => ctx.text(text), - ws::Message::Binary(bin) => ctx.binary(bin), - _ => (), - } - } -} - -fn main() { - App::new() - .resource("/ws/", |r| r.f(|req| ws::start(req, Ws))) // <- register websocket route - .finish(); -} -``` - -A simple websocket echo server example is available in the -[examples directory](https://github.com/actix/actix-web/blob/master/examples/websocket). - -An example chat server with the ability to chat over a websocket or tcp connection -is available in [websocket-chat directory](https://github.com/actix/actix-web/tree/master/examples/websocket-chat/) diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 000000000..94bd11d51 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,2 @@ +max_width = 89 +reorder_imports = true diff --git a/src/app.rs b/src/app.rs new file mode 100644 index 000000000..d67817d21 --- /dev/null +++ b/src/app.rs @@ -0,0 +1,677 @@ +use std::cell::RefCell; +use std::fmt; +use std::future::Future; +use std::marker::PhantomData; +use std::rc::Rc; + +use actix_http::body::{Body, MessageBody}; +use actix_service::boxed::{self, BoxServiceFactory}; +use actix_service::{ + apply, apply_fn_factory, IntoServiceFactory, ServiceFactory, Transform, +}; +use futures::future::{FutureExt, LocalBoxFuture}; + +use crate::app_service::{AppEntry, AppInit, AppRoutingFactory}; +use crate::config::{AppConfig, AppConfigInner, ServiceConfig}; +use crate::data::{Data, DataFactory}; +use crate::dev::ResourceDef; +use crate::error::Error; +use crate::resource::Resource; +use crate::route::Route; +use crate::service::{ + AppServiceFactory, HttpServiceFactory, ServiceFactoryWrapper, ServiceRequest, + ServiceResponse, +}; + +type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; +type FnDataFactory = + Box LocalBoxFuture<'static, Result, ()>>>; + +/// Application builder - structure that follows the builder pattern +/// for building application instances. +pub struct App { + endpoint: T, + services: Vec>, + default: Option>, + factory_ref: Rc>>, + data: Vec>, + data_factories: Vec, + config: AppConfigInner, + external: Vec, + _t: PhantomData<(B)>, +} + +impl App { + /// Create application builder. Application can be configured with a builder-like pattern. + pub fn new() -> Self { + let fref = Rc::new(RefCell::new(None)); + App { + endpoint: AppEntry::new(fref.clone()), + data: Vec::new(), + data_factories: Vec::new(), + services: Vec::new(), + default: None, + factory_ref: fref, + config: AppConfigInner::default(), + external: Vec::new(), + _t: PhantomData, + } + } +} + +impl App +where + B: MessageBody, + T: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + >, +{ + /// Set application data. Application data could be accessed + /// by using `Data` extractor where `T` is data type. + /// + /// **Note**: http server accepts an application factory rather than + /// an application instance. Http server constructs an application + /// instance for each thread, thus application data must be constructed + /// multiple times. If you want to share data between different + /// threads, a shared object should be used, e.g. `Arc`. Application + /// data does not need to be `Send` or `Sync`. + /// + /// ```rust + /// use std::cell::Cell; + /// use actix_web::{web, App, HttpResponse, Responder}; + /// + /// struct MyData { + /// counter: Cell, + /// } + /// + /// async fn index(data: web::Data) -> impl Responder { + /// data.counter.set(data.counter.get() + 1); + /// HttpResponse::Ok() + /// } + /// + /// fn main() { + /// let app = App::new() + /// .data(MyData{ counter: Cell::new(0) }) + /// .service( + /// web::resource("/index.html").route( + /// web::get().to(index))); + /// } + /// ``` + pub fn data(mut self, data: U) -> Self { + self.data.push(Box::new(Data::new(data))); + self + } + + /// Set application data factory. This function is + /// similar to `.data()` but it accepts data factory. Data object get + /// constructed asynchronously during application initialization. + pub fn data_factory(mut self, data: F) -> Self + where + F: Fn() -> Out + 'static, + Out: Future> + 'static, + D: 'static, + E: std::fmt::Debug, + { + self.data_factories.push(Box::new(move || { + { + let fut = data(); + async move { + match fut.await { + Err(e) => { + log::error!("Can not construct data instance: {:?}", e); + Err(()) + } + Ok(data) => { + let data: Box = Box::new(Data::new(data)); + Ok(data) + } + } + } + } + .boxed_local() + })); + self + } + + /// Set application data. Application data could be accessed + /// by using `Data` extractor where `T` is data type. + pub fn register_data(mut self, data: Data) -> Self { + self.data.push(Box::new(data)); + self + } + + /// Run external configuration as part of the application building + /// process + /// + /// This function is useful for moving parts of configuration to a + /// different module or even library. For example, + /// some of the resource's configuration could be moved to different module. + /// + /// ```rust + /// # extern crate actix_web; + /// use actix_web::{web, middleware, App, HttpResponse}; + /// + /// // this function could be located in different module + /// fn config(cfg: &mut web::ServiceConfig) { + /// cfg.service(web::resource("/test") + /// .route(web::get().to(|| HttpResponse::Ok())) + /// .route(web::head().to(|| HttpResponse::MethodNotAllowed())) + /// ); + /// } + /// + /// fn main() { + /// let app = App::new() + /// .wrap(middleware::Logger::default()) + /// .configure(config) // <- register resources + /// .route("/index.html", web::get().to(|| HttpResponse::Ok())); + /// } + /// ``` + pub fn configure(mut self, f: F) -> Self + where + F: FnOnce(&mut ServiceConfig), + { + let mut cfg = ServiceConfig::new(); + f(&mut cfg); + self.data.extend(cfg.data); + self.services.extend(cfg.services); + self.external.extend(cfg.external); + self + } + + /// Configure route for a specific path. + /// + /// This is a simplified version of the `App::service()` method. + /// This method can be used multiple times with same path, in that case + /// multiple resources with one route would be registered for same resource path. + /// + /// ```rust + /// use actix_web::{web, App, HttpResponse}; + /// + /// async fn index(data: web::Path<(String, String)>) -> &'static str { + /// "Welcome!" + /// } + /// + /// fn main() { + /// let app = App::new() + /// .route("/test1", web::get().to(index)) + /// .route("/test2", web::post().to(|| HttpResponse::MethodNotAllowed())); + /// } + /// ``` + pub fn route(self, path: &str, mut route: Route) -> Self { + self.service( + Resource::new(path) + .add_guards(route.take_guards()) + .route(route), + ) + } + + /// Register http service. + /// + /// Http service is any type that implements `HttpServiceFactory` trait. + /// + /// Actix web provides several services implementations: + /// + /// * *Resource* is an entry in resource table which corresponds to requested URL. + /// * *Scope* is a set of resources with common root path. + /// * "StaticFiles" is a service for static files support + pub fn service(mut self, factory: F) -> Self + where + F: HttpServiceFactory + 'static, + { + self.services + .push(Box::new(ServiceFactoryWrapper::new(factory))); + self + } + + /// Set server host name. + /// + /// Host name is used by application router as a hostname for url + /// generation. Check [ConnectionInfo](./dev/struct.ConnectionInfo. + /// html#method.host) documentation for more information. + /// + /// By default host name is set to a "localhost" value. + pub fn hostname(mut self, val: &str) -> Self { + self.config.host = val.to_owned(); + self + } + + /// Default service to be used if no matching resource could be found. + /// + /// It is possible to use services like `Resource`, `Route`. + /// + /// ```rust + /// use actix_web::{web, App, HttpResponse}; + /// + /// async fn index() -> &'static str { + /// "Welcome!" + /// } + /// + /// fn main() { + /// let app = App::new() + /// .service( + /// web::resource("/index.html").route(web::get().to(index))) + /// .default_service( + /// web::route().to(|| HttpResponse::NotFound())); + /// } + /// ``` + /// + /// It is also possible to use static files as default service. + /// + /// ```rust + /// use actix_web::{web, App, HttpResponse}; + /// + /// fn main() { + /// let app = App::new() + /// .service( + /// web::resource("/index.html").to(|| HttpResponse::Ok())) + /// .default_service( + /// web::to(|| HttpResponse::NotFound()) + /// ); + /// } + /// ``` + pub fn default_service(mut self, f: F) -> Self + where + F: IntoServiceFactory, + U: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + > + 'static, + U::InitError: fmt::Debug, + { + // create and configure default resource + self.default = Some(Rc::new(boxed::factory(f.into_factory().map_init_err( + |e| log::error!("Can not construct default service: {:?}", e), + )))); + + self + } + + /// Register an external resource. + /// + /// External resources are useful for URL generation purposes only + /// and are never considered for matching at request time. Calls to + /// `HttpRequest::url_for()` will work as expected. + /// + /// ```rust + /// use actix_web::{web, App, HttpRequest, HttpResponse, Result}; + /// + /// async fn index(req: HttpRequest) -> Result { + /// let url = req.url_for("youtube", &["asdlkjqme"])?; + /// assert_eq!(url.as_str(), "https://youtube.com/watch/asdlkjqme"); + /// Ok(HttpResponse::Ok().into()) + /// } + /// + /// fn main() { + /// let app = App::new() + /// .service(web::resource("/index.html").route( + /// web::get().to(index))) + /// .external_resource("youtube", "https://youtube.com/watch/{video_id}"); + /// } + /// ``` + pub fn external_resource(mut self, name: N, url: U) -> Self + where + N: AsRef, + U: AsRef, + { + let mut rdef = ResourceDef::new(url.as_ref()); + *rdef.name_mut() = name.as_ref().to_string(); + self.external.push(rdef); + self + } + + /// Registers middleware, in the form of a middleware component (type), + /// that runs during inbound and/or outbound processing in the request + /// lifecycle (request -> response), modifying request/response as + /// necessary, across all requests managed by the *Application*. + /// + /// Use middleware when you need to read or modify *every* request or + /// response in some way. + /// + /// Notice that the keyword for registering middleware is `wrap`. As you + /// register middleware using `wrap` in the App builder, imagine wrapping + /// layers around an inner App. The first middleware layer exposed to a + /// Request is the outermost layer-- the *last* registered in + /// the builder chain. Consequently, the *first* middleware registered + /// in the builder chain is the *last* to execute during request processing. + /// + /// ```rust + /// use actix_service::Service; + /// use actix_web::{middleware, web, App}; + /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; + /// + /// async fn index() -> &'static str { + /// "Welcome!" + /// } + /// + /// fn main() { + /// let app = App::new() + /// .wrap(middleware::Logger::default()) + /// .route("/index.html", web::get().to(index)); + /// } + /// ``` + pub fn wrap( + self, + mw: M, + ) -> App< + impl ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + >, + B1, + > + where + M: Transform< + T::Service, + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + >, + B1: MessageBody, + { + App { + endpoint: apply(mw, self.endpoint), + data: self.data, + data_factories: self.data_factories, + services: self.services, + default: self.default, + factory_ref: self.factory_ref, + config: self.config, + external: self.external, + _t: PhantomData, + } + } + + /// Registers middleware, in the form of a closure, that runs during inbound + /// and/or outbound processing in the request lifecycle (request -> response), + /// modifying request/response as necessary, across all requests managed by + /// the *Application*. + /// + /// Use middleware when you need to read or modify *every* request or response in some way. + /// + /// ```rust + /// use actix_service::Service; + /// use actix_web::{web, App}; + /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; + /// + /// async fn index() -> &'static str { + /// "Welcome!" + /// } + /// + /// fn main() { + /// let app = App::new() + /// .wrap_fn(|req, srv| { + /// let fut = srv.call(req); + /// async { + /// let mut res = fut.await?; + /// res.headers_mut().insert( + /// CONTENT_TYPE, HeaderValue::from_static("text/plain"), + /// ); + /// Ok(res) + /// } + /// }) + /// .route("/index.html", web::get().to(index)); + /// } + /// ``` + pub fn wrap_fn( + self, + mw: F, + ) -> App< + impl ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + >, + B1, + > + where + B1: MessageBody, + F: FnMut(ServiceRequest, &mut T::Service) -> R + Clone, + R: Future, Error>>, + { + App { + endpoint: apply_fn_factory(self.endpoint, mw), + data: self.data, + data_factories: self.data_factories, + services: self.services, + default: self.default, + factory_ref: self.factory_ref, + config: self.config, + external: self.external, + _t: PhantomData, + } + } +} + +impl IntoServiceFactory> for App +where + B: MessageBody, + T: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + >, +{ + fn into_factory(self) -> AppInit { + AppInit { + data: Rc::new(self.data), + data_factories: Rc::new(self.data_factories), + endpoint: self.endpoint, + services: Rc::new(RefCell::new(self.services)), + external: RefCell::new(self.external), + default: self.default, + factory_ref: self.factory_ref, + config: RefCell::new(AppConfig(Rc::new(self.config))), + } + } +} + +#[cfg(test)] +mod tests { + use actix_service::Service; + use bytes::Bytes; + use futures::future::ok; + + use super::*; + use crate::http::{header, HeaderValue, Method, StatusCode}; + use crate::middleware::DefaultHeaders; + use crate::service::ServiceRequest; + use crate::test::{call_service, init_service, read_body, TestRequest}; + use crate::{web, HttpRequest, HttpResponse}; + + #[actix_rt::test] + async fn test_default_resource() { + let mut srv = init_service( + App::new().service(web::resource("/test").to(|| HttpResponse::Ok())), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/blah").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + let mut srv = init_service( + App::new() + .service(web::resource("/test").to(|| HttpResponse::Ok())) + .service( + web::resource("/test2") + .default_service(|r: ServiceRequest| { + ok(r.into_response(HttpResponse::Created())) + }) + .route(web::get().to(|| HttpResponse::Ok())), + ) + .default_service(|r: ServiceRequest| { + ok(r.into_response(HttpResponse::MethodNotAllowed())) + }), + ) + .await; + + let req = TestRequest::with_uri("/blah").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + + let req = TestRequest::with_uri("/test2").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/test2") + .method(Method::POST) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + } + + #[actix_rt::test] + async fn test_data_factory() { + let mut srv = + init_service(App::new().data_factory(|| ok::<_, ()>(10usize)).service( + web::resource("/").to(|_: web::Data| HttpResponse::Ok()), + )) + .await; + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let mut srv = + init_service(App::new().data_factory(|| ok::<_, ()>(10u32)).service( + web::resource("/").to(|_: web::Data| HttpResponse::Ok()), + )) + .await; + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + } + + #[actix_rt::test] + async fn test_wrap() { + let mut srv = init_service( + App::new() + .wrap( + DefaultHeaders::new() + .header(header::CONTENT_TYPE, HeaderValue::from_static("0001")), + ) + .route("/test", web::get().to(|| HttpResponse::Ok())), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static("0001") + ); + } + + #[actix_rt::test] + async fn test_router_wrap() { + let mut srv = init_service( + App::new() + .route("/test", web::get().to(|| HttpResponse::Ok())) + .wrap( + DefaultHeaders::new() + .header(header::CONTENT_TYPE, HeaderValue::from_static("0001")), + ), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static("0001") + ); + } + + #[actix_rt::test] + async fn test_wrap_fn() { + let mut srv = init_service( + App::new() + .wrap_fn(|req, srv| { + let fut = srv.call(req); + async move { + let mut res = fut.await?; + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("0001"), + ); + Ok(res) + } + }) + .service(web::resource("/test").to(|| HttpResponse::Ok())), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static("0001") + ); + } + + #[actix_rt::test] + async fn test_router_wrap_fn() { + let mut srv = init_service( + App::new() + .route("/test", web::get().to(|| HttpResponse::Ok())) + .wrap_fn(|req, srv| { + let fut = srv.call(req); + async { + let mut res = fut.await?; + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("0001"), + ); + Ok(res) + } + }), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static("0001") + ); + } + + #[actix_rt::test] + async fn test_external_resource() { + let mut srv = init_service( + App::new() + .external_resource("youtube", "https://youtube.com/watch/{video_id}") + .route( + "/test", + web::get().to(|req: HttpRequest| { + HttpResponse::Ok().body(format!( + "{}", + req.url_for("youtube", &["12345"]).unwrap() + )) + }), + ), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + let body = read_body(resp).await; + assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345")); + } +} diff --git a/src/app_service.rs b/src/app_service.rs new file mode 100644 index 000000000..3fa5a6eed --- /dev/null +++ b/src/app_service.rs @@ -0,0 +1,481 @@ +use std::cell::RefCell; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; + +use actix_http::{Extensions, Request, Response}; +use actix_router::{Path, ResourceDef, ResourceInfo, Router, Url}; +use actix_server_config::ServerConfig; +use actix_service::boxed::{self, BoxService, BoxServiceFactory}; +use actix_service::{service_fn, Service, ServiceFactory}; +use futures::future::{ok, FutureExt, LocalBoxFuture}; + +use crate::config::{AppConfig, AppService}; +use crate::data::DataFactory; +use crate::error::Error; +use crate::guard::Guard; +use crate::request::{HttpRequest, HttpRequestPool}; +use crate::rmap::ResourceMap; +use crate::service::{AppServiceFactory, ServiceRequest, ServiceResponse}; + +type Guards = Vec>; +type HttpService = BoxService; +type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; +type BoxResponse = LocalBoxFuture<'static, Result>; +type FnDataFactory = + Box LocalBoxFuture<'static, Result, ()>>>; + +/// Service factory to convert `Request` to a `ServiceRequest`. +/// It also executes data factories. +pub struct AppInit +where + T: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + >, +{ + pub(crate) endpoint: T, + pub(crate) data: Rc>>, + pub(crate) data_factories: Rc>, + pub(crate) config: RefCell, + pub(crate) services: Rc>>>, + pub(crate) default: Option>, + pub(crate) factory_ref: Rc>>, + pub(crate) external: RefCell>, +} + +impl ServiceFactory for AppInit +where + T: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + >, +{ + type Config = ServerConfig; + type Request = Request; + type Response = ServiceResponse; + type Error = T::Error; + type InitError = T::InitError; + type Service = AppInitService; + type Future = AppInitResult; + + fn new_service(&self, cfg: &ServerConfig) -> Self::Future { + // update resource default service + let default = self.default.clone().unwrap_or_else(|| { + Rc::new(boxed::factory(service_fn(|req: ServiceRequest| { + ok(req.into_response(Response::NotFound().finish())) + }))) + }); + + // App config + { + let mut c = self.config.borrow_mut(); + let loc_cfg = Rc::get_mut(&mut c.0).unwrap(); + loc_cfg.secure = cfg.secure(); + loc_cfg.addr = cfg.local_addr(); + } + + let mut config = AppService::new( + self.config.borrow().clone(), + default.clone(), + self.data.clone(), + ); + + // register services + std::mem::replace(&mut *self.services.borrow_mut(), Vec::new()) + .into_iter() + .for_each(|mut srv| srv.register(&mut config)); + + let mut rmap = ResourceMap::new(ResourceDef::new("")); + + let (config, services) = config.into_services(); + + // complete pipeline creation + *self.factory_ref.borrow_mut() = Some(AppRoutingFactory { + default, + services: Rc::new( + services + .into_iter() + .map(|(mut rdef, srv, guards, nested)| { + rmap.add(&mut rdef, nested); + (rdef, srv, RefCell::new(guards)) + }) + .collect(), + ), + }); + + // external resources + for mut rdef in std::mem::replace(&mut *self.external.borrow_mut(), Vec::new()) { + rmap.add(&mut rdef, None); + } + + // complete ResourceMap tree creation + let rmap = Rc::new(rmap); + rmap.finish(rmap.clone()); + + AppInitResult { + endpoint: None, + endpoint_fut: self.endpoint.new_service(&()), + data: self.data.clone(), + data_factories: Vec::new(), + data_factories_fut: self.data_factories.iter().map(|f| f()).collect(), + config, + rmap, + _t: PhantomData, + } + } +} + +#[pin_project::pin_project] +pub struct AppInitResult +where + T: ServiceFactory, +{ + endpoint: Option, + #[pin] + endpoint_fut: T::Future, + rmap: Rc, + config: AppConfig, + data: Rc>>, + data_factories: Vec>, + data_factories_fut: Vec, ()>>>, + _t: PhantomData, +} + +impl Future for AppInitResult +where + T: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + >, +{ + type Output = Result, ()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); + + // async data factories + let mut idx = 0; + while idx < this.data_factories_fut.len() { + match Pin::new(&mut this.data_factories_fut[idx]).poll(cx)? { + Poll::Ready(f) => { + this.data_factories.push(f); + let _ = this.data_factories_fut.remove(idx); + } + Poll::Pending => idx += 1, + } + } + + if this.endpoint.is_none() { + if let Poll::Ready(srv) = this.endpoint_fut.poll(cx)? { + *this.endpoint = Some(srv); + } + } + + if this.endpoint.is_some() && this.data_factories_fut.is_empty() { + // create app data container + let mut data = Extensions::new(); + for f in this.data.iter() { + f.create(&mut data); + } + + for f in this.data_factories.iter() { + f.create(&mut data); + } + + Poll::Ready(Ok(AppInitService { + service: this.endpoint.take().unwrap(), + rmap: this.rmap.clone(), + config: this.config.clone(), + data: Rc::new(data), + pool: HttpRequestPool::create(), + })) + } else { + Poll::Pending + } + } +} + +/// Service to convert `Request` to a `ServiceRequest` +pub struct AppInitService +where + T: Service, Error = Error>, +{ + service: T, + rmap: Rc, + config: AppConfig, + data: Rc, + pool: &'static HttpRequestPool, +} + +impl Service for AppInitService +where + T: Service, Error = Error>, +{ + type Request = Request; + type Response = ServiceResponse; + type Error = T::Error; + type Future = T::Future; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let (head, payload) = req.into_parts(); + + let req = if let Some(mut req) = self.pool.get_request() { + let inner = Rc::get_mut(&mut req.0).unwrap(); + inner.path.get_mut().update(&head.uri); + inner.path.reset(); + inner.head = head; + inner.payload = payload; + inner.app_data = self.data.clone(); + req + } else { + HttpRequest::new( + Path::new(Url::new(head.uri.clone())), + head, + payload, + self.rmap.clone(), + self.config.clone(), + self.data.clone(), + self.pool, + ) + }; + self.service.call(ServiceRequest::new(req)) + } +} + +impl Drop for AppInitService +where + T: Service, Error = Error>, +{ + fn drop(&mut self) { + self.pool.clear(); + } +} + +pub struct AppRoutingFactory { + services: Rc>)>>, + default: Rc, +} + +impl ServiceFactory for AppRoutingFactory { + type Config = (); + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Service = AppRouting; + type Future = AppRoutingFactoryResponse; + + fn new_service(&self, _: &()) -> Self::Future { + AppRoutingFactoryResponse { + fut: self + .services + .iter() + .map(|(path, service, guards)| { + CreateAppRoutingItem::Future( + Some(path.clone()), + guards.borrow_mut().take(), + service.new_service(&()).boxed_local(), + ) + }) + .collect(), + default: None, + default_fut: Some(self.default.new_service(&())), + } + } +} + +type HttpServiceFut = LocalBoxFuture<'static, Result>; + +/// Create app service +#[doc(hidden)] +pub struct AppRoutingFactoryResponse { + fut: Vec, + default: Option, + default_fut: Option>>, +} + +enum CreateAppRoutingItem { + Future(Option, Option, HttpServiceFut), + Service(ResourceDef, Option, HttpService), +} + +impl Future for AppRoutingFactoryResponse { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut done = true; + + if let Some(ref mut fut) = self.default_fut { + match Pin::new(fut).poll(cx)? { + Poll::Ready(default) => self.default = Some(default), + Poll::Pending => done = false, + } + } + + // poll http services + for item in &mut self.fut { + let res = match item { + CreateAppRoutingItem::Future( + ref mut path, + ref mut guards, + ref mut fut, + ) => match Pin::new(fut).poll(cx) { + Poll::Ready(Ok(service)) => { + Some((path.take().unwrap(), guards.take(), service)) + } + Poll::Ready(Err(_)) => return Poll::Ready(Err(())), + Poll::Pending => { + done = false; + None + } + }, + CreateAppRoutingItem::Service(_, _, _) => continue, + }; + + if let Some((path, guards, service)) = res { + *item = CreateAppRoutingItem::Service(path, guards, service); + } + } + + if done { + let router = self + .fut + .drain(..) + .fold(Router::build(), |mut router, item| { + match item { + CreateAppRoutingItem::Service(path, guards, service) => { + router.rdef(path, service).2 = guards; + } + CreateAppRoutingItem::Future(_, _, _) => unreachable!(), + } + router + }); + Poll::Ready(Ok(AppRouting { + ready: None, + router: router.finish(), + default: self.default.take(), + })) + } else { + Poll::Pending + } + } +} + +pub struct AppRouting { + router: Router, + ready: Option<(ServiceRequest, ResourceInfo)>, + default: Option, +} + +impl Service for AppRouting { + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Future = BoxResponse; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + if self.ready.is_none() { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + + fn call(&mut self, mut req: ServiceRequest) -> Self::Future { + let res = self.router.recognize_mut_checked(&mut req, |req, guards| { + if let Some(ref guards) = guards { + for f in guards { + if !f.check(req.head()) { + return false; + } + } + } + true + }); + + if let Some((srv, _info)) = res { + srv.call(req) + } else if let Some(ref mut default) = self.default { + default.call(req) + } else { + let req = req.into_parts().0; + ok(ServiceResponse::new(req, Response::NotFound().finish())).boxed_local() + } + } +} + +/// Wrapper service for routing +pub struct AppEntry { + factory: Rc>>, +} + +impl AppEntry { + pub fn new(factory: Rc>>) -> Self { + AppEntry { factory } + } +} + +impl ServiceFactory for AppEntry { + type Config = (); + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Service = AppRouting; + type Future = AppRoutingFactoryResponse; + + fn new_service(&self, _: &()) -> Self::Future { + self.factory.borrow_mut().as_mut().unwrap().new_service(&()) + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + + use crate::test::{init_service, TestRequest}; + use crate::{web, App, HttpResponse}; + use actix_service::Service; + + struct DropData(Arc); + + impl Drop for DropData { + fn drop(&mut self) { + self.0.store(true, Ordering::Relaxed); + } + } + + #[actix_rt::test] + async fn test_drop_data() { + let data = Arc::new(AtomicBool::new(false)); + + { + let mut app = init_service( + App::new() + .data(DropData(data.clone())) + .service(web::resource("/test").to(|| HttpResponse::Ok())), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let _ = app.call(req).await.unwrap(); + } + assert!(data.load(Ordering::Relaxed)); + } +} diff --git a/src/application.rs b/src/application.rs deleted file mode 100644 index 38886efc5..000000000 --- a/src/application.rs +++ /dev/null @@ -1,639 +0,0 @@ -use std::mem; -use std::rc::Rc; -use std::cell::UnsafeCell; -use std::collections::HashMap; - -use handler::Reply; -use router::{Router, Resource}; -use resource::{ResourceHandler}; -use header::ContentEncoding; -use handler::{Handler, RouteHandler, WrapHandler}; -use httprequest::HttpRequest; -use pipeline::{Pipeline, PipelineHandler, HandlerType}; -use middleware::Middleware; -use server::{HttpHandler, IntoHttpHandler, HttpHandlerTask, ServerSettings}; - -#[deprecated(since="0.5.0", note="please use `actix_web::App` instead")] -pub type Application = App; - -/// Application -pub struct HttpApplication { - state: Rc, - prefix: String, - router: Router, - inner: Rc>>, - middlewares: Rc>>>, -} - -pub(crate) struct Inner { - prefix: usize, - default: ResourceHandler, - encoding: ContentEncoding, - resources: Vec>, - handlers: Vec<(String, Box>)>, -} - -impl PipelineHandler for Inner { - - fn encoding(&self) -> ContentEncoding { - self.encoding - } - - fn handle(&mut self, req: HttpRequest, htype: HandlerType) -> Reply { - match htype { - HandlerType::Normal(idx) => - self.resources[idx].handle(req, Some(&mut self.default)), - HandlerType::Handler(idx) => - self.handlers[idx].1.handle(req), - HandlerType::Default => - self.default.handle(req, None) - } - } -} - -impl HttpApplication { - - #[inline] - fn as_ref(&self) -> &Inner { - unsafe{&*self.inner.get()} - } - - #[inline] - fn get_handler(&self, req: &mut HttpRequest) -> HandlerType { - if let Some(idx) = self.router.recognize(req) { - HandlerType::Normal(idx) - } else { - let inner = self.as_ref(); - for idx in 0..inner.handlers.len() { - let &(ref prefix, _) = &inner.handlers[idx]; - let m = { - let path = &req.path()[inner.prefix..]; - path.starts_with(prefix) && ( - path.len() == prefix.len() || - path.split_at(prefix.len()).1.starts_with('/')) - }; - if m { - let path: &'static str = unsafe { - mem::transmute(&req.path()[inner.prefix+prefix.len()..]) }; - if path.is_empty() { - req.match_info_mut().add("tail", ""); - } else { - req.match_info_mut().add("tail", path.split_at(1).1); - } - return HandlerType::Handler(idx) - } - } - HandlerType::Default - } - } - - #[cfg(test)] - pub(crate) fn run(&mut self, mut req: HttpRequest) -> Reply { - let tp = self.get_handler(&mut req); - unsafe{&mut *self.inner.get()}.handle(req, tp) - } - - #[cfg(test)] - pub(crate) fn prepare_request(&self, req: HttpRequest) -> HttpRequest { - req.with_state(Rc::clone(&self.state), self.router.clone()) - } -} - -impl HttpHandler for HttpApplication { - - fn handle(&mut self, req: HttpRequest) -> Result, HttpRequest> { - let m = { - let path = req.path(); - path.starts_with(&self.prefix) && ( - path.len() == self.prefix.len() || - path.split_at(self.prefix.len()).1.starts_with('/')) - }; - if m { - let mut req = req.with_state(Rc::clone(&self.state), self.router.clone()); - let tp = self.get_handler(&mut req); - let inner = Rc::clone(&self.inner); - Ok(Box::new(Pipeline::new(req, Rc::clone(&self.middlewares), inner, tp))) - } else { - Err(req) - } - } -} - -struct ApplicationParts { - state: S, - prefix: String, - settings: ServerSettings, - default: ResourceHandler, - resources: Vec<(Resource, Option>)>, - handlers: Vec<(String, Box>)>, - external: HashMap, - encoding: ContentEncoding, - middlewares: Vec>>, -} - -/// Structure that follows the builder pattern for building application instances. -pub struct App { - parts: Option>, -} - -impl App<()> { - - /// Create application with empty state. Application can - /// be configured with builder-like pattern. - pub fn new() -> App<()> { - App { - parts: Some(ApplicationParts { - state: (), - prefix: "/".to_owned(), - settings: ServerSettings::default(), - default: ResourceHandler::default_not_found(), - resources: Vec::new(), - handlers: Vec::new(), - external: HashMap::new(), - encoding: ContentEncoding::Auto, - middlewares: Vec::new(), - }) - } - } -} - -impl Default for App<()> { - fn default() -> Self { - App::new() - } -} - -impl App where S: 'static { - - /// Create application with specific state. Application can be - /// configured with builder-like pattern. - /// - /// State is shared with all resources within same application and could be - /// accessed with `HttpRequest::state()` method. - pub fn with_state(state: S) -> App { - App { - parts: Some(ApplicationParts { - state, - prefix: "/".to_owned(), - settings: ServerSettings::default(), - default: ResourceHandler::default_not_found(), - resources: Vec::new(), - handlers: Vec::new(), - external: HashMap::new(), - middlewares: Vec::new(), - encoding: ContentEncoding::Auto, - }) - } - } - - /// Set application prefix - /// - /// Only requests that matches application's prefix get processed by this application. - /// Application prefix always contains leading "/" slash. If supplied prefix - /// does not contain leading slash, it get inserted. Prefix should - /// consists valid path segments. i.e for application with - /// prefix `/app` any request with following paths `/app`, `/app/` or `/app/test` - /// would match, but path `/application` would not match. - /// - /// In the following example only requests with "/app/" path prefix - /// get handled. Request with path "/app/test/" would be handled, - /// but request with path "/application" or "/other/..." would return *NOT FOUND* - /// - /// ```rust - /// # extern crate actix_web; - /// use actix_web::{http, App, HttpResponse}; - /// - /// fn main() { - /// let app = App::new() - /// .prefix("/app") - /// .resource("/test", |r| { - /// r.method(http::Method::GET).f(|_| HttpResponse::Ok()); - /// r.method(http::Method::HEAD).f(|_| HttpResponse::MethodNotAllowed()); - /// }) - /// .finish(); - /// } - /// ``` - pub fn prefix>(mut self, prefix: P) -> App { - { - let parts = self.parts.as_mut().expect("Use after finish"); - let mut prefix = prefix.into(); - if !prefix.starts_with('/') { - prefix.insert(0, '/') - } - parts.prefix = prefix; - } - self - } - - /// Configure resource for specific path. - /// - /// Resource may have variable path also. For instance, a resource with - /// the path */a/{name}/c* would match all incoming requests with paths - /// such as */a/b/c*, */a/1/c*, and */a/etc/c*. - /// - /// A variable part is specified in the form `{identifier}`, where - /// the identifier can be used later in a request handler to access the matched - /// value for that part. This is done by looking up the identifier - /// in the `Params` object returned by `HttpRequest.match_info()` method. - /// - /// By default, each part matches the regular expression `[^{}/]+`. - /// - /// You can also specify a custom regex in the form `{identifier:regex}`: - /// - /// For instance, to route Get requests on any route matching `/users/{userid}/{friend}` and - /// store userid and friend in the exposed Params object: - /// - /// ```rust - /// # extern crate actix_web; - /// use actix_web::{http, App, HttpResponse}; - /// - /// fn main() { - /// let app = App::new() - /// .resource("/test", |r| { - /// r.method(http::Method::GET).f(|_| HttpResponse::Ok()); - /// r.method(http::Method::HEAD).f(|_| HttpResponse::MethodNotAllowed()); - /// }); - /// } - /// ``` - pub fn resource(mut self, path: &str, f: F) -> App - where F: FnOnce(&mut ResourceHandler) + 'static - { - { - let parts = self.parts.as_mut().expect("Use after finish"); - - // add resource - let mut resource = ResourceHandler::default(); - f(&mut resource); - - let pattern = Resource::new(resource.get_name(), path); - parts.resources.push((pattern, Some(resource))); - } - self - } - - /// Default resource is used if no matched route could be found. - pub fn default_resource(mut self, f: F) -> App - where F: FnOnce(&mut ResourceHandler) + 'static - { - { - let parts = self.parts.as_mut().expect("Use after finish"); - f(&mut parts.default); - } - self - } - - /// Set default content encoding. `ContentEncoding::Auto` is set by default. - pub fn default_encoding(mut self, encoding: ContentEncoding) -> App - { - { - let parts = self.parts.as_mut().expect("Use after finish"); - parts.encoding = encoding; - } - self - } - - /// Register external resource. - /// - /// External resources are useful for URL generation purposes only and - /// are never considered for matching at request time. - /// Call to `HttpRequest::url_for()` will work as expected. - /// - /// ```rust - /// # extern crate actix_web; - /// use actix_web::{App, HttpRequest, HttpResponse, Result}; - /// - /// fn index(mut req: HttpRequest) -> Result { - /// let url = req.url_for("youtube", &["oHg5SJYRHA0"])?; - /// assert_eq!(url.as_str(), "https://youtube.com/watch/oHg5SJYRHA0"); - /// Ok(HttpResponse::Ok().into()) - /// } - /// - /// fn main() { - /// let app = App::new() - /// .resource("/index.html", |r| r.f(index)) - /// .external_resource("youtube", "https://youtube.com/watch/{video_id}") - /// .finish(); - /// } - /// ``` - pub fn external_resource(mut self, name: T, url: U) -> App - where T: AsRef, U: AsRef - { - { - let parts = self.parts.as_mut().expect("Use after finish"); - - if parts.external.contains_key(name.as_ref()) { - panic!("External resource {:?} is registered.", name.as_ref()); - } - parts.external.insert( - String::from(name.as_ref()), - Resource::external(name.as_ref(), url.as_ref())); - } - self - } - - /// Configure handler for specific path prefix. - /// - /// Path prefix consists valid path segments. i.e for prefix `/app` - /// any request with following paths `/app`, `/app/` or `/app/test` - /// would match, but path `/application` would not match. - /// - /// ```rust - /// # extern crate actix_web; - /// use actix_web::{http, App, HttpRequest, HttpResponse}; - /// - /// fn main() { - /// let app = App::new() - /// .handler("/app", |req: HttpRequest| { - /// match *req.method() { - /// http::Method::GET => HttpResponse::Ok(), - /// http::Method::POST => HttpResponse::MethodNotAllowed(), - /// _ => HttpResponse::NotFound(), - /// }}); - /// } - /// ``` - pub fn handler>(mut self, path: &str, handler: H) -> App - { - { - let path = path.trim().trim_right_matches('/').to_owned(); - let parts = self.parts.as_mut().expect("Use after finish"); - parts.handlers.push((path, Box::new(WrapHandler::new(handler)))); - } - self - } - - /// Register a middleware - pub fn middleware>(mut self, mw: M) -> App { - self.parts.as_mut().expect("Use after finish") - .middlewares.push(Box::new(mw)); - self - } - - /// Run external configuration as part of application building process - /// - /// This function is useful for moving part of configuration to a different - /// module or event library. For example we can move some of the resources - /// configuration to different module. - /// - /// ```rust - /// # extern crate actix_web; - /// use actix_web::{App, HttpResponse, http, fs, middleware}; - /// - /// // this function could be located in different module - /// fn config(app: App) -> App { - /// app - /// .resource("/test", |r| { - /// r.method(http::Method::GET).f(|_| HttpResponse::Ok()); - /// r.method(http::Method::HEAD).f(|_| HttpResponse::MethodNotAllowed()); - /// }) - /// } - /// - /// fn main() { - /// let app = App::new() - /// .middleware(middleware::Logger::default()) - /// .configure(config) // <- register resources - /// .handler("/static", fs::StaticFiles::new(".", true)); - /// } - /// ``` - pub fn configure(self, cfg: F) -> App - where F: Fn(App) -> App - { - cfg(self) - } - - /// Finish application configuration and create HttpHandler object - pub fn finish(&mut self) -> HttpApplication { - let parts = self.parts.take().expect("Use after finish"); - let prefix = parts.prefix.trim().trim_right_matches('/'); - - let mut resources = parts.resources; - for (_, pattern) in parts.external { - resources.push((pattern, None)); - } - - let (router, resources) = Router::new(prefix, parts.settings, resources); - - let inner = Rc::new(UnsafeCell::new( - Inner { - prefix: prefix.len(), - default: parts.default, - encoding: parts.encoding, - handlers: parts.handlers, - resources, - } - )); - - HttpApplication { - state: Rc::new(parts.state), - prefix: prefix.to_owned(), - router: router.clone(), - middlewares: Rc::new(parts.middlewares), - inner, - } - } - - /// Convenience method for creating `Box` instance. - /// - /// This method is useful if you need to register multiple application instances - /// with different state. - /// - /// ```rust - /// # use std::thread; - /// # extern crate actix_web; - /// use actix_web::*; - /// - /// struct State1; - /// - /// struct State2; - /// - /// fn main() { - /// # thread::spawn(|| { - /// HttpServer::new(|| { vec![ - /// App::with_state(State1) - /// .prefix("/app1") - /// .resource("/", |r| r.f(|r| HttpResponse::Ok())) - /// .boxed(), - /// App::with_state(State2) - /// .prefix("/app2") - /// .resource("/", |r| r.f(|r| HttpResponse::Ok())) - /// .boxed() ]}) - /// .bind("127.0.0.1:8080").unwrap() - /// .run() - /// # }); - /// } - /// ``` - pub fn boxed(mut self) -> Box { - Box::new(self.finish()) - } -} - -impl IntoHttpHandler for App { - type Handler = HttpApplication; - - fn into_handler(mut self, settings: ServerSettings) -> HttpApplication { - { - let parts = self.parts.as_mut().expect("Use after finish"); - parts.settings = settings; - } - self.finish() - } -} - -impl<'a, S: 'static> IntoHttpHandler for &'a mut App { - type Handler = HttpApplication; - - fn into_handler(self, settings: ServerSettings) -> HttpApplication { - { - let parts = self.parts.as_mut().expect("Use after finish"); - parts.settings = settings; - } - self.finish() - } -} - -#[doc(hidden)] -impl Iterator for App { - type Item = HttpApplication; - - fn next(&mut self) -> Option { - if self.parts.is_some() { - Some(self.finish()) - } else { - None - } - } -} - - -#[cfg(test)] -mod tests { - use http::StatusCode; - use super::*; - use test::TestRequest; - use httprequest::HttpRequest; - use httpresponse::HttpResponse; - - #[test] - fn test_default_resource() { - let mut app = App::new() - .resource("/test", |r| r.f(|_| HttpResponse::Ok())) - .finish(); - - let req = TestRequest::with_uri("/test").finish(); - let resp = app.run(req); - assert_eq!(resp.as_response().unwrap().status(), StatusCode::OK); - - let req = TestRequest::with_uri("/blah").finish(); - let resp = app.run(req); - assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); - - let mut app = App::new() - .default_resource(|r| r.f(|_| HttpResponse::MethodNotAllowed())) - .finish(); - let req = TestRequest::with_uri("/blah").finish(); - let resp = app.run(req); - assert_eq!(resp.as_response().unwrap().status(), StatusCode::METHOD_NOT_ALLOWED); - } - - #[test] - fn test_unhandled_prefix() { - let mut app = App::new() - .prefix("/test") - .resource("/test", |r| r.f(|_| HttpResponse::Ok())) - .finish(); - assert!(app.handle(HttpRequest::default()).is_err()); - } - - #[test] - fn test_state() { - let mut app = App::with_state(10) - .resource("/", |r| r.f(|_| HttpResponse::Ok())) - .finish(); - let req = HttpRequest::default().with_state(Rc::clone(&app.state), app.router.clone()); - let resp = app.run(req); - assert_eq!(resp.as_response().unwrap().status(), StatusCode::OK); - } - - #[test] - fn test_prefix() { - let mut app = App::new() - .prefix("/test") - .resource("/blah", |r| r.f(|_| HttpResponse::Ok())) - .finish(); - let req = TestRequest::with_uri("/test").finish(); - let resp = app.handle(req); - assert!(resp.is_ok()); - - let req = TestRequest::with_uri("/test/").finish(); - let resp = app.handle(req); - assert!(resp.is_ok()); - - let req = TestRequest::with_uri("/test/blah").finish(); - let resp = app.handle(req); - assert!(resp.is_ok()); - - let req = TestRequest::with_uri("/testing").finish(); - let resp = app.handle(req); - assert!(resp.is_err()); - } - - #[test] - fn test_handler() { - let mut app = App::new() - .handler("/test", |_| HttpResponse::Ok()) - .finish(); - - let req = TestRequest::with_uri("/test").finish(); - let resp = app.run(req); - assert_eq!(resp.as_response().unwrap().status(), StatusCode::OK); - - let req = TestRequest::with_uri("/test/").finish(); - let resp = app.run(req); - assert_eq!(resp.as_response().unwrap().status(), StatusCode::OK); - - let req = TestRequest::with_uri("/test/app").finish(); - let resp = app.run(req); - assert_eq!(resp.as_response().unwrap().status(), StatusCode::OK); - - let req = TestRequest::with_uri("/testapp").finish(); - let resp = app.run(req); - assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); - - let req = TestRequest::with_uri("/blah").finish(); - let resp = app.run(req); - assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); - } - - #[test] - fn test_handler_prefix() { - let mut app = App::new() - .prefix("/app") - .handler("/test", |_| HttpResponse::Ok()) - .finish(); - - let req = TestRequest::with_uri("/test").finish(); - let resp = app.run(req); - assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); - - let req = TestRequest::with_uri("/app/test").finish(); - let resp = app.run(req); - assert_eq!(resp.as_response().unwrap().status(), StatusCode::OK); - - let req = TestRequest::with_uri("/app/test/").finish(); - let resp = app.run(req); - assert_eq!(resp.as_response().unwrap().status(), StatusCode::OK); - - let req = TestRequest::with_uri("/app/test/app").finish(); - let resp = app.run(req); - assert_eq!(resp.as_response().unwrap().status(), StatusCode::OK); - - let req = TestRequest::with_uri("/app/testapp").finish(); - let resp = app.run(req); - assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); - - let req = TestRequest::with_uri("/app/blah").finish(); - let resp = app.run(req); - assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); - - } - -} diff --git a/src/body.rs b/src/body.rs deleted file mode 100644 index 97b8850c8..000000000 --- a/src/body.rs +++ /dev/null @@ -1,365 +0,0 @@ -use std::{fmt, mem}; -use std::rc::Rc; -use std::sync::Arc; -use bytes::{Bytes, BytesMut}; -use futures::Stream; - -use error::Error; -use context::ActorHttpContext; -use handler::Responder; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; - - -/// Type represent streaming body -pub type BodyStream = Box>; - -/// Represents various types of http message body. -pub enum Body { - /// Empty response. `Content-Length` header is set to `0` - Empty, - /// Specific response body. - Binary(Binary), - /// Unspecified streaming response. Developer is responsible for setting - /// right `Content-Length` or `Transfer-Encoding` headers. - Streaming(BodyStream), - /// Special body type for actor response. - Actor(Box), -} - -/// Represents various types of binary body. -/// `Content-Length` header is set to length of the body. -#[derive(Debug, PartialEq)] -pub enum Binary { - /// Bytes body - Bytes(Bytes), - /// Static slice - Slice(&'static [u8]), - /// Shared string body - SharedString(Rc), - /// Shared string body - #[doc(hidden)] - ArcSharedString(Arc), - /// Shared vec body - SharedVec(Arc>), -} - -impl Body { - /// Does this body streaming. - #[inline] - pub fn is_streaming(&self) -> bool { - match *self { - Body::Streaming(_) | Body::Actor(_) => true, - _ => false - } - } - - /// Is this binary body. - #[inline] - pub fn is_binary(&self) -> bool { - match *self { - Body::Binary(_) => true, - _ => false - } - } - - /// Create body from slice (copy) - pub fn from_slice(s: &[u8]) -> Body { - Body::Binary(Binary::Bytes(Bytes::from(s))) - } -} - -impl PartialEq for Body { - fn eq(&self, other: &Body) -> bool { - match *self { - Body::Empty => match *other { - Body::Empty => true, - _ => false, - }, - Body::Binary(ref b) => match *other { - Body::Binary(ref b2) => b == b2, - _ => false, - }, - Body::Streaming(_) | Body::Actor(_) => false, - } - } -} - -impl fmt::Debug for Body { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Body::Empty => write!(f, "Body::Empty"), - Body::Binary(ref b) => write!(f, "Body::Binary({:?})", b), - Body::Streaming(_) => write!(f, "Body::Streaming(_)"), - Body::Actor(_) => write!(f, "Body::Actor(_)"), - } - } -} - -impl From for Body where T: Into{ - fn from(b: T) -> Body { - Body::Binary(b.into()) - } -} - -impl From> for Body { - fn from(ctx: Box) -> Body { - Body::Actor(ctx) - } -} - -impl Binary { - #[inline] - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - #[inline] - pub fn len(&self) -> usize { - match *self { - Binary::Bytes(ref bytes) => bytes.len(), - Binary::Slice(slice) => slice.len(), - Binary::SharedString(ref s) => s.len(), - Binary::ArcSharedString(ref s) => s.len(), - Binary::SharedVec(ref s) => s.len(), - } - } - - /// Create binary body from slice - pub fn from_slice(s: &[u8]) -> Binary { - Binary::Bytes(Bytes::from(s)) - } - - /// Convert Binary to a Bytes instance - pub fn take(&mut self) -> Bytes { - mem::replace(self, Binary::Slice(b"")).into() - } -} - -impl Clone for Binary { - fn clone(&self) -> Binary { - match *self { - Binary::Bytes(ref bytes) => Binary::Bytes(bytes.clone()), - Binary::Slice(slice) => Binary::Bytes(Bytes::from(slice)), - Binary::SharedString(ref s) => Binary::SharedString(s.clone()), - Binary::ArcSharedString(ref s) => Binary::ArcSharedString(s.clone()), - Binary::SharedVec(ref s) => Binary::SharedVec(s.clone()), - } - } -} - -impl Into for Binary { - fn into(self) -> Bytes { - match self { - Binary::Bytes(bytes) => bytes, - Binary::Slice(slice) => Bytes::from(slice), - Binary::SharedString(s) => Bytes::from(s.as_str()), - Binary::ArcSharedString(s) => Bytes::from(s.as_str()), - Binary::SharedVec(s) => Bytes::from(AsRef::<[u8]>::as_ref(s.as_ref())), - } - } -} - -impl From<&'static str> for Binary { - fn from(s: &'static str) -> Binary { - Binary::Slice(s.as_ref()) - } -} - -impl From<&'static [u8]> for Binary { - fn from(s: &'static [u8]) -> Binary { - Binary::Slice(s) - } -} - -impl From> for Binary { - fn from(vec: Vec) -> Binary { - Binary::Bytes(Bytes::from(vec)) - } -} - -impl From for Binary { - fn from(s: String) -> Binary { - Binary::Bytes(Bytes::from(s)) - } -} - -impl<'a> From<&'a String> for Binary { - fn from(s: &'a String) -> Binary { - Binary::Bytes(Bytes::from(AsRef::<[u8]>::as_ref(&s))) - } -} - -impl From for Binary { - fn from(s: Bytes) -> Binary { - Binary::Bytes(s) - } -} - -impl From for Binary { - fn from(s: BytesMut) -> Binary { - Binary::Bytes(s.freeze()) - } -} - -impl From> for Binary { - fn from(body: Rc) -> Binary { - Binary::SharedString(body) - } -} - -impl<'a> From<&'a Rc> for Binary { - fn from(body: &'a Rc) -> Binary { - Binary::SharedString(Rc::clone(body)) - } -} - -impl From> for Binary { - fn from(body: Arc) -> Binary { - Binary::ArcSharedString(body) - } -} - -impl<'a> From<&'a Arc> for Binary { - fn from(body: &'a Arc) -> Binary { - Binary::ArcSharedString(Arc::clone(body)) - } -} - -impl From>> for Binary { - fn from(body: Arc>) -> Binary { - Binary::SharedVec(body) - } -} - -impl<'a> From<&'a Arc>> for Binary { - fn from(body: &'a Arc>) -> Binary { - Binary::SharedVec(Arc::clone(body)) - } -} - -impl AsRef<[u8]> for Binary { - #[inline] - fn as_ref(&self) -> &[u8] { - match *self { - Binary::Bytes(ref bytes) => bytes.as_ref(), - Binary::Slice(slice) => slice, - Binary::SharedString(ref s) => s.as_bytes(), - Binary::ArcSharedString(ref s) => s.as_bytes(), - Binary::SharedVec(ref s) => s.as_ref().as_ref(), - } - } -} - -impl Responder for Binary { - type Item = HttpResponse; - type Error = Error; - - fn respond_to(self, _: HttpRequest) -> Result { - Ok(HttpResponse::Ok() - .content_type("application/octet-stream") - .body(self)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_body_is_streaming() { - assert_eq!(Body::Empty.is_streaming(), false); - assert_eq!(Body::Binary(Binary::from("")).is_streaming(), false); - } - - #[test] - fn test_is_empty() { - assert_eq!(Binary::from("").is_empty(), true); - assert_eq!(Binary::from("test").is_empty(), false); - } - - #[test] - fn test_static_str() { - assert_eq!(Binary::from("test").len(), 4); - assert_eq!(Binary::from("test").as_ref(), "test".as_bytes()); - } - - #[test] - fn test_static_bytes() { - assert_eq!(Binary::from(b"test".as_ref()).len(), 4); - assert_eq!(Binary::from(b"test".as_ref()).as_ref(), "test".as_bytes()); - assert_eq!(Binary::from_slice(b"test".as_ref()).len(), 4); - assert_eq!(Binary::from_slice(b"test".as_ref()).as_ref(), "test".as_bytes()); - } - - #[test] - fn test_vec() { - assert_eq!(Binary::from(Vec::from("test")).len(), 4); - assert_eq!(Binary::from(Vec::from("test")).as_ref(), "test".as_bytes()); - } - - #[test] - fn test_bytes() { - assert_eq!(Binary::from(Bytes::from("test")).len(), 4); - assert_eq!(Binary::from(Bytes::from("test")).as_ref(), "test".as_bytes()); - } - - #[test] - fn test_ref_string() { - let b = Rc::new("test".to_owned()); - assert_eq!(Binary::from(&b).len(), 4); - assert_eq!(Binary::from(&b).as_ref(), "test".as_bytes()); - } - - #[test] - fn test_rc_string() { - let b = Rc::new("test".to_owned()); - assert_eq!(Binary::from(b.clone()).len(), 4); - assert_eq!(Binary::from(b.clone()).as_ref(), "test".as_bytes()); - assert_eq!(Binary::from(&b).len(), 4); - assert_eq!(Binary::from(&b).as_ref(), "test".as_bytes()); - } - - #[test] - fn test_arc_string() { - let b = Arc::new("test".to_owned()); - assert_eq!(Binary::from(b.clone()).len(), 4); - assert_eq!(Binary::from(b.clone()).as_ref(), "test".as_bytes()); - assert_eq!(Binary::from(&b).len(), 4); - assert_eq!(Binary::from(&b).as_ref(), "test".as_bytes()); - } - - #[test] - fn test_string() { - let b = "test".to_owned(); - assert_eq!(Binary::from(b.clone()).len(), 4); - assert_eq!(Binary::from(b.clone()).as_ref(), "test".as_bytes()); - assert_eq!(Binary::from(&b).len(), 4); - assert_eq!(Binary::from(&b).as_ref(), "test".as_bytes()); - } - - #[test] - fn test_shared_vec() { - let b = Arc::new(Vec::from(&b"test"[..])); - assert_eq!(Binary::from(b.clone()).len(), 4); - assert_eq!(Binary::from(b.clone()).as_ref(), &b"test"[..]); - assert_eq!(Binary::from(&b).len(), 4); - assert_eq!(Binary::from(&b).as_ref(), &b"test"[..]); - } - - #[test] - fn test_bytes_mut() { - let b = BytesMut::from("test"); - assert_eq!(Binary::from(b.clone()).len(), 4); - assert_eq!(Binary::from(b).as_ref(), "test".as_bytes()); - } - - #[test] - fn test_binary_into() { - let bytes = Bytes::from_static(b"test"); - let b: Bytes = Binary::from("test").into(); - assert_eq!(b, bytes); - let b: Bytes = Binary::from(bytes.clone()).into(); - assert_eq!(b, bytes); - } -} diff --git a/src/client/connector.rs b/src/client/connector.rs deleted file mode 100644 index 8f2828935..000000000 --- a/src/client/connector.rs +++ /dev/null @@ -1,537 +0,0 @@ -use std::{fmt, io, time}; -use std::cell::RefCell; -use std::rc::Rc; -use std::net::Shutdown; -use std::time::{Duration, Instant}; -use std::collections::{HashMap, VecDeque}; - -use actix::{fut, Actor, ActorFuture, Context, AsyncContext, - Handler, Message, ActorResponse, Supervised}; -use actix::registry::ArbiterService; -use actix::fut::WrapFuture; -use actix::actors::{Connector, ConnectorError, Connect as ResolveConnect}; - -use http::{Uri, HttpTryFrom, Error as HttpError}; -use futures::{Async, Poll}; -use tokio_io::{AsyncRead, AsyncWrite}; - -#[cfg(feature="alpn")] -use openssl::ssl::{SslMethod, SslConnector, Error as OpensslError}; -#[cfg(feature="alpn")] -use tokio_openssl::SslConnectorExt; -#[cfg(feature="alpn")] -use futures::Future; - -#[cfg(all(feature="tls", not(feature="alpn")))] -use native_tls::{TlsConnector, Error as TlsError}; -#[cfg(all(feature="tls", not(feature="alpn")))] -use tokio_tls::TlsConnectorExt; -#[cfg(all(feature="tls", not(feature="alpn")))] -use futures::Future; - -use {HAS_OPENSSL, HAS_TLS}; -use server::IoStream; - - -#[derive(Debug)] -/// `Connect` type represents message that can be send to `ClientConnector` -/// with connection request. -pub struct Connect { - pub uri: Uri, - pub conn_timeout: Duration, -} - -impl Connect { - /// Create `Connect` message for specified `Uri` - pub fn new(uri: U) -> Result where Uri: HttpTryFrom { - Ok(Connect { - uri: Uri::try_from(uri).map_err(|e| e.into())?, - conn_timeout: Duration::from_secs(1) - }) - } -} - -impl Message for Connect { - type Result = Result; -} - -/// A set of errors that can occur during connecting to a http host -#[derive(Fail, Debug)] -pub enum ClientConnectorError { - /// Invalid url - #[fail(display="Invalid url")] - InvalidUrl, - - /// SSL feature is not enabled - #[fail(display="SSL is not supported")] - SslIsNotSupported, - - /// SSL error - #[cfg(feature="alpn")] - #[fail(display="{}", _0)] - SslError(#[cause] OpensslError), - - /// SSL error - #[cfg(all(feature="tls", not(feature="alpn")))] - #[fail(display="{}", _0)] - SslError(#[cause] TlsError), - - /// Connection error - #[fail(display = "{}", _0)] - Connector(#[cause] ConnectorError), - - /// Connection took too long - #[fail(display = "Timeout out while establishing connection")] - Timeout, - - /// Connector has been disconnected - #[fail(display = "Internal error: connector has been disconnected")] - Disconnected, - - /// Connection io error - #[fail(display = "{}", _0)] - IoError(#[cause] io::Error), -} - -impl From for ClientConnectorError { - fn from(err: ConnectorError) -> ClientConnectorError { - match err { - ConnectorError::Timeout => ClientConnectorError::Timeout, - _ => ClientConnectorError::Connector(err) - } - } -} - -pub struct ClientConnector { - #[cfg(all(feature="alpn"))] - connector: SslConnector, - #[cfg(all(feature="tls", not(feature="alpn")))] - connector: TlsConnector, - pool: Rc, -} - -impl Actor for ClientConnector { - type Context = Context; - - fn started(&mut self, ctx: &mut Self::Context) { - self.collect(ctx); - } -} - -impl Supervised for ClientConnector {} - -impl ArbiterService for ClientConnector {} - -impl Default for ClientConnector { - fn default() -> ClientConnector { - #[cfg(all(feature="alpn"))] - { - let builder = SslConnector::builder(SslMethod::tls()).unwrap(); - ClientConnector { - connector: builder.build(), - pool: Rc::new(Pool::new()), - } - } - #[cfg(all(feature="tls", not(feature="alpn")))] - { - let builder = TlsConnector::builder().unwrap(); - ClientConnector { - connector: builder.build().unwrap(), - pool: Rc::new(Pool::new()), - } - } - - #[cfg(not(any(feature="alpn", feature="tls")))] - ClientConnector {pool: Rc::new(Pool::new())} - } -} - -impl ClientConnector { - - #[cfg(feature="alpn")] - /// Create `ClientConnector` actor with custom `SslConnector` instance. - /// - /// By default `ClientConnector` uses very simple ssl configuration. - /// With `with_connector` method it is possible to use custom `SslConnector` - /// object. - /// - /// ```rust - /// # #![cfg(feature="alpn")] - /// # extern crate actix; - /// # extern crate actix_web; - /// # extern crate futures; - /// # use futures::Future; - /// # use std::io::Write; - /// extern crate openssl; - /// use actix::prelude::*; - /// use actix_web::client::{Connect, ClientConnector}; - /// - /// use openssl::ssl::{SslMethod, SslConnector}; - /// - /// fn main() { - /// let sys = System::new("test"); - /// - /// // Start `ClientConnector` with custom `SslConnector` - /// let ssl_conn = SslConnector::builder(SslMethod::tls()).unwrap().build(); - /// let conn: Address<_> = ClientConnector::with_connector(ssl_conn).start(); - /// - /// Arbiter::handle().spawn({ - /// conn.send( - /// Connect::new("https://www.rust-lang.org").unwrap()) // <- connect to host - /// .map_err(|_| ()) - /// .and_then(|res| { - /// if let Ok(mut stream) = res { - /// stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap(); - /// } - /// # Arbiter::system().do_send(actix::msgs::SystemExit(0)); - /// Ok(()) - /// }) - /// }); - /// - /// sys.run(); - /// } - /// ``` - pub fn with_connector(connector: SslConnector) -> ClientConnector { - ClientConnector { connector, pool: Rc::new(Pool::new()) } - } - - fn collect(&mut self, ctx: &mut Context) { - self.pool.collect(); - ctx.run_later(Duration::from_secs(1), |act, ctx| act.collect(ctx)); - } -} - -impl Handler for ClientConnector { - type Result = ActorResponse; - - fn handle(&mut self, msg: Connect, _: &mut Self::Context) -> Self::Result { - let uri = &msg.uri; - let conn_timeout = msg.conn_timeout; - - // host name is required - if uri.host().is_none() { - return ActorResponse::reply(Err(ClientConnectorError::InvalidUrl)) - } - - // supported protocols - let proto = match uri.scheme_part() { - Some(scheme) => match Protocol::from(scheme.as_str()) { - Some(proto) => proto, - None => return ActorResponse::reply(Err(ClientConnectorError::InvalidUrl)), - }, - None => return ActorResponse::reply(Err(ClientConnectorError::InvalidUrl)), - }; - - // check ssl availability - if proto.is_secure() && !HAS_OPENSSL && !HAS_TLS { - return ActorResponse::reply(Err(ClientConnectorError::SslIsNotSupported)) - } - - let host = uri.host().unwrap().to_owned(); - let port = uri.port().unwrap_or_else(|| proto.port()); - let key = Key {host, port, ssl: proto.is_secure()}; - - let pool = if proto.is_http() { - if let Some(mut conn) = self.pool.query(&key) { - conn.pool = Some(self.pool.clone()); - return ActorResponse::async(fut::ok(conn)) - } else { - Some(Rc::clone(&self.pool)) - } - } else { - None - }; - - ActorResponse::async( - Connector::from_registry() - .send(ResolveConnect::host_and_port(&key.host, port) - .timeout(conn_timeout)) - .into_actor(self) - .map_err(|_, _, _| ClientConnectorError::Disconnected) - .and_then(move |res, _act, _| { - #[cfg(feature="alpn")] - match res { - Err(err) => fut::Either::B(fut::err(err.into())), - Ok(stream) => { - if proto.is_secure() { - fut::Either::A( - _act.connector.connect_async(&key.host, stream) - .map_err(ClientConnectorError::SslError) - .map(|stream| Connection::new( - key, pool, Box::new(stream))) - .into_actor(_act)) - } else { - fut::Either::B(fut::ok( - Connection::new(key, pool, Box::new(stream)))) - } - } - } - - #[cfg(all(feature="tls", not(feature="alpn")))] - match res { - Err(err) => fut::Either::B(fut::err(err.into())), - Ok(stream) => { - if proto.is_secure() { - fut::Either::A( - _act.connector.connect_async(&key.host, stream) - .map_err(ClientConnectorError::SslError) - .map(|stream| Connection::new( - key, pool, Box::new(stream))) - .into_actor(_act)) - } else { - fut::Either::B(fut::ok( - Connection::new(key, pool, Box::new(stream)))) - } - } - } - - #[cfg(not(any(feature="alpn", feature="tls")))] - match res { - Err(err) => fut::err(err.into()), - Ok(stream) => { - if proto.is_secure() { - fut::err(ClientConnectorError::SslIsNotSupported) - } else { - fut::ok(Connection::new(key, pool, Box::new(stream))) - } - } - } - })) - } -} - -#[derive(PartialEq, Hash, Debug, Clone, Copy)] -enum Protocol { - Http, - Https, - Ws, - Wss, -} - -impl Protocol { - fn from(s: &str) -> Option { - match s { - "http" => Some(Protocol::Http), - "https" => Some(Protocol::Https), - "ws" => Some(Protocol::Ws), - "wss" => Some(Protocol::Wss), - _ => None, - } - } - - fn is_http(&self) -> bool { - match *self { - Protocol::Https | Protocol::Http => true, - _ => false, - } - } - - fn is_secure(&self) -> bool { - match *self { - Protocol::Https | Protocol::Wss => true, - _ => false, - } - } - - fn port(&self) -> u16 { - match *self { - Protocol::Http | Protocol::Ws => 80, - Protocol::Https | Protocol::Wss => 443 - } - } -} - -#[derive(Hash, Eq, PartialEq, Clone, Debug)] -struct Key { - host: String, - port: u16, - ssl: bool, -} - -impl Key { - fn empty() -> Key { - Key{host: String::new(), port: 0, ssl: false} - } -} - -#[derive(Debug)] -struct Conn(Instant, Connection); - -pub struct Pool { - max_size: usize, - keep_alive: Duration, - max_lifetime: Duration, - pool: RefCell>>, - to_close: RefCell>, -} - -impl Pool { - fn new() -> Pool { - Pool { - max_size: 128, - keep_alive: Duration::from_secs(15), - max_lifetime: Duration::from_secs(75), - pool: RefCell::new(HashMap::new()), - to_close: RefCell::new(Vec::new()), - } - } - - fn collect(&self) { - let mut pool = self.pool.borrow_mut(); - let mut to_close = self.to_close.borrow_mut(); - - // check keep-alive - let now = Instant::now(); - for conns in pool.values_mut() { - while !conns.is_empty() { - if (now - conns[0].0) > self.keep_alive - || (now - conns[0].1.ts) > self.max_lifetime - { - let conn = conns.pop_front().unwrap().1; - to_close.push(conn); - } else { - break - } - } - } - - // check connections for shutdown - let mut idx = 0; - while idx < to_close.len() { - match AsyncWrite::shutdown(&mut to_close[idx]) { - Ok(Async::NotReady) => idx += 1, - _ => { - to_close.swap_remove(idx); - }, - } - } - } - - fn query(&self, key: &Key) -> Option { - let mut pool = self.pool.borrow_mut(); - let mut to_close = self.to_close.borrow_mut(); - - if let Some(ref mut connections) = pool.get_mut(key) { - let now = Instant::now(); - while let Some(conn) = connections.pop_back() { - // check if it still usable - if (now - conn.0) > self.keep_alive - || (now - conn.1.ts) > self.max_lifetime - { - to_close.push(conn.1); - } else { - let mut conn = conn.1; - let mut buf = [0; 2]; - match conn.stream().read(&mut buf) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), - Ok(n) if n > 0 => { - to_close.push(conn); - continue - }, - Ok(_) | Err(_) => continue, - } - return Some(conn) - } - } - } - None - } - - fn release(&self, conn: Connection) { - if (Instant::now() - conn.ts) < self.max_lifetime { - let mut pool = self.pool.borrow_mut(); - if !pool.contains_key(&conn.key) { - let key = conn.key.clone(); - let mut vec = VecDeque::new(); - vec.push_back(Conn(Instant::now(), conn)); - pool.insert(key, vec); - } else { - let vec = pool.get_mut(&conn.key).unwrap(); - vec.push_back(Conn(Instant::now(), conn)); - if vec.len() > self.max_size { - let conn = vec.pop_front().unwrap(); - self.to_close.borrow_mut().push(conn.1); - } - } - } else { - self.to_close.borrow_mut().push(conn); - } - } -} - - -pub struct Connection { - key: Key, - stream: Box, - pool: Option>, - ts: Instant, -} - -impl fmt::Debug for Connection { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Connection {}:{}", self.key.host, self.key.port) - } -} - -impl Connection { - fn new(key: Key, pool: Option>, stream: Box) -> Self { - Connection { - key, pool, stream, - ts: Instant::now(), - } - } - - pub fn stream(&mut self) -> &mut IoStream { - &mut *self.stream - } - - pub fn from_stream(io: T) -> Connection { - Connection::new(Key::empty(), None, Box::new(io)) - } - - pub fn release(mut self) { - if let Some(pool) = self.pool.take() { - pool.release(self) - } - } -} - -impl IoStream for Connection { - fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { - IoStream::shutdown(&mut *self.stream, how) - } - - #[inline] - fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { - IoStream::set_nodelay(&mut *self.stream, nodelay) - } - - #[inline] - fn set_linger(&mut self, dur: Option) -> io::Result<()> { - IoStream::set_linger(&mut *self.stream, dur) - } -} - -impl io::Read for Connection { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.stream.read(buf) - } -} - -impl AsyncRead for Connection {} - -impl io::Write for Connection { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.stream.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.stream.flush() - } -} - -impl AsyncWrite for Connection { - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.stream.shutdown() - } -} diff --git a/src/client/mod.rs b/src/client/mod.rs deleted file mode 100644 index 5abe4ff6b..000000000 --- a/src/client/mod.rs +++ /dev/null @@ -1,30 +0,0 @@ -//! Http client -mod connector; -mod parser; -mod request; -mod response; -mod pipeline; -mod writer; - -pub use self::pipeline::{SendRequest, SendRequestError}; -pub use self::request::{ClientRequest, ClientRequestBuilder}; -pub use self::response::ClientResponse; -pub use self::connector::{Connect, Connection, ClientConnector, ClientConnectorError}; -pub(crate) use self::writer::HttpClientWriter; -pub(crate) use self::parser::{HttpResponseParser, HttpResponseParserError}; - -use error::ResponseError; -use httpresponse::HttpResponse; - - -/// Convert `SendRequestError` to a `HttpResponse` -impl ResponseError for SendRequestError { - - fn error_response(&self) -> HttpResponse { - match *self { - SendRequestError::Connector(_) => HttpResponse::BadGateway(), - _ => HttpResponse::InternalServerError(), - } - .into() - } -} diff --git a/src/client/parser.rs b/src/client/parser.rs deleted file mode 100644 index e0c494066..000000000 --- a/src/client/parser.rs +++ /dev/null @@ -1,193 +0,0 @@ -use std::mem; -use httparse; -use http::{Version, HttpTryFrom, HeaderMap, StatusCode}; -use http::header::{self, HeaderName, HeaderValue}; -use bytes::{Bytes, BytesMut}; -use futures::{Poll, Async}; - -use error::{ParseError, PayloadError}; - -use server::{utils, IoStream}; -use server::h1::{Decoder, chunked}; - -use super::ClientResponse; -use super::response::ClientMessage; - -const MAX_BUFFER_SIZE: usize = 131_072; -const MAX_HEADERS: usize = 96; - -#[derive(Default)] -pub struct HttpResponseParser { - decoder: Option, -} - -#[derive(Debug, Fail)] -pub enum HttpResponseParserError { - /// Server disconnected - #[fail(display="Server disconnected")] - Disconnect, - #[fail(display="{}", _0)] - Error(#[cause] ParseError), -} - -impl HttpResponseParser { - - pub fn parse(&mut self, io: &mut T, buf: &mut BytesMut) - -> Poll - where T: IoStream - { - // if buf is empty parse_message will always return NotReady, let's avoid that - if buf.is_empty() { - match utils::read_from_io(io, buf) { - Ok(Async::Ready(0)) => - return Err(HttpResponseParserError::Disconnect), - Ok(Async::Ready(_)) => (), - Ok(Async::NotReady) => - return Ok(Async::NotReady), - Err(err) => - return Err(HttpResponseParserError::Error(err.into())) - } - } - - loop { - match HttpResponseParser::parse_message(buf) - .map_err(HttpResponseParserError::Error)? - { - Async::Ready((msg, decoder)) => { - self.decoder = decoder; - return Ok(Async::Ready(msg)); - }, - Async::NotReady => { - if buf.capacity() >= MAX_BUFFER_SIZE { - return Err(HttpResponseParserError::Error(ParseError::TooLarge)); - } - match utils::read_from_io(io, buf) { - Ok(Async::Ready(0)) => - return Err(HttpResponseParserError::Disconnect), - Ok(Async::Ready(_)) => (), - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(err) => - return Err(HttpResponseParserError::Error(err.into())), - } - }, - } - } - } - - pub fn parse_payload(&mut self, io: &mut T, buf: &mut BytesMut) - -> Poll, PayloadError> - where T: IoStream - { - if self.decoder.is_some() { - loop { - // read payload - let not_ready = match utils::read_from_io(io, buf) { - Ok(Async::Ready(0)) => { - if buf.is_empty() { - return Err(PayloadError::Incomplete) - } - true - } - Err(err) => return Err(err.into()), - Ok(Async::NotReady) => true, - _ => false, - }; - - match self.decoder.as_mut().unwrap().decode(buf) { - Ok(Async::Ready(Some(b))) => - return Ok(Async::Ready(Some(b))), - Ok(Async::Ready(None)) => { - self.decoder.take(); - return Ok(Async::Ready(None)) - } - Ok(Async::NotReady) => { - if not_ready { - return Ok(Async::NotReady) - } - } - Err(err) => return Err(err.into()), - } - } - } else { - Ok(Async::Ready(None)) - } - } - - fn parse_message(buf: &mut BytesMut) - -> Poll<(ClientResponse, Option), ParseError> - { - // Parse http message - let bytes_ptr = buf.as_ref().as_ptr() as usize; - let mut headers: [httparse::Header; MAX_HEADERS] = - unsafe{mem::uninitialized()}; - - let (len, version, status, headers_len) = { - let b = unsafe{ let b: &[u8] = buf; mem::transmute(b) }; - let mut resp = httparse::Response::new(&mut headers); - match resp.parse(b)? { - httparse::Status::Complete(len) => { - let version = if resp.version.unwrap_or(1) == 1 { - Version::HTTP_11 - } else { - Version::HTTP_10 - }; - let status = StatusCode::from_u16(resp.code.unwrap()) - .map_err(|_| ParseError::Status)?; - - (len, version, status, resp.headers.len()) - } - httparse::Status::Partial => return Ok(Async::NotReady), - } - }; - - let slice = buf.split_to(len).freeze(); - - // convert headers - let mut hdrs = HeaderMap::new(); - for header in headers[..headers_len].iter() { - if let Ok(name) = HeaderName::try_from(header.name) { - let v_start = header.value.as_ptr() as usize - bytes_ptr; - let v_end = v_start + header.value.len(); - let value = unsafe { - HeaderValue::from_shared_unchecked(slice.slice(v_start, v_end)) }; - hdrs.append(name, value); - } else { - return Err(ParseError::Header) - } - } - - let decoder = if status == StatusCode::SWITCHING_PROTOCOLS { - Some(Decoder::eof()) - } else if let Some(len) = hdrs.get(header::CONTENT_LENGTH) { - // Content-Length - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - Some(Decoder::length(len)) - } else { - debug!("illegal Content-Length: {:?}", len); - return Err(ParseError::Header) - } - } else { - debug!("illegal Content-Length: {:?}", len); - return Err(ParseError::Header) - } - } else if chunked(&hdrs)? { - // Chunked encoding - Some(Decoder::chunked()) - } else { - None - }; - - if let Some(decoder) = decoder { - Ok(Async::Ready( - (ClientResponse::new( - ClientMessage{status, version, - headers: hdrs, cookies: None}), Some(decoder)))) - } else { - Ok(Async::Ready( - (ClientResponse::new( - ClientMessage{status, version, - headers: hdrs, cookies: None}), None))) - } - } -} diff --git a/src/client/pipeline.rs b/src/client/pipeline.rs deleted file mode 100644 index 19ccf8927..000000000 --- a/src/client/pipeline.rs +++ /dev/null @@ -1,460 +0,0 @@ -use std::{io, mem}; -use std::time::Duration; -use bytes::{Bytes, BytesMut}; -use http::header::CONTENT_ENCODING; -use futures::{Async, Future, Poll}; -use futures::unsync::oneshot; -use tokio_core::reactor::Timeout; - -use actix::prelude::*; - -use error::Error; -use body::{Body, BodyStream}; -use context::{Frame, ActorHttpContext}; -use header::ContentEncoding; -use httpmessage::HttpMessage; -use error::PayloadError; -use server::WriterState; -use server::shared::SharedBytes; -use server::encoding::PayloadStream; -use super::{ClientRequest, ClientResponse}; -use super::{Connect, Connection, ClientConnector, ClientConnectorError}; -use super::HttpClientWriter; -use super::{HttpResponseParser, HttpResponseParserError}; - -/// A set of errors that can occur during sending request and reading response -#[derive(Fail, Debug)] -pub enum SendRequestError { - /// Response took too long - #[fail(display = "Timeout out while waiting for response")] - Timeout, - /// Failed to connect to host - #[fail(display="Failed to connect to host: {}", _0)] - Connector(#[cause] ClientConnectorError), - /// Error parsing response - #[fail(display="{}", _0)] - ParseError(#[cause] HttpResponseParserError), - /// Error reading response payload - #[fail(display="Error reading response payload: {}", _0)] - Io(#[cause] io::Error), -} - -impl From for SendRequestError { - fn from(err: io::Error) -> SendRequestError { - SendRequestError::Io(err) - } -} - -impl From for SendRequestError { - fn from(err: ClientConnectorError) -> SendRequestError { - match err { - ClientConnectorError::Timeout => SendRequestError::Timeout, - _ => SendRequestError::Connector(err), - } - } -} - -enum State { - New, - Connect(actix::dev::Request), - Connection(Connection), - Send(Box), - None, -} - -/// `SendRequest` is a `Future` which represents asynchronous request sending process. -#[must_use = "SendRequest does nothing unless polled"] -pub struct SendRequest { - req: ClientRequest, - state: State, - conn: Addr, - conn_timeout: Duration, - timeout: Option, -} - -impl SendRequest { - pub(crate) fn new(req: ClientRequest) -> SendRequest { - SendRequest::with_connector(req, ClientConnector::from_registry()) - } - - pub(crate) fn with_connector(req: ClientRequest, conn: Addr) - -> SendRequest - { - SendRequest{req, conn, - state: State::New, - timeout: None, - conn_timeout: Duration::from_secs(1) - } - } - - pub(crate) fn with_connection(req: ClientRequest, conn: Connection) -> SendRequest - { - SendRequest{req, - state: State::Connection(conn), - conn: ClientConnector::from_registry(), - timeout: None, - conn_timeout: Duration::from_secs(1), - } - } - - /// Set request timeout - /// - /// Request timeout is a total time before response should be received. - /// Default value is 5 seconds. - pub fn timeout(mut self, timeout: Duration) -> Self { - self.timeout = Some(Timeout::new(timeout, Arbiter::handle()).unwrap()); - self - } - - /// Set connection timeout - /// - /// Connection timeout includes resolving hostname and actual connection to - /// the host. - /// Default value is 1 second. - pub fn conn_timeout(mut self, timeout: Duration) -> Self { - self.conn_timeout = timeout; - self - } - - fn poll_timeout(&mut self) -> Poll<(), SendRequestError> { - if self.timeout.is_none() { - self.timeout = Some(Timeout::new( - Duration::from_secs(5), Arbiter::handle()).unwrap()); - } - - match self.timeout.as_mut().unwrap().poll() { - Ok(Async::Ready(())) => Err(SendRequestError::Timeout), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(_) => unreachable!() - } - } -} - -impl Future for SendRequest { - type Item = ClientResponse; - type Error = SendRequestError; - - fn poll(&mut self) -> Poll { - self.poll_timeout()?; - - loop { - let state = mem::replace(&mut self.state, State::None); - - match state { - State::New => - self.state = State::Connect(self.conn.send(Connect { - uri: self.req.uri().clone(), - conn_timeout: self.conn_timeout, - })), - State::Connect(mut conn) => match conn.poll() { - Ok(Async::NotReady) => { - self.state = State::Connect(conn); - return Ok(Async::NotReady); - }, - Ok(Async::Ready(result)) => match result { - Ok(stream) => { - self.state = State::Connection(stream) - }, - Err(err) => return Err(err.into()), - }, - Err(_) => return Err(SendRequestError::Connector( - ClientConnectorError::Disconnected)) - }, - State::Connection(conn) => { - let mut writer = HttpClientWriter::new(SharedBytes::default()); - writer.start(&mut self.req)?; - - let body = match self.req.replace_body(Body::Empty) { - Body::Streaming(stream) => IoBody::Payload(stream), - Body::Actor(ctx) => IoBody::Actor(ctx), - _ => IoBody::Done, - }; - - let pl = Box::new(Pipeline { - body, writer, - conn: Some(conn), - parser: Some(HttpResponseParser::default()), - parser_buf: BytesMut::new(), - disconnected: false, - drain: None, - decompress: None, - should_decompress: self.req.response_decompress(), - write_state: RunningState::Running, - }); - self.state = State::Send(pl); - }, - State::Send(mut pl) => { - pl.poll_write() - .map_err(|e| io::Error::new( - io::ErrorKind::Other, format!("{}", e).as_str()))?; - - match pl.parse() { - Ok(Async::Ready(mut resp)) => { - resp.set_pipeline(pl); - return Ok(Async::Ready(resp)) - }, - Ok(Async::NotReady) => { - self.state = State::Send(pl); - return Ok(Async::NotReady) - }, - Err(err) => return Err(SendRequestError::ParseError(err)) - } - } - State::None => unreachable!(), - } - } - } -} - - -pub(crate) struct Pipeline { - body: IoBody, - conn: Option, - writer: HttpClientWriter, - parser: Option, - parser_buf: BytesMut, - disconnected: bool, - drain: Option>, - decompress: Option, - should_decompress: bool, - write_state: RunningState, -} - -enum IoBody { - Payload(BodyStream), - Actor(Box), - Done, -} - -#[derive(Debug, PartialEq)] -enum RunningState { - Running, - Paused, - Done, -} - -impl RunningState { - #[inline] - fn pause(&mut self) { - if *self != RunningState::Done { - *self = RunningState::Paused - } - } - #[inline] - fn resume(&mut self) { - if *self != RunningState::Done { - *self = RunningState::Running - } - } -} - -impl Pipeline { - - fn release_conn(&mut self) { - if let Some(conn) = self.conn.take() { - conn.release() - } - } - - #[inline] - fn parse(&mut self) -> Poll { - if let Some(ref mut conn) = self.conn { - match self.parser.as_mut().unwrap().parse(conn, &mut self.parser_buf) { - Ok(Async::Ready(resp)) => { - // check content-encoding - if self.should_decompress { - if let Some(enc) = resp.headers().get(CONTENT_ENCODING) { - if let Ok(enc) = enc.to_str() { - match ContentEncoding::from(enc) { - ContentEncoding::Auto | ContentEncoding::Identity => (), - enc => self.decompress = Some(PayloadStream::new(enc)), - } - } - } - } - - Ok(Async::Ready(resp)) - } - val => val, - } - } else { - Ok(Async::NotReady) - } - } - - #[inline] - pub fn poll(&mut self) -> Poll, PayloadError> { - if self.conn.is_none() { - return Ok(Async::Ready(None)) - } - let conn: &mut Connection = unsafe{ mem::transmute(self.conn.as_mut().unwrap())}; - - let mut need_run = false; - - // need write? - if let Async::NotReady = self.poll_write() - .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))? - { - need_run = true; - } - - // need read? - if self.parser.is_some() { - loop { - match self.parser.as_mut().unwrap() - .parse_payload(conn, &mut self.parser_buf)? - { - Async::Ready(Some(b)) => { - if let Some(ref mut decompress) = self.decompress { - match decompress.feed_data(b) { - Ok(Some(b)) => return Ok(Async::Ready(Some(b))), - Ok(None) => return Ok(Async::NotReady), - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => - continue, - Err(err) => return Err(err.into()), - } - } else { - return Ok(Async::Ready(Some(b))) - } - }, - Async::Ready(None) => { - let _ = self.parser.take(); - break - } - Async::NotReady => return Ok(Async::NotReady), - } - } - } - - // eof - if let Some(mut decompress) = self.decompress.take() { - let res = decompress.feed_eof(); - if let Some(b) = res? { - self.release_conn(); - return Ok(Async::Ready(Some(b))) - } - } - - if need_run { - Ok(Async::NotReady) - } else { - self.release_conn(); - Ok(Async::Ready(None)) - } - } - - #[inline] - fn poll_write(&mut self) -> Poll<(), Error> { - if self.write_state == RunningState::Done || self.conn.is_none() { - return Ok(Async::Ready(())) - } - - let mut done = false; - - if self.drain.is_none() && self.write_state != RunningState::Paused { - 'outter: loop { - let result = match mem::replace(&mut self.body, IoBody::Done) { - IoBody::Payload(mut body) => { - match body.poll()? { - Async::Ready(None) => { - self.writer.write_eof()?; - self.disconnected = true; - break - }, - Async::Ready(Some(chunk)) => { - self.body = IoBody::Payload(body); - self.writer.write(chunk.into())? - } - Async::NotReady => { - done = true; - self.body = IoBody::Payload(body); - break - }, - } - }, - IoBody::Actor(mut ctx) => { - if self.disconnected { - ctx.disconnected(); - } - match ctx.poll()? { - Async::Ready(Some(vec)) => { - if vec.is_empty() { - self.body = IoBody::Actor(ctx); - break - } - let mut res = None; - for frame in vec { - match frame { - Frame::Chunk(None) => { - // info.context = Some(ctx); - self.disconnected = true; - self.writer.write_eof()?; - break 'outter - }, - Frame::Chunk(Some(chunk)) => - res = Some(self.writer.write(chunk)?), - Frame::Drain(fut) => self.drain = Some(fut), - } - } - self.body = IoBody::Actor(ctx); - if self.drain.is_some() { - self.write_state.resume(); - break - } - res.unwrap() - }, - Async::Ready(None) => { - done = true; - break - } - Async::NotReady => { - done = true; - self.body = IoBody::Actor(ctx); - break - } - } - }, - IoBody::Done => { - self.disconnected = true; - done = true; - break - } - }; - - match result { - WriterState::Pause => { - self.write_state.pause(); - break - } - WriterState::Done => { - self.write_state.resume() - }, - } - } - } - - // flush io but only if we need to - match self.writer.poll_completed(self.conn.as_mut().unwrap(), false) { - Ok(Async::Ready(_)) => { - if self.disconnected { - self.write_state = RunningState::Done; - } else { - self.write_state.resume(); - } - - // resolve drain futures - if let Some(tx) = self.drain.take() { - let _ = tx.send(()); - } - // restart io processing - if !done || self.write_state == RunningState::Done { - self.poll_write() - } else { - Ok(Async::NotReady) - } - }, - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(err) => Err(err.into()), - } - } -} diff --git a/src/client/request.rs b/src/client/request.rs deleted file mode 100644 index 8f2967ab4..000000000 --- a/src/client/request.rs +++ /dev/null @@ -1,671 +0,0 @@ -use std::{fmt, mem}; -use std::fmt::Write as FmtWrite; -use std::io::Write; -use std::time::Duration; - -use actix::{Addr, Unsync}; -use cookie::{Cookie, CookieJar}; -use bytes::{Bytes, BytesMut, BufMut}; -use http::{uri, HeaderMap, Method, Version, Uri, HttpTryFrom, Error as HttpError}; -use http::header::{self, HeaderName, HeaderValue}; -use futures::Stream; -use serde_json; -use serde::Serialize; -use url::Url; -use percent_encoding::{USERINFO_ENCODE_SET, percent_encode}; - -use body::Body; -use error::Error; -use header::{ContentEncoding, Header, IntoHeaderValue}; -use httpmessage::HttpMessage; -use httprequest::HttpRequest; -use super::pipeline::SendRequest; -use super::connector::{Connection, ClientConnector}; - -/// An HTTP Client Request -pub struct ClientRequest { - uri: Uri, - method: Method, - version: Version, - headers: HeaderMap, - body: Body, - chunked: bool, - upgrade: bool, - timeout: Option, - encoding: ContentEncoding, - response_decompress: bool, - buffer_capacity: usize, - conn: ConnectionType, -} - -enum ConnectionType { - Default, - Connector(Addr), - Connection(Connection), -} - -impl Default for ClientRequest { - - fn default() -> ClientRequest { - ClientRequest { - uri: Uri::default(), - method: Method::default(), - version: Version::HTTP_11, - headers: HeaderMap::with_capacity(16), - body: Body::Empty, - chunked: false, - upgrade: false, - timeout: None, - encoding: ContentEncoding::Auto, - response_decompress: true, - buffer_capacity: 32_768, - conn: ConnectionType::Default, - } - } -} - -impl ClientRequest { - - /// Create request builder for `GET` request - pub fn get>(uri: U) -> ClientRequestBuilder { - let mut builder = ClientRequest::build(); - builder.method(Method::GET).uri(uri); - builder - } - - /// Create request builder for `HEAD` request - pub fn head>(uri: U) -> ClientRequestBuilder { - let mut builder = ClientRequest::build(); - builder.method(Method::HEAD).uri(uri); - builder - } - - /// Create request builder for `POST` request - pub fn post>(uri: U) -> ClientRequestBuilder { - let mut builder = ClientRequest::build(); - builder.method(Method::POST).uri(uri); - builder - } - - /// Create request builder for `PUT` request - pub fn put>(uri: U) -> ClientRequestBuilder { - let mut builder = ClientRequest::build(); - builder.method(Method::PUT).uri(uri); - builder - } - - /// Create request builder for `DELETE` request - pub fn delete>(uri: U) -> ClientRequestBuilder { - let mut builder = ClientRequest::build(); - builder.method(Method::DELETE).uri(uri); - builder - } -} - -impl ClientRequest { - - /// Create client request builder - pub fn build() -> ClientRequestBuilder { - ClientRequestBuilder { - request: Some(ClientRequest::default()), - err: None, - cookies: None, - default_headers: true - } - } - - /// Create client request builder - pub fn build_from>(source: T) -> ClientRequestBuilder { - source.into() - } - - /// Get the request uri - #[inline] - pub fn uri(&self) -> &Uri { - &self.uri - } - - /// Set client request uri - #[inline] - pub fn set_uri(&mut self, uri: Uri) { - self.uri = uri - } - - /// Get the request method - #[inline] - pub fn method(&self) -> &Method { - &self.method - } - - /// Set http `Method` for the request - #[inline] - pub fn set_method(&mut self, method: Method) { - self.method = method - } - - /// Get http version for the request - #[inline] - pub fn version(&self) -> Version { - self.version - } - - /// Set http `Version` for the request - #[inline] - pub fn set_version(&mut self, version: Version) { - self.version = version - } - - /// Get the headers from the request - #[inline] - pub fn headers(&self) -> &HeaderMap { - &self.headers - } - - /// Get a mutable reference to the headers - #[inline] - pub fn headers_mut(&mut self) -> &mut HeaderMap { - &mut self.headers - } - - /// is chunked encoding enabled - #[inline] - pub fn chunked(&self) -> bool { - self.chunked - } - - /// is upgrade request - #[inline] - pub fn upgrade(&self) -> bool { - self.upgrade - } - - /// Content encoding - #[inline] - pub fn content_encoding(&self) -> ContentEncoding { - self.encoding - } - - /// Decompress response payload - #[inline] - pub fn response_decompress(&self) -> bool { - self.response_decompress - } - - /// Requested write buffer capacity - pub fn write_buffer_capacity(&self) -> usize { - self.buffer_capacity - } - - /// Get body os this response - #[inline] - pub fn body(&self) -> &Body { - &self.body - } - - /// Set a body - pub fn set_body>(&mut self, body: B) { - self.body = body.into(); - } - - /// Extract body, replace it with Empty - pub(crate) fn replace_body(&mut self, body: Body) -> Body { - mem::replace(&mut self.body, body) - } - - /// Send request - /// - /// This method returns future that resolves to a ClientResponse - pub fn send(mut self) -> SendRequest { - let timeout = self.timeout.take(); - let send = match mem::replace(&mut self.conn, ConnectionType::Default) { - ConnectionType::Default => SendRequest::new(self), - ConnectionType::Connector(conn) => SendRequest::with_connector(self, conn), - ConnectionType::Connection(conn) => SendRequest::with_connection(self, conn), - }; - if let Some(timeout) = timeout { - send.timeout(timeout) - } else { - send - } - } -} - -impl fmt::Debug for ClientRequest { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let res = write!(f, "\nClientRequest {:?} {}:{}\n", - self.version, self.method, self.uri); - let _ = write!(f, " headers:\n"); - for (key, val) in self.headers.iter() { - let _ = write!(f, " {:?}: {:?}\n", key, val); - } - res - } -} - -/// An HTTP Client request builder -/// -/// This type can be used to construct an instance of `ClientRequest` through a -/// builder-like pattern. -pub struct ClientRequestBuilder { - request: Option, - err: Option, - cookies: Option, - default_headers: bool -} - -impl ClientRequestBuilder { - /// Set HTTP uri of request. - #[inline] - pub fn uri>(&mut self, uri: U) -> &mut Self { - match Url::parse(uri.as_ref()) { - Ok(url) => self._uri(url.as_str()), - Err(_) => self._uri(uri.as_ref()), - } - } - - fn _uri(&mut self, url: &str) -> &mut Self { - match Uri::try_from(url) { - Ok(uri) => { - // set request host header - if let Some(host) = uri.host() { - self.set_header(header::HOST, host); - } - if let Some(parts) = parts(&mut self.request, &self.err) { - parts.uri = uri; - } - }, - Err(e) => self.err = Some(e.into(),), - } - self - } - - /// Set HTTP method of this request. - #[inline] - pub fn method(&mut self, method: Method) -> &mut Self { - if let Some(parts) = parts(&mut self.request, &self.err) { - parts.method = method; - } - self - } - - /// Set HTTP method of this request. - #[inline] - pub fn get_method(&mut self) -> &Method { - let parts = parts(&mut self.request, &self.err) - .expect("cannot reuse request builder"); - &parts.method - } - - /// Set HTTP version of this request. - /// - /// By default requests's http version depends on network stream - #[inline] - pub fn version(&mut self, version: Version) -> &mut Self { - if let Some(parts) = parts(&mut self.request, &self.err) { - parts.version = version; - } - self - } - - /// Set a header. - /// - /// ```rust - /// # extern crate mime; - /// # extern crate actix_web; - /// # use actix_web::client::*; - /// # - /// use actix_web::{client, http}; - /// - /// fn main() { - /// let req = client::ClientRequest::build() - /// .set(http::header::Date::now()) - /// .set(http::header::ContentType(mime::TEXT_HTML)) - /// .finish().unwrap(); - /// } - /// ``` - #[doc(hidden)] - pub fn set(&mut self, hdr: H) -> &mut Self - { - if let Some(parts) = parts(&mut self.request, &self.err) { - match hdr.try_into() { - Ok(value) => { parts.headers.insert(H::name(), value); } - Err(e) => self.err = Some(e.into()), - } - } - self - } - - /// Append a header. - /// - /// Header get appended to existing header. - /// To override header use `set_header()` method. - /// - /// ```rust - /// # extern crate http; - /// # extern crate actix_web; - /// # use actix_web::client::*; - /// # - /// use http::header; - /// - /// fn main() { - /// let req = ClientRequest::build() - /// .header("X-TEST", "value") - /// .header(header::CONTENT_TYPE, "application/json") - /// .finish().unwrap(); - /// } - /// ``` - pub fn header(&mut self, key: K, value: V) -> &mut Self - where HeaderName: HttpTryFrom, V: IntoHeaderValue - { - if let Some(parts) = parts(&mut self.request, &self.err) { - match HeaderName::try_from(key) { - Ok(key) => { - match value.try_into() { - Ok(value) => { parts.headers.append(key, value); } - Err(e) => self.err = Some(e.into()), - } - }, - Err(e) => self.err = Some(e.into()), - }; - } - self - } - - /// Set a header. - pub fn set_header(&mut self, key: K, value: V) -> &mut Self - where HeaderName: HttpTryFrom, V: IntoHeaderValue - { - if let Some(parts) = parts(&mut self.request, &self.err) { - match HeaderName::try_from(key) { - Ok(key) => { - match value.try_into() { - Ok(value) => { parts.headers.insert(key, value); } - Err(e) => self.err = Some(e.into()), - } - }, - Err(e) => self.err = Some(e.into()), - }; - } - self - } - - /// Set content encoding. - /// - /// By default `ContentEncoding::Identity` is used. - #[inline] - pub fn content_encoding(&mut self, enc: ContentEncoding) -> &mut Self { - if let Some(parts) = parts(&mut self.request, &self.err) { - parts.encoding = enc; - } - self - } - - /// Enables automatic chunked transfer encoding - #[inline] - pub fn chunked(&mut self) -> &mut Self { - if let Some(parts) = parts(&mut self.request, &self.err) { - parts.chunked = true; - } - self - } - - /// Enable connection upgrade - #[inline] - pub fn upgrade(&mut self) -> &mut Self { - if let Some(parts) = parts(&mut self.request, &self.err) { - parts.upgrade = true; - } - self - } - - /// Set request's content type - #[inline] - pub fn content_type(&mut self, value: V) -> &mut Self - where HeaderValue: HttpTryFrom - { - if let Some(parts) = parts(&mut self.request, &self.err) { - match HeaderValue::try_from(value) { - Ok(value) => { parts.headers.insert(header::CONTENT_TYPE, value); }, - Err(e) => self.err = Some(e.into()), - }; - } - self - } - - /// Set content length - #[inline] - pub fn content_length(&mut self, len: u64) -> &mut Self { - let mut wrt = BytesMut::new().writer(); - let _ = write!(wrt, "{}", len); - self.header(header::CONTENT_LENGTH, wrt.get_mut().take().freeze()) - } - - /// Set a cookie - /// - /// ```rust - /// # extern crate actix_web; - /// use actix_web::{client, http}; - /// - /// fn main() { - /// let req = client::ClientRequest::build() - /// .cookie( - /// http::Cookie::build("name", "value") - /// .domain("www.rust-lang.org") - /// .path("/") - /// .secure(true) - /// .http_only(true) - /// .finish()) - /// .finish().unwrap(); - /// } - /// ``` - pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self { - if self.cookies.is_none() { - let mut jar = CookieJar::new(); - jar.add(cookie.into_owned()); - self.cookies = Some(jar) - } else { - self.cookies.as_mut().unwrap().add(cookie.into_owned()); - } - self - } - - /// Do not add default request headers. - /// By default `Accept-Encoding` header is set. - pub fn no_default_headers(&mut self) -> &mut Self { - self.default_headers = false; - self - } - - /// Disable automatic decompress response body - pub fn disable_decompress(&mut self) -> &mut Self { - if let Some(parts) = parts(&mut self.request, &self.err) { - parts.response_decompress = false; - } - self - } - - /// Set write buffer capacity - /// - /// Default buffer capacity is 32kb - pub fn write_buffer_capacity(&mut self, cap: usize) -> &mut Self { - if let Some(parts) = parts(&mut self.request, &self.err) { - parts.buffer_capacity = cap; - } - self - } - - /// Set request timeout - /// - /// Request timeout is a total time before response should be received. - /// Default value is 5 seconds. - pub fn timeout(&mut self, timeout: Duration) -> &mut Self { - if let Some(parts) = parts(&mut self.request, &self.err) { - parts.timeout = Some(timeout); - } - self - } - - /// Send request using custom connector - pub fn with_connector(&mut self, conn: Addr) -> &mut Self { - if let Some(parts) = parts(&mut self.request, &self.err) { - parts.conn = ConnectionType::Connector(conn); - } - self - } - - /// Send request using existing Connection - pub fn with_connection(&mut self, conn: Connection) -> &mut Self { - if let Some(parts) = parts(&mut self.request, &self.err) { - parts.conn = ConnectionType::Connection(conn); - } - self - } - - /// This method calls provided closure with builder reference if value is true. - pub fn if_true(&mut self, value: bool, f: F) -> &mut Self - where F: FnOnce(&mut ClientRequestBuilder) - { - if value { - f(self); - } - self - } - - /// This method calls provided closure with builder reference if value is Some. - pub fn if_some(&mut self, value: Option, f: F) -> &mut Self - where F: FnOnce(T, &mut ClientRequestBuilder) - { - if let Some(val) = value { - f(val, self); - } - self - } - - /// Set a body and generate `ClientRequest`. - /// - /// `ClientRequestBuilder` can not be used after this call. - pub fn body>(&mut self, body: B) -> Result { - if let Some(e) = self.err.take() { - return Err(e.into()) - } - - if self.default_headers { - // enable br only for https - let https = - if let Some(parts) = parts(&mut self.request, &self.err) { - parts.uri.scheme_part() - .map(|s| s == &uri::Scheme::HTTPS).unwrap_or(true) - } else { - true - }; - - if https { - self.header(header::ACCEPT_ENCODING, "br, gzip, deflate"); - } else { - self.header(header::ACCEPT_ENCODING, "gzip, deflate"); - } - } - - let mut request = self.request.take().expect("cannot reuse request builder"); - - // set cookies - if let Some(ref mut jar) = self.cookies { - let mut cookie = String::new(); - for c in jar.delta() { - let name = percent_encode(c.name().as_bytes(), USERINFO_ENCODE_SET); - let value = percent_encode(c.value().as_bytes(), USERINFO_ENCODE_SET); - let _ = write!(&mut cookie, "; {}={}", name, value); - } - request.headers.insert( - header::COOKIE, HeaderValue::from_str(&cookie.as_str()[2..]).unwrap()); - } - request.body = body.into(); - Ok(request) - } - - /// Set a json body and generate `ClientRequest` - /// - /// `ClientRequestBuilder` can not be used after this call. - pub fn json(&mut self, value: T) -> Result { - let body = serde_json::to_string(&value)?; - - let contains = if let Some(parts) = parts(&mut self.request, &self.err) { - parts.headers.contains_key(header::CONTENT_TYPE) - } else { - true - }; - if !contains { - self.header(header::CONTENT_TYPE, "application/json"); - } - - self.body(body) - } - - /// Set a streaming body and generate `ClientRequest`. - /// - /// `ClientRequestBuilder` can not be used after this call. - pub fn streaming(&mut self, stream: S) -> Result - where S: Stream + 'static, - E: Into, - { - self.body(Body::Streaming(Box::new(stream.map_err(|e| e.into())))) - } - - /// Set an empty body and generate `ClientRequest` - /// - /// `ClientRequestBuilder` can not be used after this call. - pub fn finish(&mut self) -> Result { - self.body(Body::Empty) - } - - /// This method construct new `ClientRequestBuilder` - pub fn take(&mut self) -> ClientRequestBuilder { - ClientRequestBuilder { - request: self.request.take(), - err: self.err.take(), - cookies: self.cookies.take(), - default_headers: self.default_headers - } - } -} - -#[inline] -fn parts<'a>(parts: &'a mut Option, err: &Option) - -> Option<&'a mut ClientRequest> -{ - if err.is_some() { - return None - } - parts.as_mut() -} - -impl fmt::Debug for ClientRequestBuilder { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if let Some(ref parts) = self.request { - let res = write!(f, "\nClientRequestBuilder {:?} {}:{}\n", - parts.version, parts.method, parts.uri); - let _ = write!(f, " headers:\n"); - for (key, val) in parts.headers.iter() { - let _ = write!(f, " {:?}: {:?}\n", key, val); - } - res - } else { - write!(f, "ClientRequestBuilder(Consumed)") - } - } -} - -/// Create `ClientRequestBuilder` from `HttpRequest` -/// -/// It is useful for proxy requests. This implementation -/// copies all request's headers and method. -impl<'a, S: 'static> From<&'a HttpRequest> for ClientRequestBuilder { - fn from(req: &'a HttpRequest) -> ClientRequestBuilder { - let mut builder = ClientRequest::build(); - for (key, value) in req.headers() { - builder.header(key.clone(), value.clone()); - } - builder.method(req.method().clone()); - builder - } -} diff --git a/src/client/response.rs b/src/client/response.rs deleted file mode 100644 index 1a82d64bd..000000000 --- a/src/client/response.rs +++ /dev/null @@ -1,148 +0,0 @@ -use std::{fmt, str}; -use std::rc::Rc; -use std::cell::UnsafeCell; - -use bytes::Bytes; -use cookie::Cookie; -use futures::{Async, Poll, Stream}; -use http::{HeaderMap, StatusCode, Version}; -use http::header::{self, HeaderValue}; - -use httpmessage::HttpMessage; -use error::{CookieParseError, PayloadError}; - -use super::pipeline::Pipeline; - - -pub(crate) struct ClientMessage { - pub status: StatusCode, - pub version: Version, - pub headers: HeaderMap, - pub cookies: Option>>, -} - -impl Default for ClientMessage { - - fn default() -> ClientMessage { - ClientMessage { - status: StatusCode::OK, - version: Version::HTTP_11, - headers: HeaderMap::with_capacity(16), - cookies: None, - } - } -} - -/// An HTTP Client response -pub struct ClientResponse(Rc>, Option>); - -impl HttpMessage for ClientResponse { - /// Get the headers from the response. - #[inline] - fn headers(&self) -> &HeaderMap { - &self.as_ref().headers - } -} - -impl ClientResponse { - - pub(crate) fn new(msg: ClientMessage) -> ClientResponse { - ClientResponse(Rc::new(UnsafeCell::new(msg)), None) - } - - pub(crate) fn set_pipeline(&mut self, pl: Box) { - self.1 = Some(pl); - } - - #[inline] - fn as_ref(&self) -> &ClientMessage { - unsafe{ &*self.0.get() } - } - - #[inline] - #[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref))] - fn as_mut(&self) -> &mut ClientMessage { - unsafe{ &mut *self.0.get() } - } - - /// Get the HTTP version of this response. - #[inline] - pub fn version(&self) -> Version { - self.as_ref().version - } - - /// Get the status from the server. - #[inline] - pub fn status(&self) -> StatusCode { - self.as_ref().status - } - - /// Load response cookies. - pub fn cookies(&self) -> Result<&Vec>, CookieParseError> { - if self.as_ref().cookies.is_none() { - let msg = self.as_mut(); - let mut cookies = Vec::new(); - for val in msg.headers.get_all(header::SET_COOKIE).iter() { - let s = str::from_utf8(val.as_bytes()).map_err(CookieParseError::from)?; - cookies.push(Cookie::parse_encoded(s)?.into_owned()); - } - msg.cookies = Some(cookies) - } - Ok(self.as_ref().cookies.as_ref().unwrap()) - } - - /// Return request cookie. - pub fn cookie(&self, name: &str) -> Option<&Cookie> { - if let Ok(cookies) = self.cookies() { - for cookie in cookies { - if cookie.name() == name { - return Some(cookie) - } - } - } - None - } -} - -impl fmt::Debug for ClientResponse { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let res = write!( - f, "\nClientResponse {:?} {}\n", self.version(), self.status()); - let _ = write!(f, " headers:\n"); - for (key, val) in self.headers().iter() { - let _ = write!(f, " {:?}: {:?}\n", key, val); - } - res - } -} - -/// Future that resolves to a complete request body. -impl Stream for ClientResponse { - type Item = Bytes; - type Error = PayloadError; - - fn poll(&mut self) -> Poll, Self::Error> { - if let Some(ref mut pl) = self.1 { - pl.poll() - } else { - Ok(Async::Ready(None)) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_debug() { - let resp = ClientResponse::new(ClientMessage::default()); - resp.as_mut().headers.insert( - header::COOKIE, HeaderValue::from_static("cookie1=value1")); - resp.as_mut().headers.insert( - header::COOKIE, HeaderValue::from_static("cookie2=value2")); - - let dbg = format!("{:?}", resp); - assert!(dbg.contains("ClientResponse")); - } -} diff --git a/src/client/writer.rs b/src/client/writer.rs deleted file mode 100644 index cd50359ce..000000000 --- a/src/client/writer.rs +++ /dev/null @@ -1,367 +0,0 @@ -#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))] - -use std::io::{self, Write}; -use std::cell::RefCell; -use std::fmt::Write as FmtWrite; - -use time::{self, Duration}; -use bytes::{BytesMut, BufMut}; -use futures::{Async, Poll}; -use tokio_io::AsyncWrite; -use http::{Version, HttpTryFrom}; -use http::header::{HeaderValue, DATE, - CONNECTION, CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING}; -use flate2::Compression; -use flate2::write::{GzEncoder, DeflateEncoder}; -#[cfg(feature="brotli")] -use brotli2::write::BrotliEncoder; - -use body::{Body, Binary}; -use header::ContentEncoding; -use server::WriterState; -use server::shared::SharedBytes; -use server::encoding::{ContentEncoder, TransferEncoding}; - -use client::ClientRequest; - - -const AVERAGE_HEADER_SIZE: usize = 30; - -bitflags! { - struct Flags: u8 { - const STARTED = 0b0000_0001; - const UPGRADE = 0b0000_0010; - const KEEPALIVE = 0b0000_0100; - const DISCONNECTED = 0b0000_1000; - } -} - -pub(crate) struct HttpClientWriter { - flags: Flags, - written: u64, - headers_size: u32, - buffer: SharedBytes, - buffer_capacity: usize, - encoder: ContentEncoder, -} - -impl HttpClientWriter { - - pub fn new(buffer: SharedBytes) -> HttpClientWriter { - let encoder = ContentEncoder::Identity(TransferEncoding::eof(buffer.clone())); - HttpClientWriter { - flags: Flags::empty(), - written: 0, - headers_size: 0, - buffer_capacity: 0, - buffer, - encoder, - } - } - - pub fn disconnected(&mut self) { - self.buffer.take(); - } - - // pub fn keepalive(&self) -> bool { - // self.flags.contains(Flags::KEEPALIVE) && !self.flags.contains(Flags::UPGRADE) - // } - - fn write_to_stream(&mut self, stream: &mut T) -> io::Result { - while !self.buffer.is_empty() { - match stream.write(self.buffer.as_ref()) { - Ok(0) => { - self.disconnected(); - return Ok(WriterState::Done); - }, - Ok(n) => { - let _ = self.buffer.split_to(n); - }, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - if self.buffer.len() > self.buffer_capacity { - return Ok(WriterState::Pause) - } else { - return Ok(WriterState::Done) - } - } - Err(err) => return Err(err), - } - } - Ok(WriterState::Done) - } -} - -impl HttpClientWriter { - - pub fn start(&mut self, msg: &mut ClientRequest) -> io::Result<()> { - // prepare task - self.flags.insert(Flags::STARTED); - self.encoder = content_encoder(self.buffer.clone(), msg); - - if msg.upgrade() { - self.flags.insert(Flags::UPGRADE); - } - - // render message - { - // status line - write!(self.buffer, "{} {} {:?}\r\n", - msg.method(), - msg.uri().path_and_query().map(|u| u.as_str()).unwrap_or("/"), - msg.version())?; - - // write headers - let mut buffer = self.buffer.get_mut(); - if let Body::Binary(ref bytes) = *msg.body() { - buffer.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE + bytes.len()); - } else { - buffer.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE); - } - - for (key, value) in msg.headers() { - let v = value.as_ref(); - let k = key.as_str().as_bytes(); - buffer.reserve(k.len() + v.len() + 4); - buffer.put_slice(k); - buffer.put_slice(b": "); - buffer.put_slice(v); - buffer.put_slice(b"\r\n"); - } - - // set date header - if !msg.headers().contains_key(DATE) { - buffer.extend_from_slice(b"date: "); - set_date(&mut buffer); - buffer.extend_from_slice(b"\r\n\r\n"); - } else { - buffer.extend_from_slice(b"\r\n"); - } - self.headers_size = buffer.len() as u32; - - if msg.body().is_binary() { - if let Body::Binary(bytes) = msg.replace_body(Body::Empty) { - self.written += bytes.len() as u64; - self.encoder.write(bytes)?; - } - } else { - self.buffer_capacity = msg.write_buffer_capacity(); - } - } - Ok(()) - } - - pub fn write(&mut self, payload: Binary) -> io::Result { - self.written += payload.len() as u64; - if !self.flags.contains(Flags::DISCONNECTED) { - if self.flags.contains(Flags::UPGRADE) { - self.buffer.extend(payload); - } else { - self.encoder.write(payload)?; - } - } - - if self.buffer.len() > self.buffer_capacity { - Ok(WriterState::Pause) - } else { - Ok(WriterState::Done) - } - } - - pub fn write_eof(&mut self) -> io::Result<()> { - self.encoder.write_eof()?; - - if self.encoder.is_eof() { - Ok(()) - } else { - Err(io::Error::new(io::ErrorKind::Other, - "Last payload item, but eof is not reached")) - } - } - - #[inline] - pub fn poll_completed(&mut self, stream: &mut T, shutdown: bool) - -> Poll<(), io::Error> - { - match self.write_to_stream(stream) { - Ok(WriterState::Done) => { - if shutdown { - stream.shutdown() - } else { - Ok(Async::Ready(())) - } - }, - Ok(WriterState::Pause) => Ok(Async::NotReady), - Err(err) => Err(err) - } - } -} - - -fn content_encoder(buf: SharedBytes, req: &mut ClientRequest) -> ContentEncoder { - let version = req.version(); - let mut body = req.replace_body(Body::Empty); - let mut encoding = req.content_encoding(); - - let transfer = match body { - Body::Empty => { - req.headers_mut().remove(CONTENT_LENGTH); - TransferEncoding::length(0, buf) - }, - Body::Binary(ref mut bytes) => { - if encoding.is_compression() { - let tmp = SharedBytes::default(); - let transfer = TransferEncoding::eof(tmp.clone()); - let mut enc = match encoding { - ContentEncoding::Deflate => ContentEncoder::Deflate( - DeflateEncoder::new(transfer, Compression::default())), - ContentEncoding::Gzip => ContentEncoder::Gzip( - GzEncoder::new(transfer, Compression::default())), - #[cfg(feature="brotli")] - ContentEncoding::Br => ContentEncoder::Br( - BrotliEncoder::new(transfer, 5)), - ContentEncoding::Identity => ContentEncoder::Identity(transfer), - ContentEncoding::Auto => unreachable!() - }; - // TODO return error! - let _ = enc.write(bytes.clone()); - let _ = enc.write_eof(); - *bytes = Binary::from(tmp.take()); - - req.headers_mut().insert( - CONTENT_ENCODING, HeaderValue::from_static(encoding.as_str())); - encoding = ContentEncoding::Identity; - } - let mut b = BytesMut::new(); - let _ = write!(b, "{}", bytes.len()); - req.headers_mut().insert( - CONTENT_LENGTH, HeaderValue::try_from(b.freeze()).unwrap()); - TransferEncoding::eof(buf) - }, - Body::Streaming(_) | Body::Actor(_) => { - if req.upgrade() { - if version == Version::HTTP_2 { - error!("Connection upgrade is forbidden for HTTP/2"); - } else { - req.headers_mut().insert(CONNECTION, HeaderValue::from_static("upgrade")); - } - if encoding != ContentEncoding::Identity { - encoding = ContentEncoding::Identity; - req.headers_mut().remove(CONTENT_ENCODING); - } - TransferEncoding::eof(buf) - } else { - streaming_encoding(buf, version, req) - } - } - }; - - if encoding.is_compression() { - req.headers_mut().insert( - CONTENT_ENCODING, HeaderValue::from_static(encoding.as_str())); - } - - req.replace_body(body); - match encoding { - ContentEncoding::Deflate => ContentEncoder::Deflate( - DeflateEncoder::new(transfer, Compression::default())), - ContentEncoding::Gzip => ContentEncoder::Gzip( - GzEncoder::new(transfer, Compression::default())), - #[cfg(feature="brotli")] - ContentEncoding::Br => ContentEncoder::Br( - BrotliEncoder::new(transfer, 5)), - ContentEncoding::Identity | ContentEncoding::Auto => ContentEncoder::Identity(transfer), - } -} - -fn streaming_encoding(buf: SharedBytes, version: Version, req: &mut ClientRequest) - -> TransferEncoding { - if req.chunked() { - // Enable transfer encoding - req.headers_mut().remove(CONTENT_LENGTH); - if version == Version::HTTP_2 { - req.headers_mut().remove(TRANSFER_ENCODING); - TransferEncoding::eof(buf) - } else { - req.headers_mut().insert( - TRANSFER_ENCODING, HeaderValue::from_static("chunked")); - TransferEncoding::chunked(buf) - } - } else { - // if Content-Length is specified, then use it as length hint - let (len, chunked) = - if let Some(len) = req.headers().get(CONTENT_LENGTH) { - // Content-Length - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - (Some(len), false) - } else { - error!("illegal Content-Length: {:?}", len); - (None, false) - } - } else { - error!("illegal Content-Length: {:?}", len); - (None, false) - } - } else { - (None, true) - }; - - if !chunked { - if let Some(len) = len { - TransferEncoding::length(len, buf) - } else { - TransferEncoding::eof(buf) - } - } else { - // Enable transfer encoding - match version { - Version::HTTP_11 => { - req.headers_mut().insert( - TRANSFER_ENCODING, HeaderValue::from_static("chunked")); - TransferEncoding::chunked(buf) - }, - _ => { - req.headers_mut().remove(TRANSFER_ENCODING); - TransferEncoding::eof(buf) - } - } - } - } -} - - -// "Sun, 06 Nov 1994 08:49:37 GMT".len() -pub const DATE_VALUE_LENGTH: usize = 29; - -fn set_date(dst: &mut BytesMut) { - CACHED.with(|cache| { - let mut cache = cache.borrow_mut(); - let now = time::get_time(); - if now > cache.next_update { - cache.update(now); - } - dst.extend_from_slice(cache.buffer()); - }) -} - -struct CachedDate { - bytes: [u8; DATE_VALUE_LENGTH], - next_update: time::Timespec, -} - -thread_local!(static CACHED: RefCell = RefCell::new(CachedDate { - bytes: [0; DATE_VALUE_LENGTH], - next_update: time::Timespec::new(0, 0), -})); - -impl CachedDate { - fn buffer(&self) -> &[u8] { - &self.bytes[..] - } - - fn update(&mut self, now: time::Timespec) { - write!(&mut self.bytes[..], "{}", time::at_utc(now).rfc822()).unwrap(); - self.next_update = now + Duration::seconds(1); - self.next_update.nsec = 0; - } -} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 000000000..57ba10079 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,350 @@ +use std::net::SocketAddr; +use std::rc::Rc; + +use actix_http::Extensions; +use actix_router::ResourceDef; +use actix_service::{boxed, IntoServiceFactory, ServiceFactory}; + +use crate::data::{Data, DataFactory}; +use crate::error::Error; +use crate::guard::Guard; +use crate::resource::Resource; +use crate::rmap::ResourceMap; +use crate::route::Route; +use crate::service::{ + AppServiceFactory, HttpServiceFactory, ServiceFactoryWrapper, ServiceRequest, + ServiceResponse, +}; + +type Guards = Vec>; +type HttpNewService = + boxed::BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; + +/// Application configuration +pub struct AppService { + config: AppConfig, + root: bool, + default: Rc, + services: Vec<( + ResourceDef, + HttpNewService, + Option, + Option>, + )>, + service_data: Rc>>, +} + +impl AppService { + /// Crate server settings instance + pub(crate) fn new( + config: AppConfig, + default: Rc, + service_data: Rc>>, + ) -> Self { + AppService { + config, + default, + service_data, + root: true, + services: Vec::new(), + } + } + + /// Check if root is beeing configured + pub fn is_root(&self) -> bool { + self.root + } + + pub(crate) fn into_services( + self, + ) -> ( + AppConfig, + Vec<( + ResourceDef, + HttpNewService, + Option, + Option>, + )>, + ) { + (self.config, self.services) + } + + pub(crate) fn clone_config(&self) -> Self { + AppService { + config: self.config.clone(), + default: self.default.clone(), + services: Vec::new(), + root: false, + service_data: self.service_data.clone(), + } + } + + /// Service configuration + pub fn config(&self) -> &AppConfig { + &self.config + } + + /// Default resource + pub fn default_service(&self) -> Rc { + self.default.clone() + } + + /// Set global route data + pub fn set_service_data(&self, extensions: &mut Extensions) -> bool { + for f in self.service_data.iter() { + f.create(extensions); + } + !self.service_data.is_empty() + } + + /// Register http service + pub fn register_service( + &mut self, + rdef: ResourceDef, + guards: Option>>, + factory: F, + nested: Option>, + ) where + F: IntoServiceFactory, + S: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + > + 'static, + { + self.services.push(( + rdef, + boxed::factory(factory.into_factory()), + guards, + nested, + )); + } +} + +#[derive(Clone)] +pub struct AppConfig(pub(crate) Rc); + +impl AppConfig { + pub(crate) fn new(inner: AppConfigInner) -> Self { + AppConfig(Rc::new(inner)) + } + + /// Set server host name. + /// + /// Host name is used by application router as a hostname for url + /// generation. Check [ConnectionInfo](./dev/struct.ConnectionInfo. + /// html#method.host) documentation for more information. + /// + /// By default host name is set to a "localhost" value. + pub fn host(&self) -> &str { + &self.0.host + } + + /// Returns true if connection is secure(https) + pub fn secure(&self) -> bool { + self.0.secure + } + + /// Returns the socket address of the local half of this TCP connection + pub fn local_addr(&self) -> SocketAddr { + self.0.addr + } +} + +pub(crate) struct AppConfigInner { + pub(crate) secure: bool, + pub(crate) host: String, + pub(crate) addr: SocketAddr, +} + +impl Default for AppConfigInner { + fn default() -> AppConfigInner { + AppConfigInner { + secure: false, + addr: "127.0.0.1:8080".parse().unwrap(), + host: "localhost:8080".to_owned(), + } + } +} + +/// Service config is used for external configuration. +/// Part of application configuration could be offloaded +/// to set of external methods. This could help with +/// modularization of big application configuration. +pub struct ServiceConfig { + pub(crate) services: Vec>, + pub(crate) data: Vec>, + pub(crate) external: Vec, +} + +impl ServiceConfig { + pub(crate) fn new() -> Self { + Self { + services: Vec::new(), + data: Vec::new(), + external: Vec::new(), + } + } + + /// Set application data. Application data could be accessed + /// by using `Data` extractor where `T` is data type. + /// + /// This is same as `App::data()` method. + pub fn data(&mut self, data: S) -> &mut Self { + self.data.push(Box::new(Data::new(data))); + self + } + + /// Configure route for a specific path. + /// + /// This is same as `App::route()` method. + pub fn route(&mut self, path: &str, mut route: Route) -> &mut Self { + self.service( + Resource::new(path) + .add_guards(route.take_guards()) + .route(route), + ) + } + + /// Register http service. + /// + /// This is same as `App::service()` method. + pub fn service(&mut self, factory: F) -> &mut Self + where + F: HttpServiceFactory + 'static, + { + self.services + .push(Box::new(ServiceFactoryWrapper::new(factory))); + self + } + + /// Register an external resource. + /// + /// External resources are useful for URL generation purposes only + /// and are never considered for matching at request time. Calls to + /// `HttpRequest::url_for()` will work as expected. + /// + /// This is same as `App::external_service()` method. + pub fn external_resource(&mut self, name: N, url: U) -> &mut Self + where + N: AsRef, + U: AsRef, + { + let mut rdef = ResourceDef::new(url.as_ref()); + *rdef.name_mut() = name.as_ref().to_string(); + self.external.push(rdef); + self + } +} + +#[cfg(test)] +mod tests { + use actix_service::Service; + use bytes::Bytes; + + use super::*; + use crate::http::{Method, StatusCode}; + use crate::test::{call_service, init_service, read_body, TestRequest}; + use crate::{web, App, HttpRequest, HttpResponse}; + + #[actix_rt::test] + async fn test_data() { + let cfg = |cfg: &mut ServiceConfig| { + cfg.data(10usize); + }; + + let mut srv = + init_service(App::new().configure(cfg).service( + web::resource("/").to(|_: web::Data| HttpResponse::Ok()), + )) + .await; + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + + // #[actix_rt::test] + // async fn test_data_factory() { + // let cfg = |cfg: &mut ServiceConfig| { + // cfg.data_factory(|| { + // sleep(std::time::Duration::from_millis(50)).then(|_| { + // println!("READY"); + // Ok::<_, ()>(10usize) + // }) + // }); + // }; + + // let mut srv = + // init_service(App::new().configure(cfg).service( + // web::resource("/").to(|_: web::Data| HttpResponse::Ok()), + // )); + // let req = TestRequest::default().to_request(); + // let resp = srv.call(req).await.unwrap(); + // assert_eq!(resp.status(), StatusCode::OK); + + // let cfg2 = |cfg: &mut ServiceConfig| { + // cfg.data_factory(|| Ok::<_, ()>(10u32)); + // }; + // let mut srv = init_service( + // App::new() + // .service(web::resource("/").to(|_: web::Data| HttpResponse::Ok())) + // .configure(cfg2), + // ); + // let req = TestRequest::default().to_request(); + // let resp = srv.call(req).await.unwrap(); + // assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + // } + + #[actix_rt::test] + async fn test_external_resource() { + let mut srv = init_service( + App::new() + .configure(|cfg| { + cfg.external_resource( + "youtube", + "https://youtube.com/watch/{video_id}", + ); + }) + .route( + "/test", + web::get().to(|req: HttpRequest| { + HttpResponse::Ok().body(format!( + "{}", + req.url_for("youtube", &["12345"]).unwrap() + )) + }), + ), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + let body = read_body(resp).await; + assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345")); + } + + #[actix_rt::test] + async fn test_service() { + let mut srv = init_service(App::new().configure(|cfg| { + cfg.service( + web::resource("/test").route(web::get().to(|| HttpResponse::Created())), + ) + .route("/index.html", web::get().to(|| HttpResponse::Ok())); + })) + .await; + + let req = TestRequest::with_uri("/test") + .method(Method::GET) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::CREATED); + + let req = TestRequest::with_uri("/index.html") + .method(Method::GET) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } +} diff --git a/src/context.rs b/src/context.rs deleted file mode 100644 index 5958f8919..000000000 --- a/src/context.rs +++ /dev/null @@ -1,250 +0,0 @@ -use std::mem; -use std::marker::PhantomData; -use futures::{Async, Future, Poll}; -use futures::sync::oneshot::Sender; -use futures::unsync::oneshot; -use smallvec::SmallVec; - -use actix::{Actor, ActorState, ActorContext, AsyncContext, - Addr, Handler, Message, SpawnHandle, Syn, Unsync}; -use actix::fut::ActorFuture; -use actix::dev::{ContextImpl, ToEnvelope, SyncEnvelope}; - -use body::{Body, Binary}; -use error::{Error, ErrorInternalServerError}; -use httprequest::HttpRequest; - - -pub trait ActorHttpContext: 'static { - fn disconnected(&mut self); - fn poll(&mut self) -> Poll>, Error>; -} - -#[derive(Debug)] -pub enum Frame { - Chunk(Option), - Drain(oneshot::Sender<()>), -} - -impl Frame { - pub fn len(&self) -> usize { - match *self { - Frame::Chunk(Some(ref bin)) => bin.len(), - _ => 0, - } - } -} - -/// Execution context for http actors -pub struct HttpContext where A: Actor>, -{ - inner: ContextImpl, - stream: Option>, - request: HttpRequest, - disconnected: bool, -} - -impl ActorContext for HttpContext where A: Actor -{ - fn stop(&mut self) { - self.inner.stop(); - } - fn terminate(&mut self) { - self.inner.terminate() - } - fn state(&self) -> ActorState { - self.inner.state() - } -} - -impl AsyncContext for HttpContext where A: Actor -{ - #[inline] - fn spawn(&mut self, fut: F) -> SpawnHandle - where F: ActorFuture + 'static - { - self.inner.spawn(fut) - } - #[inline] - fn wait(&mut self, fut: F) - where F: ActorFuture + 'static - { - self.inner.wait(fut) - } - #[doc(hidden)] - #[inline] - fn waiting(&self) -> bool { - self.inner.waiting() || self.inner.state() == ActorState::Stopping || - self.inner.state() == ActorState::Stopped - } - #[inline] - fn cancel_future(&mut self, handle: SpawnHandle) -> bool { - self.inner.cancel_future(handle) - } - #[doc(hidden)] - #[inline] - fn unsync_address(&mut self) -> Addr { - self.inner.unsync_address() - } - #[doc(hidden)] - #[inline] - fn sync_address(&mut self) -> Addr { - self.inner.sync_address() - } -} - -impl HttpContext where A: Actor { - - #[inline] - pub fn new(req: HttpRequest, actor: A) -> HttpContext { - HttpContext::from_request(req).actor(actor) - } - pub fn from_request(req: HttpRequest) -> HttpContext { - HttpContext { - inner: ContextImpl::new(None), - stream: None, - request: req, - disconnected: false, - } - } - #[inline] - pub fn actor(mut self, actor: A) -> HttpContext { - self.inner.set_actor(actor); - self - } -} - -impl HttpContext where A: Actor { - - /// Shared application state - #[inline] - pub fn state(&self) -> &S { - self.request.state() - } - - /// Incoming request - #[inline] - pub fn request(&mut self) -> &mut HttpRequest { - &mut self.request - } - - /// Write payload - #[inline] - pub fn write>(&mut self, data: B) { - if !self.disconnected { - self.add_frame(Frame::Chunk(Some(data.into()))); - } else { - warn!("Trying to write to disconnected response"); - } - } - - /// Indicate end of streaming payload. Also this method calls `Self::close`. - #[inline] - pub fn write_eof(&mut self) { - self.add_frame(Frame::Chunk(None)); - } - - /// Returns drain future - pub fn drain(&mut self) -> Drain { - let (tx, rx) = oneshot::channel(); - self.inner.modify(); - self.add_frame(Frame::Drain(tx)); - Drain::new(rx) - } - - /// Check if connection still open - #[inline] - pub fn connected(&self) -> bool { - !self.disconnected - } - - #[inline] - fn add_frame(&mut self, frame: Frame) { - if self.stream.is_none() { - self.stream = Some(SmallVec::new()); - } - self.stream.as_mut().map(|s| s.push(frame)); - self.inner.modify(); - } - - /// Handle of the running future - /// - /// SpawnHandle is the handle returned by `AsyncContext::spawn()` method. - pub fn handle(&self) -> SpawnHandle { - self.inner.curr_handle() - } -} - -impl ActorHttpContext for HttpContext where A: Actor, S: 'static { - - #[inline] - fn disconnected(&mut self) { - self.disconnected = true; - self.stop(); - } - - fn poll(&mut self) -> Poll>, Error> { - let ctx: &mut HttpContext = unsafe { - mem::transmute(self as &mut HttpContext) - }; - - if self.inner.alive() { - match self.inner.poll(ctx) { - Ok(Async::NotReady) | Ok(Async::Ready(())) => (), - Err(_) => return Err(ErrorInternalServerError("error")), - } - } - - // frames - if let Some(data) = self.stream.take() { - Ok(Async::Ready(Some(data))) - } else if self.inner.alive() { - Ok(Async::NotReady) - } else { - Ok(Async::Ready(None)) - } - } -} - -impl ToEnvelope for HttpContext - where A: Actor> + Handler, - M: Message + Send + 'static, M::Result: Send, -{ - fn pack(msg: M, tx: Option>) -> SyncEnvelope { - SyncEnvelope::new(msg, tx) - } -} - -impl From> for Body - where A: Actor>, - S: 'static -{ - fn from(ctx: HttpContext) -> Body { - Body::Actor(Box::new(ctx)) - } -} - -pub struct Drain { - fut: oneshot::Receiver<()>, - _a: PhantomData, -} - -impl Drain { - pub fn new(fut: oneshot::Receiver<()>) -> Self { - Drain { fut, _a: PhantomData } - } -} - -impl ActorFuture for Drain { - type Item = (); - type Error = (); - type Actor = A; - - #[inline] - fn poll(&mut self, - _: &mut A, - _: &mut ::Context) -> Poll - { - self.fut.poll().map_err(|_| ()) - } -} diff --git a/src/data.rs b/src/data.rs new file mode 100644 index 000000000..e8928188f --- /dev/null +++ b/src/data.rs @@ -0,0 +1,235 @@ +use std::ops::Deref; +use std::sync::Arc; + +use actix_http::error::{Error, ErrorInternalServerError}; +use actix_http::Extensions; +use futures::future::{err, ok, Ready}; + +use crate::dev::Payload; +use crate::extract::FromRequest; +use crate::request::HttpRequest; + +/// Application data factory +pub(crate) trait DataFactory { + fn create(&self, extensions: &mut Extensions) -> bool; +} + +/// Application data. +/// +/// Application data is an arbitrary data attached to the app. +/// Application data is available to all routes and could be added +/// during application configuration process +/// with `App::data()` method. +/// +/// Application data could be accessed by using `Data` +/// extractor where `T` is data type. +/// +/// **Note**: http server accepts an application factory rather than +/// an application instance. Http server constructs an application +/// instance for each thread, thus application data must be constructed +/// multiple times. If you want to share data between different +/// threads, a shareable object should be used, e.g. `Send + Sync`. Application +/// data does not need to be `Send` or `Sync`. Internally `Data` type +/// uses `Arc`. if your data implements `Send` + `Sync` traits you can +/// use `web::Data::new()` and avoid double `Arc`. +/// +/// If route data is not set for a handler, using `Data` extractor would +/// cause *Internal Server Error* response. +/// +/// ```rust +/// use std::sync::Mutex; +/// use actix_web::{web, App, HttpResponse, Responder}; +/// +/// struct MyData { +/// counter: usize, +/// } +/// +/// /// Use `Data` extractor to access data in handler. +/// async fn index(data: web::Data>) -> impl Responder { +/// let mut data = data.lock().unwrap(); +/// data.counter += 1; +/// HttpResponse::Ok() +/// } +/// +/// fn main() { +/// let data = web::Data::new(Mutex::new(MyData{ counter: 0 })); +/// +/// let app = App::new() +/// // Store `MyData` in application storage. +/// .register_data(data.clone()) +/// .service( +/// web::resource("/index.html").route( +/// web::get().to(index))); +/// } +/// ``` +#[derive(Debug)] +pub struct Data(Arc); + +impl Data { + /// Create new `Data` instance. + /// + /// Internally `Data` type uses `Arc`. if your data implements + /// `Send` + `Sync` traits you can use `web::Data::new()` and + /// avoid double `Arc`. + pub fn new(state: T) -> Data { + Data(Arc::new(state)) + } + + /// Get reference to inner app data. + pub fn get_ref(&self) -> &T { + self.0.as_ref() + } + + /// Convert to the internal Arc + pub fn into_inner(self) -> Arc { + self.0 + } +} + +impl Deref for Data { + type Target = T; + + fn deref(&self) -> &T { + self.0.as_ref() + } +} + +impl Clone for Data { + fn clone(&self) -> Data { + Data(self.0.clone()) + } +} + +impl FromRequest for Data { + type Config = (); + type Error = Error; + type Future = Ready>; + + #[inline] + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + if let Some(st) = req.get_app_data::() { + ok(st) + } else { + log::debug!( + "Failed to construct App-level Data extractor. \ + Request path: {:?}", + req.path() + ); + err(ErrorInternalServerError( + "App data is not configured, to configure use App::data()", + )) + } + } +} + +impl DataFactory for Data { + fn create(&self, extensions: &mut Extensions) -> bool { + if !extensions.contains::>() { + extensions.insert(Data(self.0.clone())); + true + } else { + false + } + } +} + +#[cfg(test)] +mod tests { + use actix_service::Service; + + use super::*; + use crate::http::StatusCode; + use crate::test::{init_service, TestRequest}; + use crate::{web, App, HttpResponse}; + + #[actix_rt::test] + async fn test_data_extractor() { + let mut srv = + init_service(App::new().data(10usize).service( + web::resource("/").to(|_: web::Data| HttpResponse::Ok()), + )) + .await; + + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let mut srv = + init_service(App::new().data(10u32).service( + web::resource("/").to(|_: web::Data| HttpResponse::Ok()), + )) + .await; + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + } + + #[actix_rt::test] + async fn test_register_data_extractor() { + let mut srv = + init_service(App::new().register_data(Data::new(10usize)).service( + web::resource("/").to(|_: web::Data| HttpResponse::Ok()), + )) + .await; + + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let mut srv = + init_service(App::new().register_data(Data::new(10u32)).service( + web::resource("/").to(|_: web::Data| HttpResponse::Ok()), + )) + .await; + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + } + + #[actix_rt::test] + async fn test_route_data_extractor() { + let mut srv = + init_service(App::new().service(web::resource("/").data(10usize).route( + web::get().to(|data: web::Data| { + let _ = data.clone(); + HttpResponse::Ok() + }), + ))) + .await; + + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + // different type + let mut srv = init_service( + App::new().service( + web::resource("/") + .data(10u32) + .route(web::get().to(|_: web::Data| HttpResponse::Ok())), + ), + ) + .await; + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + } + + #[actix_rt::test] + async fn test_override_data() { + let mut srv = init_service(App::new().data(1usize).service( + web::resource("/").data(10usize).route(web::get().to( + |data: web::Data| { + assert_eq!(*data, 10); + let _ = data.clone(); + HttpResponse::Ok() + }, + )), + )) + .await; + + let req = TestRequest::default().to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } +} diff --git a/src/de.rs b/src/de.rs deleted file mode 100644 index 659dc10a6..000000000 --- a/src/de.rs +++ /dev/null @@ -1,390 +0,0 @@ -use std::slice::Iter; -use std::borrow::Cow; -use std::convert::AsRef; -use serde::de::{self, Deserializer, Visitor, Error as DeError}; - -use httprequest::HttpRequest; - - -macro_rules! unsupported_type { - ($trait_fn:ident, $name:expr) => { - fn $trait_fn(self, _: V) -> Result - where V: Visitor<'de> - { - Err(de::value::Error::custom(concat!("unsupported type: ", $name))) - } - }; -} - -macro_rules! parse_single_value { - ($trait_fn:ident, $visit_fn:ident, $tp:tt) => { - fn $trait_fn(self, visitor: V) -> Result - where V: Visitor<'de> - { - if self.req.match_info().len() != 1 { - Err(de::value::Error::custom( - format!("wrong number of parameters: {} expected 1", - self.req.match_info().len()).as_str())) - } else { - let v = self.req.match_info()[0].parse().map_err( - |_| de::value::Error::custom( - format!("can not parse {:?} to a {}", - &self.req.match_info()[0], $tp)))?; - visitor.$visit_fn(v) - } - } - } -} - -pub struct PathDeserializer<'de, S: 'de> { - req: &'de HttpRequest -} - -impl<'de, S: 'de> PathDeserializer<'de, S> { - pub fn new(req: &'de HttpRequest) -> Self { - PathDeserializer{req} - } -} - -impl<'de, S: 'de> Deserializer<'de> for PathDeserializer<'de, S> -{ - type Error = de::value::Error; - - fn deserialize_map(self, visitor: V) -> Result - where V: Visitor<'de>, - { - visitor.visit_map(ParamsDeserializer{ - params: self.req.match_info().iter(), - current: None, - }) - } - - fn deserialize_struct(self, _: &'static str, _: &'static [&'static str], visitor: V) - -> Result - where V: Visitor<'de>, - { - self.deserialize_map(visitor) - } - - fn deserialize_unit(self, visitor: V) -> Result - where V: Visitor<'de>, - { - visitor.visit_unit() - } - - fn deserialize_unit_struct(self, _: &'static str, visitor: V) - -> Result - where V: Visitor<'de> - { - self.deserialize_unit(visitor) - } - - fn deserialize_newtype_struct(self, _: &'static str, visitor: V) - -> Result - where V: Visitor<'de>, - { - visitor.visit_newtype_struct(self) - } - - fn deserialize_tuple(self, len: usize, visitor: V) -> Result - where V: Visitor<'de> - { - if self.req.match_info().len() < len { - Err(de::value::Error::custom( - format!("wrong number of parameters: {} expected {}", - self.req.match_info().len(), len).as_str())) - } else { - visitor.visit_seq(ParamsSeq{params: self.req.match_info().iter()}) - } - } - - fn deserialize_tuple_struct(self, _: &'static str, len: usize, visitor: V) - -> Result - where V: Visitor<'de> - { - if self.req.match_info().len() < len { - Err(de::value::Error::custom( - format!("wrong number of parameters: {} expected {}", - self.req.match_info().len(), len).as_str())) - } else { - visitor.visit_seq(ParamsSeq{params: self.req.match_info().iter()}) - } - } - - fn deserialize_enum(self, _: &'static str, _: &'static [&'static str], _: V) - -> Result - where V: Visitor<'de> - { - Err(de::value::Error::custom("unsupported type: enum")) - } - - fn deserialize_str(self, visitor: V) -> Result - where V: Visitor<'de>, - { - if self.req.match_info().len() != 1 { - Err(de::value::Error::custom( - format!("wrong number of parameters: {} expected 1", - self.req.match_info().len()).as_str())) - } else { - visitor.visit_str(&self.req.match_info()[0]) - } - } - - fn deserialize_seq(self, visitor: V) -> Result - where V: Visitor<'de> - { - visitor.visit_seq(ParamsSeq{params: self.req.match_info().iter()}) - } - - unsupported_type!(deserialize_any, "'any'"); - unsupported_type!(deserialize_bytes, "bytes"); - unsupported_type!(deserialize_option, "Option"); - unsupported_type!(deserialize_identifier, "identifier"); - unsupported_type!(deserialize_ignored_any, "ignored_any"); - - parse_single_value!(deserialize_bool, visit_bool, "bool"); - parse_single_value!(deserialize_i8, visit_i8, "i8"); - parse_single_value!(deserialize_i16, visit_i16, "i16"); - parse_single_value!(deserialize_i32, visit_i32, "i16"); - parse_single_value!(deserialize_i64, visit_i64, "i64"); - parse_single_value!(deserialize_u8, visit_u8, "u8"); - parse_single_value!(deserialize_u16, visit_u16, "u16"); - parse_single_value!(deserialize_u32, visit_u32, "u32"); - parse_single_value!(deserialize_u64, visit_u64, "u64"); - parse_single_value!(deserialize_f32, visit_f32, "f32"); - parse_single_value!(deserialize_f64, visit_f64, "f64"); - parse_single_value!(deserialize_string, visit_string, "String"); - parse_single_value!(deserialize_byte_buf, visit_string, "String"); - parse_single_value!(deserialize_char, visit_char, "char"); -} - -struct ParamsDeserializer<'de> { - params: Iter<'de, (Cow<'de, str>, Cow<'de, str>)>, - current: Option<(&'de str, &'de str)>, -} - -impl<'de> de::MapAccess<'de> for ParamsDeserializer<'de> -{ - type Error = de::value::Error; - - fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> - where K: de::DeserializeSeed<'de>, - { - self.current = self.params.next().map(|&(ref k, ref v)| (k.as_ref(), v.as_ref())); - match self.current { - Some((key, _)) => Ok(Some(seed.deserialize(Key{key})?)), - None => Ok(None), - } - } - - fn next_value_seed(&mut self, seed: V) -> Result - where V: de::DeserializeSeed<'de>, - { - if let Some((_, value)) = self.current.take() { - seed.deserialize(Value { value }) - } else { - Err(de::value::Error::custom("unexpected item")) - } - } -} - -struct Key<'de> { - key: &'de str, -} - -impl<'de> Deserializer<'de> for Key<'de> { - type Error = de::value::Error; - - fn deserialize_identifier(self, visitor: V) -> Result - where V: Visitor<'de>, - { - visitor.visit_str(self.key) - } - - fn deserialize_any(self, _visitor: V) -> Result - where V: Visitor<'de>, - { - Err(de::value::Error::custom("Unexpected")) - } - - forward_to_deserialize_any! { - bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes - byte_buf option unit unit_struct newtype_struct seq tuple - tuple_struct map struct enum ignored_any - } -} - -macro_rules! parse_value { - ($trait_fn:ident, $visit_fn:ident, $tp:tt) => { - fn $trait_fn(self, visitor: V) -> Result - where V: Visitor<'de> - { - let v = self.value.parse().map_err( - |_| de::value::Error::custom( - format!("can not parse {:?} to a {}", self.value, $tp)))?; - visitor.$visit_fn(v) - } - } -} - -struct Value<'de> { - value: &'de str, -} - -impl<'de> Deserializer<'de> for Value<'de> -{ - type Error = de::value::Error; - - parse_value!(deserialize_bool, visit_bool, "bool"); - parse_value!(deserialize_i8, visit_i8, "i8"); - parse_value!(deserialize_i16, visit_i16, "i16"); - parse_value!(deserialize_i32, visit_i32, "i16"); - parse_value!(deserialize_i64, visit_i64, "i64"); - parse_value!(deserialize_u8, visit_u8, "u8"); - parse_value!(deserialize_u16, visit_u16, "u16"); - parse_value!(deserialize_u32, visit_u32, "u32"); - parse_value!(deserialize_u64, visit_u64, "u64"); - parse_value!(deserialize_f32, visit_f32, "f32"); - parse_value!(deserialize_f64, visit_f64, "f64"); - parse_value!(deserialize_string, visit_string, "String"); - parse_value!(deserialize_byte_buf, visit_string, "String"); - parse_value!(deserialize_char, visit_char, "char"); - - fn deserialize_ignored_any(self, visitor: V) -> Result - where V: Visitor<'de>, - { - visitor.visit_unit() - } - - fn deserialize_unit(self, visitor: V) -> Result - where V: Visitor<'de>, - { - visitor.visit_unit() - } - - fn deserialize_unit_struct( - self, _: &'static str, visitor: V) -> Result - where V: Visitor<'de> - { - visitor.visit_unit() - } - - fn deserialize_bytes(self, visitor: V) -> Result - where V: Visitor<'de>, - { - visitor.visit_borrowed_bytes(self.value.as_bytes()) - } - - fn deserialize_str(self, visitor: V) -> Result - where V: Visitor<'de>, - { - visitor.visit_borrowed_str(self.value) - } - - fn deserialize_option(self, visitor: V) -> Result - where V: Visitor<'de>, - { - visitor.visit_some(self) - } - - fn deserialize_enum(self, _: &'static str, _: &'static [&'static str], visitor: V) - -> Result - where V: Visitor<'de>, - { - visitor.visit_enum(ValueEnum {value: self.value}) - } - - fn deserialize_newtype_struct(self, _: &'static str, visitor: V) - -> Result - where V: Visitor<'de>, - { - visitor.visit_newtype_struct(self) - } - - fn deserialize_tuple(self, _: usize, _: V) -> Result - where V: Visitor<'de> - { - Err(de::value::Error::custom("unsupported type: tuple")) - } - - fn deserialize_struct(self, _: &'static str, _: &'static [&'static str], _: V) - -> Result - where V: Visitor<'de> - { - Err(de::value::Error::custom("unsupported type: struct")) - } - - fn deserialize_tuple_struct(self, _: &'static str, _: usize, _: V) - -> Result - where V: Visitor<'de> - { - Err(de::value::Error::custom("unsupported type: tuple struct")) - } - - unsupported_type!(deserialize_any, "any"); - unsupported_type!(deserialize_seq, "seq"); - unsupported_type!(deserialize_map, "map"); - unsupported_type!(deserialize_identifier, "identifier"); -} - -struct ParamsSeq<'de> { - params: Iter<'de, (Cow<'de, str>, Cow<'de, str>)>, -} - -impl<'de> de::SeqAccess<'de> for ParamsSeq<'de> -{ - type Error = de::value::Error; - - fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> - where T: de::DeserializeSeed<'de>, - { - match self.params.next() { - Some(item) => Ok(Some(seed.deserialize(Value { value: item.1.as_ref() })?)), - None => Ok(None), - } - } -} - -struct ValueEnum<'de> { - value: &'de str, -} - -impl<'de> de::EnumAccess<'de> for ValueEnum<'de> { - type Error = de::value::Error; - type Variant = UnitVariant; - - fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> - where V: de::DeserializeSeed<'de>, - { - Ok((seed.deserialize(Key { key: self.value })?, UnitVariant)) - } -} - -struct UnitVariant; - -impl<'de> de::VariantAccess<'de> for UnitVariant { - type Error = de::value::Error; - - fn unit_variant(self) -> Result<(), Self::Error> { - Ok(()) - } - - fn newtype_variant_seed(self, _seed: T) -> Result - where T: de::DeserializeSeed<'de>, - { - Err(de::value::Error::custom("not supported")) - } - - fn tuple_variant(self, _len: usize, _visitor: V) -> Result - where V: Visitor<'de>, - { - Err(de::value::Error::custom("not supported")) - } - - fn struct_variant(self, _: &'static [&'static str], _: V) - -> Result - where V: Visitor<'de>, - { - Err(de::value::Error::custom("not supported")) - } -} diff --git a/src/error.rs b/src/error.rs index dc4ae78ec..2eec7c51b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,823 +1,192 @@ //! Error and Result module -use std::{io, fmt, result}; -use std::str::Utf8Error; -use std::string::FromUtf8Error; -use std::io::Error as IoError; - -use cookie; -use httparse; -use actix::MailboxError; -use futures::Canceled; -use failure::{self, Fail, Backtrace}; -use http2::Error as Http2Error; -use http::{header, StatusCode, Error as HttpError}; -use http::uri::InvalidUri; -use http_range::HttpRangeParseError; -use serde::de::value::Error as DeError; +pub use actix_http::error::*; +use derive_more::{Display, From}; use serde_json::error::Error as JsonError; -pub use url::ParseError as UrlParseError; +use url::ParseError as UrlParseError; -// re-exports -pub use cookie::{ParseError as CookieParseError}; - -use handler::Responder; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; - -/// A specialized [`Result`](https://doc.rust-lang.org/std/result/enum.Result.html) -/// for actix web operations -/// -/// This typedef is generally used to avoid writing out `actix_web::error::Error` directly and -/// is otherwise a direct mapping to `Result`. -pub type Result = result::Result; - -/// General purpose actix web error -pub struct Error { - cause: Box, - backtrace: Option, -} - -impl Error { - - /// Returns a reference to the underlying cause of this Error. - // this should return &Fail but needs this https://github.com/rust-lang/rust/issues/5665 - pub fn cause(&self) -> &ResponseError { - self.cause.as_ref() - } -} - -/// Error that can be converted to `HttpResponse` -pub trait ResponseError: Fail { - - /// Create response for error - /// - /// Internal server error is generated by default. - fn error_response(&self) -> HttpResponse { - HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR) - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - fmt::Display::fmt(&self.cause, f) - } -} - -impl fmt::Debug for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if let Some(bt) = self.cause.backtrace() { - write!(f, "{:?}\n\n{:?}", &self.cause, bt) - } else { - write!(f, "{:?}\n\n{:?}", &self.cause, self.backtrace.as_ref().unwrap()) - } - } -} - -/// `HttpResponse` for `Error` -impl From for HttpResponse { - fn from(err: Error) -> Self { - HttpResponse::from_error(err) - } -} - -/// `Error` for any error that implements `ResponseError` -impl From for Error { - fn from(err: T) -> Error { - let backtrace = if err.backtrace().is_none() { - Some(Backtrace::new()) - } else { - None - }; - Error { cause: Box::new(err), backtrace } - } -} - -/// Compatibility for `failure::Error` -impl ResponseError for failure::Compat - where T: fmt::Display + fmt::Debug + Sync + Send + 'static { } - -impl From for Error { - fn from(err: failure::Error) -> Error { - err.compat().into() - } -} - -/// `InternalServerError` for `JsonError` -impl ResponseError for JsonError {} - -/// `InternalServerError` for `UrlParseError` -impl ResponseError for UrlParseError {} - -/// Return `BAD_REQUEST` for `de::value::Error` -impl ResponseError for DeError { - fn error_response(&self) -> HttpResponse { - HttpResponse::new(StatusCode::BAD_REQUEST) - } -} - -/// Return `BAD_REQUEST` for `Utf8Error` -impl ResponseError for Utf8Error { - fn error_response(&self) -> HttpResponse { - HttpResponse::new(StatusCode::BAD_REQUEST) - } -} - -/// Return `InternalServerError` for `HttpError`, -/// Response generation can return `HttpError`, so it is internal error -impl ResponseError for HttpError {} - -/// Return `InternalServerError` for `io::Error` -impl ResponseError for io::Error { - - fn error_response(&self) -> HttpResponse { - match self.kind() { - io::ErrorKind::NotFound => - HttpResponse::new(StatusCode::NOT_FOUND), - io::ErrorKind::PermissionDenied => - HttpResponse::new(StatusCode::FORBIDDEN), - _ => - HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR) - } - } -} - -/// `BadRequest` for `InvalidHeaderValue` -impl ResponseError for header::InvalidHeaderValue { - fn error_response(&self) -> HttpResponse { - HttpResponse::new(StatusCode::BAD_REQUEST) - } -} - -/// `BadRequest` for `InvalidHeaderValue` -impl ResponseError for header::InvalidHeaderValueBytes { - fn error_response(&self) -> HttpResponse { - HttpResponse::new(StatusCode::BAD_REQUEST) - } -} - -/// `InternalServerError` for `futures::Canceled` -impl ResponseError for Canceled {} - -/// `InternalServerError` for `actix::MailboxError` -impl ResponseError for MailboxError {} - -/// A set of errors that can occur during parsing HTTP streams -#[derive(Fail, Debug)] -pub enum ParseError { - /// An invalid `Method`, such as `GE.T`. - #[fail(display="Invalid Method specified")] - Method, - /// An invalid `Uri`, such as `exam ple.domain`. - #[fail(display="Uri error: {}", _0)] - Uri(InvalidUri), - /// An invalid `HttpVersion`, such as `HTP/1.1` - #[fail(display="Invalid HTTP version specified")] - Version, - /// An invalid `Header`. - #[fail(display="Invalid Header provided")] - Header, - /// A message head is too large to be reasonable. - #[fail(display="Message head is too large")] - TooLarge, - /// A message reached EOF, but is not complete. - #[fail(display="Message is incomplete")] - Incomplete, - /// An invalid `Status`, such as `1337 ELITE`. - #[fail(display="Invalid Status provided")] - Status, - /// A timeout occurred waiting for an IO event. - #[allow(dead_code)] - #[fail(display="Timeout")] - Timeout, - /// An `io::Error` that occurred while trying to read or write to a network stream. - #[fail(display="IO error: {}", _0)] - Io(#[cause] IoError), - /// Parsing a field as string failed - #[fail(display="UTF8 error: {}", _0)] - Utf8(#[cause] Utf8Error), -} - -/// Return `BadRequest` for `ParseError` -impl ResponseError for ParseError { - fn error_response(&self) -> HttpResponse { - HttpResponse::new(StatusCode::BAD_REQUEST) - } -} - -impl From for ParseError { - fn from(err: IoError) -> ParseError { - ParseError::Io(err) - } -} - -impl From for ParseError { - fn from(err: InvalidUri) -> ParseError { - ParseError::Uri(err) - } -} - -impl From for ParseError { - fn from(err: Utf8Error) -> ParseError { - ParseError::Utf8(err) - } -} - -impl From for ParseError { - fn from(err: FromUtf8Error) -> ParseError { - ParseError::Utf8(err.utf8_error()) - } -} - -impl From for ParseError { - fn from(err: httparse::Error) -> ParseError { - match err { - httparse::Error::HeaderName | httparse::Error::HeaderValue | - httparse::Error::NewLine | httparse::Error::Token => ParseError::Header, - httparse::Error::Status => ParseError::Status, - httparse::Error::TooManyHeaders => ParseError::TooLarge, - httparse::Error::Version => ParseError::Version, - } - } -} - -#[derive(Fail, Debug)] -/// A set of errors that can occur during payload parsing -pub enum PayloadError { - /// A payload reached EOF, but is not complete. - #[fail(display="A payload reached EOF, but is not complete.")] - Incomplete, - /// Content encoding stream corruption - #[fail(display="Can not decode content-encoding.")] - EncodingCorrupted, - /// A payload reached size limit. - #[fail(display="A payload reached size limit.")] - Overflow, - /// A payload length is unknown. - #[fail(display="A payload length is unknown.")] - UnknownLength, - /// Io error - #[fail(display="{}", _0)] - Io(#[cause] IoError), - /// Http2 error - #[fail(display="{}", _0)] - Http2(#[cause] Http2Error), -} - -impl From for PayloadError { - fn from(err: IoError) -> PayloadError { - PayloadError::Io(err) - } -} - -/// `InternalServerError` for `PayloadError` -impl ResponseError for PayloadError {} - -/// Return `BadRequest` for `cookie::ParseError` -impl ResponseError for cookie::ParseError { - fn error_response(&self) -> HttpResponse { - HttpResponse::new(StatusCode::BAD_REQUEST) - } -} - -/// Http range header parsing error -#[derive(Fail, PartialEq, Debug)] -pub enum HttpRangeError { - /// Returned if range is invalid. - #[fail(display="Range header is invalid")] - InvalidRange, - /// Returned if first-byte-pos of all of the byte-range-spec - /// values is greater than the content size. - /// See `https://github.com/golang/go/commit/aa9b3d7` - #[fail(display="First-byte-pos of all of the byte-range-spec values is greater than the content size")] - NoOverlap, -} - -/// Return `BadRequest` for `HttpRangeError` -impl ResponseError for HttpRangeError { - fn error_response(&self) -> HttpResponse { - HttpResponse::with_body( - StatusCode::BAD_REQUEST, "Invalid Range header provided") - } -} - -impl From for HttpRangeError { - fn from(err: HttpRangeParseError) -> HttpRangeError { - match err { - HttpRangeParseError::InvalidRange => HttpRangeError::InvalidRange, - HttpRangeParseError::NoOverlap => HttpRangeError::NoOverlap, - } - } -} - -/// A set of errors that can occur during parsing multipart streams -#[derive(Fail, Debug)] -pub enum MultipartError { - /// Content-Type header is not found - #[fail(display="No Content-type header found")] - NoContentType, - /// Can not parse Content-Type header - #[fail(display="Can not parse Content-Type header")] - ParseContentType, - /// Multipart boundary is not found - #[fail(display="Multipart boundary is not found")] - Boundary, - /// Multipart stream is incomplete - #[fail(display="Multipart stream is incomplete")] - Incomplete, - /// Error during field parsing - #[fail(display="{}", _0)] - Parse(#[cause] ParseError), - /// Payload error - #[fail(display="{}", _0)] - Payload(#[cause] PayloadError), -} - -impl From for MultipartError { - fn from(err: ParseError) -> MultipartError { - MultipartError::Parse(err) - } -} - -impl From for MultipartError { - fn from(err: PayloadError) -> MultipartError { - MultipartError::Payload(err) - } -} - -/// Return `BadRequest` for `MultipartError` -impl ResponseError for MultipartError { - - fn error_response(&self) -> HttpResponse { - HttpResponse::new(StatusCode::BAD_REQUEST) - } -} - -/// Error during handling `Expect` header -#[derive(Fail, PartialEq, Debug)] -pub enum ExpectError { - /// Expect header value can not be converted to utf8 - #[fail(display="Expect header value can not be converted to utf8")] - Encoding, - /// Unknown expect value - #[fail(display="Unknown expect value")] - UnknownExpect, -} - -impl ResponseError for ExpectError { - fn error_response(&self) -> HttpResponse { - HttpResponse::with_body(StatusCode::EXPECTATION_FAILED, "Unknown Expect") - } -} - -/// A set of error that can occure during parsing content type -#[derive(Fail, PartialEq, Debug)] -pub enum ContentTypeError { - /// Can not parse content type - #[fail(display="Can not parse content type")] - ParseError, - /// Unknown content encoding - #[fail(display="Unknown content encoding")] - UnknownEncoding, -} - -/// Return `BadRequest` for `ContentTypeError` -impl ResponseError for ContentTypeError { - fn error_response(&self) -> HttpResponse { - HttpResponse::new(StatusCode::BAD_REQUEST) - } -} - -/// A set of errors that can occur during parsing urlencoded payloads -#[derive(Fail, Debug)] -pub enum UrlencodedError { - /// Can not decode chunked transfer encoding - #[fail(display="Can not decode chunked transfer encoding")] - Chunked, - /// Payload size is bigger than 256k - #[fail(display="Payload size is bigger than 256k")] - Overflow, - /// Payload size is now known - #[fail(display="Payload size is now known")] - UnknownLength, - /// Content type error - #[fail(display="Content type error")] - ContentType, - /// Parse error - #[fail(display="Parse error")] - Parse, - /// Payload error - #[fail(display="Error that occur during reading payload: {}", _0)] - Payload(#[cause] PayloadError), -} - -/// Return `BadRequest` for `UrlencodedError` -impl ResponseError for UrlencodedError { - - fn error_response(&self) -> HttpResponse { - match *self { - UrlencodedError::Overflow => - HttpResponse::new(StatusCode::PAYLOAD_TOO_LARGE), - UrlencodedError::UnknownLength => - HttpResponse::new(StatusCode::LENGTH_REQUIRED), - _ => - HttpResponse::new(StatusCode::BAD_REQUEST), - } - } -} - -impl From for UrlencodedError { - fn from(err: PayloadError) -> UrlencodedError { - UrlencodedError::Payload(err) - } -} - -/// A set of errors that can occur during parsing json payloads -#[derive(Fail, Debug)] -pub enum JsonPayloadError { - /// Payload size is bigger than 256k - #[fail(display="Payload size is bigger than 256k")] - Overflow, - /// Content type error - #[fail(display="Content type error")] - ContentType, - /// Deserialize error - #[fail(display="Json deserialize error: {}", _0)] - Deserialize(#[cause] JsonError), - /// Payload error - #[fail(display="Error that occur during reading payload: {}", _0)] - Payload(#[cause] PayloadError), -} - -/// Return `BadRequest` for `UrlencodedError` -impl ResponseError for JsonPayloadError { - - fn error_response(&self) -> HttpResponse { - match *self { - JsonPayloadError::Overflow => - HttpResponse::new(StatusCode::PAYLOAD_TOO_LARGE), - _ => - HttpResponse::new(StatusCode::BAD_REQUEST), - } - } -} - -impl From for JsonPayloadError { - fn from(err: PayloadError) -> JsonPayloadError { - JsonPayloadError::Payload(err) - } -} - -impl From for JsonPayloadError { - fn from(err: JsonError) -> JsonPayloadError { - JsonPayloadError::Deserialize(err) - } -} - -/// Errors which can occur when attempting to interpret a segment string as a -/// valid path segment. -#[derive(Fail, Debug, PartialEq)] -pub enum UriSegmentError { - /// The segment started with the wrapped invalid character. - #[fail(display="The segment started with the wrapped invalid character")] - BadStart(char), - /// The segment contained the wrapped invalid character. - #[fail(display="The segment contained the wrapped invalid character")] - BadChar(char), - /// The segment ended with the wrapped invalid character. - #[fail(display="The segment ended with the wrapped invalid character")] - BadEnd(char), -} - -/// Return `BadRequest` for `UriSegmentError` -impl ResponseError for UriSegmentError { - - fn error_response(&self) -> HttpResponse { - HttpResponse::new(StatusCode::BAD_REQUEST) - } -} +use crate::http::StatusCode; +use crate::HttpResponse; +use serde_urlencoded::de; /// Errors which can occur when attempting to generate resource uri. -#[derive(Fail, Debug, PartialEq)] +#[derive(Debug, PartialEq, Display, From)] pub enum UrlGenerationError { - #[fail(display="Resource not found")] + /// Resource not found + #[display(fmt = "Resource not found")] ResourceNotFound, - #[fail(display="Not all path pattern covered")] + /// Not all path pattern covered + #[display(fmt = "Not all path pattern covered")] NotEnoughElements, - #[fail(display="Router is not available")] - RouterNotAvailable, - #[fail(display="{}", _0)] - ParseError(#[cause] UrlParseError), + /// URL parse error + #[display(fmt = "{}", _0)] + ParseError(UrlParseError), } /// `InternalServerError` for `UrlGeneratorError` impl ResponseError for UrlGenerationError {} -impl From for UrlGenerationError { - fn from(err: UrlParseError) -> Self { - UrlGenerationError::ParseError(err) - } +/// A set of errors that can occur during parsing urlencoded payloads +#[derive(Debug, Display, From)] +pub enum UrlencodedError { + /// Can not decode chunked transfer encoding + #[display(fmt = "Can not decode chunked transfer encoding")] + Chunked, + /// Payload size is bigger than allowed. (default: 256kB) + #[display( + fmt = "Urlencoded payload size is bigger ({} bytes) than allowed (default: {} bytes)", + size, + limit + )] + Overflow { size: usize, limit: usize }, + /// Payload size is now known + #[display(fmt = "Payload size is now known")] + UnknownLength, + /// Content type error + #[display(fmt = "Content type error")] + ContentType, + /// Parse error + #[display(fmt = "Parse error")] + Parse, + /// Payload error + #[display(fmt = "Error that occur during reading payload: {}", _0)] + Payload(PayloadError), } -/// Helper type that can wrap any error and generate custom response. -/// -/// In following example any `io::Error` will be converted into "BAD REQUEST" response -/// as opposite to *INNTERNAL SERVER ERROR* which is defined by default. -/// -/// ```rust -/// # extern crate actix_web; -/// # use actix_web::*; -/// use actix_web::fs::NamedFile; -/// -/// fn index(req: HttpRequest) -> Result { -/// let f = NamedFile::open("test.txt").map_err(error::ErrorBadRequest)?; -/// Ok(f) -/// } -/// # fn main() {} -/// ``` -pub struct InternalError { - cause: T, - status: StatusCode, - backtrace: Backtrace, -} - -unsafe impl Sync for InternalError {} -unsafe impl Send for InternalError {} - -impl InternalError { - pub fn new(cause: T, status: StatusCode) -> Self { - InternalError { - cause, - status, - backtrace: Backtrace::new(), +/// Return `BadRequest` for `UrlencodedError` +impl ResponseError for UrlencodedError { + fn status_code(&self) -> StatusCode { + match *self { + UrlencodedError::Overflow { .. } => StatusCode::PAYLOAD_TOO_LARGE, + UrlencodedError::UnknownLength => StatusCode::LENGTH_REQUIRED, + _ => StatusCode::BAD_REQUEST, } } } -impl Fail for InternalError - where T: Send + Sync + fmt::Debug + 'static -{ - fn backtrace(&self) -> Option<&Backtrace> { - Some(&self.backtrace) - } +/// A set of errors that can occur during parsing json payloads +#[derive(Debug, Display, From)] +pub enum JsonPayloadError { + /// Payload size is bigger than allowed. (default: 32kB) + #[display(fmt = "Json payload size is bigger than allowed")] + Overflow, + /// Content type error + #[display(fmt = "Content type error")] + ContentType, + /// Deserialize error + #[display(fmt = "Json deserialize error: {}", _0)] + Deserialize(JsonError), + /// Payload error + #[display(fmt = "Error that occur during reading payload: {}", _0)] + Payload(PayloadError), } -impl fmt::Debug for InternalError - where T: Send + Sync + fmt::Debug + 'static -{ - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - fmt::Debug::fmt(&self.cause, f) - } -} - -impl fmt::Display for InternalError - where T: Send + Sync + fmt::Debug + 'static -{ - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - fmt::Debug::fmt(&self.cause, f) - } -} - -impl ResponseError for InternalError - where T: Send + Sync + fmt::Debug + 'static -{ +/// Return `BadRequest` for `JsonPayloadError` +impl ResponseError for JsonPayloadError { fn error_response(&self) -> HttpResponse { - HttpResponse::new(self.status) + match *self { + JsonPayloadError::Overflow => { + HttpResponse::new(StatusCode::PAYLOAD_TOO_LARGE) + } + _ => HttpResponse::new(StatusCode::BAD_REQUEST), + } } } -impl Responder for InternalError - where T: Send + Sync + fmt::Debug + 'static -{ - type Item = HttpResponse; - type Error = Error; +/// A set of errors that can occur during parsing request paths +#[derive(Debug, Display, From)] +pub enum PathError { + /// Deserialize error + #[display(fmt = "Path deserialize error: {}", _0)] + Deserialize(de::Error), +} - fn respond_to(self, _: HttpRequest) -> Result { - Err(self.into()) +/// Return `BadRequest` for `PathError` +impl ResponseError for PathError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST } } -/// Helper function that creates wrapper of any error and generate *BAD REQUEST* response. -#[allow(non_snake_case)] -pub fn ErrorBadRequest(err: T) -> Error - where T: Send + Sync + fmt::Debug + 'static -{ - InternalError::new(err, StatusCode::BAD_REQUEST).into() +/// A set of errors that can occur during parsing query strings +#[derive(Debug, Display, From)] +pub enum QueryPayloadError { + /// Deserialize error + #[display(fmt = "Query deserialize error: {}", _0)] + Deserialize(de::Error), } -/// Helper function that creates wrapper of any error and generate *UNAUTHORIZED* response. -#[allow(non_snake_case)] -pub fn ErrorUnauthorized(err: T) -> Error - where T: Send + Sync + fmt::Debug + 'static -{ - InternalError::new(err, StatusCode::UNAUTHORIZED).into() +/// Return `BadRequest` for `QueryPayloadError` +impl ResponseError for QueryPayloadError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } } -/// Helper function that creates wrapper of any error and generate *FORBIDDEN* response. -#[allow(non_snake_case)] -pub fn ErrorForbidden(err: T) -> Error - where T: Send + Sync + fmt::Debug + 'static -{ - InternalError::new(err, StatusCode::FORBIDDEN).into() +/// Error type returned when reading body as lines. +#[derive(From, Display, Debug)] +pub enum ReadlinesError { + /// Error when decoding a line. + #[display(fmt = "Encoding error")] + /// Payload size is bigger than allowed. (default: 256kB) + EncodingError, + /// Payload error. + #[display(fmt = "Error that occur during reading payload: {}", _0)] + Payload(PayloadError), + /// Line limit exceeded. + #[display(fmt = "Line limit exceeded")] + LimitOverflow, + /// ContentType error. + #[display(fmt = "Content-type error")] + ContentTypeError(ContentTypeError), } -/// Helper function that creates wrapper of any error and generate *NOT FOUND* response. -#[allow(non_snake_case)] -pub fn ErrorNotFound(err: T) -> Error - where T: Send + Sync + fmt::Debug + 'static -{ - InternalError::new(err, StatusCode::NOT_FOUND).into() -} - -/// Helper function that creates wrapper of any error and generate *METHOD NOT ALLOWED* response. -#[allow(non_snake_case)] -pub fn ErrorMethodNotAllowed(err: T) -> Error - where T: Send + Sync + fmt::Debug + 'static -{ - InternalError::new(err, StatusCode::METHOD_NOT_ALLOWED).into() -} - -/// Helper function that creates wrapper of any error and generate *REQUEST TIMEOUT* response. -#[allow(non_snake_case)] -pub fn ErrorRequestTimeout(err: T) -> Error - where T: Send + Sync + fmt::Debug + 'static -{ - InternalError::new(err, StatusCode::REQUEST_TIMEOUT).into() -} - -/// Helper function that creates wrapper of any error and generate *CONFLICT* response. -#[allow(non_snake_case)] -pub fn ErrorConflict(err: T) -> Error - where T: Send + Sync + fmt::Debug + 'static -{ - InternalError::new(err, StatusCode::CONFLICT).into() -} - -/// Helper function that creates wrapper of any error and generate *GONE* response. -#[allow(non_snake_case)] -pub fn ErrorGone(err: T) -> Error - where T: Send + Sync + fmt::Debug + 'static -{ - InternalError::new(err, StatusCode::GONE).into() -} - -/// Helper function that creates wrapper of any error and generate *PRECONDITION FAILED* response. -#[allow(non_snake_case)] -pub fn ErrorPreconditionFailed(err: T) -> Error - where T: Send + Sync + fmt::Debug + 'static -{ - InternalError::new(err, StatusCode::PRECONDITION_FAILED).into() -} - -/// Helper function that creates wrapper of any error and generate *EXPECTATION FAILED* response. -#[allow(non_snake_case)] -pub fn ErrorExpectationFailed(err: T) -> Error - where T: Send + Sync + fmt::Debug + 'static -{ - InternalError::new(err, StatusCode::EXPECTATION_FAILED).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *INTERNAL SERVER ERROR* response. -#[allow(non_snake_case)] -pub fn ErrorInternalServerError(err: T) -> Error - where T: Send + Sync + fmt::Debug + 'static -{ - InternalError::new(err, StatusCode::INTERNAL_SERVER_ERROR).into() +/// Return `BadRequest` for `ReadlinesError` +impl ResponseError for ReadlinesError { + fn status_code(&self) -> StatusCode { + match *self { + ReadlinesError::LimitOverflow => StatusCode::PAYLOAD_TOO_LARGE, + _ => StatusCode::BAD_REQUEST, + } + } } #[cfg(test)] mod tests { - use std::env; - use std::error::Error as StdError; - use std::io; - use httparse; - use http::{StatusCode, Error as HttpError}; - use cookie::ParseError as CookieParseError; - use failure; use super::*; #[test] - #[cfg(actix_nightly)] - fn test_nightly() { - let resp: HttpResponse = IoError::new(io::ErrorKind::Other, "test").error_response(); - assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); - } - - #[test] - fn test_into_response() { - let resp: HttpResponse = ParseError::Incomplete.error_response(); + fn test_urlencoded_error() { + let resp: HttpResponse = + UrlencodedError::Overflow { size: 0, limit: 0 }.error_response(); + assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE); + let resp: HttpResponse = UrlencodedError::UnknownLength.error_response(); + assert_eq!(resp.status(), StatusCode::LENGTH_REQUIRED); + let resp: HttpResponse = UrlencodedError::ContentType.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } - let resp: HttpResponse = HttpRangeError::InvalidRange.error_response(); + #[test] + fn test_json_payload_error() { + let resp: HttpResponse = JsonPayloadError::Overflow.error_response(); + assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE); + let resp: HttpResponse = JsonPayloadError::ContentType.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } - let resp: HttpResponse = CookieParseError::EmptyName.error_response(); + #[test] + fn test_query_payload_error() { + let resp: HttpResponse = QueryPayloadError::Deserialize( + serde_urlencoded::from_str::("bad query").unwrap_err(), + ) + .error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } - let resp: HttpResponse = MultipartError::Boundary.error_response(); + #[test] + fn test_readlines_error() { + let resp: HttpResponse = ReadlinesError::LimitOverflow.error_response(); + assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE); + let resp: HttpResponse = ReadlinesError::EncodingError.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - - let err: HttpError = StatusCode::from_u16(10000).err().unwrap().into(); - let resp: HttpResponse = err.error_response(); - assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); - } - - #[test] - fn test_cause() { - let orig = io::Error::new(io::ErrorKind::Other, "other"); - let desc = orig.description().to_owned(); - let e = ParseError::Io(orig); - assert_eq!(format!("{}", e.cause().unwrap()), desc); - } - - #[test] - fn test_error_cause() { - let orig = io::Error::new(io::ErrorKind::Other, "other"); - let desc = orig.description().to_owned(); - let e = Error::from(orig); - assert_eq!(format!("{}", e.cause()), desc); - } - - #[test] - fn test_error_display() { - let orig = io::Error::new(io::ErrorKind::Other, "other"); - let desc = orig.description().to_owned(); - let e = Error::from(orig); - assert_eq!(format!("{}", e), desc); - } - - #[test] - fn test_error_http_response() { - let orig = io::Error::new(io::ErrorKind::Other, "other"); - let e = Error::from(orig); - let resp: HttpResponse = e.into(); - assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); - } - - #[test] - fn test_range_error() { - let e: HttpRangeError = HttpRangeParseError::InvalidRange.into(); - assert_eq!(e, HttpRangeError::InvalidRange); - let e: HttpRangeError = HttpRangeParseError::NoOverlap.into(); - assert_eq!(e, HttpRangeError::NoOverlap); - } - - #[test] - fn test_expect_error() { - let resp: HttpResponse = ExpectError::Encoding.error_response(); - assert_eq!(resp.status(), StatusCode::EXPECTATION_FAILED); - let resp: HttpResponse = ExpectError::UnknownExpect.error_response(); - assert_eq!(resp.status(), StatusCode::EXPECTATION_FAILED); - } - - macro_rules! from { - ($from:expr => $error:pat) => { - match ParseError::from($from) { - e @ $error => { - assert!(format!("{}", e).len() >= 5); - } , - e => unreachable!("{:?}", e) - } - } - } - - macro_rules! from_and_cause { - ($from:expr => $error:pat) => { - match ParseError::from($from) { - e @ $error => { - let desc = format!("{}", e.cause().unwrap()); - assert_eq!(desc, $from.description().to_owned()); - }, - _ => unreachable!("{:?}", $from) - } - } - } - - #[test] - fn test_from() { - from_and_cause!(io::Error::new(io::ErrorKind::Other, "other") => ParseError::Io(..)); - - from!(httparse::Error::HeaderName => ParseError::Header); - from!(httparse::Error::HeaderName => ParseError::Header); - from!(httparse::Error::HeaderValue => ParseError::Header); - from!(httparse::Error::NewLine => ParseError::Header); - from!(httparse::Error::Status => ParseError::Status); - from!(httparse::Error::Token => ParseError::Header); - from!(httparse::Error::TooManyHeaders => ParseError::TooLarge); - from!(httparse::Error::Version => ParseError::Version); - } - - #[test] - fn failure_error() { - const NAME: &str = "RUST_BACKTRACE"; - let old_tb = env::var(NAME); - env::set_var(NAME, "0"); - let error = failure::err_msg("Hello!"); - let resp: Error = error.into(); - assert_eq!(format!("{:?}", resp), "Compat { error: ErrorMessage { msg: \"Hello!\" } }\n\n"); - match old_tb { - Ok(x) => env::set_var(NAME, x), - _ => env::remove_var(NAME), - } } } diff --git a/src/extract.rs b/src/extract.rs new file mode 100644 index 000000000..d43402c73 --- /dev/null +++ b/src/extract.rs @@ -0,0 +1,361 @@ +//! Request extractors +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use actix_http::error::Error; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; + +use crate::dev::Payload; +use crate::request::HttpRequest; + +/// Trait implemented by types that can be extracted from request. +/// +/// Types that implement this trait can be used with `Route` handlers. +pub trait FromRequest: Sized { + /// The associated error which can be returned. + type Error: Into; + + /// Future that resolves to a Self + type Future: Future>; + + /// Configuration for this extractor + type Config: Default + 'static; + + /// Convert request to a Self + fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future; + + /// Convert request to a Self + /// + /// This method uses `Payload::None` as payload stream. + fn extract(req: &HttpRequest) -> Self::Future { + Self::from_request(req, &mut Payload::None) + } + + /// Create and configure config instance. + fn configure(f: F) -> Self::Config + where + F: FnOnce(Self::Config) -> Self::Config, + { + f(Self::Config::default()) + } +} + +/// Optionally extract a field from the request +/// +/// If the FromRequest for T fails, return None rather than returning an error response +/// +/// ## Example +/// +/// ```rust +/// use actix_web::{web, dev, App, Error, HttpRequest, FromRequest}; +/// use actix_web::error::ErrorBadRequest; +/// use futures::future::{ok, err, Ready}; +/// use serde_derive::Deserialize; +/// use rand; +/// +/// #[derive(Debug, Deserialize)] +/// struct Thing { +/// name: String +/// } +/// +/// impl FromRequest for Thing { +/// type Error = Error; +/// type Future = Ready>; +/// type Config = (); +/// +/// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { +/// if rand::random() { +/// ok(Thing { name: "thingy".into() }) +/// } else { +/// err(ErrorBadRequest("no luck")) +/// } +/// +/// } +/// } +/// +/// /// extract `Thing` from request +/// async fn index(supplied_thing: Option) -> String { +/// match supplied_thing { +/// // Puns not intended +/// Some(thing) => format!("Got something: {:?}", thing), +/// None => format!("No thing!") +/// } +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/users/:first").route( +/// web::post().to(index)) +/// ); +/// } +/// ``` +impl FromRequest for Option +where + T: FromRequest, + T::Future: 'static, +{ + type Config = T::Config; + type Error = Error; + type Future = LocalBoxFuture<'static, Result, Error>>; + + #[inline] + fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { + T::from_request(req, payload) + .then(|r| match r { + Ok(v) => ok(Some(v)), + Err(e) => { + log::debug!("Error for Option extractor: {}", e.into()); + ok(None) + } + }) + .boxed_local() + } +} + +/// Optionally extract a field from the request or extract the Error if unsuccessful +/// +/// If the `FromRequest` for T fails, inject Err into handler rather than returning an error response +/// +/// ## Example +/// +/// ```rust +/// use actix_web::{web, dev, App, Result, Error, HttpRequest, FromRequest}; +/// use actix_web::error::ErrorBadRequest; +/// use futures::future::{ok, err, Ready}; +/// use serde_derive::Deserialize; +/// use rand; +/// +/// #[derive(Debug, Deserialize)] +/// struct Thing { +/// name: String +/// } +/// +/// impl FromRequest for Thing { +/// type Error = Error; +/// type Future = Ready>; +/// type Config = (); +/// +/// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { +/// if rand::random() { +/// ok(Thing { name: "thingy".into() }) +/// } else { +/// err(ErrorBadRequest("no luck")) +/// } +/// } +/// } +/// +/// /// extract `Thing` from request +/// async fn index(supplied_thing: Result) -> String { +/// match supplied_thing { +/// Ok(thing) => format!("Got thing: {:?}", thing), +/// Err(e) => format!("Error extracting thing: {}", e) +/// } +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/users/:first").route(web::post().to(index)) +/// ); +/// } +/// ``` +impl FromRequest for Result +where + T: FromRequest + 'static, + T::Error: 'static, + T::Future: 'static, +{ + type Config = T::Config; + type Error = Error; + type Future = LocalBoxFuture<'static, Result, Error>>; + + #[inline] + fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { + T::from_request(req, payload) + .then(|res| match res { + Ok(v) => ok(Ok(v)), + Err(e) => ok(Err(e)), + }) + .boxed_local() + } +} + +#[doc(hidden)] +impl FromRequest for () { + type Config = (); + type Error = Error; + type Future = Ready>; + + fn from_request(_: &HttpRequest, _: &mut Payload) -> Self::Future { + ok(()) + } +} + +macro_rules! tuple_from_req ({$fut_type:ident, $(($n:tt, $T:ident)),+} => { + + /// FromRequest implementation for tuple + #[doc(hidden)] + impl<$($T: FromRequest + 'static),+> FromRequest for ($($T,)+) + { + type Error = Error; + type Future = $fut_type<$($T),+>; + type Config = ($($T::Config),+); + + fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { + $fut_type { + items: <($(Option<$T>,)+)>::default(), + futs: ($($T::from_request(req, payload),)+), + } + } + } + + #[doc(hidden)] + #[pin_project::pin_project] + pub struct $fut_type<$($T: FromRequest),+> { + items: ($(Option<$T>,)+), + futs: ($($T::Future,)+), + } + + impl<$($T: FromRequest),+> Future for $fut_type<$($T),+> + { + type Output = Result<($($T,)+), Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); + + let mut ready = true; + $( + if this.items.$n.is_none() { + match unsafe { Pin::new_unchecked(&mut this.futs.$n) }.poll(cx) { + Poll::Ready(Ok(item)) => { + this.items.$n = Some(item); + } + Poll::Pending => ready = false, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), + } + } + )+ + + if ready { + Poll::Ready(Ok( + ($(this.items.$n.take().unwrap(),)+) + )) + } else { + Poll::Pending + } + } + } +}); + +#[rustfmt::skip] +mod m { + use super::*; + +tuple_from_req!(TupleFromRequest1, (0, A)); +tuple_from_req!(TupleFromRequest2, (0, A), (1, B)); +tuple_from_req!(TupleFromRequest3, (0, A), (1, B), (2, C)); +tuple_from_req!(TupleFromRequest4, (0, A), (1, B), (2, C), (3, D)); +tuple_from_req!(TupleFromRequest5, (0, A), (1, B), (2, C), (3, D), (4, E)); +tuple_from_req!(TupleFromRequest6, (0, A), (1, B), (2, C), (3, D), (4, E), (5, F)); +tuple_from_req!(TupleFromRequest7, (0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G)); +tuple_from_req!(TupleFromRequest8, (0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H)); +tuple_from_req!(TupleFromRequest9, (0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I)); +tuple_from_req!(TupleFromRequest10, (0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I), (9, J)); +} + +#[cfg(test)] +mod tests { + use actix_http::http::header; + use bytes::Bytes; + use serde_derive::Deserialize; + + use super::*; + use crate::test::TestRequest; + use crate::types::{Form, FormConfig}; + + #[derive(Deserialize, Debug, PartialEq)] + struct Info { + hello: String, + } + + #[actix_rt::test] + async fn test_option() { + let (req, mut pl) = TestRequest::with_header( + header::CONTENT_TYPE, + "application/x-www-form-urlencoded", + ) + .data(FormConfig::default().limit(4096)) + .to_http_parts(); + + let r = Option::>::from_request(&req, &mut pl) + .await + .unwrap(); + assert_eq!(r, None); + + let (req, mut pl) = TestRequest::with_header( + header::CONTENT_TYPE, + "application/x-www-form-urlencoded", + ) + .header(header::CONTENT_LENGTH, "9") + .set_payload(Bytes::from_static(b"hello=world")) + .to_http_parts(); + + let r = Option::>::from_request(&req, &mut pl) + .await + .unwrap(); + assert_eq!( + r, + Some(Form(Info { + hello: "world".into() + })) + ); + + let (req, mut pl) = TestRequest::with_header( + header::CONTENT_TYPE, + "application/x-www-form-urlencoded", + ) + .header(header::CONTENT_LENGTH, "9") + .set_payload(Bytes::from_static(b"bye=world")) + .to_http_parts(); + + let r = Option::>::from_request(&req, &mut pl) + .await + .unwrap(); + assert_eq!(r, None); + } + + #[actix_rt::test] + async fn test_result() { + let (req, mut pl) = TestRequest::with_header( + header::CONTENT_TYPE, + "application/x-www-form-urlencoded", + ) + .header(header::CONTENT_LENGTH, "11") + .set_payload(Bytes::from_static(b"hello=world")) + .to_http_parts(); + + let r = Result::, Error>::from_request(&req, &mut pl) + .await + .unwrap() + .unwrap(); + assert_eq!( + r, + Form(Info { + hello: "world".into() + }) + ); + + let (req, mut pl) = TestRequest::with_header( + header::CONTENT_TYPE, + "application/x-www-form-urlencoded", + ) + .header(header::CONTENT_LENGTH, "9") + .set_payload(Bytes::from_static(b"bye=world")) + .to_http_parts(); + + let r = Result::, Error>::from_request(&req, &mut pl) + .await + .unwrap(); + assert!(r.is_err()); + } +} diff --git a/src/extractor.rs b/src/extractor.rs deleted file mode 100644 index 2346365bc..000000000 --- a/src/extractor.rs +++ /dev/null @@ -1,461 +0,0 @@ -use std::str; -use std::ops::{Deref, DerefMut}; - -use bytes::Bytes; -use serde_urlencoded; -use serde::de::{self, DeserializeOwned}; -use futures::future::{Future, FutureResult, result}; -use encoding::all::UTF_8; -use encoding::types::{Encoding, DecoderTrap}; - -use error::{Error, ErrorBadRequest}; -use handler::{Either, FromRequest}; -use httprequest::HttpRequest; -use httpmessage::{HttpMessage, MessageBody, UrlEncoded}; -use de::PathDeserializer; - -/// Extract typed information from the request's path. -/// -/// ## Example -/// -/// ```rust -/// # extern crate bytes; -/// # extern crate actix_web; -/// # extern crate futures; -/// #[macro_use] extern crate serde_derive; -/// use actix_web::{App, Path, Result, http}; -/// -/// /// extract path info from "/{username}/{count}/?index.html" url -/// /// {username} - deserializes to a String -/// /// {count} - - deserializes to a u32 -/// fn index(info: Path<(String, u32)>) -> Result { -/// Ok(format!("Welcome {}! {}", info.0, info.1)) -/// } -/// -/// fn main() { -/// let app = App::new().resource( -/// "/{username}/{count}/?index.html", // <- define path parameters -/// |r| r.method(http::Method::GET).with(index)); // <- use `with` extractor -/// } -/// ``` -/// -/// It is possible to extract path information to a specific type that implements -/// `Deserialize` trait from *serde*. -/// -/// ```rust -/// # extern crate bytes; -/// # extern crate actix_web; -/// # extern crate futures; -/// #[macro_use] extern crate serde_derive; -/// use actix_web::{App, Path, Result, http}; -/// -/// #[derive(Deserialize)] -/// struct Info { -/// username: String, -/// } -/// -/// /// extract path info using serde -/// fn index(info: Path) -> Result { -/// Ok(format!("Welcome {}!", info.username)) -/// } -/// -/// fn main() { -/// let app = App::new().resource( -/// "/{username}/index.html", // <- define path parameters -/// |r| r.method(http::Method::GET).with(index)); // <- use `with` extractor -/// } -/// ``` -pub struct Path{ - inner: T -} - -impl AsRef for Path { - - fn as_ref(&self) -> &T { - &self.inner - } -} - -impl Deref for Path { - type Target = T; - - fn deref(&self) -> &T { - &self.inner - } -} - -impl DerefMut for Path { - fn deref_mut(&mut self) -> &mut T { - &mut self.inner - } -} - -impl Path { - /// Deconstruct to an inner value - pub fn into_inner(self) -> T { - self.inner - } -} - -impl FromRequest for Path - where T: DeserializeOwned, S: 'static -{ - type Result = FutureResult; - - #[inline] - fn from_request(req: &HttpRequest) -> Self::Result { - let req = req.clone(); - result(de::Deserialize::deserialize(PathDeserializer::new(&req)) - .map_err(|e| e.into()) - .map(|inner| Path{inner})) - } -} - -/// Extract typed information from from the request's query. -/// -/// ## Example -/// -/// ```rust -/// # extern crate bytes; -/// # extern crate actix_web; -/// # extern crate futures; -/// #[macro_use] extern crate serde_derive; -/// use actix_web::{App, Query, http}; -/// -/// #[derive(Deserialize)] -/// struct Info { -/// username: String, -/// } -/// -/// // use `with` extractor for query info -/// // this handler get called only if request's query contains `username` field -/// fn index(info: Query) -> String { -/// format!("Welcome {}!", info.username) -/// } -/// -/// fn main() { -/// let app = App::new().resource( -/// "/index.html", -/// |r| r.method(http::Method::GET).with(index)); // <- use `with` extractor -/// } -/// ``` -pub struct Query(T); - -impl Deref for Query { - type Target = T; - - fn deref(&self) -> &T { - &self.0 - } -} - -impl DerefMut for Query { - fn deref_mut(&mut self) -> &mut T { - &mut self.0 - } -} - -impl Query { - /// Deconstruct to a inner value - pub fn into_inner(self) -> T { - self.0 - } -} - -impl FromRequest for Query - where T: de::DeserializeOwned, S: 'static -{ - type Result = FutureResult; - - #[inline] - fn from_request(req: &HttpRequest) -> Self::Result { - let req = req.clone(); - result(serde_urlencoded::from_str::(req.query_string()) - .map_err(|e| e.into()) - .map(Query)) - } -} - -/// Extract typed information from the request's body. -/// -/// To extract typed information from request's body, the type `T` must implement the -/// `Deserialize` trait from *serde*. -/// -/// ## Example -/// -/// It is possible to extract path information to a specific type that implements -/// `Deserialize` trait from *serde*. -/// -/// ```rust -/// # extern crate actix_web; -/// #[macro_use] extern crate serde_derive; -/// use actix_web::{App, Form, Result}; -/// -/// #[derive(Deserialize)] -/// struct FormData { -/// username: String, -/// } -/// -/// /// extract form data using serde -/// /// this handle get called only if content type is *x-www-form-urlencoded* -/// /// and content of the request could be deserialized to a `FormData` struct -/// fn index(form: Form) -> Result { -/// Ok(format!("Welcome {}!", form.username)) -/// } -/// # fn main() {} -/// ``` -pub struct Form(pub T); - -impl Deref for Form { - type Target = T; - - fn deref(&self) -> &T { - &self.0 - } -} - -impl DerefMut for Form { - fn deref_mut(&mut self) -> &mut T { - &mut self.0 - } -} - -impl FromRequest for Form - where T: DeserializeOwned + 'static, S: 'static -{ - type Result = Box>; - - #[inline] - fn from_request(req: &HttpRequest) -> Self::Result { - Box::new(UrlEncoded::new(req.clone()).from_err().map(Form)) - } -} - -/// Request payload extractor. -/// -/// Loads request's payload and construct Bytes instance. -/// -/// ## Example -/// -/// ```rust -/// extern crate bytes; -/// # extern crate actix_web; -/// use actix_web::{App, Result}; -/// -/// /// extract text data from request -/// fn index(body: bytes::Bytes) -> Result { -/// Ok(format!("Body {:?}!", body)) -/// } -/// # fn main() {} -/// ``` -impl FromRequest for Bytes -{ - type Result = Box>; - - #[inline] - fn from_request(req: &HttpRequest) -> Self::Result { - Box::new(MessageBody::new(req.clone()).from_err()) - } -} - -/// Extract text information from the request's body. -/// -/// Text extractor automatically decode body according to the request's charset. -/// -/// ## Example -/// -/// ```rust -/// # extern crate actix_web; -/// use actix_web::{App, Result}; -/// -/// /// extract text data from request -/// fn index(body: String) -> Result { -/// Ok(format!("Body {}!", body)) -/// } -/// # fn main() {} -/// ``` -impl FromRequest for String -{ - type Result = Either, - Box>>; - - #[inline] - fn from_request(req: &HttpRequest) -> Self::Result { - let encoding = match req.encoding() { - Err(_) => return Either::A( - result(Err(ErrorBadRequest("Unknown request charset")))), - Ok(encoding) => encoding, - }; - - Either::B(Box::new( - MessageBody::new(req.clone()) - .from_err() - .and_then(move |body| { - let enc: *const Encoding = encoding as *const Encoding; - if enc == UTF_8 { - Ok(str::from_utf8(body.as_ref()) - .map_err(|_| ErrorBadRequest("Can not decode body"))? - .to_owned()) - } else { - Ok(encoding.decode(&body, DecoderTrap::Strict) - .map_err(|_| ErrorBadRequest("Can not decode body"))?) - } - }))) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::Bytes; - use futures::{Async, Future}; - use http::header; - use router::{Router, Resource}; - use resource::ResourceHandler; - use test::TestRequest; - use server::ServerSettings; - - #[derive(Deserialize, Debug, PartialEq)] - struct Info { - hello: String, - } - - #[test] - fn test_bytes() { - let mut req = TestRequest::with_header(header::CONTENT_LENGTH, "11").finish(); - req.payload_mut().unread_data(Bytes::from_static(b"hello=world")); - - match Bytes::from_request(&req).poll().unwrap() { - Async::Ready(s) => { - assert_eq!(s, Bytes::from_static(b"hello=world")); - }, - _ => unreachable!(), - } - } - - #[test] - fn test_string() { - let mut req = TestRequest::with_header(header::CONTENT_LENGTH, "11").finish(); - req.payload_mut().unread_data(Bytes::from_static(b"hello=world")); - - match String::from_request(&req).poll().unwrap() { - Async::Ready(s) => { - assert_eq!(s, "hello=world"); - }, - _ => unreachable!(), - } - } - - #[test] - fn test_form() { - let mut req = TestRequest::with_header( - header::CONTENT_TYPE, "application/x-www-form-urlencoded") - .header(header::CONTENT_LENGTH, "11") - .finish(); - req.payload_mut().unread_data(Bytes::from_static(b"hello=world")); - - match Form::::from_request(&req).poll().unwrap() { - Async::Ready(s) => { - assert_eq!(s.hello, "world"); - }, - _ => unreachable!(), - } - } - - #[derive(Deserialize)] - struct MyStruct { - key: String, - value: String, - } - - #[derive(Deserialize)] - struct Id { - id: String, - } - - #[derive(Deserialize)] - struct Test2 { - key: String, - value: u32, - } - - #[test] - fn test_request_extract() { - let mut req = TestRequest::with_uri("/name/user1/?id=test").finish(); - - let mut resource = ResourceHandler::<()>::default(); - resource.name("index"); - let mut routes = Vec::new(); - routes.push((Resource::new("index", "/{key}/{value}/"), Some(resource))); - let (router, _) = Router::new("", ServerSettings::default(), routes); - assert!(router.recognize(&mut req).is_some()); - - match Path::::from_request(&req).poll().unwrap() { - Async::Ready(s) => { - assert_eq!(s.key, "name"); - assert_eq!(s.value, "user1"); - }, - _ => unreachable!(), - } - - match Path::<(String, String)>::from_request(&req).poll().unwrap() { - Async::Ready(s) => { - assert_eq!(s.0, "name"); - assert_eq!(s.1, "user1"); - }, - _ => unreachable!(), - } - - match Query::::from_request(&req).poll().unwrap() { - Async::Ready(s) => { - assert_eq!(s.id, "test"); - }, - _ => unreachable!(), - } - - let mut req = TestRequest::with_uri("/name/32/").finish(); - assert!(router.recognize(&mut req).is_some()); - - match Path::::from_request(&req).poll().unwrap() { - Async::Ready(s) => { - assert_eq!(s.as_ref().key, "name"); - assert_eq!(s.value, 32); - }, - _ => unreachable!(), - } - - match Path::<(String, u8)>::from_request(&req).poll().unwrap() { - Async::Ready(s) => { - assert_eq!(s.0, "name"); - assert_eq!(s.1, 32); - }, - _ => unreachable!(), - } - - match Path::>::from_request(&req).poll().unwrap() { - Async::Ready(s) => { - assert_eq!(s.into_inner(), vec!["name".to_owned(), "32".to_owned()]); - }, - _ => unreachable!(), - } - } - - #[test] - fn test_extract_path_signle() { - let mut resource = ResourceHandler::<()>::default(); - resource.name("index"); - let mut routes = Vec::new(); - routes.push((Resource::new("index", "/{value}/"), Some(resource))); - let (router, _) = Router::new("", ServerSettings::default(), routes); - - let mut req = TestRequest::with_uri("/32/").finish(); - assert!(router.recognize(&mut req).is_some()); - - match Path::::from_request(&req).poll().unwrap() { - Async::Ready(s) => { - assert_eq!(s.into_inner(), 32); - }, - _ => unreachable!(), - } - } -} diff --git a/src/fs.rs b/src/fs.rs deleted file mode 100644 index 2d6c0a359..000000000 --- a/src/fs.rs +++ /dev/null @@ -1,580 +0,0 @@ -//! Static files support. - -// //! TODO: needs to re-implement actual files handling, current impl blocks -use std::{io, cmp}; -use std::io::{Read, Seek}; -use std::fmt::Write; -use std::fs::{File, DirEntry, Metadata}; -use std::path::{Path, PathBuf}; -use std::ops::{Deref, DerefMut}; -use std::time::{SystemTime, UNIX_EPOCH}; - -#[cfg(unix)] -use std::os::unix::fs::MetadataExt; - -use bytes::{Bytes, BytesMut, BufMut}; -use futures::{Async, Poll, Future, Stream}; -use futures_cpupool::{CpuPool, CpuFuture}; -use mime_guess::get_mime_type; - -use header; -use error::Error; -use param::FromParam; -use handler::{Handler, RouteHandler, WrapHandler, Responder, Reply}; -use http::{Method, StatusCode}; -use httpmessage::HttpMessage; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; - -/// A file with an associated name; responds with the Content-Type based on the -/// file extension. -#[derive(Debug)] -pub struct NamedFile { - path: PathBuf, - file: File, - md: Metadata, - modified: Option, - cpu_pool: Option, - only_get: bool, -} - -impl NamedFile { - /// Attempts to open a file in read-only mode. - /// - /// # Examples - /// - /// ```rust - /// use actix_web::fs::NamedFile; - /// - /// let file = NamedFile::open("foo.txt"); - /// ``` - pub fn open>(path: P) -> io::Result { - let file = File::open(path.as_ref())?; - let md = file.metadata()?; - let path = path.as_ref().to_path_buf(); - let modified = md.modified().ok(); - let cpu_pool = None; - Ok(NamedFile{path, file, md, modified, cpu_pool, only_get: false}) - } - - /// Allow only GET and HEAD methods - #[inline] - pub fn only_get(mut self) -> Self { - self.only_get = true; - self - } - - /// Returns reference to the underlying `File` object. - #[inline] - pub fn file(&self) -> &File { - &self.file - } - - /// Retrieve the path of this file. - /// - /// # Examples - /// - /// ```rust - /// # use std::io; - /// use actix_web::fs::NamedFile; - /// - /// # fn path() -> io::Result<()> { - /// let file = NamedFile::open("test.txt")?; - /// assert_eq!(file.path().as_os_str(), "foo.txt"); - /// # Ok(()) - /// # } - /// ``` - #[inline] - pub fn path(&self) -> &Path { - self.path.as_path() - } - - /// Set `CpuPool` to use - #[inline] - pub fn set_cpu_pool(mut self, cpu_pool: CpuPool) -> Self { - self.cpu_pool = Some(cpu_pool); - self - } - - fn etag(&self) -> Option { - // This etag format is similar to Apache's. - self.modified.as_ref().map(|mtime| { - let ino = { - #[cfg(unix)] - { self.md.ino() } - #[cfg(not(unix))] - { 0 } - }; - - let dur = mtime.duration_since(UNIX_EPOCH) - .expect("modification time must be after epoch"); - header::EntityTag::strong( - format!("{:x}:{:x}:{:x}:{:x}", - ino, self.md.len(), dur.as_secs(), - dur.subsec_nanos())) - }) - } - - fn last_modified(&self) -> Option { - self.modified.map(|mtime| mtime.into()) - } -} - -impl Deref for NamedFile { - type Target = File; - - fn deref(&self) -> &File { - &self.file - } -} - -impl DerefMut for NamedFile { - fn deref_mut(&mut self) -> &mut File { - &mut self.file - } -} - -/// Returns true if `req` has no `If-Match` header or one which matches `etag`. -fn any_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool { - match req.get_header::() { - None | Some(header::IfMatch::Any) => true, - Some(header::IfMatch::Items(ref items)) => { - if let Some(some_etag) = etag { - for item in items { - if item.strong_eq(some_etag) { - return true; - } - } - } - false - } - } -} - -/// Returns true if `req` doesn't have an `If-None-Match` header matching `req`. -fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool { - match req.get_header::() { - Some(header::IfNoneMatch::Any) => false, - Some(header::IfNoneMatch::Items(ref items)) => { - if let Some(some_etag) = etag { - for item in items { - if item.weak_eq(some_etag) { - return false; - } - } - } - true - } - None => true, - } -} - - -impl Responder for NamedFile { - type Item = HttpResponse; - type Error = io::Error; - - fn respond_to(self, req: HttpRequest) -> Result { - if self.only_get && *req.method() != Method::GET && *req.method() != Method::HEAD { - return Ok(HttpResponse::MethodNotAllowed() - .header(header::CONTENT_TYPE, "text/plain") - .header(header::ALLOW, "GET, HEAD") - .body("This resource only supports GET and HEAD.")) - } - - let etag = self.etag(); - let last_modified = self.last_modified(); - - // check preconditions - let precondition_failed = if !any_match(etag.as_ref(), &req) { - true - } else if let (Some(ref m), Some(header::IfUnmodifiedSince(ref since))) = - (last_modified, req.get_header()) - { - m > since - } else { - false - }; - - // check last modified - let not_modified = if !none_match(etag.as_ref(), &req) { - true - } else if let (Some(ref m), Some(header::IfModifiedSince(ref since))) = - (last_modified, req.get_header()) - { - m <= since - } else { - false - }; - - let mut resp = HttpResponse::Ok(); - - resp - .if_some(self.path().extension(), |ext, resp| { - resp.set(header::ContentType(get_mime_type(&ext.to_string_lossy()))); - }) - .if_some(last_modified, |lm, resp| {resp.set(header::LastModified(lm));}) - .if_some(etag, |etag, resp| {resp.set(header::ETag(etag));}); - - if precondition_failed { - return Ok(resp.status(StatusCode::PRECONDITION_FAILED).finish()) - } else if not_modified { - return Ok(resp.status(StatusCode::NOT_MODIFIED).finish()) - } - - if *req.method() == Method::HEAD { - Ok(resp.finish()) - } else { - let reader = ChunkedReadFile { - size: self.md.len(), - offset: 0, - cpu_pool: self.cpu_pool.unwrap_or_else(|| req.cpu_pool().clone()), - file: Some(self.file), - fut: None, - }; - Ok(resp.streaming(reader)) - } - } -} - -/// A helper created from a `std::fs::File` which reads the file -/// chunk-by-chunk on a `CpuPool`. -pub struct ChunkedReadFile { - size: u64, - offset: u64, - cpu_pool: CpuPool, - file: Option, - fut: Option>, -} - -impl Stream for ChunkedReadFile { - type Item = Bytes; - type Error= Error; - - fn poll(&mut self) -> Poll, Error> { - if self.fut.is_some() { - return match self.fut.as_mut().unwrap().poll()? { - Async::Ready((file, bytes)) => { - self.fut.take(); - self.file = Some(file); - self.offset += bytes.len() as u64; - Ok(Async::Ready(Some(bytes))) - }, - Async::NotReady => Ok(Async::NotReady), - }; - } - - let size = self.size; - let offset = self.offset; - - if size == offset { - Ok(Async::Ready(None)) - } else { - let mut file = self.file.take().expect("Use after completion"); - self.fut = Some(self.cpu_pool.spawn_fn(move || { - let max_bytes = cmp::min(size.saturating_sub(offset), 65_536) as usize; - let mut buf = BytesMut::with_capacity(max_bytes); - file.seek(io::SeekFrom::Start(offset))?; - let nbytes = file.read(unsafe{buf.bytes_mut()})?; - if nbytes == 0 { - return Err(io::ErrorKind::UnexpectedEof.into()) - } - unsafe{buf.advance_mut(nbytes)}; - Ok((file, buf.freeze())) - })); - self.poll() - } - } -} - -/// A directory; responds with the generated directory listing. -#[derive(Debug)] -pub struct Directory{ - base: PathBuf, - path: PathBuf -} - -impl Directory { - pub fn new(base: PathBuf, path: PathBuf) -> Directory { - Directory { base, path } - } - - fn can_list(&self, entry: &io::Result) -> bool { - if let Ok(ref entry) = *entry { - if let Some(name) = entry.file_name().to_str() { - if name.starts_with('.') { - return false - } - } - if let Ok(ref md) = entry.metadata() { - let ft = md.file_type(); - return ft.is_dir() || ft.is_file() || ft.is_symlink() - } - } - false - } -} - -impl Responder for Directory { - type Item = HttpResponse; - type Error = io::Error; - - fn respond_to(self, req: HttpRequest) -> Result { - let index_of = format!("Index of {}", req.path()); - let mut body = String::new(); - let base = Path::new(req.path()); - - for entry in self.path.read_dir()? { - if self.can_list(&entry) { - let entry = entry.unwrap(); - let p = match entry.path().strip_prefix(&self.path) { - Ok(p) => base.join(p), - Err(_) => continue - }; - // show file url as relative to static path - let file_url = format!("{}", p.to_string_lossy()); - - // if file is a directory, add '/' to the end of the name - if let Ok(metadata) = entry.metadata() { - if metadata.is_dir() { - let _ = write!(body, "", - file_url, entry.file_name().to_string_lossy()); - } else { - let _ = write!(body, "
  • {}
  • ", - file_url, entry.file_name().to_string_lossy()); - } - } else { - continue - } - } - } - - let html = format!("\ - {}\ -

    {}

    \ -
      \ - {}\ -
    \n", index_of, index_of, body); - Ok(HttpResponse::Ok() - .content_type("text/html; charset=utf-8") - .body(html)) - } -} - -/// Static files handling -/// -/// `StaticFile` handler must be registered with `App::handler()` method, -/// because `StaticFile` handler requires access sub-path information. -/// -/// ```rust -/// # extern crate actix_web; -/// use actix_web::{fs, App}; -/// -/// fn main() { -/// let app = App::new() -/// .handler("/static", fs::StaticFiles::new(".", true)) -/// .finish(); -/// } -/// ``` -pub struct StaticFiles { - directory: PathBuf, - accessible: bool, - index: Option, - show_index: bool, - cpu_pool: CpuPool, - default: Box>, - _chunk_size: usize, - _follow_symlinks: bool, -} - -impl StaticFiles { - /// Create new `StaticFiles` instance - /// - /// `dir` - base directory - /// - /// `index` - show index for directory - pub fn new>(dir: T, index: bool) -> StaticFiles { - let dir = dir.into(); - - let (dir, access) = match dir.canonicalize() { - Ok(dir) => { - if dir.is_dir() { - (dir, true) - } else { - warn!("Is not directory `{:?}`", dir); - (dir, false) - } - }, - Err(err) => { - warn!("Static files directory `{:?}` error: {}", dir, err); - (dir, false) - } - }; - - StaticFiles { - directory: dir, - accessible: access, - index: None, - show_index: index, - cpu_pool: CpuPool::new(40), - default: Box::new(WrapHandler::new( - |_| HttpResponse::new(StatusCode::NOT_FOUND))), - _chunk_size: 0, - _follow_symlinks: false, - } - } - - /// Set index file - /// - /// Redirects to specific index file for directory "/" instead of - /// showing files listing. - pub fn index_file>(mut self, index: T) -> StaticFiles { - self.index = Some(index.into()); - self - } - - /// Sets default handler which is used when no matched file could be found. - pub fn default_handler>(mut self, handler: H) -> StaticFiles { - self.default = Box::new(WrapHandler::new(handler)); - self - } -} - -impl Handler for StaticFiles { - type Result = Result; - - fn handle(&mut self, req: HttpRequest) -> Self::Result { - if !self.accessible { - Ok(self.default.handle(req)) - } else { - let relpath = match req.match_info().get("tail").map(PathBuf::from_param) { - Some(Ok(path)) => path, - _ => return Ok(self.default.handle(req)) - }; - - // full filepath - let path = self.directory.join(&relpath).canonicalize()?; - - if path.is_dir() { - if let Some(ref redir_index) = self.index { - // TODO: Don't redirect, just return the index content. - // TODO: It'd be nice if there were a good usable URL manipulation library - let mut new_path: String = req.path().to_owned(); - for el in relpath.iter() { - new_path.push_str(&el.to_string_lossy()); - new_path.push('/'); - } - new_path.push_str(redir_index); - HttpResponse::Found() - .header(header::LOCATION, new_path.as_str()) - .finish() - .respond_to(req.without_state()) - } else if self.show_index { - Directory::new(self.directory.clone(), path) - .respond_to(req.without_state())? - .respond_to(req.without_state()) - } else { - Ok(self.default.handle(req)) - } - } else { - NamedFile::open(path)?.set_cpu_pool(self.cpu_pool.clone()) - .respond_to(req.without_state())? - .respond_to(req.without_state()) - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use test::TestRequest; - use http::{header, Method, StatusCode}; - - #[test] - fn test_named_file() { - assert!(NamedFile::open("test--").is_err()); - let mut file = NamedFile::open("Cargo.toml").unwrap() - .set_cpu_pool(CpuPool::new(1)); - { file.file(); - let _f: &File = &file; } - { let _f: &mut File = &mut file; } - - let resp = file.respond_to(HttpRequest::default()).unwrap(); - assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(), "text/x-toml") - } - - #[test] - fn test_named_file_not_allowed() { - let req = TestRequest::default().method(Method::POST).finish(); - let file = NamedFile::open("Cargo.toml").unwrap(); - - let resp = file.only_get().respond_to(req).unwrap(); - assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); - } - - #[test] - fn test_named_file_any_method() { - let req = TestRequest::default().method(Method::POST).finish(); - let file = NamedFile::open("Cargo.toml").unwrap(); - let resp = file.respond_to(req).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - } - - #[test] - fn test_static_files() { - let mut st = StaticFiles::new(".", true); - st.accessible = false; - let resp = st.handle(HttpRequest::default()).respond_to(HttpRequest::default()).unwrap(); - let resp = resp.as_response().expect("HTTP Response"); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); - - st.accessible = true; - st.show_index = false; - let resp = st.handle(HttpRequest::default()).respond_to(HttpRequest::default()).unwrap(); - let resp = resp.as_response().expect("HTTP Response"); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); - - let mut req = HttpRequest::default(); - req.match_info_mut().add("tail", ""); - - st.show_index = true; - let resp = st.handle(req).respond_to(HttpRequest::default()).unwrap(); - let resp = resp.as_response().expect("HTTP Response"); - assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(), "text/html; charset=utf-8"); - assert!(resp.body().is_binary()); - assert!(format!("{:?}", resp.body()).contains("README.md")); - } - - #[test] - fn test_redirect_to_index() { - let mut st = StaticFiles::new(".", false).index_file("index.html"); - let mut req = HttpRequest::default(); - req.match_info_mut().add("tail", "guide"); - - let resp = st.handle(req).respond_to(HttpRequest::default()).unwrap(); - let resp = resp.as_response().expect("HTTP Response"); - assert_eq!(resp.status(), StatusCode::FOUND); - assert_eq!(resp.headers().get(header::LOCATION).unwrap(), "/guide/index.html"); - - let mut req = HttpRequest::default(); - req.match_info_mut().add("tail", "guide/"); - - let resp = st.handle(req).respond_to(HttpRequest::default()).unwrap(); - let resp = resp.as_response().expect("HTTP Response"); - assert_eq!(resp.status(), StatusCode::FOUND); - assert_eq!(resp.headers().get(header::LOCATION).unwrap(), "/guide/index.html"); - } - - #[test] - fn test_redirect_to_index_nested() { - let mut st = StaticFiles::new(".", false).index_file("Cargo.toml"); - let mut req = HttpRequest::default(); - req.match_info_mut().add("tail", "examples/basics"); - - let resp = st.handle(req).respond_to(HttpRequest::default()).unwrap(); - let resp = resp.as_response().expect("HTTP Response"); - assert_eq!(resp.status(), StatusCode::FOUND); - assert_eq!(resp.headers().get(header::LOCATION).unwrap(), "/examples/basics/Cargo.toml"); - } -} diff --git a/src/guard.rs b/src/guard.rs new file mode 100644 index 000000000..3db525f9a --- /dev/null +++ b/src/guard.rs @@ -0,0 +1,498 @@ +//! Route match guards. +//! +//! Guards are one of the ways how actix-web router chooses a +//! handler service. In essence it is just a function that accepts a +//! reference to a `RequestHead` instance and returns a boolean. +//! It is possible to add guards to *scopes*, *resources* +//! and *routes*. Actix provide several guards by default, like various +//! http methods, header, etc. To become a guard, type must implement `Guard` +//! trait. Simple functions coulds guards as well. +//! +//! Guards can not modify the request object. But it is possible +//! to store extra attributes on a request by using the `Extensions` container. +//! Extensions containers are available via the `RequestHead::extensions()` method. +//! +//! ```rust +//! use actix_web::{web, http, dev, guard, App, HttpResponse}; +//! +//! fn main() { +//! App::new().service(web::resource("/index.html").route( +//! web::route() +//! .guard(guard::Post()) +//! .guard(guard::fn_guard(|head| head.method == http::Method::GET)) +//! .to(|| HttpResponse::MethodNotAllowed())) +//! ); +//! } +//! ``` + +#![allow(non_snake_case)] +use actix_http::http::{self, header, uri::Uri, HttpTryFrom}; +use actix_http::RequestHead; + +/// Trait defines resource guards. Guards are used for route selection. +/// +/// Guards can not modify the request object. But it is possible +/// to store extra attributes on a request by using the `Extensions` container. +/// Extensions containers are available via the `RequestHead::extensions()` method. +pub trait Guard { + /// Check if request matches predicate + fn check(&self, request: &RequestHead) -> bool; +} + +/// Create guard object for supplied function. +/// +/// ```rust +/// use actix_web::{guard, web, App, HttpResponse}; +/// +/// fn main() { +/// App::new().service(web::resource("/index.html").route( +/// web::route() +/// .guard( +/// guard::fn_guard( +/// |req| req.headers() +/// .contains_key("content-type"))) +/// .to(|| HttpResponse::MethodNotAllowed())) +/// ); +/// } +/// ``` +pub fn fn_guard(f: F) -> impl Guard +where + F: Fn(&RequestHead) -> bool, +{ + FnGuard(f) +} + +struct FnGuard bool>(F); + +impl Guard for FnGuard +where + F: Fn(&RequestHead) -> bool, +{ + fn check(&self, head: &RequestHead) -> bool { + (self.0)(head) + } +} + +impl Guard for F +where + F: Fn(&RequestHead) -> bool, +{ + fn check(&self, head: &RequestHead) -> bool { + (self)(head) + } +} + +/// Return guard that matches if any of supplied guards. +/// +/// ```rust +/// use actix_web::{web, guard, App, HttpResponse}; +/// +/// fn main() { +/// App::new().service(web::resource("/index.html").route( +/// web::route() +/// .guard(guard::Any(guard::Get()).or(guard::Post())) +/// .to(|| HttpResponse::MethodNotAllowed())) +/// ); +/// } +/// ``` +pub fn Any(guard: F) -> AnyGuard { + AnyGuard(vec![Box::new(guard)]) +} + +/// Matches if any of supplied guards matche. +pub struct AnyGuard(Vec>); + +impl AnyGuard { + /// Add guard to a list of guards to check + pub fn or(mut self, guard: F) -> Self { + self.0.push(Box::new(guard)); + self + } +} + +impl Guard for AnyGuard { + fn check(&self, req: &RequestHead) -> bool { + for p in &self.0 { + if p.check(req) { + return true; + } + } + false + } +} + +/// Return guard that matches if all of the supplied guards. +/// +/// ```rust +/// use actix_web::{guard, web, App, HttpResponse}; +/// +/// fn main() { +/// App::new().service(web::resource("/index.html").route( +/// web::route() +/// .guard( +/// guard::All(guard::Get()).and(guard::Header("content-type", "text/plain"))) +/// .to(|| HttpResponse::MethodNotAllowed())) +/// ); +/// } +/// ``` +pub fn All(guard: F) -> AllGuard { + AllGuard(vec![Box::new(guard)]) +} + +/// Matches if all of supplied guards. +pub struct AllGuard(Vec>); + +impl AllGuard { + /// Add new guard to the list of guards to check + pub fn and(mut self, guard: F) -> Self { + self.0.push(Box::new(guard)); + self + } +} + +impl Guard for AllGuard { + fn check(&self, request: &RequestHead) -> bool { + for p in &self.0 { + if !p.check(request) { + return false; + } + } + true + } +} + +/// Return guard that matches if supplied guard does not match. +pub fn Not(guard: F) -> NotGuard { + NotGuard(Box::new(guard)) +} + +#[doc(hidden)] +pub struct NotGuard(Box); + +impl Guard for NotGuard { + fn check(&self, request: &RequestHead) -> bool { + !self.0.check(request) + } +} + +/// Http method guard +#[doc(hidden)] +pub struct MethodGuard(http::Method); + +impl Guard for MethodGuard { + fn check(&self, request: &RequestHead) -> bool { + request.method == self.0 + } +} + +/// Guard to match *GET* http method +pub fn Get() -> MethodGuard { + MethodGuard(http::Method::GET) +} + +/// Predicate to match *POST* http method +pub fn Post() -> MethodGuard { + MethodGuard(http::Method::POST) +} + +/// Predicate to match *PUT* http method +pub fn Put() -> MethodGuard { + MethodGuard(http::Method::PUT) +} + +/// Predicate to match *DELETE* http method +pub fn Delete() -> MethodGuard { + MethodGuard(http::Method::DELETE) +} + +/// Predicate to match *HEAD* http method +pub fn Head() -> MethodGuard { + MethodGuard(http::Method::HEAD) +} + +/// Predicate to match *OPTIONS* http method +pub fn Options() -> MethodGuard { + MethodGuard(http::Method::OPTIONS) +} + +/// Predicate to match *CONNECT* http method +pub fn Connect() -> MethodGuard { + MethodGuard(http::Method::CONNECT) +} + +/// Predicate to match *PATCH* http method +pub fn Patch() -> MethodGuard { + MethodGuard(http::Method::PATCH) +} + +/// Predicate to match *TRACE* http method +pub fn Trace() -> MethodGuard { + MethodGuard(http::Method::TRACE) +} + +/// Predicate to match specified http method +pub fn Method(method: http::Method) -> MethodGuard { + MethodGuard(method) +} + +/// Return predicate that matches if request contains specified header and +/// value. +pub fn Header(name: &'static str, value: &'static str) -> HeaderGuard { + HeaderGuard( + header::HeaderName::try_from(name).unwrap(), + header::HeaderValue::from_static(value), + ) +} + +#[doc(hidden)] +pub struct HeaderGuard(header::HeaderName, header::HeaderValue); + +impl Guard for HeaderGuard { + fn check(&self, req: &RequestHead) -> bool { + if let Some(val) = req.headers.get(&self.0) { + return val == self.1; + } + false + } +} + +/// Return predicate that matches if request contains specified Host name. +/// +/// ```rust,ignore +/// # extern crate actix_web; +/// use actix_web::{guard::Host, App, HttpResponse}; +/// +/// fn main() { +/// App::new().resource("/index.html", |r| { +/// r.route() +/// .guard(Host("www.rust-lang.org")) +/// .f(|_| HttpResponse::MethodNotAllowed()) +/// }); +/// } +/// ``` +pub fn Host>(host: H) -> HostGuard { + HostGuard(host.as_ref().to_string(), None) +} + +fn get_host_uri(req: &RequestHead) -> Option { + use core::str::FromStr; + req.headers + .get(header::HOST) + .and_then(|host_value| host_value.to_str().ok()) + .or_else(|| req.uri.host()) + .map(|host: &str| Uri::from_str(host).ok()) + .and_then(|host_success| host_success) +} + +#[doc(hidden)] +pub struct HostGuard(String, Option); + +impl HostGuard { + /// Set request scheme to match + pub fn scheme>(mut self, scheme: H) -> HostGuard { + self.1 = Some(scheme.as_ref().to_string()); + self + } +} + +impl Guard for HostGuard { + fn check(&self, req: &RequestHead) -> bool { + let req_host_uri = if let Some(uri) = get_host_uri(req) { + uri + } else { + return false; + }; + + if let Some(uri_host) = req_host_uri.host() { + if self.0 != uri_host { + return false; + } + } else { + return false; + } + + if let Some(ref scheme) = self.1 { + if let Some(ref req_host_uri_scheme) = req_host_uri.scheme_str() { + return scheme == req_host_uri_scheme; + } + } + + true + } +} + +#[cfg(test)] +mod tests { + use actix_http::http::{header, Method}; + + use super::*; + use crate::test::TestRequest; + + #[test] + fn test_header() { + let req = TestRequest::with_header(header::TRANSFER_ENCODING, "chunked") + .to_http_request(); + + let pred = Header("transfer-encoding", "chunked"); + assert!(pred.check(req.head())); + + let pred = Header("transfer-encoding", "other"); + assert!(!pred.check(req.head())); + + let pred = Header("content-type", "other"); + assert!(!pred.check(req.head())); + } + + #[test] + fn test_host() { + let req = TestRequest::default() + .header( + header::HOST, + header::HeaderValue::from_static("www.rust-lang.org"), + ) + .to_http_request(); + + let pred = Host("www.rust-lang.org"); + assert!(pred.check(req.head())); + + let pred = Host("www.rust-lang.org").scheme("https"); + assert!(pred.check(req.head())); + + let pred = Host("blog.rust-lang.org"); + assert!(!pred.check(req.head())); + + let pred = Host("blog.rust-lang.org").scheme("https"); + assert!(!pred.check(req.head())); + + let pred = Host("crates.io"); + assert!(!pred.check(req.head())); + + let pred = Host("localhost"); + assert!(!pred.check(req.head())); + } + + #[test] + fn test_host_scheme() { + let req = TestRequest::default() + .header( + header::HOST, + header::HeaderValue::from_static("https://www.rust-lang.org"), + ) + .to_http_request(); + + let pred = Host("www.rust-lang.org").scheme("https"); + assert!(pred.check(req.head())); + + let pred = Host("www.rust-lang.org"); + assert!(pred.check(req.head())); + + let pred = Host("www.rust-lang.org").scheme("http"); + assert!(!pred.check(req.head())); + + let pred = Host("blog.rust-lang.org"); + assert!(!pred.check(req.head())); + + let pred = Host("blog.rust-lang.org").scheme("https"); + assert!(!pred.check(req.head())); + + let pred = Host("crates.io").scheme("https"); + assert!(!pred.check(req.head())); + + let pred = Host("localhost"); + assert!(!pred.check(req.head())); + } + + #[test] + fn test_host_without_header() { + let req = TestRequest::default() + .uri("www.rust-lang.org") + .to_http_request(); + + let pred = Host("www.rust-lang.org"); + assert!(pred.check(req.head())); + + let pred = Host("www.rust-lang.org").scheme("https"); + assert!(pred.check(req.head())); + + let pred = Host("blog.rust-lang.org"); + assert!(!pred.check(req.head())); + + let pred = Host("blog.rust-lang.org").scheme("https"); + assert!(!pred.check(req.head())); + + let pred = Host("crates.io"); + assert!(!pred.check(req.head())); + + let pred = Host("localhost"); + assert!(!pred.check(req.head())); + } + + #[test] + fn test_methods() { + let req = TestRequest::default().to_http_request(); + let req2 = TestRequest::default() + .method(Method::POST) + .to_http_request(); + + assert!(Get().check(req.head())); + assert!(!Get().check(req2.head())); + assert!(Post().check(req2.head())); + assert!(!Post().check(req.head())); + + let r = TestRequest::default().method(Method::PUT).to_http_request(); + assert!(Put().check(r.head())); + assert!(!Put().check(req.head())); + + let r = TestRequest::default() + .method(Method::DELETE) + .to_http_request(); + assert!(Delete().check(r.head())); + assert!(!Delete().check(req.head())); + + let r = TestRequest::default() + .method(Method::HEAD) + .to_http_request(); + assert!(Head().check(r.head())); + assert!(!Head().check(req.head())); + + let r = TestRequest::default() + .method(Method::OPTIONS) + .to_http_request(); + assert!(Options().check(r.head())); + assert!(!Options().check(req.head())); + + let r = TestRequest::default() + .method(Method::CONNECT) + .to_http_request(); + assert!(Connect().check(r.head())); + assert!(!Connect().check(req.head())); + + let r = TestRequest::default() + .method(Method::PATCH) + .to_http_request(); + assert!(Patch().check(r.head())); + assert!(!Patch().check(req.head())); + + let r = TestRequest::default() + .method(Method::TRACE) + .to_http_request(); + assert!(Trace().check(r.head())); + assert!(!Trace().check(req.head())); + } + + #[test] + fn test_preds() { + let r = TestRequest::default() + .method(Method::TRACE) + .to_http_request(); + + assert!(Not(Get()).check(r.head())); + assert!(!Not(Trace()).check(r.head())); + + assert!(All(Trace()).and(Trace()).check(r.head())); + assert!(!All(Get()).and(Trace()).check(r.head())); + + assert!(Any(Get()).or(Trace()).check(r.head())); + assert!(!Any(Get()).or(Get()).check(r.head())); + } +} diff --git a/src/handler.rs b/src/handler.rs index 6041dc288..a7023422b 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,442 +1,291 @@ -use std::ops::Deref; +use std::convert::Infallible; +use std::future::Future; use std::marker::PhantomData; -use futures::Poll; -use futures::future::{Future, FutureResult, ok, err}; +use std::pin::Pin; +use std::task::{Context, Poll}; -use error::Error; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; +use actix_http::{Error, Response}; +use actix_service::{Service, ServiceFactory}; +use futures::future::{ok, Ready}; +use futures::ready; +use pin_project::pin_project; -/// Trait defines object that could be registered as route handler -#[allow(unused_variables)] -pub trait Handler: 'static { +use crate::extract::FromRequest; +use crate::request::HttpRequest; +use crate::responder::Responder; +use crate::service::{ServiceRequest, ServiceResponse}; - /// The type of value that handler will return. - type Result: Responder; - - /// Handle request - fn handle(&mut self, req: HttpRequest) -> Self::Result; -} - -/// Trait implemented by types that generate responses for clients. -/// -/// Types that implement this trait can be used as the return type of a handler. -pub trait Responder { - /// The associated item which can be returned. - type Item: Into; - - /// The associated error which can be returned. - type Error: Into; - - /// Convert itself to `Reply` or `Error`. - fn respond_to(self, req: HttpRequest) -> Result; -} - -/// Trait implemented by types that can be extracted from request. -/// -/// Types that implement this trait can be used with `Route::with()` method. -pub trait FromRequest: Sized where S: 'static +/// Async handler converter factory +pub trait Factory: Clone + 'static +where + R: Future, + O: Responder, { - type Result: Future; - - fn from_request(req: &HttpRequest) -> Self::Result; + fn call(&self, param: T) -> R; } -/// Combines two different responder types into a single type -/// -/// ```rust -/// # extern crate actix_web; -/// # extern crate futures; -/// # use futures::future::Future; -/// use futures::future::result; -/// use actix_web::{Either, Error, HttpRequest, HttpResponse, AsyncResponder}; -/// -/// type RegisterResult = Either>>; -/// -/// -/// fn index(req: HttpRequest) -> RegisterResult { -/// if is_a_variant() { // <- choose variant A -/// Either::A( -/// HttpResponse::BadRequest().body("Bad data")) -/// } else { -/// Either::B( // <- variant B -/// result(Ok(HttpResponse::Ok() -/// .content_type("text/html") -/// .body("Hello!"))) -/// .responder()) -/// } -/// } -/// # fn is_a_variant() -> bool { true } -/// # fn main() {} -/// ``` -#[derive(Debug)] -pub enum Either { - /// First branch of the type - A(A), - /// Second branch of the type - B(B), -} - -impl Responder for Either - where A: Responder, B: Responder +impl Factory<(), R, O> for F +where + F: Fn() -> R + Clone + 'static, + R: Future, + O: Responder, { - type Item = Reply; - type Error = Error; + fn call(&self, _: ()) -> R { + (self)() + } +} - fn respond_to(self, req: HttpRequest) -> Result { - match self { - Either::A(a) => match a.respond_to(req) { - Ok(val) => Ok(val.into()), - Err(err) => Err(err.into()), - }, - Either::B(b) => match b.respond_to(req) { - Ok(val) => Ok(val.into()), - Err(err) => Err(err.into()), - }, +#[doc(hidden)] +pub struct Handler +where + F: Factory, + R: Future, + O: Responder, +{ + hnd: F, + _t: PhantomData<(T, R, O)>, +} + +impl Handler +where + F: Factory, + R: Future, + O: Responder, +{ + pub fn new(hnd: F) -> Self { + Handler { + hnd, + _t: PhantomData, } } } -impl Future for Either - where A: Future, - B: Future, +impl Clone for Handler +where + F: Factory, + R: Future, + O: Responder, { - type Item = I; - type Error = E; - - fn poll(&mut self) -> Poll { - match *self { - Either::A(ref mut fut) => fut.poll(), - Either::B(ref mut fut) => fut.poll(), + fn clone(&self) -> Self { + Handler { + hnd: self.hnd.clone(), + _t: PhantomData, } } } -/// Convenience trait that converts `Future` object to a `Boxed` future -/// -/// For example loading json from request's body is async operation. -/// -/// ```rust -/// # extern crate actix_web; -/// # extern crate futures; -/// # #[macro_use] extern crate serde_derive; -/// use futures::future::Future; -/// use actix_web::{ -/// App, HttpRequest, HttpResponse, HttpMessage, Error, AsyncResponder}; -/// -/// #[derive(Deserialize, Debug)] -/// struct MyObj { -/// name: String, -/// } -/// -/// fn index(mut req: HttpRequest) -> Box> { -/// req.json() // <- get JsonBody future -/// .from_err() -/// .and_then(|val: MyObj| { // <- deserialized value -/// Ok(HttpResponse::Ok().into()) -/// }) -/// // Construct boxed future by using `AsyncResponder::responder()` method -/// .responder() -/// } -/// # fn main() {} -/// ``` -pub trait AsyncResponder: Sized { - fn responder(self) -> Box>; -} - -impl AsyncResponder for F - where F: Future + 'static, - I: Responder + 'static, - E: Into + 'static, +impl Service for Handler +where + F: Factory, + R: Future, + O: Responder, { - fn responder(self) -> Box> { - Box::new(self) + type Request = (T, HttpRequest); + type Response = ServiceResponse; + type Error = Infallible; + type Future = HandlerServiceResponse; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, (param, req): (T, HttpRequest)) -> Self::Future { + HandlerServiceResponse { + fut: self.hnd.call(param), + fut2: None, + req: Some(req), + } } } -/// Handler for Fn() -impl Handler for F - where F: Fn(HttpRequest) -> R + 'static, - R: Responder + 'static +#[doc(hidden)] +#[pin_project] +pub struct HandlerServiceResponse +where + T: Future, + R: Responder, { - type Result = R; + #[pin] + fut: T, + #[pin] + fut2: Option, + req: Option, +} - fn handle(&mut self, req: HttpRequest) -> R { - (self)(req) +impl Future for HandlerServiceResponse +where + T: Future, + R: Responder, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.as_mut().project(); + + if let Some(fut) = this.fut2.as_pin_mut() { + return match fut.poll(cx) { + Poll::Ready(Ok(res)) => { + Poll::Ready(Ok(ServiceResponse::new(this.req.take().unwrap(), res))) + } + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => { + let res: Response = e.into().into(); + Poll::Ready(Ok(ServiceResponse::new(this.req.take().unwrap(), res))) + } + }; + } + + match this.fut.poll(cx) { + Poll::Ready(res) => { + let fut = res.respond_to(this.req.as_ref().unwrap()); + self.as_mut().project().fut2.set(Some(fut)); + self.poll(cx) + } + Poll::Pending => Poll::Pending, + } } } -/// Represents response process. -pub struct Reply(ReplyItem); - -pub(crate) enum ReplyItem { - Message(HttpResponse), - Future(Box>), +/// Extract arguments from request +pub struct Extract { + service: S, + _t: PhantomData, } -impl Reply { +impl Extract { + pub fn new(service: S) -> Self { + Extract { + service, + _t: PhantomData, + } + } +} - /// Create async response - #[inline] - pub fn async(fut: F) -> Reply - where F: Future + 'static +impl ServiceFactory for Extract +where + S: Service< + Request = (T, HttpRequest), + Response = ServiceResponse, + Error = Infallible, + > + Clone, +{ + type Config = (); + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = (Error, ServiceRequest); + type InitError = (); + type Service = ExtractService; + type Future = Ready>; + + fn new_service(&self, _: &()) -> Self::Future { + ok(ExtractService { + _t: PhantomData, + service: self.service.clone(), + }) + } +} + +pub struct ExtractService { + service: S, + _t: PhantomData, +} + +impl Service for ExtractService +where + S: Service< + Request = (T, HttpRequest), + Response = ServiceResponse, + Error = Infallible, + > + Clone, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = (Error, ServiceRequest); + type Future = ExtractResponse; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: ServiceRequest) -> Self::Future { + let (req, mut payload) = req.into_parts(); + let fut = T::from_request(&req, &mut payload); + + ExtractResponse { + fut, + req, + fut_s: None, + service: self.service.clone(), + } + } +} + +#[pin_project] +pub struct ExtractResponse { + req: HttpRequest, + service: S, + #[pin] + fut: T::Future, + #[pin] + fut_s: Option, +} + +impl Future for ExtractResponse +where + S: Service< + Request = (T, HttpRequest), + Response = ServiceResponse, + Error = Infallible, + >, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.as_mut().project(); + + if let Some(fut) = this.fut_s.as_pin_mut() { + return fut.poll(cx).map_err(|_| panic!()); + } + + match ready!(this.fut.poll(cx)) { + Err(e) => { + let req = ServiceRequest::new(this.req.clone()); + Poll::Ready(Err((e.into(), req))) + } + Ok(item) => { + let fut = Some(this.service.call((item, this.req.clone()))); + self.as_mut().project().fut_s.set(fut); + self.poll(cx) + } + } + } +} + +/// FromRequest trait impl for tuples +macro_rules! factory_tuple ({ $(($n:tt, $T:ident)),+} => { + impl Factory<($($T,)+), Res, O> for Func + where Func: Fn($($T,)+) -> Res + Clone + 'static, + Res: Future, + O: Responder, { - Reply(ReplyItem::Future(Box::new(fut))) - } - - /// Send response - #[inline] - pub fn response>(response: R) -> Reply { - Reply(ReplyItem::Message(response.into())) - } - - #[inline] - pub(crate) fn into(self) -> ReplyItem { - self.0 - } - - #[cfg(test)] - pub(crate) fn as_response(&self) -> Option<&HttpResponse> { - match self.0 { - ReplyItem::Message(ref resp) => Some(resp), - _ => None, + fn call(&self, param: ($($T,)+)) -> Res { + (self)($(param.$n,)+) } } -} - -impl Responder for Reply { - type Item = Reply; - type Error = Error; - - fn respond_to(self, _: HttpRequest) -> Result { - Ok(self) - } -} - -impl Responder for HttpResponse { - type Item = Reply; - type Error = Error; - - #[inline] - fn respond_to(self, _: HttpRequest) -> Result { - Ok(Reply(ReplyItem::Message(self))) - } -} - -impl From for Reply { - - #[inline] - fn from(resp: HttpResponse) -> Reply { - Reply(ReplyItem::Message(resp)) - } -} - -impl> Responder for Result -{ - type Item = ::Item; - type Error = Error; - - fn respond_to(self, req: HttpRequest) -> Result { - match self { - Ok(val) => match val.respond_to(req) { - Ok(val) => Ok(val), - Err(err) => Err(err.into()), - }, - Err(err) => Err(err.into()), - } - } -} - -impl> From> for Reply { - #[inline] - fn from(res: Result) -> Self { - match res { - Ok(val) => val, - Err(err) => Reply(ReplyItem::Message(err.into().into())), - } - } -} - -impl> From> for Reply { - #[inline] - fn from(res: Result) -> Self { - match res { - Ok(val) => Reply(ReplyItem::Message(val)), - Err(err) => Reply(ReplyItem::Message(err.into().into())), - } - } -} - -impl From>> for Reply { - #[inline] - fn from(fut: Box>) -> Reply { - Reply(ReplyItem::Future(fut)) - } -} - -/// Convenience type alias -pub type FutureResponse = Box>; - -impl Responder for Box> - where I: Responder + 'static, - E: Into + 'static -{ - type Item = Reply; - type Error = Error; - - #[inline] - fn respond_to(self, req: HttpRequest) -> Result { - let fut = self.map_err(|e| e.into()) - .then(move |r| { - match r.respond_to(req) { - Ok(reply) => match reply.into().0 { - ReplyItem::Message(resp) => ok(resp), - _ => panic!("Nested async replies are not supported"), - }, - Err(e) => err(e), - } - }); - Ok(Reply::async(fut)) - } -} - -/// Trait defines object that could be registered as resource route -pub(crate) trait RouteHandler: 'static { - fn handle(&mut self, req: HttpRequest) -> Reply; -} - -/// Route handler wrapper for Handler -pub(crate) -struct WrapHandler - where H: Handler, - R: Responder, - S: 'static, -{ - h: H, - s: PhantomData, -} - -impl WrapHandler - where H: Handler, - R: Responder, - S: 'static, -{ - pub fn new(h: H) -> Self { - WrapHandler{h, s: PhantomData} - } -} - -impl RouteHandler for WrapHandler - where H: Handler, - R: Responder + 'static, - S: 'static, -{ - fn handle(&mut self, req: HttpRequest) -> Reply { - let req2 = req.without_state(); - match self.h.handle(req).respond_to(req2) { - Ok(reply) => reply.into(), - Err(err) => Reply::response(err.into()), - } - } -} - -/// Async route handler -pub(crate) -struct AsyncHandler - where H: Fn(HttpRequest) -> F + 'static, - F: Future + 'static, - R: Responder + 'static, - E: Into + 'static, - S: 'static, -{ - h: Box, - s: PhantomData, -} - -impl AsyncHandler - where H: Fn(HttpRequest) -> F + 'static, - F: Future + 'static, - R: Responder + 'static, - E: Into + 'static, - S: 'static, -{ - pub fn new(h: H) -> Self { - AsyncHandler{h: Box::new(h), s: PhantomData} - } -} - -impl RouteHandler for AsyncHandler - where H: Fn(HttpRequest) -> F + 'static, - F: Future + 'static, - R: Responder + 'static, - E: Into + 'static, - S: 'static, -{ - fn handle(&mut self, req: HttpRequest) -> Reply { - let req2 = req.without_state(); - let fut = (self.h)(req) - .map_err(|e| e.into()) - .then(move |r| { - match r.respond_to(req2) { - Ok(reply) => match reply.into().0 { - ReplyItem::Message(resp) => ok(resp), - _ => panic!("Nested async replies are not supported"), - }, - Err(e) => err(e), - } - }); - Reply::async(fut) - } -} - -/// Access an application state -/// -/// `S` - application state type -/// -/// ## Example -/// -/// ```rust -/// # extern crate bytes; -/// # extern crate actix_web; -/// # extern crate futures; -/// #[macro_use] extern crate serde_derive; -/// use actix_web::{App, Path, State, http}; -/// -/// /// Application state -/// struct MyApp {msg: &'static str} -/// -/// #[derive(Deserialize)] -/// struct Info { -/// username: String, -/// } -/// -/// /// extract path info using serde -/// fn index(state: State, info: Path) -> String { -/// format!("{} {}!", state.msg, info.username) -/// } -/// -/// fn main() { -/// let app = App::with_state(MyApp{msg: "Welcome"}).resource( -/// "/{username}/index.html", // <- define path parameters -/// |r| r.method(http::Method::GET).with2(index)); // <- use `with` extractor -/// } -/// ``` -pub struct State (HttpRequest); - -impl Deref for State { - type Target = S; - - fn deref(&self) -> &S { - self.0.state() - } -} - -impl FromRequest for State -{ - type Result = FutureResult; - - #[inline] - fn from_request(req: &HttpRequest) -> Self::Result { - ok(State(req.clone())) - } +}); + +#[rustfmt::skip] +mod m { + use super::*; + +factory_tuple!((0, A)); +factory_tuple!((0, A), (1, B)); +factory_tuple!((0, A), (1, B), (2, C)); +factory_tuple!((0, A), (1, B), (2, C), (3, D)); +factory_tuple!((0, A), (1, B), (2, C), (3, D), (4, E)); +factory_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F)); +factory_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G)); +factory_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H)); +factory_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I)); +factory_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I), (9, J)); } diff --git a/src/header/common/content_disposition.rs b/src/header/common/content_disposition.rs deleted file mode 100644 index 0fcd6ee09..000000000 --- a/src/header/common/content_disposition.rs +++ /dev/null @@ -1,264 +0,0 @@ -// # References -// -// "The Content-Disposition Header Field" https://www.ietf.org/rfc/rfc2183.txt -// "The Content-Disposition Header Field in the Hypertext Transfer Protocol (HTTP)" https://www.ietf.org/rfc/rfc6266.txt -// "Returning Values from Forms: multipart/form-data" https://www.ietf.org/rfc/rfc2388.txt -// Browser conformance tests at: http://greenbytes.de/tech/tc2231/ -// IANA assignment: http://www.iana.org/assignments/cont-disp/cont-disp.xhtml - -use language_tags::LanguageTag; -use std::fmt; -use unicase; - -use header::{Header, Raw, parsing}; -use header::parsing::{parse_extended_value, http_percent_encode}; -use header::shared::Charset; - -/// The implied disposition of the content of the HTTP body. -#[derive(Clone, Debug, PartialEq)] -pub enum DispositionType { - /// Inline implies default processing - Inline, - /// Attachment implies that the recipient should prompt the user to save the response locally, - /// rather than process it normally (as per its media type). - Attachment, - /// Extension type. Should be handled by recipients the same way as Attachment - Ext(String) -} - -/// A parameter to the disposition type. -#[derive(Clone, Debug, PartialEq)] -pub enum DispositionParam { - /// A Filename consisting of a Charset, an optional LanguageTag, and finally a sequence of - /// bytes representing the filename - Filename(Charset, Option, Vec), - /// Extension type consisting of token and value. Recipients should ignore unrecognized - /// parameters. - Ext(String, String) -} - -/// A `Content-Disposition` header, (re)defined in [RFC6266](https://tools.ietf.org/html/rfc6266). -/// -/// The Content-Disposition response header field is used to convey -/// additional information about how to process the response payload, and -/// also can be used to attach additional metadata, such as the filename -/// to use when saving the response payload locally. -/// -/// # ABNF - -/// ```text -/// content-disposition = "Content-Disposition" ":" -/// disposition-type *( ";" disposition-parm ) -/// -/// disposition-type = "inline" | "attachment" | disp-ext-type -/// ; case-insensitive -/// -/// disp-ext-type = token -/// -/// disposition-parm = filename-parm | disp-ext-parm -/// -/// filename-parm = "filename" "=" value -/// | "filename*" "=" ext-value -/// -/// disp-ext-parm = token "=" value -/// | ext-token "=" ext-value -/// -/// ext-token = -/// ``` -/// -/// # Example -/// -/// ``` -/// use hyper::header::{Headers, ContentDisposition, DispositionType, DispositionParam, Charset}; -/// -/// let mut headers = Headers::new(); -/// headers.set(ContentDisposition { -/// disposition: DispositionType::Attachment, -/// parameters: vec![DispositionParam::Filename( -/// Charset::Iso_8859_1, // The character set for the bytes of the filename -/// None, // The optional language tag (see `language-tag` crate) -/// b"\xa9 Copyright 1989.txt".to_vec() // the actual bytes of the filename -/// )] -/// }); -/// ``` -#[derive(Clone, Debug, PartialEq)] -pub struct ContentDisposition { - /// The disposition - pub disposition: DispositionType, - /// Disposition parameters - pub parameters: Vec, -} - -impl Header for ContentDisposition { - fn header_name() -> &'static str { - static NAME: &'static str = "Content-Disposition"; - NAME - } - - fn parse_header(raw: &Raw) -> ::Result { - parsing::from_one_raw_str(raw).and_then(|s: String| { - let mut sections = s.split(';'); - let disposition = match sections.next() { - Some(s) => s.trim(), - None => return Err(::Error::Header), - }; - - let mut cd = ContentDisposition { - disposition: if unicase::eq_ascii(&*disposition, "inline") { - DispositionType::Inline - } else if unicase::eq_ascii(&*disposition, "attachment") { - DispositionType::Attachment - } else { - DispositionType::Ext(disposition.to_owned()) - }, - parameters: Vec::new(), - }; - - for section in sections { - let mut parts = section.splitn(2, '='); - - let key = if let Some(key) = parts.next() { - key.trim() - } else { - return Err(::Error::Header); - }; - - let val = if let Some(val) = parts.next() { - val.trim() - } else { - return Err(::Error::Header); - }; - - cd.parameters.push( - if unicase::eq_ascii(&*key, "filename") { - DispositionParam::Filename( - Charset::Ext("UTF-8".to_owned()), None, - val.trim_matches('"').as_bytes().to_owned()) - } else if unicase::eq_ascii(&*key, "filename*") { - let extended_value = try!(parse_extended_value(val)); - DispositionParam::Filename(extended_value.charset, extended_value.language_tag, extended_value.value) - } else { - DispositionParam::Ext(key.to_owned(), val.trim_matches('"').to_owned()) - } - ); - } - - Ok(cd) - }) - } - - #[inline] - fn fmt_header(&self, f: &mut ::header::Formatter) -> fmt::Result { - f.fmt_line(self) - } -} - -impl fmt::Display for ContentDisposition { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.disposition { - DispositionType::Inline => try!(write!(f, "inline")), - DispositionType::Attachment => try!(write!(f, "attachment")), - DispositionType::Ext(ref s) => try!(write!(f, "{}", s)), - } - for param in &self.parameters { - match *param { - DispositionParam::Filename(ref charset, ref opt_lang, ref bytes) => { - let mut use_simple_format: bool = false; - if opt_lang.is_none() { - if let Charset::Ext(ref ext) = *charset { - if unicase::eq_ascii(&**ext, "utf-8") { - use_simple_format = true; - } - } - } - if use_simple_format { - try!(write!(f, "; filename=\"{}\"", - match String::from_utf8(bytes.clone()) { - Ok(s) => s, - Err(_) => return Err(fmt::Error), - })); - } else { - try!(write!(f, "; filename*={}'", charset)); - if let Some(ref lang) = *opt_lang { - try!(write!(f, "{}", lang)); - }; - try!(write!(f, "'")); - try!(http_percent_encode(f, bytes)) - } - }, - DispositionParam::Ext(ref k, ref v) => try!(write!(f, "; {}=\"{}\"", k, v)), - } - } - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::{ContentDisposition,DispositionType,DispositionParam}; - use ::header::Header; - use ::header::shared::Charset; - - #[test] - fn test_parse_header() { - assert!(ContentDisposition::parse_header(&"".into()).is_err()); - - let a = "form-data; dummy=3; name=upload;\r\n filename=\"sample.png\"".into(); - let a: ContentDisposition = ContentDisposition::parse_header(&a).unwrap(); - let b = ContentDisposition { - disposition: DispositionType::Ext("form-data".to_owned()), - parameters: vec![ - DispositionParam::Ext("dummy".to_owned(), "3".to_owned()), - DispositionParam::Ext("name".to_owned(), "upload".to_owned()), - DispositionParam::Filename( - Charset::Ext("UTF-8".to_owned()), - None, - "sample.png".bytes().collect()) ] - }; - assert_eq!(a, b); - - let a = "attachment; filename=\"image.jpg\"".into(); - let a: ContentDisposition = ContentDisposition::parse_header(&a).unwrap(); - let b = ContentDisposition { - disposition: DispositionType::Attachment, - parameters: vec![ - DispositionParam::Filename( - Charset::Ext("UTF-8".to_owned()), - None, - "image.jpg".bytes().collect()) ] - }; - assert_eq!(a, b); - - let a = "attachment; filename*=UTF-8''%c2%a3%20and%20%e2%82%ac%20rates".into(); - let a: ContentDisposition = ContentDisposition::parse_header(&a).unwrap(); - let b = ContentDisposition { - disposition: DispositionType::Attachment, - parameters: vec![ - DispositionParam::Filename( - Charset::Ext("UTF-8".to_owned()), - None, - vec![0xc2, 0xa3, 0x20, b'a', b'n', b'd', 0x20, - 0xe2, 0x82, 0xac, 0x20, b'r', b'a', b't', b'e', b's']) ] - }; - assert_eq!(a, b); - } - - #[test] - fn test_display() { - let as_string = "attachment; filename*=UTF-8'en'%C2%A3%20and%20%E2%82%AC%20rates"; - let a = as_string.into(); - let a: ContentDisposition = ContentDisposition::parse_header(&a).unwrap(); - let display_rendered = format!("{}",a); - assert_eq!(as_string, display_rendered); - - let a = "attachment; filename*=UTF-8''black%20and%20white.csv".into(); - let a: ContentDisposition = ContentDisposition::parse_header(&a).unwrap(); - let display_rendered = format!("{}",a); - assert_eq!("attachment; filename=\"black and white.csv\"".to_owned(), display_rendered); - - let a = "attachment; filename=colourful.csv".into(); - let a: ContentDisposition = ContentDisposition::parse_header(&a).unwrap(); - let display_rendered = format!("{}",a); - assert_eq!("attachment; filename=\"colourful.csv\"".to_owned(), display_rendered); - } -} diff --git a/src/header/mod.rs b/src/header/mod.rs deleted file mode 100644 index 2e57eef80..000000000 --- a/src/header/mod.rs +++ /dev/null @@ -1,248 +0,0 @@ -//! Various http headers -// This is mostly copy of [hyper](https://github.com/hyperium/hyper/tree/master/src/header) - -use std::fmt; -use std::str::FromStr; - -use bytes::{Bytes, BytesMut}; -use modhttp::{Error as HttpError}; -use modhttp::header::GetAll; -use mime::Mime; - -pub use modhttp::header::*; - -use error::ParseError; -use httpmessage::HttpMessage; - -mod common; -mod shared; -#[doc(hidden)] -pub use self::common::*; -#[doc(hidden)] -pub use self::shared::*; - - -#[doc(hidden)] -/// A trait for any object that will represent a header field and value. -pub trait Header where Self: IntoHeaderValue { - - /// Returns the name of the header field - fn name() -> HeaderName; - - /// Parse a header - fn parse(msg: &T) -> Result; -} - -#[doc(hidden)] -/// A trait for any object that can be Converted to a `HeaderValue` -pub trait IntoHeaderValue: Sized { - /// The type returned in the event of a conversion error. - type Error: Into; - - /// Cast from PyObject to a concrete Python object type. - fn try_into(self) -> Result; -} - -impl IntoHeaderValue for HeaderValue { - type Error = InvalidHeaderValue; - - #[inline] - fn try_into(self) -> Result { - Ok(self) - } -} - -impl<'a> IntoHeaderValue for &'a str { - type Error = InvalidHeaderValue; - - #[inline] - fn try_into(self) -> Result { - self.parse() - } -} - -impl<'a> IntoHeaderValue for &'a [u8] { - type Error = InvalidHeaderValue; - - #[inline] - fn try_into(self) -> Result { - HeaderValue::from_bytes(self) - } -} - -impl IntoHeaderValue for Bytes { - type Error = InvalidHeaderValueBytes; - - #[inline] - fn try_into(self) -> Result { - HeaderValue::from_shared(self) - } -} - -impl IntoHeaderValue for Vec { - type Error = InvalidHeaderValueBytes; - - #[inline] - fn try_into(self) -> Result { - HeaderValue::from_shared(Bytes::from(self)) - } -} - -impl IntoHeaderValue for String { - type Error = InvalidHeaderValueBytes; - - #[inline] - fn try_into(self) -> Result { - HeaderValue::from_shared(Bytes::from(self)) - } -} - -impl IntoHeaderValue for Mime { - type Error = InvalidHeaderValueBytes; - - #[inline] - fn try_into(self) -> Result { - HeaderValue::from_shared(Bytes::from(format!("{}", self))) - } -} - -/// Represents supported types of content encodings -#[derive(Copy, Clone, PartialEq, Debug)] -pub enum ContentEncoding { - /// Automatically select encoding based on encoding negotiation - Auto, - /// A format using the Brotli algorithm - #[cfg(feature="brotli")] - Br, - /// A format using the zlib structure with deflate algorithm - Deflate, - /// Gzip algorithm - Gzip, - /// Indicates the identity function (i.e. no compression, nor modification) - Identity, -} - -impl ContentEncoding { - - #[inline] - pub fn is_compression(&self) -> bool { - match *self { - ContentEncoding::Identity | ContentEncoding::Auto => false, - _ => true - } - } - - #[inline] - pub fn as_str(&self) -> &'static str { - match *self { - #[cfg(feature="brotli")] - ContentEncoding::Br => "br", - ContentEncoding::Gzip => "gzip", - ContentEncoding::Deflate => "deflate", - ContentEncoding::Identity | ContentEncoding::Auto => "identity", - } - } - - #[inline] - /// default quality value - pub fn quality(&self) -> f64 { - match *self { - #[cfg(feature="brotli")] - ContentEncoding::Br => 1.1, - ContentEncoding::Gzip => 1.0, - ContentEncoding::Deflate => 0.9, - ContentEncoding::Identity | ContentEncoding::Auto => 0.1, - } - } -} - -// TODO: remove memory allocation -impl<'a> From<&'a str> for ContentEncoding { - fn from(s: &'a str) -> ContentEncoding { - match s.trim().to_lowercase().as_ref() { - #[cfg(feature="brotli")] - "br" => ContentEncoding::Br, - "gzip" => ContentEncoding::Gzip, - "deflate" => ContentEncoding::Deflate, - _ => ContentEncoding::Identity, - } - } -} - -#[doc(hidden)] -pub(crate) struct Writer { - buf: BytesMut, -} - -impl Writer { - fn new() -> Writer { - Writer{buf: BytesMut::new()} - } - fn take(&mut self) -> Bytes { - self.buf.take().freeze() - } -} - -impl fmt::Write for Writer { - #[inline] - fn write_str(&mut self, s: &str) -> fmt::Result { - self.buf.extend_from_slice(s.as_bytes()); - Ok(()) - } - - #[inline] - fn write_fmt(&mut self, args: fmt::Arguments) -> fmt::Result { - fmt::write(self, args) - } -} - -#[inline] -#[doc(hidden)] -/// Reads a comma-delimited raw header into a Vec. -pub fn from_comma_delimited(all: GetAll) - -> Result, ParseError> -{ - let mut result = Vec::new(); - for h in all { - let s = h.to_str().map_err(|_| ParseError::Header)?; - result.extend(s.split(',') - .filter_map(|x| match x.trim() { - "" => None, - y => Some(y) - }) - .filter_map(|x| x.trim().parse().ok())) - } - Ok(result) -} - -#[inline] -#[doc(hidden)] -/// Reads a single string when parsing a header. -pub fn from_one_raw_str(val: Option<&HeaderValue>) - -> Result -{ - if let Some(line) = val { - let line = line.to_str().map_err(|_| ParseError::Header)?; - if !line.is_empty() { - return T::from_str(line).or(Err(ParseError::Header)) - } - } - Err(ParseError::Header) -} - -#[inline] -#[doc(hidden)] -/// Format an array into a comma-delimited string. -pub fn fmt_comma_delimited(f: &mut fmt::Formatter, parts: &[T]) -> fmt::Result - where T: fmt::Display -{ - let mut iter = parts.iter(); - if let Some(part) = iter.next() { - fmt::Display::fmt(part, f)?; - } - for part in iter { - f.write_str(", ")?; - fmt::Display::fmt(part, f)?; - } - Ok(()) -} diff --git a/src/helpers.rs b/src/helpers.rs deleted file mode 100644 index 446e717a4..000000000 --- a/src/helpers.rs +++ /dev/null @@ -1,321 +0,0 @@ -//! Various helpers - -use regex::Regex; -use http::{header, StatusCode}; - -use handler::Handler; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; - -/// Path normalization helper -/// -/// By normalizing it means: -/// -/// - Add a trailing slash to the path. -/// - Remove a trailing slash from the path. -/// - Double slashes are replaced by one. -/// -/// The handler returns as soon as it finds a path that resolves -/// correctly. The order if all enable is 1) merge, 3) both merge and append -/// and 3) append. If the path resolves with -/// at least one of those conditions, it will redirect to the new path. -/// -/// If *append* is *true* append slash when needed. If a resource is -/// defined with trailing slash and the request comes without it, it will -/// append it automatically. -/// -/// If *merge* is *true*, merge multiple consecutive slashes in the path into one. -/// -/// This handler designed to be use as a handler for application's *default resource*. -/// -/// ```rust -/// # extern crate actix_web; -/// # #[macro_use] extern crate serde_derive; -/// # use actix_web::*; -/// use actix_web::http::NormalizePath; -/// -/// # fn index(req: HttpRequest) -> HttpResponse { -/// # HttpResponse::Ok().into() -/// # } -/// fn main() { -/// let app = App::new() -/// .resource("/test/", |r| r.f(index)) -/// .default_resource(|r| r.h(NormalizePath::default())) -/// .finish(); -/// } -/// ``` -/// In this example `/test`, `/test///` will be redirected to `/test/` url. -pub struct NormalizePath { - append: bool, - merge: bool, - re_merge: Regex, - redirect: StatusCode, - not_found: StatusCode, -} - -impl Default for NormalizePath { - /// Create default `NormalizePath` instance, *append* is set to *true*, - /// *merge* is set to *true* and *redirect* is set to `StatusCode::MOVED_PERMANENTLY` - fn default() -> NormalizePath { - NormalizePath { - append: true, - merge: true, - re_merge: Regex::new("//+").unwrap(), - redirect: StatusCode::MOVED_PERMANENTLY, - not_found: StatusCode::NOT_FOUND, - } - } -} - -impl NormalizePath { - /// Create new `NormalizePath` instance - pub fn new(append: bool, merge: bool, redirect: StatusCode) -> NormalizePath { - NormalizePath { - append, - merge, - redirect, - re_merge: Regex::new("//+").unwrap(), - not_found: StatusCode::NOT_FOUND, - } - } -} - -impl Handler for NormalizePath { - type Result = HttpResponse; - - fn handle(&mut self, req: HttpRequest) -> Self::Result { - if let Some(router) = req.router() { - let query = req.query_string(); - if self.merge { - // merge slashes - let p = self.re_merge.replace_all(req.path(), "/"); - if p.len() != req.path().len() { - if router.has_route(p.as_ref()) { - let p = if !query.is_empty() { p + "?" + query } else { p }; - return HttpResponse::build(self.redirect) - .header(header::LOCATION, p.as_ref()) - .finish(); - } - // merge slashes and append trailing slash - if self.append && !p.ends_with('/') { - let p = p.as_ref().to_owned() + "/"; - if router.has_route(&p) { - let p = if !query.is_empty() { p + "?" + query } else { p }; - return HttpResponse::build(self.redirect) - .header(header::LOCATION, p.as_str()) - .finish() - } - } - - // try to remove trailing slash - if p.ends_with('/') { - let p = p.as_ref().trim_right_matches('/'); - if router.has_route(p) { - let mut req = HttpResponse::build(self.redirect); - return if !query.is_empty() { - req.header(header::LOCATION, (p.to_owned() + "?" + query).as_str()) - } else { - req.header(header::LOCATION, p) - } - .finish(); - } - } - } else if p.ends_with('/') { - // try to remove trailing slash - let p = p.as_ref().trim_right_matches('/'); - if router.has_route(p) { - let mut req = HttpResponse::build(self.redirect); - return if !query.is_empty() { - req.header(header::LOCATION, - (p.to_owned() + "?" + query).as_str()) - } else { - req.header(header::LOCATION, p) - } - .finish(); - } - } - } - // append trailing slash - if self.append && !req.path().ends_with('/') { - let p = req.path().to_owned() + "/"; - if router.has_route(&p) { - let p = if !query.is_empty() { p + "?" + query } else { p }; - return HttpResponse::build(self.redirect) - .header(header::LOCATION, p.as_str()) - .finish(); - } - } - } - HttpResponse::new(self.not_found) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use http::{header, Method}; - use test::TestRequest; - use application::App; - - fn index(_req: HttpRequest) -> HttpResponse { - HttpResponse::new(StatusCode::OK) - } - - #[test] - fn test_normalize_path_trailing_slashes() { - let mut app = App::new() - .resource("/resource1", |r| r.method(Method::GET).f(index)) - .resource("/resource2/", |r| r.method(Method::GET).f(index)) - .default_resource(|r| r.h(NormalizePath::default())) - .finish(); - - // trailing slashes - let params = - vec![("/resource1", "", StatusCode::OK), - ("/resource1/", "/resource1", StatusCode::MOVED_PERMANENTLY), - ("/resource2", "/resource2/", StatusCode::MOVED_PERMANENTLY), - ("/resource2/", "", StatusCode::OK), - ("/resource1?p1=1&p2=2", "", StatusCode::OK), - ("/resource1/?p1=1&p2=2", "/resource1?p1=1&p2=2", StatusCode::MOVED_PERMANENTLY), - ("/resource2?p1=1&p2=2", "/resource2/?p1=1&p2=2", - StatusCode::MOVED_PERMANENTLY), - ("/resource2/?p1=1&p2=2", "", StatusCode::OK) - ]; - for (path, target, code) in params { - let req = app.prepare_request(TestRequest::with_uri(path).finish()); - let resp = app.run(req); - let r = resp.as_response().unwrap(); - assert_eq!(r.status(), code); - if !target.is_empty() { - assert_eq!( - target, - r.headers().get(header::LOCATION).unwrap().to_str().unwrap()); - } - } - } - - #[test] - fn test_normalize_path_trailing_slashes_disabled() { - let mut app = App::new() - .resource("/resource1", |r| r.method(Method::GET).f(index)) - .resource("/resource2/", |r| r.method(Method::GET).f(index)) - .default_resource(|r| r.h( - NormalizePath::new(false, true, StatusCode::MOVED_PERMANENTLY))) - .finish(); - - // trailing slashes - let params = vec![("/resource1", StatusCode::OK), - ("/resource1/", StatusCode::MOVED_PERMANENTLY), - ("/resource2", StatusCode::NOT_FOUND), - ("/resource2/", StatusCode::OK), - ("/resource1?p1=1&p2=2", StatusCode::OK), - ("/resource1/?p1=1&p2=2", StatusCode::MOVED_PERMANENTLY), - ("/resource2?p1=1&p2=2", StatusCode::NOT_FOUND), - ("/resource2/?p1=1&p2=2", StatusCode::OK) - ]; - for (path, code) in params { - let req = app.prepare_request(TestRequest::with_uri(path).finish()); - let resp = app.run(req); - let r = resp.as_response().unwrap(); - assert_eq!(r.status(), code); - } - } - - #[test] - fn test_normalize_path_merge_slashes() { - let mut app = App::new() - .resource("/resource1", |r| r.method(Method::GET).f(index)) - .resource("/resource1/a/b", |r| r.method(Method::GET).f(index)) - .default_resource(|r| r.h(NormalizePath::default())) - .finish(); - - // trailing slashes - let params = vec![ - ("/resource1/a/b", "", StatusCode::OK), - ("/resource1/", "/resource1", StatusCode::MOVED_PERMANENTLY), - ("/resource1//", "/resource1", StatusCode::MOVED_PERMANENTLY), - ("//resource1//a//b", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), - ("//resource1//a//b/", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), - ("//resource1//a//b//", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), - ("///resource1//a//b", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), - ("/////resource1/a///b", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), - ("/////resource1/a//b/", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), - ("/resource1/a/b?p=1", "", StatusCode::OK), - ("//resource1//a//b?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), - ("//resource1//a//b/?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), - ("///resource1//a//b?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), - ("/////resource1/a///b?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), - ("/////resource1/a//b/?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), - ("/////resource1/a//b//?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), - ]; - for (path, target, code) in params { - let req = app.prepare_request(TestRequest::with_uri(path).finish()); - let resp = app.run(req); - let r = resp.as_response().unwrap(); - assert_eq!(r.status(), code); - if !target.is_empty() { - assert_eq!( - target, - r.headers().get(header::LOCATION).unwrap().to_str().unwrap()); - } - } - } - - #[test] - fn test_normalize_path_merge_and_append_slashes() { - let mut app = App::new() - .resource("/resource1", |r| r.method(Method::GET).f(index)) - .resource("/resource2/", |r| r.method(Method::GET).f(index)) - .resource("/resource1/a/b", |r| r.method(Method::GET).f(index)) - .resource("/resource2/a/b/", |r| r.method(Method::GET).f(index)) - .default_resource(|r| r.h(NormalizePath::default())) - .finish(); - - // trailing slashes - let params = vec![ - ("/resource1/a/b", "", StatusCode::OK), - ("/resource1/a/b/", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), - ("//resource2//a//b", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), - ("//resource2//a//b/", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), - ("//resource2//a//b//", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), - ("///resource1//a//b", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), - ("///resource1//a//b/", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), - ("/////resource1/a///b", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), - ("/////resource1/a///b/", "/resource1/a/b", StatusCode::MOVED_PERMANENTLY), - ("/resource2/a/b", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), - ("/resource2/a/b/", "", StatusCode::OK), - ("//resource2//a//b", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), - ("//resource2//a//b/", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), - ("///resource2//a//b", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), - ("///resource2//a//b/", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), - ("/////resource2/a///b", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), - ("/////resource2/a///b/", "/resource2/a/b/", StatusCode::MOVED_PERMANENTLY), - ("/resource1/a/b?p=1", "", StatusCode::OK), - ("/resource1/a/b/?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), - ("//resource2//a//b?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), - ("//resource2//a//b/?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), - ("///resource1//a//b?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), - ("///resource1//a//b/?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), - ("/////resource1/a///b?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), - ("/////resource1/a///b/?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), - ("/////resource1/a///b//?p=1", "/resource1/a/b?p=1", StatusCode::MOVED_PERMANENTLY), - ("/resource2/a/b?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), - ("//resource2//a//b?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), - ("//resource2//a//b/?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), - ("///resource2//a//b?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), - ("///resource2//a//b/?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), - ("/////resource2/a///b?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), - ("/////resource2/a///b/?p=1", "/resource2/a/b/?p=1", StatusCode::MOVED_PERMANENTLY), - ]; - for (path, target, code) in params { - let req = app.prepare_request(TestRequest::with_uri(path).finish()); - let resp = app.run(req); - let r = resp.as_response().unwrap(); - assert_eq!(r.status(), code); - if !target.is_empty() { - assert_eq!( - target, r.headers().get(header::LOCATION).unwrap().to_str().unwrap()); - } - } - } -} diff --git a/src/httpcodes.rs b/src/httpcodes.rs deleted file mode 100644 index 7ad66cb1b..000000000 --- a/src/httpcodes.rs +++ /dev/null @@ -1,275 +0,0 @@ -//! Basic http responses -#![allow(non_upper_case_globals, deprecated)] -use http::StatusCode; - -use body::Body; -use error::Error; -use handler::{Reply, Handler, RouteHandler, Responder}; -use httprequest::HttpRequest; -use httpresponse::{HttpResponse, HttpResponseBuilder}; - -#[deprecated(since="0.5.0", note="please use `HttpResponse::Ok()` instead")] -pub const HttpOk: StaticResponse = StaticResponse(StatusCode::OK); -#[deprecated(since="0.5.0", note="please use `HttpResponse::Created()` instead")] -pub const HttpCreated: StaticResponse = StaticResponse(StatusCode::CREATED); -#[deprecated(since="0.5.0", note="please use `HttpResponse::Accepted()` instead")] -pub const HttpAccepted: StaticResponse = StaticResponse(StatusCode::ACCEPTED); -#[deprecated(since="0.5.0", - note="please use `HttpResponse::pNonAuthoritativeInformation()` instead")] -pub const HttpNonAuthoritativeInformation: StaticResponse = - StaticResponse(StatusCode::NON_AUTHORITATIVE_INFORMATION); -#[deprecated(since="0.5.0", note="please use `HttpResponse::NoContent()` instead")] -pub const HttpNoContent: StaticResponse = StaticResponse(StatusCode::NO_CONTENT); -#[deprecated(since="0.5.0", note="please use `HttpResponse::ResetContent()` instead")] -pub const HttpResetContent: StaticResponse = StaticResponse(StatusCode::RESET_CONTENT); -#[deprecated(since="0.5.0", note="please use `HttpResponse::PartialContent()` instead")] -pub const HttpPartialContent: StaticResponse = StaticResponse(StatusCode::PARTIAL_CONTENT); -#[deprecated(since="0.5.0", note="please use `HttpResponse::MultiStatus()` instead")] -pub const HttpMultiStatus: StaticResponse = StaticResponse(StatusCode::MULTI_STATUS); -#[deprecated(since="0.5.0", note="please use `HttpResponse::AlreadyReported()` instead")] -pub const HttpAlreadyReported: StaticResponse = StaticResponse(StatusCode::ALREADY_REPORTED); - -#[deprecated(since="0.5.0", note="please use `HttpResponse::MultipleChoices()` instead")] -pub const HttpMultipleChoices: StaticResponse = StaticResponse(StatusCode::MULTIPLE_CHOICES); -#[deprecated(since="0.5.0", note="please use `HttpResponse::MovedPermanently()` instead")] -pub const HttpMovedPermanently: StaticResponse = StaticResponse(StatusCode::MOVED_PERMANENTLY); -#[deprecated(since="0.5.0", note="please use `HttpResponse::Found()` instead")] -pub const HttpFound: StaticResponse = StaticResponse(StatusCode::FOUND); -#[deprecated(since="0.5.0", note="please use `HttpResponse::SeeOther()` instead")] -pub const HttpSeeOther: StaticResponse = StaticResponse(StatusCode::SEE_OTHER); -#[deprecated(since="0.5.0", note="please use `HttpResponse::NotModified()` instead")] -pub const HttpNotModified: StaticResponse = StaticResponse(StatusCode::NOT_MODIFIED); -#[deprecated(since="0.5.0", note="please use `HttpResponse::UseProxy()` instead")] -pub const HttpUseProxy: StaticResponse = StaticResponse(StatusCode::USE_PROXY); -#[deprecated(since="0.5.0", note="please use `HttpResponse::TemporaryRedirect()` instead")] -pub const HttpTemporaryRedirect: StaticResponse = - StaticResponse(StatusCode::TEMPORARY_REDIRECT); -#[deprecated(since="0.5.0", note="please use `HttpResponse::PermanentRedirect()` instead")] -pub const HttpPermanentRedirect: StaticResponse = - StaticResponse(StatusCode::PERMANENT_REDIRECT); - -#[deprecated(since="0.5.0", note="please use `HttpResponse::BadRequest()` instead")] -pub const HttpBadRequest: StaticResponse = StaticResponse(StatusCode::BAD_REQUEST); -#[deprecated(since="0.5.0", note="please use `HttpResponse::Unauthorized()` instead")] -pub const HttpUnauthorized: StaticResponse = StaticResponse(StatusCode::UNAUTHORIZED); -#[deprecated(since="0.5.0", note="please use `HttpResponse::PaymentRequired()` instead")] -pub const HttpPaymentRequired: StaticResponse = StaticResponse(StatusCode::PAYMENT_REQUIRED); -#[deprecated(since="0.5.0", note="please use `HttpResponse::Forbidden()` instead")] -pub const HttpForbidden: StaticResponse = StaticResponse(StatusCode::FORBIDDEN); -#[deprecated(since="0.5.0", note="please use `HttpResponse::NotFound()` instead")] -pub const HttpNotFound: StaticResponse = StaticResponse(StatusCode::NOT_FOUND); -#[deprecated(since="0.5.0", note="please use `HttpResponse::MethodNotAllowed()` instead")] -pub const HttpMethodNotAllowed: StaticResponse = - StaticResponse(StatusCode::METHOD_NOT_ALLOWED); -#[deprecated(since="0.5.0", note="please use `HttpResponse::NotAcceptable()` instead")] -pub const HttpNotAcceptable: StaticResponse = StaticResponse(StatusCode::NOT_ACCEPTABLE); -#[deprecated(since="0.5.0", - note="please use `HttpResponse::ProxyAuthenticationRequired()` instead")] -pub const HttpProxyAuthenticationRequired: StaticResponse = - StaticResponse(StatusCode::PROXY_AUTHENTICATION_REQUIRED); -#[deprecated(since="0.5.0", note="please use `HttpResponse::RequestTimeout()` instead")] -pub const HttpRequestTimeout: StaticResponse = StaticResponse(StatusCode::REQUEST_TIMEOUT); -#[deprecated(since="0.5.0", note="please use `HttpResponse::Conflict()` instead")] -pub const HttpConflict: StaticResponse = StaticResponse(StatusCode::CONFLICT); -#[deprecated(since="0.5.0", note="please use `HttpResponse::Gone()` instead")] -pub const HttpGone: StaticResponse = StaticResponse(StatusCode::GONE); -#[deprecated(since="0.5.0", note="please use `HttpResponse::LengthRequired()` instead")] -pub const HttpLengthRequired: StaticResponse = StaticResponse(StatusCode::LENGTH_REQUIRED); -#[deprecated(since="0.5.0", note="please use `HttpResponse::PreconditionFailed()` instead")] -pub const HttpPreconditionFailed: StaticResponse = - StaticResponse(StatusCode::PRECONDITION_FAILED); -#[deprecated(since="0.5.0", note="please use `HttpResponse::PayloadTooLarge()` instead")] -pub const HttpPayloadTooLarge: StaticResponse = StaticResponse(StatusCode::PAYLOAD_TOO_LARGE); -#[deprecated(since="0.5.0", note="please use `HttpResponse::UriTooLong()` instead")] -pub const HttpUriTooLong: StaticResponse = StaticResponse(StatusCode::URI_TOO_LONG); -#[deprecated(since="0.5.0", - note="please use `HttpResponse::UnsupportedMediaType()` instead")] -pub const HttpUnsupportedMediaType: StaticResponse = - StaticResponse(StatusCode::UNSUPPORTED_MEDIA_TYPE); -#[deprecated(since="0.5.0", - note="please use `HttpResponse::RangeNotSatisfiable()` instead")] -pub const HttpRangeNotSatisfiable: StaticResponse = - StaticResponse(StatusCode::RANGE_NOT_SATISFIABLE); -#[deprecated(since="0.5.0", note="please use `HttpResponse::ExpectationFailed()` instead")] -pub const HttpExpectationFailed: StaticResponse = - StaticResponse(StatusCode::EXPECTATION_FAILED); - -#[deprecated(since="0.5.0", - note="please use `HttpResponse::InternalServerError()` instead")] -pub const HttpInternalServerError: StaticResponse = - StaticResponse(StatusCode::INTERNAL_SERVER_ERROR); -#[deprecated(since="0.5.0", note="please use `HttpResponse::NotImplemented()` instead")] -pub const HttpNotImplemented: StaticResponse = StaticResponse(StatusCode::NOT_IMPLEMENTED); -#[deprecated(since="0.5.0", note="please use `HttpResponse::BadGateway()` instead")] -pub const HttpBadGateway: StaticResponse = StaticResponse(StatusCode::BAD_GATEWAY); -#[deprecated(since="0.5.0", note="please use `HttpResponse::ServiceUnavailable()` instead")] -pub const HttpServiceUnavailable: StaticResponse = - StaticResponse(StatusCode::SERVICE_UNAVAILABLE); -#[deprecated(since="0.5.0", note="please use `HttpResponse::GatewayTimeout()` instead")] -pub const HttpGatewayTimeout: StaticResponse = - StaticResponse(StatusCode::GATEWAY_TIMEOUT); -#[deprecated(since="0.5.0", - note="please use `HttpResponse::VersionNotSupported()` instead")] -pub const HttpVersionNotSupported: StaticResponse = - StaticResponse(StatusCode::HTTP_VERSION_NOT_SUPPORTED); -#[deprecated(since="0.5.0", - note="please use `HttpResponse::VariantAlsoNegotiates()` instead")] -pub const HttpVariantAlsoNegotiates: StaticResponse = - StaticResponse(StatusCode::VARIANT_ALSO_NEGOTIATES); -#[deprecated(since="0.5.0", - note="please use `HttpResponse::InsufficientStorage()` instead")] -pub const HttpInsufficientStorage: StaticResponse = - StaticResponse(StatusCode::INSUFFICIENT_STORAGE); -#[deprecated(since="0.5.0", note="please use `HttpResponse::LoopDetected()` instead")] -pub const HttpLoopDetected: StaticResponse = StaticResponse(StatusCode::LOOP_DETECTED); - - -#[deprecated(since="0.5.0", note="please use `HttpResponse` instead")] -#[derive(Copy, Clone, Debug)] -pub struct StaticResponse(StatusCode); - -impl StaticResponse { - pub fn build(&self) -> HttpResponseBuilder { - HttpResponse::build(self.0) - } - pub fn build_from(&self, req: &HttpRequest) -> HttpResponseBuilder { - req.build_response(self.0) - } - pub fn with_reason(self, reason: &'static str) -> HttpResponse { - let mut resp = HttpResponse::new(self.0); - resp.set_reason(reason); - resp - } - pub fn with_body>(self, body: B) -> HttpResponse { - HttpResponse::with_body(self.0, body.into()) - } -} - -impl Handler for StaticResponse { - type Result = HttpResponse; - - fn handle(&mut self, _: HttpRequest) -> HttpResponse { - HttpResponse::new(self.0) - } -} - -impl RouteHandler for StaticResponse { - fn handle(&mut self, _: HttpRequest) -> Reply { - Reply::response(HttpResponse::new(self.0)) - } -} - -impl Responder for StaticResponse { - type Item = HttpResponse; - type Error = Error; - - fn respond_to(self, _: HttpRequest) -> Result { - Ok(self.build().finish()) - } -} - -impl From for HttpResponse { - fn from(st: StaticResponse) -> Self { - HttpResponse::new(st.0) - } -} - -impl From for Reply { - fn from(st: StaticResponse) -> Self { - HttpResponse::new(st.0).into() - } -} - -macro_rules! STATIC_RESP { - ($name:ident, $status:expr) => { - #[allow(non_snake_case)] - pub fn $name() -> HttpResponseBuilder { - HttpResponse::build($status) - } - } -} - -impl HttpResponse { - STATIC_RESP!(Ok, StatusCode::OK); - STATIC_RESP!(Created, StatusCode::CREATED); - STATIC_RESP!(Accepted, StatusCode::ACCEPTED); - STATIC_RESP!(NonAuthoritativeInformation, StatusCode::NON_AUTHORITATIVE_INFORMATION); - - STATIC_RESP!(NoContent, StatusCode::NO_CONTENT); - STATIC_RESP!(ResetContent, StatusCode::RESET_CONTENT); - STATIC_RESP!(PartialContent, StatusCode::PARTIAL_CONTENT); - STATIC_RESP!(MultiStatus, StatusCode::MULTI_STATUS); - STATIC_RESP!(AlreadyReported, StatusCode::ALREADY_REPORTED); - - STATIC_RESP!(MultipleChoices, StatusCode::MULTIPLE_CHOICES); - STATIC_RESP!(MovedPermanenty, StatusCode::MOVED_PERMANENTLY); - STATIC_RESP!(Found, StatusCode::FOUND); - STATIC_RESP!(SeeOther, StatusCode::SEE_OTHER); - STATIC_RESP!(NotModified, StatusCode::NOT_MODIFIED); - STATIC_RESP!(UseProxy, StatusCode::USE_PROXY); - STATIC_RESP!(TemporaryRedirect, StatusCode::TEMPORARY_REDIRECT); - STATIC_RESP!(PermanentRedirect, StatusCode::PERMANENT_REDIRECT); - - STATIC_RESP!(BadRequest, StatusCode::BAD_REQUEST); - STATIC_RESP!(NotFound, StatusCode::NOT_FOUND); - STATIC_RESP!(Unauthorized, StatusCode::UNAUTHORIZED); - STATIC_RESP!(PaymentRequired, StatusCode::PAYMENT_REQUIRED); - STATIC_RESP!(Forbidden, StatusCode::FORBIDDEN); - STATIC_RESP!(MethodNotAllowed, StatusCode::METHOD_NOT_ALLOWED); - STATIC_RESP!(NotAcceptable, StatusCode::NOT_ACCEPTABLE); - STATIC_RESP!(ProxyAuthenticationRequired, StatusCode::PROXY_AUTHENTICATION_REQUIRED); - STATIC_RESP!(RequestTimeout, StatusCode::REQUEST_TIMEOUT); - STATIC_RESP!(Conflict, StatusCode::CONFLICT); - STATIC_RESP!(Gone, StatusCode::GONE); - STATIC_RESP!(LengthRequired, StatusCode::LENGTH_REQUIRED); - STATIC_RESP!(PreconditionFailed, StatusCode::PRECONDITION_FAILED); - STATIC_RESP!(PayloadTooLarge, StatusCode::PAYLOAD_TOO_LARGE); - STATIC_RESP!(UriTooLong, StatusCode::URI_TOO_LONG); - STATIC_RESP!(UnsupportedMediaType, StatusCode::UNSUPPORTED_MEDIA_TYPE); - STATIC_RESP!(RangeNotSatisfiable, StatusCode::RANGE_NOT_SATISFIABLE); - STATIC_RESP!(ExpectationFailed, StatusCode::EXPECTATION_FAILED); - - STATIC_RESP!(InternalServerError, StatusCode::INTERNAL_SERVER_ERROR); - STATIC_RESP!(NotImplemented, StatusCode::NOT_IMPLEMENTED); - STATIC_RESP!(BadGateway, StatusCode::BAD_GATEWAY); - STATIC_RESP!(ServiceUnavailable, StatusCode::SERVICE_UNAVAILABLE); - STATIC_RESP!(GatewayTimeout, StatusCode::GATEWAY_TIMEOUT); - STATIC_RESP!(VersionNotSupported, StatusCode::HTTP_VERSION_NOT_SUPPORTED); - STATIC_RESP!(VariantAlsoNegotiates, StatusCode::VARIANT_ALSO_NEGOTIATES); - STATIC_RESP!(InsufficientStorage, StatusCode::INSUFFICIENT_STORAGE); - STATIC_RESP!(LoopDetected, StatusCode::LOOP_DETECTED); -} - -#[cfg(test)] -mod tests { - use http::StatusCode; - use super::{HttpOk, HttpBadRequest, Body, HttpResponse}; - - #[test] - fn test_build() { - let resp = HttpOk.build().body(Body::Empty); - assert_eq!(resp.status(), StatusCode::OK); - } - - #[test] - fn test_response() { - let resp: HttpResponse = HttpOk.into(); - assert_eq!(resp.status(), StatusCode::OK); - } - - #[test] - fn test_from() { - let resp: HttpResponse = HttpOk.into(); - assert_eq!(resp.status(), StatusCode::OK); - } - - #[test] - fn test_with_reason() { - let resp: HttpResponse = HttpOk.into(); - assert_eq!(resp.reason(), "OK"); - - let resp = HttpBadRequest.with_reason("test"); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - assert_eq!(resp.reason(), "test"); - } -} diff --git a/src/httpmessage.rs b/src/httpmessage.rs deleted file mode 100644 index 11d1d087b..000000000 --- a/src/httpmessage.rs +++ /dev/null @@ -1,618 +0,0 @@ -use std::str; -use bytes::{Bytes, BytesMut}; -use futures::{Future, Stream, Poll}; -use http_range::HttpRange; -use serde::de::DeserializeOwned; -use mime::Mime; -use serde_urlencoded; -use encoding::all::UTF_8; -use encoding::EncodingRef; -use encoding::types::{Encoding, DecoderTrap}; -use encoding::label::encoding_from_whatwg_label; -use http::{header, HeaderMap}; - -use json::JsonBody; -use header::Header; -use multipart::Multipart; -use error::{ParseError, ContentTypeError, - HttpRangeError, PayloadError, UrlencodedError}; - - -/// Trait that implements general purpose operations on http messages -pub trait HttpMessage { - - /// Read the message headers. - fn headers(&self) -> &HeaderMap; - - #[doc(hidden)] - /// Get a header - fn get_header(&self) -> Option where Self: Sized { - if self.headers().contains_key(H::name()) { - H::parse(self).ok() - } else { - None - } - } - - /// Read the request content type. If request does not contain - /// *Content-Type* header, empty str get returned. - fn content_type(&self) -> &str { - if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) { - if let Ok(content_type) = content_type.to_str() { - return content_type.split(';').next().unwrap().trim() - } - } - "" - } - - /// Get content type encoding - /// - /// UTF-8 is used by default, If request charset is not set. - fn encoding(&self) -> Result { - if let Some(mime_type) = self.mime_type()? { - if let Some(charset) = mime_type.get_param("charset") { - if let Some(enc) = encoding_from_whatwg_label(charset.as_str()) { - Ok(enc) - } else { - Err(ContentTypeError::UnknownEncoding) - } - } else { - Ok(UTF_8) - } - } else { - Ok(UTF_8) - } - } - - /// Convert the request content type to a known mime type. - fn mime_type(&self) -> Result, ContentTypeError> { - if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) { - if let Ok(content_type) = content_type.to_str() { - return match content_type.parse() { - Ok(mt) => Ok(Some(mt)), - Err(_) => Err(ContentTypeError::ParseError), - }; - } else { - return Err(ContentTypeError::ParseError) - } - } - Ok(None) - } - - /// Check if request has chunked transfer encoding - fn chunked(&self) -> Result { - if let Some(encodings) = self.headers().get(header::TRANSFER_ENCODING) { - if let Ok(s) = encodings.to_str() { - Ok(s.to_lowercase().contains("chunked")) - } else { - Err(ParseError::Header) - } - } else { - Ok(false) - } - } - - /// Parses Range HTTP header string as per RFC 2616. - /// `size` is full size of response (file). - fn range(&self, size: u64) -> Result, HttpRangeError> { - if let Some(range) = self.headers().get(header::RANGE) { - HttpRange::parse(unsafe{str::from_utf8_unchecked(range.as_bytes())}, size) - .map_err(|e| e.into()) - } else { - Ok(Vec::new()) - } - } - - /// Load http message body. - /// - /// By default only 256Kb payload reads to a memory, then `PayloadError::Overflow` - /// get returned. Use `MessageBody::limit()` method to change upper limit. - /// - /// ## Server example - /// - /// ```rust - /// # extern crate bytes; - /// # extern crate actix_web; - /// # extern crate futures; - /// # #[macro_use] extern crate serde_derive; - /// use actix_web::*; - /// use bytes::Bytes; - /// use futures::future::Future; - /// - /// fn index(mut req: HttpRequest) -> Box> { - /// req.body() // <- get Body future - /// .limit(1024) // <- change max size of the body to a 1kb - /// .from_err() - /// .and_then(|bytes: Bytes| { // <- complete body - /// println!("==== BODY ==== {:?}", bytes); - /// Ok(HttpResponse::Ok().into()) - /// }).responder() - /// } - /// # fn main() {} - /// ``` - fn body(self) -> MessageBody - where Self: Stream + Sized - { - MessageBody::new(self) - } - - /// Parse `application/x-www-form-urlencoded` encoded request's body. - /// Return `UrlEncoded` future. Form can be deserialized to any type that implements - /// `Deserialize` trait from *serde*. - /// - /// Returns error: - /// - /// * content type is not `application/x-www-form-urlencoded` - /// * transfer encoding is `chunked`. - /// * content-length is greater than 256k - /// - /// ## Server example - /// - /// ```rust - /// # extern crate actix_web; - /// # extern crate futures; - /// # use futures::Future; - /// # use std::collections::HashMap; - /// use actix_web::{HttpMessage, HttpRequest, HttpResponse, FutureResponse}; - /// - /// fn index(mut req: HttpRequest) -> FutureResponse { - /// Box::new( - /// req.urlencoded::>() // <- get UrlEncoded future - /// .from_err() - /// .and_then(|params| { // <- url encoded parameters - /// println!("==== BODY ==== {:?}", params); - /// Ok(HttpResponse::Ok().into()) - /// })) - /// } - /// # fn main() {} - /// ``` - fn urlencoded(self) -> UrlEncoded - where Self: Stream + Sized - { - UrlEncoded::new(self) - } - - /// Parse `application/json` encoded body. - /// Return `JsonBody` future. It resolves to a `T` value. - /// - /// Returns error: - /// - /// * content type is not `application/json` - /// * content length is greater than 256k - /// - /// ## Server example - /// - /// ```rust - /// # extern crate actix_web; - /// # extern crate futures; - /// # #[macro_use] extern crate serde_derive; - /// use actix_web::*; - /// use futures::future::{Future, ok}; - /// - /// #[derive(Deserialize, Debug)] - /// struct MyObj { - /// name: String, - /// } - /// - /// fn index(mut req: HttpRequest) -> Box> { - /// req.json() // <- get JsonBody future - /// .from_err() - /// .and_then(|val: MyObj| { // <- deserialized value - /// println!("==== BODY ==== {:?}", val); - /// Ok(HttpResponse::Ok().into()) - /// }).responder() - /// } - /// # fn main() {} - /// ``` - fn json(self) -> JsonBody - where Self: Stream + Sized - { - JsonBody::new(self) - } - - /// Return stream to http payload processes as multipart. - /// - /// Content-type: multipart/form-data; - /// - /// ## Server example - /// - /// ```rust - /// # extern crate actix; - /// # extern crate actix_web; - /// # extern crate env_logger; - /// # extern crate futures; - /// # use std::str; - /// # use actix::*; - /// # use actix_web::*; - /// # use futures::{Future, Stream}; - /// # use futures::future::{ok, result, Either}; - /// fn index(mut req: HttpRequest) -> Box> { - /// req.multipart().from_err() // <- get multipart stream for current request - /// .and_then(|item| match item { // <- iterate over multipart items - /// multipart::MultipartItem::Field(field) => { - /// // Field in turn is stream of *Bytes* object - /// Either::A(field.from_err() - /// .map(|c| println!("-- CHUNK: \n{:?}", str::from_utf8(&c))) - /// .finish()) - /// }, - /// multipart::MultipartItem::Nested(mp) => { - /// // Or item could be nested Multipart stream - /// Either::B(ok(())) - /// } - /// }) - /// .finish() // <- Stream::finish() combinator from actix - /// .map(|_| HttpResponse::Ok().into()) - /// .responder() - /// } - /// # fn main() {} - /// ``` - fn multipart(self) -> Multipart - where Self: Stream + Sized - { - let boundary = Multipart::boundary(self.headers()); - Multipart::new(boundary, self) - } -} - -/// Future that resolves to a complete http message body. -pub struct MessageBody { - limit: usize, - req: Option, - fut: Option>>, -} - -impl MessageBody { - - /// Create `RequestBody` for request. - pub fn new(req: T) -> MessageBody { - MessageBody { - limit: 262_144, - req: Some(req), - fut: None, - } - } - - /// Change max size of payload. By default max size is 256Kb - pub fn limit(mut self, limit: usize) -> Self { - self.limit = limit; - self - } -} - -impl Future for MessageBody - where T: HttpMessage + Stream + 'static -{ - type Item = Bytes; - type Error = PayloadError; - - fn poll(&mut self) -> Poll { - if let Some(req) = self.req.take() { - if let Some(len) = req.headers().get(header::CONTENT_LENGTH) { - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - if len > self.limit { - return Err(PayloadError::Overflow); - } - } else { - return Err(PayloadError::UnknownLength); - } - } else { - return Err(PayloadError::UnknownLength); - } - } - - // future - let limit = self.limit; - self.fut = Some(Box::new( - req.from_err() - .fold(BytesMut::new(), move |mut body, chunk| { - if (body.len() + chunk.len()) > limit { - Err(PayloadError::Overflow) - } else { - body.extend_from_slice(&chunk); - Ok(body) - } - }) - .map(|body| body.freeze()) - )); - } - - self.fut.as_mut().expect("UrlEncoded could not be used second time").poll() - } -} - -/// Future that resolves to a parsed urlencoded values. -pub struct UrlEncoded { - req: Option, - limit: usize, - fut: Option>>, -} - -impl UrlEncoded { - pub fn new(req: T) -> UrlEncoded { - UrlEncoded { - req: Some(req), - limit: 262_144, - fut: None, - } - } - - /// Change max size of payload. By default max size is 256Kb - pub fn limit(mut self, limit: usize) -> Self { - self.limit = limit; - self - } -} - -impl Future for UrlEncoded - where T: HttpMessage + Stream + 'static, - U: DeserializeOwned + 'static -{ - type Item = U; - type Error = UrlencodedError; - - fn poll(&mut self) -> Poll { - if let Some(req) = self.req.take() { - if req.chunked().unwrap_or(false) { - return Err(UrlencodedError::Chunked) - } else if let Some(len) = req.headers().get(header::CONTENT_LENGTH) { - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - if len > 262_144 { - return Err(UrlencodedError::Overflow); - } - } else { - return Err(UrlencodedError::UnknownLength) - } - } else { - return Err(UrlencodedError::UnknownLength) - } - } - - // check content type - if req.content_type().to_lowercase() != "application/x-www-form-urlencoded" { - return Err(UrlencodedError::ContentType) - } - let encoding = req.encoding().map_err(|_| UrlencodedError::ContentType)?; - - // future - let limit = self.limit; - let fut = req.from_err() - .fold(BytesMut::new(), move |mut body, chunk| { - if (body.len() + chunk.len()) > limit { - Err(UrlencodedError::Overflow) - } else { - body.extend_from_slice(&chunk); - Ok(body) - } - }) - .and_then(move |body| { - let enc: *const Encoding = encoding as *const Encoding; - if enc == UTF_8 { - serde_urlencoded::from_bytes::(&body) - .map_err(|_| UrlencodedError::Parse) - } else { - let body = encoding.decode(&body, DecoderTrap::Strict) - .map_err(|_| UrlencodedError::Parse)?; - serde_urlencoded::from_str::(&body) - .map_err(|_| UrlencodedError::Parse) - } - }); - self.fut = Some(Box::new(fut)); - } - - self.fut.as_mut().expect("UrlEncoded could not be used second time").poll() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use mime; - use encoding::Encoding; - use encoding::all::ISO_8859_2; - use futures::Async; - use http::{Method, Version, Uri}; - use httprequest::HttpRequest; - use std::str::FromStr; - use test::TestRequest; - - #[test] - fn test_content_type() { - let req = TestRequest::with_header("content-type", "text/plain").finish(); - assert_eq!(req.content_type(), "text/plain"); - let req = TestRequest::with_header( - "content-type", "application/json; charset=utf=8").finish(); - assert_eq!(req.content_type(), "application/json"); - let req = HttpRequest::default(); - assert_eq!(req.content_type(), ""); - } - - #[test] - fn test_mime_type() { - let req = TestRequest::with_header("content-type", "application/json").finish(); - assert_eq!(req.mime_type().unwrap(), Some(mime::APPLICATION_JSON)); - let req = HttpRequest::default(); - assert_eq!(req.mime_type().unwrap(), None); - let req = TestRequest::with_header( - "content-type", "application/json; charset=utf-8").finish(); - let mt = req.mime_type().unwrap().unwrap(); - assert_eq!(mt.get_param(mime::CHARSET), Some(mime::UTF_8)); - assert_eq!(mt.type_(), mime::APPLICATION); - assert_eq!(mt.subtype(), mime::JSON); - } - - #[test] - fn test_mime_type_error() { - let req = TestRequest::with_header( - "content-type", "applicationadfadsfasdflknadsfklnadsfjson").finish(); - assert_eq!(Err(ContentTypeError::ParseError), req.mime_type()); - } - - #[test] - fn test_encoding() { - let req = HttpRequest::default(); - assert_eq!(UTF_8.name(), req.encoding().unwrap().name()); - - let req = TestRequest::with_header( - "content-type", "application/json").finish(); - assert_eq!(UTF_8.name(), req.encoding().unwrap().name()); - - let req = TestRequest::with_header( - "content-type", "application/json; charset=ISO-8859-2").finish(); - assert_eq!(ISO_8859_2.name(), req.encoding().unwrap().name()); - } - - #[test] - fn test_encoding_error() { - let req = TestRequest::with_header( - "content-type", "applicatjson").finish(); - assert_eq!(Some(ContentTypeError::ParseError), req.encoding().err()); - - let req = TestRequest::with_header( - "content-type", "application/json; charset=kkkttktk").finish(); - assert_eq!(Some(ContentTypeError::UnknownEncoding), req.encoding().err()); - } - - #[test] - fn test_no_request_range_header() { - let req = HttpRequest::default(); - let ranges = req.range(100).unwrap(); - assert!(ranges.is_empty()); - } - - #[test] - fn test_request_range_header() { - let req = TestRequest::with_header(header::RANGE, "bytes=0-4").finish(); - let ranges = req.range(100).unwrap(); - assert_eq!(ranges.len(), 1); - assert_eq!(ranges[0].start, 0); - assert_eq!(ranges[0].length, 5); - } - - #[test] - fn test_chunked() { - let req = HttpRequest::default(); - assert!(!req.chunked().unwrap()); - - let req = TestRequest::with_header(header::TRANSFER_ENCODING, "chunked").finish(); - assert!(req.chunked().unwrap()); - - let mut headers = HeaderMap::new(); - let s = unsafe{str::from_utf8_unchecked(b"some va\xadscc\xacas0xsdasdlue".as_ref())}; - - headers.insert(header::TRANSFER_ENCODING, - header::HeaderValue::from_str(s).unwrap()); - let req = HttpRequest::new( - Method::GET, Uri::from_str("/").unwrap(), - Version::HTTP_11, headers, None); - assert!(req.chunked().is_err()); - } - - impl PartialEq for UrlencodedError { - fn eq(&self, other: &UrlencodedError) -> bool { - match *self { - UrlencodedError::Chunked => match *other { - UrlencodedError::Chunked => true, - _ => false, - }, - UrlencodedError::Overflow => match *other { - UrlencodedError::Overflow => true, - _ => false, - }, - UrlencodedError::UnknownLength => match *other { - UrlencodedError::UnknownLength => true, - _ => false, - }, - UrlencodedError::ContentType => match *other { - UrlencodedError::ContentType => true, - _ => false, - }, - _ => false, - } - } - } - - #[derive(Deserialize, Debug, PartialEq)] - struct Info { - hello: String, - } - - #[test] - fn test_urlencoded_error() { - let req = TestRequest::with_header(header::TRANSFER_ENCODING, "chunked").finish(); - assert_eq!(req.urlencoded::() - .poll().err().unwrap(), UrlencodedError::Chunked); - - let req = TestRequest::with_header( - header::CONTENT_TYPE, "application/x-www-form-urlencoded") - .header(header::CONTENT_LENGTH, "xxxx") - .finish(); - assert_eq!(req.urlencoded::() - .poll().err().unwrap(), UrlencodedError::UnknownLength); - - let req = TestRequest::with_header( - header::CONTENT_TYPE, "application/x-www-form-urlencoded") - .header(header::CONTENT_LENGTH, "1000000") - .finish(); - assert_eq!(req.urlencoded::() - .poll().err().unwrap(), UrlencodedError::Overflow); - - let req = TestRequest::with_header( - header::CONTENT_TYPE, "text/plain") - .header(header::CONTENT_LENGTH, "10") - .finish(); - assert_eq!(req.urlencoded::() - .poll().err().unwrap(), UrlencodedError::ContentType); - } - - #[test] - fn test_urlencoded() { - let mut req = TestRequest::with_header( - header::CONTENT_TYPE, "application/x-www-form-urlencoded") - .header(header::CONTENT_LENGTH, "11") - .finish(); - req.payload_mut().unread_data(Bytes::from_static(b"hello=world")); - - let result = req.urlencoded::().poll().ok().unwrap(); - assert_eq!(result, Async::Ready(Info{hello: "world".to_owned()})); - - let mut req = TestRequest::with_header( - header::CONTENT_TYPE, "application/x-www-form-urlencoded; charset=utf-8") - .header(header::CONTENT_LENGTH, "11") - .finish(); - req.payload_mut().unread_data(Bytes::from_static(b"hello=world")); - - let result = req.urlencoded().poll().ok().unwrap(); - assert_eq!(result, Async::Ready(Info{hello: "world".to_owned()})); - } - - #[test] - fn test_message_body() { - let req = TestRequest::with_header(header::CONTENT_LENGTH, "xxxx").finish(); - match req.body().poll().err().unwrap() { - PayloadError::UnknownLength => (), - _ => unreachable!("error"), - } - - let req = TestRequest::with_header(header::CONTENT_LENGTH, "1000000").finish(); - match req.body().poll().err().unwrap() { - PayloadError::Overflow => (), - _ => unreachable!("error"), - } - - let mut req = HttpRequest::default(); - req.payload_mut().unread_data(Bytes::from_static(b"test")); - match req.body().poll().ok().unwrap() { - Async::Ready(bytes) => assert_eq!(bytes, Bytes::from_static(b"test")), - _ => unreachable!("error"), - } - - let mut req = HttpRequest::default(); - req.payload_mut().unread_data(Bytes::from_static(b"11111111111111")); - match req.body().limit(5).poll().err().unwrap() { - PayloadError::Overflow => (), - _ => unreachable!("error"), - } - } -} diff --git a/src/httprequest.rs b/src/httprequest.rs deleted file mode 100644 index 00aacb810..000000000 --- a/src/httprequest.rs +++ /dev/null @@ -1,706 +0,0 @@ -//! HTTP Request message related code. -use std::{io, cmp, str, fmt, mem}; -use std::rc::Rc; -use std::net::SocketAddr; -use std::borrow::Cow; -use bytes::Bytes; -use cookie::Cookie; -use futures::{Async, Stream, Poll}; -use futures::future::{FutureResult, result}; -use futures_cpupool::CpuPool; -use failure; -use url::{Url, form_urlencoded}; -use http::{header, Uri, Method, Version, HeaderMap, Extensions, StatusCode}; -use tokio_io::AsyncRead; -use percent_encoding::percent_decode; - -use body::Body; -use info::ConnectionInfo; -use param::Params; -use router::{Router, Resource}; -use payload::Payload; -use handler::FromRequest; -use httpmessage::HttpMessage; -use httpresponse::{HttpResponse, HttpResponseBuilder}; -use server::helpers::SharedHttpInnerMessage; -use error::{Error, UrlGenerationError, CookieParseError, PayloadError}; - - -pub struct HttpInnerMessage { - pub version: Version, - pub method: Method, - pub uri: Uri, - pub headers: HeaderMap, - pub extensions: Extensions, - pub params: Params<'static>, - pub cookies: Option>>, - pub query: Params<'static>, - pub query_loaded: bool, - pub addr: Option, - pub payload: Option, - pub info: Option>, - resource: RouterResource, -} - -#[derive(Debug, Copy, Clone,PartialEq)] -enum RouterResource { - Notset, - Normal(u16), -} - -impl Default for HttpInnerMessage { - - fn default() -> HttpInnerMessage { - HttpInnerMessage { - method: Method::GET, - uri: Uri::default(), - version: Version::HTTP_11, - headers: HeaderMap::with_capacity(16), - params: Params::new(), - query: Params::new(), - query_loaded: false, - cookies: None, - addr: None, - payload: None, - extensions: Extensions::new(), - info: None, - resource: RouterResource::Notset, - } - } -} - -impl HttpInnerMessage { - - /// Checks if a connection should be kept alive. - #[inline] - pub fn keep_alive(&self) -> bool { - if let Some(conn) = self.headers.get(header::CONNECTION) { - if let Ok(conn) = conn.to_str() { - if self.version == Version::HTTP_10 && conn.contains("keep-alive") { - true - } else { - self.version == Version::HTTP_11 && - !(conn.contains("close") || conn.contains("upgrade")) - } - } else { - false - } - } else { - self.version != Version::HTTP_10 - } - } - - #[inline] - pub(crate) fn reset(&mut self) { - self.headers.clear(); - self.extensions.clear(); - self.params.clear(); - self.query.clear(); - self.query_loaded = false; - self.cookies = None; - self.addr = None; - self.info = None; - self.payload = None; - self.resource = RouterResource::Notset; - } -} - -lazy_static!{ - static ref RESOURCE: Resource = Resource::unset(); -} - - -/// An HTTP Request -pub struct HttpRequest(SharedHttpInnerMessage, Option>, Option); - -impl HttpRequest<()> { - /// Construct a new Request. - #[inline] - pub fn new(method: Method, uri: Uri, - version: Version, headers: HeaderMap, payload: Option) - -> HttpRequest - { - HttpRequest( - SharedHttpInnerMessage::from_message(HttpInnerMessage { - method, - uri, - version, - headers, - payload, - params: Params::new(), - query: Params::new(), - query_loaded: false, - cookies: None, - addr: None, - extensions: Extensions::new(), - info: None, - resource: RouterResource::Notset, - }), - None, - None, - ) - } - - #[inline(always)] - #[cfg_attr(feature="cargo-clippy", allow(inline_always))] - pub(crate) fn from_message(msg: SharedHttpInnerMessage) -> HttpRequest { - HttpRequest(msg, None, None) - } - - #[inline] - /// Construct new http request with state. - pub fn with_state(self, state: Rc, router: Router) -> HttpRequest { - HttpRequest(self.0, Some(state), Some(router)) - } -} - - -impl HttpMessage for HttpRequest { - #[inline] - fn headers(&self) -> &HeaderMap { - &self.as_ref().headers - } -} - -impl HttpRequest { - - #[inline] - /// Construct new http request with state. - pub fn change_state(&self, state: Rc) -> HttpRequest { - HttpRequest(self.0.clone(), Some(state), self.2.clone()) - } - - #[inline] - /// Construct new http request without state. - pub(crate) fn without_state(&self) -> HttpRequest { - HttpRequest(self.0.clone(), None, self.2.clone()) - } - - /// get mutable reference for inner message - /// mutable reference should not be returned as result for request's method - #[inline(always)] - #[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref, inline_always))] - pub(crate) fn as_mut(&self) -> &mut HttpInnerMessage { - self.0.get_mut() - } - - #[inline(always)] - #[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref, inline_always))] - fn as_ref(&self) -> &HttpInnerMessage { - self.0.get_ref() - } - - #[inline] - pub(crate) fn get_inner(&mut self) -> &mut HttpInnerMessage { - self.as_mut() - } - - /// Shared application state - #[inline] - pub fn state(&self) -> &S { - self.1.as_ref().unwrap() - } - - /// Request extensions - #[inline] - pub fn extensions(&mut self) -> &mut Extensions { - &mut self.as_mut().extensions - } - - /// Default `CpuPool` - #[inline] - #[doc(hidden)] - pub fn cpu_pool(&self) -> &CpuPool { - self.router().expect("HttpRequest has to have Router instance") - .server_settings().cpu_pool() - } - - /// Create http response - pub fn response(&self, status: StatusCode, body: Body) -> HttpResponse { - if let Some(router) = self.router() { - router.server_settings().get_response(status, body) - } else { - HttpResponse::with_body(status, body) - } - } - - /// Create http response builder - pub fn build_response(&self, status: StatusCode) -> HttpResponseBuilder { - if let Some(router) = self.router() { - router.server_settings().get_response_builder(status) - } else { - HttpResponse::build(status) - } - } - - #[doc(hidden)] - pub fn prefix_len(&self) -> usize { - if let Some(router) = self.router() { router.prefix().len() } else { 0 } - } - - /// Read the Request Uri. - #[inline] - pub fn uri(&self) -> &Uri { &self.as_ref().uri } - - /// Returns mutable the Request Uri. - /// - /// This might be useful for middlewares, e.g. path normalization. - #[inline] - pub fn uri_mut(&mut self) -> &mut Uri { - &mut self.as_mut().uri - } - - /// Read the Request method. - #[inline] - pub fn method(&self) -> &Method { &self.as_ref().method } - - /// Read the Request Version. - #[inline] - pub fn version(&self) -> Version { - self.as_ref().version - } - - ///Returns mutable Request's headers. - /// - ///This is intended to be used by middleware. - #[inline] - pub fn headers_mut(&mut self) -> &mut HeaderMap { - &mut self.as_mut().headers - } - - /// The target path of this Request. - #[inline] - pub fn path(&self) -> &str { - self.uri().path() - } - - /// Percent decoded path of this Request. - #[inline] - pub fn path_decoded(&self) -> Cow { - percent_decode(self.uri().path().as_bytes()).decode_utf8().unwrap() - } - - /// Get *ConnectionInfo* for correct request. - pub fn connection_info(&self) -> &ConnectionInfo { - if self.as_ref().info.is_none() { - let info: ConnectionInfo<'static> = unsafe{ - mem::transmute(ConnectionInfo::new(self))}; - self.as_mut().info = Some(info); - } - self.as_ref().info.as_ref().unwrap() - } - - /// Generate url for named resource - /// - /// ```rust - /// # extern crate actix_web; - /// # use actix_web::{App, HttpRequest, HttpResponse, http}; - /// # - /// fn index(req: HttpRequest) -> HttpResponse { - /// let url = req.url_for("foo", &["1", "2", "3"]); // <- generate url for "foo" resource - /// HttpResponse::Ok().into() - /// } - /// - /// fn main() { - /// let app = App::new() - /// .resource("/test/{one}/{two}/{three}", |r| { - /// r.name("foo"); // <- set resource name, then it could be used in `url_for` - /// r.method(http::Method::GET).f(|_| HttpResponse::Ok()); - /// }) - /// .finish(); - /// } - /// ``` - pub fn url_for(&self, name: &str, elements: U) -> Result - where U: IntoIterator, - I: AsRef, - { - if self.router().is_none() { - Err(UrlGenerationError::RouterNotAvailable) - } else { - let path = self.router().unwrap().resource_path(name, elements)?; - if path.starts_with('/') { - let conn = self.connection_info(); - Ok(Url::parse(&format!("{}://{}{}", conn.scheme(), conn.host(), path))?) - } else { - Ok(Url::parse(&path)?) - } - } - } - - /// This method returns reference to current `Router` object. - #[inline] - pub fn router(&self) -> Option<&Router> { - self.2.as_ref() - } - - /// This method returns reference to matched `Resource` object. - #[inline] - pub fn resource(&self) -> &Resource { - if let Some(ref router) = self.2 { - if let RouterResource::Normal(idx) = self.as_ref().resource { - return router.get_resource(idx as usize) - } - } - &*RESOURCE - } - - pub(crate) fn set_resource(&mut self, res: usize) { - self.as_mut().resource = RouterResource::Normal(res as u16); - } - - /// Peer socket address - /// - /// Peer address is actual socket address, if proxy is used in front of - /// actix http server, then peer address would be address of this proxy. - /// - /// To get client connection information `connection_info()` method should be used. - #[inline] - pub fn peer_addr(&self) -> Option<&SocketAddr> { - self.as_ref().addr.as_ref() - } - - #[inline] - pub(crate) fn set_peer_addr(&mut self, addr: Option) { - self.as_mut().addr = addr - } - - /// Get a reference to the Params object. - /// Params is a container for url query parameters. - pub fn query(&self) -> &Params { - if !self.as_ref().query_loaded { - let params: &mut Params = unsafe{ mem::transmute(&mut self.as_mut().query) }; - self.as_mut().query_loaded = true; - for (key, val) in form_urlencoded::parse(self.query_string().as_ref()) { - params.add(key, val); - } - } - unsafe{ mem::transmute(&self.as_ref().query) } - } - - /// The query string in the URL. - /// - /// E.g., id=10 - #[inline] - pub fn query_string(&self) -> &str { - if let Some(query) = self.uri().query().as_ref() { - query - } else { - "" - } - } - - /// Load request cookies. - pub fn cookies(&self) -> Result<&Vec>, CookieParseError> { - if self.as_ref().cookies.is_none() { - let msg = self.as_mut(); - let mut cookies = Vec::new(); - for hdr in msg.headers.get_all(header::COOKIE) { - let s = str::from_utf8(hdr.as_bytes()).map_err(CookieParseError::from)?; - for cookie_str in s.split(';').map(|s| s.trim()) { - if !cookie_str.is_empty() { - cookies.push(Cookie::parse_encoded(cookie_str)?.into_owned()); - } - } - } - msg.cookies = Some(cookies) - } - Ok(self.as_ref().cookies.as_ref().unwrap()) - } - - /// Return request cookie. - pub fn cookie(&self, name: &str) -> Option<&Cookie> { - if let Ok(cookies) = self.cookies() { - for cookie in cookies { - if cookie.name() == name { - return Some(cookie) - } - } - } - None - } - - /// Get a reference to the Params object. - /// - /// Params is a container for url parameters. - /// Route supports glob patterns: * for a single wildcard segment and :param - /// for matching storing that segment of the request url in the Params object. - #[inline] - pub fn match_info(&self) -> &Params { - unsafe{ mem::transmute(&self.as_ref().params) } - } - - /// Get mutable reference to request's Params. - #[inline] - pub fn match_info_mut(&mut self) -> &mut Params { - unsafe{ mem::transmute(&mut self.as_mut().params) } - } - - /// Checks if a connection should be kept alive. - pub fn keep_alive(&self) -> bool { - self.as_ref().keep_alive() - } - - /// Check if request requires connection upgrade - pub(crate) fn upgrade(&self) -> bool { - if let Some(conn) = self.as_ref().headers.get(header::CONNECTION) { - if let Ok(s) = conn.to_str() { - return s.to_lowercase().contains("upgrade") - } - } - self.as_ref().method == Method::CONNECT - } - - /// Set read buffer capacity - /// - /// Default buffer capacity is 32Kb. - pub fn set_read_buffer_capacity(&mut self, cap: usize) { - if let Some(ref mut payload) = self.as_mut().payload { - payload.set_read_buffer_capacity(cap) - } - } - - #[cfg(test)] - pub(crate) fn payload(&self) -> &Payload { - let msg = self.as_mut(); - if msg.payload.is_none() { - msg.payload = Some(Payload::empty()); - } - msg.payload.as_ref().unwrap() - } - - #[cfg(test)] - pub(crate) fn payload_mut(&mut self) -> &mut Payload { - let msg = self.as_mut(); - if msg.payload.is_none() { - msg.payload = Some(Payload::empty()); - } - msg.payload.as_mut().unwrap() - } -} - -impl Default for HttpRequest<()> { - - /// Construct default request - fn default() -> HttpRequest { - HttpRequest(SharedHttpInnerMessage::default(), None, None) - } -} - -impl Clone for HttpRequest { - fn clone(&self) -> HttpRequest { - HttpRequest(self.0.clone(), self.1.clone(), self.2.clone()) - } -} - -impl FromRequest for HttpRequest -{ - type Result = FutureResult; - - #[inline] - fn from_request(req: &HttpRequest) -> Self::Result { - result(Ok(req.clone())) - } -} - -impl Stream for HttpRequest { - type Item = Bytes; - type Error = PayloadError; - - fn poll(&mut self) -> Poll, PayloadError> { - let msg = self.as_mut(); - if msg.payload.is_none() { - Ok(Async::Ready(None)) - } else { - msg.payload.as_mut().unwrap().poll() - } - } -} - -impl io::Read for HttpRequest { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - if self.as_mut().payload.is_some() { - match self.as_mut().payload.as_mut().unwrap().poll() { - Ok(Async::Ready(Some(mut b))) => { - let i = cmp::min(b.len(), buf.len()); - buf.copy_from_slice(&b.split_to(i)[..i]); - - if !b.is_empty() { - self.as_mut().payload.as_mut().unwrap().unread_data(b); - } - - if i < buf.len() { - match self.read(&mut buf[i..]) { - Ok(n) => Ok(i + n), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(i), - Err(e) => Err(e), - } - } else { - Ok(i) - } - } - Ok(Async::Ready(None)) => Ok(0), - Ok(Async::NotReady) => - Err(io::Error::new(io::ErrorKind::WouldBlock, "Not ready")), - Err(e) => - Err(io::Error::new(io::ErrorKind::Other, failure::Error::from(e).compat())), - } - } else { - Ok(0) - } - } -} - -impl AsyncRead for HttpRequest {} - -impl fmt::Debug for HttpRequest { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let res = write!(f, "\nHttpRequest {:?} {}:{}\n", - self.as_ref().version, self.as_ref().method, self.path_decoded()); - if !self.query_string().is_empty() { - let _ = write!(f, " query: ?{:?}\n", self.query_string()); - } - if !self.match_info().is_empty() { - let _ = write!(f, " params: {:?}\n", self.as_ref().params); - } - let _ = write!(f, " headers:\n"); - for (key, val) in self.as_ref().headers.iter() { - let _ = write!(f, " {:?}: {:?}\n", key, val); - } - res - } -} - -#[cfg(test)] -mod tests { - use super::*; - use http::{Uri, HttpTryFrom}; - use router::Resource; - use resource::ResourceHandler; - use test::TestRequest; - use server::ServerSettings; - - #[test] - fn test_debug() { - let req = TestRequest::with_header("content-type", "text/plain").finish(); - let dbg = format!("{:?}", req); - assert!(dbg.contains("HttpRequest")); - } - - #[test] - fn test_uri_mut() { - let mut req = HttpRequest::default(); - assert_eq!(req.path(), "/"); - *req.uri_mut() = Uri::try_from("/test").unwrap(); - assert_eq!(req.path(), "/test"); - } - - #[test] - fn test_no_request_cookies() { - let req = HttpRequest::default(); - assert!(req.cookies().unwrap().is_empty()); - } - - #[test] - fn test_request_cookies() { - let req = TestRequest::default() - .header(header::COOKIE, "cookie1=value1") - .header(header::COOKIE, "cookie2=value2") - .finish(); - { - let cookies = req.cookies().unwrap(); - assert_eq!(cookies.len(), 2); - assert_eq!(cookies[0].name(), "cookie1"); - assert_eq!(cookies[0].value(), "value1"); - assert_eq!(cookies[1].name(), "cookie2"); - assert_eq!(cookies[1].value(), "value2"); - } - - let cookie = req.cookie("cookie1"); - assert!(cookie.is_some()); - let cookie = cookie.unwrap(); - assert_eq!(cookie.name(), "cookie1"); - assert_eq!(cookie.value(), "value1"); - - let cookie = req.cookie("cookie-unknown"); - assert!(cookie.is_none()); - } - - #[test] - fn test_request_query() { - let req = TestRequest::with_uri("/?id=test").finish(); - assert_eq!(req.query_string(), "id=test"); - let query = req.query(); - assert_eq!(&query["id"], "test"); - } - - #[test] - fn test_request_match_info() { - let mut req = TestRequest::with_uri("/value/?id=test").finish(); - - let mut resource = ResourceHandler::<()>::default(); - resource.name("index"); - let mut routes = Vec::new(); - routes.push((Resource::new("index", "/{key}/"), Some(resource))); - let (router, _) = Router::new("", ServerSettings::default(), routes); - assert!(router.recognize(&mut req).is_some()); - - assert_eq!(req.match_info().get("key"), Some("value")); - } - - #[test] - fn test_url_for() { - let req2 = HttpRequest::default(); - assert_eq!(req2.url_for("unknown", &["test"]), - Err(UrlGenerationError::RouterNotAvailable)); - - let mut resource = ResourceHandler::<()>::default(); - resource.name("index"); - let routes = vec!((Resource::new("index", "/user/{name}.{ext}"), Some(resource))); - let (router, _) = Router::new("/", ServerSettings::default(), routes); - assert!(router.has_route("/user/test.html")); - assert!(!router.has_route("/test/unknown")); - - let req = TestRequest::with_header(header::HOST, "www.rust-lang.org") - .finish_with_router(router); - - assert_eq!(req.url_for("unknown", &["test"]), - Err(UrlGenerationError::ResourceNotFound)); - assert_eq!(req.url_for("index", &["test"]), - Err(UrlGenerationError::NotEnoughElements)); - let url = req.url_for("index", &["test", "html"]); - assert_eq!(url.ok().unwrap().as_str(), "http://www.rust-lang.org/user/test.html"); - } - - #[test] - fn test_url_for_with_prefix() { - let req = TestRequest::with_header(header::HOST, "www.rust-lang.org").finish(); - - let mut resource = ResourceHandler::<()>::default(); - resource.name("index"); - let routes = vec![(Resource::new("index", "/user/{name}.{ext}"), Some(resource))]; - let (router, _) = Router::new("/prefix/", ServerSettings::default(), routes); - assert!(router.has_route("/user/test.html")); - assert!(!router.has_route("/prefix/user/test.html")); - - let req = req.with_state(Rc::new(()), router); - let url = req.url_for("index", &["test", "html"]); - assert_eq!(url.ok().unwrap().as_str(), - "http://www.rust-lang.org/prefix/user/test.html"); - } - - #[test] - fn test_url_for_external() { - let req = HttpRequest::default(); - - let mut resource = ResourceHandler::<()>::default(); - resource.name("index"); - let routes = vec![ - (Resource::external("youtube", "https://youtube.com/watch/{video_id}"), None)]; - let (router, _) = Router::new::<()>("", ServerSettings::default(), routes); - assert!(!router.has_route("https://youtube.com/watch/unknown")); - - let req = req.with_state(Rc::new(()), router); - let url = req.url_for("youtube", &["oHg5SJYRHA0"]); - assert_eq!(url.ok().unwrap().as_str(), "https://youtube.com/watch/oHg5SJYRHA0"); - } -} diff --git a/src/httpresponse.rs b/src/httpresponse.rs deleted file mode 100644 index 1f763159d..000000000 --- a/src/httpresponse.rs +++ /dev/null @@ -1,1092 +0,0 @@ -//! Http response -use std::{mem, str, fmt}; -use std::rc::Rc; -use std::io::Write; -use std::cell::UnsafeCell; -use std::collections::VecDeque; - -use cookie::{Cookie, CookieJar}; -use bytes::{Bytes, BytesMut, BufMut}; -use futures::Stream; -use http::{StatusCode, Version, HeaderMap, HttpTryFrom, Error as HttpError}; -use http::header::{self, HeaderName, HeaderValue}; -use serde_json; -use serde::Serialize; - -use body::Body; -use error::Error; -use handler::Responder; -use header::{Header, IntoHeaderValue, ContentEncoding}; -use httprequest::HttpRequest; -use httpmessage::HttpMessage; -use client::ClientResponse; - -/// max write buffer size 64k -pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = 65_536; - - -/// Represents various types of connection -#[derive(Copy, Clone, PartialEq, Debug)] -pub enum ConnectionType { - /// Close connection after response - Close, - /// Keep connection alive after response - KeepAlive, - /// Connection is upgraded to different type - Upgrade, -} - -/// An HTTP Response -pub struct HttpResponse(Option>, Rc>); - -impl Drop for HttpResponse { - fn drop(&mut self) { - if let Some(inner) = self.0.take() { - HttpResponsePool::release(&self.1, inner) - } - } -} - -impl HttpResponse { - - #[inline(always)] - #[cfg_attr(feature = "cargo-clippy", allow(inline_always))] - fn get_ref(&self) -> &InnerHttpResponse { - self.0.as_ref().unwrap() - } - - #[inline(always)] - #[cfg_attr(feature = "cargo-clippy", allow(inline_always))] - fn get_mut(&mut self) -> &mut InnerHttpResponse { - self.0.as_mut().unwrap() - } - - /// Create http response builder with specific status. - #[inline] - pub fn build(status: StatusCode) -> HttpResponseBuilder { - HttpResponsePool::get(status) - } - - /// Create http response builder - #[inline] - pub fn build_from>(source: T) -> HttpResponseBuilder { - source.into() - } - - /// Constructs a response - #[inline] - pub fn new(status: StatusCode) -> HttpResponse { - HttpResponsePool::with_body(status, Body::Empty) - } - - /// Constructs a response with body - #[inline] - pub fn with_body>(status: StatusCode, body: B) -> HttpResponse { - HttpResponsePool::with_body(status, body.into()) - } - - /// Constructs a error response - #[inline] - pub fn from_error(error: Error) -> HttpResponse { - let mut resp = error.cause().error_response(); - resp.get_mut().error = Some(error); - resp - } - - /// Convert `HttpResponse` to a `HttpResponseBuilder` - #[inline] - pub fn into_builder(mut self) -> HttpResponseBuilder { - let response = self.0.take(); - let pool = Some(Rc::clone(&self.1)); - - HttpResponseBuilder { - response, - pool, - err: None, - cookies: None, // TODO: convert set-cookie headers - } - } - - /// The source `error` for this response - #[inline] - pub fn error(&self) -> Option<&Error> { - self.get_ref().error.as_ref() - } - - /// Get the HTTP version of this response - #[inline] - pub fn version(&self) -> Option { - self.get_ref().version - } - - /// Get the headers from the response - #[inline] - pub fn headers(&self) -> &HeaderMap { - &self.get_ref().headers - } - - /// Get a mutable reference to the headers - #[inline] - pub fn headers_mut(&mut self) -> &mut HeaderMap { - &mut self.get_mut().headers - } - - /// Get the response status code - #[inline] - pub fn status(&self) -> StatusCode { - self.get_ref().status - } - - /// Set the `StatusCode` for this response - #[inline] - pub fn status_mut(&mut self) -> &mut StatusCode { - &mut self.get_mut().status - } - - /// Get custom reason for the response - #[inline] - pub fn reason(&self) -> &str { - if let Some(reason) = self.get_ref().reason { - reason - } else { - self.get_ref().status.canonical_reason().unwrap_or("") - } - } - - /// Set the custom reason for the response - #[inline] - pub fn set_reason(&mut self, reason: &'static str) -> &mut Self { - self.get_mut().reason = Some(reason); - self - } - - /// Set connection type - pub fn set_connection_type(&mut self, conn: ConnectionType) -> &mut Self { - self.get_mut().connection_type = Some(conn); - self - } - - /// Connection upgrade status - #[inline] - pub fn upgrade(&self) -> bool { - self.get_ref().connection_type == Some(ConnectionType::Upgrade) - } - - /// Keep-alive status for this connection - pub fn keep_alive(&self) -> Option { - if let Some(ct) = self.get_ref().connection_type { - match ct { - ConnectionType::KeepAlive => Some(true), - ConnectionType::Close | ConnectionType::Upgrade => Some(false), - } - } else { - None - } - } - - /// is chunked encoding enabled - #[inline] - pub fn chunked(&self) -> Option { - self.get_ref().chunked - } - - /// Content encoding - #[inline] - pub fn content_encoding(&self) -> Option { - self.get_ref().encoding - } - - /// Set content encoding - pub fn set_content_encoding(&mut self, enc: ContentEncoding) -> &mut Self { - self.get_mut().encoding = Some(enc); - self - } - - /// Get body os this response - #[inline] - pub fn body(&self) -> &Body { - &self.get_ref().body - } - - /// Set a body - pub fn set_body>(&mut self, body: B) { - self.get_mut().body = body.into(); - } - - /// Set a body and return previous body value - pub fn replace_body>(&mut self, body: B) -> Body { - mem::replace(&mut self.get_mut().body, body.into()) - } - - /// Size of response in bytes, excluding HTTP headers - pub fn response_size(&self) -> u64 { - self.get_ref().response_size - } - - /// Set content encoding - pub(crate) fn set_response_size(&mut self, size: u64) { - self.get_mut().response_size = size; - } - - /// Set write buffer capacity - pub fn write_buffer_capacity(&self) -> usize { - self.get_ref().write_capacity - } - - /// Set write buffer capacity - pub fn set_write_buffer_capacity(&mut self, cap: usize) { - self.get_mut().write_capacity = cap; - } -} - -impl fmt::Debug for HttpResponse { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let res = write!(f, "\nHttpResponse {:?} {}{}\n", - self.get_ref().version, self.get_ref().status, - self.get_ref().reason.unwrap_or("")); - let _ = write!(f, " encoding: {:?}\n", self.get_ref().encoding); - let _ = write!(f, " headers:\n"); - for (key, val) in self.get_ref().headers.iter() { - let _ = write!(f, " {:?}: {:?}\n", key, val); - } - res - } -} - -/// An HTTP response builder -/// -/// This type can be used to construct an instance of `HttpResponse` through a -/// builder-like pattern. -pub struct HttpResponseBuilder { - response: Option>, - pool: Option>>, - err: Option, - cookies: Option, -} - -impl HttpResponseBuilder { - /// Set HTTP status code of this response. - #[inline] - pub fn status(&mut self, status: StatusCode) -> &mut Self { - if let Some(parts) = parts(&mut self.response, &self.err) { - parts.status = status; - } - self - } - - /// Set HTTP version of this response. - /// - /// By default response's http version depends on request's version. - #[inline] - pub fn version(&mut self, version: Version) -> &mut Self { - if let Some(parts) = parts(&mut self.response, &self.err) { - parts.version = Some(version); - } - self - } - - /// Set a header. - /// - /// ```rust - /// # extern crate actix_web; - /// use actix_web::{HttpRequest, HttpResponse, Result, http}; - /// - /// fn index(req: HttpRequest) -> Result { - /// Ok(HttpResponse::Ok() - /// .set(http::header::IfModifiedSince("Sun, 07 Nov 1994 08:48:37 GMT".parse()?)) - /// .finish()) - /// } - /// fn main() {} - /// ``` - #[doc(hidden)] - pub fn set(&mut self, hdr: H) -> &mut Self - { - if let Some(parts) = parts(&mut self.response, &self.err) { - match hdr.try_into() { - Ok(value) => { parts.headers.append(H::name(), value); } - Err(e) => self.err = Some(e.into()), - } - } - self - } - - /// Set a header. - /// - /// ```rust - /// # extern crate actix_web; - /// use actix_web::{http, HttpRequest, HttpResponse}; - /// - /// fn index(req: HttpRequest) -> HttpResponse { - /// HttpResponse::Ok() - /// .header("X-TEST", "value") - /// .header(http::header::CONTENT_TYPE, "application/json") - /// .finish() - /// } - /// fn main() {} - /// ``` - pub fn header(&mut self, key: K, value: V) -> &mut Self - where HeaderName: HttpTryFrom, - V: IntoHeaderValue, - { - if let Some(parts) = parts(&mut self.response, &self.err) { - match HeaderName::try_from(key) { - Ok(key) => { - match value.try_into() { - Ok(value) => { parts.headers.append(key, value); } - Err(e) => self.err = Some(e.into()), - } - }, - Err(e) => self.err = Some(e.into()), - }; - } - self - } - - /// Set the custom reason for the response. - #[inline] - pub fn reason(&mut self, reason: &'static str) -> &mut Self { - if let Some(parts) = parts(&mut self.response, &self.err) { - parts.reason = Some(reason); - } - self - } - - /// Set content encoding. - /// - /// By default `ContentEncoding::Auto` is used, which automatically - /// negotiates content encoding based on request's `Accept-Encoding` headers. - /// To enforce specific encoding, use specific ContentEncoding` value. - #[inline] - pub fn content_encoding(&mut self, enc: ContentEncoding) -> &mut Self { - if let Some(parts) = parts(&mut self.response, &self.err) { - parts.encoding = Some(enc); - } - self - } - - /// Set connection type - #[inline] - #[doc(hidden)] - pub fn connection_type(&mut self, conn: ConnectionType) -> &mut Self { - if let Some(parts) = parts(&mut self.response, &self.err) { - parts.connection_type = Some(conn); - } - self - } - - /// Set connection type to Upgrade - #[inline] - #[doc(hidden)] - pub fn upgrade(&mut self) -> &mut Self { - self.connection_type(ConnectionType::Upgrade) - } - - /// Force close connection, even if it is marked as keep-alive - #[inline] - pub fn force_close(&mut self) -> &mut Self { - self.connection_type(ConnectionType::Close) - } - - /// Enables automatic chunked transfer encoding - #[inline] - pub fn chunked(&mut self) -> &mut Self { - if let Some(parts) = parts(&mut self.response, &self.err) { - parts.chunked = Some(true); - } - self - } - - /// Force disable chunked encoding - #[inline] - pub fn no_chunking(&mut self) -> &mut Self { - if let Some(parts) = parts(&mut self.response, &self.err) { - parts.chunked = Some(false); - } - self - } - - /// Set response content type - #[inline] - pub fn content_type(&mut self, value: V) -> &mut Self - where HeaderValue: HttpTryFrom - { - if let Some(parts) = parts(&mut self.response, &self.err) { - match HeaderValue::try_from(value) { - Ok(value) => { parts.headers.insert(header::CONTENT_TYPE, value); }, - Err(e) => self.err = Some(e.into()), - }; - } - self - } - - /// Set content length - #[inline] - pub fn content_length(&mut self, len: u64) -> &mut Self { - let mut wrt = BytesMut::new().writer(); - let _ = write!(wrt, "{}", len); - self.header(header::CONTENT_LENGTH, wrt.get_mut().take().freeze()) - } - - /// Set a cookie - /// - /// ```rust - /// # extern crate actix_web; - /// use actix_web::{http, HttpRequest, HttpResponse, Result}; - /// - /// fn index(req: HttpRequest) -> HttpResponse { - /// HttpResponse::Ok() - /// .cookie( - /// http::Cookie::build("name", "value") - /// .domain("www.rust-lang.org") - /// .path("/") - /// .secure(true) - /// .http_only(true) - /// .finish()) - /// .finish() - /// } - /// fn main() {} - /// ``` - pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self { - if self.cookies.is_none() { - let mut jar = CookieJar::new(); - jar.add(cookie.into_owned()); - self.cookies = Some(jar) - } else { - self.cookies.as_mut().unwrap().add(cookie.into_owned()); - } - self - } - - /// Remove cookie, cookie has to be cookie from `HttpRequest::cookies()` method. - pub fn del_cookie<'a>(&mut self, cookie: &Cookie<'a>) -> &mut Self { - { - if self.cookies.is_none() { - self.cookies = Some(CookieJar::new()) - } - let jar = self.cookies.as_mut().unwrap(); - let cookie = cookie.clone().into_owned(); - jar.add_original(cookie.clone()); - jar.remove(cookie); - } - self - } - - /// This method calls provided closure with builder reference if value is true. - pub fn if_true(&mut self, value: bool, f: F) -> &mut Self - where F: FnOnce(&mut HttpResponseBuilder) - { - if value { - f(self); - } - self - } - - /// This method calls provided closure with builder reference if value is Some. - pub fn if_some(&mut self, value: Option, f: F) -> &mut Self - where F: FnOnce(T, &mut HttpResponseBuilder) - { - if let Some(val) = value { - f(val, self); - } - self - } - - /// Set write buffer capacity - /// - /// This parameter makes sense only for streaming response - /// or actor. If write buffer reaches specified capacity, stream or actor get - /// paused. - /// - /// Default write buffer capacity is 64kb - pub fn write_buffer_capacity(&mut self, cap: usize) -> &mut Self { - if let Some(parts) = parts(&mut self.response, &self.err) { - parts.write_capacity = cap; - } - self - } - - /// Set a body and generate `HttpResponse`. - /// - /// `HttpResponseBuilder` can not be used after this call. - pub fn body>(&mut self, body: B) -> HttpResponse { - if let Some(e) = self.err.take() { - return Error::from(e).into() - } - let mut response = self.response.take().expect("cannot reuse response builder"); - if let Some(ref jar) = self.cookies { - for cookie in jar.delta() { - match HeaderValue::from_str(&cookie.to_string()) { - Ok(val) => response.headers.append(header::SET_COOKIE, val), - Err(e) => return Error::from(e).into(), - }; - } - } - response.body = body.into(); - HttpResponse(Some(response), self.pool.take().unwrap()) - } - - #[inline] - /// Set a streaming body and generate `HttpResponse`. - /// - /// `HttpResponseBuilder` can not be used after this call. - pub fn streaming(&mut self, stream: S) -> HttpResponse - where S: Stream + 'static, - E: Into, - { - self.body(Body::Streaming(Box::new(stream.map_err(|e| e.into())))) - } - - /// Set a json body and generate `HttpResponse` - /// - /// `HttpResponseBuilder` can not be used after this call. - pub fn json(&mut self, value: T) -> HttpResponse { - match serde_json::to_string(&value) { - Ok(body) => { - let contains = - if let Some(parts) = parts(&mut self.response, &self.err) { - parts.headers.contains_key(header::CONTENT_TYPE) - } else { - true - }; - if !contains { - self.header(header::CONTENT_TYPE, "application/json"); - } - - self.body(body) - }, - Err(e) => Error::from(e).into() - } - } - - #[inline] - /// Set an empty body and generate `HttpResponse` - /// - /// `HttpResponseBuilder` can not be used after this call. - pub fn finish(&mut self) -> HttpResponse { - self.body(Body::Empty) - } - - /// This method construct new `HttpResponseBuilder` - pub fn take(&mut self) -> HttpResponseBuilder { - HttpResponseBuilder { - response: self.response.take(), - pool: self.pool.take(), - err: self.err.take(), - cookies: self.cookies.take(), - } - } -} - -#[inline] -#[cfg_attr(feature = "cargo-clippy", allow(borrowed_box))] -fn parts<'a>(parts: &'a mut Option>, err: &Option) - -> Option<&'a mut Box> -{ - if err.is_some() { - return None - } - parts.as_mut() -} - -/// Helper converters -impl, E: Into> From> for HttpResponse { - fn from(res: Result) -> Self { - match res { - Ok(val) => val.into(), - Err(err) => err.into().into(), - } - } -} - -impl From for HttpResponse { - fn from(mut builder: HttpResponseBuilder) -> Self { - builder.finish() - } -} - -impl Responder for HttpResponseBuilder { - type Item = HttpResponse; - type Error = Error; - - #[inline] - fn respond_to(mut self, _: HttpRequest) -> Result { - Ok(self.finish()) - } -} - -impl From<&'static str> for HttpResponse { - fn from(val: &'static str) -> Self { - HttpResponse::Ok() - .content_type("text/plain; charset=utf-8") - .body(val) - } -} - -impl Responder for &'static str { - type Item = HttpResponse; - type Error = Error; - - fn respond_to(self, _: HttpRequest) -> Result { - Ok(HttpResponse::Ok() - .content_type("text/plain; charset=utf-8") - .body(self)) - } -} - -impl From<&'static [u8]> for HttpResponse { - fn from(val: &'static [u8]) -> Self { - HttpResponse::Ok() - .content_type("application/octet-stream") - .body(val) - } -} - -impl Responder for &'static [u8] { - type Item = HttpResponse; - type Error = Error; - - fn respond_to(self, _: HttpRequest) -> Result { - Ok(HttpResponse::Ok() - .content_type("application/octet-stream") - .body(self)) - } -} - -impl From for HttpResponse { - fn from(val: String) -> Self { - HttpResponse::Ok() - .content_type("text/plain; charset=utf-8") - .body(val) - } -} - -impl Responder for String { - type Item = HttpResponse; - type Error = Error; - - fn respond_to(self, _: HttpRequest) -> Result { - Ok(HttpResponse::Ok() - .content_type("text/plain; charset=utf-8") - .body(self)) - } -} - -impl<'a> From<&'a String> for HttpResponse { - fn from(val: &'a String) -> Self { - HttpResponse::build(StatusCode::OK) - .content_type("text/plain; charset=utf-8") - .body(val) - } -} - -impl<'a> Responder for &'a String { - type Item = HttpResponse; - type Error = Error; - - fn respond_to(self, _: HttpRequest) -> Result { - Ok(HttpResponse::Ok() - .content_type("text/plain; charset=utf-8") - .body(self)) - } -} - -impl From for HttpResponse { - fn from(val: Bytes) -> Self { - HttpResponse::Ok() - .content_type("application/octet-stream") - .body(val) - } -} - -impl Responder for Bytes { - type Item = HttpResponse; - type Error = Error; - - fn respond_to(self, _: HttpRequest) -> Result { - Ok(HttpResponse::Ok() - .content_type("application/octet-stream") - .body(self)) - } -} - -impl From for HttpResponse { - fn from(val: BytesMut) -> Self { - HttpResponse::Ok() - .content_type("application/octet-stream") - .body(val) - } -} - -impl Responder for BytesMut { - type Item = HttpResponse; - type Error = Error; - - fn respond_to(self, _: HttpRequest) -> Result { - Ok(HttpResponse::Ok() - .content_type("application/octet-stream") - .body(self)) - } -} - -/// Create `HttpResponseBuilder` from `ClientResponse` -/// -/// It is useful for proxy response. This implementation -/// copies all responses's headers and status. -impl<'a> From<&'a ClientResponse> for HttpResponseBuilder { - fn from(resp: &'a ClientResponse) -> HttpResponseBuilder { - let mut builder = HttpResponse::build(resp.status()); - for (key, value) in resp.headers() { - builder.header(key.clone(), value.clone()); - } - builder - } -} - -impl<'a, S> From<&'a HttpRequest> for HttpResponseBuilder { - fn from(req: &'a HttpRequest) -> HttpResponseBuilder { - if let Some(router) = req.router() { - router.server_settings().get_response_builder(StatusCode::OK) - } else { - HttpResponse::Ok() - } - } -} - -#[derive(Debug)] -struct InnerHttpResponse { - version: Option, - headers: HeaderMap, - status: StatusCode, - reason: Option<&'static str>, - body: Body, - chunked: Option, - encoding: Option, - connection_type: Option, - write_capacity: usize, - response_size: u64, - error: Option, -} - -impl InnerHttpResponse { - - #[inline] - fn new(status: StatusCode, body: Body) -> InnerHttpResponse { - InnerHttpResponse { - status, - body, - version: None, - headers: HeaderMap::with_capacity(16), - reason: None, - chunked: None, - encoding: None, - connection_type: None, - response_size: 0, - write_capacity: MAX_WRITE_BUFFER_SIZE, - error: None, - } - } -} - -/// Internal use only! unsafe -pub(crate) struct HttpResponsePool(VecDeque>); - -thread_local!(static POOL: Rc> = HttpResponsePool::pool()); - -impl HttpResponsePool { - - pub fn pool() -> Rc> { - Rc::new(UnsafeCell::new(HttpResponsePool(VecDeque::with_capacity(128)))) - } - - #[inline] - pub fn get_builder(pool: &Rc>, status: StatusCode) - -> HttpResponseBuilder - { - let p = unsafe{&mut *pool.as_ref().get()}; - if let Some(mut msg) = p.0.pop_front() { - msg.status = status; - HttpResponseBuilder { - response: Some(msg), - pool: Some(Rc::clone(pool)), - err: None, - cookies: None } - } else { - let msg = Box::new(InnerHttpResponse::new(status, Body::Empty)); - HttpResponseBuilder { - response: Some(msg), - pool: Some(Rc::clone(pool)), - err: None, - cookies: None } - } - } - - #[inline] - pub fn get_response(pool: &Rc>, - status: StatusCode, body: Body) -> HttpResponse - { - let p = unsafe{&mut *pool.as_ref().get()}; - if let Some(mut msg) = p.0.pop_front() { - msg.status = status; - msg.body = body; - HttpResponse(Some(msg), Rc::clone(pool)) - } else { - let msg = Box::new(InnerHttpResponse::new(status, body)); - HttpResponse(Some(msg), Rc::clone(pool)) - } - } - - #[inline] - fn get(status: StatusCode) -> HttpResponseBuilder { - POOL.with(|pool| HttpResponsePool::get_builder(pool, status)) - } - - #[inline] - fn with_body(status: StatusCode, body: Body) -> HttpResponse { - POOL.with(|pool| HttpResponsePool::get_response(pool, status, body)) - } - - #[inline(always)] - #[cfg_attr(feature = "cargo-clippy", allow(boxed_local, inline_always))] - fn release(pool: &Rc>, mut inner: Box) - { - let pool = unsafe{&mut *pool.as_ref().get()}; - if pool.0.len() < 128 { - inner.headers.clear(); - inner.version = None; - inner.chunked = None; - inner.reason = None; - inner.encoding = None; - inner.connection_type = None; - inner.response_size = 0; - inner.error = None; - inner.write_capacity = MAX_WRITE_BUFFER_SIZE; - pool.0.push_front(inner); - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::str::FromStr; - use time::Duration; - use http::{Method, Uri}; - use http::header::{COOKIE, CONTENT_TYPE, HeaderValue}; - use body::Binary; - use http; - - #[test] - fn test_debug() { - let resp = HttpResponse::Ok() - .header(COOKIE, HeaderValue::from_static("cookie1=value1; ")) - .header(COOKIE, HeaderValue::from_static("cookie2=value2; ")) - .finish(); - let dbg = format!("{:?}", resp); - assert!(dbg.contains("HttpResponse")); - } - - #[test] - fn test_response_cookies() { - let mut headers = HeaderMap::new(); - headers.insert(COOKIE, HeaderValue::from_static("cookie1=value1")); - headers.insert(COOKIE, HeaderValue::from_static("cookie2=value2")); - - let req = HttpRequest::new( - Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None); - let cookies = req.cookies().unwrap(); - - let resp = HttpResponse::Ok() - .cookie(http::Cookie::build("name", "value") - .domain("www.rust-lang.org") - .path("/test") - .http_only(true) - .max_age(Duration::days(1)) - .finish()) - .del_cookie(&cookies[0]) - .finish(); - - let mut val: Vec<_> = resp.headers().get_all("Set-Cookie") - .iter().map(|v| v.to_str().unwrap().to_owned()).collect(); - val.sort(); - assert!(val[0].starts_with("cookie2=; Max-Age=0;")); - assert_eq!( - val[1],"name=value; HttpOnly; Path=/test; Domain=www.rust-lang.org; Max-Age=86400"); - } - - #[test] - fn test_basic_builder() { - let resp = HttpResponse::Ok() - .header("X-TEST", "value") - .version(Version::HTTP_10) - .finish(); - assert_eq!(resp.version(), Some(Version::HTTP_10)); - assert_eq!(resp.status(), StatusCode::OK); - } - - #[test] - fn test_upgrade() { - let resp = HttpResponse::build(StatusCode::OK).upgrade().finish(); - assert!(resp.upgrade()) - } - - #[test] - fn test_force_close() { - let resp = HttpResponse::build(StatusCode::OK).force_close().finish(); - assert!(!resp.keep_alive().unwrap()) - } - - #[test] - fn test_content_type() { - let resp = HttpResponse::build(StatusCode::OK) - .content_type("text/plain").body(Body::Empty); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain") - } - - #[test] - fn test_content_encoding() { - let resp = HttpResponse::build(StatusCode::OK).finish(); - assert_eq!(resp.content_encoding(), None); - - #[cfg(feature="brotli")] - { - let resp = HttpResponse::build(StatusCode::OK) - .content_encoding(ContentEncoding::Br).finish(); - assert_eq!(resp.content_encoding(), Some(ContentEncoding::Br)); - } - - let resp = HttpResponse::build(StatusCode::OK) - .content_encoding(ContentEncoding::Gzip).finish(); - assert_eq!(resp.content_encoding(), Some(ContentEncoding::Gzip)); - } - - #[test] - fn test_json() { - let resp = HttpResponse::build(StatusCode::OK) - .json(vec!["v1", "v2", "v3"]); - let ct = resp.headers().get(CONTENT_TYPE).unwrap(); - assert_eq!(ct, HeaderValue::from_static("application/json")); - assert_eq!(*resp.body(), Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]"))); - } - - #[test] - fn test_json_ct() { - let resp = HttpResponse::build(StatusCode::OK) - .header(CONTENT_TYPE, "text/json") - .json(vec!["v1", "v2", "v3"]); - let ct = resp.headers().get(CONTENT_TYPE).unwrap(); - assert_eq!(ct, HeaderValue::from_static("text/json")); - assert_eq!(*resp.body(), Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]"))); - } - - impl Body { - pub(crate) fn binary(&self) -> Option<&Binary> { - match *self { - Body::Binary(ref bin) => Some(bin), - _ => None, - } - } - } - - #[test] - fn test_into_response() { - let req = HttpRequest::default(); - - let resp: HttpResponse = "test".into(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8")); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().binary().unwrap(), &Binary::from("test")); - - let resp: HttpResponse = "test".respond_to(req.clone()).ok().unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8")); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().binary().unwrap(), &Binary::from("test")); - - let resp: HttpResponse = b"test".as_ref().into(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/octet-stream")); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().binary().unwrap(), &Binary::from(b"test".as_ref())); - - let resp: HttpResponse = b"test".as_ref().respond_to(req.clone()).ok().unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/octet-stream")); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().binary().unwrap(), &Binary::from(b"test".as_ref())); - - let resp: HttpResponse = "test".to_owned().into(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8")); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().binary().unwrap(), &Binary::from("test".to_owned())); - - let resp: HttpResponse = "test".to_owned().respond_to(req.clone()).ok().unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8")); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().binary().unwrap(), &Binary::from("test".to_owned())); - - let resp: HttpResponse = (&"test".to_owned()).into(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8")); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().binary().unwrap(), &Binary::from(&"test".to_owned())); - - let resp: HttpResponse = (&"test".to_owned()).respond_to(req.clone()).ok().unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8")); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().binary().unwrap(), &Binary::from(&"test".to_owned())); - - let b = Bytes::from_static(b"test"); - let resp: HttpResponse = b.into(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/octet-stream")); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().binary().unwrap(), &Binary::from(Bytes::from_static(b"test"))); - - let b = Bytes::from_static(b"test"); - let resp: HttpResponse = b.respond_to(req.clone()).ok().unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/octet-stream")); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().binary().unwrap(), &Binary::from(Bytes::from_static(b"test"))); - - let b = BytesMut::from("test"); - let resp: HttpResponse = b.into(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/octet-stream")); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().binary().unwrap(), &Binary::from(BytesMut::from("test"))); - - let b = BytesMut::from("test"); - let resp: HttpResponse = b.respond_to(req.clone()).ok().unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/octet-stream")); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().binary().unwrap(), &Binary::from(BytesMut::from("test"))); - } - - #[test] - fn test_into_builder() { - let resp: HttpResponse = "test".into(); - assert_eq!(resp.status(), StatusCode::OK); - - let mut builder = resp.into_builder(); - let resp = builder.status(StatusCode::BAD_REQUEST).finish(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - } -} diff --git a/src/info.rs b/src/info.rs index 7e1b40f04..a9c3e4eeb 100644 --- a/src/info.rs +++ b/src/info.rs @@ -1,33 +1,39 @@ -use std::str::FromStr; -use http::header::{self, HeaderName}; -use httpmessage::HttpMessage; -use httprequest::HttpRequest; +use std::cell::Ref; -const X_FORWARDED_FOR: &str = "X-FORWARDED-FOR"; -const X_FORWARDED_HOST: &str = "X-FORWARDED-HOST"; -const X_FORWARDED_PROTO: &str = "X-FORWARDED-PROTO"; +use crate::dev::{AppConfig, RequestHead}; +use crate::http::header::{self, HeaderName}; +const X_FORWARDED_FOR: &[u8] = b"x-forwarded-for"; +const X_FORWARDED_HOST: &[u8] = b"x-forwarded-host"; +const X_FORWARDED_PROTO: &[u8] = b"x-forwarded-proto"; /// `HttpRequest` connection information -pub struct ConnectionInfo<'a> { - scheme: &'a str, - host: &'a str, - remote: Option<&'a str>, +#[derive(Debug, Clone, Default)] +pub struct ConnectionInfo { + scheme: String, + host: String, + remote: Option, peer: Option, } -impl<'a> ConnectionInfo<'a> { - +impl ConnectionInfo { /// Create *ConnectionInfo* instance for a request. - #[cfg_attr(feature = "cargo-clippy", allow(cyclomatic_complexity))] - pub fn new(req: &'a HttpRequest) -> ConnectionInfo<'a> { + pub fn get<'a>(req: &'a RequestHead, cfg: &AppConfig) -> Ref<'a, Self> { + if !req.extensions().contains::() { + req.extensions_mut().insert(ConnectionInfo::new(req, cfg)); + } + Ref::map(req.extensions(), |e| e.get().unwrap()) + } + + #[allow(clippy::cognitive_complexity)] + fn new(req: &RequestHead, cfg: &AppConfig) -> ConnectionInfo { let mut host = None; let mut scheme = None; let mut remote = None; let mut peer = None; // load forwarded header - for hdr in req.headers().get_all(header::FORWARDED) { + for hdr in req.headers.get_all(&header::FORWARDED) { if let Ok(val) = hdr.to_str() { for pair in val.split(';') { for el in pair.split(',') { @@ -35,15 +41,21 @@ impl<'a> ConnectionInfo<'a> { if let Some(name) = items.next() { if let Some(val) = items.next() { match &name.to_lowercase() as &str { - "for" => if remote.is_none() { - remote = Some(val.trim()); - }, - "proto" => if scheme.is_none() { - scheme = Some(val.trim()); - }, - "host" => if host.is_none() { - host = Some(val.trim()); - }, + "for" => { + if remote.is_none() { + remote = Some(val.trim()); + } + } + "proto" => { + if scheme.is_none() { + scheme = Some(val.trim()); + } + } + "host" => { + if host.is_none() { + host = Some(val.trim()); + } + } _ => (), } } @@ -55,41 +67,40 @@ impl<'a> ConnectionInfo<'a> { // scheme if scheme.is_none() { - if let Some(h) = req.headers().get( - HeaderName::from_str(X_FORWARDED_PROTO).unwrap()) { + if let Some(h) = req + .headers + .get(&HeaderName::from_lowercase(X_FORWARDED_PROTO).unwrap()) + { if let Ok(h) = h.to_str() { scheme = h.split(',').next().map(|v| v.trim()); } } if scheme.is_none() { - scheme = req.uri().scheme_part().map(|a| a.as_str()); - if scheme.is_none() { - if let Some(router) = req.router() { - if router.server_settings().secure() { - scheme = Some("https") - } - } + scheme = req.uri.scheme_part().map(|a| a.as_str()); + if scheme.is_none() && cfg.secure() { + scheme = Some("https") } } } // host if host.is_none() { - if let Some(h) = req.headers().get(HeaderName::from_str(X_FORWARDED_HOST).unwrap()) { + if let Some(h) = req + .headers + .get(&HeaderName::from_lowercase(X_FORWARDED_HOST).unwrap()) + { if let Ok(h) = h.to_str() { host = h.split(',').next().map(|v| v.trim()); } } if host.is_none() { - if let Some(h) = req.headers().get(header::HOST) { + if let Some(h) = req.headers.get(&header::HOST) { host = h.to_str().ok(); } if host.is_none() { - host = req.uri().authority_part().map(|a| a.as_str()); + host = req.uri.authority_part().map(|a| a.as_str()); if host.is_none() { - if let Some(router) = req.router() { - host = Some(router.server_settings().host()); - } + host = Some(cfg.host()); } } } @@ -97,22 +108,25 @@ impl<'a> ConnectionInfo<'a> { // remote addr if remote.is_none() { - if let Some(h) = req.headers().get( - HeaderName::from_str(X_FORWARDED_FOR).unwrap()) { + if let Some(h) = req + .headers + .get(&HeaderName::from_lowercase(X_FORWARDED_FOR).unwrap()) + { if let Ok(h) = h.to_str() { remote = h.split(',').next().map(|v| v.trim()); } } - if remote.is_none() { // get peeraddr from socketaddr - peer = req.peer_addr().map(|addr| format!("{}", addr)); + if remote.is_none() { + // get peeraddr from socketaddr + peer = req.peer_addr.map(|addr| format!("{}", addr)); } } ConnectionInfo { - scheme: scheme.unwrap_or("http"), - host: host.unwrap_or("localhost"), - remote, peer, + scheme: scheme.unwrap_or("http").to_owned(), + host: host.unwrap_or("localhost").to_owned(), + remote: remote.map(|s| s.to_owned()), } } @@ -125,7 +139,7 @@ impl<'a> ConnectionInfo<'a> { /// - Uri #[inline] pub fn scheme(&self) -> &str { - self.scheme + &self.scheme } /// Hostname of the request. @@ -138,19 +152,25 @@ impl<'a> ConnectionInfo<'a> { /// - Uri /// - Server hostname pub fn host(&self) -> &str { - self.host + &self.host } - /// Remote IP of client initiated HTTP request. + /// Remote socket addr of client initiated HTTP request. /// - /// The IP is resolved through the following headers, in this order: + /// The addr is resolved through the following headers, in this order: /// /// - Forwarded /// - X-Forwarded-For /// - peer name of opened socket + /// + /// # Security + /// Do not use this function for security purposes, unless you can ensure the Forwarded and + /// X-Forwarded-For headers cannot be spoofed by the client. If you want the client's socket + /// address explicitly, use + /// [`HttpRequest::peer_addr()`](../web/struct.HttpRequest.html#method.peer_addr) instead. #[inline] pub fn remote(&self) -> Option<&str> { - if let Some(r) = self.remote { + if let Some(ref r) = self.remote { Some(r) } else if let Some(ref peer) = self.peer { Some(peer) @@ -163,52 +183,53 @@ impl<'a> ConnectionInfo<'a> { #[cfg(test)] mod tests { use super::*; - use http::header::HeaderValue; + use crate::test::TestRequest; #[test] fn test_forwarded() { - let req = HttpRequest::default(); - let info = ConnectionInfo::new(&req); + let req = TestRequest::default().to_http_request(); + let info = req.connection_info(); assert_eq!(info.scheme(), "http"); - assert_eq!(info.host(), "localhost"); + assert_eq!(info.host(), "localhost:8080"); - let mut req = HttpRequest::default(); - req.headers_mut().insert( - header::FORWARDED, - HeaderValue::from_static( - "for=192.0.2.60; proto=https; by=203.0.113.43; host=rust-lang.org")); + let req = TestRequest::default() + .header( + header::FORWARDED, + "for=192.0.2.60; proto=https; by=203.0.113.43; host=rust-lang.org", + ) + .to_http_request(); - let info = ConnectionInfo::new(&req); + let info = req.connection_info(); assert_eq!(info.scheme(), "https"); assert_eq!(info.host(), "rust-lang.org"); assert_eq!(info.remote(), Some("192.0.2.60")); - let mut req = HttpRequest::default(); - req.headers_mut().insert( - header::HOST, HeaderValue::from_static("rust-lang.org")); + let req = TestRequest::default() + .header(header::HOST, "rust-lang.org") + .to_http_request(); - let info = ConnectionInfo::new(&req); + let info = req.connection_info(); assert_eq!(info.scheme(), "http"); assert_eq!(info.host(), "rust-lang.org"); assert_eq!(info.remote(), None); - let mut req = HttpRequest::default(); - req.headers_mut().insert( - HeaderName::from_str(X_FORWARDED_FOR).unwrap(), HeaderValue::from_static("192.0.2.60")); - let info = ConnectionInfo::new(&req); + let req = TestRequest::default() + .header(X_FORWARDED_FOR, "192.0.2.60") + .to_http_request(); + let info = req.connection_info(); assert_eq!(info.remote(), Some("192.0.2.60")); - let mut req = HttpRequest::default(); - req.headers_mut().insert( - HeaderName::from_str(X_FORWARDED_HOST).unwrap(), HeaderValue::from_static("192.0.2.60")); - let info = ConnectionInfo::new(&req); + let req = TestRequest::default() + .header(X_FORWARDED_HOST, "192.0.2.60") + .to_http_request(); + let info = req.connection_info(); assert_eq!(info.host(), "192.0.2.60"); assert_eq!(info.remote(), None); - let mut req = HttpRequest::default(); - req.headers_mut().insert( - HeaderName::from_str(X_FORWARDED_PROTO).unwrap(), HeaderValue::from_static("https")); - let info = ConnectionInfo::new(&req); + let req = TestRequest::default() + .header(X_FORWARDED_PROTO, "https") + .to_http_request(); + let info = req.connection_info(); assert_eq!(info.scheme(), "https"); } } diff --git a/src/json.rs b/src/json.rs deleted file mode 100644 index 3c8f81e8d..000000000 --- a/src/json.rs +++ /dev/null @@ -1,312 +0,0 @@ -use std::fmt; -use std::ops::{Deref, DerefMut}; -use bytes::{Bytes, BytesMut}; -use futures::{Poll, Future, Stream}; -use http::header::CONTENT_LENGTH; - -use mime; -use serde_json; -use serde::Serialize; -use serde::de::DeserializeOwned; - -use error::{Error, JsonPayloadError, PayloadError}; -use handler::{Responder, FromRequest}; -use httpmessage::HttpMessage; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; - -/// Json helper -/// -/// Json can be used for two different purpose. First is for json response generation -/// and second is for extracting typed information from request's payload. -pub struct Json(pub T); - -impl Deref for Json { - type Target = T; - - fn deref(&self) -> &T { - &self.0 - } -} - -impl DerefMut for Json { - fn deref_mut(&mut self) -> &mut T { - &mut self.0 - } -} - -impl fmt::Debug for Json where T: fmt::Debug { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Json: {:?}", self.0) - } -} - -impl fmt::Display for Json where T: fmt::Display { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - fmt::Display::fmt(&self.0, f) - } -} - -/// The `Json` type allows you to respond with well-formed JSON data: simply -/// return a value of type Json where T is the type of a structure -/// to serialize into *JSON*. The type `T` must implement the `Serialize` -/// trait from *serde*. -/// -/// ```rust -/// # extern crate actix_web; -/// # #[macro_use] extern crate serde_derive; -/// # use actix_web::*; -/// # -/// #[derive(Serialize)] -/// struct MyObj { -/// name: String, -/// } -/// -/// fn index(req: HttpRequest) -> Result> { -/// Ok(Json(MyObj{name: req.match_info().query("name")?})) -/// } -/// # fn main() {} -/// ``` -impl Responder for Json { - type Item = HttpResponse; - type Error = Error; - - fn respond_to(self, _: HttpRequest) -> Result { - let body = serde_json::to_string(&self.0)?; - - Ok(HttpResponse::Ok() - .content_type("application/json") - .body(body)) - } -} - -/// To extract typed information from request's body, the type `T` must implement the -/// `Deserialize` trait from *serde*. -/// -/// ## Example -/// -/// ```rust -/// # extern crate actix_web; -/// #[macro_use] extern crate serde_derive; -/// use actix_web::{App, Json, Result, http}; -/// -/// #[derive(Deserialize)] -/// struct Info { -/// username: String, -/// } -/// -/// /// deserialize `Info` from request's body -/// fn index(info: Json) -> Result { -/// Ok(format!("Welcome {}!", info.username)) -/// } -/// -/// fn main() { -/// let app = App::new().resource( -/// "/index.html", -/// |r| r.method(http::Method::POST).with(index)); // <- use `with` extractor -/// } -/// ``` -impl FromRequest for Json - where T: DeserializeOwned + 'static, S: 'static -{ - type Result = Box>; - - #[inline] - fn from_request(req: &HttpRequest) -> Self::Result { - Box::new( - JsonBody::new(req.clone()) - .from_err() - .map(Json)) - } -} - -/// Request payload json parser that resolves to a deserialized `T` value. -/// -/// Returns error: -/// -/// * content type is not `application/json` -/// * content length is greater than 256k -/// -/// # Server example -/// -/// ```rust -/// # extern crate actix_web; -/// # extern crate futures; -/// # #[macro_use] extern crate serde_derive; -/// use futures::future::Future; -/// use actix_web::{AsyncResponder, HttpRequest, HttpResponse, HttpMessage, Error}; -/// -/// #[derive(Deserialize, Debug)] -/// struct MyObj { -/// name: String, -/// } -/// -/// fn index(mut req: HttpRequest) -> Box> { -/// req.json() // <- get JsonBody future -/// .from_err() -/// .and_then(|val: MyObj| { // <- deserialized value -/// println!("==== BODY ==== {:?}", val); -/// Ok(HttpResponse::Ok().into()) -/// }).responder() -/// } -/// # fn main() {} -/// ``` -pub struct JsonBody{ - limit: usize, - req: Option, - fut: Option>>, -} - -impl JsonBody { - - /// Create `JsonBody` for request. - pub fn new(req: T) -> Self { - JsonBody{ - limit: 262_144, - req: Some(req), - fut: None, - } - } - - /// Change max size of payload. By default max size is 256Kb - pub fn limit(mut self, limit: usize) -> Self { - self.limit = limit; - self - } -} - -impl Future for JsonBody - where T: HttpMessage + Stream + 'static -{ - type Item = U; - type Error = JsonPayloadError; - - fn poll(&mut self) -> Poll { - if let Some(req) = self.req.take() { - if let Some(len) = req.headers().get(CONTENT_LENGTH) { - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - if len > self.limit { - return Err(JsonPayloadError::Overflow); - } - } else { - return Err(JsonPayloadError::Overflow); - } - } - } - // check content-type - - let json = if let Ok(Some(mime)) = req.mime_type() { - mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) - } else { - false - }; - if !json { - return Err(JsonPayloadError::ContentType) - } - - let limit = self.limit; - let fut = req.from_err() - .fold(BytesMut::new(), move |mut body, chunk| { - if (body.len() + chunk.len()) > limit { - Err(JsonPayloadError::Overflow) - } else { - body.extend_from_slice(&chunk); - Ok(body) - } - }) - .and_then(|body| Ok(serde_json::from_slice::(&body)?)); - self.fut = Some(Box::new(fut)); - } - - self.fut.as_mut().expect("JsonBody could not be used second time").poll() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::Bytes; - use http::header; - use futures::Async; - - use with::With; - use handler::Handler; - - impl PartialEq for JsonPayloadError { - fn eq(&self, other: &JsonPayloadError) -> bool { - match *self { - JsonPayloadError::Overflow => match *other { - JsonPayloadError::Overflow => true, - _ => false, - }, - JsonPayloadError::ContentType => match *other { - JsonPayloadError::ContentType => true, - _ => false, - }, - _ => false, - } - } - } - - #[derive(Serialize, Deserialize, PartialEq, Debug)] - struct MyObject { - name: String, - } - - #[test] - fn test_json() { - let json = Json(MyObject{name: "test".to_owned()}); - let resp = json.respond_to(HttpRequest::default()).unwrap(); - assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(), "application/json"); - } - - #[test] - fn test_json_body() { - let req = HttpRequest::default(); - let mut json = req.json::(); - assert_eq!(json.poll().err().unwrap(), JsonPayloadError::ContentType); - - let mut req = HttpRequest::default(); - req.headers_mut().insert(header::CONTENT_TYPE, - header::HeaderValue::from_static("application/text")); - let mut json = req.json::(); - assert_eq!(json.poll().err().unwrap(), JsonPayloadError::ContentType); - - let mut req = HttpRequest::default(); - req.headers_mut().insert(header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json")); - req.headers_mut().insert(header::CONTENT_LENGTH, - header::HeaderValue::from_static("10000")); - let mut json = req.json::().limit(100); - assert_eq!(json.poll().err().unwrap(), JsonPayloadError::Overflow); - - let mut req = HttpRequest::default(); - req.headers_mut().insert(header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json")); - req.headers_mut().insert(header::CONTENT_LENGTH, - header::HeaderValue::from_static("16")); - req.payload_mut().unread_data(Bytes::from_static(b"{\"name\": \"test\"}")); - let mut json = req.json::(); - assert_eq!(json.poll().ok().unwrap(), - Async::Ready(MyObject{name: "test".to_owned()})); - } - - #[test] - fn test_with_json() { - let mut handler = With::new(|data: Json| data); - - let req = HttpRequest::default(); - let err = handler.handle(req).as_response().unwrap().error().is_some(); - assert!(err); - - let mut req = HttpRequest::default(); - req.headers_mut().insert(header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json")); - req.headers_mut().insert(header::CONTENT_LENGTH, - header::HeaderValue::from_static("16")); - req.payload_mut().unread_data(Bytes::from_static(b"{\"name\": \"test\"}")); - let ok = handler.handle(req).as_response().unwrap().error().is_none(); - assert!(ok) - } -} diff --git a/src/lib.rs b/src/lib.rs index 3e134c441..b7fd8d155 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,31 +1,55 @@ -//! Actix web is a small, pragmatic, extremely fast, web framework for Rust. +#![allow(clippy::borrow_interior_mutable_const)] +//! Actix web is a small, pragmatic, and extremely fast web framework +//! for Rust. //! //! ```rust -//! use actix_web::{App, HttpServer, Path}; +//! use actix_web::{web, App, Responder, HttpServer}; //! # use std::thread; //! -//! fn index(info: Path<(String, u32)>) -> String { -//! format!("Hello {}! id:{}", info.0, info.1) +//! async fn index(info: web::Path<(String, u32)>) -> impl Responder { +//! format!("Hello {}! id:{}", info.0, info.1) //! } //! -//! fn main() { -//! # thread::spawn(|| { -//! HttpServer::new( -//! || App::new() -//! .resource("/{name}/{id}/index.html", |r| r.with(index))) -//! .bind("127.0.0.1:8080").unwrap() -//! .run(); -//! # }); +//! fn main() -> std::io::Result<()> { +//! # thread::spawn(|| { +//! HttpServer::new(|| App::new().service( +//! web::resource("/{name}/{id}/index.html").to(index)) +//! ) +//! .bind("127.0.0.1:8080")? +//! .run() +//! # }); +//! # Ok(()) //! } //! ``` //! -//! ## Documentation +//! ## Documentation & community resources //! -//! * [User Guide](http://actix.github.io/actix-web/guide/) +//! Besides the API documentation (which you are currently looking +//! at!), several other resources are available: +//! +//! * [User Guide](https://actix.rs/docs/) //! * [Chat on gitter](https://gitter.im/actix/actix) //! * [GitHub repository](https://github.com/actix/actix-web) //! * [Cargo package](https://crates.io/crates/actix-web) -//! * Supported Rust version: 1.21 or later +//! +//! To get started navigating the API documentation you may want to +//! consider looking at the following pages: +//! +//! * [App](struct.App.html): This struct represents an actix-web +//! application and is used to configure routes and other common +//! settings. +//! +//! * [HttpServer](struct.HttpServer.html): This struct +//! represents an HTTP server instance and is used to instantiate and +//! configure servers. +//! +//! * [web](web/index.html): This module +//! provide essentials helper functions and types for application registration. +//! +//! * [HttpRequest](struct.HttpRequest.html) and +//! [HttpResponse](struct.HttpResponse.html): These structs +//! represent HTTP requests and responses and expose various methods +//! for inspecting, creating and otherwise utilizing them. //! //! ## Features //! @@ -35,176 +59,137 @@ //! * `WebSockets` server/client //! * Transparent content compression/decompression (br, gzip, deflate) //! * Configurable request routing -//! * Graceful server shutdown //! * Multipart streams -//! * SSL support with openssl or native-tls -//! * Middlewares (`Logger`, `Session`, `CORS`, `CSRF`, `DefaultHeaders`) -//! * Built on top of [Actix actor framework](https://github.com/actix/actix). +//! * SSL support with OpenSSL or `native-tls` +//! * Middlewares (`Logger`, `Session`, `CORS`, `DefaultHeaders`) +//! * Supports [Actix actor framework](https://github.com/actix/actix) +//! * Supported Rust version: 1.39 or later +//! +//! ## Package feature +//! +//! * `client` - enables http client (default enabled) +//! * `openssl` - enables ssl support via `openssl` crate, supports `http/2` +//! * `rustls` - enables ssl support via `rustls` crate, supports `http/2` +//! * `secure-cookies` - enables secure cookies support, includes `ring` crate as +//! dependency +//! * `brotli` - enables `brotli` compression support, requires `c` +//! compiler (default enabled) +//! * `flate2-zlib` - enables `gzip`, `deflate` compression support, requires +//! `c` compiler (default enabled) +//! * `flate2-rust` - experimental rust based implementation for +//! `gzip`, `deflate` compression. +//! +#![allow(clippy::type_complexity, clippy::new_without_default)] -#![cfg_attr(actix_nightly, feature( - specialization, // for impl ErrorResponse for std::error::Error -))] -#![cfg_attr(feature = "cargo-clippy", allow( - decimal_literal_representation,suspicious_arithmetic_impl,))] - -#[macro_use] -extern crate log; -extern crate time; -extern crate base64; -extern crate bytes; -extern crate byteorder; -extern crate sha1; -extern crate regex; -#[macro_use] -extern crate bitflags; -#[macro_use] -extern crate failure; -#[macro_use] -extern crate lazy_static; -#[macro_use] -extern crate futures; -extern crate futures_cpupool; -extern crate tokio_io; -extern crate tokio_core; -extern crate mio; -extern crate net2; -extern crate cookie; -extern crate http as modhttp; -extern crate httparse; -extern crate http_range; -extern crate mime; -extern crate mime_guess; -extern crate language_tags; -extern crate rand; -extern crate url; -extern crate libc; -#[macro_use] extern crate serde; -extern crate serde_json; -extern crate serde_urlencoded; -extern crate flate2; -#[cfg(feature="brotli")] -extern crate brotli2; -extern crate encoding; -extern crate percent_encoding; -extern crate smallvec; -extern crate num_cpus; -extern crate h2 as http2; -extern crate trust_dns_resolver; -#[macro_use] extern crate actix; - -#[cfg(test)] -#[macro_use] extern crate serde_derive; - -#[cfg(feature="tls")] -extern crate native_tls; -#[cfg(feature="tls")] -extern crate tokio_tls; - -#[cfg(feature="openssl")] -extern crate openssl; -#[cfg(feature="openssl")] -extern crate tokio_openssl; - -mod application; -mod body; -mod context; -mod de; -mod extractor; -mod handler; -mod header; -mod helpers; -mod httpmessage; -mod httprequest; -mod httpresponse; -mod info; -mod json; -mod route; -mod router; -mod resource; -mod param; -mod payload; -mod pipeline; -mod with; - -pub mod client; -pub mod fs; -pub mod ws; +mod app; +mod app_service; +mod config; +mod data; pub mod error; -pub mod multipart; +mod extract; +pub mod guard; +mod handler; +mod info; pub mod middleware; -pub mod pred; +mod request; +mod resource; +mod responder; +mod rmap; +mod route; +mod scope; +mod server; +mod service; pub mod test; -pub mod server; -pub use extractor::{Path, Form, Query}; -pub use error::{Error, Result, ResponseError}; -pub use body::{Body, Binary}; -pub use json::Json; -pub use application::App; -pub use httpmessage::HttpMessage; -pub use httprequest::HttpRequest; -pub use httpresponse::HttpResponse; -pub use handler::{Either, Responder, AsyncResponder, FromRequest, FutureResponse, State}; -pub use context::HttpContext; -pub use server::HttpServer; +mod types; +pub mod web; + +#[allow(unused_imports)] +#[macro_use] +extern crate actix_web_codegen; #[doc(hidden)] -pub mod httpcodes; +pub use actix_web_codegen::*; -#[doc(hidden)] -#[allow(deprecated)] -pub use application::Application; +// re-export for convenience +pub use actix_http::Response as HttpResponse; +pub use actix_http::{body, cookie, http, Error, HttpMessage, ResponseError, Result}; -#[cfg(feature="openssl")] -pub(crate) const HAS_OPENSSL: bool = true; -#[cfg(not(feature="openssl"))] -pub(crate) const HAS_OPENSSL: bool = false; - -#[cfg(feature="tls")] -pub(crate) const HAS_TLS: bool = true; -#[cfg(not(feature="tls"))] -pub(crate) const HAS_TLS: bool = false; +pub use crate::app::App; +pub use crate::extract::FromRequest; +pub use crate::request::HttpRequest; +pub use crate::resource::Resource; +pub use crate::responder::{Either, Responder}; +pub use crate::route::Route; +pub use crate::scope::Scope; +pub use crate::server::HttpServer; pub mod dev { -//! The `actix-web` prelude for library developers -//! -//! The purpose of this module is to alleviate imports of many common actix traits -//! by adding a glob import to the top of actix heavy modules: -//! -//! ``` -//! # #![allow(unused_imports)] -//! use actix_web::dev::*; -//! ``` - - pub use body::BodyStream; - pub use context::Drain; - pub use json::JsonBody; - pub use info::ConnectionInfo; - pub use handler::{Handler, Reply}; - pub use route::Route; - pub use router::{Router, Resource, ResourceType}; - pub use resource::ResourceHandler; - pub use param::{FromParam, Params}; - pub use httpmessage::{UrlEncoded, MessageBody}; - pub use httpresponse::HttpResponseBuilder; -} - -pub mod http { - //! Various http related types - - // re-exports - pub use modhttp::{Method, StatusCode, Version}; + //! The `actix-web` prelude for library developers + //! + //! The purpose of this module is to alleviate imports of many common actix + //! traits by adding a glob import to the top of actix heavy modules: + //! + //! ``` + //! # #![allow(unused_imports)] + //! use actix_web::dev::*; + //! ``` + pub use crate::config::{AppConfig, AppService}; #[doc(hidden)] - pub use modhttp::{uri, Uri, Error, Extensions, HeaderMap, HttpTryFrom}; + pub use crate::handler::Factory; + pub use crate::info::ConnectionInfo; + pub use crate::rmap::ResourceMap; + pub use crate::service::{ + HttpServiceFactory, ServiceRequest, ServiceResponse, WebService, + }; - pub use http_range::HttpRange; - pub use cookie::{Cookie, CookieBuilder}; + pub use crate::types::form::UrlEncoded; + pub use crate::types::json::JsonBody; + pub use crate::types::readlines::Readlines; - pub use helpers::NormalizePath; + pub use actix_http::body::{Body, BodySize, MessageBody, ResponseBody, SizedStream}; + pub use actix_http::encoding::Decoder as Decompress; + pub use actix_http::ResponseBuilder as HttpResponseBuilder; + pub use actix_http::{ + Extensions, Payload, PayloadStream, RequestHead, ResponseHead, + }; + pub use actix_router::{Path, ResourceDef, ResourcePath, Url}; + pub use actix_server::Server; + pub use actix_service::{Service, Transform}; - pub mod header { - pub use ::header::*; + pub(crate) fn insert_slash(path: &str) -> String { + let mut path = path.to_owned(); + if !path.is_empty() && !path.starts_with('/') { + path.insert(0, '/'); + }; + path } - pub use header::ContentEncoding; - pub use httpresponse::ConnectionType; +} + +#[cfg(feature = "client")] +pub mod client { + //! An HTTP Client + //! + //! ```rust + //! use actix_rt::System; + //! use actix_web::client::Client; + //! + //! #[actix_rt::main] + //! async fn main() { + //! let mut client = Client::default(); + //! + //! // Create request builder and send request + //! let response = client.get("http://www.rust-lang.org") + //! .header("User-Agent", "Actix-web") + //! .send().await; // <- Send http request + //! + //! println!("Response: {:?}", response); + //! } + //! ``` + pub use awc::error::{ + ConnectError, InvalidUrl, PayloadError, SendRequestError, WsClientError, + }; + pub use awc::{ + test, Client, ClientBuilder, ClientRequest, ClientResponse, Connector, + }; } diff --git a/src/middleware/compress.rs b/src/middleware/compress.rs new file mode 100644 index 000000000..a697deaec --- /dev/null +++ b/src/middleware/compress.rs @@ -0,0 +1,241 @@ +//! `Middleware` for compressing response body. +use std::cmp; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::str::FromStr; +use std::task::{Context, Poll}; + +use actix_http::body::MessageBody; +use actix_http::encoding::Encoder; +use actix_http::http::header::{ContentEncoding, ACCEPT_ENCODING}; +use actix_http::{Error, Response, ResponseBuilder}; +use actix_service::{Service, Transform}; +use futures::future::{ok, Ready}; +use pin_project::pin_project; + +use crate::service::{ServiceRequest, ServiceResponse}; + +struct Enc(ContentEncoding); + +/// Helper trait that allows to set specific encoding for response. +pub trait BodyEncoding { + fn encoding(&mut self, encoding: ContentEncoding) -> &mut Self; +} + +impl BodyEncoding for ResponseBuilder { + fn encoding(&mut self, encoding: ContentEncoding) -> &mut Self { + self.extensions_mut().insert(Enc(encoding)); + self + } +} + +impl BodyEncoding for Response { + fn encoding(&mut self, encoding: ContentEncoding) -> &mut Self { + self.extensions_mut().insert(Enc(encoding)); + self + } +} + +#[derive(Debug, Clone)] +/// `Middleware` for compressing response body. +/// +/// Use `BodyEncoding` trait for overriding response compression. +/// To disable compression set encoding to `ContentEncoding::Identity` value. +/// +/// ```rust +/// use actix_web::{web, middleware, App, HttpResponse}; +/// +/// fn main() { +/// let app = App::new() +/// .wrap(middleware::Compress::default()) +/// .service( +/// web::resource("/test") +/// .route(web::get().to(|| HttpResponse::Ok())) +/// .route(web::head().to(|| HttpResponse::MethodNotAllowed())) +/// ); +/// } +/// ``` +pub struct Compress(ContentEncoding); + +impl Compress { + /// Create new `Compress` middleware with default encoding. + pub fn new(encoding: ContentEncoding) -> Self { + Compress(encoding) + } +} + +impl Default for Compress { + fn default() -> Self { + Compress::new(ContentEncoding::Auto) + } +} + +impl Transform for Compress +where + B: MessageBody, + S: Service, Error = Error>, +{ + type Request = ServiceRequest; + type Response = ServiceResponse>; + type Error = Error; + type InitError = (); + type Transform = CompressMiddleware; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ok(CompressMiddleware { + service, + encoding: self.0, + }) + } +} + +pub struct CompressMiddleware { + service: S, + encoding: ContentEncoding, +} + +impl Service for CompressMiddleware +where + B: MessageBody, + S: Service, Error = Error>, +{ + type Request = ServiceRequest; + type Response = ServiceResponse>; + type Error = Error; + type Future = CompressResponse; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: ServiceRequest) -> Self::Future { + // negotiate content-encoding + let encoding = if let Some(val) = req.headers().get(&ACCEPT_ENCODING) { + if let Ok(enc) = val.to_str() { + AcceptEncoding::parse(enc, self.encoding) + } else { + ContentEncoding::Identity + } + } else { + ContentEncoding::Identity + }; + + CompressResponse { + encoding, + fut: self.service.call(req), + _t: PhantomData, + } + } +} + +#[doc(hidden)] +#[pin_project] +pub struct CompressResponse +where + S: Service, + B: MessageBody, +{ + #[pin] + fut: S::Future, + encoding: ContentEncoding, + _t: PhantomData<(B)>, +} + +impl Future for CompressResponse +where + B: MessageBody, + S: Service, Error = Error>, +{ + type Output = Result>, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); + + match futures::ready!(this.fut.poll(cx)) { + Ok(resp) => { + let enc = if let Some(enc) = resp.response().extensions().get::() { + enc.0 + } else { + *this.encoding + }; + + Poll::Ready(Ok( + resp.map_body(move |head, body| Encoder::response(enc, head, body)) + )) + } + Err(e) => Poll::Ready(Err(e)), + } + } +} + +struct AcceptEncoding { + encoding: ContentEncoding, + quality: f64, +} + +impl Eq for AcceptEncoding {} + +impl Ord for AcceptEncoding { + fn cmp(&self, other: &AcceptEncoding) -> cmp::Ordering { + if self.quality > other.quality { + cmp::Ordering::Less + } else if self.quality < other.quality { + cmp::Ordering::Greater + } else { + cmp::Ordering::Equal + } + } +} + +impl PartialOrd for AcceptEncoding { + fn partial_cmp(&self, other: &AcceptEncoding) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for AcceptEncoding { + fn eq(&self, other: &AcceptEncoding) -> bool { + self.quality == other.quality + } +} + +impl AcceptEncoding { + fn new(tag: &str) -> Option { + let parts: Vec<&str> = tag.split(';').collect(); + let encoding = match parts.len() { + 0 => return None, + _ => ContentEncoding::from(parts[0]), + }; + let quality = match parts.len() { + 1 => encoding.quality(), + _ => match f64::from_str(parts[1]) { + Ok(q) => q, + Err(_) => 0.0, + }, + }; + Some(AcceptEncoding { encoding, quality }) + } + + /// Parse a raw Accept-Encoding header value into an ordered list. + pub fn parse(raw: &str, encoding: ContentEncoding) -> ContentEncoding { + let mut encodings: Vec<_> = raw + .replace(' ', "") + .split(',') + .map(|l| AcceptEncoding::new(l)) + .collect(); + encodings.sort(); + + for enc in encodings { + if let Some(enc) = enc { + if encoding == ContentEncoding::Auto { + return enc.encoding; + } else if encoding == enc.encoding { + return encoding; + } + } + } + ContentEncoding::Identity + } +} diff --git a/src/middleware/condition.rs b/src/middleware/condition.rs new file mode 100644 index 000000000..2ede81783 --- /dev/null +++ b/src/middleware/condition.rs @@ -0,0 +1,151 @@ +//! `Middleware` for conditionally enables another middleware. +use std::task::{Context, Poll}; + +use actix_service::{Service, Transform}; +use futures::future::{ok, Either, FutureExt, LocalBoxFuture}; + +/// `Middleware` for conditionally enables another middleware. +/// The controled middleware must not change the `Service` interfaces. +/// This means you cannot control such middlewares like `Logger` or `Compress`. +/// +/// ## Usage +/// +/// ```rust +/// use actix_web::middleware::{Condition, NormalizePath}; +/// use actix_web::App; +/// +/// # fn main() { +/// let enable_normalize = std::env::var("NORMALIZE_PATH") == Ok("true".into()); +/// let app = App::new() +/// .wrap(Condition::new(enable_normalize, NormalizePath)); +/// # } +/// ``` +pub struct Condition { + trans: T, + enable: bool, +} + +impl Condition { + pub fn new(enable: bool, trans: T) -> Self { + Self { trans, enable } + } +} + +impl Transform for Condition +where + S: Service + 'static, + T: Transform, + T::Future: 'static, + T::InitError: 'static, + T::Transform: 'static, +{ + type Request = S::Request; + type Response = S::Response; + type Error = S::Error; + type InitError = T::InitError; + type Transform = ConditionMiddleware; + type Future = LocalBoxFuture<'static, Result>; + + fn new_transform(&self, service: S) -> Self::Future { + if self.enable { + let f = self.trans.new_transform(service).map(|res| { + res.map( + ConditionMiddleware::Enable as fn(T::Transform) -> Self::Transform, + ) + }); + Either::Left(f) + } else { + Either::Right(ok(ConditionMiddleware::Disable(service))) + } + .boxed_local() + } +} + +pub enum ConditionMiddleware { + Enable(E), + Disable(D), +} + +impl Service for ConditionMiddleware +where + E: Service, + D: Service, +{ + type Request = E::Request; + type Response = E::Response; + type Error = E::Error; + type Future = Either; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + use ConditionMiddleware::*; + match self { + Enable(service) => service.poll_ready(cx), + Disable(service) => service.poll_ready(cx), + } + } + + fn call(&mut self, req: E::Request) -> Self::Future { + use ConditionMiddleware::*; + match self { + Enable(service) => Either::Left(service.call(req)), + Disable(service) => Either::Right(service.call(req)), + } + } +} + +#[cfg(test)] +mod tests { + use actix_service::IntoService; + + use super::*; + use crate::dev::{ServiceRequest, ServiceResponse}; + use crate::error::Result; + use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; + use crate::middleware::errhandlers::*; + use crate::test::{self, TestRequest}; + use crate::HttpResponse; + + fn render_500(mut res: ServiceResponse) -> Result> { + res.response_mut() + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); + Ok(ErrorHandlerResponse::Response(res)) + } + + #[actix_rt::test] + async fn test_handler_enabled() { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) + }; + + let mw = + ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); + + let mut mw = Condition::new(true, mw) + .new_transform(srv.into_service()) + .await + .unwrap(); + let resp = + test::call_service(&mut mw, TestRequest::default().to_srv_request()).await; + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + } + + #[actix_rt::test] + async fn test_handler_disabled() { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) + }; + + let mw = + ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); + + let mut mw = Condition::new(false, mw) + .new_transform(srv.into_service()) + .await + .unwrap(); + + let resp = + test::call_service(&mut mw, TestRequest::default().to_srv_request()).await; + assert_eq!(resp.headers().get(CONTENT_TYPE), None); + } +} diff --git a/src/middleware/cors.rs b/src/middleware/cors.rs deleted file mode 100644 index 28c5c7898..000000000 --- a/src/middleware/cors.rs +++ /dev/null @@ -1,867 +0,0 @@ -//! Cross-origin resource sharing (CORS) for Actix applications -//! -//! CORS middleware could be used with application and with resource. -//! First you need to construct CORS middleware instance. -//! -//! To construct a cors: -//! -//! 1. Call [`Cors::build`](struct.Cors.html#method.build) to start building. -//! 2. Use any of the builder methods to set fields in the backend. -//! 3. Call [finish](struct.Cors.html#method.finish) to retrieve the constructed backend. -//! -//! Cors middleware could be used as parameter for `App::middleware()` or -//! `ResourceHandler::middleware()` methods. But you have to use `Cors::register()` method to -//! support *preflight* OPTIONS request. -//! -//! -//! # Example -//! -//! ```rust -//! # extern crate actix_web; -//! use actix_web::{http, App, HttpRequest, HttpResponse}; -//! use actix_web::middleware::cors; -//! -//! fn index(mut req: HttpRequest) -> &'static str { -//! "Hello world" -//! } -//! -//! fn main() { -//! let app = App::new() -//! .resource("/index.html", |r| { -//! cors::Cors::build() // <- Construct CORS middleware -//! .allowed_origin("https://www.rust-lang.org/") -//! .allowed_methods(vec!["GET", "POST"]) -//! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) -//! .allowed_header(http::header::CONTENT_TYPE) -//! .max_age(3600) -//! .finish().expect("Can not create CORS middleware") -//! .register(r); // <- Register CORS middleware -//! r.method(http::Method::GET).f(|_| HttpResponse::Ok()); -//! r.method(http::Method::HEAD).f(|_| HttpResponse::MethodNotAllowed()); -//! }) -//! .finish(); -//! } -//! ``` -//! In this example custom *CORS* middleware get registered for "/index.html" endpoint. -//! -//! Cors middleware automatically handle *OPTIONS* preflight request. -use std::collections::HashSet; -use std::iter::FromIterator; - -use http::{self, Method, HttpTryFrom, Uri, StatusCode}; -use http::header::{self, HeaderName, HeaderValue}; - -use error::{Result, ResponseError}; -use resource::ResourceHandler; -use httpmessage::HttpMessage; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; -use middleware::{Middleware, Response, Started}; - -/// A set of errors that can occur during processing CORS -#[derive(Debug, Fail)] -pub enum CorsError { - /// The HTTP request header `Origin` is required but was not provided - #[fail(display="The HTTP request header `Origin` is required but was not provided")] - MissingOrigin, - /// The HTTP request header `Origin` could not be parsed correctly. - #[fail(display="The HTTP request header `Origin` could not be parsed correctly.")] - BadOrigin, - /// The request header `Access-Control-Request-Method` is required but is missing - #[fail(display="The request header `Access-Control-Request-Method` is required but is missing")] - MissingRequestMethod, - /// The request header `Access-Control-Request-Method` has an invalid value - #[fail(display="The request header `Access-Control-Request-Method` has an invalid value")] - BadRequestMethod, - /// The request header `Access-Control-Request-Headers` has an invalid value - #[fail(display="The request header `Access-Control-Request-Headers` has an invalid value")] - BadRequestHeaders, - /// The request header `Access-Control-Request-Headers` is required but is missing. - #[fail(display="The request header `Access-Control-Request-Headers` is required but is - missing")] - MissingRequestHeaders, - /// Origin is not allowed to make this request - #[fail(display="Origin is not allowed to make this request")] - OriginNotAllowed, - /// Requested method is not allowed - #[fail(display="Requested method is not allowed")] - MethodNotAllowed, - /// One or more headers requested are not allowed - #[fail(display="One or more headers requested are not allowed")] - HeadersNotAllowed, -} - -/// A set of errors that can occur during building CORS middleware -#[derive(Debug, Fail)] -pub enum CorsBuilderError { - #[fail(display="Parse error: {}", _0)] - ParseError(http::Error), - /// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C - /// - /// This is a misconfiguration. Check the documentation for `Cors`. - #[fail(display="Credentials are allowed, but the Origin is set to \"*\"")] - CredentialsWithWildcardOrigin, -} - - -impl ResponseError for CorsError { - - fn error_response(&self) -> HttpResponse { - HttpResponse::with_body(StatusCode::BAD_REQUEST, format!("{}", self)) - } -} - -/// An enum signifying that some of type T is allowed, or `All` (everything is allowed). -/// -/// `Default` is implemented for this enum and is `All`. -#[derive(Clone, Debug, Eq, PartialEq)] -pub enum AllOrSome { - /// Everything is allowed. Usually equivalent to the "*" value. - All, - /// Only some of `T` is allowed - Some(T), -} - -impl Default for AllOrSome { - fn default() -> Self { - AllOrSome::All - } -} - -impl AllOrSome { - /// Returns whether this is an `All` variant - pub fn is_all(&self) -> bool { - match *self { - AllOrSome::All => true, - AllOrSome::Some(_) => false, - } - } - - /// Returns whether this is a `Some` variant - pub fn is_some(&self) -> bool { - !self.is_all() - } - - /// Returns &T - pub fn as_ref(&self) -> Option<&T> { - match *self { - AllOrSome::All => None, - AllOrSome::Some(ref t) => Some(t), - } - } -} - -/// `Middleware` for Cross-origin resource sharing support -/// -/// The Cors struct contains the settings for CORS requests to be validated and -/// for responses to be generated. -pub struct Cors { - methods: HashSet, - origins: AllOrSome>, - origins_str: Option, - headers: AllOrSome>, - expose_hdrs: Option, - max_age: Option, - preflight: bool, - send_wildcard: bool, - supports_credentials: bool, - vary_header: bool, -} - -impl Default for Cors { - fn default() -> Cors { - Cors { - origins: AllOrSome::default(), - origins_str: None, - methods: HashSet::from_iter( - vec![Method::GET, Method::HEAD, - Method::POST, Method::OPTIONS, Method::PUT, - Method::PATCH, Method::DELETE].into_iter()), - headers: AllOrSome::All, - expose_hdrs: None, - max_age: None, - preflight: true, - send_wildcard: false, - supports_credentials: false, - vary_header: true, - } - } -} - -impl Cors { - pub fn build() -> CorsBuilder { - CorsBuilder { - cors: Some(Cors { - origins: AllOrSome::All, - origins_str: None, - methods: HashSet::new(), - headers: AllOrSome::All, - expose_hdrs: None, - max_age: None, - preflight: true, - send_wildcard: false, - supports_credentials: false, - vary_header: true, - }), - methods: false, - error: None, - expose_hdrs: HashSet::new(), - } - } - - /// This method register cors middleware with resource and - /// adds route for *OPTIONS* preflight requests. - /// - /// It is possible to register *Cors* middleware with `ResourceHandler::middleware()` - /// method, but in that case *Cors* middleware wont be able to handle *OPTIONS* - /// requests. - pub fn register(self, resource: &mut ResourceHandler) { - resource.method(Method::OPTIONS).h(|_| HttpResponse::Ok()); - resource.middleware(self); - } - - fn validate_origin(&self, req: &mut HttpRequest) -> Result<(), CorsError> { - if let Some(hdr) = req.headers().get(header::ORIGIN) { - if let Ok(origin) = hdr.to_str() { - return match self.origins { - AllOrSome::All => Ok(()), - AllOrSome::Some(ref allowed_origins) => { - allowed_origins - .get(origin) - .and_then(|_| Some(())) - .ok_or_else(|| CorsError::OriginNotAllowed) - } - }; - } - Err(CorsError::BadOrigin) - } else { - return match self.origins { - AllOrSome::All => Ok(()), - _ => Err(CorsError::MissingOrigin) - } - } - } - - fn validate_allowed_method(&self, req: &mut HttpRequest) -> Result<(), CorsError> { - if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) { - if let Ok(meth) = hdr.to_str() { - if let Ok(method) = Method::try_from(meth) { - return self.methods.get(&method) - .and_then(|_| Some(())) - .ok_or_else(|| CorsError::MethodNotAllowed); - } - } - Err(CorsError::BadRequestMethod) - } else { - Err(CorsError::MissingRequestMethod) - } - } - - fn validate_allowed_headers(&self, req: &mut HttpRequest) -> Result<(), CorsError> { - match self.headers { - AllOrSome::All => Ok(()), - AllOrSome::Some(ref allowed_headers) => { - if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) { - if let Ok(headers) = hdr.to_str() { - let mut hdrs = HashSet::new(); - for hdr in headers.split(',') { - match HeaderName::try_from(hdr.trim()) { - Ok(hdr) => hdrs.insert(hdr), - Err(_) => return Err(CorsError::BadRequestHeaders) - }; - } - - if !hdrs.is_empty() && !hdrs.is_subset(allowed_headers) { - return Err(CorsError::HeadersNotAllowed) - } - return Ok(()) - } - Err(CorsError::BadRequestHeaders) - } else { - Err(CorsError::MissingRequestHeaders) - } - } - } - } -} - -impl Middleware for Cors { - - fn start(&self, req: &mut HttpRequest) -> Result { - if self.preflight && Method::OPTIONS == *req.method() { - self.validate_origin(req)?; - self.validate_allowed_method(req)?; - self.validate_allowed_headers(req)?; - - // allowed headers - let headers = if let Some(headers) = self.headers.as_ref() { - Some(HeaderValue::try_from(&headers.iter().fold( - String::new(), |s, v| s + "," + v.as_str()).as_str()[1..]).unwrap()) - } else if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) { - Some(hdr.clone()) - } else { - None - }; - - Ok(Started::Response( - HttpResponse::Ok() - .if_some(self.max_age.as_ref(), |max_age, resp| { - let _ = resp.header( - header::ACCESS_CONTROL_MAX_AGE, format!("{}", max_age).as_str());}) - .if_some(headers, |headers, resp| { - let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); }) - .if_true(self.origins.is_all(), |resp| { - if self.send_wildcard { - resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*"); - } else { - let origin = req.headers().get(header::ORIGIN).unwrap(); - resp.header( - header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); - } - }) - .if_true(self.origins.is_some(), |resp| { - resp.header( - header::ACCESS_CONTROL_ALLOW_ORIGIN, - self.origins_str.as_ref().unwrap().clone()); - }) - .if_true(self.supports_credentials, |resp| { - resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); - }) - .header( - header::ACCESS_CONTROL_ALLOW_METHODS, - &self.methods.iter().fold( - String::new(), |s, v| s + "," + v.as_str()).as_str()[1..]) - .finish())) - } else { - self.validate_origin(req)?; - - Ok(Started::Done) - } - } - - fn response(&self, req: &mut HttpRequest, mut resp: HttpResponse) -> Result { - match self.origins { - AllOrSome::All => { - if self.send_wildcard { - resp.headers_mut().insert( - header::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*")); - } else if let Some(origin) = req.headers().get(header::ORIGIN) { - resp.headers_mut().insert( - header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); - } - } - AllOrSome::Some(_) => { - resp.headers_mut().insert( - header::ACCESS_CONTROL_ALLOW_ORIGIN, - self.origins_str.as_ref().unwrap().clone()); - } - } - - if let Some(ref expose) = self.expose_hdrs { - resp.headers_mut().insert( - header::ACCESS_CONTROL_EXPOSE_HEADERS, - HeaderValue::try_from(expose.as_str()).unwrap()); - } - if self.supports_credentials { - resp.headers_mut().insert( - header::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true")); - } - if self.vary_header { - let value = if let Some(hdr) = resp.headers_mut().get(header::VARY) { - let mut val: Vec = Vec::with_capacity(hdr.as_bytes().len() + 8); - val.extend(hdr.as_bytes()); - val.extend(b", Origin"); - HeaderValue::try_from(&val[..]).unwrap() - } else { - HeaderValue::from_static("Origin") - }; - resp.headers_mut().insert(header::VARY, value); - } - Ok(Response::Done(resp)) - } -} - -/// Structure that follows the builder pattern for building `Cors` middleware structs. -/// -/// To construct a cors: -/// -/// 1. Call [`Cors::build`](struct.Cors.html#method.build) to start building. -/// 2. Use any of the builder methods to set fields in the backend. -/// 3. Call [finish](struct.Cors.html#method.finish) to retrieve the constructed backend. -/// -/// # Example -/// -/// ```rust -/// # extern crate http; -/// # extern crate actix_web; -/// use http::header; -/// use actix_web::middleware::cors; -/// -/// # fn main() { -/// let cors = cors::Cors::build() -/// .allowed_origin("https://www.rust-lang.org/") -/// .allowed_methods(vec!["GET", "POST"]) -/// .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) -/// .allowed_header(header::CONTENT_TYPE) -/// .max_age(3600) -/// .finish().unwrap(); -/// # } -/// ``` -pub struct CorsBuilder { - cors: Option, - methods: bool, - error: Option, - expose_hdrs: HashSet, -} - -fn cors<'a>(parts: &'a mut Option, err: &Option) -> Option<&'a mut Cors> { - if err.is_some() { - return None - } - parts.as_mut() -} - -impl CorsBuilder { - - /// Add an origin that are allowed to make requests. - /// Will be verified against the `Origin` request header. - /// - /// When `All` is set, and `send_wildcard` is set, "*" will be sent in - /// the `Access-Control-Allow-Origin` response header. Otherwise, the client's `Origin` request - /// header will be echoed back in the `Access-Control-Allow-Origin` response header. - /// - /// When `Some` is set, the client's `Origin` request header will be checked in a - /// case-sensitive manner. - /// - /// This is the `list of origins` in the - /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). - /// - /// Defaults to `All`. - pub fn allowed_origin(&mut self, origin: &str) -> &mut CorsBuilder { - if let Some(cors) = cors(&mut self.cors, &self.error) { - match Uri::try_from(origin) { - Ok(_) => { - if cors.origins.is_all() { - cors.origins = AllOrSome::Some(HashSet::new()); - } - if let AllOrSome::Some(ref mut origins) = cors.origins { - origins.insert(origin.to_owned()); - } - } - Err(e) => { - self.error = Some(e.into()); - } - } - } - self - } - - /// Set a list of methods which the allowed origins are allowed to access for - /// requests. - /// - /// This is the `list of methods` in the - /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). - /// - /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]` - pub fn allowed_methods(&mut self, methods: U) -> &mut CorsBuilder - where U: IntoIterator, Method: HttpTryFrom - { - self.methods = true; - if let Some(cors) = cors(&mut self.cors, &self.error) { - for m in methods { - match Method::try_from(m) { - Ok(method) => { - cors.methods.insert(method); - }, - Err(e) => { - self.error = Some(e.into()); - break - } - } - }; - } - self - } - - /// Set an allowed header - pub fn allowed_header(&mut self, header: H) -> &mut CorsBuilder - where HeaderName: HttpTryFrom - { - if let Some(cors) = cors(&mut self.cors, &self.error) { - match HeaderName::try_from(header) { - Ok(method) => { - if cors.headers.is_all() { - cors.headers = AllOrSome::Some(HashSet::new()); - } - if let AllOrSome::Some(ref mut headers) = cors.headers { - headers.insert(method); - } - } - Err(e) => self.error = Some(e.into()), - } - } - self - } - - /// Set a list of header field names which can be used when - /// this resource is accessed by allowed origins. - /// - /// If `All` is set, whatever is requested by the client in `Access-Control-Request-Headers` - /// will be echoed back in the `Access-Control-Allow-Headers` header. - /// - /// This is the `list of headers` in the - /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). - /// - /// Defaults to `All`. - pub fn allowed_headers(&mut self, headers: U) -> &mut CorsBuilder - where U: IntoIterator, HeaderName: HttpTryFrom - { - if let Some(cors) = cors(&mut self.cors, &self.error) { - for h in headers { - match HeaderName::try_from(h) { - Ok(method) => { - if cors.headers.is_all() { - cors.headers = AllOrSome::Some(HashSet::new()); - } - if let AllOrSome::Some(ref mut headers) = cors.headers { - headers.insert(method); - } - } - Err(e) => { - self.error = Some(e.into()); - break - } - } - }; - } - self - } - - /// Set a list of headers which are safe to expose to the API of a CORS API specification. - /// This corresponds to the `Access-Control-Expose-Headers` response header. - /// - /// This is the `list of exposed headers` in the - /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). - /// - /// This defaults to an empty set. - pub fn expose_headers(&mut self, headers: U) -> &mut CorsBuilder - where U: IntoIterator, HeaderName: HttpTryFrom - { - for h in headers { - match HeaderName::try_from(h) { - Ok(method) => { - self.expose_hdrs.insert(method); - }, - Err(e) => { - self.error = Some(e.into()); - break - } - } - } - self - } - - /// Set a maximum time for which this CORS request maybe cached. - /// This value is set as the `Access-Control-Max-Age` header. - /// - /// This defaults to `None` (unset). - pub fn max_age(&mut self, max_age: usize) -> &mut CorsBuilder { - if let Some(cors) = cors(&mut self.cors, &self.error) { - cors.max_age = Some(max_age) - } - self - } - - /// Set a wildcard origins - /// - /// If send wildcard is set and the `allowed_origins` parameter is `All`, a wildcard - /// `Access-Control-Allow-Origin` response header is sent, rather than the request’s - /// `Origin` header. - /// - /// This is the `supports credentials flag` in the - /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). - /// - /// This **CANNOT** be used in conjunction with `allowed_origins` set to `All` and - /// `allow_credentials` set to `true`. Depending on the mode of usage, this will either result - /// in an `Error::CredentialsWithWildcardOrigin` error during actix launch or runtime. - /// - /// Defaults to `false`. - pub fn send_wildcard(&mut self) -> &mut CorsBuilder { - if let Some(cors) = cors(&mut self.cors, &self.error) { - cors.send_wildcard = true - } - self - } - - /// Allows users to make authenticated requests - /// - /// If true, injects the `Access-Control-Allow-Credentials` header in responses. - /// This allows cookies and credentials to be submitted across domains. - /// - /// This option cannot be used in conjunction with an `allowed_origin` set to `All` - /// and `send_wildcards` set to `true`. - /// - /// Defaults to `false`. - pub fn supports_credentials(&mut self) -> &mut CorsBuilder { - if let Some(cors) = cors(&mut self.cors, &self.error) { - cors.supports_credentials = true - } - self - } - - /// Disable `Vary` header support. - /// - /// When enabled the header `Vary: Origin` will be returned as per the W3 - /// implementation guidelines. - /// - /// Setting this header when the `Access-Control-Allow-Origin` is - /// dynamically generated (e.g. when there is more than one allowed - /// origin, and an Origin than '*' is returned) informs CDNs and other - /// caches that the CORS headers are dynamic, and cannot be cached. - /// - /// By default `vary` header support is enabled. - pub fn disable_vary_header(&mut self) -> &mut CorsBuilder { - if let Some(cors) = cors(&mut self.cors, &self.error) { - cors.vary_header = false - } - self - } - - /// Disable *preflight* request support. - /// - /// When enabled cors middleware automatically handles *OPTIONS* request. - /// This is useful application level middleware. - /// - /// By default *preflight* support is enabled. - pub fn disable_preflight(&mut self) -> &mut CorsBuilder { - if let Some(cors) = cors(&mut self.cors, &self.error) { - cors.preflight = false - } - self - } - - /// Finishes building and returns the built `Cors` instance. - pub fn finish(&mut self) -> Result { - if !self.methods { - self.allowed_methods(vec![Method::GET, Method::HEAD, - Method::POST, Method::OPTIONS, Method::PUT, - Method::PATCH, Method::DELETE]); - } - - if let Some(e) = self.error.take() { - return Err(CorsBuilderError::ParseError(e)) - } - - let mut cors = self.cors.take().expect("cannot reuse CorsBuilder"); - - if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() { - return Err(CorsBuilderError::CredentialsWithWildcardOrigin) - } - - if let AllOrSome::Some(ref origins) = cors.origins { - let s = origins.iter().fold(String::new(), |s, v| s + &format!("{}", v)); - cors.origins_str = Some(HeaderValue::try_from(s.as_str()).unwrap()); - } - - if !self.expose_hdrs.is_empty() { - cors.expose_hdrs = Some( - self.expose_hdrs.iter().fold( - String::new(), |s, v| s + v.as_str())[1..].to_owned()); - } - Ok(cors) - } -} - - -#[cfg(test)] -mod tests { - use super::*; - use test::TestRequest; - - impl Started { - fn is_done(&self) -> bool { - match *self { - Started::Done => true, - _ => false, - } - } - fn response(self) -> HttpResponse { - match self { - Started::Response(resp) => resp, - _ => panic!(), - } - } - } - impl Response { - fn response(self) -> HttpResponse { - match self { - Response::Done(resp) => resp, - _ => panic!(), - } - } - } - - #[test] - #[should_panic(expected = "CredentialsWithWildcardOrigin")] - fn cors_validates_illegal_allow_credentials() { - Cors::build() - .supports_credentials() - .send_wildcard() - .finish() - .unwrap(); - } - - #[test] - fn validate_origin_allows_all_origins() { - let cors = Cors::default(); - let mut req = TestRequest::with_header( - "Origin", "https://www.example.com").finish(); - - assert!(cors.start(&mut req).ok().unwrap().is_done()) - } - - #[test] - fn test_preflight() { - let mut cors = Cors::build() - .send_wildcard() - .max_age(3600) - .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) - .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) - .allowed_header(header::CONTENT_TYPE) - .finish().unwrap(); - - let mut req = TestRequest::with_header( - "Origin", "https://www.example.com") - .method(Method::OPTIONS) - .finish(); - - assert!(cors.start(&mut req).is_err()); - - let mut req = TestRequest::with_header("Origin", "https://www.example.com") - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put") - .method(Method::OPTIONS) - .finish(); - - assert!(cors.start(&mut req).is_err()); - - let mut req = TestRequest::with_header("Origin", "https://www.example.com") - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") - .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT") - .method(Method::OPTIONS) - .finish(); - - let resp = cors.start(&mut req).unwrap().response(); - assert_eq!( - &b"*"[..], - resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()); - assert_eq!( - &b"3600"[..], - resp.headers().get(header::ACCESS_CONTROL_MAX_AGE).unwrap().as_bytes()); - //assert_eq!( - // &b"authorization,accept,content-type"[..], - // resp.headers().get(header::ACCESS_CONTROL_ALLOW_HEADERS).unwrap().as_bytes()); - //assert_eq!( - // &b"POST,GET,OPTIONS"[..], - // resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().as_bytes()); - - cors.preflight = false; - assert!(cors.start(&mut req).unwrap().is_done()); - } - - #[test] - #[should_panic(expected = "MissingOrigin")] - fn test_validate_missing_origin() { - let cors = Cors::build() - .allowed_origin("https://www.example.com").finish().unwrap(); - - let mut req = HttpRequest::default(); - cors.start(&mut req).unwrap(); - } - - #[test] - #[should_panic(expected = "OriginNotAllowed")] - fn test_validate_not_allowed_origin() { - let cors = Cors::build() - .allowed_origin("https://www.example.com").finish().unwrap(); - - let mut req = TestRequest::with_header("Origin", "https://www.unknown.com") - .method(Method::GET) - .finish(); - cors.start(&mut req).unwrap(); - } - - #[test] - fn test_validate_origin() { - let cors = Cors::build() - .allowed_origin("https://www.example.com").finish().unwrap(); - - let mut req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::GET) - .finish(); - - assert!(cors.start(&mut req).unwrap().is_done()); - } - - #[test] - fn test_no_origin_response() { - let cors = Cors::build().finish().unwrap(); - - let mut req = TestRequest::default().method(Method::GET).finish(); - let resp: HttpResponse = HttpResponse::Ok().into(); - let resp = cors.response(&mut req, resp).unwrap().response(); - assert!(resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).is_none()); - - let mut req = TestRequest::with_header( - "Origin", "https://www.example.com") - .method(Method::OPTIONS) - .finish(); - let resp = cors.response(&mut req, resp).unwrap().response(); - assert_eq!( - &b"https://www.example.com"[..], - resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()); - } - - #[test] - fn test_response() { - let cors = Cors::build() - .send_wildcard() - .disable_preflight() - .max_age(3600) - .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) - .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) - .allowed_header(header::CONTENT_TYPE) - .finish().unwrap(); - - let mut req = TestRequest::with_header( - "Origin", "https://www.example.com") - .method(Method::OPTIONS) - .finish(); - - let resp: HttpResponse = HttpResponse::Ok().into(); - let resp = cors.response(&mut req, resp).unwrap().response(); - assert_eq!( - &b"*"[..], - resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()); - assert_eq!( - &b"Origin"[..], - resp.headers().get(header::VARY).unwrap().as_bytes()); - - let resp: HttpResponse = HttpResponse::Ok() - .header(header::VARY, "Accept") - .finish(); - let resp = cors.response(&mut req, resp).unwrap().response(); - assert_eq!( - &b"Accept, Origin"[..], - resp.headers().get(header::VARY).unwrap().as_bytes()); - - let cors = Cors::build() - .disable_vary_header() - .allowed_origin("https://www.example.com") - .finish().unwrap(); - let resp: HttpResponse = HttpResponse::Ok().into(); - let resp = cors.response(&mut req, resp).unwrap().response(); - assert_eq!( - &b"https://www.example.com"[..], - resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()); - } -} diff --git a/src/middleware/csrf.rs b/src/middleware/csrf.rs deleted file mode 100644 index c2003ae35..000000000 --- a/src/middleware/csrf.rs +++ /dev/null @@ -1,296 +0,0 @@ -//! A filter for cross-site request forgery (CSRF). -//! -//! This middleware is stateless and [based on request -//! headers](https://www.owasp.org/index.php/Cross-Site_Request_Forgery_(CSRF)_Prevention_Cheat_Sheet#Verifying_Same_Origin_with_Standard_Headers). -//! -//! By default requests are allowed only if one of these is true: -//! -//! * The request method is safe (`GET`, `HEAD`, `OPTIONS`). It is the -//! applications responsibility to ensure these methods cannot be used to -//! execute unwanted actions. Note that upgrade requests for websockets are -//! also considered safe. -//! * The `Origin` header (added automatically by the browser) matches one -//! of the allowed origins. -//! * There is no `Origin` header but the `Referer` header matches one of -//! the allowed origins. -//! -//! Use [`CsrfFilterBuilder::allow_xhr()`](struct.CsrfFilterBuilder.html#method.allow_xhr) -//! if you want to allow requests with unsafe methods via -//! [CORS](../cors/struct.Cors.html). -//! -//! # Example -//! -//! ``` -//! # extern crate actix_web; -//! use actix_web::{http, App, HttpRequest, HttpResponse}; -//! use actix_web::middleware::csrf; -//! -//! fn handle_post(_: HttpRequest) -> &'static str { -//! "This action should only be triggered with requests from the same site" -//! } -//! -//! fn main() { -//! let app = App::new() -//! .middleware( -//! csrf::CsrfFilter::build() -//! .allowed_origin("https://www.example.com") -//! .finish()) -//! .resource("/", |r| { -//! r.method(http::Method::GET).f(|_| HttpResponse::Ok()); -//! r.method(http::Method::POST).f(handle_post); -//! }) -//! .finish(); -//! } -//! ``` -//! -//! In this example the entire application is protected from CSRF. - -use std::borrow::Cow; -use std::collections::HashSet; - -use bytes::Bytes; -use error::{Result, ResponseError}; -use http::{HeaderMap, HttpTryFrom, Uri, header}; -use httpmessage::HttpMessage; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; -use middleware::{Middleware, Started}; - -/// Potential cross-site request forgery detected. -#[derive(Debug, Fail)] -pub enum CsrfError { - /// The HTTP request header `Origin` was required but not provided. - #[fail(display="Origin header required")] - MissingOrigin, - /// The HTTP request header `Origin` could not be parsed correctly. - #[fail(display="Could not parse Origin header")] - BadOrigin, - /// The cross-site request was denied. - #[fail(display="Cross-site request denied")] - CsrDenied, -} - -impl ResponseError for CsrfError { - fn error_response(&self) -> HttpResponse { - HttpResponse::Forbidden().body(self.to_string()) - } -} - -fn uri_origin(uri: &Uri) -> Option { - match (uri.scheme_part(), uri.host(), uri.port()) { - (Some(scheme), Some(host), Some(port)) => { - Some(format!("{}://{}:{}", scheme, host, port)) - } - (Some(scheme), Some(host), None) => { - Some(format!("{}://{}", scheme, host)) - } - _ => None - } -} - -fn origin(headers: &HeaderMap) -> Option, CsrfError>> { - headers.get(header::ORIGIN) - .map(|origin| { - origin - .to_str() - .map_err(|_| CsrfError::BadOrigin) - .map(|o| o.into()) - }) - .or_else(|| { - headers.get(header::REFERER) - .map(|referer| { - Uri::try_from(Bytes::from(referer.as_bytes())) - .ok() - .as_ref() - .and_then(uri_origin) - .ok_or(CsrfError::BadOrigin) - .map(|o| o.into()) - }) - }) -} - -/// A middleware that filters cross-site requests. -pub struct CsrfFilter { - origins: HashSet, - allow_xhr: bool, - allow_missing_origin: bool, - allow_upgrade: bool, -} - -impl CsrfFilter { - /// Start building a `CsrfFilter`. - pub fn build() -> CsrfFilterBuilder { - CsrfFilterBuilder { - csrf: CsrfFilter { - origins: HashSet::new(), - allow_xhr: false, - allow_missing_origin: false, - allow_upgrade: false, - } - } - } - - fn validate(&self, req: &mut HttpRequest) -> Result<(), CsrfError> { - let is_upgrade = req.headers().contains_key(header::UPGRADE); - let is_safe = req.method().is_safe() && (self.allow_upgrade || !is_upgrade); - - if is_safe || (self.allow_xhr && req.headers().contains_key("x-requested-with")) { - Ok(()) - } else if let Some(header) = origin(req.headers()) { - match header { - Ok(ref origin) if self.origins.contains(origin.as_ref()) => Ok(()), - Ok(_) => Err(CsrfError::CsrDenied), - Err(err) => Err(err), - } - } else if self.allow_missing_origin { - Ok(()) - } else { - Err(CsrfError::MissingOrigin) - } - } -} - -impl Middleware for CsrfFilter { - fn start(&self, req: &mut HttpRequest) -> Result { - self.validate(req)?; - Ok(Started::Done) - } -} - -/// Used to build a `CsrfFilter`. -/// -/// To construct a CSRF filter: -/// -/// 1. Call [`CsrfFilter::build`](struct.CsrfFilter.html#method.build) to -/// start building. -/// 2. [Add](struct.CsrfFilterBuilder.html#method.allowed_origin) allowed -/// origins. -/// 3. Call [finish](struct.CsrfFilterBuilder.html#method.finish) to retrieve -/// the constructed filter. -/// -/// # Example -/// -/// ``` -/// use actix_web::middleware::csrf; -/// -/// let csrf = csrf::CsrfFilter::build() -/// .allowed_origin("https://www.example.com") -/// .finish(); -/// ``` -pub struct CsrfFilterBuilder { - csrf: CsrfFilter, -} - -impl CsrfFilterBuilder { - /// Add an origin that is allowed to make requests. Will be verified - /// against the `Origin` request header. - pub fn allowed_origin(mut self, origin: &str) -> CsrfFilterBuilder { - self.csrf.origins.insert(origin.to_owned()); - self - } - - /// Allow all requests with an `X-Requested-With` header. - /// - /// A cross-site attacker should not be able to send requests with custom - /// headers unless a CORS policy whitelists them. Therefore it should be - /// safe to allow requests with an `X-Requested-With` header (added - /// automatically by many JavaScript libraries). - /// - /// This is disabled by default, because in Safari it is possible to - /// circumvent this using redirects and Flash. - /// - /// Use this method to enable more lax filtering. - pub fn allow_xhr(mut self) -> CsrfFilterBuilder { - self.csrf.allow_xhr = true; - self - } - - /// Allow requests if the expected `Origin` header is missing (and - /// there is no `Referer` to fall back on). - /// - /// The filter is conservative by default, but it should be safe to allow - /// missing `Origin` headers because a cross-site attacker cannot prevent - /// the browser from sending `Origin` on unsafe requests. - pub fn allow_missing_origin(mut self) -> CsrfFilterBuilder { - self.csrf.allow_missing_origin = true; - self - } - - /// Allow cross-site upgrade requests (for example to open a WebSocket). - pub fn allow_upgrade(mut self) -> CsrfFilterBuilder { - self.csrf.allow_upgrade = true; - self - } - - /// Finishes building the `CsrfFilter` instance. - pub fn finish(self) -> CsrfFilter { - self.csrf - } -} - -#[cfg(test)] -mod tests { - use super::*; - use http::Method; - use test::TestRequest; - - #[test] - fn test_safe() { - let csrf = CsrfFilter::build() - .allowed_origin("https://www.example.com") - .finish(); - - let mut req = TestRequest::with_header("Origin", "https://www.w3.org") - .method(Method::HEAD) - .finish(); - - assert!(csrf.start(&mut req).is_ok()); - } - - #[test] - fn test_csrf() { - let csrf = CsrfFilter::build() - .allowed_origin("https://www.example.com") - .finish(); - - let mut req = TestRequest::with_header("Origin", "https://www.w3.org") - .method(Method::POST) - .finish(); - - assert!(csrf.start(&mut req).is_err()); - } - - #[test] - fn test_referer() { - let csrf = CsrfFilter::build() - .allowed_origin("https://www.example.com") - .finish(); - - let mut req = TestRequest::with_header("Referer", "https://www.example.com/some/path?query=param") - .method(Method::POST) - .finish(); - - assert!(csrf.start(&mut req).is_ok()); - } - - #[test] - fn test_upgrade() { - let strict_csrf = CsrfFilter::build() - .allowed_origin("https://www.example.com") - .finish(); - - let lax_csrf = CsrfFilter::build() - .allowed_origin("https://www.example.com") - .allow_upgrade() - .finish(); - - let mut req = TestRequest::with_header("Origin", "https://cswsh.com") - .header("Connection", "Upgrade") - .header("Upgrade", "websocket") - .method(Method::GET) - .finish(); - - assert!(strict_csrf.start(&mut req).is_err()); - assert!(lax_csrf.start(&mut req).is_ok()); - } -} diff --git a/src/middleware/defaultheaders.rs b/src/middleware/defaultheaders.rs index 5399b29d4..05a031065 100644 --- a/src/middleware/defaultheaders.rs +++ b/src/middleware/defaultheaders.rs @@ -1,40 +1,50 @@ -//! Default response headers -use http::{HeaderMap, HttpTryFrom}; -use http::header::{HeaderName, HeaderValue, CONTENT_TYPE}; +//! Middleware for setting default response headers +use std::rc::Rc; +use std::task::{Context, Poll}; -use error::Result; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; -use middleware::{Response, Middleware}; +use actix_service::{Service, Transform}; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; + +use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE}; +use crate::http::{HeaderMap, HttpTryFrom}; +use crate::service::{ServiceRequest, ServiceResponse}; +use crate::Error; /// `Middleware` for setting default response headers. /// /// This middleware does not set header if response headers already contains it. /// /// ```rust -/// # extern crate actix_web; -/// use actix_web::{http, middleware, App, HttpResponse}; +/// use actix_web::{web, http, middleware, App, HttpResponse}; /// /// fn main() { /// let app = App::new() -/// .middleware( -/// middleware::DefaultHeaders::new() -/// .header("X-Version", "0.2")) -/// .resource("/test", |r| { -/// r.method(http::Method::GET).f(|_| HttpResponse::Ok()); -/// r.method(http::Method::HEAD).f(|_| HttpResponse::MethodNotAllowed()); -/// }) -/// .finish(); +/// .wrap(middleware::DefaultHeaders::new().header("X-Version", "0.2")) +/// .service( +/// web::resource("/test") +/// .route(web::get().to(|| HttpResponse::Ok())) +/// .route(web::method(http::Method::HEAD).to(|| HttpResponse::MethodNotAllowed())) +/// ); /// } /// ``` -pub struct DefaultHeaders{ +#[derive(Clone)] +pub struct DefaultHeaders { + inner: Rc, +} + +struct Inner { ct: bool, headers: HeaderMap, } impl Default for DefaultHeaders { fn default() -> Self { - DefaultHeaders{ct: false, headers: HeaderMap::new()} + DefaultHeaders { + inner: Rc::new(Inner { + ct: false, + headers: HeaderMap::new(), + }), + } } } @@ -46,17 +56,21 @@ impl DefaultHeaders { /// Set a header. #[inline] - #[cfg_attr(feature = "cargo-clippy", allow(match_wild_err_arm))] pub fn header(mut self, key: K, value: V) -> Self - where HeaderName: HttpTryFrom, - HeaderValue: HttpTryFrom + where + HeaderName: HttpTryFrom, + HeaderValue: HttpTryFrom, { + #[allow(clippy::match_wild_err_arm)] match HeaderName::try_from(key) { - Ok(key) => { - match HeaderValue::try_from(value) { - Ok(value) => { self.headers.append(key, value); } - Err(_) => panic!("Can not create header value"), + Ok(key) => match HeaderValue::try_from(value) { + Ok(value) => { + Rc::get_mut(&mut self.inner) + .expect("Multiple copies exist") + .headers + .append(key, value); } + Err(_) => panic!("Can not create header value"), }, Err(_) => panic!("Can not create header name"), } @@ -65,52 +79,130 @@ impl DefaultHeaders { /// Set *CONTENT-TYPE* header if response does not contain this header. pub fn content_type(mut self) -> Self { - self.ct = true; + Rc::get_mut(&mut self.inner) + .expect("Multiple copies exist") + .ct = true; self } } -impl Middleware for DefaultHeaders { +impl Transform for DefaultHeaders +where + S: Service, Error = Error>, + S::Future: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Transform = DefaultHeadersMiddleware; + type Future = Ready>; - fn response(&self, _: &mut HttpRequest, mut resp: HttpResponse) -> Result { - for (key, value) in self.headers.iter() { - if !resp.headers().contains_key(key) { - resp.headers_mut().insert(key, value.clone()); + fn new_transform(&self, service: S) -> Self::Future { + ok(DefaultHeadersMiddleware { + service, + inner: self.inner.clone(), + }) + } +} + +pub struct DefaultHeadersMiddleware { + service: S, + inner: Rc, +} + +impl Service for DefaultHeadersMiddleware +where + S: Service, Error = Error>, + S::Future: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: ServiceRequest) -> Self::Future { + let inner = self.inner.clone(); + let fut = self.service.call(req); + + async move { + let mut res = fut.await?; + + // set response headers + for (key, value) in inner.headers.iter() { + if !res.headers().contains_key(key) { + res.headers_mut().insert(key.clone(), value.clone()); + } } + // default content-type + if inner.ct && !res.headers().contains_key(&CONTENT_TYPE) { + res.headers_mut().insert( + CONTENT_TYPE, + HeaderValue::from_static("application/octet-stream"), + ); + } + Ok(res) } - // default content-type - if self.ct && !resp.headers().contains_key(CONTENT_TYPE) { - resp.headers_mut().insert( - CONTENT_TYPE, HeaderValue::from_static("application/octet-stream")); - } - Ok(Response::Done(resp)) + .boxed_local() } } #[cfg(test)] mod tests { + use actix_service::IntoService; + use futures::future::ok; + use super::*; - use http::header::CONTENT_TYPE; + use crate::dev::ServiceRequest; + use crate::http::header::CONTENT_TYPE; + use crate::test::{ok_service, TestRequest}; + use crate::HttpResponse; - #[test] - fn test_default_headers() { - let mw = DefaultHeaders::new() - .header(CONTENT_TYPE, "0001"); + #[actix_rt::test] + async fn test_default_headers() { + let mut mw = DefaultHeaders::new() + .header(CONTENT_TYPE, "0001") + .new_transform(ok_service()) + .await + .unwrap(); - let mut req = HttpRequest::default(); - - let resp = HttpResponse::Ok().finish(); - let resp = match mw.response(&mut req, resp) { - Ok(Response::Done(resp)) => resp, - _ => panic!(), - }; + let req = TestRequest::default().to_srv_request(); + let resp = mw.call(req).await.unwrap(); assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); - let resp = HttpResponse::Ok().header(CONTENT_TYPE, "0002").finish(); - let resp = match mw.response(&mut req, resp) { - Ok(Response::Done(resp)) => resp, - _ => panic!(), + let req = TestRequest::default().to_srv_request(); + let srv = |req: ServiceRequest| { + ok(req + .into_response(HttpResponse::Ok().header(CONTENT_TYPE, "0002").finish())) }; + let mut mw = DefaultHeaders::new() + .header(CONTENT_TYPE, "0001") + .new_transform(srv.into_service()) + .await + .unwrap(); + let resp = mw.call(req).await.unwrap(); assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002"); } + + #[actix_rt::test] + async fn test_content_type() { + let srv = + |req: ServiceRequest| ok(req.into_response(HttpResponse::Ok().finish())); + let mut mw = DefaultHeaders::new() + .content_type() + .new_transform(srv.into_service()) + .await + .unwrap(); + + let req = TestRequest::default().to_srv_request(); + let resp = mw.call(req).await.unwrap(); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + "application/octet-stream" + ); + } } diff --git a/src/middleware/errhandlers.rs b/src/middleware/errhandlers.rs index db3c70f34..3dc1f0828 100644 --- a/src/middleware/errhandlers.rs +++ b/src/middleware/errhandlers.rs @@ -1,59 +1,69 @@ -use std::collections::HashMap; +//! Custom handlers service for responses. +use std::rc::Rc; +use std::task::{Context, Poll}; -use error::Result; -use http::StatusCode; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; -use middleware::{Middleware, Response}; +use actix_service::{Service, Transform}; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; +use hashbrown::HashMap; +use crate::dev::{ServiceRequest, ServiceResponse}; +use crate::error::{Error, Result}; +use crate::http::StatusCode; -type ErrorHandler = Fn(&mut HttpRequest, HttpResponse) -> Result; +/// Error handler response +pub enum ErrorHandlerResponse { + /// New http response got generated + Response(ServiceResponse), + /// Result is a future that resolves to a new http response + Future(LocalBoxFuture<'static, Result, Error>>), +} + +type ErrorHandler = dyn Fn(ServiceResponse) -> Result>; /// `Middleware` for allowing custom handlers for responses. /// -/// You can use `ErrorHandlers::handler()` method to register a custom error handler -/// for specific status code. You can modify existing response or create completly new -/// one. +/// You can use `ErrorHandlers::handler()` method to register a custom error +/// handler for specific status code. You can modify existing response or +/// create completely new one. /// /// ## Example /// /// ```rust -/// # extern crate actix_web; -/// use actix_web::{http, App, HttpRequest, HttpResponse, Result}; -/// use actix_web::middleware::{Response, ErrorHandlers}; +/// use actix_web::middleware::errhandlers::{ErrorHandlers, ErrorHandlerResponse}; +/// use actix_web::{web, http, dev, App, HttpRequest, HttpResponse, Result}; /// -/// fn render_500(_: &mut HttpRequest, resp: HttpResponse) -> Result { -/// let mut builder = resp.into_builder(); -/// builder.header(http::header::CONTENT_TYPE, "application/json"); -/// Ok(Response::Done(builder.into())) +/// fn render_500(mut res: dev::ServiceResponse) -> Result> { +/// res.response_mut() +/// .headers_mut() +/// .insert(http::header::CONTENT_TYPE, http::HeaderValue::from_static("Error")); +/// Ok(ErrorHandlerResponse::Response(res)) /// } /// -/// fn main() { -/// let app = App::new() -/// .middleware( -/// ErrorHandlers::new() -/// .handler(http::StatusCode::INTERNAL_SERVER_ERROR, render_500)) -/// .resource("/test", |r| { -/// r.method(http::Method::GET).f(|_| HttpResponse::Ok()); -/// r.method(http::Method::HEAD).f(|_| HttpResponse::MethodNotAllowed()); -/// }) -/// .finish(); -/// } +/// # fn main() { +/// let app = App::new() +/// .wrap( +/// ErrorHandlers::new() +/// .handler(http::StatusCode::INTERNAL_SERVER_ERROR, render_500), +/// ) +/// .service(web::resource("/test") +/// .route(web::get().to(|| HttpResponse::Ok())) +/// .route(web::head().to(|| HttpResponse::MethodNotAllowed()) +/// )); +/// # } /// ``` -pub struct ErrorHandlers { - handlers: HashMap>>, +pub struct ErrorHandlers { + handlers: Rc>>>, } -impl Default for ErrorHandlers { +impl Default for ErrorHandlers { fn default() -> Self { ErrorHandlers { - handlers: HashMap::new(), + handlers: Rc::new(HashMap::new()), } } } -impl ErrorHandlers { - +impl ErrorHandlers { /// Construct new `ErrorHandlers` instance pub fn new() -> Self { ErrorHandlers::default() @@ -61,54 +71,136 @@ impl ErrorHandlers { /// Register error handler for specified status code pub fn handler(mut self, status: StatusCode, handler: F) -> Self - where F: Fn(&mut HttpRequest, HttpResponse) -> Result + 'static + where + F: Fn(ServiceResponse) -> Result> + 'static, { - self.handlers.insert(status, Box::new(handler)); + Rc::get_mut(&mut self.handlers) + .unwrap() + .insert(status, Box::new(handler)); self } } -impl Middleware for ErrorHandlers { +impl Transform for ErrorHandlers +where + S: Service, Error = Error>, + S::Future: 'static, + B: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Transform = ErrorHandlersMiddleware; + type Future = Ready>; - fn response(&self, req: &mut HttpRequest, resp: HttpResponse) -> Result { - if let Some(handler) = self.handlers.get(&resp.status()) { - handler(req, resp) - } else { - Ok(Response::Done(resp)) + fn new_transform(&self, service: S) -> Self::Future { + ok(ErrorHandlersMiddleware { + service, + handlers: self.handlers.clone(), + }) + } +} + +#[doc(hidden)] +pub struct ErrorHandlersMiddleware { + service: S, + handlers: Rc>>>, +} + +impl Service for ErrorHandlersMiddleware +where + S: Service, Error = Error>, + S::Future: 'static, + B: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: ServiceRequest) -> Self::Future { + let handlers = self.handlers.clone(); + let fut = self.service.call(req); + + async move { + let res = fut.await?; + + if let Some(handler) = handlers.get(&res.status()) { + match handler(res) { + Ok(ErrorHandlerResponse::Response(res)) => Ok(res), + Ok(ErrorHandlerResponse::Future(fut)) => fut.await, + Err(e) => Err(e), + } + } else { + Ok(res) + } } + .boxed_local() } } #[cfg(test)] mod tests { + use actix_service::IntoService; + use futures::future::ok; + use super::*; - use http::StatusCode; - use http::header::CONTENT_TYPE; + use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; + use crate::test::{self, TestRequest}; + use crate::HttpResponse; - fn render_500(_: &mut HttpRequest, resp: HttpResponse) -> Result { - let mut builder = resp.into_builder(); - builder.header(CONTENT_TYPE, "0001"); - Ok(Response::Done(builder.into())) + fn render_500(mut res: ServiceResponse) -> Result> { + res.response_mut() + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); + Ok(ErrorHandlerResponse::Response(res)) } - - #[test] - fn test_handler() { - let mw = ErrorHandlers::new() - .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); - let mut req = HttpRequest::default(); - let resp = HttpResponse::InternalServerError().finish(); - let resp = match mw.response(&mut req, resp) { - Ok(Response::Done(resp)) => resp, - _ => panic!(), + #[actix_rt::test] + async fn test_handler() { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) }; + + let mut mw = ErrorHandlers::new() + .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500) + .new_transform(srv.into_service()) + .await + .unwrap(); + + let resp = + test::call_service(&mut mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + } - let resp = HttpResponse::Ok().finish(); - let resp = match mw.response(&mut req, resp) { - Ok(Response::Done(resp)) => resp, - _ => panic!(), + fn render_500_async( + mut res: ServiceResponse, + ) -> Result> { + res.response_mut() + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); + Ok(ErrorHandlerResponse::Future(ok(res).boxed_local())) + } + + #[actix_rt::test] + async fn test_handler_async() { + let srv = |req: ServiceRequest| { + ok(req.into_response(HttpResponse::InternalServerError().finish())) }; - assert!(!resp.headers().contains_key(CONTENT_TYPE)); + + let mut mw = ErrorHandlers::new() + .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500_async) + .new_transform(srv.into_service()) + .await + .unwrap(); + + let resp = + test::call_service(&mut mw, TestRequest::default().to_srv_request()).await; + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } } diff --git a/src/middleware/logger.rs b/src/middleware/logger.rs index 48a8d3db9..a57ea2961 100644 --- a/src/middleware/logger.rs +++ b/src/middleware/logger.rs @@ -1,19 +1,28 @@ //! Request logging middleware +use std::collections::HashSet; use std::env; -use std::fmt; -use std::fmt::{Display, Formatter}; +use std::fmt::{self, Display, Formatter}; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; -use libc; -use time; +use actix_service::{Service, Transform}; +use bytes::Bytes; +use futures::future::{ok, Ready}; +use log::debug; use regex::Regex; +use time; -use error::Result; -use httpmessage::HttpMessage; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; -use middleware::{Middleware, Started, Finished}; +use crate::dev::{BodySize, MessageBody, ResponseBody}; +use crate::error::{Error, Result}; +use crate::http::{HeaderName, HttpTryFrom, StatusCode}; +use crate::service::{ServiceRequest, ServiceResponse}; +use crate::HttpResponse; /// `Middleware` for logging request and response info to the terminal. +/// /// `Logger` middleware uses standard log crate to log information. You should /// enable logger for `actix_web` package to see access log. /// ([`env_logger`](https://docs.rs/env_logger/*/env_logger/) or similar) @@ -21,25 +30,23 @@ use middleware::{Middleware, Started, Finished}; /// ## Usage /// /// Create `Logger` middleware with the specified `format`. -/// Default `Logger` could be created with `default` method, it uses the default format: +/// Default `Logger` could be created with `default` method, it uses the +/// default format: /// /// ```ignore -/// %a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T +/// %a "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T /// ``` /// ```rust -/// # extern crate actix_web; -/// extern crate env_logger; -/// use actix_web::App; /// use actix_web::middleware::Logger; +/// use actix_web::App; /// /// fn main() { /// std::env::set_var("RUST_LOG", "actix_web=info"); /// env_logger::init(); /// /// let app = App::new() -/// .middleware(Logger::default()) -/// .middleware(Logger::new("%a %{User-Agent}i")) -/// .finish(); +/// .wrap(Logger::default()) +/// .wrap(Logger::new("%a %{User-Agent}i")); /// } /// ``` /// @@ -49,9 +56,7 @@ use middleware::{Middleware, Started, Finished}; /// /// `%a` Remote IP-address (IP-address of proxy if using reverse proxy) /// -/// `%t` Time when the request was started to process -/// -/// `%P` The process ID of the child that serviced the request +/// `%t` Time when the request was started to process (in rfc3339 format) /// /// `%r` First line of request /// @@ -59,24 +64,42 @@ use middleware::{Middleware, Started, Finished}; /// /// `%b` Size of response in bytes, including HTTP headers /// -/// `%T` Time taken to serve the request, in seconds with floating fraction in .06f format +/// `%T` Time taken to serve the request, in seconds with floating fraction in +/// .06f format /// /// `%D` Time taken to serve the request, in milliseconds /// +/// `%U` Request URL +/// /// `%{FOO}i` request.headers['FOO'] /// /// `%{FOO}o` response.headers['FOO'] /// /// `%{FOO}e` os.environ['FOO'] /// -pub struct Logger { +pub struct Logger(Rc); + +struct Inner { format: Format, + exclude: HashSet, } impl Logger { /// Create `Logger` middleware with the specified `format`. pub fn new(format: &str) -> Logger { - Logger { format: Format::new(format) } + Logger(Rc::new(Inner { + format: Format::new(format), + exclude: HashSet::new(), + })) + } + + /// Ignore and do not log access info for specified path. + pub fn exclude>(mut self, path: T) -> Self { + Rc::get_mut(&mut self.0) + .unwrap() + .exclude + .insert(path.into()); + self } } @@ -84,40 +107,170 @@ impl Default for Logger { /// Create `Logger` middleware with format: /// /// ```ignore - /// %a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T + /// %a "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T /// ``` fn default() -> Logger { - Logger { format: Format::default() } + Logger(Rc::new(Inner { + format: Format::default(), + exclude: HashSet::new(), + })) } } -struct StartTime(time::Tm); +impl Transform for Logger +where + S: Service, Error = Error>, + B: MessageBody, +{ + type Request = ServiceRequest; + type Response = ServiceResponse>; + type Error = Error; + type InitError = (); + type Transform = LoggerMiddleware; + type Future = Ready>; -impl Logger { + fn new_transform(&self, service: S) -> Self::Future { + ok(LoggerMiddleware { + service, + inner: self.0.clone(), + }) + } +} - fn log(&self, req: &mut HttpRequest, resp: &HttpResponse) { - let entry_time = req.extensions().get::().unwrap().0; +/// Logger middleware +pub struct LoggerMiddleware { + inner: Rc, + service: S, +} - let render = |fmt: &mut Formatter| { - for unit in &self.format.0 { - unit.render(fmt, req, resp, entry_time)?; +impl Service for LoggerMiddleware +where + S: Service, Error = Error>, + B: MessageBody, +{ + type Request = ServiceRequest; + type Response = ServiceResponse>; + type Error = Error; + type Future = LoggerResponse; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: ServiceRequest) -> Self::Future { + if self.inner.exclude.contains(req.path()) { + LoggerResponse { + fut: self.service.call(req), + format: None, + time: time::now(), + _t: PhantomData, } - Ok(()) - }; - info!("{}", FormatDisplay(&render)); + } else { + let now = time::now(); + let mut format = self.inner.format.clone(); + + for unit in &mut format.0 { + unit.render_request(now, &req); + } + LoggerResponse { + fut: self.service.call(req), + format: Some(format), + time: now, + _t: PhantomData, + } + } } } -impl Middleware for Logger { +#[doc(hidden)] +#[pin_project::pin_project] +pub struct LoggerResponse +where + B: MessageBody, + S: Service, +{ + #[pin] + fut: S::Future, + time: time::Tm, + format: Option, + _t: PhantomData<(B,)>, +} - fn start(&self, req: &mut HttpRequest) -> Result { - req.extensions().insert(StartTime(time::now())); - Ok(Started::Done) +impl Future for LoggerResponse +where + B: MessageBody, + S: Service, Error = Error>, +{ + type Output = Result>, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); + + let res = match futures::ready!(this.fut.poll(cx)) { + Ok(res) => res, + Err(e) => return Poll::Ready(Err(e)), + }; + + if let Some(error) = res.response().error() { + if res.response().head().status != StatusCode::INTERNAL_SERVER_ERROR { + debug!("Error in response: {:?}", error); + } + } + + if let Some(ref mut format) = this.format { + for unit in &mut format.0 { + unit.render_response(res.response()); + } + } + + let time = *this.time; + let format = this.format.take(); + + Poll::Ready(Ok(res.map_body(move |_, body| { + ResponseBody::Body(StreamLog { + body, + time, + format, + size: 0, + }) + }))) + } +} + +pub struct StreamLog { + body: ResponseBody, + format: Option, + size: usize, + time: time::Tm, +} + +impl Drop for StreamLog { + fn drop(&mut self) { + if let Some(ref format) = self.format { + let render = |fmt: &mut Formatter| { + for unit in &format.0 { + unit.render(fmt, self.size, self.time)?; + } + Ok(()) + }; + log::info!("{}", FormatDisplay(&render)); + } + } +} + +impl MessageBody for StreamLog { + fn size(&self) -> BodySize { + self.body.size() } - fn finish(&self, req: &mut HttpRequest, resp: &HttpResponse) -> Finished { - self.log(req, resp); - Finished::Done + fn poll_next(&mut self, cx: &mut Context) -> Poll>> { + match self.body.poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { + self.size += chunk.len(); + Poll::Ready(Some(Ok(chunk))) + } + val => val, + } } } @@ -130,7 +283,7 @@ struct Format(Vec); impl Default for Format { /// Return the default formatting style for the `Logger`: fn default() -> Format { - Format::new(r#"%a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T"#) + Format::new(r#"%a "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T"#) } } @@ -139,8 +292,8 @@ impl Format { /// /// Returns `None` if the format string syntax is incorrect. pub fn new(s: &str) -> Format { - trace!("Access log format: {}", s); - let fmt = Regex::new(r"%(\{([A-Za-z0-9\-_]+)\}([ioe])|[atPrsbTD]?)").unwrap(); + log::trace!("Access log format: {}", s); + let fmt = Regex::new(r"%(\{([A-Za-z0-9\-_]+)\}([ioe])|[atPrUsbTD]?)").unwrap(); let mut idx = 0; let mut results = Vec::new(); @@ -153,29 +306,30 @@ impl Format { idx = m.end(); if let Some(key) = cap.get(2) { - results.push( - match cap.get(3).unwrap().as_str() { - "i" => FormatText::RequestHeader(key.as_str().to_owned()), - "o" => FormatText::ResponseHeader(key.as_str().to_owned()), - "e" => FormatText::EnvironHeader(key.as_str().to_owned()), - _ => unreachable!(), - }) + results.push(match cap.get(3).unwrap().as_str() { + "i" => FormatText::RequestHeader( + HeaderName::try_from(key.as_str()).unwrap(), + ), + "o" => FormatText::ResponseHeader( + HeaderName::try_from(key.as_str()).unwrap(), + ), + "e" => FormatText::EnvironHeader(key.as_str().to_owned()), + _ => unreachable!(), + }) } else { let m = cap.get(1).unwrap(); - results.push( - match m.as_str() { - "%" => FormatText::Percent, - "a" => FormatText::RemoteAddr, - "t" => FormatText::RequestTime, - "P" => FormatText::Pid, - "r" => FormatText::RequestLine, - "s" => FormatText::ResponseStatus, - "b" => FormatText::ResponseSize, - "T" => FormatText::Time, - "D" => FormatText::TimeMillis, - _ => FormatText::Str(m.as_str().to_owned()), - } - ); + results.push(match m.as_str() { + "%" => FormatText::Percent, + "a" => FormatText::RemoteAddr, + "t" => FormatText::RequestTime, + "r" => FormatText::RequestLine, + "s" => FormatText::ResponseStatus, + "b" => FormatText::ResponseSize, + "U" => FormatText::UrlPath, + "T" => FormatText::Time, + "D" => FormatText::TimeMillis, + _ => FormatText::Str(m.as_str().to_owned()), + }); } } if idx != s.len() { @@ -192,7 +346,6 @@ impl Format { #[derive(Debug, Clone)] pub enum FormatText { Str(String), - Pid, Percent, RequestLine, RequestTime, @@ -201,72 +354,32 @@ pub enum FormatText { Time, TimeMillis, RemoteAddr, - RequestHeader(String), - ResponseHeader(String), + UrlPath, + RequestHeader(HeaderName), + ResponseHeader(HeaderName), EnvironHeader(String), } impl FormatText { - - fn render(&self, fmt: &mut Formatter, - req: &HttpRequest, - resp: &HttpResponse, - entry_time: time::Tm) -> Result<(), fmt::Error> - { + fn render( + &self, + fmt: &mut Formatter, + size: usize, + entry_time: time::Tm, + ) -> Result<(), fmt::Error> { match *self { FormatText::Str(ref string) => fmt.write_str(string), FormatText::Percent => "%".fmt(fmt), - FormatText::RequestLine => { - if req.query_string().is_empty() { - fmt.write_fmt(format_args!( - "{} {} {:?}", - req.method(), req.path(), req.version())) - } else { - fmt.write_fmt(format_args!( - "{} {}?{} {:?}", - req.method(), req.path(), req.query_string(), req.version())) - } - }, - FormatText::ResponseStatus => resp.status().as_u16().fmt(fmt), - FormatText::ResponseSize => resp.response_size().fmt(fmt), - FormatText::Pid => unsafe{libc::getpid().fmt(fmt)}, + FormatText::ResponseSize => size.fmt(fmt), FormatText::Time => { let rt = time::now() - entry_time; let rt = (rt.num_nanoseconds().unwrap_or(0) as f64) / 1_000_000_000.0; fmt.write_fmt(format_args!("{:.6}", rt)) - }, + } FormatText::TimeMillis => { let rt = time::now() - entry_time; let rt = (rt.num_nanoseconds().unwrap_or(0) as f64) / 1_000_000.0; fmt.write_fmt(format_args!("{:.6}", rt)) - }, - FormatText::RemoteAddr => { - if let Some(remote) = req.connection_info().remote() { - return remote.fmt(fmt); - } else { - "-".fmt(fmt) - } - } - FormatText::RequestTime => { - entry_time.strftime("[%d/%b/%Y:%H:%M:%S %z]") - .unwrap() - .fmt(fmt) - } - FormatText::RequestHeader(ref name) => { - let s = if let Some(val) = req.headers().get(name) { - if let Ok(s) = val.to_str() { s } else { "-" } - } else { - "-" - }; - fmt.write_fmt(format_args!("{}", s)) - } - FormatText::ResponseHeader(ref name) => { - let s = if let Some(val) = resp.headers().get(name) { - if let Ok(s) = val.to_str() { s } else { "-" } - } else { - "-" - }; - fmt.write_fmt(format_args!("{}", s)) } FormatText::EnvironHeader(ref name) => { if let Ok(val) = env::var(name) { @@ -275,12 +388,83 @@ impl FormatText { "-".fmt(fmt) } } + _ => Ok(()), + } + } + + fn render_response(&mut self, res: &HttpResponse) { + match *self { + FormatText::ResponseStatus => { + *self = FormatText::Str(format!("{}", res.status().as_u16())) + } + FormatText::ResponseHeader(ref name) => { + let s = if let Some(val) = res.headers().get(name) { + if let Ok(s) = val.to_str() { + s + } else { + "-" + } + } else { + "-" + }; + *self = FormatText::Str(s.to_string()) + } + _ => (), + } + } + + fn render_request(&mut self, now: time::Tm, req: &ServiceRequest) { + match *self { + FormatText::RequestLine => { + *self = if req.query_string().is_empty() { + FormatText::Str(format!( + "{} {} {:?}", + req.method(), + req.path(), + req.version() + )) + } else { + FormatText::Str(format!( + "{} {}?{} {:?}", + req.method(), + req.path(), + req.query_string(), + req.version() + )) + }; + } + FormatText::UrlPath => *self = FormatText::Str(req.path().to_string()), + FormatText::RequestTime => { + *self = FormatText::Str(now.rfc3339().to_string()) + } + FormatText::RequestHeader(ref name) => { + let s = if let Some(val) = req.headers().get(name) { + if let Ok(s) = val.to_str() { + s + } else { + "-" + } + } else { + "-" + }; + *self = FormatText::Str(s.to_string()); + } + FormatText::RemoteAddr => { + let s = if let Some(remote) = req.connection_info().remote() { + FormatText::Str(remote.to_string()) + } else { + FormatText::Str("-".to_string()) + }; + *self = s; + } + _ => (), } } } pub(crate) struct FormatDisplay<'a>( - &'a Fn(&mut Formatter) -> Result<(), fmt::Error>); + &'a dyn Fn(&mut Formatter) -> Result<(), fmt::Error>, +); impl<'a> fmt::Display for FormatDisplay<'a> { fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> { @@ -290,79 +474,120 @@ impl<'a> fmt::Display for FormatDisplay<'a> { #[cfg(test)] mod tests { - use super::*; - use std::str::FromStr; - use time; - use http::{Method, Version, StatusCode, Uri}; - use http::header::{self, HeaderMap}; + use actix_service::{IntoService, Service, Transform}; + use futures::future::ok; - #[test] - fn test_logger() { + use super::*; + use crate::http::{header, StatusCode}; + use crate::test::TestRequest; + + #[actix_rt::test] + async fn test_logger() { + let srv = |req: ServiceRequest| { + ok(req.into_response( + HttpResponse::build(StatusCode::OK) + .header("X-Test", "ttt") + .finish(), + )) + }; let logger = Logger::new("%% %{User-Agent}i %{X-Test}o %{HOME}e %D test"); - let mut headers = HeaderMap::new(); - headers.insert(header::USER_AGENT, header::HeaderValue::from_static("ACTIX-WEB")); - let mut req = HttpRequest::new( - Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None); - let resp = HttpResponse::build(StatusCode::OK) - .header("X-Test", "ttt") - .force_close() - .finish(); + let mut srv = logger.new_transform(srv.into_service()).await.unwrap(); - match logger.start(&mut req) { - Ok(Started::Done) => (), - _ => panic!(), - }; - match logger.finish(&mut req, &resp) { - Finished::Done => (), - _ => panic!(), + let req = TestRequest::with_header( + header::USER_AGENT, + header::HeaderValue::from_static("ACTIX-WEB"), + ) + .to_srv_request(); + let _res = srv.call(req).await; + } + + #[actix_rt::test] + async fn test_url_path() { + let mut format = Format::new("%T %U"); + let req = TestRequest::with_header( + header::USER_AGENT, + header::HeaderValue::from_static("ACTIX-WEB"), + ) + .uri("/test/route/yeah") + .to_srv_request(); + + let now = time::now(); + for unit in &mut format.0 { + unit.render_request(now, &req); } - let entry_time = time::now(); + + let resp = HttpResponse::build(StatusCode::OK).force_close().finish(); + for unit in &mut format.0 { + unit.render_response(&resp); + } + let render = |fmt: &mut Formatter| { - for unit in logger.format.0.iter() { - unit.render(fmt, &req, &resp, entry_time)?; + for unit in &format.0 { + unit.render(fmt, 1024, now)?; } Ok(()) }; let s = format!("{}", FormatDisplay(&render)); - assert!(s.contains("ACTIX-WEB ttt")); + println!("{}", s); + assert!(s.contains("/test/route/yeah")); } - #[test] - fn test_default_format() { - let format = Format::default(); + #[actix_rt::test] + async fn test_default_format() { + let mut format = Format::default(); + + let req = TestRequest::with_header( + header::USER_AGENT, + header::HeaderValue::from_static("ACTIX-WEB"), + ) + .to_srv_request(); + + let now = time::now(); + for unit in &mut format.0 { + unit.render_request(now, &req); + } - let mut headers = HeaderMap::new(); - headers.insert(header::USER_AGENT, header::HeaderValue::from_static("ACTIX-WEB")); - let req = HttpRequest::new( - Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None); let resp = HttpResponse::build(StatusCode::OK).force_close().finish(); - let entry_time = time::now(); + for unit in &mut format.0 { + unit.render_response(&resp); + } + let entry_time = time::now(); let render = |fmt: &mut Formatter| { - for unit in format.0.iter() { - unit.render(fmt, &req, &resp, entry_time)?; + for unit in &format.0 { + unit.render(fmt, 1024, entry_time)?; } Ok(()) }; let s = format!("{}", FormatDisplay(&render)); assert!(s.contains("GET / HTTP/1.1")); - assert!(s.contains("200 0")); + assert!(s.contains("200 1024")); assert!(s.contains("ACTIX-WEB")); + } + + #[actix_rt::test] + async fn test_request_time_format() { + let mut format = Format::new("%t"); + let req = TestRequest::default().to_srv_request(); + + let now = time::now(); + for unit in &mut format.0 { + unit.render_request(now, &req); + } - let req = HttpRequest::new( - Method::GET, Uri::from_str("/?test").unwrap(), - Version::HTTP_11, HeaderMap::new(), None); let resp = HttpResponse::build(StatusCode::OK).force_close().finish(); - let entry_time = time::now(); + for unit in &mut format.0 { + unit.render_response(&resp); + } let render = |fmt: &mut Formatter| { - for unit in format.0.iter() { - unit.render(fmt, &req, &resp, entry_time)?; + for unit in &format.0 { + unit.render(fmt, 1024, now)?; } Ok(()) }; let s = format!("{}", FormatDisplay(&render)); - assert!(s.contains("GET /?test HTTP/1.1")); + assert!(s.contains(&format!("{}", now.rfc3339()))); } } diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 8b0503925..84e0758bf 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -1,71 +1,14 @@ //! Middlewares -use futures::Future; +mod compress; +pub use self::compress::{BodyEncoding, Compress}; -use error::{Error, Result}; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; - -mod logger; - -#[cfg(feature = "session")] -mod session; +mod condition; mod defaultheaders; -mod errhandlers; -pub mod cors; -pub mod csrf; -pub use self::logger::Logger; -pub use self::errhandlers::ErrorHandlers; +pub mod errhandlers; +mod logger; +mod normalize; + +pub use self::condition::Condition; pub use self::defaultheaders::DefaultHeaders; - -#[cfg(feature = "session")] -pub use self::session::{RequestSession, Session, SessionImpl, SessionBackend, SessionStorage, - CookieSessionError, CookieSessionBackend, CookieSessionBackendBuilder}; - -/// Middleware start result -pub enum Started { - /// Execution completed - Done, - /// New http response got generated. If middleware generates response - /// handler execution halts. - Response(HttpResponse), - /// Execution completed, runs future to completion. - Future(Box, Error=Error>>), -} - -/// Middleware execution result -pub enum Response { - /// New http response got generated - Done(HttpResponse), - /// Result is a future that resolves to a new http response - Future(Box>), -} - -/// Middleware finish result -pub enum Finished { - /// Execution completed - Done, - /// Execution completed, but run future to completion - Future(Box>), -} - -/// Middleware definition -#[allow(unused_variables)] -pub trait Middleware: 'static { - - /// Method is called when request is ready. It may return - /// future, which should resolve before next middleware get called. - fn start(&self, req: &mut HttpRequest) -> Result { - Ok(Started::Done) - } - - /// Method is called when handler returns response, - /// but before sending http message to peer. - fn response(&self, req: &mut HttpRequest, resp: HttpResponse) -> Result { - Ok(Response::Done(resp)) - } - - /// Method is called after body stream get sent to peer. - fn finish(&self, req: &mut HttpRequest, resp: &HttpResponse) -> Finished { - Finished::Done - } -} +pub use self::logger::Logger; +pub use self::normalize::NormalizePath; diff --git a/src/middleware/normalize.rs b/src/middleware/normalize.rs new file mode 100644 index 000000000..2926eacc9 --- /dev/null +++ b/src/middleware/normalize.rs @@ -0,0 +1,160 @@ +//! `Middleware` to normalize request's URI +use std::task::{Context, Poll}; + +use actix_http::http::{HttpTryFrom, PathAndQuery, Uri}; +use actix_service::{Service, Transform}; +use bytes::Bytes; +use futures::future::{ok, Ready}; +use regex::Regex; + +use crate::service::{ServiceRequest, ServiceResponse}; +use crate::Error; + +#[derive(Default, Clone, Copy)] +/// `Middleware` to normalize request's URI in place +/// +/// Performs following: +/// +/// - Merges multiple slashes into one. +/// +/// ```rust +/// use actix_web::{web, http, middleware, App, HttpResponse}; +/// +/// # fn main() { +/// let app = App::new() +/// .wrap(middleware::NormalizePath) +/// .service( +/// web::resource("/test") +/// .route(web::get().to(|| HttpResponse::Ok())) +/// .route(web::method(http::Method::HEAD).to(|| HttpResponse::MethodNotAllowed())) +/// ); +/// # } +/// ``` + +pub struct NormalizePath; + +impl Transform for NormalizePath +where + S: Service, Error = Error>, + S::Future: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Transform = NormalizePathNormalization; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ok(NormalizePathNormalization { + service, + merge_slash: Regex::new("//+").unwrap(), + }) + } +} + +pub struct NormalizePathNormalization { + service: S, + merge_slash: Regex, +} + +impl Service for NormalizePathNormalization +where + S: Service, Error = Error>, + S::Future: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, mut req: ServiceRequest) -> Self::Future { + let head = req.head_mut(); + + let path = head.uri.path(); + let original_len = path.len(); + let path = self.merge_slash.replace_all(path, "/"); + + if original_len != path.len() { + let mut parts = head.uri.clone().into_parts(); + let pq = parts.path_and_query.as_ref().unwrap(); + + let path = if let Some(q) = pq.query() { + Bytes::from(format!("{}?{}", path, q)) + } else { + Bytes::from(path.as_ref()) + }; + parts.path_and_query = Some(PathAndQuery::try_from(path).unwrap()); + + let uri = Uri::from_parts(parts).unwrap(); + req.match_info_mut().get_mut().update(&uri); + req.head_mut().uri = uri; + } + + self.service.call(req) + } +} + +#[cfg(test)] +mod tests { + use actix_service::IntoService; + + use super::*; + use crate::dev::ServiceRequest; + use crate::test::{call_service, init_service, TestRequest}; + use crate::{web, App, HttpResponse}; + + #[actix_rt::test] + async fn test_wrap() { + let mut app = init_service( + App::new() + .wrap(NormalizePath::default()) + .service(web::resource("/v1/something/").to(|| HttpResponse::Ok())), + ) + .await; + + let req = TestRequest::with_uri("/v1//something////").to_request(); + let res = call_service(&mut app, req).await; + assert!(res.status().is_success()); + } + + #[actix_rt::test] + async fn test_in_place_normalization() { + let srv = |req: ServiceRequest| { + assert_eq!("/v1/something/", req.path()); + ok(req.into_response(HttpResponse::Ok().finish())) + }; + + let mut normalize = NormalizePath + .new_transform(srv.into_service()) + .await + .unwrap(); + + let req = TestRequest::with_uri("/v1//something////").to_srv_request(); + let res = normalize.call(req).await.unwrap(); + assert!(res.status().is_success()); + } + + #[actix_rt::test] + async fn should_normalize_nothing() { + const URI: &str = "/v1/something/"; + + let srv = |req: ServiceRequest| { + assert_eq!(URI, req.path()); + ok(req.into_response(HttpResponse::Ok().finish())) + }; + + let mut normalize = NormalizePath + .new_transform(srv.into_service()) + .await + .unwrap(); + + let req = TestRequest::with_uri(URI).to_srv_request(); + let res = normalize.call(req).await.unwrap(); + assert!(res.status().is_success()); + } +} diff --git a/src/middleware/session.rs b/src/middleware/session.rs deleted file mode 100644 index c0fe80158..000000000 --- a/src/middleware/session.rs +++ /dev/null @@ -1,439 +0,0 @@ -use std::rc::Rc; -use std::sync::Arc; -use std::marker::PhantomData; -use std::collections::HashMap; - -use serde_json; -use serde_json::error::Error as JsonError; -use serde::{Serialize, Deserialize}; -use http::header::{self, HeaderValue}; -use cookie::{CookieJar, Cookie, Key}; -use futures::Future; -use futures::future::{FutureResult, ok as FutOk, err as FutErr}; - -use error::{Result, Error, ResponseError}; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; -use middleware::{Middleware, Started, Response}; - -/// The helper trait to obtain your session data from a request. -/// -/// ```rust -/// use actix_web::*; -/// use actix_web::middleware::RequestSession; -/// -/// fn index(mut req: HttpRequest) -> Result<&'static str> { -/// // access session data -/// if let Some(count) = req.session().get::("counter")? { -/// req.session().set("counter", count+1)?; -/// } else { -/// req.session().set("counter", 1)?; -/// } -/// -/// Ok("Welcome!") -/// } -/// # fn main() {} -/// ``` -pub trait RequestSession { - fn session(&mut self) -> Session; -} - -impl RequestSession for HttpRequest { - - fn session(&mut self) -> Session { - if let Some(s_impl) = self.extensions().get_mut::>() { - if let Some(s) = Arc::get_mut(s_impl) { - return Session(s.0.as_mut()) - } - } - Session(unsafe{&mut DUMMY}) - } -} - -/// The high-level interface you use to modify session data. -/// -/// Session object could be obtained with -/// [`RequestSession::session`](trait.RequestSession.html#tymethod.session) -/// method. `RequestSession` trait is implemented for `HttpRequest`. -/// -/// ```rust -/// use actix_web::*; -/// use actix_web::middleware::RequestSession; -/// -/// fn index(mut req: HttpRequest) -> Result<&'static str> { -/// // access session data -/// if let Some(count) = req.session().get::("counter")? { -/// req.session().set("counter", count+1)?; -/// } else { -/// req.session().set("counter", 1)?; -/// } -/// -/// Ok("Welcome!") -/// } -/// # fn main() {} -/// ``` -pub struct Session<'a>(&'a mut SessionImpl); - -impl<'a> Session<'a> { - - /// Get a `value` from the session. - pub fn get>(&'a self, key: &str) -> Result> { - if let Some(s) = self.0.get(key) { - Ok(Some(serde_json::from_str(s)?)) - } else { - Ok(None) - } - } - - /// Set a `value` from the session. - pub fn set(&mut self, key: &str, value: T) -> Result<()> { - self.0.set(key, serde_json::to_string(&value)?); - Ok(()) - } - - /// Remove value from the session. - pub fn remove(&'a mut self, key: &str) { - self.0.remove(key) - } - - /// Clear the session. - pub fn clear(&'a mut self) { - self.0.clear() - } -} - -struct SessionImplBox(Box); - -#[doc(hidden)] -unsafe impl Send for SessionImplBox {} -#[doc(hidden)] -unsafe impl Sync for SessionImplBox {} - -/// Session storage middleware -/// -/// ```rust -/// # extern crate actix; -/// # extern crate actix_web; -/// use actix_web::App; -/// use actix_web::middleware::{SessionStorage, CookieSessionBackend}; -/// -/// fn main() { -/// let app = App::new().middleware( -/// SessionStorage::new( // <- create session middleware -/// CookieSessionBackend::build(&[0; 32]) // <- create cookie session backend -/// .secure(false) -/// .finish()) -/// ); -/// } -/// ``` -pub struct SessionStorage(T, PhantomData); - -impl> SessionStorage { - /// Create session storage - pub fn new(backend: T) -> SessionStorage { - SessionStorage(backend, PhantomData) - } -} - -impl> Middleware for SessionStorage { - - fn start(&self, req: &mut HttpRequest) -> Result { - let mut req = req.clone(); - - let fut = self.0.from_request(&mut req) - .then(move |res| { - match res { - Ok(sess) => { - req.extensions().insert(Arc::new(SessionImplBox(Box::new(sess)))); - FutOk(None) - }, - Err(err) => FutErr(err) - } - }); - Ok(Started::Future(Box::new(fut))) - } - - fn response(&self, req: &mut HttpRequest, resp: HttpResponse) -> Result { - if let Some(s_box) = req.extensions().remove::>() { - s_box.0.write(resp) - } else { - Ok(Response::Done(resp)) - } - } -} - -/// A simple key-value storage interface that is internally used by `Session`. -#[doc(hidden)] -pub trait SessionImpl: 'static { - - fn get(&self, key: &str) -> Option<&str>; - - fn set(&mut self, key: &str, value: String); - - fn remove(&mut self, key: &str); - - fn clear(&mut self); - - /// Write session to storage backend. - fn write(&self, resp: HttpResponse) -> Result; -} - -/// Session's storage backend trait definition. -#[doc(hidden)] -pub trait SessionBackend: Sized + 'static { - type Session: SessionImpl; - type ReadFuture: Future; - - /// Parse the session from request and load data from a storage backend. - fn from_request(&self, request: &mut HttpRequest) -> Self::ReadFuture; -} - -/// Dummy session impl, does not do anything -struct DummySessionImpl; - -static mut DUMMY: DummySessionImpl = DummySessionImpl; - -impl SessionImpl for DummySessionImpl { - - fn get(&self, _: &str) -> Option<&str> { None } - fn set(&mut self, _: &str, _: String) {} - fn remove(&mut self, _: &str) {} - fn clear(&mut self) {} - fn write(&self, resp: HttpResponse) -> Result { - Ok(Response::Done(resp)) - } -} - -/// Session that uses signed cookies as session storage -pub struct CookieSession { - changed: bool, - state: HashMap, - inner: Rc, -} - -/// Errors that can occur during handling cookie session -#[derive(Fail, Debug)] -pub enum CookieSessionError { - /// Size of the serialized session is greater than 4000 bytes. - #[fail(display="Size of the serialized session is greater than 4000 bytes.")] - Overflow, - /// Fail to serialize session. - #[fail(display="Fail to serialize session")] - Serialize(JsonError), -} - -impl ResponseError for CookieSessionError {} - -impl SessionImpl for CookieSession { - - fn get(&self, key: &str) -> Option<&str> { - if let Some(s) = self.state.get(key) { - Some(s) - } else { - None - } - } - - fn set(&mut self, key: &str, value: String) { - self.changed = true; - self.state.insert(key.to_owned(), value); - } - - fn remove(&mut self, key: &str) { - self.changed = true; - self.state.remove(key); - } - - fn clear(&mut self) { - self.changed = true; - self.state.clear() - } - - fn write(&self, mut resp: HttpResponse) -> Result { - if self.changed { - let _ = self.inner.set_cookie(&mut resp, &self.state); - } - Ok(Response::Done(resp)) - } -} - -struct CookieSessionInner { - key: Key, - name: String, - path: String, - domain: Option, - secure: bool, -} - -impl CookieSessionInner { - - fn new(key: &[u8]) -> CookieSessionInner { - CookieSessionInner { - key: Key::from_master(key), - name: "actix-session".to_owned(), - path: "/".to_owned(), - domain: None, - secure: true } - } - - fn set_cookie(&self, resp: &mut HttpResponse, state: &HashMap) -> Result<()> { - let value = serde_json::to_string(&state) - .map_err(CookieSessionError::Serialize)?; - if value.len() > 4064 { - return Err(CookieSessionError::Overflow.into()) - } - - let mut cookie = Cookie::new(self.name.clone(), value); - cookie.set_path(self.path.clone()); - cookie.set_secure(self.secure); - cookie.set_http_only(true); - - if let Some(ref domain) = self.domain { - cookie.set_domain(domain.clone()); - } - - let mut jar = CookieJar::new(); - jar.signed(&self.key).add(cookie); - - for cookie in jar.delta() { - let val = HeaderValue::from_str(&cookie.to_string())?; - resp.headers_mut().append(header::SET_COOKIE, val); - } - - Ok(()) - } - - fn load(&self, req: &mut HttpRequest) -> HashMap { - if let Ok(cookies) = req.cookies() { - for cookie in cookies { - if cookie.name() == self.name { - let mut jar = CookieJar::new(); - jar.add_original(cookie.clone()); - if let Some(cookie) = jar.signed(&self.key).get(&self.name) { - if let Ok(val) = serde_json::from_str(cookie.value()) { - return val; - } - } - } - } - } - HashMap::new() - } -} - -/// Use signed cookies as session storage. -/// -/// `CookieSessionBackend` creates sessions which are limited to storing -/// fewer than 4000 bytes of data (as the payload must fit into a single cookie). -/// Internal server error get generated if session contains more than 4000 bytes. -/// -/// You need to pass a random value to the constructor of `CookieSessionBackend`. -/// This is private key for cookie session, When this value is changed, all session data is lost. -/// -/// Note that whatever you write into your session is visible by the user (but not modifiable). -/// -/// Constructor panics if key length is less than 32 bytes. -pub struct CookieSessionBackend(Rc); - -impl CookieSessionBackend { - - /// Construct new `CookieSessionBackend` instance. - /// - /// Panics if key length is less than 32 bytes. - pub fn new(key: &[u8]) -> CookieSessionBackend { - CookieSessionBackend( - Rc::new(CookieSessionInner::new(key))) - } - - /// Creates a new `CookieSessionBackendBuilder` instance from the given key. - /// - /// Panics if key length is less than 32 bytes. - /// - /// # Example - /// - /// ``` - /// use actix_web::middleware::CookieSessionBackend; - /// - /// let backend = CookieSessionBackend::build(&[0; 32]).finish(); - /// ``` - pub fn build(key: &[u8]) -> CookieSessionBackendBuilder { - CookieSessionBackendBuilder::new(key) - } -} - -impl SessionBackend for CookieSessionBackend { - - type Session = CookieSession; - type ReadFuture = FutureResult; - - fn from_request(&self, req: &mut HttpRequest) -> Self::ReadFuture { - let state = self.0.load(req); - FutOk( - CookieSession { - changed: false, - inner: Rc::clone(&self.0), - state, - }) - } -} - -/// Structure that follows the builder pattern for building `CookieSessionBackend` structs. -/// -/// To construct a backend: -/// -/// 1. Call [`CookieSessionBackend::build`](struct.CookieSessionBackend.html#method.build) to start building. -/// 2. Use any of the builder methods to set fields in the backend. -/// 3. Call [finish](#method.finish) to retrieve the constructed backend. -/// -/// # Example -/// -/// ```rust -/// # extern crate actix_web; -/// use actix_web::middleware::CookieSessionBackend; -/// -/// # fn main() { -/// let backend: CookieSessionBackend = CookieSessionBackend::build(&[0; 32]) -/// .domain("www.rust-lang.org") -/// .name("actix_session") -/// .path("/") -/// .secure(true) -/// .finish(); -/// # } -/// ``` -pub struct CookieSessionBackendBuilder(CookieSessionInner); - -impl CookieSessionBackendBuilder { - pub fn new(key: &[u8]) -> CookieSessionBackendBuilder { - CookieSessionBackendBuilder( - CookieSessionInner::new(key)) - } - - /// Sets the `path` field in the session cookie being built. - pub fn path>(mut self, value: S) -> CookieSessionBackendBuilder { - self.0.path = value.into(); - self - } - - /// Sets the `name` field in the session cookie being built. - pub fn name>(mut self, value: S) -> CookieSessionBackendBuilder { - self.0.name = value.into(); - self - } - - /// Sets the `domain` field in the session cookie being built. - pub fn domain>(mut self, value: S) -> CookieSessionBackendBuilder { - self.0.domain = Some(value.into()); - self - } - - /// Sets the `secure` field in the session cookie being built. - pub fn secure(mut self, value: bool) -> CookieSessionBackendBuilder { - self.0.secure = value; - self - } - - /// Finishes building and returns the built `CookieSessionBackend`. - pub fn finish(self) -> CookieSessionBackend { - CookieSessionBackend(Rc::new(self.0)) - } -} diff --git a/src/multipart.rs b/src/multipart.rs deleted file mode 100644 index 4ac7b2a15..000000000 --- a/src/multipart.rs +++ /dev/null @@ -1,739 +0,0 @@ -//! Multipart requests support -use std::{cmp, fmt}; -use std::rc::Rc; -use std::cell::RefCell; -use std::marker::PhantomData; - -use mime; -use httparse; -use bytes::Bytes; -use http::HttpTryFrom; -use http::header::{self, HeaderMap, HeaderName, HeaderValue}; -use futures::{Async, Stream, Poll}; -use futures::task::{Task, current as current_task}; - -use error::{ParseError, PayloadError, MultipartError}; -use payload::PayloadHelper; - -const MAX_HEADERS: usize = 32; - -/// The server-side implementation of `multipart/form-data` requests. -/// -/// This will parse the incoming stream into `MultipartItem` instances via its -/// Stream implementation. -/// `MultipartItem::Field` contains multipart field. `MultipartItem::Multipart` -/// is used for nested multipart streams. -pub struct Multipart { - safety: Safety, - error: Option, - inner: Option>>>, -} - -/// -pub enum MultipartItem { - /// Multipart field - Field(Field), - /// Nested multipart stream - Nested(Multipart), -} - -enum InnerMultipartItem { - None, - Field(Rc>>), - Multipart(Rc>>), -} - -#[derive(PartialEq, Debug)] -enum InnerState { - /// Stream eof - Eof, - /// Skip data until first boundary - FirstBoundary, - /// Reading boundary - Boundary, - /// Reading Headers, - Headers, -} - -struct InnerMultipart { - payload: PayloadRef, - boundary: String, - state: InnerState, - item: InnerMultipartItem, -} - -impl Multipart<()> { - /// Extract boundary info from headers. - pub fn boundary(headers: &HeaderMap) -> Result { - if let Some(content_type) = headers.get(header::CONTENT_TYPE) { - if let Ok(content_type) = content_type.to_str() { - if let Ok(ct) = content_type.parse::() { - if let Some(boundary) = ct.get_param(mime::BOUNDARY) { - Ok(boundary.as_str().to_owned()) - } else { - Err(MultipartError::Boundary) - } - } else { - Err(MultipartError::ParseContentType) - } - } else { - Err(MultipartError::ParseContentType) - } - } else { - Err(MultipartError::NoContentType) - } - } -} - -impl Multipart where S: Stream { - - /// Create multipart instance for boundary. - pub fn new(boundary: Result, stream: S) -> Multipart { - match boundary { - Ok(boundary) => Multipart { - error: None, - safety: Safety::new(), - inner: Some(Rc::new(RefCell::new( - InnerMultipart { - boundary, - payload: PayloadRef::new(PayloadHelper::new(stream)), - state: InnerState::FirstBoundary, - item: InnerMultipartItem::None, - }))) - }, - Err(err) => - Multipart { - error: Some(err), - safety: Safety::new(), - inner: None, - } - } - } -} - -impl Stream for Multipart where S: Stream { - type Item = MultipartItem; - type Error = MultipartError; - - fn poll(&mut self) -> Poll, Self::Error> { - if let Some(err) = self.error.take() { - Err(err) - } else if self.safety.current() { - self.inner.as_mut().unwrap().borrow_mut().poll(&self.safety) - } else { - Ok(Async::NotReady) - } - } -} - -impl InnerMultipart where S: Stream { - - fn read_headers(payload: &mut PayloadHelper) -> Poll - { - match payload.read_until(b"\r\n\r\n")? { - Async::NotReady => Ok(Async::NotReady), - Async::Ready(None) => Err(MultipartError::Incomplete), - Async::Ready(Some(bytes)) => { - let mut hdrs = [httparse::EMPTY_HEADER; MAX_HEADERS]; - match httparse::parse_headers(&bytes, &mut hdrs) { - Ok(httparse::Status::Complete((_, hdrs))) => { - // convert headers - let mut headers = HeaderMap::with_capacity(hdrs.len()); - for h in hdrs { - if let Ok(name) = HeaderName::try_from(h.name) { - if let Ok(value) = HeaderValue::try_from(h.value) { - headers.append(name, value); - } else { - return Err(ParseError::Header.into()) - } - } else { - return Err(ParseError::Header.into()) - } - } - Ok(Async::Ready(headers)) - } - Ok(httparse::Status::Partial) => Err(ParseError::Header.into()), - Err(err) => Err(ParseError::from(err).into()), - } - } - } - } - - fn read_boundary(payload: &mut PayloadHelper, boundary: &str) - -> Poll - { - // TODO: need to read epilogue - match payload.readline()? { - Async::NotReady => Ok(Async::NotReady), - Async::Ready(None) => Err(MultipartError::Incomplete), - Async::Ready(Some(chunk)) => { - if chunk.len() == boundary.len() + 4 && - &chunk[..2] == b"--" && - &chunk[2..boundary.len()+2] == boundary.as_bytes() - { - Ok(Async::Ready(false)) - } else if chunk.len() == boundary.len() + 6 && - &chunk[..2] == b"--" && - &chunk[2..boundary.len()+2] == boundary.as_bytes() && - &chunk[boundary.len()+2..boundary.len()+4] == b"--" - { - Ok(Async::Ready(true)) - } else { - Err(MultipartError::Boundary) - } - } - } - } - - fn skip_until_boundary(payload: &mut PayloadHelper, boundary: &str) - -> Poll - { - let mut eof = false; - loop { - match payload.readline()? { - Async::Ready(Some(chunk)) => { - if chunk.is_empty() { - //ValueError("Could not find starting boundary %r" - //% (self._boundary)) - } - if chunk.len() < boundary.len() { - continue - } - if &chunk[..2] == b"--" && &chunk[2..chunk.len()-2] == boundary.as_bytes() { - break; - } else { - if chunk.len() < boundary.len() + 2{ - continue - } - let b: &[u8] = boundary.as_ref(); - if &chunk[..boundary.len()] == b && - &chunk[boundary.len()..boundary.len()+2] == b"--" { - eof = true; - break; - } - } - }, - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(None) => return Err(MultipartError::Incomplete), - } - } - Ok(Async::Ready(eof)) - } - - fn poll(&mut self, safety: &Safety) -> Poll>, MultipartError> { - if self.state == InnerState::Eof { - Ok(Async::Ready(None)) - } else { - // release field - loop { - // Nested multipart streams of fields has to be consumed - // before switching to next - if safety.current() { - let stop = match self.item { - InnerMultipartItem::Field(ref mut field) => { - match field.borrow_mut().poll(safety)? { - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(Some(_)) => continue, - Async::Ready(None) => true, - } - }, - InnerMultipartItem::Multipart(ref mut multipart) => { - match multipart.borrow_mut().poll(safety)? { - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(Some(_)) => continue, - Async::Ready(None) => true, - } - }, - _ => false, - }; - if stop { - self.item = InnerMultipartItem::None; - } - if let InnerMultipartItem::None = self.item { - break; - } - } - } - - let headers = if let Some(payload) = self.payload.get_mut(safety) { - match self.state { - // read until first boundary - InnerState::FirstBoundary => { - match InnerMultipart::skip_until_boundary(payload, &self.boundary)? { - Async::Ready(eof) => { - if eof { - self.state = InnerState::Eof; - return Ok(Async::Ready(None)); - } else { - self.state = InnerState::Headers; - } - }, - Async::NotReady => return Ok(Async::NotReady), - } - }, - // read boundary - InnerState::Boundary => { - match InnerMultipart::read_boundary(payload, &self.boundary)? { - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(eof) => { - if eof { - self.state = InnerState::Eof; - return Ok(Async::Ready(None)); - } else { - self.state = InnerState::Headers; - } - } - } - } - _ => (), - } - - // read field headers for next field - if self.state == InnerState::Headers { - if let Async::Ready(headers) = InnerMultipart::read_headers(payload)? { - self.state = InnerState::Boundary; - headers - } else { - return Ok(Async::NotReady) - } - } else { - unreachable!() - } - } else { - debug!("NotReady: field is in flight"); - return Ok(Async::NotReady) - }; - - // content type - let mut mt = mime::APPLICATION_OCTET_STREAM; - if let Some(content_type) = headers.get(header::CONTENT_TYPE) { - if let Ok(content_type) = content_type.to_str() { - if let Ok(ct) = content_type.parse::() { - mt = ct; - } - } - } - - self.state = InnerState::Boundary; - - // nested multipart stream - if mt.type_() == mime::MULTIPART { - let inner = if let Some(boundary) = mt.get_param(mime::BOUNDARY) { - Rc::new(RefCell::new( - InnerMultipart { - payload: self.payload.clone(), - boundary: boundary.as_str().to_owned(), - state: InnerState::FirstBoundary, - item: InnerMultipartItem::None, - })) - } else { - return Err(MultipartError::Boundary) - }; - - self.item = InnerMultipartItem::Multipart(Rc::clone(&inner)); - - Ok(Async::Ready(Some( - MultipartItem::Nested( - Multipart{safety: safety.clone(), - error: None, - inner: Some(inner)})))) - } else { - let field = Rc::new(RefCell::new(InnerField::new( - self.payload.clone(), self.boundary.clone(), &headers)?)); - self.item = InnerMultipartItem::Field(Rc::clone(&field)); - - Ok(Async::Ready(Some( - MultipartItem::Field( - Field::new(safety.clone(), headers, mt, field))))) - } - } - } -} - -impl Drop for InnerMultipart { - fn drop(&mut self) { - // InnerMultipartItem::Field has to be dropped first because of Safety. - self.item = InnerMultipartItem::None; - } -} - -/// A single field in a multipart stream -pub struct Field { - ct: mime::Mime, - headers: HeaderMap, - inner: Rc>>, - safety: Safety, -} - -impl Field where S: Stream { - - fn new(safety: Safety, headers: HeaderMap, - ct: mime::Mime, inner: Rc>>) -> Self { - Field {ct, headers, inner, safety} - } - - pub fn headers(&self) -> &HeaderMap { - &self.headers - } - - pub fn content_type(&self) -> &mime::Mime { - &self.ct - } -} - -impl Stream for Field where S: Stream { - type Item = Bytes; - type Error = MultipartError; - - fn poll(&mut self) -> Poll, Self::Error> { - if self.safety.current() { - self.inner.borrow_mut().poll(&self.safety) - } else { - Ok(Async::NotReady) - } - } -} - -impl fmt::Debug for Field { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let res = write!(f, "\nMultipartField: {}\n", self.ct); - let _ = write!(f, " boundary: {}\n", self.inner.borrow().boundary); - let _ = write!(f, " headers:\n"); - for (key, val) in self.headers.iter() { - let _ = write!(f, " {:?}: {:?}\n", key, val); - } - res - } -} - -struct InnerField { - payload: Option>, - boundary: String, - eof: bool, - length: Option, -} - -impl InnerField where S: Stream { - - fn new(payload: PayloadRef, boundary: String, headers: &HeaderMap) - -> Result, PayloadError> - { - let len = if let Some(len) = headers.get(header::CONTENT_LENGTH) { - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - Some(len) - } else { - return Err(PayloadError::Incomplete) - } - } else { - return Err(PayloadError::Incomplete) - } - } else { - None - }; - - Ok(InnerField { - boundary, - payload: Some(payload), - eof: false, - length: len }) - } - - /// Reads body part content chunk of the specified size. - /// The body part must has `Content-Length` header with proper value. - fn read_len(payload: &mut PayloadHelper, size: &mut u64) - -> Poll, MultipartError> - { - if *size == 0 { - Ok(Async::Ready(None)) - } else { - match payload.readany() { - Ok(Async::NotReady) => Ok(Async::NotReady), - Ok(Async::Ready(None)) => Err(MultipartError::Incomplete), - Ok(Async::Ready(Some(mut chunk))) => { - let len = cmp::min(chunk.len() as u64, *size); - *size -= len; - let ch = chunk.split_to(len as usize); - if !chunk.is_empty() { - payload.unread_data(chunk); - } - Ok(Async::Ready(Some(ch))) - }, - Err(err) => Err(err.into()) - } - } - } - - /// Reads content chunk of body part with unknown length. - /// The `Content-Length` header for body part is not necessary. - fn read_stream(payload: &mut PayloadHelper, boundary: &str) - -> Poll, MultipartError> - { - match payload.read_until(b"\r")? { - Async::NotReady => Ok(Async::NotReady), - Async::Ready(None) => Err(MultipartError::Incomplete), - Async::Ready(Some(mut chunk)) => { - if chunk.len() == 1 { - payload.unread_data(chunk); - match payload.read_exact(boundary.len() + 4)? { - Async::NotReady => Ok(Async::NotReady), - Async::Ready(None) => Err(MultipartError::Incomplete), - Async::Ready(Some(chunk)) => { - if &chunk[..2] == b"\r\n" && &chunk[2..4] == b"--" && - &chunk[4..] == boundary.as_bytes() - { - payload.unread_data(chunk); - Ok(Async::Ready(None)) - } else { - Ok(Async::Ready(Some(chunk))) - } - } - } - } else { - let to = chunk.len() - 1; - let ch = chunk.split_to(to); - payload.unread_data(chunk); - Ok(Async::Ready(Some(ch))) - } - } - } - } - - fn poll(&mut self, s: &Safety) -> Poll, MultipartError> { - if self.payload.is_none() { - return Ok(Async::Ready(None)) - } - - let result = if let Some(payload) = self.payload.as_ref().unwrap().get_mut(s) { - let res = if let Some(ref mut len) = self.length { - InnerField::read_len(payload, len)? - } else { - InnerField::read_stream(payload, &self.boundary)? - }; - - match res { - Async::NotReady => Async::NotReady, - Async::Ready(Some(bytes)) => Async::Ready(Some(bytes)), - Async::Ready(None) => { - self.eof = true; - match payload.readline()? { - Async::NotReady => Async::NotReady, - Async::Ready(None) => Async::Ready(None), - Async::Ready(Some(line)) => { - if line.as_ref() != b"\r\n" { - warn!("multipart field did not read all the data or it is malformed"); - } - Async::Ready(None) - } - } - } - } - } else { - Async::NotReady - }; - - if Async::Ready(None) == result { - self.payload.take(); - } - Ok(result) - } -} - -struct PayloadRef { - payload: Rc>, -} - -impl PayloadRef where S: Stream { - fn new(payload: PayloadHelper) -> PayloadRef { - PayloadRef { - payload: Rc::new(payload), - } - } - - fn get_mut<'a, 'b>(&'a self, s: &'b Safety) -> Option<&'a mut PayloadHelper> - where 'a: 'b - { - if s.current() { - let payload: &mut PayloadHelper = unsafe { - &mut *(self.payload.as_ref() as *const _ as *mut _)}; - Some(payload) - } else { - None - } - } -} - -impl Clone for PayloadRef { - fn clone(&self) -> PayloadRef { - PayloadRef { - payload: Rc::clone(&self.payload), - } - } -} - -/// Counter. It tracks of number of clones of payloads and give access to payload only -/// to top most task panics if Safety get destroyed and it not top most task. -#[derive(Debug)] -struct Safety { - task: Option, - level: usize, - payload: Rc>, -} - -impl Safety { - fn new() -> Safety { - let payload = Rc::new(PhantomData); - Safety { - task: None, - level: Rc::strong_count(&payload), - payload, - } - } - - fn current(&self) -> bool { - Rc::strong_count(&self.payload) == self.level - } - -} - -impl Clone for Safety { - fn clone(&self) -> Safety { - let payload = Rc::clone(&self.payload); - Safety { - task: Some(current_task()), - level: Rc::strong_count(&payload), - payload, - } - } -} - -impl Drop for Safety { - fn drop(&mut self) { - // parent task is dead - if Rc::strong_count(&self.payload) != self.level { - panic!("Safety get dropped but it is not from top-most task"); - } - if let Some(task) = self.task.take() { - task.notify() - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::Bytes; - use futures::future::{lazy, result}; - use tokio_core::reactor::Core; - use payload::{Payload, PayloadWriter}; - - #[test] - fn test_boundary() { - let headers = HeaderMap::new(); - match Multipart::boundary(&headers) { - Err(MultipartError::NoContentType) => (), - _ => unreachable!("should not happen"), - } - - let mut headers = HeaderMap::new(); - headers.insert(header::CONTENT_TYPE, - header::HeaderValue::from_static("test")); - - match Multipart::boundary(&headers) { - Err(MultipartError::ParseContentType) => (), - _ => unreachable!("should not happen"), - } - - let mut headers = HeaderMap::new(); - headers.insert( - header::CONTENT_TYPE, - header::HeaderValue::from_static("multipart/mixed")); - match Multipart::boundary(&headers) { - Err(MultipartError::Boundary) => (), - _ => unreachable!("should not happen"), - } - - let mut headers = HeaderMap::new(); - headers.insert( - header::CONTENT_TYPE, - header::HeaderValue::from_static( - "multipart/mixed; boundary=\"5c02368e880e436dab70ed54e1c58209\"")); - - assert_eq!(Multipart::boundary(&headers).unwrap(), - "5c02368e880e436dab70ed54e1c58209"); - } - - #[test] - fn test_multipart() { - Core::new().unwrap().run(lazy(|| { - let (mut sender, payload) = Payload::new(false); - - let bytes = Bytes::from( - "testasdadsad\r\n\ - --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ - Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ - test\r\n\ - --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ - Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ - data\r\n\ - --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n"); - sender.feed_data(bytes); - - let mut multipart = Multipart::new( - Ok("abbc761f78ff4d7cb7573b5a23f96ef0".to_owned()), payload); - match multipart.poll() { - Ok(Async::Ready(Some(item))) => { - match item { - MultipartItem::Field(mut field) => { - assert_eq!(field.content_type().type_(), mime::TEXT); - assert_eq!(field.content_type().subtype(), mime::PLAIN); - - match field.poll() { - Ok(Async::Ready(Some(chunk))) => - assert_eq!(chunk, "test"), - _ => unreachable!() - } - match field.poll() { - Ok(Async::Ready(None)) => (), - _ => unreachable!() - } - }, - _ => unreachable!() - } - } - _ => unreachable!() - } - - match multipart.poll() { - Ok(Async::Ready(Some(item))) => { - match item { - MultipartItem::Field(mut field) => { - assert_eq!(field.content_type().type_(), mime::TEXT); - assert_eq!(field.content_type().subtype(), mime::PLAIN); - - match field.poll() { - Ok(Async::Ready(Some(chunk))) => - assert_eq!(chunk, "data"), - _ => unreachable!() - } - match field.poll() { - Ok(Async::Ready(None)) => (), - _ => unreachable!() - } - }, - _ => unreachable!() - } - } - _ => unreachable!() - } - - match multipart.poll() { - Ok(Async::Ready(None)) => (), - _ => unreachable!() - } - - let res: Result<(), ()> = Ok(()); - result(res) - })).unwrap(); - } -} diff --git a/src/param.rs b/src/param.rs deleted file mode 100644 index b3476ae53..000000000 --- a/src/param.rs +++ /dev/null @@ -1,205 +0,0 @@ -use std; -use std::ops::Index; -use std::path::PathBuf; -use std::str::FromStr; -use std::slice::Iter; -use std::borrow::Cow; -use http::StatusCode; -use smallvec::SmallVec; - -use error::{ResponseError, UriSegmentError, InternalError}; - - -/// A trait to abstract the idea of creating a new instance of a type from a path parameter. -pub trait FromParam: Sized { - /// The associated error which can be returned from parsing. - type Err: ResponseError; - - /// Parses a string `s` to return a value of this type. - fn from_param(s: &str) -> Result; -} - -/// Route match information -/// -/// If resource path contains variable patterns, `Params` stores this variables. -#[derive(Debug)] -pub struct Params<'a>(SmallVec<[(Cow<'a, str>, Cow<'a, str>); 3]>); - -impl<'a> Params<'a> { - - pub(crate) fn new() -> Params<'a> { - Params(SmallVec::new()) - } - - pub(crate) fn clear(&mut self) { - self.0.clear(); - } - - pub(crate) fn add(&mut self, name: N, value: V) - where N: Into>, V: Into>, - { - self.0.push((name.into(), value.into())); - } - - /// Check if there are any matched patterns - pub fn is_empty(&self) -> bool { - self.0.is_empty() - } - - /// Check number of extracted parameters - pub fn len(&self) -> usize { - self.0.len() - } - - /// Get matched parameter by name without type conversion - pub fn get(&'a self, key: &str) -> Option<&'a str> { - for item in self.0.iter() { - if key == item.0 { - return Some(item.1.as_ref()) - } - } - None - } - - /// Get matched `FromParam` compatible parameter by name. - /// - /// If keyed parameter is not available empty string is used as default value. - /// - /// ```rust - /// # extern crate actix_web; - /// # use actix_web::*; - /// fn index(req: HttpRequest) -> Result { - /// let ivalue: isize = req.match_info().query("val")?; - /// Ok(format!("isuze value: {:?}", ivalue)) - /// } - /// # fn main() {} - /// ``` - pub fn query(&'a self, key: &str) -> Result::Err> - { - if let Some(s) = self.get(key) { - T::from_param(s) - } else { - T::from_param("") - } - } - - /// Return iterator to items in parameter container - pub fn iter(&self) -> Iter<(Cow<'a, str>, Cow<'a, str>)> { - self.0.iter() - } -} - -impl<'a, 'b, 'c: 'a> Index<&'b str> for &'c Params<'a> { - type Output = str; - - fn index(&self, name: &'b str) -> &str { - self.get(name).expect("Value for parameter is not available") - } -} - -impl<'a, 'c: 'a> Index for &'c Params<'a> { - type Output = str; - - fn index(&self, idx: usize) -> &str { - self.0[idx].1.as_ref() - } -} - -/// Creates a `PathBuf` from a path parameter. The returned `PathBuf` is -/// percent-decoded. If a segment is equal to "..", the previous segment (if -/// any) is skipped. -/// -/// For security purposes, if a segment meets any of the following conditions, -/// an `Err` is returned indicating the condition met: -/// -/// * Decoded segment starts with any of: `.` (except `..`), `*` -/// * Decoded segment ends with any of: `:`, `>`, `<` -/// * Decoded segment contains any of: `/` -/// * On Windows, decoded segment contains any of: '\' -/// * Percent-encoding results in invalid UTF8. -/// -/// As a result of these conditions, a `PathBuf` parsed from request path parameter is -/// safe to interpolate within, or use as a suffix of, a path without additional -/// checks. -impl FromParam for PathBuf { - type Err = UriSegmentError; - - fn from_param(val: &str) -> Result { - let mut buf = PathBuf::new(); - for segment in val.split('/') { - if segment == ".." { - buf.pop(); - } else if segment.starts_with('.') { - return Err(UriSegmentError::BadStart('.')) - } else if segment.starts_with('*') { - return Err(UriSegmentError::BadStart('*')) - } else if segment.ends_with(':') { - return Err(UriSegmentError::BadEnd(':')) - } else if segment.ends_with('>') { - return Err(UriSegmentError::BadEnd('>')) - } else if segment.ends_with('<') { - return Err(UriSegmentError::BadEnd('<')) - } else if segment.is_empty() { - continue - } else if cfg!(windows) && segment.contains('\\') { - return Err(UriSegmentError::BadChar('\\')) - } else { - buf.push(segment) - } - } - - Ok(buf) - } -} - -macro_rules! FROM_STR { - ($type:ty) => { - impl FromParam for $type { - type Err = InternalError<<$type as FromStr>::Err>; - - fn from_param(val: &str) -> Result { - <$type as FromStr>::from_str(val) - .map_err(|e| InternalError::new(e, StatusCode::BAD_REQUEST)) - } - } - } -} - -FROM_STR!(u8); -FROM_STR!(u16); -FROM_STR!(u32); -FROM_STR!(u64); -FROM_STR!(usize); -FROM_STR!(i8); -FROM_STR!(i16); -FROM_STR!(i32); -FROM_STR!(i64); -FROM_STR!(isize); -FROM_STR!(f32); -FROM_STR!(f64); -FROM_STR!(String); -FROM_STR!(std::net::IpAddr); -FROM_STR!(std::net::Ipv4Addr); -FROM_STR!(std::net::Ipv6Addr); -FROM_STR!(std::net::SocketAddr); -FROM_STR!(std::net::SocketAddrV4); -FROM_STR!(std::net::SocketAddrV6); - -#[cfg(test)] -mod tests { - use super::*; - use std::iter::FromIterator; - - #[test] - fn test_path_buf() { - assert_eq!(PathBuf::from_param("/test/.tt"), Err(UriSegmentError::BadStart('.'))); - assert_eq!(PathBuf::from_param("/test/*tt"), Err(UriSegmentError::BadStart('*'))); - assert_eq!(PathBuf::from_param("/test/tt:"), Err(UriSegmentError::BadEnd(':'))); - assert_eq!(PathBuf::from_param("/test/tt<"), Err(UriSegmentError::BadEnd('<'))); - assert_eq!(PathBuf::from_param("/test/tt>"), Err(UriSegmentError::BadEnd('>'))); - assert_eq!(PathBuf::from_param("/seg1/seg2/"), - Ok(PathBuf::from_iter(vec!["seg1", "seg2"]))); - assert_eq!(PathBuf::from_param("/seg1/../seg2/"), - Ok(PathBuf::from_iter(vec!["seg2"]))); - } -} diff --git a/src/payload.rs b/src/payload.rs deleted file mode 100644 index 8afff81c9..000000000 --- a/src/payload.rs +++ /dev/null @@ -1,675 +0,0 @@ -//! Payload stream -use std::cmp; -use std::rc::{Rc, Weak}; -use std::cell::RefCell; -use std::collections::VecDeque; -use bytes::{Bytes, BytesMut}; -use futures::{Async, Poll, Stream}; -use futures::task::{Task, current as current_task}; - -use error::PayloadError; - -/// max buffer size 32k -pub(crate) const MAX_BUFFER_SIZE: usize = 32_768; - - -#[derive(Debug, PartialEq)] -pub(crate) enum PayloadStatus { - Read, - Pause, - Dropped, -} - -/// Buffered stream of bytes chunks -/// -/// Payload stores chunks in a vector. First chunk can be received with `.readany()` method. -/// Payload stream is not thread safe. Payload does not notify current task when -/// new data is available. -/// -/// Payload stream can be used as `HttpResponse` body stream. -#[derive(Debug)] -pub struct Payload { - inner: Rc>, -} - -impl Payload { - - /// Create payload stream. - /// - /// This method construct two objects responsible for bytes stream generation. - /// - /// * `PayloadSender` - *Sender* side of the stream - /// - /// * `Payload` - *Receiver* side of the stream - pub fn new(eof: bool) -> (PayloadSender, Payload) { - let shared = Rc::new(RefCell::new(Inner::new(eof))); - - (PayloadSender{inner: Rc::downgrade(&shared)}, Payload{inner: shared}) - } - - /// Create empty payload - #[doc(hidden)] - pub fn empty() -> Payload { - Payload{inner: Rc::new(RefCell::new(Inner::new(true)))} - } - - /// Indicates EOF of payload - #[inline] - pub fn eof(&self) -> bool { - self.inner.borrow().eof() - } - - /// Length of the data in this payload - #[inline] - pub fn len(&self) -> usize { - self.inner.borrow().len() - } - - /// Is payload empty - #[inline] - pub fn is_empty(&self) -> bool { - self.inner.borrow().len() == 0 - } - - /// Put unused data back to payload - #[inline] - pub fn unread_data(&mut self, data: Bytes) { - self.inner.borrow_mut().unread_data(data); - } - - #[cfg(test)] - pub(crate) fn readall(&self) -> Option { - self.inner.borrow_mut().readall() - } - - #[inline] - /// Set read buffer capacity - /// - /// Default buffer capacity is 32Kb. - pub fn set_read_buffer_capacity(&mut self, cap: usize) { - self.inner.borrow_mut().capacity = cap; - } -} - -impl Stream for Payload { - type Item = Bytes; - type Error = PayloadError; - - #[inline] - fn poll(&mut self) -> Poll, PayloadError> { - self.inner.borrow_mut().readany() - } -} - -impl Clone for Payload { - fn clone(&self) -> Payload { - Payload{inner: Rc::clone(&self.inner)} - } -} - -/// Payload writer interface. -pub(crate) trait PayloadWriter { - - /// Set stream error. - fn set_error(&mut self, err: PayloadError); - - /// Write eof into a stream which closes reading side of a stream. - fn feed_eof(&mut self); - - /// Feed bytes into a payload stream - fn feed_data(&mut self, data: Bytes); - - /// Need read data - fn need_read(&self) -> PayloadStatus; -} - -/// Sender part of the payload stream -pub struct PayloadSender { - inner: Weak>, -} - -impl PayloadWriter for PayloadSender { - - #[inline] - fn set_error(&mut self, err: PayloadError) { - if let Some(shared) = self.inner.upgrade() { - shared.borrow_mut().set_error(err) - } - } - - #[inline] - fn feed_eof(&mut self) { - if let Some(shared) = self.inner.upgrade() { - shared.borrow_mut().feed_eof() - } - } - - #[inline] - fn feed_data(&mut self, data: Bytes) { - if let Some(shared) = self.inner.upgrade() { - shared.borrow_mut().feed_data(data) - } - } - - #[inline] - fn need_read(&self) -> PayloadStatus { - // we check need_read only if Payload (other side) is alive, - // otherwise always return true (consume payload) - if let Some(shared) = self.inner.upgrade() { - if shared.borrow().need_read { - PayloadStatus::Read - } else { - #[cfg(not(test))] - { - if shared.borrow_mut().io_task.is_none() { - shared.borrow_mut().io_task = Some(current_task()); - } - } - PayloadStatus::Pause - } - } else { - PayloadStatus::Dropped - } - } -} - -#[derive(Debug)] -struct Inner { - len: usize, - eof: bool, - err: Option, - need_read: bool, - items: VecDeque, - capacity: usize, - task: Option, - io_task: Option, -} - -impl Inner { - - fn new(eof: bool) -> Self { - Inner { - eof, - len: 0, - err: None, - items: VecDeque::new(), - need_read: true, - capacity: MAX_BUFFER_SIZE, - task: None, - io_task: None, - } - } - - #[inline] - fn set_error(&mut self, err: PayloadError) { - self.err = Some(err); - } - - #[inline] - fn feed_eof(&mut self) { - self.eof = true; - } - - #[inline] - fn feed_data(&mut self, data: Bytes) { - self.len += data.len(); - self.items.push_back(data); - self.need_read = self.len < self.capacity; - if let Some(task) = self.task.take() { - task.notify() - } - } - - #[inline] - fn eof(&self) -> bool { - self.items.is_empty() && self.eof - } - - #[inline] - fn len(&self) -> usize { - self.len - } - - #[cfg(test)] - pub(crate) fn readall(&mut self) -> Option { - let len = self.items.iter().map(|b| b.len()).sum(); - if len > 0 { - let mut buf = BytesMut::with_capacity(len); - for item in &self.items { - buf.extend_from_slice(item); - } - self.items = VecDeque::new(); - self.len = 0; - Some(buf.take().freeze()) - } else { - self.need_read = true; - None - } - } - - fn readany(&mut self) -> Poll, PayloadError> { - if let Some(data) = self.items.pop_front() { - self.len -= data.len(); - self.need_read = self.len < self.capacity; - #[cfg(not(test))] - { - if self.need_read && self.task.is_none() { - self.task = Some(current_task()); - } - if let Some(task) = self.io_task.take() { - task.notify() - } - } - Ok(Async::Ready(Some(data))) - } else if let Some(err) = self.err.take() { - Err(err) - } else if self.eof { - Ok(Async::Ready(None)) - } else { - self.need_read = true; - #[cfg(not(test))] - { - if self.task.is_none() { - self.task = Some(current_task()); - } - if let Some(task) = self.io_task.take() { - task.notify() - } - } - Ok(Async::NotReady) - } - } - - fn unread_data(&mut self, data: Bytes) { - self.len += data.len(); - self.items.push_front(data); - } -} - -pub struct PayloadHelper { - len: usize, - items: VecDeque, - stream: S, -} - -impl PayloadHelper where S: Stream { - - pub fn new(stream: S) -> Self { - PayloadHelper { - len: 0, - items: VecDeque::new(), - stream, - } - } - - /// Get mutable reference to an inner stream. - pub fn get_mut(&mut self) -> &mut S { - &mut self.stream - } - - #[inline] - fn poll_stream(&mut self) -> Poll { - self.stream.poll().map(|res| { - match res { - Async::Ready(Some(data)) => { - self.len += data.len(); - self.items.push_back(data); - Async::Ready(true) - }, - Async::Ready(None) => Async::Ready(false), - Async::NotReady => Async::NotReady, - } - }) - } - - #[inline] - pub fn readany(&mut self) -> Poll, PayloadError> { - if let Some(data) = self.items.pop_front() { - self.len -= data.len(); - Ok(Async::Ready(Some(data))) - } else { - match self.poll_stream()? { - Async::Ready(true) => self.readany(), - Async::Ready(false) => Ok(Async::Ready(None)), - Async::NotReady => Ok(Async::NotReady), - } - } - } - - #[inline] - pub fn can_read(&mut self, size: usize) -> Poll, PayloadError> { - if size <= self.len { - Ok(Async::Ready(Some(true))) - } else { - match self.poll_stream()? { - Async::Ready(true) => self.can_read(size), - Async::Ready(false) => Ok(Async::Ready(None)), - Async::NotReady => Ok(Async::NotReady), - } - } - } - - #[inline] - pub fn get_chunk(&mut self) -> Poll, PayloadError> { - if self.items.is_empty() { - match self.poll_stream()? { - Async::Ready(true) => (), - Async::Ready(false) => return Ok(Async::Ready(None)), - Async::NotReady => return Ok(Async::NotReady), - } - } - match self.items.front().map(|c| c.as_ref()) { - Some(chunk) => Ok(Async::Ready(Some(chunk))), - None => Ok(Async::NotReady), - } - } - - #[inline] - pub fn read_exact(&mut self, size: usize) -> Poll, PayloadError> { - if size <= self.len { - self.len -= size; - let mut chunk = self.items.pop_front().unwrap(); - if size < chunk.len() { - let buf = chunk.split_to(size); - self.items.push_front(chunk); - Ok(Async::Ready(Some(buf))) - } - else if size == chunk.len() { - Ok(Async::Ready(Some(chunk))) - } - else { - let mut buf = BytesMut::with_capacity(size); - buf.extend_from_slice(&chunk); - - while buf.len() < size { - let mut chunk = self.items.pop_front().unwrap(); - let rem = cmp::min(size - buf.len(), chunk.len()); - buf.extend_from_slice(&chunk.split_to(rem)); - if !chunk.is_empty() { - self.items.push_front(chunk); - } - } - Ok(Async::Ready(Some(buf.freeze()))) - } - } else { - match self.poll_stream()? { - Async::Ready(true) => self.read_exact(size), - Async::Ready(false) => Ok(Async::Ready(None)), - Async::NotReady => Ok(Async::NotReady), - } - } - } - - #[inline] - pub fn drop_payload(&mut self, size: usize) { - if size <= self.len { - self.len -= size; - - let mut len = 0; - while len < size { - let mut chunk = self.items.pop_front().unwrap(); - let rem = cmp::min(size-len, chunk.len()); - len += rem; - if rem < chunk.len() { - chunk.split_to(rem); - self.items.push_front(chunk); - } - } - } - } - - pub fn copy(&mut self, size: usize) -> Poll, PayloadError> { - if size <= self.len { - let mut buf = BytesMut::with_capacity(size); - for chunk in &self.items { - if buf.len() < size { - let rem = cmp::min(size - buf.len(), chunk.len()); - buf.extend_from_slice(&chunk[..rem]); - } - if buf.len() == size { - return Ok(Async::Ready(Some(buf))) - } - } - } - - match self.poll_stream()? { - Async::Ready(true) => self.copy(size), - Async::Ready(false) => Ok(Async::Ready(None)), - Async::NotReady => Ok(Async::NotReady), - } - } - - pub fn read_until(&mut self, line: &[u8]) -> Poll, PayloadError> { - let mut idx = 0; - let mut num = 0; - let mut offset = 0; - let mut found = false; - let mut length = 0; - - for no in 0..self.items.len() { - { - let chunk = &self.items[no]; - for (pos, ch) in chunk.iter().enumerate() { - if *ch == line[idx] { - idx += 1; - if idx == line.len() { - num = no; - offset = pos+1; - length += pos+1; - found = true; - break; - } - } else { - idx = 0 - } - } - if !found { - length += chunk.len() - } - } - - if found { - let mut buf = BytesMut::with_capacity(length); - if num > 0 { - for _ in 0..num { - buf.extend_from_slice(&self.items.pop_front().unwrap()); - } - } - if offset > 0 { - let mut chunk = self.items.pop_front().unwrap(); - buf.extend_from_slice(&chunk.split_to(offset)); - if !chunk.is_empty() { - self.items.push_front(chunk) - } - } - self.len -= length; - return Ok(Async::Ready(Some(buf.freeze()))) - } - } - - match self.poll_stream()? { - Async::Ready(true) => self.read_until(line), - Async::Ready(false) => Ok(Async::Ready(None)), - Async::NotReady => Ok(Async::NotReady), - } - } - - pub fn readline(&mut self) -> Poll, PayloadError> { - self.read_until(b"\n") - } - - pub fn unread_data(&mut self, data: Bytes) { - self.len += data.len(); - self.items.push_front(data); - } - - #[allow(dead_code)] - pub fn remaining(&mut self) -> Bytes { - self.items.iter_mut() - .fold(BytesMut::new(), |mut b, c| { - b.extend_from_slice(c); - b - }).freeze() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::io; - use failure::Fail; - use futures::future::{lazy, result}; - use tokio_core::reactor::Core; - - #[test] - fn test_error() { - let err: PayloadError = io::Error::new(io::ErrorKind::Other, "ParseError").into(); - assert_eq!(format!("{}", err), "ParseError"); - assert_eq!(format!("{}", err.cause().unwrap()), "ParseError"); - - let err = PayloadError::Incomplete; - assert_eq!(format!("{}", err), "A payload reached EOF, but is not complete."); - } - - #[test] - fn test_basic() { - Core::new().unwrap().run(lazy(|| { - let (_, payload) = Payload::new(false); - let mut payload = PayloadHelper::new(payload); - - assert_eq!(payload.len, 0); - assert_eq!(Async::NotReady, payload.readany().ok().unwrap()); - - let res: Result<(), ()> = Ok(()); - result(res) - })).unwrap(); - } - - #[test] - fn test_eof() { - Core::new().unwrap().run(lazy(|| { - let (mut sender, payload) = Payload::new(false); - let mut payload = PayloadHelper::new(payload); - - assert_eq!(Async::NotReady, payload.readany().ok().unwrap()); - sender.feed_data(Bytes::from("data")); - sender.feed_eof(); - - assert_eq!(Async::Ready(Some(Bytes::from("data"))), - payload.readany().ok().unwrap()); - assert_eq!(payload.len, 0); - assert_eq!(Async::Ready(None), payload.readany().ok().unwrap()); - - let res: Result<(), ()> = Ok(()); - result(res) - })).unwrap(); - } - - #[test] - fn test_err() { - Core::new().unwrap().run(lazy(|| { - let (mut sender, payload) = Payload::new(false); - let mut payload = PayloadHelper::new(payload); - - assert_eq!(Async::NotReady, payload.readany().ok().unwrap()); - - sender.set_error(PayloadError::Incomplete); - payload.readany().err().unwrap(); - let res: Result<(), ()> = Ok(()); - result(res) - })).unwrap(); - } - - #[test] - fn test_readany() { - Core::new().unwrap().run(lazy(|| { - let (mut sender, payload) = Payload::new(false); - let mut payload = PayloadHelper::new(payload); - - sender.feed_data(Bytes::from("line1")); - sender.feed_data(Bytes::from("line2")); - - assert_eq!(Async::Ready(Some(Bytes::from("line1"))), - payload.readany().ok().unwrap()); - assert_eq!(payload.len, 0); - - assert_eq!(Async::Ready(Some(Bytes::from("line2"))), - payload.readany().ok().unwrap()); - assert_eq!(payload.len, 0); - - let res: Result<(), ()> = Ok(()); - result(res) - })).unwrap(); - } - - #[test] - fn test_readexactly() { - Core::new().unwrap().run(lazy(|| { - let (mut sender, payload) = Payload::new(false); - let mut payload = PayloadHelper::new(payload); - - assert_eq!(Async::NotReady, payload.read_exact(2).ok().unwrap()); - - sender.feed_data(Bytes::from("line1")); - sender.feed_data(Bytes::from("line2")); - - assert_eq!(Async::Ready(Some(Bytes::from_static(b"li"))), - payload.read_exact(2).ok().unwrap()); - assert_eq!(payload.len, 3); - - assert_eq!(Async::Ready(Some(Bytes::from_static(b"ne1l"))), - payload.read_exact(4).ok().unwrap()); - assert_eq!(payload.len, 4); - - sender.set_error(PayloadError::Incomplete); - payload.read_exact(10).err().unwrap(); - - let res: Result<(), ()> = Ok(()); - result(res) - })).unwrap(); - } - - #[test] - fn test_readuntil() { - Core::new().unwrap().run(lazy(|| { - let (mut sender, payload) = Payload::new(false); - let mut payload = PayloadHelper::new(payload); - - assert_eq!(Async::NotReady, payload.read_until(b"ne").ok().unwrap()); - - sender.feed_data(Bytes::from("line1")); - sender.feed_data(Bytes::from("line2")); - - assert_eq!(Async::Ready(Some(Bytes::from("line"))), - payload.read_until(b"ne").ok().unwrap()); - assert_eq!(payload.len, 1); - - assert_eq!(Async::Ready(Some(Bytes::from("1line2"))), - payload.read_until(b"2").ok().unwrap()); - assert_eq!(payload.len, 0); - - sender.set_error(PayloadError::Incomplete); - payload.read_until(b"b").err().unwrap(); - - let res: Result<(), ()> = Ok(()); - result(res) - })).unwrap(); - } - - #[test] - fn test_unread_data() { - Core::new().unwrap().run(lazy(|| { - let (_, mut payload) = Payload::new(false); - - payload.unread_data(Bytes::from("data")); - assert!(!payload.is_empty()); - assert_eq!(payload.len(), 4); - - assert_eq!(Async::Ready(Some(Bytes::from("data"))), - payload.poll().ok().unwrap()); - - let res: Result<(), ()> = Ok(()); - result(res) - })).unwrap(); - } -} diff --git a/src/pipeline.rs b/src/pipeline.rs deleted file mode 100644 index 842d519ab..000000000 --- a/src/pipeline.rs +++ /dev/null @@ -1,791 +0,0 @@ -use std::{io, mem}; -use std::rc::Rc; -use std::cell::UnsafeCell; -use std::marker::PhantomData; - -use log::Level::Debug; -use futures::{Async, Poll, Future, Stream}; -use futures::unsync::oneshot; - -use body::{Body, BodyStream}; -use context::{Frame, ActorHttpContext}; -use error::Error; -use header::ContentEncoding; -use handler::{Reply, ReplyItem}; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; -use middleware::{Middleware, Finished, Started, Response}; -use application::Inner; -use server::{Writer, WriterState, HttpHandlerTask}; - -#[derive(Debug, Clone, Copy)] -pub(crate) enum HandlerType { - Normal(usize), - Handler(usize), - Default, -} - -pub(crate) trait PipelineHandler { - - fn encoding(&self) -> ContentEncoding; - - fn handle(&mut self, req: HttpRequest, htype: HandlerType) -> Reply; -} - -pub(crate) struct Pipeline(PipelineInfo, PipelineState); - -enum PipelineState { - None, - Error, - Starting(StartMiddlewares), - Handler(WaitingResponse), - RunMiddlewares(RunMiddlewares), - Response(ProcessResponse), - Finishing(FinishingMiddlewares), - Completed(Completed), -} - -impl> PipelineState { - - fn is_response(&self) -> bool { - match *self { - PipelineState::Response(_) => true, - _ => false, - } - } - - fn poll(&mut self, info: &mut PipelineInfo) -> Option> { - match *self { - PipelineState::Starting(ref mut state) => state.poll(info), - PipelineState::Handler(ref mut state) => state.poll(info), - PipelineState::RunMiddlewares(ref mut state) => state.poll(info), - PipelineState::Finishing(ref mut state) => state.poll(info), - PipelineState::Completed(ref mut state) => state.poll(info), - PipelineState::Response(_) | PipelineState::None | PipelineState::Error => None, - } - } -} - -struct PipelineInfo { - req: HttpRequest, - count: u16, - mws: Rc>>>, - context: Option>, - error: Option, - disconnected: Option, - encoding: ContentEncoding, -} - -impl PipelineInfo { - fn new(req: HttpRequest) -> PipelineInfo { - PipelineInfo { - req, - count: 0, - mws: Rc::new(Vec::new()), - error: None, - context: None, - disconnected: None, - encoding: ContentEncoding::Auto, - } - } - - #[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref))] - fn req_mut(&self) -> &mut HttpRequest { - #[allow(mutable_transmutes)] - unsafe{mem::transmute(&self.req)} - } - - fn poll_context(&mut self) -> Poll<(), Error> { - if let Some(ref mut context) = self.context { - match context.poll() { - Err(err) => Err(err), - Ok(Async::NotReady) => Ok(Async::NotReady), - Ok(Async::Ready(_)) => Ok(Async::Ready(())), - } - } else { - Ok(Async::Ready(())) - } - } -} - -impl> Pipeline { - - pub fn new(req: HttpRequest, - mws: Rc>>>, - handler: Rc>, htype: HandlerType) -> Pipeline - { - let mut info = PipelineInfo { - req, mws, - count: 0, - error: None, - context: None, - disconnected: None, - encoding: unsafe{&*handler.get()}.encoding(), - }; - let state = StartMiddlewares::init(&mut info, handler, htype); - - Pipeline(info, state) - } -} - -impl Pipeline<(), Inner<()>> { - pub fn error>(err: R) -> Box { - Box::new(Pipeline::<(), Inner<()>>( - PipelineInfo::new(HttpRequest::default()), ProcessResponse::init(err.into()))) - } -} - -impl Pipeline { - - fn is_done(&self) -> bool { - match self.1 { - PipelineState::None | PipelineState::Error - | PipelineState::Starting(_) | PipelineState::Handler(_) - | PipelineState::RunMiddlewares(_) | PipelineState::Response(_) => true, - PipelineState::Finishing(_) | PipelineState::Completed(_) => false, - } - } -} - -impl> HttpHandlerTask for Pipeline { - - fn disconnected(&mut self) { - self.0.disconnected = Some(true); - } - - fn poll_io(&mut self, io: &mut Writer) -> Poll { - let info: &mut PipelineInfo<_> = unsafe{ mem::transmute(&mut self.0) }; - - loop { - if self.1.is_response() { - let state = mem::replace(&mut self.1, PipelineState::None); - if let PipelineState::Response(st) = state { - match st.poll_io(io, info) { - Ok(state) => { - self.1 = state; - if let Some(error) = self.0.error.take() { - return Err(error) - } else { - return Ok(Async::Ready(self.is_done())) - } - } - Err(state) => { - self.1 = state; - return Ok(Async::NotReady); - } - } - } - } - match self.1 { - PipelineState::None => - return Ok(Async::Ready(true)), - PipelineState::Error => - return Err(io::Error::new( - io::ErrorKind::Other, "Internal error").into()), - _ => (), - } - - match self.1.poll(info) { - Some(state) => self.1 = state, - None => return Ok(Async::NotReady), - } - } - } - - fn poll(&mut self) -> Poll<(), Error> { - let info: &mut PipelineInfo<_> = unsafe{ mem::transmute(&mut self.0) }; - - loop { - match self.1 { - PipelineState::None | PipelineState::Error => { - return Ok(Async::Ready(())) - } - _ => (), - } - - if let Some(state) = self.1.poll(info) { - self.1 = state; - } else { - return Ok(Async::NotReady); - } - } - } -} - -type Fut = Box, Error=Error>>; - -/// Middlewares start executor -struct StartMiddlewares { - hnd: Rc>, - htype: HandlerType, - fut: Option, - _s: PhantomData, -} - -impl> StartMiddlewares { - - fn init(info: &mut PipelineInfo, hnd: Rc>, htype: HandlerType) - -> PipelineState - { - // execute middlewares, we need this stage because middlewares could be non-async - // and we can move to next state immediately - let len = info.mws.len() as u16; - loop { - if info.count == len { - let reply = unsafe{&mut *hnd.get()}.handle(info.req.clone(), htype); - return WaitingResponse::init(info, reply) - } else { - match info.mws[info.count as usize].start(&mut info.req) { - Ok(Started::Done) => - info.count += 1, - Ok(Started::Response(resp)) => - return RunMiddlewares::init(info, resp), - Ok(Started::Future(mut fut)) => - match fut.poll() { - Ok(Async::NotReady) => - return PipelineState::Starting(StartMiddlewares { - hnd, htype, - fut: Some(fut), - _s: PhantomData}), - Ok(Async::Ready(resp)) => { - if let Some(resp) = resp { - return RunMiddlewares::init(info, resp); - } - info.count += 1; - } - Err(err) => - return ProcessResponse::init(err.into()), - }, - Err(err) => - return ProcessResponse::init(err.into()), - } - } - } - } - - fn poll(&mut self, info: &mut PipelineInfo) -> Option> { - let len = info.mws.len() as u16; - 'outer: loop { - match self.fut.as_mut().unwrap().poll() { - Ok(Async::NotReady) => return None, - Ok(Async::Ready(resp)) => { - info.count += 1; - if let Some(resp) = resp { - return Some(RunMiddlewares::init(info, resp)); - } - if info.count == len { - let reply = unsafe{ - &mut *self.hnd.get()}.handle(info.req.clone(), self.htype); - return Some(WaitingResponse::init(info, reply)); - } else { - loop { - match info.mws[info.count as usize].start(info.req_mut()) { - Ok(Started::Done) => - info.count += 1, - Ok(Started::Response(resp)) => { - return Some(RunMiddlewares::init(info, resp)); - }, - Ok(Started::Future(fut)) => { - self.fut = Some(fut); - continue 'outer - }, - Err(err) => - return Some(ProcessResponse::init(err.into())) - } - } - } - } - Err(err) => - return Some(ProcessResponse::init(err.into())) - } - } - } -} - -// waiting for response -struct WaitingResponse { - fut: Box>, - _s: PhantomData, - _h: PhantomData, -} - -impl WaitingResponse { - - #[inline] - fn init(info: &mut PipelineInfo, reply: Reply) -> PipelineState { - match reply.into() { - ReplyItem::Message(resp) => - RunMiddlewares::init(info, resp), - ReplyItem::Future(fut) => - PipelineState::Handler( - WaitingResponse { fut, _s: PhantomData, _h: PhantomData }), - } - } - - fn poll(&mut self, info: &mut PipelineInfo) -> Option> { - match self.fut.poll() { - Ok(Async::NotReady) => None, - Ok(Async::Ready(response)) => - Some(RunMiddlewares::init(info, response)), - Err(err) => - Some(ProcessResponse::init(err.into())), - } - } -} - -/// Middlewares response executor -struct RunMiddlewares { - curr: usize, - fut: Option>>, - _s: PhantomData, - _h: PhantomData, -} - -impl RunMiddlewares { - - fn init(info: &mut PipelineInfo, mut resp: HttpResponse) -> PipelineState { - if info.count == 0 { - return ProcessResponse::init(resp); - } - let mut curr = 0; - let len = info.mws.len(); - - loop { - resp = match info.mws[curr].response(info.req_mut(), resp) { - Err(err) => { - info.count = (curr + 1) as u16; - return ProcessResponse::init(err.into()) - } - Ok(Response::Done(r)) => { - curr += 1; - if curr == len { - return ProcessResponse::init(r) - } else { - r - } - }, - Ok(Response::Future(fut)) => { - return PipelineState::RunMiddlewares( - RunMiddlewares { curr, fut: Some(fut), - _s: PhantomData, _h: PhantomData }) - }, - }; - } - } - - fn poll(&mut self, info: &mut PipelineInfo) -> Option> { - let len = info.mws.len(); - - loop { - // poll latest fut - let mut resp = match self.fut.as_mut().unwrap().poll() { - Ok(Async::NotReady) => { - return None - } - Ok(Async::Ready(resp)) => { - self.curr += 1; - resp - } - Err(err) => - return Some(ProcessResponse::init(err.into())), - }; - - loop { - if self.curr == len { - return Some(ProcessResponse::init(resp)); - } else { - match info.mws[self.curr].response(info.req_mut(), resp) { - Err(err) => - return Some(ProcessResponse::init(err.into())), - Ok(Response::Done(r)) => { - self.curr += 1; - resp = r - }, - Ok(Response::Future(fut)) => { - self.fut = Some(fut); - break - }, - } - } - } - } - } -} - -struct ProcessResponse { - resp: HttpResponse, - iostate: IOState, - running: RunningState, - drain: Option>, - _s: PhantomData, - _h: PhantomData, -} - -#[derive(PartialEq)] -enum RunningState { - Running, - Paused, - Done, -} - -impl RunningState { - #[inline] - fn pause(&mut self) { - if *self != RunningState::Done { - *self = RunningState::Paused - } - } - #[inline] - fn resume(&mut self) { - if *self != RunningState::Done { - *self = RunningState::Running - } - } -} - -enum IOState { - Response, - Payload(BodyStream), - Actor(Box), - Done, -} - -impl ProcessResponse { - - #[inline] - fn init(resp: HttpResponse) -> PipelineState { - PipelineState::Response( - ProcessResponse{ resp, - iostate: IOState::Response, - running: RunningState::Running, - drain: None, _s: PhantomData, _h: PhantomData}) - } - - fn poll_io(mut self, io: &mut Writer, info: &mut PipelineInfo) - -> Result, PipelineState> - { - loop { - if self.drain.is_none() && self.running != RunningState::Paused { - // if task is paused, write buffer is probably full - 'inner: loop { - let result = match mem::replace(&mut self.iostate, IOState::Done) { - IOState::Response => { - let encoding = self.resp.content_encoding().unwrap_or(info.encoding); - - let result = match io.start(info.req_mut().get_inner(), - &mut self.resp, encoding) - { - Ok(res) => res, - Err(err) => { - info.error = Some(err.into()); - return Ok(FinishingMiddlewares::init(info, self.resp)) - } - }; - - if let Some(err) = self.resp.error() { - if self.resp.status().is_server_error() { - error!("Error occured during request handling: {}", err); - } else { - warn!("Error occured during request handling: {}", err); - } - if log_enabled!(Debug) { - debug!("{:?}", err); - } - } - - // always poll stream or actor for the first time - match self.resp.replace_body(Body::Empty) { - Body::Streaming(stream) => { - self.iostate = IOState::Payload(stream); - continue 'inner - }, - Body::Actor(ctx) => { - self.iostate = IOState::Actor(ctx); - continue 'inner - }, - _ => (), - } - - result - }, - IOState::Payload(mut body) => { - match body.poll() { - Ok(Async::Ready(None)) => { - if let Err(err) = io.write_eof() { - info.error = Some(err.into()); - return Ok(FinishingMiddlewares::init(info, self.resp)) - } - break - }, - Ok(Async::Ready(Some(chunk))) => { - self.iostate = IOState::Payload(body); - match io.write(chunk.into()) { - Err(err) => { - info.error = Some(err.into()); - return Ok(FinishingMiddlewares::init(info, self.resp)) - }, - Ok(result) => result - } - } - Ok(Async::NotReady) => { - self.iostate = IOState::Payload(body); - break - }, - Err(err) => { - info.error = Some(err); - return Ok(FinishingMiddlewares::init(info, self.resp)) - } - } - }, - IOState::Actor(mut ctx) => { - if info.disconnected.take().is_some() { - ctx.disconnected(); - } - match ctx.poll() { - Ok(Async::Ready(Some(vec))) => { - if vec.is_empty() { - self.iostate = IOState::Actor(ctx); - break - } - let mut res = None; - for frame in vec { - match frame { - Frame::Chunk(None) => { - info.context = Some(ctx); - if let Err(err) = io.write_eof() { - info.error = Some(err.into()); - return Ok( - FinishingMiddlewares::init(info, self.resp)) - } - break 'inner - }, - Frame::Chunk(Some(chunk)) => { - match io.write(chunk) { - Err(err) => { - info.error = Some(err.into()); - return Ok( - FinishingMiddlewares::init(info, self.resp)) - }, - Ok(result) => res = Some(result), - } - }, - Frame::Drain(fut) => self.drain = Some(fut), - } - } - self.iostate = IOState::Actor(ctx); - if self.drain.is_some() { - self.running.resume(); - break 'inner - } - res.unwrap() - }, - Ok(Async::Ready(None)) => { - break - } - Ok(Async::NotReady) => { - self.iostate = IOState::Actor(ctx); - break - } - Err(err) => { - info.error = Some(err); - return Ok(FinishingMiddlewares::init(info, self.resp)) - } - } - } - IOState::Done => break, - }; - - match result { - WriterState::Pause => { - self.running.pause(); - break - } - WriterState::Done => { - self.running.resume() - }, - } - } - } - - // flush io but only if we need to - if self.running == RunningState::Paused || self.drain.is_some() { - match io.poll_completed(false) { - Ok(Async::Ready(_)) => { - self.running.resume(); - - // resolve drain futures - if let Some(tx) = self.drain.take() { - let _ = tx.send(()); - } - // restart io processing - continue - }, - Ok(Async::NotReady) => - return Err(PipelineState::Response(self)), - Err(err) => { - info.error = Some(err.into()); - return Ok(FinishingMiddlewares::init(info, self.resp)) - } - } - } - break - } - - // response is completed - match self.iostate { - IOState::Done => { - match io.write_eof() { - Ok(_) => (), - Err(err) => { - info.error = Some(err.into()); - return Ok(FinishingMiddlewares::init(info, self.resp)) - } - } - self.resp.set_response_size(io.written()); - Ok(FinishingMiddlewares::init(info, self.resp)) - } - _ => Err(PipelineState::Response(self)), - } - } -} - -/// Middlewares start executor -struct FinishingMiddlewares { - resp: HttpResponse, - fut: Option>>, - _s: PhantomData, - _h: PhantomData, -} - -impl FinishingMiddlewares { - - fn init(info: &mut PipelineInfo, resp: HttpResponse) -> PipelineState { - if info.count == 0 { - Completed::init(info) - } else { - let mut state = FinishingMiddlewares{resp, fut: None, - _s: PhantomData, _h: PhantomData}; - if let Some(st) = state.poll(info) { - st - } else { - PipelineState::Finishing(state) - } - } - } - - fn poll(&mut self, info: &mut PipelineInfo) -> Option> { - loop { - // poll latest fut - let not_ready = if let Some(ref mut fut) = self.fut { - match fut.poll() { - Ok(Async::NotReady) => { - true - }, - Ok(Async::Ready(())) => { - false - }, - Err(err) => { - error!("Middleware finish error: {}", err); - false - } - } - } else { - false - }; - if not_ready { - return None; - } - self.fut = None; - info.count -= 1; - - match info.mws[info.count as usize].finish(info.req_mut(), &self.resp) { - Finished::Done => { - if info.count == 0 { - return Some(Completed::init(info)) - } - } - Finished::Future(fut) => { - self.fut = Some(fut); - }, - } - } - } -} - -#[derive(Debug)] -struct Completed(PhantomData, PhantomData); - -impl Completed { - - #[inline] - fn init(info: &mut PipelineInfo) -> PipelineState { - if let Some(ref err) = info.error { - error!("Error occurred during request handling: {}", err); - } - - if info.context.is_none() { - PipelineState::None - } else { - PipelineState::Completed(Completed(PhantomData, PhantomData)) - } - } - - #[inline] - fn poll(&mut self, info: &mut PipelineInfo) -> Option> { - match info.poll_context() { - Ok(Async::NotReady) => None, - Ok(Async::Ready(())) => Some(PipelineState::None), - Err(_) => Some(PipelineState::Error), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use actix::*; - use context::HttpContext; - use tokio_core::reactor::Core; - use futures::future::{lazy, result}; - - impl PipelineState { - fn is_none(&self) -> Option { - if let PipelineState::None = *self { Some(true) } else { None } - } - fn completed(self) -> Option> { - if let PipelineState::Completed(c) = self { Some(c) } else { None } - } - } - - struct MyActor; - impl Actor for MyActor { - type Context = HttpContext; - } - - #[test] - fn test_completed() { - Core::new().unwrap().run(lazy(|| { - let mut info = PipelineInfo::new(HttpRequest::default()); - Completed::<(), Inner<()>>::init(&mut info).is_none().unwrap(); - - let req = HttpRequest::default(); - let mut ctx = HttpContext::new(req.clone(), MyActor); - let addr: Addr = ctx.address(); - let mut info = PipelineInfo::new(req); - info.context = Some(Box::new(ctx)); - let mut state = Completed::<(), Inner<()>>::init(&mut info).completed().unwrap(); - - assert!(state.poll(&mut info).is_none()); - let pp = Pipeline(info, PipelineState::Completed(state)); - assert!(!pp.is_done()); - - let Pipeline(mut info, st) = pp; - let mut st = st.completed().unwrap(); - drop(addr); - - assert!(st.poll(&mut info).unwrap().is_none().unwrap()); - - result(Ok::<_, ()>(())) - })).unwrap(); - } -} diff --git a/src/pred.rs b/src/pred.rs deleted file mode 100644 index 57398fc2b..000000000 --- a/src/pred.rs +++ /dev/null @@ -1,292 +0,0 @@ -//! Route match predicates -#![allow(non_snake_case)] -use std::marker::PhantomData; -use http; -use http::{header, HttpTryFrom}; -use httpmessage::HttpMessage; -use httprequest::HttpRequest; - -/// Trait defines resource route predicate. -/// Predicate can modify request object. It is also possible to -/// to store extra attributes on request by using `Extensions` container, -/// Extensions container available via `HttpRequest::extensions()` method. -pub trait Predicate { - - /// Check if request matches predicate - fn check(&self, &mut HttpRequest) -> bool; - -} - -/// Return predicate that matches if any of supplied predicate matches. -/// -/// ```rust -/// # extern crate actix_web; -/// use actix_web::{pred, App, HttpResponse}; -/// -/// fn main() { -/// App::new() -/// .resource("/index.html", |r| r.route() -/// .filter(pred::Any(pred::Get()).or(pred::Post())) -/// .f(|r| HttpResponse::MethodNotAllowed())); -/// } -/// ``` -pub fn Any + 'static>(pred: P) -> AnyPredicate -{ - AnyPredicate(vec![Box::new(pred)]) -} - -/// Matches if any of supplied predicate matches. -pub struct AnyPredicate(Vec>>); - -impl AnyPredicate { - /// Add new predicate to list of predicates to check - pub fn or + 'static>(mut self, pred: P) -> Self { - self.0.push(Box::new(pred)); - self - } -} - -impl Predicate for AnyPredicate { - fn check(&self, req: &mut HttpRequest) -> bool { - for p in &self.0 { - if p.check(req) { - return true - } - } - false - } -} - -/// Return predicate that matches if all of supplied predicate matches. -/// -/// ```rust -/// # extern crate actix_web; -/// use actix_web::{pred, Application, HttpResponse}; -/// -/// fn main() { -/// Application::new() -/// .resource("/index.html", |r| r.route() -/// .filter(pred::All(pred::Get()) -/// .and(pred::Header("content-type", "plain/text"))) -/// .f(|_| HttpResponse::MethodNotAllowed())); -/// } -/// ``` -pub fn All + 'static>(pred: P) -> AllPredicate { - AllPredicate(vec![Box::new(pred)]) -} - -/// Matches if all of supplied predicate matches. -pub struct AllPredicate(Vec>>); - -impl AllPredicate { - /// Add new predicate to list of predicates to check - pub fn and + 'static>(mut self, pred: P) -> Self { - self.0.push(Box::new(pred)); - self - } -} - -impl Predicate for AllPredicate { - fn check(&self, req: &mut HttpRequest) -> bool { - for p in &self.0 { - if !p.check(req) { - return false - } - } - true - } -} - -/// Return predicate that matches if supplied predicate does not match. -pub fn Not + 'static>(pred: P) -> NotPredicate -{ - NotPredicate(Box::new(pred)) -} - -#[doc(hidden)] -pub struct NotPredicate(Box>); - -impl Predicate for NotPredicate { - fn check(&self, req: &mut HttpRequest) -> bool { - !self.0.check(req) - } -} - -/// Http method predicate -#[doc(hidden)] -pub struct MethodPredicate(http::Method, PhantomData); - -impl Predicate for MethodPredicate { - fn check(&self, req: &mut HttpRequest) -> bool { - *req.method() == self.0 - } -} - -/// Predicate to match *GET* http method -pub fn Get() -> MethodPredicate { - MethodPredicate(http::Method::GET, PhantomData) -} - -/// Predicate to match *POST* http method -pub fn Post() -> MethodPredicate { - MethodPredicate(http::Method::POST, PhantomData) -} - -/// Predicate to match *PUT* http method -pub fn Put() -> MethodPredicate { - MethodPredicate(http::Method::PUT, PhantomData) -} - -/// Predicate to match *DELETE* http method -pub fn Delete() -> MethodPredicate { - MethodPredicate(http::Method::DELETE, PhantomData) -} - -/// Predicate to match *HEAD* http method -pub fn Head() -> MethodPredicate { - MethodPredicate(http::Method::HEAD, PhantomData) -} - -/// Predicate to match *OPTIONS* http method -pub fn Options() -> MethodPredicate { - MethodPredicate(http::Method::OPTIONS, PhantomData) -} - -/// Predicate to match *CONNECT* http method -pub fn Connect() -> MethodPredicate { - MethodPredicate(http::Method::CONNECT, PhantomData) -} - -/// Predicate to match *PATCH* http method -pub fn Patch() -> MethodPredicate { - MethodPredicate(http::Method::PATCH, PhantomData) -} - -/// Predicate to match *TRACE* http method -pub fn Trace() -> MethodPredicate { - MethodPredicate(http::Method::TRACE, PhantomData) -} - -/// Predicate to match specified http method -pub fn Method(method: http::Method) -> MethodPredicate { - MethodPredicate(method, PhantomData) -} - -/// Return predicate that matches if request contains specified header and value. -pub fn Header(name: &'static str, value: &'static str) -> HeaderPredicate -{ - HeaderPredicate(header::HeaderName::try_from(name).unwrap(), - header::HeaderValue::from_static(value), - PhantomData) -} - -#[doc(hidden)] -pub struct HeaderPredicate(header::HeaderName, header::HeaderValue, PhantomData); - -impl Predicate for HeaderPredicate { - fn check(&self, req: &mut HttpRequest) -> bool { - if let Some(val) = req.headers().get(&self.0) { - return val == self.1 - } - false - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::str::FromStr; - use http::{Uri, Version, Method}; - use http::header::{self, HeaderMap}; - - #[test] - fn test_header() { - let mut headers = HeaderMap::new(); - headers.insert(header::TRANSFER_ENCODING, - header::HeaderValue::from_static("chunked")); - let mut req = HttpRequest::new( - Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None); - - let pred = Header("transfer-encoding", "chunked"); - assert!(pred.check(&mut req)); - - let pred = Header("transfer-encoding", "other"); - assert!(!pred.check(&mut req)); - - let pred = Header("content-type", "other"); - assert!(!pred.check(&mut req)); - } - - #[test] - fn test_methods() { - let mut req = HttpRequest::new( - Method::GET, Uri::from_str("/").unwrap(), - Version::HTTP_11, HeaderMap::new(), None); - let mut req2 = HttpRequest::new( - Method::POST, Uri::from_str("/").unwrap(), - Version::HTTP_11, HeaderMap::new(), None); - - assert!(Get().check(&mut req)); - assert!(!Get().check(&mut req2)); - assert!(Post().check(&mut req2)); - assert!(!Post().check(&mut req)); - - let mut r = HttpRequest::new( - Method::PUT, Uri::from_str("/").unwrap(), - Version::HTTP_11, HeaderMap::new(), None); - assert!(Put().check(&mut r)); - assert!(!Put().check(&mut req)); - - let mut r = HttpRequest::new( - Method::DELETE, Uri::from_str("/").unwrap(), - Version::HTTP_11, HeaderMap::new(), None); - assert!(Delete().check(&mut r)); - assert!(!Delete().check(&mut req)); - - let mut r = HttpRequest::new( - Method::HEAD, Uri::from_str("/").unwrap(), - Version::HTTP_11, HeaderMap::new(), None); - assert!(Head().check(&mut r)); - assert!(!Head().check(&mut req)); - - let mut r = HttpRequest::new( - Method::OPTIONS, Uri::from_str("/").unwrap(), - Version::HTTP_11, HeaderMap::new(), None); - assert!(Options().check(&mut r)); - assert!(!Options().check(&mut req)); - - let mut r = HttpRequest::new( - Method::CONNECT, Uri::from_str("/").unwrap(), - Version::HTTP_11, HeaderMap::new(), None); - assert!(Connect().check(&mut r)); - assert!(!Connect().check(&mut req)); - - let mut r = HttpRequest::new( - Method::PATCH, Uri::from_str("/").unwrap(), - Version::HTTP_11, HeaderMap::new(), None); - assert!(Patch().check(&mut r)); - assert!(!Patch().check(&mut req)); - - let mut r = HttpRequest::new( - Method::TRACE, Uri::from_str("/").unwrap(), - Version::HTTP_11, HeaderMap::new(), None); - assert!(Trace().check(&mut r)); - assert!(!Trace().check(&mut req)); - } - - #[test] - fn test_preds() { - let mut r = HttpRequest::new( - Method::TRACE, Uri::from_str("/").unwrap(), - Version::HTTP_11, HeaderMap::new(), None); - - assert!(Not(Get()).check(&mut r)); - assert!(!Not(Trace()).check(&mut r)); - - assert!(All(Trace()).and(Trace()).check(&mut r)); - assert!(!All(Get()).and(Trace()).check(&mut r)); - - assert!(Any(Get()).or(Trace()).check(&mut r)); - assert!(!Any(Get()).or(Get()).check(&mut r)); - } -} diff --git a/src/request.rs b/src/request.rs new file mode 100644 index 000000000..84f0503c0 --- /dev/null +++ b/src/request.rs @@ -0,0 +1,536 @@ +use std::cell::{Ref, RefCell, RefMut}; +use std::rc::Rc; +use std::{fmt, net}; + +use actix_http::http::{HeaderMap, Method, Uri, Version}; +use actix_http::{Error, Extensions, HttpMessage, Message, Payload, RequestHead}; +use actix_router::{Path, Url}; +use futures::future::{ok, Ready}; + +use crate::config::AppConfig; +use crate::data::Data; +use crate::error::UrlGenerationError; +use crate::extract::FromRequest; +use crate::info::ConnectionInfo; +use crate::rmap::ResourceMap; + +#[derive(Clone)] +/// An HTTP Request +pub struct HttpRequest(pub(crate) Rc); + +pub(crate) struct HttpRequestInner { + pub(crate) head: Message, + pub(crate) path: Path, + pub(crate) payload: Payload, + pub(crate) app_data: Rc, + rmap: Rc, + config: AppConfig, + pool: &'static HttpRequestPool, +} + +impl HttpRequest { + #[inline] + pub(crate) fn new( + path: Path, + head: Message, + payload: Payload, + rmap: Rc, + config: AppConfig, + app_data: Rc, + pool: &'static HttpRequestPool, + ) -> HttpRequest { + HttpRequest(Rc::new(HttpRequestInner { + head, + path, + payload, + rmap, + config, + app_data, + pool, + })) + } +} + +impl HttpRequest { + /// This method returns reference to the request head + #[inline] + pub fn head(&self) -> &RequestHead { + &self.0.head + } + + /// This method returns muttable reference to the request head. + /// panics if multiple references of http request exists. + #[inline] + pub(crate) fn head_mut(&mut self) -> &mut RequestHead { + &mut Rc::get_mut(&mut self.0).unwrap().head + } + + /// Request's uri. + #[inline] + pub fn uri(&self) -> &Uri { + &self.head().uri + } + + /// Read the Request method. + #[inline] + pub fn method(&self) -> &Method { + &self.head().method + } + + /// Read the Request Version. + #[inline] + pub fn version(&self) -> Version { + self.head().version + } + + #[inline] + /// Returns request's headers. + pub fn headers(&self) -> &HeaderMap { + &self.head().headers + } + + /// The target path of this Request. + #[inline] + pub fn path(&self) -> &str { + self.head().uri.path() + } + + /// The query string in the URL. + /// + /// E.g., id=10 + #[inline] + pub fn query_string(&self) -> &str { + if let Some(query) = self.uri().query().as_ref() { + query + } else { + "" + } + } + + /// Get a reference to the Path parameters. + /// + /// Params is a container for url parameters. + /// A variable segment is specified in the form `{identifier}`, + /// where the identifier can be used later in a request handler to + /// access the matched value for that segment. + #[inline] + pub fn match_info(&self) -> &Path { + &self.0.path + } + + #[inline] + pub(crate) fn match_info_mut(&mut self) -> &mut Path { + &mut Rc::get_mut(&mut self.0).unwrap().path + } + + /// Request extensions + #[inline] + pub fn extensions(&self) -> Ref { + self.head().extensions() + } + + /// Mutable reference to a the request's extensions + #[inline] + pub fn extensions_mut(&self) -> RefMut { + self.head().extensions_mut() + } + + /// Generate url for named resource + /// + /// ```rust + /// # extern crate actix_web; + /// # use actix_web::{web, App, HttpRequest, HttpResponse}; + /// # + /// fn index(req: HttpRequest) -> HttpResponse { + /// let url = req.url_for("foo", &["1", "2", "3"]); // <- generate url for "foo" resource + /// HttpResponse::Ok().into() + /// } + /// + /// fn main() { + /// let app = App::new() + /// .service(web::resource("/test/{one}/{two}/{three}") + /// .name("foo") // <- set resource name, then it could be used in `url_for` + /// .route(web::get().to(|| HttpResponse::Ok())) + /// ); + /// } + /// ``` + pub fn url_for( + &self, + name: &str, + elements: U, + ) -> Result + where + U: IntoIterator, + I: AsRef, + { + self.0.rmap.url_for(&self, name, elements) + } + + /// Generate url for named resource + /// + /// This method is similar to `HttpRequest::url_for()` but it can be used + /// for urls that do not contain variable parts. + pub fn url_for_static(&self, name: &str) -> Result { + const NO_PARAMS: [&str; 0] = []; + self.url_for(name, &NO_PARAMS) + } + + #[inline] + /// Get a reference to a `ResourceMap` of current application. + pub fn resource_map(&self) -> &ResourceMap { + &self.0.rmap + } + + /// Peer socket address + /// + /// Peer address is actual socket address, if proxy is used in front of + /// actix http server, then peer address would be address of this proxy. + /// + /// To get client connection information `.connection_info()` should be used. + #[inline] + pub fn peer_addr(&self) -> Option { + self.head().peer_addr + } + + /// Get *ConnectionInfo* for the current request. + /// + /// This method panics if request's extensions container is already + /// borrowed. + #[inline] + pub fn connection_info(&self) -> Ref { + ConnectionInfo::get(self.head(), &*self.app_config()) + } + + /// App config + #[inline] + pub fn app_config(&self) -> &AppConfig { + &self.0.config + } + + /// Get an application data stored with `App::data()` method during + /// application configuration. + pub fn app_data(&self) -> Option<&T> { + if let Some(st) = self.0.app_data.get::>() { + Some(&st) + } else { + None + } + } + + /// Get an application data stored with `App::data()` method during + /// application configuration. + pub fn get_app_data(&self) -> Option> { + if let Some(st) = self.0.app_data.get::>() { + Some(st.clone()) + } else { + None + } + } +} + +impl HttpMessage for HttpRequest { + type Stream = (); + + #[inline] + /// Returns Request's headers. + fn headers(&self) -> &HeaderMap { + &self.head().headers + } + + /// Request extensions + #[inline] + fn extensions(&self) -> Ref { + self.0.head.extensions() + } + + /// Mutable reference to a the request's extensions + #[inline] + fn extensions_mut(&self) -> RefMut { + self.0.head.extensions_mut() + } + + #[inline] + fn take_payload(&mut self) -> Payload { + Payload::None + } +} + +impl Drop for HttpRequest { + fn drop(&mut self) { + if Rc::strong_count(&self.0) == 1 { + let v = &mut self.0.pool.0.borrow_mut(); + if v.len() < 128 { + self.extensions_mut().clear(); + v.push(self.0.clone()); + } + } + } +} + +/// It is possible to get `HttpRequest` as an extractor handler parameter +/// +/// ## Example +/// +/// ```rust +/// use actix_web::{web, App, HttpRequest}; +/// use serde_derive::Deserialize; +/// +/// /// extract `Thing` from request +/// async fn index(req: HttpRequest) -> String { +/// format!("Got thing: {:?}", req) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/users/{first}").route( +/// web::get().to(index)) +/// ); +/// } +/// ``` +impl FromRequest for HttpRequest { + type Config = (); + type Error = Error; + type Future = Ready>; + + #[inline] + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + ok(req.clone()) + } +} + +impl fmt::Debug for HttpRequest { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!( + f, + "\nHttpRequest {:?} {}:{}", + self.0.head.version, + self.0.head.method, + self.path() + )?; + if !self.query_string().is_empty() { + writeln!(f, " query: ?{:?}", self.query_string())?; + } + if !self.match_info().is_empty() { + writeln!(f, " params: {:?}", self.match_info())?; + } + writeln!(f, " headers:")?; + for (key, val) in self.headers().iter() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + Ok(()) + } +} + +/// Request's objects pool +pub(crate) struct HttpRequestPool(RefCell>>); + +impl HttpRequestPool { + pub(crate) fn create() -> &'static HttpRequestPool { + let pool = HttpRequestPool(RefCell::new(Vec::with_capacity(128))); + Box::leak(Box::new(pool)) + } + + /// Get message from the pool + #[inline] + pub(crate) fn get_request(&self) -> Option { + if let Some(inner) = self.0.borrow_mut().pop() { + Some(HttpRequest(inner)) + } else { + None + } + } + + pub(crate) fn clear(&self) { + self.0.borrow_mut().clear() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dev::{ResourceDef, ResourceMap}; + use crate::http::{header, StatusCode}; + use crate::test::{call_service, init_service, TestRequest}; + use crate::{web, App, HttpResponse}; + + #[test] + fn test_debug() { + let req = + TestRequest::with_header("content-type", "text/plain").to_http_request(); + let dbg = format!("{:?}", req); + assert!(dbg.contains("HttpRequest")); + } + + #[test] + fn test_no_request_cookies() { + let req = TestRequest::default().to_http_request(); + assert!(req.cookies().unwrap().is_empty()); + } + + #[test] + fn test_request_cookies() { + let req = TestRequest::default() + .header(header::COOKIE, "cookie1=value1") + .header(header::COOKIE, "cookie2=value2") + .to_http_request(); + { + let cookies = req.cookies().unwrap(); + assert_eq!(cookies.len(), 2); + assert_eq!(cookies[0].name(), "cookie2"); + assert_eq!(cookies[0].value(), "value2"); + assert_eq!(cookies[1].name(), "cookie1"); + assert_eq!(cookies[1].value(), "value1"); + } + + let cookie = req.cookie("cookie1"); + assert!(cookie.is_some()); + let cookie = cookie.unwrap(); + assert_eq!(cookie.name(), "cookie1"); + assert_eq!(cookie.value(), "value1"); + + let cookie = req.cookie("cookie-unknown"); + assert!(cookie.is_none()); + } + + #[test] + fn test_request_query() { + let req = TestRequest::with_uri("/?id=test").to_http_request(); + assert_eq!(req.query_string(), "id=test"); + } + + #[test] + fn test_url_for() { + let mut res = ResourceDef::new("/user/{name}.{ext}"); + *res.name_mut() = "index".to_string(); + + let mut rmap = ResourceMap::new(ResourceDef::new("")); + rmap.add(&mut res, None); + assert!(rmap.has_resource("/user/test.html")); + assert!(!rmap.has_resource("/test/unknown")); + + let req = TestRequest::with_header(header::HOST, "www.rust-lang.org") + .rmap(rmap) + .to_http_request(); + + assert_eq!( + req.url_for("unknown", &["test"]), + Err(UrlGenerationError::ResourceNotFound) + ); + assert_eq!( + req.url_for("index", &["test"]), + Err(UrlGenerationError::NotEnoughElements) + ); + let url = req.url_for("index", &["test", "html"]); + assert_eq!( + url.ok().unwrap().as_str(), + "http://www.rust-lang.org/user/test.html" + ); + } + + #[test] + fn test_url_for_static() { + let mut rdef = ResourceDef::new("/index.html"); + *rdef.name_mut() = "index".to_string(); + + let mut rmap = ResourceMap::new(ResourceDef::new("")); + rmap.add(&mut rdef, None); + + assert!(rmap.has_resource("/index.html")); + + let req = TestRequest::with_uri("/test") + .header(header::HOST, "www.rust-lang.org") + .rmap(rmap) + .to_http_request(); + let url = req.url_for_static("index"); + assert_eq!( + url.ok().unwrap().as_str(), + "http://www.rust-lang.org/index.html" + ); + } + + #[test] + fn test_url_for_external() { + let mut rdef = ResourceDef::new("https://youtube.com/watch/{video_id}"); + + *rdef.name_mut() = "youtube".to_string(); + + let mut rmap = ResourceMap::new(ResourceDef::new("")); + rmap.add(&mut rdef, None); + assert!(rmap.has_resource("https://youtube.com/watch/unknown")); + + let req = TestRequest::default().rmap(rmap).to_http_request(); + let url = req.url_for("youtube", &["oHg5SJYRHA0"]); + assert_eq!( + url.ok().unwrap().as_str(), + "https://youtube.com/watch/oHg5SJYRHA0" + ); + } + + #[actix_rt::test] + async fn test_app_data() { + let mut srv = init_service(App::new().data(10usize).service( + web::resource("/").to(|req: HttpRequest| { + if req.app_data::().is_some() { + HttpResponse::Ok() + } else { + HttpResponse::BadRequest() + } + }), + )) + .await; + + let req = TestRequest::default().to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let mut srv = init_service(App::new().data(10u32).service( + web::resource("/").to(|req: HttpRequest| { + if req.app_data::().is_some() { + HttpResponse::Ok() + } else { + HttpResponse::BadRequest() + } + }), + )) + .await; + + let req = TestRequest::default().to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[actix_rt::test] + async fn test_extensions_dropped() { + struct Tracker { + pub dropped: bool, + } + struct Foo { + tracker: Rc>, + } + impl Drop for Foo { + fn drop(&mut self) { + self.tracker.borrow_mut().dropped = true; + } + } + + let tracker = Rc::new(RefCell::new(Tracker { dropped: false })); + { + let tracker2 = Rc::clone(&tracker); + let mut srv = init_service(App::new().data(10u32).service( + web::resource("/").to(move |req: HttpRequest| { + req.extensions_mut().insert(Foo { + tracker: Rc::clone(&tracker2), + }); + HttpResponse::Ok() + }), + )) + .await; + + let req = TestRequest::default().to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + assert!(tracker.borrow().dropped); + } +} diff --git a/src/resource.rs b/src/resource.rs index f28363e28..866cbecf5 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -1,177 +1,781 @@ +use std::cell::RefCell; +use std::fmt; +use std::future::Future; +use std::pin::Pin; use std::rc::Rc; -use std::marker::PhantomData; +use std::task::{Context, Poll}; -use smallvec::SmallVec; -use http::{Method, StatusCode}; +use actix_http::{Error, Extensions, Response}; +use actix_service::boxed::{self, BoxService, BoxServiceFactory}; +use actix_service::{ + apply, apply_fn_factory, IntoServiceFactory, Service, ServiceFactory, Transform, +}; +use futures::future::{ok, Either, LocalBoxFuture, Ready}; -use pred; -use route::Route; -use handler::{Reply, Handler, Responder, FromRequest}; -use middleware::Middleware; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; +use crate::data::Data; +use crate::dev::{insert_slash, AppService, HttpServiceFactory, ResourceDef}; +use crate::extract::FromRequest; +use crate::guard::Guard; +use crate::handler::Factory; +use crate::responder::Responder; +use crate::route::{CreateRouteService, Route, RouteService}; +use crate::service::{ServiceRequest, ServiceResponse}; -/// *Resource* is an entry in route table which corresponds to requested URL. +type HttpService = BoxService; +type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; + +/// *Resource* is an entry in resources table which corresponds to requested URL. /// /// Resource in turn has at least one route. -/// Route consists of an object that implements `Handler` trait (handler) -/// and list of predicates (objects that implement `Predicate` trait). -/// Route uses builder-like pattern for configuration. +/// Route consists of an handlers objects and list of guards +/// (objects that implement `Guard` trait). +/// Resources and routes uses builder-like pattern for configuration. /// During request handling, resource object iterate through all routes -/// and check all predicates for specific route, if request matches all predicates route -/// route considered matched and route handler get called. +/// and check guards for specific route, if request matches all +/// guards, route considered matched and route handler get called. /// /// ```rust -/// # extern crate actix_web; -/// use actix_web::{App, HttpResponse, http}; +/// use actix_web::{web, App, HttpResponse}; /// /// fn main() { -/// let app = App::new() -/// .resource( -/// "/", |r| r.method(http::Method::GET).f(|r| HttpResponse::Ok())) -/// .finish(); +/// let app = App::new().service( +/// web::resource("/") +/// .route(web::get().to(|| HttpResponse::Ok()))); /// } -pub struct ResourceHandler { - name: String, - state: PhantomData, - routes: SmallVec<[Route; 3]>, - middlewares: Rc>>>, +/// ``` +/// +/// If no matching route could be found, *405* response code get returned. +/// Default behavior could be overriden with `default_resource()` method. +pub struct Resource { + endpoint: T, + rdef: String, + name: Option, + routes: Vec, + data: Option, + guards: Vec>, + default: Rc>>>, + factory_ref: Rc>>, } -impl Default for ResourceHandler { - fn default() -> Self { - ResourceHandler { - name: String::new(), - state: PhantomData, - routes: SmallVec::new(), - middlewares: Rc::new(Vec::new()) } +impl Resource { + pub fn new(path: &str) -> Resource { + let fref = Rc::new(RefCell::new(None)); + + Resource { + routes: Vec::new(), + rdef: path.to_string(), + name: None, + endpoint: ResourceEndpoint::new(fref.clone()), + factory_ref: fref, + guards: Vec::new(), + data: None, + default: Rc::new(RefCell::new(None)), + } } } -impl ResourceHandler { - - pub(crate) fn default_not_found() -> Self { - ResourceHandler { - name: String::new(), - state: PhantomData, - routes: SmallVec::new(), - middlewares: Rc::new(Vec::new()) } +impl Resource +where + T: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + >, +{ + /// Set resource name. + /// + /// Name is used for url generation. + pub fn name(mut self, name: &str) -> Self { + self.name = Some(name.to_string()); + self } - /// Set resource name - pub fn name>(&mut self, name: T) { - self.name = name.into(); - } - - pub(crate) fn get_name(&self) -> &str { - &self.name - } -} - -impl ResourceHandler { - - /// Register a new route and return mutable reference to *Route* object. - /// *Route* is used for route configuration, i.e. adding predicates, setting up handler. + /// Add match guard to a resource. /// /// ```rust - /// # extern crate actix_web; - /// use actix_web::*; + /// use actix_web::{web, guard, App, HttpResponse}; + /// + /// async fn index(data: web::Path<(String, String)>) -> &'static str { + /// "Welcome!" + /// } /// /// fn main() { /// let app = App::new() - /// .resource( - /// "/", |r| r.route() - /// .filter(pred::Any(pred::Get()).or(pred::Put())) - /// .filter(pred::Header("Content-Type", "text/plain")) - /// .f(|r| HttpResponse::Ok())) - /// .finish(); + /// .service( + /// web::resource("/app") + /// .guard(guard::Header("content-type", "text/plain")) + /// .route(web::get().to(index)) + /// ) + /// .service( + /// web::resource("/app") + /// .guard(guard::Header("content-type", "text/json")) + /// .route(web::get().to(|| HttpResponse::MethodNotAllowed())) + /// ); /// } /// ``` - pub fn route(&mut self) -> &mut Route { - self.routes.push(Route::default()); - self.routes.last_mut().unwrap() + pub fn guard(mut self, guard: G) -> Self { + self.guards.push(Box::new(guard)); + self } - /// Register a new route and add method check to route. + pub(crate) fn add_guards(mut self, guards: Vec>) -> Self { + self.guards.extend(guards); + self + } + + /// Register a new route. /// - /// This is shortcut for: + /// ```rust + /// use actix_web::{web, guard, App, HttpResponse}; /// - /// ```rust,ignore - /// Application::resource("/", |r| r.route().filter(pred::Get()).f(index) + /// fn main() { + /// let app = App::new().service( + /// web::resource("/").route( + /// web::route() + /// .guard(guard::Any(guard::Get()).or(guard::Put())) + /// .guard(guard::Header("Content-Type", "text/plain")) + /// .to(|| HttpResponse::Ok())) + /// ); + /// } /// ``` - pub fn method(&mut self, method: Method) -> &mut Route { - self.routes.push(Route::default()); - self.routes.last_mut().unwrap().filter(pred::Method(method)) - } - - /// Register a new route and add handler object. /// - /// This is shortcut for: + /// Multiple routes could be added to a resource. Resource object uses + /// match guards for route selection. /// - /// ```rust,ignore - /// Application::resource("/", |r| r.route().h(handler) + /// ```rust + /// use actix_web::{web, guard, App}; + /// + /// fn main() { + /// let app = App::new().service( + /// web::resource("/container/") + /// .route(web::get().to(get_handler)) + /// .route(web::post().to(post_handler)) + /// .route(web::delete().to(delete_handler)) + /// ); + /// } + /// # async fn get_handler() -> impl actix_web::Responder { actix_web::HttpResponse::Ok() } + /// # async fn post_handler() -> impl actix_web::Responder { actix_web::HttpResponse::Ok() } + /// # async fn delete_handler() -> impl actix_web::Responder { actix_web::HttpResponse::Ok() } /// ``` - pub fn h>(&mut self, handler: H) { - self.routes.push(Route::default()); - self.routes.last_mut().unwrap().h(handler) + pub fn route(mut self, route: Route) -> Self { + self.routes.push(route); + self } - /// Register a new route and add handler function. + /// Provide resource specific data. This method allows to add extractor + /// configuration or specific state available via `Data` extractor. + /// Provided data is available for all routes registered for the current resource. + /// Resource data overrides data registered by `App::data()` method. /// - /// This is shortcut for: + /// ```rust + /// use actix_web::{web, App, FromRequest}; /// - /// ```rust,ignore - /// Application::resource("/", |r| r.route().f(index) + /// /// extract text data from request + /// async fn index(body: String) -> String { + /// format!("Body {}!", body) + /// } + /// + /// fn main() { + /// let app = App::new().service( + /// web::resource("/index.html") + /// // limit size of the payload + /// .data(String::configure(|cfg| { + /// cfg.limit(4096) + /// })) + /// .route( + /// web::get() + /// // register handler + /// .to(index) + /// )); + /// } /// ``` - pub fn f(&mut self, handler: F) - where F: Fn(HttpRequest) -> R + 'static, - R: Responder + 'static, - { - self.routes.push(Route::default()); - self.routes.last_mut().unwrap().f(handler) + pub fn data(self, data: U) -> Self { + self.register_data(Data::new(data)) } - /// Register a new route and add handler. + /// Set or override application data. /// - /// This is shortcut for: - /// - /// ```rust,ignore - /// Application::resource("/", |r| r.route().with(index) - /// ``` - pub fn with(&mut self, handler: F) - where F: Fn(T) -> R + 'static, - R: Responder + 'static, - T: FromRequest + 'static, - { - self.routes.push(Route::default()); - self.routes.last_mut().unwrap().with(handler) - } - - /// Register a resource middleware - /// - /// This is similar to `App's` middlewares, but - /// middlewares get invoked on resource level. - pub fn middleware>(&mut self, mw: M) { - Rc::get_mut(&mut self.middlewares).unwrap().push(Box::new(mw)); - } - - pub(crate) fn handle(&mut self, - mut req: HttpRequest, - default: Option<&mut ResourceHandler>) -> Reply - { - for route in &mut self.routes { - if route.check(&mut req) { - return if self.middlewares.is_empty() { - route.handle(req) - } else { - route.compose(req, Rc::clone(&self.middlewares)) - }; - } + /// This method has the same effect as [`Resource::data`](#method.data), + /// except that instead of taking a value of some type `T`, it expects a + /// value of type `Data`. Use a `Data` extractor to retrieve its + /// value. + pub fn register_data(mut self, data: Data) -> Self { + if self.data.is_none() { + self.data = Some(Extensions::new()); } - if let Some(resource) = default { - resource.handle(req, None) + self.data.as_mut().unwrap().insert(data); + self + } + + /// Register a new route and add handler. This route matches all requests. + /// + /// ```rust + /// use actix_web::*; + /// + /// fn index(req: HttpRequest) -> HttpResponse { + /// unimplemented!() + /// } + /// + /// App::new().service(web::resource("/").to(index)); + /// ``` + /// + /// This is shortcut for: + /// + /// ```rust + /// # extern crate actix_web; + /// # use actix_web::*; + /// # fn index(req: HttpRequest) -> HttpResponse { unimplemented!() } + /// App::new().service(web::resource("/").route(web::route().to(index))); + /// ``` + pub fn to(mut self, handler: F) -> Self + where + F: Factory, + I: FromRequest + 'static, + R: Future + 'static, + U: Responder + 'static, + { + self.routes.push(Route::new().to(handler)); + self + } + + /// Register a resource middleware. + /// + /// This is similar to `App's` middlewares, but middleware get invoked on resource level. + /// Resource level middlewares are not allowed to change response + /// type (i.e modify response's body). + /// + /// **Note**: middlewares get called in opposite order of middlewares registration. + pub fn wrap( + self, + mw: M, + ) -> Resource< + impl ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + >, + > + where + M: Transform< + T::Service, + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + >, + { + Resource { + endpoint: apply(mw, self.endpoint), + rdef: self.rdef, + name: self.name, + guards: self.guards, + routes: self.routes, + default: self.default, + data: self.data, + factory_ref: self.factory_ref, + } + } + + /// Register a resource middleware function. + /// + /// This function accepts instance of `ServiceRequest` type and + /// mutable reference to the next middleware in chain. + /// + /// This is similar to `App's` middlewares, but middleware get invoked on resource level. + /// Resource level middlewares are not allowed to change response + /// type (i.e modify response's body). + /// + /// ```rust + /// use actix_service::Service; + /// use actix_web::{web, App}; + /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; + /// + /// async fn index() -> &'static str { + /// "Welcome!" + /// } + /// + /// fn main() { + /// let app = App::new().service( + /// web::resource("/index.html") + /// .wrap_fn(|req, srv| { + /// let fut = srv.call(req); + /// async { + /// let mut res = fut.await?; + /// res.headers_mut().insert( + /// CONTENT_TYPE, HeaderValue::from_static("text/plain"), + /// ); + /// Ok(res) + /// } + /// }) + /// .route(web::get().to(index))); + /// } + /// ``` + pub fn wrap_fn( + self, + mw: F, + ) -> Resource< + impl ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + >, + > + where + F: FnMut(ServiceRequest, &mut T::Service) -> R + Clone, + R: Future>, + { + Resource { + endpoint: apply_fn_factory(self.endpoint, mw), + rdef: self.rdef, + name: self.name, + guards: self.guards, + routes: self.routes, + default: self.default, + data: self.data, + factory_ref: self.factory_ref, + } + } + + /// Default service to be used if no matching route could be found. + /// By default *405* response get returned. Resource does not use + /// default handler from `App` or `Scope`. + pub fn default_service(mut self, f: F) -> Self + where + F: IntoServiceFactory, + U: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + > + 'static, + U::InitError: fmt::Debug, + { + // create and configure default resource + self.default = Rc::new(RefCell::new(Some(Rc::new(boxed::factory( + f.into_factory().map_init_err(|e| { + log::error!("Can not construct default service: {:?}", e) + }), + ))))); + + self + } +} + +impl HttpServiceFactory for Resource +where + T: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + > + 'static, +{ + fn register(mut self, config: &mut AppService) { + let guards = if self.guards.is_empty() { + None } else { - Reply::response(HttpResponse::new(StatusCode::NOT_FOUND)) + Some(std::mem::replace(&mut self.guards, Vec::new())) + }; + let mut rdef = if config.is_root() || !self.rdef.is_empty() { + ResourceDef::new(&insert_slash(&self.rdef)) + } else { + ResourceDef::new(&self.rdef) + }; + if let Some(ref name) = self.name { + *rdef.name_mut() = name.clone(); + } + // custom app data storage + if let Some(ref mut ext) = self.data { + config.set_service_data(ext); + } + config.register_service(rdef, guards, self, None) + } +} + +impl IntoServiceFactory for Resource +where + T: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + >, +{ + fn into_factory(self) -> T { + *self.factory_ref.borrow_mut() = Some(ResourceFactory { + routes: self.routes, + data: self.data.map(Rc::new), + default: self.default, + }); + + self.endpoint + } +} + +pub struct ResourceFactory { + routes: Vec, + data: Option>, + default: Rc>>>, +} + +impl ServiceFactory for ResourceFactory { + type Config = (); + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Service = ResourceService; + type Future = CreateResourceService; + + fn new_service(&self, _: &()) -> Self::Future { + let default_fut = if let Some(ref default) = *self.default.borrow() { + Some(default.new_service(&())) + } else { + None + }; + + CreateResourceService { + fut: self + .routes + .iter() + .map(|route| CreateRouteServiceItem::Future(route.new_service(&()))) + .collect(), + data: self.data.clone(), + default: None, + default_fut, } } } + +enum CreateRouteServiceItem { + Future(CreateRouteService), + Service(RouteService), +} + +pub struct CreateResourceService { + fut: Vec, + data: Option>, + default: Option, + default_fut: Option>>, +} + +impl Future for CreateResourceService { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut done = true; + + if let Some(ref mut fut) = self.default_fut { + match Pin::new(fut).poll(cx)? { + Poll::Ready(default) => self.default = Some(default), + Poll::Pending => done = false, + } + } + + // poll http services + for item in &mut self.fut { + match item { + CreateRouteServiceItem::Future(ref mut fut) => match Pin::new(fut) + .poll(cx)? + { + Poll::Ready(route) => *item = CreateRouteServiceItem::Service(route), + Poll::Pending => { + done = false; + } + }, + CreateRouteServiceItem::Service(_) => continue, + }; + } + + if done { + let routes = self + .fut + .drain(..) + .map(|item| match item { + CreateRouteServiceItem::Service(service) => service, + CreateRouteServiceItem::Future(_) => unreachable!(), + }) + .collect(); + Poll::Ready(Ok(ResourceService { + routes, + data: self.data.clone(), + default: self.default.take(), + })) + } else { + Poll::Pending + } + } +} + +pub struct ResourceService { + routes: Vec, + data: Option>, + default: Option, +} + +impl Service for ResourceService { + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Future = Either< + Ready>, + LocalBoxFuture<'static, Result>, + >; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, mut req: ServiceRequest) -> Self::Future { + for route in self.routes.iter_mut() { + if route.check(&mut req) { + if let Some(ref data) = self.data { + req.set_data_container(data.clone()); + } + return Either::Right(route.call(req)); + } + } + if let Some(ref mut default) = self.default { + Either::Right(default.call(req)) + } else { + let req = req.into_parts().0; + Either::Left(ok(ServiceResponse::new( + req, + Response::MethodNotAllowed().finish(), + ))) + } + } +} + +#[doc(hidden)] +pub struct ResourceEndpoint { + factory: Rc>>, +} + +impl ResourceEndpoint { + fn new(factory: Rc>>) -> Self { + ResourceEndpoint { factory } + } +} + +impl ServiceFactory for ResourceEndpoint { + type Config = (); + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Service = ResourceService; + type Future = CreateResourceService; + + fn new_service(&self, _: &()) -> Self::Future { + self.factory.borrow_mut().as_mut().unwrap().new_service(&()) + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use actix_rt::time::delay_for; + use actix_service::Service; + use futures::future::ok; + + use crate::http::{header, HeaderValue, Method, StatusCode}; + use crate::middleware::DefaultHeaders; + use crate::service::ServiceRequest; + use crate::test::{call_service, init_service, TestRequest}; + use crate::{guard, web, App, Error, HttpResponse}; + + #[actix_rt::test] + async fn test_middleware() { + let mut srv = + init_service( + App::new().service( + web::resource("/test") + .name("test") + .wrap(DefaultHeaders::new().header( + header::CONTENT_TYPE, + HeaderValue::from_static("0001"), + )) + .route(web::get().to(|| HttpResponse::Ok())), + ), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static("0001") + ); + } + + #[actix_rt::test] + async fn test_middleware_fn() { + let mut srv = init_service( + App::new().service( + web::resource("/test") + .wrap_fn(|req, srv| { + let fut = srv.call(req); + async { + fut.await.map(|mut res| { + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("0001"), + ); + res + }) + } + }) + .route(web::get().to(|| HttpResponse::Ok())), + ), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static("0001") + ); + } + + #[actix_rt::test] + async fn test_to() { + let mut srv = + init_service(App::new().service(web::resource("/test").to(|| { + async { + delay_for(Duration::from_millis(100)).await; + Ok::<_, Error>(HttpResponse::Ok()) + } + }))) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_default_resource() { + let mut srv = init_service( + App::new() + .service( + web::resource("/test").route(web::get().to(|| HttpResponse::Ok())), + ) + .default_service(|r: ServiceRequest| { + ok(r.into_response(HttpResponse::BadRequest())) + }), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/test") + .method(Method::POST) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + + let mut srv = init_service( + App::new().service( + web::resource("/test") + .route(web::get().to(|| HttpResponse::Ok())) + .default_service(|r: ServiceRequest| { + ok(r.into_response(HttpResponse::BadRequest())) + }), + ), + ) + .await; + + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/test") + .method(Method::POST) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[actix_rt::test] + async fn test_resource_guards() { + let mut srv = init_service( + App::new() + .service( + web::resource("/test/{p}") + .guard(guard::Get()) + .to(|| HttpResponse::Ok()), + ) + .service( + web::resource("/test/{p}") + .guard(guard::Put()) + .to(|| HttpResponse::Created()), + ) + .service( + web::resource("/test/{p}") + .guard(guard::Delete()) + .to(|| HttpResponse::NoContent()), + ), + ) + .await; + + let req = TestRequest::with_uri("/test/it") + .method(Method::GET) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/test/it") + .method(Method::PUT) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::CREATED); + + let req = TestRequest::with_uri("/test/it") + .method(Method::DELETE) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::NO_CONTENT); + } + + #[actix_rt::test] + async fn test_data() { + let mut srv = init_service( + App::new() + .data(1.0f64) + .data(1usize) + .register_data(web::Data::new('-')) + .service( + web::resource("/test") + .data(10usize) + .register_data(web::Data::new('*')) + .guard(guard::Get()) + .to( + |data1: web::Data, + data2: web::Data, + data3: web::Data| { + assert_eq!(*data1, 10); + assert_eq!(*data2, '*'); + assert_eq!(*data3, 1.0); + HttpResponse::Ok() + }, + ), + ), + ) + .await; + + let req = TestRequest::get().uri("/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } +} diff --git a/src/responder.rs b/src/responder.rs new file mode 100644 index 000000000..7b30315f5 --- /dev/null +++ b/src/responder.rs @@ -0,0 +1,654 @@ +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use actix_http::error::InternalError; +use actix_http::http::{ + header::IntoHeaderValue, Error as HttpError, HeaderMap, HeaderName, HttpTryFrom, + StatusCode, +}; +use actix_http::{Error, Response, ResponseBuilder}; +use bytes::{Bytes, BytesMut}; +use futures::future::{err, ok, Either as EitherFuture, Ready}; +use futures::ready; +use pin_project::{pin_project, project}; + +use crate::request::HttpRequest; + +/// Trait implemented by types that can be converted to a http response. +/// +/// Types that implement this trait can be used as the return type of a handler. +pub trait Responder { + /// The associated error which can be returned. + type Error: Into; + + /// The future response value. + type Future: Future>; + + /// Convert itself to `AsyncResult` or `Error`. + fn respond_to(self, req: &HttpRequest) -> Self::Future; + + /// Override a status code for a Responder. + /// + /// ```rust + /// use actix_web::{HttpRequest, Responder, http::StatusCode}; + /// + /// fn index(req: HttpRequest) -> impl Responder { + /// "Welcome!".with_status(StatusCode::OK) + /// } + /// # fn main() {} + /// ``` + fn with_status(self, status: StatusCode) -> CustomResponder + where + Self: Sized, + { + CustomResponder::new(self).with_status(status) + } + + /// Add header to the Responder's response. + /// + /// ```rust + /// use actix_web::{web, HttpRequest, Responder}; + /// use serde::Serialize; + /// + /// #[derive(Serialize)] + /// struct MyObj { + /// name: String, + /// } + /// + /// fn index(req: HttpRequest) -> impl Responder { + /// web::Json( + /// MyObj{name: "Name".to_string()} + /// ) + /// .with_header("x-version", "1.2.3") + /// } + /// # fn main() {} + /// ``` + fn with_header(self, key: K, value: V) -> CustomResponder + where + Self: Sized, + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + CustomResponder::new(self).with_header(key, value) + } +} + +impl Responder for Response { + type Error = Error; + type Future = Ready>; + + #[inline] + fn respond_to(self, _: &HttpRequest) -> Self::Future { + ok(self) + } +} + +impl Responder for Option +where + T: Responder, +{ + type Error = T::Error; + type Future = EitherFuture>>; + + fn respond_to(self, req: &HttpRequest) -> Self::Future { + match self { + Some(t) => EitherFuture::Left(t.respond_to(req)), + None => { + EitherFuture::Right(ok(Response::build(StatusCode::NOT_FOUND).finish())) + } + } + } +} + +impl Responder for Result +where + T: Responder, + E: Into, +{ + type Error = Error; + type Future = EitherFuture< + ResponseFuture, + Ready>, + >; + + fn respond_to(self, req: &HttpRequest) -> Self::Future { + match self { + Ok(val) => EitherFuture::Left(ResponseFuture::new(val.respond_to(req))), + Err(e) => EitherFuture::Right(err(e.into())), + } + } +} + +impl Responder for ResponseBuilder { + type Error = Error; + type Future = Ready>; + + #[inline] + fn respond_to(mut self, _: &HttpRequest) -> Self::Future { + ok(self.finish()) + } +} + +impl Responder for (T, StatusCode) +where + T: Responder, +{ + type Error = T::Error; + type Future = CustomResponderFut; + + fn respond_to(self, req: &HttpRequest) -> Self::Future { + CustomResponderFut { + fut: self.0.respond_to(req), + status: Some(self.1), + headers: None, + } + } +} + +impl Responder for &'static str { + type Error = Error; + type Future = Ready>; + + fn respond_to(self, _: &HttpRequest) -> Self::Future { + ok(Response::build(StatusCode::OK) + .content_type("text/plain; charset=utf-8") + .body(self)) + } +} + +impl Responder for &'static [u8] { + type Error = Error; + type Future = Ready>; + + fn respond_to(self, _: &HttpRequest) -> Self::Future { + ok(Response::build(StatusCode::OK) + .content_type("application/octet-stream") + .body(self)) + } +} + +impl Responder for String { + type Error = Error; + type Future = Ready>; + + fn respond_to(self, _: &HttpRequest) -> Self::Future { + ok(Response::build(StatusCode::OK) + .content_type("text/plain; charset=utf-8") + .body(self)) + } +} + +impl<'a> Responder for &'a String { + type Error = Error; + type Future = Ready>; + + fn respond_to(self, _: &HttpRequest) -> Self::Future { + ok(Response::build(StatusCode::OK) + .content_type("text/plain; charset=utf-8") + .body(self)) + } +} + +impl Responder for Bytes { + type Error = Error; + type Future = Ready>; + + fn respond_to(self, _: &HttpRequest) -> Self::Future { + ok(Response::build(StatusCode::OK) + .content_type("application/octet-stream") + .body(self)) + } +} + +impl Responder for BytesMut { + type Error = Error; + type Future = Ready>; + + fn respond_to(self, _: &HttpRequest) -> Self::Future { + ok(Response::build(StatusCode::OK) + .content_type("application/octet-stream") + .body(self)) + } +} + +/// Allows to override status code and headers for a responder. +pub struct CustomResponder { + responder: T, + status: Option, + headers: Option, + error: Option, +} + +impl CustomResponder { + fn new(responder: T) -> Self { + CustomResponder { + responder, + status: None, + headers: None, + error: None, + } + } + + /// Override a status code for the Responder's response. + /// + /// ```rust + /// use actix_web::{HttpRequest, Responder, http::StatusCode}; + /// + /// fn index(req: HttpRequest) -> impl Responder { + /// "Welcome!".with_status(StatusCode::OK) + /// } + /// # fn main() {} + /// ``` + pub fn with_status(mut self, status: StatusCode) -> Self { + self.status = Some(status); + self + } + + /// Add header to the Responder's response. + /// + /// ```rust + /// use actix_web::{web, HttpRequest, Responder}; + /// use serde::Serialize; + /// + /// #[derive(Serialize)] + /// struct MyObj { + /// name: String, + /// } + /// + /// fn index(req: HttpRequest) -> impl Responder { + /// web::Json( + /// MyObj{name: "Name".to_string()} + /// ) + /// .with_header("x-version", "1.2.3") + /// } + /// # fn main() {} + /// ``` + pub fn with_header(mut self, key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + if self.headers.is_none() { + self.headers = Some(HeaderMap::new()); + } + + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => { + self.headers.as_mut().unwrap().append(key, value); + } + Err(e) => self.error = Some(e.into()), + }, + Err(e) => self.error = Some(e.into()), + }; + self + } +} + +impl Responder for CustomResponder { + type Error = T::Error; + type Future = CustomResponderFut; + + fn respond_to(self, req: &HttpRequest) -> Self::Future { + CustomResponderFut { + fut: self.responder.respond_to(req), + status: self.status, + headers: self.headers, + } + } +} + +#[pin_project] +pub struct CustomResponderFut { + #[pin] + fut: T::Future, + status: Option, + headers: Option, +} + +impl Future for CustomResponderFut { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); + + let mut res = match ready!(this.fut.poll(cx)) { + Ok(res) => res, + Err(e) => return Poll::Ready(Err(e)), + }; + if let Some(status) = this.status.take() { + *res.status_mut() = status; + } + if let Some(ref headers) = this.headers { + for (k, v) in headers { + res.headers_mut().insert(k.clone(), v.clone()); + } + } + Poll::Ready(Ok(res)) + } +} + +/// Combines two different responder types into a single type +/// +/// ```rust +/// use actix_web::{Either, Error, HttpResponse}; +/// +/// type RegisterResult = Either>; +/// +/// fn index() -> RegisterResult { +/// if is_a_variant() { +/// // <- choose left variant +/// Either::A(HttpResponse::BadRequest().body("Bad data")) +/// } else { +/// Either::B( +/// // <- Right variant +/// Ok(HttpResponse::Ok() +/// .content_type("text/html") +/// .body("Hello!")) +/// ) +/// } +/// } +/// # fn is_a_variant() -> bool { true } +/// # fn main() {} +/// ``` +#[derive(Debug, PartialEq)] +pub enum Either { + /// First branch of the type + A(A), + /// Second branch of the type + B(B), +} + +impl Responder for Either +where + A: Responder, + B: Responder, +{ + type Error = Error; + type Future = EitherResponder; + + fn respond_to(self, req: &HttpRequest) -> Self::Future { + match self { + Either::A(a) => EitherResponder::A(a.respond_to(req)), + Either::B(b) => EitherResponder::B(b.respond_to(req)), + } + } +} + +#[pin_project] +pub enum EitherResponder +where + A: Responder, + B: Responder, +{ + A(#[pin] A::Future), + B(#[pin] B::Future), +} + +impl Future for EitherResponder +where + A: Responder, + B: Responder, +{ + type Output = Result; + + #[project] + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + #[project] + match self.project() { + EitherResponder::A(fut) => { + Poll::Ready(ready!(fut.poll(cx)).map_err(|e| e.into())) + } + EitherResponder::B(fut) => { + Poll::Ready(ready!(fut.poll(cx).map_err(|e| e.into()))) + } + } + } +} + +impl Responder for InternalError +where + T: std::fmt::Debug + std::fmt::Display + 'static, +{ + type Error = Error; + type Future = Ready>; + + fn respond_to(self, _: &HttpRequest) -> Self::Future { + let err: Error = self.into(); + ok(err.into()) + } +} + +#[pin_project] +pub struct ResponseFuture { + #[pin] + fut: T, + _t: PhantomData, +} + +impl ResponseFuture { + pub fn new(fut: T) -> Self { + ResponseFuture { + fut, + _t: PhantomData, + } + } +} + +impl Future for ResponseFuture +where + T: Future>, + E: Into, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + Poll::Ready(ready!(self.project().fut.poll(cx)).map_err(|e| e.into())) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use actix_service::Service; + use bytes::{Bytes, BytesMut}; + + use super::*; + use crate::dev::{Body, ResponseBody}; + use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; + use crate::test::{init_service, TestRequest}; + use crate::{error, web, App, HttpResponse}; + + #[actix_rt::test] + async fn test_option_responder() { + let mut srv = init_service( + App::new() + .service( + web::resource("/none").to(|| async { Option::<&'static str>::None }), + ) + .service(web::resource("/some").to(|| async { Some("some") })), + ) + .await; + + let req = TestRequest::with_uri("/none").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + let req = TestRequest::with_uri("/some").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + match resp.response().body() { + ResponseBody::Body(Body::Bytes(ref b)) => { + let bytes: Bytes = b.clone().into(); + assert_eq!(bytes, Bytes::from_static(b"some")); + } + _ => panic!(), + } + } + + pub(crate) trait BodyTest { + fn bin_ref(&self) -> &[u8]; + fn body(&self) -> &Body; + } + + impl BodyTest for ResponseBody { + fn bin_ref(&self) -> &[u8] { + match self { + ResponseBody::Body(ref b) => match b { + Body::Bytes(ref bin) => &bin, + _ => panic!(), + }, + ResponseBody::Other(ref b) => match b { + Body::Bytes(ref bin) => &bin, + _ => panic!(), + }, + } + } + fn body(&self) -> &Body { + match self { + ResponseBody::Body(ref b) => b, + ResponseBody::Other(ref b) => b, + } + } + } + + #[actix_rt::test] + async fn test_responder() { + let req = TestRequest::default().to_http_request(); + + let resp: HttpResponse = "test".respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + + let resp: HttpResponse = b"test".respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + + let resp: HttpResponse = "test".to_string().respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + + let resp: HttpResponse = (&"test".to_string()).respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + + let resp: HttpResponse = + Bytes::from_static(b"test").respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + + let resp: HttpResponse = BytesMut::from(b"test".as_ref()) + .respond_to(&req) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + + // InternalError + let resp: HttpResponse = + error::InternalError::new("err", StatusCode::BAD_REQUEST) + .respond_to(&req) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[actix_rt::test] + async fn test_result_responder() { + let req = TestRequest::default().to_http_request(); + + // Result + let resp: HttpResponse = Ok::<_, Error>("test".to_string()) + .respond_to(&req) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + + let res = + Err::(error::InternalError::new("err", StatusCode::BAD_REQUEST)) + .respond_to(&req) + .await; + assert!(res.is_err()); + } + + #[actix_rt::test] + async fn test_custom_responder() { + let req = TestRequest::default().to_http_request(); + let res = "test" + .to_string() + .with_status(StatusCode::BAD_REQUEST) + .respond_to(&req) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + assert_eq!(res.body().bin_ref(), b"test"); + + let res = "test" + .to_string() + .with_header("content-type", "json") + .respond_to(&req) + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.body().bin_ref(), b"test"); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("json") + ); + } + + #[actix_rt::test] + async fn test_tuple_responder_with_status_code() { + let req = TestRequest::default().to_http_request(); + let res = ("test".to_string(), StatusCode::BAD_REQUEST) + .respond_to(&req) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + assert_eq!(res.body().bin_ref(), b"test"); + + let req = TestRequest::default().to_http_request(); + let res = ("test".to_string(), StatusCode::OK) + .with_header("content-type", "json") + .respond_to(&req) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.body().bin_ref(), b"test"); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("json") + ); + } +} diff --git a/src/rmap.rs b/src/rmap.rs new file mode 100644 index 000000000..42ddb1349 --- /dev/null +++ b/src/rmap.rs @@ -0,0 +1,190 @@ +use std::cell::RefCell; +use std::rc::Rc; + +use actix_router::ResourceDef; +use hashbrown::HashMap; +use url::Url; + +use crate::error::UrlGenerationError; +use crate::request::HttpRequest; + +#[derive(Clone, Debug)] +pub struct ResourceMap { + root: ResourceDef, + parent: RefCell>>, + named: HashMap, + patterns: Vec<(ResourceDef, Option>)>, +} + +impl ResourceMap { + pub fn new(root: ResourceDef) -> Self { + ResourceMap { + root, + parent: RefCell::new(None), + named: HashMap::new(), + patterns: Vec::new(), + } + } + + pub fn add(&mut self, pattern: &mut ResourceDef, nested: Option>) { + pattern.set_id(self.patterns.len() as u16); + self.patterns.push((pattern.clone(), nested)); + if !pattern.name().is_empty() { + self.named + .insert(pattern.name().to_string(), pattern.clone()); + } + } + + pub(crate) fn finish(&self, current: Rc) { + for (_, nested) in &self.patterns { + if let Some(ref nested) = nested { + *nested.parent.borrow_mut() = Some(current.clone()); + nested.finish(nested.clone()); + } + } + } +} + +impl ResourceMap { + /// Generate url for named resource + /// + /// Check [`HttpRequest::url_for()`](../struct.HttpRequest.html#method. + /// url_for) for detailed information. + pub fn url_for( + &self, + req: &HttpRequest, + name: &str, + elements: U, + ) -> Result + where + U: IntoIterator, + I: AsRef, + { + let mut path = String::new(); + let mut elements = elements.into_iter(); + + if self.patterns_for(name, &mut path, &mut elements)?.is_some() { + if path.starts_with('/') { + let conn = req.connection_info(); + Ok(Url::parse(&format!( + "{}://{}{}", + conn.scheme(), + conn.host(), + path + ))?) + } else { + Ok(Url::parse(&path)?) + } + } else { + Err(UrlGenerationError::ResourceNotFound) + } + } + + pub fn has_resource(&self, path: &str) -> bool { + let path = if path.is_empty() { "/" } else { path }; + + for (pattern, rmap) in &self.patterns { + if let Some(ref rmap) = rmap { + if let Some(plen) = pattern.is_prefix_match(path) { + return rmap.has_resource(&path[plen..]); + } + } else if pattern.is_match(path) { + return true; + } + } + false + } + + fn patterns_for( + &self, + name: &str, + path: &mut String, + elements: &mut U, + ) -> Result, UrlGenerationError> + where + U: Iterator, + I: AsRef, + { + if self.pattern_for(name, path, elements)?.is_some() { + Ok(Some(())) + } else { + self.parent_pattern_for(name, path, elements) + } + } + + fn pattern_for( + &self, + name: &str, + path: &mut String, + elements: &mut U, + ) -> Result, UrlGenerationError> + where + U: Iterator, + I: AsRef, + { + if let Some(pattern) = self.named.get(name) { + if pattern.pattern().starts_with('/') { + self.fill_root(path, elements)?; + } + if pattern.resource_path(path, elements) { + Ok(Some(())) + } else { + Err(UrlGenerationError::NotEnoughElements) + } + } else { + for (_, rmap) in &self.patterns { + if let Some(ref rmap) = rmap { + if rmap.pattern_for(name, path, elements)?.is_some() { + return Ok(Some(())); + } + } + } + Ok(None) + } + } + + fn fill_root( + &self, + path: &mut String, + elements: &mut U, + ) -> Result<(), UrlGenerationError> + where + U: Iterator, + I: AsRef, + { + if let Some(ref parent) = *self.parent.borrow() { + parent.fill_root(path, elements)?; + } + if self.root.resource_path(path, elements) { + Ok(()) + } else { + Err(UrlGenerationError::NotEnoughElements) + } + } + + fn parent_pattern_for( + &self, + name: &str, + path: &mut String, + elements: &mut U, + ) -> Result, UrlGenerationError> + where + U: Iterator, + I: AsRef, + { + if let Some(ref parent) = *self.parent.borrow() { + if let Some(pattern) = parent.named.get(name) { + self.fill_root(path, elements)?; + if pattern.resource_path(path, elements) { + Ok(Some(())) + } else { + Err(UrlGenerationError::NotEnoughElements) + } + } else { + parent.parent_pattern_for(name, path, elements) + } + } else { + Ok(None) + } + } +} diff --git a/src/route.rs b/src/route.rs index 1eebaa3ea..93f88bfe2 100644 --- a/src/route.rs +++ b/src/route.rs @@ -1,115 +1,185 @@ -use std::mem; +use std::future::Future; +use std::pin::Pin; use std::rc::Rc; -use std::marker::PhantomData; -use futures::{Async, Future, Poll}; +use std::task::{Context, Poll}; -use error::Error; -use pred::Predicate; -use http::StatusCode; -use handler::{Reply, ReplyItem, Handler, FromRequest, - Responder, RouteHandler, AsyncHandler, WrapHandler}; -use middleware::{Middleware, Response as MiddlewareResponse, Started as MiddlewareStarted}; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; -use with::{With, With2, With3}; +use actix_http::{http::Method, Error}; +use actix_service::{Service, ServiceFactory}; +use futures::future::{ready, FutureExt, LocalBoxFuture}; + +use crate::extract::FromRequest; +use crate::guard::{self, Guard}; +use crate::handler::{Extract, Factory, Handler}; +use crate::responder::Responder; +use crate::service::{ServiceRequest, ServiceResponse}; +use crate::HttpResponse; + +type BoxedRouteService = Box< + dyn Service< + Request = Req, + Response = Res, + Error = Error, + Future = LocalBoxFuture<'static, Result>, + >, +>; + +type BoxedRouteNewService = Box< + dyn ServiceFactory< + Config = (), + Request = Req, + Response = Res, + Error = Error, + InitError = (), + Service = BoxedRouteService, + Future = LocalBoxFuture<'static, Result, ()>>, + >, +>; /// Resource route definition /// /// Route uses builder-like pattern for configuration. /// If handler is not explicitly set, default *404 Not Found* handler is used. -pub struct Route { - preds: Vec>>, - handler: InnerHandler, +pub struct Route { + service: BoxedRouteNewService, + guards: Rc>>, } -impl Default for Route { - - fn default() -> Route { +impl Route { + /// Create new route which matches any request. + pub fn new() -> Route { Route { - preds: Vec::new(), - handler: InnerHandler::new(|_| HttpResponse::new(StatusCode::NOT_FOUND)), + service: Box::new(RouteNewService::new(Extract::new(Handler::new(|| { + ready(HttpResponse::NotFound()) + })))), + guards: Rc::new(Vec::new()), + } + } + + pub(crate) fn take_guards(&mut self) -> Vec> { + std::mem::replace(Rc::get_mut(&mut self.guards).unwrap(), Vec::new()) + } +} + +impl ServiceFactory for Route { + type Config = (); + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Service = RouteService; + type Future = CreateRouteService; + + fn new_service(&self, _: &()) -> Self::Future { + CreateRouteService { + fut: self.service.new_service(&()), + guards: self.guards.clone(), } } } -impl Route { +type RouteFuture = LocalBoxFuture< + 'static, + Result, ()>, +>; - #[inline] - pub(crate) fn check(&self, req: &mut HttpRequest) -> bool { - for pred in &self.preds { - if !pred.check(req) { - return false +#[pin_project::pin_project] +pub struct CreateRouteService { + #[pin] + fut: RouteFuture, + guards: Rc>>, +} + +impl Future for CreateRouteService { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); + + match this.fut.poll(cx)? { + Poll::Ready(service) => Poll::Ready(Ok(RouteService { + service, + guards: this.guards.clone(), + })), + Poll::Pending => Poll::Pending, + } + } +} + +pub struct RouteService { + service: BoxedRouteService, + guards: Rc>>, +} + +impl RouteService { + pub fn check(&self, req: &mut ServiceRequest) -> bool { + for f in self.guards.iter() { + if !f.check(req.head()) { + return false; } } true } +} - #[inline] - pub(crate) fn handle(&mut self, req: HttpRequest) -> Reply { - self.handler.handle(req) +impl Service for RouteService { + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) } - #[inline] - pub(crate) fn compose(&mut self, - req: HttpRequest, - mws: Rc>>>) -> Reply { - Reply::async(Compose::new(req, mws, self.handler.clone())) + fn call(&mut self, req: ServiceRequest) -> Self::Future { + self.service.call(req).boxed_local() } +} - /// Add match predicate to route. +impl Route { + /// Add method guard to the route. /// /// ```rust - /// # extern crate actix_web; /// # use actix_web::*; /// # fn main() { - /// App::new() - /// .resource("/path", |r| - /// r.route() - /// .filter(pred::Get()) - /// .filter(pred::Header("content-type", "text/plain")) - /// .f(|req| HttpResponse::Ok()) - /// ) - /// # .finish(); + /// App::new().service(web::resource("/path").route( + /// web::get() + /// .method(http::Method::CONNECT) + /// .guard(guard::Header("content-type", "text/plain")) + /// .to(|req: HttpRequest| HttpResponse::Ok())) + /// ); /// # } /// ``` - pub fn filter + 'static>(&mut self, p: T) -> &mut Self { - self.preds.push(Box::new(p)); + pub fn method(mut self, method: Method) -> Self { + Rc::get_mut(&mut self.guards) + .unwrap() + .push(Box::new(guard::Method(method))); self } - /// Set handler object. Usually call to this method is last call - /// during route configuration, so it does not return reference to self. - pub fn h>(&mut self, handler: H) { - self.handler = InnerHandler::new(handler); - } - - /// Set handler function. Usually call to this method is last call - /// during route configuration, so it does not return reference to self. - pub fn f(&mut self, handler: F) - where F: Fn(HttpRequest) -> R + 'static, - R: Responder + 'static, - { - self.handler = InnerHandler::new(handler); - } - - /// Set async handler function. - pub fn a(&mut self, handler: H) - where H: Fn(HttpRequest) -> F + 'static, - F: Future + 'static, - R: Responder + 'static, - E: Into + 'static - { - self.handler = InnerHandler::async(handler); - } - - /// Set handler function with http request extractor. + /// Add guard to the route. /// /// ```rust - /// # extern crate bytes; - /// # extern crate actix_web; - /// # extern crate futures; - /// #[macro_use] extern crate serde_derive; - /// use actix_web::{App, Path, Result, http}; + /// # use actix_web::*; + /// # fn main() { + /// App::new().service(web::resource("/path").route( + /// web::route() + /// .guard(guard::Get()) + /// .guard(guard::Header("content-type", "text/plain")) + /// .to(|req: HttpRequest| HttpResponse::Ok())) + /// ); + /// # } + /// ``` + pub fn guard(mut self, f: F) -> Self { + Rc::get_mut(&mut self.guards).unwrap().push(Box::new(f)); + self + } + + /// Set handler function, use request extractors for parameters. + /// + /// ```rust + /// use actix_web::{web, http, App}; + /// use serde_derive::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -117,376 +187,245 @@ impl Route { /// } /// /// /// extract path info using serde - /// fn index(info: Path) -> Result { - /// Ok(format!("Welcome {}!", info.username)) + /// async fn index(info: web::Path) -> String { + /// format!("Welcome {}!", info.username) /// } /// /// fn main() { - /// let app = App::new().resource( - /// "/{username}/index.html", // <- define path parameters - /// |r| r.method(http::Method::GET).with(index)); // <- use `with` extractor + /// let app = App::new().service( + /// web::resource("/{username}/index.html") // <- define path parameters + /// .route(web::get().to(index)) // <- register handler + /// ); /// } /// ``` - pub fn with(&mut self, handler: F) - where F: Fn(T) -> R + 'static, - R: Responder + 'static, - T: FromRequest + 'static, - { - self.h(With::new(handler)) - } - - /// Set handler function, function has to accept two request extractors. + /// + /// It is possible to use multiple extractors for one handler function. /// /// ```rust - /// # extern crate bytes; - /// # extern crate actix_web; - /// # extern crate futures; - /// #[macro_use] extern crate serde_derive; - /// use actix_web::{App, Query, Path, Result, http}; + /// # use std::collections::HashMap; + /// # use serde_derive::Deserialize; + /// use actix_web::{web, App}; /// /// #[derive(Deserialize)] - /// struct PParam { + /// struct Info { /// username: String, /// } /// - /// #[derive(Deserialize)] - /// struct QParam { - /// count: u32, - /// } - /// - /// /// extract path and query information using serde - /// fn index(p: Path, q: Query) -> Result { - /// Ok(format!("Welcome {}!", p.username)) + /// /// extract path info using serde + /// async fn index(path: web::Path, query: web::Query>, body: web::Json) -> String { + /// format!("Welcome {}!", path.username) /// } /// /// fn main() { - /// let app = App::new().resource( - /// "/{username}/index.html", // <- define path parameters - /// |r| r.method(http::Method::GET).with2(index)); // <- use `with` extractor + /// let app = App::new().service( + /// web::resource("/{username}/index.html") // <- define path parameters + /// .route(web::get().to(index)) + /// ); /// } /// ``` - pub fn with2(&mut self, handler: F) - where F: Fn(T1, T2) -> R + 'static, - R: Responder + 'static, - T1: FromRequest + 'static, - T2: FromRequest + 'static, + pub fn to(mut self, handler: F) -> Self + where + F: Factory, + T: FromRequest + 'static, + R: Future + 'static, + U: Responder + 'static, { - self.h(With2::new(handler)) - } - - /// Set handler function, function has to accept three request extractors. - pub fn with3(&mut self, handler: F) - where F: Fn(T1, T2, T3) -> R + 'static, - R: Responder + 'static, - T1: FromRequest + 'static, - T2: FromRequest + 'static, - T3: FromRequest + 'static, - { - self.h(With3::new(handler)) + self.service = + Box::new(RouteNewService::new(Extract::new(Handler::new(handler)))); + self } } -/// `RouteHandler` wrapper. This struct is required because it needs to be shared -/// for resource level middlewares. -struct InnerHandler(Rc>>); +struct RouteNewService +where + T: ServiceFactory, +{ + service: T, +} -impl InnerHandler { - - #[inline] - fn new>(h: H) -> Self { - InnerHandler(Rc::new(Box::new(WrapHandler::new(h)))) - } - - #[inline] - fn async(h: H) -> Self - where H: Fn(HttpRequest) -> F + 'static, - F: Future + 'static, - R: Responder + 'static, - E: Into + 'static - { - InnerHandler(Rc::new(Box::new(AsyncHandler::new(h)))) - } - - #[inline] - pub fn handle(&self, req: HttpRequest) -> Reply { - // reason: handler is unique per thread, - // handler get called from async code only - #[allow(mutable_transmutes)] - #[cfg_attr(feature = "cargo-clippy", allow(borrowed_box))] - let h: &mut Box> = unsafe { mem::transmute(self.0.as_ref()) }; - h.handle(req) +impl RouteNewService +where + T: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = (Error, ServiceRequest), + >, + T::Future: 'static, + T::Service: 'static, + ::Future: 'static, +{ + pub fn new(service: T) -> Self { + RouteNewService { service } } } -impl Clone for InnerHandler { - #[inline] - fn clone(&self) -> Self { - InnerHandler(Rc::clone(&self.0)) - } -} - - -/// Compose resource level middlewares with route handler. -struct Compose { - info: ComposeInfo, - state: ComposeState, -} - -struct ComposeInfo { - count: usize, - req: HttpRequest, - mws: Rc>>>, - handler: InnerHandler, -} - -enum ComposeState { - Starting(StartMiddlewares), - Handler(WaitingResponse), - RunMiddlewares(RunMiddlewares), - Response(Response), -} - -impl ComposeState { - fn poll(&mut self, info: &mut ComposeInfo) -> Option> { - match *self { - ComposeState::Starting(ref mut state) => state.poll(info), - ComposeState::Handler(ref mut state) => state.poll(info), - ComposeState::RunMiddlewares(ref mut state) => state.poll(info), - ComposeState::Response(_) => None, - } - } -} - -impl Compose { - fn new(req: HttpRequest, - mws: Rc>>>, - handler: InnerHandler) -> Self - { - let mut info = ComposeInfo { count: 0, req, mws, handler }; - let state = StartMiddlewares::init(&mut info); - - Compose {state, info} - } -} - -impl Future for Compose { - type Item = HttpResponse; +impl ServiceFactory for RouteNewService +where + T: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = (Error, ServiceRequest), + >, + T::Future: 'static, + T::Service: 'static, + ::Future: 'static, +{ + type Config = (); + type Request = ServiceRequest; + type Response = ServiceResponse; type Error = Error; + type InitError = (); + type Service = BoxedRouteService; + type Future = LocalBoxFuture<'static, Result>; - fn poll(&mut self) -> Poll { - loop { - if let ComposeState::Response(ref mut resp) = self.state { - let resp = resp.resp.take().unwrap(); - return Ok(Async::Ready(resp)) - } - if let Some(state) = self.state.poll(&mut self.info) { - self.state = state; - } else { - return Ok(Async::NotReady) - } - } + fn new_service(&self, _: &()) -> Self::Future { + self.service + .new_service(&()) + .map(|result| match result { + Ok(service) => { + let service: BoxedRouteService<_, _> = + Box::new(RouteServiceWrapper { service }); + Ok(service) + } + Err(_) => Err(()), + }) + .boxed_local() } } -/// Middlewares start executor -struct StartMiddlewares { - fut: Option, - _s: PhantomData, +struct RouteServiceWrapper { + service: T, } -type Fut = Box, Error=Error>>; +impl Service for RouteServiceWrapper +where + T::Future: 'static, + T: Service< + Request = ServiceRequest, + Response = ServiceResponse, + Error = (Error, ServiceRequest), + >, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; -impl StartMiddlewares { + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx).map_err(|(e, _)| e) + } - fn init(info: &mut ComposeInfo) -> ComposeState { - let len = info.mws.len(); - loop { - if info.count == len { - let reply = info.handler.handle(info.req.clone()); - return WaitingResponse::init(info, reply) - } else { - match info.mws[info.count].start(&mut info.req) { - Ok(MiddlewareStarted::Done) => - info.count += 1, - Ok(MiddlewareStarted::Response(resp)) => - return RunMiddlewares::init(info, resp), - Ok(MiddlewareStarted::Future(mut fut)) => - match fut.poll() { - Ok(Async::NotReady) => - return ComposeState::Starting(StartMiddlewares { - fut: Some(fut), - _s: PhantomData}), - Ok(Async::Ready(resp)) => { - if let Some(resp) = resp { - return RunMiddlewares::init(info, resp); - } - info.count += 1; + fn call(&mut self, req: ServiceRequest) -> Self::Future { + // let mut fut = self.service.call(req); + self.service + .call(req) + .map(|res| match res { + Ok(res) => Ok(res), + Err((err, req)) => Ok(req.error_response(err)), + }) + .boxed_local() + + // match fut.poll() { + // Poll::Ready(Ok(res)) => Either::Left(ok(res)), + // Poll::Ready(Err((e, req))) => Either::Left(ok(req.error_response(e))), + // Poll::Pending => Either::Right(Box::new(fut.then(|res| match res { + // Ok(res) => Ok(res), + // Err((err, req)) => Ok(req.error_response(err)), + // }))), + // } + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use actix_rt::time::delay_for; + use bytes::Bytes; + use serde_derive::Serialize; + + use crate::http::{Method, StatusCode}; + use crate::test::{call_service, init_service, read_body, TestRequest}; + use crate::{error, web, App, HttpResponse}; + + #[derive(Serialize, PartialEq, Debug)] + struct MyObject { + name: String, + } + + #[actix_rt::test] + async fn test_route() { + let mut srv = init_service( + App::new() + .service( + web::resource("/test") + .route(web::get().to(|| HttpResponse::Ok())) + .route(web::put().to(|| { + async { + Err::(error::ErrorBadRequest("err")) } - Err(err) => - return Response::init(err.into()), - }, - Err(err) => - return Response::init(err.into()), - } - } - } - } - - fn poll(&mut self, info: &mut ComposeInfo) -> Option> - { - let len = info.mws.len(); - 'outer: loop { - match self.fut.as_mut().unwrap().poll() { - Ok(Async::NotReady) => - return None, - Ok(Async::Ready(resp)) => { - info.count += 1; - if let Some(resp) = resp { - return Some(RunMiddlewares::init(info, resp)); - } - if info.count == len { - let reply = info.handler.handle(info.req.clone()); - return Some(WaitingResponse::init(info, reply)); - } else { - loop { - match info.mws[info.count].start(&mut info.req) { - Ok(MiddlewareStarted::Done) => - info.count += 1, - Ok(MiddlewareStarted::Response(resp)) => { - return Some(RunMiddlewares::init(info, resp)); - }, - Ok(MiddlewareStarted::Future(fut)) => { - self.fut = Some(fut); - continue 'outer - }, - Err(err) => - return Some(Response::init(err.into())) + })) + .route(web::post().to(|| { + async { + delay_for(Duration::from_millis(100)).await; + HttpResponse::Created() } - } + })) + .route(web::delete().to(|| { + async { + delay_for(Duration::from_millis(100)).await; + Err::(error::ErrorBadRequest("err")) + } + })), + ) + .service(web::resource("/json").route(web::get().to(|| { + async { + delay_for(Duration::from_millis(25)).await; + web::Json(MyObject { + name: "test".to_string(), + }) } - } - Err(err) => - return Some(Response::init(err.into())) - } - } - } -} - -// waiting for response -struct WaitingResponse { - fut: Box>, - _s: PhantomData, -} - -impl WaitingResponse { - - #[inline] - fn init(info: &mut ComposeInfo, reply: Reply) -> ComposeState { - match reply.into() { - ReplyItem::Message(resp) => - RunMiddlewares::init(info, resp), - ReplyItem::Future(fut) => - ComposeState::Handler( - WaitingResponse { fut, _s: PhantomData }), - } - } - - fn poll(&mut self, info: &mut ComposeInfo) -> Option> { - match self.fut.poll() { - Ok(Async::NotReady) => None, - Ok(Async::Ready(response)) => - Some(RunMiddlewares::init(info, response)), - Err(err) => - Some(Response::init(err.into())), - } - } -} - - -/// Middlewares response executor -struct RunMiddlewares { - curr: usize, - fut: Option>>, - _s: PhantomData, -} - -impl RunMiddlewares { - - fn init(info: &mut ComposeInfo, mut resp: HttpResponse) -> ComposeState { - let mut curr = 0; - let len = info.mws.len(); - - loop { - resp = match info.mws[curr].response(&mut info.req, resp) { - Err(err) => { - info.count = curr + 1; - return Response::init(err.into()) - }, - Ok(MiddlewareResponse::Done(r)) => { - curr += 1; - if curr == len { - return Response::init(r) - } else { - r - } - }, - Ok(MiddlewareResponse::Future(fut)) => { - return ComposeState::RunMiddlewares( - RunMiddlewares { curr, fut: Some(fut), _s: PhantomData }) - }, - }; - } - } - - fn poll(&mut self, info: &mut ComposeInfo) -> Option> - { - let len = info.mws.len(); - - loop { - // poll latest fut - let mut resp = match self.fut.as_mut().unwrap().poll() { - Ok(Async::NotReady) => { - return None - } - Ok(Async::Ready(resp)) => { - self.curr += 1; - resp - } - Err(err) => - return Some(Response::init(err.into())), - }; - - loop { - if self.curr == len { - return Some(Response::init(resp)); - } else { - match info.mws[self.curr].response(&mut info.req, resp) { - Err(err) => - return Some(Response::init(err.into())), - Ok(MiddlewareResponse::Done(r)) => { - self.curr += 1; - resp = r - }, - Ok(MiddlewareResponse::Future(fut)) => { - self.fut = Some(fut); - break - }, - } - } - } - } - } -} - -struct Response { - resp: Option, - _s: PhantomData, -} - -impl Response { - - fn init(resp: HttpResponse) -> ComposeState { - ComposeState::Response( - Response{resp: Some(resp), _s: PhantomData}) + }))), + ) + .await; + + let req = TestRequest::with_uri("/test") + .method(Method::GET) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/test") + .method(Method::POST) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::CREATED); + + let req = TestRequest::with_uri("/test") + .method(Method::PUT) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + let req = TestRequest::with_uri("/test") + .method(Method::DELETE) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + let req = TestRequest::with_uri("/test") + .method(Method::HEAD) + .to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + + let req = TestRequest::with_uri("/json").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let body = read_body(resp).await; + assert_eq!(body, Bytes::from_static(b"{\"name\":\"test\"}")); } } diff --git a/src/router.rs b/src/router.rs deleted file mode 100644 index b8e6baf00..000000000 --- a/src/router.rs +++ /dev/null @@ -1,544 +0,0 @@ -use std::mem; -use std::rc::Rc; -use std::hash::{Hash, Hasher}; -use std::collections::HashMap; - -use regex::{Regex, escape}; -use percent_encoding::percent_decode; - -use param::Params; -use error::UrlGenerationError; -use resource::ResourceHandler; -use httprequest::HttpRequest; -use server::ServerSettings; - -/// Interface for application router. -pub struct Router(Rc); - -struct Inner { - prefix: String, - prefix_len: usize, - named: HashMap, - patterns: Vec, - srv: ServerSettings, -} - -impl Router { - /// Create new router - pub fn new(prefix: &str, - settings: ServerSettings, - map: Vec<(Resource, Option>)>) - -> (Router, Vec>) - { - let prefix = prefix.trim().trim_right_matches('/').to_owned(); - let mut named = HashMap::new(); - let mut patterns = Vec::new(); - let mut resources = Vec::new(); - - for (pattern, resource) in map { - if !pattern.name().is_empty() { - let name = pattern.name().into(); - named.insert(name, (pattern.clone(), resource.is_none())); - } - - if let Some(resource) = resource { - patterns.push(pattern); - resources.push(resource); - } - } - - let prefix_len = prefix.len(); - (Router(Rc::new( - Inner{ prefix, prefix_len, named, patterns, srv: settings })), resources) - } - - /// Router prefix - #[inline] - pub fn prefix(&self) -> &str { - &self.0.prefix - } - - /// Server settings - #[inline] - pub fn server_settings(&self) -> &ServerSettings { - &self.0.srv - } - - pub(crate) fn get_resource(&self, idx: usize) -> &Resource { - &self.0.patterns[idx] - } - - /// Query for matched resource - pub fn recognize(&self, req: &mut HttpRequest) -> Option { - if self.0.prefix_len > req.path().len() { - return None - } - let path: &str = unsafe{mem::transmute(&req.path()[self.0.prefix_len..])}; - let route_path = if path.is_empty() { "/" } else { path }; - let p = percent_decode(route_path.as_bytes()).decode_utf8().unwrap(); - - for (idx, pattern) in self.0.patterns.iter().enumerate() { - if pattern.match_with_params(p.as_ref(), req.match_info_mut()) { - req.set_resource(idx); - return Some(idx) - } - } - None - } - - /// Check if application contains matching route. - /// - /// This method does not take `prefix` into account. - /// For example if prefix is `/test` and router contains route `/name`, - /// following path would be recognizable `/test/name` but `has_route()` call - /// would return `false`. - pub fn has_route(&self, path: &str) -> bool { - let path = if path.is_empty() { "/" } else { path }; - - for pattern in &self.0.patterns { - if pattern.is_match(path) { - return true - } - } - false - } - - /// Build named resource path. - /// - /// Check [`HttpRequest::url_for()`](../struct.HttpRequest.html#method.url_for) - /// for detailed information. - pub fn resource_path(&self, name: &str, elements: U) - -> Result - where U: IntoIterator, - I: AsRef, - { - if let Some(pattern) = self.0.named.get(name) { - pattern.0.resource_path(self, elements) - } else { - Err(UrlGenerationError::ResourceNotFound) - } - } -} - -impl Clone for Router { - fn clone(&self) -> Router { - Router(Rc::clone(&self.0)) - } -} - -#[derive(Debug, Clone, PartialEq)] -enum PatternElement { - Str(String), - Var(String), -} - -#[derive(Clone, Debug)] -enum PatternType { - Static(String), - Dynamic(Regex, Vec), -} - -#[derive(Debug, Copy, Clone, PartialEq)] -/// Resource type -pub enum ResourceType { - /// Normal resource - Normal, - /// Resource for applicaiton default handler - Default, - /// External resource - External, - /// Unknown resource type - Unset, -} - -/// Reslource type describes an entry in resources table -#[derive(Clone)] -pub struct Resource { - tp: PatternType, - rtp: ResourceType, - name: String, - pattern: String, - elements: Vec, -} - -impl Resource { - /// Parse path pattern and create new `Resource` instance. - /// - /// Panics if path pattern is wrong. - pub fn new(name: &str, path: &str) -> Self { - Resource::with_prefix(name, path, "/") - } - - /// Construct external resource - /// - /// Panics if path pattern is wrong. - pub fn external(name: &str, path: &str) -> Self { - let mut resource = Resource::with_prefix(name, path, "/"); - resource.rtp = ResourceType::External; - resource - } - - /// Unset resource type - pub(crate) fn unset() -> Resource { - Resource { - tp: PatternType::Static("".to_owned()), - rtp: ResourceType::Unset, - name: "".to_owned(), - pattern: "".to_owned(), - elements: Vec::new(), - } - } - - /// Parse path pattern and create new `Resource` instance with custom prefix - pub fn with_prefix(name: &str, path: &str, prefix: &str) -> Self { - let (pattern, elements, is_dynamic) = Resource::parse(path, prefix); - - let tp = if is_dynamic { - let re = match Regex::new(&pattern) { - Ok(re) => re, - Err(err) => panic!("Wrong path pattern: \"{}\" {}", path, err) - }; - let names = re.capture_names() - .filter_map(|name| name.map(|name| name.to_owned())) - .collect(); - PatternType::Dynamic(re, names) - } else { - PatternType::Static(pattern.clone()) - }; - - Resource { - tp, - elements, - name: name.into(), - rtp: ResourceType::Normal, - pattern: path.to_owned(), - } - } - - /// Name of the resource - pub fn name(&self) -> &str { - &self.name - } - - /// Resource type - pub fn rtype(&self) -> ResourceType { - self.rtp - } - - /// Path pattern of the resource - pub fn pattern(&self) -> &str { - &self.pattern - } - - pub fn is_match(&self, path: &str) -> bool { - match self.tp { - PatternType::Static(ref s) => s == path, - PatternType::Dynamic(ref re, _) => re.is_match(path), - } - } - - pub fn match_with_params<'a>(&'a self, path: &'a str, params: &'a mut Params<'a>) - -> bool - { - match self.tp { - PatternType::Static(ref s) => s == path, - PatternType::Dynamic(ref re, ref names) => { - if let Some(captures) = re.captures(path) { - let mut idx = 0; - for capture in captures.iter() { - if let Some(ref m) = capture { - if idx != 0 { - params.add(names[idx-1].as_str(), m.as_str()); - } - idx += 1; - } - } - true - } else { - false - } - } - } - } - - /// Build reousrce path. - pub fn resource_path(&self, router: &Router, elements: U) - -> Result - where U: IntoIterator, - I: AsRef, - { - let mut iter = elements.into_iter(); - let mut path = if self.rtp != ResourceType::External { - format!("{}/", router.prefix()) - } else { - String::new() - }; - for el in &self.elements { - match *el { - PatternElement::Str(ref s) => path.push_str(s), - PatternElement::Var(_) => { - if let Some(val) = iter.next() { - path.push_str(val.as_ref()) - } else { - return Err(UrlGenerationError::NotEnoughElements) - } - } - } - } - Ok(path) - } - - fn parse(pattern: &str, prefix: &str) -> (String, Vec, bool) { - const DEFAULT_PATTERN: &str = "[^/]+"; - - let mut re1 = String::from("^") + prefix; - let mut re2 = String::from(prefix); - let mut el = String::new(); - let mut in_param = false; - let mut in_param_pattern = false; - let mut param_name = String::new(); - let mut param_pattern = String::from(DEFAULT_PATTERN); - let mut is_dynamic = false; - let mut elems = Vec::new(); - - for (index, ch) in pattern.chars().enumerate() { - // All routes must have a leading slash so its optional to have one - if index == 0 && ch == '/' { - continue; - } - - if in_param { - // In parameter segment: `{....}` - if ch == '}' { - elems.push(PatternElement::Var(param_name.clone())); - re1.push_str(&format!(r"(?P<{}>{})", ¶m_name, ¶m_pattern)); - - param_name.clear(); - param_pattern = String::from(DEFAULT_PATTERN); - - in_param_pattern = false; - in_param = false; - } else if ch == ':' { - // The parameter name has been determined; custom pattern land - in_param_pattern = true; - param_pattern.clear(); - } else if in_param_pattern { - // Ignore leading whitespace for pattern - if !(ch == ' ' && param_pattern.is_empty()) { - param_pattern.push(ch); - } - } else { - param_name.push(ch); - } - } else if ch == '{' { - in_param = true; - is_dynamic = true; - elems.push(PatternElement::Str(el.clone())); - el.clear(); - } else { - re1.push_str(escape(&ch.to_string()).as_str()); - re2.push(ch); - el.push(ch); - } - } - - let re = if is_dynamic { - re1.push('$'); - re1 - } else { - re2 - }; - (re, elems, is_dynamic) - } -} - -impl PartialEq for Resource { - fn eq(&self, other: &Resource) -> bool { - self.pattern == other.pattern - } -} - -impl Eq for Resource {} - -impl Hash for Resource { - fn hash(&self, state: &mut H) { - self.pattern.hash(state); - } -} - -#[cfg(test)] -mod tests { - use super::*; - use test::TestRequest; - - #[test] - fn test_recognizer() { - let routes = vec![ - (Resource::new("", "/name"), - Some(ResourceHandler::default())), - (Resource::new("", "/name/{val}"), - Some(ResourceHandler::default())), - (Resource::new("", "/name/{val}/index.html"), - Some(ResourceHandler::default())), - (Resource::new("", "/file/{file}.{ext}"), - Some(ResourceHandler::default())), - (Resource::new("", "/v{val}/{val2}/index.html"), - Some(ResourceHandler::default())), - (Resource::new("", "/v/{tail:.*}"), - Some(ResourceHandler::default())), - (Resource::new("", "{test}/index.html"), - Some(ResourceHandler::default()))]; - let (rec, _) = Router::new::<()>("", ServerSettings::default(), routes); - - let mut req = TestRequest::with_uri("/name").finish(); - assert_eq!(rec.recognize(&mut req), Some(0)); - assert!(req.match_info().is_empty()); - - let mut req = TestRequest::with_uri("/name/value").finish(); - assert_eq!(rec.recognize(&mut req), Some(1)); - assert_eq!(req.match_info().get("val").unwrap(), "value"); - assert_eq!(&req.match_info()["val"], "value"); - - let mut req = TestRequest::with_uri("/name/value2/index.html").finish(); - assert_eq!(rec.recognize(&mut req), Some(2)); - assert_eq!(req.match_info().get("val").unwrap(), "value2"); - - let mut req = TestRequest::with_uri("/file/file.gz").finish(); - assert_eq!(rec.recognize(&mut req), Some(3)); - assert_eq!(req.match_info().get("file").unwrap(), "file"); - assert_eq!(req.match_info().get("ext").unwrap(), "gz"); - - let mut req = TestRequest::with_uri("/vtest/ttt/index.html").finish(); - assert_eq!(rec.recognize(&mut req), Some(4)); - assert_eq!(req.match_info().get("val").unwrap(), "test"); - assert_eq!(req.match_info().get("val2").unwrap(), "ttt"); - - let mut req = TestRequest::with_uri("/v/blah-blah/index.html").finish(); - assert_eq!(rec.recognize(&mut req), Some(5)); - assert_eq!(req.match_info().get("tail").unwrap(), "blah-blah/index.html"); - - let mut req = TestRequest::with_uri("/bbb/index.html").finish(); - assert_eq!(rec.recognize(&mut req), Some(6)); - assert_eq!(req.match_info().get("test").unwrap(), "bbb"); - } - - #[test] - fn test_recognizer_2() { - let routes = vec![ - (Resource::new("", "/index.json"), Some(ResourceHandler::default())), - (Resource::new("", "/{source}.json"), Some(ResourceHandler::default()))]; - let (rec, _) = Router::new::<()>("", ServerSettings::default(), routes); - - let mut req = TestRequest::with_uri("/index.json").finish(); - assert_eq!(rec.recognize(&mut req), Some(0)); - - let mut req = TestRequest::with_uri("/test.json").finish(); - assert_eq!(rec.recognize(&mut req), Some(1)); - } - - #[test] - fn test_recognizer_with_prefix() { - let routes = vec![ - (Resource::new("", "/name"), Some(ResourceHandler::default())), - (Resource::new("", "/name/{val}"), Some(ResourceHandler::default()))]; - let (rec, _) = Router::new::<()>("/test", ServerSettings::default(), routes); - - let mut req = TestRequest::with_uri("/name").finish(); - assert!(rec.recognize(&mut req).is_none()); - - let mut req = TestRequest::with_uri("/test/name").finish(); - assert_eq!(rec.recognize(&mut req), Some(0)); - - let mut req = TestRequest::with_uri("/test/name/value").finish(); - assert_eq!(rec.recognize(&mut req), Some(1)); - assert_eq!(req.match_info().get("val").unwrap(), "value"); - assert_eq!(&req.match_info()["val"], "value"); - - // same patterns - let routes = vec![ - (Resource::new("", "/name"), Some(ResourceHandler::default())), - (Resource::new("", "/name/{val}"), Some(ResourceHandler::default()))]; - let (rec, _) = Router::new::<()>("/test2", ServerSettings::default(), routes); - - let mut req = TestRequest::with_uri("/name").finish(); - assert!(rec.recognize(&mut req).is_none()); - let mut req = TestRequest::with_uri("/test2/name").finish(); - assert_eq!(rec.recognize(&mut req), Some(0)); - let mut req = TestRequest::with_uri("/test2/name-test").finish(); - assert!(rec.recognize(&mut req).is_none()); - let mut req = TestRequest::with_uri("/test2/name/ttt").finish(); - assert_eq!(rec.recognize(&mut req), Some(1)); - assert_eq!(&req.match_info()["val"], "ttt"); - } - - #[test] - fn test_parse_static() { - let re = Resource::new("test", "/"); - assert!(re.is_match("/")); - assert!(!re.is_match("/a")); - - let re = Resource::new("test", "/name"); - assert!(re.is_match("/name")); - assert!(!re.is_match("/name1")); - assert!(!re.is_match("/name/")); - assert!(!re.is_match("/name~")); - - let re = Resource::new("test", "/name/"); - assert!(re.is_match("/name/")); - assert!(!re.is_match("/name")); - assert!(!re.is_match("/name/gs")); - - let re = Resource::new("test", "/user/profile"); - assert!(re.is_match("/user/profile")); - assert!(!re.is_match("/user/profile/profile")); - } - - #[test] - fn test_parse_param() { - let mut req = HttpRequest::default(); - - let re = Resource::new("test", "/user/{id}"); - assert!(re.is_match("/user/profile")); - assert!(re.is_match("/user/2345")); - assert!(!re.is_match("/user/2345/")); - assert!(!re.is_match("/user/2345/sdg")); - - req.match_info_mut().clear(); - assert!(re.match_with_params("/user/profile", req.match_info_mut())); - assert_eq!(req.match_info().get("id").unwrap(), "profile"); - - req.match_info_mut().clear(); - assert!(re.match_with_params("/user/1245125", req.match_info_mut())); - assert_eq!(req.match_info().get("id").unwrap(), "1245125"); - - let re = Resource::new("test", "/v{version}/resource/{id}"); - assert!(re.is_match("/v1/resource/320120")); - assert!(!re.is_match("/v/resource/1")); - assert!(!re.is_match("/resource")); - - req.match_info_mut().clear(); - assert!(re.match_with_params("/v151/resource/adahg32", req.match_info_mut())); - assert_eq!(req.match_info().get("version").unwrap(), "151"); - assert_eq!(req.match_info().get("id").unwrap(), "adahg32"); - } - - #[test] - fn test_request_resource() { - let routes = vec![ - (Resource::new("r1", "/index.json"), Some(ResourceHandler::default())), - (Resource::new("r2", "/test.json"), Some(ResourceHandler::default()))]; - let (router, _) = Router::new::<()>("", ServerSettings::default(), routes); - - let mut req = TestRequest::with_uri("/index.json") - .finish_with_router(router.clone()); - assert_eq!(router.recognize(&mut req), Some(0)); - let resource = req.resource(); - assert_eq!(resource.name(), "r1"); - - let mut req = TestRequest::with_uri("/test.json") - .finish_with_router(router.clone()); - assert_eq!(router.recognize(&mut req), Some(1)); - let resource = req.resource(); - assert_eq!(resource.name(), "r2"); - } -} diff --git a/src/scope.rs b/src/scope.rs new file mode 100644 index 000000000..db6f5da57 --- /dev/null +++ b/src/scope.rs @@ -0,0 +1,1231 @@ +use std::cell::RefCell; +use std::fmt; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; + +use actix_http::{Extensions, Response}; +use actix_router::{ResourceDef, ResourceInfo, Router}; +use actix_service::boxed::{self, BoxService, BoxServiceFactory}; +use actix_service::{ + apply, apply_fn_factory, IntoServiceFactory, Service, ServiceFactory, Transform, +}; +use futures::future::{ok, Either, Future, LocalBoxFuture, Ready}; + +use crate::config::ServiceConfig; +use crate::data::Data; +use crate::dev::{AppService, HttpServiceFactory}; +use crate::error::Error; +use crate::guard::Guard; +use crate::resource::Resource; +use crate::rmap::ResourceMap; +use crate::route::Route; +use crate::service::{ + AppServiceFactory, ServiceFactoryWrapper, ServiceRequest, ServiceResponse, +}; + +type Guards = Vec>; +type HttpService = BoxService; +type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; +type BoxedResponse = LocalBoxFuture<'static, Result>; + +/// Resources scope. +/// +/// Scope is a set of resources with common root path. +/// Scopes collect multiple paths under a common path prefix. +/// Scope path can contain variable path segments as resources. +/// Scope prefix is always complete path segment, i.e `/app` would +/// be converted to a `/app/` and it would not match `/app` path. +/// +/// You can get variable path segments from `HttpRequest::match_info()`. +/// `Path` extractor also is able to extract scope level variable segments. +/// +/// ```rust +/// use actix_web::{web, App, HttpResponse}; +/// +/// fn main() { +/// let app = App::new().service( +/// web::scope("/{project_id}/") +/// .service(web::resource("/path1").to(|| async { HttpResponse::Ok() })) +/// .service(web::resource("/path2").route(web::get().to(|| HttpResponse::Ok()))) +/// .service(web::resource("/path3").route(web::head().to(|| HttpResponse::MethodNotAllowed()))) +/// ); +/// } +/// ``` +/// +/// In the above example three routes get registered: +/// * /{project_id}/path1 - reponds to all http method +/// * /{project_id}/path2 - `GET` requests +/// * /{project_id}/path3 - `HEAD` requests +/// +pub struct Scope { + endpoint: T, + rdef: String, + data: Option, + services: Vec>, + guards: Vec>, + default: Rc>>>, + external: Vec, + factory_ref: Rc>>, +} + +impl Scope { + /// Create a new scope + pub fn new(path: &str) -> Scope { + let fref = Rc::new(RefCell::new(None)); + Scope { + endpoint: ScopeEndpoint::new(fref.clone()), + rdef: path.to_string(), + data: None, + guards: Vec::new(), + services: Vec::new(), + default: Rc::new(RefCell::new(None)), + external: Vec::new(), + factory_ref: fref, + } + } +} + +impl Scope +where + T: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + >, +{ + /// Add match guard to a scope. + /// + /// ```rust + /// use actix_web::{web, guard, App, HttpRequest, HttpResponse}; + /// + /// async fn index(data: web::Path<(String, String)>) -> &'static str { + /// "Welcome!" + /// } + /// + /// fn main() { + /// let app = App::new().service( + /// web::scope("/app") + /// .guard(guard::Header("content-type", "text/plain")) + /// .route("/test1", web::get().to(index)) + /// .route("/test2", web::post().to(|r: HttpRequest| { + /// HttpResponse::MethodNotAllowed() + /// })) + /// ); + /// } + /// ``` + pub fn guard(mut self, guard: G) -> Self { + self.guards.push(Box::new(guard)); + self + } + + /// Set or override application data. Application data could be accessed + /// by using `Data` extractor where `T` is data type. + /// + /// ```rust + /// use std::cell::Cell; + /// use actix_web::{web, App, HttpResponse, Responder}; + /// + /// struct MyData { + /// counter: Cell, + /// } + /// + /// async fn index(data: web::Data) -> impl Responder { + /// data.counter.set(data.counter.get() + 1); + /// HttpResponse::Ok() + /// } + /// + /// fn main() { + /// let app = App::new().service( + /// web::scope("/app") + /// .data(MyData{ counter: Cell::new(0) }) + /// .service( + /// web::resource("/index.html").route( + /// web::get().to(index))) + /// ); + /// } + /// ``` + pub fn data(self, data: U) -> Self { + self.register_data(Data::new(data)) + } + + /// Set or override application data. + /// + /// This method has the same effect as [`Scope::data`](#method.data), except + /// that instead of taking a value of some type `T`, it expects a value of + /// type `Data`. Use a `Data` extractor to retrieve its value. + pub fn register_data(mut self, data: Data) -> Self { + if self.data.is_none() { + self.data = Some(Extensions::new()); + } + self.data.as_mut().unwrap().insert(data); + self + } + + /// Run external configuration as part of the scope building + /// process + /// + /// This function is useful for moving parts of configuration to a + /// different module or even library. For example, + /// some of the resource's configuration could be moved to different module. + /// + /// ```rust + /// # extern crate actix_web; + /// use actix_web::{web, middleware, App, HttpResponse}; + /// + /// // this function could be located in different module + /// fn config(cfg: &mut web::ServiceConfig) { + /// cfg.service(web::resource("/test") + /// .route(web::get().to(|| HttpResponse::Ok())) + /// .route(web::head().to(|| HttpResponse::MethodNotAllowed())) + /// ); + /// } + /// + /// fn main() { + /// let app = App::new() + /// .wrap(middleware::Logger::default()) + /// .service( + /// web::scope("/api") + /// .configure(config) + /// ) + /// .route("/index.html", web::get().to(|| HttpResponse::Ok())); + /// } + /// ``` + pub fn configure(mut self, f: F) -> Self + where + F: FnOnce(&mut ServiceConfig), + { + let mut cfg = ServiceConfig::new(); + f(&mut cfg); + self.services.extend(cfg.services); + self.external.extend(cfg.external); + + if !cfg.data.is_empty() { + let mut data = self.data.unwrap_or_else(Extensions::new); + + for value in cfg.data.iter() { + value.create(&mut data); + } + + self.data = Some(data); + } + self + } + + /// Register http service. + /// + /// This is similar to `App's` service registration. + /// + /// Actix web provides several services implementations: + /// + /// * *Resource* is an entry in resource table which corresponds to requested URL. + /// * *Scope* is a set of resources with common root path. + /// * "StaticFiles" is a service for static files support + /// + /// ```rust + /// use actix_web::{web, App, HttpRequest}; + /// + /// struct AppState; + /// + /// async fn index(req: HttpRequest) -> &'static str { + /// "Welcome!" + /// } + /// + /// fn main() { + /// let app = App::new().service( + /// web::scope("/app").service( + /// web::scope("/v1") + /// .service(web::resource("/test1").to(index))) + /// ); + /// } + /// ``` + pub fn service(mut self, factory: F) -> Self + where + F: HttpServiceFactory + 'static, + { + self.services + .push(Box::new(ServiceFactoryWrapper::new(factory))); + self + } + + /// Configure route for a specific path. + /// + /// This is a simplified version of the `Scope::service()` method. + /// This method can be called multiple times, in that case + /// multiple resources with one route would be registered for same resource path. + /// + /// ```rust + /// use actix_web::{web, App, HttpResponse}; + /// + /// async fn index(data: web::Path<(String, String)>) -> &'static str { + /// "Welcome!" + /// } + /// + /// fn main() { + /// let app = App::new().service( + /// web::scope("/app") + /// .route("/test1", web::get().to(index)) + /// .route("/test2", web::post().to(|| HttpResponse::MethodNotAllowed())) + /// ); + /// } + /// ``` + pub fn route(self, path: &str, mut route: Route) -> Self { + self.service( + Resource::new(path) + .add_guards(route.take_guards()) + .route(route), + ) + } + + /// Default service to be used if no matching route could be found. + /// + /// If default resource is not registered, app's default resource is being used. + pub fn default_service(mut self, f: F) -> Self + where + F: IntoServiceFactory, + U: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + > + 'static, + U::InitError: fmt::Debug, + { + // create and configure default resource + self.default = Rc::new(RefCell::new(Some(Rc::new(boxed::factory( + f.into_factory().map_init_err(|e| { + log::error!("Can not construct default service: {:?}", e) + }), + ))))); + + self + } + + /// Registers middleware, in the form of a middleware component (type), + /// that runs during inbound processing in the request + /// lifecycle (request -> response), modifying request as + /// necessary, across all requests managed by the *Scope*. Scope-level + /// middleware is more limited in what it can modify, relative to Route or + /// Application level middleware, in that Scope-level middleware can not modify + /// ServiceResponse. + /// + /// Use middleware when you need to read or modify *every* request in some way. + pub fn wrap( + self, + mw: M, + ) -> Scope< + impl ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + >, + > + where + M: Transform< + T::Service, + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + >, + { + Scope { + endpoint: apply(mw, self.endpoint), + rdef: self.rdef, + data: self.data, + guards: self.guards, + services: self.services, + default: self.default, + external: self.external, + factory_ref: self.factory_ref, + } + } + + /// Registers middleware, in the form of a closure, that runs during inbound + /// processing in the request lifecycle (request -> response), modifying + /// request as necessary, across all requests managed by the *Scope*. + /// Scope-level middleware is more limited in what it can modify, relative + /// to Route or Application level middleware, in that Scope-level middleware + /// can not modify ServiceResponse. + /// + /// ```rust + /// use actix_service::Service; + /// use actix_web::{web, App}; + /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; + /// + /// async fn index() -> &'static str { + /// "Welcome!" + /// } + /// + /// fn main() { + /// let app = App::new().service( + /// web::scope("/app") + /// .wrap_fn(|req, srv| { + /// let fut = srv.call(req); + /// async { + /// let mut res = fut.await?; + /// res.headers_mut().insert( + /// CONTENT_TYPE, HeaderValue::from_static("text/plain"), + /// ); + /// Ok(res) + /// } + /// }) + /// .route("/index.html", web::get().to(index))); + /// } + /// ``` + pub fn wrap_fn( + self, + mw: F, + ) -> Scope< + impl ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + >, + > + where + F: FnMut(ServiceRequest, &mut T::Service) -> R + Clone, + R: Future>, + { + Scope { + endpoint: apply_fn_factory(self.endpoint, mw), + rdef: self.rdef, + data: self.data, + guards: self.guards, + services: self.services, + default: self.default, + external: self.external, + factory_ref: self.factory_ref, + } + } +} + +impl HttpServiceFactory for Scope +where + T: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + > + 'static, +{ + fn register(mut self, config: &mut AppService) { + // update default resource if needed + if self.default.borrow().is_none() { + *self.default.borrow_mut() = Some(config.default_service()); + } + + // register nested services + let mut cfg = config.clone_config(); + self.services + .into_iter() + .for_each(|mut srv| srv.register(&mut cfg)); + + let mut rmap = ResourceMap::new(ResourceDef::root_prefix(&self.rdef)); + + // external resources + for mut rdef in std::mem::replace(&mut self.external, Vec::new()) { + rmap.add(&mut rdef, None); + } + + // custom app data storage + if let Some(ref mut ext) = self.data { + config.set_service_data(ext); + } + + // complete scope pipeline creation + *self.factory_ref.borrow_mut() = Some(ScopeFactory { + data: self.data.take().map(Rc::new), + default: self.default.clone(), + services: Rc::new( + cfg.into_services() + .1 + .into_iter() + .map(|(mut rdef, srv, guards, nested)| { + rmap.add(&mut rdef, nested); + (rdef, srv, RefCell::new(guards)) + }) + .collect(), + ), + }); + + // get guards + let guards = if self.guards.is_empty() { + None + } else { + Some(self.guards) + }; + + // register final service + config.register_service( + ResourceDef::root_prefix(&self.rdef), + guards, + self.endpoint, + Some(Rc::new(rmap)), + ) + } +} + +pub struct ScopeFactory { + data: Option>, + services: Rc>)>>, + default: Rc>>>, +} + +impl ServiceFactory for ScopeFactory { + type Config = (); + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Service = ScopeService; + type Future = ScopeFactoryResponse; + + fn new_service(&self, _: &()) -> Self::Future { + let default_fut = if let Some(ref default) = *self.default.borrow() { + Some(default.new_service(&())) + } else { + None + }; + + ScopeFactoryResponse { + fut: self + .services + .iter() + .map(|(path, service, guards)| { + CreateScopeServiceItem::Future( + Some(path.clone()), + guards.borrow_mut().take(), + service.new_service(&()), + ) + }) + .collect(), + default: None, + data: self.data.clone(), + default_fut, + } + } +} + +/// Create scope service +#[doc(hidden)] +#[pin_project::pin_project] +pub struct ScopeFactoryResponse { + fut: Vec, + data: Option>, + default: Option, + default_fut: Option>>, +} + +type HttpServiceFut = LocalBoxFuture<'static, Result>; + +enum CreateScopeServiceItem { + Future(Option, Option, HttpServiceFut), + Service(ResourceDef, Option, HttpService), +} + +impl Future for ScopeFactoryResponse { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut done = true; + + if let Some(ref mut fut) = self.default_fut { + match Pin::new(fut).poll(cx)? { + Poll::Ready(default) => self.default = Some(default), + Poll::Pending => done = false, + } + } + + // poll http services + for item in &mut self.fut { + let res = match item { + CreateScopeServiceItem::Future( + ref mut path, + ref mut guards, + ref mut fut, + ) => match Pin::new(fut).poll(cx)? { + Poll::Ready(service) => { + Some((path.take().unwrap(), guards.take(), service)) + } + Poll::Pending => { + done = false; + None + } + }, + CreateScopeServiceItem::Service(_, _, _) => continue, + }; + + if let Some((path, guards, service)) = res { + *item = CreateScopeServiceItem::Service(path, guards, service); + } + } + + if done { + let router = self + .fut + .drain(..) + .fold(Router::build(), |mut router, item| { + match item { + CreateScopeServiceItem::Service(path, guards, service) => { + router.rdef(path, service).2 = guards; + } + CreateScopeServiceItem::Future(_, _, _) => unreachable!(), + } + router + }); + Poll::Ready(Ok(ScopeService { + data: self.data.clone(), + router: router.finish(), + default: self.default.take(), + _ready: None, + })) + } else { + Poll::Pending + } + } +} + +pub struct ScopeService { + data: Option>, + router: Router>>, + default: Option, + _ready: Option<(ServiceRequest, ResourceInfo)>, +} + +impl Service for ScopeService { + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Future = Either>>; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, mut req: ServiceRequest) -> Self::Future { + let res = self.router.recognize_mut_checked(&mut req, |req, guards| { + if let Some(ref guards) = guards { + for f in guards { + if !f.check(req.head()) { + return false; + } + } + } + true + }); + + if let Some((srv, _info)) = res { + if let Some(ref data) = self.data { + req.set_data_container(data.clone()); + } + Either::Left(srv.call(req)) + } else if let Some(ref mut default) = self.default { + Either::Left(default.call(req)) + } else { + let req = req.into_parts().0; + Either::Right(ok(ServiceResponse::new(req, Response::NotFound().finish()))) + } + } +} + +#[doc(hidden)] +pub struct ScopeEndpoint { + factory: Rc>>, +} + +impl ScopeEndpoint { + fn new(factory: Rc>>) -> Self { + ScopeEndpoint { factory } + } +} + +impl ServiceFactory for ScopeEndpoint { + type Config = (); + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Service = ScopeService; + type Future = ScopeFactoryResponse; + + fn new_service(&self, _: &()) -> Self::Future { + self.factory.borrow_mut().as_mut().unwrap().new_service(&()) + } +} + +#[cfg(test)] +mod tests { + use actix_service::Service; + use bytes::Bytes; + use futures::future::ok; + + use crate::dev::{Body, ResponseBody}; + use crate::http::{header, HeaderValue, Method, StatusCode}; + use crate::middleware::DefaultHeaders; + use crate::service::ServiceRequest; + use crate::test::{call_service, init_service, read_body, TestRequest}; + use crate::{guard, web, App, HttpRequest, HttpResponse}; + + #[actix_rt::test] + async fn test_scope() { + let mut srv = init_service( + App::new().service( + web::scope("/app") + .service(web::resource("/path1").to(|| HttpResponse::Ok())), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_scope_root() { + let mut srv = init_service( + App::new().service( + web::scope("/app") + .service(web::resource("").to(|| HttpResponse::Ok())) + .service(web::resource("/").to(|| HttpResponse::Created())), + ), + ) + .await; + + let req = TestRequest::with_uri("/app").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/app/").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + } + + #[actix_rt::test] + async fn test_scope_root2() { + let mut srv = init_service(App::new().service( + web::scope("/app/").service(web::resource("").to(|| HttpResponse::Ok())), + )) + .await; + + let req = TestRequest::with_uri("/app").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + let req = TestRequest::with_uri("/app/").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_scope_root3() { + let mut srv = init_service(App::new().service( + web::scope("/app/").service(web::resource("/").to(|| HttpResponse::Ok())), + )) + .await; + + let req = TestRequest::with_uri("/app").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + let req = TestRequest::with_uri("/app/").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + } + + #[actix_rt::test] + async fn test_scope_route() { + let mut srv = init_service( + App::new().service( + web::scope("app") + .route("/path1", web::get().to(|| HttpResponse::Ok())) + .route("/path1", web::delete().to(|| HttpResponse::Ok())), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/app/path1") + .method(Method::DELETE) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/app/path1") + .method(Method::POST) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + } + + #[actix_rt::test] + async fn test_scope_route_without_leading_slash() { + let mut srv = init_service( + App::new().service( + web::scope("app").service( + web::resource("path1") + .route(web::get().to(|| HttpResponse::Ok())) + .route(web::delete().to(|| HttpResponse::Ok())), + ), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/app/path1") + .method(Method::DELETE) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/app/path1") + .method(Method::POST) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + } + + #[actix_rt::test] + async fn test_scope_guard() { + let mut srv = init_service( + App::new().service( + web::scope("/app") + .guard(guard::Get()) + .service(web::resource("/path1").to(|| HttpResponse::Ok())), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/path1") + .method(Method::POST) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + let req = TestRequest::with_uri("/app/path1") + .method(Method::GET) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_scope_variable_segment() { + let mut srv = + init_service(App::new().service(web::scope("/ab-{project}").service( + web::resource("/path1").to(|r: HttpRequest| { + async move { + HttpResponse::Ok() + .body(format!("project: {}", &r.match_info()["project"])) + } + }), + ))) + .await; + + let req = TestRequest::with_uri("/ab-project1/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + match resp.response().body() { + ResponseBody::Body(Body::Bytes(ref b)) => { + let bytes: Bytes = b.clone().into(); + assert_eq!(bytes, Bytes::from_static(b"project: project1")); + } + _ => panic!(), + } + + let req = TestRequest::with_uri("/aa-project1/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + } + + #[actix_rt::test] + async fn test_nested_scope() { + let mut srv = init_service( + App::new().service( + web::scope("/app") + .service(web::scope("/t1").service( + web::resource("/path1").to(|| HttpResponse::Created()), + )), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/t1/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + } + + #[actix_rt::test] + async fn test_nested_scope_no_slash() { + let mut srv = init_service( + App::new().service( + web::scope("/app") + .service(web::scope("t1").service( + web::resource("/path1").to(|| HttpResponse::Created()), + )), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/t1/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + } + + #[actix_rt::test] + async fn test_nested_scope_root() { + let mut srv = init_service( + App::new().service( + web::scope("/app").service( + web::scope("/t1") + .service(web::resource("").to(|| HttpResponse::Ok())) + .service(web::resource("/").to(|| HttpResponse::Created())), + ), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/t1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/app/t1/").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + } + + #[actix_rt::test] + async fn test_nested_scope_filter() { + let mut srv = init_service( + App::new().service( + web::scope("/app").service( + web::scope("/t1") + .guard(guard::Get()) + .service(web::resource("/path1").to(|| HttpResponse::Ok())), + ), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/t1/path1") + .method(Method::POST) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + let req = TestRequest::with_uri("/app/t1/path1") + .method(Method::GET) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_nested_scope_with_variable_segment() { + let mut srv = init_service(App::new().service(web::scope("/app").service( + web::scope("/{project_id}").service(web::resource("/path1").to( + |r: HttpRequest| { + async move { + HttpResponse::Created() + .body(format!("project: {}", &r.match_info()["project_id"])) + } + }, + )), + ))) + .await; + + let req = TestRequest::with_uri("/app/project_1/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + + match resp.response().body() { + ResponseBody::Body(Body::Bytes(ref b)) => { + let bytes: Bytes = b.clone().into(); + assert_eq!(bytes, Bytes::from_static(b"project: project_1")); + } + _ => panic!(), + } + } + + #[actix_rt::test] + async fn test_nested2_scope_with_variable_segment() { + let mut srv = init_service(App::new().service(web::scope("/app").service( + web::scope("/{project}").service(web::scope("/{id}").service( + web::resource("/path1").to(|r: HttpRequest| { + async move { + HttpResponse::Created().body(format!( + "project: {} - {}", + &r.match_info()["project"], + &r.match_info()["id"], + )) + } + }), + )), + ))) + .await; + + let req = TestRequest::with_uri("/app/test/1/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + + match resp.response().body() { + ResponseBody::Body(Body::Bytes(ref b)) => { + let bytes: Bytes = b.clone().into(); + assert_eq!(bytes, Bytes::from_static(b"project: test - 1")); + } + _ => panic!(), + } + + let req = TestRequest::with_uri("/app/test/1/path2").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + } + + #[actix_rt::test] + async fn test_default_resource() { + let mut srv = init_service( + App::new().service( + web::scope("/app") + .service(web::resource("/path1").to(|| HttpResponse::Ok())) + .default_service(|r: ServiceRequest| { + ok(r.into_response(HttpResponse::BadRequest())) + }), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/path2").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + let req = TestRequest::with_uri("/path2").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + } + + #[actix_rt::test] + async fn test_default_resource_propagation() { + let mut srv = init_service( + App::new() + .service(web::scope("/app1").default_service( + web::resource("").to(|| HttpResponse::BadRequest()), + )) + .service(web::scope("/app2")) + .default_service(|r: ServiceRequest| { + ok(r.into_response(HttpResponse::MethodNotAllowed())) + }), + ) + .await; + + let req = TestRequest::with_uri("/non-exist").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + + let req = TestRequest::with_uri("/app1/non-exist").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + let req = TestRequest::with_uri("/app2/non-exist").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + } + + #[actix_rt::test] + async fn test_middleware() { + let mut srv = + init_service( + App::new().service( + web::scope("app") + .wrap(DefaultHeaders::new().header( + header::CONTENT_TYPE, + HeaderValue::from_static("0001"), + )) + .service( + web::resource("/test") + .route(web::get().to(|| HttpResponse::Ok())), + ), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static("0001") + ); + } + + #[actix_rt::test] + async fn test_middleware_fn() { + let mut srv = init_service( + App::new().service( + web::scope("app") + .wrap_fn(|req, srv| { + let fut = srv.call(req); + async move { + let mut res = fut.await?; + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("0001"), + ); + Ok(res) + } + }) + .route("/test", web::get().to(|| HttpResponse::Ok())), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + HeaderValue::from_static("0001") + ); + } + + #[actix_rt::test] + async fn test_override_data() { + let mut srv = init_service(App::new().data(1usize).service( + web::scope("app").data(10usize).route( + "/t", + web::get().to(|data: web::Data| { + assert_eq!(*data, 10); + let _ = data.clone(); + HttpResponse::Ok() + }), + ), + )) + .await; + + let req = TestRequest::with_uri("/app/t").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_override_register_data() { + let mut srv = init_service( + App::new().register_data(web::Data::new(1usize)).service( + web::scope("app") + .register_data(web::Data::new(10usize)) + .route( + "/t", + web::get().to(|data: web::Data| { + assert_eq!(*data, 10); + let _ = data.clone(); + HttpResponse::Ok() + }), + ), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/t").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_scope_config() { + let mut srv = + init_service(App::new().service(web::scope("/app").configure(|s| { + s.route("/path1", web::get().to(|| HttpResponse::Ok())); + }))) + .await; + + let req = TestRequest::with_uri("/app/path1").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_scope_config_2() { + let mut srv = + init_service(App::new().service(web::scope("/app").configure(|s| { + s.service(web::scope("/v1").configure(|s| { + s.route("/", web::get().to(|| HttpResponse::Ok())); + })); + }))) + .await; + + let req = TestRequest::with_uri("/app/v1/").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_url_for_external() { + let mut srv = + init_service(App::new().service(web::scope("/app").configure(|s| { + s.service(web::scope("/v1").configure(|s| { + s.external_resource( + "youtube", + "https://youtube.com/watch/{video_id}", + ); + s.route( + "/", + web::get().to(|req: HttpRequest| { + async move { + HttpResponse::Ok().body(format!( + "{}", + req.url_for("youtube", &["xxxxxx"]) + .unwrap() + .as_str() + )) + } + }), + ); + })); + }))) + .await; + + let req = TestRequest::with_uri("/app/v1/").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let body = read_body(resp).await; + assert_eq!(body, &b"https://youtube.com/watch/xxxxxx"[..]); + } + + #[actix_rt::test] + async fn test_url_for_nested() { + let mut srv = init_service(App::new().service(web::scope("/a").service( + web::scope("/b").service(web::resource("/c/{stuff}").name("c").route( + web::get().to(|req: HttpRequest| { + async move { + HttpResponse::Ok() + .body(format!("{}", req.url_for("c", &["12345"]).unwrap())) + } + }), + )), + ))) + .await; + + let req = TestRequest::with_uri("/a/b/c/test").to_request(); + let resp = call_service(&mut srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + let body = read_body(resp).await; + assert_eq!( + body, + Bytes::from_static(b"http://localhost:8080/a/b/c/12345") + ); + } +} diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 000000000..a98d06275 --- /dev/null +++ b/src/server.rs @@ -0,0 +1,597 @@ +use std::marker::PhantomData; +use std::sync::Arc; +use std::{fmt, io, net}; + +use actix_http::{body::MessageBody, Error, HttpService, KeepAlive, Request, Response}; +use actix_rt::System; +use actix_server::{Server, ServerBuilder}; +use actix_server_config::ServerConfig; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; +use parking_lot::Mutex; + +use net2::TcpBuilder; + +#[cfg(feature = "openssl")] +use open_ssl::ssl::{SslAcceptor, SslAcceptorBuilder}; +#[cfg(feature = "rustls")] +use rust_tls::ServerConfig as RustlsServerConfig; + +struct Socket { + scheme: &'static str, + addr: net::SocketAddr, +} + +struct Config { + keep_alive: KeepAlive, + client_timeout: u64, + client_shutdown: u64, +} + +/// An HTTP Server. +/// +/// Create new http server with application factory. +/// +/// ```rust +/// use std::io; +/// use actix_web::{web, App, HttpResponse, HttpServer}; +/// +/// fn main() -> io::Result<()> { +/// let sys = actix_rt::System::new("example"); // <- create Actix runtime +/// +/// HttpServer::new( +/// || App::new() +/// .service(web::resource("/").to(|| HttpResponse::Ok()))) +/// .bind("127.0.0.1:59090")? +/// .start(); +/// +/// # actix_rt::System::current().stop(); +/// sys.run() +/// } +/// ``` +pub struct HttpServer +where + F: Fn() -> I + Send + Clone + 'static, + I: IntoServiceFactory, + S: ServiceFactory, + S::Error: Into, + S::InitError: fmt::Debug, + S::Response: Into>, + B: MessageBody, +{ + pub(super) factory: F, + pub(super) host: Option, + config: Arc>, + backlog: i32, + sockets: Vec, + builder: ServerBuilder, + _t: PhantomData<(S, B)>, +} + +impl HttpServer +where + F: Fn() -> I + Send + Clone + 'static, + I: IntoServiceFactory, + S: ServiceFactory, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody + 'static, +{ + /// Create new http server with application factory + pub fn new(factory: F) -> Self { + HttpServer { + factory, + host: None, + config: Arc::new(Mutex::new(Config { + keep_alive: KeepAlive::Timeout(5), + client_timeout: 5000, + client_shutdown: 5000, + })), + backlog: 1024, + sockets: Vec::new(), + builder: ServerBuilder::default(), + _t: PhantomData, + } + } + + /// Set number of workers to start. + /// + /// By default http server uses number of available logical cpu as threads + /// count. + pub fn workers(mut self, num: usize) -> Self { + self.builder = self.builder.workers(num); + self + } + + /// Set the maximum number of pending connections. + /// + /// This refers to the number of clients that can be waiting to be served. + /// Exceeding this number results in the client getting an error when + /// attempting to connect. It should only affect servers under significant + /// load. + /// + /// Generally set in the 64-2048 range. Default value is 2048. + /// + /// This method should be called before `bind()` method call. + pub fn backlog(mut self, backlog: i32) -> Self { + self.backlog = backlog; + self.builder = self.builder.backlog(backlog); + self + } + + /// Sets the maximum per-worker number of concurrent connections. + /// + /// All socket listeners will stop accepting connections when this limit is reached + /// for each worker. + /// + /// By default max connections is set to a 25k. + pub fn maxconn(mut self, num: usize) -> Self { + self.builder = self.builder.maxconn(num); + self + } + + /// Sets the maximum per-worker concurrent connection establish process. + /// + /// All listeners will stop accepting connections when this limit is reached. It + /// can be used to limit the global SSL CPU usage. + /// + /// By default max connections is set to a 256. + pub fn maxconnrate(mut self, num: usize) -> Self { + self.builder = self.builder.maxconnrate(num); + self + } + + /// Set server keep-alive setting. + /// + /// By default keep alive is set to a 5 seconds. + pub fn keep_alive>(self, val: T) -> Self { + self.config.lock().keep_alive = val.into(); + self + } + + /// Set server client timeout in milliseconds for first request. + /// + /// Defines a timeout for reading client request header. If a client does not transmit + /// the entire set headers within this time, the request is terminated with + /// the 408 (Request Time-out) error. + /// + /// To disable timeout set value to 0. + /// + /// By default client timeout is set to 5000 milliseconds. + pub fn client_timeout(self, val: u64) -> Self { + self.config.lock().client_timeout = val; + self + } + + /// Set server connection shutdown timeout in milliseconds. + /// + /// Defines a timeout for shutdown connection. If a shutdown procedure does not complete + /// within this time, the request is dropped. + /// + /// To disable timeout set value to 0. + /// + /// By default client timeout is set to 5000 milliseconds. + pub fn client_shutdown(self, val: u64) -> Self { + self.config.lock().client_shutdown = val; + self + } + + /// Set server host name. + /// + /// Host name is used by application router as a hostname for url + /// generation. Check [ConnectionInfo](./dev/struct.ConnectionInfo. + /// html#method.host) documentation for more information. + pub fn server_hostname>(mut self, val: T) -> Self { + self.host = Some(val.as_ref().to_owned()); + self + } + + /// Stop actix system. + pub fn system_exit(mut self) -> Self { + self.builder = self.builder.system_exit(); + self + } + + /// Disable signal handling + pub fn disable_signals(mut self) -> Self { + self.builder = self.builder.disable_signals(); + self + } + + /// Timeout for graceful workers shutdown. + /// + /// After receiving a stop signal, workers have this much time to finish + /// serving requests. Workers still alive after the timeout are force + /// dropped. + /// + /// By default shutdown timeout sets to 30 seconds. + pub fn shutdown_timeout(mut self, sec: u64) -> Self { + self.builder = self.builder.shutdown_timeout(sec); + self + } + + /// Get addresses of bound sockets. + pub fn addrs(&self) -> Vec { + self.sockets.iter().map(|s| s.addr).collect() + } + + /// Get addresses of bound sockets and the scheme for it. + /// + /// This is useful when the server is bound from different sources + /// with some sockets listening on http and some listening on https + /// and the user should be presented with an enumeration of which + /// socket requires which protocol. + pub fn addrs_with_scheme(&self) -> Vec<(net::SocketAddr, &str)> { + self.sockets.iter().map(|s| (s.addr, s.scheme)).collect() + } + + /// Use listener for accepting incoming connection requests + /// + /// HttpServer does not change any configuration for TcpListener, + /// it needs to be configured before passing it to listen() method. + pub fn listen(mut self, lst: net::TcpListener) -> io::Result { + let cfg = self.config.clone(); + let factory = self.factory.clone(); + let addr = lst.local_addr().unwrap(); + self.sockets.push(Socket { + addr, + scheme: "http", + }); + + self.builder = self.builder.listen( + format!("actix-web-service-{}", addr), + lst, + move || { + let c = cfg.lock(); + HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout) + .finish(factory()) + }, + )?; + Ok(self) + } + + #[cfg(feature = "openssl")] + /// Use listener for accepting incoming tls connection requests + /// + /// This method sets alpn protocols to "h2" and "http/1.1" + pub fn listen_openssl( + self, + lst: net::TcpListener, + builder: SslAcceptorBuilder, + ) -> io::Result { + self.listen_ssl_inner(lst, openssl_acceptor(builder)?) + } + + #[cfg(feature = "openssl")] + fn listen_ssl_inner( + mut self, + lst: net::TcpListener, + acceptor: SslAcceptor, + ) -> io::Result { + use actix_server::ssl::{OpensslAcceptor, SslError}; + use actix_service::pipeline_factory; + + let acceptor = OpensslAcceptor::new(acceptor); + let factory = self.factory.clone(); + let cfg = self.config.clone(); + let addr = lst.local_addr().unwrap(); + self.sockets.push(Socket { + addr, + scheme: "https", + }); + + self.builder = self.builder.listen( + format!("actix-web-service-{}", addr), + lst, + move || { + let c = cfg.lock(); + pipeline_factory(acceptor.clone().map_err(SslError::Ssl)).and_then( + HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout) + .client_disconnect(c.client_shutdown) + .finish(factory()) + .map_err(SslError::Service) + .map_init_err(|_| ()), + ) + }, + )?; + Ok(self) + } + + #[cfg(feature = "rustls")] + /// Use listener for accepting incoming tls connection requests + /// + /// This method sets alpn protocols to "h2" and "http/1.1" + pub fn listen_rustls( + self, + lst: net::TcpListener, + config: RustlsServerConfig, + ) -> io::Result { + self.listen_rustls_inner(lst, config) + } + + #[cfg(feature = "rustls")] + fn listen_rustls_inner( + mut self, + lst: net::TcpListener, + mut config: RustlsServerConfig, + ) -> io::Result { + use actix_server::ssl::{RustlsAcceptor, SslError}; + use actix_service::pipeline_factory; + + let protos = vec!["h2".to_string().into(), "http/1.1".to_string().into()]; + config.set_protocols(&protos); + + let acceptor = RustlsAcceptor::new(config); + let factory = self.factory.clone(); + let cfg = self.config.clone(); + let addr = lst.local_addr().unwrap(); + self.sockets.push(Socket { + addr, + scheme: "https", + }); + + self.builder = self.builder.listen( + format!("actix-web-service-{}", addr), + lst, + move || { + let c = cfg.lock(); + pipeline_factory(acceptor.clone().map_err(SslError::Ssl)).and_then( + HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout) + .client_disconnect(c.client_shutdown) + .finish(factory()) + .map_err(SslError::Service) + .map_init_err(|_| ()), + ) + }, + )?; + Ok(self) + } + + /// The socket address to bind + /// + /// To bind multiple addresses this method can be called multiple times. + pub fn bind(mut self, addr: A) -> io::Result { + let sockets = self.bind2(addr)?; + + for lst in sockets { + self = self.listen(lst)?; + } + + Ok(self) + } + + fn bind2( + &self, + addr: A, + ) -> io::Result> { + let mut err = None; + let mut succ = false; + let mut sockets = Vec::new(); + for addr in addr.to_socket_addrs()? { + match create_tcp_listener(addr, self.backlog) { + Ok(lst) => { + succ = true; + sockets.push(lst); + } + Err(e) => err = Some(e), + } + } + + if !succ { + if let Some(e) = err.take() { + Err(e) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + "Can not bind to address.", + )) + } + } else { + Ok(sockets) + } + } + + #[cfg(feature = "openssl")] + /// Start listening for incoming tls connections. + /// + /// This method sets alpn protocols to "h2" and "http/1.1" + pub fn bind_openssl( + mut self, + addr: A, + builder: SslAcceptorBuilder, + ) -> io::Result + where + A: net::ToSocketAddrs, + { + let sockets = self.bind2(addr)?; + let acceptor = openssl_acceptor(builder)?; + + for lst in sockets { + self = self.listen_ssl_inner(lst, acceptor.clone())?; + } + + Ok(self) + } + + #[cfg(feature = "rustls")] + /// Start listening for incoming tls connections. + /// + /// This method sets alpn protocols to "h2" and "http/1.1" + pub fn bind_rustls( + mut self, + addr: A, + config: RustlsServerConfig, + ) -> io::Result { + let sockets = self.bind2(addr)?; + for lst in sockets { + self = self.listen_rustls_inner(lst, config.clone())?; + } + Ok(self) + } + + #[cfg(unix)] + /// Start listening for unix domain connections on existing listener. + /// + /// This method is available with `uds` feature. + pub fn listen_uds( + mut self, + lst: std::os::unix::net::UnixListener, + ) -> io::Result { + let cfg = self.config.clone(); + let factory = self.factory.clone(); + // todo duplicated: + self.sockets.push(Socket { + scheme: "http", + addr: net::SocketAddr::new( + net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), + 8080, + ), + }); + + let addr = format!("actix-web-service-{:?}", lst.local_addr()?); + + self.builder = self.builder.listen_uds(addr, lst, move || { + let c = cfg.lock(); + HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout) + .finish(factory()) + })?; + Ok(self) + } + + #[cfg(unix)] + /// Start listening for incoming unix domain connections. + /// + /// This method is available with `uds` feature. + pub fn bind_uds(mut self, addr: A) -> io::Result + where + A: AsRef, + { + let cfg = self.config.clone(); + let factory = self.factory.clone(); + self.sockets.push(Socket { + scheme: "http", + addr: net::SocketAddr::new( + net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), + 8080, + ), + }); + + self.builder = self.builder.bind_uds( + format!("actix-web-service-{:?}", addr.as_ref()), + addr, + move || { + let c = cfg.lock(); + HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout) + .finish(factory()) + }, + )?; + Ok(self) + } +} + +impl HttpServer +where + F: Fn() -> I + Send + Clone + 'static, + I: IntoServiceFactory, + S: ServiceFactory, + S::Error: Into, + S::InitError: fmt::Debug, + S::Response: Into>, + S::Service: 'static, + B: MessageBody, +{ + /// Start listening for incoming connections. + /// + /// This method starts number of http workers in separate threads. + /// For each address this method starts separate thread which does + /// `accept()` in a loop. + /// + /// This methods panics if no socket address can be bound or an `Actix` system is not yet + /// configured. + /// + /// ```rust + /// use std::io; + /// use actix_web::{web, App, HttpResponse, HttpServer}; + /// + /// fn main() -> io::Result<()> { + /// let sys = actix_rt::System::new("example"); // <- create Actix system + /// + /// HttpServer::new(|| App::new().service(web::resource("/").to(|| HttpResponse::Ok()))) + /// .bind("127.0.0.1:0")? + /// .start(); + /// # actix_rt::System::current().stop(); + /// sys.run() // <- Run actix system, this method starts all async processes + /// } + /// ``` + pub fn start(self) -> Server { + self.builder.start() + } + + /// Spawn new thread and start listening for incoming connections. + /// + /// This method spawns new thread and starts new actix system. Other than + /// that it is similar to `start()` method. This method blocks. + /// + /// This methods panics if no socket addresses get bound. + /// + /// ```rust + /// use std::io; + /// use actix_web::{web, App, HttpResponse, HttpServer}; + /// + /// fn main() -> io::Result<()> { + /// # std::thread::spawn(|| { + /// HttpServer::new(|| App::new().service(web::resource("/").to(|| HttpResponse::Ok()))) + /// .bind("127.0.0.1:0")? + /// .run() + /// # }); + /// # Ok(()) + /// } + /// ``` + pub fn run(self) -> io::Result<()> { + let sys = System::new("http-server"); + self.start(); + sys.run() + } +} + +fn create_tcp_listener( + addr: net::SocketAddr, + backlog: i32, +) -> io::Result { + let builder = match addr { + net::SocketAddr::V4(_) => TcpBuilder::new_v4()?, + net::SocketAddr::V6(_) => TcpBuilder::new_v6()?, + }; + builder.reuse_address(true)?; + builder.bind(addr)?; + Ok(builder.listen(backlog)?) +} + +#[cfg(feature = "openssl")] +/// Configure `SslAcceptorBuilder` with custom server flags. +fn openssl_acceptor(mut builder: SslAcceptorBuilder) -> io::Result { + use open_ssl::ssl::AlpnError; + + builder.set_alpn_select_callback(|_, protos| { + const H2: &[u8] = b"\x02h2"; + if protos.windows(3).any(|window| window == H2) { + Ok(b"h2") + } else { + Err(AlpnError::NOACK) + } + }); + builder.set_alpn_protos(b"\x08http/1.1\x02h2")?; + + Ok(builder.build()) +} diff --git a/src/server/channel.rs b/src/server/channel.rs deleted file mode 100644 index 390aaee87..000000000 --- a/src/server/channel.rs +++ /dev/null @@ -1,289 +0,0 @@ -use std::{ptr, mem, time, io}; -use std::rc::Rc; -use std::net::{SocketAddr, Shutdown}; - -use bytes::{Bytes, BytesMut, Buf, BufMut}; -use futures::{Future, Poll, Async}; -use tokio_io::{AsyncRead, AsyncWrite}; - -use super::{h1, h2, utils, HttpHandler, IoStream}; -use super::settings::WorkerSettings; - -const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0"; - - -enum HttpProtocol { - H1(h1::Http1), - H2(h2::Http2), - Unknown(Rc>, Option, T, BytesMut), -} - -enum ProtocolKind { - Http1, - Http2, -} - -#[doc(hidden)] -pub struct HttpChannel where T: IoStream, H: HttpHandler + 'static { - proto: Option>, - node: Option>>, -} - -impl HttpChannel where T: IoStream, H: HttpHandler + 'static -{ - pub(crate) fn new(settings: Rc>, - mut io: T, peer: Option, http2: bool) -> HttpChannel - { - settings.add_channel(); - let _ = io.set_nodelay(true); - - if http2 { - HttpChannel { - node: None, proto: Some(HttpProtocol::H2( - h2::Http2::new(settings, io, peer, Bytes::new()))) } - } else { - HttpChannel { - node: None, proto: Some(HttpProtocol::Unknown( - settings, peer, io, BytesMut::with_capacity(4096))) } - } - } - - fn shutdown(&mut self) { - match self.proto { - Some(HttpProtocol::H1(ref mut h1)) => { - let io = h1.io(); - let _ = IoStream::set_linger(io, Some(time::Duration::new(0, 0))); - let _ = IoStream::shutdown(io, Shutdown::Both); - } - Some(HttpProtocol::H2(ref mut h2)) => { - h2.shutdown() - } - _ => (), - } - } -} - -impl Future for HttpChannel where T: IoStream, H: HttpHandler + 'static -{ - type Item = (); - type Error = (); - - fn poll(&mut self) -> Poll { - if !self.node.is_none() { - let el = self as *mut _; - self.node = Some(Node::new(el)); - let _ = match self.proto { - Some(HttpProtocol::H1(ref mut h1)) => - self.node.as_ref().map(|n| h1.settings().head().insert(n)), - Some(HttpProtocol::H2(ref mut h2)) => - self.node.as_ref().map(|n| h2.settings().head().insert(n)), - Some(HttpProtocol::Unknown(ref mut settings, _, _, _)) => - self.node.as_ref().map(|n| settings.head().insert(n)), - None => unreachable!(), - }; - } - - let kind = match self.proto { - Some(HttpProtocol::H1(ref mut h1)) => { - let result = h1.poll(); - match result { - Ok(Async::Ready(())) | Err(_) => { - h1.settings().remove_channel(); - self.node.as_mut().map(|n| n.remove()); - }, - _ => (), - } - return result - }, - Some(HttpProtocol::H2(ref mut h2)) => { - let result = h2.poll(); - match result { - Ok(Async::Ready(())) | Err(_) => { - h2.settings().remove_channel(); - self.node.as_mut().map(|n| n.remove()); - }, - _ => (), - } - return result - }, - Some(HttpProtocol::Unknown(ref mut settings, _, ref mut io, ref mut buf)) => { - match utils::read_from_io(io, buf) { - Ok(Async::Ready(0)) | Err(_) => { - debug!("Ignored premature client disconnection"); - settings.remove_channel(); - self.node.as_mut().map(|n| n.remove()); - return Err(()) - }, - _ => (), - } - - if buf.len() >= 14 { - if buf[..14] == HTTP2_PREFACE[..] { - ProtocolKind::Http2 - } else { - ProtocolKind::Http1 - } - } else { - return Ok(Async::NotReady); - } - }, - None => unreachable!(), - }; - - // upgrade to specific http protocol - if let Some(HttpProtocol::Unknown(settings, addr, io, buf)) = self.proto.take() { - match kind { - ProtocolKind::Http1 => { - self.proto = Some( - HttpProtocol::H1(h1::Http1::new(settings, io, addr, buf))); - return self.poll() - }, - ProtocolKind::Http2 => { - self.proto = Some( - HttpProtocol::H2(h2::Http2::new(settings, io, addr, buf.freeze()))); - return self.poll() - }, - } - } - unreachable!() - } -} - -pub(crate) struct Node -{ - next: Option<*mut Node<()>>, - prev: Option<*mut Node<()>>, - element: *mut T, -} - -impl Node -{ - fn new(el: *mut T) -> Self { - Node { - next: None, - prev: None, - element: el, - } - } - - fn insert(&self, next: &Node) { - #[allow(mutable_transmutes)] - unsafe { - if let Some(ref next2) = self.next { - let n: &mut Node<()> = mem::transmute(next2.as_ref().unwrap()); - n.prev = Some(next as *const _ as *mut _); - } - let slf: &mut Node = mem::transmute(self); - slf.next = Some(next as *const _ as *mut _); - - let next: &mut Node = mem::transmute(next); - next.prev = Some(slf as *const _ as *mut _); - } - } - - fn remove(&mut self) { - unsafe { - self.element = ptr::null_mut(); - let next = self.next.take(); - let mut prev = self.prev.take(); - - if let Some(ref mut prev) = prev { - prev.as_mut().unwrap().next = next; - } - } - } -} - - -impl Node<()> { - - pub(crate) fn head() -> Self { - Node { - next: None, - prev: None, - element: ptr::null_mut(), - } - } - - pub(crate) fn traverse(&self) where T: IoStream, H: HttpHandler + 'static { - let mut next = self.next.as_ref(); - loop { - if let Some(n) = next { - unsafe { - let n: &Node<()> = mem::transmute(n.as_ref().unwrap()); - next = n.next.as_ref(); - - if !n.element.is_null() { - let ch: &mut HttpChannel = mem::transmute( - &mut *(n.element as *mut _)); - ch.shutdown(); - } - } - } else { - return - } - } - } -} - -/// Wrapper for `AsyncRead + AsyncWrite` types -pub(crate) struct WrapperStream where T: AsyncRead + AsyncWrite + 'static { - io: T, -} - -impl WrapperStream where T: AsyncRead + AsyncWrite + 'static { - pub fn new(io: T) -> Self { - WrapperStream{ io } - } -} - -impl IoStream for WrapperStream where T: AsyncRead + AsyncWrite + 'static { - #[inline] - fn shutdown(&mut self, _: Shutdown) -> io::Result<()> { - Ok(()) - } - #[inline] - fn set_nodelay(&mut self, _: bool) -> io::Result<()> { - Ok(()) - } - #[inline] - fn set_linger(&mut self, _: Option) -> io::Result<()> { - Ok(()) - } -} - -impl io::Read for WrapperStream where T: AsyncRead + AsyncWrite + 'static { - #[inline] - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.io.read(buf) - } -} - -impl io::Write for WrapperStream where T: AsyncRead + AsyncWrite + 'static { - #[inline] - fn write(&mut self, buf: &[u8]) -> io::Result { - self.io.write(buf) - } - #[inline] - fn flush(&mut self) -> io::Result<()> { - self.io.flush() - } -} - -impl AsyncRead for WrapperStream where T: AsyncRead + AsyncWrite + 'static { - #[inline] - fn read_buf(&mut self, buf: &mut B) -> Poll { - self.io.read_buf(buf) - } -} - -impl AsyncWrite for WrapperStream where T: AsyncRead + AsyncWrite + 'static { - #[inline] - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.io.shutdown() - } - #[inline] - fn write_buf(&mut self, buf: &mut B) -> Poll { - self.io.write_buf(buf) - } -} diff --git a/src/server/encoding.rs b/src/server/encoding.rs deleted file mode 100644 index fc624d55f..000000000 --- a/src/server/encoding.rs +++ /dev/null @@ -1,854 +0,0 @@ -use std::{io, cmp, mem}; -use std::io::{Read, Write}; -use std::fmt::Write as FmtWrite; -use std::str::FromStr; - -use bytes::{Bytes, BytesMut, BufMut}; -use http::{Version, Method, HttpTryFrom}; -use http::header::{HeaderMap, HeaderValue, - ACCEPT_ENCODING, CONNECTION, - CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING}; -use flate2::Compression; -use flate2::read::GzDecoder; -use flate2::write::{GzEncoder, DeflateDecoder, DeflateEncoder}; -#[cfg(feature="brotli")] -use brotli2::write::{BrotliDecoder, BrotliEncoder}; - -use header::ContentEncoding; -use body::{Body, Binary}; -use error::PayloadError; -use httprequest::HttpInnerMessage; -use httpresponse::HttpResponse; -use payload::{PayloadSender, PayloadWriter, PayloadStatus}; - -use super::shared::SharedBytes; - -pub(crate) enum PayloadType { - Sender(PayloadSender), - Encoding(Box), -} - -impl PayloadType { - - pub fn new(headers: &HeaderMap, sender: PayloadSender) -> PayloadType { - // check content-encoding - let enc = if let Some(enc) = headers.get(CONTENT_ENCODING) { - if let Ok(enc) = enc.to_str() { - ContentEncoding::from(enc) - } else { - ContentEncoding::Auto - } - } else { - ContentEncoding::Auto - }; - - match enc { - ContentEncoding::Auto | ContentEncoding::Identity => - PayloadType::Sender(sender), - _ => PayloadType::Encoding(Box::new(EncodedPayload::new(sender, enc))), - } - } -} - -impl PayloadWriter for PayloadType { - #[inline] - fn set_error(&mut self, err: PayloadError) { - match *self { - PayloadType::Sender(ref mut sender) => sender.set_error(err), - PayloadType::Encoding(ref mut enc) => enc.set_error(err), - } - } - - #[inline] - fn feed_eof(&mut self) { - match *self { - PayloadType::Sender(ref mut sender) => sender.feed_eof(), - PayloadType::Encoding(ref mut enc) => enc.feed_eof(), - } - } - - #[inline] - fn feed_data(&mut self, data: Bytes) { - match *self { - PayloadType::Sender(ref mut sender) => sender.feed_data(data), - PayloadType::Encoding(ref mut enc) => enc.feed_data(data), - } - } - - #[inline] - fn need_read(&self) -> PayloadStatus { - match *self { - PayloadType::Sender(ref sender) => sender.need_read(), - PayloadType::Encoding(ref enc) => enc.need_read(), - } - } -} - - -/// Payload wrapper with content decompression support -pub(crate) struct EncodedPayload { - inner: PayloadSender, - error: bool, - payload: PayloadStream, -} - -impl EncodedPayload { - pub fn new(inner: PayloadSender, enc: ContentEncoding) -> EncodedPayload { - EncodedPayload{ inner, error: false, payload: PayloadStream::new(enc) } - } -} - -impl PayloadWriter for EncodedPayload { - - fn set_error(&mut self, err: PayloadError) { - self.inner.set_error(err) - } - - fn feed_eof(&mut self) { - if !self.error { - match self.payload.feed_eof() { - Err(err) => { - self.error = true; - self.set_error(PayloadError::Io(err)); - }, - Ok(value) => { - if let Some(b) = value { - self.inner.feed_data(b); - } - self.inner.feed_eof(); - } - } - } - } - - fn feed_data(&mut self, data: Bytes) { - if self.error { - return - } - - match self.payload.feed_data(data) { - Ok(Some(b)) => self.inner.feed_data(b), - Ok(None) => (), - Err(e) => { - self.error = true; - self.set_error(e.into()); - } - } - } - - #[inline] - fn need_read(&self) -> PayloadStatus { - self.inner.need_read() - } -} - -pub(crate) enum Decoder { - Deflate(Box>), - Gzip(Option>>), - #[cfg(feature="brotli")] - Br(Box>), - Identity, -} - -// should go after write::GzDecoder get implemented -#[derive(Debug)] -pub(crate) struct Wrapper { - pub buf: BytesMut, - pub eof: bool, -} - -impl io::Read for Wrapper { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let len = cmp::min(buf.len(), self.buf.len()); - buf[..len].copy_from_slice(&self.buf[..len]); - self.buf.split_to(len); - if len == 0 { - if self.eof { - Ok(0) - } else { - Err(io::Error::new(io::ErrorKind::WouldBlock, "")) - } - } else { - Ok(len) - } - } -} - -impl io::Write for Wrapper { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.buf.extend_from_slice(buf); - Ok(buf.len()) - } - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - -pub(crate) struct Writer { - buf: BytesMut, -} - -impl Writer { - fn new() -> Writer { - Writer{buf: BytesMut::with_capacity(8192)} - } - fn take(&mut self) -> Bytes { - self.buf.take().freeze() - } -} - -impl io::Write for Writer { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.buf.extend_from_slice(buf); - Ok(buf.len()) - } - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - -/// Payload stream with decompression support -pub(crate) struct PayloadStream { - decoder: Decoder, - dst: BytesMut, -} - -impl PayloadStream { - pub fn new(enc: ContentEncoding) -> PayloadStream { - let dec = match enc { - #[cfg(feature="brotli")] - ContentEncoding::Br => Decoder::Br( - Box::new(BrotliDecoder::new(Writer::new()))), - ContentEncoding::Deflate => Decoder::Deflate( - Box::new(DeflateDecoder::new(Writer::new()))), - ContentEncoding::Gzip => Decoder::Gzip(None), - _ => Decoder::Identity, - }; - PayloadStream{ decoder: dec, dst: BytesMut::new() } - } -} - -impl PayloadStream { - - pub fn feed_eof(&mut self) -> io::Result> { - match self.decoder { - #[cfg(feature="brotli")] - Decoder::Br(ref mut decoder) => { - match decoder.finish() { - Ok(mut writer) => { - let b = writer.take(); - if !b.is_empty() { - Ok(Some(b)) - } else { - Ok(None) - } - }, - Err(e) => Err(e), - } - }, - Decoder::Gzip(ref mut decoder) => { - if let Some(ref mut decoder) = *decoder { - decoder.as_mut().get_mut().eof = true; - - self.dst.reserve(8192); - match decoder.read(unsafe{self.dst.bytes_mut()}) { - Ok(n) => { - unsafe{self.dst.advance_mut(n)}; - return Ok(Some(self.dst.take().freeze())) - } - Err(e) => - return Err(e), - } - } else { - Ok(None) - } - }, - Decoder::Deflate(ref mut decoder) => { - match decoder.try_finish() { - Ok(_) => { - let b = decoder.get_mut().take(); - if !b.is_empty() { - Ok(Some(b)) - } else { - Ok(None) - } - }, - Err(e) => Err(e), - } - }, - Decoder::Identity => Ok(None), - } - } - - pub fn feed_data(&mut self, data: Bytes) -> io::Result> { - match self.decoder { - #[cfg(feature="brotli")] - Decoder::Br(ref mut decoder) => { - match decoder.write_all(&data) { - Ok(_) => { - decoder.flush()?; - let b = decoder.get_mut().take(); - if !b.is_empty() { - Ok(Some(b)) - } else { - Ok(None) - } - }, - Err(e) => Err(e) - } - }, - Decoder::Gzip(ref mut decoder) => { - if decoder.is_none() { - *decoder = Some( - Box::new(GzDecoder::new( - Wrapper{buf: BytesMut::from(data), eof: false}))); - } else { - let _ = decoder.as_mut().unwrap().write(&data); - } - - loop { - self.dst.reserve(8192); - match decoder.as_mut() - .as_mut().unwrap().read(unsafe{self.dst.bytes_mut()}) - { - Ok(n) => { - if n != 0 { - unsafe{self.dst.advance_mut(n)}; - } - if n == 0 { - return Ok(Some(self.dst.take().freeze())); - } - } - Err(e) => { - if e.kind() == io::ErrorKind::WouldBlock && !self.dst.is_empty() - { - return Ok(Some(self.dst.take().freeze())); - } - return Err(e) - } - } - } - }, - Decoder::Deflate(ref mut decoder) => { - match decoder.write_all(&data) { - Ok(_) => { - decoder.flush()?; - let b = decoder.get_mut().take(); - if !b.is_empty() { - Ok(Some(b)) - } else { - Ok(None) - } - }, - Err(e) => Err(e), - } - }, - Decoder::Identity => Ok(Some(data)), - } - } -} - -pub(crate) enum ContentEncoder { - Deflate(DeflateEncoder), - Gzip(GzEncoder), - #[cfg(feature="brotli")] - Br(BrotliEncoder), - Identity(TransferEncoding), -} - -impl ContentEncoder { - - pub fn empty(bytes: SharedBytes) -> ContentEncoder { - ContentEncoder::Identity(TransferEncoding::eof(bytes)) - } - - pub fn for_server(buf: SharedBytes, - req: &HttpInnerMessage, - resp: &mut HttpResponse, - response_encoding: ContentEncoding) -> ContentEncoder - { - let version = resp.version().unwrap_or_else(|| req.version); - let is_head = req.method == Method::HEAD; - let mut body = resp.replace_body(Body::Empty); - let has_body = match body { - Body::Empty => false, - Body::Binary(ref bin) => - !(response_encoding == ContentEncoding::Auto && bin.len() < 96), - _ => true, - }; - - // Enable content encoding only if response does not contain Content-Encoding header - let mut encoding = if has_body { - let encoding = match response_encoding { - ContentEncoding::Auto => { - // negotiate content-encoding - if let Some(val) = req.headers.get(ACCEPT_ENCODING) { - if let Ok(enc) = val.to_str() { - AcceptEncoding::parse(enc) - } else { - ContentEncoding::Identity - } - } else { - ContentEncoding::Identity - } - } - encoding => encoding, - }; - if encoding.is_compression() { - resp.headers_mut().insert( - CONTENT_ENCODING, HeaderValue::from_static(encoding.as_str())); - } - encoding - } else { - ContentEncoding::Identity - }; - - let mut transfer = match body { - Body::Empty => { - if req.method != Method::HEAD { - resp.headers_mut().remove(CONTENT_LENGTH); - } - TransferEncoding::length(0, buf) - }, - Body::Binary(ref mut bytes) => { - if !(encoding == ContentEncoding::Identity - || encoding == ContentEncoding::Auto) - { - let tmp = SharedBytes::default(); - let transfer = TransferEncoding::eof(tmp.clone()); - let mut enc = match encoding { - ContentEncoding::Deflate => ContentEncoder::Deflate( - DeflateEncoder::new(transfer, Compression::fast())), - ContentEncoding::Gzip => ContentEncoder::Gzip( - GzEncoder::new(transfer, Compression::fast())), - #[cfg(feature="brotli")] - ContentEncoding::Br => ContentEncoder::Br( - BrotliEncoder::new(transfer, 3)), - ContentEncoding::Identity => ContentEncoder::Identity(transfer), - ContentEncoding::Auto => unreachable!() - }; - // TODO return error! - let _ = enc.write(bytes.clone()); - let _ = enc.write_eof(); - - *bytes = Binary::from(tmp.take()); - encoding = ContentEncoding::Identity; - } - if is_head { - let mut b = BytesMut::new(); - let _ = write!(b, "{}", bytes.len()); - resp.headers_mut().insert( - CONTENT_LENGTH, HeaderValue::try_from(b.freeze()).unwrap()); - } else { - // resp.headers_mut().remove(CONTENT_LENGTH); - } - TransferEncoding::eof(buf) - } - Body::Streaming(_) | Body::Actor(_) => { - if resp.upgrade() { - if version == Version::HTTP_2 { - error!("Connection upgrade is forbidden for HTTP/2"); - } else { - resp.headers_mut().insert( - CONNECTION, HeaderValue::from_static("upgrade")); - } - if encoding != ContentEncoding::Identity { - encoding = ContentEncoding::Identity; - resp.headers_mut().remove(CONTENT_ENCODING); - } - TransferEncoding::eof(buf) - } else { - ContentEncoder::streaming_encoding(buf, version, resp) - } - } - }; - // - if is_head { - transfer.kind = TransferEncodingKind::Length(0); - } else { - resp.replace_body(body); - } - - match encoding { - ContentEncoding::Deflate => ContentEncoder::Deflate( - DeflateEncoder::new(transfer, Compression::fast())), - ContentEncoding::Gzip => ContentEncoder::Gzip( - GzEncoder::new(transfer, Compression::fast())), - #[cfg(feature="brotli")] - ContentEncoding::Br => ContentEncoder::Br( - BrotliEncoder::new(transfer, 3)), - ContentEncoding::Identity | ContentEncoding::Auto => - ContentEncoder::Identity(transfer), - } - } - - fn streaming_encoding(buf: SharedBytes, version: Version, - resp: &mut HttpResponse) -> TransferEncoding { - match resp.chunked() { - Some(true) => { - // Enable transfer encoding - resp.headers_mut().remove(CONTENT_LENGTH); - if version == Version::HTTP_2 { - resp.headers_mut().remove(TRANSFER_ENCODING); - TransferEncoding::eof(buf) - } else { - resp.headers_mut().insert( - TRANSFER_ENCODING, HeaderValue::from_static("chunked")); - TransferEncoding::chunked(buf) - } - }, - Some(false) => - TransferEncoding::eof(buf), - None => { - // if Content-Length is specified, then use it as length hint - let (len, chunked) = - if let Some(len) = resp.headers().get(CONTENT_LENGTH) { - // Content-Length - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - (Some(len), false) - } else { - error!("illegal Content-Length: {:?}", len); - (None, false) - } - } else { - error!("illegal Content-Length: {:?}", len); - (None, false) - } - } else { - (None, true) - }; - - if !chunked { - if let Some(len) = len { - TransferEncoding::length(len, buf) - } else { - TransferEncoding::eof(buf) - } - } else { - // Enable transfer encoding - match version { - Version::HTTP_11 => { - resp.headers_mut().insert( - TRANSFER_ENCODING, HeaderValue::from_static("chunked")); - TransferEncoding::chunked(buf) - }, - _ => { - resp.headers_mut().remove(TRANSFER_ENCODING); - TransferEncoding::eof(buf) - } - } - } - } - } - } -} - -impl ContentEncoder { - - #[inline] - pub fn is_eof(&self) -> bool { - match *self { - #[cfg(feature="brotli")] - ContentEncoder::Br(ref encoder) => encoder.get_ref().is_eof(), - ContentEncoder::Deflate(ref encoder) => encoder.get_ref().is_eof(), - ContentEncoder::Gzip(ref encoder) => encoder.get_ref().is_eof(), - ContentEncoder::Identity(ref encoder) => encoder.is_eof(), - } - } - - #[cfg_attr(feature = "cargo-clippy", allow(inline_always))] - #[inline(always)] - pub fn write_eof(&mut self) -> Result<(), io::Error> { - let encoder = mem::replace( - self, ContentEncoder::Identity(TransferEncoding::eof(SharedBytes::empty()))); - - match encoder { - #[cfg(feature="brotli")] - ContentEncoder::Br(encoder) => { - match encoder.finish() { - Ok(mut writer) => { - writer.encode_eof(); - *self = ContentEncoder::Identity(writer); - Ok(()) - }, - Err(err) => Err(err), - } - } - ContentEncoder::Gzip(encoder) => { - match encoder.finish() { - Ok(mut writer) => { - writer.encode_eof(); - *self = ContentEncoder::Identity(writer); - Ok(()) - }, - Err(err) => Err(err), - } - }, - ContentEncoder::Deflate(encoder) => { - match encoder.finish() { - Ok(mut writer) => { - writer.encode_eof(); - *self = ContentEncoder::Identity(writer); - Ok(()) - }, - Err(err) => Err(err), - } - }, - ContentEncoder::Identity(mut writer) => { - writer.encode_eof(); - *self = ContentEncoder::Identity(writer); - Ok(()) - } - } - } - - #[cfg_attr(feature = "cargo-clippy", allow(inline_always))] - #[inline(always)] - pub fn write(&mut self, data: Binary) -> Result<(), io::Error> { - match *self { - #[cfg(feature="brotli")] - ContentEncoder::Br(ref mut encoder) => { - match encoder.write_all(data.as_ref()) { - Ok(_) => Ok(()), - Err(err) => { - trace!("Error decoding br encoding: {}", err); - Err(err) - }, - } - }, - ContentEncoder::Gzip(ref mut encoder) => { - match encoder.write_all(data.as_ref()) { - Ok(_) => Ok(()), - Err(err) => { - trace!("Error decoding gzip encoding: {}", err); - Err(err) - }, - } - } - ContentEncoder::Deflate(ref mut encoder) => { - match encoder.write_all(data.as_ref()) { - Ok(_) => Ok(()), - Err(err) => { - trace!("Error decoding deflate encoding: {}", err); - Err(err) - }, - } - } - ContentEncoder::Identity(ref mut encoder) => { - encoder.encode(data)?; - Ok(()) - } - } - } -} - -/// Encoders to handle different Transfer-Encodings. -#[derive(Debug, Clone)] -pub(crate) struct TransferEncoding { - kind: TransferEncodingKind, - buffer: SharedBytes, -} - -#[derive(Debug, PartialEq, Clone)] -enum TransferEncodingKind { - /// An Encoder for when Transfer-Encoding includes `chunked`. - Chunked(bool), - /// An Encoder for when Content-Length is set. - /// - /// Enforces that the body is not longer than the Content-Length header. - Length(u64), - /// An Encoder for when Content-Length is not known. - /// - /// Application decides when to stop writing. - Eof, -} - -impl TransferEncoding { - - #[inline] - pub fn eof(bytes: SharedBytes) -> TransferEncoding { - TransferEncoding { - kind: TransferEncodingKind::Eof, - buffer: bytes, - } - } - - #[inline] - pub fn chunked(bytes: SharedBytes) -> TransferEncoding { - TransferEncoding { - kind: TransferEncodingKind::Chunked(false), - buffer: bytes, - } - } - - #[inline] - pub fn length(len: u64, bytes: SharedBytes) -> TransferEncoding { - TransferEncoding { - kind: TransferEncodingKind::Length(len), - buffer: bytes, - } - } - - #[inline] - pub fn is_eof(&self) -> bool { - match self.kind { - TransferEncodingKind::Eof => true, - TransferEncodingKind::Chunked(ref eof) => *eof, - TransferEncodingKind::Length(ref remaining) => *remaining == 0, - } - } - - /// Encode message. Return `EOF` state of encoder - #[inline] - pub fn encode(&mut self, mut msg: Binary) -> io::Result { - match self.kind { - TransferEncodingKind::Eof => { - let eof = msg.is_empty(); - self.buffer.extend(msg); - Ok(eof) - }, - TransferEncodingKind::Chunked(ref mut eof) => { - if *eof { - return Ok(true); - } - - if msg.is_empty() { - *eof = true; - self.buffer.extend_from_slice(b"0\r\n\r\n"); - } else { - let mut buf = BytesMut::new(); - write!(&mut buf, "{:X}\r\n", msg.len()) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - self.buffer.reserve(buf.len() + msg.len() + 2); - self.buffer.extend(buf.into()); - self.buffer.extend(msg); - self.buffer.extend_from_slice(b"\r\n"); - } - Ok(*eof) - }, - TransferEncodingKind::Length(ref mut remaining) => { - if *remaining > 0 { - if msg.is_empty() { - return Ok(*remaining == 0) - } - let len = cmp::min(*remaining, msg.len() as u64); - self.buffer.extend(msg.take().split_to(len as usize).into()); - - *remaining -= len as u64; - Ok(*remaining == 0) - } else { - Ok(true) - } - }, - } - } - - /// Encode eof. Return `EOF` state of encoder - #[inline] - pub fn encode_eof(&mut self) { - match self.kind { - TransferEncodingKind::Eof | TransferEncodingKind::Length(_) => (), - TransferEncodingKind::Chunked(ref mut eof) => { - if !*eof { - *eof = true; - self.buffer.extend_from_slice(b"0\r\n\r\n"); - } - }, - } - } -} - -impl io::Write for TransferEncoding { - - #[inline] - fn write(&mut self, buf: &[u8]) -> io::Result { - self.encode(Binary::from_slice(buf))?; - Ok(buf.len()) - } - - #[inline] - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - - -struct AcceptEncoding { - encoding: ContentEncoding, - quality: f64, -} - -impl Eq for AcceptEncoding {} - -impl Ord for AcceptEncoding { - fn cmp(&self, other: &AcceptEncoding) -> cmp::Ordering { - if self.quality > other.quality { - cmp::Ordering::Less - } else if self.quality < other.quality { - cmp::Ordering::Greater - } else { - cmp::Ordering::Equal - } - } -} - -impl PartialOrd for AcceptEncoding { - fn partial_cmp(&self, other: &AcceptEncoding) -> Option { - Some(self.cmp(other)) - } -} - -impl PartialEq for AcceptEncoding { - fn eq(&self, other: &AcceptEncoding) -> bool { - self.quality == other.quality - } -} - -impl AcceptEncoding { - fn new(tag: &str) -> Option { - let parts: Vec<&str> = tag.split(';').collect(); - let encoding = match parts.len() { - 0 => return None, - _ => ContentEncoding::from(parts[0]), - }; - let quality = match parts.len() { - 1 => encoding.quality(), - _ => match f64::from_str(parts[1]) { - Ok(q) => q, - Err(_) => 0.0, - } - }; - Some(AcceptEncoding{ encoding, quality }) - } - - /// Parse a raw Accept-Encoding header value into an ordered list. - pub fn parse(raw: &str) -> ContentEncoding { - let mut encodings: Vec<_> = - raw.replace(' ', "").split(',').map(|l| AcceptEncoding::new(l)).collect(); - encodings.sort(); - - for enc in encodings { - if let Some(enc) = enc { - return enc.encoding - } - } - ContentEncoding::Identity - } -} - - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_chunked_te() { - let bytes = SharedBytes::default(); - let mut enc = TransferEncoding::chunked(bytes.clone()); - assert!(!enc.encode(Binary::from(b"test".as_ref())).ok().unwrap()); - assert!(enc.encode(Binary::from(b"".as_ref())).ok().unwrap()); - assert_eq!(bytes.get_mut().take().freeze(), - Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n")); - } -} diff --git a/src/server/h1.rs b/src/server/h1.rs deleted file mode 100644 index cb2e0b049..000000000 --- a/src/server/h1.rs +++ /dev/null @@ -1,1507 +0,0 @@ -#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))] - -use std::{self, io}; -use std::rc::Rc; -use std::net::SocketAddr; -use std::time::Duration; -use std::collections::VecDeque; - -use actix::Arbiter; -use httparse; -use http::{Uri, Method, Version, HttpTryFrom, HeaderMap}; -use http::header::{self, HeaderName, HeaderValue}; -use bytes::{Bytes, BytesMut}; -use futures::{Future, Poll, Async}; -use tokio_core::reactor::Timeout; - -use pipeline::Pipeline; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; -use error::{ParseError, PayloadError, ResponseError}; -use payload::{Payload, PayloadWriter, PayloadStatus}; - -use super::{utils, Writer}; -use super::h1writer::H1Writer; -use super::encoding::PayloadType; -use super::settings::WorkerSettings; -use super::{HttpHandler, HttpHandlerTask, IoStream}; - -const MAX_BUFFER_SIZE: usize = 131_072; -const MAX_HEADERS: usize = 96; -const MAX_PIPELINED_MESSAGES: usize = 16; - -bitflags! { - struct Flags: u8 { - const STARTED = 0b0000_0001; - const ERROR = 0b0000_0010; - const KEEPALIVE = 0b0000_0100; - const SHUTDOWN = 0b0000_1000; - } -} - -bitflags! { - struct EntryFlags: u8 { - const EOF = 0b0000_0001; - const ERROR = 0b0000_0010; - const FINISHED = 0b0000_0100; - } -} - -pub(crate) struct Http1 { - flags: Flags, - settings: Rc>, - addr: Option, - stream: H1Writer, - reader: Reader, - read_buf: BytesMut, - tasks: VecDeque, - keepalive_timer: Option, -} - -struct Entry { - pipe: Box, - flags: EntryFlags, -} - -impl Http1 - where T: IoStream, H: HttpHandler + 'static -{ - pub fn new(settings: Rc>, - stream: T, - addr: Option, read_buf: BytesMut) -> Self - { - let bytes = settings.get_shared_bytes(); - Http1{ flags: Flags::KEEPALIVE, - stream: H1Writer::new(stream, bytes, Rc::clone(&settings)), - reader: Reader::new(), - tasks: VecDeque::new(), - keepalive_timer: None, - addr, - read_buf, - settings, - } - } - - pub fn settings(&self) -> &WorkerSettings { - self.settings.as_ref() - } - - pub(crate) fn io(&mut self) -> &mut T { - self.stream.get_mut() - } - - pub fn poll(&mut self) -> Poll<(), ()> { - // keep-alive timer - if let Some(ref mut timer) = self.keepalive_timer { - match timer.poll() { - Ok(Async::Ready(_)) => { - trace!("Keep-alive timeout, close connection"); - self.flags.insert(Flags::SHUTDOWN); - } - Ok(Async::NotReady) => (), - Err(_) => unreachable!(), - } - } - - // shutdown - if self.flags.contains(Flags::SHUTDOWN) { - match self.stream.poll_completed(true) { - Ok(Async::NotReady) => return Ok(Async::NotReady), - Ok(Async::Ready(_)) => return Ok(Async::Ready(())), - Err(err) => { - debug!("Error sending data: {}", err); - return Err(()) - } - } - } - - loop { - match self.poll_io()? { - Async::Ready(true) => (), - Async::Ready(false) => { - self.flags.insert(Flags::SHUTDOWN); - return self.poll() - }, - Async::NotReady => return Ok(Async::NotReady), - } - } - } - - // TODO: refactor - pub fn poll_io(&mut self) -> Poll { - // read incoming data - let need_read = if !self.flags.intersects(Flags::ERROR) && - self.tasks.len() < MAX_PIPELINED_MESSAGES - { - 'outer: loop { - match self.reader.parse(self.stream.get_mut(), - &mut self.read_buf, &self.settings) { - Ok(Async::Ready(mut req)) => { - self.flags.insert(Flags::STARTED); - - // set remote addr - req.set_peer_addr(self.addr); - - // stop keepalive timer - self.keepalive_timer.take(); - - // start request processing - for h in self.settings.handlers().iter_mut() { - req = match h.handle(req) { - Ok(pipe) => { - self.tasks.push_back( - Entry {pipe, flags: EntryFlags::empty()}); - continue 'outer - }, - Err(req) => req, - } - } - - self.tasks.push_back( - Entry {pipe: Pipeline::error(HttpResponse::NotFound()), - flags: EntryFlags::empty()}); - continue - }, - Ok(Async::NotReady) => (), - Err(err) => { - trace!("Parse error: {:?}", err); - - // notify all tasks - self.stream.disconnected(); - for entry in &mut self.tasks { - entry.pipe.disconnected() - } - - // kill keepalive - self.flags.remove(Flags::KEEPALIVE); - self.keepalive_timer.take(); - - // on parse error, stop reading stream but tasks need to be completed - self.flags.insert(Flags::ERROR); - - match err { - ReaderError::Disconnect => (), - _ => - if self.tasks.is_empty() { - if let ReaderError::Error(err) = err { - self.tasks.push_back( - Entry {pipe: Pipeline::error(err.error_response()), - flags: EntryFlags::empty()}); - } - } - } - }, - } - break - } - false - } else { - true - }; - - let retry = self.reader.need_read() == PayloadStatus::Read; - - // check in-flight messages - let mut io = false; - let mut idx = 0; - while idx < self.tasks.len() { - let item = &mut self.tasks[idx]; - - if !io && !item.flags.contains(EntryFlags::EOF) { - // io is corrupted, send buffer - if item.flags.contains(EntryFlags::ERROR) { - if let Ok(Async::NotReady) = self.stream.poll_completed(true) { - return Ok(Async::NotReady) - } - return Err(()) - } - - match item.pipe.poll_io(&mut self.stream) { - Ok(Async::Ready(ready)) => { - // override keep-alive state - if self.stream.keepalive() { - self.flags.insert(Flags::KEEPALIVE); - } else { - self.flags.remove(Flags::KEEPALIVE); - } - // prepare stream for next response - self.stream.reset(); - - if ready { - item.flags.insert(EntryFlags::EOF | EntryFlags::FINISHED); - } else { - item.flags.insert(EntryFlags::FINISHED); - } - }, - // no more IO for this iteration - Ok(Async::NotReady) => { - if self.reader.need_read() == PayloadStatus::Read && !retry { - return Ok(Async::Ready(true)); - } - io = true; - } - Err(err) => { - // it is not possible to recover from error - // during pipe handling, so just drop connection - error!("Unhandled error: {}", err); - item.flags.insert(EntryFlags::ERROR); - - // check stream state, we still can have valid data in buffer - if let Ok(Async::NotReady) = self.stream.poll_completed(true) { - return Ok(Async::NotReady) - } - return Err(()) - } - } - } else if !item.flags.contains(EntryFlags::FINISHED) { - match item.pipe.poll() { - Ok(Async::NotReady) => (), - Ok(Async::Ready(_)) => item.flags.insert(EntryFlags::FINISHED), - Err(err) => { - item.flags.insert(EntryFlags::ERROR); - error!("Unhandled error: {}", err); - } - } - } - idx += 1; - } - - // cleanup finished tasks - let mut popped = false; - while !self.tasks.is_empty() { - if self.tasks[0].flags.contains(EntryFlags::EOF | EntryFlags::FINISHED) { - popped = true; - self.tasks.pop_front(); - } else { - break - } - } - if need_read && popped { - return self.poll_io() - } - - // check stream state - if self.flags.contains(Flags::STARTED) { - match self.stream.poll_completed(false) { - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(err) => { - debug!("Error sending data: {}", err); - return Err(()) - } - _ => (), - } - } - - // deal with keep-alive - if self.tasks.is_empty() { - // no keep-alive situations - if self.flags.contains(Flags::ERROR) || - (!self.flags.contains(Flags::KEEPALIVE) - || !self.settings.keep_alive_enabled()) && - self.flags.contains(Flags::STARTED) - { - return Ok(Async::Ready(false)) - } - - // start keep-alive timer - let keep_alive = self.settings.keep_alive(); - if self.keepalive_timer.is_none() && keep_alive > 0 { - trace!("Start keep-alive timer"); - let mut timer = Timeout::new( - Duration::new(keep_alive, 0), Arbiter::handle()).unwrap(); - // register timer - let _ = timer.poll(); - self.keepalive_timer = Some(timer); - } - } - Ok(Async::NotReady) - } -} - -struct Reader { - payload: Option, -} - -enum Decoding { - Ready, - NotReady, -} - -struct PayloadInfo { - tx: PayloadType, - decoder: Decoder, -} - -#[derive(Debug)] -enum ReaderError { - Disconnect, - Payload, - PayloadDropped, - Error(ParseError), -} - -impl Reader { - pub fn new() -> Reader { - Reader { - payload: None, - } - } - - #[inline] - fn need_read(&self) -> PayloadStatus { - if let Some(ref info) = self.payload { - info.tx.need_read() - } else { - PayloadStatus::Read - } - } - - #[inline] - fn decode(&mut self, buf: &mut BytesMut, payload: &mut PayloadInfo) - -> Result - { - while !buf.is_empty() { - match payload.decoder.decode(buf) { - Ok(Async::Ready(Some(bytes))) => { - payload.tx.feed_data(bytes); - if payload.decoder.is_eof() { - payload.tx.feed_eof(); - return Ok(Decoding::Ready) - } - }, - Ok(Async::Ready(None)) => { - payload.tx.feed_eof(); - return Ok(Decoding::Ready) - }, - Ok(Async::NotReady) => return Ok(Decoding::NotReady), - Err(err) => { - payload.tx.set_error(err.into()); - return Err(ReaderError::Payload) - } - } - } - Ok(Decoding::NotReady) - } - - pub fn parse(&mut self, io: &mut T, - buf: &mut BytesMut, - settings: &WorkerSettings) -> Poll - where T: IoStream - { - match self.need_read() { - PayloadStatus::Read => (), - PayloadStatus::Pause => return Ok(Async::NotReady), - PayloadStatus::Dropped => return Err(ReaderError::PayloadDropped), - } - - // read payload - let done = { - if let Some(ref mut payload) = self.payload { - 'buf: loop { - let not_ready = match utils::read_from_io(io, buf) { - Ok(Async::Ready(0)) => { - payload.tx.set_error(PayloadError::Incomplete); - - // http channel should not deal with payload errors - return Err(ReaderError::Payload) - }, - Ok(Async::NotReady) => true, - Err(err) => { - payload.tx.set_error(err.into()); - - // http channel should not deal with payload errors - return Err(ReaderError::Payload) - } - _ => false, - }; - loop { - match payload.decoder.decode(buf) { - Ok(Async::Ready(Some(bytes))) => { - payload.tx.feed_data(bytes); - if payload.decoder.is_eof() { - payload.tx.feed_eof(); - break 'buf true - } - }, - Ok(Async::Ready(None)) => { - payload.tx.feed_eof(); - break 'buf true - }, - Ok(Async::NotReady) => { - // if buffer is full then - // socket still can contain more data - if not_ready { - return Ok(Async::NotReady) - } - continue 'buf - }, - Err(err) => { - payload.tx.set_error(err.into()); - return Err(ReaderError::Payload) - } - } - } - } - } else { - false - } - }; - if done { self.payload = None } - - // if buf is empty parse_message will always return NotReady, let's avoid that - if buf.is_empty() { - match utils::read_from_io(io, buf) { - Ok(Async::Ready(0)) => return Err(ReaderError::Disconnect), - Ok(Async::Ready(_)) => (), - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(err) => return Err(ReaderError::Error(err.into())) - } - }; - - loop { - match Reader::parse_message(buf, settings).map_err(ReaderError::Error)? { - Async::Ready((msg, decoder)) => { - // process payload - if let Some(mut payload) = decoder { - match self.decode(buf, &mut payload)? { - Decoding::Ready => (), - Decoding::NotReady => self.payload = Some(payload), - } - } - return Ok(Async::Ready(msg)); - }, - Async::NotReady => { - if buf.len() >= MAX_BUFFER_SIZE { - error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); - return Err(ReaderError::Error(ParseError::TooLarge)); - } - match utils::read_from_io(io, buf) { - Ok(Async::Ready(0)) => { - debug!("Ignored premature client disconnection"); - return Err(ReaderError::Disconnect); - }, - Ok(Async::Ready(_)) => (), - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(err) => return Err(ReaderError::Error(err.into())), - } - }, - } - } - } - - fn parse_message(buf: &mut BytesMut, settings: &WorkerSettings) - -> Poll<(HttpRequest, Option), ParseError> { - // Parse http message - let mut has_te = false; - let mut has_upgrade = false; - let mut has_length = false; - let msg = { - let bytes_ptr = buf.as_ref().as_ptr() as usize; - let mut headers: [httparse::Header; MAX_HEADERS] = - unsafe{std::mem::uninitialized()}; - - let (len, method, path, version, headers_len) = { - let b = unsafe{ let b: &[u8] = buf; std::mem::transmute(b) }; - let mut req = httparse::Request::new(&mut headers); - match req.parse(b)? { - httparse::Status::Complete(len) => { - let method = Method::from_bytes( - req.method.unwrap().as_bytes()) - .map_err(|_| ParseError::Method)?; - let path = Uri::try_from(req.path.unwrap())?; - let version = if req.version.unwrap() == 1 { - Version::HTTP_11 - } else { - Version::HTTP_10 - }; - (len, method, path, version, req.headers.len()) - } - httparse::Status::Partial => return Ok(Async::NotReady), - } - }; - - let slice = buf.split_to(len).freeze(); - - // convert headers - let msg = settings.get_http_message(); - { - let msg_mut = msg.get_mut(); - for header in headers[..headers_len].iter() { - if let Ok(name) = HeaderName::from_bytes(header.name.as_bytes()) { - has_te = has_te || name == header::TRANSFER_ENCODING; - has_length = has_length || name == header::CONTENT_LENGTH; - has_upgrade = has_upgrade || name == header::UPGRADE; - let v_start = header.value.as_ptr() as usize - bytes_ptr; - let v_end = v_start + header.value.len(); - let value = unsafe { - HeaderValue::from_shared_unchecked( - slice.slice(v_start, v_end)) }; - msg_mut.headers.append(name, value); - } else { - return Err(ParseError::Header) - } - } - - msg_mut.uri = path; - msg_mut.method = method; - msg_mut.version = version; - } - msg - }; - - // https://tools.ietf.org/html/rfc7230#section-3.3.3 - let decoder = if has_te && chunked(&msg.get_mut().headers)? { - // Chunked encoding - Some(Decoder::chunked()) - } else if has_length { - // Content-Length - let len = msg.get_ref().headers.get(header::CONTENT_LENGTH).unwrap(); - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - Some(Decoder::length(len)) - } else { - debug!("illegal Content-Length: {:?}", len); - return Err(ParseError::Header) - } - } else { - debug!("illegal Content-Length: {:?}", len); - return Err(ParseError::Header) - } - } else if has_upgrade || msg.get_ref().method == Method::CONNECT { - // upgrade(websocket) or connect - Some(Decoder::eof()) - } else { - None - }; - - if let Some(decoder) = decoder { - let (psender, payload) = Payload::new(false); - let info = PayloadInfo { - tx: PayloadType::new(&msg.get_ref().headers, psender), - decoder, - }; - msg.get_mut().payload = Some(payload); - Ok(Async::Ready((HttpRequest::from_message(msg), Some(info)))) - } else { - Ok(Async::Ready((HttpRequest::from_message(msg), None))) - } - } -} - -/// Check if request has chunked transfer encoding -pub fn chunked(headers: &HeaderMap) -> Result { - if let Some(encodings) = headers.get(header::TRANSFER_ENCODING) { - if let Ok(s) = encodings.to_str() { - Ok(s.to_lowercase().contains("chunked")) - } else { - Err(ParseError::Header) - } - } else { - Ok(false) - } -} - -/// Decoders to handle different Transfer-Encodings. -/// -/// If a message body does not include a Transfer-Encoding, it *should* -/// include a Content-Length header. -#[derive(Debug, Clone, PartialEq)] -pub struct Decoder { - kind: Kind, -} - -impl Decoder { - pub fn length(x: u64) -> Decoder { - Decoder { kind: Kind::Length(x) } - } - - pub fn chunked() -> Decoder { - Decoder { kind: Kind::Chunked(ChunkedState::Size, 0) } - } - - pub fn eof() -> Decoder { - Decoder { kind: Kind::Eof(false) } - } -} - -#[derive(Debug, Clone, PartialEq)] -enum Kind { - /// A Reader used when a Content-Length header is passed with a positive integer. - Length(u64), - /// A Reader used when Transfer-Encoding is `chunked`. - Chunked(ChunkedState, u64), - /// A Reader used for responses that don't indicate a length or chunked. - /// - /// Note: This should only used for `Response`s. It is illegal for a - /// `Request` to be made with both `Content-Length` and - /// `Transfer-Encoding: chunked` missing, as explained from the spec: - /// - /// > If a Transfer-Encoding header field is present in a response and - /// > the chunked transfer coding is not the final encoding, the - /// > message body length is determined by reading the connection until - /// > it is closed by the server. If a Transfer-Encoding header field - /// > is present in a request and the chunked transfer coding is not - /// > the final encoding, the message body length cannot be determined - /// > reliably; the server MUST respond with the 400 (Bad Request) - /// > status code and then close the connection. - Eof(bool), -} - -#[derive(Debug, PartialEq, Clone)] -enum ChunkedState { - Size, - SizeLws, - Extension, - SizeLf, - Body, - BodyCr, - BodyLf, - EndCr, - EndLf, - End, -} - -impl Decoder { - pub fn is_eof(&self) -> bool { - match self.kind { - Kind::Length(0) | Kind::Chunked(ChunkedState::End, _) | Kind::Eof(true) => true, - _ => false, - } - } - - pub fn decode(&mut self, body: &mut BytesMut) -> Poll, io::Error> { - match self.kind { - Kind::Length(ref mut remaining) => { - if *remaining == 0 { - Ok(Async::Ready(None)) - } else { - if body.is_empty() { - return Ok(Async::NotReady) - } - let len = body.len() as u64; - let buf; - if *remaining > len { - buf = body.take().freeze(); - *remaining -= len; - } else { - buf = body.split_to(*remaining as usize).freeze(); - *remaining = 0; - } - trace!("Length read: {}", buf.len()); - Ok(Async::Ready(Some(buf))) - } - } - Kind::Chunked(ref mut state, ref mut size) => { - loop { - let mut buf = None; - // advances the chunked state - *state = try_ready!(state.step(body, size, &mut buf)); - if *state == ChunkedState::End { - trace!("End of chunked stream"); - return Ok(Async::Ready(None)); - } - if let Some(buf) = buf { - return Ok(Async::Ready(Some(buf))); - } - if body.is_empty() { - return Ok(Async::NotReady); - } - } - } - Kind::Eof(ref mut is_eof) => { - if *is_eof { - Ok(Async::Ready(None)) - } else if !body.is_empty() { - Ok(Async::Ready(Some(body.take().freeze()))) - } else { - Ok(Async::NotReady) - } - } - } - } -} - -macro_rules! byte ( - ($rdr:ident) => ({ - if $rdr.len() > 0 { - let b = $rdr[0]; - $rdr.split_to(1); - b - } else { - return Ok(Async::NotReady) - } - }) -); - -impl ChunkedState { - fn step(&self, body: &mut BytesMut, size: &mut u64, buf: &mut Option) - -> Poll - { - use self::ChunkedState::*; - match *self { - Size => ChunkedState::read_size(body, size), - SizeLws => ChunkedState::read_size_lws(body), - Extension => ChunkedState::read_extension(body), - SizeLf => ChunkedState::read_size_lf(body, size), - Body => ChunkedState::read_body(body, size, buf), - BodyCr => ChunkedState::read_body_cr(body), - BodyLf => ChunkedState::read_body_lf(body), - EndCr => ChunkedState::read_end_cr(body), - EndLf => ChunkedState::read_end_lf(body), - End => Ok(Async::Ready(ChunkedState::End)), - } - } - fn read_size(rdr: &mut BytesMut, size: &mut u64) -> Poll { - let radix = 16; - match byte!(rdr) { - b @ b'0'...b'9' => { - *size *= radix; - *size += u64::from(b - b'0'); - } - b @ b'a'...b'f' => { - *size *= radix; - *size += u64::from(b + 10 - b'a'); - } - b @ b'A'...b'F' => { - *size *= radix; - *size += u64::from(b + 10 - b'A'); - } - b'\t' | b' ' => return Ok(Async::Ready(ChunkedState::SizeLws)), - b';' => return Ok(Async::Ready(ChunkedState::Extension)), - b'\r' => return Ok(Async::Ready(ChunkedState::SizeLf)), - _ => { - return Err(io::Error::new(io::ErrorKind::InvalidInput, - "Invalid chunk size line: Invalid Size")); - } - } - Ok(Async::Ready(ChunkedState::Size)) - } - fn read_size_lws(rdr: &mut BytesMut) -> Poll { - trace!("read_size_lws"); - match byte!(rdr) { - // LWS can follow the chunk size, but no more digits can come - b'\t' | b' ' => Ok(Async::Ready(ChunkedState::SizeLws)), - b';' => Ok(Async::Ready(ChunkedState::Extension)), - b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)), - _ => { - Err(io::Error::new(io::ErrorKind::InvalidInput, - "Invalid chunk size linear white space")) - } - } - } - fn read_extension(rdr: &mut BytesMut) -> Poll { - match byte!(rdr) { - b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)), - _ => Ok(Async::Ready(ChunkedState::Extension)), // no supported extensions - } - } - fn read_size_lf(rdr: &mut BytesMut, size: &mut u64) -> Poll { - match byte!(rdr) { - b'\n' if *size > 0 => Ok(Async::Ready(ChunkedState::Body)), - b'\n' if *size == 0 => Ok(Async::Ready(ChunkedState::EndCr)), - _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk size LF")), - } - } - - fn read_body(rdr: &mut BytesMut, rem: &mut u64, buf: &mut Option) - -> Poll - { - trace!("Chunked read, remaining={:?}", rem); - - let len = rdr.len() as u64; - if len == 0 { - Ok(Async::Ready(ChunkedState::Body)) - } else { - let slice; - if *rem > len { - slice = rdr.take().freeze(); - *rem -= len; - } else { - slice = rdr.split_to(*rem as usize).freeze(); - *rem = 0; - } - *buf = Some(slice); - if *rem > 0 { - Ok(Async::Ready(ChunkedState::Body)) - } else { - Ok(Async::Ready(ChunkedState::BodyCr)) - } - } - } - - fn read_body_cr(rdr: &mut BytesMut) -> Poll { - match byte!(rdr) { - b'\r' => Ok(Async::Ready(ChunkedState::BodyLf)), - _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk body CR")), - } - } - fn read_body_lf(rdr: &mut BytesMut) -> Poll { - match byte!(rdr) { - b'\n' => Ok(Async::Ready(ChunkedState::Size)), - _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk body LF")), - } - } - fn read_end_cr(rdr: &mut BytesMut) -> Poll { - match byte!(rdr) { - b'\r' => Ok(Async::Ready(ChunkedState::EndLf)), - _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk end CR")), - } - } - fn read_end_lf(rdr: &mut BytesMut) -> Poll { - match byte!(rdr) { - b'\n' => Ok(Async::Ready(ChunkedState::End)), - _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk end LF")), - } - } -} - -#[cfg(test)] -mod tests { - use std::{io, cmp, time}; - use std::net::Shutdown; - use bytes::{Bytes, BytesMut, Buf}; - use futures::{Async, Stream}; - use tokio_io::{AsyncRead, AsyncWrite}; - use http::{Version, Method}; - - use super::*; - use httpmessage::HttpMessage; - use application::HttpApplication; - use server::settings::WorkerSettings; - use server::{IoStream, KeepAlive}; - - struct Buffer { - buf: Bytes, - err: Option, - } - - impl Buffer { - fn new(data: &'static str) -> Buffer { - Buffer { - buf: Bytes::from(data), - err: None, - } - } - fn feed_data(&mut self, data: &'static str) { - let mut b = BytesMut::from(self.buf.as_ref()); - b.extend(data.as_bytes()); - self.buf = b.take().freeze(); - } - } - - impl AsyncRead for Buffer {} - impl io::Read for Buffer { - fn read(&mut self, dst: &mut [u8]) -> Result { - if self.buf.is_empty() { - if self.err.is_some() { - Err(self.err.take().unwrap()) - } else { - Err(io::Error::new(io::ErrorKind::WouldBlock, "")) - } - } else { - let size = cmp::min(self.buf.len(), dst.len()); - let b = self.buf.split_to(size); - dst[..size].copy_from_slice(&b); - Ok(size) - } - } - } - - impl IoStream for Buffer { - fn shutdown(&mut self, _: Shutdown) -> io::Result<()> { - Ok(()) - } - fn set_nodelay(&mut self, _: bool) -> io::Result<()> { - Ok(()) - } - fn set_linger(&mut self, _: Option) -> io::Result<()> { - Ok(()) - } - } - impl io::Write for Buffer { - fn write(&mut self, buf: &[u8]) -> io::Result {Ok(buf.len())} - fn flush(&mut self) -> io::Result<()> {Ok(())} - } - impl AsyncWrite for Buffer { - fn shutdown(&mut self) -> Poll<(), io::Error> { Ok(Async::Ready(())) } - fn write_buf(&mut self, _: &mut B) -> Poll { - Ok(Async::NotReady) - } - } - - macro_rules! not_ready { - ($e:expr) => (match $e { - Ok(Async::NotReady) => (), - Err(err) => unreachable!("Unexpected error: {:?}", err), - _ => unreachable!("Should not be ready"), - }) - } - - macro_rules! parse_ready { - ($e:expr) => ({ - let settings = WorkerSettings::::new( - Vec::new(), KeepAlive::Os); - match Reader::new().parse($e, &mut BytesMut::new(), &settings) { - Ok(Async::Ready(req)) => req, - Ok(_) => unreachable!("Eof during parsing http request"), - Err(err) => unreachable!("Error during parsing http request: {:?}", err), - } - }) - } - - macro_rules! reader_parse_ready { - ($e:expr) => ( - match $e { - Ok(Async::Ready(req)) => req, - Ok(_) => unreachable!("Eof during parsing http request"), - Err(err) => unreachable!("Error during parsing http request: {:?}", err), - } - ) - } - - macro_rules! expect_parse_err { - ($e:expr) => ({ - let mut buf = BytesMut::new(); - let settings = WorkerSettings::::new( - Vec::new(), KeepAlive::Os); - - match Reader::new().parse($e, &mut buf, &settings) { - Err(err) => match err { - ReaderError::Error(_) => (), - _ => unreachable!("Parse error expected"), - }, - _ => { - unreachable!("Error expected") - } - }} - ) - } - - #[test] - fn test_parse() { - let mut buf = Buffer::new("GET /test HTTP/1.1\r\n\r\n"); - let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new( - Vec::new(), KeepAlive::Os); - - let mut reader = Reader::new(); - match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::Ready(req)) => { - assert_eq!(req.version(), Version::HTTP_11); - assert_eq!(*req.method(), Method::GET); - assert_eq!(req.path(), "/test"); - } - Ok(_) | Err(_) => unreachable!("Error during parsing http request"), - } - } - - #[test] - fn test_parse_partial() { - let mut buf = Buffer::new("PUT /test HTTP/1"); - let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new( - Vec::new(), KeepAlive::Os); - - let mut reader = Reader::new(); - match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::NotReady) => (), - _ => unreachable!("Error"), - } - - buf.feed_data(".1\r\n\r\n"); - match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::Ready(req)) => { - assert_eq!(req.version(), Version::HTTP_11); - assert_eq!(*req.method(), Method::PUT); - assert_eq!(req.path(), "/test"); - } - Ok(_) | Err(_) => unreachable!("Error during parsing http request"), - } - } - - #[test] - fn test_parse_post() { - let mut buf = Buffer::new("POST /test2 HTTP/1.0\r\n\r\n"); - let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new( - Vec::new(), KeepAlive::Os); - - let mut reader = Reader::new(); - match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::Ready(req)) => { - assert_eq!(req.version(), Version::HTTP_10); - assert_eq!(*req.method(), Method::POST); - assert_eq!(req.path(), "/test2"); - } - Ok(_) | Err(_) => unreachable!("Error during parsing http request"), - } - } - - #[test] - fn test_parse_body() { - let mut buf = Buffer::new("GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); - let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new( - Vec::new(), KeepAlive::Os); - - let mut reader = Reader::new(); - match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::Ready(mut req)) => { - assert_eq!(req.version(), Version::HTTP_11); - assert_eq!(*req.method(), Method::GET); - assert_eq!(req.path(), "/test"); - assert_eq!(req.payload_mut().readall().unwrap().as_ref(), b"body"); - } - Ok(_) | Err(_) => unreachable!("Error during parsing http request"), - } - } - - #[test] - fn test_parse_body_crlf() { - let mut buf = Buffer::new( - "\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); - let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new( - Vec::new(), KeepAlive::Os); - - let mut reader = Reader::new(); - match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::Ready(mut req)) => { - assert_eq!(req.version(), Version::HTTP_11); - assert_eq!(*req.method(), Method::GET); - assert_eq!(req.path(), "/test"); - assert_eq!(req.payload_mut().readall().unwrap().as_ref(), b"body"); - } - Ok(_) | Err(_) => unreachable!("Error during parsing http request"), - } - } - - #[test] - fn test_parse_partial_eof() { - let mut buf = Buffer::new("GET /test HTTP/1.1\r\n"); - let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new( - Vec::new(), KeepAlive::Os); - - let mut reader = Reader::new(); - not_ready!{ reader.parse(&mut buf, &mut readbuf, &settings) } - - buf.feed_data("\r\n"); - match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::Ready(req)) => { - assert_eq!(req.version(), Version::HTTP_11); - assert_eq!(*req.method(), Method::GET); - assert_eq!(req.path(), "/test"); - } - Ok(_) | Err(_) => unreachable!("Error during parsing http request"), - } - } - - #[test] - fn test_headers_split_field() { - let mut buf = Buffer::new("GET /test HTTP/1.1\r\n"); - let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new( - Vec::new(), KeepAlive::Os); - - let mut reader = Reader::new(); - not_ready!{ reader.parse(&mut buf, &mut readbuf, &settings) } - - buf.feed_data("t"); - not_ready!{ reader.parse(&mut buf, &mut readbuf, &settings) } - - buf.feed_data("es"); - not_ready!{ reader.parse(&mut buf, &mut readbuf, &settings) } - - buf.feed_data("t: value\r\n\r\n"); - match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::Ready(req)) => { - assert_eq!(req.version(), Version::HTTP_11); - assert_eq!(*req.method(), Method::GET); - assert_eq!(req.path(), "/test"); - assert_eq!(req.headers().get("test").unwrap().as_bytes(), b"value"); - } - Ok(_) | Err(_) => unreachable!("Error during parsing http request"), - } - } - - #[test] - fn test_headers_multi_value() { - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - Set-Cookie: c1=cookie1\r\n\ - Set-Cookie: c2=cookie2\r\n\r\n"); - let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new( - Vec::new(), KeepAlive::Os); - - let mut reader = Reader::new(); - match reader.parse(&mut buf, &mut readbuf, &settings) { - Ok(Async::Ready(req)) => { - let val: Vec<_> = req.headers().get_all("Set-Cookie") - .iter().map(|v| v.to_str().unwrap().to_owned()).collect(); - assert_eq!(val[0], "c1=cookie1"); - assert_eq!(val[1], "c2=cookie2"); - } - Ok(_) | Err(_) => unreachable!("Error during parsing http request"), - } - } - - #[test] - fn test_conn_default_1_0() { - let mut buf = Buffer::new("GET /test HTTP/1.0\r\n\r\n"); - let req = parse_ready!(&mut buf); - - assert!(!req.keep_alive()); - } - - #[test] - fn test_conn_default_1_1() { - let mut buf = Buffer::new("GET /test HTTP/1.1\r\n\r\n"); - let req = parse_ready!(&mut buf); - - assert!(req.keep_alive()); - } - - #[test] - fn test_conn_close() { - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - connection: close\r\n\r\n"); - let req = parse_ready!(&mut buf); - - assert!(!req.keep_alive()); - } - - #[test] - fn test_conn_close_1_0() { - let mut buf = Buffer::new( - "GET /test HTTP/1.0\r\n\ - connection: close\r\n\r\n"); - let req = parse_ready!(&mut buf); - - assert!(!req.keep_alive()); - } - - #[test] - fn test_conn_keep_alive_1_0() { - let mut buf = Buffer::new( - "GET /test HTTP/1.0\r\n\ - connection: keep-alive\r\n\r\n"); - let req = parse_ready!(&mut buf); - - assert!(req.keep_alive()); - } - - #[test] - fn test_conn_keep_alive_1_1() { - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - connection: keep-alive\r\n\r\n"); - let req = parse_ready!(&mut buf); - - assert!(req.keep_alive()); - } - - #[test] - fn test_conn_other_1_0() { - let mut buf = Buffer::new( - "GET /test HTTP/1.0\r\n\ - connection: other\r\n\r\n"); - let req = parse_ready!(&mut buf); - - assert!(!req.keep_alive()); - } - - #[test] - fn test_conn_other_1_1() { - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - connection: other\r\n\r\n"); - let req = parse_ready!(&mut buf); - - assert!(req.keep_alive()); - } - - #[test] - fn test_conn_upgrade() { - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - upgrade: websockets\r\n\ - connection: upgrade\r\n\r\n"); - let req = parse_ready!(&mut buf); - - assert!(!req.payload().eof()); - assert!(req.upgrade()); - } - - #[test] - fn test_conn_upgrade_connect_method() { - let mut buf = Buffer::new( - "CONNECT /test HTTP/1.1\r\n\ - content-type: text/plain\r\n\r\n"); - let req = parse_ready!(&mut buf); - - assert!(req.upgrade()); - assert!(!req.payload().eof()); - } - - #[test] - fn test_request_chunked() { - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\r\n"); - let req = parse_ready!(&mut buf); - - if let Ok(val) = req.chunked() { - assert!(val); - } else { - unreachable!("Error"); - } - - // type in chunked - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - transfer-encoding: chnked\r\n\r\n"); - let req = parse_ready!(&mut buf); - - if let Ok(val) = req.chunked() { - assert!(!val); - } else { - unreachable!("Error"); - } - } - - #[test] - fn test_headers_content_length_err_1() { - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - content-length: line\r\n\r\n"); - - expect_parse_err!(&mut buf) - } - - #[test] - fn test_headers_content_length_err_2() { - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - content-length: -1\r\n\r\n"); - - expect_parse_err!(&mut buf); - } - - #[test] - fn test_invalid_header() { - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - test line\r\n\r\n"); - - expect_parse_err!(&mut buf); - } - - #[test] - fn test_invalid_name() { - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - test[]: line\r\n\r\n"); - - expect_parse_err!(&mut buf); - } - - #[test] - fn test_http_request_bad_status_line() { - let mut buf = Buffer::new("getpath \r\n\r\n"); - expect_parse_err!(&mut buf); - } - - #[test] - fn test_http_request_upgrade() { - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - connection: upgrade\r\n\ - upgrade: websocket\r\n\r\n\ - some raw data"); - let mut req = parse_ready!(&mut buf); - assert!(!req.keep_alive()); - assert!(req.upgrade()); - assert_eq!(req.payload_mut().readall().unwrap().as_ref(), b"some raw data"); - } - - #[test] - fn test_http_request_parser_utf8() { - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - x-test: теÑÑ‚\r\n\r\n"); - let req = parse_ready!(&mut buf); - - assert_eq!(req.headers().get("x-test").unwrap().as_bytes(), - "теÑÑ‚".as_bytes()); - } - - #[test] - fn test_http_request_parser_two_slashes() { - let mut buf = Buffer::new( - "GET //path HTTP/1.1\r\n\r\n"); - let req = parse_ready!(&mut buf); - - assert_eq!(req.path(), "//path"); - } - - #[test] - fn test_http_request_parser_bad_method() { - let mut buf = Buffer::new( - "!12%()+=~$ /get HTTP/1.1\r\n\r\n"); - - expect_parse_err!(&mut buf); - } - - #[test] - fn test_http_request_parser_bad_version() { - let mut buf = Buffer::new("GET //get HT/11\r\n\r\n"); - - expect_parse_err!(&mut buf); - } - - #[test] - fn test_http_request_chunked_payload() { - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\r\n"); - let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new( - Vec::new(), KeepAlive::Os); - - let mut reader = Reader::new(); - let mut req = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); - assert!(req.chunked().unwrap()); - assert!(!req.payload().eof()); - - buf.feed_data("4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); - let _ = req.payload_mut().poll(); - not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); - assert!(!req.payload().eof()); - assert_eq!(req.payload_mut().readall().unwrap().as_ref(), b"dataline"); - assert!(req.payload().eof()); - } - - #[test] - fn test_http_request_chunked_payload_and_next_message() { - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\r\n"); - let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new( - Vec::new(), KeepAlive::Os); - - let mut reader = Reader::new(); - - let mut req = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); - assert!(req.chunked().unwrap()); - assert!(!req.payload().eof()); - - buf.feed_data( - "4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n\ - POST /test2 HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\r\n"); - let _ = req.payload_mut().poll(); - - let req2 = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); - assert_eq!(*req2.method(), Method::POST); - assert!(req2.chunked().unwrap()); - assert!(!req2.payload().eof()); - - assert_eq!(req.payload_mut().readall().unwrap().as_ref(), b"dataline"); - assert!(req.payload().eof()); - } - - #[test] - fn test_http_request_chunked_payload_chunks() { - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\r\n"); - let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new( - Vec::new(), KeepAlive::Os); - - let mut reader = Reader::new(); - let mut req = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); - let _ = req.payload_mut().set_read_buffer_capacity(0); - assert!(req.chunked().unwrap()); - assert!(!req.payload().eof()); - - buf.feed_data("4\r\n1111\r\n"); - not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); - assert_eq!(req.payload_mut().readall().unwrap().as_ref(), b"1111"); - - buf.feed_data("4\r\ndata\r"); - not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); - - buf.feed_data("\n4"); - not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); - - buf.feed_data("\r"); - not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); - buf.feed_data("\n"); - not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); - - buf.feed_data("li"); - not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); - - buf.feed_data("ne\r\n0\r\n"); - not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); - - //trailers - //buf.feed_data("test: test\r\n"); - //not_ready!(reader.parse(&mut buf, &mut readbuf)); - - let _ = req.payload_mut().poll(); - not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); - - assert_eq!(req.payload_mut().readall().unwrap().as_ref(), b"dataline"); - assert!(!req.payload().eof()); - - buf.feed_data("\r\n"); - let _ = req.payload_mut().poll(); - not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); - assert!(req.payload().eof()); - } - - #[test] - fn test_parse_chunked_payload_chunk_extension() { - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\r\n"); - let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new( - Vec::new(), KeepAlive::Os); - - let mut reader = Reader::new(); - let mut req = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); - assert!(req.chunked().unwrap()); - assert!(!req.payload().eof()); - - buf.feed_data("4;test\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); // test: test\r\n\r\n") - let _ = req.payload_mut().poll(); - not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); - assert!(!req.payload().eof()); - assert_eq!(req.payload_mut().readall().unwrap().as_ref(), b"dataline"); - assert!(req.payload().eof()); - } - - /*#[test] - #[should_panic] - fn test_parse_multiline() { - let mut buf = Buffer::new( - "GET /test HTTP/1.1\r\n\ - test: line\r\n \ - continue\r\n\ - test2: data\r\n\ - \r\n", false); - - let mut reader = Reader::new(); - match reader.parse(&mut buf) { - Ok(res) => (), - Err(err) => unreachable!("{:?}", err), - } - }*/ -} diff --git a/src/server/h1writer.rs b/src/server/h1writer.rs deleted file mode 100644 index ef2a60893..000000000 --- a/src/server/h1writer.rs +++ /dev/null @@ -1,284 +0,0 @@ -#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))] - -use std::{io, mem}; -use std::rc::Rc; -use bytes::BufMut; -use futures::{Async, Poll}; -use tokio_io::AsyncWrite; -use http::{Method, Version}; -use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE}; - -use body::{Body, Binary}; -use header::ContentEncoding; -use httprequest::HttpInnerMessage; -use httpresponse::HttpResponse; -use super::helpers; -use super::{Writer, WriterState, MAX_WRITE_BUFFER_SIZE}; -use super::shared::SharedBytes; -use super::encoding::ContentEncoder; -use super::settings::WorkerSettings; - -const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific - -bitflags! { - struct Flags: u8 { - const STARTED = 0b0000_0001; - const UPGRADE = 0b0000_0010; - const KEEPALIVE = 0b0000_0100; - const DISCONNECTED = 0b0000_1000; - } -} - -pub(crate) struct H1Writer { - flags: Flags, - stream: T, - encoder: ContentEncoder, - written: u64, - headers_size: u32, - buffer: SharedBytes, - buffer_capacity: usize, - settings: Rc>, -} - -impl H1Writer { - - pub fn new(stream: T, buf: SharedBytes, settings: Rc>) - -> H1Writer - { - H1Writer { - flags: Flags::empty(), - encoder: ContentEncoder::empty(buf.clone()), - written: 0, - headers_size: 0, - buffer: buf, - buffer_capacity: 0, - stream, - settings, - } - } - - pub fn get_mut(&mut self) -> &mut T { - &mut self.stream - } - - pub fn reset(&mut self) { - self.written = 0; - self.flags = Flags::empty(); - } - - pub fn disconnected(&mut self) { - self.buffer.take(); - } - - pub fn keepalive(&self) -> bool { - self.flags.contains(Flags::KEEPALIVE) && !self.flags.contains(Flags::UPGRADE) - } - - fn write_data(&mut self, data: &[u8]) -> io::Result { - let mut written = 0; - while written < data.len() { - match self.stream.write(&data[written..]) { - Ok(0) => { - self.disconnected(); - return Err(io::Error::new(io::ErrorKind::WriteZero, "")) - }, - Ok(n) => { - written += n; - }, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - return Ok(written) - } - Err(err) => return Err(err), - } - } - Ok(written) - } -} - -impl Writer for H1Writer { - - #[inline] - fn written(&self) -> u64 { - self.written - } - - fn start(&mut self, - req: &mut HttpInnerMessage, - msg: &mut HttpResponse, - encoding: ContentEncoding) -> io::Result - { - // prepare task - self.encoder = ContentEncoder::for_server(self.buffer.clone(), req, msg, encoding); - if msg.keep_alive().unwrap_or_else(|| req.keep_alive()) { - self.flags.insert(Flags::STARTED | Flags::KEEPALIVE); - } else { - self.flags.insert(Flags::STARTED); - } - - // Connection upgrade - let version = msg.version().unwrap_or_else(|| req.version); - if msg.upgrade() { - self.flags.insert(Flags::UPGRADE); - msg.headers_mut().insert(CONNECTION, HeaderValue::from_static("upgrade")); - } - // keep-alive - else if self.flags.contains(Flags::KEEPALIVE) { - if version < Version::HTTP_11 { - msg.headers_mut().insert(CONNECTION, HeaderValue::from_static("keep-alive")); - } - } else if version >= Version::HTTP_11 { - msg.headers_mut().insert(CONNECTION, HeaderValue::from_static("close")); - } - let body = msg.replace_body(Body::Empty); - - // render message - { - let mut buffer = self.buffer.get_mut(); - let reason = msg.reason().as_bytes(); - let mut is_bin = if let Body::Binary(ref bytes) = body { - buffer.reserve( - 256 + msg.headers().len() * AVERAGE_HEADER_SIZE - + bytes.len() + reason.len()); - true - } else { - buffer.reserve( - 256 + msg.headers().len() * AVERAGE_HEADER_SIZE + reason.len()); - false - }; - - // status line - helpers::write_status_line(version, msg.status().as_u16(), &mut buffer); - SharedBytes::extend_from_slice_(buffer, reason); - - match body { - Body::Empty => - if req.method != Method::HEAD { - SharedBytes::put_slice(buffer, b"\r\ncontent-length: 0\r\n"); - } else { - SharedBytes::put_slice(buffer, b"\r\n"); - }, - Body::Binary(ref bytes) => - helpers::write_content_length(bytes.len(), &mut buffer), - _ => - SharedBytes::put_slice(buffer, b"\r\n"), - } - - // write headers - let mut pos = 0; - let mut has_date = false; - let mut remaining = buffer.remaining_mut(); - let mut buf: &mut [u8] = unsafe{ mem::transmute(buffer.bytes_mut()) }; - for (key, value) in msg.headers() { - if is_bin && key == CONTENT_LENGTH { - is_bin = false; - continue - } - has_date = has_date || key == DATE; - let v = value.as_ref(); - let k = key.as_str().as_bytes(); - let len = k.len() + v.len() + 4; - if len > remaining { - unsafe{buffer.advance_mut(pos)}; - pos = 0; - buffer.reserve(len); - remaining = buffer.remaining_mut(); - buf = unsafe{ mem::transmute(buffer.bytes_mut()) }; - } - - buf[pos..pos+k.len()].copy_from_slice(k); - pos += k.len(); - buf[pos..pos+2].copy_from_slice(b": "); - pos += 2; - buf[pos..pos+v.len()].copy_from_slice(v); - pos += v.len(); - buf[pos..pos+2].copy_from_slice(b"\r\n"); - pos += 2; - remaining -= len; - } - unsafe{buffer.advance_mut(pos)}; - - // optimized date header, set_date writes \r\n - if !has_date { - self.settings.set_date(&mut buffer); - } else { - // msg eof - SharedBytes::extend_from_slice_(buffer, b"\r\n"); - } - self.headers_size = buffer.len() as u32; - } - - if let Body::Binary(bytes) = body { - self.written = bytes.len() as u64; - self.encoder.write(bytes)?; - } else { - // capacity, makes sense only for streaming or actor - self.buffer_capacity = msg.write_buffer_capacity(); - - msg.replace_body(body); - } - Ok(WriterState::Done) - } - - fn write(&mut self, payload: Binary) -> io::Result { - self.written += payload.len() as u64; - if !self.flags.contains(Flags::DISCONNECTED) { - if self.flags.contains(Flags::STARTED) { - // shortcut for upgraded connection - if self.flags.contains(Flags::UPGRADE) { - if self.buffer.is_empty() { - let pl: &[u8] = payload.as_ref(); - let n = self.write_data(pl)?; - if n < pl.len() { - self.buffer.extend_from_slice(&pl[n..]); - return Ok(WriterState::Done); - } - } else { - self.buffer.extend(payload); - } - } else { - // TODO: add warning, write after EOF - self.encoder.write(payload)?; - } - } else { - // could be response to EXCEPT header - self.buffer.extend_from_slice(payload.as_ref()) - } - } - - if self.buffer.len() > self.buffer_capacity { - Ok(WriterState::Pause) - } else { - Ok(WriterState::Done) - } - } - - fn write_eof(&mut self) -> io::Result { - self.encoder.write_eof()?; - - if !self.encoder.is_eof() { - Err(io::Error::new(io::ErrorKind::Other, - "Last payload item, but eof is not reached")) - } else if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { - Ok(WriterState::Pause) - } else { - Ok(WriterState::Done) - } - } - - #[inline] - fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> { - if !self.buffer.is_empty() { - let buf: &[u8] = unsafe{mem::transmute(self.buffer.as_ref())}; - let written = self.write_data(buf)?; - let _ = self.buffer.split_to(written); - if self.buffer.len() > self.buffer_capacity { - return Ok(Async::NotReady) - } - } - if shutdown { - self.stream.shutdown() - } else { - Ok(Async::Ready(())) - } - } -} diff --git a/src/server/h2.rs b/src/server/h2.rs deleted file mode 100644 index 77ddf0847..000000000 --- a/src/server/h2.rs +++ /dev/null @@ -1,399 +0,0 @@ -#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))] - -use std::{io, cmp, mem}; -use std::rc::Rc; -use std::io::{Read, Write}; -use std::time::Duration; -use std::net::SocketAddr; -use std::collections::VecDeque; - -use actix::Arbiter; -use modhttp::request::Parts; -use http2::{Reason, RecvStream}; -use http2::server::{self, Connection, Handshake, SendResponse}; -use bytes::{Buf, Bytes}; -use futures::{Async, Poll, Future, Stream}; -use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_core::reactor::Timeout; - -use pipeline::Pipeline; -use error::PayloadError; -use httpmessage::HttpMessage; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; -use payload::{Payload, PayloadWriter, PayloadStatus}; - -use super::h2writer::H2Writer; -use super::encoding::PayloadType; -use super::settings::WorkerSettings; -use super::{HttpHandler, HttpHandlerTask, Writer}; - -bitflags! { - struct Flags: u8 { - const DISCONNECTED = 0b0000_0010; - } -} - -/// HTTP/2 Transport -pub(crate) -struct Http2 - where T: AsyncRead + AsyncWrite + 'static, H: 'static -{ - flags: Flags, - settings: Rc>, - addr: Option, - state: State>, - tasks: VecDeque>, - keepalive_timer: Option, -} - -enum State { - Handshake(Handshake), - Connection(Connection), - Empty, -} - -impl Http2 - where T: AsyncRead + AsyncWrite + 'static, - H: HttpHandler + 'static -{ - pub fn new(settings: Rc>, - io: T, - addr: Option, buf: Bytes) -> Self - { - Http2{ flags: Flags::empty(), - tasks: VecDeque::new(), - state: State::Handshake( - server::handshake(IoWrapper{unread: Some(buf), inner: io})), - keepalive_timer: None, - addr, - settings, - } - } - - pub(crate) fn shutdown(&mut self) { - self.state = State::Empty; - self.tasks.clear(); - self.keepalive_timer.take(); - } - - pub fn settings(&self) -> &WorkerSettings { - self.settings.as_ref() - } - - pub fn poll(&mut self) -> Poll<(), ()> { - // server - if let State::Connection(ref mut conn) = self.state { - // keep-alive timer - if let Some(ref mut timeout) = self.keepalive_timer { - match timeout.poll() { - Ok(Async::Ready(_)) => { - trace!("Keep-alive timeout, close connection"); - return Ok(Async::Ready(())) - } - Ok(Async::NotReady) => (), - Err(_) => unreachable!(), - } - } - - loop { - let mut not_ready = true; - - // check in-flight connections - for item in &mut self.tasks { - // read payload - item.poll_payload(); - - if !item.flags.contains(EntryFlags::EOF) { - let retry = item.payload.need_read() == PayloadStatus::Read; - loop { - match item.task.poll_io(&mut item.stream) { - Ok(Async::Ready(ready)) => { - if ready { - item.flags.insert( - EntryFlags::EOF | EntryFlags::FINISHED); - } else { - item.flags.insert(EntryFlags::EOF); - } - not_ready = false; - }, - Ok(Async::NotReady) => { - if item.payload.need_read() == PayloadStatus::Read - && !retry - { - continue - } - }, - Err(err) => { - error!("Unhandled error: {}", err); - item.flags.insert( - EntryFlags::EOF | - EntryFlags::ERROR | - EntryFlags::WRITE_DONE); - item.stream.reset(Reason::INTERNAL_ERROR); - } - } - break - } - } else if !item.flags.contains(EntryFlags::FINISHED) { - match item.task.poll() { - Ok(Async::NotReady) => (), - Ok(Async::Ready(_)) => { - not_ready = false; - item.flags.insert(EntryFlags::FINISHED); - }, - Err(err) => { - item.flags.insert( - EntryFlags::ERROR | EntryFlags::WRITE_DONE | - EntryFlags::FINISHED); - error!("Unhandled error: {}", err); - } - } - } - - if !item.flags.contains(EntryFlags::WRITE_DONE) { - match item.stream.poll_completed(false) { - Ok(Async::NotReady) => (), - Ok(Async::Ready(_)) => { - not_ready = false; - item.flags.insert(EntryFlags::WRITE_DONE); - } - Err(_err) => { - item.flags.insert(EntryFlags::ERROR); - } - } - } - } - - // cleanup finished tasks - while !self.tasks.is_empty() { - if self.tasks[0].flags.contains(EntryFlags::EOF) && - self.tasks[0].flags.contains(EntryFlags::WRITE_DONE) || - self.tasks[0].flags.contains(EntryFlags::ERROR) - { - self.tasks.pop_front(); - } else { - break - } - } - - // get request - if !self.flags.contains(Flags::DISCONNECTED) { - match conn.poll() { - Ok(Async::Ready(None)) => { - not_ready = false; - self.flags.insert(Flags::DISCONNECTED); - for entry in &mut self.tasks { - entry.task.disconnected() - } - }, - Ok(Async::Ready(Some((req, resp)))) => { - not_ready = false; - let (parts, body) = req.into_parts(); - - // stop keepalive timer - self.keepalive_timer.take(); - - self.tasks.push_back( - Entry::new(parts, body, resp, self.addr, &self.settings)); - } - Ok(Async::NotReady) => { - // start keep-alive timer - if self.tasks.is_empty() { - if self.settings.keep_alive_enabled() { - let keep_alive = self.settings.keep_alive(); - if keep_alive > 0 && self.keepalive_timer.is_none() { - trace!("Start keep-alive timer"); - let mut timeout = Timeout::new( - Duration::new(keep_alive, 0), - Arbiter::handle()).unwrap(); - // register timeout - let _ = timeout.poll(); - self.keepalive_timer = Some(timeout); - } - } else { - // keep-alive disable, drop connection - return conn.poll_close().map_err( - |e| error!("Error during connection close: {}", e)) - } - } else { - // keep-alive unset, rely on operating system - return Ok(Async::NotReady) - } - } - Err(err) => { - trace!("Connection error: {}", err); - self.flags.insert(Flags::DISCONNECTED); - for entry in &mut self.tasks { - entry.task.disconnected() - } - self.keepalive_timer.take(); - }, - } - } - - if not_ready { - if self.tasks.is_empty() && self.flags.contains(Flags::DISCONNECTED) { - return conn.poll_close().map_err( - |e| error!("Error during connection close: {}", e)) - } else { - return Ok(Async::NotReady) - } - } - } - } - - // handshake - self.state = if let State::Handshake(ref mut handshake) = self.state { - match handshake.poll() { - Ok(Async::Ready(conn)) => { - State::Connection(conn) - }, - Ok(Async::NotReady) => - return Ok(Async::NotReady), - Err(err) => { - trace!("Error handling connection: {}", err); - return Err(()) - } - } - } else { - mem::replace(&mut self.state, State::Empty) - }; - - self.poll() - } -} - -bitflags! { - struct EntryFlags: u8 { - const EOF = 0b0000_0001; - const REOF = 0b0000_0010; - const ERROR = 0b0000_0100; - const FINISHED = 0b0000_1000; - const WRITE_DONE = 0b0001_0000; - } -} - -struct Entry { - task: Box, - payload: PayloadType, - recv: RecvStream, - stream: H2Writer, - flags: EntryFlags, -} - -impl Entry { - fn new(parts: Parts, - recv: RecvStream, - resp: SendResponse, - addr: Option, - settings: &Rc>) -> Entry - where H: HttpHandler + 'static - { - // Payload and Content-Encoding - let (psender, payload) = Payload::new(false); - - let msg = settings.get_http_message(); - msg.get_mut().uri = parts.uri; - msg.get_mut().method = parts.method; - msg.get_mut().version = parts.version; - msg.get_mut().headers = parts.headers; - msg.get_mut().payload = Some(payload); - msg.get_mut().addr = addr; - - let mut req = HttpRequest::from_message(msg); - - // Payload sender - let psender = PayloadType::new(req.headers(), psender); - - // start request processing - let mut task = None; - for h in settings.handlers().iter_mut() { - req = match h.handle(req) { - Ok(t) => { - task = Some(t); - break - }, - Err(req) => req, - } - } - - Entry {task: task.unwrap_or_else(|| Pipeline::error(HttpResponse::NotFound())), - payload: psender, - stream: H2Writer::new( - resp, settings.get_shared_bytes(), Rc::clone(settings)), - flags: EntryFlags::empty(), - recv, - } - } - - fn poll_payload(&mut self) { - if !self.flags.contains(EntryFlags::REOF) { - if self.payload.need_read() == PayloadStatus::Read { - if let Err(err) = self.recv.release_capacity().release_capacity(32_768) { - self.payload.set_error(PayloadError::Http2(err)) - } - } else if let Err(err) = self.recv.release_capacity().release_capacity(0) { - self.payload.set_error(PayloadError::Http2(err)) - } - - match self.recv.poll() { - Ok(Async::Ready(Some(chunk))) => { - self.payload.feed_data(chunk); - }, - Ok(Async::Ready(None)) => { - self.flags.insert(EntryFlags::REOF); - }, - Ok(Async::NotReady) => (), - Err(err) => { - self.payload.set_error(PayloadError::Http2(err)) - } - } - } - } -} - -struct IoWrapper { - unread: Option, - inner: T, -} - -impl Read for IoWrapper { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - if let Some(mut bytes) = self.unread.take() { - let size = cmp::min(buf.len(), bytes.len()); - buf[..size].copy_from_slice(&bytes[..size]); - if bytes.len() > size { - bytes.split_to(size); - self.unread = Some(bytes); - } - Ok(size) - } else { - self.inner.read(buf) - } - } -} - -impl Write for IoWrapper { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.write(buf) - } - fn flush(&mut self) -> io::Result<()> { - self.inner.flush() - } -} - -impl AsyncRead for IoWrapper { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } -} - -impl AsyncWrite for IoWrapper { - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.inner.shutdown() - } - fn write_buf(&mut self, buf: &mut B) -> Poll { - self.inner.write_buf(buf) - } -} diff --git a/src/server/h2writer.rs b/src/server/h2writer.rs deleted file mode 100644 index 10deadaf0..000000000 --- a/src/server/h2writer.rs +++ /dev/null @@ -1,219 +0,0 @@ -#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))] - -use std::{io, cmp}; -use std::rc::Rc; -use bytes::{Bytes, BytesMut}; -use futures::{Async, Poll}; -use http2::{Reason, SendStream}; -use http2::server::SendResponse; -use modhttp::Response; - -use http::{Version, HttpTryFrom}; -use http::header::{HeaderValue, CONNECTION, TRANSFER_ENCODING, DATE, CONTENT_LENGTH}; - -use body::{Body, Binary}; -use header::ContentEncoding; -use httprequest::HttpInnerMessage; -use httpresponse::HttpResponse; -use super::helpers; -use super::encoding::ContentEncoder; -use super::shared::SharedBytes; -use super::settings::WorkerSettings; -use super::{Writer, WriterState, MAX_WRITE_BUFFER_SIZE}; - -const CHUNK_SIZE: usize = 16_384; - -bitflags! { - struct Flags: u8 { - const STARTED = 0b0000_0001; - const DISCONNECTED = 0b0000_0010; - const EOF = 0b0000_0100; - const RESERVED = 0b0000_1000; - } -} - -pub(crate) struct H2Writer { - respond: SendResponse, - stream: Option>, - encoder: ContentEncoder, - flags: Flags, - written: u64, - buffer: SharedBytes, - buffer_capacity: usize, - settings: Rc>, -} - -impl H2Writer { - - pub fn new(respond: SendResponse, - buf: SharedBytes, settings: Rc>) -> H2Writer - { - H2Writer { - respond, - settings, - stream: None, - encoder: ContentEncoder::empty(buf.clone()), - flags: Flags::empty(), - written: 0, - buffer: buf, - buffer_capacity: 0, - } - } - - pub fn reset(&mut self, reason: Reason) { - if let Some(mut stream) = self.stream.take() { - stream.send_reset(reason) - } - } -} - -impl Writer for H2Writer { - - fn written(&self) -> u64 { - self.written - } - - fn start(&mut self, - req: &mut HttpInnerMessage, - msg: &mut HttpResponse, - encoding: ContentEncoding) -> io::Result - { - // prepare response - self.flags.insert(Flags::STARTED); - self.encoder = ContentEncoder::for_server(self.buffer.clone(), req, msg, encoding); - if let Body::Empty = *msg.body() { - self.flags.insert(Flags::EOF); - } - - // http2 specific - msg.headers_mut().remove(CONNECTION); - msg.headers_mut().remove(TRANSFER_ENCODING); - - // using helpers::date is quite a lot faster - if !msg.headers().contains_key(DATE) { - let mut bytes = BytesMut::with_capacity(29); - self.settings.set_date_simple(&mut bytes); - msg.headers_mut().insert(DATE, HeaderValue::try_from(bytes.freeze()).unwrap()); - } - - let body = msg.replace_body(Body::Empty); - match body { - Body::Binary(ref bytes) => { - let mut val = BytesMut::new(); - helpers::convert_usize(bytes.len(), &mut val); - let l = val.len(); - msg.headers_mut().insert( - CONTENT_LENGTH, - HeaderValue::try_from(val.split_to(l-2).freeze()).unwrap()); - } - Body::Empty => { - msg.headers_mut().insert(CONTENT_LENGTH, HeaderValue::from_static("0")); - }, - _ => (), - } - - let mut resp = Response::new(()); - *resp.status_mut() = msg.status(); - *resp.version_mut() = Version::HTTP_2; - for (key, value) in msg.headers().iter() { - resp.headers_mut().insert(key, value.clone()); - } - - match self.respond.send_response(resp, self.flags.contains(Flags::EOF)) { - Ok(stream) => - self.stream = Some(stream), - Err(_) => - return Err(io::Error::new(io::ErrorKind::Other, "err")), - } - - trace!("Response: {:?}", msg); - - if let Body::Binary(bytes) = body { - self.flags.insert(Flags::EOF); - self.written = bytes.len() as u64; - self.encoder.write(bytes)?; - if let Some(ref mut stream) = self.stream { - self.flags.insert(Flags::RESERVED); - stream.reserve_capacity(cmp::min(self.buffer.len(), CHUNK_SIZE)); - } - Ok(WriterState::Pause) - } else { - msg.replace_body(body); - self.buffer_capacity = msg.write_buffer_capacity(); - Ok(WriterState::Done) - } - } - - fn write(&mut self, payload: Binary) -> io::Result { - self.written = payload.len() as u64; - - if !self.flags.contains(Flags::DISCONNECTED) { - if self.flags.contains(Flags::STARTED) { - // TODO: add warning, write after EOF - self.encoder.write(payload)?; - } else { - // might be response for EXCEPT - self.buffer.extend_from_slice(payload.as_ref()) - } - } - - if self.buffer.len() > self.buffer_capacity { - Ok(WriterState::Pause) - } else { - Ok(WriterState::Done) - } - } - - fn write_eof(&mut self) -> io::Result { - self.encoder.write_eof()?; - - self.flags.insert(Flags::EOF); - if !self.encoder.is_eof() { - Err(io::Error::new(io::ErrorKind::Other, - "Last payload item, but eof is not reached")) - } else if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { - Ok(WriterState::Pause) - } else { - Ok(WriterState::Done) - } - } - - fn poll_completed(&mut self, _shutdown: bool) -> Poll<(), io::Error> { - if !self.flags.contains(Flags::STARTED) { - return Ok(Async::NotReady); - } - - if let Some(ref mut stream) = self.stream { - // reserve capacity - if !self.flags.contains(Flags::RESERVED) && !self.buffer.is_empty() { - self.flags.insert(Flags::RESERVED); - stream.reserve_capacity(cmp::min(self.buffer.len(), CHUNK_SIZE)); - } - - loop { - match stream.poll_capacity() { - Ok(Async::NotReady) => return Ok(Async::NotReady), - Ok(Async::Ready(None)) => return Ok(Async::Ready(())), - Ok(Async::Ready(Some(cap))) => { - let len = self.buffer.len(); - let bytes = self.buffer.split_to(cmp::min(cap, len)); - let eof = self.buffer.is_empty() && self.flags.contains(Flags::EOF); - self.written += bytes.len() as u64; - - if let Err(e) = stream.send_data(bytes.freeze(), eof) { - return Err(io::Error::new(io::ErrorKind::Other, e)) - } else if !self.buffer.is_empty() { - let cap = cmp::min(self.buffer.len(), CHUNK_SIZE); - stream.reserve_capacity(cap); - } else { - self.flags.remove(Flags::RESERVED); - return Ok(Async::NotReady) - } - } - Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)), - } - } - } - Ok(Async::NotReady) - } -} diff --git a/src/server/helpers.rs b/src/server/helpers.rs deleted file mode 100644 index c50317a9d..000000000 --- a/src/server/helpers.rs +++ /dev/null @@ -1,263 +0,0 @@ -use std::{mem, ptr, slice}; -use std::cell::RefCell; -use std::rc::Rc; -use std::collections::VecDeque; -use bytes::{BufMut, BytesMut}; -use http::Version; - -use httprequest::HttpInnerMessage; - -/// Internal use only! unsafe -pub(crate) struct SharedMessagePool(RefCell>>); - -impl SharedMessagePool { - pub fn new() -> SharedMessagePool { - SharedMessagePool(RefCell::new(VecDeque::with_capacity(128))) - } - - #[inline] - pub fn get(&self) -> Rc { - if let Some(msg) = self.0.borrow_mut().pop_front() { - msg - } else { - Rc::new(HttpInnerMessage::default()) - } - } - - #[inline] - pub fn release(&self, mut msg: Rc) { - let v = &mut self.0.borrow_mut(); - if v.len() < 128 { - Rc::get_mut(&mut msg).unwrap().reset(); - v.push_front(msg); - } - } -} - -pub(crate) struct SharedHttpInnerMessage( - Option>, Option>); - -impl Drop for SharedHttpInnerMessage { - fn drop(&mut self) { - if let Some(ref pool) = self.1 { - if let Some(msg) = self.0.take() { - if Rc::strong_count(&msg) == 1 { - pool.release(msg); - } - } - } - } -} - -impl Clone for SharedHttpInnerMessage { - - fn clone(&self) -> SharedHttpInnerMessage { - SharedHttpInnerMessage(self.0.clone(), self.1.clone()) - } -} - -impl Default for SharedHttpInnerMessage { - - fn default() -> SharedHttpInnerMessage { - SharedHttpInnerMessage(Some(Rc::new(HttpInnerMessage::default())), None) - } -} - -impl SharedHttpInnerMessage { - - pub fn from_message(msg: HttpInnerMessage) -> SharedHttpInnerMessage { - SharedHttpInnerMessage(Some(Rc::new(msg)), None) - } - - pub fn new(msg: Rc, pool: Rc) -> SharedHttpInnerMessage { - SharedHttpInnerMessage(Some(msg), Some(pool)) - } - - #[inline(always)] - #[allow(mutable_transmutes)] - #[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref, inline_always))] - pub fn get_mut(&self) -> &mut HttpInnerMessage { - let r: &HttpInnerMessage = self.0.as_ref().unwrap().as_ref(); - unsafe{mem::transmute(r)} - } - - #[inline(always)] - #[cfg_attr(feature = "cargo-clippy", allow(inline_always))] - pub fn get_ref(&self) -> &HttpInnerMessage { - self.0.as_ref().unwrap() - } -} - -const DEC_DIGITS_LUT: &[u8] = - b"0001020304050607080910111213141516171819\ - 2021222324252627282930313233343536373839\ - 4041424344454647484950515253545556575859\ - 6061626364656667686970717273747576777879\ - 8081828384858687888990919293949596979899"; - -pub(crate) fn write_status_line(version: Version, mut n: u16, bytes: &mut BytesMut) { - let mut buf: [u8; 13] = [b'H', b'T', b'T', b'P', b'/', b'1', b'.', b'1', - b' ', b' ', b' ', b' ', b' ']; - match version { - Version::HTTP_2 => buf[5] = b'2', - Version::HTTP_10 => buf[7] = b'0', - Version::HTTP_09 => {buf[5] = b'0'; buf[7] = b'9';}, - _ => (), - } - - let mut curr: isize = 12; - let buf_ptr = buf.as_mut_ptr(); - let lut_ptr = DEC_DIGITS_LUT.as_ptr(); - let four = n > 999; - - unsafe { - // decode 2 more chars, if > 2 chars - let d1 = (n % 100) << 1; - n /= 100; - curr -= 2; - ptr::copy_nonoverlapping(lut_ptr.offset(d1 as isize), buf_ptr.offset(curr), 2); - - // decode last 1 or 2 chars - if n < 10 { - curr -= 1; - *buf_ptr.offset(curr) = (n as u8) + b'0'; - } else { - let d1 = n << 1; - curr -= 2; - ptr::copy_nonoverlapping(lut_ptr.offset(d1 as isize), buf_ptr.offset(curr), 2); - } - } - - bytes.put_slice(&buf); - if four { - bytes.put(b' '); - } -} - -/// NOTE: bytes object has to contain enough space -pub(crate) fn write_content_length(mut n: usize, bytes: &mut BytesMut) { - if n < 10 { - let mut buf: [u8; 21] = [b'\r',b'\n',b'c',b'o',b'n',b't',b'e', - b'n',b't',b'-',b'l',b'e',b'n',b'g', - b't',b'h',b':',b' ',b'0',b'\r',b'\n']; - buf[18] = (n as u8) + b'0'; - bytes.put_slice(&buf); - } else if n < 100 { - let mut buf: [u8; 22] = [b'\r',b'\n',b'c',b'o',b'n',b't',b'e', - b'n',b't',b'-',b'l',b'e',b'n',b'g', - b't',b'h',b':',b' ',b'0',b'0',b'\r',b'\n']; - let d1 = n << 1; - unsafe { - ptr::copy_nonoverlapping( - DEC_DIGITS_LUT.as_ptr().offset(d1 as isize), buf.as_mut_ptr().offset(18), 2); - } - bytes.put_slice(&buf); - } else if n < 1000 { - let mut buf: [u8; 23] = [b'\r',b'\n',b'c',b'o',b'n',b't',b'e', - b'n',b't',b'-',b'l',b'e',b'n',b'g', - b't',b'h',b':',b' ',b'0',b'0',b'0',b'\r',b'\n']; - // decode 2 more chars, if > 2 chars - let d1 = (n % 100) << 1; - n /= 100; - unsafe {ptr::copy_nonoverlapping( - DEC_DIGITS_LUT.as_ptr().offset(d1 as isize), buf.as_mut_ptr().offset(19), 2)}; - - // decode last 1 - buf[18] = (n as u8) + b'0'; - - bytes.put_slice(&buf); - } else { - bytes.put_slice(b"\r\ncontent-length: "); - convert_usize(n, bytes); - } -} - -pub(crate) fn convert_usize(mut n: usize, bytes: &mut BytesMut) { - let mut curr: isize = 39; - let mut buf: [u8; 41] = unsafe { mem::uninitialized() }; - buf[39] = b'\r'; - buf[40] = b'\n'; - let buf_ptr = buf.as_mut_ptr(); - let lut_ptr = DEC_DIGITS_LUT.as_ptr(); - - unsafe { - // eagerly decode 4 characters at a time - while n >= 10_000 { - let rem = (n % 10_000) as isize; - n /= 10_000; - - let d1 = (rem / 100) << 1; - let d2 = (rem % 100) << 1; - curr -= 4; - ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2); - ptr::copy_nonoverlapping(lut_ptr.offset(d2), buf_ptr.offset(curr + 2), 2); - } - - // if we reach here numbers are <= 9999, so at most 4 chars long - let mut n = n as isize; // possibly reduce 64bit math - - // decode 2 more chars, if > 2 chars - if n >= 100 { - let d1 = (n % 100) << 1; - n /= 100; - curr -= 2; - ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2); - } - - // decode last 1 or 2 chars - if n < 10 { - curr -= 1; - *buf_ptr.offset(curr) = (n as u8) + b'0'; - } else { - let d1 = n << 1; - curr -= 2; - ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2); - } - } - - unsafe { - bytes.extend_from_slice( - slice::from_raw_parts(buf_ptr.offset(curr), 41 - curr as usize)); - } -} - - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_write_content_length() { - let mut bytes = BytesMut::new(); - bytes.reserve(50); - write_content_length(0, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 0\r\n"[..]); - bytes.reserve(50); - write_content_length(9, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 9\r\n"[..]); - bytes.reserve(50); - write_content_length(10, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 10\r\n"[..]); - bytes.reserve(50); - write_content_length(99, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 99\r\n"[..]); - bytes.reserve(50); - write_content_length(100, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 100\r\n"[..]); - bytes.reserve(50); - write_content_length(101, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 101\r\n"[..]); - bytes.reserve(50); - write_content_length(998, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 998\r\n"[..]); - bytes.reserve(50); - write_content_length(1000, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 1000\r\n"[..]); - bytes.reserve(50); - write_content_length(1001, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 1001\r\n"[..]); - bytes.reserve(50); - write_content_length(5909, &mut bytes); - assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 5909\r\n"[..]); - } -} diff --git a/src/server/mod.rs b/src/server/mod.rs deleted file mode 100644 index 96f53c5ff..000000000 --- a/src/server/mod.rs +++ /dev/null @@ -1,250 +0,0 @@ -//! Http server -use std::{time, io}; -use std::net::Shutdown; - -use actix; -use futures::Poll; -use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_core::net::TcpStream; - -mod srv; -mod worker; -mod channel; -pub(crate) mod encoding; -pub(crate) mod h1; -mod h2; -mod h1writer; -mod h2writer; -mod settings; -pub(crate) mod helpers; -pub(crate) mod shared; -pub(crate) mod utils; - -pub use self::srv::HttpServer; -pub use self::settings::ServerSettings; - -use body::Binary; -use error::Error; -use header::ContentEncoding; -use httprequest::{HttpInnerMessage, HttpRequest}; -use httpresponse::HttpResponse; - -/// max buffer size 64k -pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = 65_536; - -/// Create new http server with application factory -/// -/// ```rust -/// # extern crate actix; -/// # extern crate actix_web; -/// use actix::*; -/// use actix_web::{server, App, HttpResponse}; -/// -/// fn main() { -/// let sys = actix::System::new("guide"); -/// -/// server::new( -/// || App::new() -/// .resource("/", |r| r.f(|_| HttpResponse::Ok()))) -/// .bind("127.0.0.1:59080").unwrap() -/// .start(); -/// -/// # actix::Arbiter::system().do_send(actix::msgs::SystemExit(0)); -/// let _ = sys.run(); -/// } -/// ``` -pub fn new(factory: F) -> HttpServer - where F: Fn() -> U + Sync + Send + 'static, - U: IntoIterator + 'static, - H: IntoHttpHandler + 'static -{ - HttpServer::new(factory) -} - -#[derive(Debug, PartialEq, Clone, Copy)] -/// Server keep-alive setting -pub enum KeepAlive { - /// Keep alive in seconds - Timeout(usize), - /// Use `SO_KEEPALIVE` socket option, value in seconds - Tcp(usize), - /// Relay on OS to shutdown tcp connection - Os, - /// Disabled - Disabled, -} - -impl From for KeepAlive { - fn from(keepalive: usize) -> Self { - KeepAlive::Timeout(keepalive) - } -} - -impl From> for KeepAlive { - fn from(keepalive: Option) -> Self { - if let Some(keepalive) = keepalive { - KeepAlive::Timeout(keepalive) - } else { - KeepAlive::Disabled - } - } -} - -/// Pause accepting incoming connections -/// -/// If socket contains some pending connection, they might be dropped. -/// All opened connection remains active. -#[derive(Message)] -pub struct PauseServer; - -/// Resume accepting incoming connections -#[derive(Message)] -pub struct ResumeServer; - -/// Stop incoming connection processing, stop all workers and exit. -/// -/// If server starts with `spawn()` method, then spawned thread get terminated. -pub struct StopServer { - pub graceful: bool -} - -impl actix::Message for StopServer { - type Result = Result<(), ()>; -} - -/// Low level http request handler -#[allow(unused_variables)] -pub trait HttpHandler: 'static { - - /// Handle request - fn handle(&mut self, req: HttpRequest) -> Result, HttpRequest>; -} - -impl HttpHandler for Box { - fn handle(&mut self, req: HttpRequest) -> Result, HttpRequest> { - self.as_mut().handle(req) - } -} - -#[doc(hidden)] -pub trait HttpHandlerTask { - - /// Poll task, this method is used before or after *io* object is available - fn poll(&mut self) -> Poll<(), Error>; - - /// Poll task when *io* object is available - fn poll_io(&mut self, io: &mut Writer) -> Poll; - - /// Connection is disconnected - fn disconnected(&mut self); -} - -/// Conversion helper trait -pub trait IntoHttpHandler { - /// The associated type which is result of conversion. - type Handler: HttpHandler; - - /// Convert into `HttpHandler` object. - fn into_handler(self, settings: ServerSettings) -> Self::Handler; -} - -impl IntoHttpHandler for T { - type Handler = T; - - fn into_handler(self, _: ServerSettings) -> Self::Handler { - self - } -} - -#[doc(hidden)] -#[derive(Debug)] -pub enum WriterState { - Done, - Pause, -} - -#[doc(hidden)] -/// Stream writer -pub trait Writer { - fn written(&self) -> u64; - - fn start(&mut self, req: &mut HttpInnerMessage, resp: &mut HttpResponse, encoding: ContentEncoding) - -> io::Result; - - fn write(&mut self, payload: Binary) -> io::Result; - - fn write_eof(&mut self) -> io::Result; - - fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error>; -} - -#[doc(hidden)] -/// Low-level io stream operations -pub trait IoStream: AsyncRead + AsyncWrite + 'static { - fn shutdown(&mut self, how: Shutdown) -> io::Result<()>; - - fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()>; - - fn set_linger(&mut self, dur: Option) -> io::Result<()>; -} - -impl IoStream for TcpStream { - #[inline] - fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { - TcpStream::shutdown(self, how) - } - - #[inline] - fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { - TcpStream::set_nodelay(self, nodelay) - } - - #[inline] - fn set_linger(&mut self, dur: Option) -> io::Result<()> { - TcpStream::set_linger(self, dur) - } -} - -#[cfg(feature="alpn")] -use tokio_openssl::SslStream; - -#[cfg(feature="alpn")] -impl IoStream for SslStream { - #[inline] - fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { - let _ = self.get_mut().shutdown(); - Ok(()) - } - - #[inline] - fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { - self.get_mut().get_mut().set_nodelay(nodelay) - } - - #[inline] - fn set_linger(&mut self, dur: Option) -> io::Result<()> { - self.get_mut().get_mut().set_linger(dur) - } -} - -#[cfg(feature="tls")] -use tokio_tls::TlsStream; - -#[cfg(feature="tls")] -impl IoStream for TlsStream { - #[inline] - fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { - let _ = self.get_mut().shutdown(); - Ok(()) - } - - #[inline] - fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { - self.get_mut().get_mut().set_nodelay(nodelay) - } - - #[inline] - fn set_linger(&mut self, dur: Option) -> io::Result<()> { - self.get_mut().get_mut().set_linger(dur) - } -} diff --git a/src/server/settings.rs b/src/server/settings.rs deleted file mode 100644 index 07b000429..000000000 --- a/src/server/settings.rs +++ /dev/null @@ -1,257 +0,0 @@ -use std::{fmt, mem, net}; -use std::fmt::Write; -use std::rc::Rc; -use std::sync::Arc; -use std::cell::{Cell, RefCell, RefMut, UnsafeCell}; -use time; -use bytes::BytesMut; -use http::StatusCode; -use futures_cpupool::{Builder, CpuPool}; - -use super::helpers; -use super::KeepAlive; -use super::channel::Node; -use super::shared::{SharedBytes, SharedBytesPool}; -use body::Body; -use httpresponse::{HttpResponse, HttpResponsePool, HttpResponseBuilder}; - -/// Various server settings -#[derive(Clone)] -pub struct ServerSettings { - addr: Option, - secure: bool, - host: String, - cpu_pool: Arc, - responses: Rc>, -} - -unsafe impl Sync for ServerSettings {} -unsafe impl Send for ServerSettings {} - -struct InnerCpuPool { - cpu_pool: UnsafeCell>, -} - -impl fmt::Debug for InnerCpuPool { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "CpuPool") - } -} - -impl InnerCpuPool { - fn new() -> Self { - InnerCpuPool { - cpu_pool: UnsafeCell::new(None), - } - } - fn cpu_pool(&self) -> &CpuPool { - unsafe { - let val = &mut *self.cpu_pool.get(); - if val.is_none() { - *val = Some(Builder::new().create()); - } - val.as_ref().unwrap() - } - } -} - -unsafe impl Sync for InnerCpuPool {} - -impl Default for ServerSettings { - fn default() -> Self { - ServerSettings { - addr: None, - secure: false, - host: "localhost:8080".to_owned(), - responses: HttpResponsePool::pool(), - cpu_pool: Arc::new(InnerCpuPool::new()), - } - } -} - -impl ServerSettings { - /// Crate server settings instance - pub(crate) fn new(addr: Option, host: &Option, secure: bool) - -> ServerSettings - { - let host = if let Some(ref host) = *host { - host.clone() - } else if let Some(ref addr) = addr { - format!("{}", addr) - } else { - "localhost".to_owned() - }; - let cpu_pool = Arc::new(InnerCpuPool::new()); - let responses = HttpResponsePool::pool(); - ServerSettings { addr, secure, host, cpu_pool, responses } - } - - /// Returns the socket address of the local half of this TCP connection - pub fn local_addr(&self) -> Option { - self.addr - } - - /// Returns true if connection is secure(https) - pub fn secure(&self) -> bool { - self.secure - } - - /// Returns host header value - pub fn host(&self) -> &str { - &self.host - } - - /// Returns default `CpuPool` for server - pub fn cpu_pool(&self) -> &CpuPool { - self.cpu_pool.cpu_pool() - } - - #[inline] - pub(crate) fn get_response(&self, status: StatusCode, body: Body) -> HttpResponse { - HttpResponsePool::get_response(&self.responses, status, body) - } - - #[inline] - pub(crate) fn get_response_builder(&self, status: StatusCode) -> HttpResponseBuilder { - HttpResponsePool::get_builder(&self.responses, status) - } -} - - -// "Sun, 06 Nov 1994 08:49:37 GMT".len() -const DATE_VALUE_LENGTH: usize = 29; - -pub(crate) struct WorkerSettings { - h: RefCell>, - keep_alive: u64, - ka_enabled: bool, - bytes: Rc, - messages: Rc, - channels: Cell, - node: Box>, - date: UnsafeCell, -} - -impl WorkerSettings { - pub(crate) fn new(h: Vec, keep_alive: KeepAlive) -> WorkerSettings { - let (keep_alive, ka_enabled) = match keep_alive { - KeepAlive::Timeout(val) => (val as u64, true), - KeepAlive::Os | KeepAlive::Tcp(_) => (0, true), - KeepAlive::Disabled => (0, false), - }; - - WorkerSettings { - keep_alive, ka_enabled, - h: RefCell::new(h), - bytes: Rc::new(SharedBytesPool::new()), - messages: Rc::new(helpers::SharedMessagePool::new()), - channels: Cell::new(0), - node: Box::new(Node::head()), - date: UnsafeCell::new(Date::new()), - } - } - - pub fn num_channels(&self) -> usize { - self.channels.get() - } - - pub fn head(&self) -> &Node<()> { - &self.node - } - - pub fn handlers(&self) -> RefMut> { - self.h.borrow_mut() - } - - pub fn keep_alive(&self) -> u64 { - self.keep_alive - } - - pub fn keep_alive_enabled(&self) -> bool { - self.ka_enabled - } - - pub fn get_shared_bytes(&self) -> SharedBytes { - SharedBytes::new(self.bytes.get_bytes(), Rc::clone(&self.bytes)) - } - - pub fn get_http_message(&self) -> helpers::SharedHttpInnerMessage { - helpers::SharedHttpInnerMessage::new(self.messages.get(), Rc::clone(&self.messages)) - } - - pub fn add_channel(&self) { - self.channels.set(self.channels.get() + 1); - } - - pub fn remove_channel(&self) { - let num = self.channels.get(); - if num > 0 { - self.channels.set(num-1); - } else { - error!("Number of removed channels is bigger than added channel. Bug in actix-web"); - } - } - - pub fn update_date(&self) { - unsafe{&mut *self.date.get()}.update(); - } - - pub fn set_date(&self, dst: &mut BytesMut) { - let mut buf: [u8; 39] = unsafe { mem::uninitialized() }; - buf[..6].copy_from_slice(b"date: "); - buf[6..35].copy_from_slice(&(unsafe{&*self.date.get()}.bytes)); - buf[35..].copy_from_slice(b"\r\n\r\n"); - dst.extend_from_slice(&buf); - } - - pub fn set_date_simple(&self, dst: &mut BytesMut) { - dst.extend_from_slice(&(unsafe{&*self.date.get()}.bytes)); - } -} - -struct Date { - bytes: [u8; DATE_VALUE_LENGTH], - pos: usize, -} - -impl Date { - fn new() -> Date { - let mut date = Date{bytes: [0; DATE_VALUE_LENGTH], pos: 0}; - date.update(); - date - } - fn update(&mut self) { - self.pos = 0; - write!(self, "{}", time::at_utc(time::get_time()).rfc822()).unwrap(); - } -} - -impl fmt::Write for Date { - fn write_str(&mut self, s: &str) -> fmt::Result { - let len = s.len(); - self.bytes[self.pos..self.pos + len].copy_from_slice(s.as_bytes()); - self.pos += len; - Ok(()) - } -} - - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_date_len() { - assert_eq!(DATE_VALUE_LENGTH, "Sun, 06 Nov 1994 08:49:37 GMT".len()); - } - - #[test] - fn test_date() { - let settings = WorkerSettings::<()>::new(Vec::new(), KeepAlive::Os); - let mut buf1 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); - settings.set_date(&mut buf1); - let mut buf2 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); - settings.set_date(&mut buf2); - assert_eq!(buf1, buf2); - } -} diff --git a/src/server/shared.rs b/src/server/shared.rs deleted file mode 100644 index bb3269c05..000000000 --- a/src/server/shared.rs +++ /dev/null @@ -1,150 +0,0 @@ -use std::{io, mem}; -use std::cell::RefCell; -use std::rc::Rc; -use std::collections::VecDeque; -use bytes::{BufMut, BytesMut}; - -use body::Binary; - - -/// Internal use only! unsafe -#[derive(Debug)] -pub(crate) struct SharedBytesPool(RefCell>>); - -impl SharedBytesPool { - pub fn new() -> SharedBytesPool { - SharedBytesPool(RefCell::new(VecDeque::with_capacity(128))) - } - - pub fn get_bytes(&self) -> Rc { - if let Some(bytes) = self.0.borrow_mut().pop_front() { - bytes - } else { - Rc::new(BytesMut::new()) - } - } - - pub fn release_bytes(&self, mut bytes: Rc) { - let v = &mut self.0.borrow_mut(); - if v.len() < 128 { - Rc::get_mut(&mut bytes).unwrap().clear(); - v.push_front(bytes); - } - } -} - -#[derive(Debug)] -pub(crate) struct SharedBytes( - Option>, Option>); - -impl Drop for SharedBytes { - fn drop(&mut self) { - if let Some(ref pool) = self.1 { - if let Some(bytes) = self.0.take() { - if Rc::strong_count(&bytes) == 1 { - pool.release_bytes(bytes); - } - } - } - } -} - -impl SharedBytes { - - pub fn empty() -> Self { - SharedBytes(None, None) - } - - pub fn new(bytes: Rc, pool: Rc) -> SharedBytes { - SharedBytes(Some(bytes), Some(pool)) - } - - #[inline(always)] - #[allow(mutable_transmutes)] - #[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref, inline_always))] - pub(crate) fn get_mut(&self) -> &mut BytesMut { - let r: &BytesMut = self.0.as_ref().unwrap().as_ref(); - unsafe{mem::transmute(r)} - } - - #[inline] - pub fn len(&self) -> usize { - self.0.as_ref().unwrap().len() - } - - #[inline] - pub fn is_empty(&self) -> bool { - self.0.as_ref().unwrap().is_empty() - } - - #[inline] - pub fn as_ref(&self) -> &[u8] { - self.0.as_ref().unwrap().as_ref() - } - - pub fn split_to(&self, n: usize) -> BytesMut { - self.get_mut().split_to(n) - } - - pub fn take(&self) -> BytesMut { - self.get_mut().take() - } - - #[inline] - pub fn reserve(&self, cnt: usize) { - self.get_mut().reserve(cnt) - } - - #[inline] - #[cfg_attr(feature = "cargo-clippy", allow(needless_pass_by_value))] - pub fn extend(&self, data: Binary) { - let buf = self.get_mut(); - let data = data.as_ref(); - buf.reserve(data.len()); - SharedBytes::put_slice(buf, data); - } - - #[inline] - pub fn extend_from_slice(&self, data: &[u8]) { - let buf = self.get_mut(); - buf.reserve(data.len()); - SharedBytes::put_slice(buf, data); - } - - #[inline] - pub(crate) fn put_slice(buf: &mut BytesMut, src: &[u8]) { - let len = src.len(); - unsafe { - buf.bytes_mut()[..len].copy_from_slice(src); - buf.advance_mut(len); - } - } - - #[inline] - pub(crate) fn extend_from_slice_(buf: &mut BytesMut, data: &[u8]) { - buf.reserve(data.len()); - SharedBytes::put_slice(buf, data); - } -} - -impl Default for SharedBytes { - fn default() -> Self { - SharedBytes(Some(Rc::new(BytesMut::new())), None) - } -} - -impl Clone for SharedBytes { - fn clone(&self) -> SharedBytes { - SharedBytes(self.0.clone(), self.1.clone()) - } -} - -impl io::Write for SharedBytes { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.extend_from_slice(buf); - Ok(buf.len()) - } - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} diff --git a/src/server/srv.rs b/src/server/srv.rs deleted file mode 100644 index 041021acf..000000000 --- a/src/server/srv.rs +++ /dev/null @@ -1,846 +0,0 @@ -use std::{io, net, thread}; -use std::rc::Rc; -use std::sync::{Arc, mpsc as sync_mpsc}; -use std::time::Duration; - -use actix::prelude::*; -use actix::actors::signal; -use futures::{Future, Sink, Stream}; -use futures::sync::mpsc; -use tokio_io::{AsyncRead, AsyncWrite}; -use mio; -use num_cpus; -use net2::TcpBuilder; - -#[cfg(feature="tls")] -use native_tls::TlsAcceptor; - -#[cfg(feature="alpn")] -use openssl::ssl::{AlpnError, SslAcceptorBuilder}; - -use super::{IntoHttpHandler, IoStream, KeepAlive}; -use super::{PauseServer, ResumeServer, StopServer}; -use super::channel::{HttpChannel, WrapperStream}; -use super::worker::{Conn, Worker, StreamHandlerType, StopWorker}; -use super::settings::{ServerSettings, WorkerSettings}; - -/// An HTTP Server -pub struct HttpServer where H: IntoHttpHandler + 'static -{ - h: Option>>, - threads: usize, - backlog: i32, - host: Option, - keep_alive: KeepAlive, - factory: Arc Vec + Send + Sync>, - #[cfg_attr(feature="cargo-clippy", allow(type_complexity))] - workers: Vec<(usize, Addr>)>, - sockets: Vec<(net::SocketAddr, net::TcpListener)>, - accept: Vec<(mio::SetReadiness, sync_mpsc::Sender)>, - exit: bool, - shutdown_timeout: u16, - signals: Option>, - no_http2: bool, - no_signals: bool, -} - -unsafe impl Sync for HttpServer where H: IntoHttpHandler {} -unsafe impl Send for HttpServer where H: IntoHttpHandler {} - -#[derive(Clone)] -struct Info { - addr: net::SocketAddr, - handler: StreamHandlerType, -} - -enum ServerCommand { - WorkerDied(usize, Info), -} - -impl Actor for HttpServer where H: IntoHttpHandler { - type Context = Context; -} - -impl HttpServer where H: IntoHttpHandler + 'static -{ - /// Create new http server with application factory - pub fn new(factory: F) -> Self - where F: Fn() -> U + Sync + Send + 'static, - U: IntoIterator + 'static, - { - let f = move || { - (factory)().into_iter().collect() - }; - - HttpServer{ h: None, - threads: num_cpus::get(), - backlog: 2048, - host: None, - keep_alive: KeepAlive::Os, - factory: Arc::new(f), - workers: Vec::new(), - sockets: Vec::new(), - accept: Vec::new(), - exit: false, - shutdown_timeout: 30, - signals: None, - no_http2: false, - no_signals: false, - } - } - - /// Set number of workers to start. - /// - /// By default http server uses number of available logical cpu as threads count. - pub fn threads(mut self, num: usize) -> Self { - self.threads = num; - self - } - - /// Set the maximum number of pending connections. - /// - /// This refers to the number of clients that can be waiting to be served. - /// Exceeding this number results in the client getting an error when - /// attempting to connect. It should only affect servers under significant load. - /// - /// Generally set in the 64-2048 range. Default value is 2048. - /// - /// This method should be called before `bind()` method call. - pub fn backlog(mut self, num: i32) -> Self { - self.backlog = num; - self - } - - /// Set server keep-alive setting. - /// - /// By default keep alive is set to a `Os`. - pub fn keep_alive>(mut self, val: T) -> Self { - self.keep_alive = val.into(); - self - } - - /// Set server host name. - /// - /// Host name is used by application router aa a hostname for url generation. - /// Check [ConnectionInfo](./dev/struct.ConnectionInfo.html#method.host) documentation - /// for more information. - pub fn server_hostname(mut self, val: String) -> Self { - self.host = Some(val); - self - } - - /// Send `SystemExit` message to actix system - /// - /// `SystemExit` message stops currently running system arbiter and all - /// nested arbiters. - pub fn system_exit(mut self) -> Self { - self.exit = true; - self - } - - /// Set alternative address for `ProcessSignals` actor. - pub fn signals(mut self, addr: Addr) -> Self { - self.signals = Some(addr); - self - } - - /// Disable signal handling - pub fn disable_signals(mut self) -> Self { - self.no_signals = true; - self - } - - /// Timeout for graceful workers shutdown. - /// - /// After receiving a stop signal, workers have this much time to finish serving requests. - /// Workers still alive after the timeout are force dropped. - /// - /// By default shutdown timeout sets to 30 seconds. - pub fn shutdown_timeout(mut self, sec: u16) -> Self { - self.shutdown_timeout = sec; - self - } - - /// Disable `HTTP/2` support - pub fn no_http2(mut self) -> Self { - self.no_http2 = true; - self - } - - /// Get addresses of bound sockets. - pub fn addrs(&self) -> Vec { - self.sockets.iter().map(|s| s.0).collect() - } - - /// Use listener for accepting incoming connection requests - /// - /// HttpServer does not change any configuration for TcpListener, - /// it needs to be configured before passing it to listen() method. - pub fn listen(mut self, lst: net::TcpListener) -> Self { - self.sockets.push((lst.local_addr().unwrap(), lst)); - self - } - - /// The socket address to bind - /// - /// To mind multiple addresses this method can be call multiple times. - pub fn bind(mut self, addr: S) -> io::Result { - let mut err = None; - let mut succ = false; - for addr in addr.to_socket_addrs()? { - match create_tcp_listener(addr, self.backlog) { - Ok(lst) => { - succ = true; - self.sockets.push((lst.local_addr().unwrap(), lst)); - }, - Err(e) => err = Some(e), - } - } - - if !succ { - if let Some(e) = err.take() { - Err(e) - } else { - Err(io::Error::new(io::ErrorKind::Other, "Can not bind to address.")) - } - } else { - Ok(self) - } - } - - fn start_workers(&mut self, settings: &ServerSettings, handler: &StreamHandlerType) - -> Vec<(usize, mpsc::UnboundedSender>)> - { - // start workers - let mut workers = Vec::new(); - for idx in 0..self.threads { - let s = settings.clone(); - let (tx, rx) = mpsc::unbounded::>(); - - let h = handler.clone(); - let ka = self.keep_alive; - let factory = Arc::clone(&self.factory); - let addr = Arbiter::start(move |ctx: &mut Context<_>| { - let apps: Vec<_> = (*factory)() - .into_iter() - .map(|h| h.into_handler(s.clone())).collect(); - ctx.add_message_stream(rx); - Worker::new(apps, h, ka) - }); - workers.push((idx, tx)); - self.workers.push((idx, addr)); - } - info!("Starting {} http workers", self.threads); - workers - } - - // subscribe to os signals - fn subscribe_to_signals(&self) -> Option> { - if !self.no_signals { - if let Some(ref signals) = self.signals { - Some(signals.clone()) - } else { - Some(Arbiter::system_registry().get::()) - } - } else { - None - } - } -} - -impl HttpServer -{ - /// Start listening for incoming connections. - /// - /// This method starts number of http handler workers in separate threads. - /// For each address this method starts separate thread which does `accept()` in a loop. - /// - /// This methods panics if no socket addresses get bound. - /// - /// This method requires to run within properly configured `Actix` system. - /// - /// ```rust - /// extern crate actix; - /// extern crate actix_web; - /// use actix_web::*; - /// - /// fn main() { - /// let sys = actix::System::new("example"); // <- create Actix system - /// - /// HttpServer::new( - /// || App::new() - /// .resource("/", |r| r.h(|_| HttpResponse::Ok()))) - /// .bind("127.0.0.1:0").expect("Can not bind to 127.0.0.1:0") - /// .start(); - /// # actix::Arbiter::system().do_send(actix::msgs::SystemExit(0)); - /// - /// let _ = sys.run(); // <- Run actix system, this method actually starts all async processes - /// } - /// ``` - pub fn start(mut self) -> Addr - { - if self.sockets.is_empty() { - panic!("HttpServer::bind() has to be called before start()"); - } else { - let (tx, rx) = mpsc::unbounded(); - let addrs: Vec<(net::SocketAddr, net::TcpListener)> = - self.sockets.drain(..).collect(); - let settings = ServerSettings::new(Some(addrs[0].0), &self.host, false); - let workers = self.start_workers(&settings, &StreamHandlerType::Normal); - let info = Info{addr: addrs[0].0, handler: StreamHandlerType::Normal}; - - // start acceptors threads - for (addr, sock) in addrs { - info!("Starting server on http://{}", addr); - self.accept.push( - start_accept_thread( - sock, addr, self.backlog, - tx.clone(), info.clone(), workers.clone())); - } - - // start http server actor - let signals = self.subscribe_to_signals(); - let addr: Addr = Actor::create(move |ctx| { - ctx.add_stream(rx); - self - }); - signals.map(|signals| signals.do_send( - signal::Subscribe(addr.clone().recipient()))); - addr - } - } - - /// Spawn new thread and start listening for incoming connections. - /// - /// This method spawns new thread and starts new actix system. Other than that it is - /// similar to `start()` method. This method blocks. - /// - /// This methods panics if no socket addresses get bound. - /// - /// ```rust,ignore - /// # extern crate futures; - /// # extern crate actix; - /// # extern crate actix_web; - /// # use futures::Future; - /// use actix_web::*; - /// - /// fn main() { - /// HttpServer::new( - /// || App::new() - /// .resource("/", |r| r.h(|_| HttpResponse::Ok()))) - /// .bind("127.0.0.1:0").expect("Can not bind to 127.0.0.1:0") - /// .run(); - /// } - /// ``` - pub fn run(mut self) { - self.exit = true; - self.no_signals = false; - - let _ = thread::spawn(move || { - let sys = System::new("http-server"); - self.start(); - let _ = sys.run(); - }).join(); - } -} - -#[cfg(feature="tls")] -impl HttpServer -{ - /// Start listening for incoming tls connections. - pub fn start_tls(mut self, acceptor: TlsAcceptor) -> io::Result> { - if self.sockets.is_empty() { - Err(io::Error::new(io::ErrorKind::Other, "No socket addresses are bound")) - } else { - let (tx, rx) = mpsc::unbounded(); - let addrs: Vec<(net::SocketAddr, net::TcpListener)> = self.sockets.drain(..).collect(); - let settings = ServerSettings::new(Some(addrs[0].0), &self.host, false); - let workers = self.start_workers( - &settings, &StreamHandlerType::Tls(acceptor.clone())); - let info = Info{addr: addrs[0].0, handler: StreamHandlerType::Tls(acceptor)}; - - // start acceptors threads - for (addr, sock) in addrs { - info!("Starting server on https://{}", addr); - self.accept.push( - start_accept_thread( - sock, addr, self.backlog, - tx.clone(), info.clone(), workers.clone())); - } - - // start http server actor - let signals = self.subscribe_to_signals(); - let addr: Addr = Actor::create(|ctx| { - ctx.add_stream(rx); - self - }); - signals.map(|signals| signals.do_send( - signal::Subscribe(addr.clone().recipient()))); - Ok(addr) - } - } -} - -#[cfg(feature="alpn")] -impl HttpServer -{ - /// Start listening for incoming tls connections. - /// - /// This method sets alpn protocols to "h2" and "http/1.1" - pub fn start_ssl(mut self, mut builder: SslAcceptorBuilder) -> io::Result> - { - if self.sockets.is_empty() { - Err(io::Error::new(io::ErrorKind::Other, "No socket addresses are bound")) - } else { - // alpn support - if !self.no_http2 { - builder.set_alpn_protos(b"\x02h2\x08http/1.1")?; - builder.set_alpn_select_callback(|_, protos| { - const H2: &[u8] = b"\x02h2"; - if protos.windows(3).any(|window| window == H2) { - Ok(b"h2") - } else { - Err(AlpnError::NOACK) - } - }); - } - - let (tx, rx) = mpsc::unbounded(); - let acceptor = builder.build(); - let addrs: Vec<(net::SocketAddr, net::TcpListener)> = self.sockets.drain(..).collect(); - let settings = ServerSettings::new(Some(addrs[0].0), &self.host, false); - let workers = self.start_workers( - &settings, &StreamHandlerType::Alpn(acceptor.clone())); - let info = Info{addr: addrs[0].0, handler: StreamHandlerType::Alpn(acceptor)}; - - // start acceptors threads - for (addr, sock) in addrs { - info!("Starting server on https://{}", addr); - self.accept.push( - start_accept_thread( - sock, addr, self.backlog, - tx.clone(), info.clone(), workers.clone())); - } - - // start http server actor - let signals = self.subscribe_to_signals(); - let addr: Addr = Actor::create(|ctx| { - ctx.add_stream(rx); - self - }); - signals.map(|signals| signals.do_send( - signal::Subscribe(addr.clone().recipient()))); - Ok(addr) - } - } -} - -impl HttpServer -{ - /// Start listening for incoming connections from a stream. - /// - /// This method uses only one thread for handling incoming connections. - pub fn start_incoming(mut self, stream: S, secure: bool) -> Addr - where S: Stream + 'static, - T: AsyncRead + AsyncWrite + 'static, - A: 'static - { - let (tx, rx) = mpsc::unbounded(); - - if !self.sockets.is_empty() { - let addrs: Vec<(net::SocketAddr, net::TcpListener)> = - self.sockets.drain(..).collect(); - let settings = ServerSettings::new(Some(addrs[0].0), &self.host, false); - let workers = self.start_workers(&settings, &StreamHandlerType::Normal); - let info = Info{addr: addrs[0].0, handler: StreamHandlerType::Normal}; - - // start acceptors threads - for (addr, sock) in addrs { - info!("Starting server on http://{}", addr); - self.accept.push( - start_accept_thread( - sock, addr, self.backlog, - tx.clone(), info.clone(), workers.clone())); - } - } - - // set server settings - let addr: net::SocketAddr = "127.0.0.1:8080".parse().unwrap(); - let settings = ServerSettings::new(Some(addr), &self.host, secure); - let apps: Vec<_> = (*self.factory)() - .into_iter().map(|h| h.into_handler(settings.clone())).collect(); - self.h = Some(Rc::new(WorkerSettings::new(apps, self.keep_alive))); - - // start server - let signals = self.subscribe_to_signals(); - let addr: Addr = HttpServer::create(move |ctx| { - ctx.add_stream(rx); - ctx.add_message_stream( - stream - .map_err(|_| ()) - .map(move |(t, _)| Conn{io: WrapperStream::new(t), peer: None, http2: false})); - self - }); - signals.map(|signals| signals.do_send( - signal::Subscribe(addr.clone().recipient()))); - addr - } -} - -/// Signals support -/// Handle `SIGINT`, `SIGTERM`, `SIGQUIT` signals and send `SystemExit(0)` -/// message to `System` actor. -impl Handler for HttpServer -{ - type Result = (); - - fn handle(&mut self, msg: signal::Signal, ctx: &mut Context) { - match msg.0 { - signal::SignalType::Int => { - info!("SIGINT received, exiting"); - self.exit = true; - Handler::::handle(self, StopServer{graceful: false}, ctx); - } - signal::SignalType::Term => { - info!("SIGTERM received, stopping"); - self.exit = true; - Handler::::handle(self, StopServer{graceful: true}, ctx); - } - signal::SignalType::Quit => { - info!("SIGQUIT received, exiting"); - self.exit = true; - Handler::::handle(self, StopServer{graceful: false}, ctx); - } - _ => (), - } - } -} - -/// Commands from accept threads -impl StreamHandler for HttpServer -{ - fn finished(&mut self, _: &mut Context) {} - fn handle(&mut self, msg: ServerCommand, _: &mut Context) { - match msg { - ServerCommand::WorkerDied(idx, info) => { - let mut found = false; - for i in 0..self.workers.len() { - if self.workers[i].0 == idx { - self.workers.swap_remove(i); - found = true; - break - } - } - - if found { - error!("Worker has died {:?}, restarting", idx); - let (tx, rx) = mpsc::unbounded::>(); - - let mut new_idx = self.workers.len(); - 'found: loop { - for i in 0..self.workers.len() { - if self.workers[i].0 == new_idx { - new_idx += 1; - continue 'found - } - } - break - } - - let h = info.handler; - let ka = self.keep_alive; - let factory = Arc::clone(&self.factory); - let settings = ServerSettings::new(Some(info.addr), &self.host, false); - - let addr = Arbiter::start(move |ctx: &mut Context<_>| { - let apps: Vec<_> = (*factory)() - .into_iter() - .map(|h| h.into_handler(settings.clone())).collect(); - ctx.add_message_stream(rx); - Worker::new(apps, h, ka) - }); - for item in &self.accept { - let _ = item.1.send(Command::Worker(new_idx, tx.clone())); - let _ = item.0.set_readiness(mio::Ready::readable()); - } - - self.workers.push((new_idx, addr)); - } - }, - } - } -} - -impl Handler> for HttpServer - where T: IoStream, - H: IntoHttpHandler, -{ - type Result = (); - - fn handle(&mut self, msg: Conn, _: &mut Context) -> Self::Result { - Arbiter::handle().spawn( - HttpChannel::new( - Rc::clone(self.h.as_ref().unwrap()), msg.io, msg.peer, msg.http2)); - } -} - -impl Handler for HttpServer -{ - type Result = (); - - fn handle(&mut self, _: PauseServer, _: &mut Context) - { - for item in &self.accept { - let _ = item.1.send(Command::Pause); - let _ = item.0.set_readiness(mio::Ready::readable()); - } - } -} - -impl Handler for HttpServer -{ - type Result = (); - - fn handle(&mut self, _: ResumeServer, _: &mut Context) { - for item in &self.accept { - let _ = item.1.send(Command::Resume); - let _ = item.0.set_readiness(mio::Ready::readable()); - } - } -} - -impl Handler for HttpServer -{ - type Result = actix::Response<(), ()>; - - fn handle(&mut self, msg: StopServer, ctx: &mut Context) -> Self::Result { - // stop accept threads - for item in &self.accept { - let _ = item.1.send(Command::Stop); - let _ = item.0.set_readiness(mio::Ready::readable()); - } - - // stop workers - let (tx, rx) = mpsc::channel(1); - - let dur = if msg.graceful { - Some(Duration::new(u64::from(self.shutdown_timeout), 0)) - } else { - None - }; - for worker in &self.workers { - let tx2 = tx.clone(); - worker.1.send(StopWorker{graceful: dur}) - .into_actor(self) - .then(move |_, slf, ctx| { - slf.workers.pop(); - if slf.workers.is_empty() { - let _ = tx2.send(()); - - // we need to stop system if server was spawned - if slf.exit { - ctx.run_later(Duration::from_millis(300), |_, _| { - Arbiter::system().do_send(actix::msgs::SystemExit(0)) - }); - } - } - actix::fut::ok(()) - }).spawn(ctx); - } - - if !self.workers.is_empty() { - Response::async( - rx.into_future().map(|_| ()).map_err(|_| ())) - } else { - // we need to stop system if server was spawned - if self.exit { - ctx.run_later(Duration::from_millis(300), |_, _| { - Arbiter::system().do_send(actix::msgs::SystemExit(0)) - }); - } - Response::reply(Ok(())) - } - } -} - -enum Command { - Pause, - Resume, - Stop, - Worker(usize, mpsc::UnboundedSender>), -} - -fn start_accept_thread( - sock: net::TcpListener, addr: net::SocketAddr, backlog: i32, - srv: mpsc::UnboundedSender, info: Info, - mut workers: Vec<(usize, mpsc::UnboundedSender>)>) - -> (mio::SetReadiness, sync_mpsc::Sender) -{ - let (tx, rx) = sync_mpsc::channel(); - let (reg, readiness) = mio::Registration::new2(); - - // start accept thread - #[cfg_attr(feature="cargo-clippy", allow(cyclomatic_complexity))] - let _ = thread::Builder::new().name(format!("Accept on {}", addr)).spawn(move || { - const SRV: mio::Token = mio::Token(0); - const CMD: mio::Token = mio::Token(1); - - let mut server = Some( - mio::net::TcpListener::from_std(sock) - .expect("Can not create mio::net::TcpListener")); - - // Create a poll instance - let poll = match mio::Poll::new() { - Ok(poll) => poll, - Err(err) => panic!("Can not create mio::Poll: {}", err), - }; - - // Start listening for incoming connections - if let Some(ref srv) = server { - if let Err(err) = poll.register( - srv, SRV, mio::Ready::readable(), mio::PollOpt::edge()) { - panic!("Can not register io: {}", err); - } - } - - // Start listening for incoming commands - if let Err(err) = poll.register(®, CMD, - mio::Ready::readable(), mio::PollOpt::edge()) { - panic!("Can not register Registration: {}", err); - } - - // Create storage for events - let mut events = mio::Events::with_capacity(128); - - // Sleep on error - let sleep = Duration::from_millis(100); - - let mut next = 0; - loop { - if let Err(err) = poll.poll(&mut events, None) { - panic!("Poll error: {}", err); - } - - for event in events.iter() { - match event.token() { - SRV => if let Some(ref server) = server { - loop { - match server.accept_std() { - Ok((sock, addr)) => { - let mut msg = Conn{ - io: sock, peer: Some(addr), http2: false}; - while !workers.is_empty() { - match workers[next].1.unbounded_send(msg) { - Ok(_) => (), - Err(err) => { - let _ = srv.unbounded_send( - ServerCommand::WorkerDied( - workers[next].0, info.clone())); - msg = err.into_inner(); - workers.swap_remove(next); - if workers.is_empty() { - error!("No workers"); - thread::sleep(sleep); - break - } else if workers.len() <= next { - next = 0; - } - continue - } - } - next = (next + 1) % workers.len(); - break - } - }, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => - break, - Err(ref e) if connection_error(e) => - continue, - Err(e) => { - error!("Error accepting connection: {}", e); - // sleep after error - thread::sleep(sleep); - break - } - } - } - }, - CMD => match rx.try_recv() { - Ok(cmd) => match cmd { - Command::Pause => if let Some(server) = server.take() { - if let Err(err) = poll.deregister(&server) { - error!("Can not deregister server socket {}", err); - } else { - info!("Paused accepting connections on {}", addr); - } - }, - Command::Resume => { - let lst = create_tcp_listener(addr, backlog) - .expect("Can not create net::TcpListener"); - - server = Some( - mio::net::TcpListener::from_std(lst) - .expect("Can not create mio::net::TcpListener")); - - if let Some(ref server) = server { - if let Err(err) = poll.register( - server, SRV, mio::Ready::readable(), mio::PollOpt::edge()) - { - error!("Can not resume socket accept process: {}", err); - } else { - info!("Accepting connections on {} has been resumed", - addr); - } - } - }, - Command::Stop => { - if let Some(server) = server.take() { - let _ = poll.deregister(&server); - } - return - }, - Command::Worker(idx, addr) => { - workers.push((idx, addr)); - }, - }, - Err(err) => match err { - sync_mpsc::TryRecvError::Empty => (), - sync_mpsc::TryRecvError::Disconnected => { - if let Some(server) = server.take() { - let _ = poll.deregister(&server); - } - return - }, - } - }, - _ => unreachable!(), - } - } - } - }); - - (readiness, tx) -} - -fn create_tcp_listener(addr: net::SocketAddr, backlog: i32) -> io::Result { - let builder = match addr { - net::SocketAddr::V4(_) => TcpBuilder::new_v4()?, - net::SocketAddr::V6(_) => TcpBuilder::new_v6()?, - }; - builder.reuse_address(true)?; - builder.bind(addr)?; - Ok(builder.listen(backlog)?) -} - -/// This function defines errors that are per-connection. Which basically -/// means that if we get this error from `accept()` system call it means -/// next connection might be ready to be accepted. -/// -/// All other errors will incur a timeout before next `accept()` is performed. -/// The timeout is useful to handle resource exhaustion errors like ENFILE -/// and EMFILE. Otherwise, could enter into tight loop. -fn connection_error(e: &io::Error) -> bool { - e.kind() == io::ErrorKind::ConnectionRefused || - e.kind() == io::ErrorKind::ConnectionAborted || - e.kind() == io::ErrorKind::ConnectionReset -} diff --git a/src/server/utils.rs b/src/server/utils.rs deleted file mode 100644 index bbc890e94..000000000 --- a/src/server/utils.rs +++ /dev/null @@ -1,30 +0,0 @@ -use std::io; -use bytes::{BytesMut, BufMut}; -use futures::{Async, Poll}; - -use super::IoStream; - -const LW_BUFFER_SIZE: usize = 4096; -const HW_BUFFER_SIZE: usize = 32_768; - - -pub fn read_from_io(io: &mut T, buf: &mut BytesMut) -> Poll { - unsafe { - if buf.remaining_mut() < LW_BUFFER_SIZE { - buf.reserve(HW_BUFFER_SIZE); - } - match io.read(buf.bytes_mut()) { - Ok(n) => { - buf.advance_mut(n); - Ok(Async::Ready(n)) - }, - Err(e) => { - if e.kind() == io::ErrorKind::WouldBlock { - Ok(Async::NotReady) - } else { - Err(e) - } - } - } - } -} diff --git a/src/server/worker.rs b/src/server/worker.rs deleted file mode 100644 index 3fe9cec19..000000000 --- a/src/server/worker.rs +++ /dev/null @@ -1,218 +0,0 @@ -use std::{net, time}; -use std::rc::Rc; -use futures::Future; -use futures::unsync::oneshot; -use tokio_core::net::TcpStream; -use tokio_core::reactor::Handle; -use net2::TcpStreamExt; - -#[cfg(any(feature="tls", feature="alpn"))] -use futures::future; - -#[cfg(feature="tls")] -use native_tls::TlsAcceptor; -#[cfg(feature="tls")] -use tokio_tls::TlsAcceptorExt; - -#[cfg(feature="alpn")] -use openssl::ssl::SslAcceptor; -#[cfg(feature="alpn")] -use tokio_openssl::SslAcceptorExt; - -use actix::*; -use actix::msgs::StopArbiter; - -use server::{HttpHandler, KeepAlive}; -use server::channel::HttpChannel; -use server::settings::WorkerSettings; - - -#[derive(Message)] -pub(crate) struct Conn { - pub io: T, - pub peer: Option, - pub http2: bool, -} - -/// Stop worker message. Returns `true` on successful shutdown -/// and `false` if some connections still alive. -pub(crate) struct StopWorker { - pub graceful: Option, -} - -impl Message for StopWorker { - type Result = Result; -} - -/// Http worker -/// -/// Worker accepts Socket objects via unbounded channel and start requests processing. -pub(crate) -struct Worker where H: HttpHandler + 'static { - settings: Rc>, - hnd: Handle, - handler: StreamHandlerType, - tcp_ka: Option, -} - -impl Worker { - - pub(crate) fn new(h: Vec, handler: StreamHandlerType, keep_alive: KeepAlive) - -> Worker - { - let tcp_ka = if let KeepAlive::Tcp(val) = keep_alive { - Some(time::Duration::new(val as u64, 0)) - } else { - None - }; - - Worker { - settings: Rc::new(WorkerSettings::new(h, keep_alive)), - hnd: Arbiter::handle().clone(), - handler, - tcp_ka, - } - } - - fn update_time(&self, ctx: &mut Context) { - self.settings.update_date(); - ctx.run_later(time::Duration::new(1, 0), |slf, ctx| slf.update_time(ctx)); - } - - fn shutdown_timeout(&self, ctx: &mut Context, - tx: oneshot::Sender, dur: time::Duration) { - // sleep for 1 second and then check again - ctx.run_later(time::Duration::new(1, 0), move |slf, ctx| { - let num = slf.settings.num_channels(); - if num == 0 { - let _ = tx.send(true); - Arbiter::arbiter().do_send(StopArbiter(0)); - } else if let Some(d) = dur.checked_sub(time::Duration::new(1, 0)) { - slf.shutdown_timeout(ctx, tx, d); - } else { - info!("Force shutdown http worker, {} connections", num); - slf.settings.head().traverse::(); - let _ = tx.send(false); - Arbiter::arbiter().do_send(StopArbiter(0)); - } - }); - } -} - -impl Actor for Worker where H: HttpHandler + 'static { - type Context = Context; - - fn started(&mut self, ctx: &mut Self::Context) { - self.update_time(ctx); - } -} - -impl Handler> for Worker - where H: HttpHandler + 'static, -{ - type Result = (); - - fn handle(&mut self, msg: Conn, _: &mut Context) - { - if self.tcp_ka.is_some() && msg.io.set_keepalive(self.tcp_ka).is_err() { - error!("Can not set socket keep-alive option"); - } - self.handler.handle(Rc::clone(&self.settings), &self.hnd, msg); - } -} - -/// `StopWorker` message handler -impl Handler for Worker - where H: HttpHandler + 'static, -{ - type Result = Response; - - fn handle(&mut self, msg: StopWorker, ctx: &mut Context) -> Self::Result { - let num = self.settings.num_channels(); - if num == 0 { - info!("Shutting down http worker, 0 connections"); - Response::reply(Ok(true)) - } else if let Some(dur) = msg.graceful { - info!("Graceful http worker shutdown, {} connections", num); - let (tx, rx) = oneshot::channel(); - self.shutdown_timeout(ctx, tx, dur); - Response::async(rx.map_err(|_| ())) - } else { - info!("Force shutdown http worker, {} connections", num); - self.settings.head().traverse::(); - Response::reply(Ok(false)) - } - } -} - -#[derive(Clone)] -pub(crate) enum StreamHandlerType { - Normal, - #[cfg(feature="tls")] - Tls(TlsAcceptor), - #[cfg(feature="alpn")] - Alpn(SslAcceptor), -} - -impl StreamHandlerType { - - fn handle(&mut self, - h: Rc>, - hnd: &Handle, msg: Conn) { - match *self { - StreamHandlerType::Normal => { - let _ = msg.io.set_nodelay(true); - let io = TcpStream::from_stream(msg.io, hnd) - .expect("failed to associate TCP stream"); - - hnd.spawn(HttpChannel::new(h, io, msg.peer, msg.http2)); - } - #[cfg(feature="tls")] - StreamHandlerType::Tls(ref acceptor) => { - let Conn { io, peer, http2 } = msg; - let _ = io.set_nodelay(true); - let io = TcpStream::from_stream(io, hnd) - .expect("failed to associate TCP stream"); - - hnd.spawn( - TlsAcceptorExt::accept_async(acceptor, io).then(move |res| { - match res { - Ok(io) => Arbiter::handle().spawn( - HttpChannel::new(h, io, peer, http2)), - Err(err) => - trace!("Error during handling tls connection: {}", err), - }; - future::result(Ok(())) - }) - ); - } - #[cfg(feature="alpn")] - StreamHandlerType::Alpn(ref acceptor) => { - let Conn { io, peer, .. } = msg; - let _ = io.set_nodelay(true); - let io = TcpStream::from_stream(io, hnd) - .expect("failed to associate TCP stream"); - - hnd.spawn( - SslAcceptorExt::accept_async(acceptor, io).then(move |res| { - match res { - Ok(io) => { - let http2 = if let Some(p) = io.get_ref().ssl().selected_alpn_protocol() - { - p.len() == 2 && &p == b"h2" - } else { - false - }; - Arbiter::handle().spawn( - HttpChannel::new(h, io, peer, http2)); - }, - Err(err) => - trace!("Error during handling tls connection: {}", err), - }; - future::result(Ok(())) - }) - ); - } - } - } -} diff --git a/src/service.rs b/src/service.rs new file mode 100644 index 000000000..b392e6e8b --- /dev/null +++ b/src/service.rs @@ -0,0 +1,602 @@ +use std::cell::{Ref, RefMut}; +use std::rc::Rc; +use std::{fmt, net}; + +use actix_http::body::{Body, MessageBody, ResponseBody}; +use actix_http::http::{HeaderMap, Method, StatusCode, Uri, Version}; +use actix_http::{ + Error, Extensions, HttpMessage, Payload, PayloadStream, RequestHead, Response, + ResponseHead, +}; +use actix_router::{Path, Resource, ResourceDef, Url}; +use actix_service::{IntoServiceFactory, ServiceFactory}; + +use crate::config::{AppConfig, AppService}; +use crate::data::Data; +use crate::dev::insert_slash; +use crate::guard::Guard; +use crate::info::ConnectionInfo; +use crate::request::HttpRequest; +use crate::rmap::ResourceMap; + +pub trait HttpServiceFactory { + fn register(self, config: &mut AppService); +} + +pub(crate) trait AppServiceFactory { + fn register(&mut self, config: &mut AppService); +} + +pub(crate) struct ServiceFactoryWrapper { + factory: Option, +} + +impl ServiceFactoryWrapper { + pub fn new(factory: T) -> Self { + Self { + factory: Some(factory), + } + } +} + +impl AppServiceFactory for ServiceFactoryWrapper +where + T: HttpServiceFactory, +{ + fn register(&mut self, config: &mut AppService) { + if let Some(item) = self.factory.take() { + item.register(config) + } + } +} + +/// An service http request +/// +/// ServiceRequest allows mutable access to request's internal structures +pub struct ServiceRequest(HttpRequest); + +impl ServiceRequest { + /// Construct service request + pub(crate) fn new(req: HttpRequest) -> Self { + ServiceRequest(req) + } + + /// Deconstruct request into parts + pub fn into_parts(mut self) -> (HttpRequest, Payload) { + let pl = Rc::get_mut(&mut (self.0).0).unwrap().payload.take(); + (self.0, pl) + } + + /// Construct request from parts. + /// + /// `ServiceRequest` can be re-constructed only if `req` hasnt been cloned. + pub fn from_parts( + mut req: HttpRequest, + pl: Payload, + ) -> Result { + if Rc::strong_count(&req.0) == 1 && Rc::weak_count(&req.0) == 0 { + Rc::get_mut(&mut req.0).unwrap().payload = pl; + Ok(ServiceRequest(req)) + } else { + Err((req, pl)) + } + } + + /// Construct request from request. + /// + /// `HttpRequest` implements `Clone` trait via `Rc` type. `ServiceRequest` + /// can be re-constructed only if rc's strong pointers count eq 1 and + /// weak pointers count is 0. + pub fn from_request(req: HttpRequest) -> Result { + if Rc::strong_count(&req.0) == 1 && Rc::weak_count(&req.0) == 0 { + Ok(ServiceRequest(req)) + } else { + Err(req) + } + } + + /// Create service response + #[inline] + pub fn into_response>>(self, res: R) -> ServiceResponse { + ServiceResponse::new(self.0, res.into()) + } + + /// Create service response for error + #[inline] + pub fn error_response>(self, err: E) -> ServiceResponse { + let res: Response = err.into().into(); + ServiceResponse::new(self.0, res.into_body()) + } + + /// This method returns reference to the request head + #[inline] + pub fn head(&self) -> &RequestHead { + &self.0.head() + } + + /// This method returns reference to the request head + #[inline] + pub fn head_mut(&mut self) -> &mut RequestHead { + self.0.head_mut() + } + + /// Request's uri. + #[inline] + pub fn uri(&self) -> &Uri { + &self.head().uri + } + + /// Read the Request method. + #[inline] + pub fn method(&self) -> &Method { + &self.head().method + } + + /// Read the Request Version. + #[inline] + pub fn version(&self) -> Version { + self.head().version + } + + #[inline] + /// Returns request's headers. + pub fn headers(&self) -> &HeaderMap { + &self.head().headers + } + + #[inline] + /// Returns mutable request's headers. + pub fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.head_mut().headers + } + + /// The target path of this Request. + #[inline] + pub fn path(&self) -> &str { + self.head().uri.path() + } + + /// The query string in the URL. + /// + /// E.g., id=10 + #[inline] + pub fn query_string(&self) -> &str { + if let Some(query) = self.uri().query().as_ref() { + query + } else { + "" + } + } + + /// Peer socket address + /// + /// Peer address is actual socket address, if proxy is used in front of + /// actix http server, then peer address would be address of this proxy. + /// + /// To get client connection information `ConnectionInfo` should be used. + #[inline] + pub fn peer_addr(&self) -> Option { + self.head().peer_addr + } + + /// Get *ConnectionInfo* for the current request. + #[inline] + pub fn connection_info(&self) -> Ref { + ConnectionInfo::get(self.head(), &*self.app_config()) + } + + /// Get a reference to the Path parameters. + /// + /// Params is a container for url parameters. + /// A variable segment is specified in the form `{identifier}`, + /// where the identifier can be used later in a request handler to + /// access the matched value for that segment. + #[inline] + pub fn match_info(&self) -> &Path { + self.0.match_info() + } + + #[inline] + /// Get a mutable reference to the Path parameters. + pub fn match_info_mut(&mut self) -> &mut Path { + self.0.match_info_mut() + } + + #[inline] + /// Get a reference to a `ResourceMap` of current application. + pub fn resource_map(&self) -> &ResourceMap { + self.0.resource_map() + } + + /// Service configuration + #[inline] + pub fn app_config(&self) -> &AppConfig { + self.0.app_config() + } + + /// Get an application data stored with `App::data()` method during + /// application configuration. + pub fn app_data(&self) -> Option> { + if let Some(st) = (self.0).0.app_data.get::>() { + Some(st.clone()) + } else { + None + } + } + + /// Set request payload. + pub fn set_payload(&mut self, payload: Payload) { + Rc::get_mut(&mut (self.0).0).unwrap().payload = payload; + } + + #[doc(hidden)] + /// Set new app data container + pub fn set_data_container(&mut self, extensions: Rc) { + Rc::get_mut(&mut (self.0).0).unwrap().app_data = extensions; + } +} + +impl Resource for ServiceRequest { + fn resource_path(&mut self) -> &mut Path { + self.match_info_mut() + } +} + +impl HttpMessage for ServiceRequest { + type Stream = PayloadStream; + + #[inline] + /// Returns Request's headers. + fn headers(&self) -> &HeaderMap { + &self.head().headers + } + + /// Request extensions + #[inline] + fn extensions(&self) -> Ref { + self.0.extensions() + } + + /// Mutable reference to a the request's extensions + #[inline] + fn extensions_mut(&self) -> RefMut { + self.0.extensions_mut() + } + + #[inline] + fn take_payload(&mut self) -> Payload { + Rc::get_mut(&mut (self.0).0).unwrap().payload.take() + } +} + +impl fmt::Debug for ServiceRequest { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!( + f, + "\nServiceRequest {:?} {}:{}", + self.head().version, + self.head().method, + self.path() + )?; + if !self.query_string().is_empty() { + writeln!(f, " query: ?{:?}", self.query_string())?; + } + if !self.match_info().is_empty() { + writeln!(f, " params: {:?}", self.match_info())?; + } + writeln!(f, " headers:")?; + for (key, val) in self.headers().iter() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + Ok(()) + } +} + +pub struct ServiceResponse { + request: HttpRequest, + response: Response, +} + +impl ServiceResponse { + /// Create service response instance + pub fn new(request: HttpRequest, response: Response) -> Self { + ServiceResponse { request, response } + } + + /// Create service response from the error + pub fn from_err>(err: E, request: HttpRequest) -> Self { + let e: Error = err.into(); + let res: Response = e.into(); + ServiceResponse { + request, + response: res.into_body(), + } + } + + /// Create service response for error + #[inline] + pub fn error_response>(self, err: E) -> Self { + Self::from_err(err, self.request) + } + + /// Create service response + #[inline] + pub fn into_response(self, response: Response) -> ServiceResponse { + ServiceResponse::new(self.request, response) + } + + /// Get reference to original request + #[inline] + pub fn request(&self) -> &HttpRequest { + &self.request + } + + /// Get reference to response + #[inline] + pub fn response(&self) -> &Response { + &self.response + } + + /// Get mutable reference to response + #[inline] + pub fn response_mut(&mut self) -> &mut Response { + &mut self.response + } + + /// Get the response status code + #[inline] + pub fn status(&self) -> StatusCode { + self.response.status() + } + + #[inline] + /// Returns response's headers. + pub fn headers(&self) -> &HeaderMap { + self.response.headers() + } + + #[inline] + /// Returns mutable response's headers. + pub fn headers_mut(&mut self) -> &mut HeaderMap { + self.response.headers_mut() + } + + /// Execute closure and in case of error convert it to response. + pub fn checked_expr(mut self, f: F) -> Self + where + F: FnOnce(&mut Self) -> Result<(), E>, + E: Into, + { + match f(&mut self) { + Ok(_) => self, + Err(err) => { + let res: Response = err.into().into(); + ServiceResponse::new(self.request, res.into_body()) + } + } + } + + /// Extract response body + pub fn take_body(&mut self) -> ResponseBody { + self.response.take_body() + } +} + +impl ServiceResponse { + /// Set a new body + pub fn map_body(self, f: F) -> ServiceResponse + where + F: FnOnce(&mut ResponseHead, ResponseBody) -> ResponseBody, + { + let response = self.response.map_body(f); + + ServiceResponse { + response, + request: self.request, + } + } +} + +impl Into> for ServiceResponse { + fn into(self) -> Response { + self.response + } +} + +impl fmt::Debug for ServiceResponse { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = writeln!( + f, + "\nServiceResponse {:?} {}{}", + self.response.head().version, + self.response.head().status, + self.response.head().reason.unwrap_or(""), + ); + let _ = writeln!(f, " headers:"); + for (key, val) in self.response.head().headers.iter() { + let _ = writeln!(f, " {:?}: {:?}", key, val); + } + let _ = writeln!(f, " body: {:?}", self.response.body().size()); + res + } +} + +pub struct WebService { + rdef: String, + name: Option, + guards: Vec>, +} + +impl WebService { + /// Create new `WebService` instance. + pub fn new(path: &str) -> Self { + WebService { + rdef: path.to_string(), + name: None, + guards: Vec::new(), + } + } + + /// Set service name. + /// + /// Name is used for url generation. + pub fn name(mut self, name: &str) -> Self { + self.name = Some(name.to_string()); + self + } + + /// Add match guard to a web service. + /// + /// ```rust + /// use actix_web::{web, guard, dev, App, Error, HttpResponse}; + /// + /// async fn index(req: dev::ServiceRequest) -> Result { + /// Ok(req.into_response(HttpResponse::Ok().finish())) + /// } + /// + /// fn main() { + /// let app = App::new() + /// .service( + /// web::service("/app") + /// .guard(guard::Header("content-type", "text/plain")) + /// .finish(index) + /// ); + /// } + /// ``` + pub fn guard(mut self, guard: G) -> Self { + self.guards.push(Box::new(guard)); + self + } + + /// Set a service factory implementation and generate web service. + pub fn finish(self, service: F) -> impl HttpServiceFactory + where + F: IntoServiceFactory, + T: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + > + 'static, + { + WebServiceImpl { + srv: service.into_factory(), + rdef: self.rdef, + name: self.name, + guards: self.guards, + } + } +} + +struct WebServiceImpl { + srv: T, + rdef: String, + name: Option, + guards: Vec>, +} + +impl HttpServiceFactory for WebServiceImpl +where + T: ServiceFactory< + Config = (), + Request = ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + > + 'static, +{ + fn register(mut self, config: &mut AppService) { + let guards = if self.guards.is_empty() { + None + } else { + Some(std::mem::replace(&mut self.guards, Vec::new())) + }; + + let mut rdef = if config.is_root() || !self.rdef.is_empty() { + ResourceDef::new(&insert_slash(&self.rdef)) + } else { + ResourceDef::new(&self.rdef) + }; + if let Some(ref name) = self.name { + *rdef.name_mut() = name.clone(); + } + config.register_service(rdef, guards, self.srv, None) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::{init_service, TestRequest}; + use crate::{guard, http, web, App, HttpResponse}; + use actix_service::Service; + use futures::future::ok; + + #[test] + fn test_service_request() { + let req = TestRequest::default().to_srv_request(); + let (r, pl) = req.into_parts(); + assert!(ServiceRequest::from_parts(r, pl).is_ok()); + + let req = TestRequest::default().to_srv_request(); + let (r, pl) = req.into_parts(); + let _r2 = r.clone(); + assert!(ServiceRequest::from_parts(r, pl).is_err()); + + let req = TestRequest::default().to_srv_request(); + let (r, _pl) = req.into_parts(); + assert!(ServiceRequest::from_request(r).is_ok()); + + let req = TestRequest::default().to_srv_request(); + let (r, _pl) = req.into_parts(); + let _r2 = r.clone(); + assert!(ServiceRequest::from_request(r).is_err()); + } + + #[actix_rt::test] + async fn test_service() { + let mut srv = init_service( + App::new().service(web::service("/test").name("test").finish( + |req: ServiceRequest| ok(req.into_response(HttpResponse::Ok().finish())), + )), + ) + .await; + let req = TestRequest::with_uri("/test").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), http::StatusCode::OK); + + let mut srv = init_service( + App::new().service(web::service("/test").guard(guard::Get()).finish( + |req: ServiceRequest| ok(req.into_response(HttpResponse::Ok().finish())), + )), + ) + .await; + let req = TestRequest::with_uri("/test") + .method(http::Method::PUT) + .to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), http::StatusCode::NOT_FOUND); + } + + #[test] + fn test_fmt_debug() { + let req = TestRequest::get() + .uri("/index.html?test=1") + .header("x-test", "111") + .to_srv_request(); + let s = format!("{:?}", req); + assert!(s.contains("ServiceRequest")); + assert!(s.contains("test=1")); + assert!(s.contains("x-test")); + + let res = HttpResponse::Ok().header("x-test", "111").finish(); + let res = TestRequest::post() + .uri("/index.html?test=1") + .to_srv_response(res); + + let s = format!("{:?}", res); + assert!(s.contains("ServiceResponse")); + assert!(s.contains("x-test")); + } +} diff --git a/src/test.rs b/src/test.rs index b6fd22d2c..e19393156 100644 --- a/src/test.rs +++ b/src/test.rs @@ -1,395 +1,271 @@ //! Various helpers for Actix applications to use during testing. - -use std::{net, thread}; use std::rc::Rc; -use std::sync::mpsc; -use std::str::FromStr; -use actix::{Actor, Arbiter, Addr, Syn, System, SystemRunner, Unsync, msgs}; -use cookie::Cookie; -use http::{Uri, Method, Version, HeaderMap, HttpTryFrom}; -use http::header::HeaderName; -use futures::Future; -use tokio_core::net::TcpListener; -use tokio_core::reactor::Core; -use net2::TcpBuilder; +use actix_http::http::header::{ContentType, Header, HeaderName, IntoHeaderValue}; +use actix_http::http::{HttpTryFrom, Method, StatusCode, Uri, Version}; +use actix_http::test::TestRequest as HttpTestRequest; +use actix_http::{cookie::Cookie, Extensions, Request}; +use actix_router::{Path, ResourceDef, Url}; +use actix_server_config::ServerConfig; +use actix_service::{IntoService, IntoServiceFactory, Service, ServiceFactory}; +use bytes::{Bytes, BytesMut}; +use futures::future::ok; +use futures::stream::{Stream, StreamExt}; +use serde::de::DeserializeOwned; +use serde::Serialize; +use serde_json; -#[cfg(feature="alpn")] -use openssl::ssl::SslAcceptor; +pub use actix_http::test::TestBuffer; -use ws; -use body::Binary; -use error::Error; -use header::{Header, IntoHeaderValue}; -use handler::{Handler, Responder, ReplyItem}; -use middleware::Middleware; -use application::{App, HttpApplication}; -use param::Params; -use router::Router; -use payload::Payload; -use resource::ResourceHandler; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; -use server::{HttpServer, IntoHttpHandler, ServerSettings}; -use client::{ClientRequest, ClientRequestBuilder, ClientConnector}; +use crate::config::{AppConfig, AppConfigInner}; +use crate::data::Data; +use crate::dev::{Body, MessageBody, Payload}; +use crate::request::HttpRequestPool; +use crate::rmap::ResourceMap; +use crate::service::{ServiceRequest, ServiceResponse}; +use crate::{Error, HttpRequest, HttpResponse}; -/// The `TestServer` type. -/// -/// `TestServer` is very simple test server that simplify process of writing -/// integration tests cases for actix web applications. -/// -/// # Examples +/// Create service that always responds with `HttpResponse::Ok()` +pub fn ok_service( +) -> impl Service, Error = Error> +{ + default_service(StatusCode::OK) +} + +/// Create service that responds with response with specified status code +pub fn default_service( + status_code: StatusCode, +) -> impl Service, Error = Error> +{ + (move |req: ServiceRequest| { + ok(req.into_response(HttpResponse::build(status_code).finish())) + }) + .into_service() +} + +/// This method accepts application builder instance, and constructs +/// service. /// /// ```rust -/// # extern crate actix; -/// # extern crate actix_web; -/// # use actix_web::*; -/// # -/// # fn my_handler(req: HttpRequest) -> HttpResponse { -/// # HttpResponse::Ok().into() -/// # } -/// # -/// # fn main() { -/// use actix_web::test::TestServer; +/// use actix_service::Service; +/// use actix_web::{test, web, App, HttpResponse, http::StatusCode}; /// -/// let mut srv = TestServer::new(|app| app.handler(my_handler)); +/// #[actix_rt::test] +/// async fn test_init_service() { +/// let mut app = test::init_service( +/// App::new() +/// .service(web::resource("/test").to(|| async { HttpResponse::Ok() })) +/// ); /// -/// let req = srv.get().finish().unwrap(); -/// let response = srv.execute(req.send()).unwrap(); -/// assert!(response.status().is_success()); -/// # } +/// // Create request object +/// let req = test::TestRequest::with_uri("/test").to_request(); +/// +/// // Execute application +/// let resp = app.call(req).await.unwrap(); +/// assert_eq!(resp.status(), StatusCode::OK); +/// } /// ``` -pub struct TestServer { - addr: net::SocketAddr, - thread: Option>, - system: SystemRunner, - server_sys: Addr, - ssl: bool, - conn: Addr, +pub async fn init_service( + app: R, +) -> impl Service, Error = E> +where + R: IntoServiceFactory, + S: ServiceFactory< + Config = ServerConfig, + Request = Request, + Response = ServiceResponse, + Error = E, + >, + S::InitError: std::fmt::Debug, +{ + let cfg = ServerConfig::new("127.0.0.1:8080".parse().unwrap()); + let srv = app.into_factory(); + srv.new_service(&cfg).await.unwrap() } -impl TestServer { - - /// Start new test server - /// - /// This method accepts configuration method. You can add - /// middlewares or set handlers for test application. - pub fn new(config: F) -> Self - where F: Sync + Send + 'static + Fn(&mut TestApp<()>) - { - TestServerBuilder::new(||()).start(config) - } - - /// Create test server builder - pub fn build() -> TestServerBuilder<()> { - TestServerBuilder::new(||()) - } - - /// Create test server builder with specific state factory - /// - /// This method can be used for constructing application state. - /// Also it can be used for external dependecy initialization, - /// like creating sync actors for diesel integration. - pub fn build_with_state(state: F) -> TestServerBuilder - where F: Fn() -> S + Sync + Send + 'static, - S: 'static, - { - TestServerBuilder::new(state) - } - - /// Start new test server with application factory - pub fn with_factory(factory: F) -> Self - where F: Fn() -> U + Sync + Send + 'static, - U: IntoIterator + 'static, - H: IntoHttpHandler + 'static, - { - let (tx, rx) = mpsc::channel(); - - // run server in separate thread - let join = thread::spawn(move || { - let sys = System::new("actix-test-server"); - let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap(); - let local_addr = tcp.local_addr().unwrap(); - let tcp = TcpListener::from_listener( - tcp, &local_addr, Arbiter::handle()).unwrap(); - - HttpServer::new(factory) - .disable_signals() - .start_incoming(tcp.incoming(), false); - - tx.send((Arbiter::system(), local_addr)).unwrap(); - let _ = sys.run(); - }); - - let sys = System::new("actix-test"); - let (server_sys, addr) = rx.recv().unwrap(); - TestServer { - addr, - server_sys, - ssl: false, - conn: TestServer::get_conn(), - thread: Some(join), - system: sys, - } - } - - fn get_conn() -> Addr { - #[cfg(feature="alpn")] - { - use openssl::ssl::{SslMethod, SslConnector, SslVerifyMode}; - - let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); - builder.set_verify(SslVerifyMode::NONE); - ClientConnector::with_connector(builder.build()).start() - } - #[cfg(not(feature="alpn"))] - { - ClientConnector::default().start() - } - } - - /// Get firat available unused address - pub fn unused_addr() -> net::SocketAddr { - let addr: net::SocketAddr = "127.0.0.1:0".parse().unwrap(); - let socket = TcpBuilder::new_v4().unwrap(); - socket.bind(&addr).unwrap(); - socket.reuse_address(true).unwrap(); - let tcp = socket.to_tcp_listener().unwrap(); - tcp.local_addr().unwrap() - } - - /// Construct test server url - pub fn addr(&self) -> net::SocketAddr { - self.addr - } - - /// Construct test server url - pub fn url(&self, uri: &str) -> String { - if uri.starts_with('/') { - format!("{}://{}{}", if self.ssl {"https"} else {"http"}, self.addr, uri) - } else { - format!("{}://{}/{}", if self.ssl {"https"} else {"http"}, self.addr, uri) - } - } - - /// Stop http server - fn stop(&mut self) { - if let Some(handle) = self.thread.take() { - self.server_sys.do_send(msgs::SystemExit(0)); - let _ = handle.join(); - } - } - - /// Execute future on current core - pub fn execute(&mut self, fut: F) -> Result - where F: Future - { - self.system.run_until_complete(fut) - } - - /// Connect to websocket server - pub fn ws(&mut self) -> Result<(ws::ClientReader, ws::ClientWriter), ws::ClientError> { - let url = self.url("/"); - self.system.run_until_complete( - ws::Client::with_connector(url, self.conn.clone()).connect()) - } - - /// Create `GET` request - pub fn get(&self) -> ClientRequestBuilder { - ClientRequest::get(self.url("/").as_str()) - } - - /// Create `POST` request - pub fn post(&self) -> ClientRequestBuilder { - ClientRequest::get(self.url("/").as_str()) - } - - /// Create `HEAD` request - pub fn head(&self) -> ClientRequestBuilder { - ClientRequest::head(self.url("/").as_str()) - } - - /// Connect to test http server - pub fn client(&self, meth: Method, path: &str) -> ClientRequestBuilder { - ClientRequest::build() - .method(meth) - .uri(self.url(path).as_str()) - .with_connector(self.conn.clone()) - .take() - } -} - -impl Drop for TestServer { - fn drop(&mut self) { - self.stop() - } -} - -/// An `TestServer` builder -/// -/// This type can be used to construct an instance of `TestServer` through a -/// builder-like pattern. -pub struct TestServerBuilder { - state: Box S + Sync + Send + 'static>, - #[cfg(feature="alpn")] - ssl: Option, -} - -impl TestServerBuilder { - - pub fn new(state: F) -> TestServerBuilder - where F: Fn() -> S + Sync + Send + 'static - { - TestServerBuilder { - state: Box::new(state), - #[cfg(feature="alpn")] - ssl: None, - } - } - - #[cfg(feature="alpn")] - /// Create ssl server - pub fn ssl(mut self, ssl: SslAcceptor) -> Self { - self.ssl = Some(ssl); - self - } - - #[allow(unused_mut)] - /// Configure test application and run test server - pub fn start(mut self, config: F) -> TestServer - where F: Sync + Send + 'static + Fn(&mut TestApp), - { - let (tx, rx) = mpsc::channel(); - - #[cfg(feature="alpn")] - let ssl = self.ssl.is_some(); - #[cfg(not(feature="alpn"))] - let ssl = false; - - // run server in separate thread - let join = thread::spawn(move || { - let sys = System::new("actix-test-server"); - - let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap(); - let local_addr = tcp.local_addr().unwrap(); - let tcp = TcpListener::from_listener( - tcp, &local_addr, Arbiter::handle()).unwrap(); - - let state = self.state; - - let srv = HttpServer::new(move || { - let mut app = TestApp::new(state()); - config(&mut app); - vec![app]}) - .disable_signals(); - - #[cfg(feature="alpn")] - { - use std::io; - use futures::Stream; - use tokio_openssl::SslAcceptorExt; - - let ssl = self.ssl.take(); - if let Some(ssl) = ssl { - srv.start_incoming( - tcp.incoming() - .and_then(move |(sock, addr)| { - ssl.accept_async(sock) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) - .map(move |s| (s, addr)) - }), - false); - } else { - srv.start_incoming(tcp.incoming(), false); - } - } - #[cfg(not(feature="alpn"))] - { - srv.start_incoming(tcp.incoming(), false); - } - - tx.send((Arbiter::system(), local_addr)).unwrap(); - let _ = sys.run(); - }); - - let system = System::new("actix-test"); - let (server_sys, addr) = rx.recv().unwrap(); - TestServer { - addr, - server_sys, - ssl, - system, - conn: TestServer::get_conn(), - thread: Some(join), - } - } -} - -/// Test application helper for testing request handlers. -pub struct TestApp { - app: Option>, -} - -impl TestApp { - fn new(state: S) -> TestApp { - let app = App::with_state(state); - TestApp{app: Some(app)} - } - - /// Register handler for "/" - pub fn handler>(&mut self, handler: H) { - self.app = Some(self.app.take().unwrap().resource("/", |r| r.h(handler))); - } - - /// Register middleware - pub fn middleware(&mut self, mw: T) -> &mut TestApp - where T: Middleware + 'static - { - self.app = Some(self.app.take().unwrap().middleware(mw)); - self - } - - /// Register resource. This method is similar - /// to `App::resource()` method. - pub fn resource(&mut self, path: &str, f: F) -> &mut TestApp - where F: FnOnce(&mut ResourceHandler) + 'static - { - self.app = Some(self.app.take().unwrap().resource(path, f)); - self - } -} - -impl IntoHttpHandler for TestApp { - type Handler = HttpApplication; - - fn into_handler(mut self, settings: ServerSettings) -> HttpApplication { - self.app.take().unwrap().into_handler(settings) - } -} - -#[doc(hidden)] -impl Iterator for TestApp { - type Item = HttpApplication; - - fn next(&mut self) -> Option { - if let Some(mut app) = self.app.take() { - Some(app.finish()) - } else { - None - } - } -} - -/// Test `HttpRequest` builder +/// Calls service and waits for response future completion. /// /// ```rust -/// # extern crate http; -/// # extern crate actix_web; -/// # use http::{header, StatusCode}; -/// # use actix_web::*; -/// use actix_web::test::TestRequest; +/// use actix_web::{test, App, HttpResponse, http::StatusCode}; +/// use actix_service::Service; /// -/// fn index(req: HttpRequest) -> HttpResponse { +/// #[test] +/// fn test_response() { +/// let mut app = test::init_service( +/// App::new() +/// .service(web::resource("/test").to(|| async { +/// HttpResponse::Ok() +/// })) +/// ).await; +/// +/// // Create request object +/// let req = test::TestRequest::with_uri("/test").to_request(); +/// +/// // Call application +/// let resp = test::call_service(&mut app, req).await; +/// assert_eq!(resp.status(), StatusCode::OK); +/// } +/// ``` +pub async fn call_service(app: &mut S, req: R) -> S::Response +where + S: Service, Error = E>, + E: std::fmt::Debug, +{ + app.call(req).await.unwrap() +} + +/// Helper function that returns a response body of a TestRequest +/// +/// ```rust +/// use actix_web::{test, web, App, HttpResponse, http::header}; +/// use bytes::Bytes; +/// +/// #[actix_rt::test] +/// async fn test_index() { +/// let mut app = test::init_service( +/// App::new().service( +/// web::resource("/index.html") +/// .route(web::post().to(|| async { +/// HttpResponse::Ok().body("welcome!") +/// }))) +/// ).await; +/// +/// let req = test::TestRequest::post() +/// .uri("/index.html") +/// .header(header::CONTENT_TYPE, "application/json") +/// .to_request(); +/// +/// let result = test::read_response(&mut app, req).await; +/// assert_eq!(result, Bytes::from_static(b"welcome!")); +/// } +/// ``` +pub async fn read_response(app: &mut S, req: Request) -> Bytes +where + S: Service, Error = Error>, + B: MessageBody, +{ + let mut resp = app + .call(req) + .await + .unwrap_or_else(|_| panic!("read_response failed at application call")); + + let mut body = resp.take_body(); + let mut bytes = BytesMut::new(); + while let Some(item) = body.next().await { + bytes.extend_from_slice(&item.unwrap()); + } + bytes.freeze() +} + +/// Helper function that returns a response body of a ServiceResponse. +/// +/// ```rust +/// use actix_web::{test, web, App, HttpResponse, http::header}; +/// use bytes::Bytes; +/// +/// #[actix_rt::test] +/// async fn test_index() { +/// let mut app = test::init_service( +/// App::new().service( +/// web::resource("/index.html") +/// .route(web::post().to(|| async { +/// HttpResponse::Ok().body("welcome!") +/// }))) +/// ).await; +/// +/// let req = test::TestRequest::post() +/// .uri("/index.html") +/// .header(header::CONTENT_TYPE, "application/json") +/// .to_request(); +/// +/// let resp = test::call_service(&mut app, req).await; +/// let result = test::read_body(resp); +/// assert_eq!(result, Bytes::from_static(b"welcome!")); +/// } +/// ``` +pub async fn read_body(mut res: ServiceResponse) -> Bytes +where + B: MessageBody, +{ + let mut body = res.take_body(); + let mut bytes = BytesMut::new(); + while let Some(item) = body.next().await { + bytes.extend_from_slice(&item.unwrap()); + } + bytes.freeze() +} + +pub async fn load_stream(mut stream: S) -> Result +where + S: Stream> + Unpin, +{ + let mut data = BytesMut::new(); + while let Some(item) = stream.next().await { + data.extend_from_slice(&item?); + } + Ok(data.freeze()) +} + +/// Helper function that returns a deserialized response body of a TestRequest +/// +/// ```rust +/// use actix_web::{App, test, web, HttpResponse, http::header}; +/// use serde::{Serialize, Deserialize}; +/// +/// #[derive(Serialize, Deserialize)] +/// pub struct Person { +/// id: String, +/// name: String +/// } +/// +/// #[actix_rt::test] +/// async fn test_add_person() { +/// let mut app = test::init_service( +/// App::new().service( +/// web::resource("/people") +/// .route(web::post().to(|person: web::Json| async { +/// HttpResponse::Ok() +/// .json(person.into_inner())}) +/// )) +/// ).await; +/// +/// let payload = r#"{"id":"12345","name":"User name"}"#.as_bytes(); +/// +/// let req = test::TestRequest::post() +/// .uri("/people") +/// .header(header::CONTENT_TYPE, "application/json") +/// .set_payload(payload) +/// .to_request(); +/// +/// let result: Person = test::read_response_json(&mut app, req).await; +/// } +/// ``` +pub async fn read_response_json(app: &mut S, req: Request) -> T +where + S: Service, Error = Error>, + B: MessageBody, + T: DeserializeOwned, +{ + let body = read_response(app, req).await; + + serde_json::from_slice(&body) + .unwrap_or_else(|_| panic!("read_response_json failed during deserialization")) +} + +/// Test `Request` builder. +/// +/// For unit testing, actix provides a request builder type and a simple handler runner. TestRequest implements a builder-like pattern. +/// You can generate various types of request via TestRequest's methods: +/// * `TestRequest::to_request` creates `actix_http::Request` instance. +/// * `TestRequest::to_srv_request` creates `ServiceRequest` instance, which is used for testing middlewares and chain adapters. +/// * `TestRequest::to_srv_response` creates `ServiceResponse` instance. +/// * `TestRequest::to_http_request` creates `HttpRequest` instance, which is used for testing handlers. +/// +/// ```rust +/// use actix_web::{test, HttpRequest, HttpResponse, HttpMessage}; +/// use actix_web::http::{header, StatusCode}; +/// +/// async fn index(req: HttpRequest) -> HttpResponse { /// if let Some(hdr) = req.headers().get(header::CONTENT_TYPE) { /// HttpResponse::Ok().into() /// } else { @@ -397,204 +273,455 @@ impl Iterator for TestApp { /// } /// } /// -/// fn main() { -/// let resp = TestRequest::with_header("content-type", "text/plain") -/// .run(index).unwrap(); +/// #[test] +/// fn test_index() { +/// let req = test::TestRequest::with_header("content-type", "text/plain") +/// .to_http_request(); +/// +/// let resp = index(req).await.unwrap(); /// assert_eq!(resp.status(), StatusCode::OK); /// -/// let resp = TestRequest::default() -/// .run(index).unwrap(); +/// let req = test::TestRequest::default().to_http_request(); +/// let resp = index(req).await.unwrap(); /// assert_eq!(resp.status(), StatusCode::BAD_REQUEST); /// } /// ``` -pub struct TestRequest { - state: S, - version: Version, - method: Method, - uri: Uri, - headers: HeaderMap, - params: Params<'static>, - cookies: Option>>, - payload: Option, +pub struct TestRequest { + req: HttpTestRequest, + rmap: ResourceMap, + config: AppConfigInner, + path: Path, + app_data: Extensions, } -impl Default for TestRequest<()> { - - fn default() -> TestRequest<()> { +impl Default for TestRequest { + fn default() -> TestRequest { TestRequest { - state: (), - method: Method::GET, - uri: Uri::from_str("/").unwrap(), - version: Version::HTTP_11, - headers: HeaderMap::new(), - params: Params::new(), - cookies: None, - payload: None, + req: HttpTestRequest::default(), + rmap: ResourceMap::new(ResourceDef::new("")), + config: AppConfigInner::default(), + path: Path::new(Url::new(Uri::default())), + app_data: Extensions::new(), } } } -impl TestRequest<()> { - +#[allow(clippy::wrong_self_convention)] +impl TestRequest { /// Create TestRequest and set request uri - pub fn with_uri(path: &str) -> TestRequest<()> { + pub fn with_uri(path: &str) -> TestRequest { TestRequest::default().uri(path) } /// Create TestRequest and set header - pub fn with_hdr(hdr: H) -> TestRequest<()> - { + pub fn with_hdr(hdr: H) -> TestRequest { TestRequest::default().set(hdr) } /// Create TestRequest and set header - pub fn with_header(key: K, value: V) -> TestRequest<()> - where HeaderName: HttpTryFrom, V: IntoHeaderValue, + pub fn with_header(key: K, value: V) -> TestRequest + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, { TestRequest::default().header(key, value) } -} -impl TestRequest { + /// Create TestRequest and set method to `Method::GET` + pub fn get() -> TestRequest { + TestRequest::default().method(Method::GET) + } - /// Start HttpRequest build process with application state - pub fn with_state(state: S) -> TestRequest { - TestRequest { - state, - method: Method::GET, - uri: Uri::from_str("/").unwrap(), - version: Version::HTTP_11, - headers: HeaderMap::new(), - params: Params::new(), - cookies: None, - payload: None, - } + /// Create TestRequest and set method to `Method::POST` + pub fn post() -> TestRequest { + TestRequest::default().method(Method::POST) + } + + /// Create TestRequest and set method to `Method::PUT` + pub fn put() -> TestRequest { + TestRequest::default().method(Method::PUT) + } + + /// Create TestRequest and set method to `Method::PATCH` + pub fn patch() -> TestRequest { + TestRequest::default().method(Method::PATCH) + } + + /// Create TestRequest and set method to `Method::DELETE` + pub fn delete() -> TestRequest { + TestRequest::default().method(Method::DELETE) } /// Set HTTP version of this request pub fn version(mut self, ver: Version) -> Self { - self.version = ver; + self.req.version(ver); self } /// Set HTTP method of this request pub fn method(mut self, meth: Method) -> Self { - self.method = meth; + self.req.method(meth); self } /// Set HTTP Uri of this request pub fn uri(mut self, path: &str) -> Self { - self.uri = Uri::from_str(path).unwrap(); + self.req.uri(path); self } /// Set a header - pub fn set(mut self, hdr: H) -> Self - { - if let Ok(value) = hdr.try_into() { - self.headers.append(H::name(), value); - return self - } - panic!("Can not set header"); + pub fn set(mut self, hdr: H) -> Self { + self.req.set(hdr); + self } /// Set a header pub fn header(mut self, key: K, value: V) -> Self - where HeaderName: HttpTryFrom, V: IntoHeaderValue + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, { - if let Ok(key) = HeaderName::try_from(key) { - if let Ok(value) = value.try_into() { - self.headers.append(key, value); - return self - } - } - panic!("Can not create header"); + self.req.header(key, value); + self + } + + /// Set cookie for this request + pub fn cookie(mut self, cookie: Cookie) -> Self { + self.req.cookie(cookie); + self } /// Set request path pattern parameter pub fn param(mut self, name: &'static str, value: &'static str) -> Self { - self.params.add(name, value); + self.path.add_static(name, value); self } /// Set request payload - pub fn set_payload>(mut self, data: B) -> Self { - let mut data = data.into(); - let mut payload = Payload::empty(); - payload.unread_data(data.take()); - self.payload = Some(payload); + pub fn set_payload>(mut self, data: B) -> Self { + self.req.set_payload(data); self } - /// Complete request creation and generate `HttpRequest` instance - pub fn finish(self) -> HttpRequest { - let TestRequest { state, method, uri, version, headers, params, cookies, payload } = self; - let req = HttpRequest::new(method, uri, version, headers, payload); - req.as_mut().cookies = cookies; - req.as_mut().params = params; - let (router, _) = Router::new::("/", ServerSettings::default(), Vec::new()); - req.with_state(Rc::new(state), router) + /// Serialize `data` to a URL encoded form and set it as the request payload. The `Content-Type` + /// header is set to `application/x-www-form-urlencoded`. + pub fn set_form(mut self, data: &T) -> Self { + let bytes = serde_urlencoded::to_string(data) + .expect("Failed to serialize test data as a urlencoded form"); + self.req.set_payload(bytes); + self.req.set(ContentType::form_url_encoded()); + self + } + + /// Serialize `data` to JSON and set it as the request payload. The `Content-Type` header is + /// set to `application/json`. + pub fn set_json(mut self, data: &T) -> Self { + let bytes = + serde_json::to_string(data).expect("Failed to serialize test data to json"); + self.req.set_payload(bytes); + self.req.set(ContentType::json()); + self + } + + /// Set application data. This is equivalent of `App::data()` method + /// for testing purpose. + pub fn data(mut self, data: T) -> Self { + self.app_data.insert(Data::new(data)); + self } #[cfg(test)] + /// Set request config + pub(crate) fn rmap(mut self, rmap: ResourceMap) -> Self { + self.rmap = rmap; + self + } + + /// Complete request creation and generate `Request` instance + pub fn to_request(mut self) -> Request { + self.req.finish() + } + + /// Complete request creation and generate `ServiceRequest` instance + pub fn to_srv_request(mut self) -> ServiceRequest { + let (head, payload) = self.req.finish().into_parts(); + self.path.get_mut().update(&head.uri); + + ServiceRequest::new(HttpRequest::new( + self.path, + head, + payload, + Rc::new(self.rmap), + AppConfig::new(self.config), + Rc::new(self.app_data), + HttpRequestPool::create(), + )) + } + + /// Complete request creation and generate `ServiceResponse` instance + pub fn to_srv_response(self, res: HttpResponse) -> ServiceResponse { + self.to_srv_request().into_response(res) + } + /// Complete request creation and generate `HttpRequest` instance - pub(crate) fn finish_with_router(self, router: Router) -> HttpRequest { - let TestRequest { state, method, uri, - version, headers, params, cookies, payload } = self; + pub fn to_http_request(mut self) -> HttpRequest { + let (head, payload) = self.req.finish().into_parts(); + self.path.get_mut().update(&head.uri); - let req = HttpRequest::new(method, uri, version, headers, payload); - req.as_mut().cookies = cookies; - req.as_mut().params = params; - req.with_state(Rc::new(state), router) + HttpRequest::new( + self.path, + head, + payload, + Rc::new(self.rmap), + AppConfig::new(self.config), + Rc::new(self.app_data), + HttpRequestPool::create(), + ) } - /// This method generates `HttpRequest` instance and runs handler - /// with generated request. - /// - /// This method panics is handler returns actor or async result. - pub fn run>(self, mut h: H) -> - Result>::Result as Responder>::Error> - { - let req = self.finish(); - let resp = h.handle(req.clone()); + /// Complete request creation and generate `HttpRequest` and `Payload` instances + pub fn to_http_parts(mut self) -> (HttpRequest, Payload) { + let (head, payload) = self.req.finish().into_parts(); + self.path.get_mut().update(&head.uri); - match resp.respond_to(req.without_state()) { - Ok(resp) => { - match resp.into().into() { - ReplyItem::Message(resp) => Ok(resp), - ReplyItem::Future(_) => panic!("Async handler is not supported."), - } - }, - Err(err) => Err(err), - } - } + let req = HttpRequest::new( + self.path, + head, + Payload::None, + Rc::new(self.rmap), + AppConfig::new(self.config), + Rc::new(self.app_data), + HttpRequestPool::create(), + ); - /// This method generates `HttpRequest` instance and runs handler - /// with generated request. - /// - /// This method panics is handler returns actor. - pub fn run_async(self, h: H) -> Result - where H: Fn(HttpRequest) -> F + 'static, - F: Future + 'static, - R: Responder + 'static, - E: Into + 'static - { - let req = self.finish(); - let fut = h(req.clone()); - - let mut core = Core::new().unwrap(); - match core.run(fut) { - Ok(r) => { - match r.respond_to(req.without_state()) { - Ok(reply) => match reply.into().into() { - ReplyItem::Message(resp) => Ok(resp), - _ => panic!("Nested async replies are not supported"), - }, - Err(e) => Err(e), - } - }, - Err(err) => Err(err), - } + (req, payload) } } + +#[cfg(test)] +mod tests { + use actix_http::httpmessage::HttpMessage; + use serde::{Deserialize, Serialize}; + use std::time::SystemTime; + + use super::*; + use crate::{http::header, web, App, HttpResponse}; + + #[actix_rt::test] + async fn test_basics() { + let req = TestRequest::with_hdr(header::ContentType::json()) + .version(Version::HTTP_2) + .set(header::Date(SystemTime::now().into())) + .param("test", "123") + .data(10u32) + .to_http_request(); + assert!(req.headers().contains_key(header::CONTENT_TYPE)); + assert!(req.headers().contains_key(header::DATE)); + assert_eq!(&req.match_info()["test"], "123"); + assert_eq!(req.version(), Version::HTTP_2); + let data = req.get_app_data::().unwrap(); + assert!(req.get_app_data::().is_none()); + assert_eq!(*data, 10); + assert_eq!(*data.get_ref(), 10); + + assert!(req.app_data::().is_none()); + let data = req.app_data::().unwrap(); + assert_eq!(*data, 10); + } + + #[actix_rt::test] + async fn test_request_methods() { + let mut app = init_service( + App::new().service( + web::resource("/index.html") + .route(web::put().to(|| async { HttpResponse::Ok().body("put!") })) + .route( + web::patch().to(|| async { HttpResponse::Ok().body("patch!") }), + ) + .route( + web::delete() + .to(|| async { HttpResponse::Ok().body("delete!") }), + ), + ), + ) + .await; + + let put_req = TestRequest::put() + .uri("/index.html") + .header(header::CONTENT_TYPE, "application/json") + .to_request(); + + let result = read_response(&mut app, put_req).await; + assert_eq!(result, Bytes::from_static(b"put!")); + + let patch_req = TestRequest::patch() + .uri("/index.html") + .header(header::CONTENT_TYPE, "application/json") + .to_request(); + + let result = read_response(&mut app, patch_req).await; + assert_eq!(result, Bytes::from_static(b"patch!")); + + let delete_req = TestRequest::delete().uri("/index.html").to_request(); + let result = read_response(&mut app, delete_req).await; + assert_eq!(result, Bytes::from_static(b"delete!")); + } + + #[actix_rt::test] + async fn test_response() { + let mut app = + init_service(App::new().service(web::resource("/index.html").route( + web::post().to(|| async { HttpResponse::Ok().body("welcome!") }), + ))) + .await; + + let req = TestRequest::post() + .uri("/index.html") + .header(header::CONTENT_TYPE, "application/json") + .to_request(); + + let result = read_response(&mut app, req).await; + assert_eq!(result, Bytes::from_static(b"welcome!")); + } + + #[derive(Serialize, Deserialize)] + pub struct Person { + id: String, + name: String, + } + + #[actix_rt::test] + async fn test_response_json() { + let mut app = init_service(App::new().service(web::resource("/people").route( + web::post().to(|person: web::Json| { + async { HttpResponse::Ok().json(person.into_inner()) } + }), + ))) + .await; + + let payload = r#"{"id":"12345","name":"User name"}"#.as_bytes(); + + let req = TestRequest::post() + .uri("/people") + .header(header::CONTENT_TYPE, "application/json") + .set_payload(payload) + .to_request(); + + let result: Person = read_response_json(&mut app, req).await; + assert_eq!(&result.id, "12345"); + } + + #[actix_rt::test] + async fn test_request_response_form() { + let mut app = init_service(App::new().service(web::resource("/people").route( + web::post().to(|person: web::Form| { + async { HttpResponse::Ok().json(person.into_inner()) } + }), + ))) + .await; + + let payload = Person { + id: "12345".to_string(), + name: "User name".to_string(), + }; + + let req = TestRequest::post() + .uri("/people") + .set_form(&payload) + .to_request(); + + assert_eq!(req.content_type(), "application/x-www-form-urlencoded"); + + let result: Person = read_response_json(&mut app, req).await; + assert_eq!(&result.id, "12345"); + assert_eq!(&result.name, "User name"); + } + + #[actix_rt::test] + async fn test_request_response_json() { + let mut app = init_service(App::new().service(web::resource("/people").route( + web::post().to(|person: web::Json| { + async { HttpResponse::Ok().json(person.into_inner()) } + }), + ))) + .await; + + let payload = Person { + id: "12345".to_string(), + name: "User name".to_string(), + }; + + let req = TestRequest::post() + .uri("/people") + .set_json(&payload) + .to_request(); + + assert_eq!(req.content_type(), "application/json"); + + let result: Person = read_response_json(&mut app, req).await; + assert_eq!(&result.id, "12345"); + assert_eq!(&result.name, "User name"); + } + + #[actix_rt::test] + async fn test_async_with_block() { + async fn async_with_block() -> Result { + let res = web::block(move || Some(4usize).ok_or("wrong")).await; + + match res? { + Ok(value) => Ok(HttpResponse::Ok() + .content_type("text/plain") + .body(format!("Async with block value: {}", value))), + Err(_) => panic!("Unexpected"), + } + } + + let mut app = init_service( + App::new().service(web::resource("/index.html").to(async_with_block)), + ) + .await; + + let req = TestRequest::post().uri("/index.html").to_request(); + let res = app.call(req).await.unwrap(); + assert!(res.status().is_success()); + } + + // #[actix_rt::test] + // fn test_actor() { + // use actix::Actor; + + // struct MyActor; + + // struct Num(usize); + // impl actix::Message for Num { + // type Result = usize; + // } + // impl actix::Actor for MyActor { + // type Context = actix::Context; + // } + // impl actix::Handler for MyActor { + // type Result = usize; + // fn handle(&mut self, msg: Num, _: &mut Self::Context) -> Self::Result { + // msg.0 + // } + // } + + // let addr = run_on(|| MyActor.start()); + // let mut app = init_service(App::new().service( + // web::resource("/index.html").to(move || { + // addr.send(Num(1)).from_err().and_then(|res| { + // if res == 1 { + // HttpResponse::Ok() + // } else { + // HttpResponse::BadRequest() + // } + // }) + // }), + // )); + + // let req = TestRequest::post().uri("/index.html").to_request(); + // let res = block_fn(|| app.call(req)).unwrap(); + // assert!(res.status().is_success()); + // } +} diff --git a/src/types/form.rs b/src/types/form.rs new file mode 100644 index 000000000..e1bd52375 --- /dev/null +++ b/src/types/form.rs @@ -0,0 +1,496 @@ +//! Form extractor + +use std::future::Future; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; +use std::{fmt, ops}; + +use actix_http::{Error, HttpMessage, Payload, Response}; +use bytes::BytesMut; +use encoding_rs::{Encoding, UTF_8}; +use futures::future::{err, ok, FutureExt, LocalBoxFuture, Ready}; +use futures::StreamExt; +use serde::de::DeserializeOwned; +use serde::Serialize; + +use crate::dev::Decompress; +use crate::error::UrlencodedError; +use crate::extract::FromRequest; +use crate::http::{ + header::{ContentType, CONTENT_LENGTH}, + StatusCode, +}; +use crate::request::HttpRequest; +use crate::responder::Responder; + +/// Form data helper (`application/x-www-form-urlencoded`) +/// +/// Can be use to extract url-encoded data from the request body, +/// or send url-encoded data as the response. +/// +/// ## Extract +/// +/// To extract typed information from request's body, the type `T` must +/// implement the `Deserialize` trait from *serde*. +/// +/// [**FormConfig**](struct.FormConfig.html) allows to configure extraction +/// process. +/// +/// ### Example +/// ```rust +/// use actix_web::web; +/// use serde_derive::Deserialize; +/// +/// #[derive(Deserialize)] +/// struct FormData { +/// username: String, +/// } +/// +/// /// Extract form data using serde. +/// /// This handler get called only if content type is *x-www-form-urlencoded* +/// /// and content of the request could be deserialized to a `FormData` struct +/// fn index(form: web::Form) -> String { +/// format!("Welcome {}!", form.username) +/// } +/// # fn main() {} +/// ``` +/// +/// ## Respond +/// +/// The `Form` type also allows you to respond with well-formed url-encoded data: +/// simply return a value of type Form where T is the type to be url-encoded. +/// The type must implement `serde::Serialize`; +/// +/// ### Example +/// ```rust +/// use actix_web::*; +/// use serde_derive::Serialize; +/// +/// #[derive(Serialize)] +/// struct SomeForm { +/// name: String, +/// age: u8 +/// } +/// +/// // Will return a 200 response with header +/// // `Content-Type: application/x-www-form-urlencoded` +/// // and body "name=actix&age=123" +/// fn index() -> web::Form { +/// web::Form(SomeForm { +/// name: "actix".into(), +/// age: 123 +/// }) +/// } +/// # fn main() {} +/// ``` +#[derive(PartialEq, Eq, PartialOrd, Ord)] +pub struct Form(pub T); + +impl Form { + /// Deconstruct to an inner value + pub fn into_inner(self) -> T { + self.0 + } +} + +impl ops::Deref for Form { + type Target = T; + + fn deref(&self) -> &T { + &self.0 + } +} + +impl ops::DerefMut for Form { + fn deref_mut(&mut self) -> &mut T { + &mut self.0 + } +} + +impl FromRequest for Form +where + T: DeserializeOwned + 'static, +{ + type Config = FormConfig; + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; + + #[inline] + fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { + let req2 = req.clone(); + let (limit, err) = req + .app_data::() + .map(|c| (c.limit, c.ehandler.clone())) + .unwrap_or((16384, None)); + + UrlEncoded::new(req, payload) + .limit(limit) + .map(move |res| match res { + Err(e) => { + if let Some(err) = err { + Err((*err)(e, &req2)) + } else { + Err(e.into()) + } + } + Ok(item) => Ok(Form(item)), + }) + .boxed_local() + } +} + +impl fmt::Debug for Form { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +impl fmt::Display for Form { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +impl Responder for Form { + type Error = Error; + type Future = Ready>; + + fn respond_to(self, _: &HttpRequest) -> Self::Future { + let body = match serde_urlencoded::to_string(&self.0) { + Ok(body) => body, + Err(e) => return err(e.into()), + }; + + ok(Response::build(StatusCode::OK) + .set(ContentType::form_url_encoded()) + .body(body)) + } +} + +/// Form extractor configuration +/// +/// ```rust +/// use actix_web::{web, App, FromRequest, Result}; +/// use serde_derive::Deserialize; +/// +/// #[derive(Deserialize)] +/// struct FormData { +/// username: String, +/// } +/// +/// /// Extract form data using serde. +/// /// Custom configuration is used for this handler, max payload size is 4k +/// async fn index(form: web::Form) -> Result { +/// Ok(format!("Welcome {}!", form.username)) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/index.html") +/// // change `Form` extractor configuration +/// .data( +/// web::Form::::configure(|cfg| cfg.limit(4097)) +/// ) +/// .route(web::get().to(index)) +/// ); +/// } +/// ``` +#[derive(Clone)] +pub struct FormConfig { + limit: usize, + ehandler: Option Error>>, +} + +impl FormConfig { + /// Change max size of payload. By default max size is 16Kb + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } + + /// Set custom error handler + pub fn error_handler(mut self, f: F) -> Self + where + F: Fn(UrlencodedError, &HttpRequest) -> Error + 'static, + { + self.ehandler = Some(Rc::new(f)); + self + } +} + +impl Default for FormConfig { + fn default() -> Self { + FormConfig { + limit: 16384, + ehandler: None, + } + } +} + +/// Future that resolves to a parsed urlencoded values. +/// +/// Parse `application/x-www-form-urlencoded` encoded request's body. +/// Return `UrlEncoded` future. Form can be deserialized to any type that +/// implements `Deserialize` trait from *serde*. +/// +/// Returns error: +/// +/// * content type is not `application/x-www-form-urlencoded` +/// * content-length is greater than 32k +/// +pub struct UrlEncoded { + stream: Option>, + limit: usize, + length: Option, + encoding: &'static Encoding, + err: Option, + fut: Option>>, +} + +impl UrlEncoded { + /// Create a new future to URL encode a request + pub fn new(req: &HttpRequest, payload: &mut Payload) -> UrlEncoded { + // check content type + if req.content_type().to_lowercase() != "application/x-www-form-urlencoded" { + return Self::err(UrlencodedError::ContentType); + } + let encoding = match req.encoding() { + Ok(enc) => enc, + Err(_) => return Self::err(UrlencodedError::ContentType), + }; + + let mut len = None; + if let Some(l) = req.headers().get(&CONTENT_LENGTH) { + if let Ok(s) = l.to_str() { + if let Ok(l) = s.parse::() { + len = Some(l) + } else { + return Self::err(UrlencodedError::UnknownLength); + } + } else { + return Self::err(UrlencodedError::UnknownLength); + } + }; + + let payload = Decompress::from_headers(payload.take(), req.headers()); + UrlEncoded { + encoding, + stream: Some(payload), + limit: 32_768, + length: len, + fut: None, + err: None, + } + } + + fn err(e: UrlencodedError) -> Self { + UrlEncoded { + stream: None, + limit: 32_768, + fut: None, + err: Some(e), + length: None, + encoding: UTF_8, + } + } + + /// Change max size of payload. By default max size is 256Kb + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } +} + +impl Future for UrlEncoded +where + U: DeserializeOwned + 'static, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + if let Some(ref mut fut) = self.fut { + return Pin::new(fut).poll(cx); + } + + if let Some(err) = self.err.take() { + return Poll::Ready(Err(err)); + } + + // payload size + let limit = self.limit; + if let Some(len) = self.length.take() { + if len > limit { + return Poll::Ready(Err(UrlencodedError::Overflow { size: len, limit })); + } + } + + // future + let encoding = self.encoding; + let mut stream = self.stream.take().unwrap(); + + self.fut = Some( + async move { + let mut body = BytesMut::with_capacity(8192); + + while let Some(item) = stream.next().await { + let chunk = item?; + if (body.len() + chunk.len()) > limit { + return Err(UrlencodedError::Overflow { + size: body.len() + chunk.len(), + limit, + }); + } else { + body.extend_from_slice(&chunk); + } + } + + if encoding == UTF_8 { + serde_urlencoded::from_bytes::(&body) + .map_err(|_| UrlencodedError::Parse) + } else { + let body = encoding + .decode_without_bom_handling_and_without_replacement(&body) + .map(|s| s.into_owned()) + .ok_or(UrlencodedError::Parse)?; + serde_urlencoded::from_str::(&body) + .map_err(|_| UrlencodedError::Parse) + } + } + .boxed_local(), + ); + self.poll(cx) + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use serde::{Deserialize, Serialize}; + + use super::*; + use crate::http::header::{HeaderValue, CONTENT_TYPE}; + use crate::test::TestRequest; + + #[derive(Deserialize, Serialize, Debug, PartialEq)] + struct Info { + hello: String, + counter: i64, + } + + #[actix_rt::test] + async fn test_form() { + let (req, mut pl) = + TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded") + .header(CONTENT_LENGTH, "11") + .set_payload(Bytes::from_static(b"hello=world&counter=123")) + .to_http_parts(); + + let Form(s) = Form::::from_request(&req, &mut pl).await.unwrap(); + assert_eq!( + s, + Info { + hello: "world".into(), + counter: 123 + } + ); + } + + fn eq(err: UrlencodedError, other: UrlencodedError) -> bool { + match err { + UrlencodedError::Overflow { .. } => match other { + UrlencodedError::Overflow { .. } => true, + _ => false, + }, + UrlencodedError::UnknownLength => match other { + UrlencodedError::UnknownLength => true, + _ => false, + }, + UrlencodedError::ContentType => match other { + UrlencodedError::ContentType => true, + _ => false, + }, + _ => false, + } + } + + #[actix_rt::test] + async fn test_urlencoded_error() { + let (req, mut pl) = + TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded") + .header(CONTENT_LENGTH, "xxxx") + .to_http_parts(); + let info = UrlEncoded::::new(&req, &mut pl).await; + assert!(eq(info.err().unwrap(), UrlencodedError::UnknownLength)); + + let (req, mut pl) = + TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded") + .header(CONTENT_LENGTH, "1000000") + .to_http_parts(); + let info = UrlEncoded::::new(&req, &mut pl).await; + assert!(eq( + info.err().unwrap(), + UrlencodedError::Overflow { size: 0, limit: 0 } + )); + + let (req, mut pl) = TestRequest::with_header(CONTENT_TYPE, "text/plain") + .header(CONTENT_LENGTH, "10") + .to_http_parts(); + let info = UrlEncoded::::new(&req, &mut pl).await; + assert!(eq(info.err().unwrap(), UrlencodedError::ContentType)); + } + + #[actix_rt::test] + async fn test_urlencoded() { + let (req, mut pl) = + TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded") + .header(CONTENT_LENGTH, "11") + .set_payload(Bytes::from_static(b"hello=world&counter=123")) + .to_http_parts(); + + let info = UrlEncoded::::new(&req, &mut pl).await.unwrap(); + assert_eq!( + info, + Info { + hello: "world".to_owned(), + counter: 123 + } + ); + + let (req, mut pl) = TestRequest::with_header( + CONTENT_TYPE, + "application/x-www-form-urlencoded; charset=utf-8", + ) + .header(CONTENT_LENGTH, "11") + .set_payload(Bytes::from_static(b"hello=world&counter=123")) + .to_http_parts(); + + let info = UrlEncoded::::new(&req, &mut pl).await.unwrap(); + assert_eq!( + info, + Info { + hello: "world".to_owned(), + counter: 123 + } + ); + } + + #[actix_rt::test] + async fn test_responder() { + let req = TestRequest::default().to_http_request(); + + let form = Form(Info { + hello: "world".to_string(), + counter: 123, + }); + let resp = form.respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/x-www-form-urlencoded") + ); + + use crate::responder::tests::BodyTest; + assert_eq!(resp.body().bin_ref(), b"hello=world&counter=123"); + } +} diff --git a/src/types/json.rs b/src/types/json.rs new file mode 100644 index 000000000..028092d1a --- /dev/null +++ b/src/types/json.rs @@ -0,0 +1,645 @@ +//! Json extractor/responder + +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::{fmt, ops}; + +use bytes::BytesMut; +use futures::future::{err, ok, FutureExt, LocalBoxFuture, Ready}; +use futures::StreamExt; +use serde::de::DeserializeOwned; +use serde::Serialize; +use serde_json; + +use actix_http::http::{header::CONTENT_LENGTH, StatusCode}; +use actix_http::{HttpMessage, Payload, Response}; + +use crate::dev::Decompress; +use crate::error::{Error, JsonPayloadError}; +use crate::extract::FromRequest; +use crate::request::HttpRequest; +use crate::responder::Responder; + +/// Json helper +/// +/// Json can be used for two different purpose. First is for json response +/// generation and second is for extracting typed information from request's +/// payload. +/// +/// To extract typed information from request's body, the type `T` must +/// implement the `Deserialize` trait from *serde*. +/// +/// [**JsonConfig**](struct.JsonConfig.html) allows to configure extraction +/// process. +/// +/// ## Example +/// +/// ```rust +/// use actix_web::{web, App}; +/// use serde_derive::Deserialize; +/// +/// #[derive(Deserialize)] +/// struct Info { +/// username: String, +/// } +/// +/// /// deserialize `Info` from request's body +/// async fn index(info: web::Json) -> String { +/// format!("Welcome {}!", info.username) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/index.html").route( +/// web::post().to(index)) +/// ); +/// } +/// ``` +/// +/// The `Json` type allows you to respond with well-formed JSON data: simply +/// return a value of type Json where T is the type of a structure +/// to serialize into *JSON*. The type `T` must implement the `Serialize` +/// trait from *serde*. +/// +/// ```rust +/// use actix_web::*; +/// use serde_derive::Serialize; +/// +/// #[derive(Serialize)] +/// struct MyObj { +/// name: String, +/// } +/// +/// fn index(req: HttpRequest) -> Result> { +/// Ok(web::Json(MyObj { +/// name: req.match_info().get("name").unwrap().to_string(), +/// })) +/// } +/// # fn main() {} +/// ``` +pub struct Json(pub T); + +impl Json { + /// Deconstruct to an inner value + pub fn into_inner(self) -> T { + self.0 + } +} + +impl ops::Deref for Json { + type Target = T; + + fn deref(&self) -> &T { + &self.0 + } +} + +impl ops::DerefMut for Json { + fn deref_mut(&mut self) -> &mut T { + &mut self.0 + } +} + +impl fmt::Debug for Json +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Json: {:?}", self.0) + } +} + +impl fmt::Display for Json +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + +impl Responder for Json { + type Error = Error; + type Future = Ready>; + + fn respond_to(self, _: &HttpRequest) -> Self::Future { + let body = match serde_json::to_string(&self.0) { + Ok(body) => body, + Err(e) => return err(e.into()), + }; + + ok(Response::build(StatusCode::OK) + .content_type("application/json") + .body(body)) + } +} + +/// Json extractor. Allow to extract typed information from request's +/// payload. +/// +/// To extract typed information from request's body, the type `T` must +/// implement the `Deserialize` trait from *serde*. +/// +/// [**JsonConfig**](struct.JsonConfig.html) allows to configure extraction +/// process. +/// +/// ## Example +/// +/// ```rust +/// use actix_web::{web, App}; +/// use serde_derive::Deserialize; +/// +/// #[derive(Deserialize)] +/// struct Info { +/// username: String, +/// } +/// +/// /// deserialize `Info` from request's body +/// async fn index(info: web::Json) -> String { +/// format!("Welcome {}!", info.username) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/index.html").route( +/// web::post().to(index)) +/// ); +/// } +/// ``` +impl FromRequest for Json +where + T: DeserializeOwned + 'static, +{ + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; + type Config = JsonConfig; + + #[inline] + fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { + let req2 = req.clone(); + let (limit, err, ctype) = req + .app_data::() + .map(|c| (c.limit, c.ehandler.clone(), c.content_type.clone())) + .unwrap_or((32768, None, None)); + + JsonBody::new(req, payload, ctype) + .limit(limit) + .map(move |res| match res { + Err(e) => { + log::debug!( + "Failed to deserialize Json from payload. \ + Request path: {}", + req2.path() + ); + if let Some(err) = err { + Err((*err)(e, &req2)) + } else { + Err(e.into()) + } + } + Ok(data) => Ok(Json(data)), + }) + .boxed_local() + } +} + +/// Json extractor configuration +/// +/// ```rust +/// use actix_web::{error, web, App, FromRequest, HttpResponse}; +/// use serde_derive::Deserialize; +/// +/// #[derive(Deserialize)] +/// struct Info { +/// username: String, +/// } +/// +/// /// deserialize `Info` from request's body, max payload size is 4kb +/// async fn index(info: web::Json) -> String { +/// format!("Welcome {}!", info.username) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/index.html").data( +/// // change json extractor configuration +/// web::Json::::configure(|cfg| { +/// cfg.limit(4096) +/// .content_type(|mime| { // <- accept text/plain content type +/// mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN +/// }) +/// .error_handler(|err, req| { // <- create custom error response +/// error::InternalError::from_response( +/// err, HttpResponse::Conflict().finish()).into() +/// }) +/// })) +/// .route(web::post().to(index)) +/// ); +/// } +/// ``` +#[derive(Clone)] +pub struct JsonConfig { + limit: usize, + ehandler: Option Error + Send + Sync>>, + content_type: Option bool + Send + Sync>>, +} + +impl JsonConfig { + /// Change max size of payload. By default max size is 32Kb + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } + + /// Set custom error handler + pub fn error_handler(mut self, f: F) -> Self + where + F: Fn(JsonPayloadError, &HttpRequest) -> Error + Send + Sync + 'static, + { + self.ehandler = Some(Arc::new(f)); + self + } + + /// Set predicate for allowed content types + pub fn content_type(mut self, predicate: F) -> Self + where + F: Fn(mime::Mime) -> bool + Send + Sync + 'static, + { + self.content_type = Some(Arc::new(predicate)); + self + } +} + +impl Default for JsonConfig { + fn default() -> Self { + JsonConfig { + limit: 32768, + ehandler: None, + content_type: None, + } + } +} + +/// Request's payload json parser, it resolves to a deserialized `T` value. +/// This future could be used with `ServiceRequest` and `ServiceFromRequest`. +/// +/// Returns error: +/// +/// * content type is not `application/json` +/// (unless specified in [`JsonConfig`](struct.JsonConfig.html)) +/// * content length is greater than 256k +pub struct JsonBody { + limit: usize, + length: Option, + stream: Option>, + err: Option, + fut: Option>>, +} + +impl JsonBody +where + U: DeserializeOwned + 'static, +{ + /// Create `JsonBody` for request. + pub fn new( + req: &HttpRequest, + payload: &mut Payload, + ctype: Option bool + Send + Sync>>, + ) -> Self { + // check content-type + let json = if let Ok(Some(mime)) = req.mime_type() { + mime.subtype() == mime::JSON + || mime.suffix() == Some(mime::JSON) + || ctype.as_ref().map_or(false, |predicate| predicate(mime)) + } else { + false + }; + + if !json { + return JsonBody { + limit: 262_144, + length: None, + stream: None, + fut: None, + err: Some(JsonPayloadError::ContentType), + }; + } + + let len = req + .headers() + .get(&CONTENT_LENGTH) + .and_then(|l| l.to_str().ok()) + .and_then(|s| s.parse::().ok()); + let payload = Decompress::from_headers(payload.take(), req.headers()); + + JsonBody { + limit: 262_144, + length: len, + stream: Some(payload), + fut: None, + err: None, + } + } + + /// Change max size of payload. By default max size is 256Kb + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } +} + +impl Future for JsonBody +where + U: DeserializeOwned + 'static, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + if let Some(ref mut fut) = self.fut { + return Pin::new(fut).poll(cx); + } + + if let Some(err) = self.err.take() { + return Poll::Ready(Err(err)); + } + + let limit = self.limit; + if let Some(len) = self.length.take() { + if len > limit { + return Poll::Ready(Err(JsonPayloadError::Overflow)); + } + } + let mut stream = self.stream.take().unwrap(); + + self.fut = Some( + async move { + let mut body = BytesMut::with_capacity(8192); + + while let Some(item) = stream.next().await { + let chunk = item?; + if (body.len() + chunk.len()) > limit { + return Err(JsonPayloadError::Overflow); + } else { + body.extend_from_slice(&chunk); + } + } + Ok(serde_json::from_slice::(&body)?) + } + .boxed_local(), + ); + + self.poll(cx) + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use serde_derive::{Deserialize, Serialize}; + + use super::*; + use crate::error::InternalError; + use crate::http::header; + use crate::test::{load_stream, TestRequest}; + use crate::HttpResponse; + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct MyObject { + name: String, + } + + fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool { + match err { + JsonPayloadError::Overflow => match other { + JsonPayloadError::Overflow => true, + _ => false, + }, + JsonPayloadError::ContentType => match other { + JsonPayloadError::ContentType => true, + _ => false, + }, + _ => false, + } + } + + #[actix_rt::test] + async fn test_responder() { + let req = TestRequest::default().to_http_request(); + + let j = Json(MyObject { + name: "test".to_string(), + }); + let resp = j.respond_to(&req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + header::HeaderValue::from_static("application/json") + ); + + use crate::responder::tests::BodyTest; + assert_eq!(resp.body().bin_ref(), b"{\"name\":\"test\"}"); + } + + #[actix_rt::test] + async fn test_custom_error_responder() { + let (req, mut pl) = TestRequest::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .data(JsonConfig::default().limit(10).error_handler(|err, _| { + let msg = MyObject { + name: "invalid request".to_string(), + }; + let resp = HttpResponse::BadRequest() + .body(serde_json::to_string(&msg).unwrap()); + InternalError::from_response(err, resp).into() + })) + .to_http_parts(); + + let s = Json::::from_request(&req, &mut pl).await; + let mut resp = Response::from_error(s.err().unwrap().into()); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + let body = load_stream(resp.take_body()).await.unwrap(); + let msg: MyObject = serde_json::from_slice(&body).unwrap(); + assert_eq!(msg.name, "invalid request"); + } + + #[actix_rt::test] + async fn test_extract() { + let (req, mut pl) = TestRequest::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .to_http_parts(); + + let s = Json::::from_request(&req, &mut pl).await.unwrap(); + assert_eq!(s.name, "test"); + assert_eq!( + s.into_inner(), + MyObject { + name: "test".to_string() + } + ); + + let (req, mut pl) = TestRequest::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .data(JsonConfig::default().limit(10)) + .to_http_parts(); + + let s = Json::::from_request(&req, &mut pl).await; + assert!(format!("{}", s.err().unwrap()) + .contains("Json payload size is bigger than allowed")); + + let (req, mut pl) = TestRequest::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .data( + JsonConfig::default() + .limit(10) + .error_handler(|_, _| JsonPayloadError::ContentType.into()), + ) + .to_http_parts(); + let s = Json::::from_request(&req, &mut pl).await; + assert!(format!("{}", s.err().unwrap()).contains("Content type error")); + } + + #[actix_rt::test] + async fn test_json_body() { + let (req, mut pl) = TestRequest::default().to_http_parts(); + let json = JsonBody::::new(&req, &mut pl, None).await; + assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); + + let (req, mut pl) = TestRequest::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/text"), + ) + .to_http_parts(); + let json = JsonBody::::new(&req, &mut pl, None).await; + assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); + + let (req, mut pl) = TestRequest::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("10000"), + ) + .to_http_parts(); + + let json = JsonBody::::new(&req, &mut pl, None) + .limit(100) + .await; + assert!(json_eq(json.err().unwrap(), JsonPayloadError::Overflow)); + + let (req, mut pl) = TestRequest::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .to_http_parts(); + + let json = JsonBody::::new(&req, &mut pl, None).await; + assert_eq!( + json.ok().unwrap(), + MyObject { + name: "test".to_owned() + } + ); + } + + #[actix_rt::test] + async fn test_with_json_and_bad_content_type() { + let (req, mut pl) = TestRequest::with_header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("text/plain"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .data(JsonConfig::default().limit(4096)) + .to_http_parts(); + + let s = Json::::from_request(&req, &mut pl).await; + assert!(s.is_err()) + } + + #[actix_rt::test] + async fn test_with_json_and_good_custom_content_type() { + let (req, mut pl) = TestRequest::with_header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("text/plain"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .data(JsonConfig::default().content_type(|mime: mime::Mime| { + mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN + })) + .to_http_parts(); + + let s = Json::::from_request(&req, &mut pl).await; + assert!(s.is_ok()) + } + + #[actix_rt::test] + async fn test_with_json_and_bad_custom_content_type() { + let (req, mut pl) = TestRequest::with_header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("text/html"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .data(JsonConfig::default().content_type(|mime: mime::Mime| { + mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN + })) + .to_http_parts(); + + let s = Json::::from_request(&req, &mut pl).await; + assert!(s.is_err()) + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs new file mode 100644 index 000000000..b32711e2a --- /dev/null +++ b/src/types/mod.rs @@ -0,0 +1,15 @@ +//! Helper types + +pub(crate) mod form; +pub(crate) mod json; +mod path; +pub(crate) mod payload; +mod query; +pub(crate) mod readlines; + +pub use self::form::{Form, FormConfig}; +pub use self::json::{Json, JsonConfig}; +pub use self::path::{Path, PathConfig}; +pub use self::payload::{Payload, PayloadConfig}; +pub use self::query::{Query, QueryConfig}; +pub use self::readlines::Readlines; diff --git a/src/types/path.rs b/src/types/path.rs new file mode 100644 index 000000000..404759300 --- /dev/null +++ b/src/types/path.rs @@ -0,0 +1,377 @@ +//! Path extractor +use std::sync::Arc; +use std::{fmt, ops}; + +use actix_http::error::{Error, ErrorNotFound}; +use actix_router::PathDeserializer; +use futures::future::{ready, Ready}; +use serde::de; + +use crate::dev::Payload; +use crate::error::PathError; +use crate::request::HttpRequest; +use crate::FromRequest; + +#[derive(PartialEq, Eq, PartialOrd, Ord)] +/// Extract typed information from the request's path. +/// +/// ## Example +/// +/// ```rust +/// use actix_web::{web, App}; +/// +/// /// extract path info from "/{username}/{count}/index.html" url +/// /// {username} - deserializes to a String +/// /// {count} - - deserializes to a u32 +/// async fn index(info: web::Path<(String, u32)>) -> String { +/// format!("Welcome {}! {}", info.0, info.1) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/{username}/{count}/index.html") // <- define path parameters +/// .route(web::get().to(index)) // <- register handler with `Path` extractor +/// ); +/// } +/// ``` +/// +/// It is possible to extract path information to a specific type that +/// implements `Deserialize` trait from *serde*. +/// +/// ```rust +/// use actix_web::{web, App, Error}; +/// use serde_derive::Deserialize; +/// +/// #[derive(Deserialize)] +/// struct Info { +/// username: String, +/// } +/// +/// /// extract `Info` from a path using serde +/// async fn index(info: web::Path) -> Result { +/// Ok(format!("Welcome {}!", info.username)) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/{username}/index.html") // <- define path parameters +/// .route(web::get().to(index)) // <- use handler with Path` extractor +/// ); +/// } +/// ``` +pub struct Path { + inner: T, +} + +impl Path { + /// Deconstruct to an inner value + pub fn into_inner(self) -> T { + self.inner + } +} + +impl AsRef for Path { + fn as_ref(&self) -> &T { + &self.inner + } +} + +impl ops::Deref for Path { + type Target = T; + + fn deref(&self) -> &T { + &self.inner + } +} + +impl ops::DerefMut for Path { + fn deref_mut(&mut self) -> &mut T { + &mut self.inner + } +} + +impl From for Path { + fn from(inner: T) -> Path { + Path { inner } + } +} + +impl fmt::Debug for Path { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.inner.fmt(f) + } +} + +impl fmt::Display for Path { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.inner.fmt(f) + } +} + +/// Extract typed information from the request's path. +/// +/// ## Example +/// +/// ```rust +/// use actix_web::{web, App}; +/// +/// /// extract path info from "/{username}/{count}/index.html" url +/// /// {username} - deserializes to a String +/// /// {count} - - deserializes to a u32 +/// async fn index(info: web::Path<(String, u32)>) -> String { +/// format!("Welcome {}! {}", info.0, info.1) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/{username}/{count}/index.html") // <- define path parameters +/// .route(web::get().to(index)) // <- register handler with `Path` extractor +/// ); +/// } +/// ``` +/// +/// It is possible to extract path information to a specific type that +/// implements `Deserialize` trait from *serde*. +/// +/// ```rust +/// use actix_web::{web, App, Error}; +/// use serde_derive::Deserialize; +/// +/// #[derive(Deserialize)] +/// struct Info { +/// username: String, +/// } +/// +/// /// extract `Info` from a path using serde +/// async fn index(info: web::Path) -> Result { +/// Ok(format!("Welcome {}!", info.username)) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/{username}/index.html") // <- define path parameters +/// .route(web::get().to(index)) // <- use handler with Path` extractor +/// ); +/// } +/// ``` +impl FromRequest for Path +where + T: de::DeserializeOwned, +{ + type Error = Error; + type Future = Ready>; + type Config = PathConfig; + + #[inline] + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + let error_handler = req + .app_data::() + .map(|c| c.ehandler.clone()) + .unwrap_or(None); + + ready( + de::Deserialize::deserialize(PathDeserializer::new(req.match_info())) + .map(|inner| Path { inner }) + .map_err(move |e| { + log::debug!( + "Failed during Path extractor deserialization. \ + Request path: {:?}", + req.path() + ); + if let Some(error_handler) = error_handler { + let e = PathError::Deserialize(e); + (error_handler)(e, req) + } else { + ErrorNotFound(e) + } + }), + ) + } +} + +/// Path extractor configuration +/// +/// ```rust +/// use actix_web::web::PathConfig; +/// use actix_web::{error, web, App, FromRequest, HttpResponse}; +/// use serde_derive::Deserialize; +/// +/// #[derive(Deserialize, Debug)] +/// enum Folder { +/// #[serde(rename = "inbox")] +/// Inbox, +/// #[serde(rename = "outbox")] +/// Outbox, +/// } +/// +/// // deserialize `Info` from request's path +/// async fn index(folder: web::Path) -> String { +/// format!("Selected folder: {:?}!", folder) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/messages/{folder}") +/// .data(PathConfig::default().error_handler(|err, req| { +/// error::InternalError::from_response( +/// err, +/// HttpResponse::Conflict().finish(), +/// ) +/// .into() +/// })) +/// .route(web::post().to(index)), +/// ); +/// } +/// ``` +#[derive(Clone)] +pub struct PathConfig { + ehandler: Option Error + Send + Sync>>, +} + +impl PathConfig { + /// Set custom error handler + pub fn error_handler(mut self, f: F) -> Self + where + F: Fn(PathError, &HttpRequest) -> Error + Send + Sync + 'static, + { + self.ehandler = Some(Arc::new(f)); + self + } +} + +impl Default for PathConfig { + fn default() -> Self { + PathConfig { ehandler: None } + } +} + +#[cfg(test)] +mod tests { + use actix_router::ResourceDef; + use derive_more::Display; + use serde_derive::Deserialize; + + use super::*; + use crate::test::TestRequest; + use crate::{error, http, HttpResponse}; + + #[derive(Deserialize, Debug, Display)] + #[display(fmt = "MyStruct({}, {})", key, value)] + struct MyStruct { + key: String, + value: String, + } + + #[derive(Deserialize)] + struct Test2 { + key: String, + value: u32, + } + + #[actix_rt::test] + async fn test_extract_path_single() { + let resource = ResourceDef::new("/{value}/"); + + let mut req = TestRequest::with_uri("/32/").to_srv_request(); + resource.match_path(req.match_info_mut()); + + let (req, mut pl) = req.into_parts(); + assert_eq!(*Path::::from_request(&req, &mut pl).await.unwrap(), 32); + assert!(Path::::from_request(&req, &mut pl).await.is_err()); + } + + #[actix_rt::test] + async fn test_tuple_extract() { + let resource = ResourceDef::new("/{key}/{value}/"); + + let mut req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); + resource.match_path(req.match_info_mut()); + + let (req, mut pl) = req.into_parts(); + let res = <(Path<(String, String)>,)>::from_request(&req, &mut pl) + .await + .unwrap(); + assert_eq!((res.0).0, "name"); + assert_eq!((res.0).1, "user1"); + + let res = <(Path<(String, String)>, Path<(String, String)>)>::from_request( + &req, &mut pl, + ) + .await + .unwrap(); + assert_eq!((res.0).0, "name"); + assert_eq!((res.0).1, "user1"); + assert_eq!((res.1).0, "name"); + assert_eq!((res.1).1, "user1"); + + let () = <()>::from_request(&req, &mut pl).await.unwrap(); + } + + #[actix_rt::test] + async fn test_request_extract() { + let mut req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); + + let resource = ResourceDef::new("/{key}/{value}/"); + resource.match_path(req.match_info_mut()); + + let (req, mut pl) = req.into_parts(); + let mut s = Path::::from_request(&req, &mut pl).await.unwrap(); + assert_eq!(s.key, "name"); + assert_eq!(s.value, "user1"); + s.value = "user2".to_string(); + assert_eq!(s.value, "user2"); + assert_eq!( + format!("{}, {:?}", s, s), + "MyStruct(name, user2), MyStruct { key: \"name\", value: \"user2\" }" + ); + let s = s.into_inner(); + assert_eq!(s.value, "user2"); + + let s = Path::<(String, String)>::from_request(&req, &mut pl) + .await + .unwrap(); + assert_eq!(s.0, "name"); + assert_eq!(s.1, "user1"); + + let mut req = TestRequest::with_uri("/name/32/").to_srv_request(); + let resource = ResourceDef::new("/{key}/{value}/"); + resource.match_path(req.match_info_mut()); + + let (req, mut pl) = req.into_parts(); + let s = Path::::from_request(&req, &mut pl).await.unwrap(); + assert_eq!(s.as_ref().key, "name"); + assert_eq!(s.value, 32); + + let s = Path::<(String, u8)>::from_request(&req, &mut pl) + .await + .unwrap(); + assert_eq!(s.0, "name"); + assert_eq!(s.1, 32); + + let res = Path::>::from_request(&req, &mut pl) + .await + .unwrap(); + assert_eq!(res[0], "name".to_owned()); + assert_eq!(res[1], "32".to_owned()); + } + + #[actix_rt::test] + async fn test_custom_err_handler() { + let (req, mut pl) = TestRequest::with_uri("/name/user1/") + .data(PathConfig::default().error_handler(|err, _| { + error::InternalError::from_response( + err, + HttpResponse::Conflict().finish(), + ) + .into() + })) + .to_http_parts(); + + let s = Path::<(usize,)>::from_request(&req, &mut pl) + .await + .unwrap_err(); + let res: HttpResponse = s.into(); + + assert_eq!(res.status(), http::StatusCode::CONFLICT); + } +} diff --git a/src/types/payload.rs b/src/types/payload.rs new file mode 100644 index 000000000..2969e385a --- /dev/null +++ b/src/types/payload.rs @@ -0,0 +1,473 @@ +//! Payload/Bytes/String extractors +use std::future::Future; +use std::pin::Pin; +use std::str; +use std::task::{Context, Poll}; + +use actix_http::error::{Error, ErrorBadRequest, PayloadError}; +use actix_http::HttpMessage; +use bytes::{Bytes, BytesMut}; +use encoding_rs::UTF_8; +use futures::future::{err, ok, Either, FutureExt, LocalBoxFuture, Ready}; +use futures::{Stream, StreamExt}; +use mime::Mime; + +use crate::dev; +use crate::extract::FromRequest; +use crate::http::header; +use crate::request::HttpRequest; + +/// Payload extractor returns request 's payload stream. +/// +/// ## Example +/// +/// ```rust +/// use futures::{Future, Stream, StreamExt}; +/// use actix_web::{web, error, App, Error, HttpResponse}; +/// +/// /// extract binary data from request +/// async fn index(mut body: web::Payload) -> Result +/// { +/// let mut bytes = web::BytesMut::new(); +/// while let Some(item) = body.next().await { +/// bytes.extend_from_slice(&item?); +/// } +/// +/// format!("Body {:?}!", bytes); +/// Ok(HttpResponse::Ok().finish()) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/index.html").route( +/// web::get().to(index)) +/// ); +/// } +/// ``` +pub struct Payload(pub crate::dev::Payload); + +impl Payload { + /// Deconstruct to a inner value + pub fn into_inner(self) -> crate::dev::Payload { + self.0 + } +} + +impl Stream for Payload { + type Item = Result; + + #[inline] + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + Pin::new(&mut self.0).poll_next(cx) + } +} + +/// Get request's payload stream +/// +/// ## Example +/// +/// ```rust +/// use futures::{Future, Stream, StreamExt}; +/// use actix_web::{web, error, App, Error, HttpResponse}; +/// +/// /// extract binary data from request +/// async fn index(mut body: web::Payload) -> Result +/// { +/// let mut bytes = web::BytesMut::new(); +/// while let Some(item) = body.next().await { +/// bytes.extend_from_slice(&item?); +/// } +/// +/// format!("Body {:?}!", bytes); +/// Ok(HttpResponse::Ok().finish()) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/index.html").route( +/// web::get().to(index)) +/// ); +/// } +/// ``` +impl FromRequest for Payload { + type Config = PayloadConfig; + type Error = Error; + type Future = Ready>; + + #[inline] + fn from_request(_: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { + ok(Payload(payload.take())) + } +} + +/// Request binary data from a request's payload. +/// +/// Loads request's payload and construct Bytes instance. +/// +/// [**PayloadConfig**](struct.PayloadConfig.html) allows to configure +/// extraction process. +/// +/// ## Example +/// +/// ```rust +/// use bytes::Bytes; +/// use actix_web::{web, App}; +/// +/// /// extract binary data from request +/// async fn index(body: Bytes) -> String { +/// format!("Body {:?}!", body) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/index.html").route( +/// web::get().to(index)) +/// ); +/// } +/// ``` +impl FromRequest for Bytes { + type Config = PayloadConfig; + type Error = Error; + type Future = Either< + LocalBoxFuture<'static, Result>, + Ready>, + >; + + #[inline] + fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { + let tmp; + let cfg = if let Some(cfg) = req.app_data::() { + cfg + } else { + tmp = PayloadConfig::default(); + &tmp + }; + + if let Err(e) = cfg.check_mimetype(req) { + return Either::Right(err(e)); + } + + let limit = cfg.limit; + let fut = HttpMessageBody::new(req, payload).limit(limit); + Either::Left(async move { Ok(fut.await?) }.boxed_local()) + } +} + +/// Extract text information from a request's body. +/// +/// Text extractor automatically decode body according to the request's charset. +/// +/// [**PayloadConfig**](struct.PayloadConfig.html) allows to configure +/// extraction process. +/// +/// ## Example +/// +/// ```rust +/// use actix_web::{web, App, FromRequest}; +/// +/// /// extract text data from request +/// async fn index(text: String) -> String { +/// format!("Body {}!", text) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/index.html") +/// .data(String::configure(|cfg| { // <- limit size of the payload +/// cfg.limit(4096) +/// })) +/// .route(web::get().to(index)) // <- register handler with extractor params +/// ); +/// } +/// ``` +impl FromRequest for String { + type Config = PayloadConfig; + type Error = Error; + type Future = Either< + LocalBoxFuture<'static, Result>, + Ready>, + >; + + #[inline] + fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { + let tmp; + let cfg = if let Some(cfg) = req.app_data::() { + cfg + } else { + tmp = PayloadConfig::default(); + &tmp + }; + + // check content-type + if let Err(e) = cfg.check_mimetype(req) { + return Either::Right(err(e)); + } + + // check charset + let encoding = match req.encoding() { + Ok(enc) => enc, + Err(e) => return Either::Right(err(e.into())), + }; + let limit = cfg.limit; + let fut = HttpMessageBody::new(req, payload).limit(limit); + + Either::Left( + async move { + let body = fut.await?; + + if encoding == UTF_8 { + Ok(str::from_utf8(body.as_ref()) + .map_err(|_| ErrorBadRequest("Can not decode body"))? + .to_owned()) + } else { + Ok(encoding + .decode_without_bom_handling_and_without_replacement(&body) + .map(|s| s.into_owned()) + .ok_or_else(|| ErrorBadRequest("Can not decode body"))?) + } + } + .boxed_local(), + ) + } +} +/// Payload configuration for request's payload. +#[derive(Clone)] +pub struct PayloadConfig { + limit: usize, + mimetype: Option, +} + +impl PayloadConfig { + /// Create `PayloadConfig` instance and set max size of payload. + pub fn new(limit: usize) -> Self { + let mut cfg = Self::default(); + cfg.limit = limit; + cfg + } + + /// Change max size of payload. By default max size is 256Kb + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } + + /// Set required mime-type of the request. By default mime type is not + /// enforced. + pub fn mimetype(mut self, mt: Mime) -> Self { + self.mimetype = Some(mt); + self + } + + fn check_mimetype(&self, req: &HttpRequest) -> Result<(), Error> { + // check content-type + if let Some(ref mt) = self.mimetype { + match req.mime_type() { + Ok(Some(ref req_mt)) => { + if mt != req_mt { + return Err(ErrorBadRequest("Unexpected Content-Type")); + } + } + Ok(None) => { + return Err(ErrorBadRequest("Content-Type is expected")); + } + Err(err) => { + return Err(err.into()); + } + } + } + Ok(()) + } +} + +impl Default for PayloadConfig { + fn default() -> Self { + PayloadConfig { + limit: 262_144, + mimetype: None, + } + } +} + +/// Future that resolves to a complete http message body. +/// +/// Load http message body. +/// +/// By default only 256Kb payload reads to a memory, then +/// `PayloadError::Overflow` get returned. Use `MessageBody::limit()` +/// method to change upper limit. +pub struct HttpMessageBody { + limit: usize, + length: Option, + stream: Option>, + err: Option, + fut: Option>>, +} + +impl HttpMessageBody { + /// Create `MessageBody` for request. + pub fn new(req: &HttpRequest, payload: &mut dev::Payload) -> HttpMessageBody { + let mut len = None; + if let Some(l) = req.headers().get(&header::CONTENT_LENGTH) { + if let Ok(s) = l.to_str() { + if let Ok(l) = s.parse::() { + len = Some(l) + } else { + return Self::err(PayloadError::UnknownLength); + } + } else { + return Self::err(PayloadError::UnknownLength); + } + } + + HttpMessageBody { + stream: Some(dev::Decompress::from_headers(payload.take(), req.headers())), + limit: 262_144, + length: len, + fut: None, + err: None, + } + } + + /// Change max size of payload. By default max size is 256Kb + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } + + fn err(e: PayloadError) -> Self { + HttpMessageBody { + stream: None, + limit: 262_144, + fut: None, + err: Some(e), + length: None, + } + } +} + +impl Future for HttpMessageBody { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + if let Some(ref mut fut) = self.fut { + return Pin::new(fut).poll(cx); + } + + if let Some(err) = self.err.take() { + return Poll::Ready(Err(err)); + } + + if let Some(len) = self.length.take() { + if len > self.limit { + return Poll::Ready(Err(PayloadError::Overflow)); + } + } + + // future + let limit = self.limit; + let mut stream = self.stream.take().unwrap(); + self.fut = Some( + async move { + let mut body = BytesMut::with_capacity(8192); + + while let Some(item) = stream.next().await { + let chunk = item?; + if body.len() + chunk.len() > limit { + return Err(PayloadError::Overflow); + } else { + body.extend_from_slice(&chunk); + } + } + Ok(body.freeze()) + } + .boxed_local(), + ); + self.poll(cx) + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::*; + use crate::http::header; + use crate::test::TestRequest; + + #[actix_rt::test] + async fn test_payload_config() { + let req = TestRequest::default().to_http_request(); + let cfg = PayloadConfig::default().mimetype(mime::APPLICATION_JSON); + assert!(cfg.check_mimetype(&req).is_err()); + + let req = TestRequest::with_header( + header::CONTENT_TYPE, + "application/x-www-form-urlencoded", + ) + .to_http_request(); + assert!(cfg.check_mimetype(&req).is_err()); + + let req = TestRequest::with_header(header::CONTENT_TYPE, "application/json") + .to_http_request(); + assert!(cfg.check_mimetype(&req).is_ok()); + } + + #[actix_rt::test] + async fn test_bytes() { + let (req, mut pl) = TestRequest::with_header(header::CONTENT_LENGTH, "11") + .set_payload(Bytes::from_static(b"hello=world")) + .to_http_parts(); + + let s = Bytes::from_request(&req, &mut pl).await.unwrap(); + assert_eq!(s, Bytes::from_static(b"hello=world")); + } + + #[actix_rt::test] + async fn test_string() { + let (req, mut pl) = TestRequest::with_header(header::CONTENT_LENGTH, "11") + .set_payload(Bytes::from_static(b"hello=world")) + .to_http_parts(); + + let s = String::from_request(&req, &mut pl).await.unwrap(); + assert_eq!(s, "hello=world"); + } + + #[actix_rt::test] + async fn test_message_body() { + let (req, mut pl) = TestRequest::with_header(header::CONTENT_LENGTH, "xxxx") + .to_srv_request() + .into_parts(); + let res = HttpMessageBody::new(&req, &mut pl).await; + match res.err().unwrap() { + PayloadError::UnknownLength => (), + _ => unreachable!("error"), + } + + let (req, mut pl) = TestRequest::with_header(header::CONTENT_LENGTH, "1000000") + .to_srv_request() + .into_parts(); + let res = HttpMessageBody::new(&req, &mut pl).await; + match res.err().unwrap() { + PayloadError::Overflow => (), + _ => unreachable!("error"), + } + + let (req, mut pl) = TestRequest::default() + .set_payload(Bytes::from_static(b"test")) + .to_http_parts(); + let res = HttpMessageBody::new(&req, &mut pl).await; + assert_eq!(res.ok().unwrap(), Bytes::from_static(b"test")); + + let (req, mut pl) = TestRequest::default() + .set_payload(Bytes::from_static(b"11111111111111")) + .to_http_parts(); + let res = HttpMessageBody::new(&req, &mut pl).limit(5).await; + match res.err().unwrap() { + PayloadError::Overflow => (), + _ => unreachable!("error"), + } + } +} diff --git a/src/types/query.rs b/src/types/query.rs new file mode 100644 index 000000000..b1f4572fa --- /dev/null +++ b/src/types/query.rs @@ -0,0 +1,295 @@ +//! Query extractor + +use std::sync::Arc; +use std::{fmt, ops}; + +use actix_http::error::Error; +use futures::future::{err, ok, Ready}; +use serde::de; +use serde_urlencoded; + +use crate::dev::Payload; +use crate::error::QueryPayloadError; +use crate::extract::FromRequest; +use crate::request::HttpRequest; + +/// Extract typed information from the request's query. +/// +/// **Note**: A query string consists of unordered `key=value` pairs, therefore it cannot +/// be decoded into any type which depends upon data ordering e.g. tuples or tuple-structs. +/// Attempts to do so will *fail at runtime*. +/// +/// ## Example +/// +/// ```rust +/// use actix_web::{web, App}; +/// use serde_derive::Deserialize; +/// +/// #[derive(Debug, Deserialize)] +/// pub enum ResponseType { +/// Token, +/// Code +/// } +/// +/// #[derive(Deserialize)] +/// pub struct AuthRequest { +/// id: u64, +/// response_type: ResponseType, +/// } +/// +/// // Use `Query` extractor for query information (and destructure it within the signature). +/// // This handler gets called only if the request's query string contains a `username` field. +/// // The correct request for this handler would be `/index.html?id=64&response_type=Code"`. +/// async fn index(web::Query(info): web::Query) -> String { +/// format!("Authorization request for client with id={} and type={:?}!", info.id, info.response_type) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/index.html").route(web::get().to(index))); // <- use `Query` extractor +/// } +/// ``` +#[derive(PartialEq, Eq, PartialOrd, Ord)] +pub struct Query(pub T); + +impl Query { + /// Deconstruct to a inner value + pub fn into_inner(self) -> T { + self.0 + } + + /// Get query parameters from the path + pub fn from_query(query_str: &str) -> Result + where + T: de::DeserializeOwned, + { + serde_urlencoded::from_str::(query_str) + .map(|val| Ok(Query(val))) + .unwrap_or_else(move |e| Err(QueryPayloadError::Deserialize(e))) + } +} + +impl ops::Deref for Query { + type Target = T; + + fn deref(&self) -> &T { + &self.0 + } +} + +impl ops::DerefMut for Query { + fn deref_mut(&mut self) -> &mut T { + &mut self.0 + } +} + +impl fmt::Debug for Query { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +impl fmt::Display for Query { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +/// Extract typed information from the request's query. +/// +/// ## Example +/// +/// ```rust +/// use actix_web::{web, App}; +/// use serde_derive::Deserialize; +/// +/// #[derive(Debug, Deserialize)] +/// pub enum ResponseType { +/// Token, +/// Code +/// } +/// +/// #[derive(Deserialize)] +/// pub struct AuthRequest { +/// id: u64, +/// response_type: ResponseType, +/// } +/// +/// // Use `Query` extractor for query information. +/// // This handler get called only if request's query contains `username` field +/// // The correct request for this handler would be `/index.html?id=64&response_type=Code"` +/// async fn index(info: web::Query) -> String { +/// format!("Authorization request for client with id={} and type={:?}!", info.id, info.response_type) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/index.html") +/// .route(web::get().to(index))); // <- use `Query` extractor +/// } +/// ``` +impl FromRequest for Query +where + T: de::DeserializeOwned, +{ + type Error = Error; + type Future = Ready>; + type Config = QueryConfig; + + #[inline] + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + let error_handler = req + .app_data::() + .map(|c| c.ehandler.clone()) + .unwrap_or(None); + + serde_urlencoded::from_str::(req.query_string()) + .map(|val| ok(Query(val))) + .unwrap_or_else(move |e| { + let e = QueryPayloadError::Deserialize(e); + + log::debug!( + "Failed during Query extractor deserialization. \ + Request path: {:?}", + req.path() + ); + + let e = if let Some(error_handler) = error_handler { + (error_handler)(e, req) + } else { + e.into() + }; + + err(e) + }) + } +} + +/// Query extractor configuration +/// +/// ## Example +/// +/// ```rust +/// use actix_web::{error, web, App, FromRequest, HttpResponse}; +/// use serde_derive::Deserialize; +/// +/// #[derive(Deserialize)] +/// struct Info { +/// username: String, +/// } +/// +/// /// deserialize `Info` from request's querystring +/// async fn index(info: web::Query) -> String { +/// format!("Welcome {}!", info.username) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/index.html").data( +/// // change query extractor configuration +/// web::Query::::configure(|cfg| { +/// cfg.error_handler(|err, req| { // <- create custom error response +/// error::InternalError::from_response( +/// err, HttpResponse::Conflict().finish()).into() +/// }) +/// })) +/// .route(web::post().to(index)) +/// ); +/// } +/// ``` +#[derive(Clone)] +pub struct QueryConfig { + ehandler: + Option Error + Send + Sync>>, +} + +impl QueryConfig { + /// Set custom error handler + pub fn error_handler(mut self, f: F) -> Self + where + F: Fn(QueryPayloadError, &HttpRequest) -> Error + Send + Sync + 'static, + { + self.ehandler = Some(Arc::new(f)); + self + } +} + +impl Default for QueryConfig { + fn default() -> Self { + QueryConfig { ehandler: None } + } +} + +#[cfg(test)] +mod tests { + use actix_http::http::StatusCode; + use derive_more::Display; + use serde_derive::Deserialize; + + use super::*; + use crate::error::InternalError; + use crate::test::TestRequest; + use crate::HttpResponse; + + #[derive(Deserialize, Debug, Display)] + struct Id { + id: String, + } + + #[actix_rt::test] + async fn test_service_request_extract() { + let req = TestRequest::with_uri("/name/user1/").to_srv_request(); + assert!(Query::::from_query(&req.query_string()).is_err()); + + let req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); + let mut s = Query::::from_query(&req.query_string()).unwrap(); + + assert_eq!(s.id, "test"); + assert_eq!(format!("{}, {:?}", s, s), "test, Id { id: \"test\" }"); + + s.id = "test1".to_string(); + let s = s.into_inner(); + assert_eq!(s.id, "test1"); + } + + #[actix_rt::test] + async fn test_request_extract() { + let req = TestRequest::with_uri("/name/user1/").to_srv_request(); + let (req, mut pl) = req.into_parts(); + assert!(Query::::from_request(&req, &mut pl).await.is_err()); + + let req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); + let (req, mut pl) = req.into_parts(); + + let mut s = Query::::from_request(&req, &mut pl).await.unwrap(); + assert_eq!(s.id, "test"); + assert_eq!(format!("{}, {:?}", s, s), "test, Id { id: \"test\" }"); + + s.id = "test1".to_string(); + let s = s.into_inner(); + assert_eq!(s.id, "test1"); + } + + #[actix_rt::test] + async fn test_custom_error_responder() { + let req = TestRequest::with_uri("/name/user1/") + .data(QueryConfig::default().error_handler(|e, _| { + let resp = HttpResponse::UnprocessableEntity().finish(); + InternalError::from_response(e, resp).into() + })) + .to_srv_request(); + + let (req, mut pl) = req.into_parts(); + let query = Query::::from_request(&req, &mut pl).await; + + assert!(query.is_err()); + assert_eq!( + query + .unwrap_err() + .as_response_error() + .error_response() + .status(), + StatusCode::UNPROCESSABLE_ENTITY + ); + } +} diff --git a/src/types/readlines.rs b/src/types/readlines.rs new file mode 100644 index 000000000..123f8102b --- /dev/null +++ b/src/types/readlines.rs @@ -0,0 +1,203 @@ +use std::borrow::Cow; +use std::pin::Pin; +use std::str; +use std::task::{Context, Poll}; + +use bytes::{Bytes, BytesMut}; +use encoding_rs::{Encoding, UTF_8}; +use futures::Stream; + +use crate::dev::Payload; +use crate::error::{PayloadError, ReadlinesError}; +use crate::HttpMessage; + +/// Stream to read request line by line. +pub struct Readlines { + stream: Payload, + buff: BytesMut, + limit: usize, + checked_buff: bool, + encoding: &'static Encoding, + err: Option, +} + +impl Readlines +where + T: HttpMessage, + T::Stream: Stream> + Unpin, +{ + /// Create a new stream to read request line by line. + pub fn new(req: &mut T) -> Self { + let encoding = match req.encoding() { + Ok(enc) => enc, + Err(err) => return Self::err(err.into()), + }; + + Readlines { + stream: req.take_payload(), + buff: BytesMut::with_capacity(262_144), + limit: 262_144, + checked_buff: true, + err: None, + encoding, + } + } + + /// Change max line size. By default max size is 256Kb + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } + + fn err(err: ReadlinesError) -> Self { + Readlines { + stream: Payload::None, + buff: BytesMut::new(), + limit: 262_144, + checked_buff: true, + encoding: UTF_8, + err: Some(err), + } + } +} + +impl Stream for Readlines +where + T: HttpMessage, + T::Stream: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.get_mut(); + + if let Some(err) = this.err.take() { + return Poll::Ready(Some(Err(err))); + } + + // check if there is a newline in the buffer + if !this.checked_buff { + let mut found: Option = None; + for (ind, b) in this.buff.iter().enumerate() { + if *b == b'\n' { + found = Some(ind); + break; + } + } + if let Some(ind) = found { + // check if line is longer than limit + if ind + 1 > this.limit { + return Poll::Ready(Some(Err(ReadlinesError::LimitOverflow))); + } + let line = if this.encoding == UTF_8 { + str::from_utf8(&this.buff.split_to(ind + 1)) + .map_err(|_| ReadlinesError::EncodingError)? + .to_owned() + } else { + this.encoding + .decode_without_bom_handling_and_without_replacement( + &this.buff.split_to(ind + 1), + ) + .map(Cow::into_owned) + .ok_or(ReadlinesError::EncodingError)? + }; + return Poll::Ready(Some(Ok(line))); + } + this.checked_buff = true; + } + // poll req for more bytes + match Pin::new(&mut this.stream).poll_next(cx) { + Poll::Ready(Some(Ok(mut bytes))) => { + // check if there is a newline in bytes + let mut found: Option = None; + for (ind, b) in bytes.iter().enumerate() { + if *b == b'\n' { + found = Some(ind); + break; + } + } + if let Some(ind) = found { + // check if line is longer than limit + if ind + 1 > this.limit { + return Poll::Ready(Some(Err(ReadlinesError::LimitOverflow))); + } + let line = if this.encoding == UTF_8 { + str::from_utf8(&bytes.split_to(ind + 1)) + .map_err(|_| ReadlinesError::EncodingError)? + .to_owned() + } else { + this.encoding + .decode_without_bom_handling_and_without_replacement( + &bytes.split_to(ind + 1), + ) + .map(Cow::into_owned) + .ok_or(ReadlinesError::EncodingError)? + }; + // extend buffer with rest of the bytes; + this.buff.extend_from_slice(&bytes); + this.checked_buff = false; + return Poll::Ready(Some(Ok(line))); + } + this.buff.extend_from_slice(&bytes); + Poll::Pending + } + Poll::Pending => Poll::Pending, + Poll::Ready(None) => { + if this.buff.is_empty() { + return Poll::Ready(None); + } + if this.buff.len() > this.limit { + return Poll::Ready(Some(Err(ReadlinesError::LimitOverflow))); + } + let line = if this.encoding == UTF_8 { + str::from_utf8(&this.buff) + .map_err(|_| ReadlinesError::EncodingError)? + .to_owned() + } else { + this.encoding + .decode_without_bom_handling_and_without_replacement(&this.buff) + .map(Cow::into_owned) + .ok_or(ReadlinesError::EncodingError)? + }; + this.buff.clear(); + Poll::Ready(Some(Ok(line))) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(ReadlinesError::from(e)))), + } + } +} + +#[cfg(test)] +mod tests { + use futures::stream::StreamExt; + + use super::*; + use crate::test::TestRequest; + + #[actix_rt::test] + async fn test_readlines() { + let mut req = TestRequest::default() + .set_payload(Bytes::from_static( + b"Lorem Ipsum is simply dummy text of the printing and typesetting\n\ + industry. Lorem Ipsum has been the industry's standard dummy\n\ + Contrary to popular belief, Lorem Ipsum is not simply random text.", + )) + .to_request(); + + let mut stream = Readlines::new(&mut req); + assert_eq!( + stream.next().await.unwrap().unwrap(), + "Lorem Ipsum is simply dummy text of the printing and typesetting\n" + ); + + assert_eq!( + stream.next().await.unwrap().unwrap(), + "industry. Lorem Ipsum has been the industry's standard dummy\n" + ); + + assert_eq!( + stream.next().await.unwrap().unwrap(), + "Contrary to popular belief, Lorem Ipsum is not simply random text." + ); + } +} diff --git a/src/web.rs b/src/web.rs new file mode 100644 index 000000000..7f1e8d8f6 --- /dev/null +++ b/src/web.rs @@ -0,0 +1,283 @@ +//! Essentials helper functions and types for application registration. +use actix_http::http::Method; +use futures::Future; + +pub use actix_http::Response as HttpResponse; +pub use bytes::{Bytes, BytesMut}; +pub use futures::channel::oneshot::Canceled; + +use crate::extract::FromRequest; +use crate::handler::Factory; +use crate::resource::Resource; +use crate::responder::Responder; +use crate::route::Route; +use crate::scope::Scope; +use crate::service::WebService; + +pub use crate::config::ServiceConfig; +pub use crate::data::Data; +pub use crate::request::HttpRequest; +pub use crate::types::*; + +/// Create resource for a specific path. +/// +/// Resources may have variable path segments. For example, a +/// resource with the path `/a/{name}/c` would match all incoming +/// requests with paths such as `/a/b/c`, `/a/1/c`, or `/a/etc/c`. +/// +/// A variable segment is specified in the form `{identifier}`, +/// where the identifier can be used later in a request handler to +/// access the matched value for that segment. This is done by +/// looking up the identifier in the `Params` object returned by +/// `HttpRequest.match_info()` method. +/// +/// By default, each segment matches the regular expression `[^{}/]+`. +/// +/// You can also specify a custom regex in the form `{identifier:regex}`: +/// +/// For instance, to route `GET`-requests on any route matching +/// `/users/{userid}/{friend}` and store `userid` and `friend` in +/// the exposed `Params` object: +/// +/// ```rust +/// # extern crate actix_web; +/// use actix_web::{web, App, HttpResponse}; +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/users/{userid}/{friend}") +/// .route(web::get().to(|| HttpResponse::Ok())) +/// .route(web::head().to(|| HttpResponse::MethodNotAllowed())) +/// ); +/// } +/// ``` +pub fn resource(path: &str) -> Resource { + Resource::new(path) +} + +/// Configure scope for common root path. +/// +/// Scopes collect multiple paths under a common path prefix. +/// Scope path can contain variable path segments as resources. +/// +/// ```rust +/// use actix_web::{web, App, HttpResponse}; +/// +/// fn main() { +/// let app = App::new().service( +/// web::scope("/{project_id}") +/// .service(web::resource("/path1").to(|| HttpResponse::Ok())) +/// .service(web::resource("/path2").to(|| HttpResponse::Ok())) +/// .service(web::resource("/path3").to(|| HttpResponse::MethodNotAllowed())) +/// ); +/// } +/// ``` +/// +/// In the above example, three routes get added: +/// * /{project_id}/path1 +/// * /{project_id}/path2 +/// * /{project_id}/path3 +/// +pub fn scope(path: &str) -> Scope { + Scope::new(path) +} + +/// Create *route* without configuration. +pub fn route() -> Route { + Route::new() +} + +/// Create *route* with `GET` method guard. +/// +/// ```rust +/// use actix_web::{web, App, HttpResponse}; +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/{project_id}") +/// .route(web::get().to(|| HttpResponse::Ok())) +/// ); +/// } +/// ``` +/// +/// In the above example, one `GET` route get added: +/// * /{project_id} +/// +pub fn get() -> Route { + method(Method::GET) +} + +/// Create *route* with `POST` method guard. +/// +/// ```rust +/// use actix_web::{web, App, HttpResponse}; +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/{project_id}") +/// .route(web::post().to(|| HttpResponse::Ok())) +/// ); +/// } +/// ``` +/// +/// In the above example, one `POST` route get added: +/// * /{project_id} +/// +pub fn post() -> Route { + method(Method::POST) +} + +/// Create *route* with `PUT` method guard. +/// +/// ```rust +/// use actix_web::{web, App, HttpResponse}; +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/{project_id}") +/// .route(web::put().to(|| HttpResponse::Ok())) +/// ); +/// } +/// ``` +/// +/// In the above example, one `PUT` route get added: +/// * /{project_id} +/// +pub fn put() -> Route { + method(Method::PUT) +} + +/// Create *route* with `PATCH` method guard. +/// +/// ```rust +/// use actix_web::{web, App, HttpResponse}; +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/{project_id}") +/// .route(web::patch().to(|| HttpResponse::Ok())) +/// ); +/// } +/// ``` +/// +/// In the above example, one `PATCH` route get added: +/// * /{project_id} +/// +pub fn patch() -> Route { + method(Method::PATCH) +} + +/// Create *route* with `DELETE` method guard. +/// +/// ```rust +/// use actix_web::{web, App, HttpResponse}; +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/{project_id}") +/// .route(web::delete().to(|| HttpResponse::Ok())) +/// ); +/// } +/// ``` +/// +/// In the above example, one `DELETE` route get added: +/// * /{project_id} +/// +pub fn delete() -> Route { + method(Method::DELETE) +} + +/// Create *route* with `HEAD` method guard. +/// +/// ```rust +/// use actix_web::{web, App, HttpResponse}; +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/{project_id}") +/// .route(web::head().to(|| HttpResponse::Ok())) +/// ); +/// } +/// ``` +/// +/// In the above example, one `HEAD` route get added: +/// * /{project_id} +/// +pub fn head() -> Route { + method(Method::HEAD) +} + +/// Create *route* and add method guard. +/// +/// ```rust +/// use actix_web::{web, http, App, HttpResponse}; +/// +/// fn main() { +/// let app = App::new().service( +/// web::resource("/{project_id}") +/// .route(web::method(http::Method::GET).to(|| HttpResponse::Ok())) +/// ); +/// } +/// ``` +/// +/// In the above example, one `GET` route get added: +/// * /{project_id} +/// +pub fn method(method: Method) -> Route { + Route::new().method(method) +} + +/// Create a new route and add handler. +/// +/// ```rust +/// use actix_web::{web, App, HttpResponse, Responder}; +/// +/// async fn index() -> impl Responder { +/// HttpResponse::Ok() +/// } +/// +/// App::new().service( +/// web::resource("/").route( +/// web::to(index)) +/// ); +/// ``` +pub fn to(handler: F) -> Route +where + F: Factory, + I: FromRequest + 'static, + R: Future + 'static, + U: Responder + 'static, +{ + Route::new().to(handler) +} + +/// Create raw service for a specific path. +/// +/// ```rust +/// use actix_web::{dev, web, guard, App, Error, HttpResponse}; +/// +/// async fn my_service(req: dev::ServiceRequest) -> Result { +/// Ok(req.into_response(HttpResponse::Ok().finish())) +/// } +/// +/// fn main() { +/// let app = App::new().service( +/// web::service("/users/*") +/// .guard(guard::Header("content-type", "text/plain")) +/// .finish(my_service) +/// ); +/// } +/// ``` +pub fn service(path: &str) -> WebService { + WebService::new(path) +} + +/// Execute blocking function on a thread pool, returns future that resolves +/// to result of the function execution. +pub fn block(f: F) -> impl Future> +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + actix_threadpool::run(f) +} diff --git a/src/with.rs b/src/with.rs deleted file mode 100644 index 2a4420392..000000000 --- a/src/with.rs +++ /dev/null @@ -1,437 +0,0 @@ -use std::rc::Rc; -use std::cell::UnsafeCell; -use std::marker::PhantomData; -use futures::{Async, Future, Poll}; - -use error::Error; -use handler::{Handler, FromRequest, Reply, ReplyItem, Responder}; -use httprequest::HttpRequest; -use httpresponse::HttpResponse; - -pub struct With - where F: Fn(T) -> R -{ - hnd: Rc>, - _t: PhantomData, - _s: PhantomData, -} - -impl With - where F: Fn(T) -> R, -{ - pub fn new(f: F) -> Self { - With{hnd: Rc::new(UnsafeCell::new(f)), _t: PhantomData, _s: PhantomData} - } -} - -impl Handler for With - where F: Fn(T) -> R + 'static, - R: Responder + 'static, - T: FromRequest + 'static, - S: 'static -{ - type Result = Reply; - - fn handle(&mut self, req: HttpRequest) -> Self::Result { - let mut fut = WithHandlerFut{ - req, - started: false, - hnd: Rc::clone(&self.hnd), - fut1: None, - fut2: None, - }; - - match fut.poll() { - Ok(Async::Ready(resp)) => Reply::response(resp), - Ok(Async::NotReady) => Reply::async(fut), - Err(e) => Reply::response(e), - } - } -} - -struct WithHandlerFut - where F: Fn(T) -> R, - R: Responder, - T: FromRequest + 'static, - S: 'static -{ - started: bool, - hnd: Rc>, - req: HttpRequest, - fut1: Option>>, - fut2: Option>>, -} - -impl Future for WithHandlerFut - where F: Fn(T) -> R, - R: Responder + 'static, - T: FromRequest + 'static, - S: 'static -{ - type Item = HttpResponse; - type Error = Error; - - fn poll(&mut self) -> Poll { - if let Some(ref mut fut) = self.fut2 { - return fut.poll() - } - - let item = if !self.started { - self.started = true; - let mut fut = T::from_request(&self.req); - match fut.poll() { - Ok(Async::Ready(item)) => item, - Ok(Async::NotReady) => { - self.fut1 = Some(Box::new(fut)); - return Ok(Async::NotReady) - }, - Err(e) => return Err(e), - } - } else { - match self.fut1.as_mut().unwrap().poll()? { - Async::Ready(item) => item, - Async::NotReady => return Ok(Async::NotReady), - } - }; - - let hnd: &mut F = unsafe{&mut *self.hnd.get()}; - let item = match (*hnd)(item).respond_to(self.req.without_state()) { - Ok(item) => item.into(), - Err(e) => return Err(e.into()), - }; - - match item.into() { - ReplyItem::Message(resp) => Ok(Async::Ready(resp)), - ReplyItem::Future(fut) => { - self.fut2 = Some(fut); - self.poll() - } - } - } -} - -pub struct With2 where F: Fn(T1, T2) -> R -{ - hnd: Rc>, - _t1: PhantomData, - _t2: PhantomData, - _s: PhantomData, -} - -impl With2 where F: Fn(T1, T2) -> R -{ - pub fn new(f: F) -> Self { - With2{hnd: Rc::new(UnsafeCell::new(f)), - _t1: PhantomData, _t2: PhantomData, _s: PhantomData} - } -} - -impl Handler for With2 - where F: Fn(T1, T2) -> R + 'static, - R: Responder + 'static, - T1: FromRequest + 'static, - T2: FromRequest + 'static, - S: 'static -{ - type Result = Reply; - - fn handle(&mut self, req: HttpRequest) -> Self::Result { - let mut fut = WithHandlerFut2{ - req, - started: false, - hnd: Rc::clone(&self.hnd), - item: None, - fut1: None, - fut2: None, - fut3: None, - }; - match fut.poll() { - Ok(Async::Ready(resp)) => Reply::response(resp), - Ok(Async::NotReady) => Reply::async(fut), - Err(e) => Reply::response(e), - } - } -} - -struct WithHandlerFut2 - where F: Fn(T1, T2) -> R + 'static, - R: Responder + 'static, - T1: FromRequest + 'static, - T2: FromRequest + 'static, - S: 'static -{ - started: bool, - hnd: Rc>, - req: HttpRequest, - item: Option, - fut1: Option>>, - fut2: Option>>, - fut3: Option>>, -} - -impl Future for WithHandlerFut2 - where F: Fn(T1, T2) -> R + 'static, - R: Responder + 'static, - T1: FromRequest + 'static, - T2: FromRequest + 'static, - S: 'static -{ - type Item = HttpResponse; - type Error = Error; - - fn poll(&mut self) -> Poll { - if let Some(ref mut fut) = self.fut3 { - return fut.poll() - } - - if !self.started { - self.started = true; - let mut fut = T1::from_request(&self.req); - match fut.poll() { - Ok(Async::Ready(item1)) => { - let mut fut = T2::from_request(&self.req); - match fut.poll() { - Ok(Async::Ready(item2)) => { - let hnd: &mut F = unsafe{&mut *self.hnd.get()}; - match (*hnd)(item1, item2) - .respond_to(self.req.without_state()) - { - Ok(item) => match item.into().into() { - ReplyItem::Message(resp) => - return Ok(Async::Ready(resp)), - ReplyItem::Future(fut) => { - self.fut3 = Some(fut); - return self.poll() - } - }, - Err(e) => return Err(e.into()), - } - }, - Ok(Async::NotReady) => { - self.item = Some(item1); - self.fut2 = Some(Box::new(fut)); - return Ok(Async::NotReady); - }, - Err(e) => return Err(e), - } - }, - Ok(Async::NotReady) => { - self.fut1 = Some(Box::new(fut)); - return Ok(Async::NotReady); - } - Err(e) => return Err(e), - } - } - - if self.fut1.is_some() { - match self.fut1.as_mut().unwrap().poll()? { - Async::Ready(item) => { - self.item = Some(item); - self.fut1.take(); - self.fut2 = Some(Box::new(T2::from_request(&self.req))); - }, - Async::NotReady => return Ok(Async::NotReady), - } - } - - let item = match self.fut2.as_mut().unwrap().poll()? { - Async::Ready(item) => item, - Async::NotReady => return Ok(Async::NotReady), - }; - - let hnd: &mut F = unsafe{&mut *self.hnd.get()}; - let item = match (*hnd)(self.item.take().unwrap(), item) - .respond_to(self.req.without_state()) - { - Ok(item) => item.into(), - Err(err) => return Err(err.into()), - }; - - match item.into() { - ReplyItem::Message(resp) => return Ok(Async::Ready(resp)), - ReplyItem::Future(fut) => self.fut3 = Some(fut), - } - - self.poll() - } -} - -pub struct With3 where F: Fn(T1, T2, T3) -> R { - hnd: Rc>, - _t1: PhantomData, - _t2: PhantomData, - _t3: PhantomData, - _s: PhantomData, -} - - -impl With3 - where F: Fn(T1, T2, T3) -> R, -{ - pub fn new(f: F) -> Self { - With3{hnd: Rc::new(UnsafeCell::new(f)), - _s: PhantomData, _t1: PhantomData, _t2: PhantomData, _t3: PhantomData} - } -} - -impl Handler for With3 - where F: Fn(T1, T2, T3) -> R + 'static, - R: Responder + 'static, - T1: FromRequest, - T2: FromRequest, - T3: FromRequest, - T1: 'static, T2: 'static, T3: 'static, S: 'static -{ - type Result = Reply; - - fn handle(&mut self, req: HttpRequest) -> Self::Result { - let mut fut = WithHandlerFut3{ - req, - hnd: Rc::clone(&self.hnd), - started: false, - item1: None, - item2: None, - fut1: None, - fut2: None, - fut3: None, - fut4: None, - }; - match fut.poll() { - Ok(Async::Ready(resp)) => Reply::response(resp), - Ok(Async::NotReady) => Reply::async(fut), - Err(e) => Reply::response(e), - } - } -} - -struct WithHandlerFut3 - where F: Fn(T1, T2, T3) -> R + 'static, - R: Responder + 'static, - T1: FromRequest + 'static, - T2: FromRequest + 'static, - T3: FromRequest + 'static, - S: 'static -{ - hnd: Rc>, - req: HttpRequest, - started: bool, - item1: Option, - item2: Option, - fut1: Option>>, - fut2: Option>>, - fut3: Option>>, - fut4: Option>>, -} - -impl Future for WithHandlerFut3 - where F: Fn(T1, T2, T3) -> R + 'static, - R: Responder + 'static, - T1: FromRequest + 'static, - T2: FromRequest + 'static, - T3: FromRequest + 'static, - S: 'static -{ - type Item = HttpResponse; - type Error = Error; - - fn poll(&mut self) -> Poll { - if let Some(ref mut fut) = self.fut4 { - return fut.poll() - } - - if !self.started { - self.started = true; - let mut fut = T1::from_request(&self.req); - match fut.poll() { - Ok(Async::Ready(item1)) => { - let mut fut = T2::from_request(&self.req); - match fut.poll() { - Ok(Async::Ready(item2)) => { - let mut fut = T3::from_request(&self.req); - match fut.poll() { - Ok(Async::Ready(item3)) => { - let hnd: &mut F = unsafe{&mut *self.hnd.get()}; - match (*hnd)(item1, item2, item3) - .respond_to(self.req.without_state()) - { - Ok(item) => match item.into().into() { - ReplyItem::Message(resp) => - return Ok(Async::Ready(resp)), - ReplyItem::Future(fut) => { - self.fut4 = Some(fut); - return self.poll() - } - }, - Err(e) => return Err(e.into()), - } - }, - Ok(Async::NotReady) => { - self.item1 = Some(item1); - self.item2 = Some(item2); - self.fut3 = Some(Box::new(fut)); - return Ok(Async::NotReady); - }, - Err(e) => return Err(e), - } - }, - Ok(Async::NotReady) => { - self.item1 = Some(item1); - self.fut2 = Some(Box::new(fut)); - return Ok(Async::NotReady); - }, - Err(e) => return Err(e), - } - }, - Ok(Async::NotReady) => { - self.fut1 = Some(Box::new(fut)); - return Ok(Async::NotReady); - } - Err(e) => return Err(e), - } - } - - if self.fut1.is_some() { - match self.fut1.as_mut().unwrap().poll()? { - Async::Ready(item) => { - self.item1 = Some(item); - self.fut1.take(); - self.fut2 = Some(Box::new(T2::from_request(&self.req))); - }, - Async::NotReady => return Ok(Async::NotReady), - } - } - - if self.fut2.is_some() { - match self.fut2.as_mut().unwrap().poll()? { - Async::Ready(item) => { - self.item2 = Some(item); - self.fut2.take(); - self.fut3 = Some(Box::new(T3::from_request(&self.req))); - }, - Async::NotReady => return Ok(Async::NotReady), - } - } - - let item = match self.fut3.as_mut().unwrap().poll()? { - Async::Ready(item) => item, - Async::NotReady => return Ok(Async::NotReady), - }; - - let hnd: &mut F = unsafe{&mut *self.hnd.get()}; - let item = match (*hnd)(self.item1.take().unwrap(), - self.item2.take().unwrap(), - item) - .respond_to(self.req.without_state()) - { - Ok(item) => item.into(), - Err(err) => return Err(err.into()), - }; - - match item.into() { - ReplyItem::Message(resp) => return Ok(Async::Ready(resp)), - ReplyItem::Future(fut) => self.fut4 = Some(fut), - } - - self.poll() - } -} diff --git a/src/ws/client.rs b/src/ws/client.rs deleted file mode 100644 index 7372832f5..000000000 --- a/src/ws/client.rs +++ /dev/null @@ -1,545 +0,0 @@ -//! Http client request -use std::{fmt, io, str}; -use std::rc::Rc; -use std::cell::UnsafeCell; -use std::time::Duration; - -use base64; -use rand; -use bytes::Bytes; -use cookie::Cookie; -use byteorder::{ByteOrder, NetworkEndian}; -use http::{HttpTryFrom, StatusCode, Error as HttpError}; -use http::header::{self, HeaderName, HeaderValue}; -use sha1::Sha1; -use futures::{Async, Future, Poll, Stream}; -use futures::unsync::mpsc::{unbounded, UnboundedSender}; - -use actix::prelude::*; - -use body::{Body, Binary}; -use error::{Error, UrlParseError}; -use header::IntoHeaderValue; -use payload::PayloadHelper; -use httpmessage::HttpMessage; - -use client::{ClientRequest, ClientRequestBuilder, ClientResponse, - ClientConnector, SendRequest, SendRequestError, - HttpResponseParserError}; - -use super::{Message, ProtocolError}; -use super::frame::Frame; -use super::proto::{CloseCode, OpCode}; - - -/// Websocket client error -#[derive(Fail, Debug)] -pub enum ClientError { - #[fail(display="Invalid url")] - InvalidUrl, - #[fail(display="Invalid response status")] - InvalidResponseStatus(StatusCode), - #[fail(display="Invalid upgrade header")] - InvalidUpgradeHeader, - #[fail(display="Invalid connection header")] - InvalidConnectionHeader(HeaderValue), - #[fail(display="Missing CONNECTION header")] - MissingConnectionHeader, - #[fail(display="Missing SEC-WEBSOCKET-ACCEPT header")] - MissingWebSocketAcceptHeader, - #[fail(display="Invalid challenge response")] - InvalidChallengeResponse(String, HeaderValue), - #[fail(display="Http parsing error")] - Http(Error), - #[fail(display="Url parsing error")] - Url(UrlParseError), - #[fail(display="Response parsing error")] - ResponseParseError(HttpResponseParserError), - #[fail(display="{}", _0)] - SendRequest(SendRequestError), - #[fail(display="{}", _0)] - Protocol(#[cause] ProtocolError), - #[fail(display="{}", _0)] - Io(io::Error), - #[fail(display="Disconnected")] - Disconnected, -} - -impl From for ClientError { - fn from(err: Error) -> ClientError { - ClientError::Http(err) - } -} - -impl From for ClientError { - fn from(err: UrlParseError) -> ClientError { - ClientError::Url(err) - } -} - -impl From for ClientError { - fn from(err: SendRequestError) -> ClientError { - ClientError::SendRequest(err) - } -} - -impl From for ClientError { - fn from(err: ProtocolError) -> ClientError { - ClientError::Protocol(err) - } -} - -impl From for ClientError { - fn from(err: io::Error) -> ClientError { - ClientError::Io(err) - } -} - -impl From for ClientError { - fn from(err: HttpResponseParserError) -> ClientError { - ClientError::ResponseParseError(err) - } -} - -/// `WebSocket` client -/// -/// Example of `WebSocket` client usage is available in -/// [websocket example]( -/// https://github.com/actix/actix-web/blob/master/examples/websocket/src/client.rs#L24) -pub struct Client { - request: ClientRequestBuilder, - err: Option, - http_err: Option, - origin: Option, - protocols: Option, - conn: Addr, - max_size: usize, -} - -impl Client { - - /// Create new websocket connection - pub fn new>(uri: S) -> Client { - Client::with_connector(uri, ClientConnector::from_registry()) - } - - /// Create new websocket connection with custom `ClientConnector` - pub fn with_connector>(uri: S, conn: Addr) -> Client { - let mut cl = Client { - request: ClientRequest::build(), - err: None, - http_err: None, - origin: None, - protocols: None, - max_size: 65_536, - conn, - }; - cl.request.uri(uri.as_ref()); - cl - } - - /// Set supported websocket protocols - pub fn protocols(mut self, protos: U) -> Self - where U: IntoIterator + 'static, - V: AsRef - { - let mut protos = protos.into_iter() - .fold(String::new(), |acc, s| {acc + s.as_ref() + ","}); - protos.pop(); - self.protocols = Some(protos); - self - } - - /// Set cookie for handshake request - pub fn cookie(mut self, cookie: Cookie) -> Self { - self.request.cookie(cookie); - self - } - - /// Set request Origin - pub fn origin(mut self, origin: V) -> Self - where HeaderValue: HttpTryFrom - { - match HeaderValue::try_from(origin) { - Ok(value) => self.origin = Some(value), - Err(e) => self.http_err = Some(e.into()), - } - self - } - - /// Set max frame size - /// - /// By default max size is set to 64kb - pub fn max_frame_size(mut self, size: usize) -> Self { - self.max_size = size; - self - } - - /// Set write buffer capacity - /// - /// Default buffer capacity is 32kb - pub fn write_buffer_capacity(mut self, cap: usize) -> Self { - self.request.write_buffer_capacity(cap); - self - } - - /// Set request header - pub fn header(mut self, key: K, value: V) -> Self - where HeaderName: HttpTryFrom, V: IntoHeaderValue - { - self.request.header(key, value); - self - } - - /// Set websocket handshake timeout - /// - /// Handshake timeout is a total time for successful handshake. - /// Default value is 5 seconds. - pub fn timeout(mut self, timeout: Duration) -> Self { - self.request.timeout(timeout); - self - } - - /// Connect to websocket server and do ws handshake - pub fn connect(&mut self) -> ClientHandshake { - if let Some(e) = self.err.take() { - ClientHandshake::error(e) - } - else if let Some(e) = self.http_err.take() { - ClientHandshake::error(Error::from(e).into()) - } else { - // origin - if let Some(origin) = self.origin.take() { - self.request.set_header(header::ORIGIN, origin); - } - - self.request.upgrade(); - self.request.set_header(header::UPGRADE, "websocket"); - self.request.set_header(header::CONNECTION, "upgrade"); - self.request.set_header(header::SEC_WEBSOCKET_VERSION, "13"); - self.request.with_connector(self.conn.clone()); - - if let Some(protocols) = self.protocols.take() { - self.request.set_header(header::SEC_WEBSOCKET_PROTOCOL, protocols.as_str()); - } - let request = match self.request.finish() { - Ok(req) => req, - Err(err) => return ClientHandshake::error(err.into()), - }; - - if request.uri().host().is_none() { - return ClientHandshake::error(ClientError::InvalidUrl) - } - if let Some(scheme) = request.uri().scheme_part() { - if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" { - return ClientHandshake::error(ClientError::InvalidUrl) - } - } else { - return ClientHandshake::error(ClientError::InvalidUrl) - } - - // start handshake - ClientHandshake::new(request, self.max_size) - } - } -} - -struct Inner { - tx: UnboundedSender, - rx: PayloadHelper, - closed: bool, -} - -/// Future that implementes client websocket handshake process. -/// -/// It resolves to a pair of `ClientReadr` and `ClientWriter` that -/// can be used for reading and writing websocket frames. -pub struct ClientHandshake { - request: Option, - tx: Option>, - key: String, - error: Option, - max_size: usize, -} - -impl ClientHandshake { - fn new(mut request: ClientRequest, max_size: usize) -> ClientHandshake - { - // Generate a random key for the `Sec-WebSocket-Key` header. - // a base64-encoded (see Section 4 of [RFC4648]) value that, - // when decoded, is 16 bytes in length (RFC 6455) - let sec_key: [u8; 16] = rand::random(); - let key = base64::encode(&sec_key); - - request.headers_mut().insert( - header::SEC_WEBSOCKET_KEY, - HeaderValue::try_from(key.as_str()).unwrap()); - - let (tx, rx) = unbounded(); - request.set_body(Body::Streaming( - Box::new(rx.map_err(|_| io::Error::new( - io::ErrorKind::Other, "disconnected").into())))); - - ClientHandshake { - key, - max_size, - request: Some(request.send()), - tx: Some(tx), - error: None, - } - } - - fn error(err: ClientError) -> ClientHandshake { - ClientHandshake { - key: String::new(), - request: None, - tx: None, - error: Some(err), - max_size: 0, - } - } - - /// Set handshake timeout - /// - /// Handshake timeout is a total time before handshake should be completed. - /// Default value is 5 seconds. - pub fn timeout(mut self, timeout: Duration) -> Self { - if let Some(request) = self.request.take() { - self.request = Some(request.timeout(timeout)); - } - self - } - - /// Set connection timeout - /// - /// Connection timeout includes resolving hostname and actual connection to - /// the host. - /// Default value is 1 second. - pub fn conn_timeout(mut self, timeout: Duration) -> Self { - if let Some(request) = self.request.take() { - self.request = Some(request.conn_timeout(timeout)); - } - self - } -} - -impl Future for ClientHandshake { - type Item = (ClientReader, ClientWriter); - type Error = ClientError; - - fn poll(&mut self) -> Poll { - if let Some(err) = self.error.take() { - return Err(err) - } - - let resp = match self.request.as_mut().unwrap().poll()? { - Async::Ready(response) => { - self.request.take(); - response - }, - Async::NotReady => return Ok(Async::NotReady) - }; - - // verify response - if resp.status() != StatusCode::SWITCHING_PROTOCOLS { - return Err(ClientError::InvalidResponseStatus(resp.status())) - } - // Check for "UPGRADE" to websocket header - let has_hdr = if let Some(hdr) = resp.headers().get(header::UPGRADE) { - if let Ok(s) = hdr.to_str() { - s.to_lowercase().contains("websocket") - } else { - false - } - } else { - false - }; - if !has_hdr { - trace!("Invalid upgrade header"); - return Err(ClientError::InvalidUpgradeHeader) - } - // Check for "CONNECTION" header - if let Some(conn) = resp.headers().get(header::CONNECTION) { - if let Ok(s) = conn.to_str() { - if !s.to_lowercase().contains("upgrade") { - trace!("Invalid connection header: {}", s); - return Err(ClientError::InvalidConnectionHeader(conn.clone())) - } - } else { - trace!("Invalid connection header: {:?}", conn); - return Err(ClientError::InvalidConnectionHeader(conn.clone())) - } - } else { - trace!("Missing connection header"); - return Err(ClientError::MissingConnectionHeader) - } - - if let Some(key) = resp.headers().get(header::SEC_WEBSOCKET_ACCEPT) - { - // field is constructed by concatenating /key/ - // with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) - const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - let mut sha1 = Sha1::new(); - sha1.update(self.key.as_ref()); - sha1.update(WS_GUID); - let encoded = base64::encode(&sha1.digest().bytes()); - if key.as_bytes() != encoded.as_bytes() { - trace!( - "Invalid challenge response: expected: {} received: {:?}", - encoded, key); - return Err(ClientError::InvalidChallengeResponse(encoded, key.clone())); - } - } else { - trace!("Missing SEC-WEBSOCKET-ACCEPT header"); - return Err(ClientError::MissingWebSocketAcceptHeader) - }; - - let inner = Inner { - tx: self.tx.take().unwrap(), - rx: PayloadHelper::new(resp), - closed: false, - }; - - let inner = Rc::new(UnsafeCell::new(inner)); - Ok(Async::Ready( - (ClientReader{inner: Rc::clone(&inner), max_size: self.max_size}, - ClientWriter{inner}))) - } -} - - -pub struct ClientReader { - inner: Rc>, - max_size: usize, -} - -impl fmt::Debug for ClientReader { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "ws::ClientReader()") - } -} - -impl ClientReader { - #[inline] - fn as_mut(&mut self) -> &mut Inner { - unsafe{ &mut *self.inner.get() } - } -} - -impl Stream for ClientReader { - type Item = Message; - type Error = ProtocolError; - - fn poll(&mut self) -> Poll, Self::Error> { - let max_size = self.max_size; - let inner = self.as_mut(); - if inner.closed { - return Ok(Async::Ready(None)) - } - - // read - match Frame::parse(&mut inner.rx, false, max_size) { - Ok(Async::Ready(Some(frame))) => { - let (_finished, opcode, payload) = frame.unpack(); - - match opcode { - // continuation is not supported - OpCode::Continue => { - inner.closed = true; - Err(ProtocolError::NoContinuation) - }, - OpCode::Bad => { - inner.closed = true; - Err(ProtocolError::BadOpCode) - }, - OpCode::Close => { - inner.closed = true; - let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16; - Ok(Async::Ready(Some(Message::Close(CloseCode::from(code))))) - }, - OpCode::Ping => - Ok(Async::Ready(Some( - Message::Ping( - String::from_utf8_lossy(payload.as_ref()).into())))), - OpCode::Pong => - Ok(Async::Ready(Some( - Message::Pong( - String::from_utf8_lossy(payload.as_ref()).into())))), - OpCode::Binary => - Ok(Async::Ready(Some(Message::Binary(payload)))), - OpCode::Text => { - let tmp = Vec::from(payload.as_ref()); - match String::from_utf8(tmp) { - Ok(s) => - Ok(Async::Ready(Some(Message::Text(s)))), - Err(_) => { - inner.closed = true; - Err(ProtocolError::BadEncoding) - } - } - } - } - } - Ok(Async::Ready(None)) => Ok(Async::Ready(None)), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => { - inner.closed = true; - Err(e) - } - } - } -} - -pub struct ClientWriter { - inner: Rc> -} - -impl ClientWriter { - #[inline] - fn as_mut(&mut self) -> &mut Inner { - unsafe{ &mut *self.inner.get() } - } -} - -impl ClientWriter { - - /// Write payload - #[inline] - fn write(&mut self, mut data: Binary) { - if !self.as_mut().closed { - let _ = self.as_mut().tx.unbounded_send(data.take()); - } else { - warn!("Trying to write to disconnected response"); - } - } - - /// Send text frame - #[inline] - pub fn text>(&mut self, text: T) { - self.write(Frame::message(text.into(), OpCode::Text, true, true)); - } - - /// Send binary frame - #[inline] - pub fn binary>(&mut self, data: B) { - self.write(Frame::message(data, OpCode::Binary, true, true)); - } - - /// Send ping frame - #[inline] - pub fn ping(&mut self, message: &str) { - self.write(Frame::message(Vec::from(message), OpCode::Ping, true, true)); - } - - /// Send pong frame - #[inline] - pub fn pong(&mut self, message: &str) { - self.write(Frame::message(Vec::from(message), OpCode::Pong, true, true)); - } - - /// Send close frame - #[inline] - pub fn close(&mut self, code: CloseCode, reason: &str) { - self.write(Frame::close(code, reason, true)); - } -} diff --git a/src/ws/context.rs b/src/ws/context.rs deleted file mode 100644 index 92151c0d4..000000000 --- a/src/ws/context.rs +++ /dev/null @@ -1,237 +0,0 @@ -use std::mem; -use futures::{Async, Poll}; -use futures::sync::oneshot::Sender; -use futures::unsync::oneshot; -use smallvec::SmallVec; - -use actix::{Actor, ActorState, ActorContext, AsyncContext, - Addr, Handler, Message, Syn, Unsync, SpawnHandle}; -use actix::fut::ActorFuture; -use actix::dev::{ContextImpl, ToEnvelope, SyncEnvelope}; - -use body::{Body, Binary}; -use error::{Error, ErrorInternalServerError}; -use httprequest::HttpRequest; -use context::{Frame as ContextFrame, ActorHttpContext, Drain}; - -use ws::frame::Frame; -use ws::proto::{OpCode, CloseCode}; - - -/// Execution context for `WebSockets` actors -pub struct WebsocketContext where A: Actor>, -{ - inner: ContextImpl, - stream: Option>, - request: HttpRequest, - disconnected: bool, -} - -impl ActorContext for WebsocketContext where A: Actor -{ - fn stop(&mut self) { - self.inner.stop(); - } - fn terminate(&mut self) { - self.inner.terminate() - } - fn state(&self) -> ActorState { - self.inner.state() - } -} - -impl AsyncContext for WebsocketContext where A: Actor -{ - fn spawn(&mut self, fut: F) -> SpawnHandle - where F: ActorFuture + 'static - { - self.inner.spawn(fut) - } - - fn wait(&mut self, fut: F) - where F: ActorFuture + 'static - { - self.inner.wait(fut) - } - - #[doc(hidden)] - #[inline] - fn waiting(&self) -> bool { - self.inner.waiting() || self.inner.state() == ActorState::Stopping || - self.inner.state() == ActorState::Stopped - } - - fn cancel_future(&mut self, handle: SpawnHandle) -> bool { - self.inner.cancel_future(handle) - } - - #[doc(hidden)] - #[inline] - fn unsync_address(&mut self) -> Addr { - self.inner.unsync_address() - } - - #[doc(hidden)] - #[inline] - fn sync_address(&mut self) -> Addr { - self.inner.sync_address() - } -} - -impl WebsocketContext where A: Actor { - - #[inline] - pub fn new(req: HttpRequest, actor: A) -> WebsocketContext { - WebsocketContext::from_request(req).actor(actor) - } - - pub fn from_request(req: HttpRequest) -> WebsocketContext { - WebsocketContext { - inner: ContextImpl::new(None), - stream: None, - request: req, - disconnected: false, - } - } - - #[inline] - pub fn actor(mut self, actor: A) -> WebsocketContext { - self.inner.set_actor(actor); - self - } -} - -impl WebsocketContext where A: Actor { - - /// Write payload - #[inline] - fn write(&mut self, data: Binary) { - if !self.disconnected { - if self.stream.is_none() { - self.stream = Some(SmallVec::new()); - } - let stream = self.stream.as_mut().unwrap(); - stream.push(ContextFrame::Chunk(Some(data))); - self.inner.modify(); - } else { - warn!("Trying to write to disconnected response"); - } - } - - /// Shared application state - #[inline] - pub fn state(&self) -> &S { - self.request.state() - } - - /// Incoming request - #[inline] - pub fn request(&mut self) -> &mut HttpRequest { - &mut self.request - } - - /// Send text frame - #[inline] - pub fn text>(&mut self, text: T) { - self.write(Frame::message(text.into(), OpCode::Text, true, false)); - } - - /// Send binary frame - #[inline] - pub fn binary>(&mut self, data: B) { - self.write(Frame::message(data, OpCode::Binary, true, false)); - } - - /// Send ping frame - #[inline] - pub fn ping(&mut self, message: &str) { - self.write(Frame::message(Vec::from(message), OpCode::Ping, true, false)); - } - - /// Send pong frame - #[inline] - pub fn pong(&mut self, message: &str) { - self.write(Frame::message(Vec::from(message), OpCode::Pong, true, false)); - } - - /// Send close frame - #[inline] - pub fn close(&mut self, code: CloseCode, reason: &str) { - self.write(Frame::close(code, reason, false)); - } - - /// Returns drain future - pub fn drain(&mut self) -> Drain { - let (tx, rx) = oneshot::channel(); - self.inner.modify(); - self.add_frame(ContextFrame::Drain(tx)); - Drain::new(rx) - } - - /// Check if connection still open - #[inline] - pub fn connected(&self) -> bool { - !self.disconnected - } - - #[inline] - fn add_frame(&mut self, frame: ContextFrame) { - if self.stream.is_none() { - self.stream = Some(SmallVec::new()); - } - self.stream.as_mut().map(|s| s.push(frame)); - self.inner.modify(); - } - - /// Handle of the running future - /// - /// SpawnHandle is the handle returned by `AsyncContext::spawn()` method. - pub fn handle(&self) -> SpawnHandle { - self.inner.curr_handle() - } -} - -impl ActorHttpContext for WebsocketContext where A: Actor, S: 'static { - - #[inline] - fn disconnected(&mut self) { - self.disconnected = true; - self.stop(); - } - - fn poll(&mut self) -> Poll>, Error> { - let ctx: &mut WebsocketContext = unsafe { - mem::transmute(self as &mut WebsocketContext) - }; - - if self.inner.alive() && self.inner.poll(ctx).is_err() { - return Err(ErrorInternalServerError("error")) - } - - // frames - if let Some(data) = self.stream.take() { - Ok(Async::Ready(Some(data))) - } else if self.inner.alive() { - Ok(Async::NotReady) - } else { - Ok(Async::Ready(None)) - } - } -} - -impl ToEnvelope for WebsocketContext - where A: Actor> + Handler, - M: Message + Send + 'static, M::Result: Send -{ - fn pack(msg: M, tx: Option>) -> SyncEnvelope { - SyncEnvelope::new(msg, tx) - } -} - -impl From> for Body - where A: Actor>, S: 'static -{ - fn from(ctx: WebsocketContext) -> Body { - Body::Actor(Box::new(ctx)) - } -} diff --git a/src/ws/frame.rs b/src/ws/frame.rs deleted file mode 100644 index 2afcd0358..000000000 --- a/src/ws/frame.rs +++ /dev/null @@ -1,501 +0,0 @@ -use std::{fmt, mem, ptr}; -use std::iter::FromIterator; -use bytes::{Bytes, BytesMut, BufMut}; -use byteorder::{ByteOrder, BigEndian, NetworkEndian}; -use futures::{Async, Poll, Stream}; -use rand; - -use body::Binary; -use error::{PayloadError}; -use payload::PayloadHelper; - -use ws::ProtocolError; -use ws::proto::{OpCode, CloseCode}; -use ws::mask::apply_mask; - -/// A struct representing a `WebSocket` frame. -#[derive(Debug)] -pub struct Frame { - finished: bool, - opcode: OpCode, - payload: Binary, -} - -impl Frame { - - /// Destruct frame - pub fn unpack(self) -> (bool, OpCode, Binary) { - (self.finished, self.opcode, self.payload) - } - - /// Create a new Close control frame. - #[inline] - pub fn close(code: CloseCode, reason: &str, genmask: bool) -> Binary { - let raw: [u8; 2] = unsafe { - let u: u16 = code.into(); - mem::transmute(u.to_be()) - }; - - let payload = if let CloseCode::Empty = code { - Vec::new() - } else { - Vec::from_iter( - raw[..].iter() - .chain(reason.as_bytes().iter()) - .cloned()) - }; - - Frame::message(payload, OpCode::Close, true, genmask) - } - - #[cfg_attr(feature="cargo-clippy", allow(type_complexity))] - fn read_copy_md(pl: &mut PayloadHelper, - server: bool, - max_size: usize - ) -> Poll)>, ProtocolError> - where S: Stream - { - let mut idx = 2; - let buf = match pl.copy(2)? { - Async::Ready(Some(buf)) => buf, - Async::Ready(None) => return Ok(Async::Ready(None)), - Async::NotReady => return Ok(Async::NotReady), - }; - let first = buf[0]; - let second = buf[1]; - let finished = first & 0x80 != 0; - - // check masking - let masked = second & 0x80 != 0; - if !masked && server { - return Err(ProtocolError::UnmaskedFrame) - } else if masked && !server { - return Err(ProtocolError::MaskedFrame) - } - - // Op code - let opcode = OpCode::from(first & 0x0F); - - if let OpCode::Bad = opcode { - return Err(ProtocolError::InvalidOpcode(first & 0x0F)) - } - - let len = second & 0x7F; - let length = if len == 126 { - let buf = match pl.copy(4)? { - Async::Ready(Some(buf)) => buf, - Async::Ready(None) => return Ok(Async::Ready(None)), - Async::NotReady => return Ok(Async::NotReady), - }; - let len = NetworkEndian::read_uint(&buf[idx..], 2) as usize; - idx += 2; - len - } else if len == 127 { - let buf = match pl.copy(10)? { - Async::Ready(Some(buf)) => buf, - Async::Ready(None) => return Ok(Async::Ready(None)), - Async::NotReady => return Ok(Async::NotReady), - }; - let len = NetworkEndian::read_uint(&buf[idx..], 8) as usize; - idx += 8; - len - } else { - len as usize - }; - - // check for max allowed size - if length > max_size { - return Err(ProtocolError::Overflow) - } - - let mask = if server { - let buf = match pl.copy(idx + 4)? { - Async::Ready(Some(buf)) => buf, - Async::Ready(None) => return Ok(Async::Ready(None)), - Async::NotReady => return Ok(Async::NotReady), - }; - - let mask: &[u8] = &buf[idx..idx+4]; - let mask_u32: u32 = unsafe {ptr::read_unaligned(mask.as_ptr() as *const u32)}; - idx += 4; - Some(mask_u32) - } else { - None - }; - - Ok(Async::Ready(Some((idx, finished, opcode, length, mask)))) - } - - fn read_chunk_md(chunk: &[u8], server: bool, max_size: usize) - -> Poll<(usize, bool, OpCode, usize, Option), ProtocolError> - { - let chunk_len = chunk.len(); - - let mut idx = 2; - if chunk_len < 2 { - return Ok(Async::NotReady) - } - - let first = chunk[0]; - let second = chunk[1]; - let finished = first & 0x80 != 0; - - // check masking - let masked = second & 0x80 != 0; - if !masked && server { - return Err(ProtocolError::UnmaskedFrame) - } else if masked && !server { - return Err(ProtocolError::MaskedFrame) - } - - // Op code - let opcode = OpCode::from(first & 0x0F); - - if let OpCode::Bad = opcode { - return Err(ProtocolError::InvalidOpcode(first & 0x0F)) - } - - let len = second & 0x7F; - let length = if len == 126 { - if chunk_len < 4 { - return Ok(Async::NotReady) - } - let len = NetworkEndian::read_uint(&chunk[idx..], 2) as usize; - idx += 2; - len - } else if len == 127 { - if chunk_len < 10 { - return Ok(Async::NotReady) - } - let len = NetworkEndian::read_uint(&chunk[idx..], 8) as usize; - idx += 8; - len - } else { - len as usize - }; - - // check for max allowed size - if length > max_size { - return Err(ProtocolError::Overflow) - } - - let mask = if server { - if chunk_len < idx + 4 { - return Ok(Async::NotReady) - } - - let mask: &[u8] = &chunk[idx..idx+4]; - let mask_u32: u32 = unsafe {ptr::read_unaligned(mask.as_ptr() as *const u32)}; - idx += 4; - Some(mask_u32) - } else { - None - }; - - Ok(Async::Ready((idx, finished, opcode, length, mask))) - } - - /// Parse the input stream into a frame. - pub fn parse(pl: &mut PayloadHelper, server: bool, max_size: usize) - -> Poll, ProtocolError> - where S: Stream - { - // try to parse ws frame md from one chunk - let result = match pl.get_chunk()? { - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(None) => return Ok(Async::Ready(None)), - Async::Ready(Some(chunk)) => Frame::read_chunk_md(chunk, server, max_size)?, - }; - - let (idx, finished, opcode, length, mask) = match result { - // we may need to join several chunks - Async::NotReady => match Frame::read_copy_md(pl, server, max_size)? { - Async::Ready(Some(item)) => item, - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(None) => return Ok(Async::Ready(None)), - }, - Async::Ready(item) => item, - }; - - match pl.can_read(idx + length)? { - Async::Ready(Some(true)) => (), - Async::Ready(None) => return Ok(Async::Ready(None)), - Async::Ready(Some(false)) | Async::NotReady => return Ok(Async::NotReady), - } - - // remove prefix - pl.drop_payload(idx); - - // no need for body - if length == 0 { - return Ok(Async::Ready(Some(Frame { - finished, opcode, payload: Binary::from("") }))); - } - - let data = match pl.read_exact(length)? { - Async::Ready(Some(buf)) => buf, - Async::Ready(None) => return Ok(Async::Ready(None)), - Async::NotReady => panic!(), - }; - - // control frames must have length <= 125 - match opcode { - OpCode::Ping | OpCode::Pong if length > 125 => { - return Err(ProtocolError::InvalidLength(length)) - } - OpCode::Close if length > 125 => { - debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); - return Ok(Async::Ready(Some(Frame::default()))) - } - _ => () - } - - // unmask - if let Some(mask) = mask { - #[allow(mutable_transmutes)] - let p: &mut [u8] = unsafe{let ptr: &[u8] = &data; mem::transmute(ptr)}; - apply_mask(p, mask); - } - - Ok(Async::Ready(Some(Frame { - finished, opcode, payload: data.into() }))) - } - - /// Generate binary representation - pub fn message>(data: B, code: OpCode, - finished: bool, genmask: bool) -> Binary - { - let payload = data.into(); - let one: u8 = if finished { - 0x80 | Into::::into(code) - } else { - code.into() - }; - let payload_len = payload.len(); - let (two, p_len) = if genmask { - (0x80, payload_len + 4) - } else { - (0, payload_len) - }; - - let mut buf = if payload_len < 126 { - let mut buf = BytesMut::with_capacity(p_len + 2); - buf.put_slice(&[one, two | payload_len as u8]); - buf - } else if payload_len <= 65_535 { - let mut buf = BytesMut::with_capacity(p_len + 4); - buf.put_slice(&[one, two | 126]); - { - let buf_mut = unsafe{buf.bytes_mut()}; - BigEndian::write_u16(&mut buf_mut[..2], payload_len as u16); - } - unsafe{buf.advance_mut(2)}; - buf - } else { - let mut buf = BytesMut::with_capacity(p_len + 10); - buf.put_slice(&[one, two | 127]); - { - let buf_mut = unsafe{buf.bytes_mut()}; - BigEndian::write_u64(&mut buf_mut[..8], payload_len as u64); - } - unsafe{buf.advance_mut(8)}; - buf - }; - - if genmask { - let mask = rand::random::(); - unsafe { - { - let buf_mut = buf.bytes_mut(); - *(buf_mut as *mut _ as *mut u32) = mask; - buf_mut[4..payload_len+4].copy_from_slice(payload.as_ref()); - apply_mask(&mut buf_mut[4..], mask); - } - buf.advance_mut(payload_len + 4); - } - buf.into() - } else { - buf.put_slice(payload.as_ref()); - buf.into() - } - } -} - -impl Default for Frame { - fn default() -> Frame { - Frame { - finished: true, - opcode: OpCode::Close, - payload: Binary::from(&b""[..]), - } - } -} - -impl fmt::Display for Frame { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, - " - - final: {} - opcode: {} - payload length: {} - payload: 0x{} -", - self.finished, - self.opcode, - self.payload.len(), - self.payload.as_ref().iter().map( - |byte| format!("{:x}", byte)).collect::()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use futures::stream::once; - - fn is_none(frm: Poll, ProtocolError>) -> bool { - match frm { - Ok(Async::Ready(None)) => true, - _ => false, - } - } - - fn extract(frm: Poll, ProtocolError>) -> Frame { - match frm { - Ok(Async::Ready(Some(frame))) => frame, - _ => unreachable!("error"), - } - } - - #[test] - fn test_parse() { - let mut buf = PayloadHelper::new( - once(Ok(BytesMut::from(&[0b00000001u8, 0b00000001u8][..]).freeze()))); - assert!(is_none(Frame::parse(&mut buf, false, 1024))); - - let mut buf = BytesMut::from(&[0b00000001u8, 0b00000001u8][..]); - buf.extend(b"1"); - let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - - let frame = extract(Frame::parse(&mut buf, false, 1024)); - assert!(!frame.finished); - assert_eq!(frame.opcode, OpCode::Text); - assert_eq!(frame.payload.as_ref(), &b"1"[..]); - } - - #[test] - fn test_parse_length0() { - let buf = BytesMut::from(&[0b00000001u8, 0b00000000u8][..]); - let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - - let frame = extract(Frame::parse(&mut buf, false, 1024)); - assert!(!frame.finished); - assert_eq!(frame.opcode, OpCode::Text); - assert!(frame.payload.is_empty()); - } - - #[test] - fn test_parse_length2() { - let buf = BytesMut::from(&[0b00000001u8, 126u8][..]); - let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - assert!(is_none(Frame::parse(&mut buf, false, 1024))); - - let mut buf = BytesMut::from(&[0b00000001u8, 126u8][..]); - buf.extend(&[0u8, 4u8][..]); - buf.extend(b"1234"); - let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - - let frame = extract(Frame::parse(&mut buf, false, 1024)); - assert!(!frame.finished); - assert_eq!(frame.opcode, OpCode::Text); - assert_eq!(frame.payload.as_ref(), &b"1234"[..]); - } - - #[test] - fn test_parse_length4() { - let buf = BytesMut::from(&[0b00000001u8, 127u8][..]); - let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - assert!(is_none(Frame::parse(&mut buf, false, 1024))); - - let mut buf = BytesMut::from(&[0b00000001u8, 127u8][..]); - buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]); - buf.extend(b"1234"); - let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - - let frame = extract(Frame::parse(&mut buf, false, 1024)); - assert!(!frame.finished); - assert_eq!(frame.opcode, OpCode::Text); - assert_eq!(frame.payload.as_ref(), &b"1234"[..]); - } - - #[test] - fn test_parse_frame_mask() { - let mut buf = BytesMut::from(&[0b00000001u8, 0b10000001u8][..]); - buf.extend(b"0001"); - buf.extend(b"1"); - let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - - assert!(Frame::parse(&mut buf, false, 1024).is_err()); - - let frame = extract(Frame::parse(&mut buf, true, 1024)); - assert!(!frame.finished); - assert_eq!(frame.opcode, OpCode::Text); - assert_eq!(frame.payload, vec![1u8].into()); - } - - #[test] - fn test_parse_frame_no_mask() { - let mut buf = BytesMut::from(&[0b00000001u8, 0b00000001u8][..]); - buf.extend(&[1u8]); - let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - - assert!(Frame::parse(&mut buf, true, 1024).is_err()); - - let frame = extract(Frame::parse(&mut buf, false, 1024)); - assert!(!frame.finished); - assert_eq!(frame.opcode, OpCode::Text); - assert_eq!(frame.payload, vec![1u8].into()); - } - - #[test] - fn test_parse_frame_max_size() { - let mut buf = BytesMut::from(&[0b00000001u8, 0b00000010u8][..]); - buf.extend(&[1u8, 1u8]); - let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - - assert!(Frame::parse(&mut buf, true, 1).is_err()); - - if let Err(ProtocolError::Overflow) = Frame::parse(&mut buf, false, 0) { - } else { - unreachable!("error"); - } - } - - #[test] - fn test_ping_frame() { - let frame = Frame::message(Vec::from("data"), OpCode::Ping, true, false); - - let mut v = vec![137u8, 4u8]; - v.extend(b"data"); - assert_eq!(frame, v.into()); - } - - #[test] - fn test_pong_frame() { - let frame = Frame::message(Vec::from("data"), OpCode::Pong, true, false); - - let mut v = vec![138u8, 4u8]; - v.extend(b"data"); - assert_eq!(frame, v.into()); - } - - #[test] - fn test_close_frame() { - let frame = Frame::close(CloseCode::Normal, "data", false); - - let mut v = vec![136u8, 6u8, 3u8, 232u8]; - v.extend(b"data"); - assert_eq!(frame, v.into()); - } -} diff --git a/src/ws/mask.rs b/src/ws/mask.rs deleted file mode 100644 index 33216bf23..000000000 --- a/src/ws/mask.rs +++ /dev/null @@ -1,143 +0,0 @@ -//! This is code from [Tungstenite project](https://github.com/snapview/tungstenite-rs) -use std::cmp::min; -use std::mem::uninitialized; -use std::ptr::copy_nonoverlapping; - -/// Mask/unmask a frame. -#[inline] -pub fn apply_mask(buf: &mut [u8], mask: u32) { - apply_mask_fast32(buf, mask) -} - -/// A safe unoptimized mask application. -#[inline] -#[allow(dead_code)] -fn apply_mask_fallback(buf: &mut [u8], mask: &[u8; 4]) { - for (i, byte) in buf.iter_mut().enumerate() { - *byte ^= mask[i & 3]; - } -} - -/// Faster version of `apply_mask()` which operates on 8-byte blocks. -#[inline] -#[cfg_attr(feature="cargo-clippy", allow(cast_lossless))] -fn apply_mask_fast32(buf: &mut [u8], mask_u32: u32) { - let mut ptr = buf.as_mut_ptr(); - let mut len = buf.len(); - - // Possible first unaligned block. - let head = min(len, (8 - (ptr as usize & 0x7)) & 0x3); - let mask_u32 = if head > 0 { - let n = if head > 4 { head - 4 } else { head }; - - let mask_u32 = if n > 0 { - unsafe { - xor_mem(ptr, mask_u32, n); - ptr = ptr.offset(head as isize); - } - len -= n; - if cfg!(target_endian = "big") { - mask_u32.rotate_left(8 * n as u32) - } else { - mask_u32.rotate_right(8 * n as u32) - } - } else { - mask_u32 - }; - - if head > 4 { - unsafe { - *(ptr as *mut u32) ^= mask_u32; - ptr = ptr.offset(4); - len -= 4; - } - } - mask_u32 - } else { - mask_u32 - }; - - if len > 0 { - debug_assert_eq!(ptr as usize % 4, 0); - } - - // Properly aligned middle of the data. - if len >= 8 { - let mut mask_u64 = mask_u32 as u64; - mask_u64 = mask_u64 << 32 | mask_u32 as u64; - - while len >= 8 { - unsafe { - *(ptr as *mut u64) ^= mask_u64; - ptr = ptr.offset(8); - len -= 8; - } - } - } - - while len >= 4 { - unsafe { - *(ptr as *mut u32) ^= mask_u32; - ptr = ptr.offset(4); - len -= 4; - } - } - - // Possible last block. - if len > 0 { - unsafe { xor_mem(ptr, mask_u32, len); } - } -} - -#[inline] -// TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so inefficient, -// it could be done better. The compiler does not see that len is limited to 3. -unsafe fn xor_mem(ptr: *mut u8, mask: u32, len: usize) { - let mut b: u32 = uninitialized(); - #[allow(trivial_casts)] - copy_nonoverlapping(ptr, &mut b as *mut _ as *mut u8, len); - b ^= mask; - #[allow(trivial_casts)] - copy_nonoverlapping(&b as *const _ as *const u8, ptr, len); -} - -#[cfg(test)] -mod tests { - use std::ptr; - use super::{apply_mask_fallback, apply_mask_fast32}; - - #[test] - fn test_apply_mask() { - let mask = [ - 0x6d, 0xb6, 0xb2, 0x80, - ]; - let mask_u32: u32 = unsafe {ptr::read_unaligned(mask.as_ptr() as *const u32)}; - - let unmasked = vec![ - 0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, - 0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9, 0x12, 0x03, - ]; - - // Check masking with proper alignment. - { - let mut masked = unmasked.clone(); - apply_mask_fallback(&mut masked, &mask); - - let mut masked_fast = unmasked.clone(); - apply_mask_fast32(&mut masked_fast, mask_u32); - - assert_eq!(masked, masked_fast); - } - - // Check masking without alignment. - { - let mut masked = unmasked.clone(); - apply_mask_fallback(&mut masked[1..], &mask); - - let mut masked_fast = unmasked.clone(); - apply_mask_fast32(&mut masked_fast[1..], mask_u32); - - assert_eq!(masked, masked_fast); - } - } -} diff --git a/src/ws/mod.rs b/src/ws/mod.rs deleted file mode 100644 index 9c5c74c53..000000000 --- a/src/ws/mod.rs +++ /dev/null @@ -1,424 +0,0 @@ -//! `WebSocket` support for Actix -//! -//! To setup a `WebSocket`, first do web socket handshake then on success convert `Payload` -//! into a `WsStream` stream and then use `WsWriter` to communicate with the peer. -//! -//! ## Example -//! -//! ```rust -//! # extern crate actix; -//! # extern crate actix_web; -//! # use actix::*; -//! # use actix_web::*; -//! use actix_web::{ws, HttpRequest, HttpResponse}; -//! -//! // do websocket handshake and start actor -//! fn ws_index(req: HttpRequest) -> Result { -//! ws::start(req, Ws) -//! } -//! -//! struct Ws; -//! -//! impl Actor for Ws { -//! type Context = ws::WebsocketContext; -//! } -//! -//! // Handler for ws::Message messages -//! impl StreamHandler for Ws { -//! -//! fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { -//! match msg { -//! ws::Message::Ping(msg) => ctx.pong(&msg), -//! ws::Message::Text(text) => ctx.text(text), -//! ws::Message::Binary(bin) => ctx.binary(bin), -//! _ => (), -//! } -//! } -//! } -//! # -//! # fn main() { -//! # App::new() -//! # .resource("/ws/", |r| r.f(ws_index)) // <- register websocket route -//! # .finish(); -//! # } -//! ``` -use bytes::Bytes; -use http::{Method, StatusCode, header}; -use futures::{Async, Poll, Stream}; -use byteorder::{ByteOrder, NetworkEndian}; - -use actix::{Actor, AsyncContext, StreamHandler}; - -use body::Binary; -use payload::PayloadHelper; -use error::{Error, PayloadError, ResponseError}; -use httpmessage::HttpMessage; -use httprequest::HttpRequest; -use httpresponse::{ConnectionType, HttpResponse, HttpResponseBuilder}; - -mod frame; -mod proto; -mod context; -mod mask; -mod client; - -pub use self::frame::Frame; -pub use self::proto::OpCode; -pub use self::proto::CloseCode; -pub use self::context::WebsocketContext; -pub use self::client::{Client, ClientError, - ClientReader, ClientWriter, ClientHandshake}; - -/// Websocket protocol errors -#[derive(Fail, Debug)] -pub enum ProtocolError { - /// Received an unmasked frame from client - #[fail(display="Received an unmasked frame from client")] - UnmaskedFrame, - /// Received a masked frame from server - #[fail(display="Received a masked frame from server")] - MaskedFrame, - /// Encountered invalid opcode - #[fail(display="Invalid opcode: {}", _0)] - InvalidOpcode(u8), - /// Invalid control frame length - #[fail(display="Invalid control frame length: {}", _0)] - InvalidLength(usize), - /// Bad web socket op code - #[fail(display="Bad web socket op code")] - BadOpCode, - /// A payload reached size limit. - #[fail(display="A payload reached size limit.")] - Overflow, - /// Continuation is not supported - #[fail(display="Continuation is not supported.")] - NoContinuation, - /// Bad utf-8 encoding - #[fail(display="Bad utf-8 encoding.")] - BadEncoding, - /// Payload error - #[fail(display="Payload error: {}", _0)] - Payload(#[cause] PayloadError), -} - -impl ResponseError for ProtocolError {} - -impl From for ProtocolError { - fn from(err: PayloadError) -> ProtocolError { - ProtocolError::Payload(err) - } -} - -/// Websocket handshake errors -#[derive(Fail, PartialEq, Debug)] -pub enum HandshakeError { - /// Only get method is allowed - #[fail(display="Method not allowed")] - GetMethodRequired, - /// Upgrade header if not set to websocket - #[fail(display="Websocket upgrade is expected")] - NoWebsocketUpgrade, - /// Connection header is not set to upgrade - #[fail(display="Connection upgrade is expected")] - NoConnectionUpgrade, - /// Websocket version header is not set - #[fail(display="Websocket version header is required")] - NoVersionHeader, - /// Unsupported websocket version - #[fail(display="Unsupported version")] - UnsupportedVersion, - /// Websocket key is not set or wrong - #[fail(display="Unknown websocket key")] - BadWebsocketKey, -} - -impl ResponseError for HandshakeError { - - fn error_response(&self) -> HttpResponse { - match *self { - HandshakeError::GetMethodRequired => { - HttpResponse::MethodNotAllowed().header(header::ALLOW, "GET").finish() - } - HandshakeError::NoWebsocketUpgrade => HttpResponse::BadRequest() - .reason("No WebSocket UPGRADE header found").finish(), - HandshakeError::NoConnectionUpgrade => HttpResponse::BadRequest() - .reason("No CONNECTION upgrade").finish(), - HandshakeError::NoVersionHeader => HttpResponse::BadRequest() - .reason("Websocket version header is required").finish(), - HandshakeError::UnsupportedVersion => HttpResponse::BadRequest() - .reason("Unsupported version").finish(), - HandshakeError::BadWebsocketKey => HttpResponse::BadRequest() - .reason("Handshake error").finish(), - } - } -} - -/// `WebSocket` Message -#[derive(Debug, PartialEq, Message)] -pub enum Message { - Text(String), - Binary(Binary), - Ping(String), - Pong(String), - Close(CloseCode), -} - -/// Do websocket handshake and start actor -pub fn start(req: HttpRequest, actor: A) -> Result - where A: Actor> + StreamHandler, - S: 'static -{ - let mut resp = handshake(&req)?; - let stream = WsStream::new(req.clone()); - - let mut ctx = WebsocketContext::new(req, actor); - ctx.add_stream(stream); - - Ok(resp.body(ctx)) -} - -/// Prepare `WebSocket` handshake response. -/// -/// This function returns handshake `HttpResponse`, ready to send to peer. -/// It does not perform any IO. -/// -// /// `protocols` is a sequence of known protocols. On successful handshake, -// /// the returned response headers contain the first protocol in this list -// /// which the server also knows. -pub fn handshake(req: &HttpRequest) -> Result { - // WebSocket accepts only GET - if *req.method() != Method::GET { - return Err(HandshakeError::GetMethodRequired) - } - - // Check for "UPGRADE" to websocket header - let has_hdr = if let Some(hdr) = req.headers().get(header::UPGRADE) { - if let Ok(s) = hdr.to_str() { - s.to_lowercase().contains("websocket") - } else { - false - } - } else { - false - }; - if !has_hdr { - return Err(HandshakeError::NoWebsocketUpgrade) - } - - // Upgrade connection - if !req.upgrade() { - return Err(HandshakeError::NoConnectionUpgrade) - } - - // check supported version - if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) { - return Err(HandshakeError::NoVersionHeader) - } - let supported_ver = { - if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) { - hdr == "13" || hdr == "8" || hdr == "7" - } else { - false - } - }; - if !supported_ver { - return Err(HandshakeError::UnsupportedVersion) - } - - // check client handshake for validity - if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) { - return Err(HandshakeError::BadWebsocketKey) - } - let key = { - let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap(); - proto::hash_key(key.as_ref()) - }; - - Ok(HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS) - .connection_type(ConnectionType::Upgrade) - .header(header::UPGRADE, "websocket") - .header(header::TRANSFER_ENCODING, "chunked") - .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str()) - .take()) -} - -/// Maps `Payload` stream into stream of `ws::Message` items -pub struct WsStream { - rx: PayloadHelper, - closed: bool, - max_size: usize, -} - -impl WsStream where S: Stream { - /// Create new websocket frames stream - pub fn new(stream: S) -> WsStream { - WsStream { rx: PayloadHelper::new(stream), - closed: false, - max_size: 65_536, - } - } - - /// Set max frame size - /// - /// By default max size is set to 64kb - pub fn max_size(mut self, size: usize) -> Self { - self.max_size = size; - self - } -} - -impl Stream for WsStream where S: Stream { - type Item = Message; - type Error = ProtocolError; - - fn poll(&mut self) -> Poll, Self::Error> { - if self.closed { - return Ok(Async::Ready(None)) - } - - match Frame::parse(&mut self.rx, true, self.max_size) { - Ok(Async::Ready(Some(frame))) => { - let (finished, opcode, payload) = frame.unpack(); - - // continuation is not supported - if !finished { - self.closed = true; - return Err(ProtocolError::NoContinuation) - } - - match opcode { - OpCode::Continue => Err(ProtocolError::NoContinuation), - OpCode::Bad => { - self.closed = true; - Err(ProtocolError::BadOpCode) - } - OpCode::Close => { - self.closed = true; - let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16; - Ok(Async::Ready( - Some(Message::Close(CloseCode::from(code))))) - }, - OpCode::Ping => - Ok(Async::Ready(Some( - Message::Ping( - String::from_utf8_lossy(payload.as_ref()).into())))), - OpCode::Pong => - Ok(Async::Ready(Some( - Message::Pong(String::from_utf8_lossy(payload.as_ref()).into())))), - OpCode::Binary => - Ok(Async::Ready(Some(Message::Binary(payload)))), - OpCode::Text => { - let tmp = Vec::from(payload.as_ref()); - match String::from_utf8(tmp) { - Ok(s) => - Ok(Async::Ready(Some(Message::Text(s)))), - Err(_) => { - self.closed = true; - Err(ProtocolError::BadEncoding) - } - } - } - } - } - Ok(Async::Ready(None)) => Ok(Async::Ready(None)), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => { - self.closed = true; - Err(e) - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::str::FromStr; - use http::{Method, HeaderMap, Version, Uri, header}; - - #[test] - fn test_handshake() { - let req = HttpRequest::new(Method::POST, Uri::from_str("/").unwrap(), - Version::HTTP_11, HeaderMap::new(), None); - assert_eq!(HandshakeError::GetMethodRequired, handshake(&req).err().unwrap()); - - let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), - Version::HTTP_11, HeaderMap::new(), None); - assert_eq!(HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap()); - - let mut headers = HeaderMap::new(); - headers.insert(header::UPGRADE, - header::HeaderValue::from_static("test")); - let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), - Version::HTTP_11, headers, None); - assert_eq!(HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap()); - - let mut headers = HeaderMap::new(); - headers.insert(header::UPGRADE, - header::HeaderValue::from_static("websocket")); - let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), - Version::HTTP_11, headers, None); - assert_eq!(HandshakeError::NoConnectionUpgrade, handshake(&req).err().unwrap()); - - let mut headers = HeaderMap::new(); - headers.insert(header::UPGRADE, - header::HeaderValue::from_static("websocket")); - headers.insert(header::CONNECTION, - header::HeaderValue::from_static("upgrade")); - let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), - Version::HTTP_11, headers, None); - assert_eq!(HandshakeError::NoVersionHeader, handshake(&req).err().unwrap()); - - let mut headers = HeaderMap::new(); - headers.insert(header::UPGRADE, - header::HeaderValue::from_static("websocket")); - headers.insert(header::CONNECTION, - header::HeaderValue::from_static("upgrade")); - headers.insert(header::SEC_WEBSOCKET_VERSION, - header::HeaderValue::from_static("5")); - let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), - Version::HTTP_11, headers, None); - assert_eq!(HandshakeError::UnsupportedVersion, handshake(&req).err().unwrap()); - - let mut headers = HeaderMap::new(); - headers.insert(header::UPGRADE, - header::HeaderValue::from_static("websocket")); - headers.insert(header::CONNECTION, - header::HeaderValue::from_static("upgrade")); - headers.insert(header::SEC_WEBSOCKET_VERSION, - header::HeaderValue::from_static("13")); - let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), - Version::HTTP_11, headers, None); - assert_eq!(HandshakeError::BadWebsocketKey, handshake(&req).err().unwrap()); - - let mut headers = HeaderMap::new(); - headers.insert(header::UPGRADE, - header::HeaderValue::from_static("websocket")); - headers.insert(header::CONNECTION, - header::HeaderValue::from_static("upgrade")); - headers.insert(header::SEC_WEBSOCKET_VERSION, - header::HeaderValue::from_static("13")); - headers.insert(header::SEC_WEBSOCKET_KEY, - header::HeaderValue::from_static("13")); - let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), - Version::HTTP_11, headers, None); - assert_eq!(StatusCode::SWITCHING_PROTOCOLS, - handshake(&req).unwrap().finish().status()); - } - - #[test] - fn test_wserror_http_response() { - let resp: HttpResponse = HandshakeError::GetMethodRequired.error_response(); - assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); - let resp: HttpResponse = HandshakeError::NoWebsocketUpgrade.error_response(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: HttpResponse = HandshakeError::NoConnectionUpgrade.error_response(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: HttpResponse = HandshakeError::NoVersionHeader.error_response(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: HttpResponse = HandshakeError::UnsupportedVersion.error_response(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: HttpResponse = HandshakeError::BadWebsocketKey.error_response(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - } -} diff --git a/test-server/CHANGES.md b/test-server/CHANGES.md new file mode 100644 index 000000000..57068fe95 --- /dev/null +++ b/test-server/CHANGES.md @@ -0,0 +1,59 @@ +# Changes + +## [0.2.5] - 2019-0917 + +### Changed + +* Update serde_urlencoded to "0.6.1" +* Increase TestServerRuntime timeouts from 500ms to 3000ms + +### Fixed + +* Do not override current `System` + + +## [0.2.4] - 2019-07-18 + +* Update actix-server to 0.6 + +## [0.2.3] - 2019-07-16 + +* Add `delete`, `options`, `patch` methods to `TestServerRunner` + +## [0.2.2] - 2019-06-16 + +* Add .put() and .sput() methods + +## [0.2.1] - 2019-06-05 + +* Add license files + +## [0.2.0] - 2019-05-12 + +* Update awc and actix-http deps + +## [0.1.1] - 2019-04-24 + +* Always make new connection for http client + + +## [0.1.0] - 2019-04-16 + +* No changes + + +## [0.1.0-alpha.3] - 2019-04-02 + +* Request functions accept path #743 + + +## [0.1.0-alpha.2] - 2019-03-29 + +* Added TestServerRuntime::load_body() method + +* Update actix-http and awc libraries + + +## [0.1.0-alpha.1] - 2019-03-28 + +* Initial impl diff --git a/test-server/Cargo.toml b/test-server/Cargo.toml new file mode 100644 index 000000000..e59e439fe --- /dev/null +++ b/test-server/Cargo.toml @@ -0,0 +1,61 @@ +[package] +name = "actix-http-test" +version = "0.3.0-alpha.1" +authors = ["Nikolay Kim "] +description = "Actix http test server" +readme = "README.md" +keywords = ["http", "web", "framework", "async", "futures"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" +documentation = "https://docs.rs/actix-http-test/" +categories = ["network-programming", "asynchronous", + "web-programming::http-server", + "web-programming::websocket"] +license = "MIT/Apache-2.0" +exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"] +edition = "2018" +workspace = ".." + +[package.metadata.docs.rs] +features = [] + +[lib] +name = "actix_http_test" +path = "src/lib.rs" + +[features] +default = [] + +# openssl +openssl = ["open-ssl", "actix-server/openssl", "awc/openssl"] + +[dependencies] +actix-service = "1.0.0-alpha.1" +actix-codec = "0.2.0-alpha.1" +actix-connect = "1.0.0-alpha.1" +actix-utils = "0.5.0-alpha.1" +actix-rt = "1.0.0-alpha.1" +actix-server = "0.8.0-alpha.1" +actix-server-config = "0.3.0-alpha.1" +actix-testing = "0.3.0-alpha.1" +awc = "0.3.0-alpha.1" + +base64 = "0.10" +bytes = "0.4" +futures = "0.3.1" +http = "0.1.8" +log = "0.4" +env_logger = "0.6" +net2 = "0.2" +serde = "1.0" +serde_json = "1.0" +sha1 = "0.6" +slab = "0.4" +serde_urlencoded = "0.6.1" +time = "0.1" +tokio-net = "0.2.0-alpha.6" +open-ssl = { version="0.10", package="openssl", optional = true } + +[dev-dependencies] +actix-web = "2.0.0-alpha.1" +actix-http = "0.3.0-alpha.1" diff --git a/test-server/LICENSE-APACHE b/test-server/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/test-server/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/test-server/LICENSE-MIT b/test-server/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/test-server/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/test-server/README.md b/test-server/README.md new file mode 100644 index 000000000..e40650124 --- /dev/null +++ b/test-server/README.md @@ -0,0 +1,9 @@ +# Actix http test server [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-http-test)](https://crates.io/crates/actix-http-test) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +## Documentation & community resources + +* [User Guide](https://actix.rs/docs/) +* [API Documentation](https://docs.rs/actix-http-test/) +* [Chat on gitter](https://gitter.im/actix/actix) +* Cargo package: [actix-http-test](https://crates.io/crates/actix-http-test) +* Minimum supported Rust version: 1.33 or later diff --git a/test-server/src/lib.rs b/test-server/src/lib.rs new file mode 100644 index 000000000..9ad06397c --- /dev/null +++ b/test-server/src/lib.rs @@ -0,0 +1,266 @@ +//! Various helpers for Actix applications to use during testing. +use std::sync::mpsc; +use std::{net, thread, time}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_rt::System; +use actix_server::{Server, ServiceFactory}; +use awc::{error::PayloadError, ws, Client, ClientRequest, ClientResponse, Connector}; +use bytes::Bytes; +use futures::Stream; +use http::Method; +use net2::TcpBuilder; +use tokio_net::tcp::TcpStream; + +pub use actix_testing::*; + +/// The `TestServer` type. +/// +/// `TestServer` is very simple test server that simplify process of writing +/// integration tests cases for actix web applications. +/// +/// # Examples +/// +/// ```rust +/// use actix_http::HttpService; +/// use actix_http_test::TestServer; +/// use actix_web::{web, App, HttpResponse, Error}; +/// +/// async fn my_handler() -> Result { +/// Ok(HttpResponse::Ok().into()) +/// } +/// +/// #[actix_rt::test] +/// async fn test_example() { +/// let mut srv = TestServer::start( +/// || HttpService::new( +/// App::new().service( +/// web::resource("/").to(my_handler)) +/// ) +/// ); +/// +/// let req = srv.get("/"); +/// let response = req.send().await.unwrap(); +/// assert!(response.status().is_success()); +/// } +/// ``` +pub struct TestServer; + +/// Test server controller +pub struct TestServerRuntime { + addr: net::SocketAddr, + client: Client, + system: System, +} + +impl TestServer { + #[allow(clippy::new_ret_no_self)] + /// Start new test server with application factory + pub fn start>(factory: F) -> TestServerRuntime { + let (tx, rx) = mpsc::channel(); + + // run server in separate thread + thread::spawn(move || { + let sys = System::new("actix-test-server"); + let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap(); + let local_addr = tcp.local_addr().unwrap(); + + Server::build() + .listen("test", tcp, factory)? + .workers(1) + .disable_signals() + .start(); + + tx.send((System::current(), local_addr)).unwrap(); + sys.run() + }); + + let (system, addr) = rx.recv().unwrap(); + + let client = { + let connector = { + #[cfg(feature = "openssl")] + { + use open_ssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let _ = builder + .set_alpn_protos(b"\x02h2\x08http/1.1") + .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); + Connector::new() + .conn_lifetime(time::Duration::from_secs(0)) + .timeout(time::Duration::from_millis(3000)) + .ssl(builder.build()) + .finish() + } + #[cfg(not(feature = "openssl"))] + { + Connector::new() + .conn_lifetime(time::Duration::from_secs(0)) + .timeout(time::Duration::from_millis(3000)) + .finish() + } + }; + + Client::build().connector(connector).finish() + }; + actix_connect::start_default_resolver(); + + TestServerRuntime { + addr, + client, + system, + } + } + + /// Get first available unused address + pub fn unused_addr() -> net::SocketAddr { + let addr: net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + let socket = TcpBuilder::new_v4().unwrap(); + socket.bind(&addr).unwrap(); + socket.reuse_address(true).unwrap(); + let tcp = socket.to_tcp_listener().unwrap(); + tcp.local_addr().unwrap() + } +} + +impl TestServerRuntime { + /// Construct test server url + pub fn addr(&self) -> net::SocketAddr { + self.addr + } + + /// Construct test server url + pub fn url(&self, uri: &str) -> String { + if uri.starts_with('/') { + format!("http://localhost:{}{}", self.addr.port(), uri) + } else { + format!("http://localhost:{}/{}", self.addr.port(), uri) + } + } + + /// Construct test https server url + pub fn surl(&self, uri: &str) -> String { + if uri.starts_with('/') { + format!("https://localhost:{}{}", self.addr.port(), uri) + } else { + format!("https://localhost:{}/{}", self.addr.port(), uri) + } + } + + /// Create `GET` request + pub fn get>(&self, path: S) -> ClientRequest { + self.client.get(self.url(path.as_ref()).as_str()) + } + + /// Create https `GET` request + pub fn sget>(&self, path: S) -> ClientRequest { + self.client.get(self.surl(path.as_ref()).as_str()) + } + + /// Create `POST` request + pub fn post>(&self, path: S) -> ClientRequest { + self.client.post(self.url(path.as_ref()).as_str()) + } + + /// Create https `POST` request + pub fn spost>(&self, path: S) -> ClientRequest { + self.client.post(self.surl(path.as_ref()).as_str()) + } + + /// Create `HEAD` request + pub fn head>(&self, path: S) -> ClientRequest { + self.client.head(self.url(path.as_ref()).as_str()) + } + + /// Create https `HEAD` request + pub fn shead>(&self, path: S) -> ClientRequest { + self.client.head(self.surl(path.as_ref()).as_str()) + } + + /// Create `PUT` request + pub fn put>(&self, path: S) -> ClientRequest { + self.client.put(self.url(path.as_ref()).as_str()) + } + + /// Create https `PUT` request + pub fn sput>(&self, path: S) -> ClientRequest { + self.client.put(self.surl(path.as_ref()).as_str()) + } + + /// Create `PATCH` request + pub fn patch>(&self, path: S) -> ClientRequest { + self.client.patch(self.url(path.as_ref()).as_str()) + } + + /// Create https `PATCH` request + pub fn spatch>(&self, path: S) -> ClientRequest { + self.client.patch(self.surl(path.as_ref()).as_str()) + } + + /// Create `DELETE` request + pub fn delete>(&self, path: S) -> ClientRequest { + self.client.delete(self.url(path.as_ref()).as_str()) + } + + /// Create https `DELETE` request + pub fn sdelete>(&self, path: S) -> ClientRequest { + self.client.delete(self.surl(path.as_ref()).as_str()) + } + + /// Create `OPTIONS` request + pub fn options>(&self, path: S) -> ClientRequest { + self.client.options(self.url(path.as_ref()).as_str()) + } + + /// Create https `OPTIONS` request + pub fn soptions>(&self, path: S) -> ClientRequest { + self.client.options(self.surl(path.as_ref()).as_str()) + } + + /// Connect to test http server + pub fn request>(&self, method: Method, path: S) -> ClientRequest { + self.client.request(method, path.as_ref()) + } + + pub async fn load_body( + &mut self, + mut response: ClientResponse, + ) -> Result + where + S: Stream> + Unpin + 'static, + { + response.body().limit(10_485_760).await + } + + /// Connect to websocket server at a given path + pub async fn ws_at( + &mut self, + path: &str, + ) -> Result, awc::error::WsClientError> + { + let url = self.url(path); + let connect = self.client.ws(url).connect(); + connect.await.map(|(_, framed)| framed) + } + + /// Connect to a websocket server + pub async fn ws( + &mut self, + ) -> Result, awc::error::WsClientError> + { + self.ws_at("/").await + } + + /// Stop http server + fn stop(&mut self) { + self.system.stop(); + } +} + +impl Drop for TestServerRuntime { + fn drop(&mut self) { + self.stop() + } +} diff --git a/tests/cert.pem b/tests/cert.pem index 159aacea2..0eeb6721d 100644 --- a/tests/cert.pem +++ b/tests/cert.pem @@ -1,31 +1,19 @@ -----BEGIN CERTIFICATE----- -MIIFPjCCAyYCCQDvLYiYD+jqeTANBgkqhkiG9w0BAQsFADBhMQswCQYDVQQGEwJV -UzELMAkGA1UECAwCQ0ExCzAJBgNVBAcMAlNGMRAwDgYDVQQKDAdDb21wYW55MQww -CgYDVQQLDANPcmcxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xODAxMjUx -NzQ2MDFaFw0xOTAxMjUxNzQ2MDFaMGExCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJD -QTELMAkGA1UEBwwCU0YxEDAOBgNVBAoMB0NvbXBhbnkxDDAKBgNVBAsMA09yZzEY -MBYGA1UEAwwPd3d3LmV4YW1wbGUuY29tMIICIjANBgkqhkiG9w0BAQEFAAOCAg8A -MIICCgKCAgEA2WzIA2IpVR9Tb9EFhITlxuhE5rY2a3S6qzYNzQVgSFggxXEPn8k1 -sQEcer5BfAP986Sck3H0FvB4Bt/I8PwOtUCmhwcc8KtB5TcGPR4fjXnrpC+MIK5U -NLkwuyBDKziYzTdBj8kUFX1WxmvEHEgqToPOZfBgsS71cJAR/zOWraDLSRM54jXy -voLZN4Ti9rQagQrvTQ44Vz5ycDQy7UxtbUGh1CVv69vNVr7/SOOh/Nw5FNOZWLWr -odGyoec5wh9iqRZgRqiTUc6Lt7V2RWc2X2gjwST2UfI+U46Ip3oaQ7ZD4eAkoqND -xdniBZAykVG3c/99ux4BAESTF8fsNch6UticBxYMuTu+ouvP0psfI9wwwNliJDmA -CRUTB9AgRynbL1AzhqQoDfsb98IZfjfNOpwnwuLwpMAPhbgd5KNdZaIJ4Hb6/stI -yFElOExxd3TAxF2Gshd/lq1JcNHAZ1DSXV5MvOWT/NWgXwbIzUgQ8eIi+HuDYX2U -UuaB6R8tbd52H7rbUv6HrfinuSlKWqjSYLkiKHkwUpoMw8y9UycRSzs1E9nPwPTO -vRXb0mNCQeBCV9FvStNVXdCUTT8LGPv87xSD2pmt7LijlE6mHLG8McfcWkzA69un -CEHIFAFDimTuN7EBljc119xWFTcHMyoZAfFF+oTqwSbBGImruCxnaJECAwEAATAN -BgkqhkiG9w0BAQsFAAOCAgEApavsgsn7SpPHfhDSN5iZs1ILZQRewJg0Bty0xPfk -3tynSW6bNH3nSaKbpsdmxxomthNSQgD2heOq1By9YzeOoNR+7Pk3s4FkASnf3ToI -JNTUasBFFfaCG96s4Yvs8KiWS/k84yaWuU8c3Wb1jXs5Rv1qE1Uvuwat1DSGXSoD -JNluuIkCsC4kWkyq5pWCGQrabWPRTWsHwC3PTcwSRBaFgYLJaR72SloHB1ot02zL -d2age9dmFRFLLCBzP+D7RojBvL37qS/HR+rQ4SoQwiVc/JzaeqSe7ZbvEH9sZYEu -ALowJzgbwro7oZflwTWunSeSGDSltkqKjvWvZI61pwfHKDahUTmZ5h2y67FuGEaC -CIOUI8dSVSPKITxaq3JL4ze2e9/0Lt7hj19YK2uUmtMAW5Tirz4Yx5lyGH9U8Wur -y/X8VPxTc4A9TMlJgkyz0hqvhbPOT/zSWB10zXh0glKAsSBryAOEDxV1UygmSir7 -YV8Qaq+oyKUTMc1MFq5vZ07M51EPaietn85t8V2Y+k/8XYltRp32NxsypxAJuyxh -g/ko6RVTrWa1sMvz/F9LFqAdKiK5eM96lh9IU4xiLg4ob8aS/GRAA8oIFkZFhLrt -tOwjIUPmEPyHWFi8dLpNuQKYalLYhuwZftG/9xV+wqhKGZO9iPrpHSYBRTap8w2y -1QU= +MIIDEDCCAfgCCQCQdmIZc/Ib/jANBgkqhkiG9w0BAQsFADBKMQswCQYDVQQGEwJ1 +czELMAkGA1UECAwCY2ExCzAJBgNVBAcMAnNmMSEwHwYJKoZIhvcNAQkBFhJmYWZo +cmQ5MUBnbWFpbC5jb20wHhcNMTkxMTE5MTEwNjU1WhcNMjkxMTE2MTEwNjU1WjBK +MQswCQYDVQQGEwJ1czELMAkGA1UECAwCY2ExCzAJBgNVBAcMAnNmMSEwHwYJKoZI +hvcNAQkBFhJmYWZocmQ5MUBnbWFpbC5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IB +DwAwggEKAoIBAQDcnaz12CKzUL7248V7Axhms/O9UQXfAdw0yolEfC3P5jADa/1C ++kLWKjAc2coqDSbGsrsR6KiH2g06Kunx+tSGqUO+Sct7HEehmxndiSwx/hfMWezy +XRe/olcHFTeCk/Tllz4xGEplhPua6GLhJygLOhAMiV8cwCYrgyPqsDduExLDFCqc +K2xntIPreumXpiE3QY4+MWyteiJko4IWDFf/UwwsdCY5MlFfw1F/Uv9vz7FfOfvu +GccHd/ex8cOwotUqd6emZb+0bVE24Sv8U+yLnHIVx/tOkxgMAnJEpAnf2G3Wp3zU +b2GJosbmfGaf+xTfnGGhTLLL7kCtva+NvZr5AgMBAAEwDQYJKoZIhvcNAQELBQAD +ggEBANftoL8zDGrjCwWvct8kOOqset2ukK8vjIGwfm88CKsy0IfSochNz2qeIu9R +ZuO7c0pfjmRkir9ZQdq9vXgG3ccL9UstFsferPH9W3YJ83kgXg3fa0EmCiN/0hwz +6Ij1ZBiN1j3+d6+PJPgyYFNu2nGwox5mJ9+aRAGe0/9c63PEOY8P2TI4HsiPmYSl +fFR8k/03vr6e+rTKW85BgctjvYKe/TnFxeCQ7dZ+na7vlEtch4tNmy6O/vEk2kCt +5jW0DUxhmRsv2wGmfFRI0+LotHjoXQQZi6nN5aGL3odaGF3gYwIVlZNd3AdkwDQz +BzG0ZwXuDDV9bSs3MfWEWcy4xuU= -----END CERTIFICATE----- diff --git a/tests/key.pem b/tests/key.pem index aac387c64..a6d308168 100644 --- a/tests/key.pem +++ b/tests/key.pem @@ -1,51 +1,28 @@ ------BEGIN RSA PRIVATE KEY----- -MIIJKAIBAAKCAgEA2WzIA2IpVR9Tb9EFhITlxuhE5rY2a3S6qzYNzQVgSFggxXEP -n8k1sQEcer5BfAP986Sck3H0FvB4Bt/I8PwOtUCmhwcc8KtB5TcGPR4fjXnrpC+M -IK5UNLkwuyBDKziYzTdBj8kUFX1WxmvEHEgqToPOZfBgsS71cJAR/zOWraDLSRM5 -4jXyvoLZN4Ti9rQagQrvTQ44Vz5ycDQy7UxtbUGh1CVv69vNVr7/SOOh/Nw5FNOZ -WLWrodGyoec5wh9iqRZgRqiTUc6Lt7V2RWc2X2gjwST2UfI+U46Ip3oaQ7ZD4eAk -oqNDxdniBZAykVG3c/99ux4BAESTF8fsNch6UticBxYMuTu+ouvP0psfI9wwwNli -JDmACRUTB9AgRynbL1AzhqQoDfsb98IZfjfNOpwnwuLwpMAPhbgd5KNdZaIJ4Hb6 -/stIyFElOExxd3TAxF2Gshd/lq1JcNHAZ1DSXV5MvOWT/NWgXwbIzUgQ8eIi+HuD -YX2UUuaB6R8tbd52H7rbUv6HrfinuSlKWqjSYLkiKHkwUpoMw8y9UycRSzs1E9nP -wPTOvRXb0mNCQeBCV9FvStNVXdCUTT8LGPv87xSD2pmt7LijlE6mHLG8McfcWkzA -69unCEHIFAFDimTuN7EBljc119xWFTcHMyoZAfFF+oTqwSbBGImruCxnaJECAwEA -AQKCAgAME3aoeXNCPxMrSri7u4Xnnk71YXl0Tm9vwvjRQlMusXZggP8VKN/KjP0/ -9AE/GhmoxqPLrLCZ9ZE1EIjgmZ9Xgde9+C8rTtfCG2RFUL7/5J2p6NonlocmxoJm -YkxYwjP6ce86RTjQWL3RF3s09u0inz9/efJk5O7M6bOWMQ9VZXDlBiRY5BYvbqUR -6FeSzD4MnMbdyMRoVBeXE88gTvZk8xhB6DJnLzYgc0tKiRoeKT0iYv5JZw25VyRM -ycLzfTrFmXCPfB1ylb483d9Ly4fBlM8nkx37PzEnAuukIawDxsPOb9yZC+hfvNJI -7NFiMN+3maEqG2iC00w4Lep4skHY7eHUEUMl+Wjr+koAy2YGLWAwHZQTm7iXn9Ab -L6adL53zyCKelRuEQOzbeosJAqS+5fpMK0ekXyoFIuskj7bWuIoCX7K/kg6q5IW+ -vC2FrlsrbQ79GztWLVmHFO1I4J9M5r666YS0qdh8c+2yyRl4FmSiHfGxb3eOKpxQ -b6uI97iZlkxPF9LYUCSc7wq0V2gGz+6LnGvTHlHrOfVXqw/5pLAKhXqxvnroDTwz -0Ay/xFF6ei/NSxBY5t8ztGCBm45wCU3l8pW0X6dXqwUipw5b4MRy1VFRu6rqlmbL -OPSCuLxqyqsigiEYsBgS/icvXz9DWmCQMPd2XM9YhsHvUq+R4QKCAQEA98EuMMXI -6UKIt1kK2t/3OeJRyDd4iv/fCMUAnuPjLBvFE4cXD/SbqCxcQYqb+pue3PYkiTIC -71rN8OQAc5yKhzmmnCE5N26br/0pG4pwEjIr6mt8kZHmemOCNEzvhhT83nfKmV0g -9lNtuGEQMiwmZrpUOF51JOMC39bzcVjYX2Cmvb7cFbIq3lR0zwM+aZpQ4P8LHCIu -bgHmwbdlkLyIULJcQmHIbo6nPFB3ZZE4mqmjwY+rA6Fh9rgBa8OFCfTtrgeYXrNb -IgZQ5U8GoYRPNC2ot0vpTinraboa/cgm6oG4M7FW1POCJTl+/ktHEnKuO5oroSga -/BSg7hCNFVaOhwKCAQEA4Kkys0HtwEbV5mY/NnvUD5KwfXX7BxoXc9lZ6seVoLEc -KjgPYxqYRVrC7dB2YDwwp3qcRTi/uBAgFNm3iYlDzI4xS5SeaudUWjglj7BSgXE2 -iOEa7EwcvVPluLaTgiWjlzUKeUCNNHWSeQOt+paBOT+IgwRVemGVpAgkqQzNh/nP -tl3p9aNtgzEm1qVlPclY/XUCtf3bcOR+z1f1b4jBdn0leu5OhnxkC+Htik+2fTXD -jt6JGrMkanN25YzsjnD3Sn+v6SO26H99wnYx5oMSdmb8SlWRrKtfJHnihphjG/YY -l1cyorV6M/asSgXNQfGJm4OuJi0I4/FL2wLUHnU+JwKCAQEAzh4WipcRthYXXcoj -gMKRkMOb3GFh1OpYqJgVExtudNTJmZxq8GhFU51MR27Eo7LycMwKy2UjEfTOnplh -Us2qZiPtW7k8O8S2m6yXlYUQBeNdq9IuuYDTaYD94vsazscJNSAeGodjE+uGvb1q -1wLqE87yoE7dUInYa1cOA3+xy2/CaNuviBFJHtzOrSb6tqqenQEyQf6h9/12+DTW -t5pSIiixHrzxHiFqOoCLRKGToQB+71rSINwTf0nITNpGBWmSj5VcC3VV3TG5/XxI -fPlxV2yhD5WFDPVNGBGvwPDSh4jSMZdZMSNBZCy4XWFNSKjGEWoK4DFYed3DoSt9 -5IG1YwKCAQA63ntHl64KJUWlkwNbboU583FF3uWBjee5VqoGKHhf3CkKMxhtGqnt -+oN7t5VdUEhbinhqdx1dyPPvIsHCS3K1pkjqii4cyzNCVNYa2dQ00Qq+QWZBpwwc -3GAkz8rFXsGIPMDa1vxpU6mnBjzPniKMcsZ9tmQDppCEpBGfLpio2eAA5IkK8eEf -cIDB3CM0Vo94EvI76CJZabaE9IJ+0HIJb2+jz9BJ00yQBIqvJIYoNy9gP5Xjpi+T -qV/tdMkD5jwWjHD3AYHLWKUGkNwwkAYFeqT/gX6jpWBP+ZRPOp011X3KInJFSpKU -DT5GQ1Dux7EMTCwVGtXqjO8Ym5wjwwsfAoIBAEcxlhIW1G6BiNfnWbNPWBdh3v/K -5Ln98Rcrz8UIbWyl7qNPjYb13C1KmifVG1Rym9vWMO3KuG5atK3Mz2yLVRtmWAVc -fxzR57zz9MZFDun66xo+Z1wN3fVxQB4CYpOEI4Lb9ioX4v85hm3D6RpFukNtRQEc -Gfr4scTjJX4jFWDp0h6ffMb8mY+quvZoJ0TJqV9L9Yj6Ksdvqez/bdSraev97bHQ -4gbQxaTZ6WjaD4HjpPQefMdWp97Metg0ZQSS8b8EzmNFgyJ3XcjirzwliKTAQtn6 -I2sd0NCIooelrKRD8EJoDUwxoOctY7R97wpZ7/wEHU45cBCbRV3H4JILS5c= ------END RSA PRIVATE KEY----- +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDcnaz12CKzUL72 +48V7Axhms/O9UQXfAdw0yolEfC3P5jADa/1C+kLWKjAc2coqDSbGsrsR6KiH2g06 +Kunx+tSGqUO+Sct7HEehmxndiSwx/hfMWezyXRe/olcHFTeCk/Tllz4xGEplhPua +6GLhJygLOhAMiV8cwCYrgyPqsDduExLDFCqcK2xntIPreumXpiE3QY4+MWyteiJk +o4IWDFf/UwwsdCY5MlFfw1F/Uv9vz7FfOfvuGccHd/ex8cOwotUqd6emZb+0bVE2 +4Sv8U+yLnHIVx/tOkxgMAnJEpAnf2G3Wp3zUb2GJosbmfGaf+xTfnGGhTLLL7kCt +va+NvZr5AgMBAAECggEBAKoU0UwzVgVCQgca8Jt2dnBvWYDhnxIfYAI/BvaKedMm +1ms87OKfB7oOiksjyI0E2JklH72dzZf2jm4CuZt5UjGC+xwPzlTaJ4s6hQVbBHyC +NRyxU1BCXtW5tThbrhD4OjxqjmLRJEIB9OunLtwAEQoeuFLB8Va7+HFhR+Zd9k3f +7aVA93pC5A50NRbZlke4miJ3Q8n7ZF0+UmxkBfm3fbqLk7aMWkoEKwLLTadjRlu1 +bBp0YDStX66I/p1kujqBOdh6VpPvxFOa1sV9pq0jeiGc9YfSkzRSKzIn8GoyviFB +fHeszQdNlcnrSDSNnMABAw+ZpxUO7SCaftjwejEmKZUCgYEA+TY43VpmV95eY7eo +WKwGepiHE0fwQLuKGELmZdZI80tFi73oZMuiB5WzwmkaKGcJmm7KGE9KEvHQCo9j +xvmktBR0VEZH8pmVfun+4h6+0H7m/NKMBBeOyv/IK8jBgHjkkB6e6nmeR7CqTxCw +tf9tbajl1QN8gNzXZSjBDT/lanMCgYEA4qANOKOSiEARtgwyXQeeSJcM2uPv6zF3 +ffM7vjSedtuEOHUSVeyBP/W8KDt7zyPppO/WNbURHS+HV0maS9yyj6zpVS2HGmbs +3fetswsQ+zYVdokW89x4oc2z4XOGHd1LcSlyhRwPt0u2g1E9L0irwTQLWU0npFmG +PRf7sN9+LeMCgYAGkDUDL2ROoB6gRa/7Vdx90hKMoXJkYgwLA4gJ2pDlR3A3c/Lw +5KQJyxmG3zm/IqeQF6be6QesZA30mT4peV2rGHbP2WH/s6fKReNelSy1VQJEWk8x +tGUgV4gwDwN5nLV4TjYlOrq+bJqvpmLhCC8bmj0jVQosYqSRl3cuICasnQKBgGlV +VO/Xb1su1EyWPK5qxRIeSxZOTYw2sMB01nbgxCqge0M2fvA6/hQ5ZlwY0cIEgits +YlcSMsMq/TAAANxz1vbaupUhlSMbZcsBvNV0Nk9c4vr2Wxm7hsJF9u66IEMvQUp2 +pkjiMxfR9CHzF4orr9EcHI5EQ0Grbq5kwFKEfoRbAoGAcWoFPILeJOlp2yW/Ds3E +g2fQdI9BAamtEZEaslJmZMmsDTg5ACPcDkOSFEQIaJ7wLPXeZy74FVk/NrY5F8Gz +bjX9OD/xzwp852yW5L9r62vYJakAlXef5jI6CFdYKDDCcarU0S7W5k6kq9n+wrBR +i1NklYmUAMr2q59uJA5zsic= +-----END PRIVATE KEY----- diff --git a/tests/skeptic.rs b/tests/skeptic.rs deleted file mode 100644 index a0e0f9b3c..000000000 --- a/tests/skeptic.rs +++ /dev/null @@ -1,2 +0,0 @@ -#[cfg(unix)] -include!(concat!(env!("OUT_DIR"), "/skeptic-tests.rs")); diff --git a/tests/test_client.rs b/tests/test_client.rs deleted file mode 100644 index c0e0e6da9..000000000 --- a/tests/test_client.rs +++ /dev/null @@ -1,393 +0,0 @@ -extern crate actix; -extern crate actix_web; -extern crate bytes; -extern crate futures; -extern crate flate2; -extern crate rand; - -use std::io::Read; - -use bytes::Bytes; -use futures::Future; -use futures::stream::once; -use flate2::read::GzDecoder; -use rand::Rng; - -use actix_web::*; - - -const STR: &str = - "Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World"; - - -#[test] -fn test_simple() { - let mut srv = test::TestServer::new( - |app| app.handler(|_| HttpResponse::Ok().body(STR))); - - let request = srv.get().header("x-test", "111").finish().unwrap(); - let repr = format!("{:?}", request); - assert!(repr.contains("ClientRequest")); - assert!(repr.contains("x-test")); - - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); - - let request = srv.post().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_with_query_parameter() { - let mut srv = test::TestServer::new( - |app| app.handler(|req: HttpRequest| match req.query().get("qp") { - Some(_) => HttpResponse::Ok().finish(), - None => HttpResponse::BadRequest().finish(), - })); - - let request = srv.get().uri(srv.url("/?qp=5").as_str()).finish().unwrap(); - - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); -} - - -#[test] -fn test_no_decompress() { - let mut srv = test::TestServer::new( - |app| app.handler(|_| HttpResponse::Ok().body(STR))); - - let request = srv.get().disable_decompress().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - - let mut e = GzDecoder::new(&bytes[..]); - let mut dec = Vec::new(); - e.read_to_end(&mut dec).unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); - - // POST - let request = srv.post().disable_decompress().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - - let bytes = srv.execute(response.body()).unwrap(); - let mut e = GzDecoder::new(&bytes[..]); - let mut dec = Vec::new(); - e.read_to_end(&mut dec).unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_client_gzip_encoding() { - let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { - req.body() - .and_then(|bytes: Bytes| { - Ok(HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Deflate) - .body(bytes)) - }).responder()} - )); - - // client request - let request = srv.post() - .content_encoding(http::ContentEncoding::Gzip) - .body(STR).unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_client_gzip_encoding_large() { - let data = STR.repeat(10); - - let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { - req.body() - .and_then(|bytes: Bytes| { - Ok(HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Deflate) - .body(bytes)) - }).responder()} - )); - - // client request - let request = srv.post() - .content_encoding(http::ContentEncoding::Gzip) - .body(data.clone()).unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from(data)); -} - -#[test] -fn test_client_gzip_encoding_large_random() { - let data = rand::thread_rng() - .gen_ascii_chars() - .take(100_000) - .collect::(); - - let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { - req.body() - .and_then(|bytes: Bytes| { - Ok(HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Deflate) - .body(bytes)) - }).responder()} - )); - - // client request - let request = srv.post() - .content_encoding(http::ContentEncoding::Gzip) - .body(data.clone()).unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from(data)); -} - -#[cfg(feature="brotli")] -#[test] -fn test_client_brotli_encoding() { - let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { - req.body() - .and_then(|bytes: Bytes| { - Ok(HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Gzip) - .body(bytes)) - }).responder()} - )); - - // client request - let request = srv.client(http::Method::POST, "/") - .content_encoding(http::ContentEncoding::Br) - .body(STR).unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[cfg(feature="brotli")] -#[test] -fn test_client_brotli_encoding_large_random() { - let data = rand::thread_rng() - .gen_ascii_chars() - .take(70_000) - .collect::(); - - let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { - req.body() - .and_then(move |bytes: Bytes| { - Ok(HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Gzip) - .body(bytes)) - }).responder()} - )); - - // client request - let request = srv.client(http::Method::POST, "/") - .content_encoding(http::ContentEncoding::Br) - .body(data.clone()).unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes.len(), data.len()); - assert_eq!(bytes, Bytes::from(data)); -} - -#[cfg(feature="brotli")] -#[test] -fn test_client_deflate_encoding() { - let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { - req.body() - .and_then(|bytes: Bytes| { - Ok(HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Br) - .body(bytes)) - }).responder()} - )); - - // client request - let request = srv.post() - .content_encoding(http::ContentEncoding::Deflate) - .body(STR).unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[cfg(feature="brotli")] -#[test] -fn test_client_deflate_encoding_large_random() { - let data = rand::thread_rng() - .gen_ascii_chars() - .take(70_000) - .collect::(); - - let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { - req.body() - .and_then(|bytes: Bytes| { - Ok(HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Br) - .body(bytes)) - }).responder()} - )); - - // client request - let request = srv.post() - .content_encoding(http::ContentEncoding::Deflate) - .body(data.clone()).unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from(data)); -} - -#[test] -fn test_client_streaming_explicit() { - let mut srv = test::TestServer::new( - |app| app.handler( - |req: HttpRequest| req.body() - .map_err(Error::from) - .and_then(|body| { - Ok(HttpResponse::Ok() - .chunked() - .content_encoding(http::ContentEncoding::Identity) - .body(body))}) - .responder())); - - let body = once(Ok(Bytes::from_static(STR.as_ref()))); - - let request = srv.get().body(Body::Streaming(Box::new(body))).unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_body_streaming_implicit() { - let mut srv = test::TestServer::new( - |app| app.handler(|_| { - let body = once(Ok(Bytes::from_static(STR.as_ref()))); - HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Gzip) - .body(Body::Streaming(Box::new(body)))})); - - let request = srv.get().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_client_cookie_handling() { - use actix_web::http::Cookie; - fn err() -> Error { - use std::io::{ErrorKind, Error as IoError}; - // stub some generic error - Error::from(IoError::from(ErrorKind::NotFound)) - } - let cookie1 = Cookie::build("cookie1", "value1").finish(); - let cookie2 = Cookie::build("cookie2", "value2") - .domain("www.example.org") - .path("/") - .secure(true) - .http_only(true) - .finish(); - // Q: are all these clones really necessary? A: Yes, possibly - let cookie1b = cookie1.clone(); - let cookie2b = cookie2.clone(); - let mut srv = test::TestServer::new( - move |app| { - let cookie1 = cookie1b.clone(); - let cookie2 = cookie2b.clone(); - app.handler(move |req: HttpRequest| { - // Check cookies were sent correctly - req.cookie("cookie1").ok_or_else(err) - .and_then(|c1| if c1.value() == "value1" { - Ok(()) - } else { - Err(err()) - }) - .and_then(|()| req.cookie("cookie2").ok_or_else(err)) - .and_then(|c2| if c2.value() == "value2" { - Ok(()) - } else { - Err(err()) - }) - // Send some cookies back - .map(|_| HttpResponse::Ok() - .cookie(cookie1.clone()) - .cookie(cookie2.clone()) - .finish() - ) - }) - }); - - let request = srv.get() - .cookie(cookie1.clone()) - .cookie(cookie2.clone()) - .finish() - .unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - let c1 = response.cookie("cookie1").expect("Missing cookie1"); - assert_eq!(c1, &cookie1); - let c2 = response.cookie("cookie2").expect("Missing cookie2"); - assert_eq!(c2, &cookie2); -} diff --git a/tests/test_handlers.rs b/tests/test_handlers.rs deleted file mode 100644 index 909c9ddf9..000000000 --- a/tests/test_handlers.rs +++ /dev/null @@ -1,134 +0,0 @@ -extern crate actix; -extern crate actix_web; -extern crate tokio_core; -extern crate futures; -extern crate h2; -extern crate http; -extern crate bytes; -#[macro_use] extern crate serde_derive; - -use actix_web::*; -use bytes::Bytes; -use http::StatusCode; - -#[derive(Deserialize)] -struct PParam { - username: String, -} - -#[test] -fn test_path_extractor() { - let mut srv = test::TestServer::new(|app| { - app.resource( - "/{username}/index.html", |r| r.with( - |p: Path| format!("Welcome {}!", p.username))); - } - ); - - // client request - let request = srv.get().uri(srv.url("/test/index.html")) - .finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(b"Welcome test!")); -} - -#[test] -fn test_query_extractor() { - let mut srv = test::TestServer::new(|app| { - app.resource( - "/index.html", |r| r.with( - |p: Query| format!("Welcome {}!", p.username))); - } - ); - - // client request - let request = srv.get().uri(srv.url("/index.html?username=test")) - .finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(b"Welcome test!")); - - // client request - let request = srv.get().uri(srv.url("/index.html")) - .finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert_eq!(response.status(), StatusCode::BAD_REQUEST); -} - -#[test] -fn test_path_and_query_extractor() { - let mut srv = test::TestServer::new(|app| { - app.resource( - "/{username}/index.html", |r| r.route().with2( - |p: Path, q: Query| - format!("Welcome {} - {}!", p.username, q.username))); - } - ); - - // client request - let request = srv.get().uri(srv.url("/test1/index.html?username=test2")) - .finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(b"Welcome test1 - test2!")); - - // client request - let request = srv.get().uri(srv.url("/test1/index.html")) - .finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert_eq!(response.status(), StatusCode::BAD_REQUEST); -} - -#[test] -fn test_path_and_query_extractor2() { - let mut srv = test::TestServer::new(|app| { - app.resource( - "/{username}/index.html", |r| r.route().with3( - |_: HttpRequest, p: Path, q: Query| - format!("Welcome {} - {}!", p.username, q.username))); - } - ); - - // client request - let request = srv.get().uri(srv.url("/test1/index.html?username=test2")) - .finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(b"Welcome test1 - test2!")); - - // client request - let request = srv.get().uri(srv.url("/test1/index.html")) - .finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert_eq!(response.status(), StatusCode::BAD_REQUEST); -} - -#[test] -fn test_non_ascii_route() { - let mut srv = test::TestServer::new(|app| { - app.resource("/中文/index.html", |r| r.f(|_| "success")); - }); - - // client request - let request = srv.get().uri(srv.url("/中文/index.html")) - .finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(b"success")); -} diff --git a/tests/test_httpserver.rs b/tests/test_httpserver.rs new file mode 100644 index 000000000..d19c46ee7 --- /dev/null +++ b/tests/test_httpserver.rs @@ -0,0 +1,144 @@ +use net2::TcpBuilder; +use std::sync::mpsc; +use std::{net, thread, time::Duration}; + +#[cfg(feature = "openssl")] +use open_ssl::ssl::SslAcceptorBuilder; + +use actix_http::Response; +use actix_web::{web, App, HttpServer}; + +fn unused_addr() -> net::SocketAddr { + let addr: net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + let socket = TcpBuilder::new_v4().unwrap(); + socket.bind(&addr).unwrap(); + socket.reuse_address(true).unwrap(); + let tcp = socket.to_tcp_listener().unwrap(); + tcp.local_addr().unwrap() +} + +#[cfg(unix)] +#[actix_rt::test] +async fn test_start() { + let addr = unused_addr(); + let (tx, rx) = mpsc::channel(); + + thread::spawn(move || { + let sys = actix_rt::System::new("test"); + + let srv = HttpServer::new(|| { + App::new().service( + web::resource("/").route(web::to(|| Response::Ok().body("test"))), + ) + }) + .workers(1) + .backlog(1) + .maxconn(10) + .maxconnrate(10) + .keep_alive(10) + .client_timeout(5000) + .client_shutdown(0) + .server_hostname("localhost") + .system_exit() + .disable_signals() + .bind(format!("{}", addr)) + .unwrap() + .start(); + + let _ = tx.send((srv, actix_rt::System::current())); + let _ = sys.run(); + }); + let (srv, sys) = rx.recv().unwrap(); + + #[cfg(feature = "client")] + { + use actix_http::client; + + let client = awc::Client::build() + .connector( + client::Connector::new() + .timeout(Duration::from_millis(100)) + .finish(), + ) + .finish(); + + let host = format!("http://{}", addr); + let response = client.get(host.clone()).send().await.unwrap(); + assert!(response.status().is_success()); + } + + // stop + let _ = srv.stop(false); + + thread::sleep(Duration::from_millis(100)); + let _ = sys.stop(); +} + +#[cfg(feature = "openssl")] +fn ssl_acceptor() -> std::io::Result { + use open_ssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; + // load ssl keys + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + builder + .set_private_key_file("tests/key.pem", SslFiletype::PEM) + .unwrap(); + builder + .set_certificate_chain_file("tests/cert.pem") + .unwrap(); + Ok(builder) +} + +#[actix_rt::test] +#[cfg(feature = "openssl")] +async fn test_start_ssl() { + let addr = unused_addr(); + let (tx, rx) = mpsc::channel(); + + thread::spawn(move || { + let sys = actix_rt::System::new("test"); + let builder = ssl_acceptor().unwrap(); + + let srv = HttpServer::new(|| { + App::new().service( + web::resource("/").route(web::to(|| Response::Ok().body("test"))), + ) + }) + .workers(1) + .shutdown_timeout(1) + .system_exit() + .disable_signals() + .bind_openssl(format!("{}", addr), builder) + .unwrap() + .start(); + + let _ = tx.send((srv, actix_rt::System::current())); + let _ = sys.run(); + }); + let (srv, sys) = rx.recv().unwrap(); + + use open_ssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let _ = builder + .set_alpn_protos(b"\x02h2\x08http/1.1") + .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); + + let client = awc::Client::build() + .connector( + awc::Connector::new() + .ssl(builder.build()) + .timeout(Duration::from_millis(100)) + .finish(), + ) + .finish(); + + let host = format!("https://{}", addr); + let response = client.get(host.clone()).send().await.unwrap(); + assert!(response.status().is_success()); + + // stop + let _ = srv.stop(false); + + thread::sleep(Duration::from_millis(100)); + let _ = sys.stop(); +} diff --git a/tests/test_server.rs b/tests/test_server.rs index a13fc2f85..bfdf3f0ee 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -1,213 +1,83 @@ -extern crate actix; -extern crate actix_web; -extern crate tokio_core; -extern crate futures; -extern crate h2; -extern crate http as modhttp; -extern crate bytes; -extern crate flate2; -extern crate rand; - -#[cfg(feature="brotli")] -extern crate brotli2; - -use std::{net, thread, time}; use std::io::{Read, Write}; -use std::sync::{Arc, mpsc}; -use std::sync::atomic::{AtomicUsize, Ordering}; -use flate2::Compression; + +use actix_http::http::header::{ + ContentEncoding, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, + TRANSFER_ENCODING, +}; +use actix_http::{h1, Error, HttpService, Response}; +use actix_http_test::TestServer; +use brotli2::write::{BrotliDecoder, BrotliEncoder}; +use bytes::Bytes; use flate2::read::GzDecoder; -use flate2::write::{GzEncoder, DeflateEncoder, DeflateDecoder}; -#[cfg(feature="brotli")] -use brotli2::write::{BrotliEncoder, BrotliDecoder}; -use futures::{Future, Stream}; -use futures::stream::once; -use h2::client as h2client; -use bytes::{Bytes, BytesMut}; -use modhttp::Request; -use tokio_core::net::TcpStream; -use tokio_core::reactor::Core; -use rand::Rng; +use flate2::write::{GzEncoder, ZlibDecoder, ZlibEncoder}; +use flate2::Compression; +use futures::{future::ok, stream::once}; +use rand::{distributions::Alphanumeric, Rng}; -use actix::System; -use actix_web::*; +use actix_web::middleware::{BodyEncoding, Compress}; +use actix_web::{dev, http, web, App, HttpResponse, HttpServer}; +const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World"; -const STR: &str = - "Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World"; - -#[test] -fn test_start() { - let _ = test::TestServer::unused_addr(); - let (tx, rx) = mpsc::channel(); - - thread::spawn(move || { - let sys = System::new("test"); - let srv = HttpServer::new( - || vec![App::new() - .resource( - "/", |r| r.method(http::Method::GET) - .f(|_|HttpResponse::Ok()))]); - - let srv = srv.bind("127.0.0.1:0").unwrap(); - let addr = srv.addrs()[0]; - let srv_addr = srv.start(); - let _ = tx.send((addr, srv_addr)); - sys.run(); +#[actix_rt::test] +async fn test_body() { + let srv = TestServer::start(|| { + h1::H1Service::new( + App::new() + .service(web::resource("/").route(web::to(|| Response::Ok().body(STR)))), + ) }); - let (addr, srv_addr) = rx.recv().unwrap(); - let mut sys = System::new("test-server"); - - { - let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()).finish().unwrap(); - let response = sys.run_until_complete(req.send()).unwrap(); - assert!(response.status().is_success()); - } - - // pause - let _ = srv_addr.send(server::PauseServer).wait(); - thread::sleep(time::Duration::from_millis(100)); - assert!(net::TcpStream::connect(addr).is_err()); - - // resume - let _ = srv_addr.send(server::ResumeServer).wait(); - - { - let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()).finish().unwrap(); - let response = sys.run_until_complete(req.send()).unwrap(); - assert!(response.status().is_success()); - } -} - -#[test] -#[cfg(unix)] -fn test_shutdown() { - let _ = test::TestServer::unused_addr(); - let (tx, rx) = mpsc::channel(); - - thread::spawn(move || { - let sys = System::new("test"); - let srv = HttpServer::new( - || vec![App::new() - .resource( - "/", |r| r.method(http::Method::GET).f(|_| HttpResponse::Ok()))]); - - let srv = srv.bind("127.0.0.1:0").unwrap(); - let addr = srv.addrs()[0]; - let srv_addr = srv.shutdown_timeout(1).start(); - let _ = tx.send((addr, srv_addr)); - sys.run(); - }); - let (addr, srv_addr) = rx.recv().unwrap(); - - let mut sys = System::new("test-server"); - - { - let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()).finish().unwrap(); - let response = sys.run_until_complete(req.send()).unwrap(); - srv_addr.do_send(server::StopServer{graceful: true}); - assert!(response.status().is_success()); - } - - thread::sleep(time::Duration::from_millis(1000)); - assert!(net::TcpStream::connect(addr).is_err()); -} - -#[test] -fn test_simple() { - let mut srv = test::TestServer::new(|app| app.handler(|_| HttpResponse::Ok())); - let req = srv.get().finish().unwrap(); - let response = srv.execute(req.send()).unwrap(); - assert!(response.status().is_success()); -} - -#[test] -fn test_headers() { - let data = STR.repeat(10); - let srv_data = Arc::new(data.clone()); - let mut srv = test::TestServer::new( - move |app| { - let data = srv_data.clone(); - app.handler(move |_| { - let mut builder = HttpResponse::Ok(); - for idx in 0..90 { - builder.header( - format!("X-TEST-{}", idx).as_str(), - "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ - TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST "); - } - builder.body(data.as_ref())}) - }); - - let request = srv.get().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); + let mut response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from(data)); -} - -#[test] -fn test_body() { - let mut srv = test::TestServer::new( - |app| app.handler(|_| HttpResponse::Ok().body(STR))); - - let request = srv.get().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -#[test] -fn test_body_gzip() { - let mut srv = test::TestServer::new( - |app| app.handler( - |_| HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Gzip) - .body(STR))); +#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] +#[actix_rt::test] +async fn test_body_gzip() { + let srv = TestServer::start(|| { + h1::H1Service::new( + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service(web::resource("/").route(web::to(|| Response::Ok().body(STR)))), + ) + }); - let request = srv.get().disable_decompress().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); + let mut response = srv + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "gzip") + .send() + .await + .unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.execute(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); // decode let mut e = GzDecoder::new(&bytes[..]); @@ -216,25 +86,128 @@ fn test_body_gzip() { assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } -#[test] -fn test_body_gzip_large() { - let data = STR.repeat(10); - let srv_data = Arc::new(data.clone()); +#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] +#[actix_rt::test] +async fn test_body_gzip2() { + let srv = TestServer::start(|| { + h1::H1Service::new( + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service(web::resource("/").route(web::to(|| { + Response::Ok().body(STR).into_body::() + }))), + ) + }); - let mut srv = test::TestServer::new( - move |app| { - let data = srv_data.clone(); - app.handler( - move |_| HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Gzip) - .body(data.as_ref()))}); - - let request = srv.get().disable_decompress().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); + let mut response = srv + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "gzip") + .send() + .await + .unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.execute(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); + + // decode + let mut e = GzDecoder::new(&bytes[..]); + let mut dec = Vec::new(); + e.read_to_end(&mut dec).unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); +} + +#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] +#[actix_rt::test] +async fn test_body_encoding_override() { + let srv = TestServer::start(|| { + h1::H1Service::new( + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service(web::resource("/").route(web::to(|| { + Response::Ok().encoding(ContentEncoding::Deflate).body(STR) + }))) + .service(web::resource("/raw").route(web::to(|| { + let body = actix_web::dev::Body::Bytes(STR.into()); + let mut response = + Response::with_body(actix_web::http::StatusCode::OK, body); + + response.encoding(ContentEncoding::Deflate); + + response + }))), + ) + }); + + // Builder + let mut response = srv + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "deflate") + .send() + .await + .unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = response.body().await.unwrap(); + + // decode + let mut e = ZlibDecoder::new(Vec::new()); + e.write_all(bytes.as_ref()).unwrap(); + let dec = e.finish().unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + + // Raw Response + let mut response = srv + .request(actix_web::http::Method::GET, srv.url("/raw")) + .no_decompress() + .header(ACCEPT_ENCODING, "deflate") + .send() + .await + .unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = response.body().await.unwrap(); + + // decode + let mut e = ZlibDecoder::new(Vec::new()); + e.write_all(bytes.as_ref()).unwrap(); + let dec = e.finish().unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); +} + +#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] +#[actix_rt::test] +async fn test_body_gzip_large() { + let data = STR.repeat(10); + let srv_data = data.clone(); + + let srv = TestServer::start(move || { + let data = srv_data.clone(); + h1::H1Service::new( + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service( + web::resource("/") + .route(web::to(move || Response::Ok().body(data.clone()))), + ), + ) + }); + + let mut response = srv + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "gzip") + .send() + .await + .unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = response.body().await.unwrap(); // decode let mut e = GzDecoder::new(&bytes[..]); @@ -243,28 +216,38 @@ fn test_body_gzip_large() { assert_eq!(Bytes::from(dec), Bytes::from(data)); } -#[test] -fn test_body_gzip_large_random() { +#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] +#[actix_rt::test] +async fn test_body_gzip_large_random() { let data = rand::thread_rng() - .gen_ascii_chars() + .sample_iter(&Alphanumeric) .take(70_000) .collect::(); - let srv_data = Arc::new(data.clone()); + let srv_data = data.clone(); - let mut srv = test::TestServer::new( - move |app| { - let data = srv_data.clone(); - app.handler( - move |_| HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Gzip) - .body(data.as_ref()))}); + let srv = TestServer::start(move || { + let data = srv_data.clone(); + h1::H1Service::new( + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service( + web::resource("/") + .route(web::to(move || Response::Ok().body(data.clone()))), + ), + ) + }); - let request = srv.get().disable_decompress().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); + let mut response = srv + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "gzip") + .send() + .await + .unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.execute(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); // decode let mut e = GzDecoder::new(&bytes[..]); @@ -274,21 +257,36 @@ fn test_body_gzip_large_random() { assert_eq!(Bytes::from(dec), Bytes::from(data)); } -#[test] -fn test_body_chunked_implicit() { - let mut srv = test::TestServer::new( - |app| app.handler(|_| { - let body = once(Ok(Bytes::from_static(STR.as_ref()))); - HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Gzip) - .body(Body::Streaming(Box::new(body)))})); +#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] +#[actix_rt::test] +async fn test_body_chunked_implicit() { + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new() + .wrap(Compress::new(ContentEncoding::Gzip)) + .service(web::resource("/").route(web::get().to(move || { + Response::Ok().streaming(once(ok::<_, Error>(Bytes::from_static( + STR.as_ref(), + )))) + }))), + ) + }); - let request = srv.get().disable_decompress().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); + let mut response = srv + .get("/") + .no_decompress() + .header(ACCEPT_ENCODING, "gzip") + .send() + .await + .unwrap(); assert!(response.status().is_success()); + assert_eq!( + response.headers().get(TRANSFER_ENCODING).unwrap(), + &b"chunked"[..] + ); // read response - let bytes = srv.execute(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); // decode let mut e = GzDecoder::new(&bytes[..]); @@ -297,22 +295,29 @@ fn test_body_chunked_implicit() { assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } -#[cfg(feature="brotli")] -#[test] -fn test_body_br_streaming() { - let mut srv = test::TestServer::new( - |app| app.handler(|_| { - let body = once(Ok(Bytes::from_static(STR.as_ref()))); - HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Br) - .body(Body::Streaming(Box::new(body)))})); +#[actix_rt::test] +#[cfg(feature = "brotli")] +async fn test_body_br_streaming() { + let srv = TestServer::start(move || { + h1::H1Service::new(App::new().wrap(Compress::new(ContentEncoding::Br)).service( + web::resource("/").route(web::to(move || { + Response::Ok() + .streaming(once(ok::<_, Error>(Bytes::from_static(STR.as_ref())))) + })), + )) + }); - let request = srv.get().disable_decompress().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); + let mut response = srv + .get("/") + .header(ACCEPT_ENCODING, "br") + .no_decompress() + .send() + .await + .unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.execute(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); // decode br let mut e = BrotliDecoder::new(Vec::with_capacity(2048)); @@ -321,150 +326,102 @@ fn test_body_br_streaming() { assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } -#[test] -fn test_head_empty() { - let mut srv = test::TestServer::new( - |app| app.handler(|_| { - HttpResponse::Ok() - .content_length(STR.len() as u64).finish()})); +#[actix_rt::test] +async fn test_head_binary() { + let srv = TestServer::start(move || { + h1::H1Service::new(App::new().service(web::resource("/").route( + web::head().to(move || Response::Ok().content_length(100).body(STR)), + ))) + }); - let request = srv.head().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); + let mut response = srv.head("/").send().await.unwrap(); assert!(response.status().is_success()); { - let len = response.headers().get(http::header::CONTENT_LENGTH).unwrap(); + let len = response.headers().get(CONTENT_LENGTH).unwrap(); assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); } // read response - //let bytes = srv.execute(response.body()).unwrap(); - //assert!(bytes.is_empty()); + let bytes = response.body().await.unwrap(); + assert!(bytes.is_empty()); } -#[test] -fn test_head_binary() { - let mut srv = test::TestServer::new( - |app| app.handler(|_| { - HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Identity) - .content_length(100).body(STR)})); +#[actix_rt::test] +async fn test_no_chunking() { + let srv = TestServer::start(move || { + h1::H1Service::new(App::new().service(web::resource("/").route(web::to( + move || { + Response::Ok() + .no_chunking() + .content_length(STR.len() as u64) + .streaming(once(ok::<_, Error>(Bytes::from_static(STR.as_ref())))) + }, + )))) + }); - let request = srv.head().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); + let mut response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); - - { - let len = response.headers().get(http::header::CONTENT_LENGTH).unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } + assert!(!response.headers().contains_key(TRANSFER_ENCODING)); // read response - //let bytes = srv.execute(response.body()).unwrap(); - //assert!(bytes.is_empty()); -} - -#[test] -fn test_head_binary2() { - let mut srv = test::TestServer::new( - |app| app.handler(|_| { - HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Identity) - .body(STR) - })); - - let request = srv.head().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - { - let len = response.headers().get(http::header::CONTENT_LENGTH).unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } -} - -#[test] -fn test_body_length() { - let mut srv = test::TestServer::new( - |app| app.handler(|_| { - let body = once(Ok(Bytes::from_static(STR.as_ref()))); - HttpResponse::Ok() - .content_length(STR.len() as u64) - .content_encoding(http::ContentEncoding::Identity) - .body(Body::Streaming(Box::new(body)))})); - - let request = srv.get().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -#[test] -fn test_body_chunked_explicit() { - let mut srv = test::TestServer::new( - |app| app.handler(|_| { - let body = once(Ok(Bytes::from_static(STR.as_ref()))); - HttpResponse::Ok() - .chunked() - .content_encoding(http::ContentEncoding::Gzip) - .body(Body::Streaming(Box::new(body)))})); - - let request = srv.get().disable_decompress().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - - // decode - let mut e = GzDecoder::new(&bytes[..]); - let mut dec = Vec::new(); - e.read_to_end(&mut dec).unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); -} - -#[test] -fn test_body_deflate() { - let mut srv = test::TestServer::new( - |app| app.handler( - |_| HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Deflate) - .body(STR))); +#[actix_rt::test] +#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] +async fn test_body_deflate() { + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new() + .wrap(Compress::new(ContentEncoding::Deflate)) + .service( + web::resource("/").route(web::to(move || Response::Ok().body(STR))), + ), + ) + }); // client request - let request = srv.get().disable_decompress().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); + let mut response = srv + .get("/") + .header(ACCEPT_ENCODING, "deflate") + .no_decompress() + .send() + .await + .unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.execute(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); - // decode deflate - let mut e = DeflateDecoder::new(Vec::new()); + let mut e = ZlibDecoder::new(Vec::new()); e.write_all(bytes.as_ref()).unwrap(); let dec = e.finish().unwrap(); assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } -#[cfg(feature="brotli")] -#[test] -fn test_body_brotli() { - let mut srv = test::TestServer::new( - |app| app.handler( - |_| HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Br) - .body(STR))); +#[actix_rt::test] +#[cfg(any(feature = "brotli"))] +async fn test_body_brotli() { + let srv = TestServer::start(move || { + h1::H1Service::new(App::new().wrap(Compress::new(ContentEncoding::Br)).service( + web::resource("/").route(web::to(move || Response::Ok().body(STR))), + )) + }); // client request - let request = srv.get().disable_decompress().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); + let mut response = srv + .get("/") + .header(ACCEPT_ENCODING, "br") + .no_decompress() + .send() + .await + .unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.execute(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); // decode brotli let mut e = BrotliDecoder::new(Vec::with_capacity(2048)); @@ -473,362 +430,652 @@ fn test_body_brotli() { assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } -#[test] -fn test_gzip_encoding() { - let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { - req.body() - .and_then(|bytes: Bytes| { - Ok(HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Identity) - .body(bytes)) - }).responder()} - )); +#[actix_rt::test] +#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] +async fn test_encoding() { + let srv = TestServer::start(move || { + HttpService::new( + App::new().wrap(Compress::default()).service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); // client request let mut e = GzEncoder::new(Vec::new(), Compression::default()); e.write_all(STR.as_ref()).unwrap(); let enc = e.finish().unwrap(); - let request = srv.post() - .header(http::header::CONTENT_ENCODING, "gzip") - .body(enc.clone()).unwrap(); - let response = srv.execute(request.send()).unwrap(); + let request = srv + .post("/") + .header(CONTENT_ENCODING, "gzip") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.execute(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -#[test] -fn test_gzip_encoding_large() { +#[actix_rt::test] +#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] +async fn test_gzip_encoding() { + let srv = TestServer::start(move || { + HttpService::new( + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); + + // client request + let mut e = GzEncoder::new(Vec::new(), Compression::default()); + e.write_all(STR.as_ref()).unwrap(); + let enc = e.finish().unwrap(); + + let request = srv + .post("/") + .header(CONTENT_ENCODING, "gzip") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] +async fn test_gzip_encoding_large() { let data = STR.repeat(10); - let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { - req.body() - .and_then(|bytes: Bytes| { - Ok(HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Identity) - .body(bytes)) - }).responder()} - )); + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); // client request let mut e = GzEncoder::new(Vec::new(), Compression::default()); e.write_all(data.as_ref()).unwrap(); let enc = e.finish().unwrap(); - let request = srv.post() - .header(http::header::CONTENT_ENCODING, "gzip") - .body(enc.clone()).unwrap(); - let response = srv.execute(request.send()).unwrap(); + let request = srv + .post("/") + .header(CONTENT_ENCODING, "gzip") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.execute(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from(data)); } -#[test] -fn test_reading_gzip_encoding_large_random() { +#[actix_rt::test] +#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] +async fn test_reading_gzip_encoding_large_random() { let data = rand::thread_rng() - .gen_ascii_chars() + .sample_iter(&Alphanumeric) .take(60_000) .collect::(); - let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { - req.body() - .and_then(|bytes: Bytes| { - Ok(HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Identity) - .body(bytes)) - }).responder()} - )); + let srv = TestServer::start(move || { + HttpService::new( + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); // client request let mut e = GzEncoder::new(Vec::new(), Compression::default()); e.write_all(data.as_ref()).unwrap(); let enc = e.finish().unwrap(); - let request = srv.post() - .header(http::header::CONTENT_ENCODING, "gzip") - .body(enc.clone()).unwrap(); - let response = srv.execute(request.send()).unwrap(); + let request = srv + .post("/") + .header(CONTENT_ENCODING, "gzip") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.execute(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes.len(), data.len()); assert_eq!(bytes, Bytes::from(data)); } -#[test] -fn test_reading_deflate_encoding() { - let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { - req.body() - .and_then(|bytes: Bytes| { - Ok(HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Identity) - .body(bytes)) - }).responder()} - )); +#[actix_rt::test] +#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] +async fn test_reading_deflate_encoding() { + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); - let mut e = DeflateEncoder::new(Vec::new(), Compression::default()); + let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); e.write_all(STR.as_ref()).unwrap(); let enc = e.finish().unwrap(); // client request - let request = srv.post() - .header(http::header::CONTENT_ENCODING, "deflate") - .body(enc).unwrap(); - let response = srv.execute(request.send()).unwrap(); + let request = srv + .post("/") + .header(CONTENT_ENCODING, "deflate") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.execute(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -#[test] -fn test_reading_deflate_encoding_large() { +#[actix_rt::test] +#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] +async fn test_reading_deflate_encoding_large() { let data = STR.repeat(10); - let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { - req.body() - .and_then(|bytes: Bytes| { - Ok(HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Identity) - .body(bytes)) - }).responder()} - )); + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); - let mut e = DeflateEncoder::new(Vec::new(), Compression::default()); + let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); e.write_all(data.as_ref()).unwrap(); let enc = e.finish().unwrap(); // client request - let request = srv.post() - .header(http::header::CONTENT_ENCODING, "deflate") - .body(enc).unwrap(); - let response = srv.execute(request.send()).unwrap(); + let request = srv + .post("/") + .header(CONTENT_ENCODING, "deflate") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.execute(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from(data)); } -#[test] -fn test_reading_deflate_encoding_large_random() { +#[actix_rt::test] +#[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] +async fn test_reading_deflate_encoding_large_random() { let data = rand::thread_rng() - .gen_ascii_chars() + .sample_iter(&Alphanumeric) .take(160_000) .collect::(); - let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { - req.body() - .and_then(|bytes: Bytes| { - Ok(HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Identity) - .body(bytes)) - }).responder()} - )); + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); - let mut e = DeflateEncoder::new(Vec::new(), Compression::default()); + let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); e.write_all(data.as_ref()).unwrap(); let enc = e.finish().unwrap(); // client request - let request = srv.post() - .header(http::header::CONTENT_ENCODING, "deflate") - .body(enc).unwrap(); - let response = srv.execute(request.send()).unwrap(); + let request = srv + .post("/") + .header(CONTENT_ENCODING, "deflate") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.execute(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes.len(), data.len()); assert_eq!(bytes, Bytes::from(data)); } -#[cfg(feature="brotli")] -#[test] -fn test_brotli_encoding() { - let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { - req.body() - .and_then(|bytes: Bytes| { - Ok(HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Identity) - .body(bytes)) - }).responder()} - )); +#[actix_rt::test] +#[cfg(feature = "brotli")] +async fn test_brotli_encoding() { + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); let mut e = BrotliEncoder::new(Vec::new(), 5); e.write_all(STR.as_ref()).unwrap(); let enc = e.finish().unwrap(); // client request - let request = srv.post() - .header(http::header::CONTENT_ENCODING, "br") - .body(enc).unwrap(); - let response = srv.execute(request.send()).unwrap(); + let request = srv + .post("/") + .header(CONTENT_ENCODING, "br") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.execute(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } -#[cfg(feature="brotli")] -#[test] -fn test_brotli_encoding_large() { +#[cfg(feature = "brotli")] +#[actix_rt::test] +async fn test_brotli_encoding_large() { let data = STR.repeat(10); - let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { - req.body() - .and_then(|bytes: Bytes| { - Ok(HttpResponse::Ok() - .content_encoding(http::ContentEncoding::Identity) - .body(bytes)) - }).responder()} - )); + let srv = TestServer::start(move || { + h1::H1Service::new( + App::new().service( + web::resource("/") + .route(web::to(move |body: Bytes| Response::Ok().body(body))), + ), + ) + }); let mut e = BrotliEncoder::new(Vec::new(), 5); e.write_all(data.as_ref()).unwrap(); let enc = e.finish().unwrap(); // client request - let request = srv.post() - .header(http::header::CONTENT_ENCODING, "br") - .body(enc).unwrap(); - let response = srv.execute(request.send()).unwrap(); + let request = srv + .post("/") + .header(CONTENT_ENCODING, "br") + .send_body(enc.clone()); + let mut response = request.await.unwrap(); assert!(response.status().is_success()); // read response - let bytes = srv.execute(response.body()).unwrap(); + let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from(data)); } -#[test] -fn test_h2() { - let srv = test::TestServer::new(|app| app.handler(|_|{ - HttpResponse::Ok().body(STR) - })); - let addr = srv.addr(); +// #[cfg(all(feature = "brotli", feature = "ssl"))] +// #[actix_rt::test] +// async fn test_brotli_encoding_large_ssl() { +// use actix::{Actor, System}; +// use openssl::ssl::{ +// SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode, +// }; +// // load ssl keys +// let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); +// builder +// .set_private_key_file("tests/key.pem", SslFiletype::PEM) +// .unwrap(); +// builder +// .set_certificate_chain_file("tests/cert.pem") +// .unwrap(); - let mut core = Core::new().unwrap(); - let handle = core.handle(); - let tcp = TcpStream::connect(&addr, &handle); +// let data = STR.repeat(10); +// let srv = test::TestServer::build().ssl(builder).start(|app| { +// app.handler(|req: &HttpRequest| { +// req.body() +// .and_then(|bytes: Bytes| { +// Ok(HttpResponse::Ok() +// .content_encoding(http::ContentEncoding::Identity) +// .body(bytes)) +// }) +// .responder() +// }) +// }); +// let mut rt = System::new("test"); - let tcp = tcp.then(|res| { - h2client::handshake(res.unwrap()) - }).then(move |res| { - let (mut client, h2) = res.unwrap(); +// // client connector +// let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); +// builder.set_verify(SslVerifyMode::NONE); +// let conn = client::ClientConnector::with_connector(builder.build()).start(); - let request = Request::builder() - .uri(format!("https://{}/", addr).as_str()) - .body(()) - .unwrap(); - let (response, _) = client.send_request(request, false).unwrap(); +// // body +// let mut e = BrotliEncoder::new(Vec::new(), 5); +// e.write_all(data.as_ref()).unwrap(); +// let enc = e.finish().unwrap(); - // Spawn a task to run the conn... - handle.spawn(h2.map_err(|e| println!("GOT ERR={:?}", e))); +// // client request +// let request = client::ClientRequest::build() +// .uri(srv.url("/")) +// .method(http::Method::POST) +// .header(http::header::CONTENT_ENCODING, "br") +// .with_connector(conn) +// .body(enc) +// .unwrap(); +// let response = rt.block_on(request.send()).unwrap(); +// assert!(response.status().is_success()); - response.and_then(|response| { - assert_eq!(response.status(), http::StatusCode::OK); +// // read response +// let bytes = rt.block_on(response.body()).unwrap(); +// assert_eq!(bytes, Bytes::from(data)); +// } - let (_, body) = response.into_parts(); +#[cfg(all( + feature = "rustls", + feature = "openssl", + any(feature = "flate2-zlib", feature = "flate2-rust") +))] +#[actix_rt::test] +async fn test_reading_deflate_encoding_large_random_ssl() { + use open_ssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + use rust_tls::internal::pemfile::{certs, pkcs8_private_keys}; + use rust_tls::{NoClientAuth, ServerConfig}; + use std::fs::File; + use std::io::BufReader; + use std::sync::mpsc; - body.fold(BytesMut::new(), |mut b, c| -> Result<_, h2::Error> { - b.extend(c); - Ok(b) - }) + let addr = TestServer::unused_addr(); + let (tx, rx) = mpsc::channel(); + + let data = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(160_000) + .collect::(); + + std::thread::spawn(move || { + let sys = actix_rt::System::new("test"); + + // load ssl keys + let mut config = ServerConfig::new(NoClientAuth::new()); + let cert_file = &mut BufReader::new(File::open("tests/cert.pem").unwrap()); + let key_file = &mut BufReader::new(File::open("tests/key.pem").unwrap()); + let cert_chain = certs(cert_file).unwrap(); + let mut keys = pkcs8_private_keys(key_file).unwrap(); + config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); + + let srv = HttpServer::new(|| { + App::new().service(web::resource("/").route(web::to(|bytes: Bytes| { + async move { + Ok::<_, Error>( + HttpResponse::Ok() + .encoding(http::ContentEncoding::Identity) + .body(bytes), + ) + } + }))) }) + .bind_rustls(addr, config) + .unwrap() + .start(); + + let _ = tx.send((srv, actix_rt::System::current())); + let _ = sys.run(); }); - let _res = core.run(tcp); - // assert_eq!(res.unwrap(), Bytes::from_static(STR.as_ref())); -} + let (srv, _sys) = rx.recv().unwrap(); + let client = { + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let _ = builder.set_alpn_protos(b"\x02h2\x08http/1.1").unwrap(); -#[test] -fn test_application() { - let mut srv = test::TestServer::with_factory( - || App::new().resource("/", |r| r.f(|_| HttpResponse::Ok()))); + awc::Client::build() + .connector( + awc::Connector::new() + .timeout(std::time::Duration::from_millis(500)) + .ssl(builder.build()) + .finish(), + ) + .finish() + }; - let request = srv.get().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); -} + // encode data + let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); + e.write_all(data.as_ref()).unwrap(); + let enc = e.finish().unwrap(); -struct MiddlewareTest { - start: Arc, - response: Arc, - finish: Arc, -} + // client request + let req = client + .post(format!("https://localhost:{}/", addr.port())) + .header(http::header::CONTENT_ENCODING, "deflate") + .send_body(enc); -impl middleware::Middleware for MiddlewareTest { - fn start(&self, _: &mut HttpRequest) -> Result { - self.start.store(self.start.load(Ordering::Relaxed) + 1, Ordering::Relaxed); - Ok(middleware::Started::Done) - } - - fn response(&self, _: &mut HttpRequest, resp: HttpResponse) -> Result { - self.response.store(self.response.load(Ordering::Relaxed) + 1, Ordering::Relaxed); - Ok(middleware::Response::Done(resp)) - } - - fn finish(&self, _: &mut HttpRequest, _: &HttpResponse) -> middleware::Finished { - self.finish.store(self.finish.load(Ordering::Relaxed) + 1, Ordering::Relaxed); - middleware::Finished::Done - } -} - -#[test] -fn test_middlewares() { - let num1 = Arc::new(AtomicUsize::new(0)); - let num2 = Arc::new(AtomicUsize::new(0)); - let num3 = Arc::new(AtomicUsize::new(0)); - - let act_num1 = Arc::clone(&num1); - let act_num2 = Arc::clone(&num2); - let act_num3 = Arc::clone(&num3); - - let mut srv = test::TestServer::new( - move |app| app.middleware(MiddlewareTest{start: Arc::clone(&act_num1), - response: Arc::clone(&act_num2), - finish: Arc::clone(&act_num3)}) - .handler(|_| HttpResponse::Ok()) - ); - - let request = srv.get().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); + let mut response = req.await.unwrap(); assert!(response.status().is_success()); - assert_eq!(num1.load(Ordering::Relaxed), 1); - assert_eq!(num2.load(Ordering::Relaxed), 1); - assert_eq!(num3.load(Ordering::Relaxed), 1); + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes.len(), data.len()); + assert_eq!(bytes, Bytes::from(data)); + + // stop + let _ = srv.stop(false); } +// #[cfg(all(feature = "tls", feature = "ssl"))] +// #[test] +// fn test_reading_deflate_encoding_large_random_tls() { +// use native_tls::{Identity, TlsAcceptor}; +// use openssl::ssl::{ +// SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode, +// }; +// use std::fs::File; +// use std::sync::mpsc; -#[test] -fn test_resource_middlewares() { - let num1 = Arc::new(AtomicUsize::new(0)); - let num2 = Arc::new(AtomicUsize::new(0)); - let num3 = Arc::new(AtomicUsize::new(0)); +// use actix::{Actor, System}; +// let (tx, rx) = mpsc::channel(); - let act_num1 = Arc::clone(&num1); - let act_num2 = Arc::clone(&num2); - let act_num3 = Arc::clone(&num3); +// // load ssl keys +// let mut file = File::open("tests/identity.pfx").unwrap(); +// let mut identity = vec![]; +// file.read_to_end(&mut identity).unwrap(); +// let identity = Identity::from_pkcs12(&identity, "1").unwrap(); +// let acceptor = TlsAcceptor::new(identity).unwrap(); - let mut srv = test::TestServer::new( - move |app| app - .middleware(MiddlewareTest{start: Arc::clone(&act_num1), - response: Arc::clone(&act_num2), - finish: Arc::clone(&act_num3)}) - .handler(|_| HttpResponse::Ok()) - ); +// // load ssl keys +// let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); +// builder +// .set_private_key_file("tests/key.pem", SslFiletype::PEM) +// .unwrap(); +// builder +// .set_certificate_chain_file("tests/cert.pem") +// .unwrap(); - let request = srv.get().finish().unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); +// let data = rand::thread_rng() +// .sample_iter(&Alphanumeric) +// .take(160_000) +// .collect::(); - assert_eq!(num1.load(Ordering::Relaxed), 1); - assert_eq!(num2.load(Ordering::Relaxed), 1); - // assert_eq!(num3.load(Ordering::Relaxed), 1); -} +// let addr = test::TestServer::unused_addr(); +// thread::spawn(move || { +// System::run(move || { +// server::new(|| { +// App::new().handler("/", |req: &HttpRequest| { +// req.body() +// .and_then(|bytes: Bytes| { +// Ok(HttpResponse::Ok() +// .content_encoding(http::ContentEncoding::Identity) +// .body(bytes)) +// }) +// .responder() +// }) +// }) +// .bind_tls(addr, acceptor) +// .unwrap() +// .start(); +// let _ = tx.send(System::current()); +// }); +// }); +// let sys = rx.recv().unwrap(); + +// let mut rt = System::new("test"); + +// // client connector +// let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); +// builder.set_verify(SslVerifyMode::NONE); +// let conn = client::ClientConnector::with_connector(builder.build()).start(); + +// // encode data +// let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); +// e.write_all(data.as_ref()).unwrap(); +// let enc = e.finish().unwrap(); + +// // client request +// let request = client::ClientRequest::build() +// .uri(format!("https://{}/", addr)) +// .method(http::Method::POST) +// .header(http::header::CONTENT_ENCODING, "deflate") +// .with_connector(conn) +// .body(enc) +// .unwrap(); +// let response = rt.block_on(request.send()).unwrap(); +// assert!(response.status().is_success()); + +// // read response +// let bytes = rt.block_on(response.body()).unwrap(); +// assert_eq!(bytes.len(), data.len()); +// assert_eq!(bytes, Bytes::from(data)); + +// let _ = sys.stop(); +// } + +// #[test] +// fn test_server_cookies() { +// use actix_web::http; + +// let srv = test::TestServer::with_factory(|| { +// App::new().resource("/", |r| { +// r.f(|_| { +// HttpResponse::Ok() +// .cookie( +// http::CookieBuilder::new("first", "first_value") +// .http_only(true) +// .finish(), +// ) +// .cookie(http::Cookie::new("second", "first_value")) +// .cookie(http::Cookie::new("second", "second_value")) +// .finish() +// }) +// }) +// }); + +// let first_cookie = http::CookieBuilder::new("first", "first_value") +// .http_only(true) +// .finish(); +// let second_cookie = http::Cookie::new("second", "second_value"); + +// let request = srv.get("/").finish().unwrap(); +// let response = srv.execute(request.send()).unwrap(); +// assert!(response.status().is_success()); + +// let cookies = response.cookies().expect("To have cookies"); +// assert_eq!(cookies.len(), 2); +// if cookies[0] == first_cookie { +// assert_eq!(cookies[1], second_cookie); +// } else { +// assert_eq!(cookies[0], second_cookie); +// assert_eq!(cookies[1], first_cookie); +// } + +// let first_cookie = first_cookie.to_string(); +// let second_cookie = second_cookie.to_string(); +// //Check that we have exactly two instances of raw cookie headers +// let cookies = response +// .headers() +// .get_all(http::header::SET_COOKIE) +// .iter() +// .map(|header| header.to_str().expect("To str").to_string()) +// .collect::>(); +// assert_eq!(cookies.len(), 2); +// if cookies[0] == first_cookie { +// assert_eq!(cookies[1], second_cookie); +// } else { +// assert_eq!(cookies[0], second_cookie); +// assert_eq!(cookies[1], first_cookie); +// } +// } + +// #[test] +// fn test_slow_request() { +// use actix::System; +// use std::net; +// use std::sync::mpsc; +// let (tx, rx) = mpsc::channel(); + +// let addr = test::TestServer::unused_addr(); +// thread::spawn(move || { +// System::run(move || { +// let srv = server::new(|| { +// vec![App::new().resource("/", |r| { +// r.method(http::Method::GET).f(|_| HttpResponse::Ok()) +// })] +// }); + +// let srv = srv.bind(addr).unwrap(); +// srv.client_timeout(200).start(); +// let _ = tx.send(System::current()); +// }); +// }); +// let sys = rx.recv().unwrap(); + +// thread::sleep(time::Duration::from_millis(200)); + +// let mut stream = net::TcpStream::connect(addr).unwrap(); +// let mut data = String::new(); +// let _ = stream.read_to_string(&mut data); +// assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); + +// let mut stream = net::TcpStream::connect(addr).unwrap(); +// let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n"); +// let mut data = String::new(); +// let _ = stream.read_to_string(&mut data); +// assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); + +// sys.stop(); +// } + +// #[test] +// #[cfg(feature = "ssl")] +// fn test_ssl_handshake_timeout() { +// use actix::System; +// use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; +// use std::net; +// use std::sync::mpsc; + +// let (tx, rx) = mpsc::channel(); +// let addr = test::TestServer::unused_addr(); + +// // load ssl keys +// let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); +// builder +// .set_private_key_file("tests/key.pem", SslFiletype::PEM) +// .unwrap(); +// builder +// .set_certificate_chain_file("tests/cert.pem") +// .unwrap(); + +// thread::spawn(move || { +// System::run(move || { +// let srv = server::new(|| { +// App::new().resource("/", |r| { +// r.method(http::Method::GET).f(|_| HttpResponse::Ok()) +// }) +// }); + +// srv.bind_ssl(addr, builder) +// .unwrap() +// .workers(1) +// .client_timeout(200) +// .start(); +// let _ = tx.send(System::current()); +// }); +// }); +// let sys = rx.recv().unwrap(); + +// let mut stream = net::TcpStream::connect(addr).unwrap(); +// let mut data = String::new(); +// let _ = stream.read_to_string(&mut data); +// assert!(data.is_empty()); + +// let _ = sys.stop(); +// } diff --git a/tests/test_ws.rs b/tests/test_ws.rs deleted file mode 100644 index 6ebb69bda..000000000 --- a/tests/test_ws.rs +++ /dev/null @@ -1,193 +0,0 @@ -extern crate actix; -extern crate actix_web; -extern crate futures; -extern crate http; -extern crate bytes; -extern crate rand; - -use bytes::Bytes; -use futures::Stream; -use rand::Rng; - -#[cfg(feature="alpn")] -extern crate openssl; - -use actix_web::*; -use actix::prelude::*; - -struct Ws; - -impl Actor for Ws { - type Context = ws::WebsocketContext; -} - -impl StreamHandler for Ws { - - fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { - match msg { - ws::Message::Ping(msg) => ctx.pong(&msg), - ws::Message::Text(text) => ctx.text(text), - ws::Message::Binary(bin) => ctx.binary(bin), - ws::Message::Close(reason) => ctx.close(reason, ""), - _ => (), - } - } -} - -#[test] -fn test_simple() { - let mut srv = test::TestServer::new( - |app| app.handler(|req| ws::start(req, Ws))); - let (reader, mut writer) = srv.ws().unwrap(); - - writer.text("text"); - let (item, reader) = srv.execute(reader.into_future()).unwrap(); - assert_eq!(item, Some(ws::Message::Text("text".to_owned()))); - - writer.binary(b"text".as_ref()); - let (item, reader) = srv.execute(reader.into_future()).unwrap(); - assert_eq!(item, Some(ws::Message::Binary(Bytes::from_static(b"text").into()))); - - writer.ping("ping"); - let (item, reader) = srv.execute(reader.into_future()).unwrap(); - assert_eq!(item, Some(ws::Message::Pong("ping".to_owned()))); - - writer.close(ws::CloseCode::Normal, ""); - let (item, _) = srv.execute(reader.into_future()).unwrap(); - assert_eq!(item, Some(ws::Message::Close(ws::CloseCode::Normal))); -} - -#[test] -fn test_large_text() { - let data = rand::thread_rng() - .gen_ascii_chars() - .take(65_536) - .collect::(); - - let mut srv = test::TestServer::new( - |app| app.handler(|req| ws::start(req, Ws))); - let (mut reader, mut writer) = srv.ws().unwrap(); - - for _ in 0..100 { - writer.text(data.clone()); - let (item, r) = srv.execute(reader.into_future()).unwrap(); - reader = r; - assert_eq!(item, Some(ws::Message::Text(data.clone()))); - } -} - -#[test] -fn test_large_bin() { - let data = rand::thread_rng() - .gen_ascii_chars() - .take(65_536) - .collect::(); - - let mut srv = test::TestServer::new( - |app| app.handler(|req| ws::start(req, Ws))); - let (mut reader, mut writer) = srv.ws().unwrap(); - - for _ in 0..100 { - writer.binary(data.clone()); - let (item, r) = srv.execute(reader.into_future()).unwrap(); - reader = r; - assert_eq!(item, Some(ws::Message::Binary(Binary::from(data.clone())))); - } -} - -struct Ws2 { - count: usize, - bin: bool, -} - -impl Actor for Ws2 { - type Context = ws::WebsocketContext; - - fn started(&mut self, ctx: &mut Self::Context) { - self.send(ctx); - } -} - -impl Ws2 { - fn send(&mut self, ctx: &mut ws::WebsocketContext) { - if self.bin { - ctx.binary(Vec::from("0".repeat(65_536))); - } else { - ctx.text("0".repeat(65_536)); - } - ctx.drain().and_then(|_, act, ctx| { - act.count += 1; - if act.count != 10_000 { - act.send(ctx); - } - actix::fut::ok(()) - }).wait(ctx); - } -} - -impl StreamHandler for Ws2 { - - fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { - match msg { - ws::Message::Ping(msg) => ctx.pong(&msg), - ws::Message::Text(text) => ctx.text(text), - ws::Message::Binary(bin) => ctx.binary(bin), - ws::Message::Close(reason) => ctx.close(reason, ""), - _ => (), - } - } -} - -#[test] -fn test_server_send_text() { - let data = Some(ws::Message::Text("0".repeat(65_536))); - - let mut srv = test::TestServer::new( - |app| app.handler(|req| ws::start(req, Ws2{count:0, bin: false}))); - let (mut reader, _writer) = srv.ws().unwrap(); - - for _ in 0..10_000 { - let (item, r) = srv.execute(reader.into_future()).unwrap(); - reader = r; - assert_eq!(item, data); - } -} - -#[test] -fn test_server_send_bin() { - let data = Some(ws::Message::Binary(Binary::from("0".repeat(65_536)))); - - let mut srv = test::TestServer::new( - |app| app.handler(|req| ws::start(req, Ws2{count:0, bin: true}))); - let (mut reader, _writer) = srv.ws().unwrap(); - - for _ in 0..10_000 { - let (item, r) = srv.execute(reader.into_future()).unwrap(); - reader = r; - assert_eq!(item, data); - } -} - -#[test] -#[cfg(feature="alpn")] -fn test_ws_server_ssl() { - extern crate openssl; - use openssl::ssl::{SslMethod, SslAcceptor, SslFiletype}; - - // load ssl keys - let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - builder.set_private_key_file("tests/key.pem", SslFiletype::PEM).unwrap(); - builder.set_certificate_chain_file("tests/cert.pem").unwrap(); - - let mut srv = test::TestServer::build() - .ssl(builder.build()) - .start(|app| app.handler(|req| ws::start(req, Ws2{count:0, bin: false}))); - let (mut reader, _writer) = srv.ws().unwrap(); - - let data = Some(ws::Message::Text("0".repeat(65_536))); - for _ in 0..10_000 { - let (item, r) = srv.execute(reader.into_future()).unwrap(); - reader = r; - assert_eq!(item, data); - } -} diff --git a/tools/wsload/Cargo.toml b/tools/wsload/Cargo.toml deleted file mode 100644 index ff782817c..000000000 --- a/tools/wsload/Cargo.toml +++ /dev/null @@ -1,21 +0,0 @@ -[package] -name = "wsclient" -version = "0.1.0" -authors = ["Nikolay Kim "] -workspace = "../.." - -[[bin]] -name = "wsclient" -path = "src/wsclient.rs" - -[dependencies] -env_logger = "*" -futures = "0.1" -clap = "2" -url = "1.6" -rand = "0.4" -time = "*" -num_cpus = "1" -tokio-core = "0.1" -actix = "0.5" -actix-web = { path="../../" } diff --git a/tools/wsload/src/wsclient.rs b/tools/wsload/src/wsclient.rs deleted file mode 100644 index ab5cbe765..000000000 --- a/tools/wsload/src/wsclient.rs +++ /dev/null @@ -1,307 +0,0 @@ -//! Simple websocket client. - -#![allow(unused_variables)] -extern crate actix; -extern crate actix_web; -extern crate env_logger; -extern crate futures; -extern crate tokio_core; -extern crate url; -extern crate clap; -extern crate rand; -extern crate time; -extern crate num_cpus; - -use std::time::Duration; -use std::sync::Arc; -use std::sync::atomic::{AtomicUsize, Ordering}; -use futures::Future; -use rand::{thread_rng, Rng}; - -use actix::prelude::*; -use actix_web::ws; - - -fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=info"); - let _ = env_logger::init(); - - let matches = clap::App::new("ws tool") - .version("0.1") - .about("Applies load to websocket server") - .args_from_usage( - " 'WebSocket url' - [bin]... -b, 'use binary frames' - -s, --size=[NUMBER] 'size of PUBLISH packet payload to send in KB' - -w, --warm-up=[SECONDS] 'seconds before counter values are considered for reporting' - -r, --sample-rate=[SECONDS] 'seconds between average reports' - -c, --concurrency=[NUMBER] 'number of websocket connections to open and use concurrently for sending' - -t, --threads=[NUMBER] 'number of threads to use' - --max-payload=[NUMBER] 'max size of payload before reconnect KB'", - ) - .get_matches(); - - let bin: bool = matches.value_of("bin").is_some(); - let ws_url = matches.value_of("url").unwrap().to_owned(); - let _ = url::Url::parse(&ws_url).map_err(|e| { - println!("Invalid url: {}", ws_url); - std::process::exit(0); - }); - - let threads = parse_u64_default(matches.value_of("threads"), num_cpus::get() as u64); - let concurrency = parse_u64_default(matches.value_of("concurrency"), 1); - let payload_size: usize = match matches.value_of("size") { - Some(s) => parse_u64_default(Some(s), 1) as usize * 1024, - None => 1024, - }; - let max_payload_size: usize = match matches.value_of("max-payload") { - Some(s) => parse_u64_default(Some(s), 0) as usize * 1024, - None => 0, - }; - let warmup_seconds = parse_u64_default(matches.value_of("warm-up"), 2) as u64; - let sample_rate = parse_u64_default(matches.value_of("sample-rate"), 1) as usize; - - let perf_counters = Arc::new(PerfCounters::new()); - let payload = Arc::new(thread_rng() - .gen_ascii_chars() - .take(payload_size) - .collect::()); - - let sys = actix::System::new("ws-client"); - - let _: () = Perf{counters: perf_counters.clone(), - payload: payload.len(), - sample_rate_secs: sample_rate}.start(); - - for t in 0..threads { - let pl = payload.clone(); - let ws = ws_url.clone(); - let perf = perf_counters.clone(); - let addr = Arbiter::new(format!("test {}", t)); - - addr.do_send(actix::msgs::Execute::new(move || -> Result<(), ()> { - for _ in 0..concurrency { - let pl2 = pl.clone(); - let perf2 = perf.clone(); - let ws2 = ws.clone(); - - Arbiter::handle().spawn( - ws::Client::new(&ws) - .write_buffer_capacity(0) - .connect() - .map_err(|e| { - println!("Error: {}", e); - //Arbiter::system().do_send(actix::msgs::SystemExit(0)); - () - }) - .map(move |(reader, writer)| { - let addr: Addr = ChatClient::create(move |ctx| { - ChatClient::add_stream(reader, ctx); - ChatClient{url: ws2, - conn: writer, - payload: pl2, - bin: bin, - ts: time::precise_time_ns(), - perf_counters: perf2, - sent: 0, - max_payload_size: max_payload_size, - } - }); - }) - ); - } - Ok(()) - })); - } - - let res = sys.run(); -} - -fn parse_u64_default(input: Option<&str>, default: u64) -> u64 { - input.map(|v| v.parse().expect(&format!("not a valid number: {}", v))) - .unwrap_or(default) -} - -struct Perf { - counters: Arc, - payload: usize, - sample_rate_secs: usize, -} - -impl Actor for Perf { - type Context = Context; - - fn started(&mut self, ctx: &mut Context) { - self.sample_rate(ctx); - } -} - -impl Perf { - fn sample_rate(&self, ctx: &mut Context) { - ctx.run_later(Duration::new(self.sample_rate_secs as u64, 0), |act, ctx| { - let req_count = act.counters.pull_request_count(); - if req_count != 0 { - let conns = act.counters.pull_connections_count(); - let latency = act.counters.pull_latency_ns(); - let latency_max = act.counters.pull_latency_max_ns(); - println!( - "rate: {}, conns: {}, throughput: {:?} kb, latency: {}, latency max: {}", - req_count / act.sample_rate_secs, - conns / act.sample_rate_secs, - (((req_count * act.payload) as f64) / 1024.0) / - act.sample_rate_secs as f64, - time::Duration::nanoseconds((latency / req_count as u64) as i64), - time::Duration::nanoseconds(latency_max as i64) - ); - } - - act.sample_rate(ctx); - }); - } -} - -struct ChatClient{ - url: String, - conn: ws::ClientWriter, - payload: Arc, - ts: u64, - bin: bool, - perf_counters: Arc, - sent: usize, - max_payload_size: usize, -} - -impl Actor for ChatClient { - type Context = Context; - - fn started(&mut self, ctx: &mut Context) { - self.send_text(); - self.perf_counters.register_connection(); - } -} - -impl ChatClient { - - fn send_text(&mut self) -> bool { - self.sent += self.payload.len(); - - if self.max_payload_size > 0 && self.sent > self.max_payload_size { - let ws = self.url.clone(); - let pl = self.payload.clone(); - let bin = self.bin; - let perf_counters = self.perf_counters.clone(); - let max_payload_size = self.max_payload_size; - - Arbiter::handle().spawn( - ws::Client::new(&self.url).connect() - .map_err(|e| { - println!("Error: {}", e); - Arbiter::system().do_send(actix::msgs::SystemExit(0)); - () - }) - .map(move |(reader, writer)| { - let addr: Addr = ChatClient::create(move |ctx| { - ChatClient::add_stream(reader, ctx); - ChatClient{url: ws, - conn: writer, - payload: pl, - bin: bin, - ts: time::precise_time_ns(), - perf_counters: perf_counters, - sent: 0, - max_payload_size: max_payload_size, - } - }); - }) - ); - false - } else { - self.ts = time::precise_time_ns(); - if self.bin { - self.conn.binary(&self.payload); - } else { - self.conn.text(&self.payload); - } - true - } - } -} - -/// Handle server websocket messages -impl StreamHandler for ChatClient { - - fn finished(&mut self, ctx: &mut Context) { - ctx.stop() - } - - fn handle(&mut self, msg: ws::Message, ctx: &mut Context) { - match msg { - ws::Message::Text(txt) => { - if txt == self.payload.as_ref().as_str() { - self.perf_counters.register_request(); - self.perf_counters.register_latency(time::precise_time_ns() - self.ts); - if !self.send_text() { - ctx.stop(); - } - } else { - println!("not eaqual"); - } - }, - _ => () - } - } -} - - -pub struct PerfCounters { - req: AtomicUsize, - conn: AtomicUsize, - lat: AtomicUsize, - lat_max: AtomicUsize -} - -impl PerfCounters { - pub fn new() -> PerfCounters { - PerfCounters { - req: AtomicUsize::new(0), - conn: AtomicUsize::new(0), - lat: AtomicUsize::new(0), - lat_max: AtomicUsize::new(0), - } - } - - pub fn pull_request_count(&self) -> usize { - self.req.swap(0, Ordering::SeqCst) - } - - pub fn pull_connections_count(&self) -> usize { - self.conn.swap(0, Ordering::SeqCst) - } - - pub fn pull_latency_ns(&self) -> u64 { - self.lat.swap(0, Ordering::SeqCst) as u64 - } - - pub fn pull_latency_max_ns(&self) -> u64 { - self.lat_max.swap(0, Ordering::SeqCst) as u64 - } - - pub fn register_request(&self) { - self.req.fetch_add(1, Ordering::SeqCst); - } - - pub fn register_connection(&self) { - self.conn.fetch_add(1, Ordering::SeqCst); - } - - pub fn register_latency(&self, nanos: u64) { - let nanos = nanos as usize; - self.lat.fetch_add(nanos, Ordering::SeqCst); - loop { - let current = self.lat_max.load(Ordering::SeqCst); - if current >= nanos || self.lat_max.compare_and_swap(current, nanos, Ordering::SeqCst) == current { - break; - } - } - } -}
  • {}/
  • { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!( + f, + "\nRequest {:?} {}:{}", + self.version(), + self.method(), + self.path() + )?; + if let Some(q) = self.uri().query().as_ref() { + writeln!(f, " query: ?{:?}", q)?; + } + writeln!(f, " headers:")?; + for (key, val) in self.headers() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use http::HttpTryFrom; + + #[test] + fn test_basics() { + let msg = Message::new(); + let mut req = Request::from(msg); + req.headers_mut().insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("text/plain"), + ); + assert!(req.headers().contains_key(header::CONTENT_TYPE)); + + *req.uri_mut() = Uri::try_from("/index.html?q=1").unwrap(); + assert_eq!(req.uri().path(), "/index.html"); + assert_eq!(req.uri().query(), Some("q=1")); + + let s = format!("{:?}", req); + assert!(s.contains("Request HTTP/1.1 GET:/index.html")); + } +} diff --git a/actix-http/src/response.rs b/actix-http/src/response.rs new file mode 100644 index 000000000..e9147aa4b --- /dev/null +++ b/actix-http/src/response.rs @@ -0,0 +1,1090 @@ +//! Http response +use std::cell::{Ref, RefMut}; +use std::future::Future; +use std::io::Write; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, str}; + +use bytes::{BufMut, Bytes, BytesMut}; +use futures::stream::Stream; +use serde::Serialize; +use serde_json; + +use crate::body::{Body, BodyStream, MessageBody, ResponseBody}; +use crate::cookie::{Cookie, CookieJar}; +use crate::error::Error; +use crate::extensions::Extensions; +use crate::header::{Header, IntoHeaderValue}; +use crate::http::header::{self, HeaderName, HeaderValue}; +use crate::http::{Error as HttpError, HeaderMap, HttpTryFrom, StatusCode}; +use crate::message::{BoxedResponseHead, ConnectionType, ResponseHead}; + +/// An HTTP Response +pub struct Response { + head: BoxedResponseHead, + body: ResponseBody, + error: Option, +} + +impl Response { + /// Create http response builder with specific status. + #[inline] + pub fn build(status: StatusCode) -> ResponseBuilder { + ResponseBuilder::new(status) + } + + /// Create http response builder + #[inline] + pub fn build_from>(source: T) -> ResponseBuilder { + source.into() + } + + /// Constructs a response + #[inline] + pub fn new(status: StatusCode) -> Response { + Response { + head: BoxedResponseHead::new(status), + body: ResponseBody::Body(Body::Empty), + error: None, + } + } + + /// Constructs an error response + #[inline] + pub fn from_error(error: Error) -> Response { + let mut resp = error.as_response_error().error_response(); + if resp.head.status == StatusCode::INTERNAL_SERVER_ERROR { + error!("Internal Server Error: {:?}", error); + } + resp.error = Some(error); + resp + } + + /// Convert response to response with body + pub fn into_body(self) -> Response { + let b = match self.body { + ResponseBody::Body(b) => b, + ResponseBody::Other(b) => b, + }; + Response { + head: self.head, + error: self.error, + body: ResponseBody::Other(b), + } + } +} + +impl Response { + /// Constructs a response with body + #[inline] + pub fn with_body(status: StatusCode, body: B) -> Response { + Response { + head: BoxedResponseHead::new(status), + body: ResponseBody::Body(body), + error: None, + } + } + + #[inline] + /// Http message part of the response + pub fn head(&self) -> &ResponseHead { + &*self.head + } + + #[inline] + /// Mutable reference to a http message part of the response + pub fn head_mut(&mut self) -> &mut ResponseHead { + &mut *self.head + } + + /// The source `error` for this response + #[inline] + pub fn error(&self) -> Option<&Error> { + self.error.as_ref() + } + + /// Get the response status code + #[inline] + pub fn status(&self) -> StatusCode { + self.head.status + } + + /// Set the `StatusCode` for this response + #[inline] + pub fn status_mut(&mut self) -> &mut StatusCode { + &mut self.head.status + } + + /// Get the headers from the response + #[inline] + pub fn headers(&self) -> &HeaderMap { + &self.head.headers + } + + /// Get a mutable reference to the headers + #[inline] + pub fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.head.headers + } + + /// Get an iterator for the cookies set by this response + #[inline] + pub fn cookies(&self) -> CookieIter { + CookieIter { + iter: self.head.headers.get_all(header::SET_COOKIE), + } + } + + /// Add a cookie to this response + #[inline] + pub fn add_cookie(&mut self, cookie: &Cookie) -> Result<(), HttpError> { + let h = &mut self.head.headers; + HeaderValue::from_str(&cookie.to_string()) + .map(|c| { + h.append(header::SET_COOKIE, c); + }) + .map_err(|e| e.into()) + } + + /// Remove all cookies with the given name from this response. Returns + /// the number of cookies removed. + #[inline] + pub fn del_cookie(&mut self, name: &str) -> usize { + let h = &mut self.head.headers; + let vals: Vec = h + .get_all(header::SET_COOKIE) + .map(|v| v.to_owned()) + .collect(); + h.remove(header::SET_COOKIE); + + let mut count: usize = 0; + for v in vals { + if let Ok(s) = v.to_str() { + if let Ok(c) = Cookie::parse_encoded(s) { + if c.name() == name { + count += 1; + continue; + } + } + } + h.append(header::SET_COOKIE, v); + } + count + } + + /// Connection upgrade status + #[inline] + pub fn upgrade(&self) -> bool { + self.head.upgrade() + } + + /// Keep-alive status for this connection + pub fn keep_alive(&self) -> bool { + self.head.keep_alive() + } + + /// Responses extensions + #[inline] + pub fn extensions(&self) -> Ref { + self.head.extensions.borrow() + } + + /// Mutable reference to a the response's extensions + #[inline] + pub fn extensions_mut(&mut self) -> RefMut { + self.head.extensions.borrow_mut() + } + + /// Get body of this response + #[inline] + pub fn body(&self) -> &ResponseBody { + &self.body + } + + /// Set a body + pub fn set_body(self, body: B2) -> Response { + Response { + head: self.head, + body: ResponseBody::Body(body), + error: None, + } + } + + /// Split response and body + pub fn into_parts(self) -> (Response<()>, ResponseBody) { + ( + Response { + head: self.head, + body: ResponseBody::Body(()), + error: self.error, + }, + self.body, + ) + } + + /// Drop request's body + pub fn drop_body(self) -> Response<()> { + Response { + head: self.head, + body: ResponseBody::Body(()), + error: None, + } + } + + /// Set a body and return previous body value + pub(crate) fn replace_body(self, body: B2) -> (Response, ResponseBody) { + ( + Response { + head: self.head, + body: ResponseBody::Body(body), + error: self.error, + }, + self.body, + ) + } + + /// Set a body and return previous body value + pub fn map_body(mut self, f: F) -> Response + where + F: FnOnce(&mut ResponseHead, ResponseBody) -> ResponseBody, + { + let body = f(&mut self.head, self.body); + + Response { + body, + head: self.head, + error: self.error, + } + } + + /// Extract response body + pub fn take_body(&mut self) -> ResponseBody { + self.body.take_body() + } +} + +impl fmt::Debug for Response { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = writeln!( + f, + "\nResponse {:?} {}{}", + self.head.version, + self.head.status, + self.head.reason.unwrap_or(""), + ); + let _ = writeln!(f, " headers:"); + for (key, val) in self.head.headers.iter() { + let _ = writeln!(f, " {:?}: {:?}", key, val); + } + let _ = writeln!(f, " body: {:?}", self.body.size()); + res + } +} + +impl Future for Response { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, _: &mut Context) -> Poll { + Poll::Ready(Ok(Response { + head: self.head.take(), + body: self.body.take_body(), + error: self.error.take(), + })) + } +} + +pub struct CookieIter<'a> { + iter: header::GetAll<'a>, +} + +impl<'a> Iterator for CookieIter<'a> { + type Item = Cookie<'a>; + + #[inline] + fn next(&mut self) -> Option> { + for v in self.iter.by_ref() { + if let Ok(c) = Cookie::parse_encoded(v.to_str().ok()?) { + return Some(c); + } + } + None + } +} + +/// An HTTP response builder +/// +/// This type can be used to construct an instance of `Response` through a +/// builder-like pattern. +pub struct ResponseBuilder { + head: Option, + err: Option, + cookies: Option, +} + +impl ResponseBuilder { + #[inline] + /// Create response builder + pub fn new(status: StatusCode) -> Self { + ResponseBuilder { + head: Some(BoxedResponseHead::new(status)), + err: None, + cookies: None, + } + } + + /// Set HTTP status code of this response. + #[inline] + pub fn status(&mut self, status: StatusCode) -> &mut Self { + if let Some(parts) = parts(&mut self.head, &self.err) { + parts.status = status; + } + self + } + + /// Set a header. + /// + /// ```rust + /// use actix_http::{http, Request, Response, Result}; + /// + /// fn index(req: Request) -> Result { + /// Ok(Response::Ok() + /// .set(http::header::IfModifiedSince( + /// "Sun, 07 Nov 1994 08:48:37 GMT".parse()?, + /// )) + /// .finish()) + /// } + /// fn main() {} + /// ``` + #[doc(hidden)] + pub fn set(&mut self, hdr: H) -> &mut Self { + if let Some(parts) = parts(&mut self.head, &self.err) { + match hdr.try_into() { + Ok(value) => { + parts.headers.append(H::name(), value); + } + Err(e) => self.err = Some(e.into()), + } + } + self + } + + /// Append a header to existing headers. + /// + /// ```rust + /// use actix_http::{http, Request, Response}; + /// + /// fn index(req: Request) -> Response { + /// Response::Ok() + /// .header("X-TEST", "value") + /// .header(http::header::CONTENT_TYPE, "application/json") + /// .finish() + /// } + /// fn main() {} + /// ``` + pub fn header(&mut self, key: K, value: V) -> &mut Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + if let Some(parts) = parts(&mut self.head, &self.err) { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => { + parts.headers.append(key, value); + } + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + }; + } + self + } + + /// Set a header. + /// + /// ```rust + /// use actix_http::{http, Request, Response}; + /// + /// fn index(req: Request) -> Response { + /// Response::Ok() + /// .set_header("X-TEST", "value") + /// .set_header(http::header::CONTENT_TYPE, "application/json") + /// .finish() + /// } + /// fn main() {} + /// ``` + pub fn set_header(&mut self, key: K, value: V) -> &mut Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + if let Some(parts) = parts(&mut self.head, &self.err) { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => { + parts.headers.insert(key, value); + } + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + }; + } + self + } + + /// Set the custom reason for the response. + #[inline] + pub fn reason(&mut self, reason: &'static str) -> &mut Self { + if let Some(parts) = parts(&mut self.head, &self.err) { + parts.reason = Some(reason); + } + self + } + + /// Set connection type to KeepAlive + #[inline] + pub fn keep_alive(&mut self) -> &mut Self { + if let Some(parts) = parts(&mut self.head, &self.err) { + parts.set_connection_type(ConnectionType::KeepAlive); + } + self + } + + /// Set connection type to Upgrade + #[inline] + pub fn upgrade(&mut self, value: V) -> &mut Self + where + V: IntoHeaderValue, + { + if let Some(parts) = parts(&mut self.head, &self.err) { + parts.set_connection_type(ConnectionType::Upgrade); + } + self.set_header(header::UPGRADE, value) + } + + /// Force close connection, even if it is marked as keep-alive + #[inline] + pub fn force_close(&mut self) -> &mut Self { + if let Some(parts) = parts(&mut self.head, &self.err) { + parts.set_connection_type(ConnectionType::Close); + } + self + } + + /// Disable chunked transfer encoding for HTTP/1.1 streaming responses. + #[inline] + pub fn no_chunking(&mut self) -> &mut Self { + if let Some(parts) = parts(&mut self.head, &self.err) { + parts.no_chunking(true); + } + self + } + + /// Set response content type + #[inline] + pub fn content_type(&mut self, value: V) -> &mut Self + where + HeaderValue: HttpTryFrom, + { + if let Some(parts) = parts(&mut self.head, &self.err) { + match HeaderValue::try_from(value) { + Ok(value) => { + parts.headers.insert(header::CONTENT_TYPE, value); + } + Err(e) => self.err = Some(e.into()), + }; + } + self + } + + /// Set content length + #[inline] + pub fn content_length(&mut self, len: u64) -> &mut Self { + let mut wrt = BytesMut::new().writer(); + let _ = write!(wrt, "{}", len); + self.header(header::CONTENT_LENGTH, wrt.get_mut().take().freeze()) + } + + /// Set a cookie + /// + /// ```rust + /// use actix_http::{http, Request, Response}; + /// + /// fn index(req: Request) -> Response { + /// Response::Ok() + /// .cookie( + /// http::Cookie::build("name", "value") + /// .domain("www.rust-lang.org") + /// .path("/") + /// .secure(true) + /// .http_only(true) + /// .finish(), + /// ) + /// .finish() + /// } + /// ``` + pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self { + if self.cookies.is_none() { + let mut jar = CookieJar::new(); + jar.add(cookie.into_owned()); + self.cookies = Some(jar) + } else { + self.cookies.as_mut().unwrap().add(cookie.into_owned()); + } + self + } + + /// Remove cookie + /// + /// ```rust + /// use actix_http::{http, Request, Response, HttpMessage}; + /// + /// fn index(req: Request) -> Response { + /// let mut builder = Response::Ok(); + /// + /// if let Some(ref cookie) = req.cookie("name") { + /// builder.del_cookie(cookie); + /// } + /// + /// builder.finish() + /// } + /// ``` + pub fn del_cookie<'a>(&mut self, cookie: &Cookie<'a>) -> &mut Self { + if self.cookies.is_none() { + self.cookies = Some(CookieJar::new()) + } + let jar = self.cookies.as_mut().unwrap(); + let cookie = cookie.clone().into_owned(); + jar.add_original(cookie.clone()); + jar.remove(cookie); + self + } + + /// This method calls provided closure with builder reference if value is + /// true. + pub fn if_true(&mut self, value: bool, f: F) -> &mut Self + where + F: FnOnce(&mut ResponseBuilder), + { + if value { + f(self); + } + self + } + + /// This method calls provided closure with builder reference if value is + /// Some. + pub fn if_some(&mut self, value: Option, f: F) -> &mut Self + where + F: FnOnce(T, &mut ResponseBuilder), + { + if let Some(val) = value { + f(val, self); + } + self + } + + /// Responses extensions + #[inline] + pub fn extensions(&self) -> Ref { + let head = self.head.as_ref().expect("cannot reuse response builder"); + head.extensions.borrow() + } + + /// Mutable reference to a the response's extensions + #[inline] + pub fn extensions_mut(&mut self) -> RefMut { + let head = self.head.as_ref().expect("cannot reuse response builder"); + head.extensions.borrow_mut() + } + + #[inline] + /// Set a body and generate `Response`. + /// + /// `ResponseBuilder` can not be used after this call. + pub fn body>(&mut self, body: B) -> Response { + self.message_body(body.into()) + } + + /// Set a body and generate `Response`. + /// + /// `ResponseBuilder` can not be used after this call. + pub fn message_body(&mut self, body: B) -> Response { + if let Some(e) = self.err.take() { + return Response::from(Error::from(e)).into_body(); + } + + let mut response = self.head.take().expect("cannot reuse response builder"); + + if let Some(ref jar) = self.cookies { + for cookie in jar.delta() { + match HeaderValue::from_str(&cookie.to_string()) { + Ok(val) => response.headers.append(header::SET_COOKIE, val), + Err(e) => return Response::from(Error::from(e)).into_body(), + }; + } + } + + Response { + head: response, + body: ResponseBody::Body(body), + error: None, + } + } + + #[inline] + /// Set a streaming body and generate `Response`. + /// + /// `ResponseBuilder` can not be used after this call. + pub fn streaming(&mut self, stream: S) -> Response + where + S: Stream> + 'static, + E: Into + 'static, + { + self.body(Body::from_message(BodyStream::new(stream))) + } + + #[inline] + /// Set a json body and generate `Response` + /// + /// `ResponseBuilder` can not be used after this call. + pub fn json(&mut self, value: T) -> Response { + self.json2(&value) + } + + /// Set a json body and generate `Response` + /// + /// `ResponseBuilder` can not be used after this call. + pub fn json2(&mut self, value: &T) -> Response { + match serde_json::to_string(value) { + Ok(body) => { + let contains = if let Some(parts) = parts(&mut self.head, &self.err) { + parts.headers.contains_key(header::CONTENT_TYPE) + } else { + true + }; + if !contains { + self.header(header::CONTENT_TYPE, "application/json"); + } + + self.body(Body::from(body)) + } + Err(e) => Error::from(e).into(), + } + } + + #[inline] + /// Set an empty body and generate `Response` + /// + /// `ResponseBuilder` can not be used after this call. + pub fn finish(&mut self) -> Response { + self.body(Body::Empty) + } + + /// This method construct new `ResponseBuilder` + pub fn take(&mut self) -> ResponseBuilder { + ResponseBuilder { + head: self.head.take(), + err: self.err.take(), + cookies: self.cookies.take(), + } + } +} + +#[inline] +fn parts<'a>( + parts: &'a mut Option, + err: &Option, +) -> Option<&'a mut ResponseHead> { + if err.is_some() { + return None; + } + parts.as_mut().map(|r| &mut **r) +} + +/// Convert `Response` to a `ResponseBuilder`. Body get dropped. +impl From> for ResponseBuilder { + fn from(res: Response) -> ResponseBuilder { + // If this response has cookies, load them into a jar + let mut jar: Option = None; + for c in res.cookies() { + if let Some(ref mut j) = jar { + j.add_original(c.into_owned()); + } else { + let mut j = CookieJar::new(); + j.add_original(c.into_owned()); + jar = Some(j); + } + } + + ResponseBuilder { + head: Some(res.head), + err: None, + cookies: jar, + } + } +} + +/// Convert `ResponseHead` to a `ResponseBuilder` +impl<'a> From<&'a ResponseHead> for ResponseBuilder { + fn from(head: &'a ResponseHead) -> ResponseBuilder { + // If this response has cookies, load them into a jar + let mut jar: Option = None; + + let cookies = CookieIter { + iter: head.headers.get_all(header::SET_COOKIE), + }; + for c in cookies { + if let Some(ref mut j) = jar { + j.add_original(c.into_owned()); + } else { + let mut j = CookieJar::new(); + j.add_original(c.into_owned()); + jar = Some(j); + } + } + + let mut msg = BoxedResponseHead::new(head.status); + msg.version = head.version; + msg.reason = head.reason; + for (k, v) in &head.headers { + msg.headers.append(k.clone(), v.clone()); + } + msg.no_chunking(!head.chunked()); + + ResponseBuilder { + head: Some(msg), + err: None, + cookies: jar, + } + } +} + +impl Future for ResponseBuilder { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, _: &mut Context) -> Poll { + Poll::Ready(Ok(self.finish())) + } +} + +impl fmt::Debug for ResponseBuilder { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let head = self.head.as_ref().unwrap(); + + let res = writeln!( + f, + "\nResponseBuilder {:?} {}{}", + head.version, + head.status, + head.reason.unwrap_or(""), + ); + let _ = writeln!(f, " headers:"); + for (key, val) in head.headers.iter() { + let _ = writeln!(f, " {:?}: {:?}", key, val); + } + res + } +} + +/// Helper converters +impl, E: Into> From> for Response { + fn from(res: Result) -> Self { + match res { + Ok(val) => val.into(), + Err(err) => err.into().into(), + } + } +} + +impl From for Response { + fn from(mut builder: ResponseBuilder) -> Self { + builder.finish() + } +} + +impl From<&'static str> for Response { + fn from(val: &'static str) -> Self { + Response::Ok() + .content_type("text/plain; charset=utf-8") + .body(val) + } +} + +impl From<&'static [u8]> for Response { + fn from(val: &'static [u8]) -> Self { + Response::Ok() + .content_type("application/octet-stream") + .body(val) + } +} + +impl From for Response { + fn from(val: String) -> Self { + Response::Ok() + .content_type("text/plain; charset=utf-8") + .body(val) + } +} + +impl<'a> From<&'a String> for Response { + fn from(val: &'a String) -> Self { + Response::Ok() + .content_type("text/plain; charset=utf-8") + .body(val) + } +} + +impl From for Response { + fn from(val: Bytes) -> Self { + Response::Ok() + .content_type("application/octet-stream") + .body(val) + } +} + +impl From for Response { + fn from(val: BytesMut) -> Self { + Response::Ok() + .content_type("application/octet-stream") + .body(val) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::body::Body; + use crate::http::header::{HeaderValue, CONTENT_TYPE, COOKIE, SET_COOKIE}; + + #[test] + fn test_debug() { + let resp = Response::Ok() + .header(COOKIE, HeaderValue::from_static("cookie1=value1; ")) + .header(COOKIE, HeaderValue::from_static("cookie2=value2; ")) + .finish(); + let dbg = format!("{:?}", resp); + assert!(dbg.contains("Response")); + } + + #[test] + fn test_response_cookies() { + use crate::httpmessage::HttpMessage; + + let req = crate::test::TestRequest::default() + .header(COOKIE, "cookie1=value1") + .header(COOKIE, "cookie2=value2") + .finish(); + let cookies = req.cookies().unwrap(); + + let resp = Response::Ok() + .cookie( + crate::http::Cookie::build("name", "value") + .domain("www.rust-lang.org") + .path("/test") + .http_only(true) + .max_age_time(time::Duration::days(1)) + .finish(), + ) + .del_cookie(&cookies[1]) + .finish(); + + let mut val: Vec<_> = resp + .headers() + .get_all(SET_COOKIE) + .map(|v| v.to_str().unwrap().to_owned()) + .collect(); + val.sort(); + assert!(val[0].starts_with("cookie1=; Max-Age=0;")); + assert_eq!( + val[1], + "name=value; HttpOnly; Path=/test; Domain=www.rust-lang.org; Max-Age=86400" + ); + } + + #[test] + fn test_update_response_cookies() { + let mut r = Response::Ok() + .cookie(crate::http::Cookie::new("original", "val100")) + .finish(); + + r.add_cookie(&crate::http::Cookie::new("cookie2", "val200")) + .unwrap(); + r.add_cookie(&crate::http::Cookie::new("cookie2", "val250")) + .unwrap(); + r.add_cookie(&crate::http::Cookie::new("cookie3", "val300")) + .unwrap(); + + assert_eq!(r.cookies().count(), 4); + r.del_cookie("cookie2"); + + let mut iter = r.cookies(); + let v = iter.next().unwrap(); + assert_eq!((v.name(), v.value()), ("cookie3", "val300")); + let v = iter.next().unwrap(); + assert_eq!((v.name(), v.value()), ("original", "val100")); + } + + #[test] + fn test_basic_builder() { + let resp = Response::Ok().header("X-TEST", "value").finish(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[test] + fn test_upgrade() { + let resp = Response::build(StatusCode::OK) + .upgrade("websocket") + .finish(); + assert!(resp.upgrade()); + assert_eq!( + resp.headers().get(header::UPGRADE).unwrap(), + HeaderValue::from_static("websocket") + ); + } + + #[test] + fn test_force_close() { + let resp = Response::build(StatusCode::OK).force_close().finish(); + assert!(!resp.keep_alive()) + } + + #[test] + fn test_content_type() { + let resp = Response::build(StatusCode::OK) + .content_type("text/plain") + .body(Body::Empty); + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain") + } + + #[test] + fn test_json() { + let resp = Response::build(StatusCode::OK).json(vec!["v1", "v2", "v3"]); + let ct = resp.headers().get(CONTENT_TYPE).unwrap(); + assert_eq!(ct, HeaderValue::from_static("application/json")); + assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); + } + + #[test] + fn test_json_ct() { + let resp = Response::build(StatusCode::OK) + .header(CONTENT_TYPE, "text/json") + .json(vec!["v1", "v2", "v3"]); + let ct = resp.headers().get(CONTENT_TYPE).unwrap(); + assert_eq!(ct, HeaderValue::from_static("text/json")); + assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); + } + + #[test] + fn test_json2() { + let resp = Response::build(StatusCode::OK).json2(&vec!["v1", "v2", "v3"]); + let ct = resp.headers().get(CONTENT_TYPE).unwrap(); + assert_eq!(ct, HeaderValue::from_static("application/json")); + assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); + } + + #[test] + fn test_json2_ct() { + let resp = Response::build(StatusCode::OK) + .header(CONTENT_TYPE, "text/json") + .json2(&vec!["v1", "v2", "v3"]); + let ct = resp.headers().get(CONTENT_TYPE).unwrap(); + assert_eq!(ct, HeaderValue::from_static("text/json")); + assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); + } + + #[test] + fn test_serde_json_in_body() { + use serde_json::json; + let resp = + Response::build(StatusCode::OK).body(json!({"test-key":"test-value"})); + assert_eq!(resp.body().get_ref(), br#"{"test-key":"test-value"}"#); + } + + #[test] + fn test_into_response() { + let resp: Response = "test".into(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().get_ref(), b"test"); + + let resp: Response = b"test".as_ref().into(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().get_ref(), b"test"); + + let resp: Response = "test".to_owned().into(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().get_ref(), b"test"); + + let resp: Response = (&"test".to_owned()).into(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().get_ref(), b"test"); + + let b = Bytes::from_static(b"test"); + let resp: Response = b.into(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().get_ref(), b"test"); + + let b = Bytes::from_static(b"test"); + let resp: Response = b.into(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().get_ref(), b"test"); + + let b = BytesMut::from("test"); + let resp: Response = b.into(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().get_ref(), b"test"); + } + + #[test] + fn test_into_builder() { + let mut resp: Response = "test".into(); + assert_eq!(resp.status(), StatusCode::OK); + + resp.add_cookie(&crate::http::Cookie::new("cookie1", "val100")) + .unwrap(); + + let mut builder: ResponseBuilder = resp.into(); + let resp = builder.status(StatusCode::BAD_REQUEST).finish(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + let cookie = resp.cookies().next().unwrap(); + assert_eq!((cookie.name(), cookie.value()), ("cookie1", "val100")); + } +} diff --git a/actix-http/src/service.rs b/actix-http/src/service.rs new file mode 100644 index 000000000..7340c15fd --- /dev/null +++ b/actix-http/src/service.rs @@ -0,0 +1,703 @@ +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, io, net, rc}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_server_config::{ + Io as ServerIo, IoStream, Protocol, ServerConfig as SrvConfig, +}; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; +use bytes::{BufMut, Bytes, BytesMut}; +use futures::{ready, Future}; +use h2::server::{self, Handshake}; +use pin_project::{pin_project, project}; + +use crate::body::MessageBody; +use crate::builder::HttpServiceBuilder; +use crate::cloneable::CloneableService; +use crate::config::{KeepAlive, ServiceConfig}; +use crate::error::{DispatchError, Error}; +use crate::helpers::DataFactory; +use crate::request::Request; +use crate::response::Response; +use crate::{h1, h2::Dispatcher}; + +/// `ServiceFactory` HTTP1.1/HTTP2 transport implementation +pub struct HttpService> { + srv: S, + cfg: ServiceConfig, + expect: X, + upgrade: Option, + on_connect: Option Box>>, + _t: PhantomData<(T, P, B)>, +} + +impl HttpService +where + S: ServiceFactory, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody + 'static, +{ + /// Create builder for `HttpService` instance. + pub fn build() -> HttpServiceBuilder { + HttpServiceBuilder::new() + } +} + +impl HttpService +where + S: ServiceFactory, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody + 'static, +{ + /// Create new `HttpService` instance. + pub fn new>(service: F) -> Self { + let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); + + HttpService { + cfg, + srv: service.into_factory(), + expect: h1::ExpectHandler, + upgrade: None, + on_connect: None, + _t: PhantomData, + } + } + + /// Create new `HttpService` instance with config. + pub(crate) fn with_config>( + cfg: ServiceConfig, + service: F, + ) -> Self { + HttpService { + cfg, + srv: service.into_factory(), + expect: h1::ExpectHandler, + upgrade: None, + on_connect: None, + _t: PhantomData, + } + } +} + +impl HttpService +where + S: ServiceFactory, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody, +{ + /// Provide service for `EXPECT: 100-Continue` support. + /// + /// Service get called with request that contains `EXPECT` header. + /// Service must return request in case of success, in that case + /// request will be forwarded to main service. + pub fn expect(self, expect: X1) -> HttpService + where + X1: ServiceFactory, + X1::Error: Into, + X1::InitError: fmt::Debug, + ::Future: 'static, + { + HttpService { + expect, + cfg: self.cfg, + srv: self.srv, + upgrade: self.upgrade, + on_connect: self.on_connect, + _t: PhantomData, + } + } + + /// Provide service for custom `Connection: UPGRADE` support. + /// + /// If service is provided then normal requests handling get halted + /// and this service get called with original request and framed object. + pub fn upgrade(self, upgrade: Option) -> HttpService + where + U1: ServiceFactory< + Config = SrvConfig, + Request = (Request, Framed), + Response = (), + >, + U1::Error: fmt::Display, + U1::InitError: fmt::Debug, + ::Future: 'static, + { + HttpService { + upgrade, + cfg: self.cfg, + srv: self.srv, + expect: self.expect, + on_connect: self.on_connect, + _t: PhantomData, + } + } + + /// Set on connect callback. + pub(crate) fn on_connect( + mut self, + f: Option Box>>, + ) -> Self { + self.on_connect = f; + self + } +} + +impl ServiceFactory for HttpService +where + T: IoStream, + S: ServiceFactory, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody + 'static, + X: ServiceFactory, + X::Error: Into, + X::InitError: fmt::Debug, + ::Future: 'static, + U: ServiceFactory< + Config = SrvConfig, + Request = (Request, Framed), + Response = (), + >, + U::Error: fmt::Display, + U::InitError: fmt::Debug, + ::Future: 'static, +{ + type Config = SrvConfig; + type Request = ServerIo; + type Response = (); + type Error = DispatchError; + type InitError = (); + type Service = HttpServiceHandler; + type Future = HttpServiceResponse; + + fn new_service(&self, cfg: &SrvConfig) -> Self::Future { + HttpServiceResponse { + fut: self.srv.new_service(cfg), + fut_ex: Some(self.expect.new_service(cfg)), + fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)), + expect: None, + upgrade: None, + on_connect: self.on_connect.clone(), + cfg: Some(self.cfg.clone()), + _t: PhantomData, + } + } +} + +#[doc(hidden)] +#[pin_project] +pub struct HttpServiceResponse< + T, + P, + S: ServiceFactory, + B, + X: ServiceFactory, + U: ServiceFactory, +> { + #[pin] + fut: S::Future, + #[pin] + fut_ex: Option, + #[pin] + fut_upg: Option, + expect: Option, + upgrade: Option, + on_connect: Option Box>>, + cfg: Option, + _t: PhantomData<(T, P, B)>, +} + +impl Future for HttpServiceResponse +where + T: IoStream, + S: ServiceFactory, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + S::Response: Into> + 'static, + ::Future: 'static, + B: MessageBody + 'static, + X: ServiceFactory, + X::Error: Into, + X::InitError: fmt::Debug, + ::Future: 'static, + U: ServiceFactory), Response = ()>, + U::Error: fmt::Display, + U::InitError: fmt::Debug, + ::Future: 'static, +{ + type Output = + Result, ()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut this = self.as_mut().project(); + + if let Some(fut) = this.fut_ex.as_pin_mut() { + let expect = ready!(fut + .poll(cx) + .map_err(|e| log::error!("Init http service error: {:?}", e)))?; + this = self.as_mut().project(); + *this.expect = Some(expect); + this.fut_ex.set(None); + } + + if let Some(fut) = this.fut_upg.as_pin_mut() { + let upgrade = ready!(fut + .poll(cx) + .map_err(|e| log::error!("Init http service error: {:?}", e)))?; + this = self.as_mut().project(); + *this.upgrade = Some(upgrade); + this.fut_ex.set(None); + } + + let result = ready!(this + .fut + .poll(cx) + .map_err(|e| log::error!("Init http service error: {:?}", e))); + Poll::Ready(result.map(|service| { + let this = self.as_mut().project(); + HttpServiceHandler::new( + this.cfg.take().unwrap(), + service, + this.expect.take().unwrap(), + this.upgrade.take(), + this.on_connect.clone(), + ) + })) + } +} + +/// `Service` implementation for http transport +pub struct HttpServiceHandler { + srv: CloneableService, + expect: CloneableService, + upgrade: Option>, + cfg: ServiceConfig, + on_connect: Option Box>>, + _t: PhantomData<(T, P, B, X)>, +} + +impl HttpServiceHandler +where + S: Service, + S::Error: Into + 'static, + S::Future: 'static, + S::Response: Into> + 'static, + B: MessageBody + 'static, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + fn new( + cfg: ServiceConfig, + srv: S, + expect: X, + upgrade: Option, + on_connect: Option Box>>, + ) -> HttpServiceHandler { + HttpServiceHandler { + cfg, + on_connect, + srv: CloneableService::new(srv), + expect: CloneableService::new(expect), + upgrade: upgrade.map(CloneableService::new), + _t: PhantomData, + } + } +} + +impl Service for HttpServiceHandler +where + T: IoStream, + S: Service, + S::Error: Into + 'static, + S::Future: 'static, + S::Response: Into> + 'static, + B: MessageBody + 'static, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + type Request = ServerIo; + type Response = (); + type Error = DispatchError; + type Future = HttpServiceHandlerResponse; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + let ready = self + .expect + .poll_ready(cx) + .map_err(|e| { + let e = e.into(); + log::error!("Http service readiness error: {:?}", e); + DispatchError::Service(e) + })? + .is_ready(); + + let ready = self + .srv + .poll_ready(cx) + .map_err(|e| { + let e = e.into(); + log::error!("Http service readiness error: {:?}", e); + DispatchError::Service(e) + })? + .is_ready() + && ready; + + if ready { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + let (io, _, proto) = req.into_parts(); + + let on_connect = if let Some(ref on_connect) = self.on_connect { + Some(on_connect(&io)) + } else { + None + }; + + match proto { + Protocol::Http2 => { + let peer_addr = io.peer_addr(); + let io = Io { + inner: io, + unread: None, + }; + HttpServiceHandlerResponse { + state: State::Handshake(Some(( + server::handshake(io), + self.cfg.clone(), + self.srv.clone(), + peer_addr, + on_connect, + ))), + } + } + Protocol::Http10 | Protocol::Http11 => HttpServiceHandlerResponse { + state: State::H1(h1::Dispatcher::new( + io, + self.cfg.clone(), + self.srv.clone(), + self.expect.clone(), + self.upgrade.clone(), + on_connect, + )), + }, + _ => HttpServiceHandlerResponse { + state: State::Unknown(Some(( + io, + BytesMut::with_capacity(14), + self.cfg.clone(), + self.srv.clone(), + self.expect.clone(), + self.upgrade.clone(), + on_connect, + ))), + }, + } + } +} + +#[pin_project] +enum State +where + S: Service, + S::Future: 'static, + S::Error: Into, + T: IoStream, + B: MessageBody, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + H1(#[pin] h1::Dispatcher), + H2(#[pin] Dispatcher, S, B>), + Unknown( + Option<( + T, + BytesMut, + ServiceConfig, + CloneableService, + CloneableService, + Option>, + Option>, + )>, + ), + Handshake( + Option<( + Handshake, Bytes>, + ServiceConfig, + CloneableService, + Option, + Option>, + )>, + ), +} + +#[pin_project] +pub struct HttpServiceHandlerResponse +where + T: IoStream, + S: Service, + S::Error: Into + 'static, + S::Future: 'static, + S::Response: Into> + 'static, + B: MessageBody + 'static, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + #[pin] + state: State, +} + +const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0"; + +impl Future for HttpServiceHandlerResponse +where + T: IoStream, + S: Service, + S::Error: Into + 'static, + S::Future: 'static, + S::Response: Into> + 'static, + B: MessageBody, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + type Output = Result<(), DispatchError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + self.project().state.poll(cx) + } +} + +impl State +where + T: IoStream, + S: Service, + S::Error: Into + 'static, + S::Response: Into> + 'static, + B: MessageBody + 'static, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + #[project] + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + #[project] + match self.as_mut().project() { + State::H1(disp) => disp.poll(cx), + State::H2(disp) => disp.poll(cx), + State::Unknown(ref mut data) => { + if let Some(ref mut item) = data { + loop { + // Safety - we only write to the returned slice. + let b = unsafe { item.1.bytes_mut() }; + let n = ready!(Pin::new(&mut item.0).poll_read(cx, b))?; + if n == 0 { + return Poll::Ready(Ok(())); + } + // Safety - we know that 'n' bytes have + // been initialized via the contract of + // 'poll_read' + unsafe { item.1.advance_mut(n) }; + if item.1.len() >= HTTP2_PREFACE.len() { + break; + } + } + } else { + panic!() + } + let (io, buf, cfg, srv, expect, upgrade, on_connect) = + data.take().unwrap(); + if buf[..14] == HTTP2_PREFACE[..] { + let peer_addr = io.peer_addr(); + let io = Io { + inner: io, + unread: Some(buf), + }; + self.set(State::Handshake(Some(( + server::handshake(io), + cfg, + srv, + peer_addr, + on_connect, + )))); + } else { + self.set(State::H1(h1::Dispatcher::with_timeout( + io, + h1::Codec::new(cfg.clone()), + cfg, + buf, + None, + srv, + expect, + upgrade, + on_connect, + ))) + } + self.poll(cx) + } + State::Handshake(ref mut data) => { + let conn = if let Some(ref mut item) = data { + match Pin::new(&mut item.0).poll(cx) { + Poll::Ready(Ok(conn)) => conn, + Poll::Ready(Err(err)) => { + trace!("H2 handshake error: {}", err); + return Poll::Ready(Err(err.into())); + } + Poll::Pending => return Poll::Pending, + } + } else { + panic!() + }; + let (_, cfg, srv, peer_addr, on_connect) = data.take().unwrap(); + self.set(State::H2(Dispatcher::new( + srv, conn, on_connect, cfg, None, peer_addr, + ))); + self.poll(cx) + } + } + } +} + +/// Wrapper for `AsyncRead + AsyncWrite` types +#[pin_project::pin_project] +struct Io { + unread: Option, + #[pin] + inner: T, +} + +impl io::Read for Io { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + if let Some(mut bytes) = self.unread.take() { + let size = std::cmp::min(buf.len(), bytes.len()); + buf[..size].copy_from_slice(&bytes[..size]); + if bytes.len() > size { + bytes.split_to(size); + self.unread = Some(bytes); + } + Ok(size) + } else { + self.inner.read(buf) + } + } +} + +impl io::Write for Io { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.inner.write(buf) + } + fn flush(&mut self) -> io::Result<()> { + self.inner.flush() + } +} + +impl AsyncRead for Io { + // unsafe fn initializer(&self) -> io::Initializer { + // self.get_mut().inner.initializer() + // } + + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.inner.prepare_uninitialized_buffer(buf) + } + + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let this = self.project(); + + if let Some(mut bytes) = this.unread.take() { + let size = std::cmp::min(buf.len(), bytes.len()); + buf[..size].copy_from_slice(&bytes[..size]); + if bytes.len() > size { + bytes.split_to(size); + *this.unread = Some(bytes); + } + Poll::Ready(Ok(size)) + } else { + this.inner.poll_read(cx, buf) + } + } + + // fn poll_read_vectored( + // self: Pin<&mut Self>, + // cx: &mut Context<'_>, + // bufs: &mut [io::IoSliceMut<'_>], + // ) -> Poll> { + // self.get_mut().inner.poll_read_vectored(cx, bufs) + // } +} + +impl actix_codec::AsyncWrite for Io { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().inner.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().inner.poll_shutdown(cx) + } +} + +impl actix_server_config::IoStream for Io { + #[inline] + fn peer_addr(&self) -> Option { + self.inner.peer_addr() + } + + #[inline] + fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { + self.inner.set_nodelay(nodelay) + } + + #[inline] + fn set_linger(&mut self, dur: Option) -> io::Result<()> { + self.inner.set_linger(dur) + } + + #[inline] + fn set_keepalive(&mut self, dur: Option) -> io::Result<()> { + self.inner.set_keepalive(dur) + } +} diff --git a/actix-http/src/test.rs b/actix-http/src/test.rs new file mode 100644 index 000000000..744f057dc --- /dev/null +++ b/actix-http/src/test.rs @@ -0,0 +1,288 @@ +//! Test Various helpers for Actix applications to use during testing. +use std::fmt::Write as FmtWrite; +use std::io::{self, Read, Write}; +use std::pin::Pin; +use std::str::FromStr; +use std::task::{Context, Poll}; + +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_server_config::IoStream; +use bytes::{Bytes, BytesMut}; +use http::header::{self, HeaderName, HeaderValue}; +use http::{HttpTryFrom, Method, Uri, Version}; +use percent_encoding::percent_encode; + +use crate::cookie::{Cookie, CookieJar, USERINFO}; +use crate::header::HeaderMap; +use crate::header::{Header, IntoHeaderValue}; +use crate::payload::Payload; +use crate::Request; + +/// Test `Request` builder +/// +/// ```rust,ignore +/// # extern crate http; +/// # extern crate actix_web; +/// # use http::{header, StatusCode}; +/// # use actix_web::*; +/// use actix_web::test::TestRequest; +/// +/// fn index(req: &HttpRequest) -> Response { +/// if let Some(hdr) = req.headers().get(header::CONTENT_TYPE) { +/// Response::Ok().into() +/// } else { +/// Response::BadRequest().into() +/// } +/// } +/// +/// fn main() { +/// let resp = TestRequest::with_header("content-type", "text/plain") +/// .run(&index) +/// .unwrap(); +/// assert_eq!(resp.status(), StatusCode::OK); +/// +/// let resp = TestRequest::default().run(&index).unwrap(); +/// assert_eq!(resp.status(), StatusCode::BAD_REQUEST); +/// } +/// ``` +pub struct TestRequest(Option); + +struct Inner { + version: Version, + method: Method, + uri: Uri, + headers: HeaderMap, + cookies: CookieJar, + payload: Option, +} + +impl Default for TestRequest { + fn default() -> TestRequest { + TestRequest(Some(Inner { + method: Method::GET, + uri: Uri::from_str("/").unwrap(), + version: Version::HTTP_11, + headers: HeaderMap::new(), + cookies: CookieJar::new(), + payload: None, + })) + } +} + +impl TestRequest { + /// Create TestRequest and set request uri + pub fn with_uri(path: &str) -> TestRequest { + TestRequest::default().uri(path).take() + } + + /// Create TestRequest and set header + pub fn with_hdr(hdr: H) -> TestRequest { + TestRequest::default().set(hdr).take() + } + + /// Create TestRequest and set header + pub fn with_header(key: K, value: V) -> TestRequest + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + TestRequest::default().header(key, value).take() + } + + /// Set HTTP version of this request + pub fn version(&mut self, ver: Version) -> &mut Self { + parts(&mut self.0).version = ver; + self + } + + /// Set HTTP method of this request + pub fn method(&mut self, meth: Method) -> &mut Self { + parts(&mut self.0).method = meth; + self + } + + /// Set HTTP Uri of this request + pub fn uri(&mut self, path: &str) -> &mut Self { + parts(&mut self.0).uri = Uri::from_str(path).unwrap(); + self + } + + /// Set a header + pub fn set(&mut self, hdr: H) -> &mut Self { + if let Ok(value) = hdr.try_into() { + parts(&mut self.0).headers.append(H::name(), value); + return self; + } + panic!("Can not set header"); + } + + /// Set a header + pub fn header(&mut self, key: K, value: V) -> &mut Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + if let Ok(key) = HeaderName::try_from(key) { + if let Ok(value) = value.try_into() { + parts(&mut self.0).headers.append(key, value); + return self; + } + } + panic!("Can not create header"); + } + + /// Set cookie for this request + pub fn cookie<'a>(&mut self, cookie: Cookie<'a>) -> &mut Self { + parts(&mut self.0).cookies.add(cookie.into_owned()); + self + } + + /// Set request payload + pub fn set_payload>(&mut self, data: B) -> &mut Self { + let mut payload = crate::h1::Payload::empty(); + payload.unread_data(data.into()); + parts(&mut self.0).payload = Some(payload.into()); + self + } + + pub fn take(&mut self) -> TestRequest { + TestRequest(self.0.take()) + } + + /// Complete request creation and generate `Request` instance + pub fn finish(&mut self) -> Request { + let inner = self.0.take().expect("cannot reuse test request builder"); + + let mut req = if let Some(pl) = inner.payload { + Request::with_payload(pl) + } else { + Request::with_payload(crate::h1::Payload::empty().into()) + }; + + let head = req.head_mut(); + head.uri = inner.uri; + head.method = inner.method; + head.version = inner.version; + head.headers = inner.headers; + + let mut cookie = String::new(); + for c in inner.cookies.delta() { + let name = percent_encode(c.name().as_bytes(), USERINFO); + let value = percent_encode(c.value().as_bytes(), USERINFO); + let _ = write!(&mut cookie, "; {}={}", name, value); + } + if !cookie.is_empty() { + head.headers.insert( + header::COOKIE, + HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(), + ); + } + + req + } +} + +#[inline] +fn parts(parts: &mut Option) -> &mut Inner { + parts.as_mut().expect("cannot reuse test request builder") +} + +/// Async io buffer +pub struct TestBuffer { + pub read_buf: BytesMut, + pub write_buf: BytesMut, + pub err: Option, +} + +impl TestBuffer { + /// Create new TestBuffer instance + pub fn new(data: T) -> TestBuffer + where + BytesMut: From, + { + TestBuffer { + read_buf: BytesMut::from(data), + write_buf: BytesMut::new(), + err: None, + } + } + + /// Create new empty TestBuffer instance + pub fn empty() -> TestBuffer { + TestBuffer::new("") + } + + /// Add extra data to read buffer. + pub fn extend_read_buf>(&mut self, data: T) { + self.read_buf.extend_from_slice(data.as_ref()) + } +} + +impl io::Read for TestBuffer { + fn read(&mut self, dst: &mut [u8]) -> Result { + if self.read_buf.is_empty() { + if self.err.is_some() { + Err(self.err.take().unwrap()) + } else { + Err(io::Error::new(io::ErrorKind::WouldBlock, "")) + } + } else { + let size = std::cmp::min(self.read_buf.len(), dst.len()); + let b = self.read_buf.split_to(size); + dst[..size].copy_from_slice(&b); + Ok(size) + } + } +} + +impl io::Write for TestBuffer { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.write_buf.extend(buf); + Ok(buf.len()) + } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl AsyncRead for TestBuffer { + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Poll::Ready(self.get_mut().read(buf)) + } +} + +impl AsyncWrite for TestBuffer { + fn poll_write( + self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Poll::Ready(self.get_mut().write(buf)) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +impl IoStream for TestBuffer { + fn set_nodelay(&mut self, _nodelay: bool) -> io::Result<()> { + Ok(()) + } + + fn set_linger(&mut self, _dur: Option) -> io::Result<()> { + Ok(()) + } + + fn set_keepalive(&mut self, _dur: Option) -> io::Result<()> { + Ok(()) + } +} diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs new file mode 100644 index 000000000..9891bfa6e --- /dev/null +++ b/actix-http/src/ws/codec.rs @@ -0,0 +1,150 @@ +use actix_codec::{Decoder, Encoder}; +use bytes::{Bytes, BytesMut}; + +use super::frame::Parser; +use super::proto::{CloseReason, OpCode}; +use super::ProtocolError; + +/// `WebSocket` Message +#[derive(Debug, PartialEq)] +pub enum Message { + /// Text message + Text(String), + /// Binary message + Binary(Bytes), + /// Ping message + Ping(String), + /// Pong message + Pong(String), + /// Close message with optional reason + Close(Option), + /// No-op. Useful for actix-net services + Nop, +} + +/// `WebSocket` frame +#[derive(Debug, PartialEq)] +pub enum Frame { + /// Text frame, codec does not verify utf8 encoding + Text(Option), + /// Binary frame + Binary(Option), + /// Ping message + Ping(String), + /// Pong message + Pong(String), + /// Close message with optional reason + Close(Option), +} + +#[derive(Debug, Copy, Clone)] +/// WebSockets protocol codec +pub struct Codec { + max_size: usize, + server: bool, +} + +impl Codec { + /// Create new websocket frames decoder + pub fn new() -> Codec { + Codec { + max_size: 65_536, + server: true, + } + } + + /// Set max frame size + /// + /// By default max size is set to 64kb + pub fn max_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } + + /// Set decoder to client mode. + /// + /// By default decoder works in server mode. + pub fn client_mode(mut self) -> Self { + self.server = false; + self + } +} + +impl Encoder for Codec { + type Item = Message; + type Error = ProtocolError; + + fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { + match item { + Message::Text(txt) => { + Parser::write_message(dst, txt, OpCode::Text, true, !self.server) + } + Message::Binary(bin) => { + Parser::write_message(dst, bin, OpCode::Binary, true, !self.server) + } + Message::Ping(txt) => { + Parser::write_message(dst, txt, OpCode::Ping, true, !self.server) + } + Message::Pong(txt) => { + Parser::write_message(dst, txt, OpCode::Pong, true, !self.server) + } + Message::Close(reason) => Parser::write_close(dst, reason, !self.server), + Message::Nop => (), + } + Ok(()) + } +} + +impl Decoder for Codec { + type Item = Frame; + type Error = ProtocolError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + match Parser::parse(src, self.server, self.max_size) { + Ok(Some((finished, opcode, payload))) => { + // continuation is not supported + if !finished { + return Err(ProtocolError::NoContinuation); + } + + match opcode { + OpCode::Continue => Err(ProtocolError::NoContinuation), + OpCode::Bad => Err(ProtocolError::BadOpCode), + OpCode::Close => { + if let Some(ref pl) = payload { + let close_reason = Parser::parse_close_payload(pl); + Ok(Some(Frame::Close(close_reason))) + } else { + Ok(Some(Frame::Close(None))) + } + } + OpCode::Ping => { + if let Some(ref pl) = payload { + Ok(Some(Frame::Ping(String::from_utf8_lossy(pl).into()))) + } else { + Ok(Some(Frame::Ping(String::new()))) + } + } + OpCode::Pong => { + if let Some(ref pl) = payload { + Ok(Some(Frame::Pong(String::from_utf8_lossy(pl).into()))) + } else { + Ok(Some(Frame::Pong(String::new()))) + } + } + OpCode::Binary => Ok(Some(Frame::Binary(payload))), + OpCode::Text => { + Ok(Some(Frame::Text(payload))) + //let tmp = Vec::from(payload.as_ref()); + //match String::from_utf8(tmp) { + // Ok(s) => Ok(Some(Message::Text(s))), + // Err(_) => Err(ProtocolError::BadEncoding), + //} + } + } + } + Ok(None) => Ok(None), + Err(e) => Err(e), + } + } +} diff --git a/actix-http/src/ws/frame.rs b/actix-http/src/ws/frame.rs new file mode 100644 index 000000000..46e9f36db --- /dev/null +++ b/actix-http/src/ws/frame.rs @@ -0,0 +1,384 @@ +use std::convert::TryFrom; + +use bytes::{BufMut, Bytes, BytesMut}; +use log::debug; +use rand; + +use crate::ws::mask::apply_mask; +use crate::ws::proto::{CloseCode, CloseReason, OpCode}; +use crate::ws::ProtocolError; + +/// A struct representing a `WebSocket` frame. +#[derive(Debug)] +pub struct Parser; + +impl Parser { + fn parse_metadata( + src: &[u8], + server: bool, + max_size: usize, + ) -> Result)>, ProtocolError> { + let chunk_len = src.len(); + + let mut idx = 2; + if chunk_len < 2 { + return Ok(None); + } + + let first = src[0]; + let second = src[1]; + let finished = first & 0x80 != 0; + + // check masking + let masked = second & 0x80 != 0; + if !masked && server { + return Err(ProtocolError::UnmaskedFrame); + } else if masked && !server { + return Err(ProtocolError::MaskedFrame); + } + + // Op code + let opcode = OpCode::from(first & 0x0F); + + if let OpCode::Bad = opcode { + return Err(ProtocolError::InvalidOpcode(first & 0x0F)); + } + + let len = second & 0x7F; + let length = if len == 126 { + if chunk_len < 4 { + return Ok(None); + } + let len = usize::from(u16::from_be_bytes( + TryFrom::try_from(&src[idx..idx + 2]).unwrap(), + )); + idx += 2; + len + } else if len == 127 { + if chunk_len < 10 { + return Ok(None); + } + let len = u64::from_be_bytes(TryFrom::try_from(&src[idx..idx + 8]).unwrap()); + if len > max_size as u64 { + return Err(ProtocolError::Overflow); + } + idx += 8; + len as usize + } else { + len as usize + }; + + // check for max allowed size + if length > max_size { + return Err(ProtocolError::Overflow); + } + + let mask = if server { + if chunk_len < idx + 4 { + return Ok(None); + } + + let mask = + u32::from_le_bytes(TryFrom::try_from(&src[idx..idx + 4]).unwrap()); + idx += 4; + Some(mask) + } else { + None + }; + + Ok(Some((idx, finished, opcode, length, mask))) + } + + /// Parse the input stream into a frame. + pub fn parse( + src: &mut BytesMut, + server: bool, + max_size: usize, + ) -> Result)>, ProtocolError> { + // try to parse ws frame metadata + let (idx, finished, opcode, length, mask) = + match Parser::parse_metadata(src, server, max_size)? { + None => return Ok(None), + Some(res) => res, + }; + + // not enough data + if src.len() < idx + length { + return Ok(None); + } + + // remove prefix + src.split_to(idx); + + // no need for body + if length == 0 { + return Ok(Some((finished, opcode, None))); + } + + let mut data = src.split_to(length); + + // control frames must have length <= 125 + match opcode { + OpCode::Ping | OpCode::Pong if length > 125 => { + return Err(ProtocolError::InvalidLength(length)); + } + OpCode::Close if length > 125 => { + debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); + return Ok(Some((true, OpCode::Close, None))); + } + _ => (), + } + + // unmask + if let Some(mask) = mask { + apply_mask(&mut data, mask); + } + + Ok(Some((finished, opcode, Some(data)))) + } + + /// Parse the payload of a close frame. + pub fn parse_close_payload(payload: &[u8]) -> Option { + if payload.len() >= 2 { + let raw_code = u16::from_be_bytes(TryFrom::try_from(&payload[..2]).unwrap()); + let code = CloseCode::from(raw_code); + let description = if payload.len() > 2 { + Some(String::from_utf8_lossy(&payload[2..]).into()) + } else { + None + }; + Some(CloseReason { code, description }) + } else { + None + } + } + + /// Generate binary representation + pub fn write_message>( + dst: &mut BytesMut, + pl: B, + op: OpCode, + fin: bool, + mask: bool, + ) { + let payload = pl.into(); + let one: u8 = if fin { + 0x80 | Into::::into(op) + } else { + op.into() + }; + let payload_len = payload.len(); + let (two, p_len) = if mask { + (0x80, payload_len + 4) + } else { + (0, payload_len) + }; + + if payload_len < 126 { + dst.reserve(p_len + 2 + if mask { 4 } else { 0 }); + dst.put_slice(&[one, two | payload_len as u8]); + } else if payload_len <= 65_535 { + dst.reserve(p_len + 4 + if mask { 4 } else { 0 }); + dst.put_slice(&[one, two | 126]); + dst.put_u16_be(payload_len as u16); + } else { + dst.reserve(p_len + 10 + if mask { 4 } else { 0 }); + dst.put_slice(&[one, two | 127]); + dst.put_u64_be(payload_len as u64); + }; + + if mask { + let mask = rand::random::(); + dst.put_u32_le(mask); + dst.put_slice(payload.as_ref()); + let pos = dst.len() - payload_len; + apply_mask(&mut dst[pos..], mask); + } else { + dst.put_slice(payload.as_ref()); + } + } + + /// Create a new Close control frame. + #[inline] + pub fn write_close(dst: &mut BytesMut, reason: Option, mask: bool) { + let payload = match reason { + None => Vec::new(), + Some(reason) => { + let mut payload = Into::::into(reason.code).to_be_bytes().to_vec(); + if let Some(description) = reason.description { + payload.extend(description.as_bytes()); + } + payload + } + }; + + Parser::write_message(dst, payload, OpCode::Close, true, mask) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + + struct F { + finished: bool, + opcode: OpCode, + payload: Bytes, + } + + fn is_none( + frm: &Result)>, ProtocolError>, + ) -> bool { + match *frm { + Ok(None) => true, + _ => false, + } + } + + fn extract( + frm: Result)>, ProtocolError>, + ) -> F { + match frm { + Ok(Some((finished, opcode, payload))) => F { + finished, + opcode, + payload: payload + .map(|b| b.freeze()) + .unwrap_or_else(|| Bytes::from("")), + }, + _ => unreachable!("error"), + } + } + + #[test] + fn test_parse() { + let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); + assert!(is_none(&Parser::parse(&mut buf, false, 1024))); + + let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); + buf.extend(b"1"); + + let frame = extract(Parser::parse(&mut buf, false, 1024)); + assert!(!frame.finished); + assert_eq!(frame.opcode, OpCode::Text); + assert_eq!(frame.payload.as_ref(), &b"1"[..]); + } + + #[test] + fn test_parse_length0() { + let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0000u8][..]); + let frame = extract(Parser::parse(&mut buf, false, 1024)); + assert!(!frame.finished); + assert_eq!(frame.opcode, OpCode::Text); + assert!(frame.payload.is_empty()); + } + + #[test] + fn test_parse_length2() { + let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]); + assert!(is_none(&Parser::parse(&mut buf, false, 1024))); + + let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]); + buf.extend(&[0u8, 4u8][..]); + buf.extend(b"1234"); + + let frame = extract(Parser::parse(&mut buf, false, 1024)); + assert!(!frame.finished); + assert_eq!(frame.opcode, OpCode::Text); + assert_eq!(frame.payload.as_ref(), &b"1234"[..]); + } + + #[test] + fn test_parse_length4() { + let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]); + assert!(is_none(&Parser::parse(&mut buf, false, 1024))); + + let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]); + buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]); + buf.extend(b"1234"); + + let frame = extract(Parser::parse(&mut buf, false, 1024)); + assert!(!frame.finished); + assert_eq!(frame.opcode, OpCode::Text); + assert_eq!(frame.payload.as_ref(), &b"1234"[..]); + } + + #[test] + fn test_parse_frame_mask() { + let mut buf = BytesMut::from(&[0b0000_0001u8, 0b1000_0001u8][..]); + buf.extend(b"0001"); + buf.extend(b"1"); + + assert!(Parser::parse(&mut buf, false, 1024).is_err()); + + let frame = extract(Parser::parse(&mut buf, true, 1024)); + assert!(!frame.finished); + assert_eq!(frame.opcode, OpCode::Text); + assert_eq!(frame.payload, Bytes::from(vec![1u8])); + } + + #[test] + fn test_parse_frame_no_mask() { + let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); + buf.extend(&[1u8]); + + assert!(Parser::parse(&mut buf, true, 1024).is_err()); + + let frame = extract(Parser::parse(&mut buf, false, 1024)); + assert!(!frame.finished); + assert_eq!(frame.opcode, OpCode::Text); + assert_eq!(frame.payload, Bytes::from(vec![1u8])); + } + + #[test] + fn test_parse_frame_max_size() { + let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0010u8][..]); + buf.extend(&[1u8, 1u8]); + + assert!(Parser::parse(&mut buf, true, 1).is_err()); + + if let Err(ProtocolError::Overflow) = Parser::parse(&mut buf, false, 0) { + } else { + unreachable!("error"); + } + } + + #[test] + fn test_ping_frame() { + let mut buf = BytesMut::new(); + Parser::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false); + + let mut v = vec![137u8, 4u8]; + v.extend(b"data"); + assert_eq!(&buf[..], &v[..]); + } + + #[test] + fn test_pong_frame() { + let mut buf = BytesMut::new(); + Parser::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false); + + let mut v = vec![138u8, 4u8]; + v.extend(b"data"); + assert_eq!(&buf[..], &v[..]); + } + + #[test] + fn test_close_frame() { + let mut buf = BytesMut::new(); + let reason = (CloseCode::Normal, "data"); + Parser::write_close(&mut buf, Some(reason.into()), false); + + let mut v = vec![136u8, 6u8, 3u8, 232u8]; + v.extend(b"data"); + assert_eq!(&buf[..], &v[..]); + } + + #[test] + fn test_empty_close_frame() { + let mut buf = BytesMut::new(); + Parser::write_close(&mut buf, None, false); + assert_eq!(&buf[..], &vec![0x88, 0x00][..]); + } +} diff --git a/actix-http/src/ws/mask.rs b/actix-http/src/ws/mask.rs new file mode 100644 index 000000000..9f7304039 --- /dev/null +++ b/actix-http/src/ws/mask.rs @@ -0,0 +1,148 @@ +//! This is code from [Tungstenite project](https://github.com/snapview/tungstenite-rs) +#![allow(clippy::cast_ptr_alignment)] +use std::ptr::copy_nonoverlapping; +use std::slice; + +// Holds a slice guaranteed to be shorter than 8 bytes +struct ShortSlice<'a>(&'a mut [u8]); + +impl<'a> ShortSlice<'a> { + unsafe fn new(slice: &'a mut [u8]) -> Self { + // Sanity check for debug builds + debug_assert!(slice.len() < 8); + ShortSlice(slice) + } + fn len(&self) -> usize { + self.0.len() + } +} + +/// Faster version of `apply_mask()` which operates on 8-byte blocks. +#[inline] +#[allow(clippy::cast_lossless)] +pub(crate) fn apply_mask(buf: &mut [u8], mask_u32: u32) { + // Extend the mask to 64 bits + let mut mask_u64 = ((mask_u32 as u64) << 32) | (mask_u32 as u64); + // Split the buffer into three segments + let (head, mid, tail) = align_buf(buf); + + // Initial unaligned segment + let head_len = head.len(); + if head_len > 0 { + xor_short(head, mask_u64); + if cfg!(target_endian = "big") { + mask_u64 = mask_u64.rotate_left(8 * head_len as u32); + } else { + mask_u64 = mask_u64.rotate_right(8 * head_len as u32); + } + } + // Aligned segment + for v in mid { + *v ^= mask_u64; + } + // Final unaligned segment + if tail.len() > 0 { + xor_short(tail, mask_u64); + } +} + +#[inline] +// TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so +// inefficient, it could be done better. The compiler does not understand that +// a `ShortSlice` must be smaller than a u64. +#[allow(clippy::needless_pass_by_value)] +fn xor_short(buf: ShortSlice, mask: u64) { + // Unsafe: we know that a `ShortSlice` fits in a u64 + unsafe { + let (ptr, len) = (buf.0.as_mut_ptr(), buf.0.len()); + let mut b: u64 = 0; + #[allow(trivial_casts)] + copy_nonoverlapping(ptr, &mut b as *mut _ as *mut u8, len); + b ^= mask; + #[allow(trivial_casts)] + copy_nonoverlapping(&b as *const _ as *const u8, ptr, len); + } +} + +#[inline] +// Unsafe: caller must ensure the buffer has the correct size and alignment +unsafe fn cast_slice(buf: &mut [u8]) -> &mut [u64] { + // Assert correct size and alignment in debug builds + debug_assert!(buf.len().trailing_zeros() >= 3); + debug_assert!((buf.as_ptr() as usize).trailing_zeros() >= 3); + + slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut u64, buf.len() >> 3) +} + +#[inline] +// Splits a slice into three parts: an unaligned short head and tail, plus an aligned +// u64 mid section. +fn align_buf(buf: &mut [u8]) -> (ShortSlice, &mut [u64], ShortSlice) { + let start_ptr = buf.as_ptr() as usize; + let end_ptr = start_ptr + buf.len(); + + // Round *up* to next aligned boundary for start + let start_aligned = (start_ptr + 7) & !0x7; + // Round *down* to last aligned boundary for end + let end_aligned = end_ptr & !0x7; + + if end_aligned >= start_aligned { + // We have our three segments (head, mid, tail) + let (tmp, tail) = buf.split_at_mut(end_aligned - start_ptr); + let (head, mid) = tmp.split_at_mut(start_aligned - start_ptr); + + // Unsafe: we know the middle section is correctly aligned, and the outer + // sections are smaller than 8 bytes + unsafe { (ShortSlice::new(head), cast_slice(mid), ShortSlice(tail)) } + } else { + // We didn't cross even one aligned boundary! + + // Unsafe: The outer sections are smaller than 8 bytes + unsafe { (ShortSlice::new(buf), &mut [], ShortSlice::new(&mut [])) } + } +} + +#[cfg(test)] +mod tests { + use super::apply_mask; + + /// A safe unoptimized mask application. + fn apply_mask_fallback(buf: &mut [u8], mask: &[u8; 4]) { + for (i, byte) in buf.iter_mut().enumerate() { + *byte ^= mask[i & 3]; + } + } + + #[test] + fn test_apply_mask() { + let mask = [0x6d, 0xb6, 0xb2, 0x80]; + let mask_u32 = u32::from_le_bytes(mask); + + let unmasked = vec![ + 0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, + 0x74, 0xf9, 0x12, 0x03, + ]; + + // Check masking with proper alignment. + { + let mut masked = unmasked.clone(); + apply_mask_fallback(&mut masked, &mask); + + let mut masked_fast = unmasked.clone(); + apply_mask(&mut masked_fast, mask_u32); + + assert_eq!(masked, masked_fast); + } + + // Check masking without alignment. + { + let mut masked = unmasked.clone(); + apply_mask_fallback(&mut masked[1..], &mask); + + let mut masked_fast = unmasked.clone(); + apply_mask(&mut masked_fast[1..], mask_u32); + + assert_eq!(masked, masked_fast); + } + } +} diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs new file mode 100644 index 000000000..891d5110d --- /dev/null +++ b/actix-http/src/ws/mod.rs @@ -0,0 +1,315 @@ +//! WebSocket protocol support. +//! +//! To setup a `WebSocket`, first do web socket handshake then on success +//! convert `Payload` into a `WsStream` stream and then use `WsWriter` to +//! communicate with the peer. +use std::io; + +use derive_more::{Display, From}; +use http::{header, Method, StatusCode}; + +use crate::error::ResponseError; +use crate::message::RequestHead; +use crate::response::{Response, ResponseBuilder}; + +mod codec; +mod frame; +mod mask; +mod proto; +mod transport; + +pub use self::codec::{Codec, Frame, Message}; +pub use self::frame::Parser; +pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode}; +pub use self::transport::Transport; + +/// Websocket protocol errors +#[derive(Debug, Display, From)] +pub enum ProtocolError { + /// Received an unmasked frame from client + #[display(fmt = "Received an unmasked frame from client")] + UnmaskedFrame, + /// Received a masked frame from server + #[display(fmt = "Received a masked frame from server")] + MaskedFrame, + /// Encountered invalid opcode + #[display(fmt = "Invalid opcode: {}", _0)] + InvalidOpcode(u8), + /// Invalid control frame length + #[display(fmt = "Invalid control frame length: {}", _0)] + InvalidLength(usize), + /// Bad web socket op code + #[display(fmt = "Bad web socket op code")] + BadOpCode, + /// A payload reached size limit. + #[display(fmt = "A payload reached size limit.")] + Overflow, + /// Continuation is not supported + #[display(fmt = "Continuation is not supported.")] + NoContinuation, + /// Bad utf-8 encoding + #[display(fmt = "Bad utf-8 encoding.")] + BadEncoding, + /// Io error + #[display(fmt = "io error: {}", _0)] + Io(io::Error), +} + +impl ResponseError for ProtocolError {} + +/// Websocket handshake errors +#[derive(PartialEq, Debug, Display)] +pub enum HandshakeError { + /// Only get method is allowed + #[display(fmt = "Method not allowed")] + GetMethodRequired, + /// Upgrade header if not set to websocket + #[display(fmt = "Websocket upgrade is expected")] + NoWebsocketUpgrade, + /// Connection header is not set to upgrade + #[display(fmt = "Connection upgrade is expected")] + NoConnectionUpgrade, + /// Websocket version header is not set + #[display(fmt = "Websocket version header is required")] + NoVersionHeader, + /// Unsupported websocket version + #[display(fmt = "Unsupported version")] + UnsupportedVersion, + /// Websocket key is not set or wrong + #[display(fmt = "Unknown websocket key")] + BadWebsocketKey, +} + +impl ResponseError for HandshakeError { + fn error_response(&self) -> Response { + match *self { + HandshakeError::GetMethodRequired => Response::MethodNotAllowed() + .header(header::ALLOW, "GET") + .finish(), + HandshakeError::NoWebsocketUpgrade => Response::BadRequest() + .reason("No WebSocket UPGRADE header found") + .finish(), + HandshakeError::NoConnectionUpgrade => Response::BadRequest() + .reason("No CONNECTION upgrade") + .finish(), + HandshakeError::NoVersionHeader => Response::BadRequest() + .reason("Websocket version header is required") + .finish(), + HandshakeError::UnsupportedVersion => Response::BadRequest() + .reason("Unsupported version") + .finish(), + HandshakeError::BadWebsocketKey => { + Response::BadRequest().reason("Handshake error").finish() + } + } + } +} + +/// Verify `WebSocket` handshake request and create handshake reponse. +// /// `protocols` is a sequence of known protocols. On successful handshake, +// /// the returned response headers contain the first protocol in this list +// /// which the server also knows. +pub fn handshake(req: &RequestHead) -> Result { + verify_handshake(req)?; + Ok(handshake_response(req)) +} + +/// Verify `WebSocket` handshake request. +// /// `protocols` is a sequence of known protocols. On successful handshake, +// /// the returned response headers contain the first protocol in this list +// /// which the server also knows. +pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> { + // WebSocket accepts only GET + if req.method != Method::GET { + return Err(HandshakeError::GetMethodRequired); + } + + // Check for "UPGRADE" to websocket header + let has_hdr = if let Some(hdr) = req.headers().get(header::UPGRADE) { + if let Ok(s) = hdr.to_str() { + s.to_ascii_lowercase().contains("websocket") + } else { + false + } + } else { + false + }; + if !has_hdr { + return Err(HandshakeError::NoWebsocketUpgrade); + } + + // Upgrade connection + if !req.upgrade() { + return Err(HandshakeError::NoConnectionUpgrade); + } + + // check supported version + if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) { + return Err(HandshakeError::NoVersionHeader); + } + let supported_ver = { + if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) { + hdr == "13" || hdr == "8" || hdr == "7" + } else { + false + } + }; + if !supported_ver { + return Err(HandshakeError::UnsupportedVersion); + } + + // check client handshake for validity + if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) { + return Err(HandshakeError::BadWebsocketKey); + } + Ok(()) +} + +/// Create websocket's handshake response +/// +/// This function returns handshake `Response`, ready to send to peer. +pub fn handshake_response(req: &RequestHead) -> ResponseBuilder { + let key = { + let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap(); + proto::hash_key(key.as_ref()) + }; + + Response::build(StatusCode::SWITCHING_PROTOCOLS) + .upgrade("websocket") + .header(header::TRANSFER_ENCODING, "chunked") + .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str()) + .take() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::TestRequest; + use http::{header, Method}; + + #[test] + fn test_handshake() { + let req = TestRequest::default().method(Method::POST).finish(); + assert_eq!( + HandshakeError::GetMethodRequired, + verify_handshake(req.head()).err().unwrap() + ); + + let req = TestRequest::default().finish(); + assert_eq!( + HandshakeError::NoWebsocketUpgrade, + verify_handshake(req.head()).err().unwrap() + ); + + let req = TestRequest::default() + .header(header::UPGRADE, header::HeaderValue::from_static("test")) + .finish(); + assert_eq!( + HandshakeError::NoWebsocketUpgrade, + verify_handshake(req.head()).err().unwrap() + ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .finish(); + assert_eq!( + HandshakeError::NoConnectionUpgrade, + verify_handshake(req.head()).err().unwrap() + ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .header( + header::CONNECTION, + header::HeaderValue::from_static("upgrade"), + ) + .finish(); + assert_eq!( + HandshakeError::NoVersionHeader, + verify_handshake(req.head()).err().unwrap() + ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .header( + header::CONNECTION, + header::HeaderValue::from_static("upgrade"), + ) + .header( + header::SEC_WEBSOCKET_VERSION, + header::HeaderValue::from_static("5"), + ) + .finish(); + assert_eq!( + HandshakeError::UnsupportedVersion, + verify_handshake(req.head()).err().unwrap() + ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .header( + header::CONNECTION, + header::HeaderValue::from_static("upgrade"), + ) + .header( + header::SEC_WEBSOCKET_VERSION, + header::HeaderValue::from_static("13"), + ) + .finish(); + assert_eq!( + HandshakeError::BadWebsocketKey, + verify_handshake(req.head()).err().unwrap() + ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .header( + header::CONNECTION, + header::HeaderValue::from_static("upgrade"), + ) + .header( + header::SEC_WEBSOCKET_VERSION, + header::HeaderValue::from_static("13"), + ) + .header( + header::SEC_WEBSOCKET_KEY, + header::HeaderValue::from_static("13"), + ) + .finish(); + assert_eq!( + StatusCode::SWITCHING_PROTOCOLS, + handshake_response(req.head()).finish().status() + ); + } + + #[test] + fn test_wserror_http_response() { + let resp: Response = HandshakeError::GetMethodRequired.error_response(); + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + let resp: Response = HandshakeError::NoWebsocketUpgrade.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let resp: Response = HandshakeError::NoConnectionUpgrade.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let resp: Response = HandshakeError::NoVersionHeader.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let resp: Response = HandshakeError::UnsupportedVersion.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let resp: Response = HandshakeError::BadWebsocketKey.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } +} diff --git a/src/ws/proto.rs b/actix-http/src/ws/proto.rs similarity index 75% rename from src/ws/proto.rs rename to actix-http/src/ws/proto.rs index 5f077a4b9..e14651a56 100644 --- a/src/ws/proto.rs +++ b/actix-http/src/ws/proto.rs @@ -1,7 +1,7 @@ -use std::fmt; -use std::convert::{Into, From}; -use sha1; use base64; +use sha1; +use std::convert::{From, Into}; +use std::fmt; use self::OpCode::*; /// Operation codes as part of rfc6455. @@ -26,52 +26,51 @@ pub enum OpCode { impl fmt::Display for OpCode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Continue => write!(f, "CONTINUE"), - Text => write!(f, "TEXT"), - Binary => write!(f, "BINARY"), - Close => write!(f, "CLOSE"), - Ping => write!(f, "PING"), - Pong => write!(f, "PONG"), - Bad => write!(f, "BAD"), + Continue => write!(f, "CONTINUE"), + Text => write!(f, "TEXT"), + Binary => write!(f, "BINARY"), + Close => write!(f, "CLOSE"), + Ping => write!(f, "PING"), + Pong => write!(f, "PONG"), + Bad => write!(f, "BAD"), } } } impl Into for OpCode { - fn into(self) -> u8 { match self { - Continue => 0, - Text => 1, - Binary => 2, - Close => 8, - Ping => 9, - Pong => 10, - Bad => { - debug_assert!(false, "Attempted to convert invalid opcode to u8. This is a bug."); - 8 // if this somehow happens, a close frame will help us tear down quickly + Continue => 0, + Text => 1, + Binary => 2, + Close => 8, + Ping => 9, + Pong => 10, + Bad => { + log::error!("Attempted to convert invalid opcode to u8. This is a bug."); + 8 // if this somehow happens, a close frame will help us tear down quickly } } } } impl From for OpCode { - fn from(byte: u8) -> OpCode { match byte { - 0 => Continue, - 1 => Text, - 2 => Binary, - 8 => Close, - 9 => Ping, - 10 => Pong, - _ => Bad + 0 => Continue, + 1 => Text, + 2 => Binary, + 8 => Close, + 9 => Ping, + 10 => Pong, + _ => Bad, } } } use self::CloseCode::*; -/// Status code used to indicate why an endpoint is closing the `WebSocket` connection. +/// Status code used to indicate why an endpoint is closing the `WebSocket` +/// connection. #[derive(Debug, Eq, PartialEq, Clone, Copy)] pub enum CloseCode { /// Indicates a normal closure, meaning that the purpose for @@ -88,10 +87,6 @@ pub enum CloseCode { /// endpoint that understands only text data MAY send this if it /// receives a binary message). Unsupported, - /// Indicates that no status code was included in a closing frame. This - /// close code makes it possible to use a single method, `on_close` to - /// handle even cases where no close code was provided. - Status, /// Indicates an abnormal closure. If the abnormal closure was due to an /// error, this close code will not be used. Instead, the `on_error` method /// of the handler will be called with the error. However, if the connection @@ -125,54 +120,48 @@ pub enum CloseCode { /// it encountered an unexpected condition that prevented it from /// fulfilling the request. Error, - /// Indicates that the server is restarting. A client may choose to reconnect, - /// and if it does, it should use a randomized delay of 5-30 seconds between attempts. + /// Indicates that the server is restarting. A client may choose to + /// reconnect, and if it does, it should use a randomized delay of 5-30 + /// seconds between attempts. Restart, - /// Indicates that the server is overloaded and the client should either connect - /// to a different IP (when multiple targets exist), or reconnect to the same IP - /// when a user has performed an action. + /// Indicates that the server is overloaded and the client should either + /// connect to a different IP (when multiple targets exist), or + /// reconnect to the same IP when a user has performed an action. Again, #[doc(hidden)] Tls, #[doc(hidden)] - Empty, - #[doc(hidden)] Other(u16), } impl Into for CloseCode { - fn into(self) -> u16 { match self { - Normal => 1000, - Away => 1001, - Protocol => 1002, - Unsupported => 1003, - Status => 1005, - Abnormal => 1006, - Invalid => 1007, - Policy => 1008, - Size => 1009, - Extension => 1010, - Error => 1011, - Restart => 1012, - Again => 1013, - Tls => 1015, - Empty => 0, - Other(code) => code, + Normal => 1000, + Away => 1001, + Protocol => 1002, + Unsupported => 1003, + Abnormal => 1006, + Invalid => 1007, + Policy => 1008, + Size => 1009, + Extension => 1010, + Error => 1011, + Restart => 1012, + Again => 1013, + Tls => 1015, + Other(code) => code, } } } impl From for CloseCode { - fn from(code: u16) -> CloseCode { match code { 1000 => Normal, 1001 => Away, 1002 => Protocol, 1003 => Unsupported, - 1005 => Status, 1006 => Abnormal, 1007 => Invalid, 1008 => Policy, @@ -182,16 +171,42 @@ impl From for CloseCode { 1012 => Restart, 1013 => Again, 1015 => Tls, - 0 => Empty, _ => Other(code), } } } -static WS_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; +#[derive(Debug, Eq, PartialEq, Clone)] +/// Reason for closing the connection +pub struct CloseReason { + /// Exit code + pub code: CloseCode, + /// Optional description of the exit code + pub description: Option, +} + +impl From for CloseReason { + fn from(code: CloseCode) -> Self { + CloseReason { + code, + description: None, + } + } +} + +impl> From<(CloseCode, T)> for CloseReason { + fn from(info: (CloseCode, T)) -> Self { + CloseReason { + code: info.0, + description: Some(info.1.into()), + } + } +} + +static WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; // TODO: hash is always same size, we dont need String -pub(crate) fn hash_key(key: &[u8]) -> String { +pub fn hash_key(key: &[u8]) -> String { let mut hasher = sha1::Sha1::new(); hasher.update(key); @@ -200,7 +215,6 @@ pub(crate) fn hash_key(key: &[u8]) -> String { base64::encode(&hasher.digest().bytes()) } - #[cfg(test)] mod test { #![allow(unused_imports, unused_variables, dead_code)] @@ -210,9 +224,9 @@ mod test { ($from:expr => $opcode:pat) => { match OpCode::from($from) { e @ $opcode => (), - e => unreachable!("{:?}", e) + e => unreachable!("{:?}", e), } - } + }; } macro_rules! opcode_from { @@ -220,9 +234,9 @@ mod test { let res: u8 = $from.into(); match res { e @ $opcode => (), - e => unreachable!("{:?}", e) + e => unreachable!("{:?}", e), } - } + }; } #[test] @@ -269,7 +283,6 @@ mod test { assert_eq!(CloseCode::from(1001u16), CloseCode::Away); assert_eq!(CloseCode::from(1002u16), CloseCode::Protocol); assert_eq!(CloseCode::from(1003u16), CloseCode::Unsupported); - assert_eq!(CloseCode::from(1005u16), CloseCode::Status); assert_eq!(CloseCode::from(1006u16), CloseCode::Abnormal); assert_eq!(CloseCode::from(1007u16), CloseCode::Invalid); assert_eq!(CloseCode::from(1008u16), CloseCode::Policy); @@ -279,7 +292,6 @@ mod test { assert_eq!(CloseCode::from(1012u16), CloseCode::Restart); assert_eq!(CloseCode::from(1013u16), CloseCode::Again); assert_eq!(CloseCode::from(1015u16), CloseCode::Tls); - assert_eq!(CloseCode::from(0u16), CloseCode::Empty); assert_eq!(CloseCode::from(2000u16), CloseCode::Other(2000)); } @@ -289,7 +301,6 @@ mod test { assert_eq!(1001u16, Into::::into(CloseCode::Away)); assert_eq!(1002u16, Into::::into(CloseCode::Protocol)); assert_eq!(1003u16, Into::::into(CloseCode::Unsupported)); - assert_eq!(1005u16, Into::::into(CloseCode::Status)); assert_eq!(1006u16, Into::::into(CloseCode::Abnormal)); assert_eq!(1007u16, Into::::into(CloseCode::Invalid)); assert_eq!(1008u16, Into::::into(CloseCode::Policy)); @@ -299,7 +310,6 @@ mod test { assert_eq!(1012u16, Into::::into(CloseCode::Restart)); assert_eq!(1013u16, Into::::into(CloseCode::Again)); assert_eq!(1015u16, Into::::into(CloseCode::Tls)); - assert_eq!(0u16, Into::::into(CloseCode::Empty)); assert_eq!(2000u16, Into::::into(CloseCode::Other(2000))); } } diff --git a/actix-http/src/ws/transport.rs b/actix-http/src/ws/transport.rs new file mode 100644 index 000000000..58ba3160f --- /dev/null +++ b/actix-http/src/ws/transport.rs @@ -0,0 +1,51 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_service::{IntoService, Service}; +use actix_utils::framed::{FramedTransport, FramedTransportError}; + +use super::{Codec, Frame, Message}; + +pub struct Transport +where + S: Service + 'static, + T: AsyncRead + AsyncWrite, +{ + inner: FramedTransport, +} + +impl Transport +where + T: AsyncRead + AsyncWrite, + S: Service, + S::Future: 'static, + S::Error: 'static, +{ + pub fn new>(io: T, service: F) -> Self { + Transport { + inner: FramedTransport::new(Framed::new(io, Codec::new()), service), + } + } + + pub fn with>(framed: Framed, service: F) -> Self { + Transport { + inner: FramedTransport::new(framed, service), + } + } +} + +impl Future for Transport +where + T: AsyncRead + AsyncWrite, + S: Service, + S::Future: 'static, + S::Error: 'static, +{ + type Output = Result<(), FramedTransportError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + Pin::new(&mut self.inner).poll(cx) + } +} diff --git a/actix-http/tests/test.binary b/actix-http/tests/test.binary new file mode 100644 index 000000000..ef8ff0245 --- /dev/null +++ b/actix-http/tests/test.binary @@ -0,0 +1 @@ +ÂTÇ‘É‚Vù2þvI ª–\ÇRË™–ˆæeÞvDØ:è—½¬RVÖYpíÿ;ÍÏGñùp!2÷CŒ.– û®õpA !ûߦÙx j+Uc÷±©X”c%Û;ï"yì­AI \ No newline at end of file diff --git a/actix-http/tests/test.png b/actix-http/tests/test.png new file mode 100644 index 000000000..6b7cdc0b8 Binary files /dev/null and b/actix-http/tests/test.png differ diff --git a/actix-http/tests/test_client.rs b/actix-http/tests/test_client.rs new file mode 100644 index 000000000..cdcaea028 --- /dev/null +++ b/actix-http/tests/test_client.rs @@ -0,0 +1,84 @@ +use actix_service::ServiceFactory; +use bytes::Bytes; +use futures::future::{self, ok}; + +use actix_http::{http, HttpService, Request, Response}; +use actix_http_test::TestServer; + +const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World"; + +#[actix_rt::test] +async fn test_h1_v2() { + let srv = TestServer::start(move || { + HttpService::build().finish(|_| future::ok::<_, ()>(Response::Ok().body(STR))) + }); + + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); + + let request = srv.get("/").header("x-test", "111").send(); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + let mut response = srv.post("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_connection_close() { + let srv = TestServer::start(move || { + HttpService::build() + .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map(|_| ()) + }); + + let response = srv.get("/").force_close().send().await.unwrap(); + assert!(response.status().is_success()); +} + +#[actix_rt::test] +async fn test_with_query_parameter() { + let srv = TestServer::start(move || { + HttpService::build() + .finish(|req: Request| { + if req.uri().query().unwrap().contains("qp=") { + ok::<_, ()>(Response::Ok().finish()) + } else { + ok::<_, ()>(Response::BadRequest().finish()) + } + }) + .map(|_| ()) + }); + + let request = srv.request(http::Method::GET, srv.url("/?qp=5")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); +} diff --git a/actix-http/tests/test_openssl.rs b/actix-http/tests/test_openssl.rs new file mode 100644 index 000000000..0fdddaa1c --- /dev/null +++ b/actix-http/tests/test_openssl.rs @@ -0,0 +1,516 @@ +#![cfg(feature = "openssl")] +use std::io; + +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_http_test::TestServer; +use actix_server::ssl::OpensslAcceptor; +use actix_server_config::ServerConfig; +use actix_service::{factory_fn_cfg, pipeline_factory, service_fn2, ServiceFactory}; + +use bytes::{Bytes, BytesMut}; +use futures::future::{err, ok, ready}; +use futures::stream::{once, Stream, StreamExt}; +use open_ssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod}; + +use actix_http::error::{ErrorBadRequest, PayloadError}; +use actix_http::http::header::{self, HeaderName, HeaderValue}; +use actix_http::http::{Method, StatusCode, Version}; +use actix_http::httpmessage::HttpMessage; +use actix_http::{body, Error, HttpService, Request, Response}; + +async fn load_body(stream: S) -> Result +where + S: Stream>, +{ + let body = stream + .map(|res| match res { + Ok(chunk) => chunk, + Err(_) => panic!(), + }) + .fold(BytesMut::new(), move |mut body, chunk| { + body.extend_from_slice(&chunk); + ready(body) + }) + .await; + + Ok(body) +} + +fn ssl_acceptor() -> io::Result> { + // load ssl keys + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + builder + .set_private_key_file("../tests/key.pem", SslFiletype::PEM) + .unwrap(); + builder + .set_certificate_chain_file("../tests/cert.pem") + .unwrap(); + builder.set_alpn_select_callback(|_, protos| { + const H2: &[u8] = b"\x02h2"; + if protos.windows(3).any(|window| window == H2) { + Ok(b"h2") + } else { + Err(AlpnError::NOACK) + } + }); + builder.set_alpn_protos(b"\x02h2")?; + Ok(OpensslAcceptor::new(builder.build())) +} + +#[actix_rt::test] +async fn test_h2() -> io::Result<()> { + let openssl = ssl_acceptor()?; + let srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|_| ok::<_, Error>(Response::Ok().finish())) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + Ok(()) +} + +#[actix_rt::test] +async fn test_h2_1() -> io::Result<()> { + let openssl = ssl_acceptor()?; + let srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .finish(|req: Request| { + assert!(req.peer_addr().is_some()); + assert_eq!(req.version(), Version::HTTP_2); + ok::<_, Error>(Response::Ok().finish()) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + Ok(()) +} + +#[actix_rt::test] +async fn test_h2_body() -> io::Result<()> { + let data = "HELLOWORLD".to_owned().repeat(64 * 1024); + let openssl = ssl_acceptor()?; + let mut srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|mut req: Request<_>| { + async move { + let body = load_body(req.take_payload()).await?; + Ok::<_, Error>(Response::Ok().body(body)) + } + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send_body(data.clone()).await.unwrap(); + assert!(response.status().is_success()); + + let body = srv.load_body(response).await.unwrap(); + assert_eq!(&body, data.as_bytes()); + Ok(()) +} + +#[actix_rt::test] +async fn test_h2_content_length() { + let openssl = ssl_acceptor().unwrap(); + + let srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|req: Request| { + let indx: usize = req.uri().path()[1..].parse().unwrap(); + let statuses = [ + StatusCode::NO_CONTENT, + StatusCode::CONTINUE, + StatusCode::SWITCHING_PROTOCOLS, + StatusCode::PROCESSING, + StatusCode::OK, + StatusCode::NOT_FOUND, + ]; + ok::<_, ()>(Response::new(statuses[indx])) + }) + .map_err(|_| ()), + ) + }); + + let header = HeaderName::from_static("content-length"); + let value = HeaderValue::from_static("0"); + + { + for i in 0..4 { + let req = srv + .request(Method::GET, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), None); + + let req = srv + .request(Method::HEAD, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), None); + } + + for i in 4..6 { + let req = srv + .request(Method::GET, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), Some(&value)); + } + } +} + +#[actix_rt::test] +async fn test_h2_headers() { + let data = STR.repeat(10); + let data2 = data.clone(); + let openssl = ssl_acceptor().unwrap(); + + let mut srv = TestServer::start(move || { + let data = data.clone(); + pipeline_factory(openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e))) + .and_then( + HttpService::build().h2(move |_| { + let mut builder = Response::Ok(); + for idx in 0..90 { + builder.header( + format!("X-TEST-{}", idx).as_str(), + "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", + ); + } + ok::<_, ()>(builder.body(data.clone())) + }).map_err(|_| ())) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from(data2)); +} + +const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World"; + +#[actix_rt::test] +async fn test_h2_body2() { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_h2_head_empty() { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.version(), Version::HTTP_2); + + { + let len = response.headers().get(header::CONTENT_LENGTH).unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert!(bytes.is_empty()); +} + +#[actix_rt::test] +async fn test_h2_head_binary() { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|_| { + ok::<_, ()>( + Response::Ok().content_length(STR.len() as u64).body(STR), + ) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + + { + let len = response.headers().get(header::CONTENT_LENGTH).unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert!(bytes.is_empty()); +} + +#[actix_rt::test] +async fn test_h2_head_binary2() { + let openssl = ssl_acceptor().unwrap(); + let srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + + { + let len = response.headers().get(header::CONTENT_LENGTH).unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } +} + +#[actix_rt::test] +async fn test_h2_body_length() { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|_| { + let body = once(ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .body(body::SizedStream::new(STR.len() as u64, body)), + ) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_h2_body_chunked_explicit() { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|_| { + let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .header(header::TRANSFER_ENCODING, "chunked") + .streaming(body), + ) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + assert!(!response.headers().contains_key(header::TRANSFER_ENCODING)); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + + // decode + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_h2_response_http_error_handling() { + let openssl = ssl_acceptor().unwrap(); + + let mut srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(factory_fn_cfg(|_: &ServerConfig| { + ok::<_, ()>(service_fn2(|_| { + let broken_header = Bytes::from_static(b"\0\0\0"); + ok::<_, ()>( + Response::Ok() + .header(header::CONTENT_TYPE, broken_header) + .body(STR), + ) + })) + })) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); +} + +#[actix_rt::test] +async fn test_h2_service_error() { + let openssl = ssl_acceptor().unwrap(); + + let mut srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(|_| err::(ErrorBadRequest("error"))) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"error")); +} + +#[actix_rt::test] +async fn test_h2_on_connect() { + let openssl = ssl_acceptor().unwrap(); + + let srv = TestServer::start(move || { + pipeline_factory( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .on_connect(|_| 10usize) + .h2(|req: Request| { + assert!(req.extensions().contains::()); + ok::<_, ()>(Response::Ok().finish()) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); +} diff --git a/actix-http/tests/test_rustls.rs b/actix-http/tests/test_rustls.rs new file mode 100644 index 000000000..4a649ca37 --- /dev/null +++ b/actix-http/tests/test_rustls.rs @@ -0,0 +1,441 @@ +#![cfg(feature = "rustls")] +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_http::error::PayloadError; +use actix_http::http::header::{self, HeaderName, HeaderValue}; +use actix_http::http::{Method, StatusCode, Version}; +use actix_http::{body, error, Error, HttpService, Request, Response}; +use actix_http_test::TestServer; +use actix_server::ssl::RustlsAcceptor; +use actix_server_config::ServerConfig; +use actix_service::{factory_fn_cfg, pipeline_factory, service_fn2, ServiceFactory}; + +use bytes::{Bytes, BytesMut}; +use futures::future::{self, err, ok}; +use futures::stream::{once, Stream, StreamExt}; +use rust_tls::{ + internal::pemfile::{certs, pkcs8_private_keys}, + NoClientAuth, ServerConfig as RustlsServerConfig, +}; + +use std::fs::File; +use std::io::{self, BufReader}; + +async fn load_body(mut stream: S) -> Result +where + S: Stream> + Unpin, +{ + let mut body = BytesMut::new(); + while let Some(item) = stream.next().await { + body.extend_from_slice(&item?) + } + Ok(body) +} + +fn ssl_acceptor() -> io::Result> { + // load ssl keys + let mut config = RustlsServerConfig::new(NoClientAuth::new()); + let cert_file = &mut BufReader::new(File::open("../tests/cert.pem").unwrap()); + let key_file = &mut BufReader::new(File::open("../tests/key.pem").unwrap()); + let cert_chain = certs(cert_file).unwrap(); + let mut keys = pkcs8_private_keys(key_file).unwrap(); + config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); + + let protos = vec![b"h2".to_vec()]; + config.set_protocols(&protos); + Ok(RustlsAcceptor::new(config)) +} + +#[actix_rt::test] +async fn test_h2() -> io::Result<()> { + let rustls = ssl_acceptor()?; + let srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|_| future::ok::<_, Error>(Response::Ok().finish())) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + Ok(()) +} + +#[actix_rt::test] +async fn test_h2_1() -> io::Result<()> { + let rustls = ssl_acceptor()?; + let srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .finish(|req: Request| { + assert!(req.peer_addr().is_some()); + assert_eq!(req.version(), Version::HTTP_2); + future::ok::<_, Error>(Response::Ok().finish()) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + Ok(()) +} + +#[actix_rt::test] +async fn test_h2_body1() -> io::Result<()> { + let data = "HELLOWORLD".to_owned().repeat(64 * 1024); + let rustls = ssl_acceptor()?; + let mut srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|mut req: Request<_>| { + async move { + let body = load_body(req.take_payload()).await?; + Ok::<_, Error>(Response::Ok().body(body)) + } + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send_body(data.clone()).await.unwrap(); + assert!(response.status().is_success()); + + let body = srv.load_body(response).await.unwrap(); + assert_eq!(&body, data.as_bytes()); + Ok(()) +} + +#[actix_rt::test] +async fn test_h2_content_length() { + let rustls = ssl_acceptor().unwrap(); + + let srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|req: Request| { + let indx: usize = req.uri().path()[1..].parse().unwrap(); + let statuses = [ + StatusCode::NO_CONTENT, + StatusCode::CONTINUE, + StatusCode::SWITCHING_PROTOCOLS, + StatusCode::PROCESSING, + StatusCode::OK, + StatusCode::NOT_FOUND, + ]; + future::ok::<_, ()>(Response::new(statuses[indx])) + }) + .map_err(|_| ()), + ) + }); + + let header = HeaderName::from_static("content-length"); + let value = HeaderValue::from_static("0"); + + { + for i in 0..4 { + let req = srv + .request(Method::GET, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), None); + + let req = srv + .request(Method::HEAD, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), None); + } + + for i in 4..6 { + let req = srv + .request(Method::GET, srv.surl(&format!("/{}", i))) + .send(); + let response = req.await.unwrap(); + assert_eq!(response.headers().get(&header), Some(&value)); + } + } +} + +#[actix_rt::test] +async fn test_h2_headers() { + let data = STR.repeat(10); + let data2 = data.clone(); + let rustls = ssl_acceptor().unwrap(); + + let mut srv = TestServer::start(move || { + let data = data.clone(); + pipeline_factory(rustls + .clone() + .map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build().h2(move |_| { + let mut config = Response::Ok(); + for idx in 0..90 { + config.header( + format!("X-TEST-{}", idx).as_str(), + "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", + ); + } + future::ok::<_, ()>(config.body(data.clone())) + }).map_err(|_| ())) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from(data2)); +} + +const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World"; + +#[actix_rt::test] +async fn test_h2_body2() { + let rustls = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|_| future::ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_h2_head_empty() { + let rustls = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.version(), Version::HTTP_2); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert!(bytes.is_empty()); +} + +#[actix_rt::test] +async fn test_h2_head_binary() { + let rustls = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|_| { + ok::<_, ()>( + Response::Ok().content_length(STR.len() as u64).body(STR), + ) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert!(bytes.is_empty()); +} + +#[actix_rt::test] +async fn test_h2_head_binary2() { + let rustls = ssl_acceptor().unwrap(); + let srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let response = srv.shead("/").send().await.unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } +} + +#[actix_rt::test] +async fn test_h2_body_length() { + let rustls = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|_| { + let body = once(ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .body(body::SizedStream::new(STR.len() as u64, body)), + ) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_h2_body_chunked_explicit() { + let rustls = ssl_acceptor().unwrap(); + let mut srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|_| { + let body = + once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .header(header::TRANSFER_ENCODING, "chunked") + .streaming(body), + ) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + assert!(!response.headers().contains_key(header::TRANSFER_ENCODING)); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + + // decode + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_h2_response_http_error_handling() { + let rustls = ssl_acceptor().unwrap(); + + let mut srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(factory_fn_cfg(|_: &ServerConfig| { + ok::<_, ()>(service_fn2(|_| { + let broken_header = Bytes::from_static(b"\0\0\0"); + ok::<_, ()>( + Response::Ok() + .header(http::header::CONTENT_TYPE, broken_header) + .body(STR), + ) + })) + })) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); +} + +#[actix_rt::test] +async fn test_h2_service_error() { + let rustls = ssl_acceptor().unwrap(); + + let mut srv = TestServer::start(move || { + pipeline_factory(rustls.clone().map_err(|e| println!("Rustls error: {}", e))) + .and_then( + HttpService::build() + .h2(|_| err::(error::ErrorBadRequest("error"))) + .map_err(|_| ()), + ) + }); + + let response = srv.sget("/").send().await.unwrap(); + assert_eq!(response.status(), http::StatusCode::BAD_REQUEST); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"error")); +} diff --git a/actix-http/tests/test_server.rs b/actix-http/tests/test_server.rs new file mode 100644 index 000000000..a3ce3f9cb --- /dev/null +++ b/actix-http/tests/test_server.rs @@ -0,0 +1,619 @@ +use std::io::{Read, Write}; +use std::time::Duration; +use std::{net, thread}; + +use actix_http_test::TestServer; +use actix_rt::time::delay_for; +use actix_server_config::ServerConfig; +use actix_service::{factory_fn_cfg, pipeline, service_fn, ServiceFactory}; +use bytes::Bytes; +use futures::future::{self, err, ok, ready, FutureExt}; +use futures::stream::{once, StreamExt}; +use regex::Regex; + +use actix_http::httpmessage::HttpMessage; +use actix_http::{ + body, error, http, http::header, Error, HttpService, KeepAlive, Request, Response, +}; + +#[actix_rt::test] +async fn test_h1() { + let srv = TestServer::start(|| { + HttpService::build() + .keep_alive(KeepAlive::Disabled) + .client_timeout(1000) + .client_disconnect(1000) + .h1(|req: Request| { + assert!(req.peer_addr().is_some()); + future::ok::<_, ()>(Response::Ok().finish()) + }) + }); + + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); +} + +#[actix_rt::test] +async fn test_h1_2() { + let srv = TestServer::start(|| { + HttpService::build() + .keep_alive(KeepAlive::Disabled) + .client_timeout(1000) + .client_disconnect(1000) + .finish(|req: Request| { + assert!(req.peer_addr().is_some()); + assert_eq!(req.version(), http::Version::HTTP_11); + future::ok::<_, ()>(Response::Ok().finish()) + }) + .map(|_| ()) + }); + + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); +} + +#[actix_rt::test] +async fn test_expect_continue() { + let srv = TestServer::start(|| { + HttpService::build() + .expect(service_fn(|req: Request| { + if req.head().uri.query() == Some("yes=") { + ok(req) + } else { + err(error::ErrorPreconditionFailed("error")) + } + })) + .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 412 Precondition Failed\r\ncontent-length")); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); +} + +#[actix_rt::test] +async fn test_expect_continue_h1() { + let srv = TestServer::start(|| { + HttpService::build() + .expect(service_fn(|req: Request| { + delay_for(Duration::from_millis(20)).then(move |_| { + if req.head().uri.query() == Some("yes=") { + ok(req) + } else { + err(error::ErrorPreconditionFailed("error")) + } + }) + })) + .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 412 Precondition Failed\r\ncontent-length")); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); +} + +#[actix_rt::test] +async fn test_chunked_payload() { + let chunk_sizes = vec![32768, 32, 32768]; + let total_size: usize = chunk_sizes.iter().sum(); + + let srv = TestServer::start(|| { + HttpService::build().h1(service_fn(|mut request: Request| { + request + .take_payload() + .map(|res| match res { + Ok(pl) => pl, + Err(e) => panic!(format!("Error reading payload: {}", e)), + }) + .fold(0usize, |acc, chunk| ready(acc + chunk.len())) + .map(|req_size| { + Ok::<_, Error>(Response::Ok().body(format!("size={}", req_size))) + }) + })) + }); + + let returned_size = { + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream + .write_all(b"POST /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n"); + + for chunk_size in chunk_sizes.iter() { + let mut bytes = Vec::new(); + let random_bytes: Vec = + (0..*chunk_size).map(|_| rand::random::()).collect(); + + bytes.extend(format!("{:X}\r\n", chunk_size).as_bytes()); + bytes.extend(&random_bytes[..]); + bytes.extend(b"\r\n"); + let _ = stream.write_all(&bytes); + } + + let _ = stream.write_all(b"0\r\n\r\n"); + stream.shutdown(net::Shutdown::Write).unwrap(); + + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + + let re = Regex::new(r"size=(\d+)").unwrap(); + let size: usize = match re.captures(&data) { + Some(caps) => caps.get(1).unwrap().as_str().parse().unwrap(), + None => panic!(format!("Failed to find size in HTTP Response: {}", data)), + }; + size + }; + + assert_eq!(returned_size, total_size); +} + +#[actix_rt::test] +async fn test_slow_request() { + let srv = TestServer::start(|| { + HttpService::build() + .client_timeout(100) + .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); +} + +#[actix_rt::test] +async fn test_http1_malformed_request() { + let srv = TestServer::start(|| { + HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP1.1\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 400 Bad Request")); +} + +#[actix_rt::test] +async fn test_http1_keepalive() { + let srv = TestServer::start(|| { + HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); +} + +#[actix_rt::test] +async fn test_http1_keepalive_timeout() { + let srv = TestServer::start(|| { + HttpService::build() + .keep_alive(1) + .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + thread::sleep(Duration::from_millis(1100)); + + let mut data = vec![0; 1024]; + let res = stream.read(&mut data).unwrap(); + assert_eq!(res, 0); +} + +#[actix_rt::test] +async fn test_http1_keepalive_close() { + let srv = TestServer::start(|| { + HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = + stream.write_all(b"GET /test/tests/test HTTP/1.1\r\nconnection: close\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + + let mut data = vec![0; 1024]; + let res = stream.read(&mut data).unwrap(); + assert_eq!(res, 0); +} + +#[actix_rt::test] +async fn test_http10_keepalive_default_close() { + let srv = TestServer::start(|| { + HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.0\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.0 200 OK\r\n"); + + let mut data = vec![0; 1024]; + let res = stream.read(&mut data).unwrap(); + assert_eq!(res, 0); +} + +#[actix_rt::test] +async fn test_http10_keepalive() { + let srv = TestServer::start(|| { + HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream + .write_all(b"GET /test/tests/test HTTP/1.0\r\nconnection: keep-alive\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.0 200 OK\r\n"); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.0\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.0 200 OK\r\n"); + + let mut data = vec![0; 1024]; + let res = stream.read(&mut data).unwrap(); + assert_eq!(res, 0); +} + +#[actix_rt::test] +async fn test_http1_keepalive_disabled() { + let srv = TestServer::start(|| { + HttpService::build() + .keep_alive(KeepAlive::Disabled) + .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + + let mut data = vec![0; 1024]; + let res = stream.read(&mut data).unwrap(); + assert_eq!(res, 0); +} + +#[actix_rt::test] +async fn test_content_length() { + use actix_http::http::{ + header::{HeaderName, HeaderValue}, + StatusCode, + }; + + let srv = TestServer::start(|| { + HttpService::build().h1(|req: Request| { + let indx: usize = req.uri().path()[1..].parse().unwrap(); + let statuses = [ + StatusCode::NO_CONTENT, + StatusCode::CONTINUE, + StatusCode::SWITCHING_PROTOCOLS, + StatusCode::PROCESSING, + StatusCode::OK, + StatusCode::NOT_FOUND, + ]; + future::ok::<_, ()>(Response::new(statuses[indx])) + }) + }); + + let header = HeaderName::from_static("content-length"); + let value = HeaderValue::from_static("0"); + + { + for i in 0..4 { + let req = srv.request(http::Method::GET, srv.url(&format!("/{}", i))); + let response = req.send().await.unwrap(); + assert_eq!(response.headers().get(&header), None); + + let req = srv.request(http::Method::HEAD, srv.url(&format!("/{}", i))); + let response = req.send().await.unwrap(); + assert_eq!(response.headers().get(&header), None); + } + + for i in 4..6 { + let req = srv.request(http::Method::GET, srv.url(&format!("/{}", i))); + let response = req.send().await.unwrap(); + assert_eq!(response.headers().get(&header), Some(&value)); + } + } +} + +#[actix_rt::test] +async fn test_h1_headers() { + let data = STR.repeat(10); + let data2 = data.clone(); + + let mut srv = TestServer::start(move || { + let data = data.clone(); + HttpService::build().h1(move |_| { + let mut builder = Response::Ok(); + for idx in 0..90 { + builder.header( + format!("X-TEST-{}", idx).as_str(), + "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", + ); + } + future::ok::<_, ()>(builder.body(data.clone())) + }) + }); + + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from(data2)); +} + +const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World"; + +#[actix_rt::test] +async fn test_h1_body() { + let mut srv = TestServer::start(|| { + HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + }); + + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_h1_head_empty() { + let mut srv = TestServer::start(|| { + HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + }); + + let response = srv.head("/").send().await.unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert!(bytes.is_empty()); +} + +#[actix_rt::test] +async fn test_h1_head_binary() { + let mut srv = TestServer::start(|| { + HttpService::build().h1(|_| { + ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).body(STR)) + }) + }); + + let response = srv.head("/").send().await.unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert!(bytes.is_empty()); +} + +#[actix_rt::test] +async fn test_h1_head_binary2() { + let srv = TestServer::start(|| { + HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + }); + + let response = srv.head("/").send().await.unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } +} + +#[actix_rt::test] +async fn test_h1_body_length() { + let mut srv = TestServer::start(|| { + HttpService::build().h1(|_| { + let body = once(ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)), + ) + }) + }); + + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_h1_body_chunked_explicit() { + let mut srv = TestServer::start(|| { + HttpService::build().h1(|_| { + let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .header(header::TRANSFER_ENCODING, "chunked") + .streaming(body), + ) + }) + }); + + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!( + response + .headers() + .get(header::TRANSFER_ENCODING) + .unwrap() + .to_str() + .unwrap(), + "chunked" + ); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + + // decode + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_h1_body_chunked_implicit() { + let mut srv = TestServer::start(|| { + HttpService::build().h1(|_| { + let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>(Response::Ok().streaming(body)) + }) + }); + + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!( + response + .headers() + .get(header::TRANSFER_ENCODING) + .unwrap() + .to_str() + .unwrap(), + "chunked" + ); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_h1_response_http_error_handling() { + let mut srv = TestServer::start(|| { + HttpService::build().h1(factory_fn_cfg(|_: &ServerConfig| { + ok::<_, ()>(pipeline(|_| { + let broken_header = Bytes::from_static(b"\0\0\0"); + ok::<_, ()>( + Response::Ok() + .header(http::header::CONTENT_TYPE, broken_header) + .body(STR), + ) + })) + })) + }); + + let response = srv.get("/").send().await.unwrap(); + assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); +} + +#[actix_rt::test] +async fn test_h1_service_error() { + let mut srv = TestServer::start(|| { + HttpService::build() + .h1(|_| future::err::(error::ErrorBadRequest("error"))) + }); + + let response = srv.get("/").send().await.unwrap(); + assert_eq!(response.status(), http::StatusCode::BAD_REQUEST); + + // read response + let bytes = srv.load_body(response).await.unwrap(); + assert_eq!(bytes, Bytes::from_static(b"error")); +} + +#[actix_rt::test] +async fn test_h1_on_connect() { + let srv = TestServer::start(|| { + HttpService::build() + .on_connect(|_| 10usize) + .h1(|req: Request| { + assert!(req.extensions().contains::()); + future::ok::<_, ()>(Response::Ok().finish()) + }) + }); + + let response = srv.get("/").send().await.unwrap(); + assert!(response.status().is_success()); +} diff --git a/actix-http/tests/test_ws.rs b/actix-http/tests/test_ws.rs new file mode 100644 index 000000000..aa81bc41b --- /dev/null +++ b/actix-http/tests/test_ws.rs @@ -0,0 +1,84 @@ +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_http::{body, h1, ws, Error, HttpService, Request, Response}; +use actix_http_test::TestServer; +use actix_utils::framed::FramedTransport; +use bytes::{Bytes, BytesMut}; +use futures::future; +use futures::{SinkExt, StreamExt}; + +async fn ws_service( + (req, mut framed): (Request, Framed), +) -> Result<(), Error> { + let res = ws::handshake(req.head()).unwrap().message_body(()); + + framed + .send((res, body::BodySize::None).into()) + .await + .unwrap(); + + FramedTransport::new(framed.into_framed(ws::Codec::new()), service) + .await + .map_err(|_| panic!()) +} + +async fn service(msg: ws::Frame) -> Result { + let msg = match msg { + ws::Frame::Ping(msg) => ws::Message::Pong(msg), + ws::Frame::Text(text) => { + ws::Message::Text(String::from_utf8_lossy(&text.unwrap()).to_string()) + } + ws::Frame::Binary(bin) => ws::Message::Binary(bin.unwrap().freeze()), + ws::Frame::Close(reason) => ws::Message::Close(reason), + _ => panic!(), + }; + Ok(msg) +} + +#[actix_rt::test] +async fn test_simple() { + let mut srv = TestServer::start(|| { + HttpService::build() + .upgrade(actix_service::service_fn(ws_service)) + .finish(|_| future::ok::<_, ()>(Response::NotFound())) + }); + + // client service + let mut framed = srv.ws().await.unwrap(); + framed + .send(ws::Message::Text("text".to_string())) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Text(Some(BytesMut::from("text"))) + ); + + framed + .send(ws::Message::Binary("text".into())) + .await + .unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Binary(Some(Bytes::from_static(b"text").into())) + ); + + framed.send(ws::Message::Ping("text".into())).await.unwrap(); + let (item, mut framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Pong("text".to_string().into()) + ); + + framed + .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) + .await + .unwrap(); + + let (item, _framed) = framed.into_future().await; + assert_eq!( + item.unwrap().unwrap(), + ws::Frame::Close(Some(ws::CloseCode::Normal.into())) + ); +} diff --git a/actix-identity/CHANGES.md b/actix-identity/CHANGES.md new file mode 100644 index 000000000..74a204055 --- /dev/null +++ b/actix-identity/CHANGES.md @@ -0,0 +1,5 @@ +# Changes + +## [0.1.0] - 2019-06-xx + +* Move identity middleware to separate crate diff --git a/actix-identity/Cargo.toml b/actix-identity/Cargo.toml new file mode 100644 index 000000000..d05b37685 --- /dev/null +++ b/actix-identity/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "actix-identity" +version = "0.2.0-alpha.1" +authors = ["Nikolay Kim "] +description = "Identity service for actix web framework." +readme = "README.md" +keywords = ["http", "web", "framework", "async", "futures"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" +documentation = "https://docs.rs/actix-identity/" +license = "MIT/Apache-2.0" +edition = "2018" +workspace = ".." + +[lib] +name = "actix_identity" +path = "src/lib.rs" + +[dependencies] +actix-web = { version = "2.0.0-alpha.1", default-features = false, features = ["secure-cookies"] } +actix-service = "1.0.0-alpha.1" +futures = "0.3.1" +serde = "1.0" +serde_json = "1.0" +time = "0.1.42" + +[dev-dependencies] +actix-rt = "1.0.0-alpha.1" +actix-http = "0.3.0-alpha.1" +bytes = "0.4" \ No newline at end of file diff --git a/actix-identity/LICENSE-APACHE b/actix-identity/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/actix-identity/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/actix-identity/LICENSE-MIT b/actix-identity/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/actix-identity/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/actix-identity/README.md b/actix-identity/README.md new file mode 100644 index 000000000..60b615c76 --- /dev/null +++ b/actix-identity/README.md @@ -0,0 +1,9 @@ +# Identity service for actix web framework [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-identity)](https://crates.io/crates/actix-identity) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +## Documentation & community resources + +* [User Guide](https://actix.rs/docs/) +* [API Documentation](https://docs.rs/actix-identity/) +* [Chat on gitter](https://gitter.im/actix/actix) +* Cargo package: [actix-session](https://crates.io/crates/actix-identity) +* Minimum supported Rust version: 1.34 or later diff --git a/actix-identity/src/lib.rs b/actix-identity/src/lib.rs new file mode 100644 index 000000000..5dfd2ae65 --- /dev/null +++ b/actix-identity/src/lib.rs @@ -0,0 +1,1072 @@ +//! Request identity service for Actix applications. +//! +//! [**IdentityService**](struct.IdentityService.html) middleware can be +//! used with different policies types to store identity information. +//! +//! By default, only cookie identity policy is implemented. Other backend +//! implementations can be added separately. +//! +//! [**CookieIdentityPolicy**](struct.CookieIdentityPolicy.html) +//! uses cookies as identity storage. +//! +//! To access current request identity +//! [**Identity**](struct.Identity.html) extractor should be used. +//! +//! ```rust +//! use actix_web::*; +//! use actix_identity::{Identity, CookieIdentityPolicy, IdentityService}; +//! +//! async fn index(id: Identity) -> String { +//! // access request identity +//! if let Some(id) = id.identity() { +//! format!("Welcome! {}", id) +//! } else { +//! "Welcome Anonymous!".to_owned() +//! } +//! } +//! +//! async fn login(id: Identity) -> HttpResponse { +//! id.remember("User1".to_owned()); // <- remember identity +//! HttpResponse::Ok().finish() +//! } +//! +//! async fn logout(id: Identity) -> HttpResponse { +//! id.forget(); // <- remove identity +//! HttpResponse::Ok().finish() +//! } +//! +//! fn main() { +//! let app = App::new().wrap(IdentityService::new( +//! // <- create identity middleware +//! CookieIdentityPolicy::new(&[0; 32]) // <- create cookie identity policy +//! .name("auth-cookie") +//! .secure(false))) +//! .service(web::resource("/index.html").to(index)) +//! .service(web::resource("/login.html").to(login)) +//! .service(web::resource("/logout.html").to(logout)); +//! } +//! ``` +use std::cell::RefCell; +use std::future::Future; +use std::rc::Rc; +use std::task::{Context, Poll}; +use std::time::SystemTime; + +use actix_service::{Service, Transform}; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; +use serde::{Deserialize, Serialize}; +use time::Duration; + +use actix_web::cookie::{Cookie, CookieJar, Key, SameSite}; +use actix_web::dev::{Extensions, Payload, ServiceRequest, ServiceResponse}; +use actix_web::error::{Error, Result}; +use actix_web::http::header::{self, HeaderValue}; +use actix_web::{FromRequest, HttpMessage, HttpRequest}; + +/// The extractor type to obtain your identity from a request. +/// +/// ```rust +/// use actix_web::*; +/// use actix_identity::Identity; +/// +/// fn index(id: Identity) -> Result { +/// // access request identity +/// if let Some(id) = id.identity() { +/// Ok(format!("Welcome! {}", id)) +/// } else { +/// Ok("Welcome Anonymous!".to_owned()) +/// } +/// } +/// +/// fn login(id: Identity) -> HttpResponse { +/// id.remember("User1".to_owned()); // <- remember identity +/// HttpResponse::Ok().finish() +/// } +/// +/// fn logout(id: Identity) -> HttpResponse { +/// id.forget(); // <- remove identity +/// HttpResponse::Ok().finish() +/// } +/// # fn main() {} +/// ``` +#[derive(Clone)] +pub struct Identity(HttpRequest); + +impl Identity { + /// Return the claimed identity of the user associated request or + /// ``None`` if no identity can be found associated with the request. + pub fn identity(&self) -> Option { + Identity::get_identity(&self.0.extensions()) + } + + /// Remember identity. + pub fn remember(&self, identity: String) { + if let Some(id) = self.0.extensions_mut().get_mut::() { + id.id = Some(identity); + id.changed = true; + } + } + + /// This method is used to 'forget' the current identity on subsequent + /// requests. + pub fn forget(&self) { + if let Some(id) = self.0.extensions_mut().get_mut::() { + id.id = None; + id.changed = true; + } + } + + fn get_identity(extensions: &Extensions) -> Option { + if let Some(id) = extensions.get::() { + id.id.clone() + } else { + None + } + } +} + +struct IdentityItem { + id: Option, + changed: bool, +} + +/// Helper trait that allows to get Identity. +/// +/// It could be used in middleware but identity policy must be set before any other middleware that needs identity +/// RequestIdentity is implemented both for `ServiceRequest` and `HttpRequest`. +pub trait RequestIdentity { + fn get_identity(&self) -> Option; +} + +impl RequestIdentity for T +where + T: HttpMessage, +{ + fn get_identity(&self) -> Option { + Identity::get_identity(&self.extensions()) + } +} + +/// Extractor implementation for Identity type. +/// +/// ```rust +/// # use actix_web::*; +/// use actix_identity::Identity; +/// +/// fn index(id: Identity) -> String { +/// // access request identity +/// if let Some(id) = id.identity() { +/// format!("Welcome! {}", id) +/// } else { +/// "Welcome Anonymous!".to_owned() +/// } +/// } +/// # fn main() {} +/// ``` +impl FromRequest for Identity { + type Config = (); + type Error = Error; + type Future = Ready>; + + #[inline] + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + ok(Identity(req.clone())) + } +} + +/// Identity policy definition. +pub trait IdentityPolicy: Sized + 'static { + /// The return type of the middleware + type Future: Future, Error>>; + + /// The return type of the middleware + type ResponseFuture: Future>; + + /// Parse the session from request and load data from a service identity. + fn from_request(&self, request: &mut ServiceRequest) -> Self::Future; + + /// Write changes to response + fn to_response( + &self, + identity: Option, + changed: bool, + response: &mut ServiceResponse, + ) -> Self::ResponseFuture; +} + +/// Request identity middleware +/// +/// ```rust +/// use actix_web::App; +/// use actix_identity::{CookieIdentityPolicy, IdentityService}; +/// +/// fn main() { +/// let app = App::new().wrap(IdentityService::new( +/// // <- create identity middleware +/// CookieIdentityPolicy::new(&[0; 32]) // <- create cookie session backend +/// .name("auth-cookie") +/// .secure(false), +/// )); +/// } +/// ``` +pub struct IdentityService { + backend: Rc, +} + +impl IdentityService { + /// Create new identity service with specified backend. + pub fn new(backend: T) -> Self { + IdentityService { + backend: Rc::new(backend), + } + } +} + +impl Transform for IdentityService +where + S: Service, Error = Error> + + 'static, + S::Future: 'static, + T: IdentityPolicy, + B: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Transform = IdentityServiceMiddleware; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ok(IdentityServiceMiddleware { + backend: self.backend.clone(), + service: Rc::new(RefCell::new(service)), + }) + } +} + +#[doc(hidden)] +pub struct IdentityServiceMiddleware { + backend: Rc, + service: Rc>, +} + +impl Service for IdentityServiceMiddleware +where + B: 'static, + S: Service, Error = Error> + + 'static, + S::Future: 'static, + T: IdentityPolicy, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.borrow_mut().poll_ready(cx) + } + + fn call(&mut self, mut req: ServiceRequest) -> Self::Future { + let srv = self.service.clone(); + let backend = self.backend.clone(); + let fut = self.backend.from_request(&mut req); + + async move { + match fut.await { + Ok(id) => { + req.extensions_mut() + .insert(IdentityItem { id, changed: false }); + + let mut res = srv.borrow_mut().call(req).await?; + let id = res.request().extensions_mut().remove::(); + + if let Some(id) = id { + match backend.to_response(id.id, id.changed, &mut res).await { + Ok(_) => Ok(res), + Err(e) => Ok(res.error_response(e)), + } + } else { + Ok(res) + } + } + Err(err) => Ok(req.error_response(err)), + } + } + .boxed_local() + } +} + +struct CookieIdentityInner { + key: Key, + key_v2: Key, + name: String, + path: String, + domain: Option, + secure: bool, + max_age: Option, + same_site: Option, + visit_deadline: Option, + login_deadline: Option, +} + +#[derive(Deserialize, Serialize, Debug)] +struct CookieValue { + identity: String, + #[serde(skip_serializing_if = "Option::is_none")] + login_timestamp: Option, + #[serde(skip_serializing_if = "Option::is_none")] + visit_timestamp: Option, +} + +#[derive(Debug)] +struct CookieIdentityExtention { + login_timestamp: Option, +} + +impl CookieIdentityInner { + fn new(key: &[u8]) -> CookieIdentityInner { + let key_v2: Vec = key.iter().chain([1, 0, 0, 0].iter()).cloned().collect(); + CookieIdentityInner { + key: Key::from_master(key), + key_v2: Key::from_master(&key_v2), + name: "actix-identity".to_owned(), + path: "/".to_owned(), + domain: None, + secure: true, + max_age: None, + same_site: None, + visit_deadline: None, + login_deadline: None, + } + } + + fn set_cookie( + &self, + resp: &mut ServiceResponse, + value: Option, + ) -> Result<()> { + let add_cookie = value.is_some(); + let val = value.map(|val| { + if !self.legacy_supported() { + serde_json::to_string(&val) + } else { + Ok(val.identity) + } + }); + let mut cookie = + Cookie::new(self.name.clone(), val.unwrap_or_else(|| Ok(String::new()))?); + cookie.set_path(self.path.clone()); + cookie.set_secure(self.secure); + cookie.set_http_only(true); + + if let Some(ref domain) = self.domain { + cookie.set_domain(domain.clone()); + } + + if let Some(max_age) = self.max_age { + cookie.set_max_age(max_age); + } + + if let Some(same_site) = self.same_site { + cookie.set_same_site(same_site); + } + + let mut jar = CookieJar::new(); + let key = if self.legacy_supported() { + &self.key + } else { + &self.key_v2 + }; + if add_cookie { + jar.private(&key).add(cookie); + } else { + jar.add_original(cookie.clone()); + jar.private(&key).remove(cookie); + } + for cookie in jar.delta() { + let val = HeaderValue::from_str(&cookie.to_string())?; + resp.headers_mut().append(header::SET_COOKIE, val); + } + Ok(()) + } + + fn load(&self, req: &ServiceRequest) -> Option { + let cookie = req.cookie(&self.name)?; + let mut jar = CookieJar::new(); + jar.add_original(cookie.clone()); + let res = if self.legacy_supported() { + jar.private(&self.key).get(&self.name).map(|n| CookieValue { + identity: n.value().to_string(), + login_timestamp: None, + visit_timestamp: None, + }) + } else { + None + }; + res.or_else(|| { + jar.private(&self.key_v2) + .get(&self.name) + .and_then(|c| self.parse(c)) + }) + } + + fn parse(&self, cookie: Cookie) -> Option { + let value: CookieValue = serde_json::from_str(cookie.value()).ok()?; + let now = SystemTime::now(); + if let Some(visit_deadline) = self.visit_deadline { + if now.duration_since(value.visit_timestamp?).ok()? + > visit_deadline.to_std().ok()? + { + return None; + } + } + if let Some(login_deadline) = self.login_deadline { + if now.duration_since(value.login_timestamp?).ok()? + > login_deadline.to_std().ok()? + { + return None; + } + } + Some(value) + } + + fn legacy_supported(&self) -> bool { + self.visit_deadline.is_none() && self.login_deadline.is_none() + } + + fn always_update_cookie(&self) -> bool { + self.visit_deadline.is_some() + } + + fn requires_oob_data(&self) -> bool { + self.login_deadline.is_some() + } +} + +/// Use cookies for request identity storage. +/// +/// The constructors take a key as an argument. +/// This is the private key for cookie - when this value is changed, +/// all identities are lost. The constructors will panic if the key is less +/// than 32 bytes in length. +/// +/// # Example +/// +/// ```rust +/// use actix_web::App; +/// use actix_identity::{CookieIdentityPolicy, IdentityService}; +/// +/// fn main() { +/// let app = App::new().wrap(IdentityService::new( +/// // <- create identity middleware +/// CookieIdentityPolicy::new(&[0; 32]) // <- construct cookie policy +/// .domain("www.rust-lang.org") +/// .name("actix_auth") +/// .path("/") +/// .secure(true), +/// )); +/// } +/// ``` +pub struct CookieIdentityPolicy(Rc); + +impl CookieIdentityPolicy { + /// Construct new `CookieIdentityPolicy` instance. + /// + /// Panics if key length is less than 32 bytes. + pub fn new(key: &[u8]) -> CookieIdentityPolicy { + CookieIdentityPolicy(Rc::new(CookieIdentityInner::new(key))) + } + + /// Sets the `path` field in the session cookie being built. + pub fn path>(mut self, value: S) -> CookieIdentityPolicy { + Rc::get_mut(&mut self.0).unwrap().path = value.into(); + self + } + + /// Sets the `name` field in the session cookie being built. + pub fn name>(mut self, value: S) -> CookieIdentityPolicy { + Rc::get_mut(&mut self.0).unwrap().name = value.into(); + self + } + + /// Sets the `domain` field in the session cookie being built. + pub fn domain>(mut self, value: S) -> CookieIdentityPolicy { + Rc::get_mut(&mut self.0).unwrap().domain = Some(value.into()); + self + } + + /// Sets the `secure` field in the session cookie being built. + /// + /// If the `secure` field is set, a cookie will only be transmitted when the + /// connection is secure - i.e. `https` + pub fn secure(mut self, value: bool) -> CookieIdentityPolicy { + Rc::get_mut(&mut self.0).unwrap().secure = value; + self + } + + /// Sets the `max-age` field in the session cookie being built with given number of seconds. + pub fn max_age(self, seconds: i64) -> CookieIdentityPolicy { + self.max_age_time(Duration::seconds(seconds)) + } + + /// Sets the `max-age` field in the session cookie being built with `chrono::Duration`. + pub fn max_age_time(mut self, value: Duration) -> CookieIdentityPolicy { + Rc::get_mut(&mut self.0).unwrap().max_age = Some(value); + self + } + + /// Sets the `same_site` field in the session cookie being built. + pub fn same_site(mut self, same_site: SameSite) -> Self { + Rc::get_mut(&mut self.0).unwrap().same_site = Some(same_site); + self + } + + /// Accepts only users whose cookie has been seen before the given deadline + /// + /// By default visit deadline is disabled. + pub fn visit_deadline(mut self, value: Duration) -> CookieIdentityPolicy { + Rc::get_mut(&mut self.0).unwrap().visit_deadline = Some(value); + self + } + + /// Accepts only users which has been authenticated before the given deadline + /// + /// By default login deadline is disabled. + pub fn login_deadline(mut self, value: Duration) -> CookieIdentityPolicy { + Rc::get_mut(&mut self.0).unwrap().login_deadline = Some(value); + self + } +} + +impl IdentityPolicy for CookieIdentityPolicy { + type Future = Ready, Error>>; + type ResponseFuture = Ready>; + + fn from_request(&self, req: &mut ServiceRequest) -> Self::Future { + ok(self.0.load(req).map( + |CookieValue { + identity, + login_timestamp, + .. + }| { + if self.0.requires_oob_data() { + req.extensions_mut() + .insert(CookieIdentityExtention { login_timestamp }); + } + identity + }, + )) + } + + fn to_response( + &self, + id: Option, + changed: bool, + res: &mut ServiceResponse, + ) -> Self::ResponseFuture { + let _ = if changed { + let login_timestamp = SystemTime::now(); + self.0.set_cookie( + res, + id.map(|identity| CookieValue { + identity, + login_timestamp: self.0.login_deadline.map(|_| login_timestamp), + visit_timestamp: self.0.visit_deadline.map(|_| login_timestamp), + }), + ) + } else if self.0.always_update_cookie() && id.is_some() { + let visit_timestamp = SystemTime::now(); + let login_timestamp = if self.0.requires_oob_data() { + let CookieIdentityExtention { + login_timestamp: lt, + } = res.request().extensions_mut().remove().unwrap(); + lt + } else { + None + }; + self.0.set_cookie( + res, + Some(CookieValue { + identity: id.unwrap(), + login_timestamp, + visit_timestamp: self.0.visit_deadline.map(|_| visit_timestamp), + }), + ) + } else { + Ok(()) + }; + ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::borrow::Borrow; + + use super::*; + use actix_web::http::StatusCode; + use actix_web::test::{self, TestRequest}; + use actix_web::{web, App, Error, HttpResponse}; + + const COOKIE_KEY_MASTER: [u8; 32] = [0; 32]; + const COOKIE_NAME: &'static str = "actix_auth"; + const COOKIE_LOGIN: &'static str = "test"; + + #[actix_rt::test] + async fn test_identity() { + let mut srv = test::init_service( + App::new() + .wrap(IdentityService::new( + CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) + .domain("www.rust-lang.org") + .name(COOKIE_NAME) + .path("/") + .secure(true), + )) + .service(web::resource("/index").to(|id: Identity| { + if id.identity().is_some() { + HttpResponse::Created() + } else { + HttpResponse::Ok() + } + })) + .service(web::resource("/login").to(|id: Identity| { + id.remember(COOKIE_LOGIN.to_string()); + HttpResponse::Ok() + })) + .service(web::resource("/logout").to(|id: Identity| { + if id.identity().is_some() { + id.forget(); + HttpResponse::Ok() + } else { + HttpResponse::BadRequest() + } + })), + ) + .await; + let resp = + test::call_service(&mut srv, TestRequest::with_uri("/index").to_request()) + .await; + assert_eq!(resp.status(), StatusCode::OK); + + let resp = + test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()) + .await; + assert_eq!(resp.status(), StatusCode::OK); + let c = resp.response().cookies().next().unwrap().to_owned(); + + let resp = test::call_service( + &mut srv, + TestRequest::with_uri("/index") + .cookie(c.clone()) + .to_request(), + ) + .await; + assert_eq!(resp.status(), StatusCode::CREATED); + + let resp = test::call_service( + &mut srv, + TestRequest::with_uri("/logout") + .cookie(c.clone()) + .to_request(), + ) + .await; + assert_eq!(resp.status(), StatusCode::OK); + assert!(resp.headers().contains_key(header::SET_COOKIE)) + } + + #[actix_rt::test] + async fn test_identity_max_age_time() { + let duration = Duration::days(1); + let mut srv = test::init_service( + App::new() + .wrap(IdentityService::new( + CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) + .domain("www.rust-lang.org") + .name(COOKIE_NAME) + .path("/") + .max_age_time(duration) + .secure(true), + )) + .service(web::resource("/login").to(|id: Identity| { + id.remember("test".to_string()); + HttpResponse::Ok() + })), + ) + .await; + let resp = + test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()) + .await; + assert_eq!(resp.status(), StatusCode::OK); + assert!(resp.headers().contains_key(header::SET_COOKIE)); + let c = resp.response().cookies().next().unwrap().to_owned(); + assert_eq!(duration, c.max_age().unwrap()); + } + + #[actix_rt::test] + async fn test_identity_max_age() { + let seconds = 60; + let mut srv = test::init_service( + App::new() + .wrap(IdentityService::new( + CookieIdentityPolicy::new(&COOKIE_KEY_MASTER) + .domain("www.rust-lang.org") + .name(COOKIE_NAME) + .path("/") + .max_age(seconds) + .secure(true), + )) + .service(web::resource("/login").to(|id: Identity| { + id.remember("test".to_string()); + HttpResponse::Ok() + })), + ) + .await; + let resp = + test::call_service(&mut srv, TestRequest::with_uri("/login").to_request()) + .await; + assert_eq!(resp.status(), StatusCode::OK); + assert!(resp.headers().contains_key(header::SET_COOKIE)); + let c = resp.response().cookies().next().unwrap().to_owned(); + assert_eq!(Duration::seconds(seconds as i64), c.max_age().unwrap()); + } + + async fn create_identity_server< + F: Fn(CookieIdentityPolicy) -> CookieIdentityPolicy + Sync + Send + Clone + 'static, + >( + f: F, + ) -> impl actix_service::Service< + Request = actix_http::Request, + Response = ServiceResponse, + Error = Error, + > { + test::init_service( + App::new() + .wrap(IdentityService::new(f(CookieIdentityPolicy::new( + &COOKIE_KEY_MASTER, + ) + .secure(false) + .name(COOKIE_NAME)))) + .service(web::resource("/").to(|id: Identity| { + async move { + let identity = id.identity(); + if identity.is_none() { + id.remember(COOKIE_LOGIN.to_string()) + } + web::Json(identity) + } + })), + ) + .await + } + + fn legacy_login_cookie(identity: &'static str) -> Cookie<'static> { + let mut jar = CookieJar::new(); + jar.private(&Key::from_master(&COOKIE_KEY_MASTER)) + .add(Cookie::new(COOKIE_NAME, identity)); + jar.get(COOKIE_NAME).unwrap().clone() + } + + fn login_cookie( + identity: &'static str, + login_timestamp: Option, + visit_timestamp: Option, + ) -> Cookie<'static> { + let mut jar = CookieJar::new(); + let key: Vec = COOKIE_KEY_MASTER + .iter() + .chain([1, 0, 0, 0].iter()) + .map(|e| *e) + .collect(); + jar.private(&Key::from_master(&key)).add(Cookie::new( + COOKIE_NAME, + serde_json::to_string(&CookieValue { + identity: identity.to_string(), + login_timestamp, + visit_timestamp, + }) + .unwrap(), + )); + jar.get(COOKIE_NAME).unwrap().clone() + } + + async fn assert_logged_in(response: ServiceResponse, identity: Option<&str>) { + let bytes = test::read_body(response).await; + let resp: Option = serde_json::from_slice(&bytes[..]).unwrap(); + assert_eq!(resp.as_ref().map(|s| s.borrow()), identity); + } + + fn assert_legacy_login_cookie(response: &mut ServiceResponse, identity: &str) { + let mut cookies = CookieJar::new(); + for cookie in response.headers().get_all(header::SET_COOKIE) { + cookies.add(Cookie::parse(cookie.to_str().unwrap().to_string()).unwrap()); + } + let cookie = cookies + .private(&Key::from_master(&COOKIE_KEY_MASTER)) + .get(COOKIE_NAME) + .unwrap(); + assert_eq!(cookie.value(), identity); + } + + enum LoginTimestampCheck { + NoTimestamp, + NewTimestamp, + OldTimestamp(SystemTime), + } + + enum VisitTimeStampCheck { + NoTimestamp, + NewTimestamp, + } + + fn assert_login_cookie( + response: &mut ServiceResponse, + identity: &str, + login_timestamp: LoginTimestampCheck, + visit_timestamp: VisitTimeStampCheck, + ) { + let mut cookies = CookieJar::new(); + for cookie in response.headers().get_all(header::SET_COOKIE) { + cookies.add(Cookie::parse(cookie.to_str().unwrap().to_string()).unwrap()); + } + let key: Vec = COOKIE_KEY_MASTER + .iter() + .chain([1, 0, 0, 0].iter()) + .map(|e| *e) + .collect(); + let cookie = cookies + .private(&Key::from_master(&key)) + .get(COOKIE_NAME) + .unwrap(); + let cv: CookieValue = serde_json::from_str(cookie.value()).unwrap(); + assert_eq!(cv.identity, identity); + let now = SystemTime::now(); + let t30sec_ago = now - Duration::seconds(30).to_std().unwrap(); + match login_timestamp { + LoginTimestampCheck::NoTimestamp => assert_eq!(cv.login_timestamp, None), + LoginTimestampCheck::NewTimestamp => assert!( + t30sec_ago <= cv.login_timestamp.unwrap() + && cv.login_timestamp.unwrap() <= now + ), + LoginTimestampCheck::OldTimestamp(old_timestamp) => { + assert_eq!(cv.login_timestamp, Some(old_timestamp)) + } + } + match visit_timestamp { + VisitTimeStampCheck::NoTimestamp => assert_eq!(cv.visit_timestamp, None), + VisitTimeStampCheck::NewTimestamp => assert!( + t30sec_ago <= cv.visit_timestamp.unwrap() + && cv.visit_timestamp.unwrap() <= now + ), + } + } + + fn assert_no_login_cookie(response: &mut ServiceResponse) { + let mut cookies = CookieJar::new(); + for cookie in response.headers().get_all(header::SET_COOKIE) { + cookies.add(Cookie::parse(cookie.to_str().unwrap().to_string()).unwrap()); + } + assert!(cookies.get(COOKIE_NAME).is_none()); + } + + #[actix_rt::test] + async fn test_identity_legacy_cookie_is_set() { + let mut srv = create_identity_server(|c| c).await; + let mut resp = + test::call_service(&mut srv, TestRequest::with_uri("/").to_request()).await; + assert_legacy_login_cookie(&mut resp, COOKIE_LOGIN); + assert_logged_in(resp, None).await; + } + + #[actix_rt::test] + async fn test_identity_legacy_cookie_works() { + let mut srv = create_identity_server(|c| c).await; + let cookie = legacy_login_cookie(COOKIE_LOGIN); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_no_login_cookie(&mut resp); + assert_logged_in(resp, Some(COOKIE_LOGIN)).await; + } + + #[actix_rt::test] + async fn test_identity_legacy_cookie_rejected_if_visit_timestamp_needed() { + let mut srv = + create_identity_server(|c| c.visit_deadline(Duration::days(90))).await; + let cookie = legacy_login_cookie(COOKIE_LOGIN); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NoTimestamp, + VisitTimeStampCheck::NewTimestamp, + ); + assert_logged_in(resp, None).await; + } + + #[actix_rt::test] + async fn test_identity_legacy_cookie_rejected_if_login_timestamp_needed() { + let mut srv = + create_identity_server(|c| c.login_deadline(Duration::days(90))).await; + let cookie = legacy_login_cookie(COOKIE_LOGIN); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NewTimestamp, + VisitTimeStampCheck::NoTimestamp, + ); + assert_logged_in(resp, None).await; + } + + #[actix_rt::test] + async fn test_identity_cookie_rejected_if_login_timestamp_needed() { + let mut srv = + create_identity_server(|c| c.login_deadline(Duration::days(90))).await; + let cookie = login_cookie(COOKIE_LOGIN, None, Some(SystemTime::now())); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NewTimestamp, + VisitTimeStampCheck::NoTimestamp, + ); + assert_logged_in(resp, None).await; + } + + #[actix_rt::test] + async fn test_identity_cookie_rejected_if_visit_timestamp_needed() { + let mut srv = + create_identity_server(|c| c.visit_deadline(Duration::days(90))).await; + let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NoTimestamp, + VisitTimeStampCheck::NewTimestamp, + ); + assert_logged_in(resp, None).await; + } + + #[actix_rt::test] + async fn test_identity_cookie_rejected_if_login_timestamp_too_old() { + let mut srv = + create_identity_server(|c| c.login_deadline(Duration::days(90))).await; + let cookie = login_cookie( + COOKIE_LOGIN, + Some(SystemTime::now() - Duration::days(180).to_std().unwrap()), + None, + ); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NewTimestamp, + VisitTimeStampCheck::NoTimestamp, + ); + assert_logged_in(resp, None).await; + } + + #[actix_rt::test] + async fn test_identity_cookie_rejected_if_visit_timestamp_too_old() { + let mut srv = + create_identity_server(|c| c.visit_deadline(Duration::days(90))).await; + let cookie = login_cookie( + COOKIE_LOGIN, + None, + Some(SystemTime::now() - Duration::days(180).to_std().unwrap()), + ); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::NoTimestamp, + VisitTimeStampCheck::NewTimestamp, + ); + assert_logged_in(resp, None).await; + } + + #[actix_rt::test] + async fn test_identity_cookie_not_updated_on_login_deadline() { + let mut srv = + create_identity_server(|c| c.login_deadline(Duration::days(90))).await; + let cookie = login_cookie(COOKIE_LOGIN, Some(SystemTime::now()), None); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_no_login_cookie(&mut resp); + assert_logged_in(resp, Some(COOKIE_LOGIN)).await; + } + + #[actix_rt::test] + async fn test_identity_cookie_updated_on_visit_deadline() { + let mut srv = create_identity_server(|c| { + c.visit_deadline(Duration::days(90)) + .login_deadline(Duration::days(90)) + }) + .await; + let timestamp = SystemTime::now() - Duration::days(1).to_std().unwrap(); + let cookie = login_cookie(COOKIE_LOGIN, Some(timestamp), Some(timestamp)); + let mut resp = test::call_service( + &mut srv, + TestRequest::with_uri("/") + .cookie(cookie.clone()) + .to_request(), + ) + .await; + assert_login_cookie( + &mut resp, + COOKIE_LOGIN, + LoginTimestampCheck::OldTimestamp(timestamp), + VisitTimeStampCheck::NewTimestamp, + ); + assert_logged_in(resp, Some(COOKIE_LOGIN)).await; + } +} diff --git a/actix-multipart/CHANGES.md b/actix-multipart/CHANGES.md new file mode 100644 index 000000000..ca61176c7 --- /dev/null +++ b/actix-multipart/CHANGES.md @@ -0,0 +1,35 @@ +# Changes + +## [0.1.4] - 2019-09-12 + +* Multipart handling now parses requests which do not end in CRLF #1038 + +## [0.1.3] - 2019-08-18 + +* Fix ring dependency from actix-web default features for #741. + +## [0.1.2] - 2019-06-02 + +* Fix boundary parsing #876 + +## [0.1.1] - 2019-05-25 + +* Fix disconnect handling #834 + +## [0.1.0] - 2019-05-18 + +* Release + +## [0.1.0-beta.4] - 2019-05-12 + +* Handle cancellation of uploads #736 + +* Upgrade to actix-web 1.0.0-beta.4 + +## [0.1.0-beta.1] - 2019-04-21 + +* Do not support nested multipart + +* Split multipart support to separate crate + +* Optimize multipart handling #634, #769 \ No newline at end of file diff --git a/actix-multipart/Cargo.toml b/actix-multipart/Cargo.toml new file mode 100644 index 000000000..52b33d582 --- /dev/null +++ b/actix-multipart/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "actix-multipart" +version = "0.2.0-alpha.1" +authors = ["Nikolay Kim "] +description = "Multipart support for actix web framework." +readme = "README.md" +keywords = ["http", "web", "framework", "async", "futures"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" +documentation = "https://docs.rs/actix-multipart/" +license = "MIT/Apache-2.0" +exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"] +workspace = ".." +edition = "2018" + +[lib] +name = "actix_multipart" +path = "src/lib.rs" + +[dependencies] +actix-web = { version = "2.0.0-alpha.1", default-features = false } +actix-service = "1.0.0-alpha.1" +actix-utils = "0.5.0-alpha.1" +bytes = "0.4" +derive_more = "0.99.2" +httparse = "1.3" +futures = "0.3.1" +log = "0.4" +mime = "0.3" +time = "0.1" +twoway = "0.2" + +[dev-dependencies] +actix-rt = "1.0.0-alpha.1" +actix-http = "0.3.0-alpha.1" \ No newline at end of file diff --git a/actix-multipart/LICENSE-APACHE b/actix-multipart/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/actix-multipart/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/actix-multipart/LICENSE-MIT b/actix-multipart/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/actix-multipart/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/actix-multipart/README.md b/actix-multipart/README.md new file mode 100644 index 000000000..ac0d05640 --- /dev/null +++ b/actix-multipart/README.md @@ -0,0 +1,8 @@ +# Multipart support for actix web framework [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-multipart)](https://crates.io/crates/actix-multipart) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +## Documentation & community resources + +* [API Documentation](https://docs.rs/actix-multipart/) +* [Chat on gitter](https://gitter.im/actix/actix) +* Cargo package: [actix-multipart](https://crates.io/crates/actix-multipart) +* Minimum supported Rust version: 1.33 or later diff --git a/actix-multipart/src/error.rs b/actix-multipart/src/error.rs new file mode 100644 index 000000000..6677f69c7 --- /dev/null +++ b/actix-multipart/src/error.rs @@ -0,0 +1,53 @@ +//! Error and Result module +use actix_web::error::{ParseError, PayloadError}; +use actix_web::http::StatusCode; +use actix_web::ResponseError; +use derive_more::{Display, From}; + +/// A set of errors that can occur during parsing multipart streams +#[derive(Debug, Display, From)] +pub enum MultipartError { + /// Content-Type header is not found + #[display(fmt = "No Content-type header found")] + NoContentType, + /// Can not parse Content-Type header + #[display(fmt = "Can not parse Content-Type header")] + ParseContentType, + /// Multipart boundary is not found + #[display(fmt = "Multipart boundary is not found")] + Boundary, + /// Nested multipart is not supported + #[display(fmt = "Nested multipart is not supported")] + Nested, + /// Multipart stream is incomplete + #[display(fmt = "Multipart stream is incomplete")] + Incomplete, + /// Error during field parsing + #[display(fmt = "{}", _0)] + Parse(ParseError), + /// Payload error + #[display(fmt = "{}", _0)] + Payload(PayloadError), + /// Not consumed + #[display(fmt = "Multipart stream is not consumed")] + NotConsumed, +} + +/// Return `BadRequest` for `MultipartError` +impl ResponseError for MultipartError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + +#[cfg(test)] +mod tests { + use super::*; + use actix_web::HttpResponse; + + #[test] + fn test_multipart_error() { + let resp: HttpResponse = MultipartError::Boundary.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } +} diff --git a/actix-multipart/src/extractor.rs b/actix-multipart/src/extractor.rs new file mode 100644 index 000000000..71c815227 --- /dev/null +++ b/actix-multipart/src/extractor.rs @@ -0,0 +1,41 @@ +//! Multipart payload support +use actix_web::{dev::Payload, Error, FromRequest, HttpRequest}; +use futures::future::{ok, Ready}; + +use crate::server::Multipart; + +/// Get request's payload as multipart stream +/// +/// Content-type: multipart/form-data; +/// +/// ## Server example +/// +/// ```rust +/// use futures::{Stream, StreamExt}; +/// use actix_web::{web, HttpResponse, Error}; +/// use actix_multipart as mp; +/// +/// async fn index(mut payload: mp::Multipart) -> Result { +/// // iterate over multipart stream +/// while let Some(item) = payload.next().await { +/// let mut field = item?; +/// +/// // Field in turn is stream of *Bytes* object +/// while let Some(chunk) = field.next().await { +/// println!("-- CHUNK: \n{:?}", std::str::from_utf8(&chunk?)); +/// } +/// } +/// Ok(HttpResponse::Ok().into()) +/// } +/// # fn main() {} +/// ``` +impl FromRequest for Multipart { + type Error = Error; + type Future = Ready>; + type Config = (); + + #[inline] + fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { + ok(Multipart::new(req.headers(), payload.take())) + } +} diff --git a/actix-multipart/src/lib.rs b/actix-multipart/src/lib.rs new file mode 100644 index 000000000..43eb048ca --- /dev/null +++ b/actix-multipart/src/lib.rs @@ -0,0 +1,8 @@ +#![allow(clippy::borrow_interior_mutable_const)] + +mod error; +mod extractor; +mod server; + +pub use self::error::MultipartError; +pub use self::server::{Field, Multipart}; diff --git a/actix-multipart/src/server.rs b/actix-multipart/src/server.rs new file mode 100644 index 000000000..c49896761 --- /dev/null +++ b/actix-multipart/src/server.rs @@ -0,0 +1,1117 @@ +//! Multipart payload support +use std::cell::{Cell, RefCell, RefMut}; +use std::marker::PhantomData; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; +use std::{cmp, fmt}; + +use bytes::{Bytes, BytesMut}; +use futures::stream::{LocalBoxStream, Stream, StreamExt}; +use httparse; +use mime; + +use actix_utils::task::LocalWaker; +use actix_web::error::{ParseError, PayloadError}; +use actix_web::http::header::{ + self, ContentDisposition, HeaderMap, HeaderName, HeaderValue, +}; +use actix_web::http::HttpTryFrom; + +use crate::error::MultipartError; + +const MAX_HEADERS: usize = 32; + +/// The server-side implementation of `multipart/form-data` requests. +/// +/// This will parse the incoming stream into `MultipartItem` instances via its +/// Stream implementation. +/// `MultipartItem::Field` contains multipart field. `MultipartItem::Multipart` +/// is used for nested multipart streams. +pub struct Multipart { + safety: Safety, + error: Option, + inner: Option>>, +} + +enum InnerMultipartItem { + None, + Field(Rc>), +} + +#[derive(PartialEq, Debug)] +enum InnerState { + /// Stream eof + Eof, + /// Skip data until first boundary + FirstBoundary, + /// Reading boundary + Boundary, + /// Reading Headers, + Headers, +} + +struct InnerMultipart { + payload: PayloadRef, + boundary: String, + state: InnerState, + item: InnerMultipartItem, +} + +impl Multipart { + /// Create multipart instance for boundary. + pub fn new(headers: &HeaderMap, stream: S) -> Multipart + where + S: Stream> + Unpin + 'static, + { + match Self::boundary(headers) { + Ok(boundary) => Multipart { + error: None, + safety: Safety::new(), + inner: Some(Rc::new(RefCell::new(InnerMultipart { + boundary, + payload: PayloadRef::new(PayloadBuffer::new(Box::new(stream))), + state: InnerState::FirstBoundary, + item: InnerMultipartItem::None, + }))), + }, + Err(err) => Multipart { + error: Some(err), + safety: Safety::new(), + inner: None, + }, + } + } + + /// Extract boundary info from headers. + fn boundary(headers: &HeaderMap) -> Result { + if let Some(content_type) = headers.get(&header::CONTENT_TYPE) { + if let Ok(content_type) = content_type.to_str() { + if let Ok(ct) = content_type.parse::() { + if let Some(boundary) = ct.get_param(mime::BOUNDARY) { + Ok(boundary.as_str().to_owned()) + } else { + Err(MultipartError::Boundary) + } + } else { + Err(MultipartError::ParseContentType) + } + } else { + Err(MultipartError::ParseContentType) + } + } else { + Err(MultipartError::NoContentType) + } + } +} + +impl Stream for Multipart { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + if let Some(err) = self.error.take() { + Poll::Ready(Some(Err(err))) + } else if self.safety.current() { + let this = self.get_mut(); + let mut inner = this.inner.as_mut().unwrap().borrow_mut(); + if let Some(mut payload) = inner.payload.get_mut(&this.safety) { + payload.poll_stream(cx)?; + } + inner.poll(&this.safety, cx) + } else if !self.safety.is_clean() { + Poll::Ready(Some(Err(MultipartError::NotConsumed))) + } else { + Poll::Pending + } + } +} + +impl InnerMultipart { + fn read_headers( + payload: &mut PayloadBuffer, + ) -> Result, MultipartError> { + match payload.read_until(b"\r\n\r\n")? { + None => { + if payload.eof { + Err(MultipartError::Incomplete) + } else { + Ok(None) + } + } + Some(bytes) => { + let mut hdrs = [httparse::EMPTY_HEADER; MAX_HEADERS]; + match httparse::parse_headers(&bytes, &mut hdrs) { + Ok(httparse::Status::Complete((_, hdrs))) => { + // convert headers + let mut headers = HeaderMap::with_capacity(hdrs.len()); + for h in hdrs { + if let Ok(name) = HeaderName::try_from(h.name) { + if let Ok(value) = HeaderValue::try_from(h.value) { + headers.append(name, value); + } else { + return Err(ParseError::Header.into()); + } + } else { + return Err(ParseError::Header.into()); + } + } + Ok(Some(headers)) + } + Ok(httparse::Status::Partial) => Err(ParseError::Header.into()), + Err(err) => Err(ParseError::from(err).into()), + } + } + } + } + + fn read_boundary( + payload: &mut PayloadBuffer, + boundary: &str, + ) -> Result, MultipartError> { + // TODO: need to read epilogue + match payload.readline_or_eof()? { + None => { + if payload.eof { + Ok(Some(true)) + } else { + Ok(None) + } + } + Some(chunk) => { + if chunk.len() < boundary.len() + 4 + || &chunk[..2] != b"--" + || &chunk[2..boundary.len() + 2] != boundary.as_bytes() + { + Err(MultipartError::Boundary) + } else if &chunk[boundary.len() + 2..] == b"\r\n" { + Ok(Some(false)) + } else if &chunk[boundary.len() + 2..boundary.len() + 4] == b"--" + && (chunk.len() == boundary.len() + 4 + || &chunk[boundary.len() + 4..] == b"\r\n") + { + Ok(Some(true)) + } else { + Err(MultipartError::Boundary) + } + } + } + } + + fn skip_until_boundary( + payload: &mut PayloadBuffer, + boundary: &str, + ) -> Result, MultipartError> { + let mut eof = false; + loop { + match payload.readline()? { + Some(chunk) => { + if chunk.is_empty() { + return Err(MultipartError::Boundary); + } + if chunk.len() < boundary.len() { + continue; + } + if &chunk[..2] == b"--" + && &chunk[2..chunk.len() - 2] == boundary.as_bytes() + { + break; + } else { + if chunk.len() < boundary.len() + 2 { + continue; + } + let b: &[u8] = boundary.as_ref(); + if &chunk[..boundary.len()] == b + && &chunk[boundary.len()..boundary.len() + 2] == b"--" + { + eof = true; + break; + } + } + } + None => { + return if payload.eof { + Err(MultipartError::Incomplete) + } else { + Ok(None) + }; + } + } + } + Ok(Some(eof)) + } + + fn poll( + &mut self, + safety: &Safety, + cx: &mut Context, + ) -> Poll>> { + if self.state == InnerState::Eof { + Poll::Ready(None) + } else { + // release field + loop { + // Nested multipart streams of fields has to be consumed + // before switching to next + if safety.current() { + let stop = match self.item { + InnerMultipartItem::Field(ref mut field) => { + match field.borrow_mut().poll(safety) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Ok(_))) => continue, + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Some(Err(e))) + } + Poll::Ready(None) => true, + } + } + InnerMultipartItem::None => false, + }; + if stop { + self.item = InnerMultipartItem::None; + } + if let InnerMultipartItem::None = self.item { + break; + } + } + } + + let headers = if let Some(mut payload) = self.payload.get_mut(safety) { + match self.state { + // read until first boundary + InnerState::FirstBoundary => { + match InnerMultipart::skip_until_boundary( + &mut *payload, + &self.boundary, + )? { + Some(eof) => { + if eof { + self.state = InnerState::Eof; + return Poll::Ready(None); + } else { + self.state = InnerState::Headers; + } + } + None => return Poll::Pending, + } + } + // read boundary + InnerState::Boundary => { + match InnerMultipart::read_boundary( + &mut *payload, + &self.boundary, + )? { + None => return Poll::Pending, + Some(eof) => { + if eof { + self.state = InnerState::Eof; + return Poll::Ready(None); + } else { + self.state = InnerState::Headers; + } + } + } + } + _ => (), + } + + // read field headers for next field + if self.state == InnerState::Headers { + if let Some(headers) = InnerMultipart::read_headers(&mut *payload)? { + self.state = InnerState::Boundary; + headers + } else { + return Poll::Pending; + } + } else { + unreachable!() + } + } else { + log::debug!("NotReady: field is in flight"); + return Poll::Pending; + }; + + // content type + let mut mt = mime::APPLICATION_OCTET_STREAM; + if let Some(content_type) = headers.get(&header::CONTENT_TYPE) { + if let Ok(content_type) = content_type.to_str() { + if let Ok(ct) = content_type.parse::() { + mt = ct; + } + } + } + + self.state = InnerState::Boundary; + + // nested multipart stream + if mt.type_() == mime::MULTIPART { + Poll::Ready(Some(Err(MultipartError::Nested))) + } else { + let field = Rc::new(RefCell::new(InnerField::new( + self.payload.clone(), + self.boundary.clone(), + &headers, + )?)); + self.item = InnerMultipartItem::Field(Rc::clone(&field)); + + Poll::Ready(Some(Ok(Field::new(safety.clone(cx), headers, mt, field)))) + } + } + } +} + +impl Drop for InnerMultipart { + fn drop(&mut self) { + // InnerMultipartItem::Field has to be dropped first because of Safety. + self.item = InnerMultipartItem::None; + } +} + +/// A single field in a multipart stream +pub struct Field { + ct: mime::Mime, + headers: HeaderMap, + inner: Rc>, + safety: Safety, +} + +impl Field { + fn new( + safety: Safety, + headers: HeaderMap, + ct: mime::Mime, + inner: Rc>, + ) -> Self { + Field { + ct, + headers, + inner, + safety, + } + } + + /// Get a map of headers + pub fn headers(&self) -> &HeaderMap { + &self.headers + } + + /// Get the content type of the field + pub fn content_type(&self) -> &mime::Mime { + &self.ct + } + + /// Get the content disposition of the field, if it exists + pub fn content_disposition(&self) -> Option { + // RFC 7578: 'Each part MUST contain a Content-Disposition header field + // where the disposition type is "form-data".' + if let Some(content_disposition) = self.headers.get(&header::CONTENT_DISPOSITION) + { + ContentDisposition::from_raw(content_disposition).ok() + } else { + None + } + } +} + +impl Stream for Field { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + if self.safety.current() { + let mut inner = self.inner.borrow_mut(); + if let Some(mut payload) = + inner.payload.as_ref().unwrap().get_mut(&self.safety) + { + payload.poll_stream(cx)?; + } + inner.poll(&self.safety) + } else if !self.safety.is_clean() { + Poll::Ready(Some(Err(MultipartError::NotConsumed))) + } else { + Poll::Pending + } + } +} + +impl fmt::Debug for Field { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "\nField: {}", self.ct)?; + writeln!(f, " boundary: {}", self.inner.borrow().boundary)?; + writeln!(f, " headers:")?; + for (key, val) in self.headers.iter() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + Ok(()) + } +} + +struct InnerField { + payload: Option, + boundary: String, + eof: bool, + length: Option, +} + +impl InnerField { + fn new( + payload: PayloadRef, + boundary: String, + headers: &HeaderMap, + ) -> Result { + let len = if let Some(len) = headers.get(&header::CONTENT_LENGTH) { + if let Ok(s) = len.to_str() { + if let Ok(len) = s.parse::() { + Some(len) + } else { + return Err(PayloadError::Incomplete(None)); + } + } else { + return Err(PayloadError::Incomplete(None)); + } + } else { + None + }; + + Ok(InnerField { + boundary, + payload: Some(payload), + eof: false, + length: len, + }) + } + + /// Reads body part content chunk of the specified size. + /// The body part must has `Content-Length` header with proper value. + fn read_len( + payload: &mut PayloadBuffer, + size: &mut u64, + ) -> Poll>> { + if *size == 0 { + Poll::Ready(None) + } else { + match payload.read_max(*size)? { + Some(mut chunk) => { + let len = cmp::min(chunk.len() as u64, *size); + *size -= len; + let ch = chunk.split_to(len as usize); + if !chunk.is_empty() { + payload.unprocessed(chunk); + } + Poll::Ready(Some(Ok(ch))) + } + None => { + if payload.eof && (*size != 0) { + Poll::Ready(Some(Err(MultipartError::Incomplete))) + } else { + Poll::Pending + } + } + } + } + } + + /// Reads content chunk of body part with unknown length. + /// The `Content-Length` header for body part is not necessary. + fn read_stream( + payload: &mut PayloadBuffer, + boundary: &str, + ) -> Poll>> { + let mut pos = 0; + + let len = payload.buf.len(); + if len == 0 { + return if payload.eof { + Poll::Ready(Some(Err(MultipartError::Incomplete))) + } else { + Poll::Pending + }; + } + + // check boundary + if len > 4 && payload.buf[0] == b'\r' { + let b_len = if &payload.buf[..2] == b"\r\n" && &payload.buf[2..4] == b"--" { + Some(4) + } else if &payload.buf[1..3] == b"--" { + Some(3) + } else { + None + }; + + if let Some(b_len) = b_len { + let b_size = boundary.len() + b_len; + if len < b_size { + return Poll::Pending; + } else if &payload.buf[b_len..b_size] == boundary.as_bytes() { + // found boundary + return Poll::Ready(None); + } + } + } + + loop { + return if let Some(idx) = twoway::find_bytes(&payload.buf[pos..], b"\r") { + let cur = pos + idx; + + // check if we have enough data for boundary detection + if cur + 4 > len { + if cur > 0 { + Poll::Ready(Some(Ok(payload.buf.split_to(cur).freeze()))) + } else { + Poll::Pending + } + } else { + // check boundary + if (&payload.buf[cur..cur + 2] == b"\r\n" + && &payload.buf[cur + 2..cur + 4] == b"--") + || (&payload.buf[cur..=cur] == b"\r" + && &payload.buf[cur + 1..cur + 3] == b"--") + { + if cur != 0 { + // return buffer + Poll::Ready(Some(Ok(payload.buf.split_to(cur).freeze()))) + } else { + pos = cur + 1; + continue; + } + } else { + // not boundary + pos = cur + 1; + continue; + } + } + } else { + Poll::Ready(Some(Ok(payload.buf.take().freeze()))) + }; + } + } + + fn poll(&mut self, s: &Safety) -> Poll>> { + if self.payload.is_none() { + return Poll::Ready(None); + } + + let result = if let Some(mut payload) = self.payload.as_ref().unwrap().get_mut(s) + { + if !self.eof { + let res = if let Some(ref mut len) = self.length { + InnerField::read_len(&mut *payload, len) + } else { + InnerField::read_stream(&mut *payload, &self.boundary) + }; + + match res { + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Ok(bytes))) => return Poll::Ready(Some(Ok(bytes))), + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => self.eof = true, + } + } + + match payload.readline() { + Ok(None) => Poll::Ready(None), + Ok(Some(line)) => { + if line.as_ref() != b"\r\n" { + log::warn!("multipart field did not read all the data or it is malformed"); + } + Poll::Ready(None) + } + Err(e) => Poll::Ready(Some(Err(e))), + } + } else { + Poll::Pending + }; + + if let Poll::Ready(None) = result { + self.payload.take(); + } + result + } +} + +struct PayloadRef { + payload: Rc>, +} + +impl PayloadRef { + fn new(payload: PayloadBuffer) -> PayloadRef { + PayloadRef { + payload: Rc::new(payload.into()), + } + } + + fn get_mut<'a, 'b>(&'a self, s: &'b Safety) -> Option> + where + 'a: 'b, + { + if s.current() { + Some(self.payload.borrow_mut()) + } else { + None + } + } +} + +impl Clone for PayloadRef { + fn clone(&self) -> PayloadRef { + PayloadRef { + payload: Rc::clone(&self.payload), + } + } +} + +/// Counter. It tracks of number of clones of payloads and give access to +/// payload only to top most task panics if Safety get destroyed and it not top +/// most task. +#[derive(Debug)] +struct Safety { + task: LocalWaker, + level: usize, + payload: Rc>, + clean: Rc>, +} + +impl Safety { + fn new() -> Safety { + let payload = Rc::new(PhantomData); + Safety { + task: LocalWaker::new(), + level: Rc::strong_count(&payload), + clean: Rc::new(Cell::new(true)), + payload, + } + } + + fn current(&self) -> bool { + Rc::strong_count(&self.payload) == self.level && self.clean.get() + } + + fn is_clean(&self) -> bool { + self.clean.get() + } + + fn clone(&self, cx: &mut Context) -> Safety { + let payload = Rc::clone(&self.payload); + let s = Safety { + task: LocalWaker::new(), + level: Rc::strong_count(&payload), + clean: self.clean.clone(), + payload, + }; + s.task.register(cx.waker()); + s + } +} + +impl Drop for Safety { + fn drop(&mut self) { + // parent task is dead + if Rc::strong_count(&self.payload) != self.level { + self.clean.set(true); + } + if let Some(task) = self.task.take() { + task.wake() + } + } +} + +/// Payload buffer +struct PayloadBuffer { + eof: bool, + buf: BytesMut, + stream: LocalBoxStream<'static, Result>, +} + +impl PayloadBuffer { + /// Create new `PayloadBuffer` instance + fn new(stream: S) -> Self + where + S: Stream> + 'static, + { + PayloadBuffer { + eof: false, + buf: BytesMut::new(), + stream: stream.boxed_local(), + } + } + + fn poll_stream(&mut self, cx: &mut Context) -> Result<(), PayloadError> { + loop { + match Pin::new(&mut self.stream).poll_next(cx) { + Poll::Ready(Some(Ok(data))) => self.buf.extend_from_slice(&data), + Poll::Ready(Some(Err(e))) => return Err(e), + Poll::Ready(None) => { + self.eof = true; + return Ok(()); + } + Poll::Pending => return Ok(()), + } + } + } + + /// Read exact number of bytes + #[cfg(test)] + fn read_exact(&mut self, size: usize) -> Option { + if size <= self.buf.len() { + Some(self.buf.split_to(size).freeze()) + } else { + None + } + } + + fn read_max(&mut self, size: u64) -> Result, MultipartError> { + if !self.buf.is_empty() { + let size = std::cmp::min(self.buf.len() as u64, size) as usize; + Ok(Some(self.buf.split_to(size).freeze())) + } else if self.eof { + Err(MultipartError::Incomplete) + } else { + Ok(None) + } + } + + /// Read until specified ending + pub fn read_until(&mut self, line: &[u8]) -> Result, MultipartError> { + let res = twoway::find_bytes(&self.buf, line) + .map(|idx| self.buf.split_to(idx + line.len()).freeze()); + + if res.is_none() && self.eof { + Err(MultipartError::Incomplete) + } else { + Ok(res) + } + } + + /// Read bytes until new line delimiter + pub fn readline(&mut self) -> Result, MultipartError> { + self.read_until(b"\n") + } + + /// Read bytes until new line delimiter or eof + pub fn readline_or_eof(&mut self) -> Result, MultipartError> { + match self.readline() { + Err(MultipartError::Incomplete) if self.eof => { + Ok(Some(self.buf.take().freeze())) + } + line => line, + } + } + + /// Put unprocessed data back to the buffer + pub fn unprocessed(&mut self, data: Bytes) { + let buf = BytesMut::from(data); + let buf = std::mem::replace(&mut self.buf, buf); + self.buf.extend_from_slice(&buf); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use actix_http::h1::Payload; + use actix_utils::mpsc; + use actix_web::http::header::{DispositionParam, DispositionType}; + use bytes::Bytes; + use futures::future::lazy; + + #[actix_rt::test] + async fn test_boundary() { + let headers = HeaderMap::new(); + match Multipart::boundary(&headers) { + Err(MultipartError::NoContentType) => (), + _ => unreachable!("should not happen"), + } + + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("test"), + ); + + match Multipart::boundary(&headers) { + Err(MultipartError::ParseContentType) => (), + _ => unreachable!("should not happen"), + } + + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("multipart/mixed"), + ); + match Multipart::boundary(&headers) { + Err(MultipartError::Boundary) => (), + _ => unreachable!("should not happen"), + } + + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static( + "multipart/mixed; boundary=\"5c02368e880e436dab70ed54e1c58209\"", + ), + ); + + assert_eq!( + Multipart::boundary(&headers).unwrap(), + "5c02368e880e436dab70ed54e1c58209" + ); + } + + fn create_stream() -> ( + mpsc::Sender>, + impl Stream>, + ) { + let (tx, rx) = mpsc::channel(); + + (tx, rx.map(|res| res.map_err(|_| panic!()))) + } + + fn create_simple_request_with_header() -> (Bytes, HeaderMap) { + let bytes = Bytes::from( + "testasdadsad\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ + Content-Disposition: form-data; name=\"file\"; filename=\"fn.txt\"\r\n\ + Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ + test\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ + Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ + data\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n", + ); + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static( + "multipart/mixed; boundary=\"abbc761f78ff4d7cb7573b5a23f96ef0\"", + ), + ); + (bytes, headers) + } + + #[actix_rt::test] + async fn test_multipart_no_end_crlf() { + let (sender, payload) = create_stream(); + let (bytes, headers) = create_simple_request_with_header(); + let bytes_stripped = bytes.slice_to(bytes.len()); // strip crlf + + sender.send(Ok(bytes_stripped)).unwrap(); + drop(sender); // eof + + let mut multipart = Multipart::new(&headers, payload); + + match multipart.next().await.unwrap() { + Ok(_) => (), + _ => unreachable!(), + } + + match multipart.next().await.unwrap() { + Ok(_) => (), + _ => unreachable!(), + } + + match multipart.next().await { + None => (), + _ => unreachable!(), + } + } + + #[actix_rt::test] + async fn test_multipart() { + let (sender, payload) = create_stream(); + let (bytes, headers) = create_simple_request_with_header(); + + sender.send(Ok(bytes)).unwrap(); + + let mut multipart = Multipart::new(&headers, payload); + match multipart.next().await { + Some(Ok(mut field)) => { + let cd = field.content_disposition().unwrap(); + assert_eq!(cd.disposition, DispositionType::FormData); + assert_eq!(cd.parameters[0], DispositionParam::Name("file".into())); + + assert_eq!(field.content_type().type_(), mime::TEXT); + assert_eq!(field.content_type().subtype(), mime::PLAIN); + + match field.next().await.unwrap() { + Ok(chunk) => assert_eq!(chunk, "test"), + _ => unreachable!(), + } + match field.next().await { + None => (), + _ => unreachable!(), + } + } + _ => unreachable!(), + } + + match multipart.next().await.unwrap() { + Ok(mut field) => { + assert_eq!(field.content_type().type_(), mime::TEXT); + assert_eq!(field.content_type().subtype(), mime::PLAIN); + + match field.next().await { + Some(Ok(chunk)) => assert_eq!(chunk, "data"), + _ => unreachable!(), + } + match field.next().await { + None => (), + _ => unreachable!(), + } + } + _ => unreachable!(), + } + + match multipart.next().await { + None => (), + _ => unreachable!(), + } + } + + #[actix_rt::test] + async fn test_stream() { + let (sender, payload) = create_stream(); + let (bytes, headers) = create_simple_request_with_header(); + + sender.send(Ok(bytes)).unwrap(); + + let mut multipart = Multipart::new(&headers, payload); + match multipart.next().await.unwrap() { + Ok(mut field) => { + let cd = field.content_disposition().unwrap(); + assert_eq!(cd.disposition, DispositionType::FormData); + assert_eq!(cd.parameters[0], DispositionParam::Name("file".into())); + + assert_eq!(field.content_type().type_(), mime::TEXT); + assert_eq!(field.content_type().subtype(), mime::PLAIN); + + match field.next().await.unwrap() { + Ok(chunk) => assert_eq!(chunk, "test"), + _ => unreachable!(), + } + match field.next().await { + None => (), + _ => unreachable!(), + } + } + _ => unreachable!(), + } + + match multipart.next().await { + Some(Ok(mut field)) => { + assert_eq!(field.content_type().type_(), mime::TEXT); + assert_eq!(field.content_type().subtype(), mime::PLAIN); + + match field.next().await { + Some(Ok(chunk)) => assert_eq!(chunk, "data"), + _ => unreachable!(), + } + match field.next().await { + None => (), + _ => unreachable!(), + } + } + _ => unreachable!(), + } + + match multipart.next().await { + None => (), + _ => unreachable!(), + } + } + + #[actix_rt::test] + async fn test_basic() { + let (_, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); + + assert_eq!(payload.buf.len(), 0); + lazy(|cx| payload.poll_stream(cx)).await.unwrap(); + assert_eq!(None, payload.read_max(1).unwrap()); + } + + #[actix_rt::test] + async fn test_eof() { + let (mut sender, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); + + assert_eq!(None, payload.read_max(4).unwrap()); + sender.feed_data(Bytes::from("data")); + sender.feed_eof(); + lazy(|cx| payload.poll_stream(cx)).await.unwrap(); + + assert_eq!(Some(Bytes::from("data")), payload.read_max(4).unwrap()); + assert_eq!(payload.buf.len(), 0); + assert!(payload.read_max(1).is_err()); + assert!(payload.eof); + } + + #[actix_rt::test] + async fn test_err() { + let (mut sender, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); + assert_eq!(None, payload.read_max(1).unwrap()); + sender.set_error(PayloadError::Incomplete(None)); + lazy(|cx| payload.poll_stream(cx)).await.err().unwrap(); + } + + #[actix_rt::test] + async fn test_readmax() { + let (mut sender, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); + + sender.feed_data(Bytes::from("line1")); + sender.feed_data(Bytes::from("line2")); + lazy(|cx| payload.poll_stream(cx)).await.unwrap(); + assert_eq!(payload.buf.len(), 10); + + assert_eq!(Some(Bytes::from("line1")), payload.read_max(5).unwrap()); + assert_eq!(payload.buf.len(), 5); + + assert_eq!(Some(Bytes::from("line2")), payload.read_max(5).unwrap()); + assert_eq!(payload.buf.len(), 0); + } + + #[actix_rt::test] + async fn test_readexactly() { + let (mut sender, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); + + assert_eq!(None, payload.read_exact(2)); + + sender.feed_data(Bytes::from("line1")); + sender.feed_data(Bytes::from("line2")); + lazy(|cx| payload.poll_stream(cx)).await.unwrap(); + + assert_eq!(Some(Bytes::from_static(b"li")), payload.read_exact(2)); + assert_eq!(payload.buf.len(), 8); + + assert_eq!(Some(Bytes::from_static(b"ne1l")), payload.read_exact(4)); + assert_eq!(payload.buf.len(), 4); + } + + #[actix_rt::test] + async fn test_readuntil() { + let (mut sender, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); + + assert_eq!(None, payload.read_until(b"ne").unwrap()); + + sender.feed_data(Bytes::from("line1")); + sender.feed_data(Bytes::from("line2")); + lazy(|cx| payload.poll_stream(cx)).await.unwrap(); + + assert_eq!( + Some(Bytes::from("line")), + payload.read_until(b"ne").unwrap() + ); + assert_eq!(payload.buf.len(), 6); + + assert_eq!( + Some(Bytes::from("1line2")), + payload.read_until(b"2").unwrap() + ); + assert_eq!(payload.buf.len(), 0); + } +} diff --git a/actix-session/CHANGES.md b/actix-session/CHANGES.md new file mode 100644 index 000000000..d85f6d5f1 --- /dev/null +++ b/actix-session/CHANGES.md @@ -0,0 +1,52 @@ +# Changes + +## [0.2.0] - 2019-07-08 + +* Enhanced ``actix-session`` to facilitate state changes. Use ``Session.renew()`` + at successful login to cycle a session (new key/cookie but keeps state). + Use ``Session.purge()`` at logout to invalid a session cookie (and remove + from redis cache, if applicable). + +## [0.1.1] - 2019-06-03 + +* Fix optional cookie session support + +## [0.1.0] - 2019-05-18 + +* Use actix-web 1.0.0-rc + +## [0.1.0-beta.4] - 2019-05-12 + +* Use actix-web 1.0.0-beta.4 + +## [0.1.0-beta.2] - 2019-04-28 + +* Add helper trait `UserSession` which allows to get session for ServiceRequest and HttpRequest + +## [0.1.0-beta.1] - 2019-04-20 + +* Update actix-web to beta.1 + +* `CookieSession::max_age()` accepts value in seconds + +## [0.1.0-alpha.6] - 2019-04-14 + +* Update actix-web alpha.6 + +## [0.1.0-alpha.4] - 2019-04-08 + +* Update actix-web + +## [0.1.0-alpha.3] - 2019-04-02 + +* Update actix-web + +## [0.1.0-alpha.2] - 2019-03-29 + +* Update actix-web + +* Use new feature name for secure cookies + +## [0.1.0-alpha.1] - 2019-03-28 + +* Initial impl diff --git a/actix-session/Cargo.toml b/actix-session/Cargo.toml new file mode 100644 index 000000000..a4c53e563 --- /dev/null +++ b/actix-session/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "actix-session" +version = "0.3.0-alpha.1" +authors = ["Nikolay Kim "] +description = "Session for actix web framework." +readme = "README.md" +keywords = ["http", "web", "framework", "async", "futures"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" +documentation = "https://docs.rs/actix-session/" +license = "MIT/Apache-2.0" +exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"] +workspace = ".." +edition = "2018" + +[lib] +name = "actix_session" +path = "src/lib.rs" + +[features] +default = ["cookie-session"] + +# sessions feature, session require "ring" crate and c compiler +cookie-session = ["actix-web/secure-cookies"] + +[dependencies] +actix-web = "2.0.0-alpha.1" +actix-service = "1.0.0-alpha.1" +bytes = "0.4" +derive_more = "0.99.2" +futures = "0.3.1" +hashbrown = "0.6.3" +serde = "1.0" +serde_json = "1.0" +time = "0.1.42" + +[dev-dependencies] +actix-rt = "1.0.0-alpha.1" diff --git a/actix-session/LICENSE-APACHE b/actix-session/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/actix-session/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/actix-session/LICENSE-MIT b/actix-session/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/actix-session/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/actix-session/README.md b/actix-session/README.md new file mode 100644 index 000000000..0aee756fd --- /dev/null +++ b/actix-session/README.md @@ -0,0 +1,9 @@ +# Session for actix web framework [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-session)](https://crates.io/crates/actix-session) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +## Documentation & community resources + +* [User Guide](https://actix.rs/docs/) +* [API Documentation](https://docs.rs/actix-session/) +* [Chat on gitter](https://gitter.im/actix/actix) +* Cargo package: [actix-session](https://crates.io/crates/actix-session) +* Minimum supported Rust version: 1.34 or later diff --git a/actix-session/src/cookie.rs b/actix-session/src/cookie.rs new file mode 100644 index 000000000..5d66d6537 --- /dev/null +++ b/actix-session/src/cookie.rs @@ -0,0 +1,480 @@ +//! Cookie session. +//! +//! [**CookieSession**](struct.CookieSession.html) +//! uses cookies as session storage. `CookieSession` creates sessions +//! which are limited to storing fewer than 4000 bytes of data, as the payload +//! must fit into a single cookie. An internal server error is generated if a +//! session contains more than 4000 bytes. +//! +//! A cookie may have a security policy of *signed* or *private*. Each has +//! a respective `CookieSession` constructor. +//! +//! A *signed* cookie may be viewed but not modified by the client. A *private* +//! cookie may neither be viewed nor modified by the client. +//! +//! The constructors take a key as an argument. This is the private key +//! for cookie session - when this value is changed, all session data is lost. + +use std::collections::HashMap; +use std::rc::Rc; +use std::task::{Context, Poll}; + +use actix_service::{Service, Transform}; +use actix_web::cookie::{Cookie, CookieJar, Key, SameSite}; +use actix_web::dev::{ServiceRequest, ServiceResponse}; +use actix_web::http::{header::SET_COOKIE, HeaderValue}; +use actix_web::{Error, HttpMessage, ResponseError}; +use derive_more::{Display, From}; +use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; +use serde_json::error::Error as JsonError; + +use crate::{Session, SessionStatus}; + +/// Errors that can occur during handling cookie session +#[derive(Debug, From, Display)] +pub enum CookieSessionError { + /// Size of the serialized session is greater than 4000 bytes. + #[display(fmt = "Size of the serialized session is greater than 4000 bytes.")] + Overflow, + /// Fail to serialize session. + #[display(fmt = "Fail to serialize session")] + Serialize(JsonError), +} + +impl ResponseError for CookieSessionError {} + +enum CookieSecurity { + Signed, + Private, +} + +struct CookieSessionInner { + key: Key, + security: CookieSecurity, + name: String, + path: String, + domain: Option, + secure: bool, + http_only: bool, + max_age: Option, + same_site: Option, +} + +impl CookieSessionInner { + fn new(key: &[u8], security: CookieSecurity) -> CookieSessionInner { + CookieSessionInner { + security, + key: Key::from_master(key), + name: "actix-session".to_owned(), + path: "/".to_owned(), + domain: None, + secure: true, + http_only: true, + max_age: None, + same_site: None, + } + } + + fn set_cookie( + &self, + res: &mut ServiceResponse, + state: impl Iterator, + ) -> Result<(), Error> { + let state: HashMap = state.collect(); + let value = + serde_json::to_string(&state).map_err(CookieSessionError::Serialize)?; + if value.len() > 4064 { + return Err(CookieSessionError::Overflow.into()); + } + + let mut cookie = Cookie::new(self.name.clone(), value); + cookie.set_path(self.path.clone()); + cookie.set_secure(self.secure); + cookie.set_http_only(self.http_only); + + if let Some(ref domain) = self.domain { + cookie.set_domain(domain.clone()); + } + + if let Some(max_age) = self.max_age { + cookie.set_max_age(max_age); + } + + if let Some(same_site) = self.same_site { + cookie.set_same_site(same_site); + } + + let mut jar = CookieJar::new(); + + match self.security { + CookieSecurity::Signed => jar.signed(&self.key).add(cookie), + CookieSecurity::Private => jar.private(&self.key).add(cookie), + } + + for cookie in jar.delta() { + let val = HeaderValue::from_str(&cookie.encoded().to_string())?; + res.headers_mut().append(SET_COOKIE, val); + } + + Ok(()) + } + + /// invalidates session cookie + fn remove_cookie(&self, res: &mut ServiceResponse) -> Result<(), Error> { + let mut cookie = Cookie::named(self.name.clone()); + cookie.set_value(""); + cookie.set_max_age(time::Duration::seconds(0)); + cookie.set_expires(time::now() - time::Duration::days(365)); + + let val = HeaderValue::from_str(&cookie.to_string())?; + res.headers_mut().append(SET_COOKIE, val); + + Ok(()) + } + + fn load(&self, req: &ServiceRequest) -> (bool, HashMap) { + if let Ok(cookies) = req.cookies() { + for cookie in cookies.iter() { + if cookie.name() == self.name { + let mut jar = CookieJar::new(); + jar.add_original(cookie.clone()); + + let cookie_opt = match self.security { + CookieSecurity::Signed => jar.signed(&self.key).get(&self.name), + CookieSecurity::Private => { + jar.private(&self.key).get(&self.name) + } + }; + if let Some(cookie) = cookie_opt { + if let Ok(val) = serde_json::from_str(cookie.value()) { + return (false, val); + } + } + } + } + } + (true, HashMap::new()) + } +} + +/// Use cookies for session storage. +/// +/// `CookieSession` creates sessions which are limited to storing +/// fewer than 4000 bytes of data (as the payload must fit into a single +/// cookie). An Internal Server Error is generated if the session contains more +/// than 4000 bytes. +/// +/// A cookie may have a security policy of *signed* or *private*. Each has a +/// respective `CookieSessionBackend` constructor. +/// +/// A *signed* cookie is stored on the client as plaintext alongside +/// a signature such that the cookie may be viewed but not modified by the +/// client. +/// +/// A *private* cookie is stored on the client as encrypted text +/// such that it may neither be viewed nor modified by the client. +/// +/// The constructors take a key as an argument. +/// This is the private key for cookie session - when this value is changed, +/// all session data is lost. The constructors will panic if the key is less +/// than 32 bytes in length. +/// +/// The backend relies on `cookie` crate to create and read cookies. +/// By default all cookies are percent encoded, but certain symbols may +/// cause troubles when reading cookie, if they are not properly percent encoded. +/// +/// # Example +/// +/// ```rust +/// use actix_session::CookieSession; +/// use actix_web::{web, App, HttpResponse, HttpServer}; +/// +/// fn main() { +/// let app = App::new().wrap( +/// CookieSession::signed(&[0; 32]) +/// .domain("www.rust-lang.org") +/// .name("actix_session") +/// .path("/") +/// .secure(true)) +/// .service(web::resource("/").to(|| HttpResponse::Ok())); +/// } +/// ``` +pub struct CookieSession(Rc); + +impl CookieSession { + /// Construct new *signed* `CookieSessionBackend` instance. + /// + /// Panics if key length is less than 32 bytes. + pub fn signed(key: &[u8]) -> CookieSession { + CookieSession(Rc::new(CookieSessionInner::new( + key, + CookieSecurity::Signed, + ))) + } + + /// Construct new *private* `CookieSessionBackend` instance. + /// + /// Panics if key length is less than 32 bytes. + pub fn private(key: &[u8]) -> CookieSession { + CookieSession(Rc::new(CookieSessionInner::new( + key, + CookieSecurity::Private, + ))) + } + + /// Sets the `path` field in the session cookie being built. + pub fn path>(mut self, value: S) -> CookieSession { + Rc::get_mut(&mut self.0).unwrap().path = value.into(); + self + } + + /// Sets the `name` field in the session cookie being built. + pub fn name>(mut self, value: S) -> CookieSession { + Rc::get_mut(&mut self.0).unwrap().name = value.into(); + self + } + + /// Sets the `domain` field in the session cookie being built. + pub fn domain>(mut self, value: S) -> CookieSession { + Rc::get_mut(&mut self.0).unwrap().domain = Some(value.into()); + self + } + + /// Sets the `secure` field in the session cookie being built. + /// + /// If the `secure` field is set, a cookie will only be transmitted when the + /// connection is secure - i.e. `https` + pub fn secure(mut self, value: bool) -> CookieSession { + Rc::get_mut(&mut self.0).unwrap().secure = value; + self + } + + /// Sets the `http_only` field in the session cookie being built. + pub fn http_only(mut self, value: bool) -> CookieSession { + Rc::get_mut(&mut self.0).unwrap().http_only = value; + self + } + + /// Sets the `same_site` field in the session cookie being built. + pub fn same_site(mut self, value: SameSite) -> CookieSession { + Rc::get_mut(&mut self.0).unwrap().same_site = Some(value); + self + } + + /// Sets the `max-age` field in the session cookie being built. + pub fn max_age(self, seconds: i64) -> CookieSession { + self.max_age_time(time::Duration::seconds(seconds)) + } + + /// Sets the `max-age` field in the session cookie being built. + pub fn max_age_time(mut self, value: time::Duration) -> CookieSession { + Rc::get_mut(&mut self.0).unwrap().max_age = Some(value); + self + } +} + +impl Transform for CookieSession +where + S: Service>, + S::Future: 'static, + S::Error: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = S::Error; + type InitError = (); + type Transform = CookieSessionMiddleware; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ok(CookieSessionMiddleware { + service, + inner: self.0.clone(), + }) + } +} + +/// Cookie session middleware +pub struct CookieSessionMiddleware { + service: S, + inner: Rc, +} + +impl Service for CookieSessionMiddleware +where + S: Service>, + S::Future: 'static, + S::Error: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = S::Error; + type Future = LocalBoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) + } + + /// On first request, a new session cookie is returned in response, regardless + /// of whether any session state is set. With subsequent requests, if the + /// session state changes, then set-cookie is returned in response. As + /// a user logs out, call session.purge() to set SessionStatus accordingly + /// and this will trigger removal of the session cookie in the response. + fn call(&mut self, mut req: ServiceRequest) -> Self::Future { + let inner = self.inner.clone(); + let (is_new, state) = self.inner.load(&req); + Session::set_session(state.into_iter(), &mut req); + + let fut = self.service.call(req); + + async move { + fut.await.map(|mut res| { + match Session::get_changes(&mut res) { + (SessionStatus::Changed, Some(state)) + | (SessionStatus::Renewed, Some(state)) => { + res.checked_expr(|res| inner.set_cookie(res, state)) + } + (SessionStatus::Unchanged, _) => + // set a new session cookie upon first request (new client) + { + if is_new { + let state: HashMap = HashMap::new(); + res.checked_expr(|res| { + inner.set_cookie(res, state.into_iter()) + }) + } else { + res + } + } + (SessionStatus::Purged, _) => { + let _ = inner.remove_cookie(&mut res); + res + } + _ => res, + } + }) + } + .boxed_local() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use actix_web::{test, web, App}; + use bytes::Bytes; + + #[actix_rt::test] + async fn cookie_session() { + let mut app = test::init_service( + App::new() + .wrap(CookieSession::signed(&[0; 32]).secure(false)) + .service(web::resource("/").to(|ses: Session| { + async move { + let _ = ses.set("counter", 100); + "test" + } + })), + ) + .await; + + let request = test::TestRequest::get().to_request(); + let response = app.call(request).await.unwrap(); + assert!(response + .response() + .cookies() + .find(|c| c.name() == "actix-session") + .is_some()); + } + + #[actix_rt::test] + async fn private_cookie() { + let mut app = test::init_service( + App::new() + .wrap(CookieSession::private(&[0; 32]).secure(false)) + .service(web::resource("/").to(|ses: Session| { + async move { + let _ = ses.set("counter", 100); + "test" + } + })), + ) + .await; + + let request = test::TestRequest::get().to_request(); + let response = app.call(request).await.unwrap(); + assert!(response + .response() + .cookies() + .find(|c| c.name() == "actix-session") + .is_some()); + } + + #[actix_rt::test] + async fn cookie_session_extractor() { + let mut app = test::init_service( + App::new() + .wrap(CookieSession::signed(&[0; 32]).secure(false)) + .service(web::resource("/").to(|ses: Session| { + async move { + let _ = ses.set("counter", 100); + "test" + } + })), + ) + .await; + + let request = test::TestRequest::get().to_request(); + let response = app.call(request).await.unwrap(); + assert!(response + .response() + .cookies() + .find(|c| c.name() == "actix-session") + .is_some()); + } + + #[actix_rt::test] + async fn basics() { + let mut app = test::init_service( + App::new() + .wrap( + CookieSession::signed(&[0; 32]) + .path("/test/") + .name("actix-test") + .domain("localhost") + .http_only(true) + .same_site(SameSite::Lax) + .max_age(100), + ) + .service(web::resource("/").to(|ses: Session| { + async move { + let _ = ses.set("counter", 100); + "test" + } + })) + .service(web::resource("/test/").to(|ses: Session| { + async move { + let val: usize = ses.get("counter").unwrap().unwrap(); + format!("counter: {}", val) + } + })), + ) + .await; + + let request = test::TestRequest::get().to_request(); + let response = app.call(request).await.unwrap(); + let cookie = response + .response() + .cookies() + .find(|c| c.name() == "actix-test") + .unwrap() + .clone(); + assert_eq!(cookie.path().unwrap(), "/test/"); + + let request = test::TestRequest::with_uri("/test/") + .cookie(cookie) + .to_request(); + let body = test::read_response(&mut app, request).await; + assert_eq!(body, Bytes::from_static(b"counter: 100")); + } +} diff --git a/actix-session/src/lib.rs b/actix-session/src/lib.rs new file mode 100644 index 000000000..def35a1e9 --- /dev/null +++ b/actix-session/src/lib.rs @@ -0,0 +1,301 @@ +//! User sessions. +//! +//! Actix provides a general solution for session management. Session +//! middlewares could provide different implementations which could +//! be accessed via general session api. +//! +//! By default, only cookie session backend is implemented. Other +//! backend implementations can be added. +//! +//! In general, you insert a *session* middleware and initialize it +//! , such as a `CookieSessionBackend`. To access session data, +//! [*Session*](struct.Session.html) extractor must be used. Session +//! extractor allows us to get or set session data. +//! +//! ```rust +//! use actix_web::{web, App, HttpServer, HttpResponse, Error}; +//! use actix_session::{Session, CookieSession}; +//! +//! fn index(session: Session) -> Result<&'static str, Error> { +//! // access session data +//! if let Some(count) = session.get::("counter")? { +//! println!("SESSION value: {}", count); +//! session.set("counter", count+1)?; +//! } else { +//! session.set("counter", 1)?; +//! } +//! +//! Ok("Welcome!") +//! } +//! +//! fn main() -> std::io::Result<()> { +//! # std::thread::spawn(|| +//! HttpServer::new( +//! || App::new().wrap( +//! CookieSession::signed(&[0; 32]) // <- create cookie based session middleware +//! .secure(false) +//! ) +//! .service(web::resource("/").to(|| HttpResponse::Ok()))) +//! .bind("127.0.0.1:59880")? +//! .run() +//! # ); +//! # Ok(()) +//! } +//! ``` +use std::cell::RefCell; +use std::rc::Rc; + +use actix_web::dev::{Extensions, Payload, ServiceRequest, ServiceResponse}; +use actix_web::{Error, FromRequest, HttpMessage, HttpRequest}; +use futures::future::{ok, Ready}; +use hashbrown::HashMap; +use serde::de::DeserializeOwned; +use serde::Serialize; +use serde_json; + +#[cfg(feature = "cookie-session")] +mod cookie; +#[cfg(feature = "cookie-session")] +pub use crate::cookie::CookieSession; + +/// The high-level interface you use to modify session data. +/// +/// Session object could be obtained with +/// [`RequestSession::session`](trait.RequestSession.html#tymethod.session) +/// method. `RequestSession` trait is implemented for `HttpRequest`. +/// +/// ```rust +/// use actix_session::Session; +/// use actix_web::*; +/// +/// fn index(session: Session) -> Result<&'static str> { +/// // access session data +/// if let Some(count) = session.get::("counter")? { +/// session.set("counter", count + 1)?; +/// } else { +/// session.set("counter", 1)?; +/// } +/// +/// Ok("Welcome!") +/// } +/// # fn main() {} +/// ``` +pub struct Session(Rc>); + +/// Helper trait that allows to get session +pub trait UserSession { + fn get_session(&mut self) -> Session; +} + +impl UserSession for HttpRequest { + fn get_session(&mut self) -> Session { + Session::get_session(&mut *self.extensions_mut()) + } +} + +impl UserSession for ServiceRequest { + fn get_session(&mut self) -> Session { + Session::get_session(&mut *self.extensions_mut()) + } +} + +#[derive(PartialEq, Clone, Debug)] +pub enum SessionStatus { + Changed, + Purged, + Renewed, + Unchanged, +} +impl Default for SessionStatus { + fn default() -> SessionStatus { + SessionStatus::Unchanged + } +} + +#[derive(Default)] +struct SessionInner { + state: HashMap, + pub status: SessionStatus, +} + +impl Session { + /// Get a `value` from the session. + pub fn get(&self, key: &str) -> Result, Error> { + if let Some(s) = self.0.borrow().state.get(key) { + Ok(Some(serde_json::from_str(s)?)) + } else { + Ok(None) + } + } + + /// Set a `value` from the session. + pub fn set(&self, key: &str, value: T) -> Result<(), Error> { + let mut inner = self.0.borrow_mut(); + if inner.status != SessionStatus::Purged { + inner.status = SessionStatus::Changed; + inner + .state + .insert(key.to_owned(), serde_json::to_string(&value)?); + } + Ok(()) + } + + /// Remove value from the session. + pub fn remove(&self, key: &str) { + let mut inner = self.0.borrow_mut(); + if inner.status != SessionStatus::Purged { + inner.status = SessionStatus::Changed; + inner.state.remove(key); + } + } + + /// Clear the session. + pub fn clear(&self) { + let mut inner = self.0.borrow_mut(); + if inner.status != SessionStatus::Purged { + inner.status = SessionStatus::Changed; + inner.state.clear() + } + } + + /// Removes session, both client and server side. + pub fn purge(&self) { + let mut inner = self.0.borrow_mut(); + inner.status = SessionStatus::Purged; + inner.state.clear(); + } + + /// Renews the session key, assigning existing session state to new key. + pub fn renew(&self) { + let mut inner = self.0.borrow_mut(); + if inner.status != SessionStatus::Purged { + inner.status = SessionStatus::Renewed; + } + } + + pub fn set_session( + data: impl Iterator, + req: &mut ServiceRequest, + ) { + let session = Session::get_session(&mut *req.extensions_mut()); + let mut inner = session.0.borrow_mut(); + inner.state.extend(data); + } + + pub fn get_changes( + res: &mut ServiceResponse, + ) -> ( + SessionStatus, + Option>, + ) { + if let Some(s_impl) = res + .request() + .extensions() + .get::>>() + { + let state = + std::mem::replace(&mut s_impl.borrow_mut().state, HashMap::new()); + (s_impl.borrow().status.clone(), Some(state.into_iter())) + } else { + (SessionStatus::Unchanged, None) + } + } + + fn get_session(extensions: &mut Extensions) -> Session { + if let Some(s_impl) = extensions.get::>>() { + return Session(Rc::clone(&s_impl)); + } + let inner = Rc::new(RefCell::new(SessionInner::default())); + extensions.insert(inner.clone()); + Session(inner) + } +} + +/// Extractor implementation for Session type. +/// +/// ```rust +/// # use actix_web::*; +/// use actix_session::Session; +/// +/// fn index(session: Session) -> Result<&'static str> { +/// // access session data +/// if let Some(count) = session.get::("counter")? { +/// session.set("counter", count + 1)?; +/// } else { +/// session.set("counter", 1)?; +/// } +/// +/// Ok("Welcome!") +/// } +/// # fn main() {} +/// ``` +impl FromRequest for Session { + type Error = Error; + type Future = Ready>; + type Config = (); + + #[inline] + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + ok(Session::get_session(&mut *req.extensions_mut())) + } +} + +#[cfg(test)] +mod tests { + use actix_web::{test, HttpResponse}; + + use super::*; + + #[test] + fn session() { + let mut req = test::TestRequest::default().to_srv_request(); + + Session::set_session( + vec![("key".to_string(), "\"value\"".to_string())].into_iter(), + &mut req, + ); + let session = Session::get_session(&mut *req.extensions_mut()); + let res = session.get::("key").unwrap(); + assert_eq!(res, Some("value".to_string())); + + session.set("key2", "value2".to_string()).unwrap(); + session.remove("key"); + + let mut res = req.into_response(HttpResponse::Ok().finish()); + let (_status, state) = Session::get_changes(&mut res); + let changes: Vec<_> = state.unwrap().collect(); + assert_eq!(changes, [("key2".to_string(), "\"value2\"".to_string())]); + } + + #[test] + fn get_session() { + let mut req = test::TestRequest::default().to_srv_request(); + + Session::set_session( + vec![("key".to_string(), "\"value\"".to_string())].into_iter(), + &mut req, + ); + + let session = req.get_session(); + let res = session.get::("key").unwrap(); + assert_eq!(res, Some("value".to_string())); + } + + #[test] + fn purge_session() { + let req = test::TestRequest::default().to_srv_request(); + let session = Session::get_session(&mut *req.extensions_mut()); + assert_eq!(session.0.borrow().status, SessionStatus::Unchanged); + session.purge(); + assert_eq!(session.0.borrow().status, SessionStatus::Purged); + } + + #[test] + fn renew_session() { + let req = test::TestRequest::default().to_srv_request(); + let session = Session::get_session(&mut *req.extensions_mut()); + assert_eq!(session.0.borrow().status, SessionStatus::Unchanged); + session.renew(); + assert_eq!(session.0.borrow().status, SessionStatus::Renewed); + } +} diff --git a/actix-web-actors/CHANGES.md b/actix-web-actors/CHANGES.md new file mode 100644 index 000000000..c1417c9c4 --- /dev/null +++ b/actix-web-actors/CHANGES.md @@ -0,0 +1,32 @@ +# Changes + +## [1.0.3] - 2019-11-14 + +* Update actix-web and actix-http dependencies + +## [1.0.2] - 2019-07-20 + +* Add `ws::start_with_addr()`, returning the address of the created actor, along + with the `HttpResponse`. + +* Add support for specifying protocols on websocket handshake #835 + +## [1.0.1] - 2019-06-28 + +* Allow to use custom ws codec with `WebsocketContext` #925 + +## [1.0.0] - 2019-05-29 + +* Update actix-http and actix-web + +## [0.1.0-alpha.3] - 2019-04-02 + +* Update actix-http and actix-web + +## [0.1.0-alpha.2] - 2019-03-29 + +* Update actix-http and actix-web + +## [0.1.0-alpha.1] - 2019-03-28 + +* Initial impl diff --git a/actix-web-actors/Cargo.toml b/actix-web-actors/Cargo.toml new file mode 100644 index 000000000..d5a6ce2c4 --- /dev/null +++ b/actix-web-actors/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "actix-web-actors" +version = "1.0.3" +authors = ["Nikolay Kim "] +description = "Actix actors support for actix web framework." +readme = "README.md" +keywords = ["actix", "http", "web", "framework", "async"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" +documentation = "https://docs.rs/actix-web-actors/" +license = "MIT/Apache-2.0" +exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"] +workspace = ".." +edition = "2018" + +[lib] +name = "actix_web_actors" +path = "src/lib.rs" + +[dependencies] +actix = "0.8.3" +actix-web = "1.0.9" +actix-http = "0.2.11" +actix-codec = "0.1.2" +bytes = "0.4" +futures = "0.1.25" + +[dev-dependencies] +env_logger = "0.6" +actix-http-test = { version = "0.2.4", features=["ssl"] } diff --git a/actix-web-actors/LICENSE-APACHE b/actix-web-actors/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/actix-web-actors/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/actix-web-actors/LICENSE-MIT b/actix-web-actors/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/actix-web-actors/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/actix-web-actors/README.md b/actix-web-actors/README.md new file mode 100644 index 000000000..6ff7ac67c --- /dev/null +++ b/actix-web-actors/README.md @@ -0,0 +1,8 @@ +Actix actors support for actix web framework [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-web-actors)](https://crates.io/crates/actix-web-actors) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +## Documentation & community resources + +* [API Documentation](https://docs.rs/actix-web-actors/) +* [Chat on gitter](https://gitter.im/actix/actix) +* Cargo package: [actix-web-actors](https://crates.io/crates/actix-web-actors) +* Minimum supported Rust version: 1.33 or later diff --git a/actix-web-actors/src/context.rs b/actix-web-actors/src/context.rs new file mode 100644 index 000000000..31b29500a --- /dev/null +++ b/actix-web-actors/src/context.rs @@ -0,0 +1,253 @@ +use std::collections::VecDeque; + +use actix::dev::{ + AsyncContextParts, ContextFut, ContextParts, Envelope, Mailbox, ToEnvelope, +}; +use actix::fut::ActorFuture; +use actix::{ + Actor, ActorContext, ActorState, Addr, AsyncContext, Handler, Message, SpawnHandle, +}; +use actix_web::error::{Error, ErrorInternalServerError}; +use bytes::Bytes; +use futures::sync::oneshot::Sender; +use futures::{Async, Future, Poll, Stream}; + +/// Execution context for http actors +pub struct HttpContext +where + A: Actor>, +{ + inner: ContextParts, + stream: VecDeque>, +} + +impl ActorContext for HttpContext +where + A: Actor, +{ + fn stop(&mut self) { + self.inner.stop(); + } + fn terminate(&mut self) { + self.inner.terminate() + } + fn state(&self) -> ActorState { + self.inner.state() + } +} + +impl AsyncContext for HttpContext +where + A: Actor, +{ + #[inline] + fn spawn(&mut self, fut: F) -> SpawnHandle + where + F: ActorFuture + 'static, + { + self.inner.spawn(fut) + } + + #[inline] + fn wait(&mut self, fut: F) + where + F: ActorFuture + 'static, + { + self.inner.wait(fut) + } + + #[doc(hidden)] + #[inline] + fn waiting(&self) -> bool { + self.inner.waiting() + || self.inner.state() == ActorState::Stopping + || self.inner.state() == ActorState::Stopped + } + + #[inline] + fn cancel_future(&mut self, handle: SpawnHandle) -> bool { + self.inner.cancel_future(handle) + } + + #[inline] + fn address(&self) -> Addr { + self.inner.address() + } +} + +impl HttpContext +where + A: Actor, +{ + #[inline] + /// Create a new HTTP Context from a request and an actor + pub fn create(actor: A) -> impl Stream { + let mb = Mailbox::default(); + let ctx = HttpContext { + inner: ContextParts::new(mb.sender_producer()), + stream: VecDeque::new(), + }; + HttpContextFut::new(ctx, actor, mb) + } + + /// Create a new HTTP Context + pub fn with_factory(f: F) -> impl Stream + where + F: FnOnce(&mut Self) -> A + 'static, + { + let mb = Mailbox::default(); + let mut ctx = HttpContext { + inner: ContextParts::new(mb.sender_producer()), + stream: VecDeque::new(), + }; + + let act = f(&mut ctx); + HttpContextFut::new(ctx, act, mb) + } +} + +impl HttpContext +where + A: Actor, +{ + /// Write payload + #[inline] + pub fn write(&mut self, data: Bytes) { + self.stream.push_back(Some(data)); + } + + /// Indicate end of streaming payload. Also this method calls `Self::close`. + #[inline] + pub fn write_eof(&mut self) { + self.stream.push_back(None); + } + + /// Handle of the running future + /// + /// SpawnHandle is the handle returned by `AsyncContext::spawn()` method. + pub fn handle(&self) -> SpawnHandle { + self.inner.curr_handle() + } +} + +impl AsyncContextParts for HttpContext +where + A: Actor, +{ + fn parts(&mut self) -> &mut ContextParts { + &mut self.inner + } +} + +struct HttpContextFut +where + A: Actor>, +{ + fut: ContextFut>, +} + +impl HttpContextFut +where + A: Actor>, +{ + fn new(ctx: HttpContext, act: A, mailbox: Mailbox) -> Self { + let fut = ContextFut::new(ctx, act, mailbox); + HttpContextFut { fut } + } +} + +impl Stream for HttpContextFut +where + A: Actor>, +{ + type Item = Bytes; + type Error = Error; + + fn poll(&mut self) -> Poll, Error> { + if self.fut.alive() { + match self.fut.poll() { + Ok(Async::NotReady) | Ok(Async::Ready(())) => (), + Err(_) => return Err(ErrorInternalServerError("error")), + } + } + + // frames + if let Some(data) = self.fut.ctx().stream.pop_front() { + Ok(Async::Ready(data)) + } else if self.fut.alive() { + Ok(Async::NotReady) + } else { + Ok(Async::Ready(None)) + } + } +} + +impl ToEnvelope for HttpContext +where + A: Actor> + Handler, + M: Message + Send + 'static, + M::Result: Send, +{ + fn pack(msg: M, tx: Option>) -> Envelope { + Envelope::new(msg, tx) + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use actix::Actor; + use actix_web::http::StatusCode; + use actix_web::test::{block_on, call_service, init_service, TestRequest}; + use actix_web::{web, App, HttpResponse}; + use bytes::{Bytes, BytesMut}; + + use super::*; + + struct MyActor { + count: usize, + } + + impl Actor for MyActor { + type Context = HttpContext; + + fn started(&mut self, ctx: &mut Self::Context) { + ctx.run_later(Duration::from_millis(100), |slf, ctx| slf.write(ctx)); + } + } + + impl MyActor { + fn write(&mut self, ctx: &mut HttpContext) { + self.count += 1; + if self.count > 3 { + ctx.write_eof() + } else { + ctx.write(Bytes::from(format!("LINE-{}", self.count).as_bytes())); + ctx.run_later(Duration::from_millis(100), |slf, ctx| slf.write(ctx)); + } + } + } + + #[test] + fn test_default_resource() { + let mut srv = + init_service(App::new().service(web::resource("/test").to(|| { + HttpResponse::Ok().streaming(HttpContext::create(MyActor { count: 0 })) + }))); + + let req = TestRequest::with_uri("/test").to_request(); + let mut resp = call_service(&mut srv, req); + assert_eq!(resp.status(), StatusCode::OK); + + let body = block_on(resp.take_body().fold( + BytesMut::new(), + move |mut body, chunk| { + body.extend_from_slice(&chunk); + Ok::<_, Error>(body) + }, + )) + .unwrap(); + assert_eq!(body.freeze(), Bytes::from_static(b"LINE-1LINE-2LINE-3")); + } +} diff --git a/actix-web-actors/src/lib.rs b/actix-web-actors/src/lib.rs new file mode 100644 index 000000000..6360917cd --- /dev/null +++ b/actix-web-actors/src/lib.rs @@ -0,0 +1,6 @@ +#![allow(clippy::borrow_interior_mutable_const)] +//! Actix actors integration for Actix web framework +mod context; +pub mod ws; + +pub use self::context::HttpContext; diff --git a/actix-web-actors/src/ws.rs b/actix-web-actors/src/ws.rs new file mode 100644 index 000000000..e25a7e6e4 --- /dev/null +++ b/actix-web-actors/src/ws.rs @@ -0,0 +1,740 @@ +//! Websocket integration +use std::collections::VecDeque; +use std::io; + +use actix::dev::{ + AsyncContextParts, ContextFut, ContextParts, Envelope, Mailbox, StreamHandler, + ToEnvelope, +}; +use actix::fut::ActorFuture; +use actix::{ + Actor, ActorContext, ActorState, Addr, AsyncContext, Handler, + Message as ActixMessage, SpawnHandle, +}; +use actix_codec::{Decoder, Encoder}; +use actix_http::ws::{hash_key, Codec}; +pub use actix_http::ws::{ + CloseCode, CloseReason, Frame, HandshakeError, Message, ProtocolError, +}; + +use actix_web::dev::HttpResponseBuilder; +use actix_web::error::{Error, ErrorInternalServerError, PayloadError}; +use actix_web::http::{header, Method, StatusCode}; +use actix_web::{HttpRequest, HttpResponse}; +use bytes::{Bytes, BytesMut}; +use futures::sync::oneshot::Sender; +use futures::{Async, Future, Poll, Stream}; + +/// Do websocket handshake and start ws actor. +pub fn start(actor: A, req: &HttpRequest, stream: T) -> Result +where + A: Actor> + StreamHandler, + T: Stream + 'static, +{ + let mut res = handshake(req)?; + Ok(res.streaming(WebsocketContext::create(actor, stream))) +} + +/// Do websocket handshake and start ws actor. +/// +/// `req` is an HTTP Request that should be requesting a websocket protocol +/// change. `stream` should be a `Bytes` stream (such as +/// `actix_web::web::Payload`) that contains a stream of the body request. +/// +/// If there is a problem with the handshake, an error is returned. +/// +/// If successful, returns a pair where the first item is an address for the +/// created actor and the second item is the response that should be returned +/// from the websocket request. +pub fn start_with_addr( + actor: A, + req: &HttpRequest, + stream: T, +) -> Result<(Addr, HttpResponse), Error> +where + A: Actor> + StreamHandler, + T: Stream + 'static, +{ + let mut res = handshake(req)?; + let (addr, out_stream) = WebsocketContext::create_with_addr(actor, stream); + Ok((addr, res.streaming(out_stream))) +} + +/// Do websocket handshake and start ws actor. +/// +/// `protocols` is a sequence of known protocols. +pub fn start_with_protocols( + actor: A, + protocols: &[&str], + req: &HttpRequest, + stream: T, +) -> Result +where + A: Actor> + StreamHandler, + T: Stream + 'static, +{ + let mut res = handshake_with_protocols(req, protocols)?; + Ok(res.streaming(WebsocketContext::create(actor, stream))) +} + +/// Prepare `WebSocket` handshake response. +/// +/// This function returns handshake `HttpResponse`, ready to send to peer. +/// It does not perform any IO. +pub fn handshake(req: &HttpRequest) -> Result { + handshake_with_protocols(req, &[]) +} + +/// Prepare `WebSocket` handshake response. +/// +/// This function returns handshake `HttpResponse`, ready to send to peer. +/// It does not perform any IO. +/// +/// `protocols` is a sequence of known protocols. On successful handshake, +/// the returned response headers contain the first protocol in this list +/// which the server also knows. +pub fn handshake_with_protocols( + req: &HttpRequest, + protocols: &[&str], +) -> Result { + // WebSocket accepts only GET + if *req.method() != Method::GET { + return Err(HandshakeError::GetMethodRequired); + } + + // Check for "UPGRADE" to websocket header + let has_hdr = if let Some(hdr) = req.headers().get(&header::UPGRADE) { + if let Ok(s) = hdr.to_str() { + s.to_ascii_lowercase().contains("websocket") + } else { + false + } + } else { + false + }; + if !has_hdr { + return Err(HandshakeError::NoWebsocketUpgrade); + } + + // Upgrade connection + if !req.head().upgrade() { + return Err(HandshakeError::NoConnectionUpgrade); + } + + // check supported version + if !req.headers().contains_key(&header::SEC_WEBSOCKET_VERSION) { + return Err(HandshakeError::NoVersionHeader); + } + let supported_ver = { + if let Some(hdr) = req.headers().get(&header::SEC_WEBSOCKET_VERSION) { + hdr == "13" || hdr == "8" || hdr == "7" + } else { + false + } + }; + if !supported_ver { + return Err(HandshakeError::UnsupportedVersion); + } + + // check client handshake for validity + if !req.headers().contains_key(&header::SEC_WEBSOCKET_KEY) { + return Err(HandshakeError::BadWebsocketKey); + } + let key = { + let key = req.headers().get(&header::SEC_WEBSOCKET_KEY).unwrap(); + hash_key(key.as_ref()) + }; + + // check requested protocols + let protocol = + req.headers() + .get(&header::SEC_WEBSOCKET_PROTOCOL) + .and_then(|req_protocols| { + let req_protocols = req_protocols.to_str().ok()?; + req_protocols + .split(", ") + .find(|req_p| protocols.iter().any(|p| p == req_p)) + }); + + let mut response = HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS) + .upgrade("websocket") + .header(header::TRANSFER_ENCODING, "chunked") + .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str()) + .take(); + + if let Some(protocol) = protocol { + response.header(&header::SEC_WEBSOCKET_PROTOCOL, protocol); + } + + Ok(response) +} + +/// Execution context for `WebSockets` actors +pub struct WebsocketContext +where + A: Actor>, +{ + inner: ContextParts, + messages: VecDeque>, +} + +impl ActorContext for WebsocketContext +where + A: Actor, +{ + fn stop(&mut self) { + self.inner.stop(); + } + + fn terminate(&mut self) { + self.inner.terminate() + } + + fn state(&self) -> ActorState { + self.inner.state() + } +} + +impl AsyncContext for WebsocketContext +where + A: Actor, +{ + fn spawn(&mut self, fut: F) -> SpawnHandle + where + F: ActorFuture + 'static, + { + self.inner.spawn(fut) + } + + fn wait(&mut self, fut: F) + where + F: ActorFuture + 'static, + { + self.inner.wait(fut) + } + + #[doc(hidden)] + #[inline] + fn waiting(&self) -> bool { + self.inner.waiting() + || self.inner.state() == ActorState::Stopping + || self.inner.state() == ActorState::Stopped + } + + fn cancel_future(&mut self, handle: SpawnHandle) -> bool { + self.inner.cancel_future(handle) + } + + #[inline] + fn address(&self) -> Addr { + self.inner.address() + } +} + +impl WebsocketContext +where + A: Actor, +{ + #[inline] + /// Create a new Websocket context from a request and an actor + pub fn create(actor: A, stream: S) -> impl Stream + where + A: StreamHandler, + S: Stream + 'static, + { + let (_, stream) = WebsocketContext::create_with_addr(actor, stream); + stream + } + + #[inline] + /// Create a new Websocket context from a request and an actor. + /// + /// Returns a pair, where the first item is an addr for the created actor, + /// and the second item is a stream intended to be set as part of the + /// response via `HttpResponseBuilder::streaming()`. + pub fn create_with_addr( + actor: A, + stream: S, + ) -> (Addr, impl Stream) + where + A: StreamHandler, + S: Stream + 'static, + { + let mb = Mailbox::default(); + let mut ctx = WebsocketContext { + inner: ContextParts::new(mb.sender_producer()), + messages: VecDeque::new(), + }; + ctx.add_stream(WsStream::new(stream, Codec::new())); + + let addr = ctx.address(); + + (addr, WebsocketContextFut::new(ctx, actor, mb, Codec::new())) + } + + #[inline] + /// Create a new Websocket context from a request, an actor, and a codec + pub fn with_codec( + actor: A, + stream: S, + codec: Codec, + ) -> impl Stream + where + A: StreamHandler, + S: Stream + 'static, + { + let mb = Mailbox::default(); + let mut ctx = WebsocketContext { + inner: ContextParts::new(mb.sender_producer()), + messages: VecDeque::new(), + }; + ctx.add_stream(WsStream::new(stream, codec)); + + WebsocketContextFut::new(ctx, actor, mb, codec) + } + + /// Create a new Websocket context + pub fn with_factory( + stream: S, + f: F, + ) -> impl Stream + where + F: FnOnce(&mut Self) -> A + 'static, + A: StreamHandler, + S: Stream + 'static, + { + let mb = Mailbox::default(); + let mut ctx = WebsocketContext { + inner: ContextParts::new(mb.sender_producer()), + messages: VecDeque::new(), + }; + ctx.add_stream(WsStream::new(stream, Codec::new())); + + let act = f(&mut ctx); + + WebsocketContextFut::new(ctx, act, mb, Codec::new()) + } +} + +impl WebsocketContext +where + A: Actor, +{ + /// Write payload + /// + /// This is a low-level function that accepts framed messages that should + /// be created using `Frame::message()`. If you want to send text or binary + /// data you should prefer the `text()` or `binary()` convenience functions + /// that handle the framing for you. + #[inline] + pub fn write_raw(&mut self, msg: Message) { + self.messages.push_back(Some(msg)); + } + + /// Send text frame + #[inline] + pub fn text>(&mut self, text: T) { + self.write_raw(Message::Text(text.into())); + } + + /// Send binary frame + #[inline] + pub fn binary>(&mut self, data: B) { + self.write_raw(Message::Binary(data.into())); + } + + /// Send ping frame + #[inline] + pub fn ping(&mut self, message: &str) { + self.write_raw(Message::Ping(message.to_string())); + } + + /// Send pong frame + #[inline] + pub fn pong(&mut self, message: &str) { + self.write_raw(Message::Pong(message.to_string())); + } + + /// Send close frame + #[inline] + pub fn close(&mut self, reason: Option) { + self.write_raw(Message::Close(reason)); + } + + /// Handle of the running future + /// + /// SpawnHandle is the handle returned by `AsyncContext::spawn()` method. + pub fn handle(&self) -> SpawnHandle { + self.inner.curr_handle() + } + + /// Set mailbox capacity + /// + /// By default mailbox capacity is 16 messages. + pub fn set_mailbox_capacity(&mut self, cap: usize) { + self.inner.set_mailbox_capacity(cap) + } +} + +impl AsyncContextParts for WebsocketContext +where + A: Actor, +{ + fn parts(&mut self) -> &mut ContextParts { + &mut self.inner + } +} + +struct WebsocketContextFut +where + A: Actor>, +{ + fut: ContextFut>, + encoder: Codec, + buf: BytesMut, + closed: bool, +} + +impl WebsocketContextFut +where + A: Actor>, +{ + fn new(ctx: WebsocketContext, act: A, mailbox: Mailbox, codec: Codec) -> Self { + let fut = ContextFut::new(ctx, act, mailbox); + WebsocketContextFut { + fut, + encoder: codec, + buf: BytesMut::new(), + closed: false, + } + } +} + +impl Stream for WebsocketContextFut +where + A: Actor>, +{ + type Item = Bytes; + type Error = Error; + + fn poll(&mut self) -> Poll, Error> { + if self.fut.alive() && self.fut.poll().is_err() { + return Err(ErrorInternalServerError("error")); + } + + // encode messages + while let Some(item) = self.fut.ctx().messages.pop_front() { + if let Some(msg) = item { + self.encoder.encode(msg, &mut self.buf)?; + } else { + self.closed = true; + break; + } + } + + if !self.buf.is_empty() { + Ok(Async::Ready(Some(self.buf.take().freeze()))) + } else if self.fut.alive() && !self.closed { + Ok(Async::NotReady) + } else { + Ok(Async::Ready(None)) + } + } +} + +impl ToEnvelope for WebsocketContext +where + A: Actor> + Handler, + M: ActixMessage + Send + 'static, + M::Result: Send, +{ + fn pack(msg: M, tx: Option>) -> Envelope { + Envelope::new(msg, tx) + } +} + +struct WsStream { + stream: S, + decoder: Codec, + buf: BytesMut, + closed: bool, +} + +impl WsStream +where + S: Stream, +{ + fn new(stream: S, codec: Codec) -> Self { + Self { + stream, + decoder: codec, + buf: BytesMut::new(), + closed: false, + } + } +} + +impl Stream for WsStream +where + S: Stream, +{ + type Item = Message; + type Error = ProtocolError; + + fn poll(&mut self) -> Poll, Self::Error> { + if !self.closed { + loop { + match self.stream.poll() { + Ok(Async::Ready(Some(chunk))) => { + self.buf.extend_from_slice(&chunk[..]); + } + Ok(Async::Ready(None)) => { + self.closed = true; + break; + } + Ok(Async::NotReady) => break, + Err(e) => { + return Err(ProtocolError::Io(io::Error::new( + io::ErrorKind::Other, + format!("{}", e), + ))); + } + } + } + } + + match self.decoder.decode(&mut self.buf)? { + None => { + if self.closed { + Ok(Async::Ready(None)) + } else { + Ok(Async::NotReady) + } + } + Some(frm) => { + let msg = match frm { + Frame::Text(data) => { + if let Some(data) = data { + Message::Text( + std::str::from_utf8(&data) + .map_err(|_| ProtocolError::BadEncoding)? + .to_string(), + ) + } else { + Message::Text(String::new()) + } + } + Frame::Binary(data) => Message::Binary( + data.map(|b| b.freeze()).unwrap_or_else(Bytes::new), + ), + Frame::Ping(s) => Message::Ping(s), + Frame::Pong(s) => Message::Pong(s), + Frame::Close(reason) => Message::Close(reason), + }; + Ok(Async::Ready(Some(msg))) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use actix_web::http::{header, Method}; + use actix_web::test::TestRequest; + + #[test] + fn test_handshake() { + let req = TestRequest::default() + .method(Method::POST) + .to_http_request(); + assert_eq!( + HandshakeError::GetMethodRequired, + handshake(&req).err().unwrap() + ); + + let req = TestRequest::default().to_http_request(); + assert_eq!( + HandshakeError::NoWebsocketUpgrade, + handshake(&req).err().unwrap() + ); + + let req = TestRequest::default() + .header(header::UPGRADE, header::HeaderValue::from_static("test")) + .to_http_request(); + assert_eq!( + HandshakeError::NoWebsocketUpgrade, + handshake(&req).err().unwrap() + ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .to_http_request(); + assert_eq!( + HandshakeError::NoConnectionUpgrade, + handshake(&req).err().unwrap() + ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .header( + header::CONNECTION, + header::HeaderValue::from_static("upgrade"), + ) + .to_http_request(); + assert_eq!( + HandshakeError::NoVersionHeader, + handshake(&req).err().unwrap() + ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .header( + header::CONNECTION, + header::HeaderValue::from_static("upgrade"), + ) + .header( + header::SEC_WEBSOCKET_VERSION, + header::HeaderValue::from_static("5"), + ) + .to_http_request(); + assert_eq!( + HandshakeError::UnsupportedVersion, + handshake(&req).err().unwrap() + ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .header( + header::CONNECTION, + header::HeaderValue::from_static("upgrade"), + ) + .header( + header::SEC_WEBSOCKET_VERSION, + header::HeaderValue::from_static("13"), + ) + .to_http_request(); + assert_eq!( + HandshakeError::BadWebsocketKey, + handshake(&req).err().unwrap() + ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .header( + header::CONNECTION, + header::HeaderValue::from_static("upgrade"), + ) + .header( + header::SEC_WEBSOCKET_VERSION, + header::HeaderValue::from_static("13"), + ) + .header( + header::SEC_WEBSOCKET_KEY, + header::HeaderValue::from_static("13"), + ) + .to_http_request(); + + assert_eq!( + StatusCode::SWITCHING_PROTOCOLS, + handshake(&req).unwrap().finish().status() + ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .header( + header::CONNECTION, + header::HeaderValue::from_static("upgrade"), + ) + .header( + header::SEC_WEBSOCKET_VERSION, + header::HeaderValue::from_static("13"), + ) + .header( + header::SEC_WEBSOCKET_KEY, + header::HeaderValue::from_static("13"), + ) + .header( + header::SEC_WEBSOCKET_PROTOCOL, + header::HeaderValue::from_static("graphql"), + ) + .to_http_request(); + + let protocols = ["graphql"]; + + assert_eq!( + StatusCode::SWITCHING_PROTOCOLS, + handshake_with_protocols(&req, &protocols) + .unwrap() + .finish() + .status() + ); + assert_eq!( + Some(&header::HeaderValue::from_static("graphql")), + handshake_with_protocols(&req, &protocols) + .unwrap() + .finish() + .headers() + .get(&header::SEC_WEBSOCKET_PROTOCOL) + ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .header( + header::CONNECTION, + header::HeaderValue::from_static("upgrade"), + ) + .header( + header::SEC_WEBSOCKET_VERSION, + header::HeaderValue::from_static("13"), + ) + .header( + header::SEC_WEBSOCKET_KEY, + header::HeaderValue::from_static("13"), + ) + .header( + header::SEC_WEBSOCKET_PROTOCOL, + header::HeaderValue::from_static("p1, p2, p3"), + ) + .to_http_request(); + + let protocols = vec!["p3", "p2"]; + + assert_eq!( + StatusCode::SWITCHING_PROTOCOLS, + handshake_with_protocols(&req, &protocols) + .unwrap() + .finish() + .status() + ); + assert_eq!( + Some(&header::HeaderValue::from_static("p2")), + handshake_with_protocols(&req, &protocols) + .unwrap() + .finish() + .headers() + .get(&header::SEC_WEBSOCKET_PROTOCOL) + ); + } +} diff --git a/actix-web-actors/tests/test_ws.rs b/actix-web-actors/tests/test_ws.rs new file mode 100644 index 000000000..687cf4314 --- /dev/null +++ b/actix-web-actors/tests/test_ws.rs @@ -0,0 +1,68 @@ +use actix::prelude::*; +use actix_http::HttpService; +use actix_http_test::TestServer; +use actix_web::{web, App, HttpRequest}; +use actix_web_actors::*; +use bytes::{Bytes, BytesMut}; +use futures::{Sink, Stream}; + +struct Ws; + +impl Actor for Ws { + type Context = ws::WebsocketContext; +} + +impl StreamHandler for Ws { + fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { + match msg { + ws::Message::Ping(msg) => ctx.pong(&msg), + ws::Message::Text(text) => ctx.text(text), + ws::Message::Binary(bin) => ctx.binary(bin), + ws::Message::Close(reason) => ctx.close(reason), + _ => (), + } + } +} + +#[test] +fn test_simple() { + let mut srv = + TestServer::new(|| { + HttpService::new(App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| ws::start(Ws, &req, stream), + ))) + }); + + // client service + let framed = srv.ws().unwrap(); + let framed = srv + .block_on(framed.send(ws::Message::Text("text".to_string()))) + .unwrap(); + let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text"))))); + + let framed = srv + .block_on(framed.send(ws::Message::Binary("text".into()))) + .unwrap(); + let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!( + item, + Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))) + ); + + let framed = srv + .block_on(framed.send(ws::Message::Ping("text".into()))) + .unwrap(); + let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into()))); + + let framed = srv + .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))) + .unwrap(); + + let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!( + item, + Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) + ); +} diff --git a/actix-web-codegen/CHANGES.md b/actix-web-codegen/CHANGES.md new file mode 100644 index 000000000..2beea62cf --- /dev/null +++ b/actix-web-codegen/CHANGES.md @@ -0,0 +1,31 @@ +# Changes + +## [0.1.3] - 2019-10-14 + +* Bump up `syn` & `quote` to 1.0 + +* Provide better error message + +## [0.1.2] - 2019-06-04 + +* Add macros for head, options, trace, connect and patch http methods + +## [0.1.1] - 2019-06-01 + +* Add syn "extra-traits" feature + +## [0.1.0] - 2019-05-18 + +* Release + +## [0.1.0-beta.1] - 2019-04-20 + +* Gen code for actix-web 1.0.0-beta.1 + +## [0.1.0-alpha.6] - 2019-04-14 + +* Gen code for actix-web 1.0.0-alpha.6 + +## [0.1.0-alpha.1] - 2019-03-28 + +* Initial impl diff --git a/actix-web-codegen/Cargo.toml b/actix-web-codegen/Cargo.toml new file mode 100644 index 000000000..95883363a --- /dev/null +++ b/actix-web-codegen/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "actix-web-codegen" +version = "0.2.0-alpha.1" +description = "Actix web proc macros" +readme = "README.md" +authors = ["Nikolay Kim "] +license = "MIT/Apache-2.0" +edition = "2018" +workspace = ".." + +[lib] +proc-macro = true + +[dependencies] +quote = "^1" +syn = { version = "^1", features = ["full", "parsing"] } +proc-macro2 = "^1" + +[dev-dependencies] +actix-rt = { version = "1.0.0-alpha.1" } +actix-web = { version = "2.0.0-alpha.1" } +actix-http = { version = "0.3.0-alpha.1", features=["openssl"] } +actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] } +futures = { version = "0.3.1" } diff --git a/actix-web-codegen/LICENSE-APACHE b/actix-web-codegen/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/actix-web-codegen/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/actix-web-codegen/LICENSE-MIT b/actix-web-codegen/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/actix-web-codegen/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/actix-web-codegen/README.md b/actix-web-codegen/README.md new file mode 100644 index 000000000..c44a5fc7f --- /dev/null +++ b/actix-web-codegen/README.md @@ -0,0 +1 @@ +# Macros for actix-web framework [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/actix-web-codegen)](https://crates.io/crates/actix-web-codegen) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) diff --git a/actix-web-codegen/src/lib.rs b/actix-web-codegen/src/lib.rs new file mode 100644 index 000000000..0a727ed69 --- /dev/null +++ b/actix-web-codegen/src/lib.rs @@ -0,0 +1,186 @@ +#![recursion_limit = "512"] +//! Actix-web codegen module +//! +//! Generators for routes and scopes +//! +//! ## Route +//! +//! Macros: +//! +//! - [get](attr.get.html) +//! - [post](attr.post.html) +//! - [put](attr.put.html) +//! - [delete](attr.delete.html) +//! - [head](attr.head.html) +//! - [connect](attr.connect.html) +//! - [options](attr.options.html) +//! - [trace](attr.trace.html) +//! - [patch](attr.patch.html) +//! +//! ### Attributes: +//! +//! - `"path"` - Raw literal string with path for which to register handle. Mandatory. +//! - `guard="function_name"` - Registers function as guard using `actix_web::guard::fn_guard` +//! +//! ## Notes +//! +//! Function name can be specified as any expression that is going to be accessible to the generate +//! code (e.g `my_guard` or `my_module::my_guard`) +//! +//! ## Example: +//! +//! ```rust +//! use actix_web::HttpResponse; +//! use actix_web_codegen::get; +//! use futures::{future, Future}; +//! +//! #[get("/test")] +//! async fn async_test() -> Result { +//! Ok(HttpResponse::Ok().finish()) +//! } +//! ``` + +extern crate proc_macro; + +mod route; + +use proc_macro::TokenStream; +use syn::parse_macro_input; + +/// Creates route handler with `GET` method guard. +/// +/// Syntax: `#[get("path"[, attributes])]` +/// +/// ## Attributes: +/// +/// - `"path"` - Raw literal string with path for which to register handler. Mandatory. +/// - `guard="function_name"` - Registers function as guard using `actix_web::guard::fn_guard` +#[proc_macro_attribute] +pub fn get(args: TokenStream, input: TokenStream) -> TokenStream { + let args = parse_macro_input!(args as syn::AttributeArgs); + let gen = match route::Route::new(args, input, route::GuardType::Get) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; + gen.generate() +} + +/// Creates route handler with `POST` method guard. +/// +/// Syntax: `#[post("path"[, attributes])]` +/// +/// Attributes are the same as in [get](attr.get.html) +#[proc_macro_attribute] +pub fn post(args: TokenStream, input: TokenStream) -> TokenStream { + let args = parse_macro_input!(args as syn::AttributeArgs); + let gen = match route::Route::new(args, input, route::GuardType::Post) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; + gen.generate() +} + +/// Creates route handler with `PUT` method guard. +/// +/// Syntax: `#[put("path"[, attributes])]` +/// +/// Attributes are the same as in [get](attr.get.html) +#[proc_macro_attribute] +pub fn put(args: TokenStream, input: TokenStream) -> TokenStream { + let args = parse_macro_input!(args as syn::AttributeArgs); + let gen = match route::Route::new(args, input, route::GuardType::Put) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; + gen.generate() +} + +/// Creates route handler with `DELETE` method guard. +/// +/// Syntax: `#[delete("path"[, attributes])]` +/// +/// Attributes are the same as in [get](attr.get.html) +#[proc_macro_attribute] +pub fn delete(args: TokenStream, input: TokenStream) -> TokenStream { + let args = parse_macro_input!(args as syn::AttributeArgs); + let gen = match route::Route::new(args, input, route::GuardType::Delete) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; + gen.generate() +} + +/// Creates route handler with `HEAD` method guard. +/// +/// Syntax: `#[head("path"[, attributes])]` +/// +/// Attributes are the same as in [head](attr.head.html) +#[proc_macro_attribute] +pub fn head(args: TokenStream, input: TokenStream) -> TokenStream { + let args = parse_macro_input!(args as syn::AttributeArgs); + let gen = match route::Route::new(args, input, route::GuardType::Head) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; + gen.generate() +} + +/// Creates route handler with `CONNECT` method guard. +/// +/// Syntax: `#[connect("path"[, attributes])]` +/// +/// Attributes are the same as in [connect](attr.connect.html) +#[proc_macro_attribute] +pub fn connect(args: TokenStream, input: TokenStream) -> TokenStream { + let args = parse_macro_input!(args as syn::AttributeArgs); + let gen = match route::Route::new(args, input, route::GuardType::Connect) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; + gen.generate() +} + +/// Creates route handler with `OPTIONS` method guard. +/// +/// Syntax: `#[options("path"[, attributes])]` +/// +/// Attributes are the same as in [options](attr.options.html) +#[proc_macro_attribute] +pub fn options(args: TokenStream, input: TokenStream) -> TokenStream { + let args = parse_macro_input!(args as syn::AttributeArgs); + let gen = match route::Route::new(args, input, route::GuardType::Options) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; + gen.generate() +} + +/// Creates route handler with `TRACE` method guard. +/// +/// Syntax: `#[trace("path"[, attributes])]` +/// +/// Attributes are the same as in [trace](attr.trace.html) +#[proc_macro_attribute] +pub fn trace(args: TokenStream, input: TokenStream) -> TokenStream { + let args = parse_macro_input!(args as syn::AttributeArgs); + let gen = match route::Route::new(args, input, route::GuardType::Trace) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; + gen.generate() +} + +/// Creates route handler with `PATCH` method guard. +/// +/// Syntax: `#[patch("path"[, attributes])]` +/// +/// Attributes are the same as in [patch](attr.patch.html) +#[proc_macro_attribute] +pub fn patch(args: TokenStream, input: TokenStream) -> TokenStream { + let args = parse_macro_input!(args as syn::AttributeArgs); + let gen = match route::Route::new(args, input, route::GuardType::Patch) { + Ok(gen) => gen, + Err(err) => return err.to_compile_error().into(), + }; + gen.generate() +} diff --git a/actix-web-codegen/src/route.rs b/actix-web-codegen/src/route.rs new file mode 100644 index 000000000..16d3e8157 --- /dev/null +++ b/actix-web-codegen/src/route.rs @@ -0,0 +1,212 @@ +extern crate proc_macro; + +use proc_macro::TokenStream; +use proc_macro2::{Span, TokenStream as TokenStream2}; +use quote::{quote, ToTokens, TokenStreamExt}; +use syn::{AttributeArgs, Ident, NestedMeta}; + +enum ResourceType { + Async, + Sync, +} + +impl ToTokens for ResourceType { + fn to_tokens(&self, stream: &mut TokenStream2) { + let ident = match self { + ResourceType::Async => "to", + ResourceType::Sync => "to", + }; + let ident = Ident::new(ident, Span::call_site()); + stream.append(ident); + } +} + +#[derive(PartialEq)] +pub enum GuardType { + Get, + Post, + Put, + Delete, + Head, + Connect, + Options, + Trace, + Patch, +} + +impl GuardType { + fn as_str(&self) -> &'static str { + match self { + GuardType::Get => "Get", + GuardType::Post => "Post", + GuardType::Put => "Put", + GuardType::Delete => "Delete", + GuardType::Head => "Head", + GuardType::Connect => "Connect", + GuardType::Options => "Options", + GuardType::Trace => "Trace", + GuardType::Patch => "Patch", + } + } +} + +impl ToTokens for GuardType { + fn to_tokens(&self, stream: &mut TokenStream2) { + let ident = self.as_str(); + let ident = Ident::new(ident, Span::call_site()); + stream.append(ident); + } +} + +struct Args { + path: syn::LitStr, + guards: Vec, +} + +impl Args { + fn new(args: AttributeArgs) -> syn::Result { + let mut path = None; + let mut guards = Vec::new(); + for arg in args { + match arg { + NestedMeta::Lit(syn::Lit::Str(lit)) => match path { + None => { + path = Some(lit); + } + _ => { + return Err(syn::Error::new_spanned( + lit, + "Multiple paths specified! Should be only one!", + )); + } + }, + NestedMeta::Meta(syn::Meta::NameValue(nv)) => { + if nv.path.is_ident("guard") { + if let syn::Lit::Str(lit) = nv.lit { + guards.push(Ident::new(&lit.value(), Span::call_site())); + } else { + return Err(syn::Error::new_spanned( + nv.lit, + "Attribute guard expects literal string!", + )); + } + } else { + return Err(syn::Error::new_spanned( + nv.path, + "Unknown attribute key is specified. Allowed: guard", + )); + } + } + arg => { + return Err(syn::Error::new_spanned(arg, "Unknown attribute")); + } + } + } + Ok(Args { + path: path.unwrap(), + guards, + }) + } +} + +pub struct Route { + name: syn::Ident, + args: Args, + ast: syn::ItemFn, + resource_type: ResourceType, + guard: GuardType, +} + +fn guess_resource_type(typ: &syn::Type) -> ResourceType { + let mut guess = ResourceType::Sync; + + if let syn::Type::ImplTrait(typ) = typ { + for bound in typ.bounds.iter() { + if let syn::TypeParamBound::Trait(bound) = bound { + for bound in bound.path.segments.iter() { + if bound.ident == "Future" { + guess = ResourceType::Async; + break; + } else if bound.ident == "Responder" { + guess = ResourceType::Sync; + break; + } + } + } + } + } + + guess +} + +impl Route { + pub fn new( + args: AttributeArgs, + input: TokenStream, + guard: GuardType, + ) -> syn::Result { + if args.is_empty() { + return Err(syn::Error::new( + Span::call_site(), + format!( + r#"invalid server definition, expected #[{}("")]"#, + guard.as_str().to_ascii_lowercase() + ), + )); + } + let ast: syn::ItemFn = syn::parse(input)?; + let name = ast.sig.ident.clone(); + + let args = Args::new(args)?; + + let resource_type = if ast.sig.asyncness.is_some() { + ResourceType::Async + } else { + match ast.sig.output { + syn::ReturnType::Default => { + return Err(syn::Error::new_spanned( + ast, + "Function has no return type. Cannot be used as handler", + )); + } + syn::ReturnType::Type(_, ref typ) => guess_resource_type(typ.as_ref()), + } + }; + + Ok(Self { + name, + args, + ast, + resource_type, + guard, + }) + } + + pub fn generate(&self) -> TokenStream { + let name = &self.name; + let resource_name = name.to_string(); + let guard = &self.guard; + let ast = &self.ast; + let path = &self.args.path; + let extra_guards = &self.args.guards; + let resource_type = &self.resource_type; + let stream = quote! { + #[allow(non_camel_case_types)] + pub struct #name; + + impl actix_web::dev::HttpServiceFactory for #name { + fn register(self, config: &mut actix_web::dev::AppService) { + #ast + let resource = actix_web::Resource::new(#path) + .name(#resource_name) + .guard(actix_web::guard::#guard()) + #(.guard(actix_web::guard::fn_guard(#extra_guards)))* + .#resource_type(#name); + + actix_web::dev::HttpServiceFactory::register(resource, config) + } + } + }; + stream.into() + } +} diff --git a/actix-web-codegen/tests/test_macro.rs b/actix-web-codegen/tests/test_macro.rs new file mode 100644 index 000000000..b6ac6dd18 --- /dev/null +++ b/actix-web-codegen/tests/test_macro.rs @@ -0,0 +1,157 @@ +use actix_http::HttpService; +use actix_http_test::TestServer; +use actix_web::{http, web::Path, App, HttpResponse, Responder}; +use actix_web_codegen::{connect, delete, get, head, options, patch, post, put, trace}; +use futures::{future, Future}; + +#[get("/test")] +async fn test() -> impl Responder { + HttpResponse::Ok() +} + +#[put("/test")] +async fn put_test() -> impl Responder { + HttpResponse::Created() +} + +#[patch("/test")] +async fn patch_test() -> impl Responder { + HttpResponse::Ok() +} + +#[post("/test")] +async fn post_test() -> impl Responder { + HttpResponse::NoContent() +} + +#[head("/test")] +async fn head_test() -> impl Responder { + HttpResponse::Ok() +} + +#[connect("/test")] +async fn connect_test() -> impl Responder { + HttpResponse::Ok() +} + +#[options("/test")] +async fn options_test() -> impl Responder { + HttpResponse::Ok() +} + +#[trace("/test")] +async fn trace_test() -> impl Responder { + HttpResponse::Ok() +} + +#[get("/test")] +fn auto_async() -> impl Future> { + future::ok(HttpResponse::Ok().finish()) +} + +#[get("/test")] +fn auto_sync() -> impl Future> { + future::ok(HttpResponse::Ok().finish()) +} + +#[put("/test/{param}")] +async fn put_param_test(_: Path) -> impl Responder { + HttpResponse::Created() +} + +#[delete("/test/{param}")] +async fn delete_param_test(_: Path) -> impl Responder { + HttpResponse::NoContent() +} + +#[get("/test/{param}")] +async fn get_param_test(_: Path) -> impl Responder { + HttpResponse::Ok() +} + +#[actix_rt::test] +async fn test_params() { + let srv = TestServer::start(|| { + HttpService::new( + App::new() + .service(get_param_test) + .service(put_param_test) + .service(delete_param_test), + ) + }); + + let request = srv.request(http::Method::GET, srv.url("/test/it")); + let response = request.send().await.unwrap(); + assert_eq!(response.status(), http::StatusCode::OK); + + let request = srv.request(http::Method::PUT, srv.url("/test/it")); + let response = request.send().await.unwrap(); + assert_eq!(response.status(), http::StatusCode::CREATED); + + let request = srv.request(http::Method::DELETE, srv.url("/test/it")); + let response = request.send().await.unwrap(); + assert_eq!(response.status(), http::StatusCode::NO_CONTENT); +} + +#[actix_rt::test] +async fn test_body() { + let srv = TestServer::start(|| { + HttpService::new( + App::new() + .service(post_test) + .service(put_test) + .service(head_test) + .service(connect_test) + .service(options_test) + .service(trace_test) + .service(patch_test) + .service(test), + ) + }); + let request = srv.request(http::Method::GET, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + + let request = srv.request(http::Method::HEAD, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + + let request = srv.request(http::Method::CONNECT, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + + let request = srv.request(http::Method::OPTIONS, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + + let request = srv.request(http::Method::TRACE, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + + let request = srv.request(http::Method::PATCH, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + + let request = srv.request(http::Method::PUT, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.status(), http::StatusCode::CREATED); + + let request = srv.request(http::Method::POST, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.status(), http::StatusCode::NO_CONTENT); + + let request = srv.request(http::Method::GET, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); +} + +#[actix_rt::test] +async fn test_auto_async() { + let srv = TestServer::start(|| HttpService::new(App::new().service(auto_async))); + + let request = srv.request(http::Method::GET, srv.url("/test")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); +} diff --git a/awc/CHANGES.md b/awc/CHANGES.md new file mode 100644 index 000000000..89423f80e --- /dev/null +++ b/awc/CHANGES.md @@ -0,0 +1,156 @@ +# Changes + +## [0.2.8] - 2019-11-06 + +* Add support for setting query from Serialize type for client request. + + +## [0.2.7] - 2019-09-25 + +### Added + +* Remaining getter methods for `ClientRequest`'s private `head` field #1101 + + +## [0.2.6] - 2019-09-12 + +### Added + +* Export frozen request related types. + + +## [0.2.5] - 2019-09-11 + +### Added + +* Add `FrozenClientRequest` to support retries for sending HTTP requests + +### Changed + +* Ensure that the `Host` header is set when initiating a WebSocket client connection. + + +## [0.2.4] - 2019-08-13 + +### Changed + +* Update percent-encoding to "2.1" + +* Update serde_urlencoded to "0.6.1" + + +## [0.2.3] - 2019-08-01 + +### Added + +* Add `rustls` support + + +## [0.2.2] - 2019-07-01 + +### Changed + +* Always append a colon after username in basic auth + +* Upgrade `rand` dependency version to 0.7 + + +## [0.2.1] - 2019-06-05 + +### Added + +* Add license files + +## [0.2.0] - 2019-05-12 + +### Added + +* Allow to send headers in `Camel-Case` form. + +### Changed + +* Upgrade actix-http dependency. + + +## [0.1.1] - 2019-04-19 + +### Added + +* Allow to specify server address for http and ws requests. + +### Changed + +* `ClientRequest::if_true()` and `ClientRequest::if_some()` use instance instead of ref + + +## [0.1.0] - 2019-04-16 + +* No changes + + +## [0.1.0-alpha.6] - 2019-04-14 + +### Changed + +* Do not set default headers for websocket request + + +## [0.1.0-alpha.5] - 2019-04-12 + +### Changed + +* Do not set any default headers + +### Added + +* Add Debug impl for BoxedSocket + + +## [0.1.0-alpha.4] - 2019-04-08 + +### Changed + +* Update actix-http dependency + + +## [0.1.0-alpha.3] - 2019-04-02 + +### Added + +* Export `MessageBody` type + +* `ClientResponse::json()` - Loads and parse `application/json` encoded body + + +### Changed + +* `ClientRequest::json()` accepts reference instead of object. + +* `ClientResponse::body()` does not consume response object. + +* Renamed `ClientRequest::close_connection()` to `ClientRequest::force_close()` + + +## [0.1.0-alpha.2] - 2019-03-29 + +### Added + +* Per request and session wide request timeout. + +* Session wide headers. + +* Session wide basic and bearer auth. + +* Re-export `actix_http::client::Connector`. + + +### Changed + +* Allow to override request's uri + +* Export `ws` sub-module with websockets related types + + +## [0.1.0-alpha.1] - 2019-03-28 + +* Initial impl diff --git a/awc/Cargo.toml b/awc/Cargo.toml new file mode 100644 index 000000000..e9268aac0 --- /dev/null +++ b/awc/Cargo.toml @@ -0,0 +1,75 @@ +[package] +name = "awc" +version = "0.3.0-alpha.1" +authors = ["Nikolay Kim "] +description = "Actix http client." +readme = "README.md" +keywords = ["actix", "http", "framework", "async", "web"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" +documentation = "https://docs.rs/awc/" +categories = ["network-programming", "asynchronous", + "web-programming::http-client", + "web-programming::websocket"] +license = "MIT/Apache-2.0" +exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"] +edition = "2018" +workspace = ".." + +[lib] +name = "awc" +path = "src/lib.rs" + +[package.metadata.docs.rs] +features = ["openssl", "brotli", "flate2-zlib"] + +[features] +default = ["brotli", "flate2-zlib"] + +# openssl +openssl = ["open-ssl", "actix-http/openssl"] + +# rustls +# rustls = ["rust-tls", "actix-http/rustls"] + +# brotli encoding, requires c compiler +brotli = ["actix-http/brotli"] + +# miniz-sys backend for flate2 crate +flate2-zlib = ["actix-http/flate2-zlib"] + +# rust backend for flate2 crate +flate2-rust = ["actix-http/flate2-rust"] + +[dependencies] +actix-codec = "0.2.0-alpha.1" +actix-service = "1.0.0-alpha.1" +actix-http = "0.3.0-alpha.1" +actix-rt = "1.0.0-alpha.1" + +base64 = "0.10.1" +bytes = "0.4" +derive_more = "0.99.2" +futures = "0.3.1" +log =" 0.4" +mime = "0.3" +percent-encoding = "2.1" +rand = "0.7" +serde = "1.0" +serde_json = "1.0" +serde_urlencoded = "0.6.1" +open-ssl = { version="0.10", package="openssl", optional = true } +# rust-tls = { version = "0.16.0", package="rustls", optional = true, features = ["dangerous_configuration"] } + +[dev-dependencies] +actix-connect = { version = "1.0.0-alpha.1", features=["openssl"] } +actix-web = { version = "2.0.0-alpha.1", features=["openssl"] } +actix-http = { version = "0.3.0-alpha.1", features=["openssl"] } +actix-http-test = { version = "0.3.0-alpha.1", features=["openssl"] } +actix-utils = "0.5.0-alpha.1" +actix-server = { version = "0.8.0-alpha.1", features=["openssl"] } +brotli2 = { version="0.3.2" } +flate2 = { version="1.0.2" } +env_logger = "0.6" +rand = "0.7" +webpki = { version = "0.21" } diff --git a/awc/LICENSE-APACHE b/awc/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/awc/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/awc/LICENSE-MIT b/awc/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/awc/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/awc/README.md b/awc/README.md new file mode 100644 index 000000000..3b0034d76 --- /dev/null +++ b/awc/README.md @@ -0,0 +1,33 @@ +# Actix http client [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](https://meritbadge.herokuapp.com/awc)](https://crates.io/crates/awc) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +An HTTP Client + +## Documentation & community resources + +* [User Guide](https://actix.rs/docs/) +* [API Documentation](https://docs.rs/awc/) +* [Chat on gitter](https://gitter.im/actix/actix) +* Cargo package: [awc](https://crates.io/crates/awc) +* Minimum supported Rust version: 1.33 or later + +## Example + +```rust +use actix_rt::System; +use awc::Client; +use futures::future::{Future, lazy}; + +fn main() { + System::new("test").block_on(lazy(|| { + let mut client = Client::default(); + + client.get("http://www.rust-lang.org") // <- Create request builder + .header("User-Agent", "Actix-web") + .send() // <- Send http request + .and_then(|response| { // <- server http response + println!("Response: {:?}", response); + Ok(()) + }) + })); +} +``` diff --git a/awc/src/builder.rs b/awc/src/builder.rs new file mode 100644 index 000000000..463f40303 --- /dev/null +++ b/awc/src/builder.rs @@ -0,0 +1,191 @@ +use std::cell::RefCell; +use std::fmt; +use std::rc::Rc; +use std::time::Duration; + +use actix_http::client::{Connect, ConnectError, Connection, Connector}; +use actix_http::http::{header, HeaderMap, HeaderName, HttpTryFrom}; +use actix_service::Service; + +use crate::connect::ConnectorWrapper; +use crate::{Client, ClientConfig}; + +/// An HTTP Client builder +/// +/// This type can be used to construct an instance of `Client` through a +/// builder-like pattern. +pub struct ClientBuilder { + config: ClientConfig, + default_headers: bool, + allow_redirects: bool, + max_redirects: usize, +} + +impl Default for ClientBuilder { + fn default() -> Self { + Self::new() + } +} + +impl ClientBuilder { + pub fn new() -> Self { + ClientBuilder { + default_headers: true, + allow_redirects: true, + max_redirects: 10, + config: ClientConfig { + headers: HeaderMap::new(), + timeout: Some(Duration::from_secs(5)), + connector: RefCell::new(Box::new(ConnectorWrapper( + Connector::new().finish(), + ))), + }, + } + } + + /// Use custom connector service. + pub fn connector(mut self, connector: T) -> Self + where + T: Service + 'static, + T::Response: Connection, + ::Future: 'static, + T::Future: 'static, + { + self.config.connector = RefCell::new(Box::new(ConnectorWrapper(connector))); + self + } + + /// Set request timeout + /// + /// Request timeout is the total time before a response must be received. + /// Default value is 5 seconds. + pub fn timeout(mut self, timeout: Duration) -> Self { + self.config.timeout = Some(timeout); + self + } + + /// Disable request timeout. + pub fn disable_timeout(mut self) -> Self { + self.config.timeout = None; + self + } + + /// Do not follow redirects. + /// + /// Redirects are allowed by default. + pub fn disable_redirects(mut self) -> Self { + self.allow_redirects = false; + self + } + + /// Set max number of redirects. + /// + /// Max redirects is set to 10 by default. + pub fn max_redirects(mut self, num: usize) -> Self { + self.max_redirects = num; + self + } + + /// Do not add default request headers. + /// By default `Date` and `User-Agent` headers are set. + pub fn no_default_headers(mut self) -> Self { + self.default_headers = false; + self + } + + /// Add default header. Headers added by this method + /// get added to every request. + pub fn header(mut self, key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + >::Error: fmt::Debug, + V: header::IntoHeaderValue, + V::Error: fmt::Debug, + { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => { + self.config.headers.append(key, value); + } + Err(e) => log::error!("Header value error: {:?}", e), + }, + Err(e) => log::error!("Header name error: {:?}", e), + } + self + } + + /// Set client wide HTTP basic authorization header + pub fn basic_auth(self, username: U, password: Option<&str>) -> Self + where + U: fmt::Display, + { + let auth = match password { + Some(password) => format!("{}:{}", username, password), + None => format!("{}:", username), + }; + self.header( + header::AUTHORIZATION, + format!("Basic {}", base64::encode(&auth)), + ) + } + + /// Set client wide HTTP bearer authentication header + pub fn bearer_auth(self, token: T) -> Self + where + T: fmt::Display, + { + self.header(header::AUTHORIZATION, format!("Bearer {}", token)) + } + + /// Finish build process and create `Client` instance. + pub fn finish(self) -> Client { + Client(Rc::new(self.config)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn client_basic_auth() { + let client = ClientBuilder::new().basic_auth("username", Some("password")); + assert_eq!( + client + .config + .headers + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap(), + "Basic dXNlcm5hbWU6cGFzc3dvcmQ=" + ); + + let client = ClientBuilder::new().basic_auth("username", None); + assert_eq!( + client + .config + .headers + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap(), + "Basic dXNlcm5hbWU6" + ); + } + + #[test] + fn client_bearer_auth() { + let client = ClientBuilder::new().bearer_auth("someS3cr3tAutht0k3n"); + assert_eq!( + client + .config + .headers + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap(), + "Bearer someS3cr3tAutht0k3n" + ); + } +} diff --git a/awc/src/connect.rs b/awc/src/connect.rs new file mode 100644 index 000000000..cc92fdbb6 --- /dev/null +++ b/awc/src/connect.rs @@ -0,0 +1,236 @@ +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; +use std::{fmt, io, net}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_http::body::Body; +use actix_http::client::{ + Connect as ClientConnect, ConnectError, Connection, SendRequestError, +}; +use actix_http::h1::ClientCodec; +use actix_http::http::HeaderMap; +use actix_http::{RequestHead, RequestHeadType, ResponseHead}; +use actix_service::Service; +use futures::future::{FutureExt, LocalBoxFuture}; + +use crate::response::ClientResponse; + +pub(crate) struct ConnectorWrapper(pub T); + +pub(crate) trait Connect { + fn send_request( + &mut self, + head: RequestHead, + body: Body, + addr: Option, + ) -> LocalBoxFuture<'static, Result>; + + fn send_request_extra( + &mut self, + head: Rc, + extra_headers: Option, + body: Body, + addr: Option, + ) -> LocalBoxFuture<'static, Result>; + + /// Send request, returns Response and Framed + fn open_tunnel( + &mut self, + head: RequestHead, + addr: Option, + ) -> LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, + >; + + /// Send request and extra headers, returns Response and Framed + fn open_tunnel_extra( + &mut self, + head: Rc, + extra_headers: Option, + addr: Option, + ) -> LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, + >; +} + +impl Connect for ConnectorWrapper +where + T: Service, + T::Response: Connection, + ::Io: 'static, + ::Future: 'static, + ::TunnelFuture: 'static, + T::Future: 'static, +{ + fn send_request( + &mut self, + head: RequestHead, + body: Body, + addr: Option, + ) -> LocalBoxFuture<'static, Result> { + // connect to the host + let fut = self.0.call(ClientConnect { + uri: head.uri.clone(), + addr, + }); + + async move { + let connection = fut.await?; + + // send request + connection + .send_request(RequestHeadType::from(head), body) + .await + .map(|(head, payload)| ClientResponse::new(head, payload)) + } + .boxed_local() + } + + fn send_request_extra( + &mut self, + head: Rc, + extra_headers: Option, + body: Body, + addr: Option, + ) -> LocalBoxFuture<'static, Result> { + // connect to the host + let fut = self.0.call(ClientConnect { + uri: head.uri.clone(), + addr, + }); + + async move { + let connection = fut.await?; + + // send request + let (head, payload) = connection + .send_request(RequestHeadType::Rc(head, extra_headers), body) + .await?; + + Ok(ClientResponse::new(head, payload)) + } + .boxed_local() + } + + fn open_tunnel( + &mut self, + head: RequestHead, + addr: Option, + ) -> LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, + > { + // connect to the host + let fut = self.0.call(ClientConnect { + uri: head.uri.clone(), + addr, + }); + + async move { + let connection = fut.await?; + + // send request + let (head, framed) = + connection.open_tunnel(RequestHeadType::from(head)).await?; + + let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); + Ok((head, framed)) + } + .boxed_local() + } + + fn open_tunnel_extra( + &mut self, + head: Rc, + extra_headers: Option, + addr: Option, + ) -> LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, + > { + // connect to the host + let fut = self.0.call(ClientConnect { + uri: head.uri.clone(), + addr, + }); + + async move { + let connection = fut.await?; + + // send request + let (head, framed) = connection + .open_tunnel(RequestHeadType::Rc(head, extra_headers)) + .await?; + + let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); + Ok((head, framed)) + } + .boxed_local() + } +} + +trait AsyncSocket { + fn as_read(&self) -> &(dyn AsyncRead + Unpin); + fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin); + fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin); +} + +struct Socket(T); + +impl AsyncSocket for Socket { + fn as_read(&self) -> &(dyn AsyncRead + Unpin) { + &self.0 + } + fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin) { + &mut self.0 + } + fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin) { + &mut self.0 + } +} + +pub struct BoxedSocket(Box); + +impl fmt::Debug for BoxedSocket { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "BoxedSocket") + } +} + +impl AsyncRead for BoxedSocket { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.0.as_read().prepare_uninitialized_buffer(buf) + } + + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(self.get_mut().0.as_read_mut()).poll_read(cx, buf) + } +} + +impl AsyncWrite for BoxedSocket { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(self.get_mut().0.as_write()).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(self.get_mut().0.as_write()).poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(self.get_mut().0.as_write()).poll_shutdown(cx) + } +} diff --git a/awc/src/error.rs b/awc/src/error.rs new file mode 100644 index 000000000..8816c4075 --- /dev/null +++ b/awc/src/error.rs @@ -0,0 +1,71 @@ +//! Http client errors +pub use actix_http::client::{ + ConnectError, FreezeRequestError, InvalidUrl, SendRequestError, +}; +pub use actix_http::error::PayloadError; +pub use actix_http::ws::HandshakeError as WsHandshakeError; +pub use actix_http::ws::ProtocolError as WsProtocolError; + +use actix_http::ResponseError; +use serde_json::error::Error as JsonError; + +use actix_http::http::{header::HeaderValue, Error as HttpError, StatusCode}; +use derive_more::{Display, From}; + +/// Websocket client error +#[derive(Debug, Display, From)] +pub enum WsClientError { + /// Invalid response status + #[display(fmt = "Invalid response status")] + InvalidResponseStatus(StatusCode), + /// Invalid upgrade header + #[display(fmt = "Invalid upgrade header")] + InvalidUpgradeHeader, + /// Invalid connection header + #[display(fmt = "Invalid connection header")] + InvalidConnectionHeader(HeaderValue), + /// Missing CONNECTION header + #[display(fmt = "Missing CONNECTION header")] + MissingConnectionHeader, + /// Missing SEC-WEBSOCKET-ACCEPT header + #[display(fmt = "Missing SEC-WEBSOCKET-ACCEPT header")] + MissingWebSocketAcceptHeader, + /// Invalid challenge response + #[display(fmt = "Invalid challenge response")] + InvalidChallengeResponse(String, HeaderValue), + /// Protocol error + #[display(fmt = "{}", _0)] + Protocol(WsProtocolError), + /// Send request error + #[display(fmt = "{}", _0)] + SendRequest(SendRequestError), +} + +impl From for WsClientError { + fn from(err: InvalidUrl) -> Self { + WsClientError::SendRequest(err.into()) + } +} + +impl From for WsClientError { + fn from(err: HttpError) -> Self { + WsClientError::SendRequest(err.into()) + } +} + +/// A set of errors that can occur during parsing json payloads +#[derive(Debug, Display, From)] +pub enum JsonPayloadError { + /// Content type error + #[display(fmt = "Content type error")] + ContentType, + /// Deserialize error + #[display(fmt = "Json deserialize error: {}", _0)] + Deserialize(JsonError), + /// Payload error + #[display(fmt = "Error that occur during reading payload: {}", _0)] + Payload(PayloadError), +} + +/// Return `InternalServerError` for `JsonPayloadError` +impl ResponseError for JsonPayloadError {} diff --git a/awc/src/frozen.rs b/awc/src/frozen.rs new file mode 100644 index 000000000..61ba87aad --- /dev/null +++ b/awc/src/frozen.rs @@ -0,0 +1,235 @@ +use std::net; +use std::rc::Rc; +use std::time::Duration; + +use bytes::Bytes; +use futures::Stream; +use serde::Serialize; + +use actix_http::body::Body; +use actix_http::http::header::IntoHeaderValue; +use actix_http::http::{ + Error as HttpError, HeaderMap, HeaderName, HttpTryFrom, Method, Uri, +}; +use actix_http::{Error, RequestHead}; + +use crate::sender::{RequestSender, SendClientRequest}; +use crate::ClientConfig; + +/// `FrozenClientRequest` struct represents clonable client request. +/// It could be used to send same request multiple times. +#[derive(Clone)] +pub struct FrozenClientRequest { + pub(crate) head: Rc, + pub(crate) addr: Option, + pub(crate) response_decompress: bool, + pub(crate) timeout: Option, + pub(crate) config: Rc, +} + +impl FrozenClientRequest { + /// Get HTTP URI of request + pub fn get_uri(&self) -> &Uri { + &self.head.uri + } + + /// Get HTTP method of this request + pub fn get_method(&self) -> &Method { + &self.head.method + } + + /// Returns request's headers. + pub fn headers(&self) -> &HeaderMap { + &self.head.headers + } + + /// Send a body. + pub fn send_body(&self, body: B) -> SendClientRequest + where + B: Into, + { + RequestSender::Rc(self.head.clone(), None).send_body( + self.addr, + self.response_decompress, + self.timeout, + self.config.as_ref(), + body, + ) + } + + /// Send a json body. + pub fn send_json(&self, value: &T) -> SendClientRequest { + RequestSender::Rc(self.head.clone(), None).send_json( + self.addr, + self.response_decompress, + self.timeout, + self.config.as_ref(), + value, + ) + } + + /// Send an urlencoded body. + pub fn send_form(&self, value: &T) -> SendClientRequest { + RequestSender::Rc(self.head.clone(), None).send_form( + self.addr, + self.response_decompress, + self.timeout, + self.config.as_ref(), + value, + ) + } + + /// Send a streaming body. + pub fn send_stream(&self, stream: S) -> SendClientRequest + where + S: Stream> + Unpin + 'static, + E: Into + 'static, + { + RequestSender::Rc(self.head.clone(), None).send_stream( + self.addr, + self.response_decompress, + self.timeout, + self.config.as_ref(), + stream, + ) + } + + /// Send an empty body. + pub fn send(&self) -> SendClientRequest { + RequestSender::Rc(self.head.clone(), None).send( + self.addr, + self.response_decompress, + self.timeout, + self.config.as_ref(), + ) + } + + /// Create a `FrozenSendBuilder` with extra headers + pub fn extra_headers(&self, extra_headers: HeaderMap) -> FrozenSendBuilder { + FrozenSendBuilder::new(self.clone(), extra_headers) + } + + /// Create a `FrozenSendBuilder` with an extra header + pub fn extra_header(&self, key: K, value: V) -> FrozenSendBuilder + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + self.extra_headers(HeaderMap::new()) + .extra_header(key, value) + } +} + +/// Builder that allows to modify extra headers. +pub struct FrozenSendBuilder { + req: FrozenClientRequest, + extra_headers: HeaderMap, + err: Option, +} + +impl FrozenSendBuilder { + pub(crate) fn new(req: FrozenClientRequest, extra_headers: HeaderMap) -> Self { + Self { + req, + extra_headers, + err: None, + } + } + + /// Insert a header, it overrides existing header in `FrozenClientRequest`. + pub fn extra_header(mut self, key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => self.extra_headers.insert(key, value), + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Complete request construction and send a body. + pub fn send_body(self, body: B) -> SendClientRequest + where + B: Into, + { + if let Some(e) = self.err { + return e.into(); + } + + RequestSender::Rc(self.req.head, Some(self.extra_headers)).send_body( + self.req.addr, + self.req.response_decompress, + self.req.timeout, + self.req.config.as_ref(), + body, + ) + } + + /// Complete request construction and send a json body. + pub fn send_json(self, value: &T) -> SendClientRequest { + if let Some(e) = self.err { + return e.into(); + } + + RequestSender::Rc(self.req.head, Some(self.extra_headers)).send_json( + self.req.addr, + self.req.response_decompress, + self.req.timeout, + self.req.config.as_ref(), + value, + ) + } + + /// Complete request construction and send an urlencoded body. + pub fn send_form(self, value: &T) -> SendClientRequest { + if let Some(e) = self.err { + return e.into(); + } + + RequestSender::Rc(self.req.head, Some(self.extra_headers)).send_form( + self.req.addr, + self.req.response_decompress, + self.req.timeout, + self.req.config.as_ref(), + value, + ) + } + + /// Complete request construction and send a streaming body. + pub fn send_stream(self, stream: S) -> SendClientRequest + where + S: Stream> + Unpin + 'static, + E: Into + 'static, + { + if let Some(e) = self.err { + return e.into(); + } + + RequestSender::Rc(self.req.head, Some(self.extra_headers)).send_stream( + self.req.addr, + self.req.response_decompress, + self.req.timeout, + self.req.config.as_ref(), + stream, + ) + } + + /// Complete request construction and send an empty body. + pub fn send(self) -> SendClientRequest { + if let Some(e) = self.err { + return e.into(); + } + + RequestSender::Rc(self.req.head, Some(self.extra_headers)).send( + self.req.addr, + self.req.response_decompress, + self.req.timeout, + self.req.config.as_ref(), + ) + } +} diff --git a/awc/src/lib.rs b/awc/src/lib.rs new file mode 100644 index 000000000..e995519ea --- /dev/null +++ b/awc/src/lib.rs @@ -0,0 +1,197 @@ +#![allow(clippy::borrow_interior_mutable_const)] +//! An HTTP Client +//! +//! ```rust +//! use futures::future::{lazy, Future}; +//! use actix_rt::System; +//! use awc::Client; +//! +//! #[actix_rt::main] +//! async fn main() { +//! let mut client = Client::default(); +//! +//! let response = client.get("http://www.rust-lang.org") // <- Create request builder +//! .header("User-Agent", "Actix-web") +//! .send() // <- Send http request +//! .await; +//! +//! println!("Response: {:?}", response); +//! } +//! ``` +use std::cell::RefCell; +use std::rc::Rc; +use std::time::Duration; + +pub use actix_http::{client::Connector, cookie, http}; + +use actix_http::http::{HeaderMap, HttpTryFrom, Method, Uri}; +use actix_http::RequestHead; + +mod builder; +mod connect; +pub mod error; +mod frozen; +mod request; +mod response; +mod sender; +pub mod test; +pub mod ws; + +pub use self::builder::ClientBuilder; +pub use self::connect::BoxedSocket; +pub use self::frozen::{FrozenClientRequest, FrozenSendBuilder}; +pub use self::request::ClientRequest; +pub use self::response::{ClientResponse, JsonBody, MessageBody}; +pub use self::sender::SendClientRequest; + +use self::connect::{Connect, ConnectorWrapper}; + +/// An HTTP Client +/// +/// ```rust +/// use actix_rt::System; +/// use awc::Client; +/// +/// fn main() { +/// System::new("test").block_on(async { +/// let mut client = Client::default(); +/// +/// client.get("http://www.rust-lang.org") // <- Create request builder +/// .header("User-Agent", "Actix-web") +/// .send() // <- Send http request +/// .await +/// .and_then(|response| { // <- server http response +/// println!("Response: {:?}", response); +/// Ok(()) +/// }) +/// }); +/// } +/// ``` +#[derive(Clone)] +pub struct Client(Rc); + +pub(crate) struct ClientConfig { + pub(crate) connector: RefCell>, + pub(crate) headers: HeaderMap, + pub(crate) timeout: Option, +} + +impl Default for Client { + fn default() -> Self { + Client(Rc::new(ClientConfig { + connector: RefCell::new(Box::new(ConnectorWrapper( + Connector::new().finish(), + ))), + headers: HeaderMap::new(), + timeout: Some(Duration::from_secs(5)), + })) + } +} + +impl Client { + /// Create new client instance with default settings. + pub fn new() -> Client { + Client::default() + } + + /// Build client instance. + pub fn build() -> ClientBuilder { + ClientBuilder::new() + } + + /// Construct HTTP request. + pub fn request(&self, method: Method, url: U) -> ClientRequest + where + Uri: HttpTryFrom, + { + let mut req = ClientRequest::new(method, url, self.0.clone()); + + for (key, value) in self.0.headers.iter() { + req = req.set_header_if_none(key.clone(), value.clone()); + } + req + } + + /// Create `ClientRequest` from `RequestHead` + /// + /// It is useful for proxy requests. This implementation + /// copies all headers and the method. + pub fn request_from(&self, url: U, head: &RequestHead) -> ClientRequest + where + Uri: HttpTryFrom, + { + let mut req = self.request(head.method.clone(), url); + for (key, value) in head.headers.iter() { + req = req.set_header_if_none(key.clone(), value.clone()); + } + req + } + + /// Construct HTTP *GET* request. + pub fn get(&self, url: U) -> ClientRequest + where + Uri: HttpTryFrom, + { + self.request(Method::GET, url) + } + + /// Construct HTTP *HEAD* request. + pub fn head(&self, url: U) -> ClientRequest + where + Uri: HttpTryFrom, + { + self.request(Method::HEAD, url) + } + + /// Construct HTTP *PUT* request. + pub fn put(&self, url: U) -> ClientRequest + where + Uri: HttpTryFrom, + { + self.request(Method::PUT, url) + } + + /// Construct HTTP *POST* request. + pub fn post(&self, url: U) -> ClientRequest + where + Uri: HttpTryFrom, + { + self.request(Method::POST, url) + } + + /// Construct HTTP *PATCH* request. + pub fn patch(&self, url: U) -> ClientRequest + where + Uri: HttpTryFrom, + { + self.request(Method::PATCH, url) + } + + /// Construct HTTP *DELETE* request. + pub fn delete(&self, url: U) -> ClientRequest + where + Uri: HttpTryFrom, + { + self.request(Method::DELETE, url) + } + + /// Construct HTTP *OPTIONS* request. + pub fn options(&self, url: U) -> ClientRequest + where + Uri: HttpTryFrom, + { + self.request(Method::OPTIONS, url) + } + + /// Construct WebSockets request. + pub fn ws(&self, url: U) -> ws::WebsocketsRequest + where + Uri: HttpTryFrom, + { + let mut req = ws::WebsocketsRequest::new(url, self.0.clone()); + for (key, value) in self.0.headers.iter() { + req.head.headers.insert(key.clone(), value.clone()); + } + req + } +} diff --git a/awc/src/request.rs b/awc/src/request.rs new file mode 100644 index 000000000..3660f8086 --- /dev/null +++ b/awc/src/request.rs @@ -0,0 +1,717 @@ +use std::fmt::Write as FmtWrite; +use std::io::Write; +use std::rc::Rc; +use std::time::Duration; +use std::{fmt, net}; + +use bytes::{BufMut, Bytes, BytesMut}; +use futures::Stream; +use percent_encoding::percent_encode; +use serde::Serialize; + +use actix_http::body::Body; +use actix_http::cookie::{Cookie, CookieJar, USERINFO}; +use actix_http::http::header::{self, Header, IntoHeaderValue}; +use actix_http::http::{ + uri, ConnectionType, Error as HttpError, HeaderMap, HeaderName, HeaderValue, + HttpTryFrom, Method, Uri, Version, +}; +use actix_http::{Error, RequestHead}; + +use crate::error::{FreezeRequestError, InvalidUrl}; +use crate::frozen::FrozenClientRequest; +use crate::sender::{PrepForSendingError, RequestSender, SendClientRequest}; +use crate::ClientConfig; + +#[cfg(any(feature = "brotli", feature = "flate2-zlib", feature = "flate2-rust"))] +const HTTPS_ENCODING: &str = "br, gzip, deflate"; +#[cfg(all( + any(feature = "flate2-zlib", feature = "flate2-rust"), + not(feature = "brotli") +))] +const HTTPS_ENCODING: &str = "gzip, deflate"; + +/// An HTTP Client request builder +/// +/// This type can be used to construct an instance of `ClientRequest` through a +/// builder-like pattern. +/// +/// ```rust +/// use actix_rt::System; +/// +/// #[actix_rt::main] +/// async fn main() { +/// let response = awc::Client::new() +/// .get("http://www.rust-lang.org") // <- Create request builder +/// .header("User-Agent", "Actix-web") +/// .send() // <- Send http request +/// .await; +/// +/// response.and_then(|response| { // <- server http response +/// println!("Response: {:?}", response); +/// Ok(()) +/// }); +/// } +/// ``` +pub struct ClientRequest { + pub(crate) head: RequestHead, + err: Option, + addr: Option, + cookies: Option, + response_decompress: bool, + timeout: Option, + config: Rc, +} + +impl ClientRequest { + /// Create new client request builder. + pub(crate) fn new(method: Method, uri: U, config: Rc) -> Self + where + Uri: HttpTryFrom, + { + ClientRequest { + config, + head: RequestHead::default(), + err: None, + addr: None, + cookies: None, + timeout: None, + response_decompress: true, + } + .method(method) + .uri(uri) + } + + /// Set HTTP URI of request. + #[inline] + pub fn uri(mut self, uri: U) -> Self + where + Uri: HttpTryFrom, + { + match Uri::try_from(uri) { + Ok(uri) => self.head.uri = uri, + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Get HTTP URI of request. + pub fn get_uri(&self) -> &Uri { + &self.head.uri + } + + /// Set socket address of the server. + /// + /// This address is used for connection. If address is not + /// provided url's host name get resolved. + pub fn address(mut self, addr: net::SocketAddr) -> Self { + self.addr = Some(addr); + self + } + + /// Set HTTP method of this request. + #[inline] + pub fn method(mut self, method: Method) -> Self { + self.head.method = method; + self + } + + /// Get HTTP method of this request + pub fn get_method(&self) -> &Method { + &self.head.method + } + + #[doc(hidden)] + /// Set HTTP version of this request. + /// + /// By default requests's HTTP version depends on network stream + #[inline] + pub fn version(mut self, version: Version) -> Self { + self.head.version = version; + self + } + + /// Get HTTP version of this request. + pub fn get_version(&self) -> &Version { + &self.head.version + } + + /// Get peer address of this request. + pub fn get_peer_addr(&self) -> &Option { + &self.head.peer_addr + } + + #[inline] + /// Returns request's headers. + pub fn headers(&self) -> &HeaderMap { + &self.head.headers + } + + #[inline] + /// Returns request's mutable headers. + pub fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.head.headers + } + + /// Set a header. + /// + /// ```rust + /// fn main() { + /// # actix_rt::System::new("test").block_on(futures::future::lazy(|_| { + /// let req = awc::Client::new() + /// .get("http://www.rust-lang.org") + /// .set(awc::http::header::Date::now()) + /// .set(awc::http::header::ContentType(mime::TEXT_HTML)); + /// # Ok::<_, ()>(()) + /// # })); + /// } + /// ``` + pub fn set(mut self, hdr: H) -> Self { + match hdr.try_into() { + Ok(value) => { + self.head.headers.insert(H::name(), value); + } + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Append a header. + /// + /// Header gets appended to existing header. + /// To override header use `set_header()` method. + /// + /// ```rust + /// use awc::{http, Client}; + /// + /// fn main() { + /// # actix_rt::System::new("test").block_on(async { + /// let req = Client::new() + /// .get("http://www.rust-lang.org") + /// .header("X-TEST", "value") + /// .header(http::header::CONTENT_TYPE, "application/json"); + /// # Ok::<_, ()>(()) + /// # }); + /// } + /// ``` + pub fn header(mut self, key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => self.head.headers.append(key, value), + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Insert a header, replaces existing header. + pub fn set_header(mut self, key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => self.head.headers.insert(key, value), + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Insert a header only if it is not yet set. + pub fn set_header_if_none(mut self, key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + match HeaderName::try_from(key) { + Ok(key) => { + if !self.head.headers.contains_key(&key) { + match value.try_into() { + Ok(value) => self.head.headers.insert(key, value), + Err(e) => self.err = Some(e.into()), + } + } + } + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Send headers in `Camel-Case` form. + #[inline] + pub fn camel_case(mut self) -> Self { + self.head.set_camel_case_headers(true); + self + } + + /// Force close connection instead of returning it back to connections pool. + /// This setting affect only http/1 connections. + #[inline] + pub fn force_close(mut self) -> Self { + self.head.set_connection_type(ConnectionType::Close); + self + } + + /// Set request's content type + #[inline] + pub fn content_type(mut self, value: V) -> Self + where + HeaderValue: HttpTryFrom, + { + match HeaderValue::try_from(value) { + Ok(value) => self.head.headers.insert(header::CONTENT_TYPE, value), + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Set content length + #[inline] + pub fn content_length(self, len: u64) -> Self { + let mut wrt = BytesMut::new().writer(); + let _ = write!(wrt, "{}", len); + self.header(header::CONTENT_LENGTH, wrt.get_mut().take().freeze()) + } + + /// Set HTTP basic authorization header + pub fn basic_auth(self, username: U, password: Option<&str>) -> Self + where + U: fmt::Display, + { + let auth = match password { + Some(password) => format!("{}:{}", username, password), + None => format!("{}:", username), + }; + self.header( + header::AUTHORIZATION, + format!("Basic {}", base64::encode(&auth)), + ) + } + + /// Set HTTP bearer authentication header + pub fn bearer_auth(self, token: T) -> Self + where + T: fmt::Display, + { + self.header(header::AUTHORIZATION, format!("Bearer {}", token)) + } + + /// Set a cookie + /// + /// ```rust + /// #[actix_rt::main] + /// async fn main() { + /// let resp = awc::Client::new().get("https://www.rust-lang.org") + /// .cookie( + /// awc::http::Cookie::build("name", "value") + /// .domain("www.rust-lang.org") + /// .path("/") + /// .secure(true) + /// .http_only(true) + /// .finish(), + /// ) + /// .send() + /// .await; + /// + /// println!("Response: {:?}", resp); + /// } + /// ``` + pub fn cookie(mut self, cookie: Cookie<'_>) -> Self { + if self.cookies.is_none() { + let mut jar = CookieJar::new(); + jar.add(cookie.into_owned()); + self.cookies = Some(jar) + } else { + self.cookies.as_mut().unwrap().add(cookie.into_owned()); + } + self + } + + /// Disable automatic decompress of response's body + pub fn no_decompress(mut self) -> Self { + self.response_decompress = false; + self + } + + /// Set request timeout. Overrides client wide timeout setting. + /// + /// Request timeout is the total time before a response must be received. + /// Default value is 5 seconds. + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + self + } + + /// This method calls provided closure with builder reference if + /// value is `true`. + pub fn if_true(self, value: bool, f: F) -> Self + where + F: FnOnce(ClientRequest) -> ClientRequest, + { + if value { + f(self) + } else { + self + } + } + + /// This method calls provided closure with builder reference if + /// value is `Some`. + pub fn if_some(self, value: Option, f: F) -> Self + where + F: FnOnce(T, ClientRequest) -> ClientRequest, + { + if let Some(val) = value { + f(val, self) + } else { + self + } + } + + /// Sets the query part of the request + pub fn query( + mut self, + query: &T, + ) -> Result { + let mut parts = self.head.uri.clone().into_parts(); + + if let Some(path_and_query) = parts.path_and_query { + let query = serde_urlencoded::to_string(query)?; + let path = path_and_query.path(); + parts.path_and_query = format!("{}?{}", path, query).parse().ok(); + + match Uri::from_parts(parts) { + Ok(uri) => self.head.uri = uri, + Err(e) => self.err = Some(e.into()), + } + } + + Ok(self) + } + + /// Freeze request builder and construct `FrozenClientRequest`, + /// which could be used for sending same request multiple times. + pub fn freeze(self) -> Result { + let slf = match self.prep_for_sending() { + Ok(slf) => slf, + Err(e) => return Err(e.into()), + }; + + let request = FrozenClientRequest { + head: Rc::new(slf.head), + addr: slf.addr, + response_decompress: slf.response_decompress, + timeout: slf.timeout, + config: slf.config, + }; + + Ok(request) + } + + /// Complete request construction and send body. + pub fn send_body(self, body: B) -> SendClientRequest + where + B: Into, + { + let slf = match self.prep_for_sending() { + Ok(slf) => slf, + Err(e) => return e.into(), + }; + + RequestSender::Owned(slf.head).send_body( + slf.addr, + slf.response_decompress, + slf.timeout, + slf.config.as_ref(), + body, + ) + } + + /// Set a JSON body and generate `ClientRequest` + pub fn send_json(self, value: &T) -> SendClientRequest { + let slf = match self.prep_for_sending() { + Ok(slf) => slf, + Err(e) => return e.into(), + }; + + RequestSender::Owned(slf.head).send_json( + slf.addr, + slf.response_decompress, + slf.timeout, + slf.config.as_ref(), + value, + ) + } + + /// Set a urlencoded body and generate `ClientRequest` + /// + /// `ClientRequestBuilder` can not be used after this call. + pub fn send_form(self, value: &T) -> SendClientRequest { + let slf = match self.prep_for_sending() { + Ok(slf) => slf, + Err(e) => return e.into(), + }; + + RequestSender::Owned(slf.head).send_form( + slf.addr, + slf.response_decompress, + slf.timeout, + slf.config.as_ref(), + value, + ) + } + + /// Set an streaming body and generate `ClientRequest`. + pub fn send_stream(self, stream: S) -> SendClientRequest + where + S: Stream> + Unpin + 'static, + E: Into + 'static, + { + let slf = match self.prep_for_sending() { + Ok(slf) => slf, + Err(e) => return e.into(), + }; + + RequestSender::Owned(slf.head).send_stream( + slf.addr, + slf.response_decompress, + slf.timeout, + slf.config.as_ref(), + stream, + ) + } + + /// Set an empty body and generate `ClientRequest`. + pub fn send(self) -> SendClientRequest { + let slf = match self.prep_for_sending() { + Ok(slf) => slf, + Err(e) => return e.into(), + }; + + RequestSender::Owned(slf.head).send( + slf.addr, + slf.response_decompress, + slf.timeout, + slf.config.as_ref(), + ) + } + + fn prep_for_sending(mut self) -> Result { + if let Some(e) = self.err { + return Err(e.into()); + } + + // validate uri + let uri = &self.head.uri; + if uri.host().is_none() { + return Err(InvalidUrl::MissingHost.into()); + } else if uri.scheme_part().is_none() { + return Err(InvalidUrl::MissingScheme.into()); + } else if let Some(scheme) = uri.scheme_part() { + match scheme.as_str() { + "http" | "ws" | "https" | "wss" => (), + _ => return Err(InvalidUrl::UnknownScheme.into()), + } + } else { + return Err(InvalidUrl::UnknownScheme.into()); + } + + // set cookies + if let Some(ref mut jar) = self.cookies { + let mut cookie = String::new(); + for c in jar.delta() { + let name = percent_encode(c.name().as_bytes(), USERINFO); + let value = percent_encode(c.value().as_bytes(), USERINFO); + let _ = write!(&mut cookie, "; {}={}", name, value); + } + self.head.headers.insert( + header::COOKIE, + HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(), + ); + } + + let mut slf = self; + + // enable br only for https + #[cfg(any( + feature = "brotli", + feature = "flate2-zlib", + feature = "flate2-rust" + ))] + { + if slf.response_decompress { + let https = slf + .head + .uri + .scheme_part() + .map(|s| s == &uri::Scheme::HTTPS) + .unwrap_or(true); + + if https { + slf = slf.set_header_if_none(header::ACCEPT_ENCODING, HTTPS_ENCODING) + } else { + #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + { + slf = slf + .set_header_if_none(header::ACCEPT_ENCODING, "gzip, deflate") + } + }; + } + } + + Ok(slf) + } +} + +impl fmt::Debug for ClientRequest { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!( + f, + "\nClientRequest {:?} {}:{}", + self.head.version, self.head.method, self.head.uri + )?; + writeln!(f, " headers:")?; + for (key, val) in self.head.headers.iter() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::time::SystemTime; + + use super::*; + use crate::Client; + + #[test] + fn test_debug() { + let request = Client::new().get("/").header("x-test", "111"); + let repr = format!("{:?}", request); + assert!(repr.contains("ClientRequest")); + assert!(repr.contains("x-test")); + } + + #[test] + fn test_basics() { + let mut req = Client::new() + .put("/") + .version(Version::HTTP_2) + .set(header::Date(SystemTime::now().into())) + .content_type("plain/text") + .if_true(true, |req| req.header(header::SERVER, "awc")) + .if_true(false, |req| req.header(header::EXPECT, "awc")) + .if_some(Some("server"), |val, req| { + req.header(header::USER_AGENT, val) + }) + .if_some(Option::<&str>::None, |_, req| { + req.header(header::ALLOW, "1") + }) + .content_length(100); + assert!(req.headers().contains_key(header::CONTENT_TYPE)); + assert!(req.headers().contains_key(header::DATE)); + assert!(req.headers().contains_key(header::SERVER)); + assert!(req.headers().contains_key(header::USER_AGENT)); + assert!(!req.headers().contains_key(header::ALLOW)); + assert!(!req.headers().contains_key(header::EXPECT)); + assert_eq!(req.head.version, Version::HTTP_2); + let _ = req.headers_mut(); + let _ = req.send_body(""); + } + + #[test] + fn test_client_header() { + let req = Client::build() + .header(header::CONTENT_TYPE, "111") + .finish() + .get("/"); + + assert_eq!( + req.head + .headers + .get(header::CONTENT_TYPE) + .unwrap() + .to_str() + .unwrap(), + "111" + ); + } + + #[test] + fn test_client_header_override() { + let req = Client::build() + .header(header::CONTENT_TYPE, "111") + .finish() + .get("/") + .set_header(header::CONTENT_TYPE, "222"); + + assert_eq!( + req.head + .headers + .get(header::CONTENT_TYPE) + .unwrap() + .to_str() + .unwrap(), + "222" + ); + } + + #[test] + fn client_basic_auth() { + let req = Client::new() + .get("/") + .basic_auth("username", Some("password")); + assert_eq!( + req.head + .headers + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap(), + "Basic dXNlcm5hbWU6cGFzc3dvcmQ=" + ); + + let req = Client::new().get("/").basic_auth("username", None); + assert_eq!( + req.head + .headers + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap(), + "Basic dXNlcm5hbWU6" + ); + } + + #[test] + fn client_bearer_auth() { + let req = Client::new().get("/").bearer_auth("someS3cr3tAutht0k3n"); + assert_eq!( + req.head + .headers + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap(), + "Bearer someS3cr3tAutht0k3n" + ); + } + + #[test] + fn client_query() { + let req = Client::new() + .get("/") + .query(&[("key1", "val1"), ("key2", "val2")]) + .unwrap(); + assert_eq!(req.get_uri().query().unwrap(), "key1=val1&key2=val2"); + } +} diff --git a/awc/src/response.rs b/awc/src/response.rs new file mode 100644 index 000000000..00ab4cee1 --- /dev/null +++ b/awc/src/response.rs @@ -0,0 +1,465 @@ +use std::cell::{Ref, RefMut}; +use std::fmt; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::{Bytes, BytesMut}; +use futures::{ready, Future, Stream}; + +use actix_http::cookie::Cookie; +use actix_http::error::{CookieParseError, PayloadError}; +use actix_http::http::header::{CONTENT_LENGTH, SET_COOKIE}; +use actix_http::http::{HeaderMap, StatusCode, Version}; +use actix_http::{Extensions, HttpMessage, Payload, PayloadStream, ResponseHead}; +use serde::de::DeserializeOwned; + +use crate::error::JsonPayloadError; + +/// Client Response +pub struct ClientResponse { + pub(crate) head: ResponseHead, + pub(crate) payload: Payload, +} + +impl HttpMessage for ClientResponse { + type Stream = S; + + fn headers(&self) -> &HeaderMap { + &self.head.headers + } + + fn extensions(&self) -> Ref { + self.head.extensions() + } + + fn extensions_mut(&self) -> RefMut { + self.head.extensions_mut() + } + + fn take_payload(&mut self) -> Payload { + std::mem::replace(&mut self.payload, Payload::None) + } + + /// Load request cookies. + #[inline] + fn cookies(&self) -> Result>>, CookieParseError> { + struct Cookies(Vec>); + + if self.extensions().get::().is_none() { + let mut cookies = Vec::new(); + for hdr in self.headers().get_all(&SET_COOKIE) { + let s = std::str::from_utf8(hdr.as_bytes()) + .map_err(CookieParseError::from)?; + cookies.push(Cookie::parse_encoded(s)?.into_owned()); + } + self.extensions_mut().insert(Cookies(cookies)); + } + Ok(Ref::map(self.extensions(), |ext| { + &ext.get::().unwrap().0 + })) + } +} + +impl ClientResponse { + /// Create new Request instance + pub(crate) fn new(head: ResponseHead, payload: Payload) -> Self { + ClientResponse { head, payload } + } + + #[inline] + pub(crate) fn head(&self) -> &ResponseHead { + &self.head + } + + /// Read the Request Version. + #[inline] + pub fn version(&self) -> Version { + self.head().version + } + + /// Get the status from the server. + #[inline] + pub fn status(&self) -> StatusCode { + self.head().status + } + + #[inline] + /// Returns request's headers. + pub fn headers(&self) -> &HeaderMap { + &self.head().headers + } + + /// Set a body and return previous body value + pub fn map_body(mut self, f: F) -> ClientResponse + where + F: FnOnce(&mut ResponseHead, Payload) -> Payload, + { + let payload = f(&mut self.head, self.payload); + + ClientResponse { + payload, + head: self.head, + } + } +} + +impl ClientResponse +where + S: Stream>, +{ + /// Loads http response's body. + pub fn body(&mut self) -> MessageBody { + MessageBody::new(self) + } + + /// Loads and parse `application/json` encoded body. + /// Return `JsonBody` future. It resolves to a `T` value. + /// + /// Returns error: + /// + /// * content type is not `application/json` + /// * content length is greater than 256k + pub fn json(&mut self) -> JsonBody { + JsonBody::new(self) + } +} + +impl Stream for ClientResponse +where + S: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.get_mut().payload).poll_next(cx) + } +} + +impl fmt::Debug for ClientResponse { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "\nClientResponse {:?} {}", self.version(), self.status(),)?; + writeln!(f, " headers:")?; + for (key, val) in self.headers().iter() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + Ok(()) + } +} + +/// Future that resolves to a complete http message body. +pub struct MessageBody { + length: Option, + err: Option, + fut: Option>, +} + +impl MessageBody +where + S: Stream>, +{ + /// Create `MessageBody` for request. + pub fn new(res: &mut ClientResponse) -> MessageBody { + let mut len = None; + if let Some(l) = res.headers().get(&CONTENT_LENGTH) { + if let Ok(s) = l.to_str() { + if let Ok(l) = s.parse::() { + len = Some(l) + } else { + return Self::err(PayloadError::UnknownLength); + } + } else { + return Self::err(PayloadError::UnknownLength); + } + } + + MessageBody { + length: len, + err: None, + fut: Some(ReadBody::new(res.take_payload(), 262_144)), + } + } + + /// Change max size of payload. By default max size is 256Kb + pub fn limit(mut self, limit: usize) -> Self { + if let Some(ref mut fut) = self.fut { + fut.limit = limit; + } + self + } + + fn err(e: PayloadError) -> Self { + MessageBody { + fut: None, + err: Some(e), + length: None, + } + } +} + +impl Future for MessageBody +where + S: Stream> + Unpin, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); + + if let Some(err) = this.err.take() { + return Poll::Ready(Err(err)); + } + + if let Some(len) = this.length.take() { + if len > this.fut.as_ref().unwrap().limit { + return Poll::Ready(Err(PayloadError::Overflow)); + } + } + + Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx) + } +} + +/// Response's payload json parser, it resolves to a deserialized `T` value. +/// +/// Returns error: +/// +/// * content type is not `application/json` +/// * content length is greater than 64k +pub struct JsonBody { + length: Option, + err: Option, + fut: Option>, + _t: PhantomData, +} + +impl JsonBody +where + S: Stream>, + U: DeserializeOwned, +{ + /// Create `JsonBody` for request. + pub fn new(req: &mut ClientResponse) -> Self { + // check content-type + let json = if let Ok(Some(mime)) = req.mime_type() { + mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) + } else { + false + }; + if !json { + return JsonBody { + length: None, + fut: None, + err: Some(JsonPayloadError::ContentType), + _t: PhantomData, + }; + } + + let mut len = None; + if let Some(l) = req.headers().get(&CONTENT_LENGTH) { + if let Ok(s) = l.to_str() { + if let Ok(l) = s.parse::() { + len = Some(l) + } + } + } + + JsonBody { + length: len, + err: None, + fut: Some(ReadBody::new(req.take_payload(), 65536)), + _t: PhantomData, + } + } + + /// Change max size of payload. By default max size is 64Kb + pub fn limit(mut self, limit: usize) -> Self { + if let Some(ref mut fut) = self.fut { + fut.limit = limit; + } + self + } +} + +impl Unpin for JsonBody +where + T: Stream> + Unpin, + U: DeserializeOwned, +{ +} + +impl Future for JsonBody +where + T: Stream> + Unpin, + U: DeserializeOwned, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + if let Some(err) = self.err.take() { + return Poll::Ready(Err(err)); + } + + if let Some(len) = self.length.take() { + if len > self.fut.as_ref().unwrap().limit { + return Poll::Ready(Err(JsonPayloadError::Payload( + PayloadError::Overflow, + ))); + } + } + + let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?; + Poll::Ready(serde_json::from_slice::(&body).map_err(JsonPayloadError::from)) + } +} + +struct ReadBody { + stream: Payload, + buf: BytesMut, + limit: usize, +} + +impl ReadBody { + fn new(stream: Payload, limit: usize) -> Self { + Self { + stream, + buf: BytesMut::with_capacity(std::cmp::min(limit, 32768)), + limit, + } + } +} + +impl Future for ReadBody +where + S: Stream> + Unpin, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); + + loop { + return match Pin::new(&mut this.stream).poll_next(cx)? { + Poll::Ready(Some(chunk)) => { + if (this.buf.len() + chunk.len()) > this.limit { + Poll::Ready(Err(PayloadError::Overflow)) + } else { + this.buf.extend_from_slice(&chunk); + continue; + } + } + Poll::Ready(None) => Poll::Ready(Ok(this.buf.take().freeze())), + Poll::Pending => Poll::Pending, + }; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde::{Deserialize, Serialize}; + + use crate::{http::header, test::TestResponse}; + + #[actix_rt::test] + async fn test_body() { + let mut req = TestResponse::with_header(header::CONTENT_LENGTH, "xxxx").finish(); + match req.body().await.err().unwrap() { + PayloadError::UnknownLength => (), + _ => unreachable!("error"), + } + + let mut req = + TestResponse::with_header(header::CONTENT_LENGTH, "1000000").finish(); + match req.body().await.err().unwrap() { + PayloadError::Overflow => (), + _ => unreachable!("error"), + } + + let mut req = TestResponse::default() + .set_payload(Bytes::from_static(b"test")) + .finish(); + assert_eq!(req.body().await.ok().unwrap(), Bytes::from_static(b"test")); + + let mut req = TestResponse::default() + .set_payload(Bytes::from_static(b"11111111111111")) + .finish(); + match req.body().limit(5).await.err().unwrap() { + PayloadError::Overflow => (), + _ => unreachable!("error"), + } + } + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct MyObject { + name: String, + } + + fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool { + match err { + JsonPayloadError::Payload(PayloadError::Overflow) => match other { + JsonPayloadError::Payload(PayloadError::Overflow) => true, + _ => false, + }, + JsonPayloadError::ContentType => match other { + JsonPayloadError::ContentType => true, + _ => false, + }, + _ => false, + } + } + + #[actix_rt::test] + async fn test_json_body() { + let mut req = TestResponse::default().finish(); + let json = JsonBody::<_, MyObject>::new(&mut req).await; + assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); + + let mut req = TestResponse::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/text"), + ) + .finish(); + let json = JsonBody::<_, MyObject>::new(&mut req).await; + assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); + + let mut req = TestResponse::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("10000"), + ) + .finish(); + + let json = JsonBody::<_, MyObject>::new(&mut req).limit(100).await; + assert!(json_eq( + json.err().unwrap(), + JsonPayloadError::Payload(PayloadError::Overflow) + )); + + let mut req = TestResponse::default() + .header( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ) + .header( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + ) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .finish(); + + let json = JsonBody::<_, MyObject>::new(&mut req).await; + assert_eq!( + json.ok().unwrap(), + MyObject { + name: "test".to_owned() + } + ); + } +} diff --git a/awc/src/sender.rs b/awc/src/sender.rs new file mode 100644 index 000000000..9cf158c0d --- /dev/null +++ b/awc/src/sender.rs @@ -0,0 +1,294 @@ +use std::net; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; +use std::time::Duration; + +use actix_rt::time::{delay_for, Delay}; +use bytes::Bytes; +use derive_more::From; +use futures::{future::LocalBoxFuture, ready, Future, Stream}; +use serde::Serialize; +use serde_json; + +use actix_http::body::{Body, BodyStream}; +use actix_http::encoding::Decoder; +use actix_http::http::header::{self, ContentEncoding, IntoHeaderValue}; +use actix_http::http::{Error as HttpError, HeaderMap, HeaderName}; +use actix_http::{Error, Payload, PayloadStream, RequestHead}; + +use crate::error::{FreezeRequestError, InvalidUrl, SendRequestError}; +use crate::response::ClientResponse; +use crate::ClientConfig; + +#[derive(Debug, From)] +pub(crate) enum PrepForSendingError { + Url(InvalidUrl), + Http(HttpError), +} + +impl Into for PrepForSendingError { + fn into(self) -> FreezeRequestError { + match self { + PrepForSendingError::Url(e) => FreezeRequestError::Url(e), + PrepForSendingError::Http(e) => FreezeRequestError::Http(e), + } + } +} + +impl Into for PrepForSendingError { + fn into(self) -> SendRequestError { + match self { + PrepForSendingError::Url(e) => SendRequestError::Url(e), + PrepForSendingError::Http(e) => SendRequestError::Http(e), + } + } +} + +/// Future that sends request's payload and resolves to a server response. +#[must_use = "futures do nothing unless polled"] +pub enum SendClientRequest { + Fut( + LocalBoxFuture<'static, Result>, + Option, + bool, + ), + Err(Option), +} + +impl SendClientRequest { + pub(crate) fn new( + send: LocalBoxFuture<'static, Result>, + response_decompress: bool, + timeout: Option, + ) -> SendClientRequest { + let delay = timeout.map(|t| delay_for(t)); + SendClientRequest::Fut(send, delay, response_decompress) + } +} + +impl Future for SendClientRequest { + type Output = + Result>>, SendRequestError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); + + match this { + SendClientRequest::Fut(send, delay, response_decompress) => { + if delay.is_some() { + match Pin::new(delay.as_mut().unwrap()).poll(cx) { + Poll::Pending => (), + _ => return Poll::Ready(Err(SendRequestError::Timeout)), + } + } + + let res = ready!(Pin::new(send).poll(cx)).map(|res| { + res.map_body(|head, payload| { + if *response_decompress { + Payload::Stream(Decoder::from_headers( + payload, + &head.headers, + )) + } else { + Payload::Stream(Decoder::new( + payload, + ContentEncoding::Identity, + )) + } + }) + }); + + Poll::Ready(res) + } + SendClientRequest::Err(ref mut e) => match e.take() { + Some(e) => Poll::Ready(Err(e)), + None => panic!("Attempting to call completed future"), + }, + } + } +} + +impl From for SendClientRequest { + fn from(e: SendRequestError) -> Self { + SendClientRequest::Err(Some(e)) + } +} + +impl From for SendClientRequest { + fn from(e: Error) -> Self { + SendClientRequest::Err(Some(e.into())) + } +} + +impl From for SendClientRequest { + fn from(e: HttpError) -> Self { + SendClientRequest::Err(Some(e.into())) + } +} + +impl From for SendClientRequest { + fn from(e: PrepForSendingError) -> Self { + SendClientRequest::Err(Some(e.into())) + } +} + +#[derive(Debug)] +pub(crate) enum RequestSender { + Owned(RequestHead), + Rc(Rc, Option), +} + +impl RequestSender { + pub(crate) fn send_body( + self, + addr: Option, + response_decompress: bool, + timeout: Option, + config: &ClientConfig, + body: B, + ) -> SendClientRequest + where + B: Into, + { + let mut connector = config.connector.borrow_mut(); + + let fut = match self { + RequestSender::Owned(head) => { + connector.send_request(head, body.into(), addr) + } + RequestSender::Rc(head, extra_headers) => { + connector.send_request_extra(head, extra_headers, body.into(), addr) + } + }; + + SendClientRequest::new( + fut, + response_decompress, + timeout.or_else(|| config.timeout), + ) + } + + pub(crate) fn send_json( + mut self, + addr: Option, + response_decompress: bool, + timeout: Option, + config: &ClientConfig, + value: &T, + ) -> SendClientRequest { + let body = match serde_json::to_string(value) { + Ok(body) => body, + Err(e) => return Error::from(e).into(), + }; + + if let Err(e) = self.set_header_if_none(header::CONTENT_TYPE, "application/json") + { + return e.into(); + } + + self.send_body( + addr, + response_decompress, + timeout, + config, + Body::Bytes(Bytes::from(body)), + ) + } + + pub(crate) fn send_form( + mut self, + addr: Option, + response_decompress: bool, + timeout: Option, + config: &ClientConfig, + value: &T, + ) -> SendClientRequest { + let body = match serde_urlencoded::to_string(value) { + Ok(body) => body, + Err(e) => return Error::from(e).into(), + }; + + // set content-type + if let Err(e) = self.set_header_if_none( + header::CONTENT_TYPE, + "application/x-www-form-urlencoded", + ) { + return e.into(); + } + + self.send_body( + addr, + response_decompress, + timeout, + config, + Body::Bytes(Bytes::from(body)), + ) + } + + pub(crate) fn send_stream( + self, + addr: Option, + response_decompress: bool, + timeout: Option, + config: &ClientConfig, + stream: S, + ) -> SendClientRequest + where + S: Stream> + Unpin + 'static, + E: Into + 'static, + { + self.send_body( + addr, + response_decompress, + timeout, + config, + Body::from_message(BodyStream::new(stream)), + ) + } + + pub(crate) fn send( + self, + addr: Option, + response_decompress: bool, + timeout: Option, + config: &ClientConfig, + ) -> SendClientRequest { + self.send_body(addr, response_decompress, timeout, config, Body::Empty) + } + + fn set_header_if_none( + &mut self, + key: HeaderName, + value: V, + ) -> Result<(), HttpError> + where + V: IntoHeaderValue, + { + match self { + RequestSender::Owned(head) => { + if !head.headers.contains_key(&key) { + match value.try_into() { + Ok(value) => head.headers.insert(key, value), + Err(e) => return Err(e.into()), + } + } + } + RequestSender::Rc(head, extra_headers) => { + if !head.headers.contains_key(&key) + && !extra_headers.iter().any(|h| h.contains_key(&key)) + { + match value.try_into() { + Ok(v) => { + let h = extra_headers.get_or_insert(HeaderMap::new()); + h.insert(key, v) + } + Err(e) => return Err(e.into()), + }; + } + } + } + + Ok(()) + } +} diff --git a/awc/src/test.rs b/awc/src/test.rs new file mode 100644 index 000000000..641ecaa88 --- /dev/null +++ b/awc/src/test.rs @@ -0,0 +1,127 @@ +//! Test helpers for actix http client to use during testing. +use std::fmt::Write as FmtWrite; + +use actix_http::cookie::{Cookie, CookieJar, USERINFO}; +use actix_http::http::header::{self, Header, HeaderValue, IntoHeaderValue}; +use actix_http::http::{HeaderName, HttpTryFrom, StatusCode, Version}; +use actix_http::{h1, Payload, ResponseHead}; +use bytes::Bytes; +use percent_encoding::percent_encode; + +use crate::ClientResponse; + +/// Test `ClientResponse` builder +pub struct TestResponse { + head: ResponseHead, + cookies: CookieJar, + payload: Option, +} + +impl Default for TestResponse { + fn default() -> TestResponse { + TestResponse { + head: ResponseHead::new(StatusCode::OK), + cookies: CookieJar::new(), + payload: None, + } + } +} + +impl TestResponse { + /// Create TestResponse and set header + pub fn with_header(key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + Self::default().header(key, value) + } + + /// Set HTTP version of this response + pub fn version(mut self, ver: Version) -> Self { + self.head.version = ver; + self + } + + /// Set a header + pub fn set(mut self, hdr: H) -> Self { + if let Ok(value) = hdr.try_into() { + self.head.headers.append(H::name(), value); + return self; + } + panic!("Can not set header"); + } + + /// Append a header + pub fn header(mut self, key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + if let Ok(key) = HeaderName::try_from(key) { + if let Ok(value) = value.try_into() { + self.head.headers.append(key, value); + return self; + } + } + panic!("Can not create header"); + } + + /// Set cookie for this response + pub fn cookie(mut self, cookie: Cookie<'_>) -> Self { + self.cookies.add(cookie.into_owned()); + self + } + + /// Set response's payload + pub fn set_payload>(mut self, data: B) -> Self { + let mut payload = h1::Payload::empty(); + payload.unread_data(data.into()); + self.payload = Some(payload.into()); + self + } + + /// Complete response creation and generate `ClientResponse` instance + pub fn finish(self) -> ClientResponse { + let mut head = self.head; + + let mut cookie = String::new(); + for c in self.cookies.delta() { + let name = percent_encode(c.name().as_bytes(), USERINFO); + let value = percent_encode(c.value().as_bytes(), USERINFO); + let _ = write!(&mut cookie, "; {}={}", name, value); + } + if !cookie.is_empty() { + head.headers.insert( + header::SET_COOKIE, + HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(), + ); + } + + if let Some(pl) = self.payload { + ClientResponse::new(head, pl) + } else { + ClientResponse::new(head, h1::Payload::empty().into()) + } + } +} + +#[cfg(test)] +mod tests { + use std::time::SystemTime; + + use super::*; + use crate::{cookie, http::header}; + + #[test] + fn test_basics() { + let res = TestResponse::default() + .version(Version::HTTP_2) + .set(header::Date(SystemTime::now().into())) + .cookie(cookie::Cookie::build("name", "value").finish()) + .finish(); + assert!(res.headers().contains_key(header::SET_COOKIE)); + assert!(res.headers().contains_key(header::DATE)); + assert_eq!(res.version(), Version::HTTP_2); + } +} diff --git a/awc/src/ws.rs b/awc/src/ws.rs new file mode 100644 index 000000000..075c83562 --- /dev/null +++ b/awc/src/ws.rs @@ -0,0 +1,493 @@ +//! Websockets client +use std::fmt::Write as FmtWrite; +use std::net::SocketAddr; +use std::rc::Rc; +use std::{fmt, str}; + +use actix_codec::Framed; +use actix_http::cookie::{Cookie, CookieJar}; +use actix_http::{ws, Payload, RequestHead}; +use actix_rt::time::Timeout; +use percent_encoding::percent_encode; + +use actix_http::cookie::USERINFO; +pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message}; + +use crate::connect::BoxedSocket; +use crate::error::{InvalidUrl, SendRequestError, WsClientError}; +use crate::http::header::{ + self, HeaderName, HeaderValue, IntoHeaderValue, AUTHORIZATION, +}; +use crate::http::{ + ConnectionType, Error as HttpError, HttpTryFrom, Method, StatusCode, Uri, Version, +}; +use crate::response::ClientResponse; +use crate::ClientConfig; + +/// `WebSocket` connection +pub struct WebsocketsRequest { + pub(crate) head: RequestHead, + err: Option, + origin: Option, + protocols: Option, + addr: Option, + max_size: usize, + server_mode: bool, + cookies: Option, + config: Rc, +} + +impl WebsocketsRequest { + /// Create new websocket connection + pub(crate) fn new(uri: U, config: Rc) -> Self + where + Uri: HttpTryFrom, + { + let mut err = None; + let mut head = RequestHead::default(); + head.method = Method::GET; + head.version = Version::HTTP_11; + + match Uri::try_from(uri) { + Ok(uri) => head.uri = uri, + Err(e) => err = Some(e.into()), + } + + WebsocketsRequest { + head, + err, + config, + addr: None, + origin: None, + protocols: None, + max_size: 65_536, + server_mode: false, + cookies: None, + } + } + + /// Set socket address of the server. + /// + /// This address is used for connection. If address is not + /// provided url's host name get resolved. + pub fn address(mut self, addr: SocketAddr) -> Self { + self.addr = Some(addr); + self + } + + /// Set supported websocket protocols + pub fn protocols(mut self, protos: U) -> Self + where + U: IntoIterator, + V: AsRef, + { + let mut protos = protos + .into_iter() + .fold(String::new(), |acc, s| acc + s.as_ref() + ","); + protos.pop(); + self.protocols = Some(protos); + self + } + + /// Set a cookie + pub fn cookie(mut self, cookie: Cookie<'_>) -> Self { + if self.cookies.is_none() { + let mut jar = CookieJar::new(); + jar.add(cookie.into_owned()); + self.cookies = Some(jar) + } else { + self.cookies.as_mut().unwrap().add(cookie.into_owned()); + } + self + } + + /// Set request Origin + pub fn origin(mut self, origin: V) -> Self + where + HeaderValue: HttpTryFrom, + { + match HeaderValue::try_from(origin) { + Ok(value) => self.origin = Some(value), + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Set max frame size + /// + /// By default max size is set to 64kb + pub fn max_frame_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } + + /// Disable payload masking. By default ws client masks frame payload. + pub fn server_mode(mut self) -> Self { + self.server_mode = true; + self + } + + /// Append a header. + /// + /// Header gets appended to existing header. + /// To override header use `set_header()` method. + pub fn header(mut self, key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => { + self.head.headers.append(key, value); + } + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Insert a header, replaces existing header. + pub fn set_header(mut self, key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => { + self.head.headers.insert(key, value); + } + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Insert a header only if it is not yet set. + pub fn set_header_if_none(mut self, key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + match HeaderName::try_from(key) { + Ok(key) => { + if !self.head.headers.contains_key(&key) { + match value.try_into() { + Ok(value) => { + self.head.headers.insert(key, value); + } + Err(e) => self.err = Some(e.into()), + } + } + } + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Set HTTP basic authorization header + pub fn basic_auth(self, username: U, password: Option<&str>) -> Self + where + U: fmt::Display, + { + let auth = match password { + Some(password) => format!("{}:{}", username, password), + None => format!("{}:", username), + }; + self.header(AUTHORIZATION, format!("Basic {}", base64::encode(&auth))) + } + + /// Set HTTP bearer authentication header + pub fn bearer_auth(self, token: T) -> Self + where + T: fmt::Display, + { + self.header(AUTHORIZATION, format!("Bearer {}", token)) + } + + /// Complete request construction and connect to a websockets server. + pub async fn connect( + mut self, + ) -> Result<(ClientResponse, Framed), WsClientError> { + if let Some(e) = self.err.take() { + return Err(e.into()); + } + + // validate uri + let uri = &self.head.uri; + if uri.host().is_none() { + return Err(InvalidUrl::MissingHost.into()); + } else if uri.scheme_part().is_none() { + return Err(InvalidUrl::MissingScheme.into()); + } else if let Some(scheme) = uri.scheme_part() { + match scheme.as_str() { + "http" | "ws" | "https" | "wss" => (), + _ => return Err(InvalidUrl::UnknownScheme.into()), + } + } else { + return Err(InvalidUrl::UnknownScheme.into()); + } + + if !self.head.headers.contains_key(header::HOST) { + self.head.headers.insert( + header::HOST, + HeaderValue::from_str(uri.host().unwrap()).unwrap(), + ); + } + + // set cookies + if let Some(ref mut jar) = self.cookies { + let mut cookie = String::new(); + for c in jar.delta() { + let name = percent_encode(c.name().as_bytes(), USERINFO); + let value = percent_encode(c.value().as_bytes(), USERINFO); + let _ = write!(&mut cookie, "; {}={}", name, value); + } + self.head.headers.insert( + header::COOKIE, + HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(), + ); + } + + // origin + if let Some(origin) = self.origin.take() { + self.head.headers.insert(header::ORIGIN, origin); + } + + self.head.set_connection_type(ConnectionType::Upgrade); + self.head + .headers + .insert(header::UPGRADE, HeaderValue::from_static("websocket")); + self.head.headers.insert( + header::SEC_WEBSOCKET_VERSION, + HeaderValue::from_static("13"), + ); + + if let Some(protocols) = self.protocols.take() { + self.head.headers.insert( + header::SEC_WEBSOCKET_PROTOCOL, + HeaderValue::try_from(protocols.as_str()).unwrap(), + ); + } + + // Generate a random key for the `Sec-WebSocket-Key` header. + // a base64-encoded (see Section 4 of [RFC4648]) value that, + // when decoded, is 16 bytes in length (RFC 6455) + let sec_key: [u8; 16] = rand::random(); + let key = base64::encode(&sec_key); + + self.head.headers.insert( + header::SEC_WEBSOCKET_KEY, + HeaderValue::try_from(key.as_str()).unwrap(), + ); + + let head = self.head; + let max_size = self.max_size; + let server_mode = self.server_mode; + + let fut = self + .config + .connector + .borrow_mut() + .open_tunnel(head, self.addr); + + // set request timeout + let (head, framed) = if let Some(timeout) = self.config.timeout { + Timeout::new(fut, timeout) + .await + .map_err(|_| SendRequestError::Timeout.into()) + .and_then(|res| res)? + } else { + fut.await? + }; + + // verify response + if head.status != StatusCode::SWITCHING_PROTOCOLS { + return Err(WsClientError::InvalidResponseStatus(head.status)); + } + + // Check for "UPGRADE" to websocket header + let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) { + if let Ok(s) = hdr.to_str() { + s.to_ascii_lowercase().contains("websocket") + } else { + false + } + } else { + false + }; + if !has_hdr { + log::trace!("Invalid upgrade header"); + return Err(WsClientError::InvalidUpgradeHeader); + } + + // Check for "CONNECTION" header + if let Some(conn) = head.headers.get(&header::CONNECTION) { + if let Ok(s) = conn.to_str() { + if !s.to_ascii_lowercase().contains("upgrade") { + log::trace!("Invalid connection header: {}", s); + return Err(WsClientError::InvalidConnectionHeader(conn.clone())); + } + } else { + log::trace!("Invalid connection header: {:?}", conn); + return Err(WsClientError::InvalidConnectionHeader(conn.clone())); + } + } else { + log::trace!("Missing connection header"); + return Err(WsClientError::MissingConnectionHeader); + } + + if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) { + let encoded = ws::hash_key(key.as_ref()); + if hdr_key.as_bytes() != encoded.as_bytes() { + log::trace!( + "Invalid challenge response: expected: {} received: {:?}", + encoded, + key + ); + return Err(WsClientError::InvalidChallengeResponse( + encoded, + hdr_key.clone(), + )); + } + } else { + log::trace!("Missing SEC-WEBSOCKET-ACCEPT header"); + return Err(WsClientError::MissingWebSocketAcceptHeader); + }; + + // response and ws framed + Ok(( + ClientResponse::new(head, Payload::None), + framed.map_codec(|_| { + if server_mode { + ws::Codec::new().max_size(max_size) + } else { + ws::Codec::new().max_size(max_size).client_mode() + } + }), + )) + } +} + +impl fmt::Debug for WebsocketsRequest { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!( + f, + "\nWebsocketsRequest {}:{}", + self.head.method, self.head.uri + )?; + writeln!(f, " headers:")?; + for (key, val) in self.head.headers.iter() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Client; + + #[actix_rt::test] + async fn test_debug() { + let request = Client::new().ws("/").header("x-test", "111"); + let repr = format!("{:?}", request); + assert!(repr.contains("WebsocketsRequest")); + assert!(repr.contains("x-test")); + } + + #[actix_rt::test] + async fn test_header_override() { + let req = Client::build() + .header(header::CONTENT_TYPE, "111") + .finish() + .ws("/") + .set_header(header::CONTENT_TYPE, "222"); + + assert_eq!( + req.head + .headers + .get(header::CONTENT_TYPE) + .unwrap() + .to_str() + .unwrap(), + "222" + ); + } + + #[actix_rt::test] + async fn basic_auth() { + let req = Client::new() + .ws("/") + .basic_auth("username", Some("password")); + assert_eq!( + req.head + .headers + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap(), + "Basic dXNlcm5hbWU6cGFzc3dvcmQ=" + ); + + let req = Client::new().ws("/").basic_auth("username", None); + assert_eq!( + req.head + .headers + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap(), + "Basic dXNlcm5hbWU6" + ); + } + + #[actix_rt::test] + async fn bearer_auth() { + let req = Client::new().ws("/").bearer_auth("someS3cr3tAutht0k3n"); + assert_eq!( + req.head + .headers + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap(), + "Bearer someS3cr3tAutht0k3n" + ); + let _ = req.connect(); + } + + #[actix_rt::test] + async fn basics() { + let req = Client::new() + .ws("http://localhost/") + .origin("test-origin") + .max_frame_size(100) + .server_mode() + .protocols(&["v1", "v2"]) + .set_header_if_none(header::CONTENT_TYPE, "json") + .set_header_if_none(header::CONTENT_TYPE, "text") + .cookie(Cookie::build("cookie1", "value1").finish()); + assert_eq!( + req.origin.as_ref().unwrap().to_str().unwrap(), + "test-origin" + ); + assert_eq!(req.max_size, 100); + assert_eq!(req.server_mode, true); + assert_eq!(req.protocols, Some("v1,v2".to_string())); + assert_eq!( + req.head.headers.get(header::CONTENT_TYPE).unwrap(), + header::HeaderValue::from_static("json") + ); + + let _ = req.connect().await; + + assert!(Client::new().ws("/").connect().await.is_err()); + assert!(Client::new().ws("http:///test").connect().await.is_err()); + assert!(Client::new().ws("hmm://test.com/").connect().await.is_err()); + } +} diff --git a/awc/tests/test_client.rs b/awc/tests/test_client.rs new file mode 100644 index 000000000..15e9a07ac --- /dev/null +++ b/awc/tests/test_client.rs @@ -0,0 +1,805 @@ +use std::collections::HashMap; +use std::io::{Read, Write}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use brotli2::write::BrotliEncoder; +use bytes::Bytes; +use flate2::read::GzDecoder; +use flate2::write::GzEncoder; +use flate2::Compression; +use futures::future::ok; +use rand::Rng; + +use actix_http::HttpService; +use actix_http_test::TestServer; +use actix_service::pipeline_factory; +use actix_web::http::Cookie; +use actix_web::middleware::{BodyEncoding, Compress}; +use actix_web::{http::header, web, App, Error, HttpMessage, HttpRequest, HttpResponse}; +use awc::error::SendRequestError; + +const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World"; + +#[actix_rt::test] +async fn test_simple() { + let srv = + TestServer::start(|| { + HttpService::new(App::new().service( + web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR))), + )) + }); + + let request = srv.get("/").header("x-test", "111").send(); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + let mut response = srv.post("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + // camel case + let response = srv.post("/").camel_case().send().await.unwrap(); + assert!(response.status().is_success()); +} + +#[actix_rt::test] +async fn test_json() { + let srv = TestServer::start(|| { + HttpService::new(App::new().service( + web::resource("/").route(web::to(|_: web::Json| HttpResponse::Ok())), + )) + }); + + let request = srv + .get("/") + .header("x-test", "111") + .send_json(&"TEST".to_string()); + let response = request.await.unwrap(); + assert!(response.status().is_success()); +} + +#[actix_rt::test] +async fn test_form() { + let srv = TestServer::start(|| { + HttpService::new(App::new().service(web::resource("/").route(web::to( + |_: web::Form>| HttpResponse::Ok(), + )))) + }); + + let mut data = HashMap::new(); + let _ = data.insert("key".to_string(), "TEST".to_string()); + + let request = srv.get("/").header("x-test", "111").send_form(&data); + let response = request.await.unwrap(); + assert!(response.status().is_success()); +} + +#[actix_rt::test] +async fn test_timeout() { + let srv = TestServer::start(|| { + HttpService::new(App::new().service(web::resource("/").route(web::to(|| { + async { + actix_rt::time::delay_for(Duration::from_millis(200)).await; + Ok::<_, Error>(HttpResponse::Ok().body(STR)) + } + })))) + }); + + let connector = awc::Connector::new() + .connector(actix_connect::new_connector( + actix_connect::start_default_resolver(), + )) + .timeout(Duration::from_secs(15)) + .finish(); + + let client = awc::Client::build() + .connector(connector) + .timeout(Duration::from_millis(50)) + .finish(); + + let request = client.get(srv.url("/")).send(); + match request.await { + Err(SendRequestError::Timeout) => (), + _ => panic!(), + } +} + +#[actix_rt::test] +async fn test_timeout_override() { + let srv = TestServer::start(|| { + HttpService::new(App::new().service(web::resource("/").route(web::to(|| { + async { + actix_rt::time::delay_for(Duration::from_millis(200)).await; + Ok::<_, Error>(HttpResponse::Ok().body(STR)) + } + })))) + }); + + let client = awc::Client::build() + .timeout(Duration::from_millis(50000)) + .finish(); + let request = client + .get(srv.url("/")) + .timeout(Duration::from_millis(50)) + .send(); + match request.await { + Err(SendRequestError::Timeout) => (), + _ => panic!(), + } +} + +#[actix_rt::test] +async fn test_connection_reuse() { + let num = Arc::new(AtomicUsize::new(0)); + let num2 = num.clone(); + + let srv = TestServer::start(move || { + let num2 = num2.clone(); + pipeline_factory(move |io| { + num2.fetch_add(1, Ordering::Relaxed); + ok(io) + }) + .and_then(HttpService::new( + App::new().service(web::resource("/").route(web::to(|| HttpResponse::Ok()))), + )) + }); + + let client = awc::Client::default(); + + // req 1 + let request = client.get(srv.url("/")).send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); + + // req 2 + let req = client.post(srv.url("/")); + let response = req.send().await.unwrap(); + assert!(response.status().is_success()); + + // one connection + assert_eq!(num.load(Ordering::Relaxed), 1); +} + +#[actix_rt::test] +async fn test_connection_force_close() { + let num = Arc::new(AtomicUsize::new(0)); + let num2 = num.clone(); + + let srv = TestServer::start(move || { + let num2 = num2.clone(); + pipeline_factory(move |io| { + num2.fetch_add(1, Ordering::Relaxed); + ok(io) + }) + .and_then(HttpService::new( + App::new().service(web::resource("/").route(web::to(|| HttpResponse::Ok()))), + )) + }); + + let client = awc::Client::default(); + + // req 1 + let request = client.get(srv.url("/")).force_close().send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); + + // req 2 + let req = client.post(srv.url("/")).force_close(); + let response = req.send().await.unwrap(); + assert!(response.status().is_success()); + + // two connection + assert_eq!(num.load(Ordering::Relaxed), 2); +} + +#[actix_rt::test] +async fn test_connection_server_close() { + let num = Arc::new(AtomicUsize::new(0)); + let num2 = num.clone(); + + let srv = TestServer::start(move || { + let num2 = num2.clone(); + pipeline_factory(move |io| { + num2.fetch_add(1, Ordering::Relaxed); + ok(io) + }) + .and_then(HttpService::new( + App::new().service( + web::resource("/") + .route(web::to(|| HttpResponse::Ok().force_close().finish())), + ), + )) + }); + + let client = awc::Client::default(); + + // req 1 + let request = client.get(srv.url("/")).send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); + + // req 2 + let req = client.post(srv.url("/")); + let response = req.send().await.unwrap(); + assert!(response.status().is_success()); + + // two connection + assert_eq!(num.load(Ordering::Relaxed), 2); +} + +#[actix_rt::test] +async fn test_connection_wait_queue() { + let num = Arc::new(AtomicUsize::new(0)); + let num2 = num.clone(); + + let srv = TestServer::start(move || { + let num2 = num2.clone(); + pipeline_factory(move |io| { + num2.fetch_add(1, Ordering::Relaxed); + ok(io) + }) + .and_then(HttpService::new(App::new().service( + web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR))), + ))) + }); + + let client = awc::Client::build() + .connector(awc::Connector::new().limit(1).finish()) + .finish(); + + // req 1 + let request = client.get(srv.url("/")).send(); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); + + // req 2 + let req2 = client.post(srv.url("/")); + let req2_fut = req2.send(); + + // read response 1 + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + // req 2 + let response = req2_fut.await.unwrap(); + assert!(response.status().is_success()); + + // two connection + assert_eq!(num.load(Ordering::Relaxed), 1); +} + +#[actix_rt::test] +async fn test_connection_wait_queue_force_close() { + let num = Arc::new(AtomicUsize::new(0)); + let num2 = num.clone(); + + let srv = TestServer::start(move || { + let num2 = num2.clone(); + pipeline_factory(move |io| { + num2.fetch_add(1, Ordering::Relaxed); + ok(io) + }) + .and_then(HttpService::new( + App::new().service( + web::resource("/") + .route(web::to(|| HttpResponse::Ok().force_close().body(STR))), + ), + )) + }); + + let client = awc::Client::build() + .connector(awc::Connector::new().limit(1).finish()) + .finish(); + + // req 1 + let request = client.get(srv.url("/")).send(); + let mut response = request.await.unwrap(); + assert!(response.status().is_success()); + + // req 2 + let req2 = client.post(srv.url("/")); + let req2_fut = req2.send(); + + // read response 1 + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + // req 2 + let response = req2_fut.await.unwrap(); + assert!(response.status().is_success()); + + // two connection + assert_eq!(num.load(Ordering::Relaxed), 2); +} + +#[actix_rt::test] +async fn test_with_query_parameter() { + let srv = TestServer::start(|| { + HttpService::new(App::new().service(web::resource("/").to( + |req: HttpRequest| { + if req.query_string().contains("qp") { + HttpResponse::Ok() + } else { + HttpResponse::BadRequest() + } + }, + ))) + }); + + let res = awc::Client::new() + .get(srv.url("/?qp=5")) + .send() + .await + .unwrap(); + assert!(res.status().is_success()); +} + +#[actix_rt::test] +async fn test_no_decompress() { + let srv = TestServer::start(|| { + HttpService::new(App::new().wrap(Compress::default()).service( + web::resource("/").route(web::to(|| { + let mut res = HttpResponse::Ok().body(STR); + res.encoding(header::ContentEncoding::Gzip); + res + })), + )) + }); + + let mut res = awc::Client::new() + .get(srv.url("/")) + .no_decompress() + .send() + .await + .unwrap(); + assert!(res.status().is_success()); + + // read response + let bytes = res.body().await.unwrap(); + + let mut e = GzDecoder::new(&bytes[..]); + let mut dec = Vec::new(); + e.read_to_end(&mut dec).unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + + // POST + let mut res = awc::Client::new() + .post(srv.url("/")) + .no_decompress() + .send() + .await + .unwrap(); + assert!(res.status().is_success()); + + let bytes = res.body().await.unwrap(); + let mut e = GzDecoder::new(&bytes[..]); + let mut dec = Vec::new(); + e.read_to_end(&mut dec).unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_client_gzip_encoding() { + let srv = TestServer::start(|| { + HttpService::new(App::new().service(web::resource("/").route(web::to(|| { + let mut e = GzEncoder::new(Vec::new(), Compression::default()); + e.write_all(STR.as_ref()).unwrap(); + let data = e.finish().unwrap(); + + HttpResponse::Ok() + .header("content-encoding", "gzip") + .body(data) + })))) + }); + + // client request + let mut response = srv.post("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[actix_rt::test] +async fn test_client_gzip_encoding_large() { + let srv = TestServer::start(|| { + HttpService::new(App::new().service(web::resource("/").route(web::to(|| { + let mut e = GzEncoder::new(Vec::new(), Compression::default()); + e.write_all(STR.repeat(10).as_ref()).unwrap(); + let data = e.finish().unwrap(); + + HttpResponse::Ok() + .header("content-encoding", "gzip") + .body(data) + })))) + }); + + // client request + let mut response = srv.post("/").send().await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from(STR.repeat(10))); +} + +#[actix_rt::test] +async fn test_client_gzip_encoding_large_random() { + let data = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(100_000) + .collect::(); + + let srv = TestServer::start(|| { + HttpService::new(App::new().service(web::resource("/").route(web::to( + |data: Bytes| { + let mut e = GzEncoder::new(Vec::new(), Compression::default()); + e.write_all(&data).unwrap(); + let data = e.finish().unwrap(); + HttpResponse::Ok() + .header("content-encoding", "gzip") + .body(data) + }, + )))) + }); + + // client request + let mut response = srv.post("/").send_body(data.clone()).await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from(data)); +} + +#[actix_rt::test] +async fn test_client_brotli_encoding() { + let srv = TestServer::start(|| { + HttpService::new(App::new().service(web::resource("/").route(web::to( + |data: Bytes| { + let mut e = BrotliEncoder::new(Vec::new(), 5); + e.write_all(&data).unwrap(); + let data = e.finish().unwrap(); + HttpResponse::Ok() + .header("content-encoding", "br") + .body(data) + }, + )))) + }); + + // client request + let mut response = srv.post("/").send_body(STR).await.unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = response.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +// #[actix_rt::test] +// async fn test_client_brotli_encoding_large_random() { +// let data = rand::thread_rng() +// .sample_iter(&rand::distributions::Alphanumeric) +// .take(70_000) +// .collect::(); + +// let srv = test::TestServer::start(|app| { +// app.handler(|req: &HttpRequest| { +// req.body() +// .and_then(move |bytes: Bytes| { +// Ok(HttpResponse::Ok() +// .content_encoding(http::ContentEncoding::Gzip) +// .body(bytes)) +// }) +// .responder() +// }) +// }); + +// // client request +// let request = srv +// .client(http::Method::POST, "/") +// .content_encoding(http::ContentEncoding::Br) +// .body(data.clone()) +// .unwrap(); +// let response = request.send().await.unwrap(); +// assert!(response.status().is_success()); + +// // read response +// let bytes = response.body().await.unwrap(); +// assert_eq!(bytes.len(), data.len()); +// assert_eq!(bytes, Bytes::from(data)); +// } + +// #[cfg(feature = "brotli")] +// #[actix_rt::test] +// async fn test_client_deflate_encoding() { +// let srv = test::TestServer::start(|app| { +// app.handler(|req: &HttpRequest| { +// req.body() +// .and_then(|bytes: Bytes| { +// Ok(HttpResponse::Ok() +// .content_encoding(http::ContentEncoding::Br) +// .body(bytes)) +// }) +// .responder() +// }) +// }); + +// // client request +// let request = srv +// .post() +// .content_encoding(http::ContentEncoding::Deflate) +// .body(STR) +// .unwrap(); +// let response = srv.execute(request.send()).unwrap(); +// assert!(response.status().is_success()); + +// // read response +// let bytes = srv.execute(response.body()).unwrap(); +// assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +// } + +// #[actix_rt::test] +// async fn test_client_deflate_encoding_large_random() { +// let data = rand::thread_rng() +// .sample_iter(&rand::distributions::Alphanumeric) +// .take(70_000) +// .collect::(); + +// let srv = test::TestServer::start(|app| { +// app.handler(|req: &HttpRequest| { +// req.body() +// .and_then(|bytes: Bytes| { +// Ok(HttpResponse::Ok() +// .content_encoding(http::ContentEncoding::Br) +// .body(bytes)) +// }) +// .responder() +// }) +// }); + +// // client request +// let request = srv +// .post() +// .content_encoding(http::ContentEncoding::Deflate) +// .body(data.clone()) +// .unwrap(); +// let response = srv.execute(request.send()).unwrap(); +// assert!(response.status().is_success()); + +// // read response +// let bytes = srv.execute(response.body()).unwrap(); +// assert_eq!(bytes, Bytes::from(data)); +// } + +// #[actix_rt::test] +// async fn test_client_streaming_explicit() { +// let srv = test::TestServer::start(|app| { +// app.handler(|req: &HttpRequest| { +// req.body() +// .map_err(Error::from) +// .and_then(|body| { +// Ok(HttpResponse::Ok() +// .chunked() +// .content_encoding(http::ContentEncoding::Identity) +// .body(body)) +// }) +// .responder() +// }) +// }); + +// let body = once(Ok(Bytes::from_static(STR.as_ref()))); + +// let request = srv.get("/").body(Body::Streaming(Box::new(body))).unwrap(); +// let response = srv.execute(request.send()).unwrap(); +// assert!(response.status().is_success()); + +// // read response +// let bytes = srv.execute(response.body()).unwrap(); +// assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +// } + +// #[actix_rt::test] +// async fn test_body_streaming_implicit() { +// let srv = test::TestServer::start(|app| { +// app.handler(|_| { +// let body = once(Ok(Bytes::from_static(STR.as_ref()))); +// HttpResponse::Ok() +// .content_encoding(http::ContentEncoding::Gzip) +// .body(Body::Streaming(Box::new(body))) +// }) +// }); + +// let request = srv.get("/").finish().unwrap(); +// let response = srv.execute(request.send()).unwrap(); +// assert!(response.status().is_success()); + +// // read response +// let bytes = srv.execute(response.body()).unwrap(); +// assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +// } + +#[actix_rt::test] +async fn test_client_cookie_handling() { + use std::io::{Error as IoError, ErrorKind}; + + let cookie1 = Cookie::build("cookie1", "value1").finish(); + let cookie2 = Cookie::build("cookie2", "value2") + .domain("www.example.org") + .path("/") + .secure(true) + .http_only(true) + .finish(); + // Q: are all these clones really necessary? A: Yes, possibly + let cookie1b = cookie1.clone(); + let cookie2b = cookie2.clone(); + + let srv = TestServer::start(move || { + let cookie1 = cookie1b.clone(); + let cookie2 = cookie2b.clone(); + + HttpService::new(App::new().route( + "/", + web::to(move |req: HttpRequest| { + let cookie1 = cookie1.clone(); + let cookie2 = cookie2.clone(); + + async move { + // Check cookies were sent correctly + let res: Result<(), Error> = req + .cookie("cookie1") + .ok_or(()) + .and_then(|c1| { + if c1.value() == "value1" { + Ok(()) + } else { + Err(()) + } + }) + .and_then(|()| req.cookie("cookie2").ok_or(())) + .and_then(|c2| { + if c2.value() == "value2" { + Ok(()) + } else { + Err(()) + } + }) + .map_err(|_| Error::from(IoError::from(ErrorKind::NotFound))); + + if let Err(e) = res { + Err(e) + } else { + // Send some cookies back + Ok::<_, Error>( + HttpResponse::Ok().cookie(cookie1).cookie(cookie2).finish(), + ) + } + } + }), + )) + }); + + let request = srv.get("/").cookie(cookie1.clone()).cookie(cookie2.clone()); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); + let c1 = response.cookie("cookie1").expect("Missing cookie1"); + assert_eq!(c1, cookie1); + let c2 = response.cookie("cookie2").expect("Missing cookie2"); + assert_eq!(c2, cookie2); +} + +// #[actix_rt::test] +// fn client_read_until_eof() { +// let addr = test::TestServer::unused_addr(); + +// thread::spawn(move || { +// let lst = net::TcpListener::bind(addr).unwrap(); + +// for stream in lst.incoming() { +// let mut stream = stream.unwrap(); +// let mut b = [0; 1000]; +// let _ = stream.read(&mut b).unwrap(); +// let _ = stream +// .write_all(b"HTTP/1.1 200 OK\r\nconnection: close\r\n\r\nwelcome!"); +// } +// }); + +// let mut sys = actix::System::new("test"); + +// // client request +// let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()) +// .finish() +// .unwrap(); +// let response = req.send().await.unwrap(); +// assert!(response.status().is_success()); + +// // read response +// let bytes = response.body().await.unwrap(); +// assert_eq!(bytes, Bytes::from_static(b"welcome!")); +// } + +#[actix_rt::test] +async fn client_basic_auth() { + let srv = TestServer::start(|| { + HttpService::new(App::new().route( + "/", + web::to(|req: HttpRequest| { + if req + .headers() + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap() + == "Basic dXNlcm5hbWU6cGFzc3dvcmQ=" + { + HttpResponse::Ok() + } else { + HttpResponse::BadRequest() + } + }), + )) + }); + + // set authorization header to Basic + let request = srv.get("/").basic_auth("username", Some("password")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); +} + +#[actix_rt::test] +async fn client_bearer_auth() { + let srv = TestServer::start(|| { + HttpService::new(App::new().route( + "/", + web::to(|req: HttpRequest| { + if req + .headers() + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap() + == "Bearer someS3cr3tAutht0k3n" + { + HttpResponse::Ok() + } else { + HttpResponse::BadRequest() + } + }), + )) + }); + + // set authorization header to Bearer + let request = srv.get("/").bearer_auth("someS3cr3tAutht0k3n"); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); +} diff --git a/awc/tests/test_rustls_client.rs b/awc/tests/test_rustls_client.rs new file mode 100644 index 000000000..ac60d8e83 --- /dev/null +++ b/awc/tests/test_rustls_client.rs @@ -0,0 +1,106 @@ +#![cfg(feature = "rustls")] +use rust_tls::ClientConfig; + +use std::io::Result; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_http::HttpService; +use actix_http_test::TestServer; +use actix_server::ssl::OpensslAcceptor; +use actix_service::{pipeline_factory, ServiceFactory}; +use actix_web::http::Version; +use actix_web::{web, App, HttpResponse}; +use futures::future::ok; +use open_ssl::ssl::{SslAcceptor, SslFiletype, SslMethod, SslVerifyMode}; + +fn ssl_acceptor() -> Result> { + // load ssl keys + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + builder.set_verify_callback(SslVerifyMode::NONE, |_, _| true); + builder + .set_private_key_file("../tests/key.pem", SslFiletype::PEM) + .unwrap(); + builder + .set_certificate_chain_file("../tests/cert.pem") + .unwrap(); + builder.set_alpn_select_callback(|_, protos| { + const H2: &[u8] = b"\x02h2"; + if protos.windows(3).any(|window| window == H2) { + Ok(b"h2") + } else { + Err(open_ssl::ssl::AlpnError::NOACK) + } + }); + builder.set_alpn_protos(b"\x02h2")?; + Ok(actix_server::ssl::OpensslAcceptor::new(builder.build())) +} + +mod danger { + pub struct NoCertificateVerification {} + + impl rust_tls::ServerCertVerifier for NoCertificateVerification { + fn verify_server_cert( + &self, + _roots: &rust_tls::RootCertStore, + _presented_certs: &[rust_tls::Certificate], + _dns_name: webpki::DNSNameRef<'_>, + _ocsp: &[u8], + ) -> Result { + Ok(rust_tls::ServerCertVerified::assertion()) + } + } +} + +// #[actix_rt::test] +async fn _test_connection_reuse_h2() { + let openssl = ssl_acceptor().unwrap(); + let num = Arc::new(AtomicUsize::new(0)); + let num2 = num.clone(); + + let srv = TestServer::start(move || { + let num2 = num2.clone(); + pipeline_factory(move |io| { + num2.fetch_add(1, Ordering::Relaxed); + ok(io) + }) + .and_then( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(App::new() + .service(web::resource("/").route(web::to(|| HttpResponse::Ok())))) + .map_err(|_| ()), + ) + }); + + // disable ssl verification + let mut config = ClientConfig::new(); + let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + config.set_protocols(&protos); + config + .dangerous() + .set_certificate_verifier(Arc::new(danger::NoCertificateVerification {})); + + let client = awc::Client::build() + .connector(awc::Connector::new().rustls(Arc::new(config)).finish()) + .finish(); + + // req 1 + let request = client.get(srv.surl("/")).send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); + + // req 2 + let req = client.post(srv.surl("/")); + let response = req.send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.version(), Version::HTTP_2); + + // one connection + assert_eq!(num.load(Ordering::Relaxed), 1); +} diff --git a/awc/tests/test_ssl_client.rs b/awc/tests/test_ssl_client.rs new file mode 100644 index 000000000..1abb071a4 --- /dev/null +++ b/awc/tests/test_ssl_client.rs @@ -0,0 +1,87 @@ +#![cfg(feature = "openssl")] +use open_ssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode}; + +use std::io::Result; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_http::HttpService; +use actix_http_test::TestServer; +use actix_server::ssl::OpensslAcceptor; +use actix_service::{pipeline_factory, ServiceFactory}; +use actix_web::http::Version; +use actix_web::{web, App, HttpResponse}; +use futures::future::ok; + +fn ssl_acceptor() -> Result> { + // load ssl keys + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + builder + .set_private_key_file("../tests/key.pem", SslFiletype::PEM) + .unwrap(); + builder + .set_certificate_chain_file("../tests/cert.pem") + .unwrap(); + builder.set_alpn_select_callback(|_, protos| { + const H2: &[u8] = b"\x02h2"; + if protos.windows(3).any(|window| window == H2) { + Ok(b"h2") + } else { + Err(open_ssl::ssl::AlpnError::NOACK) + } + }); + builder.set_alpn_protos(b"\x02h2")?; + Ok(actix_server::ssl::OpensslAcceptor::new(builder.build())) +} + +#[actix_rt::test] +async fn test_connection_reuse_h2() { + let openssl = ssl_acceptor().unwrap(); + let num = Arc::new(AtomicUsize::new(0)); + let num2 = num.clone(); + + let srv = TestServer::start(move || { + let num2 = num2.clone(); + pipeline_factory(move |io| { + num2.fetch_add(1, Ordering::Relaxed); + ok(io) + }) + .and_then( + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)), + ) + .and_then( + HttpService::build() + .h2(App::new() + .service(web::resource("/").route(web::to(|| HttpResponse::Ok())))) + .map_err(|_| ()), + ) + }); + + // disable ssl verification + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let _ = builder + .set_alpn_protos(b"\x02h2\x08http/1.1") + .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); + + let client = awc::Client::build() + .connector(awc::Connector::new().ssl(builder.build()).finish()) + .finish(); + + // req 1 + let request = client.get(srv.surl("/")).send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); + + // req 2 + let req = client.post(srv.surl("/")); + let response = req.send().await.unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.version(), Version::HTTP_2); + + // one connection + assert_eq!(num.load(Ordering::Relaxed), 1); +} diff --git a/awc/tests/test_ws.rs b/awc/tests/test_ws.rs new file mode 100644 index 000000000..2e1d3981e --- /dev/null +++ b/awc/tests/test_ws.rs @@ -0,0 +1,81 @@ +use std::io; + +use actix_codec::Framed; +use actix_http::{body::BodySize, h1, ws, Error, HttpService, Request, Response}; +use actix_http_test::TestServer; +use bytes::{Bytes, BytesMut}; +use futures::future::ok; +use futures::{SinkExt, StreamExt}; + +async fn ws_service(req: ws::Frame) -> Result { + match req { + ws::Frame::Ping(msg) => Ok(ws::Message::Pong(msg)), + ws::Frame::Text(text) => { + let text = if let Some(pl) = text { + String::from_utf8(Vec::from(pl.as_ref())).unwrap() + } else { + String::new() + }; + Ok(ws::Message::Text(text)) + } + ws::Frame::Binary(bin) => Ok(ws::Message::Binary( + bin.map(|e| e.freeze()) + .unwrap_or_else(|| Bytes::from("")) + .into(), + )), + ws::Frame::Close(reason) => Ok(ws::Message::Close(reason)), + _ => Ok(ws::Message::Close(None)), + } +} + +#[actix_rt::test] +async fn test_simple() { + let mut srv = TestServer::start(|| { + HttpService::build() + .upgrade(|(req, mut framed): (Request, Framed<_, _>)| { + async move { + let res = ws::handshake_response(req.head()).finish(); + // send handshake response + framed + .send(h1::Message::Item((res.drop_body(), BodySize::None))) + .await?; + + // start websocket service + let framed = framed.into_framed(ws::Codec::new()); + ws::Transport::with(framed, ws_service).await + } + }) + .finish(|_| ok::<_, Error>(Response::NotFound())) + }); + + // client service + let mut framed = srv.ws().await.unwrap(); + framed + .send(ws::Message::Text("text".to_string())) + .await + .unwrap(); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Text(Some(BytesMut::from("text")))); + + framed + .send(ws::Message::Binary("text".into())) + .await + .unwrap(); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!( + item, + ws::Frame::Binary(Some(Bytes::from_static(b"text").into())) + ); + + framed.send(ws::Message::Ping("text".into())).await.unwrap(); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Pong("text".to_string().into())); + + framed + .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) + .await + .unwrap(); + + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into()))); +} diff --git a/build.rs b/build.rs deleted file mode 100644 index bf2355e24..000000000 --- a/build.rs +++ /dev/null @@ -1,49 +0,0 @@ -extern crate skeptic; -extern crate version_check; - -use std::{env, fs}; - - -#[cfg(unix)] -fn main() { - println!("cargo:rerun-if-env-changed=USE_SKEPTIC"); - let f = env::var("OUT_DIR").unwrap() + "/skeptic-tests.rs"; - if env::var("USE_SKEPTIC").is_ok() { - let _ = fs::remove_file(f); - // generates doc tests for `README.md`. - skeptic::generate_doc_tests( - &[// "README.md", - "guide/src/qs_1.md", - "guide/src/qs_2.md", - "guide/src/qs_3.md", - "guide/src/qs_3_5.md", - "guide/src/qs_4.md", - "guide/src/qs_4_5.md", - "guide/src/qs_5.md", - "guide/src/qs_7.md", - "guide/src/qs_8.md", - "guide/src/qs_9.md", - "guide/src/qs_10.md", - "guide/src/qs_12.md", - "guide/src/qs_13.md", - "guide/src/qs_14.md", - ]); - } else { - let _ = fs::File::create(f); - } - - match version_check::is_nightly() { - Some(true) => println!("cargo:rustc-cfg=actix_nightly"), - Some(false) => (), - None => (), - }; -} - -#[cfg(not(unix))] -fn main() { - match version_check::is_nightly() { - Some(true) => println!("cargo:rustc-cfg=actix_nightly"), - Some(false) => (), - None => (), - }; -} diff --git a/examples/basic.rs b/examples/basic.rs new file mode 100644 index 000000000..b5b69fce2 --- /dev/null +++ b/examples/basic.rs @@ -0,0 +1,47 @@ +use actix_web::{get, middleware, web, App, HttpRequest, HttpResponse, HttpServer}; + +#[get("/resource1/{name}/index.html")] +async fn index(req: HttpRequest, name: web::Path) -> String { + println!("REQ: {:?}", req); + format!("Hello: {}!\r\n", name) +} + +async fn index_async(req: HttpRequest) -> &'static str { + println!("REQ: {:?}", req); + "Hello world!\r\n" +} + +#[get("/")] +async fn no_params() -> &'static str { + "Hello world!\r\n" +} + +#[actix_rt::main] +async fn main() -> std::io::Result<()> { + std::env::set_var("RUST_LOG", "actix_server=info,actix_web=info"); + env_logger::init(); + + HttpServer::new(|| { + App::new() + .wrap(middleware::DefaultHeaders::new().header("X-Version", "0.2")) + .wrap(middleware::Compress::default()) + .wrap(middleware::Logger::default()) + .service(index) + .service(no_params) + .service( + web::resource("/resource2/index.html") + .wrap( + middleware::DefaultHeaders::new().header("X-Version-R2", "0.3"), + ) + .default_service( + web::route().to(|| HttpResponse::MethodNotAllowed()), + ) + .route(web::get().to(index_async)), + ) + .service(web::resource("/test1.html").to(|| async { "Test\r\n" })) + }) + .bind("127.0.0.1:8080")? + .workers(1) + .start() + .await +} diff --git a/examples/basics/Cargo.toml b/examples/basics/Cargo.toml deleted file mode 100644 index 76bfa52be..000000000 --- a/examples/basics/Cargo.toml +++ /dev/null @@ -1,11 +0,0 @@ -[package] -name = "basics" -version = "0.1.0" -authors = ["Nikolay Kim "] -workspace = "../.." - -[dependencies] -futures = "*" -env_logger = "0.5" -actix = "0.5" -actix-web = { path="../.." } diff --git a/examples/basics/README.md b/examples/basics/README.md deleted file mode 100644 index 82e35e06e..000000000 --- a/examples/basics/README.md +++ /dev/null @@ -1,20 +0,0 @@ -# basics - -## Usage - -### server - -```bash -cd actix-web/examples/basics -cargo run -# Started http server: 127.0.0.1:8080 -``` - -### web client - -- [http://localhost:8080/index.html](http://localhost:8080/index.html) -- [http://localhost:8080/async/bob](http://localhost:8080/async/bob) -- [http://localhost:8080/user/bob/](http://localhost:8080/user/bob/) plain/text download -- [http://localhost:8080/test](http://localhost:8080/test) (return status switch GET or POST or other) -- [http://localhost:8080/static/index.html](http://localhost:8080/static/index.html) -- [http://localhost:8080/static/notexit](http://localhost:8080/static/notexit) display 404 page diff --git a/examples/basics/src/main.rs b/examples/basics/src/main.rs deleted file mode 100644 index 750fc7640..000000000 --- a/examples/basics/src/main.rs +++ /dev/null @@ -1,154 +0,0 @@ -#![allow(unused_variables)] -#![cfg_attr(feature="cargo-clippy", allow(needless_pass_by_value))] - -extern crate actix; -extern crate actix_web; -extern crate env_logger; -extern crate futures; -use futures::Stream; - -use std::{io, env}; -use actix_web::{error, fs, pred, - App, HttpRequest, HttpResponse, HttpServer, Result, Error}; -use actix_web::http::{Method, StatusCode}; -use actix_web::middleware::{self, RequestSession}; -use futures::future::{FutureResult, result}; - -/// favicon handler -fn favicon(req: HttpRequest) -> Result { - Ok(fs::NamedFile::open("../static/favicon.ico")?) -} - -/// simple index handler -fn index(mut req: HttpRequest) -> Result { - println!("{:?}", req); - - // example of ... - if let Ok(ch) = req.poll() { - if let futures::Async::Ready(Some(d)) = ch { - println!("{}", String::from_utf8_lossy(d.as_ref())); - } - } - - // session - let mut counter = 1; - if let Some(count) = req.session().get::("counter")? { - println!("SESSION value: {}", count); - counter = count + 1; - req.session().set("counter", counter)?; - } else { - req.session().set("counter", counter)?; - } - - // html - let html = format!(r#"actix - basics - -