diff --git a/.cargo/config.toml b/.cargo/config.toml index db47ca46d..4425e0dda 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,9 +1,14 @@ [alias] -chk = "check --workspace --all-features --tests --examples --bins" -lint = "clippy --workspace --all-features --tests --examples --bins" -ci-min = "hack check --workspace --no-default-features" -ci-min-test = "hack check --workspace --no-default-features --tests --examples" -ci-default = "check --workspace --bins --tests --examples" -ci-full = "check --workspace --all-features --bins --tests --examples" -ci-test = "test --workspace --all-features --lib --tests --no-fail-fast -- --nocapture" -ci-doctest = "hack test --workspace --all-features --doc --no-fail-fast -- --nocapture" +lint = "clippy --workspace --tests --examples --bins -- -Dclippy::todo" +lint-all = "clippy --workspace --all-features --tests --examples --bins -- -Dclippy::todo" + +# lib checking +ci-check-min = "hack --workspace check --no-default-features" +ci-check-default = "hack --workspace check" +ci-check-default-tests = "check --workspace --tests" +ci-check-all-feature-powerset="hack --workspace --feature-powerset --skip=__compress,io-uring check" +ci-check-all-feature-powerset-linux="hack --workspace --feature-powerset --skip=__compress check" + +# testing +ci-doctest-default = "test --workspace --doc --no-fail-fast -- --nocapture" +ci-doctest = "test --workspace --all-features --doc --no-fail-fast -- --nocapture" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 22b92759a..d9b98a7b8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,9 +14,9 @@ jobs: target: - { name: Linux, os: ubuntu-latest, triple: x86_64-unknown-linux-gnu } - { name: macOS, os: macos-latest, triple: x86_64-apple-darwin } - - { name: Windows, os: windows-latest, triple: x86_64-pc-windows-msvc } + - { name: Windows, os: windows-2022, triple: x86_64-pc-windows-msvc } version: - - 1.46.0 # MSRV + - 1.52.0 # MSRV - stable - nightly @@ -24,12 +24,16 @@ jobs: runs-on: ${{ matrix.target.os }} env: + CI: 1 + CARGO_INCREMENTAL: 0 VCPKGRS_DYNAMIC: 1 steps: - uses: actions/checkout@v2 # install OpenSSL on Windows + # TODO: GitHub actions docs state that OpenSSL is + # already installed on these Windows machines somewhere - name: Set vcpkg root if: matrix.target.triple == 'x86_64-pc-windows-msvc' run: echo "VCPKG_ROOT=$env:VCPKG_INSTALLATION_ROOT" | Out-File -FilePath $env:GITHUB_ENV -Append @@ -46,8 +50,7 @@ jobs: - name: Generate Cargo.lock uses: actions-rs/cargo@v1 - with: - command: generate-lockfile + with: { command: generate-lockfile } - name: Cache Dependencies uses: Swatinem/rust-cache@v1.2.0 @@ -59,52 +62,122 @@ jobs: - name: check minimal uses: actions-rs/cargo@v1 - with: { command: ci-min } - - - name: check minimal + tests - uses: actions-rs/cargo@v1 - with: { command: ci-min-test } + with: { command: ci-check-min } - name: check default uses: actions-rs/cargo@v1 - with: { command: ci-default } - - - name: check full - uses: actions-rs/cargo@v1 - with: { command: ci-full } + with: { command: ci-check-default } - name: tests - uses: actions-rs/cargo@v1 - timeout-minutes: 40 - with: - command: ci-test - args: --skip=test_reading_deflate_encoding_large_random_rustls - - - name: doc tests - # due to unknown issue with running doc tests on macOS - if: matrix.target.os == 'ubuntu-latest' - uses: actions-rs/cargo@v1 - timeout-minutes: 40 - with: { command: ci-doctest } - - - name: Generate coverage file - if: > - matrix.target.os == 'ubuntu-latest' - && matrix.version == 'stable' - && github.ref == 'refs/heads/master' + timeout-minutes: 60 run: | - cargo install cargo-tarpaulin --vers "^0.13" - cargo tarpaulin --out Xml --verbose - - name: Upload to Codecov - if: > - matrix.target.os == 'ubuntu-latest' - && matrix.version == 'stable' - && github.ref == 'refs/heads/master' - uses: codecov/codecov-action@v1 - with: - file: cobertura.xml + cargo test --lib --tests -p=actix-router --all-features + cargo test --lib --tests -p=actix-http --all-features + cargo test --lib --tests -p=actix-web --features=rustls,openssl -- --skip=test_reading_deflate_encoding_large_random_rustls + cargo test --lib --tests -p=actix-web-codegen --all-features + cargo test --lib --tests -p=awc --all-features + cargo test --lib --tests -p=actix-http-test --all-features + cargo test --lib --tests -p=actix-test --all-features + cargo test --lib --tests -p=actix-files + cargo test --lib --tests -p=actix-multipart --all-features + cargo test --lib --tests -p=actix-web-actors --all-features + + - name: tests (io-uring) + if: matrix.target.os == 'ubuntu-latest' + timeout-minutes: 60 + run: > + sudo bash -c "ulimit -Sl 512 + && ulimit -Hl 512 + && PATH=$PATH:/usr/share/rust/.cargo/bin + && RUSTUP_TOOLCHAIN=${{ matrix.version }} cargo test --lib --tests -p=actix-files --all-features" - name: Clear the cargo caches run: | - cargo install cargo-cache --version 0.6.2 --no-default-features --features ci-autoclean + cargo install cargo-cache --version 0.6.3 --no-default-features --features ci-autoclean cargo-cache + + ci_feature_powerset_check: + name: Verify Feature Combinations + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Install stable + uses: actions-rs/toolchain@v1 + with: + toolchain: stable-x86_64-unknown-linux-gnu + profile: minimal + override: true + + - name: Generate Cargo.lock + uses: actions-rs/cargo@v1 + with: { command: generate-lockfile } + - name: Cache Dependencies + uses: Swatinem/rust-cache@v1.2.0 + + - name: Install cargo-hack + uses: actions-rs/cargo@v1 + with: + command: install + args: cargo-hack + + - name: check feature combinations + uses: actions-rs/cargo@v1 + with: { command: ci-check-all-feature-powerset } + + - name: check feature combinations + uses: actions-rs/cargo@v1 + with: { command: ci-check-all-feature-powerset-linux } + + coverage: + name: coverage + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Install stable + uses: actions-rs/toolchain@v1 + with: + toolchain: stable-x86_64-unknown-linux-gnu + profile: minimal + override: true + + - name: Generate Cargo.lock + uses: actions-rs/cargo@v1 + with: { command: generate-lockfile } + - name: Cache Dependencies + uses: Swatinem/rust-cache@v1.2.0 + + - name: Generate coverage file + if: github.ref == 'refs/heads/master' + run: | + cargo install cargo-tarpaulin --vers "^0.13" + cargo tarpaulin --workspace --features=rustls,openssl --out Xml --verbose + - name: Upload to Codecov + if: github.ref == 'refs/heads/master' + uses: codecov/codecov-action@v1 + with: { file: cobertura.xml } + + rustdoc: + name: doc tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Install Rust (nightly) + uses: actions-rs/toolchain@v1 + with: + toolchain: nightly-x86_64-unknown-linux-gnu + profile: minimal + override: true + + - name: Generate Cargo.lock + uses: actions-rs/cargo@v1 + with: { command: generate-lockfile } + - name: Cache Dependencies + uses: Swatinem/rust-cache@v1.3.0 + + - name: doc tests + uses: actions-rs/cargo@v1 + timeout-minutes: 60 + with: { command: ci-doctest } diff --git a/CHANGES.md b/CHANGES.md index 6aff8d2ea..b33c0371c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,14 +2,113 @@ ## Unreleased - 2021-xx-xx ### Added -* Re-export actix-service `ServiceFactory` in `dev` module. [#2325] - -### Changed +* Methods on `AcceptLanguage`: `ranked` and `preference`. [#2480] +* `AcceptEncoding` typed header. [#2482] +* `Range` typed header. [#2485] +* `HttpResponse::map_into_{left,right}_body` and `HttpResponse::map_into_boxed_body`. [#2468] +* `ServiceResponse::map_into_{left,right}_body` and `HttpResponse::map_into_boxed_body`. [#2468] * `HttpServer::on_connect` now receives a `CloneableExtensions` object. [#2327] [#2325]: https://github.com/actix/actix-web/pull/2325 [#2327]: https://github.com/actix/actix-web/pull/2327 +### Changed +* Rename `Accept::{mime_precedence => ranked}`. [#2480] +* Rename `Accept::{mime_preference => preference}`. [#2480] +* Un-deprecate `App::data_factory`. [#2484] +* `HttpRequest::url_for` no longer constructs URLs with query or fragment components. [#2430] +* `HttpServer::on_connect` now receives a `CloneableExtensions` object. [#2327] + +### Fixed +* Accept wildcard `*` items in `AcceptLanguage`. [#2480] +* Re-exports `dev::{BodySize, MessageBody, SizedStream}`. They are exposed through the `body` module. [#2468] +* Typed headers containing lists that require one or more items now enforce this minimum. [#2482] + +[#2327]: https://github.com/actix/actix-web/pull/2327 +[#2430]: https://github.com/actix/actix-web/pull/2430 +[#2468]: https://github.com/actix/actix-web/pull/2468 +[#2480]: https://github.com/actix/actix-web/pull/2480 +[#2482]: https://github.com/actix/actix-web/pull/2482 +[#2484]: https://github.com/actix/actix-web/pull/2484 +[#2485]: https://github.com/actix/actix-web/pull/2485 + + +## 4.0.0-beta.13 - 2021-11-30 +### Changed +* Update `actix-tls` to `3.0.0-rc.1`. [#2474] + +[#2474]: https://github.com/actix/actix-web/pull/2474 + + +## 4.0.0-beta.12 - 2021-11-22 +### Changed +* Compress middleware's response type is now `AnyBody>`. [#2448] + +### Fixed +* Relax `Unpin` bound on `S` (stream) parameter of `HttpResponseBuilder::streaming`. [#2448] + +### Removed +* `dev::ResponseBody` re-export; is function is replaced by the new `dev::AnyBody` enum. [#2446] + +[#2446]: https://github.com/actix/actix-web/pull/2446 +[#2448]: https://github.com/actix/actix-web/pull/2448 + + +## 4.0.0-beta.11 - 2021-11-15 +### Added +* Re-export `dev::ServerHandle` from `actix-server`. [#2442] + +### Changed +* `ContentType::html` now produces `text/html; charset=utf-8` instead of `text/html`. [#2423] +* Update `actix-server` to `2.0.0-beta.9`. [#2442] + +[#2423]: https://github.com/actix/actix-web/pull/2423 +[#2442]: https://github.com/actix/actix-web/pull/2442 + + +## 4.0.0-beta.10 - 2021-10-20 +### Added +* Option to allow `Json` extractor to work without a `Content-Type` header present. [#2362] +* `#[actix_web::test]` macro for setting up tests with a runtime. [#2409] + +### Changed +* Associated type `FromRequest::Config` was removed. [#2233] +* Inner field made private on `web::Payload`. [#2384] +* `Data::into_inner` and `Data::get_ref` no longer requires `T: Sized`. [#2403] +* Updated rustls to v0.20. [#2414] +* Minimum supported Rust version (MSRV) is now 1.52. + +### Removed +* Useless `ServiceResponse::checked_expr` method. [#2401] + +[#2233]: https://github.com/actix/actix-web/pull/2233 +[#2362]: https://github.com/actix/actix-web/pull/2362 +[#2384]: https://github.com/actix/actix-web/pull/2384 +[#2401]: https://github.com/actix/actix-web/pull/2401 +[#2403]: https://github.com/actix/actix-web/pull/2403 +[#2409]: https://github.com/actix/actix-web/pull/2409 +[#2414]: https://github.com/actix/actix-web/pull/2414 + + +## 4.0.0-beta.9 - 2021-09-09 +### Added +* Re-export actix-service `ServiceFactory` in `dev` module. [#2325] + +### Changed +* Compress middleware will return 406 Not Acceptable when no content encoding is acceptable to the client. [#2344] +* Move `BaseHttpResponse` to `dev::Response`. [#2379] +* Enable `TestRequest::param` to accept more than just static strings. [#2172] +* Minimum supported Rust version (MSRV) is now 1.51. + +### Fixed +* Fix quality parse error in Accept-Encoding header. [#2344] +* Re-export correct type at `web::HttpResponse`. [#2379] + +[#2172]: https://github.com/actix/actix-web/pull/2172 +[#2325]: https://github.com/actix/actix-web/pull/2325 +[#2344]: https://github.com/actix/actix-web/pull/2344 +[#2379]: https://github.com/actix/actix-web/pull/2379 + ## 4.0.0-beta.8 - 2021-06-26 ### Added diff --git a/Cargo.toml b/Cargo.toml index 7556bd8d7..425bdbbb3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-web" -version = "4.0.0-beta.8" +version = "4.0.0-beta.13" authors = ["Nikolay Kim "] description = "Actix Web is a powerful, pragmatic, and extremely fast web framework for Rust" keywords = ["actix", "http", "web", "framework", "async"] @@ -11,19 +11,21 @@ categories = [ "web-programming::websocket" ] homepage = "https://actix.rs" -repository = "https://github.com/actix/actix-web" +repository = "https://github.com/actix/actix-web.git" license = "MIT OR Apache-2.0" edition = "2018" [package.metadata.docs.rs] # features that docs.rs will build with features = ["openssl", "rustls", "compress-brotli", "compress-gzip", "compress-zstd", "cookies", "secure-cookies"] +rustdoc-args = ["--cfg", "docsrs"] [lib] name = "actix_web" path = "src/lib.rs" [workspace] +resolver = "2" members = [ ".", "awc", @@ -34,9 +36,8 @@ members = [ "actix-web-codegen", "actix-http-test", "actix-test", + "actix-router", ] -# enable when MSRV is 1.51+ -# resolver = "2" [features] default = ["compress-brotli", "compress-gzip", "compress-zstd", "cookies"] @@ -60,22 +61,25 @@ openssl = ["actix-http/openssl", "actix-tls/accept", "actix-tls/openssl"] # rustls rustls = ["actix-http/rustls", "actix-tls/accept", "actix-tls/rustls"] -# Internal (PRIVATE!) features used to aid testing and cheking feature status. +# Internal (PRIVATE!) features used to aid testing and checking feature status. # Don't rely on these whatsoever. They may disappear at anytime. __compress = [] +# io-uring feature only avaiable for Linux OSes. +experimental-io-uring = ["actix-server/io-uring"] + [dependencies] -actix-codec = "0.4.0" -actix-macros = "0.2.1" -actix-router = "0.2.7" -actix-rt = "2.2" -actix-server = "2.0.0-beta.3" +actix-codec = "0.4.1" +actix-macros = "0.2.3" +actix-rt = "2.3" +actix-server = "2.0.0-beta.9" actix-service = "2.0.0" actix-utils = "3.0.0" -actix-tls = { version = "3.0.0-beta.5", default-features = false, optional = true } +actix-tls = { version = "3.0.0-rc.1", default-features = false, optional = true } -actix-web-codegen = "0.5.0-beta.2" -actix-http = "3.0.0-beta.8" +actix-http = "3.0.0-beta.14" +actix-router = "0.5.0-beta.2" +actix-web-codegen = "0.5.0-beta.5" ahash = "0.7" bytes = "1" @@ -92,29 +96,35 @@ once_cell = "1.5" log = "0.4" mime = "0.3" paste = "1" -pin-project = "1.0.0" +pin-project-lite = "0.2.7" regex = "1.4" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" serde_urlencoded = "0.7" -smallvec = "1.6" +smallvec = "1.6.1" socket2 = "0.4.0" -time = { version = "0.2.23", default-features = false, features = ["std"] } +time = { version = "0.3", default-features = false, features = ["formatting"] } url = "2.1" [dev-dependencies] -actix-test = { version = "0.1.0-beta.3", features = ["openssl", "rustls"] } -awc = { version = "3.0.0-beta.7", features = ["openssl"] } +actix-test = { version = "0.1.0-beta.7", features = ["openssl", "rustls"] } +awc = { version = "3.0.0-beta.11", features = ["openssl"] } brotli2 = "0.3.2" criterion = { version = "0.3", features = ["html_reports"] } -env_logger = "0.8" +env_logger = "0.9" flate2 = "1.0.13" -zstd = "0.7" +futures-util = { version = "0.3.7", default-features = false, features = ["std"] } rand = "0.8" rcgen = "0.8" +rustls-pemfile = "0.2" tls-openssl = { package = "openssl", version = "0.10.9" } -tls-rustls = { package = "rustls", version = "0.19.0" } +tls-rustls = { package = "rustls", version = "0.20.0" } +zstd = "0.9" + +[profile.dev] +# Disabling debug info speeds up builds a bunch and we don't rely on it for debugging that much. +debug = 0 [profile.release] lto = true @@ -126,12 +136,22 @@ actix-files = { path = "actix-files" } actix-http = { path = "actix-http" } actix-http-test = { path = "actix-http-test" } actix-multipart = { path = "actix-multipart" } +actix-router = { path = "actix-router" } actix-test = { path = "actix-test" } actix-web = { path = "." } actix-web-actors = { path = "actix-web-actors" } actix-web-codegen = { path = "actix-web-codegen" } awc = { path = "awc" } +# uncomment for quick testing against local actix-net repo +# actix-service = { path = "../actix-net/actix-service" } +# actix-macros = { path = "../actix-net/actix-macros" } +# actix-rt = { path = "../actix-net/actix-rt" } +# actix-codec = { path = "../actix-net/actix-codec" } +# actix-utils = { path = "../actix-net/actix-utils" } +# actix-tls = { path = "../actix-net/actix-tls" } +# actix-server = { path = "../actix-net/actix-server" } + [[test]] name = "test_server" required-features = ["compress-brotli", "compress-gzip", "compress-zstd", "cookies"] diff --git a/MIGRATION.md b/MIGRATION.md index 785974366..d53bd7bf8 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -3,13 +3,16 @@ * The default `NormalizePath` behavior now strips trailing slashes by default. This was previously documented to be the case in v3 but the behavior now matches. The effect is that routes defined with trailing slashes will become inaccessible when - using `NormalizePath::default()`. + using `NormalizePath::default()`. As such, calling `NormalizePath::default()` will log a warning. + It is advised that the `new` method be used instead. Before: `#[get("/test/")]` After: `#[get("/test")]` Alternatively, explicitly require trailing slashes: `NormalizePath::new(TrailingSlash::Always)`. +* The `type Config` of `FromRequest` was removed. + * Feature flag `compress` has been split into its supported algorithm (brotli, gzip, zstd). By default all compression algorithms are enabled. To select algorithm you want to include with `middleware::Compress` use following flags: diff --git a/README.md b/README.md index 309a18466..c363ece9b 100644 --- a/README.md +++ b/README.md @@ -6,10 +6,10 @@

[![crates.io](https://img.shields.io/crates/v/actix-web?label=latest)](https://crates.io/crates/actix-web) -[![Documentation](https://docs.rs/actix-web/badge.svg?version=4.0.0-beta.8)](https://docs.rs/actix-web/4.0.0-beta.8) -[![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) +[![Documentation](https://docs.rs/actix-web/badge.svg?version=4.0.0-beta.13)](https://docs.rs/actix-web/4.0.0-beta.13) +[![Version](https://img.shields.io/badge/rustc-1.52+-ab6000.svg)](https://blog.rust-lang.org/2021/05/06/Rust-1.52.0.html) ![MIT or Apache 2.0 licensed](https://img.shields.io/crates/l/actix-web.svg) -[![Dependency Status](https://deps.rs/crate/actix-web/4.0.0-beta.8/status.svg)](https://deps.rs/crate/actix-web/4.0.0-beta.8) +[![Dependency Status](https://deps.rs/crate/actix-web/4.0.0-beta.13/status.svg)](https://deps.rs/crate/actix-web/4.0.0-beta.13)
[![build status](https://github.com/actix/actix-web/workflows/CI%20%28Linux%29/badge.svg?branch=master&event=push)](https://github.com/actix/actix-web/actions) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) @@ -32,7 +32,7 @@ * SSL support using OpenSSL or Rustls * Middlewares ([Logger, Session, CORS, etc](https://actix.rs/docs/middleware/)) * Includes an async [HTTP client](https://docs.rs/awc/) -* Runs on stable Rust 1.46+ +* Runs on stable Rust 1.52+ ## Documentation diff --git a/actix-files/CHANGES.md b/actix-files/CHANGES.md index db047c44c..63d8efc3f 100644 --- a/actix-files/CHANGES.md +++ b/actix-files/CHANGES.md @@ -3,6 +3,26 @@ ## Unreleased - 2021-xx-xx +## 0.6.0-beta.9 - 2021-11-22 +* Add crate feature `experimental-io-uring`, enabling async file I/O to be utilized. This feature is only available on Linux OSes with recent kernel versions. This feature is semver-exempt. [#2408] +* Add `NamedFile::open_async`. [#2408] +* Fix 304 Not Modified responses to omit the Content-Length header, as per the spec. [#2453] +* The `Responder` impl for `NamedFile` now has a boxed future associated type. [#2408] +* The `Service` impl for `NamedFileService` now has a boxed future associated type. [#2408] +* Add `impl Clone` for `FilesService`. [#2408] + +[#2408]: https://github.com/actix/actix-web/pull/2408 +[#2453]: https://github.com/actix/actix-web/pull/2453 + + +## 0.6.0-beta.8 - 2021-10-20 +* Minimum supported Rust version (MSRV) is now 1.52. + + +## 0.6.0-beta.7 - 2021-09-09 +* Minimum supported Rust version (MSRV) is now 1.51. + + ## 0.6.0-beta.6 - 2021-06-26 * Added `Files::path_filter()`. [#2274] * `Files::show_files_listing()` can now be used with `Files::index_file()` to show files listing as a fallback when the index file is not found. [#2228] diff --git a/actix-files/Cargo.toml b/actix-files/Cargo.toml index ef288215b..6b6d6d245 100644 --- a/actix-files/Cargo.toml +++ b/actix-files/Cargo.toml @@ -1,7 +1,11 @@ [package] name = "actix-files" -version = "0.6.0-beta.6" -authors = ["Nikolay Kim "] +version = "0.6.0-beta.9" +authors = [ + "Nikolay Kim ", + "fakeshadow <24548779@qq.com>", + "Rob Ede ", +] description = "Static file serving for Actix Web" keywords = ["actix", "http", "async", "futures"] homepage = "https://actix.rs" @@ -14,11 +18,14 @@ edition = "2018" name = "actix_files" path = "src/lib.rs" +[features] +experimental-io-uring = ["actix-web/experimental-io-uring", "tokio-uring"] + [dependencies] -actix-web = { version = "4.0.0-beta.8", default-features = false } -actix-http = "3.0.0-beta.8" -actix-service = "2.0.0" -actix-utils = "3.0.0" +actix-web = { version = "4.0.0-beta.11", default-features = false } +actix-http = "3.0.0-beta.14" +actix-service = "2" +actix-utils = "3" askama_escape = "0.10" bitflags = "1" @@ -30,8 +37,11 @@ log = "0.4" mime = "0.3" mime_guess = "2.0.1" percent-encoding = "2.1" +pin-project-lite = "0.2.7" + +tokio-uring = { version = "0.1", optional = true } [dev-dependencies] actix-rt = "2.2" -actix-web = "4.0.0-beta.8" -actix-test = "0.1.0-beta.3" +actix-web = "4.0.0-beta.11" +actix-test = "0.1.0-beta.7" diff --git a/actix-files/README.md b/actix-files/README.md index 13c301c56..84e556fa9 100644 --- a/actix-files/README.md +++ b/actix-files/README.md @@ -3,11 +3,11 @@ > Static file serving for Actix Web [![crates.io](https://img.shields.io/crates/v/actix-files?label=latest)](https://crates.io/crates/actix-files) -[![Documentation](https://docs.rs/actix-files/badge.svg?version=0.6.0-beta.6)](https://docs.rs/actix-files/0.6.0-beta.6) -[![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) +[![Documentation](https://docs.rs/actix-files/badge.svg?version=0.6.0-beta.9)](https://docs.rs/actix-files/0.6.0-beta.9) +[![Version](https://img.shields.io/badge/rustc-1.52+-ab6000.svg)](https://blog.rust-lang.org/2021/05/06/Rust-1.52.0.html) ![License](https://img.shields.io/crates/l/actix-files.svg)
-[![dependency status](https://deps.rs/crate/actix-files/0.6.0-beta.6/status.svg)](https://deps.rs/crate/actix-files/0.6.0-beta.6) +[![dependency status](https://deps.rs/crate/actix-files/0.6.0-beta.9/status.svg)](https://deps.rs/crate/actix-files/0.6.0-beta.9) [![Download](https://img.shields.io/crates/d/actix-files.svg)](https://crates.io/crates/actix-files) [![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x) @@ -15,4 +15,4 @@ - [API Documentation](https://docs.rs/actix-files/) - [Example Project](https://github.com/actix/examples/tree/master/basics/static_index) -- Minimum supported Rust version: 1.46 or later +- Minimum Supported Rust Version (MSRV): 1.52 diff --git a/actix-files/src/chunked.rs b/actix-files/src/chunked.rs index f639848c9..68221ccc3 100644 --- a/actix-files/src/chunked.rs +++ b/actix-files/src/chunked.rs @@ -1,98 +1,277 @@ use std::{ cmp, fmt, - fs::File, future::Future, - io::{self, Read, Seek}, + io, pin::Pin, task::{Context, Poll}, }; -use actix_web::{ - error::{BlockingError, Error}, - rt::task::{spawn_blocking, JoinHandle}, -}; -use bytes::Bytes; +use actix_web::{error::Error, web::Bytes}; use futures_core::{ready, Stream}; +use pin_project_lite::pin_project; -#[doc(hidden)] -/// A helper created from a `std::fs::File` which reads the file -/// chunk-by-chunk on a `ThreadPool`. -pub struct ChunkedReadFile { - size: u64, - offset: u64, - state: ChunkedReadFileState, - counter: u64, -} +use super::named::File; -enum ChunkedReadFileState { - File(Option), - Future(JoinHandle>), -} - -impl ChunkedReadFile { - pub(crate) fn new(size: u64, offset: u64, file: File) -> Self { - Self { - size, - offset, - state: ChunkedReadFileState::File(Some(file)), - counter: 0, - } +pin_project! { + /// Adapter to read a `std::file::File` in chunks. + #[doc(hidden)] + pub struct ChunkedReadFile { + size: u64, + offset: u64, + #[pin] + state: ChunkedReadFileState, + counter: u64, + callback: F, } } -impl fmt::Debug for ChunkedReadFile { +#[cfg(not(feature = "experimental-io-uring"))] +pin_project! { + #[project = ChunkedReadFileStateProj] + #[project_replace = ChunkedReadFileStateProjReplace] + enum ChunkedReadFileState { + File { file: Option, }, + Future { #[pin] fut: Fut }, + } +} + +#[cfg(feature = "experimental-io-uring")] +pin_project! { + #[project = ChunkedReadFileStateProj] + #[project_replace = ChunkedReadFileStateProjReplace] + enum ChunkedReadFileState { + File { file: Option<(File, BytesMut)> }, + Future { #[pin] fut: Fut }, + } +} + +impl fmt::Debug for ChunkedReadFile { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("ChunkedReadFile") } } -impl Stream for ChunkedReadFile { +pub(crate) fn new_chunked_read( + size: u64, + offset: u64, + file: File, +) -> impl Stream> { + ChunkedReadFile { + size, + offset, + #[cfg(not(feature = "experimental-io-uring"))] + state: ChunkedReadFileState::File { file: Some(file) }, + #[cfg(feature = "experimental-io-uring")] + state: ChunkedReadFileState::File { + file: Some((file, BytesMut::new())), + }, + counter: 0, + callback: chunked_read_file_callback, + } +} + +#[cfg(not(feature = "experimental-io-uring"))] +async fn chunked_read_file_callback( + mut file: File, + offset: u64, + max_bytes: usize, +) -> Result<(File, Bytes), Error> { + use io::{Read as _, Seek as _}; + + let res = actix_web::rt::task::spawn_blocking(move || { + let mut buf = Vec::with_capacity(max_bytes); + + file.seek(io::SeekFrom::Start(offset))?; + + let n_bytes = file.by_ref().take(max_bytes as u64).read_to_end(&mut buf)?; + + if n_bytes == 0 { + Err(io::Error::from(io::ErrorKind::UnexpectedEof)) + } else { + Ok((file, Bytes::from(buf))) + } + }) + .await + .map_err(|_| actix_web::error::BlockingError)??; + + Ok(res) +} + +#[cfg(feature = "experimental-io-uring")] +async fn chunked_read_file_callback( + file: File, + offset: u64, + max_bytes: usize, + mut bytes_mut: BytesMut, +) -> io::Result<(File, Bytes, BytesMut)> { + bytes_mut.reserve(max_bytes); + + let (res, mut bytes_mut) = file.read_at(bytes_mut, offset).await; + let n_bytes = res?; + + if n_bytes == 0 { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + + let bytes = bytes_mut.split_to(n_bytes).freeze(); + + Ok((file, bytes, bytes_mut)) +} + +#[cfg(feature = "experimental-io-uring")] +impl Stream for ChunkedReadFile +where + F: Fn(File, u64, usize, BytesMut) -> Fut, + Fut: Future>, +{ type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.as_mut().get_mut(); - match this.state { - ChunkedReadFileState::File(ref mut file) => { - let size = this.size; - let offset = this.offset; - let counter = this.counter; + let mut this = self.as_mut().project(); + match this.state.as_mut().project() { + ChunkedReadFileStateProj::File { file } => { + let size = *this.size; + let offset = *this.offset; + let counter = *this.counter; if size == counter { Poll::Ready(None) } else { - let mut file = file + let max_bytes = cmp::min(size.saturating_sub(counter), 65_536) as usize; + + let (file, bytes_mut) = file .take() .expect("ChunkedReadFile polled after completion"); - let fut = spawn_blocking(move || { - let max_bytes = cmp::min(size.saturating_sub(counter), 65_536) as usize; + let fut = (this.callback)(file, offset, max_bytes, bytes_mut); - let mut buf = Vec::with_capacity(max_bytes); - file.seek(io::SeekFrom::Start(offset))?; + this.state + .project_replace(ChunkedReadFileState::Future { fut }); - let n_bytes = - file.by_ref().take(max_bytes as u64).read_to_end(&mut buf)?; - - if n_bytes == 0 { - return Err(io::ErrorKind::UnexpectedEof.into()); - } - - Ok((file, Bytes::from(buf))) - }); - this.state = ChunkedReadFileState::Future(fut); self.poll_next(cx) } } - ChunkedReadFileState::Future(ref mut fut) => { - let (file, bytes) = - ready!(Pin::new(fut).poll(cx)).map_err(|_| BlockingError)??; - this.state = ChunkedReadFileState::File(Some(file)); + ChunkedReadFileStateProj::Future { fut } => { + let (file, bytes, bytes_mut) = ready!(fut.poll(cx))?; - this.offset += bytes.len() as u64; - this.counter += bytes.len() as u64; + this.state.project_replace(ChunkedReadFileState::File { + file: Some((file, bytes_mut)), + }); + + *this.offset += bytes.len() as u64; + *this.counter += bytes.len() as u64; Poll::Ready(Some(Ok(bytes))) } } } } + +#[cfg(not(feature = "experimental-io-uring"))] +impl Stream for ChunkedReadFile +where + F: Fn(File, u64, usize) -> Fut, + Fut: Future>, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.as_mut().project(); + match this.state.as_mut().project() { + ChunkedReadFileStateProj::File { file } => { + let size = *this.size; + let offset = *this.offset; + let counter = *this.counter; + + if size == counter { + Poll::Ready(None) + } else { + let max_bytes = cmp::min(size.saturating_sub(counter), 65_536) as usize; + + let file = file + .take() + .expect("ChunkedReadFile polled after completion"); + + let fut = (this.callback)(file, offset, max_bytes); + + this.state + .project_replace(ChunkedReadFileState::Future { fut }); + + self.poll_next(cx) + } + } + ChunkedReadFileStateProj::Future { fut } => { + let (file, bytes) = ready!(fut.poll(cx))?; + + this.state + .project_replace(ChunkedReadFileState::File { file: Some(file) }); + + *this.offset += bytes.len() as u64; + *this.counter += bytes.len() as u64; + + Poll::Ready(Some(Ok(bytes))) + } + } + } +} + +#[cfg(feature = "experimental-io-uring")] +use bytes_mut::BytesMut; + +// TODO: remove new type and use bytes::BytesMut directly +#[doc(hidden)] +#[cfg(feature = "experimental-io-uring")] +mod bytes_mut { + use std::ops::{Deref, DerefMut}; + + use tokio_uring::buf::{IoBuf, IoBufMut}; + + #[derive(Debug)] + pub struct BytesMut(bytes::BytesMut); + + impl BytesMut { + pub(super) fn new() -> Self { + Self(bytes::BytesMut::new()) + } + } + + impl Deref for BytesMut { + type Target = bytes::BytesMut; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl DerefMut for BytesMut { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + + unsafe impl IoBuf for BytesMut { + fn stable_ptr(&self) -> *const u8 { + self.0.as_ptr() + } + + fn bytes_init(&self) -> usize { + self.0.len() + } + + fn bytes_total(&self) -> usize { + self.0.capacity() + } + } + + unsafe impl IoBufMut for BytesMut { + fn stable_mut_ptr(&mut self) -> *mut u8 { + self.0.as_mut_ptr() + } + + unsafe fn set_init(&mut self, init_len: usize) { + if self.len() < init_len { + self.0.set_len(init_len); + } + } + } +} diff --git a/actix-files/src/error.rs b/actix-files/src/error.rs index e5f2d4779..f8e32eef7 100644 --- a/actix-files/src/error.rs +++ b/actix-files/src/error.rs @@ -21,6 +21,7 @@ impl ResponseError for FilesError { } } +#[allow(clippy::enum_variant_names)] #[derive(Display, Debug, PartialEq)] pub enum UriSegmentError { /// The segment started with the wrapped invalid character. diff --git a/actix-files/src/files.rs b/actix-files/src/files.rs index 49d81eb03..06909bf08 100644 --- a/actix-files/src/files.rs +++ b/actix-files/src/files.rs @@ -6,7 +6,6 @@ use std::{ }; use actix_service::{boxed, IntoServiceFactory, ServiceFactory, ServiceFactoryExt}; -use actix_utils::future::ok; use actix_web::{ dev::{ AppService, HttpServiceFactory, RequestHead, ResourceDef, ServiceRequest, @@ -20,8 +19,9 @@ use actix_web::{ use futures_core::future::LocalBoxFuture; use crate::{ - directory_listing, named, Directory, DirectoryRenderer, FilesService, HttpNewService, - MimeOverride, PathFilter, + directory_listing, named, + service::{FilesService, FilesServiceInner}, + Directory, DirectoryRenderer, HttpNewService, MimeOverride, PathFilter, }; /// Static files handling service. @@ -106,7 +106,7 @@ impl Files { }; Files { - path: mount_path.to_owned(), + path: mount_path.trim_end_matches('/').to_owned(), directory: dir, index: None, show_index: false, @@ -283,11 +283,17 @@ impl Files { /// Setting a fallback static file handler: /// ``` /// use actix_files::{Files, NamedFile}; + /// use actix_web::dev::{ServiceRequest, ServiceResponse, fn_service}; /// /// # fn run() -> Result<(), actix_web::Error> { /// let files = Files::new("/", "./static") /// .index_file("index.html") - /// .default_handler(NamedFile::open("./static/404.html")?); + /// .default_handler(fn_service(|req: ServiceRequest| async { + /// let (req, _) = req.into_parts(); + /// let file = NamedFile::open_async("./static/404.html").await?; + /// let res = file.into_response(&req); + /// Ok(ServiceResponse::new(req, res)) + /// })); /// # Ok(()) /// # } /// ``` @@ -353,7 +359,7 @@ impl ServiceFactory for Files { type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - let mut srv = FilesService { + let mut inner = FilesServiceInner { directory: self.directory.clone(), index: self.index.clone(), show_index: self.show_index, @@ -372,14 +378,14 @@ impl ServiceFactory for Files { Box::pin(async { match fut.await { Ok(default) => { - srv.default = Some(default); - Ok(srv) + inner.default = Some(default); + Ok(FilesService(Rc::new(inner))) } Err(_) => Err(()), } }) } else { - Box::pin(ok(srv)) + Box::pin(async move { Ok(FilesService(Rc::new(inner))) }) } } } diff --git a/actix-files/src/lib.rs b/actix-files/src/lib.rs index 1eb091aaf..3af5282f1 100644 --- a/actix-files/src/lib.rs +++ b/actix-files/src/lib.rs @@ -33,12 +33,12 @@ mod path_buf; mod range; mod service; -pub use crate::chunked::ChunkedReadFile; -pub use crate::directory::Directory; -pub use crate::files::Files; -pub use crate::named::NamedFile; -pub use crate::range::HttpRange; -pub use crate::service::FilesService; +pub use self::chunked::ChunkedReadFile; +pub use self::directory::Directory; +pub use self::files::Files; +pub use self::named::NamedFile; +pub use self::range::HttpRange; +pub use self::service::FilesService; use self::directory::{directory_listing, DirectoryRenderer}; use self::error::FilesError; @@ -62,13 +62,12 @@ type PathFilter = dyn Fn(&Path, &RequestHead) -> bool; #[cfg(test)] mod tests { use std::{ - fs::{self, File}, + fs::{self}, ops::Add, time::{Duration, SystemTime}, }; use actix_service::ServiceFactory; - use actix_utils::future::ok; use actix_web::{ guard, http::{ @@ -82,8 +81,9 @@ mod tests { }; use super::*; + use crate::named::File; - #[actix_rt::test] + #[actix_web::test] async fn test_file_extension_to_mime() { let m = file_extension_to_mime(""); assert_eq!(m, mime::APPLICATION_OCTET_STREAM); @@ -100,7 +100,7 @@ mod tests { #[actix_rt::test] async fn test_if_modified_since_without_if_none_match() { - let file = NamedFile::open("Cargo.toml").unwrap(); + let file = NamedFile::open_async("Cargo.toml").await.unwrap(); let since = header::HttpDate::from(SystemTime::now().add(Duration::from_secs(60))); let req = TestRequest::default() @@ -112,7 +112,7 @@ mod tests { #[actix_rt::test] async fn test_if_modified_since_without_if_none_match_same() { - let file = NamedFile::open("Cargo.toml").unwrap(); + let file = NamedFile::open_async("Cargo.toml").await.unwrap(); let since = file.last_modified().unwrap(); let req = TestRequest::default() @@ -124,7 +124,7 @@ mod tests { #[actix_rt::test] async fn test_if_modified_since_with_if_none_match() { - let file = NamedFile::open("Cargo.toml").unwrap(); + let file = NamedFile::open_async("Cargo.toml").await.unwrap(); let since = header::HttpDate::from(SystemTime::now().add(Duration::from_secs(60))); let req = TestRequest::default() @@ -137,7 +137,7 @@ mod tests { #[actix_rt::test] async fn test_if_unmodified_since() { - let file = NamedFile::open("Cargo.toml").unwrap(); + let file = NamedFile::open_async("Cargo.toml").await.unwrap(); let since = file.last_modified().unwrap(); let req = TestRequest::default() @@ -149,7 +149,7 @@ mod tests { #[actix_rt::test] async fn test_if_unmodified_since_failed() { - let file = NamedFile::open("Cargo.toml").unwrap(); + let file = NamedFile::open_async("Cargo.toml").await.unwrap(); let since = header::HttpDate::from(SystemTime::UNIX_EPOCH); let req = TestRequest::default() @@ -161,8 +161,8 @@ mod tests { #[actix_rt::test] async fn test_named_file_text() { - assert!(NamedFile::open("test--").is_err()); - let mut file = NamedFile::open("Cargo.toml").unwrap(); + assert!(NamedFile::open_async("test--").await.is_err()); + let mut file = NamedFile::open_async("Cargo.toml").await.unwrap(); { file.file(); let _f: &File = &file; @@ -185,8 +185,8 @@ mod tests { #[actix_rt::test] async fn test_named_file_content_disposition() { - assert!(NamedFile::open("test--").is_err()); - let mut file = NamedFile::open("Cargo.toml").unwrap(); + assert!(NamedFile::open_async("test--").await.is_err()); + let mut file = NamedFile::open_async("Cargo.toml").await.unwrap(); { file.file(); let _f: &File = &file; @@ -202,7 +202,8 @@ mod tests { "inline; filename=\"Cargo.toml\"" ); - let file = NamedFile::open("Cargo.toml") + let file = NamedFile::open_async("Cargo.toml") + .await .unwrap() .disable_content_disposition(); let req = TestRequest::default().to_http_request(); @@ -212,8 +213,19 @@ mod tests { #[actix_rt::test] async fn test_named_file_non_ascii_file_name() { - let mut file = - NamedFile::from_file(File::open("Cargo.toml").unwrap(), "貨物.toml").unwrap(); + let file = { + #[cfg(feature = "experimental-io-uring")] + { + crate::named::File::open("Cargo.toml").await.unwrap() + } + + #[cfg(not(feature = "experimental-io-uring"))] + { + crate::named::File::open("Cargo.toml").unwrap() + } + }; + + let mut file = NamedFile::from_file(file, "貨物.toml").unwrap(); { file.file(); let _f: &File = &file; @@ -236,7 +248,8 @@ mod tests { #[actix_rt::test] async fn test_named_file_set_content_type() { - let mut file = NamedFile::open("Cargo.toml") + let mut file = NamedFile::open_async("Cargo.toml") + .await .unwrap() .set_content_type(mime::TEXT_XML); { @@ -261,7 +274,7 @@ mod tests { #[actix_rt::test] async fn test_named_file_image() { - let mut file = NamedFile::open("tests/test.png").unwrap(); + let mut file = NamedFile::open_async("tests/test.png").await.unwrap(); { file.file(); let _f: &File = &file; @@ -284,7 +297,7 @@ mod tests { #[actix_rt::test] async fn test_named_file_javascript() { - let file = NamedFile::open("tests/test.js").unwrap(); + let file = NamedFile::open_async("tests/test.js").await.unwrap(); let req = TestRequest::default().to_http_request(); let resp = file.respond_to(&req).await.unwrap(); @@ -304,7 +317,8 @@ mod tests { disposition: DispositionType::Attachment, parameters: vec![DispositionParam::Filename(String::from("test.png"))], }; - let mut file = NamedFile::open("tests/test.png") + let mut file = NamedFile::open_async("tests/test.png") + .await .unwrap() .set_content_disposition(cd); { @@ -329,7 +343,7 @@ mod tests { #[actix_rt::test] async fn test_named_file_binary() { - let mut file = NamedFile::open("tests/test.binary").unwrap(); + let mut file = NamedFile::open_async("tests/test.binary").await.unwrap(); { file.file(); let _f: &File = &file; @@ -352,7 +366,8 @@ mod tests { #[actix_rt::test] async fn test_named_file_status_code_text() { - let mut file = NamedFile::open("Cargo.toml") + let mut file = NamedFile::open_async("Cargo.toml") + .await .unwrap() .set_status_code(StatusCode::NOT_FOUND); { @@ -568,7 +583,8 @@ mod tests { async fn test_named_file_content_encoding() { let srv = test::init_service(App::new().wrap(Compress::default()).service( web::resource("/").to(|| async { - NamedFile::open("Cargo.toml") + NamedFile::open_async("Cargo.toml") + .await .unwrap() .set_content_encoding(header::ContentEncoding::Identity) }), @@ -588,7 +604,8 @@ mod tests { async fn test_named_file_content_encoding_gzip() { let srv = test::init_service(App::new().wrap(Compress::default()).service( web::resource("/").to(|| async { - NamedFile::open("Cargo.toml") + NamedFile::open_async("Cargo.toml") + .await .unwrap() .set_content_encoding(header::ContentEncoding::Gzip) }), @@ -614,7 +631,7 @@ mod tests { #[actix_rt::test] async fn test_named_file_allowed_method() { let req = TestRequest::default().method(Method::GET).to_http_request(); - let file = NamedFile::open("Cargo.toml").unwrap(); + let file = NamedFile::open_async("Cargo.toml").await.unwrap(); let resp = file.respond_to(&req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -705,8 +722,8 @@ mod tests { #[actix_rt::test] async fn test_default_handler_file_missing() { let st = Files::new("/", ".") - .default_handler(|req: ServiceRequest| { - ok(req.into_response(HttpResponse::Ok().body("default content"))) + .default_handler(|req: ServiceRequest| async { + Ok(req.into_response(HttpResponse::Ok().body("default content"))) }) .new_service(()) .await @@ -789,9 +806,8 @@ mod tests { #[actix_rt::test] async fn test_serve_named_file() { - let srv = - test::init_service(App::new().service(NamedFile::open("Cargo.toml").unwrap())) - .await; + let factory = NamedFile::open_async("Cargo.toml").await.unwrap(); + let srv = test::init_service(App::new().service(factory)).await; let req = TestRequest::get().uri("/Cargo.toml").to_request(); let res = test::call_service(&srv, req).await; @@ -808,11 +824,9 @@ mod tests { #[actix_rt::test] async fn test_serve_named_file_prefix() { - let srv = test::init_service( - App::new() - .service(web::scope("/test").service(NamedFile::open("Cargo.toml").unwrap())), - ) - .await; + let factory = NamedFile::open_async("Cargo.toml").await.unwrap(); + let srv = + test::init_service(App::new().service(web::scope("/test").service(factory))).await; let req = TestRequest::get().uri("/test/Cargo.toml").to_request(); let res = test::call_service(&srv, req).await; @@ -829,10 +843,8 @@ mod tests { #[actix_rt::test] async fn test_named_file_default_service() { - let srv = test::init_service( - App::new().default_service(NamedFile::open("Cargo.toml").unwrap()), - ) - .await; + let factory = NamedFile::open_async("Cargo.toml").await.unwrap(); + let srv = test::init_service(App::new().default_service(factory)).await; for route in ["/foobar", "/baz", "/"].iter() { let req = TestRequest::get().uri(route).to_request(); @@ -847,8 +859,9 @@ mod tests { #[actix_rt::test] async fn test_default_handler_named_file() { + let factory = NamedFile::open_async("Cargo.toml").await.unwrap(); let st = Files::new("/", ".") - .default_handler(NamedFile::open("Cargo.toml").unwrap()) + .default_handler(factory) .new_service(()) .await .unwrap(); @@ -926,8 +939,8 @@ mod tests { #[actix_rt::test] async fn test_default_handler_filter() { let st = Files::new("/", ".") - .default_handler(|req: ServiceRequest| { - ok(req.into_response(HttpResponse::Ok().body("default content"))) + .default_handler(|req: ServiceRequest| async { + Ok(req.into_response(HttpResponse::Ok().body("default content"))) }) .path_filter(|path, _| path.extension() == Some("png".as_ref())) .new_service(()) diff --git a/actix-files/src/named.rs b/actix-files/src/named.rs index 241e78cf0..89775c6b3 100644 --- a/actix-files/src/named.rs +++ b/actix-files/src/named.rs @@ -1,17 +1,22 @@ -use actix_service::{Service, ServiceFactory}; -use actix_utils::future::{ok, ready, Ready}; -use actix_web::dev::{AppService, HttpServiceFactory, ResourceDef}; -use std::fs::{File, Metadata}; -use std::io; -use std::ops::{Deref, DerefMut}; -use std::path::{Path, PathBuf}; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::{ + fmt, + fs::Metadata, + io, + ops::{Deref, DerefMut}, + path::{Path, PathBuf}, + time::{SystemTime, UNIX_EPOCH}, +}; #[cfg(unix)] use std::os::unix::fs::MetadataExt; +use actix_service::{Service, ServiceFactory}; use actix_web::{ - dev::{BodyEncoding, ServiceRequest, ServiceResponse, SizedStream}, + body::{self, BoxBody, SizedStream}, + dev::{ + AppService, BodyEncoding, HttpServiceFactory, ResourceDef, ServiceRequest, + ServiceResponse, + }, http::{ header::{ self, Charset, ContentDisposition, DispositionParam, DispositionType, ExtendedValue, @@ -21,9 +26,9 @@ use actix_web::{ Error, HttpMessage, HttpRequest, HttpResponse, Responder, }; use bitflags::bitflags; +use futures_core::future::LocalBoxFuture; use mime_guess::from_path; -use crate::ChunkedReadFile; use crate::{encoding::equiv_utf8_text, range::HttpRange}; bitflags! { @@ -48,9 +53,9 @@ impl Default for Flags { /// use actix_web::App; /// use actix_files::NamedFile; /// -/// # fn run() -> Result<(), Box> { -/// let app = App::new() -/// .service(NamedFile::open("./static/index.html")?); +/// # async fn run() -> Result<(), Box> { +/// let file = NamedFile::open_async("./static/index.html").await?; +/// let app = App::new().service(file); /// # Ok(()) /// # } /// ``` @@ -62,10 +67,9 @@ impl Default for Flags { /// /// #[get("/")] /// async fn index() -> impl Responder { -/// NamedFile::open("./static/index.html") +/// NamedFile::open_async("./static/index.html").await /// } /// ``` -#[derive(Debug)] pub struct NamedFile { path: PathBuf, file: File, @@ -78,6 +82,39 @@ pub struct NamedFile { pub(crate) encoding: Option, } +impl fmt::Debug for NamedFile { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("NamedFile") + .field("path", &self.path) + .field( + "file", + #[cfg(feature = "experimental-io-uring")] + { + &"tokio_uring::File" + }, + #[cfg(not(feature = "experimental-io-uring"))] + { + &self.file + }, + ) + .field("modified", &self.modified) + .field("md", &self.md) + .field("flags", &self.flags) + .field("status_code", &self.status_code) + .field("content_type", &self.content_type) + .field("content_disposition", &self.content_disposition) + .field("encoding", &self.encoding) + .finish() + } +} + +#[cfg(not(feature = "experimental-io-uring"))] +pub(crate) use std::fs::File; +#[cfg(feature = "experimental-io-uring")] +pub(crate) use tokio_uring::fs::File; + +use super::chunked; + impl NamedFile { /// Creates an instance from a previously opened file. /// @@ -85,8 +122,7 @@ impl NamedFile { /// `ContentDisposition` headers. /// /// # Examples - /// - /// ``` + /// ```ignore /// use actix_files::NamedFile; /// use std::io::{self, Write}; /// use std::env; @@ -147,7 +183,30 @@ impl NamedFile { (ct, cd) }; - let md = file.metadata()?; + let md = { + #[cfg(not(feature = "experimental-io-uring"))] + { + file.metadata()? + } + + #[cfg(feature = "experimental-io-uring")] + { + use std::os::unix::prelude::{AsRawFd, FromRawFd}; + + let fd = file.as_raw_fd(); + + // SAFETY: fd is borrowed and lives longer than the unsafe block + unsafe { + let file = std::fs::File::from_raw_fd(fd); + let md = file.metadata(); + // SAFETY: forget the fd before exiting block in success or error case but don't + // run destructor (that would close file handle) + std::mem::forget(file); + md? + } + } + }; + let modified = md.modified().ok(); let encoding = None; @@ -164,17 +223,45 @@ impl NamedFile { }) } + #[cfg(not(feature = "experimental-io-uring"))] /// Attempts to open a file in read-only mode. /// /// # Examples - /// /// ``` /// use actix_files::NamedFile; - /// /// let file = NamedFile::open("foo.txt"); /// ``` pub fn open>(path: P) -> io::Result { - Self::from_file(File::open(&path)?, path) + let file = File::open(&path)?; + Self::from_file(file, path) + } + + /// Attempts to open a file asynchronously in read-only mode. + /// + /// When the `experimental-io-uring` crate feature is enabled, this will be async. + /// Otherwise, it will be just like [`open`][Self::open]. + /// + /// # Examples + /// ``` + /// use actix_files::NamedFile; + /// # async fn open() { + /// let file = NamedFile::open_async("foo.txt").await.unwrap(); + /// # } + /// ``` + pub async fn open_async>(path: P) -> io::Result { + let file = { + #[cfg(not(feature = "experimental-io-uring"))] + { + File::open(&path)? + } + + #[cfg(feature = "experimental-io-uring")] + { + File::open(&path).await? + } + }; + + Self::from_file(file, path) } /// Returns reference to the underlying `File` object. @@ -186,13 +273,12 @@ impl NamedFile { /// Retrieve the path of this file. /// /// # Examples - /// /// ``` /// # use std::io; /// use actix_files::NamedFile; /// - /// # fn path() -> io::Result<()> { - /// let file = NamedFile::open("test.txt")?; + /// # async fn path() -> io::Result<()> { + /// let file = NamedFile::open_async("test.txt").await?; /// assert_eq!(file.path().as_os_str(), "foo.txt"); /// # Ok(()) /// # } @@ -310,7 +396,7 @@ impl NamedFile { } /// Creates an `HttpResponse` with file as a streaming body. - pub fn into_response(self, req: &HttpRequest) -> HttpResponse { + pub fn into_response(self, req: &HttpRequest) -> HttpResponse { if self.status_code != StatusCode::OK { let mut res = HttpResponse::build(self.status_code); @@ -332,7 +418,7 @@ impl NamedFile { res.encoding(current_encoding); } - let reader = ChunkedReadFile::new(self.md.len(), 0, self.file); + let reader = chunked::new_chunked_read(self.md.len(), 0, self.file); return res.streaming(reader); } @@ -443,10 +529,13 @@ impl NamedFile { if precondition_failed { return resp.status(StatusCode::PRECONDITION_FAILED).finish(); } else if not_modified { - return resp.status(StatusCode::NOT_MODIFIED).finish(); + return resp + .status(StatusCode::NOT_MODIFIED) + .body(body::None::new()) + .map_into_boxed_body(); } - let reader = ChunkedReadFile::new(length, offset, self.file); + let reader = chunked::new_chunked_read(length, offset, self.file); if offset != 0 || length != self.md.len() { resp.status(StatusCode::PARTIAL_CONTENT); @@ -456,20 +545,6 @@ impl NamedFile { } } -impl Deref for NamedFile { - type Target = File; - - fn deref(&self) -> &File { - &self.file - } -} - -impl DerefMut for NamedFile { - fn deref_mut(&mut self) -> &mut File { - &mut self.file - } -} - /// Returns true if `req` has no `If-Match` header or one which matches `etag`. fn any_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool { match req.get_header::() { @@ -510,8 +585,24 @@ fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool { } } +impl Deref for NamedFile { + type Target = File; + + fn deref(&self) -> &Self::Target { + &self.file + } +} + +impl DerefMut for NamedFile { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.file + } +} + impl Responder for NamedFile { - fn respond_to(self, req: &HttpRequest) -> HttpResponse { + type Body = BoxBody; + + fn respond_to(self, req: &HttpRequest) -> HttpResponse { self.into_response(req) } } @@ -520,14 +611,16 @@ impl ServiceFactory for NamedFile { type Response = ServiceResponse; type Error = Error; type Config = (); - type InitError = (); type Service = NamedFileService; - type Future = Ready>; + type InitError = (); + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - ok(NamedFileService { + let service = NamedFileService { path: self.path.clone(), - }) + }; + + Box::pin(async move { Ok(service) }) } } @@ -540,18 +633,19 @@ pub struct NamedFileService { impl Service for NamedFileService { type Response = ServiceResponse; type Error = Error; - type Future = Ready>; + type Future = LocalBoxFuture<'static, Result>; actix_service::always_ready!(); fn call(&self, req: ServiceRequest) -> Self::Future { let (req, _) = req.into_parts(); - ready( - NamedFile::open(&self.path) - .map_err(|e| e.into()) - .map(|f| f.into_response(&req)) - .map(|res| ServiceResponse::new(req, res)), - ) + + let path = self.path.clone(); + Box::pin(async move { + let file = NamedFile::open_async(path).await?; + let res = file.into_response(&req); + Ok(ServiceResponse::new(req, res)) + }) } } diff --git a/actix-files/src/path_buf.rs b/actix-files/src/path_buf.rs index 8a87acd5d..0e0d4f51d 100644 --- a/actix-files/src/path_buf.rs +++ b/actix-files/src/path_buf.rs @@ -8,7 +8,7 @@ use actix_web::{dev::Payload, FromRequest, HttpRequest}; use crate::error::UriSegmentError; -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub(crate) struct PathBufWrap(PathBuf); impl FromStr for PathBufWrap { @@ -21,6 +21,8 @@ impl FromStr for PathBufWrap { impl PathBufWrap { /// Parse a path, giving the choice of allowing hidden files to be considered valid segments. + /// + /// Path traversal is guarded by this method. pub fn parse_path(path: &str, hidden_files: bool) -> Result { let mut buf = PathBuf::new(); @@ -59,7 +61,6 @@ impl AsRef for PathBufWrap { impl FromRequest for PathBufWrap { type Error = UriSegmentError; type Future = Ready>; - type Config = (); fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { ready(req.match_info().path().parse()) @@ -116,4 +117,24 @@ mod tests { PathBuf::from_iter(vec!["test", ".tt"]) ); } + + #[test] + fn path_traversal() { + assert_eq!( + PathBufWrap::parse_path("/../README.md", false).unwrap().0, + PathBuf::from_iter(vec!["README.md"]) + ); + + assert_eq!( + PathBufWrap::parse_path("/../README.md", true).unwrap().0, + PathBuf::from_iter(vec!["README.md"]) + ); + + assert_eq!( + PathBufWrap::parse_path("/../../../../../../../../../../etc/passwd", false) + .unwrap() + .0, + PathBuf::from_iter(vec!["etc/passwd"]) + ); + } } diff --git a/actix-files/src/service.rs b/actix-files/src/service.rs index 09122c63e..f6e1c2e11 100644 --- a/actix-files/src/service.rs +++ b/actix-files/src/service.rs @@ -1,7 +1,6 @@ -use std::{fmt, io, path::PathBuf, rc::Rc}; +use std::{fmt, io, ops::Deref, path::PathBuf, rc::Rc}; use actix_service::Service; -use actix_utils::future::ok; use actix_web::{ dev::{ServiceRequest, ServiceResponse}, error::Error, @@ -17,7 +16,18 @@ use crate::{ }; /// Assembled file serving service. -pub struct FilesService { +#[derive(Clone)] +pub struct FilesService(pub(crate) Rc); + +impl Deref for FilesService { + type Target = FilesServiceInner; + + fn deref(&self) -> &Self::Target { + &*self.0 + } +} + +pub struct FilesServiceInner { pub(crate) directory: PathBuf, pub(crate) index: Option, pub(crate) show_index: bool, @@ -31,20 +41,50 @@ pub struct FilesService { pub(crate) hidden_files: bool, } +impl fmt::Debug for FilesServiceInner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("FilesServiceInner") + } +} + impl FilesService { - fn handle_err( + async fn handle_err( &self, err: io::Error, req: ServiceRequest, - ) -> LocalBoxFuture<'static, Result> { + ) -> Result { log::debug!("error handling {}: {}", req.path(), err); if let Some(ref default) = self.default { - Box::pin(default.call(req)) + default.call(req).await } else { - Box::pin(ok(req.error_response(err))) + Ok(req.error_response(err)) } } + + fn serve_named_file( + &self, + req: ServiceRequest, + mut named_file: NamedFile, + ) -> ServiceResponse { + if let Some(ref mime_override) = self.mime_override { + let new_disposition = mime_override(&named_file.content_type.type_()); + named_file.content_disposition.disposition = new_disposition; + } + named_file.flags = self.file_flags; + + let (req, _) = req.into_parts(); + let res = named_file.into_response(&req); + ServiceResponse::new(req, res) + } + + fn show_index(&self, req: ServiceRequest, path: PathBuf) -> ServiceResponse { + let dir = Directory::new(self.directory.clone(), path); + + let (req, _) = req.into_parts(); + + (self.renderer)(&dir, &req).unwrap_or_else(|e| ServiceResponse::from_err(e, req)) + } } impl fmt::Debug for FilesService { @@ -56,7 +96,7 @@ impl fmt::Debug for FilesService { impl Service for FilesService { type Response = ServiceResponse; type Error = Error; - type Future = LocalBoxFuture<'static, Result>; + type Future = LocalBoxFuture<'static, Result>; actix_service::always_ready!(); @@ -69,103 +109,87 @@ impl Service for FilesService { matches!(*req.method(), Method::HEAD | Method::GET) }; - if !is_method_valid { - return Box::pin(ok(req.into_response( - actix_web::HttpResponse::MethodNotAllowed() - .insert_header(header::ContentType(mime::TEXT_PLAIN_UTF_8)) - .body("Request did not meet this resource's requirements."), - ))); - } + let this = self.clone(); - let real_path = - match PathBufWrap::parse_path(req.match_info().path(), self.hidden_files) { - Ok(item) => item, - Err(e) => return Box::pin(ok(req.error_response(e))), - }; + Box::pin(async move { + if !is_method_valid { + return Ok(req.into_response( + actix_web::HttpResponse::MethodNotAllowed() + .insert_header(header::ContentType(mime::TEXT_PLAIN_UTF_8)) + .body("Request did not meet this resource's requirements."), + )); + } - if let Some(filter) = &self.path_filter { - if !filter(real_path.as_ref(), req.head()) { - if let Some(ref default) = self.default { - return Box::pin(default.call(req)); - } else { - return Box::pin(ok( - req.into_response(actix_web::HttpResponse::NotFound().finish()) + let real_path = + match PathBufWrap::parse_path(req.match_info().path(), this.hidden_files) { + Ok(item) => item, + Err(e) => return Ok(req.error_response(e)), + }; + + if let Some(filter) = &this.path_filter { + if !filter(real_path.as_ref(), req.head()) { + if let Some(ref default) = this.default { + return default.call(req).await; + } else { + return Ok( + req.into_response(actix_web::HttpResponse::NotFound().finish()) + ); + } + } + } + + // full file path + let path = this.directory.join(&real_path); + if let Err(err) = path.canonicalize() { + return this.handle_err(err, req).await; + } + + if path.is_dir() { + if this.redirect_to_slash + && !req.path().ends_with('/') + && (this.index.is_some() || this.show_index) + { + let redirect_to = format!("{}/", req.path()); + + return Ok(req.into_response( + HttpResponse::Found() + .insert_header((header::LOCATION, redirect_to)) + .finish(), )); } - } - } - // full file path - let path = self.directory.join(&real_path); - if let Err(err) = path.canonicalize() { - return Box::pin(self.handle_err(err, req)); - } - - if path.is_dir() { - if self.redirect_to_slash - && !req.path().ends_with('/') - && (self.index.is_some() || self.show_index) - { - let redirect_to = format!("{}/", req.path()); - - return Box::pin(ok(req.into_response( - HttpResponse::Found() - .insert_header((header::LOCATION, redirect_to)) - .finish(), - ))); - } - - let serve_named_file = |req: ServiceRequest, mut named_file: NamedFile| { - if let Some(ref mime_override) = self.mime_override { - let new_disposition = mime_override(&named_file.content_type.type_()); - named_file.content_disposition.disposition = new_disposition; - } - named_file.flags = self.file_flags; - - let (req, _) = req.into_parts(); - let res = named_file.into_response(&req); - Box::pin(ok(ServiceResponse::new(req, res))) - }; - - let show_index = |req: ServiceRequest| { - let dir = Directory::new(self.directory.clone(), path.clone()); - - let (req, _) = req.into_parts(); - let x = (self.renderer)(&dir, &req); - - Box::pin(match x { - Ok(resp) => ok(resp), - Err(err) => ok(ServiceResponse::from_err(err, req)), - }) - }; - - match self.index { - Some(ref index) => match NamedFile::open(path.join(index)) { - Ok(named_file) => serve_named_file(req, named_file), - Err(_) if self.show_index => show_index(req), - Err(err) => self.handle_err(err, req), - }, - None if self.show_index => show_index(req), - _ => Box::pin(ok(ServiceResponse::from_err( - FilesError::IsDirectory, - req.into_parts().0, - ))), - } - } else { - match NamedFile::open(path) { - Ok(mut named_file) => { - if let Some(ref mime_override) = self.mime_override { - let new_disposition = mime_override(&named_file.content_type.type_()); - named_file.content_disposition.disposition = new_disposition; + match this.index { + Some(ref index) => { + let named_path = path.join(index); + match NamedFile::open_async(named_path).await { + Ok(named_file) => Ok(this.serve_named_file(req, named_file)), + Err(_) if this.show_index => Ok(this.show_index(req, path)), + Err(err) => this.handle_err(err, req).await, + } } - named_file.flags = self.file_flags; - - let (req, _) = req.into_parts(); - let res = named_file.into_response(&req); - Box::pin(ok(ServiceResponse::new(req, res))) + None if this.show_index => Ok(this.show_index(req, path)), + _ => Ok(ServiceResponse::from_err( + FilesError::IsDirectory, + req.into_parts().0, + )), + } + } else { + match NamedFile::open_async(&path).await { + Ok(mut named_file) => { + if let Some(ref mime_override) = this.mime_override { + let new_disposition = + mime_override(&named_file.content_type.type_()); + named_file.content_disposition.disposition = new_disposition; + } + named_file.flags = this.file_flags; + + let (req, _) = req.into_parts(); + let res = named_file.into_response(&req); + Ok(ServiceResponse::new(req, res)) + } + Err(err) => this.handle_err(err, req).await, } - Err(err) => self.handle_err(err, req), } - } + }) } } diff --git a/actix-files/tests/encoding.rs b/actix-files/tests/encoding.rs index d21d4f8fd..652a7c12b 100644 --- a/actix-files/tests/encoding.rs +++ b/actix-files/tests/encoding.rs @@ -8,7 +8,7 @@ use actix_web::{ App, }; -#[actix_rt::test] +#[actix_web::test] async fn test_utf8_file_contents() { // use default ISO-8859-1 encoding let srv = test::init_service(App::new().service(Files::new("/", "./tests"))).await; diff --git a/actix-files/tests/guard.rs b/actix-files/tests/guard.rs index 8b1785e7f..d053f3fdc 100644 --- a/actix-files/tests/guard.rs +++ b/actix-files/tests/guard.rs @@ -7,7 +7,7 @@ use actix_web::{ }; use bytes::Bytes; -#[actix_rt::test] +#[actix_web::test] async fn test_guard_filter() { let srv = test::init_service( App::new() diff --git a/actix-files/tests/traversal.rs b/actix-files/tests/traversal.rs new file mode 100644 index 000000000..c890b3fe4 --- /dev/null +++ b/actix-files/tests/traversal.rs @@ -0,0 +1,27 @@ +use actix_files::Files; +use actix_web::{ + http::StatusCode, + test::{self, TestRequest}, + App, +}; + +#[actix_rt::test] +async fn test_directory_traversal_prevention() { + let srv = test::init_service(App::new().service(Files::new("/", "./tests"))).await; + + let req = + TestRequest::with_uri("/../../../../../../../../../../../etc/passwd").to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + let req = TestRequest::with_uri( + "/%2e%2e/%2e%2e/%2e%2e/%2e%2e/%2e%2e/%2e%2e/%2e%2e/%2e%2e/%2e%2e/%2e%2e/etc/passwd", + ) + .to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + let req = TestRequest::with_uri("/%00/etc/passwd%00").to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); +} diff --git a/actix-http-test/CHANGES.md b/actix-http-test/CHANGES.md index 1dbd9a15b..6984e5962 100644 --- a/actix-http-test/CHANGES.md +++ b/actix-http-test/CHANGES.md @@ -3,6 +3,30 @@ ## Unreleased - 2021-xx-xx +## 3.0.0-beta.8 - 2021-11-30 +* Update `actix-tls` to `3.0.0-rc.1`. [#2474] + +[#2474]: https://github.com/actix/actix-web/pull/2474 + + +## 3.0.0-beta.7 - 2021-11-22 +* Fix compatibility with experimental `io-uring` feature of `actix-rt`. [#2408] + +[#2408]: https://github.com/actix/actix-web/pull/2408 + + +## 3.0.0-beta.6 - 2021-11-15 +* `TestServer::stop` is now async and will wait for the server and system to shutdown. [#2442] +* Update `actix-server` to `2.0.0-beta.9`. [#2442] +* Minimum supported Rust version (MSRV) is now 1.52. + +[#2442]: https://github.com/actix/actix-web/pull/2442 + + +## 3.0.0-beta.5 - 2021-09-09 +* Minimum supported Rust version (MSRV) is now 1.51. + + ## 3.0.0-beta.4 - 2021-04-02 * Added `TestServer::client_headers` method. [#2097] diff --git a/actix-http-test/Cargo.toml b/actix-http-test/Cargo.toml index c04b5da49..8d347d4e9 100644 --- a/actix-http-test/Cargo.toml +++ b/actix-http-test/Cargo.toml @@ -1,18 +1,18 @@ [package] name = "actix-http-test" -version = "3.0.0-beta.4" +version = "3.0.0-beta.8" authors = ["Nikolay Kim "] description = "Various helpers for Actix applications to use during testing" -readme = "README.md" keywords = ["http", "web", "framework", "async", "futures"] homepage = "https://actix.rs" repository = "https://github.com/actix/actix-web.git" -documentation = "https://docs.rs/actix-http-test/" -categories = ["network-programming", "asynchronous", - "web-programming::http-server", - "web-programming::websocket"] +categories = [ + "network-programming", + "asynchronous", + "web-programming::http-server", + "web-programming::websocket", +] license = "MIT OR Apache-2.0" -exclude = [".gitignore", ".cargo/config"] edition = "2018" [package.metadata.docs.rs] @@ -30,26 +30,26 @@ openssl = ["tls-openssl", "awc/openssl"] [dependencies] actix-service = "2.0.0" -actix-codec = "0.4.0" -actix-tls = "3.0.0-beta.5" +actix-codec = "0.4.1" +actix-tls = "3.0.0-rc.1" actix-utils = "3.0.0" actix-rt = "2.2" -actix-server = "2.0.0-beta.3" -awc = { version = "3.0.0-beta.7", default-features = false } +actix-server = "2.0.0-beta.9" +awc = { version = "3.0.0-beta.11", default-features = false } base64 = "0.13" bytes = "1" futures-core = { version = "0.3.7", default-features = false } -http = "0.2.2" +http = "0.2.5" log = "0.4" socket2 = "0.4" serde = "1.0" serde_json = "1.0" slab = "0.4" serde_urlencoded = "0.7" -time = { version = "0.2.23", default-features = false, features = ["std"] } tls-openssl = { version = "0.10.9", package = "openssl", optional = true } +tokio = { version = "1.2", features = ["sync"] } [dev-dependencies] -actix-web = { version = "4.0.0-beta.8", default-features = false, features = ["cookies"] } -actix-http = "3.0.0-beta.8" +actix-web = { version = "4.0.0-beta.11", default-features = false, features = ["cookies"] } +actix-http = "3.0.0-beta.14" diff --git a/actix-http-test/README.md b/actix-http-test/README.md index 74260a352..c3e99d259 100644 --- a/actix-http-test/README.md +++ b/actix-http-test/README.md @@ -3,15 +3,15 @@ > Various helpers for Actix applications to use during testing. [![crates.io](https://img.shields.io/crates/v/actix-http-test?label=latest)](https://crates.io/crates/actix-http-test) -[![Documentation](https://docs.rs/actix-http-test/badge.svg?version=3.0.0-beta.4)](https://docs.rs/actix-http-test/3.0.0-beta.4) -[![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) +[![Documentation](https://docs.rs/actix-http-test/badge.svg?version=3.0.0-beta.8)](https://docs.rs/actix-http-test/3.0.0-beta.8) +[![Version](https://img.shields.io/badge/rustc-1.52+-ab6000.svg)](https://blog.rust-lang.org/2021/05/06/Rust-1.52.0.html) ![MIT or Apache 2.0 licensed](https://img.shields.io/crates/l/actix-http-test)
-[![Dependency Status](https://deps.rs/crate/actix-http-test/3.0.0-beta.4/status.svg)](https://deps.rs/crate/actix-http-test/3.0.0-beta.4) +[![Dependency Status](https://deps.rs/crate/actix-http-test/3.0.0-beta.8/status.svg)](https://deps.rs/crate/actix-http-test/3.0.0-beta.8) [![Download](https://img.shields.io/crates/d/actix-http-test.svg)](https://crates.io/crates/actix-http-test) [![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x) ## Documentation & Resources - [API Documentation](https://docs.rs/actix-http-test) -- Minimum Supported Rust Version (MSRV): 1.46.0 +- Minimum Supported Rust Version (MSRV): 1.52 diff --git a/actix-http-test/src/lib.rs b/actix-http-test/src/lib.rs index 0f126c99a..7f55a0bf4 100644 --- a/actix-http-test/src/lib.rs +++ b/actix-http-test/src/lib.rs @@ -7,8 +7,7 @@ #[cfg(feature = "openssl")] extern crate tls_openssl as openssl; -use std::sync::mpsc; -use std::{net, thread, time}; +use std::{net, thread, time::Duration}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_rt::{net::TcpStream, System}; @@ -20,29 +19,28 @@ use bytes::Bytes; use futures_core::stream::Stream; use http::Method; use socket2::{Domain, Protocol, Socket, Type}; +use tokio::sync::mpsc; -/// Start test server +/// Start test server. /// -/// `TestServer` is very simple test server that simplify process of writing -/// integration tests cases for actix web applications. +/// `TestServer` is very simple test server that simplify process of writing integration tests cases +/// for HTTP applications. /// /// # Examples -/// -/// ``` +/// ```no_run /// use actix_http::HttpService; -/// use actix_http_test::TestServer; +/// use actix_http_test::test_server; /// use actix_web::{web, App, HttpResponse, Error}; /// /// async fn my_handler() -> Result { /// Ok(HttpResponse::Ok().into()) /// } /// -/// #[actix_rt::test] +/// #[actix_web::test] /// async fn test_example() { -/// let mut srv = TestServer::start( -/// || HttpService::new( -/// App::new().service( -/// web::resource("/").to(my_handler)) +/// let mut srv = TestServer::start(|| +/// HttpService::new( +/// App::new().service(web::resource("/").to(my_handler)) /// ) /// ); /// @@ -56,72 +54,86 @@ pub async fn test_server>(factory: F) -> TestServer test_server_with_addr(tcp, factory).await } -/// Start [`test server`](test_server()) on a concrete Address +/// Start [`test server`](test_server()) on an existing address binding. pub async fn test_server_with_addr>( tcp: net::TcpListener, factory: F, ) -> TestServer { - let (tx, rx) = mpsc::channel(); + let (started_tx, started_rx) = std::sync::mpsc::channel(); + let (thread_stop_tx, thread_stop_rx) = mpsc::channel(1); // run server in separate thread thread::spawn(move || { - let sys = System::new(); - let local_addr = tcp.local_addr().unwrap(); + System::new().block_on(async move { + let local_addr = tcp.local_addr().unwrap(); - let srv = Server::build() - .listen("test", tcp, factory)? - .workers(1) - .disable_signals(); + let srv = Server::build() + .workers(1) + .disable_signals() + .system_exit() + .listen("test", tcp, factory) + .expect("test server could not be created"); - sys.block_on(async { - srv.run(); - tx.send((System::current(), local_addr)).unwrap(); + let srv = srv.run(); + started_tx + .send((System::current(), srv.handle(), local_addr)) + .unwrap(); + + // drive server loop + srv.await.unwrap(); }); - sys.run() + // notify TestServer that server and system have shut down + // all thread managed resources should be dropped at this point + let _ = thread_stop_tx.send(()); }); - let (system, addr) = rx.recv().unwrap(); + let (system, server, addr) = started_rx.recv().unwrap(); let client = { + #[cfg(feature = "openssl")] let connector = { - #[cfg(feature = "openssl")] - { - use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; - let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); - builder.set_verify(SslVerifyMode::NONE); - let _ = builder - .set_alpn_protos(b"\x02h2\x08http/1.1") - .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); - Connector::new() - .conn_lifetime(time::Duration::from_secs(0)) - .timeout(time::Duration::from_millis(30000)) - .ssl(builder.build()) - } - #[cfg(not(feature = "openssl"))] - { - Connector::new() - .conn_lifetime(time::Duration::from_secs(0)) - .timeout(time::Duration::from_millis(30000)) - } + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + + builder.set_verify(SslVerifyMode::NONE); + let _ = builder + .set_alpn_protos(b"\x02h2\x08http/1.1") + .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); + + Connector::new() + .conn_lifetime(Duration::from_secs(0)) + .timeout(Duration::from_millis(30000)) + .ssl(builder.build()) + }; + + #[cfg(not(feature = "openssl"))] + let connector = { + Connector::new() + .conn_lifetime(Duration::from_secs(0)) + .timeout(Duration::from_millis(30000)) }; Client::builder().connector(connector).finish() }; TestServer { - addr, + server, client, system, + addr, + thread_stop_rx, } } /// Test server controller pub struct TestServer { + server: actix_server::ServerHandle, + client: awc::Client, + system: actix_rt::System, addr: net::SocketAddr, - client: Client, - system: System, + thread_stop_rx: mpsc::Receiver<()>, } impl TestServer { @@ -258,15 +270,32 @@ impl TestServer { self.client.headers() } - /// Stop HTTP server - fn stop(&mut self) { + /// Stop HTTP server. + /// + /// Waits for spawned `Server` and `System` to (force) shutdown. + pub async fn stop(&mut self) { + // signal server to stop + self.server.stop(false).await; + + // also signal system to stop + // though this is handled by `ServerBuilder::exit_system` too self.system.stop(); + + // wait for thread to be stopped but don't care about result + let _ = self.thread_stop_rx.recv().await; } } impl Drop for TestServer { fn drop(&mut self) { - self.stop() + // calls in this Drop impl should be enough to shut down the server, system, and thread + // without needing to await anything + + // signal server to stop + let _ = self.server.stop(true); + + // signal system to stop + self.system.stop(); } } diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index e90f9490d..bc5e93b42 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -2,12 +2,123 @@ ## Unreleased - 2021-xx-xx ### Added +* Add timeout for canceling HTTP/2 server side connection handshake. Default to 5 seconds. [#2483] +* HTTP/2 handshake timeout can be configured with `ServiceConfig::client_timeout`. [#2483] +* `Response::map_into_boxed_body`. [#2468] +* `body::EitherBody` enum. [#2468] +* `body::None` struct. [#2468] +* Impl `MessageBody` for `bytestring::ByteString`. [#2468] +* `impl Clone for ws::HandshakeError`. [#2468] +* `#[must_use]` for `ws::Codec` to prevent subtle bugs. [#1920] +* `impl Default ` for `ws::Codec`. [#1920] +* `header::QualityItem::{max, min}`. [#2486] +* `header::Quality::{MAX, MIN}`. [#2486] +* `impl Display` for `header::Quality`. [#2486] * `CloneableExtensions` object for use in `on_connect` handlers. [#2327] ### Changed +* Rename `body::BoxBody::{from_body => new}`. [#2468] +* Body type for `Responses` returned from `Response::{new, ok, etc...}` is now `BoxBody`. [#2468] +* The `Error` associated type on `MessageBody` type now requires `impl Error` (or similar). [#2468] +* Error types using in service builders now require `Into>`. [#2468] +* `From` implementations on error types now return a `Response`. [#2468] +* `ResponseBuilder::body(B)` now returns `Response>`. [#2468] +* `ResponseBuilder::finish()` now returns `Response>`. [#2468] * `on_connect_ext` methods now receive a `CloneableExtensions` object. [#2327] +### Removed +* `ResponseBuilder::streaming`. [#2468] +* `impl Future` for `ResponseBuilder`. [#2468] +* Remove unnecessary `MessageBody` bound on types passed to `body::AnyBody::new`. [#2468] +* Move `body::AnyBody` to `awc`. Replaced with `EitherBody` and `BoxBody`. [#2468] +* `impl Copy` for `ws::Codec`. [#1920] +* `header::qitem` helper. Replaced with `header::QualityItem::max` [#2486] +* `impl TryFrom` for `header::Quality` [#2486] + [#2327]: https://github.com/actix/actix-web/pull/2327 +[#2483]: https://github.com/actix/actix-web/pull/2483 +[#2468]: https://github.com/actix/actix-web/pull/2468 +[#1920]: https://github.com/actix/actix-web/pull/1920 +[#2486]: https://github.com/actix/actix-web/pull/2486 + + +## 3.0.0-beta.14 - 2021-11-30 +### Changed +* Guarantee ordering of `header::GetAll` iterator to be same as insertion order. [#2467] +* Expose `header::map` module. [#2467] +* Implement `ExactSizeIterator` and `FusedIterator` for all `HeaderMap` iterators. [#2470] +* Update `actix-tls` to `3.0.0-rc.1`. [#2474] + +[#2467]: https://github.com/actix/actix-web/pull/2467 +[#2470]: https://github.com/actix/actix-web/pull/2470 +[#2474]: https://github.com/actix/actix-web/pull/2474 + + +## 3.0.0-beta.13 - 2021-11-22 +### Added +* `body::AnyBody::empty` for quickly creating an empty body. [#2446] +* `body::AnyBody::none` for quickly creating a "none" body. [#2456] +* `impl Clone` for `body::AnyBody where S: Clone`. [#2448] +* `body::AnyBody::into_boxed` for quickly converting to a type-erased, boxed body type. [#2448] + +### Changed +* Rename `body::AnyBody::{Message => Body}`. [#2446] +* Rename `body::AnyBody::{from_message => new_boxed}`. [#2448] +* Rename `body::AnyBody::{from_slice => copy_from_slice}`. [#2448] +* Rename `body::{BoxAnyBody => BoxBody}`. [#2448] +* Change representation of `AnyBody` to include a type parameter in `Body` variant. Defaults to `BoxBody`. [#2448] +* `Encoder::response` now returns `AnyBody>`. [#2448] + +### Removed +* `body::AnyBody::Empty`; an empty body can now only be represented as a zero-length `Bytes` variant. [#2446] +* `body::BodySize::Empty`; an empty body can now only be represented as a `Sized(0)` variant. [#2446] +* `EncoderError::Boxed`; it is no longer required. [#2446] +* `body::ResponseBody`; is function is replaced by the new `body::AnyBody` enum. [#2446] + +[#2446]: https://github.com/actix/actix-web/pull/2446 +[#2448]: https://github.com/actix/actix-web/pull/2448 +[#2456]: https://github.com/actix/actix-web/pull/2456 + + +## 3.0.0-beta.12 - 2021-11-15 +### Changed +* Update `actix-server` to `2.0.0-beta.9`. [#2442] + +### Removed +* `client` module. [#2425] +* `trust-dns` feature. [#2425] + +[#2425]: https://github.com/actix/actix-web/pull/2425 +[#2442]: https://github.com/actix/actix-web/pull/2442 + + +## 3.0.0-beta.11 - 2021-10-20 +### Changed +* Updated rustls to v0.20. [#2414] +* Minimum supported Rust version (MSRV) is now 1.52. + +[#2414]: https://github.com/actix/actix-web/pull/2414 + + +## 3.0.0-beta.10 - 2021-09-09 +### Changed +* `ContentEncoding` is now marked `#[non_exhaustive]`. [#2377] +* Minimum supported Rust version (MSRV) is now 1.51. + +### Fixed +* Remove slice creation pointing to potential uninitialized data on h1 encoder. [#2364] +* Remove `Into` bound on `Encoder` body types. [#2375] +* Fix quality parse error in Accept-Encoding header. [#2344] + +[#2364]: https://github.com/actix/actix-web/pull/2364 +[#2375]: https://github.com/actix/actix-web/pull/2375 +[#2344]: https://github.com/actix/actix-web/pull/2344 +[#2377]: https://github.com/actix/actix-web/pull/2377 + + +## 3.0.0-beta.9 - 2021-08-09 +### Fixed +* Potential HTTP request smuggling vulnerabilities. [RUSTSEC-2021-0081](https://github.com/rustsec/advisory-db/pull/977) ## 3.0.0-beta.8 - 2021-06-26 @@ -217,6 +328,11 @@ [#1878]: https://github.com/actix/actix-web/pull/1878 +## 2.2.1 - 2021-08-09 +### Fixed +* Potential HTTP request smuggling vulnerabilities. [RUSTSEC-2021-0081](https://github.com/rustsec/advisory-db/pull/977) + + ## 2.2.0 - 2020-11-25 ### Added * HttpResponse builders for 1xx status codes. [#1768] diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index a12fed4b9..967f04d03 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -1,14 +1,17 @@ [package] name = "actix-http" -version = "3.0.0-beta.8" +version = "3.0.0-beta.14" authors = ["Nikolay Kim "] description = "HTTP primitives for the Actix ecosystem" keywords = ["actix", "http", "framework", "async", "futures"] homepage = "https://actix.rs" -repository = "https://github.com/actix/actix-web" -categories = ["network-programming", "asynchronous", - "web-programming::http-server", - "web-programming::websocket"] +repository = "https://github.com/actix/actix-web.git" +categories = [ + "network-programming", + "asynchronous", + "web-programming::http-server", + "web-programming::websocket", +] license = "MIT OR Apache-2.0" edition = "2018" @@ -24,29 +27,25 @@ path = "src/lib.rs" default = [] # openssl -openssl = ["actix-tls/openssl"] +openssl = ["actix-tls/accept", "actix-tls/openssl"] # rustls support -rustls = ["actix-tls/rustls"] +rustls = ["actix-tls/accept", "actix-tls/rustls"] # enable compression support compress-brotli = ["brotli2", "__compress"] compress-gzip = ["flate2", "__compress"] compress-zstd = ["zstd", "__compress"] -# trust-dns as client dns resolver -trust-dns = ["trust-dns-resolver"] - # Internal (PRIVATE!) features used to aid testing and cheking feature status. # Don't rely on these whatsoever. They may disappear at anytime. __compress = [] [dependencies] actix-service = "2.0.0" -actix-codec = "0.4.0" +actix-codec = "0.4.1" actix-utils = "3.0.0" actix-rt = "2.2" -actix-tls = { version = "3.0.0-beta.5", features = ["accept", "connect"] } ahash = "0.7" base64 = "0.13" @@ -58,45 +57,45 @@ encoding_rs = "0.8" futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] } futures-util = { version = "0.3.7", default-features = false, features = ["alloc", "sink"] } h2 = "0.3.1" -http = "0.2.2" -httparse = "1.3" +http = "0.2.5" +httparse = "1.5.1" +httpdate = "1.0.1" itoa = "0.4" language-tags = "0.3" local-channel = "0.1" -once_cell = "1.5" log = "0.4" mime = "0.3" percent-encoding = "2.1" pin-project = "1.0.0" pin-project-lite = "0.2" rand = "0.8" -regex = "1.3" -serde = "1.0" sha-1 = "0.9" -smallvec = "1.6" -time = { version = "0.2.23", default-features = false, features = ["std"] } -tokio = { version = "1.2", features = ["sync"] } +smallvec = "1.6.1" + +# tls +actix-tls = { version = "3.0.0-rc.1", default-features = false, optional = true } # compression brotli2 = { version="0.3.2", optional = true } flate2 = { version = "1.0.13", optional = true } -zstd = { version = "0.7", optional = true } - -trust-dns-resolver = { version = "0.20.0", optional = true } +zstd = { version = "0.9", optional = true } [dev-dependencies] -actix-server = "2.0.0-beta.3" -actix-http-test = { version = "3.0.0-beta.4", features = ["openssl"] } -actix-tls = { version = "3.0.0-beta.5", features = ["openssl"] } +actix-server = "2.0.0-beta.9" +actix-http-test = { version = "3.0.0-beta.7", features = ["openssl"] } +actix-tls = { version = "3.0.0-rc.1", features = ["openssl"] } async-stream = "0.3" criterion = { version = "0.3", features = ["html_reports"] } -env_logger = "0.8" +env_logger = "0.9" rcgen = "0.8" +regex = "1.3" +rustls-pemfile = "0.2" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -tls-openssl = { version = "0.10", package = "openssl" } -tls-rustls = { version = "0.19", package = "rustls" } -webpki = { version = "0.21.0" } +static_assertions = "1" +tls-openssl = { package = "openssl", version = "0.10.9" } +tls-rustls = { package = "rustls", version = "0.20.0" } +tokio = { version = "1.2", features = ["net", "rt"] } [[example]] name = "ws" @@ -113,3 +112,7 @@ harness = false [[bench]] name = "uninit-headers" harness = false + +[[bench]] +name = "quality-value" +harness = false diff --git a/actix-http/README.md b/actix-http/README.md index de1ef0a9b..92b86d2a3 100644 --- a/actix-http/README.md +++ b/actix-http/README.md @@ -3,18 +3,18 @@ > HTTP primitives for the Actix ecosystem. [![crates.io](https://img.shields.io/crates/v/actix-http?label=latest)](https://crates.io/crates/actix-http) -[![Documentation](https://docs.rs/actix-http/badge.svg?version=3.0.0-beta.8)](https://docs.rs/actix-http/3.0.0-beta.8) -[![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) +[![Documentation](https://docs.rs/actix-http/badge.svg?version=3.0.0-beta.14)](https://docs.rs/actix-http/3.0.0-beta.14) +[![Version](https://img.shields.io/badge/rustc-1.52+-ab6000.svg)](https://blog.rust-lang.org/2021/05/06/Rust-1.52.0.html) ![MIT or Apache 2.0 licensed](https://img.shields.io/crates/l/actix-http.svg)
-[![dependency status](https://deps.rs/crate/actix-http/3.0.0-beta.8/status.svg)](https://deps.rs/crate/actix-http/3.0.0-beta.8) +[![dependency status](https://deps.rs/crate/actix-http/3.0.0-beta.14/status.svg)](https://deps.rs/crate/actix-http/3.0.0-beta.14) [![Download](https://img.shields.io/crates/d/actix-http.svg)](https://crates.io/crates/actix-http) [![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x) ## Documentation & Resources - [API Documentation](https://docs.rs/actix-http) -- Minimum Supported Rust Version (MSRV): 1.46.0 +- Minimum Supported Rust Version (MSRV): 1.52 ## Example diff --git a/actix-http/benches/quality-value.rs b/actix-http/benches/quality-value.rs new file mode 100644 index 000000000..31b67f999 --- /dev/null +++ b/actix-http/benches/quality-value.rs @@ -0,0 +1,90 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; + +const CODES: &[u16] = &[0, 1000, 201, 800, 550]; + +fn bench_quality_display_impls(c: &mut Criterion) { + let mut group = c.benchmark_group("quality value display impls"); + + for i in CODES.iter() { + group.bench_with_input(BenchmarkId::new("New (fast?)", i), i, |b, &i| { + b.iter(|| _new::Quality(i).to_string()) + }); + + group.bench_with_input(BenchmarkId::new("Naive", i), i, |b, &i| { + b.iter(|| _naive::Quality(i).to_string()) + }); + } + + group.finish(); +} + +criterion_group!(benches, bench_quality_display_impls); +criterion_main!(benches); + +mod _new { + use std::fmt; + + pub struct Quality(pub(crate) u16); + + impl fmt::Display for Quality { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0 { + 0 => f.write_str("0"), + 1000 => f.write_str("1"), + + // some number in the range 1–999 + x => { + f.write_str("0.")?; + + // this implementation avoids string allocation otherwise required + // for `.trim_end_matches('0')` + + if x < 10 { + f.write_str("00")?; + // 0 is handled so it's not possible to have a trailing 0, we can just return + itoa::fmt(f, x) + } else if x < 100 { + f.write_str("0")?; + if x % 10 == 0 { + // trailing 0, divide by 10 and write + itoa::fmt(f, x / 10) + } else { + itoa::fmt(f, x) + } + } else { + // x is in range 101–999 + + if x % 100 == 0 { + // two trailing 0s, divide by 100 and write + itoa::fmt(f, x / 100) + } else if x % 10 == 0 { + // one trailing 0, divide by 10 and write + itoa::fmt(f, x / 10) + } else { + itoa::fmt(f, x) + } + } + } + } + } + } +} + +mod _naive { + use std::fmt; + + pub struct Quality(pub(crate) u16); + + impl fmt::Display for Quality { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0 { + 0 => f.write_str("0"), + 1000 => f.write_str("1"), + + x => { + write!(f, "{}", format!("{:03}", x).trim_end_matches('0')) + } + } + } + } +} diff --git a/actix-http/benches/write-camel-case.rs b/actix-http/benches/write-camel-case.rs index fa4930eb9..ccf09b37e 100644 --- a/actix-http/benches/write-camel-case.rs +++ b/actix-http/benches/write-camel-case.rs @@ -18,7 +18,8 @@ fn bench_write_camel_case(c: &mut Criterion) { group.bench_with_input(BenchmarkId::new("New", i), bts, |b, bts| { b.iter(|| { let mut buf = black_box([0; 24]); - _new::write_camel_case(black_box(bts), &mut buf) + let len = black_box(bts.len()); + _new::write_camel_case(black_box(bts), buf.as_mut_ptr(), len) }); }); } @@ -30,9 +31,12 @@ criterion_group!(benches, bench_write_camel_case); criterion_main!(benches); mod _new { - pub fn write_camel_case(value: &[u8], buffer: &mut [u8]) { + pub fn write_camel_case(value: &[u8], buf: *mut u8, len: usize) { // first copy entire (potentially wrong) slice to output - buffer[..value.len()].copy_from_slice(value); + let buffer = unsafe { + std::ptr::copy_nonoverlapping(value.as_ptr(), buf, len); + std::slice::from_raw_parts_mut(buf, len) + }; let mut iter = value.iter(); diff --git a/actix-http/examples/echo2.rs b/actix-http/examples/echo2.rs index db195d65b..6092c01ce 100644 --- a/actix-http/examples/echo2.rs +++ b/actix-http/examples/echo2.rs @@ -1,12 +1,14 @@ use std::io; -use actix_http::{body::Body, http::HeaderValue, http::StatusCode}; -use actix_http::{Error, HttpService, Request, Response}; +use actix_http::{ + body::MessageBody, http::HeaderValue, http::StatusCode, Error, HttpService, Request, + Response, +}; use actix_server::Server; use bytes::BytesMut; use futures_util::StreamExt as _; -async fn handle_request(mut req: Request) -> Result, Error> { +async fn handle_request(mut req: Request) -> Result, Error> { let mut body = BytesMut::new(); while let Some(item) = req.payload().next().await { body.extend_from_slice(&item?) diff --git a/actix-http/examples/ws.rs b/actix-http/examples/ws.rs index d3cedf870..b6be4d2f1 100644 --- a/actix-http/examples/ws.rs +++ b/actix-http/examples/ws.rs @@ -85,22 +85,31 @@ impl Stream for Heartbeat { fn tls_config() -> rustls::ServerConfig { use std::io::BufReader; - use rustls::{ - internal::pemfile::{certs, pkcs8_private_keys}, - NoClientAuth, ServerConfig, - }; + use rustls::{Certificate, PrivateKey}; + use rustls_pemfile::{certs, pkcs8_private_keys}; let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); let cert_file = cert.serialize_pem().unwrap(); let key_file = cert.serialize_private_key_pem(); - let mut config = ServerConfig::new(NoClientAuth::new()); let cert_file = &mut BufReader::new(cert_file.as_bytes()); let key_file = &mut BufReader::new(key_file.as_bytes()); - let cert_chain = certs(cert_file).unwrap(); + let cert_chain = certs(cert_file) + .unwrap() + .into_iter() + .map(Certificate) + .collect(); let mut keys = pkcs8_private_keys(key_file).unwrap(); - config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); + + let mut config = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(cert_chain, PrivateKey(keys.remove(0))) + .unwrap(); + + config.alpn_protocols.push(b"http/1.1".to_vec()); + config.alpn_protocols.push(b"h2".to_vec()); config } diff --git a/actix-http/src/body/body.rs b/actix-http/src/body/body.rs deleted file mode 100644 index f04837d07..000000000 --- a/actix-http/src/body/body.rs +++ /dev/null @@ -1,233 +0,0 @@ -use std::{ - borrow::Cow, - error::Error as StdError, - fmt, mem, - pin::Pin, - task::{Context, Poll}, -}; - -use bytes::{Bytes, BytesMut}; -use futures_core::{ready, Stream}; - -use crate::error::Error; - -use super::{BodySize, BodyStream, MessageBody, MessageBodyMapErr, SizedStream}; - -pub type Body = AnyBody; - -/// Represents various types of HTTP message body. -pub enum AnyBody { - /// Empty response. `Content-Length` header is not set. - None, - - /// Zero sized response body. `Content-Length` header is set to `0`. - Empty, - - /// Specific response body. - Bytes(Bytes), - - /// Generic message body. - Message(BoxAnyBody), -} - -impl AnyBody { - /// Create body from slice (copy) - pub fn from_slice(s: &[u8]) -> Self { - Self::Bytes(Bytes::copy_from_slice(s)) - } - - /// Create body from generic message body. - pub fn from_message(body: B) -> Self - where - B: MessageBody + 'static, - B::Error: Into>, - { - Self::Message(BoxAnyBody::from_body(body)) - } -} - -impl MessageBody for AnyBody { - type Error = Error; - - fn size(&self) -> BodySize { - match self { - AnyBody::None => BodySize::None, - AnyBody::Empty => BodySize::Empty, - AnyBody::Bytes(ref bin) => BodySize::Sized(bin.len() as u64), - AnyBody::Message(ref body) => body.size(), - } - } - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - match self.get_mut() { - AnyBody::None => Poll::Ready(None), - AnyBody::Empty => Poll::Ready(None), - AnyBody::Bytes(ref mut bin) => { - let len = bin.len(); - if len == 0 { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(mem::take(bin)))) - } - } - - // TODO: MSRV 1.51: poll_map_err - AnyBody::Message(body) => match ready!(body.as_pin_mut().poll_next(cx)) { - Some(Err(err)) => { - Poll::Ready(Some(Err(Error::new_body().with_cause(err)))) - } - Some(Ok(val)) => Poll::Ready(Some(Ok(val))), - None => Poll::Ready(None), - }, - } - } -} - -impl PartialEq for AnyBody { - fn eq(&self, other: &Body) -> bool { - match *self { - AnyBody::None => matches!(*other, AnyBody::None), - AnyBody::Empty => matches!(*other, AnyBody::Empty), - AnyBody::Bytes(ref b) => match *other { - AnyBody::Bytes(ref b2) => b == b2, - _ => false, - }, - AnyBody::Message(_) => false, - } - } -} - -impl fmt::Debug for AnyBody { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - AnyBody::None => write!(f, "AnyBody::None"), - AnyBody::Empty => write!(f, "AnyBody::Empty"), - AnyBody::Bytes(ref b) => write!(f, "AnyBody::Bytes({:?})", b), - AnyBody::Message(_) => write!(f, "AnyBody::Message(_)"), - } - } -} - -impl From<&'static str> for AnyBody { - fn from(s: &'static str) -> Body { - AnyBody::Bytes(Bytes::from_static(s.as_ref())) - } -} - -impl From<&'static [u8]> for AnyBody { - fn from(s: &'static [u8]) -> Body { - AnyBody::Bytes(Bytes::from_static(s)) - } -} - -impl From> for AnyBody { - fn from(vec: Vec) -> Body { - AnyBody::Bytes(Bytes::from(vec)) - } -} - -impl From for AnyBody { - fn from(s: String) -> Body { - s.into_bytes().into() - } -} - -impl From<&'_ String> for AnyBody { - fn from(s: &String) -> Body { - AnyBody::Bytes(Bytes::copy_from_slice(AsRef::<[u8]>::as_ref(&s))) - } -} - -impl From> for AnyBody { - fn from(s: Cow<'_, str>) -> Body { - match s { - Cow::Owned(s) => AnyBody::from(s), - Cow::Borrowed(s) => { - AnyBody::Bytes(Bytes::copy_from_slice(AsRef::<[u8]>::as_ref(s))) - } - } - } -} - -impl From for AnyBody { - fn from(s: Bytes) -> Body { - AnyBody::Bytes(s) - } -} - -impl From for AnyBody { - fn from(s: BytesMut) -> Body { - AnyBody::Bytes(s.freeze()) - } -} - -impl From> for AnyBody -where - S: Stream> + 'static, - E: Into> + 'static, -{ - fn from(s: SizedStream) -> Body { - AnyBody::from_message(s) - } -} - -impl From> for AnyBody -where - S: Stream> + 'static, - E: Into> + 'static, -{ - fn from(s: BodyStream) -> Body { - AnyBody::from_message(s) - } -} - -/// A boxed message body with boxed errors. -pub struct BoxAnyBody(Pin>>>); - -impl BoxAnyBody { - /// Boxes a `MessageBody` and any errors it generates. - pub fn from_body(body: B) -> Self - where - B: MessageBody + 'static, - B::Error: Into>, - { - let body = MessageBodyMapErr::new(body, Into::into); - Self(Box::pin(body)) - } - - /// Returns a mutable pinned reference to the inner message body type. - pub fn as_pin_mut( - &mut self, - ) -> Pin<&mut (dyn MessageBody>)> { - self.0.as_mut() - } -} - -impl fmt::Debug for BoxAnyBody { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("BoxAnyBody(dyn MessageBody)") - } -} - -impl MessageBody for BoxAnyBody { - type Error = Error; - - fn size(&self) -> BodySize { - self.0.size() - } - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - // TODO: MSRV 1.51: poll_map_err - match ready!(self.0.as_mut().poll_next(cx)) { - Some(Err(err)) => Poll::Ready(Some(Err(Error::new_body().with_cause(err)))), - Some(Ok(val)) => Poll::Ready(Some(Ok(val))), - None => Poll::Ready(None), - } - } -} diff --git a/actix-http/src/body/body_stream.rs b/actix-http/src/body/body_stream.rs index f726f4475..1da7a848a 100644 --- a/actix-http/src/body/body_stream.rs +++ b/actix-http/src/body/body_stream.rs @@ -20,6 +20,8 @@ pin_project! { } } +// TODO: from_infallible method + impl BodyStream where S: Stream>, @@ -75,10 +77,23 @@ mod tests { use derive_more::{Display, Error}; use futures_core::ready; use futures_util::{stream, FutureExt as _}; + use pin_project_lite::pin_project; + use static_assertions::{assert_impl_all, assert_not_impl_all}; use super::*; use crate::body::to_bytes; + assert_impl_all!(BodyStream>>: MessageBody); + assert_impl_all!(BodyStream>>: MessageBody); + assert_impl_all!(BodyStream>>: MessageBody); + assert_impl_all!(BodyStream>>: MessageBody); + assert_impl_all!(BodyStream>>: MessageBody); + + assert_not_impl_all!(BodyStream>: MessageBody); + assert_not_impl_all!(BodyStream>: MessageBody); + // crate::Error is not Clone + assert_not_impl_all!(BodyStream>>: MessageBody); + #[actix_rt::test] async fn skips_empty_chunks() { let body = BodyStream::new(stream::iter( @@ -124,18 +139,44 @@ mod tests { assert!(matches!(to_bytes(body).await, Err(StreamErr))); } + #[actix_rt::test] + async fn stream_string_error() { + // `&'static str` does not impl `Error` + // but it does impl `Into>` + + let body = BodyStream::new(stream::once(async { Err("stringy error") })); + assert!(matches!(to_bytes(body).await, Err("stringy error"))); + } + + #[actix_rt::test] + async fn stream_boxed_error() { + // `Box` does not impl `Error` + // but it does impl `Into>` + + let body = BodyStream::new(stream::once(async { + Err(Box::::from("stringy error")) + })); + + assert_eq!( + to_bytes(body).await.unwrap_err().to_string(), + "stringy error" + ); + } + #[actix_rt::test] async fn stream_delayed_error() { let body = BodyStream::new(stream::iter(vec![Ok(Bytes::from("1")), Err(StreamErr)])); assert!(matches!(to_bytes(body).await, Err(StreamErr))); - #[pin_project::pin_project(project = TimeDelayStreamProj)] - #[derive(Debug)] - enum TimeDelayStream { - Start, - Sleep(Pin>), - Done, + pin_project! { + #[derive(Debug)] + #[project = TimeDelayStreamProj] + enum TimeDelayStream { + Start, + Sleep { delay: Pin> }, + Done, + } } impl Stream for TimeDelayStream { @@ -148,12 +189,14 @@ mod tests { match self.as_mut().get_mut() { TimeDelayStream::Start => { let sleep = sleep(Duration::from_millis(1)); - self.as_mut().set(TimeDelayStream::Sleep(Box::pin(sleep))); + self.as_mut().set(TimeDelayStream::Sleep { + delay: Box::pin(sleep), + }); cx.waker().wake_by_ref(); Poll::Pending } - TimeDelayStream::Sleep(ref mut delay) => { + TimeDelayStream::Sleep { ref mut delay } => { ready!(delay.poll_unpin(cx)); self.set(TimeDelayStream::Done); cx.waker().wake_by_ref(); diff --git a/actix-http/src/body/boxed.rs b/actix-http/src/body/boxed.rs new file mode 100644 index 000000000..9442bd1df --- /dev/null +++ b/actix-http/src/body/boxed.rs @@ -0,0 +1,80 @@ +use std::{ + error::Error as StdError, + fmt, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; + +use super::{BodySize, MessageBody, MessageBodyMapErr}; +use crate::Error; + +/// A boxed message body with boxed errors. +pub struct BoxBody(Pin>>>); + +impl BoxBody { + /// Boxes a `MessageBody` and any errors it generates. + pub fn new(body: B) -> Self + where + B: MessageBody + 'static, + { + let body = MessageBodyMapErr::new(body, Into::into); + Self(Box::pin(body)) + } + + /// Returns a mutable pinned reference to the inner message body type. + pub fn as_pin_mut( + &mut self, + ) -> Pin<&mut (dyn MessageBody>)> { + self.0.as_mut() + } +} + +impl fmt::Debug for BoxBody { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("BoxBody(dyn MessageBody)") + } +} + +impl MessageBody for BoxBody { + type Error = Error; + + fn size(&self) -> BodySize { + self.0.size() + } + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + self.0 + .as_mut() + .poll_next(cx) + .map_err(|err| Error::new_body().with_cause(err)) + } +} + +#[cfg(test)] +mod tests { + + use static_assertions::{assert_impl_all, assert_not_impl_all}; + + use super::*; + use crate::body::to_bytes; + + assert_impl_all!(BoxBody: MessageBody, fmt::Debug, Unpin); + + assert_not_impl_all!(BoxBody: Send, Sync, Unpin); + + #[actix_rt::test] + async fn nested_boxed_body() { + let body = Bytes::from_static(&[1, 2, 3]); + let boxed_body = BoxBody::new(BoxBody::new(body)); + + assert_eq!( + to_bytes(boxed_body).await.unwrap(), + Bytes::from(vec![1, 2, 3]), + ); + } +} diff --git a/actix-http/src/body/either.rs b/actix-http/src/body/either.rs new file mode 100644 index 000000000..6169ee627 --- /dev/null +++ b/actix-http/src/body/either.rs @@ -0,0 +1,83 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use pin_project_lite::pin_project; + +use super::{BodySize, BoxBody, MessageBody}; +use crate::Error; + +pin_project! { + #[project = EitherBodyProj] + #[derive(Debug, Clone)] + pub enum EitherBody { + /// A body of type `L`. + Left { #[pin] body: L }, + + /// A body of type `R`. + Right { #[pin] body: R }, + } +} + +impl EitherBody { + /// Creates new `EitherBody` using left variant and boxed right variant. + pub fn new(body: L) -> Self { + Self::Left { body } + } +} + +impl EitherBody { + /// Creates new `EitherBody` using left variant. + pub fn left(body: L) -> Self { + Self::Left { body } + } + + /// Creates new `EitherBody` using right variant. + pub fn right(body: R) -> Self { + Self::Right { body } + } +} + +impl MessageBody for EitherBody +where + L: MessageBody + 'static, + R: MessageBody + 'static, +{ + type Error = Error; + + fn size(&self) -> BodySize { + match self { + EitherBody::Left { body } => body.size(), + EitherBody::Right { body } => body.size(), + } + } + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.project() { + EitherBodyProj::Left { body } => body + .poll_next(cx) + .map_err(|err| Error::new_body().with_cause(err)), + EitherBodyProj::Right { body } => body + .poll_next(cx) + .map_err(|err| Error::new_body().with_cause(err)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn type_parameter_inference() { + let _body: EitherBody<(), _> = EitherBody::new(()); + + let _body: EitherBody<_, ()> = EitherBody::left(()); + let _body: EitherBody<(), _> = EitherBody::right(()); + } +} diff --git a/actix-http/src/body/message_body.rs b/actix-http/src/body/message_body.rs index 2d2642ba7..053b6f286 100644 --- a/actix-http/src/body/message_body.rs +++ b/actix-http/src/body/message_body.rs @@ -2,6 +2,7 @@ use std::{ convert::Infallible, + error::Error as StdError, mem, pin::Pin, task::{Context, Poll}, @@ -11,13 +12,14 @@ use bytes::{Bytes, BytesMut}; use futures_core::ready; use pin_project_lite::pin_project; -use crate::error::Error; - use super::BodySize; -/// An interface for response bodies. +/// An interface types that can converted to bytes and used as response bodies. +// TODO: examples pub trait MessageBody { - type Error; + // TODO: consider this bound to only fmt::Display since the error type is not really used + // and there is an impl for Into> on String + type Error: Into>; /// Body size hint. fn size(&self) -> BodySize; @@ -29,154 +31,218 @@ pub trait MessageBody { ) -> Poll>>; } -impl MessageBody for () { - type Error = Infallible; +mod foreign_impls { + use super::*; - fn size(&self) -> BodySize { - BodySize::Empty - } + impl MessageBody for Infallible { + type Error = Infallible; - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - Poll::Ready(None) - } -} + #[inline] + fn size(&self) -> BodySize { + match *self {} + } -impl MessageBody for Box -where - B: MessageBody + Unpin, - B::Error: Into, -{ - type Error = B::Error; - - fn size(&self) -> BodySize { - self.as_ref().size() - } - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - Pin::new(self.get_mut().as_mut()).poll_next(cx) - } -} - -impl MessageBody for Pin> -where - B: MessageBody, - B::Error: Into, -{ - type Error = B::Error; - - fn size(&self) -> BodySize { - self.as_ref().size() - } - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - self.as_mut().poll_next(cx) - } -} - -impl MessageBody for Bytes { - type Error = Infallible; - - fn size(&self) -> BodySize { - BodySize::Sized(self.len() as u64) - } - - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - if self.is_empty() { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(mem::take(self.get_mut())))) + #[inline] + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + match *self {} } } -} -impl MessageBody for BytesMut { - type Error = Infallible; + impl MessageBody for () { + type Error = Infallible; - fn size(&self) -> BodySize { - BodySize::Sized(self.len() as u64) - } + #[inline] + fn size(&self) -> BodySize { + BodySize::Sized(0) + } - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - if self.is_empty() { + #[inline] + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(mem::take(self.get_mut()).freeze()))) } } -} -impl MessageBody for &'static str { - type Error = Infallible; + impl MessageBody for Box + where + B: MessageBody + Unpin, + { + type Error = B::Error; - fn size(&self) -> BodySize { - BodySize::Sized(self.len() as u64) - } + #[inline] + fn size(&self) -> BodySize { + self.as_ref().size() + } - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - if self.is_empty() { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(Bytes::from_static( - mem::take(self.get_mut()).as_ref(), - )))) + #[inline] + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + Pin::new(self.get_mut().as_mut()).poll_next(cx) } } -} -impl MessageBody for Vec { - type Error = Infallible; + impl MessageBody for Pin> + where + B: MessageBody, + { + type Error = B::Error; - fn size(&self) -> BodySize { - BodySize::Sized(self.len() as u64) - } + #[inline] + fn size(&self) -> BodySize { + self.as_ref().size() + } - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - if self.is_empty() { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(Bytes::from(mem::take(self.get_mut()))))) + #[inline] + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + self.as_mut().poll_next(cx) } } -} -impl MessageBody for String { - type Error = Infallible; + impl MessageBody for &'static [u8] { + type Error = Infallible; - fn size(&self) -> BodySize { - BodySize::Sized(self.len() as u64) + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + let bytes = mem::take(self.get_mut()); + let bytes = Bytes::from_static(bytes); + Poll::Ready(Some(Ok(bytes))) + } + } } - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - if self.is_empty() { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(Bytes::from( - mem::take(self.get_mut()).into_bytes(), - )))) + impl MessageBody for Bytes { + type Error = Infallible; + + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + let bytes = mem::take(self.get_mut()); + Poll::Ready(Some(Ok(bytes))) + } + } + } + + impl MessageBody for BytesMut { + type Error = Infallible; + + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + let bytes = mem::take(self.get_mut()).freeze(); + Poll::Ready(Some(Ok(bytes))) + } + } + } + + impl MessageBody for Vec { + type Error = Infallible; + + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + let bytes = mem::take(self.get_mut()); + Poll::Ready(Some(Ok(Bytes::from(bytes)))) + } + } + } + + impl MessageBody for &'static str { + type Error = Infallible; + + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + let string = mem::take(self.get_mut()); + let bytes = Bytes::from_static(string.as_bytes()); + Poll::Ready(Some(Ok(bytes))) + } + } + } + + impl MessageBody for String { + type Error = Infallible; + + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + let string = mem::take(self.get_mut()); + Poll::Ready(Some(Ok(Bytes::from(string)))) + } + } + } + + impl MessageBody for bytestring::ByteString { + type Error = Infallible; + + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + let string = mem::take(self.get_mut()); + Poll::Ready(Some(Ok(string.into_bytes()))) } } } @@ -206,6 +272,7 @@ impl MessageBody for MessageBodyMapErr where B: MessageBody, F: FnOnce(B::Error) -> E, + E: Into>, { type Error = E; @@ -230,3 +297,129 @@ where } } } + +#[cfg(test)] +mod tests { + use actix_rt::pin; + use actix_utils::future::poll_fn; + use bytes::{Bytes, BytesMut}; + + use super::*; + + macro_rules! assert_poll_next { + ($pin:expr, $exp:expr) => { + assert_eq!( + poll_fn(|cx| $pin.as_mut().poll_next(cx)) + .await + .unwrap() // unwrap option + .unwrap(), // unwrap result + $exp + ); + }; + } + + macro_rules! assert_poll_next_none { + ($pin:expr) => { + assert!(poll_fn(|cx| $pin.as_mut().poll_next(cx)).await.is_none()); + }; + } + + #[actix_rt::test] + async fn boxing_equivalence() { + assert_eq!(().size(), BodySize::Sized(0)); + assert_eq!(().size(), Box::new(()).size()); + assert_eq!(().size(), Box::pin(()).size()); + + let pl = Box::new(()); + pin!(pl); + assert_poll_next_none!(pl); + + let mut pl = Box::pin(()); + assert_poll_next_none!(pl); + } + + #[actix_rt::test] + async fn test_unit() { + let pl = (); + assert_eq!(pl.size(), BodySize::Sized(0)); + pin!(pl); + assert_poll_next_none!(pl); + } + + #[actix_rt::test] + async fn test_static_str() { + assert_eq!("".size(), BodySize::Sized(0)); + assert_eq!("test".size(), BodySize::Sized(4)); + + let pl = "test"; + pin!(pl); + assert_poll_next!(pl, Bytes::from("test")); + } + + #[actix_rt::test] + async fn test_static_bytes() { + assert_eq!(b"".as_ref().size(), BodySize::Sized(0)); + assert_eq!(b"test".as_ref().size(), BodySize::Sized(4)); + + let pl = b"test".as_ref(); + pin!(pl); + assert_poll_next!(pl, Bytes::from("test")); + } + + #[actix_rt::test] + async fn test_vec() { + assert_eq!(vec![0; 0].size(), BodySize::Sized(0)); + assert_eq!(Vec::from("test").size(), BodySize::Sized(4)); + + let pl = Vec::from("test"); + pin!(pl); + assert_poll_next!(pl, Bytes::from("test")); + } + + #[actix_rt::test] + async fn test_bytes() { + assert_eq!(Bytes::new().size(), BodySize::Sized(0)); + assert_eq!(Bytes::from_static(b"test").size(), BodySize::Sized(4)); + + let pl = Bytes::from_static(b"test"); + pin!(pl); + assert_poll_next!(pl, Bytes::from("test")); + } + + #[actix_rt::test] + async fn test_bytes_mut() { + assert_eq!(BytesMut::new().size(), BodySize::Sized(0)); + assert_eq!(BytesMut::from(b"test".as_ref()).size(), BodySize::Sized(4)); + + let pl = BytesMut::from("test"); + pin!(pl); + assert_poll_next!(pl, Bytes::from("test")); + } + + #[actix_rt::test] + async fn test_string() { + assert_eq!(String::new().size(), BodySize::Sized(0)); + assert_eq!("test".to_owned().size(), BodySize::Sized(4)); + + let pl = "test".to_owned(); + pin!(pl); + assert_poll_next!(pl, Bytes::from("test")); + } + + // down-casting used to be done with a method on MessageBody trait + // test is kept to demonstrate equivalence of Any trait + #[actix_rt::test] + async fn test_body_casting() { + let mut body = String::from("hello cast"); + // let mut resp_body: &mut dyn MessageBody = &mut body; + let resp_body: &mut dyn std::any::Any = &mut body; + let body = resp_body.downcast_ref::().unwrap(); + assert_eq!(body, "hello cast"); + let body = &mut resp_body.downcast_mut::().unwrap(); + body.push('!'); + let body = resp_body.downcast_ref::().unwrap(); + assert_eq!(body, "hello cast!"); + let not_body = resp_body.downcast_ref::<()>(); + assert!(not_body.is_none()); + } +} diff --git a/actix-http/src/body/mod.rs b/actix-http/src/body/mod.rs index 8a08dbd2b..af7c4626f 100644 --- a/actix-http/src/body/mod.rs +++ b/actix-http/src/body/mod.rs @@ -1,263 +1,20 @@ //! Traits and structures to aid consuming and writing HTTP payloads. -use std::task::Poll; - -use actix_rt::pin; -use actix_utils::future::poll_fn; -use bytes::{Bytes, BytesMut}; -use futures_core::ready; - -#[allow(clippy::module_inception)] -mod body; mod body_stream; +mod boxed; +mod either; mod message_body; -mod response_body; +mod none; mod size; mod sized_stream; +mod utils; -pub use self::body::{AnyBody, Body, BoxAnyBody}; pub use self::body_stream::BodyStream; +pub use self::boxed::BoxBody; +pub use self::either::EitherBody; pub use self::message_body::MessageBody; pub(crate) use self::message_body::MessageBodyMapErr; -pub use self::response_body::ResponseBody; +pub use self::none::None; pub use self::size::BodySize; pub use self::sized_stream::SizedStream; - -/// Collects the body produced by a `MessageBody` implementation into `Bytes`. -/// -/// Any errors produced by the body stream are returned immediately. -/// -/// # Examples -/// ``` -/// use actix_http::body::{Body, to_bytes}; -/// use bytes::Bytes; -/// -/// # async fn test_to_bytes() { -/// let body = Body::Empty; -/// let bytes = to_bytes(body).await.unwrap(); -/// assert!(bytes.is_empty()); -/// -/// let body = Body::Bytes(Bytes::from_static(b"123")); -/// let bytes = to_bytes(body).await.unwrap(); -/// assert_eq!(bytes, b"123"[..]); -/// # } -/// ``` -pub async fn to_bytes(body: B) -> Result { - let cap = match body.size() { - BodySize::None | BodySize::Empty | BodySize::Sized(0) => return Ok(Bytes::new()), - BodySize::Sized(size) => size as usize, - BodySize::Stream => 32_768, - }; - - let mut buf = BytesMut::with_capacity(cap); - - pin!(body); - - poll_fn(|cx| loop { - let body = body.as_mut(); - - match ready!(body.poll_next(cx)) { - Some(Ok(bytes)) => buf.extend_from_slice(&*bytes), - None => return Poll::Ready(Ok(())), - Some(Err(err)) => return Poll::Ready(Err(err)), - } - }) - .await?; - - Ok(buf.freeze()) -} - -#[cfg(test)] -mod tests { - use std::pin::Pin; - - use actix_rt::pin; - use actix_utils::future::poll_fn; - use bytes::{Bytes, BytesMut}; - - use super::*; - - impl Body { - pub(crate) fn get_ref(&self) -> &[u8] { - match *self { - Body::Bytes(ref bin) => &bin, - _ => panic!(), - } - } - } - - #[actix_rt::test] - async fn test_static_str() { - assert_eq!(Body::from("").size(), BodySize::Sized(0)); - assert_eq!(Body::from("test").size(), BodySize::Sized(4)); - assert_eq!(Body::from("test").get_ref(), b"test"); - - assert_eq!("test".size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| Pin::new(&mut "test").poll_next(cx)) - .await - .unwrap() - .ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_static_bytes() { - assert_eq!(Body::from(b"test".as_ref()).size(), BodySize::Sized(4)); - assert_eq!(Body::from(b"test".as_ref()).get_ref(), b"test"); - assert_eq!( - Body::from_slice(b"test".as_ref()).size(), - BodySize::Sized(4) - ); - assert_eq!(Body::from_slice(b"test".as_ref()).get_ref(), b"test"); - let sb = Bytes::from(&b"test"[..]); - pin!(sb); - - assert_eq!(sb.size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| sb.as_mut().poll_next(cx)).await.unwrap().ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_vec() { - assert_eq!(Body::from(Vec::from("test")).size(), BodySize::Sized(4)); - assert_eq!(Body::from(Vec::from("test")).get_ref(), b"test"); - let test_vec = Vec::from("test"); - pin!(test_vec); - - assert_eq!(test_vec.size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| test_vec.as_mut().poll_next(cx)) - .await - .unwrap() - .ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_bytes() { - let b = Bytes::from("test"); - assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); - assert_eq!(Body::from(b.clone()).get_ref(), b"test"); - pin!(b); - - assert_eq!(b.size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| b.as_mut().poll_next(cx)).await.unwrap().ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_bytes_mut() { - let b = BytesMut::from("test"); - assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); - assert_eq!(Body::from(b.clone()).get_ref(), b"test"); - pin!(b); - - assert_eq!(b.size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| b.as_mut().poll_next(cx)).await.unwrap().ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_string() { - let b = "test".to_owned(); - assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); - assert_eq!(Body::from(b.clone()).get_ref(), b"test"); - assert_eq!(Body::from(&b).size(), BodySize::Sized(4)); - assert_eq!(Body::from(&b).get_ref(), b"test"); - pin!(b); - - assert_eq!(b.size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| b.as_mut().poll_next(cx)).await.unwrap().ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_unit() { - assert_eq!(().size(), BodySize::Empty); - assert!(poll_fn(|cx| Pin::new(&mut ()).poll_next(cx)) - .await - .is_none()); - } - - #[actix_rt::test] - async fn test_box_and_pin() { - let val = Box::new(()); - pin!(val); - assert_eq!(val.size(), BodySize::Empty); - assert!(poll_fn(|cx| val.as_mut().poll_next(cx)).await.is_none()); - - let mut val = Box::pin(()); - assert_eq!(val.size(), BodySize::Empty); - assert!(poll_fn(|cx| val.as_mut().poll_next(cx)).await.is_none()); - } - - #[actix_rt::test] - async fn test_body_eq() { - assert!( - Body::Bytes(Bytes::from_static(b"1")) - == Body::Bytes(Bytes::from_static(b"1")) - ); - assert!(Body::Bytes(Bytes::from_static(b"1")) != Body::None); - } - - #[actix_rt::test] - async fn test_body_debug() { - assert!(format!("{:?}", Body::None).contains("Body::None")); - assert!(format!("{:?}", Body::Empty).contains("Body::Empty")); - assert!(format!("{:?}", Body::Bytes(Bytes::from_static(b"1"))).contains('1')); - } - - #[actix_rt::test] - async fn test_serde_json() { - use serde_json::{json, Value}; - assert_eq!( - Body::from(serde_json::to_vec(&Value::String("test".to_owned())).unwrap()) - .size(), - BodySize::Sized(6) - ); - assert_eq!( - Body::from(serde_json::to_vec(&json!({"test-key":"test-value"})).unwrap()) - .size(), - BodySize::Sized(25) - ); - } - - // down-casting used to be done with a method on MessageBody trait - // test is kept to demonstrate equivalence of Any trait - #[actix_rt::test] - async fn test_body_casting() { - let mut body = String::from("hello cast"); - // let mut resp_body: &mut dyn MessageBody = &mut body; - let resp_body: &mut dyn std::any::Any = &mut body; - let body = resp_body.downcast_ref::().unwrap(); - assert_eq!(body, "hello cast"); - let body = &mut resp_body.downcast_mut::().unwrap(); - body.push('!'); - let body = resp_body.downcast_ref::().unwrap(); - assert_eq!(body, "hello cast!"); - let not_body = resp_body.downcast_ref::<()>(); - assert!(not_body.is_none()); - } - - #[actix_rt::test] - async fn test_to_bytes() { - let body = Body::Empty; - let bytes = to_bytes(body).await.unwrap(); - assert!(bytes.is_empty()); - - let body = Body::Bytes(Bytes::from_static(b"123")); - let bytes = to_bytes(body).await.unwrap(); - assert_eq!(bytes, b"123"[..]); - } -} +pub use self::utils::to_bytes; diff --git a/actix-http/src/body/none.rs b/actix-http/src/body/none.rs new file mode 100644 index 000000000..0fc7c8c9f --- /dev/null +++ b/actix-http/src/body/none.rs @@ -0,0 +1,43 @@ +use std::{ + convert::Infallible, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; + +use super::{BodySize, MessageBody}; + +/// Body type for responses that forbid payloads. +/// +/// Distinct from an empty response which would contain a Content-Length header. +/// +/// For an "empty" body, use `()` or `Bytes::new()`. +#[derive(Debug, Clone, Copy, Default)] +#[non_exhaustive] +pub struct None; + +impl None { + /// Constructs new "none" body. + #[inline] + pub fn new() -> Self { + None + } +} + +impl MessageBody for None { + type Error = Infallible; + + #[inline] + fn size(&self) -> BodySize { + BodySize::None + } + + #[inline] + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + Poll::Ready(Option::None) + } +} diff --git a/actix-http/src/body/response_body.rs b/actix-http/src/body/response_body.rs deleted file mode 100644 index 855c742f2..000000000 --- a/actix-http/src/body/response_body.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::{ - mem, - pin::Pin, - task::{Context, Poll}, -}; - -use bytes::Bytes; -use futures_core::{ready, Stream}; -use pin_project::pin_project; - -use crate::error::Error; - -use super::{Body, BodySize, MessageBody}; - -#[pin_project(project = ResponseBodyProj)] -pub enum ResponseBody { - Body(#[pin] B), - Other(Body), -} - -impl ResponseBody { - pub fn into_body(self) -> ResponseBody { - match self { - ResponseBody::Body(b) => ResponseBody::Other(b), - ResponseBody::Other(b) => ResponseBody::Other(b), - } - } -} - -impl ResponseBody { - pub fn take_body(&mut self) -> ResponseBody { - mem::replace(self, ResponseBody::Other(Body::None)) - } -} - -impl ResponseBody { - pub fn as_ref(&self) -> Option<&B> { - if let ResponseBody::Body(ref b) = self { - Some(b) - } else { - None - } - } -} - -impl MessageBody for ResponseBody -where - B: MessageBody, - B::Error: Into, -{ - type Error = Error; - - fn size(&self) -> BodySize { - match self { - ResponseBody::Body(ref body) => body.size(), - ResponseBody::Other(ref body) => body.size(), - } - } - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - Stream::poll_next(self, cx) - } -} - -impl Stream for ResponseBody -where - B: MessageBody, - B::Error: Into, -{ - type Item = Result; - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match self.project() { - // TODO: MSRV 1.51: poll_map_err - ResponseBodyProj::Body(body) => match ready!(body.poll_next(cx)) { - Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), - Some(Ok(val)) => Poll::Ready(Some(Ok(val))), - None => Poll::Ready(None), - }, - ResponseBodyProj::Other(body) => Pin::new(body).poll_next(cx), - } - } -} diff --git a/actix-http/src/body/size.rs b/actix-http/src/body/size.rs index 775d5b8f1..d64af9d44 100644 --- a/actix-http/src/body/size.rs +++ b/actix-http/src/body/size.rs @@ -6,14 +6,9 @@ pub enum BodySize { /// Will skip writing Content-Length header. None, - /// Zero size body. - /// - /// Will write `Content-Length: 0` header. - Empty, - /// Known size body. /// - /// Will write `Content-Length: N` header. `Sized(0)` is treated the same as `Empty`. + /// Will write `Content-Length: N` header. Sized(u64), /// Unknown size body. @@ -23,18 +18,19 @@ pub enum BodySize { } impl BodySize { - /// Returns true if size hint indicates no or empty body. + /// Returns true if size hint indicates omitted or empty body. + /// + /// Streams will return false because it cannot be known without reading the stream. /// /// ``` /// # use actix_http::body::BodySize; /// assert!(BodySize::None.is_eof()); - /// assert!(BodySize::Empty.is_eof()); /// assert!(BodySize::Sized(0).is_eof()); /// /// assert!(!BodySize::Sized(64).is_eof()); /// assert!(!BodySize::Stream.is_eof()); /// ``` pub fn is_eof(&self) -> bool { - matches!(self, BodySize::None | BodySize::Empty | BodySize::Sized(0)) + matches!(self, BodySize::None | BodySize::Sized(0)) } } diff --git a/actix-http/src/body/sized_stream.rs b/actix-http/src/body/sized_stream.rs index b6ceb32fe..c8606897d 100644 --- a/actix-http/src/body/sized_stream.rs +++ b/actix-http/src/body/sized_stream.rs @@ -32,6 +32,8 @@ where } } +// TODO: from_infallible method + impl MessageBody for SizedStream where S: Stream>, @@ -72,10 +74,22 @@ mod tests { use actix_rt::pin; use actix_utils::future::poll_fn; use futures_util::stream; + use static_assertions::{assert_impl_all, assert_not_impl_all}; use super::*; use crate::body::to_bytes; + assert_impl_all!(SizedStream>>: MessageBody); + assert_impl_all!(SizedStream>>: MessageBody); + assert_impl_all!(SizedStream>>: MessageBody); + assert_impl_all!(SizedStream>>: MessageBody); + assert_impl_all!(SizedStream>>: MessageBody); + + assert_not_impl_all!(SizedStream>: MessageBody); + assert_not_impl_all!(SizedStream>: MessageBody); + // crate::Error is not Clone + assert_not_impl_all!(SizedStream>>: MessageBody); + #[actix_rt::test] async fn skips_empty_chunks() { let body = SizedStream::new( @@ -119,4 +133,37 @@ mod tests { assert_eq!(to_bytes(body).await.ok(), Some(Bytes::from("12"))); } + + #[actix_rt::test] + async fn stream_string_error() { + // `&'static str` does not impl `Error` + // but it does impl `Into>` + + let body = SizedStream::new(0, stream::once(async { Err("stringy error") })); + assert_eq!(to_bytes(body).await, Ok(Bytes::new())); + + let body = SizedStream::new(1, stream::once(async { Err("stringy error") })); + assert!(matches!(to_bytes(body).await, Err("stringy error"))); + } + + #[actix_rt::test] + async fn stream_boxed_error() { + // `Box` does not impl `Error` + // but it does impl `Into>` + + let body = SizedStream::new( + 0, + stream::once(async { Err(Box::::from("stringy error")) }), + ); + assert_eq!(to_bytes(body).await.unwrap(), Bytes::new()); + + let body = SizedStream::new( + 1, + stream::once(async { Err(Box::::from("stringy error")) }), + ); + assert_eq!( + to_bytes(body).await.unwrap_err().to_string(), + "stringy error" + ); + } } diff --git a/actix-http/src/body/utils.rs b/actix-http/src/body/utils.rs new file mode 100644 index 000000000..a421ffd76 --- /dev/null +++ b/actix-http/src/body/utils.rs @@ -0,0 +1,78 @@ +use std::task::Poll; + +use actix_rt::pin; +use actix_utils::future::poll_fn; +use bytes::{Bytes, BytesMut}; +use futures_core::ready; + +use super::{BodySize, MessageBody}; + +/// Collects the body produced by a `MessageBody` implementation into `Bytes`. +/// +/// Any errors produced by the body stream are returned immediately. +/// +/// # Examples +/// ``` +/// use actix_http::body::{self, to_bytes}; +/// use bytes::Bytes; +/// +/// # async fn test_to_bytes() { +/// let body = body::None::new(); +/// let bytes = to_bytes(body).await.unwrap(); +/// assert!(bytes.is_empty()); +/// +/// let body = Bytes::from_static(b"123"); +/// let bytes = to_bytes(body).await.unwrap(); +/// assert_eq!(bytes, b"123"[..]); +/// # } +/// ``` +pub async fn to_bytes(body: B) -> Result { + let cap = match body.size() { + BodySize::None | BodySize::Sized(0) => return Ok(Bytes::new()), + BodySize::Sized(size) => size as usize, + // good enough first guess for chunk size + BodySize::Stream => 32_768, + }; + + let mut buf = BytesMut::with_capacity(cap); + + pin!(body); + + poll_fn(|cx| loop { + let body = body.as_mut(); + + match ready!(body.poll_next(cx)) { + Some(Ok(bytes)) => buf.extend_from_slice(&*bytes), + None => return Poll::Ready(Ok(())), + Some(Err(err)) => return Poll::Ready(Err(err)), + } + }) + .await?; + + Ok(buf.freeze()) +} + +#[cfg(test)] +mod test { + use futures_util::{stream, StreamExt as _}; + + use super::*; + use crate::{body::BodyStream, Error}; + + #[actix_rt::test] + async fn test_to_bytes() { + let bytes = to_bytes(()).await.unwrap(); + assert!(bytes.is_empty()); + + let body = Bytes::from_static(b"123"); + let bytes = to_bytes(body).await.unwrap(); + assert_eq!(bytes, b"123"[..]); + + let stream = + stream::iter(vec![Bytes::from_static(b"123"), Bytes::from_static(b"abc")]) + .map(Ok::<_, Error>); + let body = BodyStream::new(stream); + let bytes = to_bytes(body).await.unwrap(); + assert_eq!(bytes, b"123abc"[..]); + } +} diff --git a/actix-http/src/builder.rs b/actix-http/src/builder.rs index 04e8282e8..1e2c0ec2d 100644 --- a/actix-http/src/builder.rs +++ b/actix-http/src/builder.rs @@ -1,10 +1,10 @@ -use std::{error::Error as StdError, fmt, marker::PhantomData, net, rc::Rc}; +use std::{fmt, marker::PhantomData, net, rc::Rc}; use actix_codec::Framed; use actix_service::{IntoServiceFactory, Service, ServiceFactory}; use crate::{ - body::{AnyBody, MessageBody}, + body::{BoxBody, MessageBody}, config::{KeepAlive, ServiceConfig}, extensions::CloneableExtensions, h1::{self, ExpectHandler, H1Service, UpgradeHandler}, @@ -32,7 +32,7 @@ pub struct HttpServiceBuilder { impl HttpServiceBuilder where S: ServiceFactory, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, >::Future: 'static, { @@ -55,11 +55,11 @@ where impl HttpServiceBuilder where S: ServiceFactory, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, >::Future: 'static, X: ServiceFactory, - X::Error: Into>, + X::Error: Into>, X::InitError: fmt::Debug, U: ServiceFactory<(Request, Framed), Config = (), Response = ()>, U::Error: fmt::Display, @@ -121,7 +121,7 @@ where where F: IntoServiceFactory, X1: ServiceFactory, - X1::Error: Into>, + X1::Error: Into>, X1::InitError: fmt::Debug, { HttpServiceBuilder { @@ -179,7 +179,7 @@ where where B: MessageBody, F: IntoServiceFactory, - S::Error: Into>, + S::Error: Into>, S::InitError: fmt::Debug, S::Response: Into>, { @@ -201,12 +201,11 @@ where pub fn h2(self, service: F) -> H2Service where F: IntoServiceFactory, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, B: MessageBody + 'static, - B::Error: Into>, { let cfg = ServiceConfig::new( self.keep_alive, @@ -224,12 +223,11 @@ where pub fn finish(self, service: F) -> HttpService where F: IntoServiceFactory, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, B: MessageBody + 'static, - B::Error: Into>, { let cfg = ServiceConfig::new( self.keep_alive, diff --git a/actix-http/src/config.rs b/actix-http/src/config.rs index 97750ff76..5d020edfc 100644 --- a/actix-http/src/config.rs +++ b/actix-http/src/config.rs @@ -1,26 +1,29 @@ -use std::cell::Cell; -use std::fmt::Write; -use std::rc::Rc; -use std::time::Duration; -use std::{fmt, net}; +use std::{ + cell::Cell, + fmt::{self, Write}, + net, + rc::Rc, + time::{Duration, SystemTime}, +}; use actix_rt::{ task::JoinHandle, time::{interval, sleep_until, Instant, Sleep}, }; use bytes::BytesMut; -use time::OffsetDateTime; /// "Sun, 06 Nov 1994 08:49:37 GMT".len() -const DATE_VALUE_LENGTH: usize = 29; +pub(crate) const DATE_VALUE_LENGTH: usize = 29; #[derive(Debug, PartialEq, Clone, Copy)] /// Server keep-alive setting pub enum KeepAlive { /// Keep alive in seconds Timeout(usize), + /// Rely on OS to shutdown tcp connection Os, + /// Disabled Disabled, } @@ -206,12 +209,7 @@ impl Date { fn update(&mut self) { self.pos = 0; - write!( - self, - "{}", - OffsetDateTime::now_utc().format("%a, %d %b %Y %H:%M:%S GMT") - ) - .unwrap(); + write!(self, "{}", httpdate::fmt_http_date(SystemTime::now())).unwrap(); } } @@ -269,11 +267,11 @@ impl DateService { } // TODO: move to a util module for testing all spawn handle drop style tasks. -#[cfg(test)] /// Test Module for checking the drop state of certain async tasks that are spawned /// with `actix_rt::spawn` /// /// The target task must explicitly generate `NotifyOnDrop` when spawn the task +#[cfg(test)] mod notify_on_drop { use std::cell::RefCell; @@ -283,9 +281,8 @@ mod notify_on_drop { /// Check if the spawned task is dropped. /// - /// # Panic: - /// - /// When there was no `NotifyOnDrop` instance on current thread + /// # Panics + /// Panics when there was no `NotifyOnDrop` instance on current thread. pub(crate) fn is_dropped() -> bool { NOTIFY_DROPPED.with(|bool| { bool.borrow() diff --git a/actix-http/src/encoding/decoder.rs b/actix-http/src/encoding/decoder.rs index d3e304836..c32983fc7 100644 --- a/actix-http/src/encoding/decoder.rs +++ b/actix-http/src/encoding/decoder.rs @@ -80,7 +80,7 @@ where let encoding = headers .get(&CONTENT_ENCODING) .and_then(|val| val.to_str().ok()) - .map(ContentEncoding::from) + .and_then(|x| x.parse().ok()) .unwrap_or(ContentEncoding::Identity); Self::new(stream, encoding) diff --git a/actix-http/src/encoding/encoder.rs b/actix-http/src/encoding/encoder.rs index 1e69990a0..49e5663dc 100644 --- a/actix-http/src/encoding/encoder.rs +++ b/actix-http/src/encoding/encoder.rs @@ -12,7 +12,7 @@ use actix_rt::task::{spawn_blocking, JoinHandle}; use bytes::Bytes; use derive_more::Display; use futures_core::ready; -use pin_project::pin_project; +use pin_project_lite::pin_project; #[cfg(feature = "compress-brotli")] use brotli2::write::BrotliEncoder; @@ -23,99 +23,103 @@ use flate2::write::{GzEncoder, ZlibEncoder}; #[cfg(feature = "compress-zstd")] use zstd::stream::write::Encoder as ZstdEncoder; +use super::Writer; use crate::{ - body::{Body, BodySize, BoxAnyBody, MessageBody, ResponseBody}, + body::{BodySize, MessageBody}, + error::BlockingError, http::{ header::{ContentEncoding, CONTENT_ENCODING}, HeaderValue, StatusCode, }, - Error, ResponseHead, + ResponseHead, }; -use super::Writer; -use crate::error::BlockingError; - const MAX_CHUNK_SIZE_ENCODE_IN_PLACE: usize = 1024; -#[pin_project] -pub struct Encoder { - eof: bool, - #[pin] - body: EncoderBody, - encoder: Option, - fut: Option>>, +pin_project! { + pub struct Encoder { + #[pin] + body: EncoderBody, + encoder: Option, + fut: Option>>, + eof: bool, + } } impl Encoder { + fn none() -> Self { + Encoder { + body: EncoderBody::None, + encoder: None, + fut: None, + eof: true, + } + } + pub fn response( encoding: ContentEncoding, head: &mut ResponseHead, - body: ResponseBody, - ) -> ResponseBody> { + body: B, + ) -> Self { let can_encode = !(head.headers().contains_key(&CONTENT_ENCODING) || head.status == StatusCode::SWITCHING_PROTOCOLS || head.status == StatusCode::NO_CONTENT || encoding == ContentEncoding::Identity || encoding == ContentEncoding::Auto); - let body = match body { - ResponseBody::Other(b) => match b { - Body::None => return ResponseBody::Other(Body::None), - Body::Empty => return ResponseBody::Other(Body::Empty), - Body::Bytes(buf) => { - if can_encode { - EncoderBody::Bytes(buf) - } else { - return ResponseBody::Other(Body::Bytes(buf)); - } - } - Body::Message(stream) => EncoderBody::BoxedStream(stream), - }, - ResponseBody::Body(stream) => EncoderBody::Stream(stream), - }; + match body.size() { + // no need to compress an empty body + BodySize::None => return Self::none(), + + // we cannot assume that Sized is not a stream + BodySize::Sized(_) | BodySize::Stream => {} + } + + // TODO potentially some optimisation for single-chunk responses here by trying to read the + // payload eagerly, stopping after 2 polls if the first is a chunk and the second is None if can_encode { - // Modify response body only if encoder is not None + // Modify response body only if encoder is set if let Some(enc) = ContentEncoder::encoder(encoding) { update_head(encoding, head); head.no_chunking(false); - return ResponseBody::Body(Encoder { - body, - eof: false, - fut: None, + + return Encoder { + body: EncoderBody::Stream { body }, encoder: Some(enc), - }); + fut: None, + eof: false, + }; } } - ResponseBody::Body(Encoder { - body, - eof: false, - fut: None, + Encoder { + body: EncoderBody::Stream { body }, encoder: None, - }) + fut: None, + eof: false, + } } } -#[pin_project(project = EncoderBodyProj)] -enum EncoderBody { - Bytes(Bytes), - Stream(#[pin] B), - BoxedStream(BoxAnyBody), +pin_project! { + #[project = EncoderBodyProj] + enum EncoderBody { + None, + Stream { #[pin] body: B }, + } } impl MessageBody for EncoderBody where B: MessageBody, - B::Error: Into, { - type Error = EncoderError; + type Error = EncoderError; fn size(&self) -> BodySize { match self { - EncoderBody::Bytes(ref b) => b.size(), - EncoderBody::Stream(ref b) => b.size(), - EncoderBody::BoxedStream(ref b) => b.size(), + EncoderBody::None => BodySize::None, + EncoderBody::Stream { body } => body.size(), } } @@ -124,26 +128,11 @@ where cx: &mut Context<'_>, ) -> Poll>> { match self.project() { - EncoderBodyProj::Bytes(b) => { - if b.is_empty() { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(std::mem::take(b)))) - } - } - // TODO: MSRV 1.51: poll_map_err - EncoderBodyProj::Stream(b) => match ready!(b.poll_next(cx)) { - Some(Err(err)) => Poll::Ready(Some(Err(EncoderError::Body(err)))), - Some(Ok(val)) => Poll::Ready(Some(Ok(val))), - None => Poll::Ready(None), - }, - EncoderBodyProj::BoxedStream(ref mut b) => { - match ready!(b.as_pin_mut().poll_next(cx)) { - Some(Err(err)) => Poll::Ready(Some(Err(EncoderError::Boxed(err)))), - Some(Ok(val)) => Poll::Ready(Some(Ok(val))), - None => Poll::Ready(None), - } - } + EncoderBodyProj::None => Poll::Ready(None), + + EncoderBodyProj::Stream { body } => body + .poll_next(cx) + .map_err(|err| EncoderError::Body(err.into())), } } } @@ -151,9 +140,8 @@ where impl MessageBody for Encoder where B: MessageBody, - B::Error: Into, { - type Error = EncoderError; + type Error = EncoderError; fn size(&self) -> BodySize { if self.encoder.is_none() { @@ -216,6 +204,7 @@ where None => { if let Some(encoder) = this.encoder.take() { let chunk = encoder.finish().map_err(EncoderError::Io)?; + if chunk.is_empty() { return Poll::Ready(None); } else { @@ -241,12 +230,15 @@ fn update_head(encoding: ContentEncoding, head: &mut ResponseHead) { enum ContentEncoder { #[cfg(feature = "compress-gzip")] Deflate(ZlibEncoder), + #[cfg(feature = "compress-gzip")] Gzip(GzEncoder), + #[cfg(feature = "compress-brotli")] Br(BrotliEncoder), - // We need explicit 'static lifetime here because ZstdEncoder need lifetime - // argument, and we use `spawn_blocking` in `Encoder::poll_next` that require `FnOnce() -> R + Send + 'static` + + // Wwe need explicit 'static lifetime here because ZstdEncoder needs a lifetime argument and we + // use `spawn_blocking` in `Encoder::poll_next` that requires `FnOnce() -> R + Send + 'static`. #[cfg(feature = "compress-zstd")] Zstd(ZstdEncoder<'static, Writer>), } @@ -259,20 +251,24 @@ impl ContentEncoder { Writer::new(), flate2::Compression::fast(), ))), + #[cfg(feature = "compress-gzip")] ContentEncoding::Gzip => Some(ContentEncoder::Gzip(GzEncoder::new( Writer::new(), flate2::Compression::fast(), ))), + #[cfg(feature = "compress-brotli")] ContentEncoding::Br => { Some(ContentEncoder::Br(BrotliEncoder::new(Writer::new(), 3))) } + #[cfg(feature = "compress-zstd")] ContentEncoding::Zstd => { let encoder = ZstdEncoder::new(Writer::new(), 3).ok()?; Some(ContentEncoder::Zstd(encoder)) } + _ => None, } } @@ -282,10 +278,13 @@ impl ContentEncoder { match *self { #[cfg(feature = "compress-brotli")] ContentEncoder::Br(ref mut encoder) => encoder.get_mut().take(), + #[cfg(feature = "compress-gzip")] ContentEncoder::Deflate(ref mut encoder) => encoder.get_mut().take(), + #[cfg(feature = "compress-gzip")] ContentEncoder::Gzip(ref mut encoder) => encoder.get_mut().take(), + #[cfg(feature = "compress-zstd")] ContentEncoder::Zstd(ref mut encoder) => encoder.get_mut().take(), } @@ -298,16 +297,19 @@ impl ContentEncoder { Ok(writer) => Ok(writer.buf.freeze()), Err(err) => Err(err), }, + #[cfg(feature = "compress-gzip")] ContentEncoder::Gzip(encoder) => match encoder.finish() { Ok(writer) => Ok(writer.buf.freeze()), Err(err) => Err(err), }, + #[cfg(feature = "compress-gzip")] ContentEncoder::Deflate(encoder) => match encoder.finish() { Ok(writer) => Ok(writer.buf.freeze()), Err(err) => Err(err), }, + #[cfg(feature = "compress-zstd")] ContentEncoder::Zstd(encoder) => match encoder.finish() { Ok(writer) => Ok(writer.buf.freeze()), @@ -326,6 +328,7 @@ impl ContentEncoder { Err(err) } }, + #[cfg(feature = "compress-gzip")] ContentEncoder::Gzip(ref mut encoder) => match encoder.write_all(data) { Ok(_) => Ok(()), @@ -334,6 +337,7 @@ impl ContentEncoder { Err(err) } }, + #[cfg(feature = "compress-gzip")] ContentEncoder::Deflate(ref mut encoder) => match encoder.write_all(data) { Ok(_) => Ok(()), @@ -342,6 +346,7 @@ impl ContentEncoder { Err(err) } }, + #[cfg(feature = "compress-zstd")] ContentEncoder::Zstd(ref mut encoder) => match encoder.write_all(data) { Ok(_) => Ok(()), @@ -356,12 +361,9 @@ impl ContentEncoder { #[derive(Debug, Display)] #[non_exhaustive] -pub enum EncoderError { +pub enum EncoderError { #[display(fmt = "body")] - Body(E), - - #[display(fmt = "boxed")] - Boxed(Box), + Body(Box), #[display(fmt = "blocking")] Blocking(BlockingError), @@ -370,19 +372,18 @@ pub enum EncoderError { Io(io::Error), } -impl StdError for EncoderError { +impl StdError for EncoderError { fn source(&self) -> Option<&(dyn StdError + 'static)> { match self { - EncoderError::Body(err) => Some(err), - EncoderError::Boxed(err) => Some(&**err), + EncoderError::Body(err) => Some(&**err), EncoderError::Blocking(err) => Some(err), EncoderError::Io(err) => Some(err), } } } -impl From> for crate::Error { - fn from(err: EncoderError) -> Self { +impl From for crate::Error { + fn from(err: EncoderError) -> Self { crate::Error::new_encoder().with_cause(err) } } diff --git a/actix-http/src/encoding/mod.rs b/actix-http/src/encoding/mod.rs index cb271c638..d51dd66c0 100644 --- a/actix-http/src/encoding/mod.rs +++ b/actix-http/src/encoding/mod.rs @@ -10,6 +10,9 @@ mod encoder; pub use self::decoder::Decoder; pub use self::encoder::Encoder; +/// Special-purpose writer for streaming (de-)compression. +/// +/// Pre-allocates 8KiB of capacity. pub(self) struct Writer { buf: BytesMut, } diff --git a/actix-http/src/error.rs b/actix-http/src/error.rs index 54666e072..231e90e57 100644 --- a/actix-http/src/error.rs +++ b/actix-http/src/error.rs @@ -5,10 +5,7 @@ use std::{error::Error as StdError, fmt, io, str::Utf8Error, string::FromUtf8Err use derive_more::{Display, Error, From}; use http::{uri::InvalidUri, StatusCode}; -use crate::{ - body::{AnyBody, Body}, - ws, Response, -}; +use crate::{body::BoxBody, ws, Response}; pub use http::Error as HttpError; @@ -29,6 +26,11 @@ impl Error { } } + pub(crate) fn with_cause(mut self, cause: impl Into>) -> Self { + self.inner.cause = Some(cause.into()); + self + } + pub(crate) fn new_http() -> Self { Self::new(Kind::Http) } @@ -49,14 +51,12 @@ impl Error { Self::new(Kind::SendResponse) } - // TODO: remove allow - #[allow(dead_code)] + #[allow(unused)] // reserved for future use (TODO: remove allow when being used) pub(crate) fn new_io() -> Self { Self::new(Kind::Io) } - // used in encoder behind feature flag so ignore unused warning - #[allow(unused)] + #[allow(unused)] // used in encoder behind feature flag so ignore unused warning pub(crate) fn new_encoder() -> Self { Self::new(Kind::Encoder) } @@ -64,26 +64,22 @@ impl Error { pub(crate) fn new_ws() -> Self { Self::new(Kind::Ws) } - - pub(crate) fn with_cause(mut self, cause: impl Into>) -> Self { - self.inner.cause = Some(cause.into()); - self - } } -impl From for Response { +impl From for Response { fn from(err: Error) -> Self { + // TODO: more appropriate error status codes, usage assessment needed let status_code = match err.inner.kind { Kind::Parse => StatusCode::BAD_REQUEST, _ => StatusCode::INTERNAL_SERVER_ERROR, }; - Response::new(status_code).set_body(Body::from(err.to_string())) + Response::new(status_code).set_body(BoxBody::new(err.to_string())) } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Display)] -pub enum Kind { +pub(crate) enum Kind { #[display(fmt = "error processing HTTP")] Http, @@ -137,12 +133,6 @@ impl From for Error { } } -impl From for Error { - fn from(err: ws::ProtocolError) -> Self { - Self::new_ws().with_cause(err) - } -} - impl From for Error { fn from(err: HttpError) -> Self { Self::new_http().with_cause(err) @@ -155,6 +145,12 @@ impl From for Error { } } +impl From for Error { + fn from(err: ws::ProtocolError) -> Self { + Self::new_ws().with_cause(err) + } +} + /// A set of errors that can occur during parsing HTTP streams. #[derive(Debug, Display, Error)] #[non_exhaustive] @@ -196,7 +192,7 @@ pub enum ParseError { #[display(fmt = "IO error: {}", _0)] Io(io::Error), - /// Parsing a field as string failed + /// Parsing a field as string failed. #[display(fmt = "UTF8 error: {}", _0)] Utf8(Utf8Error), } @@ -245,7 +241,7 @@ impl From for Error { } } -impl From for Response { +impl From for Response { fn from(err: ParseError) -> Self { Error::from(err).into() } @@ -342,7 +338,7 @@ pub enum DispatchError { /// Service error // FIXME: display and error type #[display(fmt = "Service Error")] - Service(#[error(not(source))] Response), + Service(#[error(not(source))] Response), /// Body error // FIXME: display and error type @@ -426,11 +422,11 @@ mod tests { #[test] fn test_into_response() { - let resp: Response = ParseError::Incomplete.into(); + let resp: Response = ParseError::Incomplete.into(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let err: HttpError = StatusCode::from_u16(10000).err().unwrap().into(); - let resp: Response = Error::new_http().with_cause(err).into(); + let resp: Response = Error::new_http().with_cause(err).into(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); } @@ -455,7 +451,7 @@ mod tests { fn test_error_http_response() { let orig = io::Error::new(io::ErrorKind::Other, "other"); let err = Error::new_io().with_cause(orig); - let resp: Response = err.into(); + let resp: Response = err.into(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); } diff --git a/actix-http/src/h1/chunked.rs b/actix-http/src/h1/chunked.rs new file mode 100644 index 000000000..e5b734fff --- /dev/null +++ b/actix-http/src/h1/chunked.rs @@ -0,0 +1,432 @@ +use std::{io, task::Poll}; + +use bytes::{Buf as _, Bytes, BytesMut}; + +macro_rules! byte ( + ($rdr:ident) => ({ + if $rdr.len() > 0 { + let b = $rdr[0]; + $rdr.advance(1); + b + } else { + return Poll::Pending + } + }) +); + +#[derive(Debug, PartialEq, Clone)] +pub(super) enum ChunkedState { + Size, + SizeLws, + Extension, + SizeLf, + Body, + BodyCr, + BodyLf, + EndCr, + EndLf, + End, +} + +impl ChunkedState { + pub(super) fn step( + &self, + body: &mut BytesMut, + size: &mut u64, + buf: &mut Option, + ) -> Poll> { + use self::ChunkedState::*; + match *self { + Size => ChunkedState::read_size(body, size), + SizeLws => ChunkedState::read_size_lws(body), + Extension => ChunkedState::read_extension(body), + SizeLf => ChunkedState::read_size_lf(body, *size), + Body => ChunkedState::read_body(body, size, buf), + BodyCr => ChunkedState::read_body_cr(body), + BodyLf => ChunkedState::read_body_lf(body), + EndCr => ChunkedState::read_end_cr(body), + EndLf => ChunkedState::read_end_lf(body), + End => Poll::Ready(Ok(ChunkedState::End)), + } + } + + fn read_size( + rdr: &mut BytesMut, + size: &mut u64, + ) -> Poll> { + let radix = 16; + + let rem = match byte!(rdr) { + b @ b'0'..=b'9' => b - b'0', + b @ b'a'..=b'f' => b + 10 - b'a', + b @ b'A'..=b'F' => b + 10 - b'A', + b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => return Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size line: Invalid Size", + ))); + } + }; + + match size.checked_mul(radix) { + Some(n) => { + *size = n as u64; + *size += rem as u64; + + Poll::Ready(Ok(ChunkedState::Size)) + } + None => { + log::debug!("chunk size would overflow u64"); + Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size line: Size is too big", + ))) + } + } + } + + fn read_size_lws(rdr: &mut BytesMut) -> Poll> { + match byte!(rdr) { + // LWS can follow the chunk size, but no more digits can come + b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size linear white space", + ))), + } + } + fn read_extension(rdr: &mut BytesMut) -> Poll> { + match byte!(rdr) { + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + // strictly 0x20 (space) should be disallowed but we don't parse quoted strings here + 0x00..=0x08 | 0x0a..=0x1f | 0x7f => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid character in chunk extension", + ))), + _ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions + } + } + fn read_size_lf( + rdr: &mut BytesMut, + size: u64, + ) -> Poll> { + match byte!(rdr) { + b'\n' if size > 0 => Poll::Ready(Ok(ChunkedState::Body)), + b'\n' if size == 0 => Poll::Ready(Ok(ChunkedState::EndCr)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size LF", + ))), + } + } + + fn read_body( + rdr: &mut BytesMut, + rem: &mut u64, + buf: &mut Option, + ) -> Poll> { + log::trace!("Chunked read, remaining={:?}", rem); + + let len = rdr.len() as u64; + if len == 0 { + Poll::Ready(Ok(ChunkedState::Body)) + } else { + let slice; + if *rem > len { + slice = rdr.split().freeze(); + *rem -= len; + } else { + slice = rdr.split_to(*rem as usize).freeze(); + *rem = 0; + } + *buf = Some(slice); + if *rem > 0 { + Poll::Ready(Ok(ChunkedState::Body)) + } else { + Poll::Ready(Ok(ChunkedState::BodyCr)) + } + } + } + + fn read_body_cr(rdr: &mut BytesMut) -> Poll> { + match byte!(rdr) { + b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk body CR", + ))), + } + } + fn read_body_lf(rdr: &mut BytesMut) -> Poll> { + match byte!(rdr) { + b'\n' => Poll::Ready(Ok(ChunkedState::Size)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk body LF", + ))), + } + } + fn read_end_cr(rdr: &mut BytesMut) -> Poll> { + match byte!(rdr) { + b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk end CR", + ))), + } + } + fn read_end_lf(rdr: &mut BytesMut) -> Poll> { + match byte!(rdr) { + b'\n' => Poll::Ready(Ok(ChunkedState::End)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk end LF", + ))), + } + } +} + +#[cfg(test)] +mod tests { + use actix_codec::Decoder as _; + use bytes::{Bytes, BytesMut}; + use http::Method; + + use crate::{ + error::ParseError, + h1::decoder::{MessageDecoder, PayloadItem}, + HttpMessage as _, Request, + }; + + macro_rules! parse_ready { + ($e:expr) => {{ + match MessageDecoder::::default().decode($e) { + Ok(Some((msg, _))) => msg, + Ok(_) => unreachable!("Eof during parsing http request"), + Err(err) => unreachable!("Error during parsing http request: {:?}", err), + } + }}; + } + + macro_rules! expect_parse_err { + ($e:expr) => {{ + match MessageDecoder::::default().decode($e) { + Err(err) => match err { + ParseError::Io(_) => unreachable!("Parse error expected"), + _ => {} + }, + _ => unreachable!("Error expected"), + } + }}; + } + + #[test] + fn test_parse_chunked_payload_chunk_extension() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\ + \r\n", + ); + + let mut reader = MessageDecoder::::default(); + let (msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert!(msg.chunked().unwrap()); + + buf.extend(b"4;test\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); // test: test\r\n\r\n") + let chunk = pl.decode(&mut buf).unwrap().unwrap().chunk(); + assert_eq!(chunk, Bytes::from_static(b"data")); + let chunk = pl.decode(&mut buf).unwrap().unwrap().chunk(); + assert_eq!(chunk, Bytes::from_static(b"line")); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert!(msg.eof()); + } + + #[test] + fn test_request_chunked() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + if let Ok(val) = req.chunked() { + assert!(val); + } else { + unreachable!("Error"); + } + + // intentional typo in "chunked" + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chnked\r\n\r\n", + ); + expect_parse_err!(&mut buf); + } + + #[test] + fn test_http_request_chunked_payload() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert!(req.chunked().unwrap()); + + buf.extend(b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); + assert_eq!( + pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), + b"data" + ); + assert_eq!( + pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), + b"line" + ); + assert!(pl.decode(&mut buf).unwrap().unwrap().eof()); + } + + #[test] + fn test_http_request_chunked_payload_and_next_message() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert!(req.chunked().unwrap()); + + buf.extend( + b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n\ + POST /test2 HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n" + .iter(), + ); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"data"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"line"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert!(msg.eof()); + + let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); + assert!(req.chunked().unwrap()); + assert_eq!(*req.method(), Method::POST); + assert!(req.chunked().unwrap()); + } + + #[test] + fn test_http_request_chunked_payload_chunks() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert!(req.chunked().unwrap()); + + buf.extend(b"4\r\n1111\r\n"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"1111"); + + buf.extend(b"4\r\ndata\r"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"data"); + + buf.extend(b"\n4"); + assert!(pl.decode(&mut buf).unwrap().is_none()); + + buf.extend(b"\r"); + assert!(pl.decode(&mut buf).unwrap().is_none()); + buf.extend(b"\n"); + assert!(pl.decode(&mut buf).unwrap().is_none()); + + buf.extend(b"li"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"li"); + + //trailers + //buf.feed_data("test: test\r\n"); + //not_ready!(reader.parse(&mut buf, &mut readbuf)); + + buf.extend(b"ne\r\n0\r\n"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"ne"); + assert!(pl.decode(&mut buf).unwrap().is_none()); + + buf.extend(b"\r\n"); + assert!(pl.decode(&mut buf).unwrap().unwrap().eof()); + } + + #[test] + fn chunk_extension_quoted() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + Host: localhost:8080\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + 2;hello=b;one=\"1 2 3\"\r\n\ + xx", + ); + + let mut reader = MessageDecoder::::default(); + let (_msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + + let chunk = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(chunk, PayloadItem::Chunk(Bytes::from_static(b"xx"))); + } + + #[test] + fn hrs_chunk_extension_invalid() { + let mut buf = BytesMut::from( + "GET / HTTP/1.1\r\n\ + Host: localhost:8080\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + 2;x\nx\r\n\ + 4c\r\n\ + 0\r\n", + ); + + let mut reader = MessageDecoder::::default(); + let (_msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + + let err = pl.decode(&mut buf).unwrap_err(); + assert!(err + .to_string() + .contains("Invalid character in chunk extension")); + } + + #[test] + fn hrs_chunk_size_overflow() { + let mut buf = BytesMut::from( + "GET / HTTP/1.1\r\n\ + Host: example.com\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + f0000000000000003\r\n\ + abc\r\n\ + 0\r\n", + ); + + let mut reader = MessageDecoder::::default(); + let (_msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + + let err = pl.decode(&mut buf).unwrap_err(); + assert!(err + .to_string() + .contains("Invalid chunk size line: Size is too big")); + } +} diff --git a/actix-http/src/h1/client.rs b/actix-http/src/h1/client.rs index 4a6104688..bec167971 100644 --- a/actix-http/src/h1/client.rs +++ b/actix-http/src/h1/client.rs @@ -120,7 +120,7 @@ impl Decoder for ClientCodec { debug_assert!(!self.inner.payload.is_some(), "Payload decoder is set"); if let Some((req, payload)) = self.inner.decoder.decode(src)? { - if let Some(ctype) = req.ctype() { + if let Some(ctype) = req.conn_type() { // do not use peer's keep-alive self.inner.ctype = if ctype == ConnectionType::KeepAlive { self.inner.ctype diff --git a/actix-http/src/h1/codec.rs b/actix-http/src/h1/codec.rs index 634ca25e8..29f6f4170 100644 --- a/actix-http/src/h1/codec.rs +++ b/actix-http/src/h1/codec.rs @@ -29,7 +29,7 @@ pub struct Codec { decoder: decoder::MessageDecoder, payload: Option, version: Version, - ctype: ConnectionType, + conn_type: ConnectionType, // encoder part flags: Flags, @@ -65,7 +65,7 @@ impl Codec { decoder: decoder::MessageDecoder::default(), payload: None, version: Version::HTTP_11, - ctype: ConnectionType::Close, + conn_type: ConnectionType::Close, encoder: encoder::MessageEncoder::default(), } } @@ -73,13 +73,13 @@ impl Codec { /// Check if request is upgrade. #[inline] pub fn upgrade(&self) -> bool { - self.ctype == ConnectionType::Upgrade + self.conn_type == ConnectionType::Upgrade } /// Check if last response is keep-alive. #[inline] pub fn keepalive(&self) -> bool { - self.ctype == ConnectionType::KeepAlive + self.conn_type == ConnectionType::KeepAlive } /// Check if keep-alive enabled on server level. @@ -124,11 +124,11 @@ impl Decoder for Codec { let head = req.head(); self.flags.set(Flags::HEAD, head.method == Method::HEAD); self.version = head.version; - self.ctype = head.connection_type(); - if self.ctype == ConnectionType::KeepAlive + self.conn_type = head.connection_type(); + if self.conn_type == ConnectionType::KeepAlive && !self.flags.contains(Flags::KEEPALIVE_ENABLED) { - self.ctype = ConnectionType::Close + self.conn_type = ConnectionType::Close } match payload { PayloadType::None => self.payload = None, @@ -159,14 +159,14 @@ impl Encoder, BodySize)>> for Codec { res.head_mut().version = self.version; // connection status - self.ctype = if let Some(ct) = res.head().ctype() { + self.conn_type = if let Some(ct) = res.head().conn_type() { if ct == ConnectionType::KeepAlive { - self.ctype + self.conn_type } else { ct } } else { - self.ctype + self.conn_type }; // encode message @@ -177,10 +177,9 @@ impl Encoder, BodySize)>> for Codec { self.flags.contains(Flags::STREAM), self.version, length, - self.ctype, + self.conn_type, &self.config, )?; - // self.headers_size = (dst.len() - len) as u32; } Message::Chunk(Some(bytes)) => { self.encoder.encode_chunk(bytes.as_ref(), dst)?; @@ -189,6 +188,7 @@ impl Encoder, BodySize)>> for Codec { self.encoder.encode_eof(dst)?; } } + Ok(()) } } diff --git a/actix-http/src/h1/decoder.rs b/actix-http/src/h1/decoder.rs index f240710c2..f25c35a76 100644 --- a/actix-http/src/h1/decoder.rs +++ b/actix-http/src/h1/decoder.rs @@ -1,18 +1,18 @@ -use std::convert::TryFrom; -use std::io; -use std::marker::PhantomData; -use std::task::Poll; +use std::{convert::TryFrom, io, marker::PhantomData, mem::MaybeUninit, task::Poll}; use actix_codec::Decoder; -use bytes::{Buf, Bytes, BytesMut}; +use bytes::{Bytes, BytesMut}; use http::header::{HeaderName, HeaderValue}; use http::{header, Method, StatusCode, Uri, Version}; use log::{debug, error, trace}; -use crate::error::ParseError; -use crate::header::HeaderMap; -use crate::message::{ConnectionType, ResponseHead}; -use crate::request::Request; +use super::chunked::ChunkedState; +use crate::{ + error::ParseError, + header::HeaderMap, + message::{ConnectionType, ResponseHead}, + request::Request, +}; pub(crate) const MAX_BUFFER_SIZE: usize = 131_072; const MAX_HEADERS: usize = 96; @@ -67,6 +67,7 @@ pub(crate) trait MessageType: Sized { let mut has_upgrade_websocket = false; let mut expect = false; let mut chunked = false; + let mut seen_te = false; let mut content_length = None; { @@ -85,8 +86,17 @@ pub(crate) trait MessageType: Sized { }; match name { - header::CONTENT_LENGTH => { - if let Ok(s) = value.to_str() { + header::CONTENT_LENGTH if content_length.is_some() => { + debug!("multiple Content-Length"); + return Err(ParseError::Header); + } + + header::CONTENT_LENGTH => match value.to_str() { + Ok(s) if s.trim().starts_with('+') => { + debug!("illegal Content-Length: {:?}", s); + return Err(ParseError::Header); + } + Ok(s) => { if let Ok(len) = s.parse::() { if len != 0 { content_length = Some(len); @@ -95,15 +105,31 @@ pub(crate) trait MessageType: Sized { debug!("illegal Content-Length: {:?}", s); return Err(ParseError::Header); } - } else { + } + Err(_) => { debug!("illegal Content-Length: {:?}", value); return Err(ParseError::Header); } - } + }, + // transfer-encoding + header::TRANSFER_ENCODING if seen_te => { + debug!("multiple Transfer-Encoding not allowed"); + return Err(ParseError::Header); + } + header::TRANSFER_ENCODING => { + seen_te = true; + if let Ok(s) = value.to_str().map(str::trim) { - chunked = s.eq_ignore_ascii_case("chunked"); + if s.eq_ignore_ascii_case("chunked") { + chunked = true; + } else if s.eq_ignore_ascii_case("identity") { + // allow silently since multiple TE headers are already checked + } else { + debug!("illegal Transfer-Encoding: {:?}", s); + return Err(ParseError::Header); + } } else { return Err(ParseError::Header); } @@ -148,7 +174,7 @@ pub(crate) trait MessageType: Sized { self.set_expect() } - // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 if chunked { // Chunked encoding Ok(PayloadLength::Payload(PayloadType::Payload( @@ -186,10 +212,17 @@ impl MessageType for Request { let mut headers: [HeaderIndex; MAX_HEADERS] = EMPTY_HEADER_INDEX_ARRAY; let (len, method, uri, ver, h_len) = { - let mut parsed: [httparse::Header<'_>; MAX_HEADERS] = EMPTY_HEADER_ARRAY; + // SAFETY: + // Create an uninitialized array of `MaybeUninit`. The `assume_init` is + // safe because the type we are claiming to have initialized here is a + // bunch of `MaybeUninit`s, which do not require initialization. + let mut parsed = unsafe { + MaybeUninit::<[MaybeUninit>; MAX_HEADERS]>::uninit() + .assume_init() + }; - let mut req = httparse::Request::new(&mut parsed); - match req.parse(src)? { + let mut req = httparse::Request::new(&mut []); + match req.parse_with_uninit_headers(src, &mut parsed)? { httparse::Status::Complete(len) => { let method = Method::from_bytes(req.method.unwrap().as_bytes()) .map_err(|_| ParseError::Method)?; @@ -408,20 +441,6 @@ enum Kind { Eof, } -#[derive(Debug, PartialEq, Clone)] -enum ChunkedState { - Size, - SizeLws, - Extension, - SizeLf, - Body, - BodyCr, - BodyLf, - EndCr, - EndLf, - End, -} - impl Decoder for PayloadDecoder { type Item = PayloadItem; type Error = io::Error; @@ -451,19 +470,23 @@ impl Decoder for PayloadDecoder { Kind::Chunked(ref mut state, ref mut size) => { loop { let mut buf = None; + // advances the chunked state *state = match state.step(src, size, &mut buf) { Poll::Pending => return Ok(None), Poll::Ready(Ok(state)) => state, Poll::Ready(Err(e)) => return Err(e), }; + if *state == ChunkedState::End { trace!("End of chunked stream"); return Ok(Some(PayloadItem::Eof)); } + if let Some(buf) = buf { return Ok(Some(PayloadItem::Chunk(buf))); } + if src.is_empty() { return Ok(None); } @@ -480,201 +503,40 @@ impl Decoder for PayloadDecoder { } } -macro_rules! byte ( - ($rdr:ident) => ({ - if $rdr.len() > 0 { - let b = $rdr[0]; - $rdr.advance(1); - b - } else { - return Poll::Pending - } - }) -); - -impl ChunkedState { - fn step( - &self, - body: &mut BytesMut, - size: &mut u64, - buf: &mut Option, - ) -> Poll> { - use self::ChunkedState::*; - match *self { - Size => ChunkedState::read_size(body, size), - SizeLws => ChunkedState::read_size_lws(body), - Extension => ChunkedState::read_extension(body), - SizeLf => ChunkedState::read_size_lf(body, size), - Body => ChunkedState::read_body(body, size, buf), - BodyCr => ChunkedState::read_body_cr(body), - BodyLf => ChunkedState::read_body_lf(body), - EndCr => ChunkedState::read_end_cr(body), - EndLf => ChunkedState::read_end_lf(body), - End => Poll::Ready(Ok(ChunkedState::End)), - } - } - - fn read_size( - rdr: &mut BytesMut, - size: &mut u64, - ) -> Poll> { - let radix = 16; - match byte!(rdr) { - b @ b'0'..=b'9' => { - *size *= radix; - *size += u64::from(b - b'0'); - } - b @ b'a'..=b'f' => { - *size *= radix; - *size += u64::from(b + 10 - b'a'); - } - b @ b'A'..=b'F' => { - *size *= radix; - *size += u64::from(b + 10 - b'A'); - } - b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)), - b';' => return Poll::Ready(Ok(ChunkedState::Extension)), - b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)), - _ => { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Invalid chunk size line: Invalid Size", - ))); - } - } - Poll::Ready(Ok(ChunkedState::Size)) - } - - fn read_size_lws(rdr: &mut BytesMut) -> Poll> { - trace!("read_size_lws"); - match byte!(rdr) { - // LWS can follow the chunk size, but no more digits can come - b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)), - b';' => Poll::Ready(Ok(ChunkedState::Extension)), - b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), - _ => Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Invalid chunk size linear white space", - ))), - } - } - fn read_extension(rdr: &mut BytesMut) -> Poll> { - match byte!(rdr) { - b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), - _ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions - } - } - fn read_size_lf( - rdr: &mut BytesMut, - size: &mut u64, - ) -> Poll> { - match byte!(rdr) { - b'\n' if *size > 0 => Poll::Ready(Ok(ChunkedState::Body)), - b'\n' if *size == 0 => Poll::Ready(Ok(ChunkedState::EndCr)), - _ => Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Invalid chunk size LF", - ))), - } - } - - fn read_body( - rdr: &mut BytesMut, - rem: &mut u64, - buf: &mut Option, - ) -> Poll> { - trace!("Chunked read, remaining={:?}", rem); - - let len = rdr.len() as u64; - if len == 0 { - Poll::Ready(Ok(ChunkedState::Body)) - } else { - let slice; - if *rem > len { - slice = rdr.split().freeze(); - *rem -= len; - } else { - slice = rdr.split_to(*rem as usize).freeze(); - *rem = 0; - } - *buf = Some(slice); - if *rem > 0 { - Poll::Ready(Ok(ChunkedState::Body)) - } else { - Poll::Ready(Ok(ChunkedState::BodyCr)) - } - } - } - - fn read_body_cr(rdr: &mut BytesMut) -> Poll> { - match byte!(rdr) { - b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)), - _ => Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Invalid chunk body CR", - ))), - } - } - fn read_body_lf(rdr: &mut BytesMut) -> Poll> { - match byte!(rdr) { - b'\n' => Poll::Ready(Ok(ChunkedState::Size)), - _ => Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Invalid chunk body LF", - ))), - } - } - fn read_end_cr(rdr: &mut BytesMut) -> Poll> { - match byte!(rdr) { - b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)), - _ => Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Invalid chunk end CR", - ))), - } - } - fn read_end_lf(rdr: &mut BytesMut) -> Poll> { - match byte!(rdr) { - b'\n' => Poll::Ready(Ok(ChunkedState::End)), - _ => Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Invalid chunk end LF", - ))), - } - } -} - #[cfg(test)] mod tests { use bytes::{Bytes, BytesMut}; use http::{Method, Version}; use super::*; - use crate::error::ParseError; - use crate::http::header::{HeaderName, SET_COOKIE}; - use crate::HttpMessage; + use crate::{ + error::ParseError, + http::header::{HeaderName, SET_COOKIE}, + HttpMessage as _, + }; impl PayloadType { - fn unwrap(self) -> PayloadDecoder { + pub(crate) fn unwrap(self) -> PayloadDecoder { match self { PayloadType::Payload(pl) => pl, _ => panic!(), } } - fn is_unhandled(&self) -> bool { + pub(crate) fn is_unhandled(&self) -> bool { matches!(self, PayloadType::Stream(_)) } } impl PayloadItem { - fn chunk(self) -> Bytes { + pub(crate) fn chunk(self) -> Bytes { match self { PayloadItem::Chunk(chunk) => chunk, _ => panic!("error"), } } - fn eof(&self) -> bool { + + pub(crate) fn eof(&self) -> bool { matches!(*self, PayloadItem::Eof) } } @@ -967,34 +829,6 @@ mod tests { assert!(req.upgrade()); } - #[test] - fn test_request_chunked() { - let mut buf = BytesMut::from( - "GET /test HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\r\n", - ); - let req = parse_ready!(&mut buf); - - if let Ok(val) = req.chunked() { - assert!(val); - } else { - unreachable!("Error"); - } - - // intentional typo in "chunked" - let mut buf = BytesMut::from( - "GET /test HTTP/1.1\r\n\ - transfer-encoding: chnked\r\n\r\n", - ); - let req = parse_ready!(&mut buf); - - if let Ok(val) = req.chunked() { - assert!(!val); - } else { - unreachable!("Error"); - } - } - #[test] fn test_headers_content_length_err_1() { let mut buf = BytesMut::from( @@ -1112,126 +946,6 @@ mod tests { expect_parse_err!(&mut buf); } - #[test] - fn test_http_request_chunked_payload() { - let mut buf = BytesMut::from( - "GET /test HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\r\n", - ); - let mut reader = MessageDecoder::::default(); - let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); - let mut pl = pl.unwrap(); - assert!(req.chunked().unwrap()); - - buf.extend(b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); - assert_eq!( - pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), - b"data" - ); - assert_eq!( - pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), - b"line" - ); - assert!(pl.decode(&mut buf).unwrap().unwrap().eof()); - } - - #[test] - fn test_http_request_chunked_payload_and_next_message() { - let mut buf = BytesMut::from( - "GET /test HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\r\n", - ); - let mut reader = MessageDecoder::::default(); - let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); - let mut pl = pl.unwrap(); - assert!(req.chunked().unwrap()); - - buf.extend( - b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n\ - POST /test2 HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\r\n" - .iter(), - ); - let msg = pl.decode(&mut buf).unwrap().unwrap(); - assert_eq!(msg.chunk().as_ref(), b"data"); - let msg = pl.decode(&mut buf).unwrap().unwrap(); - assert_eq!(msg.chunk().as_ref(), b"line"); - let msg = pl.decode(&mut buf).unwrap().unwrap(); - assert!(msg.eof()); - - let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); - assert!(req.chunked().unwrap()); - assert_eq!(*req.method(), Method::POST); - assert!(req.chunked().unwrap()); - } - - #[test] - fn test_http_request_chunked_payload_chunks() { - let mut buf = BytesMut::from( - "GET /test HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\r\n", - ); - - let mut reader = MessageDecoder::::default(); - let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); - let mut pl = pl.unwrap(); - assert!(req.chunked().unwrap()); - - buf.extend(b"4\r\n1111\r\n"); - let msg = pl.decode(&mut buf).unwrap().unwrap(); - assert_eq!(msg.chunk().as_ref(), b"1111"); - - buf.extend(b"4\r\ndata\r"); - let msg = pl.decode(&mut buf).unwrap().unwrap(); - assert_eq!(msg.chunk().as_ref(), b"data"); - - buf.extend(b"\n4"); - assert!(pl.decode(&mut buf).unwrap().is_none()); - - buf.extend(b"\r"); - assert!(pl.decode(&mut buf).unwrap().is_none()); - buf.extend(b"\n"); - assert!(pl.decode(&mut buf).unwrap().is_none()); - - buf.extend(b"li"); - let msg = pl.decode(&mut buf).unwrap().unwrap(); - assert_eq!(msg.chunk().as_ref(), b"li"); - - //trailers - //buf.feed_data("test: test\r\n"); - //not_ready!(reader.parse(&mut buf, &mut readbuf)); - - buf.extend(b"ne\r\n0\r\n"); - let msg = pl.decode(&mut buf).unwrap().unwrap(); - assert_eq!(msg.chunk().as_ref(), b"ne"); - assert!(pl.decode(&mut buf).unwrap().is_none()); - - buf.extend(b"\r\n"); - assert!(pl.decode(&mut buf).unwrap().unwrap().eof()); - } - - #[test] - fn test_parse_chunked_payload_chunk_extension() { - let mut buf = BytesMut::from( - "GET /test HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\ - \r\n", - ); - - let mut reader = MessageDecoder::::default(); - let (msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); - let mut pl = pl.unwrap(); - assert!(msg.chunked().unwrap()); - - buf.extend(b"4;test\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); // test: test\r\n\r\n") - let chunk = pl.decode(&mut buf).unwrap().unwrap().chunk(); - assert_eq!(chunk, Bytes::from_static(b"data")); - let chunk = pl.decode(&mut buf).unwrap().unwrap().chunk(); - assert_eq!(chunk, Bytes::from_static(b"line")); - let msg = pl.decode(&mut buf).unwrap().unwrap(); - assert!(msg.eof()); - } - #[test] fn test_response_http10_read_until_eof() { let mut buf = BytesMut::from("HTTP/1.0 200 Ok\r\n\r\ntest data"); @@ -1243,4 +957,84 @@ mod tests { let chunk = pl.decode(&mut buf).unwrap().unwrap(); assert_eq!(chunk, PayloadItem::Chunk(Bytes::from_static(b"test data"))); } + + #[test] + fn hrs_multiple_content_length() { + let mut buf = BytesMut::from( + "GET / HTTP/1.1\r\n\ + Host: example.com\r\n\ + Content-Length: 4\r\n\ + Content-Length: 2\r\n\ + \r\n\ + abcd", + ); + + expect_parse_err!(&mut buf); + } + + #[test] + fn hrs_content_length_plus() { + let mut buf = BytesMut::from( + "GET / HTTP/1.1\r\n\ + Host: example.com\r\n\ + Content-Length: +3\r\n\ + \r\n\ + 000", + ); + + expect_parse_err!(&mut buf); + } + + #[test] + fn hrs_unknown_transfer_encoding() { + let mut buf = BytesMut::from( + "GET / HTTP/1.1\r\n\ + Host: example.com\r\n\ + Transfer-Encoding: JUNK\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + 5\r\n\ + hello\r\n\ + 0", + ); + + expect_parse_err!(&mut buf); + } + + #[test] + fn hrs_multiple_transfer_encoding() { + let mut buf = BytesMut::from( + "GET / HTTP/1.1\r\n\ + Host: example.com\r\n\ + Content-Length: 51\r\n\ + Transfer-Encoding: identity\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + 0\r\n\ + \r\n\ + GET /forbidden HTTP/1.1\r\n\ + Host: example.com\r\n\r\n", + ); + + expect_parse_err!(&mut buf); + } + + #[test] + fn transfer_encoding_agrees() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + Host: example.com\r\n\ + Content-Length: 3\r\n\ + Transfer-Encoding: identity\r\n\ + \r\n\ + 0\r\n", + ); + + let mut reader = MessageDecoder::::default(); + let (_msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + + let chunk = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(chunk, PayloadItem::Chunk(Bytes::from_static(b"0\r\n"))); + } } diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index deb25763c..6695d1bf3 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -1,6 +1,5 @@ use std::{ collections::VecDeque, - error::Error as StdError, fmt, future::Future, io, mem, net, @@ -19,7 +18,7 @@ use log::{error, trace}; use pin_project::pin_project; use crate::{ - body::{AnyBody, BodySize, MessageBody}, + body::{BodySize, BoxBody, MessageBody}, config::ServiceConfig, error::{DispatchError, ParseError, PayloadError}, service::HttpFlow, @@ -51,13 +50,12 @@ bitflags! { pub struct Dispatcher where S: Service, - S::Error: Into>, + S::Error: Into>, B: MessageBody, - B::Error: Into>, X: Service, - X::Error: Into>, + X::Error: Into>, U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, @@ -73,13 +71,12 @@ where enum DispatcherState where S: Service, - S::Error: Into>, + S::Error: Into>, B: MessageBody, - B::Error: Into>, X: Service, - X::Error: Into>, + X::Error: Into>, U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, @@ -92,13 +89,12 @@ where struct InnerDispatcher where S: Service, - S::Error: Into>, + S::Error: Into>, B: MessageBody, - B::Error: Into>, X: Service, - X::Error: Into>, + X::Error: Into>, U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, @@ -137,13 +133,12 @@ where X: Service, B: MessageBody, - B::Error: Into>, { None, ExpectCall(#[pin] X::Future), ServiceCall(#[pin] S::Future), SendPayload(#[pin] B), - SendErrorPayload(#[pin] AnyBody), + SendErrorPayload(#[pin] BoxBody), } impl State @@ -153,7 +148,6 @@ where X: Service, B: MessageBody, - B::Error: Into>, { fn is_empty(&self) -> bool { matches!(self, State::None) @@ -171,14 +165,13 @@ where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into>, + S::Error: Into>, S::Response: Into>, B: MessageBody, - B::Error: Into>, X: Service, - X::Error: Into>, + X::Error: Into>, U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, @@ -232,14 +225,13 @@ where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into>, + S::Error: Into>, S::Response: Into>, B: MessageBody, - B::Error: Into>, X: Service, - X::Error: Into>, + X::Error: Into>, U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, @@ -303,9 +295,9 @@ where body: &impl MessageBody, ) -> Result { let size = body.size(); - let mut this = self.project(); + let this = self.project(); this.codec - .encode(Message::Item((message, size)), &mut this.write_buf) + .encode(Message::Item((message, size)), this.write_buf) .map_err(|err| { if let Some(mut payload) = this.payload.take() { payload.set_error(PayloadError::Incomplete(None)); @@ -325,7 +317,7 @@ where ) -> Result<(), DispatchError> { let size = self.as_mut().send_response_inner(message, &body)?; let state = match size { - BodySize::None | BodySize::Empty => State::None, + BodySize::None | BodySize::Sized(0) => State::None, _ => State::SendPayload(body), }; self.project().state.set(state); @@ -335,11 +327,11 @@ where fn send_error_response( mut self: Pin<&mut Self>, message: Response<()>, - body: AnyBody, + body: BoxBody, ) -> Result<(), DispatchError> { let size = self.as_mut().send_response_inner(message, &body)?; let state = match size { - BodySize::None | BodySize::Empty => State::None, + BodySize::None | BodySize::Sized(0) => State::None, _ => State::SendErrorPayload(body), }; self.project().state.set(state); @@ -380,7 +372,7 @@ where // send_response would update InnerDispatcher state to SendPayload or // None(If response body is empty). // continue loop to poll it. - self.as_mut().send_error_response(res, AnyBody::Empty)?; + self.as_mut().send_error_response(res, BoxBody::new(()))?; } // return with upgrade request and poll it exclusively. @@ -400,7 +392,7 @@ where // send service call error as response Poll::Ready(Err(err)) => { - let res: Response = err.into(); + let res: Response = err.into(); let (res, body) = res.replace_body(()); self.as_mut().send_error_response(res, body)?; } @@ -425,13 +417,13 @@ where Poll::Ready(Some(Ok(item))) => { this.codec.encode( Message::Chunk(Some(item)), - &mut this.write_buf, + this.write_buf, )?; } Poll::Ready(None) => { this.codec - .encode(Message::Chunk(None), &mut this.write_buf)?; + .encode(Message::Chunk(None), this.write_buf)?; // payload stream finished. // set state to None and handle next message this.state.set(State::None); @@ -460,13 +452,13 @@ where Poll::Ready(Some(Ok(item))) => { this.codec.encode( Message::Chunk(Some(item)), - &mut this.write_buf, + this.write_buf, )?; } Poll::Ready(None) => { this.codec - .encode(Message::Chunk(None), &mut this.write_buf)?; + .encode(Message::Chunk(None), this.write_buf)?; // payload stream finished. // set state to None and handle next message this.state.set(State::None); @@ -497,7 +489,7 @@ where // send expect error as response Poll::Ready(Err(err)) => { - let res: Response = err.into(); + let res: Response = err.into(); let (res, body) = res.replace_body(()); self.as_mut().send_error_response(res, body)?; } @@ -546,7 +538,7 @@ where // to notify the dispatcher a new state is set and the outer loop // should be continue. Poll::Ready(Err(err)) => { - let res: Response = err.into(); + let res: Response = err.into(); let (res, body) = res.replace_body(()); return self.send_error_response(res, body); } @@ -566,7 +558,7 @@ where Poll::Pending => Ok(()), // see the comment on ExpectCall state branch's Ready(Err(err)). Poll::Ready(Err(err)) => { - let res: Response = err.into(); + let res: Response = err.into(); let (res, body) = res.replace_body(()); self.send_error_response(res, body) } @@ -592,7 +584,7 @@ where let mut updated = false; let mut this = self.as_mut().project(); loop { - match this.codec.decode(&mut this.read_buf) { + match this.codec.decode(this.read_buf) { Ok(Some(msg)) => { updated = true; this.flags.insert(Flags::STARTED); @@ -772,7 +764,7 @@ where trace!("Slow request timeout"); let _ = self.as_mut().send_error_response( Response::with_body(StatusCode::REQUEST_TIMEOUT, ()), - AnyBody::Empty, + BoxBody::new(()), ); this = self.project(); this.flags.insert(Flags::STARTED | Flags::SHUTDOWN); @@ -909,14 +901,13 @@ where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into>, + S::Error: Into>, S::Response: Into>, B: MessageBody, - B::Error: Into>, X: Service, - X::Error: Into>, + X::Error: Into>, U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, @@ -1060,24 +1051,26 @@ mod tests { fn stabilize_date_header(payload: &mut [u8]) { let mut from = 0; - while let Some(pos) = find_slice(&payload, b"date", from) { + while let Some(pos) = find_slice(payload, b"date", from) { payload[(from + pos)..(from + pos + 35)] .copy_from_slice(b"date: Thu, 01 Jan 1970 12:34:56 UTC"); from += 35; } } - fn ok_service() -> impl Service, Error = Error> + fn ok_service( + ) -> impl Service, Error = Error> { fn_service(|_req: Request| ready(Ok::<_, Error>(Response::ok()))) } fn echo_path_service( - ) -> impl Service, Error = Error> { + ) -> impl Service, Error = Error> + { fn_service(|req: Request| { let path = req.path().as_bytes(); ready(Ok::<_, Error>( - Response::ok().set_body(AnyBody::from_slice(path)), + Response::ok().set_body(Bytes::copy_from_slice(path)), )) }) } diff --git a/actix-http/src/h1/encoder.rs b/actix-http/src/h1/encoder.rs index 254981123..60880cd7d 100644 --- a/actix-http/src/h1/encoder.rs +++ b/actix-http/src/h1/encoder.rs @@ -20,6 +20,7 @@ const AVERAGE_HEADER_SIZE: usize = 30; #[derive(Debug)] pub(crate) struct MessageEncoder { + #[allow(dead_code)] pub length: BodySize, pub te: TransferEncoding, _phantom: PhantomData, @@ -55,7 +56,7 @@ pub(crate) trait MessageType: Sized { dst: &mut BytesMut, version: Version, mut length: BodySize, - ctype: ConnectionType, + conn_type: ConnectionType, config: &ServiceConfig, ) -> io::Result<()> { let chunked = self.chunked(); @@ -70,17 +71,28 @@ pub(crate) trait MessageType: Sized { | StatusCode::PROCESSING | StatusCode::NO_CONTENT => { // skip content-length and transfer-encoding headers - // See https://tools.ietf.org/html/rfc7230#section-3.3.1 - // and https://tools.ietf.org/html/rfc7230#section-3.3.2 + // see https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.1 + // and https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.2 skip_len = true; length = BodySize::None } + + StatusCode::NOT_MODIFIED => { + // 304 responses should never have a body but should retain a manually set + // content-length header + // see https://datatracker.ietf.org/doc/html/rfc7232#section-4.1 + skip_len = false; + length = BodySize::None; + } + _ => {} } } + match length { BodySize::Stream => { if chunked { + skip_len = true; if camel_case { dst.put_slice(b"\r\nTransfer-Encoding: chunked\r\n") } else { @@ -91,19 +103,16 @@ pub(crate) trait MessageType: Sized { dst.put_slice(b"\r\n"); } } - BodySize::Empty => { - if camel_case { - dst.put_slice(b"\r\nContent-Length: 0\r\n"); - } else { - dst.put_slice(b"\r\ncontent-length: 0\r\n"); - } + BodySize::Sized(0) if camel_case => { + dst.put_slice(b"\r\nContent-Length: 0\r\n") } + BodySize::Sized(0) => dst.put_slice(b"\r\ncontent-length: 0\r\n"), BodySize::Sized(len) => helpers::write_content_length(len, dst), BodySize::None => dst.put_slice(b"\r\n"), } // Connection - match ctype { + match conn_type { ConnectionType::Upgrade => dst.put_slice(b"connection: upgrade\r\n"), ConnectionType::KeepAlive if version < Version::HTTP_11 => { if camel_case { @@ -174,7 +183,7 @@ pub(crate) trait MessageType: Sized { unsafe { if camel_case { // use Camel-Case headers - write_camel_case(k, from_raw_parts_mut(buf, k_len)); + write_camel_case(k, buf, k_len); } else { write_data(k, buf, k_len); } @@ -328,13 +337,13 @@ impl MessageEncoder { stream: bool, version: Version, length: BodySize, - ctype: ConnectionType, + conn_type: ConnectionType, config: &ServiceConfig, ) -> io::Result<()> { // transfer encoding if !head { self.te = match length { - BodySize::Empty => TransferEncoding::empty(), + BodySize::Sized(0) => TransferEncoding::empty(), BodySize::Sized(len) => TransferEncoding::length(len), BodySize::Stream => { if message.chunked() && !stream { @@ -350,7 +359,7 @@ impl MessageEncoder { } message.encode_status(dst)?; - message.encode_headers(dst, version, length, ctype, config) + message.encode_headers(dst, version, length, conn_type, config) } } @@ -364,10 +373,12 @@ pub(crate) struct TransferEncoding { enum TransferEncodingKind { /// An Encoder for when Transfer-Encoding includes `chunked`. Chunked(bool), + /// An Encoder for when Content-Length is set. /// /// Enforces that the body is not longer than the Content-Length header. Length(u64), + /// An Encoder for when Content-Length is not known. /// /// Application decides when to stop writing. @@ -472,15 +483,22 @@ impl TransferEncoding { } /// # Safety -/// Callers must ensure that the given length matches given value length. +/// Callers must ensure that the given `len` matches the given `value` length and that `buf` is +/// valid for writes of at least `len` bytes. unsafe fn write_data(value: &[u8], buf: *mut u8, len: usize) { debug_assert_eq!(value.len(), len); copy_nonoverlapping(value.as_ptr(), buf, len); } -fn write_camel_case(value: &[u8], buffer: &mut [u8]) { +/// # Safety +/// Callers must ensure that the given `len` matches the given `value` length and that `buf` is +/// valid for writes of at least `len` bytes. +unsafe fn write_camel_case(value: &[u8], buf: *mut u8, len: usize) { // first copy entire (potentially wrong) slice to output - buffer[..value.len()].copy_from_slice(value); + write_data(value, buf, len); + + // SAFETY: We just initialized the buffer with `value` + let buffer = from_raw_parts_mut(buf, len); let mut iter = value.iter(); @@ -544,7 +562,7 @@ mod tests { let _ = head.encode_headers( &mut bytes, Version::HTTP_11, - BodySize::Empty, + BodySize::Sized(0), ConnectionType::Close, &ServiceConfig::default(), ); @@ -615,7 +633,7 @@ mod tests { let _ = head.encode_headers( &mut bytes, Version::HTTP_11, - BodySize::Empty, + BodySize::Sized(0), ConnectionType::Close, &ServiceConfig::default(), ); diff --git a/actix-http/src/h1/mod.rs b/actix-http/src/h1/mod.rs index 7e6df6ceb..17cbfb90f 100644 --- a/actix-http/src/h1/mod.rs +++ b/actix-http/src/h1/mod.rs @@ -1,6 +1,8 @@ //! HTTP/1 protocol implementation. + use bytes::{Bytes, BytesMut}; +mod chunked; mod client; mod codec; mod decoder; diff --git a/actix-http/src/h1/service.rs b/actix-http/src/h1/service.rs index dbad8cfac..70e83901c 100644 --- a/actix-http/src/h1/service.rs +++ b/actix-http/src/h1/service.rs @@ -1,5 +1,4 @@ use std::{ - error::Error as StdError, fmt, marker::PhantomData, net, @@ -16,7 +15,7 @@ use actix_utils::future::ready; use futures_core::future::LocalBoxFuture; use crate::{ - body::{AnyBody, MessageBody}, + body::{BoxBody, MessageBody}, config::ServiceConfig, error::DispatchError, service::HttpServiceHandler, @@ -38,7 +37,7 @@ pub struct H1Service { impl H1Service where S: ServiceFactory, - S::Error: Into>, + S::Error: Into>, S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, @@ -63,21 +62,20 @@ impl H1Service where S: ServiceFactory, S::Future: 'static, - S::Error: Into>, + S::Error: Into>, S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, - B::Error: Into>, X: ServiceFactory, X::Future: 'static, - X::Error: Into>, + X::Error: Into>, X::InitError: fmt::Debug, U: ServiceFactory<(Request, Framed), Config = (), Response = ()>, U::Future: 'static, - U::Error: fmt::Display + Into>, + U::Error: fmt::Display + Into>, U::InitError: fmt::Debug, { /// Create simple tcp stream service @@ -102,9 +100,11 @@ where mod openssl { use super::*; - use actix_service::ServiceFactoryExt; use actix_tls::accept::{ - openssl::{Acceptor, SslAcceptor, SslError, TlsStream}, + openssl::{ + reexports::{Error as SslError, SslAcceptor}, + Acceptor, TlsStream, + }, TlsError, }; @@ -112,16 +112,15 @@ mod openssl { where S: ServiceFactory, S::Future: 'static, - S::Error: Into>, + S::Error: Into>, S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, - B::Error: Into>, X: ServiceFactory, X::Future: 'static, - X::Error: Into>, + X::Error: Into>, X::InitError: fmt::Debug, U: ServiceFactory< @@ -130,10 +129,10 @@ mod openssl { Response = (), >, U::Future: 'static, - U::Error: fmt::Display + Into>, + U::Error: fmt::Display + Into>, U::InitError: fmt::Debug, { - /// Create openssl based service + /// Create OpenSSL based service. pub fn openssl( self, acceptor: SslAcceptor, @@ -145,11 +144,13 @@ mod openssl { InitError = (), > { Acceptor::new(acceptor) - .map_err(TlsError::Tls) - .map_init_err(|_| panic!()) - .and_then(|io: TlsStream| { + .map_init_err(|_| { + unreachable!("TLS acceptor service factory does not error on init") + }) + .map_err(TlsError::into_service_error) + .map(|io: TlsStream| { let peer_addr = io.get_ref().peer_addr().ok(); - ready(Ok((io, peer_addr))) + (io, peer_addr) }) .and_then(self.map_err(TlsError::Service)) } @@ -158,30 +159,30 @@ mod openssl { #[cfg(feature = "rustls")] mod rustls { - use super::*; use std::io; - use actix_service::ServiceFactoryExt; + use actix_service::ServiceFactoryExt as _; use actix_tls::accept::{ - rustls::{Acceptor, ServerConfig, TlsStream}, + rustls::{reexports::ServerConfig, Acceptor, TlsStream}, TlsError, }; + use super::*; + impl H1Service, S, B, X, U> where S: ServiceFactory, S::Future: 'static, - S::Error: Into>, + S::Error: Into>, S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, - B::Error: Into>, X: ServiceFactory, X::Future: 'static, - X::Error: Into>, + X::Error: Into>, X::InitError: fmt::Debug, U: ServiceFactory< @@ -190,10 +191,10 @@ mod rustls { Response = (), >, U::Future: 'static, - U::Error: fmt::Display + Into>, + U::Error: fmt::Display + Into>, U::InitError: fmt::Debug, { - /// Create rustls based service + /// Create Rustls based service. pub fn rustls( self, config: ServerConfig, @@ -205,11 +206,13 @@ mod rustls { InitError = (), > { Acceptor::new(config) - .map_err(TlsError::Tls) - .map_init_err(|_| panic!()) - .and_then(|io: TlsStream| { + .map_init_err(|_| { + unreachable!("TLS acceptor service factory does not error on init") + }) + .map_err(TlsError::into_service_error) + .map(|io: TlsStream| { let peer_addr = io.get_ref().0.peer_addr().ok(); - ready(Ok((io, peer_addr))) + (io, peer_addr) }) .and_then(self.map_err(TlsError::Service)) } @@ -219,7 +222,7 @@ mod rustls { impl H1Service where S: ServiceFactory, - S::Error: Into>, + S::Error: Into>, S::Response: Into>, S::InitError: fmt::Debug, B: MessageBody, @@ -227,7 +230,7 @@ where pub fn expect(self, expect: X1) -> H1Service where X1: ServiceFactory, - X1::Error: Into>, + X1::Error: Into>, X1::InitError: fmt::Debug, { H1Service { @@ -270,21 +273,20 @@ where S: ServiceFactory, S::Future: 'static, - S::Error: Into>, + S::Error: Into>, S::Response: Into>, S::InitError: fmt::Debug, B: MessageBody, - B::Error: Into>, X: ServiceFactory, X::Future: 'static, - X::Error: Into>, + X::Error: Into>, X::InitError: fmt::Debug, U: ServiceFactory<(Request, Framed), Config = (), Response = ()>, U::Future: 'static, - U::Error: fmt::Display + Into>, + U::Error: fmt::Display + Into>, U::InitError: fmt::Debug, { type Response = (); @@ -340,17 +342,16 @@ where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into>, + S::Error: Into>, S::Response: Into>, B: MessageBody, - B::Error: Into>, X: Service, - X::Error: Into>, + X::Error: Into>, U: Service<(Request, Framed), Response = ()>, - U::Error: fmt::Display + Into>, + U::Error: fmt::Display + Into>, { type Response = (); type Error = DispatchError; diff --git a/actix-http/src/h1/utils.rs b/actix-http/src/h1/utils.rs index 523e652fd..905585a32 100644 --- a/actix-http/src/h1/utils.rs +++ b/actix-http/src/h1/utils.rs @@ -1,22 +1,30 @@ -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use pin_project_lite::pin_project; -use crate::body::{BodySize, MessageBody}; -use crate::error::Error; -use crate::h1::{Codec, Message}; -use crate::response::Response; +use crate::{ + body::{BodySize, MessageBody}, + error::Error, + h1::{Codec, Message}, + response::Response, +}; -/// Send HTTP/1 response -#[pin_project::pin_project] -pub struct SendResponse { - res: Option, BodySize)>>, - #[pin] - body: Option, - #[pin] - framed: Option>, +pin_project! { + /// Send HTTP/1 response + pub struct SendResponse { + res: Option, BodySize)>>, + + #[pin] + body: Option, + + #[pin] + framed: Option>, + } } impl SendResponse @@ -63,7 +71,6 @@ where .is_write_buf_full() { let next = - // TODO: MSRV 1.51: poll_map_err match this.body.as_mut().as_pin_mut().unwrap().poll_next(cx) { Poll::Ready(Some(Ok(item))) => Poll::Ready(Some(item)), Poll::Ready(Some(Err(err))) => { diff --git a/actix-http/src/h2/dispatcher.rs b/actix-http/src/h2/dispatcher.rs index ea149b1e0..6d2f4579a 100644 --- a/actix-http/src/h2/dispatcher.rs +++ b/actix-http/src/h2/dispatcher.rs @@ -10,17 +10,21 @@ use std::{ }; use actix_codec::{AsyncRead, AsyncWrite}; +use actix_rt::time::{sleep, Sleep}; use actix_service::Service; use actix_utils::future::poll_fn; use bytes::{Bytes, BytesMut}; use futures_core::ready; -use h2::server::{Connection, SendResponse}; +use h2::{ + server::{Connection, SendResponse}, + Ping, PingPong, +}; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; use log::{error, trace}; use pin_project_lite::pin_project; use crate::{ - body::{AnyBody, BodySize, MessageBody}, + body::{BodySize, BoxBody, MessageBody}, config::ServiceConfig, service::HttpFlow, OnConnectData, Payload, Request, Response, ResponseHead, @@ -36,40 +40,63 @@ pin_project! { on_connect_data: OnConnectData, config: ServiceConfig, peer_addr: Option, - _phantom: PhantomData, + ping_pong: Option, + _phantom: PhantomData } } -impl Dispatcher { +impl Dispatcher +where + T: AsyncRead + AsyncWrite + Unpin, +{ pub(crate) fn new( flow: Rc>, - connection: Connection, + mut conn: Connection, on_connect_data: OnConnectData, config: ServiceConfig, peer_addr: Option, + timer: Option>>, ) -> Self { + let ping_pong = config.keep_alive().map(|dur| H2PingPong { + timer: timer + .map(|mut timer| { + // reset timer if it's received from new function. + timer.as_mut().reset(config.now() + dur); + timer + }) + .unwrap_or_else(|| Box::pin(sleep(dur))), + on_flight: false, + ping_pong: conn.ping_pong().unwrap(), + }); + Self { flow, config, peer_addr, - connection, + connection: conn, on_connect_data, + ping_pong, _phantom: PhantomData, } } } +struct H2PingPong { + timer: Pin>, + on_flight: bool, + ping_pong: PingPong, +} + impl Future for Dispatcher where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into>, + S::Error: Into>, S::Future: 'static, S::Response: Into>, B: MessageBody, - B::Error: Into>, { type Output = Result<(), crate::error::DispatchError>; @@ -77,54 +104,92 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); - while let Some((req, tx)) = - ready!(Pin::new(&mut this.connection).poll_accept(cx)?) - { - let (parts, body) = req.into_parts(); - let pl = crate::h2::Payload::new(body); - let pl = Payload::::H2(pl); - let mut req = Request::with_payload(pl); + loop { + match Pin::new(&mut this.connection).poll_accept(cx)? { + Poll::Ready(Some((req, tx))) => { + let (parts, body) = req.into_parts(); + let pl = crate::h2::Payload::new(body); + let pl = Payload::::H2(pl); + let mut req = Request::with_payload(pl); - let head = req.head_mut(); - head.uri = parts.uri; - head.method = parts.method; - head.version = parts.version; - head.headers = parts.headers.into(); - head.peer_addr = this.peer_addr; + let head = req.head_mut(); + head.uri = parts.uri; + head.method = parts.method; + head.version = parts.version; + head.headers = parts.headers.into(); + head.peer_addr = this.peer_addr; - // merge on_connect_ext data into request extensions - this.on_connect_data.merge_into(&mut req); + // merge on_connect_ext data into request extensions + this.on_connect_data.merge_into(&mut req); - let fut = this.flow.service.call(req); - let config = this.config.clone(); + let fut = this.flow.service.call(req); + let config = this.config.clone(); - // multiplex request handling with spawn task - actix_rt::spawn(async move { - // resolve service call and send response. - let res = match fut.await { - Ok(res) => handle_response(res.into(), tx, config).await, - Err(err) => { - let res: Response = err.into(); - handle_response(res, tx, config).await - } - }; + // multiplex request handling with spawn task + actix_rt::spawn(async move { + // resolve service call and send response. + let res = match fut.await { + Ok(res) => handle_response(res.into(), tx, config).await, + Err(err) => { + let res: Response = err.into(); + handle_response(res, tx, config).await + } + }; - // log error. - if let Err(err) = res { - match err { - DispatchError::SendResponse(err) => { - trace!("Error sending HTTP/2 response: {:?}", err) + // log error. + if let Err(err) = res { + match err { + DispatchError::SendResponse(err) => { + trace!("Error sending HTTP/2 response: {:?}", err) + } + DispatchError::SendData(err) => warn!("{:?}", err), + DispatchError::ResponseBody(err) => { + error!("Response payload stream error: {:?}", err) + } + } } - DispatchError::SendData(err) => warn!("{:?}", err), - DispatchError::ResponseBody(err) => { - error!("Response payload stream error: {:?}", err) - } - } + }); } - }); - } + Poll::Ready(None) => return Poll::Ready(Ok(())), + Poll::Pending => match this.ping_pong.as_mut() { + Some(ping_pong) => loop { + if ping_pong.on_flight { + // When have on flight ping pong. poll pong and and keep alive timer. + // on success pong received update keep alive timer to determine the next timing of + // ping pong. + match ping_pong.ping_pong.poll_pong(cx)? { + Poll::Ready(_) => { + ping_pong.on_flight = false; - Poll::Ready(Ok(())) + let dead_line = + this.config.keep_alive_expire().unwrap(); + ping_pong.timer.as_mut().reset(dead_line); + } + Poll::Pending => { + return ping_pong + .timer + .as_mut() + .poll(cx) + .map(|_| Ok(())) + } + } + } else { + // When there is no on flight ping pong. keep alive timer is used to wait for next + // timing of ping pong. Therefore at this point it serves as an interval instead. + ready!(ping_pong.timer.as_mut().poll(cx)); + + ping_pong.ping_pong.send_ping(Ping::opaque())?; + + let dead_line = this.config.keep_alive_expire().unwrap(); + ping_pong.timer.as_mut().reset(dead_line); + + ping_pong.on_flight = true; + } + }, + None => return Poll::Pending, + }, + } + } } } @@ -141,7 +206,6 @@ async fn handle_response( ) -> Result<(), DispatchError> where B: MessageBody, - B::Error: Into>, { let (res, body) = res.replace_body(()); @@ -226,9 +290,11 @@ fn prepare_response( let _ = match size { BodySize::None | BodySize::Stream => None, - BodySize::Empty => res + + BodySize::Sized(0) => res .headers_mut() .insert(CONTENT_LENGTH, HeaderValue::from_static("0")), + BodySize::Sized(len) => { let mut buf = itoa::Buffer::new(); @@ -243,7 +309,7 @@ fn prepare_response( for (key, value) in head.headers.iter() { match *key { // TODO: consider skipping other headers according to: - // https://tools.ietf.org/html/rfc7540#section-8.1.2.2 + // https://datatracker.ietf.org/doc/html/rfc7540#section-8.1.2.2 // omit HTTP/1.x only headers CONNECTION | TRANSFER_ENCODING => continue, CONTENT_LENGTH if skip_len => continue, diff --git a/actix-http/src/h2/mod.rs b/actix-http/src/h2/mod.rs index 7eff44ac1..25d53403e 100644 --- a/actix-http/src/h2/mod.rs +++ b/actix-http/src/h2/mod.rs @@ -1,20 +1,30 @@ //! HTTP/2 protocol. use std::{ + future::Future, pin::Pin, task::{Context, Poll}, }; +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_rt::time::Sleep; use bytes::Bytes; use futures_core::{ready, Stream}; -use h2::RecvStream; +use h2::{ + server::{handshake, Connection, Handshake}, + RecvStream, +}; mod dispatcher; mod service; pub use self::dispatcher::Dispatcher; pub use self::service::H2Service; -use crate::error::PayloadError; + +use crate::{ + config::ServiceConfig, + error::{DispatchError, PayloadError}, +}; /// HTTP/2 peer stream. pub struct Payload { @@ -50,3 +60,44 @@ impl Stream for Payload { } } } + +pub(crate) fn handshake_with_timeout( + io: T, + config: &ServiceConfig, +) -> HandshakeWithTimeout +where + T: AsyncRead + AsyncWrite + Unpin, +{ + HandshakeWithTimeout { + handshake: handshake(io), + timer: config.client_timer().map(Box::pin), + } +} + +pub(crate) struct HandshakeWithTimeout { + handshake: Handshake, + timer: Option>>, +} + +impl Future for HandshakeWithTimeout +where + T: AsyncRead + AsyncWrite + Unpin, +{ + type Output = Result<(Connection, Option>>), DispatchError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + match Pin::new(&mut this.handshake).poll(cx)? { + // return the timer on success handshake. It can be re-used for h2 ping-pong. + Poll::Ready(conn) => Poll::Ready(Ok((conn, this.timer.take()))), + Poll::Pending => match this.timer.as_mut() { + Some(timer) => { + ready!(timer.as_mut().poll(cx)); + Poll::Ready(Err(DispatchError::SlowRequestTimeout)) + } + None => Poll::Pending, + }, + } + } +} diff --git a/actix-http/src/h2/service.rs b/actix-http/src/h2/service.rs index 09e24045b..8a9061b94 100644 --- a/actix-http/src/h2/service.rs +++ b/actix-http/src/h2/service.rs @@ -1,5 +1,4 @@ use std::{ - error::Error as StdError, future::Future, marker::PhantomData, net, @@ -15,20 +14,18 @@ use actix_service::{ ServiceFactoryExt as _, }; use actix_utils::future::ready; -use bytes::Bytes; use futures_core::{future::LocalBoxFuture, ready}; -use h2::server::{handshake as h2_handshake, Handshake as H2Handshake}; use log::error; use crate::{ - body::{AnyBody, MessageBody}, + body::{BoxBody, MessageBody}, config::ServiceConfig, error::DispatchError, service::HttpFlow, ConnectCallback, OnConnectData, Request, Response, }; -use super::dispatcher::Dispatcher; +use super::{dispatcher::Dispatcher, handshake_with_timeout, HandshakeWithTimeout}; /// `ServiceFactory` implementation for HTTP/2 transport pub struct H2Service { @@ -41,12 +38,11 @@ pub struct H2Service { impl H2Service where S: ServiceFactory, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::Response: Into> + 'static, >::Future: 'static, B: MessageBody + 'static, - B::Error: Into>, { /// Create new `H2Service` instance with config. pub(crate) fn with_config>( @@ -72,12 +68,11 @@ impl H2Service where S: ServiceFactory, S::Future: 'static, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::Response: Into> + 'static, >::Future: 'static, B: MessageBody + 'static, - B::Error: Into>, { /// Create plain TCP based service pub fn tcp( @@ -101,9 +96,14 @@ where #[cfg(feature = "openssl")] mod openssl { - use actix_service::{fn_factory, fn_service, ServiceFactoryExt}; - use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, TlsStream}; - use actix_tls::accept::TlsError; + use actix_service::ServiceFactoryExt as _; + use actix_tls::accept::{ + openssl::{ + reexports::{Error as SslError, SslAcceptor}, + Acceptor, TlsStream, + }, + TlsError, + }; use super::*; @@ -111,14 +111,13 @@ mod openssl { where S: ServiceFactory, S::Future: 'static, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::Response: Into> + 'static, >::Future: 'static, B: MessageBody + 'static, - B::Error: Into>, { - /// Create OpenSSL based service + /// Create OpenSSL based service. pub fn openssl( self, acceptor: SslAcceptor, @@ -130,16 +129,14 @@ mod openssl { InitError = S::InitError, > { Acceptor::new(acceptor) - .map_err(TlsError::Tls) - .map_init_err(|_| panic!()) - .and_then(fn_factory(|| { - ready(Ok::<_, S::InitError>(fn_service( - |io: TlsStream| { - let peer_addr = io.get_ref().peer_addr().ok(); - ready(Ok((io, peer_addr))) - }, - ))) - })) + .map_init_err(|_| { + unreachable!("TLS acceptor service factory does not error on init") + }) + .map_err(TlsError::into_service_error) + .map(|io: TlsStream| { + let peer_addr = io.get_ref().peer_addr().ok(); + (io, peer_addr) + }) .and_then(self.map_err(TlsError::Service)) } } @@ -147,24 +144,27 @@ mod openssl { #[cfg(feature = "rustls")] mod rustls { - use super::*; - use actix_service::ServiceFactoryExt; - use actix_tls::accept::rustls::{Acceptor, ServerConfig, TlsStream}; - use actix_tls::accept::TlsError; use std::io; + use actix_service::ServiceFactoryExt as _; + use actix_tls::accept::{ + rustls::{reexports::ServerConfig, Acceptor, TlsStream}, + TlsError, + }; + + use super::*; + impl H2Service, S, B> where S: ServiceFactory, S::Future: 'static, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::Response: Into> + 'static, >::Future: 'static, B: MessageBody + 'static, - B::Error: Into>, { - /// Create Rustls based service + /// Create Rustls based service. pub fn rustls( self, mut config: ServerConfig, @@ -177,19 +177,17 @@ mod rustls { > { let mut protos = vec![b"h2".to_vec()]; protos.extend_from_slice(&config.alpn_protocols); - config.set_protocols(&protos); + config.alpn_protocols = protos; Acceptor::new(config) - .map_err(TlsError::Tls) - .map_init_err(|_| panic!()) - .and_then(fn_factory(|| { - ready(Ok::<_, S::InitError>(fn_service( - |io: TlsStream| { - let peer_addr = io.get_ref().0.peer_addr().ok(); - ready(Ok((io, peer_addr))) - }, - ))) - })) + .map_init_err(|_| { + unreachable!("TLS acceptor service factory does not error on init") + }) + .map_err(TlsError::into_service_error) + .map(|io: TlsStream| { + let peer_addr = io.get_ref().0.peer_addr().ok(); + (io, peer_addr) + }) .and_then(self.map_err(TlsError::Service)) } } @@ -201,12 +199,11 @@ where S: ServiceFactory, S::Future: 'static, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::Response: Into> + 'static, >::Future: 'static, B: MessageBody + 'static, - B::Error: Into>, { type Response = (); type Error = DispatchError; @@ -241,7 +238,7 @@ where impl H2ServiceHandler where S: Service, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody + 'static, @@ -264,11 +261,10 @@ impl Service<(T, Option)> for H2ServiceHandler, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody + 'static, - B::Error: Into>, { type Response = (); type Error = DispatchError; @@ -292,7 +288,7 @@ where Some(self.cfg.clone()), addr, on_connect_data, - h2_handshake(io), + handshake_with_timeout(io, &self.cfg), ), } } @@ -309,7 +305,7 @@ where Option, Option, OnConnectData, - H2Handshake, + HandshakeWithTimeout, ), } @@ -317,7 +313,7 @@ pub struct H2ServiceHandlerResponse where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody + 'static, @@ -329,11 +325,10 @@ impl Future for H2ServiceHandlerResponse where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody, - B::Error: Into>, { type Output = Result<(), DispatchError>; @@ -347,7 +342,7 @@ where ref mut on_connect_data, ref mut handshake, ) => match ready!(Pin::new(handshake).poll(cx)) { - Ok(conn) => { + Ok((conn, timer)) => { let on_connect_data = std::mem::take(on_connect_data); self.state = State::Incoming(Dispatcher::new( srv.take().unwrap(), @@ -355,12 +350,13 @@ where on_connect_data, config.take().unwrap(), *peer_addr, + timer, )); self.poll(cx) } Err(err) => { trace!("H2 handshake error: {}", err); - Poll::Ready(Err(err.into())) + Poll::Ready(Err(err)) } }, } diff --git a/actix-http/src/header/as_name.rs b/actix-http/src/header/as_name.rs index 5ce321566..04d32c41d 100644 --- a/actix-http/src/header/as_name.rs +++ b/actix-http/src/header/as_name.rs @@ -1,11 +1,12 @@ -//! Helper trait for types that can be effectively borrowed as a [HeaderValue]. -//! -//! [HeaderValue]: crate::http::HeaderValue +//! Sealed [`AsHeaderName`] trait and implementations. -use std::{borrow::Cow, str::FromStr}; +use std::{borrow::Cow, str::FromStr as _}; use http::header::{HeaderName, InvalidHeaderName}; +/// Sealed trait implemented for types that can be effectively borrowed as a [`HeaderValue`]. +/// +/// [`HeaderValue`]: crate::http::HeaderValue pub trait AsHeaderName: Sealed {} pub struct Seal; diff --git a/actix-http/src/header/into_pair.rs b/actix-http/src/header/into_pair.rs index d0d6e7324..472700548 100644 --- a/actix-http/src/header/into_pair.rs +++ b/actix-http/src/header/into_pair.rs @@ -1,4 +1,6 @@ -use std::convert::TryFrom; +//! [`IntoHeaderPair`] trait and implementations. + +use std::convert::TryFrom as _; use http::{ header::{HeaderName, InvalidHeaderName, InvalidHeaderValue}, @@ -7,7 +9,10 @@ use http::{ use super::{Header, IntoHeaderValue}; -/// Transforms structures into header K/V pairs for inserting into `HeaderMap`s. +/// An interface for types that can be converted into a [`HeaderName`]/[`HeaderValue`] pair for +/// insertion into a [`HeaderMap`]. +/// +/// [`HeaderMap`]: crate::http::HeaderMap pub trait IntoHeaderPair: Sized { type Error: Into; diff --git a/actix-http/src/header/into_value.rs b/actix-http/src/header/into_value.rs index 4ba58e726..bad05db64 100644 --- a/actix-http/src/header/into_value.rs +++ b/actix-http/src/header/into_value.rs @@ -1,10 +1,12 @@ -use std::convert::TryFrom; +//! [`IntoHeaderValue`] trait and implementations. + +use std::convert::TryFrom as _; use bytes::Bytes; use http::{header::InvalidHeaderValue, Error as HttpError, HeaderValue}; use mime::Mime; -/// A trait for any object that can be Converted to a `HeaderValue` +/// An interface for types that can be converted into a [`HeaderValue`]. pub trait IntoHeaderValue: Sized { /// The type returned in the event of a conversion error. type Error: Into; diff --git a/actix-http/src/header/map.rs b/actix-http/src/header/map.rs index 634d9282f..dd852b021 100644 --- a/actix-http/src/header/map.rs +++ b/actix-http/src/header/map.rs @@ -1,6 +1,6 @@ //! A multi-value [`HeaderMap`] and its iterators. -use std::{borrow::Cow, collections::hash_map, ops}; +use std::{borrow::Cow, collections::hash_map, iter, ops}; use ahash::AHashMap; use http::header::{HeaderName, HeaderValue}; @@ -288,7 +288,7 @@ impl HeaderMap { /// Returns an iterator over all values associated with a header name. /// /// The returned iterator does not incur any allocations and will yield no items if there are no - /// values associated with the key. Iteration order is **not** guaranteed to be the same as + /// values associated with the key. Iteration order is guaranteed to be the same as /// insertion order. /// /// # Examples @@ -355,6 +355,19 @@ impl HeaderMap { /// /// assert_eq!(map.len(), 1); /// ``` + /// + /// A convenience method is provided on the returned iterator to check if the insertion replaced + /// any values. + /// ``` + /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); + /// + /// let removed = map.insert(header::ACCEPT, HeaderValue::from_static("text/plain")); + /// assert!(removed.is_empty()); + /// + /// let removed = map.insert(header::ACCEPT, HeaderValue::from_static("text/html")); + /// assert!(!removed.is_empty()); + /// ``` pub fn insert(&mut self, key: HeaderName, val: HeaderValue) -> Removed { let value = self.inner.insert(key, Value::one(val)); Removed::new(value) @@ -393,6 +406,9 @@ impl HeaderMap { /// Removes all headers for a particular header name from the map. /// + /// Providing an invalid header names (as a string argument) will have no effect and return + /// without error. + /// /// # Examples /// ``` /// # use actix_http::http::{header, HeaderMap, HeaderValue}; @@ -409,6 +425,21 @@ impl HeaderMap { /// assert!(removed.next().is_none()); /// /// assert!(map.is_empty()); + /// ``` + /// + /// A convenience method is provided on the returned iterator to check if the `remove` call + /// actually removed any values. + /// ``` + /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); + /// + /// let removed = map.remove("accept"); + /// assert!(removed.is_empty()); + /// + /// map.insert(header::ACCEPT, HeaderValue::from_static("text/html")); + /// let removed = map.remove("accept"); + /// assert!(!removed.is_empty()); + /// ``` pub fn remove(&mut self, key: impl AsHeaderName) -> Removed { let value = match key.try_as_name(super::as_name::Seal) { Ok(Cow::Borrowed(name)) => self.inner.remove(name), @@ -550,7 +581,8 @@ impl HeaderMap { } } -/// Note that this implementation will clone a [HeaderName] for each value. +/// Note that this implementation will clone a [HeaderName] for each value. Consider using +/// [`drain`](Self::drain) to control header name cloning. impl IntoIterator for HeaderMap { type Item = (HeaderName, HeaderValue); type IntoIter = IntoIter; @@ -571,7 +603,7 @@ impl<'a> IntoIterator for &'a HeaderMap { } } -/// Iterator for all values with the same header name. +/// Iterator over borrowed values with the same associated name. /// /// See [`HeaderMap::get_all`]. #[derive(Debug)] @@ -613,18 +645,36 @@ impl<'a> Iterator for GetAll<'a> { } } -/// Iterator for owned [`HeaderValue`]s with the same associated [`HeaderName`] returned from methods -/// on [`HeaderMap`] that remove or replace items. +impl ExactSizeIterator for GetAll<'_> {} + +impl iter::FusedIterator for GetAll<'_> {} + +/// Iterator over removed, owned values with the same associated name. +/// +/// Returned from methods that remove or replace items. See [`HeaderMap::insert`] +/// and [`HeaderMap::remove`]. #[derive(Debug)] pub struct Removed { inner: Option>, } -impl<'a> Removed { +impl Removed { fn new(value: Option) -> Self { let inner = value.map(|value| value.inner.into_iter()); Self { inner } } + + /// Returns true if iterator contains no elements, without consuming it. + /// + /// If called immediately after [`HeaderMap::insert`] or [`HeaderMap::remove`], it will indicate + /// wether any items were actually replaced or removed, respectively. + pub fn is_empty(&self) -> bool { + match self.inner { + // size hint lower bound of smallvec is the correct length + Some(ref iter) => iter.size_hint().0 == 0, + None => true, + } + } } impl Iterator for Removed { @@ -644,7 +694,11 @@ impl Iterator for Removed { } } -/// Iterator over all [`HeaderName`]s in the map. +impl ExactSizeIterator for Removed {} + +impl iter::FusedIterator for Removed {} + +/// Iterator over all names in the map. #[derive(Debug)] pub struct Keys<'a>(hash_map::Keys<'a, HeaderName, Value>); @@ -662,6 +716,11 @@ impl<'a> Iterator for Keys<'a> { } } +impl ExactSizeIterator for Keys<'_> {} + +impl iter::FusedIterator for Keys<'_> {} + +/// Iterator over borrowed name-value pairs. #[derive(Debug)] pub struct Iter<'a> { inner: hash_map::Iter<'a, HeaderName, Value>, @@ -684,7 +743,7 @@ impl<'a> Iterator for Iter<'a> { fn next(&mut self) -> Option { // handle in-progress multi value lists first - if let Some((ref name, ref mut vals)) = self.multi_inner { + if let Some((name, ref mut vals)) = self.multi_inner { match vals.get(self.multi_idx) { Some(val) => { self.multi_idx += 1; @@ -713,6 +772,10 @@ impl<'a> Iterator for Iter<'a> { } } +impl ExactSizeIterator for Iter<'_> {} + +impl iter::FusedIterator for Iter<'_> {} + /// Iterator over drained name-value pairs. /// /// Iterator items are `(Option, HeaderValue)` to avoid cloning. @@ -764,6 +827,10 @@ impl<'a> Iterator for Drain<'a> { } } +impl ExactSizeIterator for Drain<'_> {} + +impl iter::FusedIterator for Drain<'_> {} + /// Iterator over owned name-value pairs. /// /// Implementation necessarily clones header names for each value. @@ -814,12 +881,27 @@ impl Iterator for IntoIter { } } +impl ExactSizeIterator for IntoIter {} + +impl iter::FusedIterator for IntoIter {} + #[cfg(test)] mod tests { + use std::iter::FusedIterator; + use http::header; + use static_assertions::assert_impl_all; use super::*; + assert_impl_all!(HeaderMap: IntoIterator); + assert_impl_all!(Keys<'_>: Iterator, ExactSizeIterator, FusedIterator); + assert_impl_all!(GetAll<'_>: Iterator, ExactSizeIterator, FusedIterator); + assert_impl_all!(Removed: Iterator, ExactSizeIterator, FusedIterator); + assert_impl_all!(Iter<'_>: Iterator, ExactSizeIterator, FusedIterator); + assert_impl_all!(IntoIter: Iterator, ExactSizeIterator, FusedIterator); + assert_impl_all!(Drain<'_>: Iterator, ExactSizeIterator, FusedIterator); + #[test] fn create() { let map = HeaderMap::new(); @@ -945,6 +1027,56 @@ mod tests { assert_eq!(vals.next(), removed.next().as_ref()); } + #[test] + fn get_all_iteration_order_matches_insertion_order() { + let mut map = HeaderMap::new(); + + let mut vals = map.get_all(header::COOKIE); + assert!(vals.next().is_none()); + + map.append(header::COOKIE, HeaderValue::from_static("1")); + let mut vals = map.get_all(header::COOKIE); + assert_eq!(vals.next().unwrap().as_bytes(), b"1"); + assert!(vals.next().is_none()); + + map.append(header::COOKIE, HeaderValue::from_static("2")); + let mut vals = map.get_all(header::COOKIE); + assert_eq!(vals.next().unwrap().as_bytes(), b"1"); + assert_eq!(vals.next().unwrap().as_bytes(), b"2"); + assert!(vals.next().is_none()); + + map.append(header::COOKIE, HeaderValue::from_static("3")); + map.append(header::COOKIE, HeaderValue::from_static("4")); + map.append(header::COOKIE, HeaderValue::from_static("5")); + let mut vals = map.get_all(header::COOKIE); + assert_eq!(vals.next().unwrap().as_bytes(), b"1"); + assert_eq!(vals.next().unwrap().as_bytes(), b"2"); + assert_eq!(vals.next().unwrap().as_bytes(), b"3"); + assert_eq!(vals.next().unwrap().as_bytes(), b"4"); + assert_eq!(vals.next().unwrap().as_bytes(), b"5"); + assert!(vals.next().is_none()); + + let _ = map.insert(header::COOKIE, HeaderValue::from_static("6")); + let mut vals = map.get_all(header::COOKIE); + assert_eq!(vals.next().unwrap().as_bytes(), b"6"); + assert!(vals.next().is_none()); + + let _ = map.insert(header::COOKIE, HeaderValue::from_static("7")); + let _ = map.insert(header::COOKIE, HeaderValue::from_static("8")); + let mut vals = map.get_all(header::COOKIE); + assert_eq!(vals.next().unwrap().as_bytes(), b"8"); + assert!(vals.next().is_none()); + + map.append(header::COOKIE, HeaderValue::from_static("9")); + let mut vals = map.get_all(header::COOKIE); + assert_eq!(vals.next().unwrap().as_bytes(), b"8"); + assert_eq!(vals.next().unwrap().as_bytes(), b"9"); + assert!(vals.next().is_none()); + + // check for fused-ness + assert!(vals.next().is_none()); + } + fn owned_pair<'a>( (name, val): (&'a HeaderName, &'a HeaderValue), ) -> (HeaderName, HeaderValue) { diff --git a/actix-http/src/header/mod.rs b/actix-http/src/header/mod.rs index 18494f555..381842e74 100644 --- a/actix-http/src/header/mod.rs +++ b/actix-http/src/header/mod.rs @@ -29,35 +29,34 @@ pub use http::header::{ X_FRAME_OPTIONS, X_XSS_PROTECTION, }; -use crate::error::ParseError; -use crate::HttpMessage; +use crate::{error::ParseError, HttpMessage}; mod as_name; mod into_pair; mod into_value; -mod utils; - -pub(crate) mod map; +pub mod map; mod shared; - -#[doc(hidden)] -pub use self::shared::*; +mod utils; pub use self::as_name::AsHeaderName; pub use self::into_pair::IntoHeaderPair; pub use self::into_value::IntoHeaderValue; -#[doc(hidden)] -pub use self::map::GetAll; pub use self::map::HeaderMap; -pub use self::utils::*; +pub use self::shared::{ + parse_extended_value, q, Charset, ContentEncoding, ExtendedValue, HttpDate, + LanguageTag, Quality, QualityItem, +}; +pub use self::utils::{ + fmt_comma_delimited, from_comma_delimited, from_one_raw_str, http_percent_encode, +}; -/// A trait for any object that already represents a valid header field and value. +/// An interface for types that already represent a valid header. pub trait Header: IntoHeaderValue { /// Returns the name of the header field fn name() -> HeaderName; /// Parse a header - fn parse(msg: &T) -> Result; + fn parse(msg: &M) -> Result; } /// Convert `http::HeaderMap` to our `HeaderMap`. @@ -68,7 +67,7 @@ impl From for HeaderMap { } /// This encode set is used for HTTP header values and is defined at -/// https://tools.ietf.org/html/rfc5987#section-3.2. +/// . pub(crate) const HTTP_VALUE: &AsciiSet = &CONTROLS .add(b' ') .add(b'"') diff --git a/actix-http/src/header/shared/charset.rs b/actix-http/src/header/shared/charset.rs index b482f6bce..1e77e1be8 100644 --- a/actix-http/src/header/shared/charset.rs +++ b/actix-http/src/header/shared/charset.rs @@ -1,14 +1,13 @@ -use std::fmt::{self, Display}; -use std::str::FromStr; +use std::{fmt, str}; use self::Charset::*; -/// A Mime charset. +/// A MIME character set. /// /// The string representation is normalized to upper case. /// /// See . -#[derive(Clone, Debug, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] #[allow(non_camel_case_types)] pub enum Charset { /// US ASCII @@ -88,20 +87,20 @@ impl Charset { Iso_8859_8_E => "ISO-8859-8-E", Iso_8859_8_I => "ISO-8859-8-I", Gb2312 => "GB2312", - Big5 => "big5", + Big5 => "Big5", Koi8_R => "KOI8-R", Ext(ref s) => s, } } } -impl Display for Charset { +impl fmt::Display for Charset { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(self.label()) } } -impl FromStr for Charset { +impl str::FromStr for Charset { type Err = crate::Error; fn from_str(s: &str) -> Result { @@ -128,7 +127,7 @@ impl FromStr for Charset { "ISO-8859-8-E" => Iso_8859_8_E, "ISO-8859-8-I" => Iso_8859_8_I, "GB2312" => Gb2312, - "big5" => Big5, + "BIG5" => Big5, "KOI8-R" => Koi8_R, s => Ext(s.to_owned()), }) diff --git a/actix-http/src/header/shared/content_encoding.rs b/actix-http/src/header/shared/content_encoding.rs index b9c1d2795..073d90dce 100644 --- a/actix-http/src/header/shared/content_encoding.rs +++ b/actix-http/src/header/shared/content_encoding.rs @@ -1,5 +1,6 @@ -use std::{convert::Infallible, str::FromStr}; +use std::{convert::TryFrom, str::FromStr}; +use derive_more::{Display, Error}; use http::header::InvalidHeaderValue; use crate::{ @@ -8,8 +9,19 @@ use crate::{ HttpMessage, }; +/// Error returned when a content encoding is unknown. +#[derive(Debug, Display, Error)] +#[display(fmt = "unsupported content encoding")] +pub struct ContentEncodingParseError; + /// Represents a supported content encoding. -#[derive(Copy, Clone, PartialEq, Debug)] +/// +/// Includes a commonly-used subset of media types appropriate for use as HTTP content encodings. +/// See [IANA HTTP Content Coding Registry]. +/// +/// [IANA HTTP Content Coding Registry]: https://www.iana.org/assignments/http-parameters/http-parameters.xhtml +#[derive(Debug, Clone, Copy, PartialEq)] +#[non_exhaustive] pub enum ContentEncoding { /// Automatically select encoding based on encoding negotiation. Auto, @@ -23,7 +35,7 @@ pub enum ContentEncoding { /// Gzip algorithm. Gzip, - // Zstd algorithm. + /// Zstd algorithm. Zstd, /// Indicates the identity function (i.e. no compression, nor modification). @@ -37,7 +49,7 @@ impl ContentEncoding { matches!(self, ContentEncoding::Identity | ContentEncoding::Auto) } - /// Convert content encoding to string + /// Convert content encoding to string. #[inline] pub fn as_str(self) -> &'static str { match self { @@ -48,18 +60,6 @@ impl ContentEncoding { ContentEncoding::Identity | ContentEncoding::Auto => "identity", } } - - /// Default Q-factor (quality) value. - #[inline] - pub fn quality(self) -> f64 { - match self { - ContentEncoding::Br => 1.1, - ContentEncoding::Gzip => 1.0, - ContentEncoding::Deflate => 0.9, - ContentEncoding::Identity | ContentEncoding::Auto => 0.1, - ContentEncoding::Zstd => 0.0, - } - } } impl Default for ContentEncoding { @@ -69,31 +69,33 @@ impl Default for ContentEncoding { } impl FromStr for ContentEncoding { - type Err = Infallible; + type Err = ContentEncodingParseError; fn from_str(val: &str) -> Result { - Ok(Self::from(val)) - } -} - -impl From<&str> for ContentEncoding { - fn from(val: &str) -> ContentEncoding { let val = val.trim(); if val.eq_ignore_ascii_case("br") { - ContentEncoding::Br + Ok(ContentEncoding::Br) } else if val.eq_ignore_ascii_case("gzip") { - ContentEncoding::Gzip + Ok(ContentEncoding::Gzip) } else if val.eq_ignore_ascii_case("deflate") { - ContentEncoding::Deflate + Ok(ContentEncoding::Deflate) } else if val.eq_ignore_ascii_case("zstd") { - ContentEncoding::Zstd + Ok(ContentEncoding::Zstd) } else { - ContentEncoding::default() + Err(ContentEncodingParseError) } } } +impl TryFrom<&str> for ContentEncoding { + type Error = ContentEncodingParseError; + + fn try_from(val: &str) -> Result { + val.parse() + } +} + impl IntoHeaderValue for ContentEncoding { type Error = InvalidHeaderValue; diff --git a/actix-http/src/header/shared/extended.rs b/actix-http/src/header/shared/extended.rs index 9fd4cdfb0..60f2d359e 100644 --- a/actix-http/src/header/shared/extended.rs +++ b/actix-http/src/header/shared/extended.rs @@ -1,17 +1,17 @@ +//! Originally taken from `hyper::header::parsing`. + use std::{fmt, str::FromStr}; use language_tags::LanguageTag; use crate::header::{Charset, HTTP_VALUE}; -// From hyper v0.11.27 src/header/parsing.rs - /// The value part of an extended parameter consisting of three parts: /// - The REQUIRED character set name (`charset`). /// - The OPTIONAL language information (`language_tag`). /// - A character sequence representing the actual value (`value`), separated by single quotes. /// -/// It is defined in [RFC 5987](https://tools.ietf.org/html/rfc5987#section-3.2). +/// It is defined in [RFC 5987 §3.2](https://datatracker.ietf.org/doc/html/rfc5987#section-3.2). #[derive(Clone, Debug, PartialEq)] pub struct ExtendedValue { /// The character set that is used to encode the `value` to a string. @@ -24,17 +24,17 @@ pub struct ExtendedValue { pub value: Vec, } -/// Parses extended header parameter values (`ext-value`), as defined in -/// [RFC 5987](https://tools.ietf.org/html/rfc5987#section-3.2). +/// Parses extended header parameter values (`ext-value`), as defined +/// in [RFC 5987 §3.2](https://datatracker.ietf.org/doc/html/rfc5987#section-3.2). /// /// Extended values are denoted by parameter names that end with `*`. /// /// ## ABNF /// -/// ```text +/// ```plain /// ext-value = charset "'" [ language ] "'" value-chars /// ; like RFC 2231's -/// ; (see [RFC2231], Section 7) +/// ; (see [RFC 2231 §7]) /// /// charset = "UTF-8" / "ISO-8859-1" / mime-charset /// @@ -43,22 +43,26 @@ pub struct ExtendedValue { /// / "!" / "#" / "$" / "%" / "&" /// / "+" / "-" / "^" / "_" / "`" /// / "{" / "}" / "~" -/// ; as in Section 2.3 of [RFC2978] +/// ; as in [RFC 2978 §2.3] /// ; except that the single quote is not included /// ; SHOULD be registered in the IANA charset registry /// -/// language = +/// language = /// /// value-chars = *( pct-encoded / attr-char ) /// /// pct-encoded = "%" HEXDIG HEXDIG -/// ; see [RFC3986], Section 2.1 +/// ; see [RFC 3986 §2.1] /// /// attr-char = ALPHA / DIGIT /// / "!" / "#" / "$" / "&" / "+" / "-" / "." /// / "^" / "_" / "`" / "|" / "~" /// ; token except ( "*" / "'" / "%" ) /// ``` +/// +/// [RFC 2231 §7]: https://datatracker.ietf.org/doc/html/rfc2231#section-7 +/// [RFC 2978 §2.3]: https://datatracker.ietf.org/doc/html/rfc2978#section-2.3 +/// [RFC 3986 §2.1]: https://datatracker.ietf.org/doc/html/rfc5646#section-2.1 pub fn parse_extended_value( val: &str, ) -> Result { diff --git a/actix-http/src/header/shared/http_date.rs b/actix-http/src/header/shared/http_date.rs new file mode 100644 index 000000000..8dbdf4a62 --- /dev/null +++ b/actix-http/src/header/shared/http_date.rs @@ -0,0 +1,82 @@ +use std::{fmt, io::Write, str::FromStr, time::SystemTime}; + +use bytes::BytesMut; +use http::header::{HeaderValue, InvalidHeaderValue}; + +use crate::{ + config::DATE_VALUE_LENGTH, error::ParseError, header::IntoHeaderValue, + helpers::MutWriter, +}; + +/// A timestamp with HTTP-style formatting and parsing. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct HttpDate(SystemTime); + +impl FromStr for HttpDate { + type Err = ParseError; + + fn from_str(s: &str) -> Result { + match httpdate::parse_http_date(s) { + Ok(sys_time) => Ok(HttpDate(sys_time)), + Err(_) => Err(ParseError::Header), + } + } +} + +impl fmt::Display for HttpDate { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let date_str = httpdate::fmt_http_date(self.0); + f.write_str(&date_str) + } +} + +impl IntoHeaderValue for HttpDate { + type Error = InvalidHeaderValue; + + fn try_into_value(self) -> Result { + let mut buf = BytesMut::with_capacity(DATE_VALUE_LENGTH); + let mut wrt = MutWriter(&mut buf); + + // unwrap: date output is known to be well formed and of known length + write!(wrt, "{}", httpdate::fmt_http_date(self.0)).unwrap(); + + HeaderValue::from_maybe_shared(buf.split().freeze()) + } +} + +impl From for HttpDate { + fn from(sys_time: SystemTime) -> HttpDate { + HttpDate(sys_time) + } +} + +impl From for SystemTime { + fn from(HttpDate(sys_time): HttpDate) -> SystemTime { + sys_time + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::*; + + #[test] + fn date_header() { + macro_rules! assert_parsed_date { + ($case:expr, $exp:expr) => { + assert_eq!($case.parse::().unwrap(), $exp); + }; + } + + // 784198117 = SystemTime::from(datetime!(1994-11-07 08:48:37).assume_utc()).duration_since(SystemTime::UNIX_EPOCH)); + let nov_07 = HttpDate(SystemTime::UNIX_EPOCH + Duration::from_secs(784198117)); + + assert_parsed_date!("Mon, 07 Nov 1994 08:48:37 GMT", nov_07); + assert_parsed_date!("Monday, 07-Nov-94 08:48:37 GMT", nov_07); + assert_parsed_date!("Mon Nov 7 08:48:37 1994", nov_07); + + assert!("this-is-no-date".parse::().is_err()); + } +} diff --git a/actix-http/src/header/shared/httpdate.rs b/actix-http/src/header/shared/httpdate.rs deleted file mode 100644 index 18278a6d8..000000000 --- a/actix-http/src/header/shared/httpdate.rs +++ /dev/null @@ -1,97 +0,0 @@ -use std::{ - fmt, - io::Write, - str::FromStr, - time::{SystemTime, UNIX_EPOCH}, -}; - -use bytes::buf::BufMut; -use bytes::BytesMut; -use http::header::{HeaderValue, InvalidHeaderValue}; -use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset}; - -use crate::error::ParseError; -use crate::header::IntoHeaderValue; -use crate::time_parser; - -/// A timestamp with HTTP formatting and parsing. -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct HttpDate(OffsetDateTime); - -impl FromStr for HttpDate { - type Err = ParseError; - - fn from_str(s: &str) -> Result { - match time_parser::parse_http_date(s) { - Some(t) => Ok(HttpDate(t.assume_utc())), - None => Err(ParseError::Header), - } - } -} - -impl fmt::Display for HttpDate { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.0.format("%a, %d %b %Y %H:%M:%S GMT"), f) - } -} - -impl From for HttpDate { - fn from(sys: SystemTime) -> HttpDate { - HttpDate(PrimitiveDateTime::from(sys).assume_utc()) - } -} - -impl IntoHeaderValue for HttpDate { - type Error = InvalidHeaderValue; - - fn try_into_value(self) -> Result { - let mut wrt = BytesMut::with_capacity(29).writer(); - write!( - wrt, - "{}", - self.0 - .to_offset(UtcOffset::UTC) - .format("%a, %d %b %Y %H:%M:%S GMT") - ) - .unwrap(); - HeaderValue::from_maybe_shared(wrt.get_mut().split().freeze()) - } -} - -impl From for SystemTime { - fn from(date: HttpDate) -> SystemTime { - let dt = date.0; - let epoch = OffsetDateTime::unix_epoch(); - - UNIX_EPOCH + (dt - epoch) - } -} - -#[cfg(test)] -mod tests { - use super::HttpDate; - use time::{date, time, PrimitiveDateTime}; - - #[test] - fn test_date() { - let nov_07 = HttpDate( - PrimitiveDateTime::new(date!(1994 - 11 - 07), time!(8:48:37)).assume_utc(), - ); - - assert_eq!( - "Sun, 07 Nov 1994 08:48:37 GMT".parse::().unwrap(), - nov_07 - ); - assert_eq!( - "Sunday, 07-Nov-94 08:48:37 GMT" - .parse::() - .unwrap(), - nov_07 - ); - assert_eq!( - "Sun Nov 7 08:48:37 1994".parse::().unwrap(), - nov_07 - ); - assert!("this-is-no-date".parse::().is_err()); - } -} diff --git a/actix-http/src/header/shared/mod.rs b/actix-http/src/header/shared/mod.rs index b8f9173f9..257e54d7a 100644 --- a/actix-http/src/header/shared/mod.rs +++ b/actix-http/src/header/shared/mod.rs @@ -3,12 +3,14 @@ mod charset; mod content_encoding; mod extended; -mod httpdate; +mod http_date; +mod quality; mod quality_item; pub use self::charset::Charset; pub use self::content_encoding::ContentEncoding; pub use self::extended::{parse_extended_value, ExtendedValue}; -pub use self::httpdate::HttpDate; -pub use self::quality_item::{q, qitem, Quality, QualityItem}; +pub use self::http_date::HttpDate; +pub use self::quality::{q, Quality}; +pub use self::quality_item::QualityItem; pub use language_tags::LanguageTag; diff --git a/actix-http/src/header/shared/quality.rs b/actix-http/src/header/shared/quality.rs new file mode 100644 index 000000000..5321c754d --- /dev/null +++ b/actix-http/src/header/shared/quality.rs @@ -0,0 +1,208 @@ +use std::{ + convert::{TryFrom, TryInto}, + fmt, +}; + +use derive_more::{Display, Error}; + +const MAX_QUALITY_INT: u16 = 1000; +const MAX_QUALITY_FLOAT: f32 = 1.0; + +/// Represents a quality used in q-factor values. +/// +/// The default value is equivalent to `q=1.0` (the [max](Self::MAX) value). +/// +/// # Implementation notes +/// The quality value is defined as a number between 0.0 and 1.0 with three decimal places. +/// This means there are 1001 possible values. Since floating point numbers are not exact and the +/// smallest floating point data type (`f32`) consumes four bytes, we use an `u16` value to store +/// the quality internally. +/// +/// [RFC 7231 §5.3.1] gives more information on quality values in HTTP header fields. +/// +/// # Examples +/// ``` +/// use actix_http::header::{Quality, q}; +/// assert_eq!(q(1.0), Quality::MAX); +/// +/// assert_eq!(q(0.42).to_string(), "0.42"); +/// assert_eq!(q(1.0).to_string(), "1"); +/// assert_eq!(Quality::MIN.to_string(), "0"); +/// ``` +/// +/// [RFC 7231 §5.3.1]: https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.1 +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct Quality(pub(super) u16); + +impl Quality { + /// The maximum quality value, equivalent to `q=1.0`. + pub const MAX: Quality = Quality(MAX_QUALITY_INT); + + /// The minimum quality value, equivalent to `q=0.0`. + pub const MIN: Quality = Quality(0); + + /// Converts a float in the range 0.0–1.0 to a `Quality`. + /// + /// Intentionally private. External uses should rely on the `TryFrom` impl. + /// + /// # Panics + /// Panics in debug mode when value is not in the range 0.0 <= n <= 1.0. + fn from_f32(value: f32) -> Self { + // Check that `value` is within range should be done before calling this method. + // Just in case, this debug_assert should catch if we were forgetful. + debug_assert!( + (0.0f32..=1.0f32).contains(&value), + "q value must be between 0.0 and 1.0" + ); + + Quality((value * MAX_QUALITY_INT as f32) as u16) + } +} + +/// The default value is [`Quality::MAX`]. +impl Default for Quality { + fn default() -> Quality { + Quality::MAX + } +} + +impl fmt::Display for Quality { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0 { + 0 => f.write_str("0"), + MAX_QUALITY_INT => f.write_str("1"), + + // some number in the range 1–999 + x => { + f.write_str("0.")?; + + // This implementation avoids string allocation for removing trailing zeroes. + // In benchmarks it is twice as fast as approach using something like + // `format!("{}").trim_end_matches('0')` for non-fast-path quality values. + + if x < 10 { + // x in is range 1–9 + + f.write_str("00")?; + + // 0 is already handled so it's not possible to have a trailing 0 in this range + // we can just write the integer + itoa::fmt(f, x) + } else if x < 100 { + // x in is range 10–99 + + f.write_str("0")?; + + if x % 10 == 0 { + // trailing 0, divide by 10 and write + itoa::fmt(f, x / 10) + } else { + itoa::fmt(f, x) + } + } else { + // x is in range 100–999 + + if x % 100 == 0 { + // two trailing 0s, divide by 100 and write + itoa::fmt(f, x / 100) + } else if x % 10 == 0 { + // one trailing 0, divide by 10 and write + itoa::fmt(f, x / 10) + } else { + itoa::fmt(f, x) + } + } + } + } + } +} + +#[derive(Debug, Clone, Display, Error)] +#[display(fmt = "quality out of bounds")] +#[non_exhaustive] +pub struct QualityOutOfBounds; + +impl TryFrom for Quality { + type Error = QualityOutOfBounds; + + #[inline] + fn try_from(value: f32) -> Result { + if (0.0..=MAX_QUALITY_FLOAT).contains(&value) { + Ok(Quality::from_f32(value)) + } else { + Err(QualityOutOfBounds) + } + } +} + +/// Convenience function to create a [`Quality`] from an `f32` (0.0–1.0). +/// +/// Not recommended for use with user input. Rely on the `TryFrom` impls where possible. +/// +/// # Panics +/// Panics if value is out of range. +/// +/// # Examples +/// ``` +/// # use actix_http::header::{q, Quality}; +/// let q1 = q(1.0); +/// assert_eq!(q1, Quality::MAX); +/// +/// let q2 = q(0.0); +/// assert_eq!(q2, Quality::MIN); +/// +/// let q3 = q(0.42); +/// ``` +/// +/// An out-of-range `f32` quality will panic. +/// ```should_panic +/// # use actix_http::header::q; +/// let _q2 = q(1.42); +/// ``` +#[inline] +pub fn q(quality: T) -> Quality +where + T: TryInto, + T::Error: fmt::Debug, +{ + quality.try_into().expect("quality value was out of bounds") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn q_helper() { + assert_eq!(q(0.5), Quality(500)); + } + + #[test] + fn display_output() { + assert_eq!(q(0.0).to_string(), "0"); + assert_eq!(q(1.0).to_string(), "1"); + assert_eq!(q(0.001).to_string(), "0.001"); + assert_eq!(q(0.5).to_string(), "0.5"); + assert_eq!(q(0.22).to_string(), "0.22"); + assert_eq!(q(0.123).to_string(), "0.123"); + assert_eq!(q(0.999).to_string(), "0.999"); + + for x in 0..=1000 { + // if trailing zeroes are handled correctly, we would not expect the serialized length + // to ever exceed "0." + 3 decimal places = 5 in length + assert!(q(x as f32 / 1000.0).to_string().len() <= 5); + } + } + + #[test] + #[should_panic] + fn negative_quality() { + q(-1.0); + } + + #[test] + #[should_panic] + fn quality_out_of_bounds() { + q(2.0); + } +} diff --git a/actix-http/src/header/shared/quality_item.rs b/actix-http/src/header/shared/quality_item.rs index 240a0afa2..9354915ad 100644 --- a/actix-http/src/header/shared/quality_item.rs +++ b/actix-http/src/header/shared/quality_item.rs @@ -1,101 +1,65 @@ -use std::{ - cmp, - convert::{TryFrom, TryInto}, - fmt, str, -}; +use std::{cmp, convert::TryFrom as _, fmt, str}; -use derive_more::{Display, Error}; +use crate::error::ParseError; -const MAX_QUALITY: u16 = 1000; -const MAX_FLOAT_QUALITY: f32 = 1.0; +use super::Quality; -/// Represents a quality used in quality values. +/// Represents an item with a quality value as defined +/// in [RFC 7231 §5.3.1](https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.1). /// -/// Can be created with the [`q`] function. +/// # Parsing and Formatting +/// This wrapper be used to parse header value items that have a q-factor annotation as well as +/// serialize items with a their q-factor. /// -/// # Implementation notes +/// # Ordering +/// Since this context of use for this type is header value items, ordering is defined for +/// `QualityItem`s but _only_ considers the item's quality. Order of appearance should be used as +/// the secondary sorting parameter; i.e., a stable sort over the quality values will produce a +/// correctly sorted sequence. /// -/// The quality value is defined as a number between 0 and 1 with three decimal -/// places. This means there are 1001 possible values. Since floating point -/// numbers are not exact and the smallest floating point data type (`f32`) -/// consumes four bytes, hyper uses an `u16` value to store the -/// quality internally. For performance reasons you may set quality directly to -/// a value between 0 and 1000 e.g. `Quality(532)` matches the quality -/// `q=0.532`. +/// # Examples +/// ``` +/// # use actix_http::header::{QualityItem, q}; +/// let q_item: QualityItem = "hello;q=0.3".parse().unwrap(); +/// assert_eq!(&q_item.item, "hello"); +/// assert_eq!(q_item.quality, q(0.3)); /// -/// [RFC7231 Section 5.3.1](https://tools.ietf.org/html/rfc7231#section-5.3.1) -/// gives more information on quality values in HTTP header fields. -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct Quality(u16); - -impl Quality { - /// # Panics - /// Panics in debug mode when value is not in the range 0.0 <= n <= 1.0. - fn from_f32(value: f32) -> Self { - // Check that `value` is within range should be done before calling this method. - // Just in case, this debug_assert should catch if we were forgetful. - debug_assert!( - (0.0f32..=1.0f32).contains(&value), - "q value must be between 0.0 and 1.0" - ); - - Quality((value * MAX_QUALITY as f32) as u16) - } -} - -impl Default for Quality { - fn default() -> Quality { - Quality(MAX_QUALITY) - } -} - -#[derive(Debug, Clone, Display, Error)] -pub struct QualityOutOfBounds; - -impl TryFrom for Quality { - type Error = QualityOutOfBounds; - - fn try_from(value: u16) -> Result { - if (0..=MAX_QUALITY).contains(&value) { - Ok(Quality(value)) - } else { - Err(QualityOutOfBounds) - } - } -} - -impl TryFrom for Quality { - type Error = QualityOutOfBounds; - - fn try_from(value: f32) -> Result { - if (0.0..=MAX_FLOAT_QUALITY).contains(&value) { - Ok(Quality::from_f32(value)) - } else { - Err(QualityOutOfBounds) - } - } -} - -/// Represents an item with a quality value as defined in -/// [RFC7231](https://tools.ietf.org/html/rfc7231#section-5.3.1). -#[derive(Clone, PartialEq, Debug)] +/// // note that format is normalized compared to parsed item +/// assert_eq!(q_item.to_string(), "hello; q=0.3"); +/// +/// // item with q=0.3 is greater than item with q=0.1 +/// let q_item_fallback: QualityItem = "abc;q=0.1".parse().unwrap(); +/// assert!(q_item > q_item_fallback); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq)] pub struct QualityItem { - /// The actual contents of the field. + /// The wrapped contents of the field. pub item: T, + /// The quality (client or server preference) for the value. pub quality: Quality, } impl QualityItem { - /// Creates a new `QualityItem` from an item and a quality. - /// The item can be of any type. - /// The quality should be a value in the range [0, 1]. - pub fn new(item: T, quality: Quality) -> QualityItem { + /// Constructs a new `QualityItem` from an item and a quality value. + /// + /// The item can be of any type. The quality should be a value in the range [0, 1]. + pub fn new(item: T, quality: Quality) -> Self { QualityItem { item, quality } } + + /// Constructs a new `QualityItem` from an item, using the maximum q-value. + pub fn max(item: T) -> Self { + Self::new(item, Quality::MAX) + } + + /// Constructs a new `QualityItem` from an item, using the minimum q-value. + pub fn min(item: T) -> Self { + Self::new(item, Quality::MIN) + } } -impl cmp::PartialOrd for QualityItem { +impl PartialOrd for QualityItem { fn partial_cmp(&self, other: &QualityItem) -> Option { self.quality.partial_cmp(&other.quality) } @@ -105,97 +69,77 @@ impl fmt::Display for QualityItem { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&self.item, f)?; - match self.quality.0 { - MAX_QUALITY => Ok(()), - 0 => f.write_str("; q=0"), - x => write!(f, "; q=0.{}", format!("{:03}", x).trim_end_matches('0')), + match self.quality { + // q-factor value is implied for max value + Quality::MAX => Ok(()), + + Quality::MIN => f.write_str("; q=0"), + q => write!(f, "; q={}", q), } } } impl str::FromStr for QualityItem { - type Err = crate::error::ParseError; + type Err = ParseError; - fn from_str(qitem_str: &str) -> Result, crate::error::ParseError> { - if !qitem_str.is_ascii() { - return Err(crate::error::ParseError::Header); + fn from_str(q_item_str: &str) -> Result { + if !q_item_str.is_ascii() { + return Err(ParseError::Header); } - // Set defaults used if parsing fails. - let mut raw_item = qitem_str; - let mut quality = 1f32; + // set defaults used if quality-item parsing fails, i.e., item has no q attribute + let mut raw_item = q_item_str; + let mut quality = Quality::MAX; - let parts: Vec<_> = qitem_str.rsplitn(2, ';').map(str::trim).collect(); + let parts = q_item_str + .rsplit_once(';') + .map(|(item, q_attr)| (item.trim(), q_attr.trim())); - if parts.len() == 2 { + if let Some((val, q_attr)) = parts { // example for item with q-factor: // - // gzip; q=0.65 - // ^^^^^^ parts[0] - // ^^ start - // ^^^^ q_val - // ^^^^ parts[1] + // gzip;q=0.65 + // ^^^^ val + // ^^^^^^ q_attr + // ^^ q + // ^^^^ q_val - if parts[0].len() < 2 { + if q_attr.len() < 2 { // Can't possibly be an attribute since an attribute needs at least a name followed // by an equals sign. And bare identifiers are forbidden. - return Err(crate::error::ParseError::Header); + return Err(ParseError::Header); } - let start = &parts[0][0..2]; + let q = &q_attr[0..2]; - if start == "q=" || start == "Q=" { - let q_val = &parts[0][2..]; + if q == "q=" || q == "Q=" { + let q_val = &q_attr[2..]; if q_val.len() > 5 { // longer than 5 indicates an over-precise q-factor - return Err(crate::error::ParseError::Header); + return Err(ParseError::Header); } - let q_value = q_val - .parse::() - .map_err(|_| crate::error::ParseError::Header)?; + let q_value = q_val.parse::().map_err(|_| ParseError::Header)?; + let q_value = + Quality::try_from(q_value).map_err(|_| ParseError::Header)?; - if (0f32..=1f32).contains(&q_value) { - quality = q_value; - raw_item = parts[1]; - } else { - return Err(crate::error::ParseError::Header); - } + quality = q_value; + raw_item = val; } } - let item = raw_item - .parse::() - .map_err(|_| crate::error::ParseError::Header)?; + let item = raw_item.parse::().map_err(|_| ParseError::Header)?; - // we already checked above that the quality is within range - Ok(QualityItem::new(item, Quality::from_f32(quality))) + Ok(QualityItem::new(item, quality)) } } -/// Convenience function to wrap a value in a `QualityItem` -/// Sets `q` to the default 1.0 -pub fn qitem(item: T) -> QualityItem { - QualityItem::new(item, Quality::default()) -} - -/// Convenience function to create a `Quality` from a float or integer. -/// -/// Implemented for `u16` and `f32`. Panics if value is out of range. -pub fn q(val: T) -> Quality -where - T: TryInto, - T::Error: fmt::Debug, -{ - // TODO: on next breaking change, handle unwrap differently - val.try_into().unwrap() -} - #[cfg(test)] mod tests { use super::*; // copy of encoding from actix-web headers + #[allow(clippy::enum_variant_names)] // allow Encoding prefix on EncodingExt #[derive(Clone, PartialEq, Debug)] pub enum Encoding { Chunked, @@ -244,7 +188,7 @@ mod tests { #[test] fn test_quality_item_fmt_q_1() { use Encoding::*; - let x = qitem(Chunked); + let x = QualityItem::max(Chunked); assert_eq!(format!("{}", x), "chunked"); } #[test] @@ -343,25 +287,8 @@ mod tests { fn test_quality_item_ordering() { let x: QualityItem = "gzip; q=0.5".parse().ok().unwrap(); let y: QualityItem = "gzip; q=0.273".parse().ok().unwrap(); - let comparision_result: bool = x.gt(&y); - assert!(comparision_result) - } - - #[test] - fn test_quality() { - assert_eq!(q(0.5), Quality(500)); - } - - #[test] - #[should_panic] - fn test_quality_invalid() { - q(-1.0); - } - - #[test] - #[should_panic] - fn test_quality_invalid2() { - q(2.0); + let comparison_result: bool = x.gt(&y); + assert!(comparison_result) } #[test] diff --git a/actix-http/src/header/utils.rs b/actix-http/src/header/utils.rs index 5e9652380..f4f34d347 100644 --- a/actix-http/src/header/utils.rs +++ b/actix-http/src/header/utils.rs @@ -1,3 +1,5 @@ +//! Header parsing utilities. + use std::{fmt, str::FromStr}; use super::HeaderValue; @@ -10,9 +12,12 @@ where I: Iterator + 'a, T: FromStr, { - let mut result = Vec::new(); + let size_guess = all.size_hint().1.unwrap_or(2); + let mut result = Vec::with_capacity(size_guess); + for h in all { let s = h.to_str().map_err(|_| ParseError::Header)?; + result.extend( s.split(',') .filter_map(|x| match x.trim() { @@ -22,6 +27,7 @@ where .filter_map(|x| x.trim().parse().ok()), ) } + Ok(result) } @@ -30,10 +36,12 @@ where pub fn from_one_raw_str(val: Option<&HeaderValue>) -> Result { if let Some(line) = val { let line = line.to_str().map_err(|_| ParseError::Header)?; + if !line.is_empty() { return T::from_str(line).or(Err(ParseError::Header)); } } + Err(ParseError::Header) } @@ -44,19 +52,53 @@ where T: fmt::Display, { let mut iter = parts.iter(); + if let Some(part) = iter.next() { fmt::Display::fmt(part, f)?; } + for part in iter { f.write_str(", ")?; fmt::Display::fmt(part, f)?; } + Ok(()) } -/// Percent encode a sequence of bytes with a character set defined in -/// +/// Percent encode a sequence of bytes with a character set defined in [RFC 5987 §3.2]. +/// +/// [RFC 5987 §3.2]: https://datatracker.ietf.org/doc/html/rfc5987#section-3.2 +#[inline] pub fn http_percent_encode(f: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result { let encoded = percent_encoding::percent_encode(bytes, HTTP_VALUE); fmt::Display::fmt(&encoded, f) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn comma_delimited_parsing() { + let headers = vec![]; + let res: Vec = from_comma_delimited(headers.iter()).unwrap(); + assert_eq!(res, vec![0; 0]); + + let headers = vec![ + HeaderValue::from_static("1, 2"), + HeaderValue::from_static("3,4"), + ]; + let res: Vec = from_comma_delimited(headers.iter()).unwrap(); + assert_eq!(res, vec![1, 2, 3, 4]); + + let headers = vec![ + HeaderValue::from_static(""), + HeaderValue::from_static(","), + HeaderValue::from_static(" "), + HeaderValue::from_static("1 ,"), + HeaderValue::from_static(""), + ]; + let res: Vec = from_comma_delimited(headers.iter()).unwrap(); + assert_eq!(res, vec![1]); + } +} diff --git a/actix-http/src/lib.rs b/actix-http/src/lib.rs index 32c41174c..a8b08eb47 100644 --- a/actix-http/src/lib.rs +++ b/actix-http/src/lib.rs @@ -14,7 +14,7 @@ //! [rustls]: https://crates.io/crates/rustls //! [trust-dns]: https://crates.io/crates/trust-dns -#![deny(rust_2018_idioms, nonstandard_style)] +#![deny(rust_2018_idioms, nonstandard_style, clippy::uninit_assumed_init)] #![allow( clippy::type_complexity, clippy::too_many_arguments, @@ -29,7 +29,6 @@ extern crate log; pub mod body; mod builder; -pub mod client; mod config; #[cfg(feature = "__compress")] @@ -44,7 +43,6 @@ mod request; mod response; mod response_builder; mod service; -mod time_parser; pub mod error; pub mod h1; @@ -104,14 +102,9 @@ type ConnectCallback = dyn Fn(&IO, &mut CloneableExtensions); /// /// # Implementation Details /// Uses Option to reduce necessary allocations when merging with request extensions. +#[derive(Default)] pub(crate) struct OnConnectData(Option); -impl Default for OnConnectData { - fn default() -> Self { - Self(None) - } -} - impl OnConnectData { /// Construct by calling the on-connect callback with the underlying transport I/O. pub(crate) fn from_io( diff --git a/actix-http/src/message.rs b/actix-http/src/message.rs index e85d686b7..c8e1ce6db 100644 --- a/actix-http/src/message.rs +++ b/actix-http/src/message.rs @@ -46,8 +46,8 @@ pub trait Head: Default + 'static { #[derive(Debug)] pub struct RequestHead { - pub uri: Uri, pub method: Method, + pub uri: Uri, pub version: Version, pub headers: HeaderMap, pub extensions: RefCell, @@ -58,13 +58,13 @@ pub struct RequestHead { impl Default for RequestHead { fn default() -> RequestHead { RequestHead { - uri: Uri::default(), method: Method::default(), + uri: Uri::default(), version: Version::HTTP_11, headers: HeaderMap::with_capacity(16), - flags: Flags::empty(), - peer_addr: None, extensions: RefCell::new(Extensions::new()), + peer_addr: None, + flags: Flags::empty(), } } } @@ -192,6 +192,7 @@ impl RequestHead { } #[derive(Debug)] +#[allow(clippy::large_enum_variant)] pub enum RequestHeadType { Owned(RequestHead), Rc(Rc, Option), @@ -209,7 +210,7 @@ impl RequestHeadType { impl AsRef for RequestHeadType { fn as_ref(&self) -> &RequestHead { match self { - RequestHeadType::Owned(head) => &head, + RequestHeadType::Owned(head) => head, RequestHeadType::Rc(head, _) => head.as_ref(), } } @@ -317,7 +318,7 @@ impl ResponseHead { } #[inline] - pub(crate) fn ctype(&self) -> Option { + pub(crate) fn conn_type(&self) -> Option { if self.flags.contains(Flags::CLOSE) { Some(ConnectionType::Close) } else if self.flags.contains(Flags::KEEP_ALIVE) { @@ -363,7 +364,7 @@ impl std::ops::Deref for Message { type Target = T; fn deref(&self) -> &Self::Target { - &self.head.as_ref() + self.head.as_ref() } } diff --git a/actix-http/src/response.rs b/actix-http/src/response.rs index 2aa38c153..ad41094ae 100644 --- a/actix-http/src/response.rs +++ b/actix-http/src/response.rs @@ -6,14 +6,15 @@ use std::{ }; use bytes::{Bytes, BytesMut}; +use bytestring::ByteString; use crate::{ - body::{AnyBody, MessageBody}, - error::Error, + body::{BoxBody, MessageBody}, extensions::Extensions, + header::{self, IntoHeaderValue}, http::{HeaderMap, StatusCode}, message::{BoxedResponseHead, ResponseHead}, - ResponseBuilder, + Error, ResponseBuilder, }; /// An HTTP response. @@ -22,13 +23,13 @@ pub struct Response { pub(crate) body: B, } -impl Response { +impl Response { /// Constructs a new response with default body. #[inline] pub fn new(status: StatusCode) -> Self { Response { head: BoxedResponseHead::new(status), - body: AnyBody::Empty, + body: BoxBody::new(()), } } @@ -189,6 +190,14 @@ impl Response { } } + #[inline] + pub fn map_into_boxed_body(self) -> Response + where + B: MessageBody + 'static, + { + self.map_body(|_, body| BoxBody::new(body)) + } + /// Returns body, consuming this response. pub fn into_body(self) -> B { self.body @@ -223,81 +232,99 @@ impl Default for Response { } } -impl>, E: Into> From> - for Response +impl>, E: Into> From> + for Response { fn from(res: Result) -> Self { match res { Ok(val) => val.into(), - Err(err) => err.into().into(), + Err(err) => Response::from(err.into()), } } } -impl From for Response { +impl From for Response { fn from(mut builder: ResponseBuilder) -> Self { - builder.finish() + builder.finish().map_into_boxed_body() } } -impl From for Response { +impl From for Response { fn from(val: std::convert::Infallible) -> Self { match val {} } } -impl From<&'static str> for Response { +impl From<&'static str> for Response<&'static str> { fn from(val: &'static str) -> Self { - Response::build(StatusCode::OK) - .content_type(mime::TEXT_PLAIN_UTF_8) - .body(val) + let mut res = Response::with_body(StatusCode::OK, val); + let mime = mime::TEXT_PLAIN_UTF_8.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); + res } } -impl From<&'static [u8]> for Response { +impl From<&'static [u8]> for Response<&'static [u8]> { fn from(val: &'static [u8]) -> Self { - Response::build(StatusCode::OK) - .content_type(mime::APPLICATION_OCTET_STREAM) - .body(val) + let mut res = Response::with_body(StatusCode::OK, val); + let mime = mime::APPLICATION_OCTET_STREAM.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); + res } } -impl From for Response { +impl From for Response { fn from(val: String) -> Self { - Response::build(StatusCode::OK) - .content_type(mime::TEXT_PLAIN_UTF_8) - .body(val) + let mut res = Response::with_body(StatusCode::OK, val); + let mime = mime::TEXT_PLAIN_UTF_8.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); + res } } -impl<'a> From<&'a String> for Response { - fn from(val: &'a String) -> Self { - Response::build(StatusCode::OK) - .content_type(mime::TEXT_PLAIN_UTF_8) - .body(val) +impl From<&String> for Response { + fn from(val: &String) -> Self { + let mut res = Response::with_body(StatusCode::OK, val.clone()); + let mime = mime::TEXT_PLAIN_UTF_8.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); + res } } -impl From for Response { +impl From for Response { fn from(val: Bytes) -> Self { - Response::build(StatusCode::OK) - .content_type(mime::APPLICATION_OCTET_STREAM) - .body(val) + let mut res = Response::with_body(StatusCode::OK, val); + let mime = mime::APPLICATION_OCTET_STREAM.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); + res } } -impl From for Response { +impl From for Response { fn from(val: BytesMut) -> Self { - Response::build(StatusCode::OK) - .content_type(mime::APPLICATION_OCTET_STREAM) - .body(val) + let mut res = Response::with_body(StatusCode::OK, val); + let mime = mime::APPLICATION_OCTET_STREAM.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); + res + } +} + +impl From for Response { + fn from(val: ByteString) -> Self { + let mut res = Response::with_body(StatusCode::OK, val); + let mime = mime::TEXT_PLAIN_UTF_8.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); + res } } #[cfg(test)] mod tests { use super::*; - use crate::http::header::{HeaderValue, CONTENT_TYPE, COOKIE}; + use crate::{ + body::to_bytes, + http::header::{HeaderValue, CONTENT_TYPE, COOKIE}, + }; #[test] fn test_debug() { @@ -309,73 +336,73 @@ mod tests { assert!(dbg.contains("Response")); } - #[test] - fn test_into_response() { - let resp: Response = "test".into(); - assert_eq!(resp.status(), StatusCode::OK); + #[actix_rt::test] + async fn test_into_response() { + let res = Response::from("test"); + assert_eq!(res.status(), StatusCode::OK); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), + res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("text/plain; charset=utf-8") ); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]); - let resp: Response = b"test".as_ref().into(); - assert_eq!(resp.status(), StatusCode::OK); + let res = Response::from(b"test".as_ref()); + assert_eq!(res.status(), StatusCode::OK); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), + res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("application/octet-stream") ); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]); - let resp: Response = "test".to_owned().into(); - assert_eq!(resp.status(), StatusCode::OK); + let res = Response::from("test".to_owned()); + assert_eq!(res.status(), StatusCode::OK); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), + res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("text/plain; charset=utf-8") ); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]); - let resp: Response = (&"test".to_owned()).into(); - assert_eq!(resp.status(), StatusCode::OK); + let res = Response::from("test".to_owned()); + assert_eq!(res.status(), StatusCode::OK); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), + res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("text/plain; charset=utf-8") ); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]); let b = Bytes::from_static(b"test"); - let resp: Response = b.into(); - assert_eq!(resp.status(), StatusCode::OK); + let res = Response::from(b); + assert_eq!(res.status(), StatusCode::OK); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), + res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("application/octet-stream") ); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]); let b = Bytes::from_static(b"test"); - let resp: Response = b.into(); - assert_eq!(resp.status(), StatusCode::OK); + let res = Response::from(b); + assert_eq!(res.status(), StatusCode::OK); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), + res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("application/octet-stream") ); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]); let b = BytesMut::from("test"); - let resp: Response = b.into(); - assert_eq!(resp.status(), StatusCode::OK); + let res = Response::from(b); + assert_eq!(res.status(), StatusCode::OK); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), + res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("application/octet-stream") ); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]); } } diff --git a/actix-http/src/response_builder.rs b/actix-http/src/response_builder.rs index e46d9a28c..0537112d5 100644 --- a/actix-http/src/response_builder.rs +++ b/actix-http/src/response_builder.rs @@ -2,19 +2,11 @@ use std::{ cell::{Ref, RefMut}, - error::Error as StdError, - fmt, - future::Future, - pin::Pin, - str, - task::{Context, Poll}, + fmt, str, }; -use bytes::Bytes; -use futures_core::Stream; - use crate::{ - body::{AnyBody, BodyStream}, + body::{EitherBody, MessageBody}, error::{Error, HttpError}, header::{self, IntoHeaderPair, IntoHeaderValue}, message::{BoxedResponseHead, ConnectionType, ResponseHead}, @@ -235,10 +227,14 @@ impl ResponseBuilder { /// Generate response with a wrapped body. /// /// This `ResponseBuilder` will be left in a useless state. - #[inline] - pub fn body>(&mut self, body: B) -> Response { - self.message_body(body.into()) - .unwrap_or_else(Response::from) + pub fn body(&mut self, body: B) -> Response> + where + B: MessageBody + 'static, + { + match self.message_body(body) { + Ok(res) => res.map_body(|_, body| EitherBody::left(body)), + Err(err) => Response::from(err).map_body(|_, body| EitherBody::right(body)), + } } /// Generate response with a body. @@ -253,24 +249,12 @@ impl ResponseBuilder { Ok(Response { head, body }) } - /// Generate response with a streaming body. - /// - /// This `ResponseBuilder` will be left in a useless state. - #[inline] - pub fn streaming(&mut self, stream: S) -> Response - where - S: Stream> + 'static, - E: Into> + 'static, - { - self.body(AnyBody::from_message(BodyStream::new(stream))) - } - /// Generate response with an empty body. /// /// This `ResponseBuilder` will be left in a useless state. #[inline] - pub fn finish(&mut self) -> Response { - self.body(AnyBody::Empty) + pub fn finish(&mut self) -> Response> { + self.body(()) } /// Create an owned `ResponseBuilder`, leaving the original in a useless state. @@ -327,14 +311,6 @@ impl<'a> From<&'a ResponseHead> for ResponseBuilder { } } -impl Future for ResponseBuilder { - type Output = Result, Error>; - - fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { - Poll::Ready(Ok(self.finish())) - } -} - impl fmt::Debug for ResponseBuilder { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let head = self.head.as_ref().unwrap(); @@ -356,8 +332,9 @@ impl fmt::Debug for ResponseBuilder { #[cfg(test)] mod tests { + use bytes::Bytes; + use super::*; - use crate::body::Body; use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE}; #[test] @@ -383,20 +360,28 @@ mod tests { #[test] fn test_force_close() { let resp = Response::build(StatusCode::OK).force_close().finish(); - assert!(!resp.keep_alive()) + assert!(!resp.keep_alive()); } #[test] fn test_content_type() { let resp = Response::build(StatusCode::OK) .content_type("text/plain") - .body(Body::Empty); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain") + .body(Bytes::new()); + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain"); + + let resp = Response::build(StatusCode::OK) + .content_type(mime::APPLICATION_JAVASCRIPT_UTF_8) + .body(Bytes::new()); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + "application/javascript; charset=utf-8" + ); } #[test] fn test_into_builder() { - let mut resp: Response = "test".into(); + let mut resp: Response<_> = "test".into(); assert_eq!(resp.status(), StatusCode::OK); resp.headers_mut().insert( diff --git a/actix-http/src/service.rs b/actix-http/src/service.rs index afe47bf2d..7af34ba05 100644 --- a/actix-http/src/service.rs +++ b/actix-http/src/service.rs @@ -1,5 +1,4 @@ use std::{ - error::Error as StdError, fmt, future::Future, marker::PhantomData, @@ -9,18 +8,16 @@ use std::{ task::{Context, Poll}, }; -use ::h2::server::{handshake as h2_handshake, Handshake as H2Handshake}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_rt::net::TcpStream; use actix_service::{ fn_service, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt as _, }; -use bytes::Bytes; use futures_core::{future::LocalBoxFuture, ready}; -use pin_project::pin_project; +use pin_project_lite::pin_project; use crate::{ - body::{AnyBody, MessageBody}, + body::{BoxBody, MessageBody}, builder::HttpServiceBuilder, config::{KeepAlive, ServiceConfig}, error::DispatchError, @@ -40,7 +37,7 @@ pub struct HttpService { impl HttpService where S: ServiceFactory, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, >::Future: 'static, @@ -55,12 +52,11 @@ where impl HttpService where S: ServiceFactory, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, >::Future: 'static, B: MessageBody + 'static, - B::Error: Into>, { /// Create new `HttpService` instance. pub fn new>(service: F) -> Self { @@ -95,7 +91,7 @@ where impl HttpService where S: ServiceFactory, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, >::Future: 'static, @@ -109,7 +105,7 @@ where pub fn expect(self, expect: X1) -> HttpService where X1: ServiceFactory, - X1::Error: Into>, + X1::Error: Into>, X1::InitError: fmt::Debug, { HttpService { @@ -153,17 +149,16 @@ impl HttpService where S: ServiceFactory, S::Future: 'static, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, >::Future: 'static, B: MessageBody + 'static, - B::Error: Into>, X: ServiceFactory, X::Future: 'static, - X::Error: Into>, + X::Error: Into>, X::InitError: fmt::Debug, U: ServiceFactory< @@ -172,7 +167,7 @@ where Response = (), >, U::Future: 'static, - U::Error: fmt::Display + Into>, + U::Error: fmt::Display + Into>, U::InitError: fmt::Debug, { /// Create simple tcp stream service @@ -195,9 +190,14 @@ where #[cfg(feature = "openssl")] mod openssl { - use actix_service::ServiceFactoryExt; - use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, TlsStream}; - use actix_tls::accept::TlsError; + use actix_service::ServiceFactoryExt as _; + use actix_tls::accept::{ + openssl::{ + reexports::{Error as SslError, SslAcceptor}, + Acceptor, TlsStream, + }, + TlsError, + }; use super::*; @@ -205,17 +205,16 @@ mod openssl { where S: ServiceFactory, S::Future: 'static, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, >::Future: 'static, B: MessageBody + 'static, - B::Error: Into>, X: ServiceFactory, X::Future: 'static, - X::Error: Into>, + X::Error: Into>, X::InitError: fmt::Debug, U: ServiceFactory< @@ -224,10 +223,10 @@ mod openssl { Response = (), >, U::Future: 'static, - U::Error: fmt::Display + Into>, + U::Error: fmt::Display + Into>, U::InitError: fmt::Debug, { - /// Create openssl based service + /// Create OpenSSL based service. pub fn openssl( self, acceptor: SslAcceptor, @@ -239,9 +238,11 @@ mod openssl { InitError = (), > { Acceptor::new(acceptor) - .map_err(TlsError::Tls) - .map_init_err(|_| panic!()) - .and_then(|io: TlsStream| async { + .map_init_err(|_| { + unreachable!("TLS acceptor service factory does not error on init") + }) + .map_err(TlsError::into_service_error) + .map(|io: TlsStream| { let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() { if protos.windows(2).any(|window| window == b"h2") { Protocol::Http2 @@ -251,8 +252,9 @@ mod openssl { } else { Protocol::Http1 }; + let peer_addr = io.get_ref().peer_addr().ok(); - Ok((io, proto, peer_addr)) + (io, proto, peer_addr) }) .and_then(self.map_err(TlsError::Service)) } @@ -263,27 +265,28 @@ mod openssl { mod rustls { use std::io; - use actix_tls::accept::rustls::{Acceptor, ServerConfig, Session, TlsStream}; - use actix_tls::accept::TlsError; + use actix_service::ServiceFactoryExt as _; + use actix_tls::accept::{ + rustls::{reexports::ServerConfig, Acceptor, TlsStream}, + TlsError, + }; use super::*; - use actix_service::ServiceFactoryExt; impl HttpService, S, B, X, U> where S: ServiceFactory, S::Future: 'static, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, >::Future: 'static, B: MessageBody + 'static, - B::Error: Into>, X: ServiceFactory, X::Future: 'static, - X::Error: Into>, + X::Error: Into>, X::InitError: fmt::Debug, U: ServiceFactory< @@ -292,10 +295,10 @@ mod rustls { Response = (), >, U::Future: 'static, - U::Error: fmt::Display + Into>, + U::Error: fmt::Display + Into>, U::InitError: fmt::Debug, { - /// Create rustls based service + /// Create Rustls based service. pub fn rustls( self, mut config: ServerConfig, @@ -308,14 +311,15 @@ mod rustls { > { let mut protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; protos.extend_from_slice(&config.alpn_protocols); - config.set_protocols(&protos); + config.alpn_protocols = protos; Acceptor::new(config) - .map_err(TlsError::Tls) - .map_init_err(|_| panic!()) + .map_init_err(|_| { + unreachable!("TLS acceptor service factory does not error on init") + }) + .map_err(TlsError::into_service_error) .and_then(|io: TlsStream| async { - let proto = if let Some(protos) = io.get_ref().1.get_alpn_protocol() - { + let proto = if let Some(protos) = io.get_ref().1.alpn_protocol() { if protos.windows(2).any(|window| window == b"h2") { Protocol::Http2 } else { @@ -339,22 +343,21 @@ where S: ServiceFactory, S::Future: 'static, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, >::Future: 'static, B: MessageBody + 'static, - B::Error: Into>, X: ServiceFactory, X::Future: 'static, - X::Error: Into>, + X::Error: Into>, X::InitError: fmt::Debug, U: ServiceFactory<(Request, Framed), Config = (), Response = ()>, U::Future: 'static, - U::Error: fmt::Display + Into>, + U::Error: fmt::Display + Into>, U::InitError: fmt::Debug, { type Response = (); @@ -417,11 +420,11 @@ where impl HttpServiceHandler where S: Service, - S::Error: Into>, + S::Error: Into>, X: Service, - X::Error: Into>, + X::Error: Into>, U: Service<(Request, Framed)>, - U::Error: Into>, + U::Error: Into>, { pub(super) fn new( cfg: ServiceConfig, @@ -441,7 +444,7 @@ where pub(super) fn _poll_ready( &self, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { ready!(self.flow.expect.poll_ready(cx).map_err(Into::into))?; ready!(self.flow.service.poll_ready(cx).map_err(Into::into))?; @@ -477,18 +480,17 @@ where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody + 'static, - B::Error: Into>, X: Service, - X::Error: Into>, + X::Error: Into>, U: Service<(Request, Framed), Response = ()>, - U::Error: fmt::Display + Into>, + U::Error: fmt::Display + Into>, { type Response = (); type Error = DispatchError; @@ -510,23 +512,27 @@ where match proto { Protocol::Http2 => HttpServiceHandlerResponse { - state: State::H2Handshake(Some(( - h2_handshake(io), - self.cfg.clone(), - self.flow.clone(), - on_connect_data, - peer_addr, - ))), + state: State::H2Handshake { + handshake: Some(( + h2::handshake_with_timeout(io, &self.cfg), + self.cfg.clone(), + self.flow.clone(), + on_connect_data, + peer_addr, + )), + }, }, Protocol::Http1 => HttpServiceHandlerResponse { - state: State::H1(h1::Dispatcher::new( - io, - self.cfg.clone(), - self.flow.clone(), - on_connect_data, - peer_addr, - )), + state: State::H1 { + dispatcher: h1::Dispatcher::new( + io, + self.cfg.clone(), + self.flow.clone(), + on_connect_data, + peer_addr, + ), + }, }, proto => unimplemented!("Unsupported HTTP version: {:?}.", proto), @@ -534,58 +540,65 @@ where } } -#[pin_project(project = StateProj)] -enum State -where - T: AsyncRead + AsyncWrite + Unpin, +pin_project! { + #[project = StateProj] + enum State + where + T: AsyncRead, + T: AsyncWrite, + T: Unpin, - S: Service, - S::Future: 'static, - S::Error: Into>, + S: Service, + S::Future: 'static, + S::Error: Into>, - B: MessageBody, - B::Error: Into>, + B: MessageBody, - X: Service, - X::Error: Into>, + X: Service, + X::Error: Into>, - U: Service<(Request, Framed), Response = ()>, - U::Error: fmt::Display, -{ - H1(#[pin] h1::Dispatcher), - H2(#[pin] h2::Dispatcher), - H2Handshake( - Option<( - H2Handshake, - ServiceConfig, - Rc>, - OnConnectData, - Option, - )>, - ), + U: Service<(Request, Framed), Response = ()>, + U::Error: fmt::Display, + { + H1 { #[pin] dispatcher: h1::Dispatcher }, + H2 { #[pin] dispatcher: h2::Dispatcher }, + H2Handshake { + handshake: Option<( + h2::HandshakeWithTimeout, + ServiceConfig, + Rc>, + OnConnectData, + Option, + )>, + }, + } } -#[pin_project] -pub struct HttpServiceHandlerResponse -where - T: AsyncRead + AsyncWrite + Unpin, +pin_project! { + pub struct HttpServiceHandlerResponse + where + T: AsyncRead, + T: AsyncWrite, + T: Unpin, - S: Service, - S::Error: Into> + 'static, - S::Future: 'static, - S::Response: Into> + 'static, + S: Service, + S::Error: Into>, + S::Error: 'static, + S::Future: 'static, + S::Response: Into>, + S::Response: 'static, - B: MessageBody, - B::Error: Into>, + B: MessageBody, - X: Service, - X::Error: Into>, + X: Service, + X::Error: Into>, - U: Service<(Request, Framed), Response = ()>, - U::Error: fmt::Display, -{ - #[pin] - state: State, + U: Service<(Request, Framed), Response = ()>, + U::Error: fmt::Display, + { + #[pin] + state: State, + } } impl Future for HttpServiceHandlerResponse @@ -593,15 +606,14 @@ where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into> + 'static, + S::Error: Into> + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody + 'static, - B::Error: Into>, X: Service, - X::Error: Into>, + X::Error: Into>, U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, @@ -610,27 +622,29 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.as_mut().project().state.project() { - StateProj::H1(disp) => disp.poll(cx), - StateProj::H2(disp) => disp.poll(cx), - StateProj::H2Handshake(data) => { + StateProj::H1 { dispatcher } => dispatcher.poll(cx), + StateProj::H2 { dispatcher } => dispatcher.poll(cx), + StateProj::H2Handshake { handshake: data } => { match ready!(Pin::new(&mut data.as_mut().unwrap().0).poll(cx)) { - Ok(conn) => { - let (_, cfg, srv, on_connect_data, peer_addr) = + Ok((conn, timer)) => { + let (_, config, flow, on_connect_data, peer_addr) = data.take().unwrap(); - self.as_mut().project().state.set(State::H2( - h2::Dispatcher::new( - srv, + + self.as_mut().project().state.set(State::H2 { + dispatcher: h2::Dispatcher::new( + flow, conn, on_connect_data, - cfg, + config, peer_addr, + timer, ), - )); + }); self.poll(cx) } Err(err) => { trace!("H2 handshake error: {}", err); - Poll::Ready(Err(err.into())) + Poll::Ready(Err(err)) } } } diff --git a/actix-http/src/time_parser.rs b/actix-http/src/time_parser.rs deleted file mode 100644 index fd82fd42e..000000000 --- a/actix-http/src/time_parser.rs +++ /dev/null @@ -1,72 +0,0 @@ -use time::{Date, OffsetDateTime, PrimitiveDateTime}; - -/// Attempt to parse a `time` string as one of either RFC 1123, RFC 850, or asctime. -pub(crate) fn parse_http_date(time: &str) -> Option { - try_parse_rfc_1123(time) - .or_else(|| try_parse_rfc_850(time)) - .or_else(|| try_parse_asctime(time)) -} - -/// Attempt to parse a `time` string as a RFC 1123 formatted date time string. -/// -/// Eg: `Fri, 12 Feb 2021 00:14:29 GMT` -fn try_parse_rfc_1123(time: &str) -> Option { - time::parse(time, "%a, %d %b %Y %H:%M:%S").ok() -} - -/// Attempt to parse a `time` string as a RFC 850 formatted date time string. -/// -/// Eg: `Wednesday, 11-Jan-21 13:37:41 UTC` -fn try_parse_rfc_850(time: &str) -> Option { - let dt = PrimitiveDateTime::parse(time, "%A, %d-%b-%y %H:%M:%S").ok()?; - - // If the `time` string contains a two-digit year, then as per RFC 2616 § 19.3, - // we consider the year as part of this century if it's within the next 50 years, - // otherwise we consider as part of the previous century. - - let now = OffsetDateTime::now_utc(); - let century_start_year = (now.year() / 100) * 100; - let mut expanded_year = century_start_year + dt.year(); - - if expanded_year > now.year() + 50 { - expanded_year -= 100; - } - - let date = Date::try_from_ymd(expanded_year, dt.month(), dt.day()).ok()?; - Some(PrimitiveDateTime::new(date, dt.time())) -} - -/// Attempt to parse a `time` string using ANSI C's `asctime` format. -/// -/// Eg: `Wed Feb 13 15:46:11 2013` -fn try_parse_asctime(time: &str) -> Option { - time::parse(time, "%a %b %_d %H:%M:%S %Y").ok() -} - -#[cfg(test)] -mod tests { - use time::{date, time}; - - use super::*; - - #[test] - fn test_rfc_850_year_shift() { - let date = try_parse_rfc_850("Friday, 19-Nov-82 16:14:55 EST").unwrap(); - assert_eq!(date, date!(1982 - 11 - 19).with_time(time!(16:14:55))); - - let date = try_parse_rfc_850("Wednesday, 11-Jan-62 13:37:41 EST").unwrap(); - assert_eq!(date, date!(2062 - 01 - 11).with_time(time!(13:37:41))); - - let date = try_parse_rfc_850("Wednesday, 11-Jan-21 13:37:41 EST").unwrap(); - assert_eq!(date, date!(2021 - 01 - 11).with_time(time!(13:37:41))); - - let date = try_parse_rfc_850("Wednesday, 11-Jan-23 13:37:41 EST").unwrap(); - assert_eq!(date, date!(2023 - 01 - 11).with_time(time!(13:37:41))); - - let date = try_parse_rfc_850("Wednesday, 11-Jan-99 13:37:41 EST").unwrap(); - assert_eq!(date, date!(1999 - 01 - 11).with_time(time!(13:37:41))); - - let date = try_parse_rfc_850("Wednesday, 11-Jan-00 13:37:41 EST").unwrap(); - assert_eq!(date, date!(2000 - 01 - 11).with_time(time!(13:37:41))); - } -} diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs index 8655216fa..d80613e5f 100644 --- a/actix-http/src/ws/codec.rs +++ b/actix-http/src/ws/codec.rs @@ -63,8 +63,8 @@ pub enum Item { Last(Bytes), } -#[derive(Debug, Copy, Clone)] /// WebSocket protocol codec. +#[derive(Debug, Clone)] pub struct Codec { flags: Flags, max_size: usize, @@ -89,7 +89,8 @@ impl Codec { /// Set max frame size. /// - /// By default max size is set to 64kB. + /// By default max size is set to 64KiB. + #[must_use = "This returns the a new Codec, without modifying the original."] pub fn max_size(mut self, size: usize) -> Self { self.max_size = size; self @@ -98,12 +99,19 @@ impl Codec { /// Set decoder to client mode. /// /// By default decoder works in server mode. + #[must_use = "This returns the a new Codec, without modifying the original."] pub fn client_mode(mut self) -> Self { self.flags.remove(Flags::SERVER); self } } +impl Default for Codec { + fn default() -> Self { + Self::new() + } +} + impl Encoder for Codec { type Error = ProtocolError; diff --git a/actix-http/src/ws/dispatcher.rs b/actix-http/src/ws/dispatcher.rs index f49cbe5d4..a3f766e9c 100644 --- a/actix-http/src/ws/dispatcher.rs +++ b/actix-http/src/ws/dispatcher.rs @@ -4,17 +4,21 @@ use std::task::{Context, Poll}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_service::{IntoService, Service}; +use pin_project_lite::pin_project; use super::{Codec, Frame, Message}; -#[pin_project::pin_project] -pub struct Dispatcher -where - S: Service + 'static, - T: AsyncRead + AsyncWrite, -{ - #[pin] - inner: inner::Dispatcher, +pin_project! { + pub struct Dispatcher + where + S: Service, + S: 'static, + T: AsyncRead, + T: AsyncWrite, + { + #[pin] + inner: inner::Dispatcher, + } } impl Dispatcher @@ -72,7 +76,7 @@ mod inner { use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed}; - use crate::{body::AnyBody, Response}; + use crate::{body::BoxBody, Response}; /// Framed transport errors pub enum DispatcherError @@ -136,7 +140,7 @@ mod inner { } } - impl From> for Response + impl From> for Response where E: fmt::Debug + fmt::Display, U: Encoder + Decoder, @@ -144,7 +148,7 @@ mod inner { ::Error: fmt::Debug, { fn from(err: DispatcherError) -> Self { - Response::internal_server_error().set_body(AnyBody::from(err.to_string())) + Response::internal_server_error().set_body(BoxBody::new(err.to_string())) } } diff --git a/actix-http/src/ws/mask.rs b/actix-http/src/ws/mask.rs index 276ca4a85..11a6ddc32 100644 --- a/actix-http/src/ws/mask.rs +++ b/actix-http/src/ws/mask.rs @@ -25,8 +25,8 @@ pub fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) { // // un aligned prefix and suffix would be mask/unmask per byte. // proper aligned middle slice goes into fast path and operates on 4-byte blocks. - let (mut prefix, words, mut suffix) = unsafe { buf.align_to_mut::() }; - apply_mask_fallback(&mut prefix, mask); + let (prefix, words, suffix) = unsafe { buf.align_to_mut::() }; + apply_mask_fallback(prefix, mask); let head = prefix.len() & 3; let mask_u32 = if head > 0 { if cfg!(target_endian = "big") { @@ -40,7 +40,7 @@ pub fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) { for word in words.iter_mut() { *word ^= mask_u32; } - apply_mask_fallback(&mut suffix, mask_u32.to_ne_bytes()); + apply_mask_fallback(suffix, mask_u32.to_ne_bytes()); } #[cfg(test)] diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index 7df924cf5..cb1aa6730 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -8,9 +8,9 @@ use std::io; use derive_more::{Display, Error, From}; use http::{header, Method, StatusCode}; +use crate::body::BoxBody; use crate::{ - body::AnyBody, header::HeaderValue, message::RequestHead, response::Response, - ResponseBuilder, + header::HeaderValue, message::RequestHead, response::Response, ResponseBuilder, }; mod codec; @@ -69,7 +69,7 @@ pub enum ProtocolError { } /// WebSocket handshake errors -#[derive(Debug, PartialEq, Display, Error)] +#[derive(Debug, Clone, Copy, PartialEq, Display, Error)] pub enum HandshakeError { /// Only get method is allowed. #[display(fmt = "Method not allowed.")] @@ -96,8 +96,8 @@ pub enum HandshakeError { BadWebsocketKey, } -impl From<&HandshakeError> for Response { - fn from(err: &HandshakeError) -> Self { +impl From for Response { + fn from(err: HandshakeError) -> Self { match err { HandshakeError::GetMethodRequired => { let mut res = Response::new(StatusCode::METHOD_NOT_ALLOWED); @@ -139,9 +139,9 @@ impl From<&HandshakeError> for Response { } } -impl From for Response { - fn from(err: HandshakeError) -> Self { - (&err).into() +impl From<&HandshakeError> for Response { + fn from(err: &HandshakeError) -> Self { + (*err).into() } } @@ -210,7 +210,6 @@ pub fn handshake_response(req: &RequestHead) -> ResponseBuilder { Response::build(StatusCode::SWITCHING_PROTOCOLS) .upgrade("websocket") - .insert_header((header::TRANSFER_ENCODING, "chunked")) .insert_header(( header::SEC_WEBSOCKET_ACCEPT, // key is known to be header value safe ascii @@ -221,9 +220,10 @@ pub fn handshake_response(req: &RequestHead) -> ResponseBuilder { #[cfg(test)] mod tests { + use crate::{header, Method}; + use super::*; - use crate::{body::AnyBody, test::TestRequest}; - use http::{header, Method}; + use crate::test::TestRequest; #[test] fn test_handshake() { @@ -337,17 +337,17 @@ mod tests { #[test] fn test_ws_error_http_response() { - let resp: Response = HandshakeError::GetMethodRequired.into(); + let resp: Response = HandshakeError::GetMethodRequired.into(); assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); - let resp: Response = HandshakeError::NoWebsocketUpgrade.into(); + let resp: Response = HandshakeError::NoWebsocketUpgrade.into(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: Response = HandshakeError::NoConnectionUpgrade.into(); + let resp: Response = HandshakeError::NoConnectionUpgrade.into(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: Response = HandshakeError::NoVersionHeader.into(); + let resp: Response = HandshakeError::NoVersionHeader.into(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: Response = HandshakeError::UnsupportedVersion.into(); + let resp: Response = HandshakeError::UnsupportedVersion.into(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: Response = HandshakeError::BadWebsocketKey.into(); + let resp: Response = HandshakeError::BadWebsocketKey.into(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } } diff --git a/actix-http/src/ws/proto.rs b/actix-http/src/ws/proto.rs index fdcde5eac..4227f221d 100644 --- a/actix-http/src/ws/proto.rs +++ b/actix-http/src/ws/proto.rs @@ -3,7 +3,9 @@ use std::{ fmt, }; -/// Operation codes as part of RFC6455. +/// Operation codes defined in [RFC 6455 §11.8]. +/// +/// [RFC 6455]: https://datatracker.ietf.org/doc/html/rfc6455#section-11.8 #[derive(Debug, Eq, PartialEq, Clone, Copy)] pub enum OpCode { /// Indicates a continuation frame of a fragmented message. @@ -105,7 +107,7 @@ pub enum CloseCode { Abnormal, /// Indicates that an endpoint is terminating the connection because it has received data within - /// a message that was not consistent with the type of the message (e.g., non-UTF-8 \[RFC3629\] + /// a message that was not consistent with the type of the message (e.g., non-UTF-8 \[RFC 3629\] /// data within a text message). Invalid, @@ -220,7 +222,8 @@ impl> From<(CloseCode, T)> for CloseReason { } } -/// The WebSocket GUID as stated in the spec. See https://tools.ietf.org/html/rfc6455#section-1.3. +/// The WebSocket GUID as stated in the spec. +/// See . static WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; /// Hashes the `Sec-WebSocket-Key` header according to the WebSocket spec. diff --git a/actix-http/tests/test_client.rs b/actix-http/tests/test_client.rs index 414266d81..4c923873f 100644 --- a/actix-http/tests/test_client.rs +++ b/actix-http/tests/test_client.rs @@ -1,7 +1,7 @@ use std::convert::Infallible; use actix_http::{ - body::AnyBody, http, http::StatusCode, HttpMessage, HttpService, Request, Response, + body::BoxBody, http, http::StatusCode, HttpMessage, HttpService, Request, Response, }; use actix_http_test::test_server; use actix_service::ServiceFactoryExt; @@ -99,7 +99,7 @@ async fn test_with_query_parameter() { #[display(fmt = "expect failed")] struct ExpectFailed; -impl From for Response { +impl From for Response { fn from(_: ExpectFailed) -> Self { Response::new(StatusCode::EXPECTATION_FAILED) } diff --git a/actix-http/tests/test_h2_timer.rs b/actix-http/tests/test_h2_timer.rs new file mode 100644 index 000000000..2b9c26e4a --- /dev/null +++ b/actix-http/tests/test_h2_timer.rs @@ -0,0 +1,153 @@ +use std::io; + +use actix_http::{error::Error, HttpService, Response}; +use actix_server::Server; +use tokio::io::AsyncWriteExt; + +#[actix_rt::test] +async fn h2_ping_pong() -> io::Result<()> { + let (tx, rx) = std::sync::mpsc::sync_channel(1); + + let lst = std::net::TcpListener::bind("127.0.0.1:0")?; + + let addr = lst.local_addr().unwrap(); + + let join = std::thread::spawn(move || { + actix_rt::System::new().block_on(async move { + let srv = Server::build() + .disable_signals() + .workers(1) + .listen("h2_ping_pong", lst, || { + HttpService::build() + .keep_alive(3) + .h2(|_| async { Ok::<_, Error>(Response::ok()) }) + .tcp() + })? + .run(); + + tx.send(srv.handle()).unwrap(); + + srv.await + }) + }); + + let handle = rx.recv().unwrap(); + + let (sync_tx, rx) = std::sync::mpsc::sync_channel(1); + + // use a separate thread for h2 client so it can be blocked. + std::thread::spawn(move || { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(async move { + let stream = tokio::net::TcpStream::connect(addr).await.unwrap(); + + let (mut tx, conn) = h2::client::handshake(stream).await.unwrap(); + + tokio::spawn(async move { conn.await.unwrap() }); + + let (res, _) = tx.send_request(::http::Request::new(()), true).unwrap(); + let res = res.await.unwrap(); + + assert_eq!(res.status().as_u16(), 200); + + sync_tx.send(()).unwrap(); + + // intentionally block the client thread so it can not answer ping pong. + std::thread::sleep(std::time::Duration::from_secs(1000)); + }) + }); + + rx.recv().unwrap(); + + let now = std::time::Instant::now(); + + // stop server gracefully. this step would take up to 30 seconds. + handle.stop(true).await; + + // join server thread. only when connection are all gone this step would finish. + join.join().unwrap()?; + + // check the time used for join server thread so it's known that the server shutdown + // is from keep alive and not server graceful shutdown timeout. + assert!(now.elapsed() < std::time::Duration::from_secs(30)); + + Ok(()) +} + +#[actix_rt::test] +async fn h2_handshake_timeout() -> io::Result<()> { + let (tx, rx) = std::sync::mpsc::sync_channel(1); + + let lst = std::net::TcpListener::bind("127.0.0.1:0")?; + + let addr = lst.local_addr().unwrap(); + + let join = std::thread::spawn(move || { + actix_rt::System::new().block_on(async move { + let srv = Server::build() + .disable_signals() + .workers(1) + .listen("h2_ping_pong", lst, || { + HttpService::build() + .keep_alive(30) + // set first request timeout to 5 seconds. + // this is the timeout used for http2 handshake. + .client_timeout(5000) + .h2(|_| async { Ok::<_, Error>(Response::ok()) }) + .tcp() + })? + .run(); + + tx.send(srv.handle()).unwrap(); + + srv.await + }) + }); + + let handle = rx.recv().unwrap(); + + let (sync_tx, rx) = std::sync::mpsc::sync_channel(1); + + // use a separate thread for tcp client so it can be blocked. + std::thread::spawn(move || { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(async move { + let mut stream = tokio::net::TcpStream::connect(addr).await.unwrap(); + + // do not send the last new line intentionally. + // This should hang the server handshake + let malicious_buf = b"PRI * HTTP/2.0\r\n\r\nSM\r\n"; + stream.write_all(malicious_buf).await.unwrap(); + stream.flush().await.unwrap(); + + sync_tx.send(()).unwrap(); + + // intentionally block the client thread so it sit idle and not do handshake. + std::thread::sleep(std::time::Duration::from_secs(1000)); + + drop(stream) + }) + }); + + rx.recv().unwrap(); + + let now = std::time::Instant::now(); + + // stop server gracefully. this step would take up to 30 seconds. + handle.stop(true).await; + + // join server thread. only when connection are all gone this step would finish. + join.join().unwrap()?; + + // check the time used for join server thread so it's known that the server shutdown + // is from handshake timeout and not server graceful shutdown timeout. + assert!(now.elapsed() < std::time::Duration::from_secs(30)); + + Ok(()) +} diff --git a/actix-http/tests/test_openssl.rs b/actix-http/tests/test_openssl.rs index a58d0cc70..86ee17c74 100644 --- a/actix-http/tests/test_openssl.rs +++ b/actix-http/tests/test_openssl.rs @@ -5,10 +5,10 @@ extern crate tls_openssl as openssl; use std::{convert::Infallible, io}; use actix_http::{ - body::{AnyBody, Body, SizedStream}, + body::{BodyStream, BoxBody, SizedStream}, error::PayloadError, http::{ - header::{self, HeaderName, HeaderValue}, + header::{self, HeaderValue}, Method, StatusCode, Version, }, Error, HttpMessage, HttpService, Request, Response, @@ -143,38 +143,25 @@ async fn test_h2_content_length() { }) .await; - let header = HeaderName::from_static("content-length"); - let value = HeaderValue::from_static("0"); + static VALUE: HeaderValue = HeaderValue::from_static("0"); { - for &i in &[0] { - let req = srv - .request(Method::HEAD, srv.surl(&format!("/{}", i))) - .send(); - let _response = req.await.expect_err("should timeout on recv 1xx frame"); - // assert_eq!(response.headers().get(&header), None); + let req = srv.request(Method::HEAD, srv.surl("/0")).send(); + req.await.expect_err("should timeout on recv 1xx frame"); - let req = srv - .request(Method::GET, srv.surl(&format!("/{}", i))) - .send(); - let _response = req.await.expect_err("should timeout on recv 1xx frame"); - // assert_eq!(response.headers().get(&header), None); - } + let req = srv.request(Method::GET, srv.surl("/0")).send(); + req.await.expect_err("should timeout on recv 1xx frame"); - for &i in &[1] { - let req = srv - .request(Method::GET, srv.surl(&format!("/{}", i))) - .send(); - let response = req.await.unwrap(); - assert_eq!(response.headers().get(&header), None); - } + let req = srv.request(Method::GET, srv.surl("/1")).send(); + let response = req.await.unwrap(); + assert!(response.headers().get("content-length").is_none()); for &i in &[2, 3] { let req = srv .request(Method::GET, srv.surl(&format!("/{}", i))) .send(); let response = req.await.unwrap(); - assert_eq!(response.headers().get(&header), Some(&value)); + assert_eq!(response.headers().get("content-length"), Some(&VALUE)); } } } @@ -361,7 +348,7 @@ async fn test_h2_body_chunked_explicit() { ok::<_, Infallible>( Response::build(StatusCode::OK) .insert_header((header::TRANSFER_ENCODING, "chunked")) - .streaming(body), + .body(BodyStream::new(body)), ) }) .openssl(tls_config()) @@ -412,9 +399,11 @@ async fn test_h2_response_http_error_handling() { #[display(fmt = "error")] struct BadRequest; -impl From for Response { +impl From for Response { fn from(err: BadRequest) -> Self { - Response::build(StatusCode::BAD_REQUEST).body(err.to_string()) + Response::build(StatusCode::BAD_REQUEST) + .body(err.to_string()) + .map_into_boxed_body() } } @@ -422,7 +411,7 @@ impl From for Response { async fn test_h2_service_error() { let mut srv = test_server(move || { HttpService::build() - .h2(|_| err::, _>(BadRequest)) + .h2(|_| err::, _>(BadRequest)) .openssl(tls_config()) .map_err(|_| ()) }) diff --git a/actix-http/tests/test_rustls.rs b/actix-http/tests/test_rustls.rs index cb7c77ad6..873752779 100644 --- a/actix-http/tests/test_rustls.rs +++ b/actix-http/tests/test_rustls.rs @@ -3,14 +3,14 @@ extern crate tls_rustls as rustls; use std::{ - convert::Infallible, + convert::{Infallible, TryFrom}, io::{self, BufReader, Write}, net::{SocketAddr, TcpStream as StdTcpStream}, sync::Arc, }; use actix_http::{ - body::{AnyBody, Body, SizedStream}, + body::{BodyStream, BoxBody, SizedStream}, error::PayloadError, http::{ header::{self, HeaderName, HeaderValue}, @@ -20,16 +20,14 @@ use actix_http::{ }; use actix_http_test::test_server; use actix_service::{fn_factory_with_config, fn_service}; +use actix_tls::connect::rustls::webpki_roots_cert_store; use actix_utils::future::{err, ok}; use bytes::{Bytes, BytesMut}; use derive_more::{Display, Error}; use futures_core::Stream; use futures_util::stream::{once, StreamExt as _}; -use rustls::{ - internal::pemfile::{certs, pkcs8_private_keys}, - NoClientAuth, ServerConfig as RustlsServerConfig, Session, -}; -use webpki::DNSNameRef; +use rustls::{Certificate, PrivateKey, ServerConfig as RustlsServerConfig, ServerName}; +use rustls_pemfile::{certs, pkcs8_private_keys}; async fn load_body(mut stream: S) -> Result where @@ -47,13 +45,24 @@ fn tls_config() -> RustlsServerConfig { let cert_file = cert.serialize_pem().unwrap(); let key_file = cert.serialize_private_key_pem(); - let mut config = RustlsServerConfig::new(NoClientAuth::new()); let cert_file = &mut BufReader::new(cert_file.as_bytes()); let key_file = &mut BufReader::new(key_file.as_bytes()); - let cert_chain = certs(cert_file).unwrap(); + let cert_chain = certs(cert_file) + .unwrap() + .into_iter() + .map(Certificate) + .collect(); let mut keys = pkcs8_private_keys(key_file).unwrap(); - config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); + + let mut config = RustlsServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(cert_chain, PrivateKey(keys.remove(0))) + .unwrap(); + + config.alpn_protocols.push(HTTP1_1_ALPN_PROTOCOL.to_vec()); + config.alpn_protocols.push(H2_ALPN_PROTOCOL.to_vec()); config } @@ -62,19 +71,28 @@ pub fn get_negotiated_alpn_protocol( addr: SocketAddr, client_alpn_protocol: &[u8], ) -> Option> { - let mut config = rustls::ClientConfig::new(); + let mut config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(webpki_roots_cert_store()) + .with_no_client_auth(); + config.alpn_protocols.push(client_alpn_protocol.to_vec()); - let mut sess = rustls::ClientSession::new( - &Arc::new(config), - DNSNameRef::try_from_ascii_str("localhost").unwrap(), - ); + + let mut sess = rustls::ClientConnection::new( + Arc::new(config), + ServerName::try_from("localhost").unwrap(), + ) + .unwrap(); + let mut sock = StdTcpStream::connect(addr).unwrap(); let mut stream = rustls::Stream::new(&mut sess, &mut sock); + // The handshake will fails because the client will not be able to verify the server // certificate, but it doesn't matter here as we are just interested in the negotiated ALPN // protocol let _ = stream.flush(); - sess.get_alpn_protocol().map(|proto| proto.to_vec()) + + sess.alpn_protocol().map(|proto| proto.to_vec()) } #[actix_rt::test] @@ -398,7 +416,7 @@ async fn test_h2_body_chunked_explicit() { ok::<_, Infallible>( Response::build(StatusCode::OK) .insert_header((header::TRANSFER_ENCODING, "chunked")) - .streaming(body), + .body(BodyStream::new(body)), ) }) .rustls(tls_config()) @@ -449,9 +467,9 @@ async fn test_h2_response_http_error_handling() { #[display(fmt = "error")] struct BadRequest; -impl From for Response { +impl From for Response { fn from(_: BadRequest) -> Self { - Response::bad_request().set_body(AnyBody::from("error")) + Response::bad_request().set_body(BoxBody::new("error")) } } @@ -459,7 +477,7 @@ impl From for Response { async fn test_h2_service_error() { let mut srv = test_server(move || { HttpService::build() - .h2(|_| err::, _>(BadRequest)) + .h2(|_| err::, _>(BadRequest)) .rustls(tls_config()) }) .await; @@ -476,7 +494,7 @@ async fn test_h2_service_error() { async fn test_h1_service_error() { let mut srv = test_server(move || { HttpService::build() - .h1(|_| err::, _>(BadRequest)) + .h1(|_| err::, _>(BadRequest)) .rustls(tls_config()) }) .await; diff --git a/actix-http/tests/test_server.rs b/actix-http/tests/test_server.rs index 1e6d0b637..adf2a28ca 100644 --- a/actix-http/tests/test_server.rs +++ b/actix-http/tests/test_server.rs @@ -6,7 +6,7 @@ use std::{ }; use actix_http::{ - body::{AnyBody, Body, SizedStream}, + body::{self, BodyStream, BoxBody, SizedStream}, header, http, Error, HttpMessage, HttpService, KeepAlive, Request, Response, StatusCode, }; @@ -24,7 +24,7 @@ use regex::Regex; #[actix_rt::test] async fn test_h1() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .keep_alive(KeepAlive::Disabled) .client_timeout(1000) @@ -39,11 +39,13 @@ async fn test_h1() { let response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); + + srv.stop().await; } #[actix_rt::test] async fn test_h1_2() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .keep_alive(KeepAlive::Disabled) .client_timeout(1000) @@ -59,13 +61,15 @@ async fn test_h1_2() { let response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); + + srv.stop().await; } #[derive(Debug, Display, Error)] #[display(fmt = "expect failed")] struct ExpectFailed; -impl From for Response { +impl From for Response { fn from(_: ExpectFailed) -> Self { Response::new(StatusCode::EXPECTATION_FAILED) } @@ -73,7 +77,7 @@ impl From for Response { #[actix_rt::test] async fn test_expect_continue() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .expect(fn_service(|req: Request| { if req.head().uri.query() == Some("yes=") { @@ -98,11 +102,13 @@ async fn test_expect_continue() { let mut data = String::new(); let _ = stream.read_to_string(&mut data); assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); + + srv.stop().await; } #[actix_rt::test] async fn test_expect_continue_h1() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .expect(fn_service(|req: Request| { sleep(Duration::from_millis(20)).then(move |_| { @@ -129,6 +135,8 @@ async fn test_expect_continue_h1() { let mut data = String::new(); let _ = stream.read_to_string(&mut data); assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); + + srv.stop().await; } #[actix_rt::test] @@ -136,7 +144,7 @@ async fn test_chunked_payload() { let chunk_sizes = vec![32768, 32, 32768]; let total_size: usize = chunk_sizes.iter().sum(); - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .h1(fn_service(|mut request: Request| { request @@ -183,15 +191,18 @@ async fn test_chunked_payload() { Some(caps) => caps.get(1).unwrap().as_str().parse().unwrap(), None => panic!("Failed to find size in HTTP Response: {}", data), }; + size }; assert_eq!(returned_size, total_size); + + srv.stop().await; } #[actix_rt::test] async fn test_slow_request() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .client_timeout(100) .finish(|_| ok::<_, Infallible>(Response::ok())) @@ -204,11 +215,13 @@ async fn test_slow_request() { let mut data = String::new(); let _ = stream.read_to_string(&mut data); assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); + + srv.stop().await; } #[actix_rt::test] async fn test_http1_malformed_request() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .h1(|_| ok::<_, Infallible>(Response::ok())) .tcp() @@ -220,11 +233,13 @@ async fn test_http1_malformed_request() { let mut data = String::new(); let _ = stream.read_to_string(&mut data); assert!(data.starts_with("HTTP/1.1 400 Bad Request")); + + srv.stop().await; } #[actix_rt::test] async fn test_http1_keepalive() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .h1(|_| ok::<_, Infallible>(Response::ok())) .tcp() @@ -241,11 +256,13 @@ async fn test_http1_keepalive() { let mut data = vec![0; 1024]; let _ = stream.read(&mut data); assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + + srv.stop().await; } #[actix_rt::test] async fn test_http1_keepalive_timeout() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .keep_alive(1) .h1(|_| ok::<_, Infallible>(Response::ok())) @@ -263,11 +280,13 @@ async fn test_http1_keepalive_timeout() { let mut data = vec![0; 1024]; let res = stream.read(&mut data).unwrap(); assert_eq!(res, 0); + + srv.stop().await; } #[actix_rt::test] async fn test_http1_keepalive_close() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .h1(|_| ok::<_, Infallible>(Response::ok())) .tcp() @@ -284,11 +303,13 @@ async fn test_http1_keepalive_close() { let mut data = vec![0; 1024]; let res = stream.read(&mut data).unwrap(); assert_eq!(res, 0); + + srv.stop().await; } #[actix_rt::test] async fn test_http10_keepalive_default_close() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .h1(|_| ok::<_, Infallible>(Response::ok())) .tcp() @@ -304,11 +325,13 @@ async fn test_http10_keepalive_default_close() { let mut data = vec![0; 1024]; let res = stream.read(&mut data).unwrap(); assert_eq!(res, 0); + + srv.stop().await; } #[actix_rt::test] async fn test_http10_keepalive() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .h1(|_| ok::<_, Infallible>(Response::ok())) .tcp() @@ -331,11 +354,13 @@ async fn test_http10_keepalive() { let mut data = vec![0; 1024]; let res = stream.read(&mut data).unwrap(); assert_eq!(res, 0); + + srv.stop().await; } #[actix_rt::test] async fn test_http1_keepalive_disabled() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .keep_alive(KeepAlive::Disabled) .h1(|_| ok::<_, Infallible>(Response::ok())) @@ -352,6 +377,8 @@ async fn test_http1_keepalive_disabled() { let mut data = vec![0; 1024]; let res = stream.read(&mut data).unwrap(); assert_eq!(res, 0); + + srv.stop().await; } #[actix_rt::test] @@ -361,7 +388,7 @@ async fn test_content_length() { StatusCode, }; - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .h1(|req: Request| { let indx: usize = req.uri().path()[1..].parse().unwrap(); @@ -399,6 +426,8 @@ async fn test_content_length() { assert_eq!(response.headers().get(&header), Some(&value)); } } + + srv.stop().await; } #[actix_rt::test] @@ -438,6 +467,8 @@ async fn test_h1_headers() { // read response let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from(data2)); + + srv.stop().await; } const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ @@ -477,6 +508,8 @@ async fn test_h1_body() { // read response let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -502,6 +535,8 @@ async fn test_h1_head_empty() { // read response let bytes = srv.load_body(response).await.unwrap(); assert!(bytes.is_empty()); + + srv.stop().await; } #[actix_rt::test] @@ -527,11 +562,13 @@ async fn test_h1_head_binary() { // read response let bytes = srv.load_body(response).await.unwrap(); assert!(bytes.is_empty()); + + srv.stop().await; } #[actix_rt::test] async fn test_h1_head_binary2() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .h1(|_| ok::<_, Infallible>(Response::ok().set_body(STR))) .tcp() @@ -548,6 +585,8 @@ async fn test_h1_head_binary2() { .unwrap(); assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); } + + srv.stop().await; } #[actix_rt::test] @@ -570,6 +609,8 @@ async fn test_h1_body_length() { // read response let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -581,7 +622,7 @@ async fn test_h1_body_chunked_explicit() { ok::<_, Infallible>( Response::build(StatusCode::OK) .insert_header((header::TRANSFER_ENCODING, "chunked")) - .streaming(body), + .body(BodyStream::new(body)), ) }) .tcp() @@ -605,6 +646,8 @@ async fn test_h1_body_chunked_explicit() { // decode assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -613,7 +656,9 @@ async fn test_h1_body_chunked_implicit() { HttpService::build() .h1(|_| { let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); - ok::<_, Infallible>(Response::build(StatusCode::OK).streaming(body)) + ok::<_, Infallible>( + Response::build(StatusCode::OK).body(BodyStream::new(body)), + ) }) .tcp() }) @@ -634,6 +679,8 @@ async fn test_h1_body_chunked_implicit() { // read response let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -661,15 +708,17 @@ async fn test_h1_response_http_error_handling() { bytes, Bytes::from_static(b"error processing HTTP: failed to parse header value") ); + + srv.stop().await; } #[derive(Debug, Display, Error)] #[display(fmt = "error")] struct BadRequest; -impl From for Response { +impl From for Response { fn from(_: BadRequest) -> Self { - Response::bad_request().set_body(AnyBody::from("error")) + Response::bad_request().set_body(BoxBody::new("error")) } } @@ -677,7 +726,7 @@ impl From for Response { async fn test_h1_service_error() { let mut srv = test_server(|| { HttpService::build() - .h1(|_| err::, _>(BadRequest)) + .h1(|_| err::, _>(BadRequest)) .tcp() }) .await; @@ -688,11 +737,13 @@ async fn test_h1_service_error() { // read response let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from_static(b"error")); + + srv.stop().await; } #[actix_rt::test] async fn test_h1_on_connect() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .on_connect_ext(|_, data| { data.insert(20isize); @@ -707,4 +758,92 @@ async fn test_h1_on_connect() { let response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); + + srv.stop().await; +} + +/// Tests compliance with 304 Not Modified spec in RFC 7232 §4.1. +/// https://datatracker.ietf.org/doc/html/rfc7232#section-4.1 +#[actix_rt::test] +async fn test_not_modified_spec_h1() { + // TODO: this test needing a few seconds to complete reveals some weirdness with either the + // dispatcher or the client, though similar hangs occur on other tests in this file, only + // succeeding, it seems, because of the keepalive timer + + static CL: header::HeaderName = header::CONTENT_LENGTH; + + let mut srv = test_server(|| { + HttpService::build() + .h1(|req: Request| { + let res: Response = match req.path() { + // with no content-length + "/none" => { + Response::with_body(StatusCode::NOT_MODIFIED, body::None::new()) + .map_into_boxed_body() + } + + // with no content-length + "/body" => Response::with_body(StatusCode::NOT_MODIFIED, "1234") + .map_into_boxed_body(), + + // with manual content-length header and specific None body + "/cl-none" => { + let mut res = Response::with_body( + StatusCode::NOT_MODIFIED, + body::None::new(), + ); + res.headers_mut() + .insert(CL.clone(), header::HeaderValue::from_static("24")); + res.map_into_boxed_body() + } + + // with manual content-length header and ignore-able body + "/cl-body" => { + let mut res = + Response::with_body(StatusCode::NOT_MODIFIED, "1234"); + res.headers_mut() + .insert(CL.clone(), header::HeaderValue::from_static("4")); + res.map_into_boxed_body() + } + + _ => panic!("unknown route"), + }; + + ok::<_, Infallible>(res) + }) + .tcp() + }) + .await; + + let res = srv.get("/none").send().await.unwrap(); + assert_eq!(res.status(), http::StatusCode::NOT_MODIFIED); + assert_eq!(res.headers().get(&CL), None); + assert!(srv.load_body(res).await.unwrap().is_empty()); + + let res = srv.get("/body").send().await.unwrap(); + assert_eq!(res.status(), http::StatusCode::NOT_MODIFIED); + assert_eq!(res.headers().get(&CL), None); + assert!(srv.load_body(res).await.unwrap().is_empty()); + + let res = srv.get("/cl-none").send().await.unwrap(); + assert_eq!(res.status(), http::StatusCode::NOT_MODIFIED); + assert_eq!( + res.headers().get(&CL), + Some(&header::HeaderValue::from_static("24")), + ); + assert!(srv.load_body(res).await.unwrap().is_empty()); + + let res = srv.get("/cl-body").send().await.unwrap(); + assert_eq!(res.status(), http::StatusCode::NOT_MODIFIED); + assert_eq!( + res.headers().get(&CL), + Some(&header::HeaderValue::from_static("4")), + ); + // server does not prevent payload from being sent but clients may choose not to read it + // TODO: this is probably a bug, especially since CL header can differ in length from the body + assert!(!srv.load_body(res).await.unwrap().is_empty()); + + // TODO: add stream response tests + + srv.stop().await; } diff --git a/actix-http/tests/test_ws.rs b/actix-http/tests/test_ws.rs index 6d0de2316..c91382013 100644 --- a/actix-http/tests/test_ws.rs +++ b/actix-http/tests/test_ws.rs @@ -6,7 +6,7 @@ use std::{ use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_http::{ - body::{AnyBody, BodySize}, + body::{BodySize, BoxBody}, h1, ws::{self, CloseCode, Frame, Item, Message}, Error, HttpService, Request, Response, @@ -50,14 +50,14 @@ enum WsServiceError { Dispatcher, } -impl From for Response { +impl From for Response { fn from(err: WsServiceError) -> Self { match err { WsServiceError::Http(err) => err.into(), WsServiceError::Ws(err) => err.into(), WsServiceError::Io(_err) => unreachable!(), WsServiceError::Dispatcher => Response::internal_server_error() - .set_body(AnyBody::from(format!("{}", err))), + .set_body(BoxBody::new(format!("{}", err))), } } } diff --git a/actix-multipart/CHANGES.md b/actix-multipart/CHANGES.md index 0b6affa3c..d9ded57a4 100644 --- a/actix-multipart/CHANGES.md +++ b/actix-multipart/CHANGES.md @@ -3,6 +3,31 @@ ## Unreleased - 2021-xx-xx +## 0.4.0-beta.9 - 2021-12-01 +* Polling `Field` after dropping `Multipart` now fails immediately instead of hanging forever. [#2463] + +[#2463]: https://github.com/actix/actix-web/pull/2463 + + +## 0.4.0-beta.8 - 2021-11-22 +* Ensure a correct Content-Disposition header is included in every part of a multipart message. [#2451] +* Added `MultipartError::NoContentDisposition` variant. [#2451] +* Since Content-Disposition is now ensured, `Field::content_disposition` is now infallible. [#2451] +* Added `Field::name` method for getting the field name. [#2451] +* `MultipartError` now marks variants with inner errors as the source. [#2451] +* `MultipartError` is now marked as non-exhaustive. [#2451] + +[#2451]: https://github.com/actix/actix-web/pull/2451 + + +## 0.4.0-beta.7 - 2021-10-20 +* Minimum supported Rust version (MSRV) is now 1.52. + + +## 0.4.0-beta.6 - 2021-09-09 +* Minimum supported Rust version (MSRV) is now 1.51. + + ## 0.4.0-beta.5 - 2021-06-17 * No notable changes. diff --git a/actix-multipart/Cargo.toml b/actix-multipart/Cargo.toml index 5103407ca..04a1d75ee 100644 --- a/actix-multipart/Cargo.toml +++ b/actix-multipart/Cargo.toml @@ -1,13 +1,11 @@ [package] name = "actix-multipart" -version = "0.4.0-beta.5" +version = "0.4.0-beta.9" authors = ["Nikolay Kim "] description = "Multipart form support for Actix Web" -readme = "README.md" keywords = ["http", "web", "framework", "async", "futures"] homepage = "https://actix.rs" repository = "https://github.com/actix/actix-web.git" -documentation = "https://docs.rs/actix-multipart" license = "MIT OR Apache-2.0" edition = "2018" @@ -16,13 +14,12 @@ name = "actix_multipart" path = "src/lib.rs" [dependencies] -actix-web = { version = "4.0.0-beta.8", default-features = false } +actix-web = { version = "4.0.0-beta.11", default-features = false } actix-utils = "3.0.0" bytes = "1" derive_more = "0.99.5" futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] } -futures-util = { version = "0.3.7", default-features = false, features = ["alloc"] } httparse = "1.3" local-waker = "0.1" log = "0.4" @@ -31,6 +28,7 @@ twoway = "0.2" [dev-dependencies] actix-rt = "2.2" -actix-http = "3.0.0-beta.8" +actix-http = "3.0.0-beta.14" +futures-util = { version = "0.3.7", default-features = false, features = ["alloc"] } tokio = { version = "1", features = ["sync"] } tokio-stream = "0.1" diff --git a/actix-multipart/README.md b/actix-multipart/README.md index 78855b815..85c78c5f3 100644 --- a/actix-multipart/README.md +++ b/actix-multipart/README.md @@ -3,15 +3,15 @@ > Multipart form support for Actix Web. [![crates.io](https://img.shields.io/crates/v/actix-multipart?label=latest)](https://crates.io/crates/actix-multipart) -[![Documentation](https://docs.rs/actix-multipart/badge.svg?version=0.4.0-beta.5)](https://docs.rs/actix-multipart/0.4.0-beta.5) -[![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) +[![Documentation](https://docs.rs/actix-multipart/badge.svg?version=0.4.0-beta.9)](https://docs.rs/actix-multipart/0.4.0-beta.9) +[![Version](https://img.shields.io/badge/rustc-1.52+-ab6000.svg)](https://blog.rust-lang.org/2021/05/06/Rust-1.52.0.html) ![MIT or Apache 2.0 licensed](https://img.shields.io/crates/l/actix-multipart.svg)
-[![dependency status](https://deps.rs/crate/actix-multipart/0.4.0-beta.5/status.svg)](https://deps.rs/crate/actix-multipart/0.4.0-beta.5) +[![dependency status](https://deps.rs/crate/actix-multipart/0.4.0-beta.9/status.svg)](https://deps.rs/crate/actix-multipart/0.4.0-beta.9) [![Download](https://img.shields.io/crates/d/actix-multipart.svg)](https://crates.io/crates/actix-multipart) [![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x) ## Documentation & Resources - [API Documentation](https://docs.rs/actix-multipart) -- Minimum Supported Rust Version (MSRV): 1.46.0 +- Minimum Supported Rust Version (MSRV): 1.52 diff --git a/actix-multipart/src/error.rs b/actix-multipart/src/error.rs index 5f91c60df..7d0da35e0 100644 --- a/actix-multipart/src/error.rs +++ b/actix-multipart/src/error.rs @@ -2,39 +2,52 @@ use actix_web::error::{ParseError, PayloadError}; use actix_web::http::StatusCode; use actix_web::ResponseError; -use derive_more::{Display, From}; +use derive_more::{Display, Error, From}; /// A set of errors that can occur during parsing multipart streams -#[derive(Debug, Display, From)] +#[non_exhaustive] +#[derive(Debug, Display, From, Error)] pub enum MultipartError { + /// Content-Disposition header is not found or is not equal to "form-data". + /// + /// According to [RFC 7578 §4.2](https://datatracker.ietf.org/doc/html/rfc7578#section-4.2) a + /// Content-Disposition header must always be present and equal to "form-data". + #[display(fmt = "No Content-Disposition `form-data` header")] + NoContentDisposition, + /// Content-Type header is not found - #[display(fmt = "No Content-type header found")] + #[display(fmt = "No Content-Type header found")] NoContentType, + /// Can not parse Content-Type header #[display(fmt = "Can not parse Content-Type header")] ParseContentType, + /// Multipart boundary is not found #[display(fmt = "Multipart boundary is not found")] Boundary, + /// Nested multipart is not supported #[display(fmt = "Nested multipart is not supported")] Nested, + /// Multipart stream is incomplete #[display(fmt = "Multipart stream is incomplete")] Incomplete, + /// Error during field parsing #[display(fmt = "{}", _0)] Parse(ParseError), + /// Payload error #[display(fmt = "{}", _0)] Payload(PayloadError), + /// Not consumed #[display(fmt = "Multipart stream is not consumed")] NotConsumed, } -impl std::error::Error for MultipartError {} - /// Return `BadRequest` for `MultipartError` impl ResponseError for MultipartError { fn status_code(&self) -> StatusCode { diff --git a/actix-multipart/src/extractor.rs b/actix-multipart/src/extractor.rs index c87f8cc2d..1ad1f203d 100644 --- a/actix-multipart/src/extractor.rs +++ b/actix-multipart/src/extractor.rs @@ -33,7 +33,6 @@ use crate::server::Multipart; impl FromRequest for Multipart { type Error = Error; type Future = Ready>; - type Config = (); #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { diff --git a/actix-multipart/src/server.rs b/actix-multipart/src/server.rs index b7d251537..8eabcee10 100644 --- a/actix-multipart/src/server.rs +++ b/actix-multipart/src/server.rs @@ -1,18 +1,22 @@ //! Multipart response payload support. -use std::cell::{Cell, RefCell, RefMut}; -use std::convert::TryFrom; -use std::marker::PhantomData; -use std::pin::Pin; -use std::rc::Rc; -use std::task::{Context, Poll}; -use std::{cmp, fmt}; +use std::{ + cell::{Cell, RefCell, RefMut}, + cmp, + convert::TryFrom, + fmt, + marker::PhantomData, + pin::Pin, + rc::Rc, + task::{Context, Poll}, +}; -use actix_web::error::{ParseError, PayloadError}; -use actix_web::http::header::{self, ContentDisposition, HeaderMap, HeaderName, HeaderValue}; +use actix_web::{ + error::{ParseError, PayloadError}, + http::header::{self, ContentDisposition, HeaderMap, HeaderName, HeaderValue}, +}; use bytes::{Bytes, BytesMut}; use futures_core::stream::{LocalBoxStream, Stream}; -use futures_util::stream::StreamExt as _; use local_waker::LocalWaker; use crate::error::MultipartError; @@ -28,7 +32,7 @@ const MAX_HEADERS: usize = 32; pub struct Multipart { safety: Safety, error: Option, - inner: Option>>, + inner: Option, } enum InnerMultipartItem { @@ -40,10 +44,13 @@ enum InnerMultipartItem { enum InnerState { /// Stream eof Eof, + /// Skip data until first boundary FirstBoundary, + /// Reading boundary Boundary, + /// Reading Headers, Headers, } @@ -59,7 +66,7 @@ impl Multipart { /// Create multipart instance for boundary. pub fn new(headers: &HeaderMap, stream: S) -> Multipart where - S: Stream> + Unpin + 'static, + S: Stream> + 'static, { match Self::boundary(headers) { Ok(boundary) => Multipart::from_boundary(boundary, stream), @@ -69,39 +76,32 @@ impl Multipart { /// Extract boundary info from headers. pub(crate) fn boundary(headers: &HeaderMap) -> Result { - if let Some(content_type) = headers.get(&header::CONTENT_TYPE) { - if let Ok(content_type) = content_type.to_str() { - if let Ok(ct) = content_type.parse::() { - if let Some(boundary) = ct.get_param(mime::BOUNDARY) { - Ok(boundary.as_str().to_owned()) - } else { - Err(MultipartError::Boundary) - } - } else { - Err(MultipartError::ParseContentType) - } - } else { - Err(MultipartError::ParseContentType) - } - } else { - Err(MultipartError::NoContentType) - } + headers + .get(&header::CONTENT_TYPE) + .ok_or(MultipartError::NoContentType)? + .to_str() + .ok() + .and_then(|content_type| content_type.parse::().ok()) + .ok_or(MultipartError::ParseContentType)? + .get_param(mime::BOUNDARY) + .map(|boundary| boundary.as_str().to_owned()) + .ok_or(MultipartError::Boundary) } /// Create multipart instance for given boundary and stream pub(crate) fn from_boundary(boundary: String, stream: S) -> Multipart where - S: Stream> + Unpin + 'static, + S: Stream> + 'static, { Multipart { error: None, safety: Safety::new(), - inner: Some(Rc::new(RefCell::new(InnerMultipart { + inner: Some(InnerMultipart { boundary, - payload: PayloadRef::new(PayloadBuffer::new(Box::new(stream))), + payload: PayloadRef::new(PayloadBuffer::new(stream)), state: InnerState::FirstBoundary, item: InnerMultipartItem::None, - }))), + }), } } @@ -118,20 +118,27 @@ impl Multipart { impl Stream for Multipart { type Item = Result; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if let Some(err) = self.error.take() { - Poll::Ready(Some(Err(err))) - } else if self.safety.current() { - let this = self.get_mut(); - let mut inner = this.inner.as_mut().unwrap().borrow_mut(); - if let Some(mut payload) = inner.payload.get_mut(&this.safety) { - payload.poll_stream(cx)?; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + match this.inner.as_mut() { + Some(inner) => { + if let Some(mut buffer) = inner.payload.get_mut(&this.safety) { + // check safety and poll read payload to buffer. + buffer.poll_stream(cx)?; + } else if !this.safety.is_clean() { + // safety violation + return Poll::Ready(Some(Err(MultipartError::NotConsumed))); + } else { + return Poll::Pending; + } + + inner.poll(&this.safety, cx) } - inner.poll(&this.safety, cx) - } else if !self.safety.is_clean() { - Poll::Ready(Some(Err(MultipartError::NotConsumed))) - } else { - Poll::Pending + None => Poll::Ready(Some(Err(this + .error + .take() + .expect("Multipart polled after finish")))), } } } @@ -152,17 +159,15 @@ impl InnerMultipart { Ok(httparse::Status::Complete((_, hdrs))) => { // convert headers let mut headers = HeaderMap::with_capacity(hdrs.len()); + for h in hdrs { - if let Ok(name) = HeaderName::try_from(h.name) { - if let Ok(value) = HeaderValue::try_from(h.value) { - headers.append(name, value); - } else { - return Err(ParseError::Header.into()); - } - } else { - return Err(ParseError::Header.into()); - } + let name = + HeaderName::try_from(h.name).map_err(|_| ParseError::Header)?; + let value = HeaderValue::try_from(h.value) + .map_err(|_| ParseError::Header)?; + headers.append(name, value); } + Ok(Some(headers)) } Ok(httparse::Status::Partial) => Err(ParseError::Header.into()), @@ -332,31 +337,55 @@ impl InnerMultipart { return Poll::Pending; }; - // content type - let mut mt = mime::APPLICATION_OCTET_STREAM; - if let Some(content_type) = headers.get(&header::CONTENT_TYPE) { - if let Ok(content_type) = content_type.to_str() { - if let Ok(ct) = content_type.parse::() { - mt = ct; - } - } - } + // According to RFC 7578 §4.2, a Content-Disposition header must always be present and + // set to "form-data". + + let content_disposition = headers + .get(&header::CONTENT_DISPOSITION) + .and_then(|cd| ContentDisposition::from_raw(cd).ok()) + .filter(|content_disposition| { + let is_form_data = + content_disposition.disposition == header::DispositionType::FormData; + + let has_field_name = content_disposition + .parameters + .iter() + .any(|param| matches!(param, header::DispositionParam::Name(_))); + + is_form_data && has_field_name + }); + + let cd = if let Some(content_disposition) = content_disposition { + content_disposition + } else { + return Poll::Ready(Some(Err(MultipartError::NoContentDisposition))); + }; + + let ct: mime::Mime = headers + .get(&header::CONTENT_TYPE) + .and_then(|ct| ct.to_str().ok()) + .and_then(|ct| ct.parse().ok()) + .unwrap_or(mime::APPLICATION_OCTET_STREAM); self.state = InnerState::Boundary; - // nested multipart stream - if mt.type_() == mime::MULTIPART { - Poll::Ready(Some(Err(MultipartError::Nested))) - } else { - let field = Rc::new(RefCell::new(InnerField::new( - self.payload.clone(), - self.boundary.clone(), - &headers, - )?)); - self.item = InnerMultipartItem::Field(Rc::clone(&field)); - - Poll::Ready(Some(Ok(Field::new(safety.clone(cx), headers, mt, field)))) + // nested multipart stream is not supported + if ct.type_() == mime::MULTIPART { + return Poll::Ready(Some(Err(MultipartError::Nested))); } + + let field = + InnerField::new_in_rc(self.payload.clone(), self.boundary.clone(), &headers)?; + + self.item = InnerMultipartItem::Field(Rc::clone(&field)); + + Poll::Ready(Some(Ok(Field::new( + safety.clone(cx), + headers, + ct, + cd, + field, + )))) } } } @@ -371,6 +400,7 @@ impl Drop for InnerMultipart { /// A single field in a multipart stream pub struct Field { ct: mime::Mime, + cd: ContentDisposition, headers: HeaderMap, inner: Rc>, safety: Safety, @@ -381,35 +411,52 @@ impl Field { safety: Safety, headers: HeaderMap, ct: mime::Mime, + cd: ContentDisposition, inner: Rc>, ) -> Self { Field { ct, + cd, headers, inner, safety, } } - /// Get a map of headers + /// Returns a reference to the field's header map. pub fn headers(&self) -> &HeaderMap { &self.headers } - /// Get the content type of the field + /// Returns a reference to the field's content (mime) type. pub fn content_type(&self) -> &mime::Mime { &self.ct } - /// Get the content disposition of the field, if it exists - pub fn content_disposition(&self) -> Option { - // RFC 7578: 'Each part MUST contain a Content-Disposition header field - // where the disposition type is "form-data".' - if let Some(content_disposition) = self.headers.get(&header::CONTENT_DISPOSITION) { - ContentDisposition::from_raw(content_disposition).ok() - } else { - None - } + /// Returns the field's Content-Disposition. + /// + /// Per [RFC 7578 §4.2]: "Each part MUST contain a Content-Disposition header field where the + /// disposition type is `form-data`. The Content-Disposition header field MUST also contain an + /// additional parameter of `name`; the value of the `name` parameter is the original field name + /// from the form." + /// + /// This crate validates that it exists before returning a `Field`. As such, it is safe to + /// unwrap `.content_disposition().get_name()`. The [name](Self::name) method is provided as + /// a convenience. + /// + /// [RFC 7578 §4.2]: https://datatracker.ietf.org/doc/html/rfc7578#section-4.2 + pub fn content_disposition(&self) -> &ContentDisposition { + &self.cd + } + + /// Returns the field's name. + /// + /// See [content_disposition](Self::content_disposition) regarding guarantees about existence of + /// the name field. + pub fn name(&self) -> &str { + self.content_disposition() + .get_name() + .expect("field name should be guaranteed to exist in multipart form-data") } } @@ -417,17 +464,19 @@ impl Stream for Field { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.safety.current() { - let mut inner = self.inner.borrow_mut(); - if let Some(mut payload) = inner.payload.as_ref().unwrap().get_mut(&self.safety) { - payload.poll_stream(cx)?; - } - inner.poll(&self.safety) - } else if !self.safety.is_clean() { - Poll::Ready(Some(Err(MultipartError::NotConsumed))) + let this = self.get_mut(); + let mut inner = this.inner.borrow_mut(); + if let Some(mut buffer) = inner.payload.as_ref().unwrap().get_mut(&this.safety) { + // check safety and poll read payload to buffer. + buffer.poll_stream(cx)?; + } else if !this.safety.is_clean() { + // safety violation + return Poll::Ready(Some(Err(MultipartError::NotConsumed))); } else { - Poll::Pending + return Poll::Pending; } + + inner.poll(&this.safety) } } @@ -451,20 +500,23 @@ struct InnerField { } impl InnerField { + fn new_in_rc( + payload: PayloadRef, + boundary: String, + headers: &HeaderMap, + ) -> Result>, PayloadError> { + Self::new(payload, boundary, headers).map(|this| Rc::new(RefCell::new(this))) + } + fn new( payload: PayloadRef, boundary: String, headers: &HeaderMap, ) -> Result { let len = if let Some(len) = headers.get(&header::CONTENT_LENGTH) { - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - Some(len) - } else { - return Err(PayloadError::Incomplete(None)); - } - } else { - return Err(PayloadError::Incomplete(None)); + match len.to_str().ok().and_then(|len| len.parse::().ok()) { + Some(len) => Some(len), + None => return Err(PayloadError::Incomplete(None)), } } else { None @@ -638,10 +690,7 @@ impl PayloadRef { } } - fn get_mut<'a, 'b>(&'a self, s: &'b Safety) -> Option> - where - 'a: 'b, - { + fn get_mut(&self, s: &Safety) -> Option> { if s.current() { Some(self.payload.borrow_mut()) } else { @@ -658,9 +707,11 @@ impl Clone for PayloadRef { } } -/// Counter. It tracks of number of clones of payloads and give access to -/// payload only to top most task panics if Safety get destroyed and it not top -/// most task. +/// Counter. It tracks of number of clones of payloads and give access to payload only to top most. +/// * When dropped, parent task is awakened. This is to support the case where Field is +/// dropped in a separate task than Multipart. +/// * Assumes that parent owners don't move to different tasks; only the top-most is allowed to. +/// * If dropped and is not top most owner, is_clean flag is set to false. #[derive(Debug)] struct Safety { task: LocalWaker, @@ -703,15 +754,16 @@ impl Safety { impl Drop for Safety { fn drop(&mut self) { - // parent task is dead if Rc::strong_count(&self.payload) != self.level { - self.clean.set(true); + // Multipart dropped leaving a Field + self.clean.set(false); } + self.task.wake(); } } -/// Payload buffer +/// Payload buffer. struct PayloadBuffer { eof: bool, buf: BytesMut, @@ -719,7 +771,7 @@ struct PayloadBuffer { } impl PayloadBuffer { - /// Create new `PayloadBuffer` instance + /// Constructs new `PayloadBuffer` instance. fn new(stream: S) -> Self where S: Stream> + 'static, @@ -727,7 +779,7 @@ impl PayloadBuffer { PayloadBuffer { eof: false, buf: BytesMut::new(), - stream: stream.boxed_local(), + stream: Box::pin(stream), } } @@ -767,7 +819,7 @@ impl PayloadBuffer { } /// Read until specified ending - pub fn read_until(&mut self, line: &[u8]) -> Result, MultipartError> { + fn read_until(&mut self, line: &[u8]) -> Result, MultipartError> { let res = twoway::find_bytes(&self.buf, line) .map(|idx| self.buf.split_to(idx + line.len()).freeze()); @@ -779,12 +831,12 @@ impl PayloadBuffer { } /// Read bytes until new line delimiter - pub fn readline(&mut self) -> Result, MultipartError> { + fn readline(&mut self) -> Result, MultipartError> { self.read_until(b"\n") } /// Read bytes until new line delimiter or eof - pub fn readline_or_eof(&mut self) -> Result, MultipartError> { + fn readline_or_eof(&mut self) -> Result, MultipartError> { match self.readline() { Err(MultipartError::Incomplete) if self.eof => Ok(Some(self.buf.split().freeze())), line => line, @@ -792,7 +844,7 @@ impl PayloadBuffer { } /// Put unprocessed data back to the buffer - pub fn unprocessed(&mut self, data: Bytes) { + fn unprocessed(&mut self, data: Bytes) { let buf = BytesMut::from(data.as_ref()); let buf = std::mem::replace(&mut self.buf, buf); self.buf.extend_from_slice(&buf); @@ -805,10 +857,12 @@ mod tests { use actix_http::h1::Payload; use actix_web::http::header::{DispositionParam, DispositionType}; + use actix_web::rt; use actix_web::test::TestRequest; use actix_web::FromRequest; use bytes::Bytes; - use futures_util::future::lazy; + use futures_util::{future::lazy, StreamExt}; + use std::time::Duration; use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -914,6 +968,7 @@ mod tests { Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ test\r\n\ --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ + Content-Disposition: form-data; name=\"file\"; filename=\"fn.txt\"\r\n\ Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ data\r\n\ --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n", @@ -965,7 +1020,7 @@ mod tests { let mut multipart = Multipart::new(&headers, payload); match multipart.next().await { Some(Ok(mut field)) => { - let cd = field.content_disposition().unwrap(); + let cd = field.content_disposition(); assert_eq!(cd.disposition, DispositionType::FormData); assert_eq!(cd.parameters[0], DispositionParam::Name("file".into())); @@ -1027,7 +1082,7 @@ mod tests { let mut multipart = Multipart::new(&headers, payload); match multipart.next().await.unwrap() { Ok(mut field) => { - let cd = field.content_disposition().unwrap(); + let cd = field.content_disposition(); assert_eq!(cd.disposition, DispositionType::FormData); assert_eq!(cd.parameters[0], DispositionParam::Name("file".into())); @@ -1182,4 +1237,99 @@ mod tests { _ => unreachable!(), } } + + #[actix_rt::test] + async fn no_content_disposition() { + let bytes = Bytes::from( + "testasdadsad\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ + Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ + test\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0\r\n", + ); + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static( + "multipart/mixed; boundary=\"abbc761f78ff4d7cb7573b5a23f96ef0\"", + ), + ); + let payload = SlowStream::new(bytes); + + let mut multipart = Multipart::new(&headers, payload); + let res = multipart.next().await.unwrap(); + assert!(res.is_err()); + assert!(matches!( + res.unwrap_err(), + MultipartError::NoContentDisposition, + )); + } + + #[actix_rt::test] + async fn no_name_in_content_disposition() { + let bytes = Bytes::from( + "testasdadsad\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ + Content-Disposition: form-data; filename=\"fn.txt\"\r\n\ + Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ + test\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0\r\n", + ); + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static( + "multipart/mixed; boundary=\"abbc761f78ff4d7cb7573b5a23f96ef0\"", + ), + ); + let payload = SlowStream::new(bytes); + + let mut multipart = Multipart::new(&headers, payload); + let res = multipart.next().await.unwrap(); + assert!(res.is_err()); + assert!(matches!( + res.unwrap_err(), + MultipartError::NoContentDisposition, + )); + } + + #[actix_rt::test] + async fn test_drop_multipart_dont_hang() { + let (sender, payload) = create_stream(); + let (bytes, headers) = create_simple_request_with_header(); + sender.send(Ok(bytes)).unwrap(); + drop(sender); // eof + + let mut multipart = Multipart::new(&headers, payload); + let mut field = multipart.next().await.unwrap().unwrap(); + + drop(multipart); + + // should fail immediately + match field.next().await { + Some(Err(MultipartError::NotConsumed)) => {} + _ => panic!(), + }; + } + + #[actix_rt::test] + async fn test_drop_field_awaken_multipart() { + let (sender, payload) = create_stream(); + let (bytes, headers) = create_simple_request_with_header(); + sender.send(Ok(bytes)).unwrap(); + drop(sender); // eof + + let mut multipart = Multipart::new(&headers, payload); + let mut field = multipart.next().await.unwrap().unwrap(); + + let task = rt::spawn(async move { + rt::time::sleep(Duration::from_secs(1)).await; + assert_eq!(field.next().await.unwrap().unwrap(), "test"); + drop(field); + }); + + // dropping field should awaken current task + let _ = multipart.next().await.unwrap().unwrap(); + task.await.unwrap(); + } } diff --git a/actix-router/CHANGES.md b/actix-router/CHANGES.md new file mode 100644 index 000000000..c2858f2ba --- /dev/null +++ b/actix-router/CHANGES.md @@ -0,0 +1,132 @@ +# Changes + +## Unreleased - 2021-xx-xx +* Minimum supported Rust version (MSRV) is now 1.52. + + +## 0.5.0-beta.2 - 2021-09-09 +* Introduce `ResourceDef::join`. [#380] +* Disallow prefix routes with tail segments. [#379] +* Enforce path separators on dynamic prefixes. [#378] +* Improve malformed path error message. [#384] +* Prefix segments now always end with with a segment delimiter or end-of-input. [#2355] +* Prefix segments with trailing slashes define a trailing empty segment. [#2355] +* Support multi-pattern prefixes and joins. [#2356] +* `ResourceDef::pattern` now returns the first pattern in multi-pattern resources. [#2356] +* Support `build_resource_path` on multi-pattern resources. [#2356] +* Minimum supported Rust version (MSRV) is now 1.51. + +[#378]: https://github.com/actix/actix-net/pull/378 +[#379]: https://github.com/actix/actix-net/pull/379 +[#380]: https://github.com/actix/actix-net/pull/380 +[#384]: https://github.com/actix/actix-net/pull/384 +[#2355]: https://github.com/actix/actix-web/pull/2355 +[#2356]: https://github.com/actix/actix-web/pull/2356 + + +## 0.5.0-beta.1 - 2021-07-20 +* Fix a bug in multi-patterns where static patterns are interpreted as regex. [#366] +* Introduce `ResourceDef::pattern_iter` to get an iterator over all patterns in a multi-pattern resource. [#373] +* Fix segment interpolation leaving `Path` in unintended state after matching. [#368] +* Fix `ResourceDef` `PartialEq` implementation. [#373] +* Re-work `IntoPatterns` trait, adding a `Patterns` enum. [#372] +* Implement `IntoPatterns` for `bytestring::ByteString`. [#372] +* Rename `Path::{len => segment_count}` to be more descriptive of it's purpose. [#370] +* Rename `ResourceDef::{resource_path => resource_path_from_iter}`. [#371] +* `ResourceDef::resource_path_from_iter` now takes an `IntoIterator`. [#373] +* Rename `ResourceDef::{resource_path_named => resource_path_from_map}`. [#371] +* Rename `ResourceDef::{is_prefix_match => find_match}`. [#373] +* Rename `ResourceDef::{match_path => capture_match_info}`. [#373] +* Rename `ResourceDef::{match_path_checked => capture_match_info_fn}`. [#373] +* Remove `ResourceDef::name_mut` and introduce `ResourceDef::set_name`. [#373] +* Rename `Router::{*_checked => *_fn}`. [#373] +* Return type of `ResourceDef::name` is now `Option<&str>`. [#373] +* Return type of `ResourceDef::pattern` is now `Option<&str>`. [#373] + +[#368]: https://github.com/actix/actix-net/pull/368 +[#366]: https://github.com/actix/actix-net/pull/366 +[#368]: https://github.com/actix/actix-net/pull/368 +[#370]: https://github.com/actix/actix-net/pull/370 +[#371]: https://github.com/actix/actix-net/pull/371 +[#372]: https://github.com/actix/actix-net/pull/372 +[#373]: https://github.com/actix/actix-net/pull/373 + + +## 0.4.0 - 2021-06-06 +* When matching path parameters, `%25` is now kept in the percent-encoded form; no longer decoded to `%`. [#357] +* Path tail patterns now match new lines (`\n`) in request URL. [#360] +* Fixed a safety bug where `Path` could return a malformed string after percent decoding. [#359] +* Methods `Path::{add, add_static}` now take `impl Into>`. [#345] + +[#345]: https://github.com/actix/actix-net/pull/345 +[#357]: https://github.com/actix/actix-net/pull/357 +[#359]: https://github.com/actix/actix-net/pull/359 +[#360]: https://github.com/actix/actix-net/pull/360 + + +## 0.3.0 - 2019-12-31 +* Version was yanked previously. See https://crates.io/crates/actix-router/0.3.0 + + +## 0.2.7 - 2021-02-06 +* Add `Router::recognize_checked` [#247] + +[#247]: https://github.com/actix/actix-net/pull/247 + + +## 0.2.6 - 2021-01-09 +* Use `bytestring` version range compatible with Bytes v1.0. [#246] + +[#246]: https://github.com/actix/actix-net/pull/246 + + +## 0.2.5 - 2020-09-20 +* Fix `from_hex()` method + + +## 0.2.4 - 2019-12-31 +* Add `ResourceDef::resource_path_named()` path generation method + + +## 0.2.3 - 2019-12-25 +* Add impl `IntoPattern` for `&String` + + +## 0.2.2 - 2019-12-25 +* Use `IntoPattern` for `RouterBuilder::path()` + + +## 0.2.1 - 2019-12-25 +* Add `IntoPattern` trait +* Add multi-pattern resources + + +## 0.2.0 - 2019-12-07 +* Update http to 0.2 +* Update regex to 1.3 +* Use bytestring instead of string + + +## 0.1.5 - 2019-05-15 +* Remove debug prints + + +## 0.1.4 - 2019-05-15 +* Fix checked resource match + + +## 0.1.3 - 2019-04-22 +* Added support for `remainder match` (i.e "/path/{tail}*") + + +## 0.1.2 - 2019-04-07 +* Export `Quoter` type +* Allow to reset `Path` instance + + +## 0.1.1 - 2019-04-03 +* Get dynamic segment by name instead of iterator. + + +## 0.1.0 - 2019-03-09 +* Initial release diff --git a/actix-router/Cargo.toml b/actix-router/Cargo.toml new file mode 100644 index 000000000..b95bca505 --- /dev/null +++ b/actix-router/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "actix-router" +version = "0.5.0-beta.2" +authors = [ + "Nikolay Kim ", + "Ali MJ Al-Nasrawy ", + "Rob Ede ", +] +description = "Resource path matching and router" +keywords = ["actix", "router", "routing"] +repository = "https://github.com/actix/actix-web.git" +license = "MIT OR Apache-2.0" +edition = "2018" + +[lib] +name = "actix_router" +path = "src/lib.rs" + +[features] +default = ["http"] + +[dependencies] +bytestring = ">=0.1.5, <2" +firestorm = "0.4" +http = { version = "0.2.3", optional = true } +log = "0.4" +regex = "1.5" +serde = "1" + +[dev-dependencies] +criterion = { version = "0.3", features = ["html_reports"] } +firestorm = { version = "0.4", features = ["enable_system_time"] } +http = "0.2.5" +serde = { version = "1", features = ["derive"] } + +[[bench]] +name = "router" +harness = false diff --git a/actix-router/LICENSE-APACHE b/actix-router/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/actix-router/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/actix-router/LICENSE-MIT b/actix-router/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/actix-router/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/actix-router/benches/router.rs b/actix-router/benches/router.rs new file mode 100644 index 000000000..a428b9f13 --- /dev/null +++ b/actix-router/benches/router.rs @@ -0,0 +1,194 @@ +//! Based on https://github.com/ibraheemdev/matchit/blob/master/benches/bench.rs + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +macro_rules! register { + (colon) => {{ + register!(finish => ":p1", ":p2", ":p3", ":p4") + }}; + (brackets) => {{ + register!(finish => "{p1}", "{p2}", "{p3}", "{p4}") + }}; + (regex) => {{ + register!(finish => "(.*)", "(.*)", "(.*)", "(.*)") + }}; + (finish => $p1:literal, $p2:literal, $p3:literal, $p4:literal) => {{ + let arr = [ + concat!("/authorizations"), + concat!("/authorizations/", $p1), + concat!("/applications/", $p1, "/tokens/", $p2), + concat!("/events"), + concat!("/repos/", $p1, "/", $p2, "/events"), + concat!("/networks/", $p1, "/", $p2, "/events"), + concat!("/orgs/", $p1, "/events"), + concat!("/users/", $p1, "/received_events"), + concat!("/users/", $p1, "/received_events/public"), + concat!("/users/", $p1, "/events"), + concat!("/users/", $p1, "/events/public"), + concat!("/users/", $p1, "/events/orgs/", $p2), + concat!("/feeds"), + concat!("/notifications"), + concat!("/repos/", $p1, "/", $p2, "/notifications"), + concat!("/notifications/threads/", $p1), + concat!("/notifications/threads/", $p1, "/subscription"), + concat!("/repos/", $p1, "/", $p2, "/stargazers"), + concat!("/users/", $p1, "/starred"), + concat!("/user/starred"), + concat!("/user/starred/", $p1, "/", $p2), + concat!("/repos/", $p1, "/", $p2, "/subscribers"), + concat!("/users/", $p1, "/subscriptions"), + concat!("/user/subscriptions"), + concat!("/repos/", $p1, "/", $p2, "/subscription"), + concat!("/user/subscriptions/", $p1, "/", $p2), + concat!("/users/", $p1, "/gists"), + concat!("/gists"), + concat!("/gists/", $p1), + concat!("/gists/", $p1, "/star"), + concat!("/repos/", $p1, "/", $p2, "/git/blobs/", $p3), + concat!("/repos/", $p1, "/", $p2, "/git/commits/", $p3), + concat!("/repos/", $p1, "/", $p2, "/git/refs"), + concat!("/repos/", $p1, "/", $p2, "/git/tags/", $p3), + concat!("/repos/", $p1, "/", $p2, "/git/trees/", $p3), + concat!("/issues"), + concat!("/user/issues"), + concat!("/orgs/", $p1, "/issues"), + concat!("/repos/", $p1, "/", $p2, "/issues"), + concat!("/repos/", $p1, "/", $p2, "/issues/", $p3), + concat!("/repos/", $p1, "/", $p2, "/assignees"), + concat!("/repos/", $p1, "/", $p2, "/assignees/", $p3), + concat!("/repos/", $p1, "/", $p2, "/issues/", $p3, "/comments"), + concat!("/repos/", $p1, "/", $p2, "/issues/", $p3, "/events"), + concat!("/repos/", $p1, "/", $p2, "/labels"), + concat!("/repos/", $p1, "/", $p2, "/labels/", $p3), + concat!("/repos/", $p1, "/", $p2, "/issues/", $p3, "/labels"), + concat!("/repos/", $p1, "/", $p2, "/milestones/", $p3, "/labels"), + concat!("/repos/", $p1, "/", $p2, "/milestones/"), + concat!("/repos/", $p1, "/", $p2, "/milestones/", $p3), + concat!("/emojis"), + concat!("/gitignore/templates"), + concat!("/gitignore/templates/", $p1), + concat!("/meta"), + concat!("/rate_limit"), + concat!("/users/", $p1, "/orgs"), + concat!("/user/orgs"), + concat!("/orgs/", $p1), + concat!("/orgs/", $p1, "/members"), + concat!("/orgs/", $p1, "/members", $p2), + concat!("/orgs/", $p1, "/public_members"), + concat!("/orgs/", $p1, "/public_members/", $p2), + concat!("/orgs/", $p1, "/teams"), + concat!("/teams/", $p1), + concat!("/teams/", $p1, "/members"), + concat!("/teams/", $p1, "/members", $p2), + concat!("/teams/", $p1, "/repos"), + concat!("/teams/", $p1, "/repos/", $p2, "/", $p3), + concat!("/user/teams"), + concat!("/repos/", $p1, "/", $p2, "/pulls"), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3, "/commits"), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3, "/files"), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3, "/merge"), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3, "/comments"), + concat!("/user/repos"), + concat!("/users/", $p1, "/repos"), + concat!("/orgs/", $p1, "/repos"), + concat!("/repositories"), + concat!("/repos/", $p1, "/", $p2), + concat!("/repos/", $p1, "/", $p2, "/contributors"), + concat!("/repos/", $p1, "/", $p2, "/languages"), + concat!("/repos/", $p1, "/", $p2, "/teams"), + concat!("/repos/", $p1, "/", $p2, "/tags"), + concat!("/repos/", $p1, "/", $p2, "/branches"), + concat!("/repos/", $p1, "/", $p2, "/branches/", $p3), + concat!("/repos/", $p1, "/", $p2, "/collaborators"), + concat!("/repos/", $p1, "/", $p2, "/collaborators/", $p3), + concat!("/repos/", $p1, "/", $p2, "/comments"), + concat!("/repos/", $p1, "/", $p2, "/commits/", $p3, "/comments"), + concat!("/repos/", $p1, "/", $p2, "/commits"), + concat!("/repos/", $p1, "/", $p2, "/commits/", $p3), + concat!("/repos/", $p1, "/", $p2, "/readme"), + concat!("/repos/", $p1, "/", $p2, "/keys"), + concat!("/repos/", $p1, "/", $p2, "/keys", $p3), + concat!("/repos/", $p1, "/", $p2, "/downloads"), + concat!("/repos/", $p1, "/", $p2, "/downloads", $p3), + concat!("/repos/", $p1, "/", $p2, "/forks"), + concat!("/repos/", $p1, "/", $p2, "/hooks"), + concat!("/repos/", $p1, "/", $p2, "/hooks", $p3), + concat!("/repos/", $p1, "/", $p2, "/releases"), + concat!("/repos/", $p1, "/", $p2, "/releases/", $p3), + concat!("/repos/", $p1, "/", $p2, "/releases/", $p3, "/assets"), + concat!("/repos/", $p1, "/", $p2, "/stats/contributors"), + concat!("/repos/", $p1, "/", $p2, "/stats/commit_activity"), + concat!("/repos/", $p1, "/", $p2, "/stats/code_frequency"), + concat!("/repos/", $p1, "/", $p2, "/stats/participation"), + concat!("/repos/", $p1, "/", $p2, "/stats/punch_card"), + concat!("/repos/", $p1, "/", $p2, "/statuses/", $p3), + concat!("/search/repositories"), + concat!("/search/code"), + concat!("/search/issues"), + concat!("/search/users"), + concat!("/legacy/issues/search/", $p1, "/", $p2, "/", $p3, "/", $p4), + concat!("/legacy/repos/search/", $p1), + concat!("/legacy/user/search/", $p1), + concat!("/legacy/user/email/", $p1), + concat!("/users/", $p1), + concat!("/user"), + concat!("/users"), + concat!("/user/emails"), + concat!("/users/", $p1, "/followers"), + concat!("/user/followers"), + concat!("/users/", $p1, "/following"), + concat!("/user/following"), + concat!("/user/following/", $p1), + concat!("/users/", $p1, "/following", $p2), + concat!("/users/", $p1, "/keys"), + concat!("/user/keys"), + concat!("/user/keys/", $p1), + ]; + std::array::IntoIter::new(arr) + }}; +} + +fn call() -> impl Iterator { + let arr = [ + "/authorizations", + "/user/repos", + "/repos/rust-lang/rust/stargazers", + "/orgs/rust-lang/public_members/nikomatsakis", + "/repos/rust-lang/rust/releases/1.51.0", + ]; + + std::array::IntoIter::new(arr) +} + +fn compare_routers(c: &mut Criterion) { + let mut group = c.benchmark_group("Compare Routers"); + + let mut actix = actix_router::Router::::build(); + for route in register!(brackets) { + actix.path(route, true); + } + let actix = actix.finish(); + group.bench_function("actix", |b| { + b.iter(|| { + for route in call() { + let mut path = actix_router::Path::new(route); + black_box(actix.recognize(&mut path).unwrap()); + } + }); + }); + + let regex_set = regex::RegexSet::new(register!(regex)).unwrap(); + group.bench_function("regex", |b| { + b.iter(|| { + for route in call() { + black_box(regex_set.matches(route)); + } + }); + }); + + group.finish(); +} + +criterion_group!(benches, compare_routers); +criterion_main!(benches); diff --git a/actix-router/examples/flamegraph.rs b/actix-router/examples/flamegraph.rs new file mode 100644 index 000000000..798cc22d9 --- /dev/null +++ b/actix-router/examples/flamegraph.rs @@ -0,0 +1,169 @@ +macro_rules! register { + (brackets) => {{ + register!(finish => "{p1}", "{p2}", "{p3}", "{p4}") + }}; + (finish => $p1:literal, $p2:literal, $p3:literal, $p4:literal) => {{ + let arr = [ + concat!("/authorizations"), + concat!("/authorizations/", $p1), + concat!("/applications/", $p1, "/tokens/", $p2), + concat!("/events"), + concat!("/repos/", $p1, "/", $p2, "/events"), + concat!("/networks/", $p1, "/", $p2, "/events"), + concat!("/orgs/", $p1, "/events"), + concat!("/users/", $p1, "/received_events"), + concat!("/users/", $p1, "/received_events/public"), + concat!("/users/", $p1, "/events"), + concat!("/users/", $p1, "/events/public"), + concat!("/users/", $p1, "/events/orgs/", $p2), + concat!("/feeds"), + concat!("/notifications"), + concat!("/repos/", $p1, "/", $p2, "/notifications"), + concat!("/notifications/threads/", $p1), + concat!("/notifications/threads/", $p1, "/subscription"), + concat!("/repos/", $p1, "/", $p2, "/stargazers"), + concat!("/users/", $p1, "/starred"), + concat!("/user/starred"), + concat!("/user/starred/", $p1, "/", $p2), + concat!("/repos/", $p1, "/", $p2, "/subscribers"), + concat!("/users/", $p1, "/subscriptions"), + concat!("/user/subscriptions"), + concat!("/repos/", $p1, "/", $p2, "/subscription"), + concat!("/user/subscriptions/", $p1, "/", $p2), + concat!("/users/", $p1, "/gists"), + concat!("/gists"), + concat!("/gists/", $p1), + concat!("/gists/", $p1, "/star"), + concat!("/repos/", $p1, "/", $p2, "/git/blobs/", $p3), + concat!("/repos/", $p1, "/", $p2, "/git/commits/", $p3), + concat!("/repos/", $p1, "/", $p2, "/git/refs"), + concat!("/repos/", $p1, "/", $p2, "/git/tags/", $p3), + concat!("/repos/", $p1, "/", $p2, "/git/trees/", $p3), + concat!("/issues"), + concat!("/user/issues"), + concat!("/orgs/", $p1, "/issues"), + concat!("/repos/", $p1, "/", $p2, "/issues"), + concat!("/repos/", $p1, "/", $p2, "/issues/", $p3), + concat!("/repos/", $p1, "/", $p2, "/assignees"), + concat!("/repos/", $p1, "/", $p2, "/assignees/", $p3), + concat!("/repos/", $p1, "/", $p2, "/issues/", $p3, "/comments"), + concat!("/repos/", $p1, "/", $p2, "/issues/", $p3, "/events"), + concat!("/repos/", $p1, "/", $p2, "/labels"), + concat!("/repos/", $p1, "/", $p2, "/labels/", $p3), + concat!("/repos/", $p1, "/", $p2, "/issues/", $p3, "/labels"), + concat!("/repos/", $p1, "/", $p2, "/milestones/", $p3, "/labels"), + concat!("/repos/", $p1, "/", $p2, "/milestones/"), + concat!("/repos/", $p1, "/", $p2, "/milestones/", $p3), + concat!("/emojis"), + concat!("/gitignore/templates"), + concat!("/gitignore/templates/", $p1), + concat!("/meta"), + concat!("/rate_limit"), + concat!("/users/", $p1, "/orgs"), + concat!("/user/orgs"), + concat!("/orgs/", $p1), + concat!("/orgs/", $p1, "/members"), + concat!("/orgs/", $p1, "/members", $p2), + concat!("/orgs/", $p1, "/public_members"), + concat!("/orgs/", $p1, "/public_members/", $p2), + concat!("/orgs/", $p1, "/teams"), + concat!("/teams/", $p1), + concat!("/teams/", $p1, "/members"), + concat!("/teams/", $p1, "/members", $p2), + concat!("/teams/", $p1, "/repos"), + concat!("/teams/", $p1, "/repos/", $p2, "/", $p3), + concat!("/user/teams"), + concat!("/repos/", $p1, "/", $p2, "/pulls"), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3, "/commits"), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3, "/files"), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3, "/merge"), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3, "/comments"), + concat!("/user/repos"), + concat!("/users/", $p1, "/repos"), + concat!("/orgs/", $p1, "/repos"), + concat!("/repositories"), + concat!("/repos/", $p1, "/", $p2), + concat!("/repos/", $p1, "/", $p2, "/contributors"), + concat!("/repos/", $p1, "/", $p2, "/languages"), + concat!("/repos/", $p1, "/", $p2, "/teams"), + concat!("/repos/", $p1, "/", $p2, "/tags"), + concat!("/repos/", $p1, "/", $p2, "/branches"), + concat!("/repos/", $p1, "/", $p2, "/branches/", $p3), + concat!("/repos/", $p1, "/", $p2, "/collaborators"), + concat!("/repos/", $p1, "/", $p2, "/collaborators/", $p3), + concat!("/repos/", $p1, "/", $p2, "/comments"), + concat!("/repos/", $p1, "/", $p2, "/commits/", $p3, "/comments"), + concat!("/repos/", $p1, "/", $p2, "/commits"), + concat!("/repos/", $p1, "/", $p2, "/commits/", $p3), + concat!("/repos/", $p1, "/", $p2, "/readme"), + concat!("/repos/", $p1, "/", $p2, "/keys"), + concat!("/repos/", $p1, "/", $p2, "/keys", $p3), + concat!("/repos/", $p1, "/", $p2, "/downloads"), + concat!("/repos/", $p1, "/", $p2, "/downloads", $p3), + concat!("/repos/", $p1, "/", $p2, "/forks"), + concat!("/repos/", $p1, "/", $p2, "/hooks"), + concat!("/repos/", $p1, "/", $p2, "/hooks", $p3), + concat!("/repos/", $p1, "/", $p2, "/releases"), + concat!("/repos/", $p1, "/", $p2, "/releases/", $p3), + concat!("/repos/", $p1, "/", $p2, "/releases/", $p3, "/assets"), + concat!("/repos/", $p1, "/", $p2, "/stats/contributors"), + concat!("/repos/", $p1, "/", $p2, "/stats/commit_activity"), + concat!("/repos/", $p1, "/", $p2, "/stats/code_frequency"), + concat!("/repos/", $p1, "/", $p2, "/stats/participation"), + concat!("/repos/", $p1, "/", $p2, "/stats/punch_card"), + concat!("/repos/", $p1, "/", $p2, "/statuses/", $p3), + concat!("/search/repositories"), + concat!("/search/code"), + concat!("/search/issues"), + concat!("/search/users"), + concat!("/legacy/issues/search/", $p1, "/", $p2, "/", $p3, "/", $p4), + concat!("/legacy/repos/search/", $p1), + concat!("/legacy/user/search/", $p1), + concat!("/legacy/user/email/", $p1), + concat!("/users/", $p1), + concat!("/user"), + concat!("/users"), + concat!("/user/emails"), + concat!("/users/", $p1, "/followers"), + concat!("/user/followers"), + concat!("/users/", $p1, "/following"), + concat!("/user/following"), + concat!("/user/following/", $p1), + concat!("/users/", $p1, "/following", $p2), + concat!("/users/", $p1, "/keys"), + concat!("/user/keys"), + concat!("/user/keys/", $p1), + ]; + + arr.to_vec() + }}; +} + +static PATHS: [&str; 5] = [ + "/authorizations", + "/user/repos", + "/repos/rust-lang/rust/stargazers", + "/orgs/rust-lang/public_members/nikomatsakis", + "/repos/rust-lang/rust/releases/1.51.0", +]; + +fn main() { + let mut router = actix_router::Router::::build(); + + for route in register!(brackets) { + router.path(route, true); + } + + let actix = router.finish(); + + if firestorm::enabled() { + firestorm::bench("target", || { + for &route in &PATHS { + let mut path = actix_router::Path::new(route); + actix.recognize(&mut path).unwrap(); + } + }) + .unwrap(); + } +} diff --git a/actix-router/src/de.rs b/actix-router/src/de.rs new file mode 100644 index 000000000..775c48b8a --- /dev/null +++ b/actix-router/src/de.rs @@ -0,0 +1,723 @@ +use serde::de::{self, Deserializer, Error as DeError, Visitor}; +use serde::forward_to_deserialize_any; + +use crate::path::{Path, PathIter}; +use crate::ResourcePath; + +macro_rules! unsupported_type { + ($trait_fn:ident, $name:expr) => { + fn $trait_fn(self, _: V) -> Result + where + V: Visitor<'de>, + { + Err(de::value::Error::custom(concat!( + "unsupported type: ", + $name + ))) + } + }; +} + +macro_rules! parse_single_value { + ($trait_fn:ident, $visit_fn:ident, $tp:tt) => { + fn $trait_fn(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.path.segment_count() != 1 { + Err(de::value::Error::custom( + format!( + "wrong number of parameters: {} expected 1", + self.path.segment_count() + ) + .as_str(), + )) + } else { + let v = self.path[0].parse().map_err(|_| { + de::value::Error::custom(format!( + "can not parse {:?} to a {}", + &self.path[0], $tp + )) + })?; + visitor.$visit_fn(v) + } + } + }; +} + +pub struct PathDeserializer<'de, T: ResourcePath> { + path: &'de Path, +} + +impl<'de, T: ResourcePath + 'de> PathDeserializer<'de, T> { + pub fn new(path: &'de Path) -> Self { + PathDeserializer { path } + } +} + +impl<'de, T: ResourcePath + 'de> Deserializer<'de> for PathDeserializer<'de, T> { + type Error = de::value::Error; + + fn deserialize_map(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_map(ParamsDeserializer { + params: self.path.iter(), + current: None, + }) + } + + fn deserialize_struct( + self, + _: &'static str, + _: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.deserialize_map(visitor) + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit_struct( + self, + _: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.deserialize_unit(visitor) + } + + fn deserialize_newtype_struct( + self, + _: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_tuple(self, len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.path.segment_count() < len { + Err(de::value::Error::custom( + format!( + "wrong number of parameters: {} expected {}", + self.path.segment_count(), + len + ) + .as_str(), + )) + } else { + visitor.visit_seq(ParamsSeq { + params: self.path.iter(), + }) + } + } + + fn deserialize_tuple_struct( + self, + _: &'static str, + len: usize, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + if self.path.segment_count() < len { + Err(de::value::Error::custom( + format!( + "wrong number of parameters: {} expected {}", + self.path.segment_count(), + len + ) + .as_str(), + )) + } else { + visitor.visit_seq(ParamsSeq { + params: self.path.iter(), + }) + } + } + + fn deserialize_enum( + self, + _: &'static str, + _: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + if self.path.is_empty() { + Err(de::value::Error::custom("expected at least one parameters")) + } else { + visitor.visit_enum(ValueEnum { + value: &self.path[0], + }) + } + } + + fn deserialize_str(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.path.segment_count() != 1 { + Err(de::value::Error::custom( + format!( + "wrong number of parameters: {} expected 1", + self.path.segment_count() + ) + .as_str(), + )) + } else { + visitor.visit_str(&self.path[0]) + } + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(ParamsSeq { + params: self.path.iter(), + }) + } + + unsupported_type!(deserialize_any, "'any'"); + unsupported_type!(deserialize_bytes, "bytes"); + unsupported_type!(deserialize_option, "Option"); + unsupported_type!(deserialize_identifier, "identifier"); + unsupported_type!(deserialize_ignored_any, "ignored_any"); + + parse_single_value!(deserialize_bool, visit_bool, "bool"); + parse_single_value!(deserialize_i8, visit_i8, "i8"); + parse_single_value!(deserialize_i16, visit_i16, "i16"); + parse_single_value!(deserialize_i32, visit_i32, "i32"); + parse_single_value!(deserialize_i64, visit_i64, "i64"); + parse_single_value!(deserialize_u8, visit_u8, "u8"); + parse_single_value!(deserialize_u16, visit_u16, "u16"); + parse_single_value!(deserialize_u32, visit_u32, "u32"); + parse_single_value!(deserialize_u64, visit_u64, "u64"); + parse_single_value!(deserialize_f32, visit_f32, "f32"); + parse_single_value!(deserialize_f64, visit_f64, "f64"); + parse_single_value!(deserialize_string, visit_string, "String"); + parse_single_value!(deserialize_byte_buf, visit_string, "String"); + parse_single_value!(deserialize_char, visit_char, "char"); +} + +struct ParamsDeserializer<'de, T: ResourcePath> { + params: PathIter<'de, T>, + current: Option<(&'de str, &'de str)>, +} + +impl<'de, T: ResourcePath> de::MapAccess<'de> for ParamsDeserializer<'de, T> { + type Error = de::value::Error; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: de::DeserializeSeed<'de>, + { + self.current = self.params.next().map(|ref item| (item.0, item.1)); + match self.current { + Some((key, _)) => Ok(Some(seed.deserialize(Key { key })?)), + None => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: de::DeserializeSeed<'de>, + { + if let Some((_, value)) = self.current.take() { + seed.deserialize(Value { value }) + } else { + Err(de::value::Error::custom("unexpected item")) + } + } +} + +struct Key<'de> { + key: &'de str, +} + +impl<'de> Deserializer<'de> for Key<'de> { + type Error = de::value::Error; + + fn deserialize_identifier(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_str(self.key) + } + + fn deserialize_any(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(de::value::Error::custom("Unexpected")) + } + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes + byte_buf option unit unit_struct newtype_struct seq tuple + tuple_struct map struct enum ignored_any + } +} + +macro_rules! parse_value { + ($trait_fn:ident, $visit_fn:ident, $tp:tt) => { + fn $trait_fn(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let v = self.value.parse().map_err(|_| { + de::value::Error::custom(format!("can not parse {:?} to a {}", self.value, $tp)) + })?; + visitor.$visit_fn(v) + } + }; +} + +struct Value<'de> { + value: &'de str, +} + +impl<'de> Deserializer<'de> for Value<'de> { + type Error = de::value::Error; + + parse_value!(deserialize_bool, visit_bool, "bool"); + parse_value!(deserialize_i8, visit_i8, "i8"); + parse_value!(deserialize_i16, visit_i16, "i16"); + parse_value!(deserialize_i32, visit_i32, "i16"); + parse_value!(deserialize_i64, visit_i64, "i64"); + parse_value!(deserialize_u8, visit_u8, "u8"); + parse_value!(deserialize_u16, visit_u16, "u16"); + parse_value!(deserialize_u32, visit_u32, "u32"); + parse_value!(deserialize_u64, visit_u64, "u64"); + parse_value!(deserialize_f32, visit_f32, "f32"); + parse_value!(deserialize_f64, visit_f64, "f64"); + parse_value!(deserialize_string, visit_string, "String"); + parse_value!(deserialize_byte_buf, visit_string, "String"); + parse_value!(deserialize_char, visit_char, "char"); + + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit_struct( + self, + _: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_borrowed_bytes(self.value.as_bytes()) + } + + fn deserialize_str(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_borrowed_str(self.value) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_some(self) + } + + fn deserialize_enum( + self, + _: &'static str, + _: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_enum(ValueEnum { value: self.value }) + } + + fn deserialize_newtype_struct( + self, + _: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_tuple(self, _: usize, _: V) -> Result + where + V: Visitor<'de>, + { + Err(de::value::Error::custom("unsupported type: tuple")) + } + + fn deserialize_struct( + self, + _: &'static str, + _: &'static [&'static str], + _: V, + ) -> Result + where + V: Visitor<'de>, + { + Err(de::value::Error::custom("unsupported type: struct")) + } + + fn deserialize_tuple_struct( + self, + _: &'static str, + _: usize, + _: V, + ) -> Result + where + V: Visitor<'de>, + { + Err(de::value::Error::custom("unsupported type: tuple struct")) + } + + unsupported_type!(deserialize_any, "any"); + unsupported_type!(deserialize_seq, "seq"); + unsupported_type!(deserialize_map, "map"); + unsupported_type!(deserialize_identifier, "identifier"); +} + +struct ParamsSeq<'de, T: ResourcePath> { + params: PathIter<'de, T>, +} + +impl<'de, T: ResourcePath> de::SeqAccess<'de> for ParamsSeq<'de, T> { + type Error = de::value::Error; + + fn next_element_seed(&mut self, seed: U) -> Result, Self::Error> + where + U: de::DeserializeSeed<'de>, + { + match self.params.next() { + Some(item) => Ok(Some(seed.deserialize(Value { value: item.1 })?)), + None => Ok(None), + } + } +} + +struct ValueEnum<'de> { + value: &'de str, +} + +impl<'de> de::EnumAccess<'de> for ValueEnum<'de> { + type Error = de::value::Error; + type Variant = UnitVariant; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: de::DeserializeSeed<'de>, + { + Ok((seed.deserialize(Key { key: self.value })?, UnitVariant)) + } +} + +struct UnitVariant; + +impl<'de> de::VariantAccess<'de> for UnitVariant { + type Error = de::value::Error; + + fn unit_variant(self) -> Result<(), Self::Error> { + Ok(()) + } + + fn newtype_variant_seed(self, _seed: T) -> Result + where + T: de::DeserializeSeed<'de>, + { + Err(de::value::Error::custom("not supported")) + } + + fn tuple_variant(self, _len: usize, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(de::value::Error::custom("not supported")) + } + + fn struct_variant( + self, + _: &'static [&'static str], + _: V, + ) -> Result + where + V: Visitor<'de>, + { + Err(de::value::Error::custom("not supported")) + } +} + +#[cfg(test)] +mod tests { + use serde::{de, Deserialize}; + + use super::*; + use crate::path::Path; + use crate::router::Router; + + #[derive(Deserialize)] + struct MyStruct { + key: String, + value: String, + } + + #[derive(Deserialize)] + struct Id { + _id: String, + } + + #[derive(Debug, Deserialize)] + struct Test1(String, u32); + + #[derive(Debug, Deserialize)] + struct Test2 { + key: String, + value: u32, + } + + #[derive(Debug, Deserialize, PartialEq)] + #[serde(rename_all = "lowercase")] + enum TestEnum { + Val1, + Val2, + } + + #[derive(Debug, Deserialize)] + struct Test3 { + val: TestEnum, + } + + #[test] + fn test_request_extract() { + let mut router = Router::<()>::build(); + router.path("/{key}/{value}/", ()); + let router = router.finish(); + + let mut path = Path::new("/name/user1/"); + assert!(router.recognize(&mut path).is_some()); + + let s: MyStruct = de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(s.key, "name"); + assert_eq!(s.value, "user1"); + + let s: (String, String) = + de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(s.0, "name"); + assert_eq!(s.1, "user1"); + + let mut router = Router::<()>::build(); + router.path("/{key}/{value}/", ()); + let router = router.finish(); + + let mut path = Path::new("/name/32/"); + assert!(router.recognize(&mut path).is_some()); + + let s: Test1 = de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(s.0, "name"); + assert_eq!(s.1, 32); + + let s: Test2 = de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(s.key, "name"); + assert_eq!(s.value, 32); + + let s: (String, u8) = + de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(s.0, "name"); + assert_eq!(s.1, 32); + + let res: Vec = + de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(res[0], "name".to_owned()); + assert_eq!(res[1], "32".to_owned()); + } + + #[test] + fn test_extract_path_single() { + let mut router = Router::<()>::build(); + router.path("/{value}/", ()); + let router = router.finish(); + + let mut path = Path::new("/32/"); + assert!(router.recognize(&mut path).is_some()); + let i: i8 = de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(i, 32); + } + + #[test] + fn test_extract_enum() { + let mut router = Router::<()>::build(); + router.path("/{val}/", ()); + let router = router.finish(); + + let mut path = Path::new("/val1/"); + assert!(router.recognize(&mut path).is_some()); + let i: TestEnum = de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(i, TestEnum::Val1); + + let mut router = Router::<()>::build(); + router.path("/{val1}/{val2}/", ()); + let router = router.finish(); + + let mut path = Path::new("/val1/val2/"); + assert!(router.recognize(&mut path).is_some()); + let i: (TestEnum, TestEnum) = + de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(i, (TestEnum::Val1, TestEnum::Val2)); + } + + #[test] + fn test_extract_enum_value() { + let mut router = Router::<()>::build(); + router.path("/{val}/", ()); + let router = router.finish(); + + let mut path = Path::new("/val1/"); + assert!(router.recognize(&mut path).is_some()); + let i: Test3 = de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(i.val, TestEnum::Val1); + + let mut path = Path::new("/val3/"); + assert!(router.recognize(&mut path).is_some()); + let i: Result = + de::Deserialize::deserialize(PathDeserializer::new(&path)); + assert!(i.is_err()); + assert!(format!("{:?}", i).contains("unknown variant")); + } + + #[test] + fn test_extract_errors() { + let mut router = Router::<()>::build(); + router.path("/{value}/", ()); + let router = router.finish(); + + let mut path = Path::new("/name/"); + assert!(router.recognize(&mut path).is_some()); + + let s: Result = + de::Deserialize::deserialize(PathDeserializer::new(&path)); + assert!(s.is_err()); + assert!(format!("{:?}", s).contains("wrong number of parameters")); + + let s: Result = + de::Deserialize::deserialize(PathDeserializer::new(&path)); + assert!(s.is_err()); + assert!(format!("{:?}", s).contains("can not parse")); + + let s: Result<(String, String), de::value::Error> = + de::Deserialize::deserialize(PathDeserializer::new(&path)); + assert!(s.is_err()); + assert!(format!("{:?}", s).contains("wrong number of parameters")); + + let s: Result = + de::Deserialize::deserialize(PathDeserializer::new(&path)); + assert!(s.is_err()); + assert!(format!("{:?}", s).contains("can not parse")); + } + + // #[test] + // fn test_extract_path_decode() { + // let mut router = Router::<()>::default(); + // router.register_resource(Resource::new(ResourceDef::new("/{value}/"))); + + // macro_rules! test_single_value { + // ($value:expr, $expected:expr) => {{ + // let req = TestRequest::with_uri($value).finish(); + // let info = router.recognize(&req, &(), 0); + // let req = req.with_route_info(info); + // assert_eq!( + // *Path::::from_request(&req, &PathConfig::default()).unwrap(), + // $expected + // ); + // }}; + // } + + // test_single_value!("/%25/", "%"); + // test_single_value!("/%40%C2%A3%24%25%5E%26%2B%3D/", "@£$%^&+="); + // test_single_value!("/%2B/", "+"); + // test_single_value!("/%252B/", "%2B"); + // test_single_value!("/%2F/", "/"); + // test_single_value!("/%252F/", "%2F"); + // test_single_value!( + // "/http%3A%2F%2Flocalhost%3A80%2Ffoo/", + // "http://localhost:80/foo" + // ); + // test_single_value!("/%2Fvar%2Flog%2Fsyslog/", "/var/log/syslog"); + // test_single_value!( + // "/http%3A%2F%2Flocalhost%3A80%2Ffile%2F%252Fvar%252Flog%252Fsyslog/", + // "http://localhost:80/file/%2Fvar%2Flog%2Fsyslog" + // ); + + // let req = TestRequest::with_uri("/%25/7/?id=test").finish(); + + // let mut router = Router::<()>::default(); + // router.register_resource(Resource::new(ResourceDef::new("/{key}/{value}/"))); + // let info = router.recognize(&req, &(), 0); + // let req = req.with_route_info(info); + + // let s = Path::::from_request(&req, &PathConfig::default()).unwrap(); + // assert_eq!(s.key, "%"); + // assert_eq!(s.value, 7); + + // let s = Path::<(String, String)>::from_request(&req, &PathConfig::default()).unwrap(); + // assert_eq!(s.0, "%"); + // assert_eq!(s.1, "7"); + // } + + // #[test] + // fn test_extract_path_no_decode() { + // let mut router = Router::<()>::default(); + // router.register_resource(Resource::new(ResourceDef::new("/{value}/"))); + + // let req = TestRequest::with_uri("/%25/").finish(); + // let info = router.recognize(&req, &(), 0); + // let req = req.with_route_info(info); + // assert_eq!( + // *Path::::from_request(&req, &&PathConfig::default().disable_decoding()) + // .unwrap(), + // "%25" + // ); + // } +} diff --git a/actix-router/src/lib.rs b/actix-router/src/lib.rs new file mode 100644 index 000000000..463e59e42 --- /dev/null +++ b/actix-router/src/lib.rs @@ -0,0 +1,149 @@ +//! Resource path matching and router. + +#![deny(rust_2018_idioms, nonstandard_style)] +#![doc(html_logo_url = "https://actix.rs/img/logo.png")] +#![doc(html_favicon_url = "https://actix.rs/favicon.ico")] + +mod de; +mod path; +mod resource; +mod router; + +pub use self::de::PathDeserializer; +pub use self::path::Path; +pub use self::resource::ResourceDef; +pub use self::router::{ResourceInfo, Router, RouterBuilder}; + +// TODO: this trait is necessary, document it +// see impl Resource for ServiceRequest +pub trait Resource { + fn resource_path(&mut self) -> &mut Path; +} + +pub trait ResourcePath { + fn path(&self) -> &str; +} + +impl ResourcePath for String { + fn path(&self) -> &str { + self.as_str() + } +} + +impl<'a> ResourcePath for &'a str { + fn path(&self) -> &str { + self + } +} + +impl ResourcePath for bytestring::ByteString { + fn path(&self) -> &str { + &*self + } +} + +/// One or many patterns. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Patterns { + Single(String), + List(Vec), +} + +impl Patterns { + pub fn is_empty(&self) -> bool { + match self { + Patterns::Single(_) => false, + Patterns::List(pats) => pats.is_empty(), + } + } +} + +/// Helper trait for type that could be converted to one or more path pattern. +pub trait IntoPatterns { + fn patterns(&self) -> Patterns; +} + +impl IntoPatterns for String { + fn patterns(&self) -> Patterns { + Patterns::Single(self.clone()) + } +} + +impl<'a> IntoPatterns for &'a String { + fn patterns(&self) -> Patterns { + Patterns::Single((*self).clone()) + } +} + +impl<'a> IntoPatterns for &'a str { + fn patterns(&self) -> Patterns { + Patterns::Single((*self).to_owned()) + } +} + +impl IntoPatterns for bytestring::ByteString { + fn patterns(&self) -> Patterns { + Patterns::Single(self.to_string()) + } +} + +impl IntoPatterns for Patterns { + fn patterns(&self) -> Patterns { + self.clone() + } +} + +impl> IntoPatterns for Vec { + fn patterns(&self) -> Patterns { + let mut patterns = self.iter().map(|v| v.as_ref().to_owned()); + + match patterns.size_hint() { + (1, _) => Patterns::Single(patterns.next().unwrap()), + _ => Patterns::List(patterns.collect()), + } + } +} + +macro_rules! array_patterns_single (($tp:ty) => { + impl IntoPatterns for [$tp; 1] { + fn patterns(&self) -> Patterns { + Patterns::Single(self[0].to_owned()) + } + } +}); + +macro_rules! array_patterns_multiple (($tp:ty, $str_fn:expr, $($num:tt) +) => { + // for each array length specified in $num + $( + impl IntoPatterns for [$tp; $num] { + fn patterns(&self) -> Patterns { + Patterns::List(self.iter().map($str_fn).collect()) + } + } + )+ +}); + +array_patterns_single!(&str); +array_patterns_multiple!(&str, |&v| v.to_owned(), 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16); + +array_patterns_single!(String); +array_patterns_multiple!(String, |v| v.clone(), 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16); + +#[cfg(feature = "http")] +mod url; + +#[cfg(feature = "http")] +pub use self::url::{Quoter, Url}; + +#[cfg(feature = "http")] +mod http_impls { + use http::Uri; + + use super::ResourcePath; + + impl ResourcePath for Uri { + fn path(&self) -> &str { + self.path() + } + } +} diff --git a/actix-router/src/path.rs b/actix-router/src/path.rs new file mode 100644 index 000000000..9af7b0b8b --- /dev/null +++ b/actix-router/src/path.rs @@ -0,0 +1,220 @@ +use std::borrow::Cow; +use std::ops::Index; + +use firestorm::profile_method; +use serde::de; + +use crate::{de::PathDeserializer, Resource, ResourcePath}; + +#[derive(Debug, Clone)] +pub(crate) enum PathItem { + Static(Cow<'static, str>), + Segment(u16, u16), +} + +impl Default for PathItem { + fn default() -> Self { + Self::Static(Cow::Borrowed("")) + } +} + +/// Resource path match information. +/// +/// If resource path contains variable patterns, `Path` stores them. +#[derive(Debug, Clone, Default)] +pub struct Path { + path: T, + pub(crate) skip: u16, + pub(crate) segments: Vec<(Cow<'static, str>, PathItem)>, +} + +impl Path { + pub fn new(path: T) -> Path { + Path { + path, + skip: 0, + segments: Vec::new(), + } + } + + /// Get reference to inner path instance. + #[inline] + pub fn get_ref(&self) -> &T { + &self.path + } + + /// Get mutable reference to inner path instance. + #[inline] + pub fn get_mut(&mut self) -> &mut T { + &mut self.path + } + + /// Path. + #[inline] + pub fn path(&self) -> &str { + profile_method!(path); + + let skip = self.skip as usize; + let path = self.path.path(); + if skip <= path.len() { + &path[skip..] + } else { + "" + } + } + + /// Set new path. + #[inline] + pub fn set(&mut self, path: T) { + self.skip = 0; + self.path = path; + self.segments.clear(); + } + + /// Reset state. + #[inline] + pub fn reset(&mut self) { + self.skip = 0; + self.segments.clear(); + } + + /// Skip first `n` chars in path. + #[inline] + pub fn skip(&mut self, n: u16) { + self.skip += n; + } + + pub(crate) fn add(&mut self, name: impl Into>, value: PathItem) { + profile_method!(add); + + match value { + PathItem::Static(s) => self.segments.push((name.into(), PathItem::Static(s))), + PathItem::Segment(begin, end) => self.segments.push(( + name.into(), + PathItem::Segment(self.skip + begin, self.skip + end), + )), + } + } + + #[doc(hidden)] + pub fn add_static( + &mut self, + name: impl Into>, + value: impl Into>, + ) { + self.segments + .push((name.into(), PathItem::Static(value.into()))); + } + + /// Check if there are any matched patterns. + #[inline] + pub fn is_empty(&self) -> bool { + self.segments.is_empty() + } + + /// Returns number of interpolated segments. + #[inline] + pub fn segment_count(&self) -> usize { + self.segments.len() + } + + /// Get matched parameter by name without type conversion + pub fn get(&self, name: &str) -> Option<&str> { + profile_method!(get); + + for (seg_name, val) in self.segments.iter() { + if name == seg_name { + return match val { + PathItem::Static(ref s) => Some(s), + PathItem::Segment(s, e) => { + Some(&self.path.path()[(*s as usize)..(*e as usize)]) + } + }; + } + } + + None + } + + /// Get unprocessed part of the path + pub fn unprocessed(&self) -> &str { + &self.path.path()[(self.skip as usize)..] + } + + /// Get matched parameter by name. + /// + /// If keyed parameter is not available empty string is used as default value. + pub fn query(&self, key: &str) -> &str { + profile_method!(query); + + if let Some(s) = self.get(key) { + s + } else { + "" + } + } + + /// Return iterator to items in parameter container. + pub fn iter(&self) -> PathIter<'_, T> { + PathIter { + idx: 0, + params: self, + } + } + + /// Try to deserialize matching parameters to a specified type `U` + pub fn load<'de, U: serde::Deserialize<'de>>(&'de self) -> Result { + profile_method!(load); + de::Deserialize::deserialize(PathDeserializer::new(self)) + } +} + +#[derive(Debug)] +pub struct PathIter<'a, T> { + idx: usize, + params: &'a Path, +} + +impl<'a, T: ResourcePath> Iterator for PathIter<'a, T> { + type Item = (&'a str, &'a str); + + #[inline] + fn next(&mut self) -> Option<(&'a str, &'a str)> { + if self.idx < self.params.segment_count() { + let idx = self.idx; + let res = match self.params.segments[idx].1 { + PathItem::Static(ref s) => s, + PathItem::Segment(s, e) => &self.params.path.path()[(s as usize)..(e as usize)], + }; + self.idx += 1; + return Some((&self.params.segments[idx].0, res)); + } + None + } +} + +impl<'a, T: ResourcePath> Index<&'a str> for Path { + type Output = str; + + fn index(&self, name: &'a str) -> &str { + self.get(name) + .expect("Value for parameter is not available") + } +} + +impl Index for Path { + type Output = str; + + fn index(&self, idx: usize) -> &str { + match self.segments[idx].1 { + PathItem::Static(ref s) => s, + PathItem::Segment(s, e) => &self.path.path()[(s as usize)..(e as usize)], + } + } +} + +impl Resource for Path { + fn resource_path(&mut self) -> &mut Self { + self + } +} diff --git a/actix-router/src/resource.rs b/actix-router/src/resource.rs new file mode 100644 index 000000000..d5f738a05 --- /dev/null +++ b/actix-router/src/resource.rs @@ -0,0 +1,1820 @@ +use std::{ + borrow::{Borrow, Cow}, + collections::HashMap, + hash::{BuildHasher, Hash, Hasher}, + mem, +}; + +use firestorm::{profile_fn, profile_method, profile_section}; +use regex::{escape, Regex, RegexSet}; + +use crate::{ + path::{Path, PathItem}, + IntoPatterns, Patterns, Resource, ResourcePath, +}; + +const MAX_DYNAMIC_SEGMENTS: usize = 16; + +/// Regex flags to allow '.' in regex to match '\n' +/// +/// See the docs under: https://docs.rs/regex/1/regex/#grouping-and-flags +const REGEX_FLAGS: &str = "(?s-m)"; + +/// Describes the set of paths that match to a resource. +/// +/// `ResourceDef`s are effectively a way to transform the a custom resource pattern syntax into +/// suitable regular expressions from which to check matches with paths and capture portions of a +/// matched path into variables. Common cases are on a fast path that avoids going through the +/// regex engine. +/// +/// +/// # Pattern Format and Matching Behavior +/// +/// Resource pattern is defined as a string of zero or more _segments_ where each segment is +/// preceded by a slash `/`. +/// +/// This means that pattern string __must__ either be empty or begin with a slash (`/`). +/// This also implies that a trailing slash in pattern defines an empty segment. +/// For example, the pattern `"/user/"` has two segments: `["user", ""]` +/// +/// A key point to underhand is that `ResourceDef` matches segments, not strings. +/// It matches segments individually. +/// For example, the pattern `/user/` is not considered a prefix for the path `/user/123/456`, +/// because the second segment doesn't match: `["user", ""]` vs `["user", "123", "456"]`. +/// +/// This definition is consistent with the definition of absolute URL path in +/// [RFC 3986 (section 3.3)](https://datatracker.ietf.org/doc/html/rfc3986#section-3.3) +/// +/// +/// # Static Resources +/// A static resource is the most basic type of definition. Pass a pattern to +/// [new][Self::new]. Conforming paths must match the pattern exactly. +/// +/// ## Examples +/// ``` +/// # use actix_router::ResourceDef; +/// let resource = ResourceDef::new("/home"); +/// +/// assert!(resource.is_match("/home")); +/// +/// assert!(!resource.is_match("/home/")); +/// assert!(!resource.is_match("/home/new")); +/// assert!(!resource.is_match("/homes")); +/// assert!(!resource.is_match("/search")); +/// ``` +/// +/// +/// # Dynamic Segments +/// Also known as "path parameters". Resources can define sections of a pattern that be extracted +/// from a conforming path, if it conforms to (one of) the resource pattern(s). +/// +/// The marker for a dynamic segment is curly braces wrapping an identifier. For example, +/// `/user/{id}` would match paths like `/user/123` or `/user/james` and be able to extract the user +/// IDs "123" and "james", respectively. +/// +/// However, this resource pattern (`/user/{id}`) would, not cover `/user/123/stars` (unless +/// constructed as a prefix; see next section) since the default pattern for segments matches all +/// characters until it finds a `/` character (or the end of the path). Custom segment patterns are +/// covered further down. +/// +/// Dynamic segments do not need to be delimited by `/` characters, they can be defined within a +/// path segment. For example, `/rust-is-{opinion}` can match the paths `/rust-is-cool` and +/// `/rust-is-hard`. +/// +/// For information on capturing segment values from paths or other custom resource types, +/// see [`capture_match_info`][Self::capture_match_info] +/// and [`capture_match_info_fn`][Self::capture_match_info_fn]. +/// +/// A resource definition can contain at most 16 dynamic segments. +/// +/// ## Examples +/// ``` +/// use actix_router::{Path, ResourceDef}; +/// +/// let resource = ResourceDef::prefix("/user/{id}"); +/// +/// assert!(resource.is_match("/user/123")); +/// assert!(!resource.is_match("/user")); +/// assert!(!resource.is_match("/user/")); +/// +/// let mut path = Path::new("/user/123"); +/// resource.capture_match_info(&mut path); +/// assert_eq!(path.get("id").unwrap(), "123"); +/// ``` +/// +/// +/// # Prefix Resources +/// A prefix resource is defined as pattern that can match just the start of a path, up to a +/// segment boundary. +/// +/// Prefix patterns with a trailing slash may have an unexpected, though correct, behavior. +/// They define and therefore require an empty segment in order to match. Examples are given below. +/// +/// Empty pattern matches any path as a prefix. +/// +/// Prefix resources can contain dynamic segments. +/// +/// ## Examples +/// ``` +/// # use actix_router::ResourceDef; +/// let resource = ResourceDef::prefix("/home"); +/// assert!(resource.is_match("/home")); +/// assert!(resource.is_match("/home/new")); +/// assert!(!resource.is_match("/homes")); +/// +/// // prefix pattern with a trailing slash +/// let resource = ResourceDef::prefix("/user/{id}/"); +/// assert!(resource.is_match("/user/123/")); +/// assert!(resource.is_match("/user/123//stars")); +/// assert!(!resource.is_match("/user/123/stars")); +/// assert!(!resource.is_match("/user/123")); +/// ``` +/// +/// +/// # Custom Regex Segments +/// Dynamic segments can be customised to only match a specific regular expression. It can be +/// helpful to do this if resource definitions would otherwise conflict and cause one to +/// be inaccessible. +/// +/// The regex used when capturing segment values can be specified explicitly using this syntax: +/// `{name:regex}`. For example, `/user/{id:\d+}` will only match paths where the user ID +/// is numeric. +/// +/// The regex could potentially match multiple segments. If this is not wanted, then care must be +/// taken to avoid matching a slash `/`. It is guaranteed, however, that the match ends at a +/// segment boundary; the pattern `r"(/|$)` is always appended to the regex. +/// +/// By default, dynamic segments use this regex: `[^/]+`. This shows why it is the case, as shown in +/// the earlier section, that segments capture a slice of the path up to the next `/` character. +/// +/// Custom regex segments can be used in static and prefix resource definition variants. +/// +/// ## Examples +/// ``` +/// # use actix_router::ResourceDef; +/// let resource = ResourceDef::new(r"/user/{id:\d+}"); +/// assert!(resource.is_match("/user/123")); +/// assert!(resource.is_match("/user/314159")); +/// assert!(!resource.is_match("/user/abc")); +/// ``` +/// +/// +/// # Tail Segments +/// As a shortcut to defining a custom regex for matching _all_ remaining characters (not just those +/// up until a `/` character), there is a special pattern to match (and capture) the remaining +/// path portion. +/// +/// To do this, use the segment pattern: `{name}*`. Since a tail segment also has a name, values are +/// extracted in the same way as non-tail dynamic segments. +/// +/// ## Examples +/// ```rust +/// # use actix_router::{Path, ResourceDef}; +/// let resource = ResourceDef::new("/blob/{tail}*"); +/// assert!(resource.is_match("/blob/HEAD/Cargo.toml")); +/// assert!(resource.is_match("/blob/HEAD/README.md")); +/// +/// let mut path = Path::new("/blob/main/LICENSE"); +/// resource.capture_match_info(&mut path); +/// assert_eq!(path.get("tail").unwrap(), "main/LICENSE"); +/// ``` +/// +/// +/// # Multi-Pattern Resources +/// For resources that can map to multiple distinct paths, it may be suitable to use +/// multi-pattern resources by passing an array/vec to [`new`][Self::new]. They will be combined +/// into a regex set which is usually quicker to check matches on than checking each +/// pattern individually. +/// +/// Multi-pattern resources can contain dynamic segments just like single pattern ones. +/// However, take care to use consistent and semantically-equivalent segment names; it could affect +/// expectations in the router using these definitions and cause runtime panics. +/// +/// ## Examples +/// ```rust +/// # use actix_router::ResourceDef; +/// let resource = ResourceDef::new(["/home", "/index"]); +/// assert!(resource.is_match("/home")); +/// assert!(resource.is_match("/index")); +/// ``` +/// +/// +/// # Trailing Slashes +/// It should be noted that this library takes no steps to normalize intra-path or trailing slashes. +/// As such, all resource definitions implicitly expect a pre-processing step to normalize paths if +/// they you wish to accommodate "recoverable" path errors. Below are several examples of +/// resource-path pairs that would not be compatible. +/// +/// ## Examples +/// ```rust +/// # use actix_router::ResourceDef; +/// assert!(!ResourceDef::new("/root").is_match("/root/")); +/// assert!(!ResourceDef::new("/root/").is_match("/root")); +/// assert!(!ResourceDef::prefix("/root/").is_match("/root")); +/// ``` +#[derive(Clone, Debug)] +pub struct ResourceDef { + id: u16, + + /// Optional name of resource. + name: Option, + + /// Pattern that generated the resource definition. + patterns: Patterns, + + is_prefix: bool, + + /// Pattern type. + pat_type: PatternType, + + /// List of segments that compose the pattern, in order. + segments: Vec, +} + +#[derive(Debug, Clone, PartialEq)] +enum PatternSegment { + /// Literal slice of pattern. + Const(String), + + /// Name of dynamic segment. + Var(String), +} + +#[derive(Clone, Debug)] +#[allow(clippy::large_enum_variant)] +enum PatternType { + /// Single constant/literal segment. + Static(String), + + /// Single regular expression and list of dynamic segment names. + Dynamic(Regex, Vec<&'static str>), + + /// Regular expression set and list of component expressions plus dynamic segment names. + DynamicSet(RegexSet, Vec<(Regex, Vec<&'static str>)>), +} + +impl ResourceDef { + /// Constructs a new resource definition from patterns. + /// + /// Multi-pattern resources can be constructed by providing a slice (or vec) of patterns. + /// + /// # Panics + /// Panics if path pattern is malformed. + /// + /// # Examples + /// ``` + /// use actix_router::ResourceDef; + /// + /// let resource = ResourceDef::new("/user/{id}"); + /// assert!(resource.is_match("/user/123")); + /// assert!(!resource.is_match("/user/123/stars")); + /// assert!(!resource.is_match("user/1234")); + /// assert!(!resource.is_match("/foo")); + /// + /// let resource = ResourceDef::new(["/profile", "/user/{id}"]); + /// assert!(resource.is_match("/profile")); + /// assert!(resource.is_match("/user/123")); + /// assert!(!resource.is_match("user/123")); + /// assert!(!resource.is_match("/foo")); + /// ``` + pub fn new(paths: T) -> Self { + profile_method!(new); + Self::new2(paths, false) + } + + /// Constructs a new resource definition using a pattern that performs prefix matching. + /// + /// More specifically, the regular expressions generated for matching are different when using + /// this method vs using `new`; they will not be appended with the `$` meta-character that + /// matches the end of an input. + /// + /// Although it will compile and run correctly, it is meaningless to construct a prefix + /// resource definition with a tail segment; use [`new`][Self::new] in this case. + /// + /// # Panics + /// Panics if path regex pattern is malformed. + /// + /// # Examples + /// ``` + /// use actix_router::ResourceDef; + /// + /// let resource = ResourceDef::prefix("/user/{id}"); + /// assert!(resource.is_match("/user/123")); + /// assert!(resource.is_match("/user/123/stars")); + /// assert!(!resource.is_match("user/123")); + /// assert!(!resource.is_match("user/123/stars")); + /// assert!(!resource.is_match("/foo")); + /// ``` + pub fn prefix(paths: T) -> Self { + profile_method!(prefix); + ResourceDef::new2(paths, true) + } + + /// Constructs a new resource definition using a string pattern that performs prefix matching, + /// inserting a `/` to beginning of the pattern if absent and pattern is not empty. + /// + /// # Panics + /// Panics if path regex pattern is malformed. + /// + /// # Examples + /// ``` + /// use actix_router::ResourceDef; + /// + /// let resource = ResourceDef::root_prefix("user/{id}"); + /// + /// assert_eq!(&resource, &ResourceDef::prefix("/user/{id}")); + /// assert_eq!(&resource, &ResourceDef::root_prefix("/user/{id}")); + /// assert_ne!(&resource, &ResourceDef::new("user/{id}")); + /// assert_ne!(&resource, &ResourceDef::new("/user/{id}")); + /// + /// assert!(resource.is_match("/user/123")); + /// assert!(!resource.is_match("user/123")); + /// ``` + pub fn root_prefix(path: &str) -> Self { + profile_method!(root_prefix); + ResourceDef::prefix(insert_slash(path).into_owned()) + } + + /// Returns a numeric resource ID. + /// + /// If not explicitly set using [`set_id`][Self::set_id], this will return `0`. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// let mut resource = ResourceDef::new("/root"); + /// assert_eq!(resource.id(), 0); + /// + /// resource.set_id(42); + /// assert_eq!(resource.id(), 42); + /// ``` + pub fn id(&self) -> u16 { + self.id + } + + /// Set numeric resource ID. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// let mut resource = ResourceDef::new("/root"); + /// resource.set_id(42); + /// assert_eq!(resource.id(), 42); + /// ``` + pub fn set_id(&mut self, id: u16) { + self.id = id; + } + + /// Returns resource definition name, if set. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// let mut resource = ResourceDef::new("/root"); + /// assert!(resource.name().is_none()); + /// + /// resource.set_name("root"); + /// assert_eq!(resource.name().unwrap(), "root"); + pub fn name(&self) -> Option<&str> { + self.name.as_deref() + } + + /// Assigns a new name to the resource. + /// + /// # Panics + /// Panics if `name` is an empty string. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// let mut resource = ResourceDef::new("/root"); + /// resource.set_name("root"); + /// assert_eq!(resource.name().unwrap(), "root"); + /// ``` + pub fn set_name(&mut self, name: impl Into) { + let name = name.into(); + + assert!(!name.is_empty(), "resource name should not be empty"); + + self.name = Some(name) + } + + /// Returns `true` if pattern type is prefix. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// assert!(ResourceDef::prefix("/user").is_prefix()); + /// assert!(!ResourceDef::new("/user").is_prefix()); + /// ``` + pub fn is_prefix(&self) -> bool { + self.is_prefix + } + + /// Returns the pattern string that generated the resource definition. + /// + /// If definition is constructed with multiple patterns, the first pattern is returned. To get + /// all patterns, use [`patterns_iter`][Self::pattern_iter]. If resource has 0 patterns, + /// returns `None`. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// let mut resource = ResourceDef::new("/user/{id}"); + /// assert_eq!(resource.pattern().unwrap(), "/user/{id}"); + /// + /// let mut resource = ResourceDef::new(["/profile", "/user/{id}"]); + /// assert_eq!(resource.pattern(), Some("/profile")); + pub fn pattern(&self) -> Option<&str> { + match &self.patterns { + Patterns::Single(pattern) => Some(pattern.as_str()), + Patterns::List(patterns) => patterns.first().map(AsRef::as_ref), + } + } + + /// Returns iterator of pattern strings that generated the resource definition. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// let mut resource = ResourceDef::new("/root"); + /// let mut iter = resource.pattern_iter(); + /// assert_eq!(iter.next().unwrap(), "/root"); + /// assert!(iter.next().is_none()); + /// + /// let mut resource = ResourceDef::new(["/root", "/backup"]); + /// let mut iter = resource.pattern_iter(); + /// assert_eq!(iter.next().unwrap(), "/root"); + /// assert_eq!(iter.next().unwrap(), "/backup"); + /// assert!(iter.next().is_none()); + pub fn pattern_iter(&self) -> impl Iterator { + struct PatternIter<'a> { + patterns: &'a Patterns, + list_idx: usize, + done: bool, + } + + impl<'a> Iterator for PatternIter<'a> { + type Item = &'a str; + + fn next(&mut self) -> Option { + match &self.patterns { + Patterns::Single(pattern) => { + if self.done { + return None; + } + + self.done = true; + Some(pattern.as_str()) + } + Patterns::List(patterns) if patterns.is_empty() => None, + Patterns::List(patterns) => match patterns.get(self.list_idx) { + Some(pattern) => { + self.list_idx += 1; + Some(pattern.as_str()) + } + None => { + // fast path future call + self.done = true; + None + } + }, + } + } + + fn size_hint(&self) -> (usize, Option) { + match &self.patterns { + Patterns::Single(_) => (1, Some(1)), + Patterns::List(patterns) => (patterns.len(), Some(patterns.len())), + } + } + } + + PatternIter { + patterns: &self.patterns, + list_idx: 0, + done: false, + } + } + + /// Joins two resources. + /// + /// Resulting resource is prefix if `other` is prefix. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// let joined = ResourceDef::prefix("/root").join(&ResourceDef::prefix("/seg")); + /// assert_eq!(joined, ResourceDef::prefix("/root/seg")); + /// ``` + pub fn join(&self, other: &ResourceDef) -> ResourceDef { + let patterns = self + .pattern_iter() + .flat_map(move |this| other.pattern_iter().map(move |other| (this, other))) + .map(|(this, other)| [this, other].join("")) + .collect::>(); + + match patterns.len() { + 1 => ResourceDef::new2(&patterns[0], other.is_prefix()), + _ => ResourceDef::new2(patterns, other.is_prefix()), + } + } + + /// Returns `true` if `path` matches this resource. + /// + /// The behavior of this method depends on how the `ResourceDef` was constructed. For example, + /// static resources will not be able to match as many paths as dynamic and prefix resources. + /// See [`ResourceDef`] struct docs for details on resource definition types. + /// + /// This method will always agree with [`find_match`][Self::find_match] on whether the path + /// matches or not. + /// + /// # Examples + /// ``` + /// use actix_router::ResourceDef; + /// + /// // static resource + /// let resource = ResourceDef::new("/user"); + /// assert!(resource.is_match("/user")); + /// assert!(!resource.is_match("/users")); + /// assert!(!resource.is_match("/user/123")); + /// assert!(!resource.is_match("/foo")); + /// + /// // dynamic resource + /// let resource = ResourceDef::new("/user/{user_id}"); + /// assert!(resource.is_match("/user/123")); + /// assert!(!resource.is_match("/user/123/stars")); + /// + /// // prefix resource + /// let resource = ResourceDef::prefix("/root"); + /// assert!(resource.is_match("/root")); + /// assert!(resource.is_match("/root/leaf")); + /// assert!(!resource.is_match("/roots")); + /// + /// // more examples are shown in the `ResourceDef` struct docs + /// ``` + #[inline] + pub fn is_match(&self, path: &str) -> bool { + profile_method!(is_match); + + // this function could be expressed as: + // `self.find_match(path).is_some()` + // but this skips some checks and uses potentially faster regex methods + + match &self.pat_type { + PatternType::Static(pattern) => self.static_match(pattern, path).is_some(), + PatternType::Dynamic(re, _) => re.is_match(path), + PatternType::DynamicSet(re, _) => re.is_match(path), + } + } + + /// Tries to match `path` to this resource, returning the position in the path where the + /// match ends. + /// + /// This method will always agree with [`is_match`][Self::is_match] on whether the path matches + /// or not. + /// + /// # Examples + /// ``` + /// use actix_router::ResourceDef; + /// + /// // static resource + /// let resource = ResourceDef::new("/user"); + /// assert_eq!(resource.find_match("/user"), Some(5)); + /// assert!(resource.find_match("/user/").is_none()); + /// assert!(resource.find_match("/user/123").is_none()); + /// assert!(resource.find_match("/foo").is_none()); + /// + /// // constant prefix resource + /// let resource = ResourceDef::prefix("/user"); + /// assert_eq!(resource.find_match("/user"), Some(5)); + /// assert_eq!(resource.find_match("/user/"), Some(5)); + /// assert_eq!(resource.find_match("/user/123"), Some(5)); + /// + /// // dynamic prefix resource + /// let resource = ResourceDef::prefix("/user/{id}"); + /// assert_eq!(resource.find_match("/user/123"), Some(9)); + /// assert_eq!(resource.find_match("/user/1234/"), Some(10)); + /// assert_eq!(resource.find_match("/user/12345/stars"), Some(11)); + /// assert!(resource.find_match("/user/").is_none()); + /// + /// // multi-pattern resource + /// let resource = ResourceDef::new(["/user/{id}", "/profile/{id}"]); + /// assert_eq!(resource.find_match("/user/123"), Some(9)); + /// assert_eq!(resource.find_match("/profile/1234"), Some(13)); + /// ``` + pub fn find_match(&self, path: &str) -> Option { + profile_method!(find_match); + + match &self.pat_type { + PatternType::Static(pattern) => self.static_match(pattern, path), + + PatternType::Dynamic(re, _) => Some(re.captures(path)?[1].len()), + + PatternType::DynamicSet(re, params) => { + let idx = re.matches(path).into_iter().next()?; + let (ref pattern, _) = params[idx]; + Some(pattern.captures(path)?[1].len()) + } + } + } + + /// Collects dynamic segment values into `path`. + /// + /// Returns `true` if `path` matches this resource. + /// + /// # Examples + /// ``` + /// use actix_router::{Path, ResourceDef}; + /// + /// let resource = ResourceDef::prefix("/user/{id}"); + /// let mut path = Path::new("/user/123/stars"); + /// assert!(resource.capture_match_info(&mut path)); + /// assert_eq!(path.get("id").unwrap(), "123"); + /// assert_eq!(path.unprocessed(), "/stars"); + /// + /// let resource = ResourceDef::new("/blob/{path}*"); + /// let mut path = Path::new("/blob/HEAD/Cargo.toml"); + /// assert!(resource.capture_match_info(&mut path)); + /// assert_eq!(path.get("path").unwrap(), "HEAD/Cargo.toml"); + /// assert_eq!(path.unprocessed(), ""); + /// ``` + pub fn capture_match_info(&self, path: &mut Path) -> bool { + profile_method!(capture_match_info); + self.capture_match_info_fn(path, |_, _| true, ()) + } + + /// Collects dynamic segment values into `resource` after matching paths and executing + /// check function. + /// + /// The check function is given a reference to the passed resource and optional arbitrary data. + /// This is useful if you want to conditionally match on some non-path related aspect of the + /// resource type. + /// + /// Returns `true` if resource path matches this resource definition _and_ satisfies the + /// given check function. + /// + /// # Examples + /// ``` + /// use actix_router::{Path, ResourceDef}; + /// + /// fn try_match(resource: &ResourceDef, path: &mut Path<&str>) -> bool { + /// let admin_allowed = std::env::var("ADMIN_ALLOWED").ok(); + /// + /// resource.capture_match_info_fn( + /// path, + /// // when env var is not set, reject when path contains "admin" + /// |res, admin_allowed| !res.path().contains("admin"), + /// &admin_allowed + /// ) + /// } + /// + /// let resource = ResourceDef::prefix("/user/{id}"); + /// + /// // path matches; segment values are collected into path + /// let mut path = Path::new("/user/james/stars"); + /// assert!(try_match(&resource, &mut path)); + /// assert_eq!(path.get("id").unwrap(), "james"); + /// assert_eq!(path.unprocessed(), "/stars"); + /// + /// // path matches but fails check function; no segments are collected + /// let mut path = Path::new("/user/admin/stars"); + /// assert!(!try_match(&resource, &mut path)); + /// assert_eq!(path.unprocessed(), "/user/admin/stars"); + /// ``` + pub fn capture_match_info_fn( + &self, + resource: &mut R, + check_fn: F, + user_data: U, + ) -> bool + where + R: Resource, + T: ResourcePath, + F: FnOnce(&R, U) -> bool, + { + profile_method!(capture_match_info_fn); + + let mut segments = <[PathItem; MAX_DYNAMIC_SEGMENTS]>::default(); + let path = resource.resource_path(); + let path_str = path.path(); + + let (matched_len, matched_vars) = match &self.pat_type { + PatternType::Static(pattern) => { + profile_section!(pattern_static_or_prefix); + + match self.static_match(pattern, path_str) { + Some(len) => (len, None), + None => return false, + } + } + + PatternType::Dynamic(re, names) => { + profile_section!(pattern_dynamic); + + let captures = { + profile_section!(pattern_dynamic_regex_exec); + + match re.captures(path.path()) { + Some(captures) => captures, + _ => return false, + } + }; + + { + profile_section!(pattern_dynamic_extract_captures); + + for (no, name) in names.iter().enumerate() { + if let Some(m) = captures.name(name) { + segments[no] = PathItem::Segment(m.start() as u16, m.end() as u16); + } else { + log::error!( + "Dynamic path match but not all segments found: {}", + name + ); + return false; + } + } + }; + + (captures[1].len(), Some(names)) + } + + PatternType::DynamicSet(re, params) => { + profile_section!(pattern_dynamic_set); + + let path = path.path(); + let (pattern, names) = match re.matches(path).into_iter().next() { + Some(idx) => ¶ms[idx], + _ => return false, + }; + + let captures = match pattern.captures(path.path()) { + Some(captures) => captures, + _ => return false, + }; + + for (no, name) in names.iter().enumerate() { + if let Some(m) = captures.name(name) { + segments[no] = PathItem::Segment(m.start() as u16, m.end() as u16); + } else { + log::error!("Dynamic path match but not all segments found: {}", name); + return false; + } + } + + (captures[1].len(), Some(names)) + } + }; + + if !check_fn(resource, user_data) { + return false; + } + + // Modify `path` to skip matched part and store matched segments + let path = resource.resource_path(); + + if let Some(vars) = matched_vars { + for i in 0..vars.len() { + path.add(vars[i], mem::take(&mut segments[i])); + } + } + + path.skip(matched_len as u16); + + true + } + + /// Assembles resource path using a closure that maps variable segment names to values. + fn build_resource_path(&self, path: &mut String, mut vars: F) -> bool + where + F: FnMut(&str) -> Option, + I: AsRef, + { + for segment in &self.segments { + match segment { + PatternSegment::Const(val) => path.push_str(val), + PatternSegment::Var(name) => match vars(name) { + Some(val) => path.push_str(val.as_ref()), + _ => return false, + }, + } + } + + true + } + + /// Assembles full resource path from iterator of dynamic segment values. + /// + /// Returns `true` on success. + /// + /// For multi-pattern resources, the first pattern is used under the assumption that it would be + /// equivalent to any other choice. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// let mut s = String::new(); + /// let resource = ResourceDef::new("/user/{id}/post/{title}"); + /// + /// assert!(resource.resource_path_from_iter(&mut s, &["123", "my-post"])); + /// assert_eq!(s, "/user/123/post/my-post"); + /// ``` + pub fn resource_path_from_iter(&self, path: &mut String, values: I) -> bool + where + I: IntoIterator, + I::Item: AsRef, + { + profile_method!(resource_path_from_iter); + let mut iter = values.into_iter(); + self.build_resource_path(path, |_| iter.next()) + } + + /// Assembles resource path from map of dynamic segment values. + /// + /// Returns `true` on success. + /// + /// For multi-pattern resources, the first pattern is used under the assumption that it would be + /// equivalent to any other choice. + /// + /// # Examples + /// ``` + /// # use std::collections::HashMap; + /// # use actix_router::ResourceDef; + /// let mut s = String::new(); + /// let resource = ResourceDef::new("/user/{id}/post/{title}"); + /// + /// let mut map = HashMap::new(); + /// map.insert("id", "123"); + /// map.insert("title", "my-post"); + /// + /// assert!(resource.resource_path_from_map(&mut s, &map)); + /// assert_eq!(s, "/user/123/post/my-post"); + /// ``` + pub fn resource_path_from_map( + &self, + path: &mut String, + values: &HashMap, + ) -> bool + where + K: Borrow + Eq + Hash, + V: AsRef, + S: BuildHasher, + { + profile_method!(resource_path_from_map); + self.build_resource_path(path, |name| values.get(name).map(AsRef::::as_ref)) + } + + /// Returns true if `prefix` acts as a proper prefix (i.e., separated by a slash) in `path`. + fn static_match(&self, pattern: &str, path: &str) -> Option { + let rem = path.strip_prefix(pattern)?; + + match self.is_prefix { + // resource is not a prefix so an exact match is needed + false if rem.is_empty() => Some(pattern.len()), + + // resource is a prefix so rem should start with a path delimiter + true if rem.is_empty() || rem.starts_with('/') => Some(pattern.len()), + + // otherwise, no match + _ => None, + } + } + + fn new2(paths: T, is_prefix: bool) -> Self { + profile_method!(new2); + + let patterns = paths.patterns(); + let (pat_type, segments) = match &patterns { + Patterns::Single(pattern) => ResourceDef::parse(pattern, is_prefix, false), + + // since zero length pattern sets are possible + // just return a useless `ResourceDef` + Patterns::List(patterns) if patterns.is_empty() => ( + PatternType::DynamicSet(RegexSet::empty(), Vec::new()), + Vec::new(), + ), + + Patterns::List(patterns) => { + let mut re_set = Vec::with_capacity(patterns.len()); + let mut pattern_data = Vec::new(); + let mut segments = None; + + for pattern in patterns { + match ResourceDef::parse(pattern, is_prefix, true) { + (PatternType::Dynamic(re, names), segs) => { + re_set.push(re.as_str().to_owned()); + pattern_data.push((re, names)); + segments.get_or_insert(segs); + } + _ => unreachable!(), + } + } + + let pattern_re_set = RegexSet::new(re_set).unwrap(); + let segments = segments.unwrap_or_else(Vec::new); + + ( + PatternType::DynamicSet(pattern_re_set, pattern_data), + segments, + ) + } + }; + + ResourceDef { + id: 0, + name: None, + patterns, + is_prefix, + pat_type, + segments, + } + } + + /// Parses a dynamic segment definition from a pattern. + /// + /// The returned tuple includes: + /// - the segment descriptor, either `Var` or `Tail` + /// - the segment's regex to check values against + /// - the remaining, unprocessed string slice + /// - whether the parsed parameter represents a tail pattern + /// + /// # Panics + /// Panics if given patterns does not contain a dynamic segment. + fn parse_param(pattern: &str) -> (PatternSegment, String, &str, bool) { + profile_method!(parse_param); + + const DEFAULT_PATTERN: &str = "[^/]+"; + const DEFAULT_PATTERN_TAIL: &str = ".*"; + + let mut params_nesting = 0usize; + let close_idx = pattern + .find(|c| match c { + '{' => { + params_nesting += 1; + false + } + '}' => { + params_nesting -= 1; + params_nesting == 0 + } + _ => false, + }) + .unwrap_or_else(|| { + panic!( + r#"pattern "{}" contains malformed dynamic segment"#, + pattern + ) + }); + + let (mut param, mut unprocessed) = pattern.split_at(close_idx + 1); + + // remove outer curly brackets + param = ¶m[1..param.len() - 1]; + + let tail = unprocessed == "*"; + + let (name, pattern) = match param.find(':') { + Some(idx) => { + assert!(!tail, "custom regex is not supported for tail match"); + + let (name, pattern) = param.split_at(idx); + (name, &pattern[1..]) + } + None => ( + param, + if tail { + unprocessed = &unprocessed[1..]; + DEFAULT_PATTERN_TAIL + } else { + DEFAULT_PATTERN + }, + ), + }; + + let segment = PatternSegment::Var(name.to_string()); + let regex = format!(r"(?P<{}>{})", &name, &pattern); + + (segment, regex, unprocessed, tail) + } + + /// Parse `pattern` using `is_prefix` and `force_dynamic` flags. + /// + /// Parameters: + /// - `is_prefix`: Use `true` if `pattern` should be treated as a prefix; i.e., a conforming + /// path will be a match even if it has parts remaining to process + /// - `force_dynamic`: Use `true` to disallow the return of static and prefix segments. + /// + /// The returned tuple includes: + /// - the pattern type detected, either `Static`, `Prefix`, or `Dynamic` + /// - a list of segment descriptors from the pattern + fn parse( + pattern: &str, + is_prefix: bool, + force_dynamic: bool, + ) -> (PatternType, Vec) { + profile_method!(parse); + + if !force_dynamic && pattern.find('{').is_none() && !pattern.ends_with('*') { + // pattern is static + return ( + PatternType::Static(pattern.to_owned()), + vec![PatternSegment::Const(pattern.to_owned())], + ); + } + + let mut unprocessed = pattern; + let mut segments = Vec::new(); + let mut re = format!("{}^", REGEX_FLAGS); + let mut dyn_segment_count = 0; + let mut has_tail_segment = false; + + while let Some(idx) = unprocessed.find('{') { + let (prefix, rem) = unprocessed.split_at(idx); + + segments.push(PatternSegment::Const(prefix.to_owned())); + re.push_str(&escape(prefix)); + + let (param_pattern, re_part, rem, tail) = Self::parse_param(rem); + + if tail { + has_tail_segment = true; + } + + segments.push(param_pattern); + re.push_str(&re_part); + + unprocessed = rem; + dyn_segment_count += 1; + } + + if is_prefix && has_tail_segment { + // tail segments in prefixes have no defined semantics + + #[cfg(not(test))] + log::warn!( + "Prefix resources should not have tail segments. \ + Use `ResourceDef::new` constructor. \ + This may become a panic in the future." + ); + + // panic in tests to make this case detectable + #[cfg(test)] + panic!("prefix resource definitions should not have tail segments"); + } + + if unprocessed.ends_with('*') { + // unnamed tail segment + + #[cfg(not(test))] + log::warn!( + "Tail segments must have names. \ + Consider `.../{{tail}}*`. \ + This may become a panic in the future." + ); + + // panic in tests to make this case detectable + #[cfg(test)] + panic!("tail segments must have names"); + } else if !has_tail_segment && !unprocessed.is_empty() { + // prevent `Const("")` element from being added after last dynamic segment + + segments.push(PatternSegment::Const(unprocessed.to_owned())); + re.push_str(&escape(unprocessed)); + } + + assert!( + dyn_segment_count <= MAX_DYNAMIC_SEGMENTS, + "Only {} dynamic segments are allowed, provided: {}", + MAX_DYNAMIC_SEGMENTS, + dyn_segment_count + ); + + // Store the pattern in capture group #1 to have context info outside it + let mut re = format!("({})", re); + + // Ensure the match ends at a segment boundary + if !has_tail_segment { + if is_prefix { + re.push_str(r"(/|$)"); + } else { + re.push('$'); + } + } + + let re = match Regex::new(&re) { + Ok(re) => re, + Err(err) => panic!("Wrong path pattern: \"{}\" {}", pattern, err), + }; + + // `Bok::leak(Box::new(name))` is an intentional memory leak. In typical applications the + // routing table is only constructed once (per worker) so leak is bounded. If you are + // constructing `ResourceDef`s more than once in your application's lifecycle you would + // expect a linear increase in leaked memory over time. + let names = re + .capture_names() + .filter_map(|name| name.map(|name| Box::leak(Box::new(name.to_owned())).as_str())) + .collect(); + + (PatternType::Dynamic(re, names), segments) + } +} + +impl Eq for ResourceDef {} + +impl PartialEq for ResourceDef { + fn eq(&self, other: &ResourceDef) -> bool { + self.patterns == other.patterns && self.is_prefix == other.is_prefix + } +} + +impl Hash for ResourceDef { + fn hash(&self, state: &mut H) { + self.patterns.hash(state); + } +} + +impl<'a> From<&'a str> for ResourceDef { + fn from(path: &'a str) -> ResourceDef { + ResourceDef::new(path) + } +} + +impl From for ResourceDef { + fn from(path: String) -> ResourceDef { + ResourceDef::new(path) + } +} + +pub(crate) fn insert_slash(path: &str) -> Cow<'_, str> { + profile_fn!(insert_slash); + + if !path.is_empty() && !path.starts_with('/') { + let mut new_path = String::with_capacity(path.len() + 1); + new_path.push('/'); + new_path.push_str(path); + Cow::Owned(new_path) + } else { + Cow::Borrowed(path) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn equivalence() { + assert_eq!( + ResourceDef::root_prefix("/root"), + ResourceDef::prefix("/root") + ); + assert_eq!( + ResourceDef::root_prefix("root"), + ResourceDef::prefix("/root") + ); + assert_eq!( + ResourceDef::root_prefix("/{id}"), + ResourceDef::prefix("/{id}") + ); + assert_eq!( + ResourceDef::root_prefix("{id}"), + ResourceDef::prefix("/{id}") + ); + + assert_eq!(ResourceDef::new("/"), ResourceDef::new(["/"])); + assert_eq!(ResourceDef::new("/"), ResourceDef::new(vec!["/"])); + + assert_ne!(ResourceDef::new(""), ResourceDef::prefix("")); + assert_ne!(ResourceDef::new("/"), ResourceDef::prefix("/")); + assert_ne!(ResourceDef::new("/{id}"), ResourceDef::prefix("/{id}")); + } + + #[test] + fn parse_static() { + let re = ResourceDef::new(""); + + assert!(!re.is_prefix()); + + assert!(re.is_match("")); + assert!(!re.is_match("/")); + assert_eq!(re.find_match(""), Some(0)); + assert_eq!(re.find_match("/"), None); + + let re = ResourceDef::new("/"); + assert!(re.is_match("/")); + assert!(!re.is_match("")); + assert!(!re.is_match("/foo")); + + let re = ResourceDef::new("/name"); + assert!(re.is_match("/name")); + assert!(!re.is_match("/name1")); + assert!(!re.is_match("/name/")); + assert!(!re.is_match("/name~")); + + let mut path = Path::new("/name"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.unprocessed(), ""); + + assert_eq!(re.find_match("/name"), Some(5)); + assert_eq!(re.find_match("/name1"), None); + assert_eq!(re.find_match("/name/"), None); + assert_eq!(re.find_match("/name~"), None); + + let re = ResourceDef::new("/name/"); + assert!(re.is_match("/name/")); + assert!(!re.is_match("/name")); + assert!(!re.is_match("/name/gs")); + + let re = ResourceDef::new("/user/profile"); + assert!(re.is_match("/user/profile")); + assert!(!re.is_match("/user/profile/profile")); + + let mut path = Path::new("/user/profile"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.unprocessed(), ""); + } + + #[test] + fn parse_param() { + let re = ResourceDef::new("/user/{id}"); + assert!(re.is_match("/user/profile")); + assert!(re.is_match("/user/2345")); + assert!(!re.is_match("/user/2345/")); + assert!(!re.is_match("/user/2345/sdg")); + + let mut path = Path::new("/user/profile"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "profile"); + assert_eq!(path.unprocessed(), ""); + + let mut path = Path::new("/user/1245125"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "1245125"); + assert_eq!(path.unprocessed(), ""); + + let re = ResourceDef::new("/v{version}/resource/{id}"); + assert!(re.is_match("/v1/resource/320120")); + assert!(!re.is_match("/v/resource/1")); + assert!(!re.is_match("/resource")); + + let mut path = Path::new("/v151/resource/adage32"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("version").unwrap(), "151"); + assert_eq!(path.get("id").unwrap(), "adage32"); + assert_eq!(path.unprocessed(), ""); + + let re = ResourceDef::new("/{id:[[:digit:]]{6}}"); + assert!(re.is_match("/012345")); + assert!(!re.is_match("/012")); + assert!(!re.is_match("/01234567")); + assert!(!re.is_match("/XXXXXX")); + + let mut path = Path::new("/012345"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "012345"); + assert_eq!(path.unprocessed(), ""); + } + + #[allow(clippy::cognitive_complexity)] + #[test] + fn dynamic_set() { + let re = ResourceDef::new(vec![ + "/user/{id}", + "/v{version}/resource/{id}", + "/{id:[[:digit:]]{6}}", + "/static", + ]); + assert!(re.is_match("/user/profile")); + assert!(re.is_match("/user/2345")); + assert!(!re.is_match("/user/2345/")); + assert!(!re.is_match("/user/2345/sdg")); + + let mut path = Path::new("/user/profile"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "profile"); + assert_eq!(path.unprocessed(), ""); + + let mut path = Path::new("/user/1245125"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "1245125"); + assert_eq!(path.unprocessed(), ""); + + assert!(re.is_match("/v1/resource/320120")); + assert!(!re.is_match("/v/resource/1")); + assert!(!re.is_match("/resource")); + + let mut path = Path::new("/v151/resource/adage32"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("version").unwrap(), "151"); + assert_eq!(path.get("id").unwrap(), "adage32"); + + assert!(re.is_match("/012345")); + assert!(!re.is_match("/012")); + assert!(!re.is_match("/01234567")); + assert!(!re.is_match("/XXXXXX")); + + assert!(re.is_match("/static")); + assert!(!re.is_match("/a/static")); + assert!(!re.is_match("/static/a")); + + let mut path = Path::new("/012345"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "012345"); + + let re = ResourceDef::new([ + "/user/{id}", + "/v{version}/resource/{id}", + "/{id:[[:digit:]]{6}}", + ]); + assert!(re.is_match("/user/profile")); + assert!(re.is_match("/user/2345")); + assert!(!re.is_match("/user/2345/")); + assert!(!re.is_match("/user/2345/sdg")); + + let re = ResourceDef::new([ + "/user/{id}".to_string(), + "/v{version}/resource/{id}".to_string(), + "/{id:[[:digit:]]{6}}".to_string(), + ]); + assert!(re.is_match("/user/profile")); + assert!(re.is_match("/user/2345")); + assert!(!re.is_match("/user/2345/")); + assert!(!re.is_match("/user/2345/sdg")); + } + + #[test] + fn dynamic_set_prefix() { + let re = ResourceDef::prefix(vec!["/u/{id}", "/{id:[[:digit:]]{3}}"]); + + assert_eq!(re.find_match("/u/abc"), Some(6)); + assert_eq!(re.find_match("/u/abc/123"), Some(6)); + assert_eq!(re.find_match("/s/user/profile"), None); + + assert_eq!(re.find_match("/123"), Some(4)); + assert_eq!(re.find_match("/123/456"), Some(4)); + assert_eq!(re.find_match("/12345"), None); + + let mut path = Path::new("/151/res"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "151"); + assert_eq!(path.unprocessed(), "/res"); + } + + #[test] + fn parse_tail() { + let re = ResourceDef::new("/user/-{id}*"); + + let mut path = Path::new("/user/-profile"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "profile"); + + let mut path = Path::new("/user/-2345"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "2345"); + + let mut path = Path::new("/user/-2345/"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "2345/"); + + let mut path = Path::new("/user/-2345/sdg"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "2345/sdg"); + } + + #[test] + fn static_tail() { + let re = ResourceDef::new("/user{tail}*"); + assert!(re.is_match("/users")); + assert!(re.is_match("/user-foo")); + assert!(re.is_match("/user/profile")); + assert!(re.is_match("/user/2345")); + assert!(re.is_match("/user/2345/")); + assert!(re.is_match("/user/2345/sdg")); + assert!(!re.is_match("/foo/profile")); + + let re = ResourceDef::new("/user/{tail}*"); + assert!(re.is_match("/user/profile")); + assert!(re.is_match("/user/2345")); + assert!(re.is_match("/user/2345/")); + assert!(re.is_match("/user/2345/sdg")); + assert!(!re.is_match("/foo/profile")); + } + + #[test] + fn dynamic_tail() { + let re = ResourceDef::new("/user/{id}/{tail}*"); + assert!(!re.is_match("/user/2345")); + let mut path = Path::new("/user/2345/sdg"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "2345"); + assert_eq!(path.get("tail").unwrap(), "sdg"); + assert_eq!(path.unprocessed(), ""); + } + + #[test] + fn newline_patterns_and_paths() { + let re = ResourceDef::new("/user/a\nb"); + assert!(re.is_match("/user/a\nb")); + assert!(!re.is_match("/user/a\nb/profile")); + + let re = ResourceDef::new("/a{x}b/test/a{y}b"); + let mut path = Path::new("/a\nb/test/a\nb"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("x").unwrap(), "\n"); + assert_eq!(path.get("y").unwrap(), "\n"); + + let re = ResourceDef::new("/user/{tail}*"); + assert!(re.is_match("/user/a\nb/")); + + let re = ResourceDef::new("/user/{id}*"); + let mut path = Path::new("/user/a\nb/a\nb"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "a\nb/a\nb"); + + let re = ResourceDef::new("/user/{id:.*}"); + let mut path = Path::new("/user/a\nb/a\nb"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "a\nb/a\nb"); + } + + #[cfg(feature = "http")] + #[test] + fn parse_urlencoded_param() { + use std::convert::TryFrom; + + let re = ResourceDef::new("/user/{id}/test"); + + let mut path = Path::new("/user/2345/test"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "2345"); + + let mut path = Path::new("/user/qwe%25/test"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "qwe%25"); + + let uri = http::Uri::try_from("/user/qwe%25/test").unwrap(); + let mut path = Path::new(uri); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "qwe%25"); + } + + #[test] + fn prefix_static() { + let re = ResourceDef::prefix("/name"); + + assert!(re.is_prefix()); + + assert!(re.is_match("/name")); + assert!(re.is_match("/name/")); + assert!(re.is_match("/name/test/test")); + assert!(!re.is_match("/name1")); + assert!(!re.is_match("/name~")); + + let mut path = Path::new("/name"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.unprocessed(), ""); + + let mut path = Path::new("/name/test"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.unprocessed(), "/test"); + + assert_eq!(re.find_match("/name"), Some(5)); + assert_eq!(re.find_match("/name/"), Some(5)); + assert_eq!(re.find_match("/name/test/test"), Some(5)); + assert_eq!(re.find_match("/name1"), None); + assert_eq!(re.find_match("/name~"), None); + + let re = ResourceDef::prefix("/name/"); + assert!(re.is_match("/name/")); + assert!(re.is_match("/name//gs")); + assert!(!re.is_match("/name/gs")); + assert!(!re.is_match("/name")); + + let mut path = Path::new("/name/gs"); + assert!(!re.capture_match_info(&mut path)); + + let mut path = Path::new("/name//gs"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.unprocessed(), "/gs"); + + let re = ResourceDef::root_prefix("name/"); + assert!(re.is_match("/name/")); + assert!(re.is_match("/name//gs")); + assert!(!re.is_match("/name/gs")); + assert!(!re.is_match("/name")); + + let mut path = Path::new("/name/gs"); + assert!(!re.capture_match_info(&mut path)); + } + + #[test] + fn prefix_dynamic() { + let re = ResourceDef::prefix("/{name}"); + + assert!(re.is_prefix()); + + assert!(re.is_match("/name/")); + assert!(re.is_match("/name/gs")); + assert!(re.is_match("/name")); + + assert_eq!(re.find_match("/name/"), Some(5)); + assert_eq!(re.find_match("/name/gs"), Some(5)); + assert_eq!(re.find_match("/name"), Some(5)); + assert_eq!(re.find_match(""), None); + + let mut path = Path::new("/test2/"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(&path["name"], "test2"); + assert_eq!(&path[0], "test2"); + assert_eq!(path.unprocessed(), "/"); + + let mut path = Path::new("/test2/subpath1/subpath2/index.html"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(&path["name"], "test2"); + assert_eq!(&path[0], "test2"); + assert_eq!(path.unprocessed(), "/subpath1/subpath2/index.html"); + + let resource = ResourceDef::prefix("/user"); + // input string shorter than prefix + assert!(resource.find_match("/foo").is_none()); + } + + #[test] + fn prefix_empty() { + let re = ResourceDef::prefix(""); + + assert!(re.is_prefix()); + + assert!(re.is_match("")); + assert!(re.is_match("/")); + assert!(re.is_match("/name/test/test")); + } + + #[test] + fn build_path_list() { + let mut s = String::new(); + let resource = ResourceDef::new("/user/{item1}/test"); + assert!(resource.resource_path_from_iter(&mut s, &mut (&["user1"]).iter())); + assert_eq!(s, "/user/user1/test"); + + let mut s = String::new(); + let resource = ResourceDef::new("/user/{item1}/{item2}/test"); + assert!(resource.resource_path_from_iter(&mut s, &mut (&["item", "item2"]).iter())); + assert_eq!(s, "/user/item/item2/test"); + + let mut s = String::new(); + let resource = ResourceDef::new("/user/{item1}/{item2}"); + assert!(resource.resource_path_from_iter(&mut s, &mut (&["item", "item2"]).iter())); + assert_eq!(s, "/user/item/item2"); + + let mut s = String::new(); + let resource = ResourceDef::new("/user/{item1}/{item2}/"); + assert!(resource.resource_path_from_iter(&mut s, &mut (&["item", "item2"]).iter())); + assert_eq!(s, "/user/item/item2/"); + + let mut s = String::new(); + assert!(!resource.resource_path_from_iter(&mut s, &mut (&["item"]).iter())); + + let mut s = String::new(); + assert!(resource.resource_path_from_iter(&mut s, &mut (&["item", "item2"]).iter())); + assert_eq!(s, "/user/item/item2/"); + assert!(!resource.resource_path_from_iter(&mut s, &mut (&["item"]).iter())); + + let mut s = String::new(); + assert!(resource.resource_path_from_iter(&mut s, &mut vec!["item", "item2"].iter())); + assert_eq!(s, "/user/item/item2/"); + } + + #[test] + fn multi_pattern_build_path() { + let resource = ResourceDef::new(["/user/{id}", "/profile/{id}"]); + let mut s = String::new(); + assert!(resource.resource_path_from_iter(&mut s, &mut ["123"].iter())); + assert_eq!(s, "/user/123"); + } + + #[test] + fn multi_pattern_capture_segment_values() { + let resource = ResourceDef::new(["/user/{id}", "/profile/{id}"]); + + let mut path = Path::new("/user/123"); + assert!(resource.capture_match_info(&mut path)); + assert!(path.get("id").is_some()); + + let mut path = Path::new("/profile/123"); + assert!(resource.capture_match_info(&mut path)); + assert!(path.get("id").is_some()); + + let resource = ResourceDef::new(["/user/{id}", "/profile/{uid}"]); + + let mut path = Path::new("/user/123"); + assert!(resource.capture_match_info(&mut path)); + assert!(path.get("id").is_some()); + assert!(path.get("uid").is_none()); + + let mut path = Path::new("/profile/123"); + assert!(resource.capture_match_info(&mut path)); + assert!(path.get("id").is_none()); + assert!(path.get("uid").is_some()); + } + + #[test] + fn dynamic_prefix_proper_segmentation() { + let resource = ResourceDef::prefix(r"/id/{id:\d{3}}"); + + assert!(resource.is_match("/id/123")); + assert!(resource.is_match("/id/123/foo")); + assert!(!resource.is_match("/id/1234")); + assert!(!resource.is_match("/id/123a")); + + assert_eq!(resource.find_match("/id/123"), Some(7)); + assert_eq!(resource.find_match("/id/123/foo"), Some(7)); + assert_eq!(resource.find_match("/id/1234"), None); + assert_eq!(resource.find_match("/id/123a"), None); + } + + #[test] + fn build_path_map() { + let resource = ResourceDef::new("/user/{item1}/{item2}/"); + + let mut map = HashMap::new(); + map.insert("item1", "item"); + + let mut s = String::new(); + assert!(!resource.resource_path_from_map(&mut s, &map)); + + map.insert("item2", "item2"); + + let mut s = String::new(); + assert!(resource.resource_path_from_map(&mut s, &map)); + assert_eq!(s, "/user/item/item2/"); + } + + #[test] + fn build_path_tail() { + let resource = ResourceDef::new("/user/{item1}*"); + + let mut s = String::new(); + assert!(!resource.resource_path_from_iter(&mut s, &mut (&[""; 0]).iter())); + + let mut s = String::new(); + assert!(resource.resource_path_from_iter(&mut s, &mut (&["user1"]).iter())); + assert_eq!(s, "/user/user1"); + + let mut s = String::new(); + let mut map = HashMap::new(); + map.insert("item1", "item"); + assert!(resource.resource_path_from_map(&mut s, &map)); + assert_eq!(s, "/user/item"); + } + + #[test] + fn prefix_trailing_slash() { + // The prefix "/abc/" matches two segments: ["user", ""] + + // These are not prefixes + let re = ResourceDef::prefix("/abc/"); + assert_eq!(re.find_match("/abc/def"), None); + assert_eq!(re.find_match("/abc//def"), Some(5)); + + let re = ResourceDef::prefix("/{id}/"); + assert_eq!(re.find_match("/abc/def"), None); + assert_eq!(re.find_match("/abc//def"), Some(5)); + } + + #[test] + fn join() { + // test joined defs match the same paths as each component separately + + fn seq_find_match(re1: &ResourceDef, re2: &ResourceDef, path: &str) -> Option { + let len1 = re1.find_match(path)?; + let len2 = re2.find_match(&path[len1..])?; + Some(len1 + len2) + } + + macro_rules! join_test { + ($pat1:expr, $pat2:expr => $($test:expr),+) => {{ + let pat1 = $pat1; + let pat2 = $pat2; + $({ + let _path = $test; + let (re1, re2) = (ResourceDef::prefix(pat1), ResourceDef::new(pat2)); + let _seq = seq_find_match(&re1, &re2, _path); + let _join = re1.join(&re2).find_match(_path); + assert_eq!( + _seq, _join, + "patterns: prefix {:?}, {:?}; mismatch on \"{}\"; seq={:?}; join={:?}", + pat1, pat2, _path, _seq, _join + ); + assert!(!re1.join(&re2).is_prefix()); + + let (re1, re2) = (ResourceDef::prefix(pat1), ResourceDef::prefix(pat2)); + let _seq = seq_find_match(&re1, &re2, _path); + let _join = re1.join(&re2).find_match(_path); + assert_eq!( + _seq, _join, + "patterns: prefix {:?}, prefix {:?}; mismatch on \"{}\"; seq={:?}; join={:?}", + pat1, pat2, _path, _seq, _join + ); + assert!(re1.join(&re2).is_prefix()); + })+ + }} + } + + join_test!("", "" => "", "/hello", "/"); + join_test!("/user", "" => "", "/user", "/user/123", "/user11", "user", "user/123"); + join_test!("", "/user" => "", "/user", "foo", "/user11", "user", "user/123"); + join_test!("/user", "/xx" => "", "", "/", "/user", "/xx", "/userxx", "/user/xx"); + + join_test!(["/ver/{v}", "/v{v}"], ["/req/{req}", "/{req}"] => "/v1/abc", + "/ver/1/abc", "/v1/req/abc", "/ver/1/req/abc", "/v1/abc/def", + "/ver1/req/abc/def", "", "/", "/v1/"); + } + + #[test] + fn match_methods_agree() { + macro_rules! match_methods_agree { + ($pat:expr => $($test:expr),+) => {{ + match_methods_agree!(finish $pat, ResourceDef::new($pat), $($test),+); + }}; + (prefix $pat:expr => $($test:expr),+) => {{ + match_methods_agree!(finish $pat, ResourceDef::prefix($pat), $($test),+); + }}; + (finish $pat:expr, $re:expr, $($test:expr),+) => {{ + let re = $re; + $({ + let _is = re.is_match($test); + let _find = re.find_match($test).is_some(); + assert_eq!( + _is, _find, + "pattern: {:?}; mismatch on \"{}\"; is={}; find={}", + $pat, $test, _is, _find + ); + })+ + }} + } + + match_methods_agree!("" => "", "/", "/foo"); + match_methods_agree!("/" => "", "/", "/foo"); + match_methods_agree!("/user" => "user", "/user", "/users", "/user/123", "/foo"); + match_methods_agree!("/v{v}" => "v", "/v", "/v1", "/v222", "/foo"); + match_methods_agree!(["/v{v}", "/version/{v}"] => "/v", "/v1", "/version", "/version/1", "/foo"); + + match_methods_agree!("/path{tail}*" => "/path", "/path1", "/path/123"); + match_methods_agree!("/path/{tail}*" => "/path", "/path1", "/path/123"); + + match_methods_agree!(prefix "" => "", "/", "/foo"); + match_methods_agree!(prefix "/user" => "user", "/user", "/users", "/user/123", "/foo"); + match_methods_agree!(prefix r"/id/{id:\d{3}}" => "/id/123", "/id/1234"); + match_methods_agree!(["/v{v}", "/ver/{v}"] => "", "s/v", "/v1", "/v1/xx", "/ver/i3/5", "/ver/1"); + } + + #[test] + #[should_panic] + fn duplicate_segment_name() { + ResourceDef::new("/user/{id}/post/{id}"); + } + + #[test] + #[should_panic] + fn invalid_dynamic_segment_delimiter() { + ResourceDef::new("/user/{username"); + } + + #[test] + #[should_panic] + fn invalid_dynamic_segment_name() { + ResourceDef::new("/user/{}"); + } + + #[test] + #[should_panic] + fn invalid_too_many_dynamic_segments() { + // valid + ResourceDef::new("/{a}/{b}/{c}/{d}/{e}/{f}/{g}/{h}/{i}/{j}/{k}/{l}/{m}/{n}/{o}/{p}"); + + // panics + ResourceDef::new( + "/{a}/{b}/{c}/{d}/{e}/{f}/{g}/{h}/{i}/{j}/{k}/{l}/{m}/{n}/{o}/{p}/{q}", + ); + } + + #[test] + #[should_panic] + fn invalid_custom_regex_for_tail() { + ResourceDef::new(r"/{tail:\d+}*"); + } + + #[test] + #[should_panic] + fn invalid_unnamed_tail_segment() { + ResourceDef::new("/*"); + } + + #[test] + #[should_panic] + fn prefix_plus_tail_match_is_allowed() { + ResourceDef::prefix("/user/{id}*"); + } +} diff --git a/actix-router/src/router.rs b/actix-router/src/router.rs new file mode 100644 index 000000000..fad1a440b --- /dev/null +++ b/actix-router/src/router.rs @@ -0,0 +1,282 @@ +use firestorm::profile_method; + +use crate::{IntoPatterns, Resource, ResourceDef, ResourcePath}; + +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct ResourceId(pub u16); + +/// Information about current resource +#[derive(Debug, Clone)] +pub struct ResourceInfo { + #[allow(dead_code)] + resource: ResourceId, +} + +/// Resource router. +// T is the resource itself +// U is any other data needed for routing like method guards +pub struct Router { + routes: Vec<(ResourceDef, T, Option)>, +} + +impl Router { + pub fn build() -> RouterBuilder { + RouterBuilder { + resources: Vec::new(), + } + } + + pub fn recognize(&self, resource: &mut R) -> Option<(&T, ResourceId)> + where + R: Resource

, + P: ResourcePath, + { + profile_method!(recognize); + + for item in self.routes.iter() { + if item.0.capture_match_info(resource.resource_path()) { + return Some((&item.1, ResourceId(item.0.id()))); + } + } + + None + } + + pub fn recognize_mut(&mut self, resource: &mut R) -> Option<(&mut T, ResourceId)> + where + R: Resource

, + P: ResourcePath, + { + profile_method!(recognize_mut); + + for item in self.routes.iter_mut() { + if item.0.capture_match_info(resource.resource_path()) { + return Some((&mut item.1, ResourceId(item.0.id()))); + } + } + + None + } + + pub fn recognize_fn(&self, resource: &mut R, check: F) -> Option<(&T, ResourceId)> + where + F: Fn(&R, &Option) -> bool, + R: Resource

, + P: ResourcePath, + { + profile_method!(recognize_checked); + + for item in self.routes.iter() { + if item.0.capture_match_info_fn(resource, &check, &item.2) { + return Some((&item.1, ResourceId(item.0.id()))); + } + } + + None + } + + pub fn recognize_mut_fn( + &mut self, + resource: &mut R, + check: F, + ) -> Option<(&mut T, ResourceId)> + where + F: Fn(&R, &Option) -> bool, + R: Resource

, + P: ResourcePath, + { + profile_method!(recognize_mut_checked); + + for item in self.routes.iter_mut() { + if item.0.capture_match_info_fn(resource, &check, &item.2) { + return Some((&mut item.1, ResourceId(item.0.id()))); + } + } + + None + } +} + +pub struct RouterBuilder { + resources: Vec<(ResourceDef, T, Option)>, +} + +impl RouterBuilder { + /// Register resource for specified path. + pub fn path( + &mut self, + path: P, + resource: T, + ) -> &mut (ResourceDef, T, Option) { + profile_method!(path); + + self.resources + .push((ResourceDef::new(path), resource, None)); + self.resources.last_mut().unwrap() + } + + /// Register resource for specified path prefix. + pub fn prefix(&mut self, prefix: &str, resource: T) -> &mut (ResourceDef, T, Option) { + profile_method!(prefix); + + self.resources + .push((ResourceDef::prefix(prefix), resource, None)); + self.resources.last_mut().unwrap() + } + + /// Register resource for ResourceDef + pub fn rdef(&mut self, rdef: ResourceDef, resource: T) -> &mut (ResourceDef, T, Option) { + profile_method!(rdef); + + self.resources.push((rdef, resource, None)); + self.resources.last_mut().unwrap() + } + + /// Finish configuration and create router instance. + pub fn finish(self) -> Router { + Router { + routes: self.resources, + } + } +} + +#[cfg(test)] +mod tests { + use crate::path::Path; + use crate::router::{ResourceId, Router}; + + #[allow(clippy::cognitive_complexity)] + #[test] + fn test_recognizer_1() { + let mut router = Router::::build(); + router.path("/name", 10).0.set_id(0); + router.path("/name/{val}", 11).0.set_id(1); + router.path("/name/{val}/index.html", 12).0.set_id(2); + router.path("/file/{file}.{ext}", 13).0.set_id(3); + router.path("/v{val}/{val2}/index.html", 14).0.set_id(4); + router.path("/v/{tail:.*}", 15).0.set_id(5); + router.path("/test2/{test}.html", 16).0.set_id(6); + router.path("/{test}/index.html", 17).0.set_id(7); + let mut router = router.finish(); + + let mut path = Path::new("/unknown"); + assert!(router.recognize_mut(&mut path).is_none()); + + let mut path = Path::new("/name"); + let (h, info) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 10); + assert_eq!(info, ResourceId(0)); + assert!(path.is_empty()); + + let mut path = Path::new("/name/value"); + let (h, info) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 11); + assert_eq!(info, ResourceId(1)); + assert_eq!(path.get("val").unwrap(), "value"); + assert_eq!(&path["val"], "value"); + + let mut path = Path::new("/name/value2/index.html"); + let (h, info) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 12); + assert_eq!(info, ResourceId(2)); + assert_eq!(path.get("val").unwrap(), "value2"); + + let mut path = Path::new("/file/file.gz"); + let (h, info) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 13); + assert_eq!(info, ResourceId(3)); + assert_eq!(path.get("file").unwrap(), "file"); + assert_eq!(path.get("ext").unwrap(), "gz"); + + let mut path = Path::new("/vtest/ttt/index.html"); + let (h, info) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 14); + assert_eq!(info, ResourceId(4)); + assert_eq!(path.get("val").unwrap(), "test"); + assert_eq!(path.get("val2").unwrap(), "ttt"); + + let mut path = Path::new("/v/blah-blah/index.html"); + let (h, info) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 15); + assert_eq!(info, ResourceId(5)); + assert_eq!(path.get("tail").unwrap(), "blah-blah/index.html"); + + let mut path = Path::new("/test2/index.html"); + let (h, info) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 16); + assert_eq!(info, ResourceId(6)); + assert_eq!(path.get("test").unwrap(), "index"); + + let mut path = Path::new("/bbb/index.html"); + let (h, info) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 17); + assert_eq!(info, ResourceId(7)); + assert_eq!(path.get("test").unwrap(), "bbb"); + } + + #[test] + fn test_recognizer_2() { + let mut router = Router::::build(); + router.path("/index.json", 10); + router.path("/{source}.json", 11); + let mut router = router.finish(); + + let mut path = Path::new("/index.json"); + let (h, _) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 10); + + let mut path = Path::new("/test.json"); + let (h, _) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 11); + } + + #[test] + fn test_recognizer_with_prefix() { + let mut router = Router::::build(); + router.path("/name", 10).0.set_id(0); + router.path("/name/{val}", 11).0.set_id(1); + let mut router = router.finish(); + + let mut path = Path::new("/name"); + path.skip(5); + assert!(router.recognize_mut(&mut path).is_none()); + + let mut path = Path::new("/test/name"); + path.skip(5); + let (h, _) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 10); + + let mut path = Path::new("/test/name/value"); + path.skip(5); + let (h, id) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 11); + assert_eq!(id, ResourceId(1)); + assert_eq!(path.get("val").unwrap(), "value"); + assert_eq!(&path["val"], "value"); + + // same patterns + let mut router = Router::::build(); + router.path("/name", 10); + router.path("/name/{val}", 11); + let mut router = router.finish(); + + let mut path = Path::new("/name"); + path.skip(6); + assert!(router.recognize_mut(&mut path).is_none()); + + let mut path = Path::new("/test2/name"); + path.skip(6); + let (h, _) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 10); + + let mut path = Path::new("/test2/name-test"); + path.skip(6); + assert!(router.recognize_mut(&mut path).is_none()); + + let mut path = Path::new("/test2/name/ttt"); + path.skip(6); + let (h, _) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 11); + assert_eq!(&path["val"], "ttt"); + } +} diff --git a/actix-router/src/url.rs b/actix-router/src/url.rs new file mode 100644 index 000000000..e08a7171a --- /dev/null +++ b/actix-router/src/url.rs @@ -0,0 +1,288 @@ +use crate::ResourcePath; + +#[allow(dead_code)] +const GEN_DELIMS: &[u8] = b":/?#[]@"; +#[allow(dead_code)] +const SUB_DELIMS_WITHOUT_QS: &[u8] = b"!$'()*,"; +#[allow(dead_code)] +const SUB_DELIMS: &[u8] = b"!$'()*,+?=;"; +#[allow(dead_code)] +const RESERVED: &[u8] = b":/?#[]@!$'()*,+?=;"; +#[allow(dead_code)] +const UNRESERVED: &[u8] = b"abcdefghijklmnopqrstuvwxyz + ABCDEFGHIJKLMNOPQRSTUVWXYZ + 1234567890 + -._~"; +const ALLOWED: &[u8] = b"abcdefghijklmnopqrstuvwxyz + ABCDEFGHIJKLMNOPQRSTUVWXYZ + 1234567890 + -._~ + !$'()*,"; +const QS: &[u8] = b"+&=;b"; + +#[inline] +fn bit_at(array: &[u8], ch: u8) -> bool { + array[(ch >> 3) as usize] & (1 << (ch & 7)) != 0 +} + +#[inline] +fn set_bit(array: &mut [u8], ch: u8) { + array[(ch >> 3) as usize] |= 1 << (ch & 7) +} + +thread_local! { + static DEFAULT_QUOTER: Quoter = Quoter::new(b"@:", b"%/+"); +} + +#[derive(Default, Clone, Debug)] +pub struct Url { + uri: http::Uri, + path: Option, +} + +impl Url { + pub fn new(uri: http::Uri) -> Url { + let path = DEFAULT_QUOTER.with(|q| q.requote(uri.path().as_bytes())); + + Url { uri, path } + } + + pub fn with_quoter(uri: http::Uri, quoter: &Quoter) -> Url { + Url { + path: quoter.requote(uri.path().as_bytes()), + uri, + } + } + + pub fn uri(&self) -> &http::Uri { + &self.uri + } + + pub fn path(&self) -> &str { + if let Some(ref s) = self.path { + s + } else { + self.uri.path() + } + } + + #[inline] + pub fn update(&mut self, uri: &http::Uri) { + self.uri = uri.clone(); + self.path = DEFAULT_QUOTER.with(|q| q.requote(uri.path().as_bytes())); + } + + #[inline] + pub fn update_with_quoter(&mut self, uri: &http::Uri, quoter: &Quoter) { + self.uri = uri.clone(); + self.path = quoter.requote(uri.path().as_bytes()); + } +} + +impl ResourcePath for Url { + #[inline] + fn path(&self) -> &str { + self.path() + } +} + +pub struct Quoter { + safe_table: [u8; 16], + protected_table: [u8; 16], +} + +impl Quoter { + pub fn new(safe: &[u8], protected: &[u8]) -> Quoter { + let mut q = Quoter { + safe_table: [0; 16], + protected_table: [0; 16], + }; + + // prepare safe table + for i in 0..128 { + if ALLOWED.contains(&i) { + set_bit(&mut q.safe_table, i); + } + if QS.contains(&i) { + set_bit(&mut q.safe_table, i); + } + } + + for ch in safe { + set_bit(&mut q.safe_table, *ch) + } + + // prepare protected table + for ch in protected { + set_bit(&mut q.safe_table, *ch); + set_bit(&mut q.protected_table, *ch); + } + + q + } + + pub fn requote(&self, val: &[u8]) -> Option { + let mut has_pct = 0; + let mut pct = [b'%', 0, 0]; + let mut idx = 0; + let mut cloned: Option> = None; + + let len = val.len(); + while idx < len { + let ch = val[idx]; + + if has_pct != 0 { + pct[has_pct] = val[idx]; + has_pct += 1; + if has_pct == 3 { + has_pct = 0; + let buf = cloned.as_mut().unwrap(); + + if let Some(ch) = restore_ch(pct[1], pct[2]) { + if ch < 128 { + if bit_at(&self.protected_table, ch) { + buf.extend_from_slice(&pct); + idx += 1; + continue; + } + + if bit_at(&self.safe_table, ch) { + buf.push(ch); + idx += 1; + continue; + } + } + buf.push(ch); + } else { + buf.extend_from_slice(&pct[..]); + } + } + } else if ch == b'%' { + has_pct = 1; + if cloned.is_none() { + let mut c = Vec::with_capacity(len); + c.extend_from_slice(&val[..idx]); + cloned = Some(c); + } + } else if let Some(ref mut cloned) = cloned { + cloned.push(ch) + } + idx += 1; + } + + cloned.map(|data| String::from_utf8_lossy(&data).into_owned()) + } +} + +#[inline] +fn from_hex(v: u8) -> Option { + if (b'0'..=b'9').contains(&v) { + Some(v - 0x30) // ord('0') == 0x30 + } else if (b'A'..=b'F').contains(&v) { + Some(v - 0x41 + 10) // ord('A') == 0x41 + } else if (b'a'..=b'f').contains(&v) { + Some(v - 0x61 + 10) // ord('a') == 0x61 + } else { + None + } +} + +#[inline] +fn restore_ch(d1: u8, d2: u8) -> Option { + from_hex(d1).and_then(|d1| from_hex(d2).map(move |d2| d1 << 4 | d2)) +} + +#[cfg(test)] +mod tests { + use http::Uri; + use std::convert::TryFrom; + + use super::*; + use crate::{Path, ResourceDef}; + + const PROTECTED: &[u8] = b"%/+"; + + fn match_url(pattern: &'static str, url: impl AsRef) -> Path { + let re = ResourceDef::new(pattern); + let uri = Uri::try_from(url.as_ref()).unwrap(); + let mut path = Path::new(Url::new(uri)); + assert!(re.capture_match_info(&mut path)); + path + } + + fn percent_encode(data: &[u8]) -> String { + data.iter().map(|c| format!("%{:02X}", c)).collect() + } + + #[test] + fn test_parse_url() { + let re = "/user/{id}/test"; + + let path = match_url(re, "/user/2345/test"); + assert_eq!(path.get("id").unwrap(), "2345"); + + // "%25" should never be decoded into '%' to guarantee the output is a valid + // percent-encoded format + let path = match_url(re, "/user/qwe%25/test"); + assert_eq!(path.get("id").unwrap(), "qwe%25"); + + let path = match_url(re, "/user/qwe%25rty/test"); + assert_eq!(path.get("id").unwrap(), "qwe%25rty"); + } + + #[test] + fn test_protected_chars() { + let encoded = percent_encode(PROTECTED); + let path = match_url("/user/{id}/test", format!("/user/{}/test", encoded)); + assert_eq!(path.get("id").unwrap(), &encoded); + } + + #[test] + fn test_non_protecteed_ascii() { + let nonprotected_ascii = ('\u{0}'..='\u{7F}') + .filter(|&c| c.is_ascii() && !PROTECTED.contains(&(c as u8))) + .collect::(); + let encoded = percent_encode(nonprotected_ascii.as_bytes()); + let path = match_url("/user/{id}/test", format!("/user/{}/test", encoded)); + assert_eq!(path.get("id").unwrap(), &nonprotected_ascii); + } + + #[test] + fn test_valid_utf8_multibyte() { + let test = ('\u{FF00}'..='\u{FFFF}').collect::(); + let encoded = percent_encode(test.as_bytes()); + let path = match_url("/a/{id}/b", format!("/a/{}/b", &encoded)); + assert_eq!(path.get("id").unwrap(), &test); + } + + #[test] + fn test_invalid_utf8() { + let invalid_utf8 = percent_encode((0x80..=0xff).collect::>().as_slice()); + let uri = Uri::try_from(format!("/{}", invalid_utf8)).unwrap(); + let path = Path::new(Url::new(uri)); + + // We should always get a valid utf8 string + assert!(String::from_utf8(path.path().as_bytes().to_owned()).is_ok()); + } + + #[test] + fn test_from_hex() { + let hex = b"0123456789abcdefABCDEF"; + + for i in 0..256 { + let c = i as u8; + if hex.contains(&c) { + assert!(from_hex(c).is_some()) + } else { + assert!(from_hex(c).is_none()) + } + } + + let expected = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, + ]; + for i in 0..hex.len() { + assert_eq!(from_hex(hex[i]).unwrap(), expected[i]); + } + } +} diff --git a/actix-test/CHANGES.md b/actix-test/CHANGES.md index fa554ba2e..b739011f0 100644 --- a/actix-test/CHANGES.md +++ b/actix-test/CHANGES.md @@ -3,6 +3,27 @@ ## Unreleased - 2021-xx-xx +## 0.1.0-beta.7 - 2021-11-22 +* Fix compatibility with experimental `io-uring` feature of `actix-rt`. [#2408] + +[#2408]: https://github.com/actix/actix-web/pull/2408 + + +## 0.1.0-beta.6 - 2021-11-15 +* No significant changes from `0.1.0-beta.5`. + + +## 0.1.0-beta.5 - 2021-10-20 +* Updated rustls to v0.20. [#2414] +* Minimum supported Rust version (MSRV) is now 1.52. + +[#2414]: https://github.com/actix/actix-web/pull/2414 + + +## 0.1.0-beta.4 - 2021-09-09 +* Minimum supported Rust version (MSRV) is now 1.51. + + ## 0.1.0-beta.3 - 2021-06-20 * No significant changes from `0.1.0-beta.2`. diff --git a/actix-test/Cargo.toml b/actix-test/Cargo.toml index b732cf744..dcaa3e9a3 100644 --- a/actix-test/Cargo.toml +++ b/actix-test/Cargo.toml @@ -1,32 +1,41 @@ [package] name = "actix-test" -version = "0.1.0-beta.3" +version = "0.1.0-beta.7" authors = [ "Nikolay Kim ", "Rob Ede ", ] -edition = "2018" description = "Integration testing tools for Actix Web applications" +keywords = ["http", "web", "framework", "async", "futures"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" +categories = [ + "network-programming", + "asynchronous", + "web-programming::http-server", + "web-programming::websocket", +] license = "MIT OR Apache-2.0" +edition = "2018" [features] default = [] # rustls -rustls = ["tls-rustls", "actix-http/rustls"] +rustls = ["tls-rustls", "actix-http/rustls", "awc/rustls"] # openssl -openssl = ["tls-openssl", "actix-http/openssl"] +openssl = ["tls-openssl", "actix-http/openssl", "awc/openssl"] [dependencies] -actix-codec = "0.4.0" -actix-http = "3.0.0-beta.8" -actix-http-test = { version = "3.0.0-beta.4", features = [] } +actix-codec = "0.4.1" +actix-http = "3.0.0-beta.14" +actix-http-test = "3.0.0-beta.7" actix-service = "2.0.0" actix-utils = "3.0.0" -actix-web = { version = "4.0.0-beta.8", default-features = false, features = ["cookies"] } +actix-web = { version = "4.0.0-beta.11", default-features = false, features = ["cookies"] } actix-rt = "2.1" -awc = { version = "3.0.0-beta.7", default-features = false, features = ["cookies"] } +awc = { version = "3.0.0-beta.11", default-features = false, features = ["cookies"] } futures-core = { version = "0.3.7", default-features = false, features = ["std"] } futures-util = { version = "0.3.7", default-features = false, features = [] } @@ -35,4 +44,5 @@ serde = { version = "1", features = ["derive"] } serde_json = "1" serde_urlencoded = "0.7" tls-openssl = { package = "openssl", version = "0.10.9", optional = true } -tls-rustls = { package = "rustls", version = "0.19.0", optional = true } +tls-rustls = { package = "rustls", version = "0.20.0", optional = true } +tokio = { version = "1.2", features = ["sync"] } diff --git a/actix-test/src/lib.rs b/actix-test/src/lib.rs index c863af44a..1decd6e98 100644 --- a/actix-test/src/lib.rs +++ b/actix-test/src/lib.rs @@ -31,7 +31,7 @@ extern crate tls_openssl as openssl; #[cfg(feature = "rustls")] extern crate tls_rustls as rustls; -use std::{error::Error as StdError, fmt, net, sync::mpsc, thread, time}; +use std::{fmt, net, thread, time::Duration}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; pub use actix_http::test::TestBuffer; @@ -41,8 +41,10 @@ use actix_http::{ }; use actix_service::{map_config, IntoServiceFactory, ServiceFactory, ServiceFactoryExt as _}; use actix_web::{ - dev::{AppConfig, MessageBody, Server, Service}, - rt, web, Error, + body::MessageBody, + dev::{AppConfig, Server, ServerHandle, Service}, + rt::{self, System}, + web, Error, }; use awc::{error::PayloadError, Client, ClientRequest, ClientResponse, Connector}; use futures_core::Stream; @@ -52,6 +54,7 @@ pub use actix_web::test::{ call_service, default_service, init_service, load_stream, ok_service, read_body, read_body_json, read_response, read_response_json, TestRequest, }; +use tokio::sync::mpsc; /// Start default [`TestServer`]. /// @@ -64,7 +67,7 @@ pub use actix_web::test::{ /// Ok(HttpResponse::Ok()) /// } /// -/// #[actix_rt::test] +/// #[actix_web::test] /// async fn test_example() { /// let srv = actix_test::start(|| /// App::new().service(my_handler) @@ -86,7 +89,6 @@ where S::Response: Into> + 'static, >::Future: 'static, B: MessageBody + 'static, - B::Error: Into>, { start_with(TestServerConfig::default(), factory) } @@ -104,7 +106,7 @@ where /// Ok(HttpResponse::Ok()) /// } /// -/// #[actix_rt::test] +/// #[actix_web::test] /// async fn test_example() { /// let srv = actix_test::start_with(actix_test::config().h1(), || /// App::new().service(my_handler) @@ -126,9 +128,12 @@ where S::Response: Into> + 'static, >::Future: 'static, B: MessageBody + 'static, - B::Error: Into>, { - let (tx, rx) = mpsc::channel(); + // for sending handles and server info back from the spawned thread + let (started_tx, started_rx) = std::sync::mpsc::channel(); + + // for signaling the shutdown of spawned server and system + let (thread_stop_tx, thread_stop_rx) = mpsc::channel(1); let tls = match cfg.stream { StreamType::Tcp => false, @@ -138,154 +143,189 @@ where StreamType::Rustls(_) => true, }; - // run server in separate thread + // run server in separate orphaned thread thread::spawn(move || { - let sys = rt::System::new(); - let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap(); - let local_addr = tcp.local_addr().unwrap(); - let factory = factory.clone(); - let srv_cfg = cfg.clone(); - let timeout = cfg.client_timeout; - let builder = Server::build().workers(1).disable_signals(); + rt::System::new().block_on(async move { + let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap(); + let local_addr = tcp.local_addr().unwrap(); + let factory = factory.clone(); + let srv_cfg = cfg.clone(); + let timeout = cfg.client_timeout; - let srv = match srv_cfg.stream { - StreamType::Tcp => match srv_cfg.tp { - HttpVer::Http1 => builder.listen("test", tcp, move || { - let app_cfg = - AppConfig::__priv_test_new(false, local_addr.to_string(), local_addr); + let builder = Server::build().workers(1).disable_signals().system_exit(); - let fac = factory() - .into_factory() - .map_err(|err| err.into().error_response()); + let srv = match srv_cfg.stream { + StreamType::Tcp => match srv_cfg.tp { + HttpVer::Http1 => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); - HttpService::build() - .client_timeout(timeout) - .h1(map_config(fac, move |_| app_cfg.clone())) - .tcp() - }), - HttpVer::Http2 => builder.listen("test", tcp, move || { - let app_cfg = - AppConfig::__priv_test_new(false, local_addr.to_string(), local_addr); + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); - let fac = factory() - .into_factory() - .map_err(|err| err.into().error_response()); + HttpService::build() + .client_timeout(timeout) + .h1(map_config(fac, move |_| app_cfg.clone())) + .tcp() + }), + HttpVer::Http2 => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); - HttpService::build() - .client_timeout(timeout) - .h2(map_config(fac, move |_| app_cfg.clone())) - .tcp() - }), - HttpVer::Both => builder.listen("test", tcp, move || { - let app_cfg = - AppConfig::__priv_test_new(false, local_addr.to_string(), local_addr); + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); - let fac = factory() - .into_factory() - .map_err(|err| err.into().error_response()); + HttpService::build() + .client_timeout(timeout) + .h2(map_config(fac, move |_| app_cfg.clone())) + .tcp() + }), + HttpVer::Both => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); - HttpService::build() - .client_timeout(timeout) - .finish(map_config(fac, move |_| app_cfg.clone())) - .tcp() - }), - }, - #[cfg(feature = "openssl")] - StreamType::Openssl(acceptor) => match cfg.tp { - HttpVer::Http1 => builder.listen("test", tcp, move || { - let app_cfg = - AppConfig::__priv_test_new(false, local_addr.to_string(), local_addr); + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); - let fac = factory() - .into_factory() - .map_err(|err| err.into().error_response()); + HttpService::build() + .client_timeout(timeout) + .finish(map_config(fac, move |_| app_cfg.clone())) + .tcp() + }), + }, + #[cfg(feature = "openssl")] + StreamType::Openssl(acceptor) => match cfg.tp { + HttpVer::Http1 => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); - HttpService::build() - .client_timeout(timeout) - .h1(map_config(fac, move |_| app_cfg.clone())) - .openssl(acceptor.clone()) - }), - HttpVer::Http2 => builder.listen("test", tcp, move || { - let app_cfg = - AppConfig::__priv_test_new(false, local_addr.to_string(), local_addr); + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); - let fac = factory() - .into_factory() - .map_err(|err| err.into().error_response()); + HttpService::build() + .client_timeout(timeout) + .h1(map_config(fac, move |_| app_cfg.clone())) + .openssl(acceptor.clone()) + }), + HttpVer::Http2 => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); - HttpService::build() - .client_timeout(timeout) - .h2(map_config(fac, move |_| app_cfg.clone())) - .openssl(acceptor.clone()) - }), - HttpVer::Both => builder.listen("test", tcp, move || { - let app_cfg = - AppConfig::__priv_test_new(false, local_addr.to_string(), local_addr); + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); - let fac = factory() - .into_factory() - .map_err(|err| err.into().error_response()); + HttpService::build() + .client_timeout(timeout) + .h2(map_config(fac, move |_| app_cfg.clone())) + .openssl(acceptor.clone()) + }), + HttpVer::Both => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); - HttpService::build() - .client_timeout(timeout) - .finish(map_config(fac, move |_| app_cfg.clone())) - .openssl(acceptor.clone()) - }), - }, - #[cfg(feature = "rustls")] - StreamType::Rustls(config) => match cfg.tp { - HttpVer::Http1 => builder.listen("test", tcp, move || { - let app_cfg = - AppConfig::__priv_test_new(false, local_addr.to_string(), local_addr); + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); - let fac = factory() - .into_factory() - .map_err(|err| err.into().error_response()); + HttpService::build() + .client_timeout(timeout) + .finish(map_config(fac, move |_| app_cfg.clone())) + .openssl(acceptor.clone()) + }), + }, + #[cfg(feature = "rustls")] + StreamType::Rustls(config) => match cfg.tp { + HttpVer::Http1 => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); - HttpService::build() - .client_timeout(timeout) - .h1(map_config(fac, move |_| app_cfg.clone())) - .rustls(config.clone()) - }), - HttpVer::Http2 => builder.listen("test", tcp, move || { - let app_cfg = - AppConfig::__priv_test_new(false, local_addr.to_string(), local_addr); + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); - let fac = factory() - .into_factory() - .map_err(|err| err.into().error_response()); + HttpService::build() + .client_timeout(timeout) + .h1(map_config(fac, move |_| app_cfg.clone())) + .rustls(config.clone()) + }), + HttpVer::Http2 => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); - HttpService::build() - .client_timeout(timeout) - .h2(map_config(fac, move |_| app_cfg.clone())) - .rustls(config.clone()) - }), - HttpVer::Both => builder.listen("test", tcp, move || { - let app_cfg = - AppConfig::__priv_test_new(false, local_addr.to_string(), local_addr); + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); - let fac = factory() - .into_factory() - .map_err(|err| err.into().error_response()); + HttpService::build() + .client_timeout(timeout) + .h2(map_config(fac, move |_| app_cfg.clone())) + .rustls(config.clone()) + }), + HttpVer::Both => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); - HttpService::build() - .client_timeout(timeout) - .finish(map_config(fac, move |_| app_cfg.clone())) - .rustls(config.clone()) - }), - }, - } - .unwrap(); + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); + + HttpService::build() + .client_timeout(timeout) + .finish(map_config(fac, move |_| app_cfg.clone())) + .rustls(config.clone()) + }), + }, + } + .expect("test server could not be created"); - sys.block_on(async { let srv = srv.run(); - tx.send((rt::System::current(), srv, local_addr)).unwrap(); + started_tx + .send((System::current(), srv.handle(), local_addr)) + .unwrap(); + + // drive server loop + srv.await.unwrap(); + + // notify TestServer that server and system have shut down + // all thread managed resources should be dropped at this point }); - sys.run() + let _ = thread_stop_tx.send(()); }); - let (system, server, addr) = rx.recv().unwrap(); + let (system, server, addr) = started_rx.recv().unwrap(); let client = { let connector = { @@ -299,15 +339,15 @@ where .set_alpn_protos(b"\x02h2\x08http/1.1") .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); Connector::new() - .conn_lifetime(time::Duration::from_secs(0)) - .timeout(time::Duration::from_millis(30000)) + .conn_lifetime(Duration::from_secs(0)) + .timeout(Duration::from_millis(30000)) .ssl(builder.build()) } #[cfg(not(feature = "openssl"))] { Connector::new() - .conn_lifetime(time::Duration::from_secs(0)) - .timeout(time::Duration::from_millis(30000)) + .conn_lifetime(Duration::from_secs(0)) + .timeout(Duration::from_millis(30000)) } }; @@ -315,11 +355,12 @@ where }; TestServer { - addr, + server, + thread_stop_rx, client, system, + addr, tls, - server, } } @@ -405,11 +446,12 @@ impl TestServerConfig { /// /// See [`start`] for usage example. pub struct TestServer { - addr: net::SocketAddr, + server: ServerHandle, + thread_stop_rx: mpsc::Receiver<()>, client: awc::Client, system: rt::System, + addr: net::SocketAddr, tls: bool, - server: Server, } impl TestServer { @@ -504,16 +546,31 @@ impl TestServer { self.client.headers() } - /// Gracefully stop HTTP server. - pub async fn stop(self) { - self.server.stop(true).await; + /// Stop HTTP server. + /// + /// Waits for spawned `Server` and `System` to shutdown (force) shutdown. + pub async fn stop(mut self) { + // signal server to stop + self.server.stop(false).await; + + // also signal system to stop + // though this is handled by `ServerBuilder::exit_system` too self.system.stop(); - rt::time::sleep(time::Duration::from_millis(100)).await; + + // wait for thread to be stopped but don't care about result + let _ = self.thread_stop_rx.recv().await; } } impl Drop for TestServer { fn drop(&mut self) { - self.system.stop() + // calls in this Drop impl should be enough to shut down the server, system, and thread + // without needing to await anything + + // signal server to stop + let _ = self.server.stop(true); + + // signal system to stop + self.system.stop(); } } diff --git a/actix-web-actors/CHANGES.md b/actix-web-actors/CHANGES.md index bf642ef95..898098ed8 100644 --- a/actix-web-actors/CHANGES.md +++ b/actix-web-actors/CHANGES.md @@ -1,6 +1,15 @@ # Changes ## Unreleased - 2021-xx-xx +* Add `ws:WsResponseBuilder` for building WebSocket session response. [#1920] +* Deprecate `ws::{start_with_addr, start_with_protocols}`. [#1920] +* Minimum supported Rust version (MSRV) is now 1.52. + +[#1920]: https://github.com/actix/actix-web/pull/1920 + + +## 4.0.0-beta.7 - 2021-09-09 +* Minimum supported Rust version (MSRV) is now 1.51. ## 4.0.0-beta.6 - 2021-06-26 diff --git a/actix-web-actors/Cargo.toml b/actix-web-actors/Cargo.toml index fcb5195b8..28b5b29ea 100644 --- a/actix-web-actors/Cargo.toml +++ b/actix-web-actors/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-web-actors" -version = "4.0.0-beta.6" +version = "4.0.0-beta.7" authors = ["Nikolay Kim "] description = "Actix actors support for Actix Web" keywords = ["actix", "http", "web", "framework", "async"] @@ -15,20 +15,20 @@ path = "src/lib.rs" [dependencies] actix = { version = "0.12.0", default-features = false } -actix-codec = "0.4.0" -actix-http = "3.0.0-beta.8" -actix-web = { version = "4.0.0-beta.8", default-features = false } +actix-codec = "0.4.1" +actix-http = "3.0.0-beta.14" +actix-web = { version = "4.0.0-beta.11", default-features = false } bytes = "1" bytestring = "1" futures-core = { version = "0.3.7", default-features = false } -pin-project = "1.0.0" +pin-project-lite = "0.2" tokio = { version = "1", features = ["sync"] } [dev-dependencies] actix-rt = "2.2" -actix-test = "0.1.0-beta.3" +actix-test = "0.1.0-beta.7" -awc = { version = "3.0.0-beta.7", default-features = false } -env_logger = "0.8" +awc = { version = "3.0.0-beta.11", default-features = false } +env_logger = "0.9" futures-util = { version = "0.3.7", default-features = false } diff --git a/actix-web-actors/README.md b/actix-web-actors/README.md index 5f8f78bde..2c29dedf2 100644 --- a/actix-web-actors/README.md +++ b/actix-web-actors/README.md @@ -3,15 +3,15 @@ > Actix actors support for Actix Web. [![crates.io](https://img.shields.io/crates/v/actix-web-actors?label=latest)](https://crates.io/crates/actix-web-actors) -[![Documentation](https://docs.rs/actix-web-actors/badge.svg?version=4.0.0-beta.6)](https://docs.rs/actix-web-actors/4.0.0-beta.6) -[![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) +[![Documentation](https://docs.rs/actix-web-actors/badge.svg?version=4.0.0-beta.7)](https://docs.rs/actix-web-actors/4.0.0-beta.7) +[![Version](https://img.shields.io/badge/rustc-1.52+-ab6000.svg)](https://blog.rust-lang.org/2021/05/06/Rust-1.52.0.html) ![License](https://img.shields.io/crates/l/actix-web-actors.svg)
-[![dependency status](https://deps.rs/crate/actix-web-actors/4.0.0-beta.6/status.svg)](https://deps.rs/crate/actix-web-actors/4.0.0-beta.6) +[![dependency status](https://deps.rs/crate/actix-web-actors/4.0.0-beta.7/status.svg)](https://deps.rs/crate/actix-web-actors/4.0.0-beta.7) [![Download](https://img.shields.io/crates/d/actix-web-actors.svg)](https://crates.io/crates/actix-web-actors) [![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x) ## Documentation & Resources - [API Documentation](https://docs.rs/actix-web-actors) -- Minimum supported Rust version: 1.46 or later +- Minimum Supported Rust Version (MSRV): 1.52 diff --git a/actix-web-actors/src/ws.rs b/actix-web-actors/src/ws.rs index f0a53d4e0..c41268b01 100644 --- a/actix-web-actors/src/ws.rs +++ b/actix-web-actors/src/ws.rs @@ -1,20 +1,24 @@ //! Websocket integration. -use std::future::Future; -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::{collections::VecDeque, convert::TryFrom}; - -use actix::dev::{ - AsyncContextParts, ContextFut, ContextParts, Envelope, Mailbox, StreamHandler, ToEnvelope, +use std::{ + collections::VecDeque, + convert::TryFrom, + future::Future, + io, mem, + pin::Pin, + task::{Context, Poll}, }; -use actix::fut::ActorFuture; + use actix::{ + dev::{ + AsyncContextParts, ContextFut, ContextParts, Envelope, Mailbox, StreamHandler, + ToEnvelope, + }, + fut::ActorFuture, Actor, ActorContext, ActorState, Addr, AsyncContext, Handler, Message as ActixMessage, SpawnHandle, }; -use actix_codec::{Decoder, Encoder}; +use actix_codec::{Decoder as _, Encoder as _}; pub use actix_http::ws::{ CloseCode, CloseReason, Frame, HandshakeError, Message, ProtocolError, }; @@ -30,9 +34,189 @@ use actix_web::{ use bytes::{Bytes, BytesMut}; use bytestring::ByteString; use futures_core::Stream; -use tokio::sync::oneshot::Sender; +use pin_project_lite::pin_project; +use tokio::sync::oneshot; + +/// Builder for Websocket session response. +/// +/// # Examples +/// +/// Create a Websocket session response with default configuration. +/// ```ignore +/// WsResponseBuilder::new(WsActor, &req, stream).start() +/// ``` +/// +/// Create a Websocket session with a specific max frame size, [`Codec`], and protocols. +/// ```ignore +/// const MAX_FRAME_SIZE: usize = 16_384; // 16KiB +/// +/// ws::WsResponseBuilder::new(WsActor, &req, stream) +/// .codec(Codec::new()) +/// .protocols(&["A", "B"]) +/// .frame_size(MAX_FRAME_SIZE) +/// .start() +/// ``` +pub struct WsResponseBuilder<'a, A, T> +where + A: Actor> + StreamHandler>, + T: Stream> + 'static, +{ + actor: A, + req: &'a HttpRequest, + stream: T, + codec: Option, + protocols: Option<&'a [&'a str]>, + frame_size: Option, +} + +impl<'a, A, T> WsResponseBuilder<'a, A, T> +where + A: Actor> + StreamHandler>, + T: Stream> + 'static, +{ + /// Construct a new `WsResponseBuilder` with actor, request, and payload stream. + /// + /// For usage example, see docs on [`WsResponseBuilder`] struct. + pub fn new(actor: A, req: &'a HttpRequest, stream: T) -> Self { + WsResponseBuilder { + actor, + req, + stream, + codec: None, + protocols: None, + frame_size: None, + } + } + + /// Set the protocols for the session. + pub fn protocols(mut self, protocols: &'a [&'a str]) -> Self { + self.protocols = Some(protocols); + self + } + + /// Set the max frame size for each message (in bytes). + /// + /// **Note**: This will override any given [`Codec`]'s max frame size. + pub fn frame_size(mut self, frame_size: usize) -> Self { + self.frame_size = Some(frame_size); + self + } + + /// Set the [`Codec`] for the session. If [`Self::frame_size`] is also set, the given + /// [`Codec`]'s max frame size will be overridden. + pub fn codec(mut self, codec: Codec) -> Self { + self.codec = Some(codec); + self + } + + fn handshake_resp(&self) -> Result { + match self.protocols { + Some(protocols) => handshake_with_protocols(self.req, protocols), + None => handshake(self.req), + } + } + + fn set_frame_size(&mut self) { + if let Some(frame_size) = self.frame_size { + match &mut self.codec { + Some(codec) => { + // modify existing codec's max frame size + let orig_codec = mem::take(codec); + *codec = orig_codec.max_size(frame_size); + } + + None => { + // create a new codec with the given size + self.codec = Some(Codec::new().max_size(frame_size)); + } + } + } + } + + /// Create a new Websocket context from an actor, request stream, and codec. + /// + /// Returns a pair, where the first item is an addr for the created actor, and the second item + /// is a stream intended to be set as part of the response + /// via [`HttpResponseBuilder::streaming()`]. + fn create_with_codec_addr( + actor: A, + stream: S, + codec: Codec, + ) -> (Addr, impl Stream>) + where + A: StreamHandler>, + S: Stream> + 'static, + { + let mb = Mailbox::default(); + let mut ctx = WebsocketContext { + inner: ContextParts::new(mb.sender_producer()), + messages: VecDeque::new(), + }; + ctx.add_stream(WsStream::new(stream, codec.clone())); + + let addr = ctx.address(); + + (addr, WebsocketContextFut::new(ctx, actor, mb, codec)) + } + + /// Perform WebSocket handshake and start actor. + /// + /// `req` is an [`HttpRequest`] that should be requesting a websocket protocol change. + /// `stream` should be a [`Bytes`] stream (such as `actix_web::web::Payload`) that contains a + /// stream of the body request. + /// + /// If there is a problem with the handshake, an error is returned. + /// + /// If successful, consume the [`WsResponseBuilder`] and return a [`HttpResponse`] wrapped in + /// a [`Result`]. + pub fn start(mut self) -> Result { + let mut res = self.handshake_resp()?; + self.set_frame_size(); + + match self.codec { + Some(codec) => { + let out_stream = WebsocketContext::with_codec(self.actor, self.stream, codec); + Ok(res.streaming(out_stream)) + } + None => { + let out_stream = WebsocketContext::create(self.actor, self.stream); + Ok(res.streaming(out_stream)) + } + } + } + + /// Perform WebSocket handshake and start actor. + /// + /// `req` is an [`HttpRequest`] that should be requesting a websocket protocol change. + /// `stream` should be a [`Bytes`] stream (such as `actix_web::web::Payload`) that contains a + /// stream of the body request. + /// + /// If there is a problem with the handshake, an error is returned. + /// + /// If successful, returns a pair where the first item is an address for the created actor and + /// the second item is the [`HttpResponse`] that should be returned from the websocket request. + pub fn start_with_addr(mut self) -> Result<(Addr, HttpResponse), Error> { + let mut res = self.handshake_resp()?; + self.set_frame_size(); + + match self.codec { + Some(codec) => { + let (addr, out_stream) = + Self::create_with_codec_addr(self.actor, self.stream, codec); + Ok((addr, res.streaming(out_stream))) + } + None => { + let (addr, out_stream) = + WebsocketContext::create_with_addr(self.actor, self.stream); + Ok((addr, res.streaming(out_stream))) + } + } + } +} /// Perform WebSocket handshake and start actor. +/// +/// To customize options, see [`WsResponseBuilder`]. pub fn start(actor: A, req: &HttpRequest, stream: T) -> Result where A: Actor> + StreamHandler>, @@ -44,15 +228,15 @@ where /// Perform WebSocket handshake and start actor. /// -/// `req` is an HTTP Request that should be requesting a websocket protocol -/// change. `stream` should be a `Bytes` stream (such as -/// `actix_web::web::Payload`) that contains a stream of the body request. +/// `req` is an HTTP Request that should be requesting a websocket protocol change. `stream` should +/// be a `Bytes` stream (such as `actix_web::web::Payload`) that contains a stream of the +/// body request. /// /// If there is a problem with the handshake, an error is returned. /// -/// If successful, returns a pair where the first item is an address for the -/// created actor and the second item is the response that should be returned -/// from the WebSocket request. +/// If successful, returns a pair where the first item is an address for the created actor and the +/// second item is the response that should be returned from the WebSocket request. +#[deprecated(since = "4.0.0", note = "Prefer `WsResponseBuilder::start_with_addr`.")] pub fn start_with_addr( actor: A, req: &HttpRequest, @@ -70,6 +254,10 @@ where /// Do WebSocket handshake and start ws actor. /// /// `protocols` is a sequence of known protocols. +#[deprecated( + since = "4.0.0", + note = "Prefer `WsResponseBuilder` for setting protocols." +)] pub fn start_with_protocols( actor: A, protocols: &[&str], @@ -86,20 +274,19 @@ where /// Prepare WebSocket handshake response. /// -/// This function returns handshake `HttpResponse`, ready to send to peer. -/// It does not perform any IO. +/// This function returns handshake `HttpResponse`, ready to send to peer. It does not perform +/// any IO. pub fn handshake(req: &HttpRequest) -> Result { handshake_with_protocols(req, &[]) } /// Prepare WebSocket handshake response. /// -/// This function returns handshake `HttpResponse`, ready to send to peer. -/// It does not perform any IO. +/// This function returns handshake `HttpResponse`, ready to send to peer. It does not perform +/// any IO. /// -/// `protocols` is a sequence of known protocols. On successful handshake, -/// the returned response headers contain the first protocol in this list -/// which the server also knows. +/// `protocols` is a sequence of known protocols. On successful handshake, the returned response +/// headers contain the first protocol in this list which the server also knows. pub fn handshake_with_protocols( req: &HttpRequest, protocols: &[&str], @@ -246,8 +433,8 @@ impl WebsocketContext where A: Actor, { + /// Create a new Websocket context from a request and an actor. #[inline] - /// Create a new Websocket context from a request and an actor pub fn create(actor: A, stream: S) -> impl Stream> where A: StreamHandler>, @@ -257,12 +444,11 @@ where stream } - #[inline] /// Create a new Websocket context from a request and an actor. /// - /// Returns a pair, where the first item is an addr for the created actor, - /// and the second item is a stream intended to be set as part of the - /// response via `HttpResponseBuilder::streaming()`. + /// Returns a pair, where the first item is an addr for the created actor, and the second item + /// is a stream intended to be set as part of the response + /// via [`HttpResponseBuilder::streaming()`]. pub fn create_with_addr( actor: A, stream: S, @@ -283,7 +469,6 @@ where (addr, WebsocketContextFut::new(ctx, actor, mb, Codec::new())) } - #[inline] /// Create a new Websocket context from a request, an actor, and a codec pub fn with_codec( actor: A, @@ -299,7 +484,7 @@ where inner: ContextParts::new(mb.sender_producer()), messages: VecDeque::new(), }; - ctx.add_stream(WsStream::new(stream, codec)); + ctx.add_stream(WsStream::new(stream, codec.clone())); WebsocketContextFut::new(ctx, actor, mb, codec) } @@ -457,18 +642,20 @@ where M: ActixMessage + Send + 'static, M::Result: Send, { - fn pack(msg: M, tx: Option>) -> Envelope { + fn pack(msg: M, tx: Option>) -> Envelope { Envelope::new(msg, tx) } } -#[pin_project::pin_project] -struct WsStream { - #[pin] - stream: S, - decoder: Codec, - buf: BytesMut, - closed: bool, +pin_project! { + #[derive(Debug)] + struct WsStream { + #[pin] + stream: S, + decoder: Codec, + buf: BytesMut, + closed: bool, + } } impl WsStream @@ -547,9 +734,12 @@ where #[cfg(test)] mod tests { + use actix_web::{ + http::{header, Method}, + test::TestRequest, + }; + use super::*; - use actix_web::http::{header, Method}; - use actix_web::test::TestRequest; #[test] fn test_handshake() { diff --git a/actix-web-actors/tests/test_ws.rs b/actix-web-actors/tests/test_ws.rs index 0a8e50b3e..a9eb37699 100644 --- a/actix-web-actors/tests/test_ws.rs +++ b/actix-web-actors/tests/test_ws.rs @@ -1,11 +1,9 @@ use actix::prelude::*; -use actix_web::{ - http::{header, StatusCode}, - web, App, HttpRequest, HttpResponse, -}; -use actix_web_actors::*; +use actix_http::ws::Codec; +use actix_web::{web, App, HttpRequest}; +use actix_web_actors::ws; use bytes::Bytes; -use futures_util::{SinkExt as _, StreamExt as _}; +use futures_util::{SinkExt, StreamExt}; struct Ws; @@ -15,37 +13,34 @@ impl Actor for Ws { impl StreamHandler> for Ws { fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { - match msg.unwrap() { - ws::Message::Ping(msg) => ctx.pong(&msg), - ws::Message::Text(text) => ctx.text(text), - ws::Message::Binary(bin) => ctx.binary(bin), - ws::Message::Close(reason) => ctx.close(reason), - _ => {} + match msg { + Ok(ws::Message::Ping(msg)) => ctx.pong(&msg), + Ok(ws::Message::Text(text)) => ctx.text(text), + Ok(ws::Message::Binary(bin)) => ctx.binary(bin), + Ok(ws::Message::Close(reason)) => ctx.close(reason), + _ => ctx.close(Some(ws::CloseCode::Error.into())), } } } -#[actix_rt::test] -async fn test_simple() { - let mut srv = actix_test::start(|| { - App::new().service(web::resource("/").to( - |req: HttpRequest, stream: web::Payload| async move { ws::start(Ws, &req, stream) }, - )) - }); +const MAX_FRAME_SIZE: usize = 10_000; +const DEFAULT_FRAME_SIZE: usize = 10; +async fn common_test_code(mut srv: actix_test::TestServer, frame_size: usize) { // client service let mut framed = srv.ws().await.unwrap(); - framed.send(ws::Message::Text("text".into())).await.unwrap(); + framed.send(ws::Message::Text("text".into())).await.unwrap(); let item = framed.next().await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text"))); + let bytes = Bytes::from(vec![0; frame_size]); framed - .send(ws::Message::Binary("text".into())) + .send(ws::Message::Binary(bytes.clone())) .await .unwrap(); let item = framed.next().await.unwrap().unwrap(); - assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text"))); + assert_eq!(item, ws::Frame::Binary(bytes)); framed.send(ws::Message::Ping("text".into())).await.unwrap(); let item = framed.next().await.unwrap().unwrap(); @@ -55,55 +50,137 @@ async fn test_simple() { .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) .await .unwrap(); - let item = framed.next().await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into()))); } #[actix_rt::test] -async fn test_with_credentials() { - let mut srv = actix_test::start(|| { +async fn simple_builder() { + let srv = actix_test::start(|| { App::new().service(web::resource("/").to( |req: HttpRequest, stream: web::Payload| async move { - if req.headers().contains_key("Authorization") { - ws::start(Ws, &req, stream) - } else { - Ok(HttpResponse::new(StatusCode::UNAUTHORIZED)) - } + ws::WsResponseBuilder::new(Ws, &req, stream).start() }, )) }); - // client service without credentials - match srv.ws().await { - Ok(_) => panic!("WebSocket client without credentials should panic"), - Err(awc::error::WsClientError::InvalidResponseStatus(status)) => { - assert_eq!(status, StatusCode::UNAUTHORIZED) - } - Err(e) => panic!("Invalid error from WebSocket client: {}", e), - } - - let headers = srv.client_headers().unwrap(); - headers.insert( - header::AUTHORIZATION, - header::HeaderValue::from_static("Bearer Something"), - ); - - // client service with credentials - let client = srv.ws(); - - let mut framed = client.await.unwrap(); - - framed.send(ws::Message::Text("text".into())).await.unwrap(); - - let item = framed.next().await.unwrap().unwrap(); - assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text"))); - - framed - .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) - .await - .unwrap(); - - let item = framed.next().await.unwrap().unwrap(); - assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into()))); + common_test_code(srv, DEFAULT_FRAME_SIZE).await; +} + +#[actix_rt::test] +async fn builder_with_frame_size() { + let srv = actix_test::start(|| { + App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| async move { + ws::WsResponseBuilder::new(Ws, &req, stream) + .frame_size(MAX_FRAME_SIZE) + .start() + }, + )) + }); + + common_test_code(srv, MAX_FRAME_SIZE).await; +} + +#[actix_rt::test] +async fn builder_with_frame_size_exceeded() { + const MAX_FRAME_SIZE: usize = 64; + + let mut srv = actix_test::start(|| { + App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| async move { + ws::WsResponseBuilder::new(Ws, &req, stream) + .frame_size(MAX_FRAME_SIZE) + .start() + }, + )) + }); + + // client service + let mut framed = srv.ws().await.unwrap(); + + // create a request with a frame size larger than expected + let bytes = Bytes::from(vec![0; MAX_FRAME_SIZE + 1]); + framed.send(ws::Message::Binary(bytes)).await.unwrap(); + + let frame = framed.next().await.unwrap().unwrap(); + let close_reason = match frame { + ws::Frame::Close(Some(reason)) => reason, + _ => panic!("close frame expected"), + }; + assert_eq!(close_reason.code, ws::CloseCode::Error); +} + +#[actix_rt::test] +async fn builder_with_codec() { + let srv = actix_test::start(|| { + App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| async move { + ws::WsResponseBuilder::new(Ws, &req, stream) + .codec(Codec::new()) + .start() + }, + )) + }); + + common_test_code(srv, DEFAULT_FRAME_SIZE).await; +} + +#[actix_rt::test] +async fn builder_with_protocols() { + let srv = actix_test::start(|| { + App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| async move { + ws::WsResponseBuilder::new(Ws, &req, stream) + .protocols(&["A", "B"]) + .start() + }, + )) + }); + + common_test_code(srv, DEFAULT_FRAME_SIZE).await; +} + +#[actix_rt::test] +async fn builder_with_codec_and_frame_size() { + let srv = actix_test::start(|| { + App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| async move { + ws::WsResponseBuilder::new(Ws, &req, stream) + .codec(Codec::new()) + .frame_size(MAX_FRAME_SIZE) + .start() + }, + )) + }); + + common_test_code(srv, DEFAULT_FRAME_SIZE).await; +} + +#[actix_rt::test] +async fn builder_full() { + let srv = actix_test::start(|| { + App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| async move { + ws::WsResponseBuilder::new(Ws, &req, stream) + .frame_size(MAX_FRAME_SIZE) + .codec(Codec::new()) + .protocols(&["A", "B"]) + .start() + }, + )) + }); + + common_test_code(srv, MAX_FRAME_SIZE).await; +} + +#[actix_rt::test] +async fn simple_start() { + let srv = actix_test::start(|| { + App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| async move { ws::start(Ws, &req, stream) }, + )) + }); + + common_test_code(srv, DEFAULT_FRAME_SIZE).await; } diff --git a/actix-web-codegen/CHANGES.md b/actix-web-codegen/CHANGES.md index a8a901f72..3811ef030 100644 --- a/actix-web-codegen/CHANGES.md +++ b/actix-web-codegen/CHANGES.md @@ -3,6 +3,22 @@ ## Unreleased - 2021-xx-xx +## 0.5.0-beta.5 - 2021-10-20 +* Improve error recovery potential when macro input is invalid. [#2410] +* Add `#[actix_web::test]` macro for setting up tests with a runtime. [#2409] +* Minimum supported Rust version (MSRV) is now 1.52. + +[#2410]: https://github.com/actix/actix-web/pull/2410 +[#2409]: https://github.com/actix/actix-web/pull/2409 + + +## 0.5.0-beta.4 - 2021-09-09 +* In routing macros, paths are now validated at compile time. [#2350] +* Minimum supported Rust version (MSRV) is now 1.51. + +[#2350]: https://github.com/actix/actix-web/pull/2350 + + ## 0.5.0-beta.3 - 2021-06-17 * No notable changes. diff --git a/actix-web-codegen/Cargo.toml b/actix-web-codegen/Cargo.toml index 4d0fd5e26..8497f0b23 100644 --- a/actix-web-codegen/Cargo.toml +++ b/actix-web-codegen/Cargo.toml @@ -1,12 +1,13 @@ [package] name = "actix-web-codegen" -version = "0.5.0-beta.3" +version = "0.5.0-beta.5" description = "Routing and runtime macros for Actix Web" -readme = "README.md" homepage = "https://actix.rs" -repository = "https://github.com/actix/actix-web" -documentation = "https://docs.rs/actix-web-codegen" -authors = ["Nikolay Kim "] +repository = "https://github.com/actix/actix-web.git" +authors = [ + "Nikolay Kim ", + "Rob Ede ", +] license = "MIT OR Apache-2.0" edition = "2018" @@ -17,12 +18,14 @@ proc-macro = true quote = "1" syn = { version = "1", features = ["full", "parsing"] } proc-macro2 = "1" +actix-router = "0.5.0-beta.2" [dev-dependencies] actix-rt = "2.2" -actix-test = "0.1.0-beta.3" +actix-macros = "0.2.3" +actix-test = "0.1.0-beta.7" actix-utils = "3.0.0" -actix-web = "4.0.0-beta.8" +actix-web = "4.0.0-beta.11" futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] } trybuild = "1" diff --git a/actix-web-codegen/README.md b/actix-web-codegen/README.md index 96e4cb51f..2ffd5b31c 100644 --- a/actix-web-codegen/README.md +++ b/actix-web-codegen/README.md @@ -3,18 +3,18 @@ > Routing and runtime macros for Actix Web. [![crates.io](https://img.shields.io/crates/v/actix-web-codegen?label=latest)](https://crates.io/crates/actix-web-codegen) -[![Documentation](https://docs.rs/actix-web-codegen/badge.svg?version=0.5.0-beta.3)](https://docs.rs/actix-web-codegen/0.5.0-beta.3) -[![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) +[![Documentation](https://docs.rs/actix-web-codegen/badge.svg?version=0.5.0-beta.5)](https://docs.rs/actix-web-codegen/0.5.0-beta.5) +[![Version](https://img.shields.io/badge/rustc-1.52+-ab6000.svg)](https://blog.rust-lang.org/2021/05/06/Rust-1.52.0.html) ![License](https://img.shields.io/crates/l/actix-web-codegen.svg)
-[![dependency status](https://deps.rs/crate/actix-web-codegen/0.5.0-beta.3/status.svg)](https://deps.rs/crate/actix-web-codegen/0.5.0-beta.3) +[![dependency status](https://deps.rs/crate/actix-web-codegen/0.5.0-beta.5/status.svg)](https://deps.rs/crate/actix-web-codegen/0.5.0-beta.5) [![Download](https://img.shields.io/crates/d/actix-web-codegen.svg)](https://crates.io/crates/actix-web-codegen) [![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x) ## Documentation & Resources - [API Documentation](https://docs.rs/actix-web-codegen) -- Minimum supported Rust version: 1.46 or later. +- Minimum Supported Rust Version (MSRV): 1.52 ## Compile Testing diff --git a/actix-web-codegen/src/lib.rs b/actix-web-codegen/src/lib.rs index 2237f422c..cebf9e5fb 100644 --- a/actix-web-codegen/src/lib.rs +++ b/actix-web-codegen/src/lib.rs @@ -59,13 +59,14 @@ #![recursion_limit = "512"] use proc_macro::TokenStream; +use quote::quote; mod route; /// Creates resource handler, allowing multiple HTTP method guards. /// /// # Syntax -/// ```text +/// ```plain /// #[route("path", method="HTTP_METHOD"[, attributes])] /// ``` /// @@ -111,7 +112,7 @@ concat!(" Creates route handler with `actix_web::guard::", stringify!($variant), "`. # Syntax -```text +```plain #[", stringify!($method), r#"("path"[, attributes])] ``` @@ -157,24 +158,41 @@ method_macro! { } /// Marks async main function as the actix system entry-point. -/// -/// # Actix Web Re-export -/// This macro can be applied with `#[actix_web::main]` when used in Actix Web applications. -/// + /// # Examples /// ``` -/// #[actix_web_codegen::main] +/// #[actix_web::main] /// async fn main() { /// async { println!("Hello world"); }.await /// } /// ``` #[proc_macro_attribute] pub fn main(_: TokenStream, item: TokenStream) -> TokenStream { - use quote::quote; - let input = syn::parse_macro_input!(item as syn::ItemFn); - (quote! { - #[actix_web::rt::main(system = "::actix_web::rt::System")] - #input + let mut output: TokenStream = (quote! { + #[::actix_web::rt::main(system = "::actix_web::rt::System")] }) - .into() + .into(); + + output.extend(item); + output +} + +/// Marks async test functions to use the actix system entry-point. +/// +/// # Examples +/// ``` +/// #[actix_web::test] +/// async fn test() { +/// assert_eq!(async { "Hello world" }.await, "Hello world"); +/// } +/// ``` +#[proc_macro_attribute] +pub fn test(_: TokenStream, item: TokenStream) -> TokenStream { + let mut output: TokenStream = (quote! { + #[::actix_web::rt::test(system = "::actix_web::rt::System")] + }) + .into(); + + output.extend(item); + output } diff --git a/actix-web-codegen/src/route.rs b/actix-web-codegen/src/route.rs index 747042527..eac1948a7 100644 --- a/actix-web-codegen/src/route.rs +++ b/actix-web-codegen/src/route.rs @@ -3,6 +3,7 @@ extern crate proc_macro; use std::collections::HashSet; use std::convert::TryFrom; +use actix_router::ResourceDef; use proc_macro::TokenStream; use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{format_ident, quote, ToTokens, TokenStreamExt}; @@ -101,6 +102,7 @@ impl Args { match arg { NestedMeta::Lit(syn::Lit::Str(lit)) => match path { None => { + let _ = ResourceDef::new(lit.value()); path = Some(lit); } _ => { @@ -218,7 +220,7 @@ fn guess_resource_type(typ: &syn::Type) -> ResourceType { impl Route { pub fn new( args: AttributeArgs, - input: TokenStream, + ast: syn::ItemFn, method: Option, ) -> syn::Result { if args.is_empty() { @@ -232,14 +234,11 @@ impl Route { ), )); } - let ast: syn::ItemFn = syn::parse(input)?; + let name = ast.sig.ident.clone(); - // Try and pull out the doc comments so that we can reapply them to the - // generated struct. - // - // Note that multi line doc comments are converted to multiple doc - // attributes. + // Try and pull out the doc comments so that we can reapply them to the generated struct. + // Note that multi line doc comments are converted to multiple doc attributes. let doc_attributes = ast .attrs .iter() @@ -347,8 +346,28 @@ pub(crate) fn with_method( input: TokenStream, ) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - match Route::new(args, input, method) { + + let ast = match syn::parse::(input.clone()) { + Ok(ast) => ast, + // on parse error, make IDEs happy; see fn docs + Err(err) => return input_and_compile_error(input, err), + }; + + match Route::new(args, ast, method) { Ok(route) => route.into_token_stream().into(), - Err(err) => err.to_compile_error().into(), + // on macro related error, make IDEs happy; see fn docs + Err(err) => input_and_compile_error(input, err), } } + +/// Converts the error to a token stream and appends it to the original input. +/// +/// Returning the original input in addition to the error is good for IDEs which can gracefully +/// recover and show more precise errors within the macro body. +/// +/// See for more info. +fn input_and_compile_error(mut item: TokenStream, err: syn::Error) -> TokenStream { + let compile_err = TokenStream::from(err.to_compile_error()); + item.extend(compile_err); + item +} diff --git a/actix-web-codegen/tests/test_macro.rs b/actix-web-codegen/tests/test_macro.rs index 6b08c409c..769cf2bc3 100644 --- a/actix-web-codegen/tests/test_macro.rs +++ b/actix-web-codegen/tests/test_macro.rs @@ -256,7 +256,7 @@ async fn test_auto_async() { assert!(response.status().is_success()); } -#[actix_rt::test] +#[actix_web::test] async fn test_wrap() { let srv = actix_test::start(|| App::new().service(get_wrap)); diff --git a/actix-web-codegen/tests/trybuild.rs b/actix-web-codegen/tests/trybuild.rs index 12e848cf3..dd70cb7ca 100644 --- a/actix-web-codegen/tests/trybuild.rs +++ b/actix-web-codegen/tests/trybuild.rs @@ -1,4 +1,4 @@ -#[rustversion::stable(1.46)] // MSRV +#[rustversion::stable(1.52)] // MSRV #[test] fn compile_macros() { let t = trybuild::TestCases::new(); @@ -10,6 +10,9 @@ fn compile_macros() { t.compile_fail("tests/trybuild/route-missing-method-fail.rs"); t.compile_fail("tests/trybuild/route-duplicate-method-fail.rs"); t.compile_fail("tests/trybuild/route-unexpected-method-fail.rs"); + t.compile_fail("tests/trybuild/route-malformed-path-fail.rs"); t.pass("tests/trybuild/docstring-ok.rs"); + + t.pass("tests/trybuild/test-runtime.rs"); } diff --git a/actix-web-codegen/tests/trybuild/route-duplicate-method-fail.stderr b/actix-web-codegen/tests/trybuild/route-duplicate-method-fail.stderr index abdc895d7..90cff1b1c 100644 --- a/actix-web-codegen/tests/trybuild/route-duplicate-method-fail.stderr +++ b/actix-web-codegen/tests/trybuild/route-duplicate-method-fail.stderr @@ -4,8 +4,8 @@ error: HTTP method defined more than once: `GET` 3 | #[route("/", method="GET", method="GET")] | ^^^^^ -error[E0425]: cannot find value `index` in this scope +error[E0277]: the trait bound `fn() -> impl std::future::Future {index}: HttpServiceFactory` is not satisfied --> $DIR/route-duplicate-method-fail.rs:12:55 | 12 | let srv = actix_test::start(|| App::new().service(index)); - | ^^^^^ not found in this scope + | ^^^^^ the trait `HttpServiceFactory` is not implemented for `fn() -> impl std::future::Future {index}` diff --git a/actix-web-codegen/tests/trybuild/route-malformed-path-fail.rs b/actix-web-codegen/tests/trybuild/route-malformed-path-fail.rs new file mode 100644 index 000000000..1258a6f2f --- /dev/null +++ b/actix-web-codegen/tests/trybuild/route-malformed-path-fail.rs @@ -0,0 +1,33 @@ +use actix_web_codegen::get; + +#[get("/{")] +async fn zero() -> &'static str { + "malformed resource def" +} + +#[get("/{foo")] +async fn one() -> &'static str { + "malformed resource def" +} + +#[get("/{}")] +async fn two() -> &'static str { + "malformed resource def" +} + +#[get("/*")] +async fn three() -> &'static str { + "malformed resource def" +} + +#[get("/{tail:\\d+}*")] +async fn four() -> &'static str { + "malformed resource def" +} + +#[get("/{a}/{b}/{c}/{d}/{e}/{f}/{g}/{h}/{i}/{j}/{k}/{l}/{m}/{n}/{o}/{p}/{q}")] +async fn five() -> &'static str { + "malformed resource def" +} + +fn main() {} diff --git a/actix-web-codegen/tests/trybuild/route-malformed-path-fail.stderr b/actix-web-codegen/tests/trybuild/route-malformed-path-fail.stderr new file mode 100644 index 000000000..93c510109 --- /dev/null +++ b/actix-web-codegen/tests/trybuild/route-malformed-path-fail.stderr @@ -0,0 +1,42 @@ +error: custom attribute panicked + --> $DIR/route-malformed-path-fail.rs:3:1 + | +3 | #[get("/{")] + | ^^^^^^^^^^^^ + | + = help: message: pattern "{" contains malformed dynamic segment + +error: custom attribute panicked + --> $DIR/route-malformed-path-fail.rs:8:1 + | +8 | #[get("/{foo")] + | ^^^^^^^^^^^^^^^ + | + = help: message: pattern "{foo" contains malformed dynamic segment + +error: custom attribute panicked + --> $DIR/route-malformed-path-fail.rs:13:1 + | +13 | #[get("/{}")] + | ^^^^^^^^^^^^^ + | + = help: message: Wrong path pattern: "/{}" regex parse error: + ((?s-m)^/(?P<>[^/]+))$ + ^ + error: empty capture group name + +error: custom attribute panicked + --> $DIR/route-malformed-path-fail.rs:23:1 + | +23 | #[get("/{tail:\\d+}*")] + | ^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: message: custom regex is not supported for tail match + +error: custom attribute panicked + --> $DIR/route-malformed-path-fail.rs:28:1 + | +28 | #[get("/{a}/{b}/{c}/{d}/{e}/{f}/{g}/{h}/{i}/{j}/{k}/{l}/{m}/{n}/{o}/{p}/{q}")] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: message: Only 16 dynamic segments are allowed, provided: 17 diff --git a/actix-web-codegen/tests/trybuild/route-missing-method-fail.stderr b/actix-web-codegen/tests/trybuild/route-missing-method-fail.stderr index 0e16b5e27..c36b090c0 100644 --- a/actix-web-codegen/tests/trybuild/route-missing-method-fail.stderr +++ b/actix-web-codegen/tests/trybuild/route-missing-method-fail.stderr @@ -6,8 +6,8 @@ error: The #[route(..)] macro requires at least one `method` attribute | = note: this error originates in an attribute macro (in Nightly builds, run with -Z macro-backtrace for more info) -error[E0425]: cannot find value `index` in this scope +error[E0277]: the trait bound `fn() -> impl std::future::Future {index}: HttpServiceFactory` is not satisfied --> $DIR/route-missing-method-fail.rs:12:55 | 12 | let srv = actix_test::start(|| App::new().service(index)); - | ^^^^^ not found in this scope + | ^^^^^ the trait `HttpServiceFactory` is not implemented for `fn() -> impl std::future::Future {index}` diff --git a/actix-web-codegen/tests/trybuild/route-unexpected-method-fail.stderr b/actix-web-codegen/tests/trybuild/route-unexpected-method-fail.stderr index a638a96a6..dda366067 100644 --- a/actix-web-codegen/tests/trybuild/route-unexpected-method-fail.stderr +++ b/actix-web-codegen/tests/trybuild/route-unexpected-method-fail.stderr @@ -4,8 +4,8 @@ error: Unexpected HTTP method: `UNEXPECTED` 3 | #[route("/", method="UNEXPECTED")] | ^^^^^^^^^^^^ -error[E0425]: cannot find value `index` in this scope +error[E0277]: the trait bound `fn() -> impl std::future::Future {index}: HttpServiceFactory` is not satisfied --> $DIR/route-unexpected-method-fail.rs:12:55 | 12 | let srv = actix_test::start(|| App::new().service(index)); - | ^^^^^ not found in this scope + | ^^^^^ the trait `HttpServiceFactory` is not implemented for `fn() -> impl std::future::Future {index}` diff --git a/actix-web-codegen/tests/trybuild/test-runtime.rs b/actix-web-codegen/tests/trybuild/test-runtime.rs new file mode 100644 index 000000000..0b901b258 --- /dev/null +++ b/actix-web-codegen/tests/trybuild/test-runtime.rs @@ -0,0 +1,6 @@ +#[actix_web::test] +async fn my_test() { + assert!(async { 1 }.await, 1); +} + +fn main() {} diff --git a/awc/CHANGES.md b/awc/CHANGES.md index 16132be1c..ab3362b72 100644 --- a/awc/CHANGES.md +++ b/awc/CHANGES.md @@ -3,6 +3,33 @@ ## Unreleased - 2021-xx-xx +## 3.0.0-beta.12 - 2021-11-30 +* Update `actix-tls` to `3.0.0-rc.1`. [#2474] + +[#2474]: https://github.com/actix/actix-web/pull/2474 + + +## 3.0.0-beta.11 - 2021-11-22 +* No significant changes from `3.0.0-beta.10`. + + +## 3.0.0-beta.10 - 2021-11-15 +* No significant changes from `3.0.0-beta.9`. + + +## 3.0.0-beta.9 - 2021-10-20 +* Updated rustls to v0.20. [#2414] + +[#2414]: https://github.com/actix/actix-web/pull/2414 + + +## 3.0.0-beta.8 - 2021-09-09 +### Changed +* Send headers within the redirect requests. [#2310] + +[#2310]: https://github.com/actix/actix-web/pull/2310 + + ## 3.0.0-beta.7 - 2021-06-26 ### Changed * Change compression algorithm features flags. [#2250] diff --git a/awc/Cargo.toml b/awc/Cargo.toml index 016d3b48b..fc60f5edb 100644 --- a/awc/Cargo.toml +++ b/awc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "awc" -version = "3.0.0-beta.7" +version = "3.0.0-beta.12" authors = [ "Nikolay Kim ", "fakeshadow <24548779@qq.com>", @@ -14,7 +14,7 @@ categories = [ "web-programming::websocket", ] homepage = "https://actix.rs" -repository = "https://github.com/actix/actix-web" +repository = "https://github.com/actix/actix-web.git" license = "MIT OR Apache-2.0" edition = "2018" @@ -30,10 +30,10 @@ features = ["openssl", "rustls", "compress-brotli", "compress-gzip", "compress-z default = ["compress-brotli", "compress-gzip", "compress-zstd", "cookies"] # openssl -openssl = ["tls-openssl", "actix-http/openssl"] +openssl = ["tls-openssl", "actix-tls/openssl"] # rustls -rustls = ["tls-rustls", "actix-http/rustls"] +rustls = ["tls-rustls", "actix-tls/rustls"] # Brotli algorithm content-encoding support compress-brotli = ["actix-http/compress-brotli", "__compress"] @@ -46,24 +46,34 @@ compress-zstd = ["actix-http/compress-zstd", "__compress"] cookies = ["cookie"] # trust-dns as dns resolver -trust-dns = ["actix-http/trust-dns"] +trust-dns = ["trust-dns-resolver"] # Internal (PRIVATE!) features used to aid testing and cheking feature status. # Don't rely on these whatsoever. They may disappear at anytime. __compress = [] -[dependencies] -actix-codec = "0.4.0" -actix-service = "2.0.0" -actix-http = "3.0.0-beta.8" -actix-rt = { version = "2.1", default-features = false } +# Enable dangerous feature for testing and local network usage: +# - HTTP/2 over TCP(No Tls). +# DO NOT enable this over any internet use case. +dangerous-h2c = [] +[dependencies] +actix-codec = "0.4.1" +actix-service = "2.0.0" +actix-http = "3.0.0-beta.14" +actix-rt = { version = "2.1", default-features = false } +actix-tls = { version = "3.0.0-rc.1", features = ["connect", "uri"] } +actix-utils = "3.0.0" + +ahash = "0.7" base64 = "0.13" bytes = "1" cfg-if = "1" -cookie = { version = "0.15", features = ["percent-encode"], optional = true } derive_more = "0.99.5" futures-core = { version = "0.3.7", default-features = false } +futures-util = { version = "0.3.7", default-features = false } +h2 = "0.3" +http = "0.2.5" itoa = "0.4" log =" 0.4" mime = "0.3" @@ -73,24 +83,31 @@ rand = "0.8" serde = "1.0" serde_json = "1.0" serde_urlencoded = "0.7" -tls-openssl = { version = "0.10.9", package = "openssl", optional = true } -tls-rustls = { version = "0.19.0", package = "rustls", optional = true, features = ["dangerous_configuration"] } +tokio = { version = "1", features = ["sync"] } + +cookie = { version = "0.15", features = ["percent-encode"], optional = true } + +tls-openssl = { package = "openssl", version = "0.10.9", optional = true } +tls-rustls = { package = "rustls", version = "0.20.0", optional = true, features = ["dangerous_configuration"] } + +trust-dns-resolver = { version = "0.20.0", optional = true } [dev-dependencies] -actix-web = { version = "4.0.0-beta.8", features = ["openssl"] } -actix-http = { version = "3.0.0-beta.8", features = ["openssl"] } -actix-http-test = { version = "3.0.0-beta.4", features = ["openssl"] } +actix-web = { version = "4.0.0-beta.11", features = ["openssl"] } +actix-http = { version = "3.0.0-beta.14", features = ["openssl"] } +actix-http-test = { version = "3.0.0-beta.7", features = ["openssl"] } actix-utils = "3.0.0" -actix-server = "2.0.0-beta.3" -actix-tls = { version = "3.0.0-beta.5", features = ["openssl", "rustls"] } -actix-test = { version = "0.1.0-beta.3", features = ["openssl", "rustls"] } +actix-server = "2.0.0-beta.9" +actix-tls = { version = "3.0.0-rc.1", features = ["openssl", "rustls"] } +actix-test = { version = "0.1.0-beta.7", features = ["openssl", "rustls"] } brotli2 = "0.3.2" -env_logger = "0.8" +env_logger = "0.9" flate2 = "1.0.13" futures-util = { version = "0.3.7", default-features = false } +static_assertions = "1.1" rcgen = "0.8" -webpki = "0.21" +rustls-pemfile = "0.2" [[example]] name = "client" diff --git a/awc/README.md b/awc/README.md index dd08c6e10..b0faedc68 100644 --- a/awc/README.md +++ b/awc/README.md @@ -3,16 +3,16 @@ > Async HTTP and WebSocket client library. [![crates.io](https://img.shields.io/crates/v/awc?label=latest)](https://crates.io/crates/awc) -[![Documentation](https://docs.rs/awc/badge.svg?version=3.0.0-beta.7)](https://docs.rs/awc/3.0.0-beta.7) +[![Documentation](https://docs.rs/awc/badge.svg?version=3.0.0-beta.12)](https://docs.rs/awc/3.0.0-beta.12) ![MIT or Apache 2.0 licensed](https://img.shields.io/crates/l/awc) -[![Dependency Status](https://deps.rs/crate/awc/3.0.0-beta.7/status.svg)](https://deps.rs/crate/awc/3.0.0-beta.7) +[![Dependency Status](https://deps.rs/crate/awc/3.0.0-beta.12/status.svg)](https://deps.rs/crate/awc/3.0.0-beta.12) [![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x) ## Documentation & Resources - [API Documentation](https://docs.rs/awc) - [Example Project](https://github.com/actix/examples/tree/HEAD/security/awc_https) -- Minimum Supported Rust Version (MSRV): 1.46.0 +- Minimum Supported Rust Version (MSRV): 1.52 ## Example diff --git a/awc/src/any_body.rs b/awc/src/any_body.rs new file mode 100644 index 000000000..cb9038ff3 --- /dev/null +++ b/awc/src/any_body.rs @@ -0,0 +1,266 @@ +use std::{ + borrow::Cow, + fmt, mem, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::{Bytes, BytesMut}; +use futures_core::Stream; +use pin_project_lite::pin_project; + +use actix_http::body::{BodySize, BodyStream, BoxBody, MessageBody, SizedStream}; + +use crate::BoxError; + +pin_project! { + /// Represents various types of HTTP message body. + #[derive(Clone)] + #[project = AnyBodyProj] + pub enum AnyBody { + /// Empty response. `Content-Length` header is not set. + None, + + /// Complete, in-memory response body. + Bytes { body: Bytes }, + + /// Generic / Other message body. + Body { #[pin] body: B }, + } +} + +impl AnyBody { + /// Constructs a "body" representing an empty response. + pub fn none() -> Self { + Self::None + } + + /// Constructs a new, 0-length body. + pub fn empty() -> Self { + Self::Bytes { body: Bytes::new() } + } + + /// Create boxed body from generic message body. + pub fn new_boxed(body: B) -> Self + where + B: MessageBody + 'static, + { + Self::Body { + body: BoxBody::new(body), + } + } + + /// Constructs new `AnyBody` instance from a slice of bytes by copying it. + /// + /// If your bytes container is owned, it may be cheaper to use a `From` impl. + pub fn copy_from_slice(s: &[u8]) -> Self { + Self::Bytes { + body: Bytes::copy_from_slice(s), + } + } + + #[doc(hidden)] + #[deprecated(since = "4.0.0", note = "Renamed to `copy_from_slice`.")] + pub fn from_slice(s: &[u8]) -> Self { + Self::Bytes { + body: Bytes::copy_from_slice(s), + } + } +} + +impl AnyBody { + /// Create body from generic message body. + pub fn new(body: B) -> Self { + Self::Body { body } + } +} + +impl AnyBody +where + B: MessageBody + 'static, +{ + pub fn into_boxed(self) -> AnyBody { + match self { + Self::None => AnyBody::None, + Self::Bytes { body: bytes } => AnyBody::Bytes { body: bytes }, + Self::Body { body } => AnyBody::new_boxed(body), + } + } +} + +impl MessageBody for AnyBody +where + B: MessageBody, +{ + type Error = crate::BoxError; + + fn size(&self) -> BodySize { + match self { + AnyBody::None => BodySize::None, + AnyBody::Bytes { ref body } => BodySize::Sized(body.len() as u64), + AnyBody::Body { ref body } => body.size(), + } + } + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.project() { + AnyBodyProj::None => Poll::Ready(None), + AnyBodyProj::Bytes { body } => { + let len = body.len(); + if len == 0 { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(mem::take(body)))) + } + } + + AnyBodyProj::Body { body } => body.poll_next(cx).map_err(|err| err.into()), + } + } +} + +impl PartialEq for AnyBody { + fn eq(&self, other: &AnyBody) -> bool { + match self { + AnyBody::None => matches!(*other, AnyBody::None), + AnyBody::Bytes { body } => match other { + AnyBody::Bytes { body: b2 } => body == b2, + _ => false, + }, + AnyBody::Body { .. } => false, + } + } +} + +impl fmt::Debug for AnyBody { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + AnyBody::None => write!(f, "AnyBody::None"), + AnyBody::Bytes { ref body } => write!(f, "AnyBody::Bytes({:?})", body), + AnyBody::Body { ref body } => write!(f, "AnyBody::Message({:?})", body), + } + } +} + +impl From<&'static str> for AnyBody { + fn from(string: &'static str) -> Self { + Self::Bytes { + body: Bytes::from_static(string.as_ref()), + } + } +} + +impl From<&'static [u8]> for AnyBody { + fn from(bytes: &'static [u8]) -> Self { + Self::Bytes { + body: Bytes::from_static(bytes), + } + } +} + +impl From> for AnyBody { + fn from(vec: Vec) -> Self { + Self::Bytes { + body: Bytes::from(vec), + } + } +} + +impl From for AnyBody { + fn from(string: String) -> Self { + Self::Bytes { + body: Bytes::from(string), + } + } +} + +impl From<&'_ String> for AnyBody { + fn from(string: &String) -> Self { + Self::Bytes { + body: Bytes::copy_from_slice(AsRef::<[u8]>::as_ref(&string)), + } + } +} + +impl From> for AnyBody { + fn from(string: Cow<'_, str>) -> Self { + match string { + Cow::Owned(s) => Self::from(s), + Cow::Borrowed(s) => Self::Bytes { + body: Bytes::copy_from_slice(AsRef::<[u8]>::as_ref(s)), + }, + } + } +} + +impl From for AnyBody { + fn from(bytes: Bytes) -> Self { + Self::Bytes { body: bytes } + } +} + +impl From for AnyBody { + fn from(bytes: BytesMut) -> Self { + Self::Bytes { + body: bytes.freeze(), + } + } +} + +impl From> for AnyBody +where + S: Stream> + 'static, + E: Into + 'static, +{ + fn from(stream: SizedStream) -> Self { + AnyBody::new_boxed(stream) + } +} + +impl From> for AnyBody +where + S: Stream> + 'static, + E: Into + 'static, +{ + fn from(stream: BodyStream) -> Self { + AnyBody::new_boxed(stream) + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomPinned; + + use static_assertions::{assert_impl_all, assert_not_impl_all}; + + use super::*; + + struct PinType(PhantomPinned); + + impl MessageBody for PinType { + type Error = crate::BoxError; + + fn size(&self) -> BodySize { + unimplemented!() + } + + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + unimplemented!() + } + } + + assert_impl_all!(AnyBody<()>: MessageBody, fmt::Debug, Send, Sync, Unpin); + assert_impl_all!(AnyBody>: MessageBody, fmt::Debug, Send, Sync, Unpin); + assert_impl_all!(AnyBody: MessageBody, fmt::Debug, Send, Sync, Unpin); + assert_impl_all!(AnyBody: MessageBody, fmt::Debug, Unpin); + assert_impl_all!(AnyBody: MessageBody); + + assert_not_impl_all!(AnyBody: Send, Sync, Unpin); + assert_not_impl_all!(AnyBody: Send, Sync, Unpin); +} diff --git a/awc/src/builder.rs b/awc/src/builder.rs index c594b4836..70a28c419 100644 --- a/awc/src/builder.rs +++ b/awc/src/builder.rs @@ -1,20 +1,16 @@ -use std::convert::TryFrom; -use std::fmt; -use std::net::IpAddr; -use std::rc::Rc; -use std::time::Duration; +use std::{convert::TryFrom, fmt, net::IpAddr, rc::Rc, time::Duration}; -use actix_http::{ - client::{Connector, ConnectorService, TcpConnect, TcpConnectError, TcpConnection}, - http::{self, header, Error as HttpError, HeaderMap, HeaderName, Uri}, -}; +use actix_http::http::{self, header, Error as HttpError, HeaderMap, HeaderName, Uri}; use actix_rt::net::{ActixStream, TcpStream}; use actix_service::{boxed, Service}; -use crate::connect::DefaultConnector; -use crate::error::SendRequestError; -use crate::middleware::{NestTransform, Redirect, Transform}; -use crate::{Client, ClientConfig, ConnectRequest, ConnectResponse}; +use crate::{ + client::{ConnectInfo, Connector, ConnectorService, TcpConnectError, TcpConnection}, + connect::DefaultConnector, + error::SendRequestError, + middleware::{NestTransform, Redirect, Transform}, + Client, ClientConfig, ConnectRequest, ConnectResponse, +}; /// An HTTP Client builder /// @@ -37,7 +33,7 @@ impl ClientBuilder { #[allow(clippy::new_ret_no_self)] pub fn new() -> ClientBuilder< impl Service< - TcpConnect, + ConnectInfo, Response = TcpConnection, Error = TcpConnectError, > + Clone, @@ -60,7 +56,7 @@ impl ClientBuilder { impl ClientBuilder where - S: Service, Response = TcpConnection, Error = TcpConnectError> + S: Service, Response = TcpConnection, Error = TcpConnectError> + Clone + 'static, Io: ActixStream + fmt::Debug + 'static, @@ -69,7 +65,7 @@ where pub fn connector(self, connector: Connector) -> ClientBuilder where S1: Service< - TcpConnect, + ConnectInfo, Response = TcpConnection, Error = TcpConnectError, > + Clone diff --git a/actix-http/src/client/config.rs b/awc/src/client/config.rs similarity index 96% rename from actix-http/src/client/config.rs rename to awc/src/client/config.rs index 1c0405cbc..530c1e03b 100644 --- a/actix-http/src/client/config.rs +++ b/awc/src/client/config.rs @@ -1,5 +1,4 @@ -use std::net::IpAddr; -use std::time::Duration; +use std::{net::IpAddr, time::Duration}; const DEFAULT_H2_CONN_WINDOW: u32 = 1024 * 1024 * 2; // 2MB const DEFAULT_H2_STREAM_WINDOW: u32 = 1024 * 1024; // 1MB diff --git a/actix-http/src/client/connection.rs b/awc/src/client/connection.rs similarity index 89% rename from actix-http/src/client/connection.rs rename to awc/src/client/connection.rs index a30f651ca..0e1f0bfec 100644 --- a/actix-http/src/client/connection.rs +++ b/awc/src/client/connection.rs @@ -12,10 +12,9 @@ use bytes::Bytes; use futures_core::future::LocalBoxFuture; use h2::client::SendRequest; -use crate::h1::ClientCodec; -use crate::message::{RequestHeadType, ResponseHead}; -use crate::payload::Payload; -use crate::{body::MessageBody, Error}; +use actix_http::{body::MessageBody, h1::ClientCodec, Payload, RequestHeadType, ResponseHead}; + +use crate::BoxError; use super::error::SendRequestError; use super::pool::Acquired; @@ -174,6 +173,7 @@ impl H2ConnectionInner { /// Cancel spawned connection task on drop. impl Drop for H2ConnectionInner { fn drop(&mut self) { + // TODO: this can end up sending extraneous requests; see if there is a better way to handle if self .sender .send_request(http::Request::new(()), true) @@ -184,8 +184,8 @@ impl Drop for H2ConnectionInner { } } +/// Unified connection type cover HTTP/1 Plain/TLS and HTTP/2 protocols. #[allow(dead_code)] -/// Unified connection type cover Http1 Plain/Tls and Http2 protocols pub enum Connection> where A: ConnectionIo, @@ -219,11 +219,7 @@ impl ConnectionType { } } - pub(super) fn from_h1( - io: Io, - created: time::Instant, - acquired: Acquired, - ) -> Self { + pub(super) fn from_h1(io: Io, created: time::Instant, acquired: Acquired) -> Self { Self::H1(H1Connection { io: Some(io), created, @@ -258,7 +254,7 @@ where where H: Into + 'static, RB: MessageBody + 'static, - RB::Error: Into, + RB::Error: Into, { Box::pin(async move { match self { @@ -271,9 +267,7 @@ where Connection::Tls(ConnectionType::H2(conn)) => { h2proto::send_request(conn, head.into(), body).await } - _ => unreachable!( - "Plain Tcp connection can be used only in Http1 protocol" - ), + _ => unreachable!("Plain Tcp connection can be used only in Http1 protocol"), } }) } @@ -301,9 +295,7 @@ where Err(SendRequestError::TunnelNotSupported) } Connection::Tcp(ConnectionType::H2(_)) => { - unreachable!( - "Plain Tcp connection can be used only in Http1 protocol" - ) + unreachable!("Plain Tcp connection can be used only in Http1 protocol") } } }) @@ -321,12 +313,8 @@ where buf: &mut ReadBuf<'_>, ) -> Poll> { match self.get_mut() { - Connection::Tcp(ConnectionType::H1(conn)) => { - Pin::new(conn).poll_read(cx, buf) - } - Connection::Tls(ConnectionType::H1(conn)) => { - Pin::new(conn).poll_read(cx, buf) - } + Connection::Tcp(ConnectionType::H1(conn)) => Pin::new(conn).poll_read(cx, buf), + Connection::Tls(ConnectionType::H1(conn)) => Pin::new(conn).poll_read(cx, buf), _ => unreachable!("H2Connection can not impl AsyncRead trait"), } } @@ -345,12 +333,8 @@ where buf: &[u8], ) -> Poll> { match self.get_mut() { - Connection::Tcp(ConnectionType::H1(conn)) => { - Pin::new(conn).poll_write(cx, buf) - } - Connection::Tls(ConnectionType::H1(conn)) => { - Pin::new(conn).poll_write(cx, buf) - } + Connection::Tcp(ConnectionType::H1(conn)) => Pin::new(conn).poll_write(cx, buf), + Connection::Tls(ConnectionType::H1(conn)) => Pin::new(conn).poll_write(cx, buf), _ => unreachable!(H2_UNREACHABLE_WRITE), } } @@ -363,17 +347,10 @@ where } } - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - Connection::Tcp(ConnectionType::H1(conn)) => { - Pin::new(conn).poll_shutdown(cx) - } - Connection::Tls(ConnectionType::H1(conn)) => { - Pin::new(conn).poll_shutdown(cx) - } + Connection::Tcp(ConnectionType::H1(conn)) => Pin::new(conn).poll_shutdown(cx), + Connection::Tls(ConnectionType::H1(conn)) => Pin::new(conn).poll_shutdown(cx), _ => unreachable!(H2_UNREACHABLE_WRITE), } } diff --git a/actix-http/src/client/connector.rs b/awc/src/client/connector.rs similarity index 65% rename from actix-http/src/client/connector.rs rename to awc/src/client/connector.rs index bd46919e8..40b3c4d32 100644 --- a/actix-http/src/client/connector.rs +++ b/awc/src/client/connector.rs @@ -8,38 +8,35 @@ use std::{ time::Duration, }; +use actix_http::Protocol; use actix_rt::{ net::{ActixStream, TcpStream}, time::{sleep, Sleep}, }; use actix_service::Service; use actix_tls::connect::{ - new_connector, Connect as TcpConnect, ConnectError as TcpConnectError, - Connection as TcpConnection, Resolver, + ConnectError as TcpConnectError, ConnectInfo, Connection as TcpConnection, + Connector as TcpConnector, Resolver, }; use futures_core::{future::LocalBoxFuture, ready}; use http::Uri; -use pin_project::pin_project; +use pin_project_lite::pin_project; use super::config::ConnectorConfig; use super::connection::{Connection, ConnectionIo}; use super::error::ConnectError; use super::pool::ConnectionPool; use super::Connect; -use super::Protocol; -#[cfg(feature = "openssl")] -use actix_tls::connect::ssl::openssl::SslConnector as OpensslConnector; -#[cfg(feature = "rustls")] -use actix_tls::connect::ssl::rustls::ClientConfig; - -enum SslConnector { - #[allow(dead_code)] +enum OurTlsConnector { + #[allow(dead_code)] // only dead when no TLS feature is enabled None, + #[cfg(feature = "openssl")] - Openssl(OpensslConnector), + Openssl(actix_tls::connect::openssl::reexports::SslConnector), + #[cfg(feature = "rustls")] - Rustls(std::sync::Arc), + Rustls(std::sync::Arc), } /// Manages HTTP client network connectivity. @@ -58,30 +55,54 @@ enum SslConnector { pub struct Connector { connector: T, config: ConnectorConfig, - #[allow(dead_code)] - ssl: SslConnector, + + #[allow(dead_code)] // only dead when no TLS feature is enabled + ssl: OurTlsConnector, } impl Connector<()> { #[allow(clippy::new_ret_no_self, clippy::let_unit_value)] pub fn new() -> Connector< impl Service< - TcpConnect, + ConnectInfo, Response = TcpConnection, Error = actix_tls::connect::ConnectError, > + Clone, > { Connector { - ssl: Self::build_ssl(vec![b"h2".to_vec(), b"http/1.1".to_vec()]), - connector: new_connector(resolver::resolver()), + connector: TcpConnector::new(resolver::resolver()).service(), config: ConnectorConfig::default(), + ssl: Self::build_ssl(vec![b"h2".to_vec(), b"http/1.1".to_vec()]), } } - // Build Ssl connector with openssl, based on supplied alpn protocols - #[cfg(feature = "openssl")] - fn build_ssl(protocols: Vec>) -> SslConnector { - use actix_tls::connect::ssl::openssl::SslMethod; + /// Provides an empty TLS connector when no TLS feature is enabled. + #[cfg(not(any(feature = "openssl", feature = "rustls")))] + fn build_ssl(_: Vec>) -> OurTlsConnector { + OurTlsConnector::None + } + + /// Build TLS connector with rustls, based on supplied ALPN protocols + /// + /// Note that if both `openssl` and `rustls` features are enabled, rustls will be used. + #[cfg(feature = "rustls")] + fn build_ssl(protocols: Vec>) -> OurTlsConnector { + use actix_tls::connect::rustls::{reexports::ClientConfig, webpki_roots_cert_store}; + + let mut config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(webpki_roots_cert_store()) + .with_no_client_auth(); + + config.alpn_protocols = protocols; + + OurTlsConnector::Rustls(std::sync::Arc::new(config)) + } + + /// Build TLS connector with openssl, based on supplied ALPN protocols + #[cfg(all(feature = "openssl", not(feature = "rustls")))] + fn build_ssl(protocols: Vec>) -> OurTlsConnector { + use actix_tls::connect::openssl::reexports::{SslConnector, SslMethod}; use bytes::{BufMut, BytesMut}; let mut alpn = BytesMut::with_capacity(20); @@ -90,28 +111,12 @@ impl Connector<()> { alpn.put(proto.as_slice()); } - let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap(); - let _ = ssl - .set_alpn_protos(&alpn) - .map_err(|e| error!("Can not set alpn protocol: {:?}", e)); - SslConnector::Openssl(ssl.build()) - } + let mut ssl = SslConnector::builder(SslMethod::tls()).unwrap(); + if let Err(err) = ssl.set_alpn_protos(&alpn) { + log::error!("Can not set ALPN protocol: {:?}", err); + } - // Build Ssl connector with rustls, based on supplied alpn protocols - #[cfg(all(not(feature = "openssl"), feature = "rustls"))] - fn build_ssl(protocols: Vec>) -> SslConnector { - let mut config = ClientConfig::new(); - config.set_protocols(&protocols); - config.root_store.add_server_trust_anchors( - &actix_tls::connect::ssl::rustls::TLS_SERVER_ROOTS, - ); - SslConnector::Rustls(std::sync::Arc::new(config)) - } - - // ssl turned off, provides empty ssl connector - #[cfg(not(any(feature = "openssl", feature = "rustls")))] - fn build_ssl(_: Vec>) -> SslConnector { - SslConnector::None + OurTlsConnector::Openssl(ssl.build()) } } @@ -121,7 +126,7 @@ impl Connector { where Io1: ActixStream + fmt::Debug + 'static, S1: Service< - TcpConnect, + ConnectInfo, Response = TcpConnection, Error = TcpConnectError, > + Clone, @@ -134,7 +139,7 @@ impl Connector { } } -impl Connector +impl Connector where // Note: // Input Io type is bound to ActixStream trait but internally in client module they @@ -143,12 +148,9 @@ where // // This remap is to hide ActixStream's trait methods. They are not meant to be called // from user code. - Io: ActixStream + fmt::Debug + 'static, - S: Service< - TcpConnect, - Response = TcpConnection, - Error = TcpConnectError, - > + Clone + IO: ActixStream + fmt::Debug + 'static, + S: Service, Response = TcpConnection, Error = TcpConnectError> + + Clone + 'static, { /// Tcp connection timeout, i.e. max time to connect to remote host including dns name @@ -167,15 +169,21 @@ where #[cfg(feature = "openssl")] /// Use custom `SslConnector` instance. - pub fn ssl(mut self, connector: OpensslConnector) -> Self { - self.ssl = SslConnector::Openssl(connector); + pub fn ssl( + mut self, + connector: actix_tls::connect::openssl::reexports::SslConnector, + ) -> Self { + self.ssl = OurTlsConnector::Openssl(connector); self } #[cfg(feature = "rustls")] - /// Use custom `SslConnector` instance. - pub fn rustls(mut self, connector: std::sync::Arc) -> Self { - self.ssl = SslConnector::Rustls(connector); + /// Use custom `ClientConfig` instance. + pub fn rustls( + mut self, + connector: std::sync::Arc, + ) -> Self { + self.ssl = OurTlsConnector::Rustls(connector); self } @@ -187,7 +195,7 @@ where http::Version::HTTP_11 => vec![b"http/1.1".to_vec()], http::Version::HTTP_2 => vec![b"h2".to_vec(), b"http/1.1".to_vec()], _ => { - unimplemented!("actix-http:client: supported versions http/1.1, http/2") + unimplemented!("actix-http client only supports versions http/1.1 & http/2") } }; self.ssl = Connector::build_ssl(versions); @@ -264,7 +272,7 @@ where /// Finish configuration process and create connector service. /// The Connector builder always concludes by calling `finish()` last in /// its combinator chain. - pub fn finish(self) -> ConnectorService { + pub fn finish(self) -> ConnectorService { let local_address = self.config.local_address; let timeout = self.config.timeout; @@ -277,14 +285,70 @@ where }; let tls_service = match self.ssl { - SslConnector::None => None, + OurTlsConnector::None => { + #[cfg(not(feature = "dangerous-h2c"))] + { + None + } + + #[cfg(feature = "dangerous-h2c")] + { + use std::io; + + use actix_tls::connect::Connection; + use actix_utils::future::{ready, Ready}; + + impl IntoConnectionIo for TcpConnection> { + fn into_connection_io(self) -> (Box, Protocol) { + let io = self.into_parts().0; + (io, Protocol::Http2) + } + } + + /// With the `dangerous-h2c` feature enabled, this connector uses a no-op TLS + /// connection service that passes through plain TCP as a TLS connection. + /// + /// The protocol version of this fake TLS connection is set to be HTTP/2. + #[derive(Clone)] + struct NoOpTlsConnectorService; + + impl Service> for NoOpTlsConnectorService + where + IO: ActixStream + 'static, + { + type Response = Connection>; + type Error = io::Error; + type Future = Ready>; + + actix_service::always_ready!(); + + fn call(&self, connection: Connection) -> Self::Future { + let (io, connection) = connection.replace_io(()); + let (_, connection) = connection.replace_io(Box::new(io) as _); + + ready(Ok(connection)) + } + } + + let handshake_timeout = self.config.handshake_timeout; + + let tls_service = TlsConnectorService { + tcp_service: tcp_service_inner, + tls_service: NoOpTlsConnectorService, + timeout: handshake_timeout, + }; + + Some(actix_service::boxed::rc_service(tls_service)) + } + } + #[cfg(feature = "openssl")] - SslConnector::Openssl(tls) => { + OurTlsConnector::Openssl(tls) => { const H2: &[u8] = b"h2"; - use actix_tls::connect::ssl::openssl::{OpensslConnector, SslStream}; + use actix_tls::connect::openssl::{reexports::AsyncSslStream, TlsConnector}; - impl IntoConnectionIo for TcpConnection> { + impl IntoConnectionIo for TcpConnection> { fn into_connection_io(self) -> (Box, Protocol) { let sock = self.into_parts().0; let h2 = sock @@ -303,27 +367,26 @@ where let tls_service = TlsConnectorService { tcp_service: tcp_service_inner, - tls_service: OpensslConnector::service(tls), + tls_service: TlsConnector::service(tls), timeout: handshake_timeout, }; Some(actix_service::boxed::rc_service(tls_service)) } + #[cfg(feature = "rustls")] - SslConnector::Rustls(tls) => { + OurTlsConnector::Rustls(tls) => { const H2: &[u8] = b"h2"; - use actix_tls::connect::ssl::rustls::{ - RustlsConnector, Session, TlsStream, - }; + use actix_tls::connect::rustls::{reexports::AsyncTlsStream, TlsConnector}; - impl IntoConnectionIo for TcpConnection> { + impl IntoConnectionIo for TcpConnection> { fn into_connection_io(self) -> (Box, Protocol) { let sock = self.into_parts().0; let h2 = sock .get_ref() .1 - .get_alpn_protocol() + .alpn_protocol() .map_or(false, |protos| protos.windows(2).any(|w| w == H2)); if h2 { (Box::new(sock), Protocol::Http2) @@ -337,7 +400,7 @@ where let tls_service = TlsConnectorService { tcp_service: tcp_service_inner, - tls_service: RustlsConnector::service(tls), + tls_service: TlsConnector::service(tls), timeout: handshake_timeout, }; @@ -350,8 +413,8 @@ where let tcp_pool = ConnectionPool::new(tcp_service, tcp_config); let tls_config = self.config; - let tls_pool = tls_service - .map(move |tls_service| ConnectionPool::new(tls_service, tls_config)); + let tls_pool = + tls_service.map(move |tls_service| ConnectionPool::new(tls_service, tls_config)); ConnectorServicePriv { tcp_pool, tls_pool } } @@ -382,10 +445,12 @@ where } } -#[pin_project] -pub struct TcpConnectorFuture { - #[pin] - fut: Fut, +pin_project! { + #[project = TcpConnectorFutureProj] + pub struct TcpConnectorFuture { + #[pin] + fut: Fut, + } } impl Future for TcpConnectorFuture @@ -404,26 +469,28 @@ where /// service for establish tcp connection and do client tls handshake. /// operation is canceled when timeout limit reached. -struct TlsConnectorService { - /// tcp connection is canceled on `TcpConnectorInnerService`'s timeout setting. - tcp_service: S, - /// tls connection is canceled on `TlsConnectorService`'s timeout setting. - tls_service: St, +struct TlsConnectorService { + /// TCP connection is canceled on `TcpConnectorInnerService`'s timeout setting. + tcp_service: Tcp, + + /// TLS connection is canceled on `TlsConnectorService`'s timeout setting. + tls_service: Tls, + timeout: Duration, } -impl Service for TlsConnectorService +impl Service for TlsConnectorService where - S: Service, Error = ConnectError> + Tcp: Service, Error = ConnectError> + Clone + 'static, - St: Service, Error = std::io::Error> + Clone + 'static, - Io: ConnectionIo, - St::Response: IntoConnectionIo, + Tls: Service, Error = std::io::Error> + Clone + 'static, + Tls::Response: IntoConnectionIo, + IO: ConnectionIo, { type Response = (Box, Protocol); type Error = ConnectError; - type Future = TlsConnectorFuture; + type Future = TlsConnectorFuture; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { ready!(self.tcp_service.poll_ready(cx))?; @@ -444,23 +511,25 @@ where } } -#[pin_project(project = TlsConnectorProj)] -#[allow(clippy::large_enum_variant)] -enum TlsConnectorFuture { - TcpConnect { - #[pin] - fut: Fut1, - tls_service: Option, - timeout: Duration, - }, - TlsConnect { - #[pin] - fut: Fut2, - #[pin] - timeout: Sleep, - }, -} +pin_project! { + #[project = TlsConnectorProj] + #[allow(clippy::large_enum_variant)] + enum TlsConnectorFuture { + TcpConnect { + #[pin] + fut: Fut1, + tls_service: Option, + timeout: Duration, + }, + TlsConnect { + #[pin] + fut: Fut2, + #[pin] + timeout: Sleep, + }, + } +} /// helper trait for generic over different TlsStream types between tls crates. trait IntoConnectionIo { fn into_connection_io(self) -> (Box, Protocol); @@ -468,12 +537,7 @@ trait IntoConnectionIo { impl Future for TlsConnectorFuture where - S: Service< - TcpConnection, - Response = Res, - Error = std::io::Error, - Future = Fut2, - >, + S: Service, Response = Res, Error = std::io::Error, Future = Fut2>, S::Response: IntoConnectionIo, Fut1: Future, ConnectError>>, Fut2: Future>, @@ -515,11 +579,7 @@ pub struct TcpConnectorInnerService { } impl TcpConnectorInnerService { - fn new( - service: S, - timeout: Duration, - local_address: Option, - ) -> Self { + fn new(service: S, timeout: Duration, local_address: Option) -> Self { Self { service, timeout, @@ -530,11 +590,8 @@ impl TcpConnectorInnerService { impl Service for TcpConnectorInnerService where - S: Service< - TcpConnect, - Response = TcpConnection, - Error = TcpConnectError, - > + Clone + S: Service, Response = TcpConnection, Error = TcpConnectError> + + Clone + 'static, { type Response = S::Response; @@ -544,7 +601,7 @@ where actix_service::forward_ready!(service); fn call(&self, req: Connect) -> Self::Future { - let mut req = TcpConnect::new(req.uri).set_addr(req.addr); + let mut req = ConnectInfo::new(req.uri).set_addr(req.addr); if let Some(local_addr) = self.local_address { req = req.set_local_addr(local_addr); @@ -557,12 +614,14 @@ where } } -#[pin_project] -pub struct TcpConnectorInnerFuture { - #[pin] - fut: Fut, - #[pin] - timeout: Sleep, +pin_project! { + #[project = TcpConnectorInnerFutureProj] + pub struct TcpConnectorInnerFuture { + #[pin] + fut: Fut, + #[pin] + timeout: Sleep, + } } impl Future for TcpConnectorInnerFuture @@ -581,8 +640,8 @@ where } /// Connector service for pooled Plain/Tls Tcp connections. -pub type ConnectorService = ConnectorServicePriv< - TcpConnectorService>, +pub type ConnectorService = ConnectorServicePriv< + TcpConnectorService>, Rc< dyn Service< Connect, @@ -594,7 +653,7 @@ pub type ConnectorService = ConnectorServicePriv< >, >, >, - Io, + IO, Box, >; @@ -611,12 +670,8 @@ where impl Service for ConnectorServicePriv where - S1: Service - + Clone - + 'static, - S2: Service - + Clone - + 'static, + S1: Service + Clone + 'static, + S2: Service + Clone + 'static, Io1: ConnectionIo, Io2: ConnectionIo, { @@ -636,38 +691,46 @@ where match req.uri.scheme_str() { Some("https") | Some("wss") => match self.tls_pool { None => ConnectorServiceFuture::SslIsNotSupported, - Some(ref pool) => ConnectorServiceFuture::Tls(pool.call(req)), + Some(ref pool) => ConnectorServiceFuture::Tls { + fut: pool.call(req), + }, + }, + _ => ConnectorServiceFuture::Tcp { + fut: self.tcp_pool.call(req), }, - _ => ConnectorServiceFuture::Tcp(self.tcp_pool.call(req)), } } } -#[pin_project(project = ConnectorServiceProj)] -pub enum ConnectorServiceFuture -where - S1: Service - + Clone - + 'static, - S2: Service - + Clone - + 'static, - Io1: ConnectionIo, - Io2: ConnectionIo, -{ - Tcp(#[pin] as Service>::Future), - Tls(#[pin] as Service>::Future), - SslIsNotSupported, +pin_project! { + #[project = ConnectorServiceFutureProj] + pub enum ConnectorServiceFuture + where + S1: Service, + S1: Clone, + S1: 'static, + S2: Service, + S2: Clone, + S2: 'static, + Io1: ConnectionIo, + Io2: ConnectionIo, + { + Tcp { + #[pin] + fut: as Service>::Future + }, + Tls { + #[pin] + fut: as Service>::Future + }, + SslIsNotSupported + } } impl Future for ConnectorServiceFuture where - S1: Service - + Clone - + 'static, - S2: Service - + Clone - + 'static, + S1: Service + Clone + 'static, + S2: Service + Clone + 'static, Io1: ConnectionIo, Io2: ConnectionIo, { @@ -675,9 +738,9 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.project() { - ConnectorServiceProj::Tcp(fut) => fut.poll(cx).map_ok(Connection::Tcp), - ConnectorServiceProj::Tls(fut) => fut.poll(cx).map_ok(Connection::Tls), - ConnectorServiceProj::SslIsNotSupported => { + ConnectorServiceFutureProj::Tcp { fut } => fut.poll(cx).map_ok(Connection::Tcp), + ConnectorServiceFutureProj::Tls { fut } => fut.poll(cx).map_ok(Connection::Tls), + ConnectorServiceFutureProj::SslIsNotSupported => { Poll::Ready(Err(ConnectError::SslIsNotSupported)) } } @@ -689,7 +752,7 @@ mod resolver { use super::*; pub(super) fn resolver() -> Resolver { - Resolver::Default + Resolver::default() } } @@ -731,8 +794,7 @@ mod resolver { } } - // dns struct is cached in thread local. - // so new client constructor can reuse the existing dns resolver. + // resolver struct is cached in thread local so new clients can reuse the existing instance thread_local! { static TRUST_DNS_RESOLVER: RefCell> = RefCell::new(None); } @@ -740,8 +802,10 @@ mod resolver { // get from thread local or construct a new trust-dns resolver. TRUST_DNS_RESOLVER.with(|local| { let resolver = local.borrow().as_ref().map(Clone::clone); + match resolver { Some(resolver) => resolver, + None => { let (cfg, opts) = match read_system_conf() { Ok((cfg, opts)) => (cfg, opts), @@ -754,11 +818,51 @@ mod resolver { let resolver = TokioAsyncResolver::tokio(cfg, opts).unwrap(); // box trust dns resolver and put it in thread local. - let resolver = Resolver::new_custom(TrustDnsResolver(resolver)); + let resolver = Resolver::custom(TrustDnsResolver(resolver)); *local.borrow_mut() = Some(resolver.clone()); + resolver } } }) } } + +#[cfg(feature = "dangerous-h2c")] +#[cfg(test)] +mod tests { + use std::convert::Infallible; + + use actix_http::{HttpService, Request, Response, Version}; + use actix_http_test::test_server; + use actix_service::ServiceFactoryExt as _; + + use super::*; + use crate::Client; + + #[actix_rt::test] + async fn h2c_connector() { + let mut srv = test_server(|| { + HttpService::build() + .h2(|_req: Request| async { Ok::<_, Infallible>(Response::ok()) }) + .tcp() + .map_err(|_| ()) + }) + .await; + + let connector = Connector { + connector: TcpConnector::new(resolver::resolver()).service(), + config: ConnectorConfig::default(), + ssl: OurTlsConnector::None, + }; + + let client = Client::builder().connector(connector).finish(); + + let request = client.get(srv.surl("/")).send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.version(), Version::HTTP_2); + + srv.stop().await; + } +} diff --git a/actix-http/src/client/error.rs b/awc/src/client/error.rs similarity index 92% rename from actix-http/src/client/error.rs rename to awc/src/client/error.rs index 34833503b..d351b5d5e 100644 --- a/actix-http/src/client/error.rs +++ b/awc/src/client/error.rs @@ -1,12 +1,13 @@ -use std::{error::Error as StdError, fmt, io}; +use std::{fmt, io}; use derive_more::{Display, From}; -#[cfg(feature = "openssl")] -use actix_tls::accept::openssl::SslError; +use actix_http::{error::ParseError, http::Error as HttpError}; -use crate::error::{Error, ParseError}; -use crate::http::Error as HttpError; +#[cfg(feature = "openssl")] +use actix_tls::accept::openssl::reexports::Error as OpensslError; + +use crate::BoxError; /// A set of errors that can occur while connecting to an HTTP host #[derive(Debug, Display, From)] @@ -19,7 +20,7 @@ pub enum ConnectError { /// SSL error #[cfg(feature = "openssl")] #[display(fmt = "{}", _0)] - SslError(SslError), + SslError(OpensslError), /// Failed to resolve the hostname #[display(fmt = "Failed resolving hostname: {}", _0)] @@ -117,11 +118,11 @@ pub enum SendRequestError { TunnelNotSupported, /// Error sending request body - Body(Error), + Body(BoxError), /// Other errors that can occur after submitting a request. #[display(fmt = "{:?}: {}", _1, _0)] - Custom(Box, Box), + Custom(BoxError, Box), } impl std::error::Error for SendRequestError {} @@ -140,7 +141,7 @@ pub enum FreezeRequestError { /// Other errors that can occur after submitting a request. #[display(fmt = "{:?}: {}", _1, _0)] - Custom(Box, Box), + Custom(BoxError, Box), } impl std::error::Error for FreezeRequestError {} diff --git a/actix-http/src/client/h1proto.rs b/awc/src/client/h1proto.rs similarity index 87% rename from actix-http/src/client/h1proto.rs rename to awc/src/client/h1proto.rs index 65a30748c..b26a97eeb 100644 --- a/actix-http/src/client/h1proto.rs +++ b/awc/src/client/h1proto.rs @@ -5,24 +5,27 @@ use std::{ }; use actix_codec::Framed; +use actix_http::{ + body::{BodySize, MessageBody}, + error::PayloadError, + h1, + http::{ + header::{HeaderMap, IntoHeaderValue, EXPECT, HOST}, + StatusCode, + }, + Payload, RequestHeadType, ResponseHead, +}; use actix_utils::future::poll_fn; use bytes::buf::BufMut; use bytes::{Bytes, BytesMut}; use futures_core::{ready, Stream}; use futures_util::SinkExt as _; +use pin_project_lite::pin_project; -use crate::h1; -use crate::http::{ - header::{HeaderMap, IntoHeaderValue, EXPECT, HOST}, - StatusCode, -}; -use crate::message::{RequestHeadType, ResponseHead}; -use crate::payload::Payload; -use crate::{error::PayloadError, Error}; +use crate::BoxError; use super::connection::{ConnectionIo, H1Connection}; use super::error::{ConnectError, SendRequestError}; -use crate::body::{BodySize, MessageBody}; pub(crate) async fn send_request( io: H1Connection, @@ -32,7 +35,7 @@ pub(crate) async fn send_request( where Io: ConnectionIo, B: MessageBody, - B::Error: Into, + B::Error: Into, { // set request host header if !head.as_ref().headers.contains_key(HOST) @@ -65,11 +68,10 @@ where let mut framed = Framed::new(io, h1::ClientCodec::default()); // Check EXPECT header and enable expect handle flag accordingly. - // - // RFC: https://tools.ietf.org/html/rfc7231#section-5.1.1 + // See https://datatracker.ietf.org/doc/html/rfc7231#section-5.1.1 let is_expect = if head.as_ref().headers.contains_key(EXPECT) { match body.size() { - BodySize::None | BodySize::Empty | BodySize::Sized(0) => { + BodySize::None | BodySize::Sized(0) => { let keep_alive = framed.codec_ref().keepalive(); framed.io_mut().on_release(keep_alive); @@ -103,7 +105,7 @@ where if do_send { // send request body match body.size() { - BodySize::None | BodySize::Empty | BodySize::Sized(0) => {} + BodySize::None | BodySize::Sized(0) => {} _ => send_body(body, pin_framed.as_mut()).await?, }; @@ -155,7 +157,7 @@ pub(crate) async fn send_body( where Io: ConnectionIo, B: MessageBody, - B::Error: Into, + B::Error: Into, { actix_rt::pin!(body); @@ -166,7 +168,7 @@ where Some(Ok(chunk)) => { framed.as_mut().write(h1::Message::Chunk(Some(chunk)))?; } - Some(Err(err)) => return Err(err.into().into()), + Some(Err(err)) => return Err(SendRequestError::Body(err.into())), None => { eof = true; framed.as_mut().write(h1::Message::Chunk(None))?; @@ -194,10 +196,11 @@ where Ok(()) } -#[pin_project::pin_project] -pub(crate) struct PlStream { - #[pin] - framed: Framed, h1::ClientPayloadCodec>, +pin_project! { + pub(crate) struct PlStream { + #[pin] + framed: Framed, h1::ClientPayloadCodec>, + } } impl PlStream { @@ -211,10 +214,7 @@ impl PlStream { impl Stream for PlStream { type Item = Result; - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); match ready!(this.framed.as_mut().next_item(cx)?) { diff --git a/actix-http/src/client/h2proto.rs b/awc/src/client/h2proto.rs similarity index 88% rename from actix-http/src/client/h2proto.rs rename to awc/src/client/h2proto.rs index b9d5f96bd..9ced5776b 100644 --- a/actix-http/src/client/h2proto.rs +++ b/awc/src/client/h2proto.rs @@ -8,15 +8,16 @@ use h2::{ }; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING}; use http::{request::Request, Method, Version}; +use log::trace; -use crate::{ +use actix_http::{ body::{BodySize, MessageBody}, header::HeaderMap, - message::{RequestHeadType, ResponseHead}, - payload::Payload, - Error, + Payload, RequestHeadType, ResponseHead, }; +use crate::BoxError; + use super::{ config::ConnectorConfig, connection::{ConnectionIo, H2Connection}, @@ -31,16 +32,13 @@ pub(crate) async fn send_request( where Io: ConnectionIo, B: MessageBody, - B::Error: Into, + B::Error: Into, { trace!("Sending client request: {:?} {:?}", head, body.size()); let head_req = head.as_ref().method == Method::HEAD; let length = body.size(); - let eof = matches!( - length, - BodySize::None | BodySize::Empty | BodySize::Sized(0) - ); + let eof = matches!(length, BodySize::None | BodySize::Sized(0)); let mut req = Request::new(()); *req.uri_mut() = head.as_ref().uri.clone(); @@ -53,13 +51,11 @@ where // Content length let _ = match length { BodySize::None => None, - BodySize::Stream => { - skip_len = false; - None - } - BodySize::Empty => req + + BodySize::Sized(0) => req .headers_mut() .insert(CONTENT_LENGTH, HeaderValue::from_static("0")), + BodySize::Sized(len) => { let mut buf = itoa::Buffer::new(); @@ -68,6 +64,11 @@ where HeaderValue::from_str(buf.format(len)).unwrap(), ) } + + BodySize::Stream => { + skip_len = false; + None + } }; // Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate. @@ -91,7 +92,7 @@ where for (key, value) in headers { match *key { // TODO: consider skipping other headers according to: - // https://tools.ietf.org/html/rfc7540#section-8.1.2.2 + // https://datatracker.ietf.org/doc/html/rfc7540#section-8.1.2.2 // omit HTTP/1.x only headers CONNECTION | TRANSFER_ENCODING => continue, CONTENT_LENGTH if skip_len => continue, @@ -131,16 +132,15 @@ where Ok((head, payload)) } -async fn send_body( - body: B, - mut send: SendStream, -) -> Result<(), SendRequestError> +async fn send_body(body: B, mut send: SendStream) -> Result<(), SendRequestError> where B: MessageBody, - B::Error: Into, + B::Error: Into, { let mut buf = None; + actix_rt::pin!(body); + loop { if buf.is_none() { match poll_fn(|cx| body.as_mut().poll_next(cx)).await { @@ -148,10 +148,10 @@ where send.reserve_capacity(b.len()); buf = Some(b); } - Some(Err(e)) => return Err(e.into().into()), + Some(Err(err)) => return Err(SendRequestError::Body(err.into())), None => { - if let Err(e) = send.send_data(Bytes::new(), true) { - return Err(e.into()); + if let Err(err) = send.send_data(Bytes::new(), true) { + return Err(err.into()); } send.reserve_capacity(0); return Ok(()); @@ -184,8 +184,7 @@ where pub(crate) fn handshake( io: Io, config: &ConnectorConfig, -) -> impl Future, Connection), h2::Error>> -{ +) -> impl Future, Connection), h2::Error>> { let mut builder = Builder::new(); builder .initial_window_size(config.stream_window_size) diff --git a/actix-http/src/client/mod.rs b/awc/src/client/mod.rs similarity index 80% rename from actix-http/src/client/mod.rs rename to awc/src/client/mod.rs index 41d5fef2a..0d5c899bc 100644 --- a/actix-http/src/client/mod.rs +++ b/awc/src/client/mod.rs @@ -11,13 +11,12 @@ mod h2proto; mod pool; pub use actix_tls::connect::{ - Connect as TcpConnect, ConnectError as TcpConnectError, Connection as TcpConnection, + ConnectError as TcpConnectError, ConnectInfo, Connection as TcpConnection, }; pub use self::connection::{Connection, ConnectionIo}; pub use self::connector::{Connector, ConnectorService}; pub use self::error::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError}; -pub use crate::Protocol; #[derive(Clone)] pub struct Connect { diff --git a/actix-http/src/client/pool.rs b/awc/src/client/pool.rs similarity index 95% rename from actix-http/src/client/pool.rs rename to awc/src/client/pool.rs index 88188038f..9d130412b 100644 --- a/actix-http/src/client/pool.rs +++ b/awc/src/client/pool.rs @@ -14,22 +14,21 @@ use std::{ }; use actix_codec::{AsyncRead, AsyncWrite, ReadBuf}; +use actix_http::Protocol; use actix_rt::time::{sleep, Sleep}; use actix_service::Service; use ahash::AHashMap; use futures_core::future::LocalBoxFuture; +use futures_util::FutureExt; use http::uri::Authority; -use pin_project::pin_project; +use pin_project_lite::pin_project; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use super::config::ConnectorConfig; -use super::connection::{ - ConnectionInnerType, ConnectionIo, ConnectionType, H2ConnectionInner, -}; +use super::connection::{ConnectionInnerType, ConnectionIo, ConnectionType, H2ConnectionInner}; use super::error::ConnectError; use super::h2proto::handshake; use super::Connect; -use super::Protocol; #[derive(Hash, Eq, PartialEq, Clone, Debug)] pub struct Key { @@ -152,9 +151,7 @@ where impl Service for ConnectionPool where - S: Service - + Clone - + 'static, + S: Service + Clone + 'static, Io: ConnectionIo, { type Response = ConnectionType; @@ -195,8 +192,8 @@ where let config = &inner.config; let idle_dur = now - c.used; let age = now - c.created; - let conn_ineligible = idle_dur > config.conn_keep_alive - || age > config.conn_lifetime; + let conn_ineligible = + idle_dur > config.conn_keep_alive || age > config.conn_lifetime; if conn_ineligible { // drop connections that are too old @@ -205,7 +202,7 @@ where // check if the connection is still usable if let ConnectionInnerType::H1(ref mut io) = c.conn { let check = ConnectionCheckFuture { io }; - match check.await { + match check.now_or_never().expect("ConnectionCheckFuture must never yield with Poll::Pending.") { ConnectionState::Tainted => { inner.close(c.conn); continue; @@ -231,9 +228,7 @@ where // match the connection and spawn new one if did not get anything. match conn { - Some(conn) => { - Ok(ConnectionType::from_pool(conn.conn, conn.created, acquired)) - } + Some(conn) => Ok(ConnectionType::from_pool(conn.conn, conn.created, acquired)), None => { let (io, proto) = connector.call(req).await?; @@ -284,9 +279,7 @@ where let mut read_buf = ReadBuf::new(&mut buf); let state = match Pin::new(&mut this.io).poll_read(cx, &mut read_buf) { - Poll::Ready(Ok(())) if !read_buf.filled().is_empty() => { - ConnectionState::Tainted - } + Poll::Ready(Ok(())) if !read_buf.filled().is_empty() => ConnectionState::Tainted, Poll::Pending => ConnectionState::Live, _ => ConnectionState::Skip, @@ -302,11 +295,13 @@ struct PooledConnection { created: Instant, } -#[pin_project] -struct CloseConnection { - io: Io, - #[pin] - timeout: Sleep, +pin_project! { + #[project = CloseConnectionProj] + struct CloseConnection { + io: Io, + #[pin] + timeout: Sleep, + } } impl CloseConnection @@ -413,17 +408,11 @@ mod test { unimplemented!() } - fn poll_flush( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { unimplemented!() } - fn poll_shutdown( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } } diff --git a/awc/src/connect.rs b/awc/src/connect.rs index 6a9fc4630..19870b069 100644 --- a/awc/src/connect.rs +++ b/awc/src/connect.rs @@ -7,18 +7,17 @@ use std::{ }; use actix_codec::Framed; -use actix_http::{ - body::Body, - client::{ - Connect as ClientConnect, ConnectError, Connection, ConnectionIo, SendRequestError, - }, - h1::ClientCodec, - Payload, RequestHead, RequestHeadType, ResponseHead, -}; +use actix_http::{h1::ClientCodec, Payload, RequestHead, RequestHeadType, ResponseHead}; use actix_service::Service; use futures_core::{future::LocalBoxFuture, ready}; -use crate::response::ClientResponse; +use crate::{ + any_body::AnyBody, + client::{ + Connect as ClientConnect, ConnectError, Connection, ConnectionIo, SendRequestError, + }, + response::ClientResponse, +}; pub type BoxConnectorService = Rc< dyn Service< @@ -32,7 +31,7 @@ pub type BoxConnectorService = Rc< pub type BoxedSocket = Box; pub enum ConnectRequest { - Client(RequestHeadType, Body, Option), + Client(RequestHeadType, AnyBody, Option), Tunnel(RequestHead, Option), } diff --git a/awc/src/error.rs b/awc/src/error.rs index c83c5ebbf..726e1a506 100644 --- a/awc/src/error.rs +++ b/awc/src/error.rs @@ -1,15 +1,17 @@ //! HTTP client errors -pub use actix_http::client::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError}; -pub use actix_http::error::PayloadError; -pub use actix_http::http::Error as HttpError; -pub use actix_http::ws::HandshakeError as WsHandshakeError; -pub use actix_http::ws::ProtocolError as WsProtocolError; +pub use actix_http::{ + error::PayloadError, + http::{header::HeaderValue, Error as HttpError, StatusCode}, + ws::{HandshakeError as WsHandshakeError, ProtocolError as WsProtocolError}, +}; +use derive_more::{Display, From}; use serde_json::error::Error as JsonError; -use actix_http::http::{header::HeaderValue, StatusCode}; -use derive_more::{Display, From}; +pub use crate::client::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError}; + +// TODO: address display, error, and from impls /// Websocket client error #[derive(Debug, Display, From)] diff --git a/awc/src/frozen.rs b/awc/src/frozen.rs index cb8c0f1bf..472397359 100644 --- a/awc/src/frozen.rs +++ b/awc/src/frozen.rs @@ -1,18 +1,18 @@ -use std::{convert::TryFrom, error::Error as StdError, net, rc::Rc, time::Duration}; +use std::{convert::TryFrom, net, rc::Rc, time::Duration}; use bytes::Bytes; use futures_core::Stream; use serde::Serialize; use actix_http::{ - body::Body, http::{header::IntoHeaderValue, Error as HttpError, HeaderMap, HeaderName, Method, Uri}, RequestHead, }; use crate::{ + any_body::AnyBody, sender::{RequestSender, SendClientRequest}, - ClientConfig, + BoxError, ClientConfig, }; /// `FrozenClientRequest` struct represents cloneable client request. @@ -45,7 +45,7 @@ impl FrozenClientRequest { /// Send a body. pub fn send_body(&self, body: B) -> SendClientRequest where - B: Into, + B: Into, { RequestSender::Rc(self.head.clone(), None).send_body( self.addr, @@ -82,7 +82,7 @@ impl FrozenClientRequest { pub fn send_stream(&self, stream: S) -> SendClientRequest where S: Stream> + Unpin + 'static, - E: Into> + 'static, + E: Into + 'static, { RequestSender::Rc(self.head.clone(), None).send_stream( self.addr, @@ -158,7 +158,7 @@ impl FrozenSendBuilder { /// Complete request construction and send a body. pub fn send_body(self, body: B) -> SendClientRequest where - B: Into, + B: Into, { if let Some(e) = self.err { return e.into(); @@ -207,7 +207,7 @@ impl FrozenSendBuilder { pub fn send_stream(self, stream: S) -> SendClientRequest where S: Stream> + Unpin + 'static, - E: Into> + 'static, + E: Into + 'static, { if let Some(e) = self.err { return e.into(); diff --git a/awc/src/lib.rs b/awc/src/lib.rs index c0290ddcf..2f4183120 100644 --- a/awc/src/lib.rs +++ b/awc/src/lib.rs @@ -104,22 +104,9 @@ #![doc(html_logo_url = "https://actix.rs/img/logo.png")] #![doc(html_favicon_url = "https://actix.rs/favicon.ico")] -use std::{convert::TryFrom, rc::Rc, time::Duration}; - -#[cfg(feature = "cookies")] -pub use cookie; - -pub use actix_http::{client::Connector, http}; - -use actix_http::{ - client::{TcpConnect, TcpConnectError, TcpConnection}, - http::{Error as HttpError, HeaderMap, Method, Uri}, - RequestHead, -}; -use actix_rt::net::TcpStream; -use actix_service::Service; - +mod any_body; mod builder; +mod client; mod connect; pub mod error; mod frozen; @@ -130,13 +117,31 @@ mod sender; pub mod test; pub mod ws; +pub use actix_http::http; +#[cfg(feature = "cookies")] +pub use cookie; + pub use self::builder::ClientBuilder; +pub use self::client::Connector; pub use self::connect::{BoxConnectorService, BoxedSocket, ConnectRequest, ConnectResponse}; pub use self::frozen::{FrozenClientRequest, FrozenSendBuilder}; pub use self::request::ClientRequest; pub use self::response::{ClientResponse, JsonBody, MessageBody}; pub use self::sender::SendClientRequest; +use std::{convert::TryFrom, rc::Rc, time::Duration}; + +use actix_http::{ + http::{Error as HttpError, HeaderMap, Method, Uri}, + RequestHead, +}; +use actix_rt::net::TcpStream; +use actix_service::Service; + +use self::client::{ConnectInfo, TcpConnectError, TcpConnection}; + +pub(crate) type BoxError = Box; + /// An asynchronous HTTP and WebSocket client. /// /// You should take care to create, at most, one `Client` per thread. Otherwise, expect higher CPU @@ -184,7 +189,7 @@ impl Client { /// This function is equivalent of `ClientBuilder::new()`. pub fn builder() -> ClientBuilder< impl Service< - TcpConnect, + ConnectInfo, Response = TcpConnection, Error = TcpConnectError, > + Clone, diff --git a/awc/src/middleware/redirect.rs b/awc/src/middleware/redirect.rs index ae09edf9c..89cff22cd 100644 --- a/awc/src/middleware/redirect.rs +++ b/awc/src/middleware/redirect.rs @@ -8,8 +8,6 @@ use std::{ }; use actix_http::{ - body::Body, - client::{InvalidUrl, SendRequestError}, http::{header, Method, StatusCode, Uri}, RequestHead, RequestHeadType, }; @@ -18,9 +16,12 @@ use bytes::Bytes; use futures_core::ready; use super::Transform; - -use crate::connect::{ConnectRequest, ConnectResponse}; -use crate::ClientResponse; +use crate::{ + any_body::AnyBody, + client::{InvalidUrl, SendRequestError}, + connect::{ConnectRequest, ConnectResponse}, + ClientResponse, +}; pub struct Redirect { max_redirect_times: u8, @@ -85,15 +86,17 @@ where let max_redirect_times = self.max_redirect_times; // backup the uri and method for reuse schema and authority. - let (uri, method) = match head { - RequestHeadType::Owned(ref head) => (head.uri.clone(), head.method.clone()), + let (uri, method, headers) = match head { + RequestHeadType::Owned(ref head) => { + (head.uri.clone(), head.method.clone(), head.headers.clone()) + } RequestHeadType::Rc(ref head, ..) => { - (head.uri.clone(), head.method.clone()) + (head.uri.clone(), head.method.clone(), head.headers.clone()) } }; let body_opt = match body { - Body::Bytes(ref b) => Some(b.clone()), + AnyBody::Bytes { ref body } => Some(body.clone()), _ => None, }; @@ -104,6 +107,7 @@ where max_redirect_times, uri: Some(uri), method: Some(method), + headers: Some(headers), body: body_opt, addr, connector: Some(connector), @@ -127,9 +131,10 @@ pin_project_lite::pin_project! { max_redirect_times: u8, uri: Option, method: Option, + headers: Option, body: Option, addr: Option, - connector: Option> + connector: Option>, } } } @@ -148,6 +153,7 @@ where max_redirect_times, uri, method, + headers, body, addr, connector, @@ -156,79 +162,62 @@ where StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND | StatusCode::SEE_OTHER + | StatusCode::TEMPORARY_REDIRECT + | StatusCode::PERMANENT_REDIRECT if *max_redirect_times > 0 => { - let org_uri = uri.take().unwrap(); - // rebuild uri from the location header value. - let uri = rebuild_uri(&res, org_uri)?; + let is_redirect = res.head().status == StatusCode::TEMPORARY_REDIRECT + || res.head().status == StatusCode::PERMANENT_REDIRECT; - // reset method - let method = method.take().unwrap(); - let method = match method { - Method::GET | Method::HEAD => method, - _ => Method::GET, - }; + let prev_uri = uri.take().unwrap(); + + // rebuild uri from the location header value. + let next_uri = build_next_uri(&res, &prev_uri)?; // take ownership of states that could be reused let addr = addr.take(); let connector = connector.take(); - let mut max_redirect_times = *max_redirect_times; - // use a new request head. - let mut head = RequestHead::default(); - head.uri = uri.clone(); - head.method = method.clone(); - - let head = RequestHeadType::Owned(head); - - max_redirect_times -= 1; - - let fut = connector - .as_ref() - .unwrap() - // remove body - .call(ConnectRequest::Client(head, Body::None, addr)); - - self.set(RedirectServiceFuture::Client { - fut, - max_redirect_times, - uri: Some(uri), - method: Some(method), - // body is dropped on 301,302,303 - body: None, - addr, - connector, - }); - - self.poll(cx) - } - StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT - if *max_redirect_times > 0 => - { - let org_uri = uri.take().unwrap(); - // rebuild uri from the location header value. - let uri = rebuild_uri(&res, org_uri)?; - - // try to reuse body - let body = body.take(); - let body_new = match body { - Some(ref bytes) => Body::Bytes(bytes.clone()), - // TODO: should this be Body::Empty or Body::None. - _ => Body::Empty, + // reset method + let method = if is_redirect { + method.take().unwrap() + } else { + let method = method.take().unwrap(); + match method { + Method::GET | Method::HEAD => method, + _ => Method::GET, + } }; - let addr = addr.take(); - let method = method.take().unwrap(); - let connector = connector.take(); - let mut max_redirect_times = *max_redirect_times; + let mut body = body.take(); + let body_new = if is_redirect { + // try to reuse body + match body { + Some(ref bytes) => AnyBody::Bytes { + body: bytes.clone(), + }, + // TODO: should this be AnyBody::Empty or AnyBody::None. + _ => AnyBody::empty(), + } + } else { + body = None; + // remove body + AnyBody::None + }; + + let mut headers = headers.take().unwrap(); + + remove_sensitive_headers(&mut headers, &prev_uri, &next_uri); // use a new request head. let mut head = RequestHead::default(); - head.uri = uri.clone(); + head.uri = next_uri.clone(); head.method = method.clone(); + head.headers = headers.clone(); let head = RequestHeadType::Owned(head); + let mut max_redirect_times = *max_redirect_times; max_redirect_times -= 1; let fut = connector @@ -239,8 +228,9 @@ where self.set(RedirectServiceFuture::Client { fut, max_redirect_times, - uri: Some(uri), + uri: Some(next_uri), method: Some(method), + headers: Some(headers), body, addr, connector, @@ -256,7 +246,7 @@ where } } -fn rebuild_uri(res: &ClientResponse, org_uri: Uri) -> Result { +fn build_next_uri(res: &ClientResponse, prev_uri: &Uri) -> Result { let uri = res .headers() .get(header::LOCATION) @@ -266,8 +256,8 @@ fn rebuild_uri(res: &ClientResponse, org_uri: Uri) -> Result(uri) @@ -281,12 +271,25 @@ fn rebuild_uri(res: &ClientResponse, org_uri: Uri) -> Result HttpResponse { + HttpResponse::TemporaryRedirect() + .append_header(("location", "/test")) + .finish() + } + + async fn test(req: HttpRequest, body: Bytes) -> HttpResponse { + if req.method() == Method::POST && !body.is_empty() { + HttpResponse::Ok().finish() + } else { + HttpResponse::InternalServerError().finish() + } + } + + App::new() + .service(web::resource("/").route(web::to(root))) + .service(web::resource("/test").route(web::to(test))) + }); + + let res = srv.post("/").send_body("Hello").await.unwrap(); + assert_eq!(res.status().as_u16(), 200); + } + + #[actix_rt::test] + async fn test_redirect_status_kind_301_302_303() { + let srv = actix_test::start(|| { + async fn root() -> HttpResponse { + HttpResponse::Found() + .append_header(("location", "/test")) + .finish() + } + + async fn test(req: HttpRequest, body: Bytes) -> HttpResponse { + if (req.method() == Method::GET || req.method() == Method::HEAD) + && body.is_empty() + { + HttpResponse::Ok().finish() + } else { + HttpResponse::InternalServerError().finish() + } + } + + App::new() + .service(web::resource("/").route(web::to(root))) + .service(web::resource("/test").route(web::to(test))) + }); + + let res = srv.post("/").send_body("Hello").await.unwrap(); + assert_eq!(res.status().as_u16(), 200); + + let res = srv.post("/").send().await.unwrap(); + assert_eq!(res.status().as_u16(), 200); + } + + #[actix_rt::test] + async fn test_redirect_headers() { + let srv = actix_test::start(|| { + async fn root(req: HttpRequest) -> HttpResponse { + if req + .headers() + .get("custom") + .unwrap_or(&HeaderValue::from_str("").unwrap()) + == "value" + { + HttpResponse::Found() + .append_header(("location", "/test")) + .finish() + } else { + HttpResponse::InternalServerError().finish() + } + } + + async fn test(req: HttpRequest) -> HttpResponse { + if req + .headers() + .get("custom") + .unwrap_or(&HeaderValue::from_str("").unwrap()) + == "value" + { + HttpResponse::Ok().finish() + } else { + HttpResponse::InternalServerError().finish() + } + } + + App::new() + .service(web::resource("/").route(web::to(root))) + .service(web::resource("/test").route(web::to(test))) + }); + + let client = ClientBuilder::new() + .header("custom", "value") + .disable_redirects() + .finish(); + let res = client.get(srv.url("/")).send().await.unwrap(); + assert_eq!(res.status().as_u16(), 302); + + let client = ClientBuilder::new().header("custom", "value").finish(); + let res = client.get(srv.url("/")).send().await.unwrap(); + assert_eq!(res.status().as_u16(), 200); + + let client = ClientBuilder::new().finish(); + let res = client + .get(srv.url("/")) + .insert_header(("custom", "value")) + .send() + .await + .unwrap(); + assert_eq!(res.status().as_u16(), 200); + } + + #[actix_rt::test] + async fn test_redirect_cross_origin_headers() { + // defining two services to have two different origins + let srv2 = actix_test::start(|| { + async fn root(req: HttpRequest) -> HttpResponse { + if req.headers().get(header::AUTHORIZATION).is_none() { + HttpResponse::Ok().finish() + } else { + HttpResponse::InternalServerError().finish() + } + } + + App::new().service(web::resource("/").route(web::to(root))) + }); + let srv2_port: u16 = srv2.addr().port(); + + let srv1 = actix_test::start(move || { + async fn root(req: HttpRequest) -> HttpResponse { + let port = *req.app_data::().unwrap(); + if req.headers().get(header::AUTHORIZATION).is_some() { + HttpResponse::Found() + .append_header(( + "location", + format!("http://localhost:{}/", port).as_str(), + )) + .finish() + } else { + HttpResponse::InternalServerError().finish() + } + } + + async fn test1(req: HttpRequest) -> HttpResponse { + if req.headers().get(header::AUTHORIZATION).is_some() { + HttpResponse::Found() + .append_header(("location", "/test2")) + .finish() + } else { + HttpResponse::InternalServerError().finish() + } + } + + async fn test2(req: HttpRequest) -> HttpResponse { + if req.headers().get(header::AUTHORIZATION).is_some() { + HttpResponse::Ok().finish() + } else { + HttpResponse::InternalServerError().finish() + } + } + + App::new() + .app_data(srv2_port) + .service(web::resource("/").route(web::to(root))) + .service(web::resource("/test1").route(web::to(test1))) + .service(web::resource("/test2").route(web::to(test2))) + }); + + // send a request to different origins, http://srv1/ then http://srv2/. So it should remove the header + let client = ClientBuilder::new() + .header(header::AUTHORIZATION, "auth_key_value") + .finish(); + let res = client.get(srv1.url("/")).send().await.unwrap(); + assert_eq!(res.status().as_u16(), 200); + + // send a request to same origin, http://srv1/test1 then http://srv1/test2. So it should NOT remove any header + let res = client.get(srv1.url("/test1")).send().await.unwrap(); + assert_eq!(res.status().as_u16(), 200); + } + + #[actix_rt::test] + async fn test_remove_sensitive_headers() { + fn gen_headers() -> header::HeaderMap { + let mut headers = header::HeaderMap::new(); + headers.insert(header::USER_AGENT, HeaderValue::from_str("value").unwrap()); + headers.insert( + header::AUTHORIZATION, + HeaderValue::from_str("value").unwrap(), + ); + headers.insert( + header::PROXY_AUTHORIZATION, + HeaderValue::from_str("value").unwrap(), + ); + headers.insert(header::COOKIE, HeaderValue::from_str("value").unwrap()); + headers + } + + // Same origin + let prev_uri = Uri::from_str("https://host/path1").unwrap(); + let next_uri = Uri::from_str("https://host/path2").unwrap(); + let mut headers = gen_headers(); + remove_sensitive_headers(&mut headers, &prev_uri, &next_uri); + assert_eq!(headers.len(), 4); + + // different schema + let prev_uri = Uri::from_str("http://host/").unwrap(); + let next_uri = Uri::from_str("https://host/").unwrap(); + let mut headers = gen_headers(); + remove_sensitive_headers(&mut headers, &prev_uri, &next_uri); + assert_eq!(headers.len(), 1); + + // different host + let prev_uri = Uri::from_str("https://host1/").unwrap(); + let next_uri = Uri::from_str("https://host2/").unwrap(); + let mut headers = gen_headers(); + remove_sensitive_headers(&mut headers, &prev_uri, &next_uri); + assert_eq!(headers.len(), 1); + + // different port + let prev_uri = Uri::from_str("https://host:12/").unwrap(); + let next_uri = Uri::from_str("https://host:23/").unwrap(); + let mut headers = gen_headers(); + remove_sensitive_headers(&mut headers, &prev_uri, &next_uri); + assert_eq!(headers.len(), 1); + + // different everything! + let prev_uri = Uri::from_str("http://host1:12/path1").unwrap(); + let next_uri = Uri::from_str("https://host2:23/path2").unwrap(); + let mut headers = gen_headers(); + remove_sensitive_headers(&mut headers, &prev_uri, &next_uri); + assert_eq!(headers.len(), 1); + } } diff --git a/awc/src/request.rs b/awc/src/request.rs index 812c76318..d26b703f6 100644 --- a/awc/src/request.rs +++ b/awc/src/request.rs @@ -1,11 +1,10 @@ -use std::{convert::TryFrom, error::Error as StdError, fmt, net, rc::Rc, time::Duration}; +use std::{convert::TryFrom, fmt, net, rc::Rc, time::Duration}; use bytes::Bytes; use futures_core::Stream; use serde::Serialize; use actix_http::{ - body::Body, http::{ header::{self, IntoHeaderPair}, ConnectionType, Error as HttpError, HeaderMap, HeaderValue, Method, Uri, Version, @@ -13,15 +12,17 @@ use actix_http::{ RequestHead, }; -#[cfg(feature = "cookies")] -use crate::cookie::{Cookie, CookieJar}; use crate::{ + any_body::AnyBody, error::{FreezeRequestError, InvalidUrl}, frozen::FrozenClientRequest, sender::{PrepForSendingError, RequestSender, SendClientRequest}, - ClientConfig, + BoxError, ClientConfig, }; +#[cfg(feature = "cookies")] +use crate::cookie::{Cookie, CookieJar}; + /// An HTTP Client request builder /// /// This type can be used to construct an instance of `ClientRequest` through a @@ -115,10 +116,10 @@ impl ClientRequest { &self.head.method } - #[doc(hidden)] /// Set HTTP version of this request. /// /// By default requests's HTTP version depends on network stream + #[doc(hidden)] #[inline] pub fn version(mut self, version: Version) -> Self { self.head.version = version; @@ -350,7 +351,7 @@ impl ClientRequest { /// Complete request construction and send body. pub fn send_body(self, body: B) -> SendClientRequest where - B: Into, + B: Into, { let slf = match self.prep_for_sending() { Ok(slf) => slf, @@ -404,7 +405,7 @@ impl ClientRequest { pub fn send_stream(self, stream: S) -> SendClientRequest where S: Stream> + Unpin + 'static, - E: Into> + 'static, + E: Into + 'static, { let slf = match self.prep_for_sending() { Ok(slf) => slf, diff --git a/awc/src/sender.rs b/awc/src/sender.rs index c0639606e..51fce1913 100644 --- a/awc/src/sender.rs +++ b/awc/src/sender.rs @@ -1,5 +1,4 @@ use std::{ - error::Error as StdError, future::Future, net, pin::Pin, @@ -9,12 +8,12 @@ use std::{ }; use actix_http::{ - body::{Body, BodyStream}, + body::BodyStream, http::{ header::{self, HeaderMap, HeaderName, IntoHeaderValue}, Error as HttpError, }, - Error, RequestHead, RequestHeadType, + RequestHead, RequestHeadType, }; use actix_rt::time::{sleep, Sleep}; use bytes::Bytes; @@ -26,8 +25,9 @@ use serde::Serialize; use actix_http::{encoding::Decoder, http::header::ContentEncoding, Payload, PayloadStream}; use crate::{ + any_body::AnyBody, error::{FreezeRequestError, InvalidUrl, SendRequestError}, - ClientConfig, ClientResponse, ConnectRequest, ConnectResponse, + BoxError, ClientConfig, ClientResponse, ConnectRequest, ConnectResponse, }; #[derive(Debug, From)] @@ -162,12 +162,6 @@ impl From for SendClientRequest { } } -impl From for SendClientRequest { - fn from(e: Error) -> Self { - SendClientRequest::Err(Some(e.into())) - } -} - impl From for SendClientRequest { fn from(e: HttpError) -> Self { SendClientRequest::Err(Some(e.into())) @@ -196,7 +190,7 @@ impl RequestSender { body: B, ) -> SendClientRequest where - B: Into, + B: Into, { let req = match self { RequestSender::Owned(head) => { @@ -236,7 +230,9 @@ impl RequestSender { response_decompress, timeout, config, - Body::Bytes(Bytes::from(body)), + AnyBody::Bytes { + body: Bytes::from(body), + }, ) } @@ -265,7 +261,9 @@ impl RequestSender { response_decompress, timeout, config, - Body::Bytes(Bytes::from(body)), + AnyBody::Bytes { + body: Bytes::from(body), + }, ) } @@ -279,14 +277,14 @@ impl RequestSender { ) -> SendClientRequest where S: Stream> + Unpin + 'static, - E: Into> + 'static, + E: Into + 'static, { self.send_body( addr, response_decompress, timeout, config, - Body::from_message(BodyStream::new(stream)), + AnyBody::new_boxed(BodyStream::new(stream)), ) } @@ -297,7 +295,7 @@ impl RequestSender { timeout: Option, config: &ClientConfig, ) -> SendClientRequest { - self.send_body(addr, response_decompress, timeout, config, Body::Empty) + self.send_body(addr, response_decompress, timeout, config, AnyBody::empty()) } fn set_header_if_none(&mut self, key: HeaderName, value: V) -> Result<(), HttpError> diff --git a/awc/src/ws.rs b/awc/src/ws.rs index 2fe36399c..e2f1f86d0 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -312,9 +312,8 @@ impl WebsocketsRequest { ); } - // Generate a random key for the `Sec-WebSocket-Key` header. - // a base64-encoded (see Section 4 of [RFC4648]) value that, - // when decoded, is 16 bytes in length (RFC 6455) + // Generate a random key for the `Sec-WebSocket-Key` header which is a base64-encoded + // (see RFC 4648 §4) value that, when decoded, is 16 bytes in length (RFC 6455 §1.3). let sec_key: [u8; 16] = rand::random(); let key = base64::encode(&sec_key); diff --git a/awc/tests/test_client.rs b/awc/tests/test_client.rs index 615789fb3..5abb63e39 100644 --- a/awc/tests/test_client.rs +++ b/awc/tests/test_client.rs @@ -1,20 +1,26 @@ -use std::collections::HashMap; -use std::io::{Read, Write}; -use std::net::{IpAddr, Ipv4Addr}; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::time::Duration; +use std::{ + collections::HashMap, + io::{Read, Write}, + net::{IpAddr, Ipv4Addr}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + time::Duration, +}; use actix_utils::future::ok; -use brotli2::write::BrotliEncoder; use bytes::Bytes; use cookie::Cookie; -use flate2::read::GzDecoder; -use flate2::write::GzEncoder; -use flate2::Compression; use futures_util::stream; use rand::Rng; +#[cfg(feature = "compress-brotli")] +use brotli2::write::BrotliEncoder; + +#[cfg(feature = "compress-gzip")] +use flate2::{read::GzDecoder, write::GzEncoder, Compression}; + use actix_http::{ http::{self, StatusCode}, HttpService, @@ -24,7 +30,6 @@ use actix_service::{fn_service, map_config, ServiceFactoryExt as _}; use actix_web::{ dev::{AppConfig, BodyEncoding}, http::header, - middleware::Compress, web, App, Error, HttpRequest, HttpResponse, }; use awc::error::{JsonPayloadError, PayloadError, SendRequestError}; @@ -122,7 +127,7 @@ async fn test_timeout() { }); let connector = awc::Connector::new() - .connector(actix_tls::connect::default_connector()) + .connector(actix_tls::connect::ConnectorService::default()) .timeout(Duration::from_secs(15)); let client = awc::Client::builder() @@ -463,11 +468,12 @@ async fn test_with_query_parameter() { assert!(res.status().is_success()); } +#[cfg(feature = "compress-gzip")] #[actix_rt::test] async fn test_no_decompress() { let srv = actix_test::start(|| { App::new() - .wrap(Compress::default()) + .wrap(actix_web::middleware::Compress::default()) .service(web::resource("/").route(web::to(|| { let mut res = HttpResponse::Ok().body(STR); res.encoding(header::ContentEncoding::Gzip); @@ -507,6 +513,7 @@ async fn test_no_decompress() { assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } +#[cfg(feature = "compress-gzip")] #[actix_rt::test] async fn test_client_gzip_encoding() { let srv = actix_test::start(|| { @@ -530,6 +537,7 @@ async fn test_client_gzip_encoding() { assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } +#[cfg(feature = "compress-gzip")] #[actix_rt::test] async fn test_client_gzip_encoding_large() { let srv = actix_test::start(|| { @@ -553,6 +561,7 @@ async fn test_client_gzip_encoding_large() { assert_eq!(bytes, Bytes::from(STR.repeat(10))); } +#[cfg(feature = "compress-gzip")] #[actix_rt::test] async fn test_client_gzip_encoding_large_random() { let data = rand::thread_rng() @@ -581,6 +590,7 @@ async fn test_client_gzip_encoding_large_random() { assert_eq!(bytes, Bytes::from(data)); } +#[cfg(feature = "compress-brotli")] #[actix_rt::test] async fn test_client_brotli_encoding() { let srv = actix_test::start(|| { @@ -603,6 +613,7 @@ async fn test_client_brotli_encoding() { assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } +#[cfg(feature = "compress-brotli")] #[actix_rt::test] async fn test_client_brotli_encoding_large_random() { let data = rand::thread_rng() @@ -795,17 +806,15 @@ async fn client_unread_response() { let lst = std::net::TcpListener::bind(addr).unwrap(); std::thread::spawn(move || { - for stream in lst.incoming() { - let mut stream = stream.unwrap(); - let mut b = [0; 1000]; - let _ = stream.read(&mut b).unwrap(); - let _ = stream.write_all( - b"HTTP/1.1 200 OK\r\n\ + let (mut stream, _) = lst.accept().unwrap(); + let mut b = [0; 1000]; + let _ = stream.read(&mut b).unwrap(); + let _ = stream.write_all( + b"HTTP/1.1 200 OK\r\n\ connection: close\r\n\ \r\n\ welcome!", - ); - } + ); }); // client request diff --git a/awc/tests/test_connector.rs b/awc/tests/test_connector.rs index 632f68b72..588c51463 100644 --- a/awc/tests/test_connector.rs +++ b/awc/tests/test_connector.rs @@ -39,7 +39,7 @@ fn tls_config() -> SslAcceptor { #[actix_rt::test] async fn test_connection_window_size() { - let srv = test_server(move || { + let srv = test_server(|| { HttpService::build() .h2(map_config( App::new().service(web::resource("/").route(web::to(HttpResponse::Ok))), diff --git a/awc/tests/test_rustls_client.rs b/awc/tests/test_rustls_client.rs index bc811c046..652997de6 100644 --- a/awc/tests/test_rustls_client.rs +++ b/awc/tests/test_rustls_client.rs @@ -8,44 +8,59 @@ use std::{ atomic::{AtomicUsize, Ordering}, Arc, }, + time::SystemTime, }; use actix_http::HttpService; use actix_http_test::test_server; use actix_service::{fn_service, map_config, ServiceFactoryExt}; +use actix_tls::connect::rustls::webpki_roots_cert_store; use actix_utils::future::ok; use actix_web::{dev::AppConfig, http::Version, web, App, HttpResponse}; -use rustls::internal::pemfile::{certs, pkcs8_private_keys}; -use rustls::{ClientConfig, NoClientAuth, ServerConfig}; +use rustls::{ + client::{ServerCertVerified, ServerCertVerifier}, + Certificate, ClientConfig, PrivateKey, ServerConfig, ServerName, +}; +use rustls_pemfile::{certs, pkcs8_private_keys}; fn tls_config() -> ServerConfig { let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); let cert_file = cert.serialize_pem().unwrap(); let key_file = cert.serialize_private_key_pem(); - let mut config = ServerConfig::new(NoClientAuth::new()); let cert_file = &mut BufReader::new(cert_file.as_bytes()); let key_file = &mut BufReader::new(key_file.as_bytes()); - let cert_chain = certs(cert_file).unwrap(); + let cert_chain = certs(cert_file) + .unwrap() + .into_iter() + .map(Certificate) + .collect(); let mut keys = pkcs8_private_keys(key_file).unwrap(); - config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); - config + ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(cert_chain, PrivateKey(keys.remove(0))) + .unwrap() } mod danger { + use super::*; + pub struct NoCertificateVerification; - impl rustls::ServerCertVerifier for NoCertificateVerification { + impl ServerCertVerifier for NoCertificateVerification { fn verify_server_cert( &self, - _roots: &rustls::RootCertStore, - _presented_certs: &[rustls::Certificate], - _dns_name: webpki::DNSNameRef<'_>, - _ocsp: &[u8], - ) -> Result { - Ok(rustls::ServerCertVerified::assertion()) + _end_entity: &Certificate, + _intermediates: &[Certificate], + _server_name: &ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: SystemTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) } } } @@ -73,10 +88,15 @@ async fn test_connection_reuse_h2() { }) .await; - // disable TLS verification - let mut config = ClientConfig::new(); + let mut config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(webpki_roots_cert_store()) + .with_no_client_auth(); + let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - config.set_protocols(&protos); + config.alpn_protocols = protos; + + // disable TLS verification config .dangerous() .set_certificate_verifier(Arc::new(danger::NoCertificateVerification)); diff --git a/benches/responder.rs b/benches/responder.rs index 5d0b98d5f..20aae3351 100644 --- a/benches/responder.rs +++ b/benches/responder.rs @@ -1,9 +1,10 @@ use std::{future::Future, time::Instant}; +use actix_http::body::BoxBody; use actix_utils::future::{ready, Ready}; -use actix_web::http::StatusCode; -use actix_web::test::TestRequest; -use actix_web::{error, Error, HttpRequest, HttpResponse, Responder}; +use actix_web::{ + error, http::StatusCode, test::TestRequest, Error, HttpRequest, HttpResponse, Responder, +}; use criterion::{criterion_group, criterion_main, Criterion}; use futures_util::future::{join_all, Either}; @@ -50,7 +51,9 @@ where } impl Responder for StringResponder { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { + type Body = BoxBody; + + fn respond_to(self, _: &HttpRequest) -> HttpResponse { HttpResponse::build(StatusCode::OK) .content_type("text/plain; charset=utf-8") .body(self.0) @@ -58,9 +61,11 @@ impl Responder for StringResponder { } impl Responder for OptionResponder { - fn respond_to(self, req: &HttpRequest) -> HttpResponse { + type Body = BoxBody; + + fn respond_to(self, req: &HttpRequest) -> HttpResponse { match self.0 { - Some(t) => t.respond_to(req), + Some(t) => t.respond_to(req).map_into_boxed_body(), None => HttpResponse::from_error(error::ErrorInternalServerError("err")), } } diff --git a/clippy.toml b/clippy.toml index eb66960ac..cef91fde7 100644 --- a/clippy.toml +++ b/clippy.toml @@ -1 +1 @@ -msrv = "1.46" +msrv = "1.52" diff --git a/docs/graphs/net-only.dot b/docs/graphs/net-only.dot index bee0185ab..8a58ec2b8 100644 --- a/docs/graphs/net-only.dot +++ b/docs/graphs/net-only.dot @@ -4,7 +4,7 @@ digraph { subgraph cluster_net { label="actix-net" "actix-codec" "actix-macros" "actix-rt" "actix-server" "actix-service" - "actix-tls" "actix-tracing" "actix-utils" "actix-router" + "actix-tls" "actix-tracing" "actix-utils" } subgraph cluster_other { @@ -25,7 +25,6 @@ digraph { "actix-tls" -> { "tokio-util" }[color="#009900"] "actix-server" -> { "actix-service" "actix-rt" "actix-utils" "tokio" } "actix-rt" -> { "actix-macros" "tokio" } - "actix-router" -> { "bytestring" } "local-channel" -> { "local-waker" } diff --git a/docs/graphs/web-focus.dot b/docs/graphs/web-focus.dot index 2c6e2779b..63b3eaa82 100644 --- a/docs/graphs/web-focus.dot +++ b/docs/graphs/web-focus.dot @@ -10,6 +10,7 @@ digraph { "web-actors" "web-codegen" "http-test" + "router" { rank=same; "multipart" "web-actors" "http-test" }; { rank=same; "files" "awc" "web" }; @@ -36,7 +37,7 @@ digraph { "rt" -> { "macros" } { rank=same; "utils" "codec" }; - { rank=same; "rt" "macros" "service" "router" }; + { rank=same; "rt" "macros" "service" }; // actix diff --git a/docs/graphs/web-only.dot b/docs/graphs/web-only.dot index b0decd818..ee74c292b 100644 --- a/docs/graphs/web-only.dot +++ b/docs/graphs/web-only.dot @@ -10,9 +10,10 @@ digraph { "actix-web-codegen" "actix-http-test" "actix-test" + "actix-router" } - "actix-web" -> { "actix-web-codegen" "actix-http" } + "actix-web" -> { "actix-web-codegen" "actix-http" "actix-router" } "awc" -> { "actix-http" } "actix-web-actors" -> { "actix" "actix-web" "actix-http" } "actix-multipart" -> { "actix-web" } diff --git a/examples/basic.rs b/examples/basic.rs index 796f002e8..d29546129 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -35,7 +35,7 @@ async fn main() -> std::io::Result<()> { ) .service(web::resource("/test1.html").to(|| async { "Test\r\n" })) }) - .bind("127.0.0.1:8080")? + .bind(("127.0.0.1", 8080))? .workers(1) .run() .await diff --git a/examples/on_connect.rs b/examples/on_connect.rs index bd77229f0..6d05716f1 100644 --- a/examples/on_connect.rs +++ b/examples/on_connect.rs @@ -9,6 +9,7 @@ use std::{any::Any, io, net::SocketAddr}; use actix_http::CloneableExtensions; use actix_web::{rt::net::TcpStream, web, App, HttpServer}; +#[allow(dead_code)] #[derive(Debug, Clone)] struct ConnectionInfo { bind: SocketAddr, diff --git a/scripts/bump b/scripts/bump new file mode 100755 index 000000000..8b6a3c424 --- /dev/null +++ b/scripts/bump @@ -0,0 +1,111 @@ +#!/bin/sh + +# developed on macOS and probably doesn't work on Linux yet due to minor +# differences in flags on sed + +# requires github cli tool for automatic release draft creation + +set -euo pipefail + +DIR=$1 + +LINUX="" +MACOS="" + +if [ "$(uname)" = "Darwin" ]; then + MACOS="1" +fi + +CARGO_MANIFEST=$DIR/Cargo.toml +CHANGELOG_FILE=$DIR/CHANGES.md +README_FILE=$DIR/README.md + +# get current version +PACKAGE_NAME="$(sed -nE 's/^name ?= ?"([^"]+)"$/\1/ p' "$CARGO_MANIFEST" | head -n 1)" +CURRENT_VERSION="$(sed -nE 's/^version ?= ?"([^"]+)"$/\1/ p' "$CARGO_MANIFEST")" + +CHANGE_CHUNK_FILE="$(mktemp)" +echo saving changelog to $CHANGE_CHUNK_FILE +echo + +# get changelog chunk and save to temp file +cat "$CHANGELOG_FILE" | + # skip up to unreleased heading + sed '1,/Unreleased/ d' | + # take up to previous version heading + sed "/$CURRENT_VERSION/ q" | + # drop last line + sed '$d' \ + >"$CHANGE_CHUNK_FILE" + +# if word count of changelog chunk is 0 then insert filler changelog chunk +if [ "$(wc -w "$CHANGE_CHUNK_FILE" | awk '{ print $1 }')" = "0" ]; then + echo "* No significant changes since \`$CURRENT_VERSION\`." >"$CHANGE_CHUNK_FILE" +fi + +if [ -n "${2-}" ]; then + NEW_VERSION="$2" +else + echo + echo "--- Changes since $CURRENT_VERSION ----" + cat "$CHANGE_CHUNK_FILE" + echo + read -p "Update version to: " NEW_VERSION +fi + +DATE="$(date -u +"%Y-%m-%d")" +echo "updating from $CURRENT_VERSION => $NEW_VERSION ($DATE)" + +# update package.version field +sed -i.bak -E "s/^version ?= ?\"[^\"]+\"$/version = \"$NEW_VERSION\"/" "$CARGO_MANIFEST" + +# update readme +[ -f "$README_FILE" ] && sed -i.bak -E "s#$CURRENT_VERSION([/)])#$NEW_VERSION\1#g" "$README_FILE" + +# update changelog file +( + sed '/Unreleased/ q' "$CHANGELOG_FILE" # up to unreleased heading + echo # blank line + echo # blank line + echo "## $NEW_VERSION - $DATE" # new version heading + cat "$CHANGE_CHUNK_FILE" # previously unreleased changes + sed "/$CURRENT_VERSION/ q" "$CHANGELOG_FILE" | tail -n 1 # the previous version heading + sed "1,/$CURRENT_VERSION/ d" "$CHANGELOG_FILE" # everything after previous version heading +) >"$CHANGELOG_FILE.bak" +mv "$CHANGELOG_FILE.bak" "$CHANGELOG_FILE" + +# done; remove backup files +rm -f $CARGO_MANIFEST.bak +rm -f $CHANGELOG_FILE.bak +rm -f $README_FILE.bak + +echo "manifest, changelog, and readme updated" +echo +echo "check other references:" +rg "$PACKAGE_NAME =" || true +rg "package = \"$PACKAGE_NAME\"" || true + +if [ $MACOS ]; then + printf "prepare $PACKAGE_NAME release $NEW_VERSION" | pbcopy +else + echo + echo "commit message:" + echo "prepare $PACKAGE_NAME release $NEW_VERSION" +fi + +SHORT_PACKAGE_NAME="$(echo $PACKAGE_NAME | sed 's/^actix-web-//' | sed 's/^actix-//')" +GIT_TAG="$(echo $SHORT_PACKAGE_NAME-v$NEW_VERSION)" +RELEASE_TITLE="$(echo $PACKAGE_NAME: v$NEW_VERSION)" + +echo +echo "GitHub release command:" +echo "gh release create \"$GIT_TAG\" --draft --title \"$RELEASE_TITLE\" --notes-file \"$CHANGE_CHUNK_FILE\" --prerelease" + +read -p "Submit draft GH release: (y/N) " GH_RELEASE +GH_RELEASE="${GH_RELEASE:-n}" + +if [ "$GH_RELEASE" = 'y' ] || [ "$GH_RELEASE" = 'Y' ]; then + gh release create "$GIT_TAG" --draft --title "$RELEASE_TITLE" --notes-file "$CHANGE_CHUNK_FILE" --prerelease +fi + +echo diff --git a/scripts/ci-test b/scripts/ci-test new file mode 100755 index 000000000..98e13927d --- /dev/null +++ b/scripts/ci-test @@ -0,0 +1,18 @@ +#!/bin/sh + +# run tests matching what CI does for non-linux feature sets + +set -x + +cargo test --lib --tests -p=actix-router --all-features +cargo test --lib --tests -p=actix-http --all-features +cargo test --lib --tests -p=actix-web --features=rustls,openssl -- --skip=test_reading_deflate_encoding_large_random_rustls +cargo test --lib --tests -p=actix-web-codegen --all-features +cargo test --lib --tests -p=awc --all-features +cargo test --lib --tests -p=actix-http-test --all-features +cargo test --lib --tests -p=actix-test --all-features +cargo test --lib --tests -p=actix-files +cargo test --lib --tests -p=actix-multipart --all-features +cargo test --lib --tests -p=actix-web-actors --all-features + +cargo test --workspace --doc diff --git a/src/app.rs b/src/app.rs index 5cff20568..efc108cb9 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,37 +1,35 @@ -use std::cell::RefCell; -use std::fmt; -use std::future::Future; -use std::marker::PhantomData; -use std::rc::Rc; +use std::{cell::RefCell, fmt, future::Future, marker::PhantomData, rc::Rc}; -use actix_http::body::{Body, MessageBody}; -use actix_http::{Extensions, Request}; -use actix_service::boxed::{self, BoxServiceFactory}; +use actix_http::{ + body::{BoxBody, MessageBody}, + Extensions, Request, +}; use actix_service::{ - apply, apply_fn_factory, IntoServiceFactory, ServiceFactory, ServiceFactoryExt, Transform, + apply, apply_fn_factory, boxed, IntoServiceFactory, ServiceFactory, ServiceFactoryExt, + Transform, }; use futures_util::future::FutureExt as _; -use crate::app_service::{AppEntry, AppInit, AppRoutingFactory}; -use crate::config::ServiceConfig; -use crate::data::{Data, DataFactory, FnDataFactory}; -use crate::dev::ResourceDef; -use crate::error::Error; -use crate::resource::Resource; -use crate::route::Route; -use crate::service::{ - AppServiceFactory, HttpServiceFactory, ServiceFactoryWrapper, ServiceRequest, - ServiceResponse, +use crate::{ + app_service::{AppEntry, AppInit, AppRoutingFactory}, + config::ServiceConfig, + data::{Data, DataFactory, FnDataFactory}, + dev::ResourceDef, + error::Error, + resource::Resource, + route::Route, + service::{ + AppServiceFactory, BoxedHttpServiceFactory, HttpServiceFactory, ServiceFactoryWrapper, + ServiceRequest, ServiceResponse, + }, }; -type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; - /// Application builder - structure that follows the builder pattern /// for building application instances. pub struct App { endpoint: T, services: Vec>, - default: Option>, + default: Option>, factory_ref: Rc>>, data_factories: Vec, external: Vec, @@ -39,7 +37,7 @@ pub struct App { _phantom: PhantomData, } -impl App { +impl App { /// Create application builder. Application can be configured with a builder-like pattern. #[allow(clippy::new_without_default)] pub fn new() -> Self { @@ -142,10 +140,6 @@ where /// Add application data factory. This function is similar to `.data()` but it accepts a /// "data factory". Data values are constructed asynchronously during application /// initialization, before the server starts accepting requests. - #[deprecated( - since = "4.0.0", - note = "Construct data value before starting server and use `.app_data(Data::new(val))` instead." - )] pub fn data_factory(mut self, data: F) -> Self where F: Fn() -> Out + 'static, @@ -287,7 +281,7 @@ where /// ); /// } /// ``` - pub fn default_service(mut self, f: F) -> Self + pub fn default_service(mut self, svc: F) -> Self where F: IntoServiceFactory, U: ServiceFactory< @@ -298,10 +292,12 @@ where > + 'static, U::InitError: fmt::Debug, { - // create and configure default resource - self.default = Some(Rc::new(boxed::factory(f.into_factory().map_init_err( - |e| log::error!("Can not construct default service: {:?}", e), - )))); + let svc = svc + .into_factory() + .map(|res| res.map_into_boxed_body()) + .map_init_err(|e| log::error!("Can not construct default service: {:?}", e)); + + self.default = Some(Rc::new(boxed::factory(svc))); self } @@ -334,7 +330,7 @@ where U: AsRef, { let mut rdef = ResourceDef::new(url.as_ref()); - *rdef.name_mut() = name.as_ref().to_string(); + rdef.set_name(name.as_ref()); self.external.push(rdef); self } diff --git a/src/app_service.rs b/src/app_service.rs index 3c1b78474..bca8f2629 100644 --- a/src/app_service.rs +++ b/src/app_service.rs @@ -2,10 +2,7 @@ use std::{cell::RefCell, mem, rc::Rc}; use actix_http::{Extensions, Request}; use actix_router::{Path, ResourceDef, Router, Url}; -use actix_service::{ - boxed::{self, BoxService, BoxServiceFactory}, - fn_service, Service, ServiceFactory, -}; +use actix_service::{boxed, fn_service, Service, ServiceFactory}; use futures_core::future::LocalBoxFuture; use futures_util::future::join_all; @@ -15,13 +12,14 @@ use crate::{ guard::Guard, request::{HttpRequest, HttpRequestPool}, rmap::ResourceMap, - service::{AppServiceFactory, ServiceRequest, ServiceResponse}, + service::{ + AppServiceFactory, BoxedHttpService, BoxedHttpServiceFactory, ServiceRequest, + ServiceResponse, + }, Error, HttpResponse, }; type Guards = Vec>; -type HttpService = BoxService; -type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; /// Service factory to convert `Request` to a `ServiceRequest`. /// It also executes data factories. @@ -39,7 +37,7 @@ where pub(crate) extensions: RefCell>, pub(crate) async_data_factories: Rc<[FnDataFactory]>, pub(crate) services: Rc>>>, - pub(crate) default: Option>, + pub(crate) default: Option>, pub(crate) factory_ref: Rc>>, pub(crate) external: RefCell>, } @@ -79,7 +77,7 @@ where .into_iter() .for_each(|mut srv| srv.register(&mut config)); - let mut rmap = ResourceMap::new(ResourceDef::new("")); + let mut rmap = ResourceMap::new(ResourceDef::prefix("")); let (config, services) = config.into_services(); @@ -104,7 +102,7 @@ where // complete ResourceMap tree creation let rmap = Rc::new(rmap); - rmap.finish(rmap.clone()); + ResourceMap::finish(&rmap); // construct all async data factory futures let factory_futs = join_all(self.async_data_factories.iter().map(|f| f())); @@ -230,8 +228,14 @@ where } pub struct AppRoutingFactory { - services: Rc<[(ResourceDef, HttpNewService, RefCell>)]>, - default: Rc, + services: Rc< + [( + ResourceDef, + BoxedHttpServiceFactory, + RefCell>, + )], + >, + default: Rc, } impl ServiceFactory for AppRoutingFactory { @@ -279,8 +283,8 @@ impl ServiceFactory for AppRoutingFactory { /// The Actix Web router default entry point. pub struct AppRouting { - router: Router, - default: HttpService, + router: Router, + default: BoxedHttpService, } impl Service for AppRouting { @@ -291,7 +295,7 @@ impl Service for AppRouting { actix_service::always_ready!(); fn call(&self, mut req: ServiceRequest) -> Self::Future { - let res = self.router.recognize_checked(&mut req, |req, guards| { + let res = self.router.recognize_fn(&mut req, |req, guards| { if let Some(ref guards) = guards { for f in guards { if !f.check(req.head()) { diff --git a/src/config.rs b/src/config.rs index b072ace16..9e77c0f96 100644 --- a/src/config.rs +++ b/src/config.rs @@ -249,7 +249,7 @@ impl ServiceConfig { U: AsRef, { let mut rdef = ResourceDef::new(url.as_ref()); - *rdef.name_mut() = name.as_ref().to_string(); + rdef.set_name(name.as_ref()); self.external.push(rdef); self } diff --git a/src/data.rs b/src/data.rs index 174faba37..b29e4ecf4 100644 --- a/src/data.rs +++ b/src/data.rs @@ -75,7 +75,9 @@ impl Data { pub fn new(state: T) -> Data { Data(Arc::new(state)) } +} +impl Data { /// Get reference to inner app data. pub fn get_ref(&self) -> &T { self.0.as_ref() @@ -120,7 +122,6 @@ where } impl FromRequest for Data { - type Config = (); type Error = Error; type Future = Ready>; @@ -136,7 +137,7 @@ impl FromRequest for Data { type_name::(), ); err(ErrorInternalServerError( - "App data is not configured, to configure use App::data()", + "App data is not configured, to configure construct it with web::Data::new() and pass it to App::app_data()", )) } } @@ -283,7 +284,7 @@ mod tests { async fn test_data_from_arc() { let data_new = Data::new(String::from("test-123")); let data_from_arc = Data::from(Arc::new(String::from("test-123"))); - assert_eq!(data_new.0, data_from_arc.0) + assert_eq!(data_new.0, data_from_arc.0); } #[actix_rt::test] @@ -305,4 +306,38 @@ mod tests { let data_arc = Data::from(dyn_arc); assert_eq!(data_arc_box.get_num(), data_arc.get_num()) } + + #[actix_rt::test] + async fn test_dyn_data_into_arc() { + trait TestTrait { + fn get_num(&self) -> i32; + } + struct A {} + impl TestTrait for A { + fn get_num(&self) -> i32 { + 42 + } + } + let dyn_arc: Arc = Arc::new(A {}); + let data_arc = Data::from(dyn_arc); + let arc_from_data = data_arc.clone().into_inner(); + assert_eq!(data_arc.get_num(), arc_from_data.get_num()) + } + + #[actix_rt::test] + async fn test_get_ref_from_dyn_data() { + trait TestTrait { + fn get_num(&self) -> i32; + } + struct A {} + impl TestTrait for A { + fn get_num(&self) -> i32 { + 42 + } + } + let dyn_arc: Arc = Arc::new(A {}); + let data_arc = Data::from(dyn_arc); + let ref_data = data_arc.get_ref(); + assert_eq!(data_arc.get_num(), ref_data.get_num()) + } } diff --git a/src/dev.rs b/src/dev.rs index 13bc48a29..66a013ea0 100644 --- a/src/dev.rs +++ b/src/dev.rs @@ -1,7 +1,7 @@ //! Lower-level types and re-exports. //! //! Most users will not have to interact with the types in this module, but it is useful for those -//! writing extractors, middleware and libraries, or interacting with the service API directly. +//! writing extractors, middleware, libraries, or interacting with the service API directly. pub use crate::config::{AppConfig, AppService}; #[doc(hidden)] @@ -14,27 +14,37 @@ pub use crate::types::form::UrlEncoded; pub use crate::types::json::JsonBody; pub use crate::types::readlines::Readlines; -pub use actix_http::body::{AnyBody, Body, BodySize, MessageBody, ResponseBody, SizedStream}; - -#[cfg(feature = "__compress")] -pub use actix_http::encoding::Decoder as Decompress; pub use actix_http::{ - CloneableExtensions, Extensions, Payload, PayloadStream, RequestHead, ResponseHead, + CloneableExtensions, Extensions, Payload, PayloadStream, RequestHead, Response, + ResponseHead, }; pub use actix_router::{Path, ResourceDef, ResourcePath, Url}; -pub use actix_server::Server; +pub use actix_server::{Server, ServerHandle}; pub use actix_service::{ always_ready, fn_factory, fn_service, forward_ready, Service, ServiceFactory, Transform, }; -use crate::http::header::ContentEncoding; -use actix_http::{Response, ResponseBuilder}; +#[cfg(feature = "__compress")] +pub use actix_http::encoding::Decoder as Decompress; -pub(crate) fn insert_leading_slash(mut patterns: Vec) -> Vec { - for path in &mut patterns { - if !path.is_empty() && !path.starts_with('/') { - path.insert(0, '/'); - }; +use crate::http::header::ContentEncoding; + +use actix_router::Patterns; + +pub(crate) fn ensure_leading_slash(mut patterns: Patterns) -> Patterns { + match &mut patterns { + Patterns::Single(pat) => { + if !pat.is_empty() && !pat.starts_with('/') { + pat.insert(0, '/'); + }; + } + Patterns::List(pats) => { + for pat in pats { + if !pat.is_empty() && !pat.starts_with('/') { + pat.insert(0, '/'); + }; + } + } } patterns @@ -52,7 +62,7 @@ pub trait BodyEncoding { fn encoding(&mut self, encoding: ContentEncoding) -> &mut Self; } -impl BodyEncoding for ResponseBuilder { +impl BodyEncoding for actix_http::ResponseBuilder { fn get_encoding(&self) -> Option { self.extensions().get::().map(|enc| enc.0) } @@ -63,7 +73,7 @@ impl BodyEncoding for ResponseBuilder { } } -impl BodyEncoding for Response { +impl BodyEncoding for actix_http::Response { fn get_encoding(&self) -> Option { self.extensions().get::().map(|enc| enc.0) } @@ -95,3 +105,41 @@ impl BodyEncoding for crate::HttpResponse { self } } + +// TODO: remove this if it doesn't appear to be needed + +#[allow(dead_code)] +#[derive(Debug)] +pub(crate) enum AnyBody { + None, + Full { body: crate::web::Bytes }, + Boxed { body: actix_http::body::BoxBody }, +} + +impl crate::body::MessageBody for AnyBody { + type Error = crate::BoxError; + + /// Body size hint. + fn size(&self) -> crate::body::BodySize { + match self { + AnyBody::None => crate::body::BodySize::None, + AnyBody::Full { body } => body.size(), + AnyBody::Boxed { body } => body.size(), + } + } + + /// Attempt to pull out the next chunk of body bytes. + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll>> { + match self.get_mut() { + AnyBody::None => std::task::Poll::Ready(None), + AnyBody::Full { body } => { + let bytes = std::mem::take(body); + std::task::Poll::Ready(Some(Ok(bytes))) + } + AnyBody::Boxed { body } => body.as_pin_mut().poll_next(cx), + } + } +} diff --git a/src/error/error.rs b/src/error/error.rs index add290867..be17c1962 100644 --- a/src/error/error.rs +++ b/src/error/error.rs @@ -1,6 +1,6 @@ use std::{error::Error as StdError, fmt}; -use actix_http::{body::AnyBody, Response}; +use actix_http::{body::BoxBody, Response}; use crate::{HttpResponse, ResponseError}; @@ -69,8 +69,8 @@ impl From for Error { } } -impl From for Response { - fn from(err: Error) -> Response { +impl From for Response { + fn from(err: Error) -> Response { err.error_response().into() } } diff --git a/src/error/internal.rs b/src/error/internal.rs index 1d9ca904e..c766ba83e 100644 --- a/src/error/internal.rs +++ b/src/error/internal.rs @@ -1,6 +1,10 @@ use std::{cell::RefCell, fmt, io::Write as _}; -use actix_http::{body::Body, header, StatusCode}; +use actix_http::{ + body::BoxBody, + header::{self, IntoHeaderValue as _}, + StatusCode, +}; use bytes::{BufMut as _, BytesMut}; use crate::{Error, HttpRequest, HttpResponse, Responder, ResponseError}; @@ -84,11 +88,10 @@ where let mut buf = BytesMut::new().writer(); let _ = write!(buf, "{}", self); - res.headers_mut().insert( - header::CONTENT_TYPE, - header::HeaderValue::from_static("text/plain; charset=utf-8"), - ); - res.set_body(Body::from(buf.into_inner())) + let mime = mime::TEXT_PLAIN_UTF_8.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); + + res.set_body(BoxBody::new(buf.into_inner())) } InternalErrorType::Response(ref resp) => { @@ -106,7 +109,9 @@ impl Responder for InternalError where T: fmt::Debug + fmt::Display + 'static, { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { + type Body = BoxBody; + + fn respond_to(self, _: &HttpRequest) -> HttpResponse { HttpResponse::from_error(self) } } diff --git a/src/error/macros.rs b/src/error/macros.rs index 38650c5e8..78b1ed9f6 100644 --- a/src/error/macros.rs +++ b/src/error/macros.rs @@ -97,7 +97,7 @@ mod tests { let resp_body: &mut dyn MB = &mut body; let body = resp_body.downcast_ref::().unwrap(); assert_eq!(body, "hello cast"); - let body = &mut resp_body.downcast_mut::().unwrap(); + let body = resp_body.downcast_mut::().unwrap(); body.push('!'); let body = resp_body.downcast_ref::().unwrap(); assert_eq!(body, "hello cast!"); diff --git a/src/error/mod.rs b/src/error/mod.rs index 3ccd5bba6..46d0dccc6 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -29,15 +29,15 @@ pub type Result = std::result::Result; #[derive(Debug, PartialEq, Display, Error, From)] #[non_exhaustive] pub enum UrlGenerationError { - /// Resource not found + /// Resource not found. #[display(fmt = "Resource not found")] ResourceNotFound, - /// Not all path pattern covered - #[display(fmt = "Not all path pattern covered")] + /// Not all URL parameters covered. + #[display(fmt = "Not all URL parameters covered")] NotEnoughElements, - /// URL parse error + /// URL parse error. #[display(fmt = "{}", _0)] ParseError(UrlParseError), } diff --git a/src/error/response_error.rs b/src/error/response_error.rs index c3c543419..7260efa1a 100644 --- a/src/error/response_error.rs +++ b/src/error/response_error.rs @@ -6,11 +6,17 @@ use std::{ io::{self, Write as _}, }; -use actix_http::{body::AnyBody, header, Response, StatusCode}; +use actix_http::{ + body::BoxBody, + header::{self, IntoHeaderValue}, + Response, StatusCode, +}; use bytes::BytesMut; -use crate::error::{downcast_dyn, downcast_get_type_id}; -use crate::{helpers, HttpResponse}; +use crate::{ + error::{downcast_dyn, downcast_get_type_id}, + helpers, HttpResponse, +}; /// Errors that can generate responses. // TODO: add std::error::Error bound when replacement for Box is found @@ -27,18 +33,16 @@ pub trait ResponseError: fmt::Debug + fmt::Display { /// /// By default, the generated response uses a 500 Internal Server Error status code, a /// `Content-Type` of `text/plain`, and the body is set to `Self`'s `Display` impl. - fn error_response(&self) -> HttpResponse { + fn error_response(&self) -> HttpResponse { let mut res = HttpResponse::new(self.status_code()); let mut buf = BytesMut::new(); let _ = write!(helpers::MutWriter(&mut buf), "{}", self); - res.headers_mut().insert( - header::CONTENT_TYPE, - header::HeaderValue::from_static("text/plain; charset=utf-8"), - ); + let mime = mime::TEXT_PLAIN_UTF_8.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); - res.set_body(AnyBody::from(buf)) + res.set_body(BoxBody::new(buf)) } downcast_get_type_id!(); @@ -49,7 +53,7 @@ downcast_dyn!(ResponseError); impl ResponseError for Box {} #[cfg(feature = "openssl")] -impl ResponseError for actix_tls::accept::openssl::SslError {} +impl ResponseError for actix_tls::accept::openssl::reexports::Error {} impl ResponseError for serde::de::value::Error { fn status_code(&self) -> StatusCode { @@ -86,8 +90,8 @@ impl ResponseError for actix_http::Error { StatusCode::INTERNAL_SERVER_ERROR } - fn error_response(&self) -> HttpResponse { - HttpResponse::new(self.status_code()).set_body(self.to_string().into()) + fn error_response(&self) -> HttpResponse { + HttpResponse::with_body(self.status_code(), self.to_string()).map_into_boxed_body() } } @@ -123,8 +127,8 @@ impl ResponseError for actix_http::error::ContentTypeError { } impl ResponseError for actix_http::ws::HandshakeError { - fn error_response(&self) -> HttpResponse { - Response::from(self).into() + fn error_response(&self) -> HttpResponse { + Response::from(self).map_into_boxed_body().into() } } diff --git a/src/extract.rs b/src/extract.rs index 592f7ab83..bb2dabb9f 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -10,16 +10,56 @@ use std::{ use actix_http::http::{Method, Uri}; use actix_utils::future::{ok, Ready}; use futures_core::ready; +use pin_project_lite::pin_project; use crate::{dev::Payload, Error, HttpRequest}; -/// Trait implemented by types that can be extracted from request. +/// A type that implements [`FromRequest`] is called an **extractor** and can extract data from +/// the request. Some types that implement this trait are: [`Json`], [`Header`], and [`Path`]. /// -/// Types that implement this trait can be used with `Route` handlers. +/// # Configuration +/// An extractor can be customized by injecting the corresponding configuration with one of: +/// +/// - [`App::app_data()`][crate::App::app_data] +/// - [`Scope::app_data()`][crate::Scope::app_data] +/// - [`Resource::app_data()`][crate::Resource::app_data] +/// +/// Here are some built-in extractors and their corresponding configuration. +/// Please refer to the respective documentation for details. +/// +/// | Extractor | Configuration | +/// |-------------|-------------------| +/// | [`Header`] | _None_ | +/// | [`Path`] | [`PathConfig`] | +/// | [`Json`] | [`JsonConfig`] | +/// | [`Form`] | [`FormConfig`] | +/// | [`Query`] | [`QueryConfig`] | +/// | [`Bytes`] | [`PayloadConfig`] | +/// | [`String`] | [`PayloadConfig`] | +/// | [`Payload`] | [`PayloadConfig`] | +/// +/// # Implementing An Extractor +/// To reduce duplicate code in handlers where extracting certain parts of a request has a common +/// structure, you can implement `FromRequest` for your own types. +/// +/// Note that the request payload can only be consumed by one extractor. +/// +/// [`Header`]: crate::web::Header +/// [`Json`]: crate::web::Json +/// [`JsonConfig`]: crate::web::JsonConfig +/// [`Form`]: crate::web::Form +/// [`FormConfig`]: crate::web::FormConfig +/// [`Path`]: crate::web::Path +/// [`PathConfig`]: crate::web::PathConfig +/// [`Query`]: crate::web::Query +/// [`QueryConfig`]: crate::web::QueryConfig +/// [`Payload`]: crate::web::Payload +/// [`PayloadConfig`]: crate::web::PayloadConfig +/// [`String`]: FromRequest#impl-FromRequest-for-String +/// [`Bytes`]: crate::web::Bytes#impl-FromRequest +/// [`Either`]: crate::web::Either +#[doc(alias = "extract", alias = "extractor")] pub trait FromRequest: Sized { - /// Configuration for this extractor. - type Config: Default + 'static; - /// The associated error which can be returned. type Error: Into; @@ -35,14 +75,6 @@ pub trait FromRequest: Sized { fn extract(req: &HttpRequest) -> Self::Future { Self::from_request(req, &mut Payload::None) } - - /// Create and configure config instance. - fn configure(f: F) -> Self::Config - where - F: FnOnce(Self::Config) -> Self::Config, - { - f(Self::Config::default()) - } } /// Optionally extract a field from the request @@ -65,7 +97,6 @@ pub trait FromRequest: Sized { /// impl FromRequest for Thing { /// type Error = Error; /// type Future = Ready>; -/// type Config = (); /// /// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { /// if rand::random() { @@ -100,7 +131,6 @@ where { type Error = Error; type Future = FromRequestOptFuture; - type Config = T::Config; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { @@ -110,10 +140,11 @@ where } } -#[pin_project::pin_project] -pub struct FromRequestOptFuture { - #[pin] - fut: Fut, +pin_project! { + pub struct FromRequestOptFuture { + #[pin] + fut: Fut, + } } impl Future for FromRequestOptFuture @@ -156,7 +187,6 @@ where /// impl FromRequest for Thing { /// type Error = Error; /// type Future = Ready>; -/// type Config = (); /// /// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { /// if rand::random() { @@ -189,7 +219,6 @@ where { type Error = Error; type Future = FromRequestResFuture; - type Config = T::Config; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { @@ -199,10 +228,11 @@ where } } -#[pin_project::pin_project] -pub struct FromRequestResFuture { - #[pin] - fut: Fut, +pin_project! { + pub struct FromRequestResFuture { + #[pin] + fut: Fut, + } } impl Future for FromRequestResFuture @@ -233,7 +263,6 @@ where impl FromRequest for Uri { type Error = Infallible; type Future = Ready>; - type Config = (); fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { ok(req.uri().clone()) @@ -255,7 +284,6 @@ impl FromRequest for Uri { impl FromRequest for Method { type Error = Infallible; type Future = Ready>; - type Config = (); fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { ok(req.method().clone()) @@ -266,110 +294,110 @@ impl FromRequest for Method { impl FromRequest for () { type Error = Infallible; type Future = Ready>; - type Config = (); fn from_request(_: &HttpRequest, _: &mut Payload) -> Self::Future { ok(()) } } -macro_rules! tuple_from_req ({$fut_type:ident, $(($n:tt, $T:ident)),+} => { - - // This module is a trick to get around the inability of - // `macro_rules!` macros to make new idents. We want to make - // a new `FutWrapper` struct for each distinct invocation of - // this macro. Ideally, we would name it something like - // `FutWrapper_$fut_type`, but this can't be done in a macro_rules - // macro. - // - // Instead, we put everything in a module named `$fut_type`, thus allowing - // us to use the name `FutWrapper` without worrying about conflicts. - // This macro only exists to generate trait impls for tuples - these - // are inherently global, so users don't have to care about this - // weird trick. - #[allow(non_snake_case)] - mod $fut_type { - - // Bring everything into scope, so we don't need - // redundant imports - use super::*; - - /// A helper struct to allow us to pin-project through - /// to individual fields - #[pin_project::pin_project] - struct FutWrapper<$($T: FromRequest),+>($(#[pin] $T::Future),+); - - /// FromRequest implementation for tuple - #[doc(hidden)] - #[allow(unused_parens)] - impl<$($T: FromRequest + 'static),+> FromRequest for ($($T,)+) - { - type Error = Error; - type Future = $fut_type<$($T),+>; - type Config = ($($T::Config),+); - - fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - $fut_type { - items: <($(Option<$T>,)+)>::default(), - futs: FutWrapper($($T::from_request(req, payload),)+), - } - } - } - - #[doc(hidden)] - #[pin_project::pin_project] - pub struct $fut_type<$($T: FromRequest),+> { - items: ($(Option<$T>,)+), - #[pin] - futs: FutWrapper<$($T,)+>, - } - - impl<$($T: FromRequest),+> Future for $fut_type<$($T),+> - { - type Output = Result<($($T,)+), Error>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - - let mut ready = true; - $( - if this.items.$n.is_none() { - match this.futs.as_mut().project().$n.poll(cx) { - Poll::Ready(Ok(item)) => { - this.items.$n = Some(item); - } - Poll::Pending => ready = false, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), - } - } - )+ - - if ready { - Poll::Ready(Ok( - ($(this.items.$n.take().unwrap(),)+) - )) - } else { - Poll::Pending - } - } - } - } -}); - -#[rustfmt::skip] -mod m { +#[doc(hidden)] +#[allow(non_snake_case)] +mod tuple_from_req { use super::*; - tuple_from_req!(TupleFromRequest1, (0, A)); - tuple_from_req!(TupleFromRequest2, (0, A), (1, B)); - tuple_from_req!(TupleFromRequest3, (0, A), (1, B), (2, C)); - tuple_from_req!(TupleFromRequest4, (0, A), (1, B), (2, C), (3, D)); - tuple_from_req!(TupleFromRequest5, (0, A), (1, B), (2, C), (3, D), (4, E)); - tuple_from_req!(TupleFromRequest6, (0, A), (1, B), (2, C), (3, D), (4, E), (5, F)); - tuple_from_req!(TupleFromRequest7, (0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G)); - tuple_from_req!(TupleFromRequest8, (0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H)); - tuple_from_req!(TupleFromRequest9, (0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I)); - tuple_from_req!(TupleFromRequest10, (0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I), (9, J)); + macro_rules! tuple_from_req { + ($fut: ident; $($T: ident),*) => { + /// FromRequest implementation for tuple + #[allow(unused_parens)] + impl<$($T: FromRequest + 'static),+> FromRequest for ($($T,)+) + { + type Error = Error; + type Future = $fut<$($T),+>; + + fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { + $fut { + $( + $T: ExtractFuture::Future { + fut: $T::from_request(req, payload) + }, + )+ + } + } + } + + pin_project! { + pub struct $fut<$($T: FromRequest),+> { + $( + #[pin] + $T: ExtractFuture<$T::Future, $T>, + )+ + } + } + + impl<$($T: FromRequest),+> Future for $fut<$($T),+> + { + type Output = Result<($($T,)+), Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + let mut ready = true; + $( + match this.$T.as_mut().project() { + ExtractProj::Future { fut } => match fut.poll(cx) { + Poll::Ready(Ok(output)) => { + let _ = this.$T.as_mut().project_replace(ExtractFuture::Done { output }); + }, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), + Poll::Pending => ready = false, + }, + ExtractProj::Done { .. } => {}, + ExtractProj::Empty => unreachable!("FromRequest polled after finished"), + } + )+ + + if ready { + Poll::Ready(Ok( + ($( + match this.$T.project_replace(ExtractFuture::Empty) { + ExtractReplaceProj::Done { output } => output, + _ => unreachable!("FromRequest polled after finished"), + }, + )+) + )) + } else { + Poll::Pending + } + } + } + }; + } + + pin_project! { + #[project = ExtractProj] + #[project_replace = ExtractReplaceProj] + enum ExtractFuture { + Future { + #[pin] + fut: Fut + }, + Done { + output: Res, + }, + Empty + } + } + + tuple_from_req! { TupleFromRequest1; A } + tuple_from_req! { TupleFromRequest2; A, B } + tuple_from_req! { TupleFromRequest3; A, B, C } + tuple_from_req! { TupleFromRequest4; A, B, C, D } + tuple_from_req! { TupleFromRequest5; A, B, C, D, E } + tuple_from_req! { TupleFromRequest6; A, B, C, D, E, F } + tuple_from_req! { TupleFromRequest7; A, B, C, D, E, F, G } + tuple_from_req! { TupleFromRequest8; A, B, C, D, E, F, G, H } + tuple_from_req! { TupleFromRequest9; A, B, C, D, E, F, G, H, I } + tuple_from_req! { TupleFromRequest10; A, B, C, D, E, F, G, H, I, J } } #[cfg(test)] @@ -471,4 +499,26 @@ mod tests { let method = Method::extract(&req).await.unwrap(); assert_eq!(method, Method::GET); } + + #[actix_rt::test] + async fn test_concurrent() { + let (req, mut pl) = TestRequest::default() + .uri("/foo/bar") + .method(Method::GET) + .insert_header((header::CONTENT_TYPE, "application/x-www-form-urlencoded")) + .insert_header((header::CONTENT_LENGTH, "11")) + .set_payload(Bytes::from_static(b"hello=world")) + .to_http_parts(); + let (method, uri, form) = <(Method, Uri, Form)>::from_request(&req, &mut pl) + .await + .unwrap(); + assert_eq!(method, Method::GET); + assert_eq!(uri.path(), "/foo/bar"); + assert_eq!( + form, + Form(Info { + hello: "world".into() + }) + ); + } } diff --git a/src/handler.rs b/src/handler.rs index bc91ce41b..e543ecc7f 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,24 +1,21 @@ use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::task::{Context, Poll}; -use actix_service::{Service, ServiceFactory}; -use actix_utils::future::{ready, Ready}; -use futures_core::ready; -use pin_project::pin_project; +use actix_service::{boxed, fn_service}; use crate::{ - service::{ServiceRequest, ServiceResponse}, - Error, FromRequest, HttpRequest, HttpResponse, Responder, + body::MessageBody, + service::{BoxedHttpServiceFactory, ServiceRequest, ServiceResponse}, + BoxError, FromRequest, HttpResponse, Responder, }; /// A request handler is an async function that accepts zero or more parameters that can be -/// extracted from a request (i.e., [`impl FromRequest`](crate::FromRequest)) and returns a type -/// that can be converted into an [`HttpResponse`] (that is, it impls the [`Responder`] trait). +/// extracted from a request (i.e., [`impl FromRequest`]) and returns a type that can be converted +/// into an [`HttpResponse`] (that is, it impls the [`Responder`] trait). /// /// If you got the error `the trait Handler<_, _, _> is not implemented`, then your function is not -/// a valid handler. See [Request Handlers](https://actix.rs/docs/handlers/) for more information. +/// a valid handler. See for more information. +/// +/// [`impl FromRequest`]: crate::FromRequest pub trait Handler: Clone + 'static where R: Future, @@ -27,142 +24,44 @@ where fn call(&self, param: T) -> R; } -#[doc(hidden)] -/// Extract arguments from request, run factory function and make response. -pub struct HandlerService +pub(crate) fn handler_service(handler: F) -> BoxedHttpServiceFactory where F: Handler, T: FromRequest, R: Future, R::Output: Responder, + ::Body: MessageBody, + <::Body as MessageBody>::Error: Into, { - hnd: F, - _phantom: PhantomData<(T, R)>, -} + boxed::factory(fn_service(move |req: ServiceRequest| { + let handler = handler.clone(); -impl HandlerService -where - F: Handler, - T: FromRequest, - R: Future, - R::Output: Responder, -{ - pub fn new(hnd: F) -> Self { - Self { - hnd, - _phantom: PhantomData, + async move { + let (req, mut payload) = req.into_parts(); + + let res = match T::from_request(&req, &mut payload).await { + Err(err) => HttpResponse::from_error(err), + + Ok(data) => handler + .call(data) + .await + .respond_to(&req) + .map_into_boxed_body(), + }; + + Ok(ServiceResponse::new(req, res)) } - } + })) } -impl Clone for HandlerService -where - F: Handler, - T: FromRequest, - R: Future, - R::Output: Responder, -{ - fn clone(&self) -> Self { - Self { - hnd: self.hnd.clone(), - _phantom: PhantomData, - } - } -} - -impl ServiceFactory for HandlerService -where - F: Handler, - T: FromRequest, - R: Future, - R::Output: Responder, -{ - type Response = ServiceResponse; - type Error = Error; - type Config = (); - type Service = Self; - type InitError = (); - type Future = Ready>; - - fn new_service(&self, _: ()) -> Self::Future { - ready(Ok(self.clone())) - } -} - -/// HandlerService is both it's ServiceFactory and Service Type. -impl Service for HandlerService -where - F: Handler, - T: FromRequest, - R: Future, - R::Output: Responder, -{ - type Response = ServiceResponse; - type Error = Error; - type Future = HandlerServiceFuture; - - actix_service::always_ready!(); - - fn call(&self, req: ServiceRequest) -> Self::Future { - let (req, mut payload) = req.into_parts(); - let fut = T::from_request(&req, &mut payload); - HandlerServiceFuture::Extract(fut, Some(req), self.hnd.clone()) - } -} - -#[doc(hidden)] -#[pin_project(project = HandlerProj)] -pub enum HandlerServiceFuture -where - F: Handler, - T: FromRequest, - R: Future, - R::Output: Responder, -{ - Extract(#[pin] T::Future, Option, F), - Handle(#[pin] R, Option), -} - -impl Future for HandlerServiceFuture -where - F: Handler, - T: FromRequest, - R: Future, - R::Output: Responder, -{ - // Error type in this future is a placeholder type. - // all instances of error must be converted to ServiceResponse and return in Ok. - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - match self.as_mut().project() { - HandlerProj::Extract(fut, req, handle) => { - match ready!(fut.poll(cx)) { - Ok(item) => { - let fut = handle.call(item); - let state = HandlerServiceFuture::Handle(fut, req.take()); - self.as_mut().set(state); - } - Err(err) => { - let req = req.take().unwrap(); - let res = HttpResponse::from_error(err.into()); - return Poll::Ready(Ok(ServiceResponse::new(req, res))); - } - }; - } - HandlerProj::Handle(fut, req) => { - let res = ready!(fut.poll(cx)); - let req = req.take().unwrap(); - let res = res.respond_to(&req); - return Poll::Ready(Ok(ServiceResponse::new(req, res))); - } - } - } - } -} - -/// FromRequest trait impl for tuples +/// Generates a [`Handler`] trait impl for N-ary functions where N is specified with a sequence of +/// space separated type parameters. +/// +/// # Examples +/// ```ignore +/// factory_tuple! {} // implements Handler for types: fn() -> Res +/// factory_tuple! { A B C } // implements Handler for types: fn(A, B, C) -> Res +/// ``` macro_rules! factory_tuple ({ $($param:ident)* } => { impl Handler<($($param,)*), Res> for Func where Func: Fn($($param),*) -> Res + Clone + 'static, diff --git a/src/http/header/accept.rs b/src/http/header/accept.rs index 75366dfae..c61e6ab49 100644 --- a/src/http/header/accept.rs +++ b/src/http/header/accept.rs @@ -2,11 +2,12 @@ use std::cmp::Ordering; use mime::Mime; -use super::{qitem, QualityItem}; +use super::QualityItem; use crate::http::header; crate::http::header::common_header! { - /// `Accept` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.2) + /// `Accept` header, defined + /// in [RFC 7231 §5.3.2](https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2) /// /// The `Accept` header field can be used by user agents to specify /// response media types that are acceptable. Accept header fields can @@ -15,8 +16,7 @@ crate::http::header::common_header! { /// in-line image /// /// # ABNF - /// - /// ```text + /// ```plain /// Accept = #( media-range [ accept-params ] ) /// /// media-range = ( "*/*" @@ -27,97 +27,94 @@ crate::http::header::common_header! { /// accept-ext = OWS ";" OWS token [ "=" ( token / quoted-string ) ] /// ``` /// - /// # Example values + /// # Example Values /// * `audio/*; q=0.2, audio/basic` /// * `text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c` /// /// # Examples /// ``` /// use actix_web::HttpResponse; - /// use actix_web::http::header::{Accept, qitem}; + /// use actix_web::http::header::{Accept, QualityItem}; /// /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// Accept(vec![ - /// qitem(mime::TEXT_HTML), + /// QualityItem::max(mime::TEXT_HTML), /// ]) /// ); /// ``` /// /// ``` /// use actix_web::HttpResponse; - /// use actix_web::http::header::{Accept, qitem}; + /// use actix_web::http::header::{Accept, QualityItem}; /// /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// Accept(vec![ - /// qitem(mime::APPLICATION_JSON), + /// QualityItem::max(mime::APPLICATION_JSON), /// ]) /// ); /// ``` /// /// ``` /// use actix_web::HttpResponse; - /// use actix_web::http::header::{Accept, QualityItem, q, qitem}; + /// use actix_web::http::header::{Accept, QualityItem, q}; /// /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// Accept(vec![ - /// qitem(mime::TEXT_HTML), - /// qitem("application/xhtml+xml".parse().unwrap()), - /// QualityItem::new( - /// mime::TEXT_XML, - /// q(900) - /// ), - /// qitem("image/webp".parse().unwrap()), - /// QualityItem::new( - /// mime::STAR_STAR, - /// q(800) - /// ), + /// QualityItem::max(mime::TEXT_HTML), + /// QualityItem::max("application/xhtml+xml".parse().unwrap()), + /// QualityItem::new(mime::TEXT_XML, q(0.9)), + /// QualityItem::max("image/webp".parse().unwrap()), + /// QualityItem::new(mime::STAR_STAR, q(0.8)), /// ]) /// ); /// ``` - (Accept, header::ACCEPT) => (QualityItem)+ + (Accept, header::ACCEPT) => (QualityItem)* - test_accept { + test_parse_and_format { // Tests from the RFC crate::http::header::common_header_test!( test1, vec![b"audio/*; q=0.2, audio/basic"], Some(Accept(vec![ - QualityItem::new("audio/*".parse().unwrap(), q(200)), - qitem("audio/basic".parse().unwrap()), + QualityItem::new("audio/*".parse().unwrap(), q(0.2)), + QualityItem::max("audio/basic".parse().unwrap()), ]))); + crate::http::header::common_header_test!( test2, vec![b"text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c"], Some(Accept(vec![ - QualityItem::new(mime::TEXT_PLAIN, q(500)), - qitem(mime::TEXT_HTML), + QualityItem::new(mime::TEXT_PLAIN, q(0.5)), + QualityItem::max(mime::TEXT_HTML), QualityItem::new( "text/x-dvi".parse().unwrap(), - q(800)), - qitem("text/x-c".parse().unwrap()), + q(0.8)), + QualityItem::max("text/x-c".parse().unwrap()), ]))); + // Custom tests crate::http::header::common_header_test!( test3, vec![b"text/plain; charset=utf-8"], Some(Accept(vec![ - qitem(mime::TEXT_PLAIN_UTF_8), + QualityItem::max(mime::TEXT_PLAIN_UTF_8), ]))); crate::http::header::common_header_test!( test4, vec![b"text/plain; charset=utf-8; q=0.5"], Some(Accept(vec![ QualityItem::new(mime::TEXT_PLAIN_UTF_8, - q(500)), + q(0.5)), ]))); #[test] fn test_fuzzing1() { - use actix_http::test::TestRequest; - let req = TestRequest::default().insert_header((crate::http::header::ACCEPT, "chunk#;e")).finish(); + let req = test::TestRequest::default() + .insert_header((header::ACCEPT, "chunk#;e")) + .finish(); let header = Accept::parse(&req); assert!(header.is_ok()); } @@ -127,34 +124,38 @@ crate::http::header::common_header! { impl Accept { /// Construct `Accept: */*`. pub fn star() -> Accept { - Accept(vec![qitem(mime::STAR_STAR)]) + Accept(vec![QualityItem::max(mime::STAR_STAR)]) } /// Construct `Accept: application/json`. pub fn json() -> Accept { - Accept(vec![qitem(mime::APPLICATION_JSON)]) + Accept(vec![QualityItem::max(mime::APPLICATION_JSON)]) } /// Construct `Accept: text/*`. pub fn text() -> Accept { - Accept(vec![qitem(mime::TEXT_STAR)]) + Accept(vec![QualityItem::max(mime::TEXT_STAR)]) } /// Construct `Accept: image/*`. pub fn image() -> Accept { - Accept(vec![qitem(mime::IMAGE_STAR)]) + Accept(vec![QualityItem::max(mime::IMAGE_STAR)]) } /// Construct `Accept: text/html`. pub fn html() -> Accept { - Accept(vec![qitem(mime::TEXT_HTML)]) + Accept(vec![QualityItem::max(mime::TEXT_HTML)]) } /// Returns a sorted list of mime types from highest to lowest preference, accounting for /// [q-factor weighting] and specificity. /// - /// [q-factor weighting]: https://tools.ietf.org/html/rfc7231#section-5.3.2 - pub fn mime_precedence(&self) -> Vec { + /// [q-factor weighting]: https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 + pub fn ranked(&self) -> Vec { + if self.is_empty() { + return vec![]; + } + let mut types = self.0.clone(); // use stable sort so items with equal q-factor and specificity retain listed order @@ -201,12 +202,29 @@ impl Accept { /// If no q-factors are provided, the first mime type is chosen. Note that items without /// q-factors are given the maximum preference value. /// - /// Returns `None` if contained list is empty. + /// As per the spec, will return [`mime::STAR_STAR`] (indicating no preference) if the contained + /// list is empty. /// - /// [q-factor weighting]: https://tools.ietf.org/html/rfc7231#section-5.3.2 - pub fn mime_preference(&self) -> Option { - let types = self.mime_precedence(); - types.first().cloned() + /// [q-factor weighting]: https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 + pub fn preference(&self) -> Mime { + use actix_http::header::Quality; + + let mut max_item = None; + let mut max_pref = Quality::MIN; + + // uses manual max lookup loop since we want the first occurrence in the case of same + // preference but `Iterator::max_by_key` would give us the last occurrence + + for pref in &self.0 { + // only change if strictly greater + // equal items, even while unsorted, still have higher preference if they appear first + if pref.quality > max_pref { + max_pref = pref.quality; + max_item = Some(pref.item.clone()); + } + } + + max_item.unwrap_or(mime::STAR_STAR) } } @@ -216,21 +234,21 @@ mod tests { use crate::http::header::q; #[test] - fn test_mime_precedence() { + fn ranking_precedence() { let test = Accept(vec![]); - assert!(test.mime_precedence().is_empty()); + assert!(test.ranked().is_empty()); - let test = Accept(vec![qitem(mime::APPLICATION_JSON)]); - assert_eq!(test.mime_precedence(), vec!(mime::APPLICATION_JSON)); + let test = Accept(vec![QualityItem::max(mime::APPLICATION_JSON)]); + assert_eq!(test.ranked(), vec!(mime::APPLICATION_JSON)); let test = Accept(vec![ - qitem(mime::TEXT_HTML), + QualityItem::max(mime::TEXT_HTML), "application/xhtml+xml".parse().unwrap(), QualityItem::new("application/xml".parse().unwrap(), q(0.9)), QualityItem::new(mime::STAR_STAR, q(0.8)), ]); assert_eq!( - test.mime_precedence(), + test.ranked(), vec![ mime::TEXT_HTML, "application/xhtml+xml".parse().unwrap(), @@ -240,33 +258,33 @@ mod tests { ); let test = Accept(vec![ - qitem(mime::STAR_STAR), - qitem(mime::IMAGE_STAR), - qitem(mime::IMAGE_PNG), + QualityItem::max(mime::STAR_STAR), + QualityItem::max(mime::IMAGE_STAR), + QualityItem::max(mime::IMAGE_PNG), ]); assert_eq!( - test.mime_precedence(), + test.ranked(), vec![mime::IMAGE_PNG, mime::IMAGE_STAR, mime::STAR_STAR] ); } #[test] - fn test_mime_preference() { + fn preference_selection() { let test = Accept(vec![ - qitem(mime::TEXT_HTML), + QualityItem::max(mime::TEXT_HTML), "application/xhtml+xml".parse().unwrap(), QualityItem::new("application/xml".parse().unwrap(), q(0.9)), QualityItem::new(mime::STAR_STAR, q(0.8)), ]); - assert_eq!(test.mime_preference(), Some(mime::TEXT_HTML)); + assert_eq!(test.preference(), mime::TEXT_HTML); let test = Accept(vec![ QualityItem::new("video/*".parse().unwrap(), q(0.8)), - qitem(mime::IMAGE_PNG), + QualityItem::max(mime::IMAGE_PNG), QualityItem::new(mime::STAR_STAR, q(0.5)), - qitem(mime::IMAGE_SVG), + QualityItem::max(mime::IMAGE_SVG), QualityItem::new(mime::IMAGE_STAR, q(0.8)), ]); - assert_eq!(test.mime_preference(), Some(mime::IMAGE_PNG)); + assert_eq!(test.preference(), mime::IMAGE_PNG); } } diff --git a/src/http/header/accept_charset.rs b/src/http/header/accept_charset.rs index bb7d86516..c8b918c91 100644 --- a/src/http/header/accept_charset.rs +++ b/src/http/header/accept_charset.rs @@ -2,7 +2,7 @@ use super::{Charset, QualityItem, ACCEPT_CHARSET}; crate::http::header::common_header! { /// `Accept-Charset` header, defined in - /// [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.3) + /// [RFC 7231 §5.3.3](https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.3) /// /// The `Accept-Charset` header field can be sent by a user agent to /// indicate what charsets are acceptable in textual response content. @@ -12,22 +12,21 @@ crate::http::header::common_header! { /// those charsets. /// /// # ABNF - /// - /// ```text + /// ```plain /// Accept-Charset = 1#( ( charset / "*" ) [ weight ] ) /// ``` /// - /// # Example values + /// # Example Values /// * `iso-8859-5, unicode-1-1;q=0.8` /// /// # Examples /// ``` /// use actix_web::HttpResponse; - /// use actix_web::http::header::{AcceptCharset, Charset, qitem}; + /// use actix_web::http::header::{AcceptCharset, Charset, QualityItem}; /// /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( - /// AcceptCharset(vec![qitem(Charset::Us_Ascii)]) + /// AcceptCharset(vec![QualityItem::max(Charset::Us_Ascii)]) /// ); /// ``` /// @@ -38,24 +37,24 @@ crate::http::header::common_header! { /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// AcceptCharset(vec![ - /// QualityItem::new(Charset::Us_Ascii, q(900)), - /// QualityItem::new(Charset::Iso_8859_10, q(200)), + /// QualityItem::new(Charset::Us_Ascii, q(0.9)), + /// QualityItem::new(Charset::Iso_8859_10, q(0.2)), /// ]) /// ); /// ``` /// /// ``` /// use actix_web::HttpResponse; - /// use actix_web::http::header::{AcceptCharset, Charset, qitem}; + /// use actix_web::http::header::{AcceptCharset, Charset, QualityItem}; /// /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( - /// AcceptCharset(vec![qitem(Charset::Ext("utf-8".to_owned()))]) + /// AcceptCharset(vec![QualityItem::max(Charset::Ext("utf-8".to_owned()))]) /// ); /// ``` (AcceptCharset, ACCEPT_CHARSET) => (QualityItem)+ - test_accept_charset { + test_parse_and_format { // Test case from RFC crate::http::header::common_header_test!(test1, vec![b"iso-8859-5, unicode-1-1;q=0.8"]); } diff --git a/src/http/header/accept_encoding.rs b/src/http/header/accept_encoding.rs index cfd29bf77..828a0533c 100644 --- a/src/http/header/accept_encoding.rs +++ b/src/http/header/accept_encoding.rs @@ -1,8 +1,11 @@ -use header::{Encoding, QualityItem}; +use actix_http::header::QualityItem; -header! { - /// `Accept-Encoding` header, defined in - /// [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.4) +use super::{common_header, Encoding}; +use crate::http::header; + +common_header! { + /// `Accept-Encoding` header, defined + /// in [RFC 7231](https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.4) /// /// The `Accept-Encoding` header field can be used by user agents to /// indicate what response content-codings are @@ -11,13 +14,12 @@ header! { /// preferred. /// /// # ABNF - /// - /// ```text + /// ```plain /// Accept-Encoding = #( codings [ weight ] ) /// codings = content-coding / "identity" / "*" /// ``` /// - /// # Example values + /// # Example Values /// * `compress, gzip` /// * `` /// * `*` @@ -27,49 +29,55 @@ header! { /// # Examples /// ``` /// use actix_web::HttpResponse; - /// use actix_web::http::header::{AcceptEncoding, Encoding, qitem}; + /// use actix_web::http::header::{AcceptEncoding, Encoding, QualityItem}; /// - /// let mut builder = HttpResponse::new(); + /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( - /// AcceptEncoding(vec![qitem(Encoding::Chunked)]) + /// AcceptEncoding(vec![QualityItem::max(Encoding::Chunked)]) /// ); /// ``` + /// /// ``` /// use actix_web::HttpResponse; - /// use actix_web::http::header::{AcceptEncoding, Encoding, qitem}; + /// use actix_web::http::header::{AcceptEncoding, Encoding, QualityItem}; /// - /// let mut builder = HttpResponse::new(); + /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// AcceptEncoding(vec![ - /// qitem(Encoding::Chunked), - /// qitem(Encoding::Gzip), - /// qitem(Encoding::Deflate), + /// QualityItem::max(Encoding::Chunked), + /// QualityItem::max(Encoding::Gzip), + /// QualityItem::max(Encoding::Deflate), /// ]) /// ); /// ``` + /// /// ``` /// use actix_web::HttpResponse; - /// use actix_web::http::header::{AcceptEncoding, Encoding, QualityItem, q, qitem}; + /// use actix_web::http::header::{AcceptEncoding, Encoding, QualityItem, q}; /// - /// let mut builder = HttpResponse::new(); + /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// AcceptEncoding(vec![ - /// qitem(Encoding::Chunked), - /// QualityItem::new(Encoding::Gzip, q(600)), - /// QualityItem::new(Encoding::EncodingExt("*".to_owned()), q(0)), + /// QualityItem::max(Encoding::Chunked), + /// QualityItem::new(Encoding::Gzip, q(0.60)), + /// QualityItem::min(Encoding::EncodingExt("*".to_owned())), /// ]) /// ); /// ``` - (AcceptEncoding, "Accept-Encoding") => (QualityItem)* + (AcceptEncoding, header::ACCEPT_ENCODING) => (QualityItem)* - test_accept_encoding { + test_parse_and_format { // From the RFC - crate::http::header::common_header_test!(test1, vec![b"compress, gzip"]); - crate::http::header::common_header_test!(test2, vec![b""], Some(AcceptEncoding(vec![]))); - crate::http::header::common_header_test!(test3, vec![b"*"]); + common_header_test!(test1, vec![b"compress, gzip"]); + common_header_test!(test2, vec![b""], Some(AcceptEncoding(vec![]))); + common_header_test!(test3, vec![b"*"]); + // Note: Removed quality 1 from gzip - crate::http::header::common_header_test!(test4, vec![b"compress;q=0.5, gzip"]); + common_header_test!(test4, vec![b"compress;q=0.5, gzip"]); + // Note: Removed quality 1 from gzip - crate::http::header::common_header_test!(test5, vec![b"gzip, identity; q=0.5, *;q=0"]); + common_header_test!(test5, vec![b"gzip, identity; q=0.5, *;q=0"]); } } + +// TODO: shortcut for EncodingExt(*) = Any diff --git a/src/http/header/accept_language.rs b/src/http/header/accept_language.rs index 1552f6578..011257b87 100644 --- a/src/http/header/accept_language.rs +++ b/src/http/header/accept_language.rs @@ -1,66 +1,223 @@ use language_tags::LanguageTag; -use super::{QualityItem, ACCEPT_LANGUAGE}; +use super::{common_header, Preference, Quality, QualityItem}; +use crate::http::header; -crate::http::header::common_header! { - /// `Accept-Language` header, defined in - /// [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.5) +common_header! { + /// `Accept-Language` header, defined + /// in [RFC 7231 §5.3.5](https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.5) /// - /// The `Accept-Language` header field can be used by user agents to - /// indicate the set of natural languages that are preferred in the - /// response. + /// The `Accept-Language` header field can be used by user agents to indicate the set of natural + /// languages that are preferred in the response. + /// + /// The `Accept-Language` header is defined in + /// [RFC 7231 §5.3.5](https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.5) using language + /// ranges defined in [RFC 4647 §2.1](https://datatracker.ietf.org/doc/html/rfc4647#section-2.1). /// /// # ABNF - /// - /// ```text + /// ```plain /// Accept-Language = 1#( language-range [ weight ] ) - /// language-range = + /// language-range = (1*8ALPHA *("-" 1*8alphanum)) / "*" + /// alphanum = ALPHA / DIGIT + /// weight = OWS ";" OWS "q=" qvalue + /// qvalue = ( "0" [ "." 0*3DIGIT ] ) + /// / ( "1" [ "." 0*3("0") ] ) /// ``` /// - /// # Example values - /// * `da, en-gb;q=0.8, en;q=0.7` - /// * `en-us;q=1.0, en;q=0.5, fr` + /// # Example Values + /// - `da, en-gb;q=0.8, en;q=0.7` + /// - `en-us;q=1.0, en;q=0.5, fr` + /// - `fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5` /// /// # Examples - /// /// ``` /// use actix_web::HttpResponse; - /// use actix_web::http::header::{AcceptLanguage, LanguageTag, qitem}; + /// use actix_web::http::header::{AcceptLanguage, QualityItem}; /// /// let mut builder = HttpResponse::Ok(); - /// let langtag = LanguageTag::parse("en-US").unwrap(); /// builder.insert_header( /// AcceptLanguage(vec![ - /// qitem(langtag), + /// QualityItem::max("en-US".parse().unwrap()) /// ]) /// ); /// ``` /// /// ``` /// use actix_web::HttpResponse; - /// use actix_web::http::header::{AcceptLanguage, LanguageTag, QualityItem, q, qitem}; + /// use actix_web::http::header::{AcceptLanguage, QualityItem, q}; /// /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// AcceptLanguage(vec![ - /// qitem(LanguageTag::parse("da").unwrap()), - /// QualityItem::new(LanguageTag::parse("en-GB").unwrap(), q(800)), - /// QualityItem::new(LanguageTag::parse("en").unwrap(), q(700)), + /// QualityItem::max("da".parse().unwrap()), + /// QualityItem::new("en-GB".parse().unwrap(), q(0.8)), + /// QualityItem::new("en".parse().unwrap(), q(0.7)), /// ]) /// ); /// ``` - (AcceptLanguage, ACCEPT_LANGUAGE) => (QualityItem)+ + (AcceptLanguage, header::ACCEPT_LANGUAGE) => (QualityItem>)* - test_accept_language { - // From the RFC - crate::http::header::common_header_test!(test1, vec![b"da, en-gb;q=0.8, en;q=0.7"]); - // Own test - crate::http::header::common_header_test!( - test2, vec![b"en-US, en; q=0.5, fr"], + test_parse_and_format { + common_header_test!(no_headers, vec![b""; 0], Some(AcceptLanguage(vec![]))); + + common_header_test!(empty_header, vec![b""; 1], Some(AcceptLanguage(vec![]))); + + common_header_test!( + example_from_rfc, + vec![b"da, en-gb;q=0.8, en;q=0.7"] + ); + + + common_header_test!( + not_ordered_by_weight, + vec![b"en-US, en; q=0.5, fr"], Some(AcceptLanguage(vec![ - qitem("en-US".parse().unwrap()), - QualityItem::new("en".parse().unwrap(), q(500)), - qitem("fr".parse().unwrap()), - ]))); + QualityItem::max("en-US".parse().unwrap()), + QualityItem::new("en".parse().unwrap(), q(0.5)), + QualityItem::max("fr".parse().unwrap()), + ])) + ); + + common_header_test!( + has_wildcard, + vec![b"fr-CH, fr; q=0.9, en; q=0.8, de; q=0.7, *; q=0.5"], + Some(AcceptLanguage(vec![ + QualityItem::max("fr-CH".parse().unwrap()), + QualityItem::new("fr".parse().unwrap(), q(0.9)), + QualityItem::new("en".parse().unwrap(), q(0.8)), + QualityItem::new("de".parse().unwrap(), q(0.7)), + QualityItem::new("*".parse().unwrap(), q(0.5)), + ])) + ); + } +} + +impl AcceptLanguage { + /// Returns a sorted list of languages from highest to lowest precedence, accounting + /// for [q-factor weighting]. + /// + /// [q-factor weighting]: https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 + pub fn ranked(&self) -> Vec> { + if self.0.is_empty() { + return vec![]; + } + + let mut types = self.0.clone(); + + // use stable sort so items with equal q-factor retain listed order + types.sort_by(|a, b| { + // sort by q-factor descending + b.quality.cmp(&a.quality) + }); + + types.into_iter().map(|qitem| qitem.item).collect() + } + + /// Extracts the most preferable language, accounting for [q-factor weighting]. + /// + /// If no q-factors are provided, the first language is chosen. Note that items without + /// q-factors are given the maximum preference value. + /// + /// As per the spec, returns [`Preference::Any`] if contained list is empty. + /// + /// [q-factor weighting]: https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 + pub fn preference(&self) -> Preference { + let mut max_item = None; + let mut max_pref = Quality::MIN; + + // uses manual max lookup loop since we want the first occurrence in the case of same + // preference but `Iterator::max_by_key` would give us the last occurrence + + for pref in &self.0 { + // only change if strictly greater + // equal items, even while unsorted, still have higher preference if they appear first + if pref.quality > max_pref { + max_pref = pref.quality; + max_item = Some(pref.item.clone()); + } + } + + max_item.unwrap_or(Preference::Any) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::http::header::*; + + #[test] + fn ranking_precedence() { + let test = AcceptLanguage(vec![]); + assert!(test.ranked().is_empty()); + + let test = AcceptLanguage(vec![QualityItem::max("fr-CH".parse().unwrap())]); + assert_eq!(test.ranked(), vec!("fr-CH".parse().unwrap())); + + let test = AcceptLanguage(vec![ + QualityItem::new("fr".parse().unwrap(), q(0.900)), + QualityItem::new("fr-CH".parse().unwrap(), q(1.0)), + QualityItem::new("en".parse().unwrap(), q(0.800)), + QualityItem::new("*".parse().unwrap(), q(0.500)), + QualityItem::new("de".parse().unwrap(), q(0.700)), + ]); + assert_eq!( + test.ranked(), + vec![ + "fr-CH".parse().unwrap(), + "fr".parse().unwrap(), + "en".parse().unwrap(), + "de".parse().unwrap(), + "*".parse().unwrap(), + ] + ); + + let test = AcceptLanguage(vec![ + QualityItem::max("fr".parse().unwrap()), + QualityItem::max("fr-CH".parse().unwrap()), + QualityItem::max("en".parse().unwrap()), + QualityItem::max("*".parse().unwrap()), + QualityItem::max("de".parse().unwrap()), + ]); + assert_eq!( + test.ranked(), + vec![ + "fr".parse().unwrap(), + "fr-CH".parse().unwrap(), + "en".parse().unwrap(), + "*".parse().unwrap(), + "de".parse().unwrap(), + ] + ); + } + + #[test] + fn preference_selection() { + let test = AcceptLanguage(vec![ + QualityItem::new("fr".parse().unwrap(), q(0.900)), + QualityItem::new("fr-CH".parse().unwrap(), q(1.0)), + QualityItem::new("en".parse().unwrap(), q(0.800)), + QualityItem::new("*".parse().unwrap(), q(0.500)), + QualityItem::new("de".parse().unwrap(), q(0.700)), + ]); + assert_eq!( + test.preference(), + Preference::Specific("fr-CH".parse().unwrap()) + ); + + let test = AcceptLanguage(vec![ + QualityItem::max("fr".parse().unwrap()), + QualityItem::max("fr-CH".parse().unwrap()), + QualityItem::max("en".parse().unwrap()), + QualityItem::max("*".parse().unwrap()), + QualityItem::max("de".parse().unwrap()), + ]); + assert_eq!( + test.preference(), + Preference::Specific("fr".parse().unwrap()) + ); + + let test = AcceptLanguage(vec![]); + assert_eq!(test.preference(), Preference::Any); } } diff --git a/src/http/header/allow.rs b/src/http/header/allow.rs index 946f70e0a..c8cc153e8 100644 --- a/src/http/header/allow.rs +++ b/src/http/header/allow.rs @@ -1,27 +1,27 @@ -use crate::http::header; use actix_http::http::Method; +use crate::http::header; + crate::http::header::common_header! { - /// `Allow` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-7.4.1) + /// `Allow` header, defined + /// in [RFC 7231 §7.4.1](https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1) /// /// The `Allow` header field lists the set of methods advertised as - /// supported by the target resource. The purpose of this field is + /// supported by the target resource. The purpose of this field is /// strictly to inform the recipient of valid request methods associated /// with the resource. /// /// # ABNF - /// - /// ```text + /// ```plain /// Allow = #method /// ``` /// - /// # Example values + /// # Example Values /// * `GET, HEAD, PUT` /// * `OPTIONS, GET, PUT, POST, DELETE, HEAD, TRACE, CONNECT, PATCH, fOObAr` /// * `` /// /// # Examples - /// /// ``` /// use actix_web::HttpResponse; /// use actix_web::http::{header::Allow, Method}; @@ -47,7 +47,7 @@ crate::http::header::common_header! { /// ``` (Allow, header::ALLOW) => (Method)* - test_allow { + test_parse_and_format { // From the RFC crate::http::header::common_header_test!( test1, diff --git a/src/http/header/any_or_some.rs b/src/http/header/any_or_some.rs new file mode 100644 index 000000000..e5a37e495 --- /dev/null +++ b/src/http/header/any_or_some.rs @@ -0,0 +1,70 @@ +use std::{ + fmt::{self, Write as _}, + str, +}; + +/// A wrapper for types used in header values where wildcard (`*`) items are allowed but the +/// underlying type does not support them. +/// +/// For example, we use the `language-tags` crate for the [`AcceptLanguage`](super::AcceptLanguage) +/// typed header but it does parse `*` successfully. On the other hand, the `mime` crate, used for +/// [`Accept`](super::Accept), has first-party support for wildcard items so this wrapper is not +/// used in those header types. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Hash)] +pub enum AnyOrSome { + /// A wildcard value. + Any, + + /// A valid `T`. + Item(T), +} + +impl AnyOrSome { + /// Returns true if item is wildcard (`*`) variant. + pub fn is_any(&self) -> bool { + matches!(self, Self::Any) + } + + /// Returns true if item is a valid item (`T`) variant. + pub fn is_item(&self) -> bool { + matches!(self, Self::Item(_)) + } + + /// Returns reference to value in `Item` variant, if it is set. + pub fn item(&self) -> Option<&T> { + match self { + AnyOrSome::Item(ref item) => Some(item), + AnyOrSome::Any => None, + } + } + + /// Consumes the container, returning the value in the `Item` variant, if it is set. + pub fn into_item(self) -> Option { + match self { + AnyOrSome::Item(item) => Some(item), + AnyOrSome::Any => None, + } + } +} + +impl fmt::Display for AnyOrSome { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AnyOrSome::Any => f.write_char('*'), + AnyOrSome::Item(item) => fmt::Display::fmt(item, f), + } + } +} + +impl str::FromStr for AnyOrSome { + type Err = T::Err; + + #[inline] + fn from_str(s: &str) -> Result { + match s.trim() { + "*" => Ok(Self::Any), + other => other.parse().map(AnyOrSome::Item), + } + } +} diff --git a/src/http/header/cache_control.rs b/src/http/header/cache_control.rs index 05903e3a3..490d36558 100644 --- a/src/http/header/cache_control.rs +++ b/src/http/header/cache_control.rs @@ -1,94 +1,99 @@ -use std::fmt::{self, Write}; -use std::str::FromStr; - -use super::{fmt_comma_delimited, from_comma_delimited, Header, IntoHeaderValue, Writer}; +use std::{fmt, str}; +use super::common_header; use crate::http::header; -/// `Cache-Control` header, defined in [RFC7234](https://tools.ietf.org/html/rfc7234#section-5.2) -/// -/// The `Cache-Control` header field is used to specify directives for -/// caches along the request/response chain. Such cache directives are -/// unidirectional in that the presence of a directive in a request does -/// not imply that the same directive is to be given in the response. -/// -/// # ABNF -/// -/// ```text -/// Cache-Control = 1#cache-directive -/// cache-directive = token [ "=" ( token / quoted-string ) ] -/// ``` -/// -/// # Example values -/// -/// * `no-cache` -/// * `private, community="UCI"` -/// * `max-age=30` -/// -/// # Examples -/// ``` -/// use actix_web::HttpResponse; -/// use actix_web::http::header::{CacheControl, CacheDirective}; -/// -/// let mut builder = HttpResponse::Ok(); -/// builder.insert_header(CacheControl(vec![CacheDirective::MaxAge(86400u32)])); -/// ``` -/// -/// ``` -/// use actix_web::HttpResponse; -/// use actix_web::http::header::{CacheControl, CacheDirective}; -/// -/// let mut builder = HttpResponse::Ok(); -/// builder.insert_header(CacheControl(vec![ -/// CacheDirective::NoCache, -/// CacheDirective::Private, -/// CacheDirective::MaxAge(360u32), -/// CacheDirective::Extension("foo".to_owned(), Some("bar".to_owned())), -/// ])); -/// ``` -#[derive(PartialEq, Clone, Debug)] -pub struct CacheControl(pub Vec); +common_header! { + /// `Cache-Control` header, defined + /// in [RFC 7234 §5.2](https://datatracker.ietf.org/doc/html/rfc7234#section-5.2). + /// + /// The `Cache-Control` header field is used to specify directives for + /// caches along the request/response chain. Such cache directives are + /// unidirectional in that the presence of a directive in a request does + /// not imply that the same directive is to be given in the response. + /// + /// # ABNF + /// ```text + /// Cache-Control = 1#cache-directive + /// cache-directive = token [ "=" ( token / quoted-string ) ] + /// ``` + /// + /// # Example Values + /// * `no-cache` + /// * `private, community="UCI"` + /// * `max-age=30` + /// + /// # Examples + /// ``` + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{CacheControl, CacheDirective}; + /// + /// let mut builder = HttpResponse::Ok(); + /// builder.insert_header(CacheControl(vec![CacheDirective::MaxAge(86400u32)])); + /// ``` + /// + /// ``` + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{CacheControl, CacheDirective}; + /// + /// let mut builder = HttpResponse::Ok(); + /// builder.insert_header(CacheControl(vec![ + /// CacheDirective::NoCache, + /// CacheDirective::Private, + /// CacheDirective::MaxAge(360u32), + /// CacheDirective::Extension("foo".to_owned(), Some("bar".to_owned())), + /// ])); + /// ``` + (CacheControl, header::CACHE_CONTROL) => (CacheDirective)+ -crate::http::header::common_header_deref!(CacheControl => Vec); + test_parse_and_format { + common_header_test!(no_headers, vec![b""; 0], None); + common_header_test!(empty_header, vec![b""; 1], None); + common_header_test!(bad_syntax, vec![b"foo="], None); -// TODO: this could just be the crate::http::header::common_header! macro -impl Header for CacheControl { - fn name() -> header::HeaderName { - header::CACHE_CONTROL - } + common_header_test!( + multiple_headers, + vec![&b"no-cache"[..], &b"private"[..]], + Some(CacheControl(vec![ + CacheDirective::NoCache, + CacheDirective::Private, + ])) + ); - #[inline] - fn parse(msg: &T) -> Result - where - T: crate::HttpMessage, - { - let directives = from_comma_delimited(msg.headers().get_all(&Self::name()))?; - if !directives.is_empty() { - Ok(CacheControl(directives)) - } else { - Err(crate::error::ParseError::Header) + common_header_test!( + argument, + vec![b"max-age=100, private"], + Some(CacheControl(vec![ + CacheDirective::MaxAge(100), + CacheDirective::Private, + ])) + ); + + common_header_test!( + extension, + vec![b"foo, bar=baz"], + Some(CacheControl(vec![ + CacheDirective::Extension("foo".to_owned(), None), + CacheDirective::Extension("bar".to_owned(), Some("baz".to_owned())), + ])) + ); + + #[test] + fn parse_quote_form() { + let req = test::TestRequest::default() + .insert_header((header::CACHE_CONTROL, "max-age=\"200\"")) + .finish(); + + assert_eq!( + Header::parse(&req).ok(), + Some(CacheControl(vec![CacheDirective::MaxAge(200)])) + ) } } } -impl fmt::Display for CacheControl { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt_comma_delimited(f, &self.0[..]) - } -} - -impl IntoHeaderValue for CacheControl { - type Error = header::InvalidHeaderValue; - - fn try_into_value(self) -> Result { - let mut writer = Writer::new(); - let _ = write!(&mut writer, "{}", self); - header::HeaderValue::from_maybe_shared(writer.take()) - } -} - /// `CacheControl` contains a list of these directives. -#[derive(PartialEq, Clone, Debug)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum CacheDirective { /// "no-cache" NoCache, @@ -126,38 +131,40 @@ pub enum CacheDirective { impl fmt::Display for CacheDirective { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use self::CacheDirective::*; - fmt::Display::fmt( - match *self { - NoCache => "no-cache", - NoStore => "no-store", - NoTransform => "no-transform", - OnlyIfCached => "only-if-cached", - MaxAge(secs) => return write!(f, "max-age={}", secs), - MaxStale(secs) => return write!(f, "max-stale={}", secs), - MinFresh(secs) => return write!(f, "min-fresh={}", secs), + let dir_str = match self { + NoCache => "no-cache", + NoStore => "no-store", + NoTransform => "no-transform", + OnlyIfCached => "only-if-cached", - MustRevalidate => "must-revalidate", - Public => "public", - Private => "private", - ProxyRevalidate => "proxy-revalidate", - SMaxAge(secs) => return write!(f, "s-maxage={}", secs), + MaxAge(secs) => return write!(f, "max-age={}", secs), + MaxStale(secs) => return write!(f, "max-stale={}", secs), + MinFresh(secs) => return write!(f, "min-fresh={}", secs), - Extension(ref name, None) => &name[..], - Extension(ref name, Some(ref arg)) => { - return write!(f, "{}={}", name, arg); - } - }, - f, - ) + MustRevalidate => "must-revalidate", + Public => "public", + Private => "private", + ProxyRevalidate => "proxy-revalidate", + SMaxAge(secs) => return write!(f, "s-maxage={}", secs), + + Extension(name, None) => name.as_str(), + Extension(name, Some(arg)) => return write!(f, "{}={}", name, arg), + }; + + f.write_str(dir_str) } } -impl FromStr for CacheDirective { - type Err = Option<::Err>; - fn from_str(s: &str) -> Result::Err>> { +impl str::FromStr for CacheDirective { + type Err = Option<::Err>; + + fn from_str(s: &str) -> Result { use self::CacheDirective::*; + match s { + "" => Err(None), + "no-cache" => Ok(NoCache), "no-store" => Ok(NoStore), "no-transform" => Ok(NoTransform), @@ -166,7 +173,7 @@ impl FromStr for CacheDirective { "public" => Ok(Public), "private" => Ok(Private), "proxy-revalidate" => Ok(ProxyRevalidate), - "" => Err(None), + _ => match s.find('=') { Some(idx) if idx + 1 < s.len() => { match (&s[..idx], (&s[idx + 1..]).trim_matches('"')) { @@ -183,76 +190,3 @@ impl FromStr for CacheDirective { } } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::http::header::Header; - use actix_http::test::TestRequest; - - #[test] - fn test_parse_multiple_headers() { - let req = TestRequest::default() - .insert_header((header::CACHE_CONTROL, "no-cache, private")) - .finish(); - let cache = Header::parse(&req); - assert_eq!( - cache.ok(), - Some(CacheControl(vec![ - CacheDirective::NoCache, - CacheDirective::Private, - ])) - ) - } - - #[test] - fn test_parse_argument() { - let req = TestRequest::default() - .insert_header((header::CACHE_CONTROL, "max-age=100, private")) - .finish(); - let cache = Header::parse(&req); - assert_eq!( - cache.ok(), - Some(CacheControl(vec![ - CacheDirective::MaxAge(100), - CacheDirective::Private, - ])) - ) - } - - #[test] - fn test_parse_quote_form() { - let req = TestRequest::default() - .insert_header((header::CACHE_CONTROL, "max-age=\"200\"")) - .finish(); - let cache = Header::parse(&req); - assert_eq!( - cache.ok(), - Some(CacheControl(vec![CacheDirective::MaxAge(200)])) - ) - } - - #[test] - fn test_parse_extension() { - let req = TestRequest::default() - .insert_header((header::CACHE_CONTROL, "foo, bar=baz")) - .finish(); - let cache = Header::parse(&req); - assert_eq!( - cache.ok(), - Some(CacheControl(vec![ - CacheDirective::Extension("foo".to_owned(), None), - CacheDirective::Extension("bar".to_owned(), Some("baz".to_owned())), - ])) - ) - } - - #[test] - fn test_parse_bad_syntax() { - let req = TestRequest::default() - .insert_header((header::CACHE_CONTROL, "foo=")) - .finish(); - let cache: Result = Header::parse(&req); - assert_eq!(cache.ok(), None) - } -} diff --git a/src/http/header/content_disposition.rs b/src/http/header/content_disposition.rs index 9f67baffb..945a58f7f 100644 --- a/src/http/header/content_disposition.rs +++ b/src/http/header/content_disposition.rs @@ -1,10 +1,14 @@ -//! # References +//! The `Content-Disposition` header and associated types. //! -//! "The Content-Disposition Header Field" https://www.ietf.org/rfc/rfc2183.txt -//! "The Content-Disposition Header Field in the Hypertext Transfer Protocol (HTTP)" https://www.ietf.org/rfc/rfc6266.txt -//! "Returning Values from Forms: multipart/form-data" https://www.ietf.org/rfc/rfc7578.txt -//! Browser conformance tests at: http://greenbytes.de/tech/tc2231/ -//! IANA assignment: http://www.iana.org/assignments/cont-disp/cont-disp.xhtml +//! # References +//! - "The Content-Disposition Header Field": +//! +//! - "The Content-Disposition Header Field in the Hypertext Transfer Protocol (HTTP)": +//! +//! - "Returning Values from Forms: multipart/form-data": +//! +//! - Browser conformance tests at: +//! - IANA assignment: use once_cell::sync::Lazy; use regex::Regex; @@ -34,15 +38,19 @@ fn split_once_and_trim(haystack: &str, needle: char) -> (&str, &str) { /// The implied disposition of the content of the HTTP body. #[derive(Clone, Debug, PartialEq)] pub enum DispositionType { - /// Inline implies default processing + /// Inline implies default processing. Inline, + /// Attachment implies that the recipient should prompt the user to save the response locally, /// rather than process it normally (as per its media type). Attachment, + /// Used in *multipart/form-data* as defined in - /// [RFC7578](https://tools.ietf.org/html/rfc7578) to carry the field name and the file name. + /// [RFC 7578](https://datatracker.ietf.org/doc/html/rfc7578) to carry the field name and + /// optional filename. FormData, - /// Extension type. Should be handled by recipients the same way as Attachment + + /// Extension type. Should be handled by recipients the same way as Attachment. Ext(String), } @@ -76,25 +84,32 @@ pub enum DispositionParam { /// For [`DispositionType::FormData`] (i.e. *multipart/form-data*), the name of an field from /// the form. Name(String), + /// A plain file name. /// - /// It is [not supposed](https://tools.ietf.org/html/rfc6266#appendix-D) to contain any - /// non-ASCII characters when used in a *Content-Disposition* HTTP response header, where + /// It is [not supposed](https://datatracker.ietf.org/doc/html/rfc6266#appendix-D) to contain + /// any non-ASCII characters when used in a *Content-Disposition* HTTP response header, where /// [`FilenameExt`](DispositionParam::FilenameExt) with charset UTF-8 may be used instead /// in case there are Unicode characters in file names. Filename(String), + /// An extended file name. It must not exist for `ContentType::Formdata` according to - /// [RFC7578 Section 4.2](https://tools.ietf.org/html/rfc7578#section-4.2). + /// [RFC 7578 §4.2](https://datatracker.ietf.org/doc/html/rfc7578#section-4.2). FilenameExt(ExtendedValue), + /// An unrecognized regular parameter as defined in - /// [RFC5987](https://tools.ietf.org/html/rfc5987) as *reg-parameter*, in - /// [RFC6266](https://tools.ietf.org/html/rfc6266) as *token "=" value*. Recipients should - /// ignore unrecognizable parameters. + /// [RFC 5987 §3.2.1](https://datatracker.ietf.org/doc/html/rfc5987#section-3.2.1) as + /// `reg-parameter`, in + /// [RFC 6266 §4.1](https://datatracker.ietf.org/doc/html/rfc6266#section-4.1) as + /// `token "=" value`. Recipients should ignore unrecognizable parameters. Unknown(String, String), + /// An unrecognized extended parameter as defined in - /// [RFC5987](https://tools.ietf.org/html/rfc5987) as *ext-parameter*, in - /// [RFC6266](https://tools.ietf.org/html/rfc6266) as *ext-token "=" ext-value*. The single - /// trailing asterisk is not included. Recipients should ignore unrecognizable parameters. + /// [RFC 5987 §3.2.1](https://datatracker.ietf.org/doc/html/rfc5987#section-3.2.1) as + /// `ext-parameter`, in + /// [RFC 6266 §4.1](https://datatracker.ietf.org/doc/html/rfc6266#section-4.1) as + /// `ext-token "=" ext-value`. The single trailing asterisk is not included. Recipients should + /// ignore unrecognizable parameters. UnknownExt(String, ExtendedValue), } @@ -188,10 +203,10 @@ impl DispositionParam { } /// A *Content-Disposition* header. It is compatible to be used either as -/// [a response header for the main body](https://mdn.io/Content-Disposition#As_a_response_header_for_the_main_body) -/// as (re)defined in [RFC6266](https://tools.ietf.org/html/rfc6266), or as +/// [a response header for the main body](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition#as_a_response_header_for_the_main_body) +/// as (re)defined in [RFC 6266](https://datatracker.ietf.org/doc/html/rfc6266), or as /// [a header for a multipart body](https://mdn.io/Content-Disposition#As_a_header_for_a_multipart_body) -/// as (re)defined in [RFC7587](https://tools.ietf.org/html/rfc7578). +/// as (re)defined in [RFC 7587](https://datatracker.ietf.org/doc/html/rfc7578). /// /// In a regular HTTP response, the *Content-Disposition* response header is a header indicating if /// the content is expected to be displayed *inline* in the browser, that is, as a Web page or as @@ -205,8 +220,7 @@ impl DispositionParam { /// itself, *Content-Disposition* has no effect. /// /// # ABNF - -/// ```text +/// ```plain /// content-disposition = "Content-Disposition" ":" /// disposition-type *( ";" disposition-parm ) /// @@ -227,19 +241,17 @@ impl DispositionParam { /// ``` /// /// # Note +/// *filename* is [not supposed](https://datatracker.ietf.org/doc/html/rfc6266#appendix-D) to +/// contain any non-ASCII characters when used in a *Content-Disposition* HTTP response header, +/// where filename* with charset UTF-8 may be used instead in case there are Unicode characters in +/// file names. Filename is [acceptable](https://datatracker.ietf.org/doc/html/rfc7578#section-4.2) +/// to be UTF-8 encoded directly in a *Content-Disposition* header for +/// *multipart/form-data*, though. /// -/// filename is [not supposed](https://tools.ietf.org/html/rfc6266#appendix-D) to contain any -/// non-ASCII characters when used in a *Content-Disposition* HTTP response header, where -/// filename* with charset UTF-8 may be used instead in case there are Unicode characters in file -/// names. -/// filename is [acceptable](https://tools.ietf.org/html/rfc7578#section-4.2) to be UTF-8 encoded -/// directly in a *Content-Disposition* header for *multipart/form-data*, though. -/// -/// filename* [must not](https://tools.ietf.org/html/rfc7578#section-4.2) be used within +/// *filename* [must not](https://datatracker.ietf.org/doc/html/rfc7578#section-4.2) be used within /// *multipart/form-data*. /// -/// # Example -/// +/// # Examples /// ``` /// use actix_web::http::header::{ /// Charset, ContentDisposition, DispositionParam, DispositionType, @@ -285,14 +297,16 @@ impl DispositionParam { /// ``` /// /// # Security Note -/// /// If "filename" parameter is supplied, do not use the file name blindly, check and possibly /// change to match local file system conventions if applicable, and do not use directory path -/// information that may be present. See [RFC2183](https://tools.ietf.org/html/rfc2183#section-2.3). +/// information that may be present. +/// See [RFC 2183 §2.3](https://datatracker.ietf.org/doc/html/rfc2183#section-2.3). +// TODO: think about using private fields and smallvec #[derive(Clone, Debug, PartialEq)] pub struct ContentDisposition { /// The disposition type pub disposition: DispositionType, + /// Disposition parameters pub parameters: Vec, } @@ -334,7 +348,7 @@ impl ContentDisposition { } else { // regular parameters let value = if left.starts_with('\"') { - // quoted-string: defined in RFC6266 -> RFC2616 Section 3.6 + // quoted-string: defined in RFC 6266 -> RFC 2616 Section 3.6 let mut escaping = false; let mut quoted_string = vec![]; let mut end = None; @@ -385,22 +399,22 @@ impl ContentDisposition { Ok(cd) } - /// Returns `true` if it is [`Inline`](DispositionType::Inline). + /// Returns `true` if type is [`Inline`](DispositionType::Inline). pub fn is_inline(&self) -> bool { matches!(self.disposition, DispositionType::Inline) } - /// Returns `true` if it is [`Attachment`](DispositionType::Attachment). + /// Returns `true` if type is [`Attachment`](DispositionType::Attachment). pub fn is_attachment(&self) -> bool { matches!(self.disposition, DispositionType::Attachment) } - /// Returns `true` if it is [`FormData`](DispositionType::FormData). + /// Returns `true` if type is [`FormData`](DispositionType::FormData). pub fn is_form_data(&self) -> bool { matches!(self.disposition, DispositionType::FormData) } - /// Returns `true` if it is [`Ext`](DispositionType::Ext) and the `disp_type` matches. + /// Returns `true` if type is [`Ext`](DispositionType::Ext) and the `disp_type` matches. pub fn is_ext(&self, disp_type: impl AsRef) -> bool { matches!( self.disposition, @@ -457,7 +471,7 @@ impl Header for ContentDisposition { fn parse(msg: &T) -> Result { if let Some(h) = msg.headers().get(&Self::name()) { - Self::from_raw(&h) + Self::from_raw(h) } else { Err(crate::error::ParseError::Header) } @@ -479,7 +493,9 @@ impl fmt::Display for DispositionParam { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // All ASCII control characters (0-30, 127) including horizontal tab, double quote, and // backslash should be escaped in quoted-string (i.e. "foobar"). - // Ref: RFC6266 S4.1 -> RFC2616 S3.6 + // + // Ref: RFC 6266 §4.1 -> RFC 2616 §3.6 + // // filename-parm = "filename" "=" value // value = token | quoted-string // quoted-string = ( <"> *(qdtext | quoted-pair ) <"> ) @@ -493,7 +509,7 @@ impl fmt::Display for DispositionParam { // CTL = // - // Ref: RFC7578 S4.2 -> RFC2183 S2 -> RFC2045 S5.1 + // Ref: RFC 7578 S4.2 -> RFC 2183 S2 -> RFC 2045 S5.1 // parameter := attribute "=" value // attribute := token // ; Matching of attributes @@ -509,22 +525,28 @@ impl fmt::Display for DispositionParam { // // // See also comments in test_from_raw_unnecessary_percent_decode. + static RE: Lazy = Lazy::new(|| Regex::new("[\x00-\x08\x10-\x1F\x7F\"\\\\]").unwrap()); + match self { DispositionParam::Name(ref value) => write!(f, "name={}", value), + DispositionParam::Filename(ref value) => { write!(f, "filename=\"{}\"", RE.replace_all(value, "\\$0").as_ref()) } + DispositionParam::Unknown(ref name, ref value) => write!( f, "{}=\"{}\"", name, &RE.replace_all(value, "\\$0").as_ref() ), + DispositionParam::FilenameExt(ref ext_value) => { write!(f, "filename*={}", ext_value) } + DispositionParam::UnknownExt(ref name, ref ext_value) => { write!(f, "{}*={}", name, ext_value) } @@ -732,7 +754,7 @@ mod tests { #[test] fn from_raw_with_unicode() { - /* RFC7578 Section 4.2: + /* RFC 7578 Section 4.2: Some commonly deployed systems use multipart/form-data with file names directly encoded including octets outside the US-ASCII range. The encoding used for the file names is typically UTF-8, although HTML forms will use the charset associated with the form. @@ -805,9 +827,9 @@ mod tests { #[test] fn test_from_raw_unnecessary_percent_decode() { - // In fact, RFC7578 (multipart/form-data) Section 2 and 4.2 suggests that filename with + // In fact, RFC 7578 (multipart/form-data) Section 2 and 4.2 suggests that filename with // non-ASCII characters MAY be percent-encoded. - // On the contrary, RFC6266 or other RFCs related to Content-Disposition response header + // On the contrary, RFC 6266 or other RFCs related to Content-Disposition response header // do not mention such percent-encoding. // So, it appears to be undecidable whether to percent-decode or not without // knowing the usage scenario (multipart/form-data v.s. HTTP response header) and diff --git a/src/http/header/content_language.rs b/src/http/header/content_language.rs index 604ada83c..ff317e1de 100644 --- a/src/http/header/content_language.rs +++ b/src/http/header/content_language.rs @@ -1,9 +1,10 @@ -use super::{QualityItem, CONTENT_LANGUAGE}; use language_tags::LanguageTag; -crate::http::header::common_header! { - /// `Content-Language` header, defined in - /// [RFC7231](https://tools.ietf.org/html/rfc7231#section-3.1.3.2) +use super::{common_header, QualityItem, CONTENT_LANGUAGE}; + +common_header! { + /// `Content-Language` header, defined + /// in [RFC 7231 §3.1.3.2](https://datatracker.ietf.org/doc/html/rfc7231#section-3.1.3.2) /// /// The `Content-Language` header field describes the natural language(s) /// of the intended audience for the representation. Note that this @@ -11,45 +12,42 @@ crate::http::header::common_header! { /// representation. /// /// # ABNF - /// - /// ```text + /// ```plain /// Content-Language = 1#language-tag /// ``` /// - /// # Example values - /// + /// # Example Values /// * `da` /// * `mi, en` /// /// # Examples - /// /// ``` /// use actix_web::HttpResponse; - /// use actix_web::http::header::{ContentLanguage, LanguageTag, qitem}; + /// use actix_web::http::header::{ContentLanguage, LanguageTag, QualityItem}; /// /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// ContentLanguage(vec![ - /// qitem(LanguageTag::parse("en").unwrap()), + /// QualityItem::max(LanguageTag::parse("en").unwrap()), /// ]) /// ); /// ``` /// /// ``` /// use actix_web::HttpResponse; - /// use actix_web::http::header::{ContentLanguage, LanguageTag, qitem}; + /// use actix_web::http::header::{ContentLanguage, LanguageTag, QualityItem}; /// /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// ContentLanguage(vec![ - /// qitem(LanguageTag::parse("da").unwrap()), - /// qitem(LanguageTag::parse("en-GB").unwrap()), + /// QualityItem::max(LanguageTag::parse("da").unwrap()), + /// QualityItem::max(LanguageTag::parse("en-GB").unwrap()), /// ]) /// ); /// ``` (ContentLanguage, CONTENT_LANGUAGE) => (QualityItem)+ - test_content_language { + test_parse_and_format { crate::http::header::common_header_test!(test1, vec![b"da"]); crate::http::header::common_header_test!(test2, vec![b"mi, en"]); } diff --git a/src/http/header/content_range.rs b/src/http/header/content_range.rs index 3bdead2c0..90b3f7fe2 100644 --- a/src/http/header/content_range.rs +++ b/src/http/header/content_range.rs @@ -1,15 +1,17 @@ -use std::fmt::{self, Display, Write}; -use std::str::FromStr; +use std::{ + fmt::{self, Display, Write}, + str::FromStr, +}; use super::{HeaderValue, IntoHeaderValue, InvalidHeaderValue, Writer, CONTENT_RANGE}; use crate::error::ParseError; crate::http::header::common_header! { - /// `Content-Range` header, defined in - /// [RFC7233](http://tools.ietf.org/html/rfc7233#section-4.2) + /// `Content-Range` header, defined + /// in [RFC 7233 §4.2](https://datatracker.ietf.org/doc/html/rfc7233#section-4.2) (ContentRange, CONTENT_RANGE) => [ContentRangeSpec] - test_content_range { + test_parse_and_format { crate::http::header::common_header_test!(test_bytes, vec![b"bytes 0-499/500"], Some(ContentRange(ContentRangeSpec::Bytes { @@ -69,11 +71,11 @@ crate::http::header::common_header! { } } -/// Content-Range, described in [RFC7233](https://tools.ietf.org/html/rfc7233#section-4.2) +/// Content-Range header, defined +/// in [RFC 7233 §4.2](https://datatracker.ietf.org/doc/html/rfc7233#section-4.2) /// /// # ABNF -/// -/// ```text +/// ```plain /// Content-Range = byte-content-range /// / other-content-range /// @@ -89,7 +91,7 @@ crate::http::header::common_header! { /// other-content-range = other-range-unit SP other-range-resp /// other-range-resp = *CHAR /// ``` -#[derive(PartialEq, Clone, Debug)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum ContentRangeSpec { /// Byte range Bytes { diff --git a/src/http/header/content_type.rs b/src/http/header/content_type.rs index e1c419c22..1fc75d0e2 100644 --- a/src/http/header/content_type.rs +++ b/src/http/header/content_type.rs @@ -2,8 +2,8 @@ use super::CONTENT_TYPE; use mime::Mime; crate::http::header::common_header! { - /// `Content-Type` header, defined in - /// [RFC7231](http://tools.ietf.org/html/rfc7231#section-3.1.1.5) + /// `Content-Type` header, defined + /// in [RFC 7231 §3.1.1.5](https://datatracker.ietf.org/doc/html/rfc7231#section-3.1.1.5) /// /// The `Content-Type` header field indicates the media type of the /// associated representation: either the representation enclosed in the @@ -18,18 +18,15 @@ crate::http::header::common_header! { /// this is an issue, it's possible to implement `Header` on a custom struct. /// /// # ABNF - /// - /// ```text + /// ```plain /// Content-Type = media-type /// ``` /// - /// # Example values - /// + /// # Example Values /// * `text/html; charset=utf-8` /// * `application/json` /// /// # Examples - /// /// ``` /// use actix_web::HttpResponse; /// use actix_web::http::header::ContentType; @@ -51,7 +48,7 @@ crate::http::header::common_header! { /// ``` (ContentType, CONTENT_TYPE) => [Mime] - test_content_type { + test_parse_and_format { crate::http::header::common_header_test!( test1, vec![b"text/html"], @@ -60,57 +57,56 @@ crate::http::header::common_header! { } impl ContentType { - /// A constructor to easily create a `Content-Type: application/json` + /// A constructor to easily create a `Content-Type: application/json` /// header. #[inline] pub fn json() -> ContentType { ContentType(mime::APPLICATION_JSON) } - /// A constructor to easily create a `Content-Type: text/plain; + /// A constructor to easily create a `Content-Type: text/plain; /// charset=utf-8` header. #[inline] pub fn plaintext() -> ContentType { ContentType(mime::TEXT_PLAIN_UTF_8) } - /// A constructor to easily create a `Content-Type: text/html` header. + /// A constructor to easily create a `Content-Type: text/html; charset=utf-8` + /// header. #[inline] pub fn html() -> ContentType { - ContentType(mime::TEXT_HTML) + ContentType(mime::TEXT_HTML_UTF_8) } - /// A constructor to easily create a `Content-Type: text/xml` header. + /// A constructor to easily create a `Content-Type: text/xml` header. #[inline] pub fn xml() -> ContentType { ContentType(mime::TEXT_XML) } - /// A constructor to easily create a `Content-Type: + /// A constructor to easily create a `Content-Type: /// application/www-form-url-encoded` header. #[inline] pub fn form_url_encoded() -> ContentType { ContentType(mime::APPLICATION_WWW_FORM_URLENCODED) } - /// A constructor to easily create a `Content-Type: image/jpeg` header. + /// A constructor to easily create a `Content-Type: image/jpeg` header. #[inline] pub fn jpeg() -> ContentType { ContentType(mime::IMAGE_JPEG) } - /// A constructor to easily create a `Content-Type: image/png` header. + /// A constructor to easily create a `Content-Type: image/png` header. #[inline] pub fn png() -> ContentType { ContentType(mime::IMAGE_PNG) } - /// A constructor to easily create a `Content-Type: + /// A constructor to easily create a `Content-Type: /// application/octet-stream` header. #[inline] pub fn octet_stream() -> ContentType { ContentType(mime::APPLICATION_OCTET_STREAM) } } - -impl Eq for ContentType {} diff --git a/src/http/header/date.rs b/src/http/header/date.rs index 4d1717886..4063deab1 100644 --- a/src/http/header/date.rs +++ b/src/http/header/date.rs @@ -2,19 +2,18 @@ use super::{HttpDate, DATE}; use std::time::SystemTime; crate::http::header::common_header! { - /// `Date` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-7.1.1.2) + /// `Date` header, defined + /// in [RFC 7231 §7.1.1.2](https://datatracker.ietf.org/doc/html/rfc7231#section-7.1.1.2) /// /// The `Date` header field represents the date and time at which the /// message was originated. /// /// # ABNF - /// - /// ```text + /// ```plain /// Date = HTTP-date /// ``` /// - /// # Example values - /// + /// # Example Values /// * `Tue, 15 Nov 1994 08:12:31 GMT` /// /// # Example @@ -31,7 +30,7 @@ crate::http::header::common_header! { /// ``` (Date, DATE) => [HttpDate] - test_date { + test_parse_and_format { crate::http::header::common_header_test!(test1, vec![b"Tue, 15 Nov 1994 08:12:31 GMT"]); } } diff --git a/src/http/header/encoding.rs b/src/http/header/encoding.rs index ce31c100f..a61edda67 100644 --- a/src/http/header/encoding.rs +++ b/src/http/header/encoding.rs @@ -4,26 +4,33 @@ pub use self::Encoding::{ Brotli, Chunked, Compress, Deflate, EncodingExt, Gzip, Identity, Trailers, Zstd, }; -/// A value to represent an encoding used in `Transfer-Encoding` -/// or `Accept-Encoding` header. -#[derive(Clone, PartialEq, Debug)] +/// A value to represent an encoding used in `Transfer-Encoding` or `Accept-Encoding` header. +#[derive(Debug, Clone, PartialEq, Eq)] pub enum Encoding { /// The `chunked` encoding. Chunked, + /// The `br` encoding. Brotli, + /// The `gzip` encoding. Gzip, + /// The `deflate` encoding. Deflate, + /// The `compress` encoding. Compress, + /// The `identity` encoding. Identity, + /// The `trailers` encoding. Trailers, + /// The `zstd` encoding. Zstd, + /// Some other encoding that is less common, can be any String. EncodingExt(String), } diff --git a/src/http/header/entity.rs b/src/http/header/entity.rs index 5073ed692..50b40b7b2 100644 --- a/src/http/header/entity.rs +++ b/src/http/header/entity.rs @@ -1,5 +1,7 @@ -use std::fmt::{self, Display, Write}; -use std::str::FromStr; +use std::{ + fmt::{self, Display, Write}, + str::FromStr, +}; use super::{HeaderValue, IntoHeaderValue, InvalidHeaderValue, Writer}; @@ -15,7 +17,8 @@ fn check_slice_validity(slice: &str) -> bool { slice.bytes().all(entity_validate_char) } -/// An entity tag, defined in [RFC7232](https://tools.ietf.org/html/rfc7232#section-2.3) +/// An entity tag, defined +/// in [RFC 7232 §2.3](https://datatracker.ietf.org/doc/html/rfc7232#section-2.3) /// /// An entity tag consists of a string enclosed by two literal double quotes. /// Preceding the first double quote is an optional weakness indicator, @@ -23,8 +26,7 @@ fn check_slice_validity(slice: &str) -> bool { /// `W/"xyzzy"`. /// /// # ABNF -/// -/// ```text +/// ```plain /// entity-tag = [ weak ] opaque-tag /// weak = %x57.2F ; "W/", case-sensitive /// opaque-tag = DQUOTE *etagc DQUOTE diff --git a/src/http/header/etag.rs b/src/http/header/etag.rs index aded72665..4724c917e 100644 --- a/src/http/header/etag.rs +++ b/src/http/header/etag.rs @@ -1,7 +1,8 @@ use super::{EntityTag, ETAG}; crate::http::header::common_header! { - /// `ETag` header, defined in [RFC7232](http://tools.ietf.org/html/rfc7232#section-2.3) + /// `ETag` header, defined in + /// [RFC 7232 §2.3](https://datatracker.ietf.org/doc/html/rfc7232#section-2.3) /// /// The `ETag` header field in a response provides the current entity-tag /// for the selected representation, as determined at the conclusion of @@ -14,19 +15,16 @@ crate::http::header::common_header! { /// prefixed by a weakness indicator. /// /// # ABNF - /// - /// ```text + /// ```plain /// ETag = entity-tag /// ``` /// - /// # Example values - /// + /// # Example Values /// * `"xyzzy"` /// * `W/"xyzzy"` /// * `""` /// /// # Examples - /// /// ``` /// use actix_web::HttpResponse; /// use actix_web::http::header::{ETag, EntityTag}; @@ -48,7 +46,7 @@ crate::http::header::common_header! { /// ``` (ETag, ETAG) => [EntityTag] - test_etag { + test_parse_and_format { // From the RFC crate::http::header::common_header_test!(test1, vec![b"\"xyzzy\""], diff --git a/src/http/header/expires.rs b/src/http/header/expires.rs index e810fe267..5b6c65c53 100644 --- a/src/http/header/expires.rs +++ b/src/http/header/expires.rs @@ -1,7 +1,8 @@ use super::{HttpDate, EXPIRES}; crate::http::header::common_header! { - /// `Expires` header, defined in [RFC7234](http://tools.ietf.org/html/rfc7234#section-5.3) + /// `Expires` header, defined + /// in [RFC 7234 §5.3](https://datatracker.ietf.org/doc/html/rfc7234#section-5.3) /// /// The `Expires` header field gives the date/time after which the /// response is considered stale. @@ -11,12 +12,11 @@ crate::http::header::common_header! { /// time. /// /// # ABNF - /// - /// ```text + /// ```plain /// Expires = HTTP-date /// ``` /// - /// # Example values + /// # Example Values /// * `Thu, 01 Dec 1994 16:00:00 GMT` /// /// # Example @@ -34,7 +34,7 @@ crate::http::header::common_header! { /// ``` (Expires, EXPIRES) => [HttpDate] - test_expires { + test_parse_and_format { // Test case from RFC crate::http::header::common_header_test!(test1, vec![b"Thu, 01 Dec 1994 16:00:00 GMT"]); } diff --git a/src/http/header/if_match.rs b/src/http/header/if_match.rs index 87a94a809..a565b9125 100644 --- a/src/http/header/if_match.rs +++ b/src/http/header/if_match.rs @@ -1,8 +1,8 @@ -use super::{EntityTag, IF_MATCH}; +use super::{common_header, EntityTag, IF_MATCH}; -crate::http::header::common_header! { - /// `If-Match` header, defined in - /// [RFC7232](https://tools.ietf.org/html/rfc7232#section-3.1) +common_header! { + /// `If-Match` header, defined + /// in [RFC 7232 §3.1](https://datatracker.ietf.org/doc/html/rfc7232#section-3.1) /// /// The `If-Match` header field makes the request method conditional on /// the recipient origin server either having at least one current @@ -17,18 +17,15 @@ crate::http::header::common_header! { /// there have been any changes to the representation data. /// /// # ABNF - /// - /// ```text + /// ```plain /// If-Match = "*" / 1#entity-tag /// ``` /// - /// # Example values - /// + /// # Example Values /// * `"xyzzy"` /// * "xyzzy", "r2d2xxxx", "c3piozzzz" /// /// # Examples - /// /// ``` /// use actix_web::HttpResponse; /// use actix_web::http::header::IfMatch; @@ -52,7 +49,7 @@ crate::http::header::common_header! { /// ``` (IfMatch, IF_MATCH) => {Any / (EntityTag)+} - test_if_match { + test_parse_and_format { crate::http::header::common_header_test!( test1, vec![b"\"xyzzy\""], diff --git a/src/http/header/if_modified_since.rs b/src/http/header/if_modified_since.rs index 254003523..14d6c3553 100644 --- a/src/http/header/if_modified_since.rs +++ b/src/http/header/if_modified_since.rs @@ -1,8 +1,8 @@ use super::{HttpDate, IF_MODIFIED_SINCE}; crate::http::header::common_header! { - /// `If-Modified-Since` header, defined in - /// [RFC7232](http://tools.ietf.org/html/rfc7232#section-3.3) + /// `If-Modified-Since` header, defined + /// in [RFC 7232 §3.3](https://datatracker.ietf.org/doc/html/rfc7232#section-3.3) /// /// The `If-Modified-Since` header field makes a GET or HEAD request /// method conditional on the selected representation's modification date @@ -11,12 +11,11 @@ crate::http::header::common_header! { /// data has not changed. /// /// # ABNF - /// - /// ```text + /// ```plain /// If-Unmodified-Since = HTTP-date /// ``` /// - /// # Example values + /// # Example Values /// * `Sat, 29 Oct 1994 19:43:31 GMT` /// /// # Example @@ -34,7 +33,7 @@ crate::http::header::common_header! { /// ``` (IfModifiedSince, IF_MODIFIED_SINCE) => [HttpDate] - test_if_modified_since { + test_parse_and_format { // Test case from RFC crate::http::header::common_header_test!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); } diff --git a/src/http/header/if_none_match.rs b/src/http/header/if_none_match.rs index e1422bd36..fb1895fc8 100644 --- a/src/http/header/if_none_match.rs +++ b/src/http/header/if_none_match.rs @@ -1,8 +1,8 @@ use super::{EntityTag, IF_NONE_MATCH}; crate::http::header::common_header! { - /// `If-None-Match` header, defined in - /// [RFC7232](https://tools.ietf.org/html/rfc7232#section-3.2) + /// `If-None-Match` header, defined + /// in [RFC 7232 §3.2](https://datatracker.ietf.org/doc/html/rfc7232#section-3.2) /// /// The `If-None-Match` header field makes the request method conditional /// on a recipient cache or origin server either not having any current @@ -16,13 +16,11 @@ crate::http::header::common_header! { /// the representation data. /// /// # ABNF - /// - /// ```text + /// ```plain /// If-None-Match = "*" / 1#entity-tag /// ``` /// - /// # Example values - /// + /// # Example Values /// * `"xyzzy"` /// * `W/"xyzzy"` /// * `"xyzzy", "r2d2xxxx", "c3piozzzz"` @@ -30,7 +28,6 @@ crate::http::header::common_header! { /// * `*` /// /// # Examples - /// /// ``` /// use actix_web::HttpResponse; /// use actix_web::http::header::IfNoneMatch; @@ -54,7 +51,7 @@ crate::http::header::common_header! { /// ``` (IfNoneMatch, IF_NONE_MATCH) => {Any / (EntityTag)+} - test_if_none_match { + test_parse_and_format { crate::http::header::common_header_test!(test1, vec![b"\"xyzzy\""]); crate::http::header::common_header_test!(test2, vec![b"W/\"xyzzy\""]); crate::http::header::common_header_test!(test3, vec![b"\"xyzzy\", \"r2d2xxxx\", \"c3piozzzz\""]); diff --git a/src/http/header/if_range.rs b/src/http/header/if_range.rs index cf69e7269..5af9255f6 100644 --- a/src/http/header/if_range.rs +++ b/src/http/header/if_range.rs @@ -8,7 +8,8 @@ use crate::error::ParseError; use crate::http::header; use crate::HttpMessage; -/// `If-Range` header, defined in [RFC7233](http://tools.ietf.org/html/rfc7233#section-3.2) +/// `If-Range` header, defined +/// in [RFC 7233 §3.2](https://datatracker.ietf.org/doc/html/rfc7233#section-3.2) /// /// If a client has a partial copy of a representation and wishes to have /// an up-to-date copy of the entire representation, it could use the @@ -24,18 +25,16 @@ use crate::HttpMessage; /// in Range; otherwise, send me the entire representation. /// /// # ABNF -/// -/// ```text +/// ```plain /// If-Range = entity-tag / HTTP-date /// ``` /// -/// # Example values +/// # Example Values /// /// * `Sat, 29 Oct 1994 19:43:31 GMT` /// * `\"xyzzy\"` /// /// # Examples -/// /// ``` /// use actix_web::HttpResponse; /// use actix_web::http::header::{EntityTag, IfRange}; @@ -108,10 +107,11 @@ impl IntoHeaderValue for IfRange { } #[cfg(test)] -mod test_if_range { +mod test_parse_and_format { + use std::str; + use super::IfRange as HeaderField; use crate::http::header::*; - use std::str; crate::http::header::common_header_test!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); crate::http::header::common_header_test!(test2, vec![b"\"abc\""]); diff --git a/src/http/header/if_unmodified_since.rs b/src/http/header/if_unmodified_since.rs index 1cc7b304e..0df6d7ba0 100644 --- a/src/http/header/if_unmodified_since.rs +++ b/src/http/header/if_unmodified_since.rs @@ -1,8 +1,8 @@ use super::{HttpDate, IF_UNMODIFIED_SINCE}; crate::http::header::common_header! { - /// `If-Unmodified-Since` header, defined in - /// [RFC7232](http://tools.ietf.org/html/rfc7232#section-3.4) + /// `If-Unmodified-Since` header, defined + /// in [RFC 7232 §3.4](https://datatracker.ietf.org/doc/html/rfc7232#section-3.4) /// /// The `If-Unmodified-Since` header field makes the request method /// conditional on the selected representation's last modification date @@ -11,13 +11,11 @@ crate::http::header::common_header! { /// the user agent does not have an entity-tag for the representation. /// /// # ABNF - /// - /// ```text + /// ```plain /// If-Unmodified-Since = HTTP-date /// ``` /// - /// # Example values - /// + /// # Example Values /// * `Sat, 29 Oct 1994 19:43:31 GMT` /// /// # Example @@ -35,7 +33,7 @@ crate::http::header::common_header! { /// ``` (IfUnmodifiedSince, IF_UNMODIFIED_SINCE) => [HttpDate] - test_if_unmodified_since { + test_parse_and_format { // Test case from RFC crate::http::header::common_header_test!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); } diff --git a/src/http/header/last_modified.rs b/src/http/header/last_modified.rs index c43bf3ac9..e15443ed1 100644 --- a/src/http/header/last_modified.rs +++ b/src/http/header/last_modified.rs @@ -1,8 +1,8 @@ use super::{HttpDate, LAST_MODIFIED}; crate::http::header::common_header! { - /// `Last-Modified` header, defined in - /// [RFC7232](http://tools.ietf.org/html/rfc7232#section-2.2) + /// `Last-Modified` header, defined + /// in [RFC 7232 §2.2](https://datatracker.ietf.org/doc/html/rfc7232#section-2.2) /// /// The `Last-Modified` header field in a response provides a timestamp /// indicating the date and time at which the origin server believes the @@ -10,13 +10,11 @@ crate::http::header::common_header! { /// conclusion of handling the request. /// /// # ABNF - /// - /// ```text + /// ```plain /// Expires = HTTP-date /// ``` /// - /// # Example values - /// + /// # Example Values /// * `Sat, 29 Oct 1994 19:43:31 GMT` /// /// # Example @@ -34,8 +32,8 @@ crate::http::header::common_header! { /// ``` (LastModified, LAST_MODIFIED) => [HttpDate] - test_last_modified { - // Test case from RFC - crate::http::header::common_header_test!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); - } + test_parse_and_format { + // Test case from RFC + crate::http::header::common_header_test!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); + } } diff --git a/src/http/header/macros.rs b/src/http/header/macros.rs index 419d4fb6e..3f530658c 100644 --- a/src/http/header/macros.rs +++ b/src/http/header/macros.rs @@ -1,33 +1,17 @@ -macro_rules! common_header_deref { - ($from:ty => $to:ty) => { - impl ::std::ops::Deref for $from { - type Target = $to; - - #[inline] - fn deref(&self) -> &$to { - &self.0 - } - } - - impl ::std::ops::DerefMut for $from { - #[inline] - fn deref_mut(&mut self) -> &mut $to { - &mut self.0 - } - } - }; -} - macro_rules! common_header_test_module { ($id:ident, $tm:ident{$($tf:item)*}) => { - #[allow(unused_imports)] #[cfg(test)] mod $tm { - use std::str; - use actix_http::http::Method; - use mime::*; - use $crate::http::header::*; - use super::$id as HeaderField; + #![allow(unused_imports)] + + use ::core::str; + + use ::actix_http::{http::Method, test}; + use ::mime::*; + + use $crate::http::header::{self, *}; + use super::{$id as HeaderField, *}; + $($tf)* } } @@ -38,18 +22,23 @@ macro_rules! common_header_test { ($id:ident, $raw:expr) => { #[test] fn $id() { - use actix_http::test; + use ::actix_http::test; let raw = $raw; - let a: Vec> = raw.iter().map(|x| x.to_vec()).collect(); + let headers = raw.iter().map(|x| x.to_vec()).collect::>(); + let mut req = test::TestRequest::default(); - for item in a { - req = req.insert_header((HeaderField::name(), item)).take(); + + for item in headers { + req = req.append_header((HeaderField::name(), item)).take(); } + let req = req.finish(); let value = HeaderField::parse(&req); + let result = format!("{}", value.unwrap()); - let expected = String::from_utf8(raw[0].to_vec()).unwrap(); + let expected = ::std::string::String::from_utf8(raw[0].to_vec()).unwrap(); + let result_cmp: Vec = result .to_ascii_lowercase() .split(' ') @@ -60,154 +49,181 @@ macro_rules! common_header_test { .split(' ') .map(|x| x.to_owned()) .collect(); + assert_eq!(result_cmp.concat(), expected_cmp.concat()); } }; - ($id:ident, $raw:expr, $typed:expr) => { + + ($id:ident, $raw:expr, $exp:expr) => { #[test] fn $id() { use actix_http::test; - let a: Vec> = $raw.iter().map(|x| x.to_vec()).collect(); + let headers = $raw.iter().map(|x| x.to_vec()).collect::>(); let mut req = test::TestRequest::default(); - for item in a { - req.insert_header((HeaderField::name(), item)); + + for item in headers { + req.append_header((HeaderField::name(), item)); } + let req = req.finish(); let val = HeaderField::parse(&req); - let typed: Option = $typed; - // Test parsing - assert_eq!(val.ok(), typed); - // Test formatting - if typed.is_some() { + + let exp: ::core::option::Option = $exp; + + // test parsing + assert_eq!(val.ok(), exp); + + // test formatting + if let Some(exp) = exp { let raw = &($raw)[..]; let mut iter = raw.iter().map(|b| str::from_utf8(&b[..]).unwrap()); let mut joined = String::new(); - joined.push_str(iter.next().unwrap()); - for s in iter { - joined.push_str(", "); + if let Some(s) = iter.next() { joined.push_str(s); + for s in iter { + joined.push_str(", "); + joined.push_str(s); + } } - assert_eq!(format!("{}", typed.unwrap()), joined); + assert_eq!(format!("{}", exp), joined); } } }; } macro_rules! common_header { - // $a:meta: Attributes associated with the header item (usually docs) + // TODO: these docs are wrong, there's no $n or $nn + // $attrs:meta: Attributes associated with the header item (usually docs) // $id:ident: Identifier of the header // $n:expr: Lowercase name of the header // $nn:expr: Nice name of the header // List header, zero or more items - ($(#[$a:meta])*($id:ident, $name:expr) => ($item:ty)*) => { - $(#[$a])* - #[derive(Clone, Debug, PartialEq)] + ($(#[$attrs:meta])*($id:ident, $name:expr) => ($item:ty)*) => { + $(#[$attrs])* + #[derive(Debug, Clone, PartialEq, Eq, ::derive_more::Deref, ::derive_more::DerefMut)] pub struct $id(pub Vec<$item>); - crate::http::header::common_header_deref!($id => Vec<$item>); + impl $crate::http::header::Header for $id { #[inline] fn name() -> $crate::http::header::HeaderName { $name } + #[inline] - fn parse(msg: &T) -> Result - where T: $crate::HttpMessage - { - $crate::http::header::from_comma_delimited( - msg.headers().get_all(Self::name())).map($id) + fn parse(msg: &M) -> Result { + let headers = msg.headers().get_all(Self::name()); + $crate::http::header::from_comma_delimited(headers).map($id) } } - impl std::fmt::Display for $id { + + impl ::core::fmt::Display for $id { #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> ::std::fmt::Result { + fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { $crate::http::header::fmt_comma_delimited(f, &self.0[..]) } } + impl $crate::http::header::IntoHeaderValue for $id { type Error = $crate::http::header::InvalidHeaderValue; + #[inline] fn try_into_value(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { - use std::fmt::Write; + use ::core::fmt::Write; let mut writer = $crate::http::header::Writer::new(); let _ = write!(&mut writer, "{}", self); $crate::http::header::HeaderValue::from_maybe_shared(writer.take()) } } }; + // List header, one or more items - ($(#[$a:meta])*($id:ident, $name:expr) => ($item:ty)+) => { - $(#[$a])* - #[derive(Clone, Debug, PartialEq)] + ($(#[$attrs:meta])*($id:ident, $name:expr) => ($item:ty)+) => { + $(#[$attrs])* + #[derive(Debug, Clone, PartialEq, Eq, ::derive_more::Deref, ::derive_more::DerefMut)] pub struct $id(pub Vec<$item>); - crate::http::header::common_header_deref!($id => Vec<$item>); + impl $crate::http::header::Header for $id { #[inline] fn name() -> $crate::http::header::HeaderName { $name } + #[inline] - fn parse(msg: &T) -> Result - where T: $crate::HttpMessage - { - $crate::http::header::from_comma_delimited( - msg.headers().get_all(Self::name())).map($id) + fn parse(msg: &M) -> Result{ + let headers = msg.headers().get_all(Self::name()); + + $crate::http::header::from_comma_delimited(headers) + .and_then(|items| { + if items.is_empty() { + Err($crate::error::ParseError::Header) + } else { + Ok($id(items)) + } + }) } } - impl std::fmt::Display for $id { + + impl ::core::fmt::Display for $id { #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { $crate::http::header::fmt_comma_delimited(f, &self.0[..]) } } + impl $crate::http::header::IntoHeaderValue for $id { type Error = $crate::http::header::InvalidHeaderValue; + #[inline] fn try_into_value(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { - use std::fmt::Write; + use ::core::fmt::Write; let mut writer = $crate::http::header::Writer::new(); let _ = write!(&mut writer, "{}", self); $crate::http::header::HeaderValue::from_maybe_shared(writer.take()) } } }; + // Single value header - ($(#[$a:meta])*($id:ident, $name:expr) => [$value:ty]) => { - $(#[$a])* - #[derive(Clone, Debug, PartialEq)] + ($(#[$attrs:meta])*($id:ident, $name:expr) => [$value:ty]) => { + $(#[$attrs])* + #[derive(Debug, Clone, PartialEq, Eq, ::derive_more::Deref, ::derive_more::DerefMut)] pub struct $id(pub $value); - crate::http::header::common_header_deref!($id => $value); + impl $crate::http::header::Header for $id { #[inline] fn name() -> $crate::http::header::HeaderName { $name } + #[inline] - fn parse(msg: &T) -> Result - where T: $crate::HttpMessage - { - $crate::http::header::from_one_raw_str( - msg.headers().get(Self::name())).map($id) + fn parse(msg: &M) -> Result { + let header = msg.headers().get(Self::name()); + $crate::http::header::from_one_raw_str(header).map($id) } } - impl std::fmt::Display for $id { + + impl ::core::fmt::Display for $id { #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - std::fmt::Display::fmt(&self.0, f) + fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { + ::core::fmt::Display::fmt(&self.0, f) } } + impl $crate::http::header::IntoHeaderValue for $id { type Error = $crate::http::header::InvalidHeaderValue; + #[inline] fn try_into_value(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { self.0.try_into_value() } } }; + // List header, one or more items with "*" option - ($(#[$a:meta])*($id:ident, $name:expr) => {Any / ($item:ty)+}) => { - $(#[$a])* + ($(#[$attrs:meta])*($id:ident, $name:expr) => {Any / ($item:ty)+}) => { + $(#[$attrs])* #[derive(Clone, Debug, PartialEq)] pub enum $id { /// Any value is a match @@ -215,42 +231,47 @@ macro_rules! common_header { /// Only the listed items are a match Items(Vec<$item>), } + impl $crate::http::header::Header for $id { #[inline] fn name() -> $crate::http::header::HeaderName { $name } - #[inline] - fn parse(msg: &T) -> Result - where T: $crate::HttpMessage - { - let any = msg.headers().get(Self::name()).and_then(|hdr| { - hdr.to_str().ok().and_then(|hdr| Some(hdr.trim() == "*"))}); - if let Some(true) = any { + #[inline] + fn parse(msg: &M) -> Result { + let is_any = msg + .headers() + .get(Self::name()) + .and_then(|hdr| hdr.to_str().ok()) + .map(|hdr| hdr.trim() == "*"); + + if let Some(true) = is_any { Ok($id::Any) } else { - Ok($id::Items( - $crate::http::header::from_comma_delimited( - msg.headers().get_all(Self::name()))?)) + let headers = msg.headers().get_all(Self::name()); + Ok($id::Items($crate::http::header::from_comma_delimited(headers)?)) } } } - impl std::fmt::Display for $id { + + impl ::core::fmt::Display for $id { #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { match *self { $id::Any => f.write_str("*"), - $id::Items(ref fields) => $crate::http::header::fmt_comma_delimited( - f, &fields[..]) + $id::Items(ref fields) => + $crate::http::header::fmt_comma_delimited(f, &fields[..]) } } } + impl $crate::http::header::IntoHeaderValue for $id { type Error = $crate::http::header::InvalidHeaderValue; + #[inline] fn try_into_value(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { - use std::fmt::Write; + use ::core::fmt::Write; let mut writer = $crate::http::header::Writer::new(); let _ = write!(&mut writer, "{}", self); $crate::http::header::HeaderValue::from_maybe_shared(writer.take()) @@ -259,32 +280,32 @@ macro_rules! common_header { }; // optional test module - ($(#[$a:meta])*($id:ident, $name:expr) => ($item:ty)* $tm:ident{$($tf:item)*}) => { + ($(#[$attrs:meta])*($id:ident, $name:expr) => ($item:ty)* $tm:ident{$($tf:item)*}) => { crate::http::header::common_header! { - $(#[$a])* + $(#[$attrs])* ($id, $name) => ($item)* } crate::http::header::common_header_test_module! { $id, $tm { $($tf)* }} }; - ($(#[$a:meta])*($id:ident, $n:expr) => ($item:ty)+ $tm:ident{$($tf:item)*}) => { + ($(#[$attrs:meta])*($id:ident, $n:expr) => ($item:ty)+ $tm:ident{$($tf:item)*}) => { crate::http::header::common_header! { - $(#[$a])* + $(#[$attrs])* ($id, $n) => ($item)+ } crate::http::header::common_header_test_module! { $id, $tm { $($tf)* }} }; - ($(#[$a:meta])*($id:ident, $name:expr) => [$item:ty] $tm:ident{$($tf:item)*}) => { + ($(#[$attrs:meta])*($id:ident, $name:expr) => [$item:ty] $tm:ident{$($tf:item)*}) => { crate::http::header::common_header! { - $(#[$a])* ($id, $name) => [$item] + $(#[$attrs])* ($id, $name) => [$item] } crate::http::header::common_header_test_module! { $id, $tm { $($tf)* }} }; - ($(#[$a:meta])*($id:ident, $name:expr) => {Any / ($item:ty)+} $tm:ident{$($tf:item)*}) => { + ($(#[$attrs:meta])*($id:ident, $name:expr) => {Any / ($item:ty)+} $tm:ident{$($tf:item)*}) => { crate::http::header::common_header! { - $(#[$a])* + $(#[$attrs])* ($id, $name) => {Any / ($item)+} } @@ -292,7 +313,7 @@ macro_rules! common_header { }; } -pub(crate) use {common_header, common_header_deref, common_header_test_module}; +pub(crate) use {common_header, common_header_test_module}; #[cfg(test)] pub(crate) use common_header_test; diff --git a/src/http/header/mod.rs b/src/http/header/mod.rs index 79ba5772b..07b7592d7 100644 --- a/src/http/header/mod.rs +++ b/src/http/header/mod.rs @@ -1,71 +1,25 @@ //! A Collection of Header implementations for common HTTP Headers. //! -//! ## Mime -//! +//! ## Mime Types //! Several header fields use MIME values for their contents. Keeping with the strongly-typed theme, //! the [mime] crate is used in such headers as [`ContentType`] and [`Accept`]. -use bytes::{Bytes, BytesMut}; use std::fmt; -pub use self::accept_charset::AcceptCharset; +use bytes::{Bytes, BytesMut}; + +// re-export from actix-http +// - header name / value types +// - relevant traits for converting to header name / value +// - all const header names +// - header map +// - the few typed headers from actix-http +// - header parsing utils pub use actix_http::http::header::*; -//pub use self::accept_encoding::AcceptEncoding; -pub use self::accept::Accept; -pub use self::accept_language::AcceptLanguage; -pub use self::allow::Allow; -pub use self::cache_control::{CacheControl, CacheDirective}; -pub use self::content_disposition::{ContentDisposition, DispositionParam, DispositionType}; -pub use self::content_language::ContentLanguage; -pub use self::content_range::{ContentRange, ContentRangeSpec}; -pub use self::content_type::ContentType; -pub use self::date::Date; -pub use self::encoding::Encoding; -pub use self::entity::EntityTag; -pub use self::etag::ETag; -pub use self::expires::Expires; -pub use self::if_match::IfMatch; -pub use self::if_modified_since::IfModifiedSince; -pub use self::if_none_match::IfNoneMatch; -pub use self::if_range::IfRange; -pub use self::if_unmodified_since::IfUnmodifiedSince; -pub use self::last_modified::LastModified; -//pub use self::range::{Range, ByteRangeSpec}; -pub(crate) use actix_http::http::header::{ - fmt_comma_delimited, from_comma_delimited, from_one_raw_str, -}; -#[derive(Debug, Default)] -struct Writer { - buf: BytesMut, -} - -impl Writer { - pub fn new() -> Writer { - Writer::default() - } - - pub fn take(&mut self) -> Bytes { - self.buf.split().freeze() - } -} - -impl fmt::Write for Writer { - #[inline] - fn write_str(&mut self, s: &str) -> fmt::Result { - self.buf.extend_from_slice(s.as_bytes()); - Ok(()) - } - - #[inline] - fn write_fmt(&mut self, args: fmt::Arguments<'_>) -> fmt::Result { - fmt::write(self, args) - } -} - -mod accept_charset; -// mod accept_encoding; mod accept; +mod accept_charset; +mod accept_encoding; mod accept_language; mod allow; mod cache_control; @@ -84,8 +38,65 @@ mod if_none_match; mod if_range; mod if_unmodified_since; mod last_modified; - mod macros; +mod preference; +mod range; + #[cfg(test)] pub(crate) use macros::common_header_test; -pub(crate) use macros::{common_header, common_header_deref, common_header_test_module}; +pub(crate) use macros::{common_header, common_header_test_module}; + +pub use self::accept::Accept; +pub use self::accept_charset::AcceptCharset; +pub use self::accept_encoding::AcceptEncoding; +pub use self::accept_language::AcceptLanguage; +pub use self::allow::Allow; +pub use self::cache_control::{CacheControl, CacheDirective}; +pub use self::content_disposition::{ContentDisposition, DispositionParam, DispositionType}; +pub use self::content_language::ContentLanguage; +pub use self::content_range::{ContentRange, ContentRangeSpec}; +pub use self::content_type::ContentType; +pub use self::date::Date; +pub use self::encoding::Encoding; +pub use self::entity::EntityTag; +pub use self::etag::ETag; +pub use self::expires::Expires; +pub use self::if_match::IfMatch; +pub use self::if_modified_since::IfModifiedSince; +pub use self::if_none_match::IfNoneMatch; +pub use self::if_range::IfRange; +pub use self::if_unmodified_since::IfUnmodifiedSince; +pub use self::last_modified::LastModified; +pub use self::preference::Preference; +pub use self::range::{ByteRangeSpec, Range}; + +/// Format writer ([`fmt::Write`]) for a [`BytesMut`]. +#[derive(Debug, Default)] +struct Writer { + buf: BytesMut, +} + +impl Writer { + /// Constructs new bytes writer. + pub fn new() -> Writer { + Writer::default() + } + + /// Splits bytes out of writer, leaving writer buffer empty. + pub fn take(&mut self) -> Bytes { + self.buf.split().freeze() + } +} + +impl fmt::Write for Writer { + #[inline] + fn write_str(&mut self, s: &str) -> fmt::Result { + self.buf.extend_from_slice(s.as_bytes()); + Ok(()) + } + + #[inline] + fn write_fmt(&mut self, args: fmt::Arguments<'_>) -> fmt::Result { + fmt::write(self, args) + } +} diff --git a/src/http/header/preference.rs b/src/http/header/preference.rs new file mode 100644 index 000000000..979fc7720 --- /dev/null +++ b/src/http/header/preference.rs @@ -0,0 +1,70 @@ +use std::{ + fmt::{self, Write as _}, + str, +}; + +/// A wrapper for types used in header values where wildcard (`*`) items are allowed but the +/// underlying type does not support them. +/// +/// For example, we use the `language-tags` crate for the [`AcceptLanguage`](super::AcceptLanguage) +/// typed header but it does not parse `*` successfully. On the other hand, the `mime` crate, used +/// for [`Accept`](super::Accept), has first-party support for wildcard items so this wrapper is not +/// used in those header types. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Hash)] +pub enum Preference { + /// A wildcard value. + Any, + + /// A valid `T`. + Specific(T), +} + +impl Preference { + /// Returns true if preference is the any/wildcard (`*`) value. + pub fn is_any(&self) -> bool { + matches!(self, Self::Any) + } + + /// Returns true if preference is the specific item (`T`) variant. + pub fn is_specific(&self) -> bool { + matches!(self, Self::Specific(_)) + } + + /// Returns reference to value in `Specific` variant, if it is set. + pub fn item(&self) -> Option<&T> { + match self { + Preference::Specific(ref item) => Some(item), + Preference::Any => None, + } + } + + /// Consumes the container, returning the value in the `Specific` variant, if it is set. + pub fn into_item(self) -> Option { + match self { + Preference::Specific(item) => Some(item), + Preference::Any => None, + } + } +} + +impl fmt::Display for Preference { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Preference::Any => f.write_char('*'), + Preference::Specific(item) => fmt::Display::fmt(item, f), + } + } +} + +impl str::FromStr for Preference { + type Err = T::Err; + + #[inline] + fn from_str(s: &str) -> Result { + match s.trim() { + "*" => Ok(Self::Any), + other => other.parse().map(Preference::Specific), + } + } +} diff --git a/src/http/header/range.rs b/src/http/header/range.rs index a9b40b403..c1d60f1ee 100644 --- a/src/http/header/range.rs +++ b/src/http/header/range.rs @@ -1,20 +1,23 @@ -use std::fmt::{self, Display}; -use std::str::FromStr; +use std::{ + cmp, + fmt::{self, Display, Write}, + str::FromStr, +}; -use super::parsing::from_one_raw_str; -use super::{Header, Raw}; +use actix_http::{error::ParseError, header, HttpMessage}; -/// `Range` header, defined in [RFC7233](https://tools.ietf.org/html/rfc7233#section-3.1) +use super::{Header, HeaderName, HeaderValue, IntoHeaderValue, InvalidHeaderValue, Writer}; + +/// `Range` header, defined +/// in [RFC 7233 §3.1](https://datatracker.ietf.org/doc/html/rfc7233#section-3.1) /// -/// The "Range" header field on a GET request modifies the method -/// semantics to request transfer of only one or more sub-ranges of the -/// selected representation data, rather than the entire selected +/// The "Range" header field on a GET request modifies the method semantics to request transfer of +/// only one or more sub-ranges of the selected representation data, rather than the entire selected /// representation data. /// /// # ABNF -/// -/// ```text -/// Range = byte-ranges-specifier / other-ranges-specifier +/// ```plain +/// Range = byte-ranges-specifier / other-ranges-specifier /// other-ranges-specifier = other-range-unit "=" other-range-set /// other-range-set = 1*VCHAR /// @@ -23,121 +26,126 @@ use super::{Header, Raw}; /// byte-ranges-specifier = bytes-unit "=" byte-range-set /// byte-range-set = 1#(byte-range-spec / suffix-byte-range-spec) /// byte-range-spec = first-byte-pos "-" [last-byte-pos] +/// suffix-byte-range-spec = "-" suffix-length +/// suffix-length = 1*DIGIT /// first-byte-pos = 1*DIGIT /// last-byte-pos = 1*DIGIT /// ``` /// -/// # Example values -/// +/// # Example Values /// * `bytes=1000-` -/// * `bytes=-2000` +/// * `bytes=-50` /// * `bytes=0-1,30-40` /// * `bytes=0-10,20-90,-100` /// * `custom_unit=0-123` /// * `custom_unit=xxx-yyy` /// /// # Examples -/// /// ``` -/// use hyper::header::{Headers, Range, ByteRangeSpec}; +/// use actix_web::http::header::{Range, ByteRangeSpec}; +/// use actix_web::HttpResponse; /// -/// let mut headers = Headers::new(); -/// headers.set(Range::Bytes( -/// vec![ByteRangeSpec::FromTo(1, 100), ByteRangeSpec::AllFrom(200)] +/// let mut builder = HttpResponse::Ok(); +/// builder.insert_header(Range::Bytes( +/// vec![ByteRangeSpec::FromTo(1, 100), ByteRangeSpec::From(200)] /// )); -/// -/// headers.clear(); -/// headers.set(Range::Unregistered("letters".to_owned(), "a-f".to_owned())); -/// ``` -/// -/// ``` -/// use hyper::header::{Headers, Range}; -/// -/// let mut headers = Headers::new(); -/// headers.set(Range::bytes(1, 100)); -/// -/// headers.clear(); -/// headers.set(Range::bytes_multi(vec![(1, 100), (200, 300)])); +/// builder.insert_header(Range::Unregistered("letters".to_owned(), "a-f".to_owned())); +/// builder.insert_header(Range::bytes(1, 100)); +/// builder.insert_header(Range::bytes_multi(vec![(1, 100), (200, 300)])); /// ``` #[derive(PartialEq, Clone, Debug)] pub enum Range { - /// Byte range + /// Byte range. Bytes(Vec), - /// Custom range, with unit not registered at IANA + + /// Custom range, with unit not registered at IANA. + /// /// (`other-range-unit`: String , `other-range-set`: String) Unregistered(String, String), } -/// Each `Range::Bytes` header can contain one or more `ByteRangeSpecs`. -/// Each `ByteRangeSpec` defines a range of bytes to fetch -#[derive(PartialEq, Clone, Debug)] +/// A range of bytes to fetch. +/// +/// Each [`Range::Bytes`] header can contain one or more `ByteRangeSpec`s. +#[derive(Debug, Clone, PartialEq, Eq)] pub enum ByteRangeSpec { - /// Get all bytes between x and y ("x-y") + /// All bytes from `x` to `y`, inclusive. + /// + /// Serialized as `x-y`. + /// + /// Example: `bytes=500-999` would represent the second 500 bytes. FromTo(u64, u64), - /// Get all bytes starting from x ("x-") - AllFrom(u64), - /// Get last x bytes ("-x") + + /// All bytes starting from `x`, inclusive. + /// + /// Serialized as `x-`. + /// + /// Example: For a file of 1000 bytes, `bytes=950-` would represent bytes 950-999, inclusive. + From(u64), + + /// The last `y` bytes, inclusive. + /// + /// Using the spec terminology, this is `suffix-byte-range-spec`. Serialized as `-y`. + /// + /// Example: For a file of 1000 bytes, `bytes=-50` is equivalent to `bytes=950-`. Last(u64), } impl ByteRangeSpec { - /// Given the full length of the entity, attempt to normalize the byte range - /// into an satisfiable end-inclusive (from, to) range. + /// Given the full length of the entity, attempt to normalize the byte range into an satisfiable + /// end-inclusive `(from, to)` range. /// - /// The resulting range is guaranteed to be a satisfiable range within the - /// bounds of `0 <= from <= to < full_length`. + /// The resulting range is guaranteed to be a satisfiable range within the bounds + /// of `0 <= from <= to < full_length`. /// - /// If the byte range is deemed unsatisfiable, `None` is returned. - /// An unsatisfiable range is generally cause for a server to either reject - /// the client request with a `416 Range Not Satisfiable` status code, or to - /// simply ignore the range header and serve the full entity using a `200 - /// OK` status code. + /// If the byte range is deemed unsatisfiable, `None` is returned. An unsatisfiable range is + /// generally cause for a server to either reject the client request with a + /// `416 Range Not Satisfiable` status code, or to simply ignore the range header and serve the + /// full entity using a `200 OK` status code. /// - /// This function closely follows [RFC 7233][1] section 2.1. - /// As such, it considers ranges to be satisfiable if they meet the - /// following conditions: + /// This function closely follows [RFC 7233 §2.1]. As such, it considers ranges to be + /// satisfiable if they meet the following conditions: /// - /// > If a valid byte-range-set includes at least one byte-range-spec with - /// a first-byte-pos that is less than the current length of the - /// representation, or at least one suffix-byte-range-spec with a - /// non-zero suffix-length, then the byte-range-set is satisfiable. - /// Otherwise, the byte-range-set is unsatisfiable. + /// > If a valid byte-range-set includes at least one byte-range-spec with a first-byte-pos that + /// is less than the current length of the representation, or at least one + /// suffix-byte-range-spec with a non-zero suffix-length, then the byte-range-set + /// is satisfiable. Otherwise, the byte-range-set is unsatisfiable. /// /// The function also computes remainder ranges based on the RFC: /// - /// > If the last-byte-pos value is - /// absent, or if the value is greater than or equal to the current - /// length of the representation data, the byte range is interpreted as - /// the remainder of the representation (i.e., the server replaces the - /// value of last-byte-pos with a value that is one less than the current - /// length of the selected representation). + /// > If the last-byte-pos value is absent, or if the value is greater than or equal to the + /// current length of the representation data, the byte range is interpreted as the remainder + /// of the representation (i.e., the server replaces the value of last-byte-pos with a value + /// that is one less than the current length of the selected representation). /// - /// [1]: https://tools.ietf.org/html/rfc7233 + /// [RFC 7233 §2.1]: https://datatracker.ietf.org/doc/html/rfc7233 pub fn to_satisfiable_range(&self, full_length: u64) -> Option<(u64, u64)> { // If the full length is zero, there is no satisfiable end-inclusive range. if full_length == 0 { return None; } - match self { - &ByteRangeSpec::FromTo(from, to) => { + + match *self { + ByteRangeSpec::FromTo(from, to) => { if from < full_length && from <= to { - Some((from, ::std::cmp::min(to, full_length - 1))) + Some((from, cmp::min(to, full_length - 1))) } else { None } } - &ByteRangeSpec::AllFrom(from) => { + + ByteRangeSpec::From(from) => { if from < full_length { Some((from, full_length - 1)) } else { None } } - &ByteRangeSpec::Last(last) => { + + ByteRangeSpec::Last(last) => { if last > 0 { - // From the RFC: If the selected representation is shorter - // than the specified suffix-length, - // the entire representation is used. + // From the RFC: If the selected representation is shorter than the specified + // suffix-length, the entire representation is used. if last > full_length { Some((0, full_length - 1)) } else { @@ -152,48 +160,53 @@ impl ByteRangeSpec { } impl Range { - /// Get the most common byte range header ("bytes=from-to") + /// Constructs a common byte range header. + /// + /// Eg: `bytes=from-to` pub fn bytes(from: u64, to: u64) -> Range { Range::Bytes(vec![ByteRangeSpec::FromTo(from, to)]) } - /// Get byte range header with multiple subranges - /// ("bytes=from1-to1,from2-to2,fromX-toX") + /// Constructs a byte range header with multiple subranges. + /// + /// Eg: `bytes=from1-to1,from2-to2,fromX-toX` pub fn bytes_multi(ranges: Vec<(u64, u64)>) -> Range { Range::Bytes( ranges - .iter() - .map(|r| ByteRangeSpec::FromTo(r.0, r.1)) + .into_iter() + .map(|(from, to)| ByteRangeSpec::FromTo(from, to)) .collect(), ) } } impl fmt::Display for ByteRangeSpec { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { ByteRangeSpec::FromTo(from, to) => write!(f, "{}-{}", from, to), ByteRangeSpec::Last(pos) => write!(f, "-{}", pos), - ByteRangeSpec::AllFrom(pos) => write!(f, "{}-", pos), + ByteRangeSpec::From(pos) => write!(f, "{}-", pos), } } } impl fmt::Display for Range { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Range::Bytes(ref ranges) => { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Range::Bytes(ranges) => { write!(f, "bytes=")?; for (i, range) in ranges.iter().enumerate() { if i != 0 { f.write_str(",")?; } + Display::fmt(range, f)?; } Ok(()) } - Range::Unregistered(ref unit, ref range_str) => { + + Range::Unregistered(unit, range_str) => { write!(f, "{}={}", unit, range_str) } } @@ -201,53 +214,77 @@ impl fmt::Display for Range { } impl FromStr for Range { - type Err = ::Error; + type Err = ParseError; - fn from_str(s: &str) -> ::Result { - let mut iter = s.splitn(2, '='); + fn from_str(s: &str) -> Result { + let (unit, val) = s.split_once('=').ok_or(ParseError::Header)?; - match (iter.next(), iter.next()) { - (Some("bytes"), Some(ranges)) => { + match (unit, val) { + ("bytes", ranges) => { let ranges = from_comma_delimited(ranges); + if ranges.is_empty() { - return Err(::Error::Header); + return Err(ParseError::Header); } + Ok(Range::Bytes(ranges)) } - (Some(unit), Some(range_str)) if unit != "" && range_str != "" => { - Ok(Range::Unregistered(unit.to_owned(), range_str.to_owned())) - } - _ => Err(::Error::Header), + + (_, "") => Err(ParseError::Header), + ("", _) => Err(ParseError::Header), + + (unit, range_str) => Ok(Range::Unregistered(unit.to_owned(), range_str.to_owned())), } } } impl FromStr for ByteRangeSpec { - type Err = ::Error; + type Err = ParseError; - fn from_str(s: &str) -> ::Result { - let mut parts = s.splitn(2, '-'); + fn from_str(s: &str) -> Result { + let (start, end) = s.split_once('-').ok_or(ParseError::Header)?; - match (parts.next(), parts.next()) { - (Some(""), Some(end)) => end + match (start, end) { + ("", end) => end .parse() - .or(Err(::Error::Header)) + .or(Err(ParseError::Header)) .map(ByteRangeSpec::Last), - (Some(start), Some("")) => start + + (start, "") => start .parse() - .or(Err(::Error::Header)) - .map(ByteRangeSpec::AllFrom), - (Some(start), Some(end)) => match (start.parse(), end.parse()) { - (Ok(start), Ok(end)) if start <= end => { - Ok(ByteRangeSpec::FromTo(start, end)) - } - _ => Err(::Error::Header), + .or(Err(ParseError::Header)) + .map(ByteRangeSpec::From), + + (start, end) => match (start.parse(), end.parse()) { + (Ok(start), Ok(end)) if start <= end => Ok(ByteRangeSpec::FromTo(start, end)), + _ => Err(ParseError::Header), }, - _ => Err(::Error::Header), } } } +impl Header for Range { + fn name() -> HeaderName { + header::RANGE + } + + #[inline] + fn parse(msg: &T) -> Result { + header::from_one_raw_str(msg.headers().get(&Self::name())) + } +} + +impl IntoHeaderValue for Range { + type Error = InvalidHeaderValue; + + fn try_into_value(self) -> Result { + let mut wrt = Writer::new(); + let _ = write!(wrt, "{}", self); + HeaderValue::from_maybe_shared(wrt.take()) + } +} + +/// Parses 0 or more items out of a comma delimited string, ignoring invalid items. fn from_comma_delimited(s: &str) -> Vec { s.split(',') .filter_map(|x| match x.trim() { @@ -258,45 +295,37 @@ fn from_comma_delimited(s: &str) -> Vec { .collect() } -impl Header for Range { - fn header_name() -> &'static str { - static NAME: &'static str = "Range"; - NAME - } - - fn parse_header(raw: &Raw) -> ::Result { - from_one_raw_str(raw) - } - - fn fmt_header(&self, f: &mut ::header::Formatter) -> fmt::Result { - f.fmt_line(self) - } -} - #[cfg(test)] mod tests { + use actix_http::{test::TestRequest, Request}; + use super::*; + fn req(s: &str) -> Request { + TestRequest::default() + .insert_header((header::RANGE, s)) + .finish() + } + #[test] fn test_parse_bytes_range_valid() { - let r: Range = Header::parse_header(&"bytes=1-100".into()).unwrap(); - let r2: Range = Header::parse_header(&"bytes=1-100,-".into()).unwrap(); + let r: Range = Header::parse(&req("bytes=1-100")).unwrap(); + let r2: Range = Header::parse(&req("bytes=1-100,-")).unwrap(); let r3 = Range::bytes(1, 100); assert_eq!(r, r2); assert_eq!(r2, r3); - let r: Range = Header::parse_header(&"bytes=1-100,200-".into()).unwrap(); - let r2: Range = - Header::parse_header(&"bytes= 1-100 , 101-xxx, 200- ".into()).unwrap(); + let r: Range = Header::parse(&req("bytes=1-100,200-")).unwrap(); + let r2: Range = Header::parse(&req("bytes= 1-100 , 101-xxx, 200- ")).unwrap(); let r3 = Range::Bytes(vec![ ByteRangeSpec::FromTo(1, 100), - ByteRangeSpec::AllFrom(200), + ByteRangeSpec::From(200), ]); assert_eq!(r, r2); assert_eq!(r2, r3); - let r: Range = Header::parse_header(&"bytes=1-100,-100".into()).unwrap(); - let r2: Range = Header::parse_header(&"bytes=1-100, ,,-100".into()).unwrap(); + let r: Range = Header::parse(&req("bytes=1-100,-100")).unwrap(); + let r2: Range = Header::parse(&req("bytes=1-100, ,,-100")).unwrap(); let r3 = Range::Bytes(vec![ ByteRangeSpec::FromTo(1, 100), ByteRangeSpec::Last(100), @@ -304,71 +333,65 @@ mod tests { assert_eq!(r, r2); assert_eq!(r2, r3); - let r: Range = Header::parse_header(&"custom=1-100,-100".into()).unwrap(); + let r: Range = Header::parse(&req("custom=1-100,-100")).unwrap(); let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned()); assert_eq!(r, r2); } #[test] fn test_parse_unregistered_range_valid() { - let r: Range = Header::parse_header(&"custom=1-100,-100".into()).unwrap(); + let r: Range = Header::parse(&req("custom=1-100,-100")).unwrap(); let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned()); assert_eq!(r, r2); - let r: Range = Header::parse_header(&"custom=abcd".into()).unwrap(); + let r: Range = Header::parse(&req("custom=abcd")).unwrap(); let r2 = Range::Unregistered("custom".to_owned(), "abcd".to_owned()); assert_eq!(r, r2); - let r: Range = Header::parse_header(&"custom=xxx-yyy".into()).unwrap(); + let r: Range = Header::parse(&req("custom=xxx-yyy")).unwrap(); let r2 = Range::Unregistered("custom".to_owned(), "xxx-yyy".to_owned()); assert_eq!(r, r2); } #[test] fn test_parse_invalid() { - let r: ::Result = Header::parse_header(&"bytes=1-a,-".into()); + let r: Result = Header::parse(&req("bytes=1-a,-")); assert_eq!(r.ok(), None); - let r: ::Result = Header::parse_header(&"bytes=1-2-3".into()); + let r: Result = Header::parse(&req("bytes=1-2-3")); assert_eq!(r.ok(), None); - let r: ::Result = Header::parse_header(&"abc".into()); + let r: Result = Header::parse(&req("abc")); assert_eq!(r.ok(), None); - let r: ::Result = Header::parse_header(&"bytes=1-100=".into()); + let r: Result = Header::parse(&req("bytes=1-100=")); assert_eq!(r.ok(), None); - let r: ::Result = Header::parse_header(&"bytes=".into()); + let r: Result = Header::parse(&req("bytes=")); assert_eq!(r.ok(), None); - let r: ::Result = Header::parse_header(&"custom=".into()); + let r: Result = Header::parse(&req("custom=")); assert_eq!(r.ok(), None); - let r: ::Result = Header::parse_header(&"=1-100".into()); + let r: Result = Header::parse(&req("=1-100")); assert_eq!(r.ok(), None); } #[test] fn test_fmt() { - use header::Headers; - - let mut headers = Headers::new(); - - headers.set(Range::Bytes(vec![ + let range = Range::Bytes(vec![ ByteRangeSpec::FromTo(0, 1000), - ByteRangeSpec::AllFrom(2000), - ])); - assert_eq!(&headers.to_string(), "Range: bytes=0-1000,2000-\r\n"); + ByteRangeSpec::From(2000), + ]); + assert_eq!(&range.to_string(), "bytes=0-1000,2000-"); - headers.clear(); - headers.set(Range::Bytes(vec![])); + let range = Range::Bytes(vec![]); - assert_eq!(&headers.to_string(), "Range: bytes=\r\n"); + assert_eq!(&range.to_string(), "bytes="); - headers.clear(); - headers.set(Range::Unregistered("custom".to_owned(), "1-xxx".to_owned())); + let range = Range::Unregistered("custom".to_owned(), "1-xxx".to_owned()); - assert_eq!(&headers.to_string(), "Range: custom=1-xxx\r\n"); + assert_eq!(&range.to_string(), "custom=1-xxx"); } #[test] @@ -389,17 +412,11 @@ mod tests { assert_eq!(None, ByteRangeSpec::FromTo(2, 1).to_satisfiable_range(3)); assert_eq!(None, ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(0)); - assert_eq!( - Some((0, 2)), - ByteRangeSpec::AllFrom(0).to_satisfiable_range(3) - ); - assert_eq!( - Some((2, 2)), - ByteRangeSpec::AllFrom(2).to_satisfiable_range(3) - ); - assert_eq!(None, ByteRangeSpec::AllFrom(3).to_satisfiable_range(3)); - assert_eq!(None, ByteRangeSpec::AllFrom(5).to_satisfiable_range(3)); - assert_eq!(None, ByteRangeSpec::AllFrom(0).to_satisfiable_range(0)); + assert_eq!(Some((0, 2)), ByteRangeSpec::From(0).to_satisfiable_range(3)); + assert_eq!(Some((2, 2)), ByteRangeSpec::From(2).to_satisfiable_range(3)); + assert_eq!(None, ByteRangeSpec::From(3).to_satisfiable_range(3)); + assert_eq!(None, ByteRangeSpec::From(5).to_satisfiable_range(3)); + assert_eq!(None, ByteRangeSpec::From(0).to_satisfiable_range(0)); assert_eq!(Some((1, 2)), ByteRangeSpec::Last(2).to_satisfiable_range(3)); assert_eq!(Some((2, 2)), ByteRangeSpec::Last(1).to_satisfiable_range(3)); diff --git a/src/info.rs b/src/info.rs index de8ad67ee..d928a1e63 100644 --- a/src/info.rs +++ b/src/info.rs @@ -209,7 +209,6 @@ impl ConnectionInfo { impl FromRequest for ConnectionInfo { type Error = Infallible; type Future = Ready>; - type Config = (); fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { ok(req.connection_info().clone()) @@ -252,7 +251,6 @@ impl ResponseError for MissingPeerAddr {} impl FromRequest for PeerAddr { type Error = MissingPeerAddr; type Future = Ready>; - type Config = (); fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { match req.peer_addr() { diff --git a/src/lib.rs b/src/lib.rs index 714c759cf..f6ec4082a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -53,7 +53,7 @@ //! * SSL support using OpenSSL or Rustls //! * Middlewares ([Logger, Session, CORS, etc](https://actix.rs/docs/middleware/)) //! * Includes an async [HTTP client](https://docs.rs/awc/) -//! * Runs on stable Rust 1.46+ +//! * Runs on stable Rust 1.52+ //! //! # Crate Features //! * `cookies` - cookies support (enabled by default) @@ -96,7 +96,6 @@ pub mod test; pub(crate) mod types; pub mod web; -pub use actix_http::Response as BaseHttpResponse; pub use actix_http::{body, HttpMessage}; #[doc(inline)] pub use actix_rt as rt; @@ -116,3 +115,5 @@ pub use crate::scope::Scope; pub use crate::server::HttpServer; // TODO: is exposing the error directly really needed pub use crate::types::{Either, EitherExtractError}; + +pub(crate) type BoxError = Box; diff --git a/src/middleware/compat.rs b/src/middleware/compat.rs index 0a6256fe2..e6ef1806f 100644 --- a/src/middleware/compat.rs +++ b/src/middleware/compat.rs @@ -1,15 +1,15 @@ //! For middleware documentation, see [`Compat`]. use std::{ - error::Error as StdError, future::Future, pin::Pin, task::{Context, Poll}, }; -use actix_http::body::{Body, MessageBody}; +use actix_http::body::MessageBody; use actix_service::{Service, Transform}; use futures_core::{future::LocalBoxFuture, ready}; +use pin_project_lite::pin_project; use crate::{error::Error, service::ServiceResponse}; @@ -89,10 +89,11 @@ where } } -#[pin_project::pin_project] -pub struct CompatMiddlewareFuture { - #[pin] - fut: Fut, +pin_project! { + pub struct CompatMiddlewareFuture { + #[pin] + fut: Fut, + } } impl Future for CompatMiddlewareFuture @@ -121,10 +122,9 @@ pub trait MapServiceResponseBody { impl MapServiceResponseBody for ServiceResponse where B: MessageBody + Unpin + 'static, - B::Error: Into>, { fn map_body(self) -> ServiceResponse { - self.map_body(|_, body| Body::from_message(body)) + self.map_into_boxed_body() } } diff --git a/src/middleware/compress.rs b/src/middleware/compress.rs index a9128bc47..d017e9a5a 100644 --- a/src/middleware/compress.rs +++ b/src/middleware/compress.rs @@ -2,27 +2,29 @@ use std::{ cmp, + convert::TryFrom, future::Future, marker::PhantomData, pin::Pin, - str::FromStr, task::{Context, Poll}, }; use actix_http::{ - body::{MessageBody, ResponseBody}, + body::{EitherBody, MessageBody}, encoding::Encoder, http::header::{ContentEncoding, ACCEPT_ENCODING}, + StatusCode, }; use actix_service::{Service, Transform}; -use actix_utils::future::{ok, Ready}; +use actix_utils::future::{ok, Either, Ready}; use futures_core::ready; -use pin_project::pin_project; +use once_cell::sync::Lazy; +use pin_project_lite::pin_project; use crate::{ dev::BodyEncoding, service::{ServiceRequest, ServiceResponse}, - Error, + Error, HttpResponse, }; /// Middleware for compressing response payloads. @@ -59,7 +61,7 @@ where B: MessageBody, S: Service, Error = Error>, { - type Response = ServiceResponse>>; + type Response = ServiceResponse>>; type Error = Error; type Transform = CompressMiddleware; type InitError = (); @@ -78,48 +80,93 @@ pub struct CompressMiddleware { encoding: ContentEncoding, } +static SUPPORTED_ALGORITHM_NAMES: Lazy = Lazy::new(|| { + #[allow(unused_mut)] // only unused when no compress features enabled + let mut encoding: Vec<&str> = vec![]; + + #[cfg(feature = "compress-brotli")] + { + encoding.push("br"); + } + + #[cfg(feature = "compress-gzip")] + { + encoding.push("gzip"); + encoding.push("deflate"); + } + + #[cfg(feature = "compress-zstd")] + encoding.push("zstd"); + + assert!( + !encoding.is_empty(), + "encoding can not be empty unless __compress feature has been explicitly enabled by itself" + ); + + encoding.join(", ") +}); + impl Service for CompressMiddleware where - B: MessageBody, S: Service, Error = Error>, + B: MessageBody, { - type Response = ServiceResponse>>; + type Response = ServiceResponse>>; type Error = Error; - type Future = CompressResponse; + type Future = Either, Ready>>; actix_service::forward_ready!(service); #[allow(clippy::borrow_interior_mutable_const)] fn call(&self, req: ServiceRequest) -> Self::Future { // negotiate content-encoding - let encoding = if let Some(val) = req.headers().get(&ACCEPT_ENCODING) { - if let Ok(enc) = val.to_str() { - AcceptEncoding::parse(enc, self.encoding) - } else { - ContentEncoding::Identity - } - } else { - ContentEncoding::Identity - }; + let encoding_result = req + .headers() + .get(&ACCEPT_ENCODING) + .and_then(|val| val.to_str().ok()) + .map(|enc| AcceptEncoding::try_parse(enc, self.encoding)); - CompressResponse { - encoding, - fut: self.service.call(req), - _phantom: PhantomData, + match encoding_result { + // Missing header => fallback to identity + None => Either::left(CompressResponse { + encoding: ContentEncoding::Identity, + fut: self.service.call(req), + _phantom: PhantomData, + }), + + // Valid encoding + Some(Ok(encoding)) => Either::left(CompressResponse { + encoding, + fut: self.service.call(req), + _phantom: PhantomData, + }), + + // There is an HTTP header but we cannot match what client as asked for + Some(Err(_)) => { + let res = HttpResponse::with_body( + StatusCode::NOT_ACCEPTABLE, + SUPPORTED_ALGORITHM_NAMES.clone(), + ); + + Either::right(ok(req + .into_response(res) + .map_into_boxed_body() + .map_into_right_body())) + } } } } -#[pin_project] -pub struct CompressResponse -where - S: Service, - B: MessageBody, -{ - #[pin] - fut: S::Future, - encoding: ContentEncoding, - _phantom: PhantomData, +pin_project! { + pub struct CompressResponse + where + S: Service, + { + #[pin] + fut: S::Future, + encoding: ContentEncoding, + _phantom: PhantomData, + } } impl Future for CompressResponse @@ -127,7 +174,7 @@ where B: MessageBody, S: Service, Error = Error>, { - type Output = Result>>, Error>; + type Output = Result>>, Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); @@ -141,16 +188,18 @@ where }; Poll::Ready(Ok(resp.map_body(move |head, body| { - Encoder::response(enc, head, ResponseBody::Body(body)) + EitherBody::left(Encoder::response(enc, head, body)) }))) } - Err(e) => Poll::Ready(Err(e)), + + Err(err) => Poll::Ready(Err(err)), } } } struct AcceptEncoding { encoding: ContentEncoding, + // TODO: use Quality or QualityItem quality: f64, } @@ -177,42 +226,149 @@ impl PartialOrd for AcceptEncoding { impl PartialEq for AcceptEncoding { fn eq(&self, other: &AcceptEncoding) -> bool { - self.quality == other.quality + self.encoding == other.encoding && self.quality == other.quality } } +/// Parse q-factor from quality strings. +/// +/// If parse fail, then fallback to default value which is 1. +/// More details available here: +fn parse_quality(parts: &[&str]) -> f64 { + for part in parts { + if part.trim().starts_with("q=") { + return part[2..].parse().unwrap_or(1.0); + } + } + + 1.0 +} + +#[derive(Debug, PartialEq, Eq)] +enum AcceptEncodingError { + /// This error occurs when client only support compressed response and server do not have any + /// algorithm that match client accepted algorithms. + CompressionAlgorithmMismatch, +} + impl AcceptEncoding { fn new(tag: &str) -> Option { let parts: Vec<&str> = tag.split(';').collect(); let encoding = match parts.len() { 0 => return None, - _ => ContentEncoding::from(parts[0]), - }; - let quality = match parts.len() { - 1 => encoding.quality(), - _ => f64::from_str(parts[1]).unwrap_or(0.0), + _ => match ContentEncoding::try_from(parts[0]) { + Err(_) => return None, + Ok(x) => x, + }, }; + + let quality = parse_quality(&parts[1..]); + if quality <= 0.0 || quality > 1.0 { + return None; + } + Some(AcceptEncoding { encoding, quality }) } - /// Parse a raw Accept-Encoding header value into an ordered list. - pub fn parse(raw: &str, encoding: ContentEncoding) -> ContentEncoding { + /// Parse a raw Accept-Encoding header value into an ordered list then return the best match + /// based on middleware configuration. + pub fn try_parse( + raw: &str, + encoding: ContentEncoding, + ) -> Result { let mut encodings = raw .replace(' ', "") .split(',') - .filter_map(|l| AcceptEncoding::new(l)) + .filter_map(AcceptEncoding::new) .collect::>(); encodings.sort(); for enc in encodings { - if encoding == ContentEncoding::Auto { - return enc.encoding; - } else if encoding == enc.encoding { - return encoding; + if encoding == ContentEncoding::Auto || encoding == enc.encoding { + return Ok(enc.encoding); } } - ContentEncoding::Identity + // Special case if user cannot accept uncompressed data. + // See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding + // TODO: account for whitespace + if raw.contains("*;q=0") || raw.contains("identity;q=0") { + return Err(AcceptEncodingError::CompressionAlgorithmMismatch); + } + + Ok(ContentEncoding::Identity) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! assert_parse_eq { + ($raw:expr, $result:expr) => { + assert_eq!( + AcceptEncoding::try_parse($raw, ContentEncoding::Auto), + Ok($result) + ); + }; + } + + macro_rules! assert_parse_fail { + ($raw:expr) => { + assert!(AcceptEncoding::try_parse($raw, ContentEncoding::Auto).is_err()); + }; + } + + #[test] + fn test_parse_encoding() { + // Test simple case + assert_parse_eq!("br", ContentEncoding::Br); + assert_parse_eq!("gzip", ContentEncoding::Gzip); + assert_parse_eq!("deflate", ContentEncoding::Deflate); + assert_parse_eq!("zstd", ContentEncoding::Zstd); + + // Test space, trim, missing values + assert_parse_eq!("br,,,,", ContentEncoding::Br); + assert_parse_eq!("gzip , br, zstd", ContentEncoding::Gzip); + + // Test float number parsing + assert_parse_eq!("br;q=1 ,", ContentEncoding::Br); + assert_parse_eq!("br;q=1.0 , br", ContentEncoding::Br); + + // Test wildcard + assert_parse_eq!("*", ContentEncoding::Identity); + assert_parse_eq!("*;q=1.0", ContentEncoding::Identity); + } + + #[test] + fn test_parse_encoding_qfactor_ordering() { + assert_parse_eq!("gzip, br, zstd", ContentEncoding::Gzip); + assert_parse_eq!("zstd, br, gzip", ContentEncoding::Zstd); + + assert_parse_eq!("gzip;q=0.4, br;q=0.6", ContentEncoding::Br); + assert_parse_eq!("gzip;q=0.8, br;q=0.4", ContentEncoding::Gzip); + } + + #[test] + fn test_parse_encoding_qfactor_invalid() { + // Out of range + assert_parse_eq!("gzip;q=-5.0", ContentEncoding::Identity); + assert_parse_eq!("gzip;q=5.0", ContentEncoding::Identity); + + // Disabled + assert_parse_eq!("gzip;q=0", ContentEncoding::Identity); + } + + #[test] + fn test_parse_compression_required() { + // Check we fallback to identity if there is an unsupported compression algorithm + assert_parse_eq!("compress", ContentEncoding::Identity); + + // User do not want any compression + assert_parse_fail!("compress, identity;q=0"); + assert_parse_fail!("compress, identity;q=0.0"); + assert_parse_fail!("compress, *;q=0"); + assert_parse_fail!("compress, *;q=0.0"); } } diff --git a/src/middleware/default_headers.rs b/src/middleware/default_headers.rs index d8a947aab..426810247 100644 --- a/src/middleware/default_headers.rs +++ b/src/middleware/default_headers.rs @@ -11,6 +11,7 @@ use std::{ use actix_utils::future::{ready, Ready}; use futures_core::ready; +use pin_project_lite::pin_project; use crate::{ dev::{Service, Transform}, @@ -153,12 +154,13 @@ where } } -#[pin_project::pin_project] -pub struct DefaultHeaderFuture, B> { - #[pin] - fut: S::Future, - inner: Rc, - _body: PhantomData, +pin_project! { + pub struct DefaultHeaderFuture, B> { + #[pin] + fut: S::Future, + inner: Rc, + _body: PhantomData, + } } impl Future for DefaultHeaderFuture diff --git a/src/middleware/err_handlers.rs b/src/middleware/err_handlers.rs index 75cc819bc..1a834c1e8 100644 --- a/src/middleware/err_handlers.rs +++ b/src/middleware/err_handlers.rs @@ -10,6 +10,7 @@ use std::{ use actix_service::{Service, Transform}; use ahash::AHashMap; use futures_core::{future::LocalBoxFuture, ready}; +use pin_project_lite::pin_project; use crate::{ dev::{ServiceRequest, ServiceResponse}, @@ -130,19 +131,21 @@ where } } -#[pin_project::pin_project(project = ErrorHandlersProj)] -pub enum ErrorHandlersFuture -where - Fut: Future, -{ - ServiceFuture { - #[pin] - fut: Fut, - handlers: Handlers, - }, - HandlerFuture { - fut: LocalBoxFuture<'static, Fut::Output>, - }, +pin_project! { + #[project = ErrorHandlersProj] + pub enum ErrorHandlersFuture + where + Fut: Future, + { + ServiceFuture { + #[pin] + fut: Fut, + handlers: Handlers, + }, + HandlerFuture { + fut: LocalBoxFuture<'static, Fut::Output>, + }, + } } impl Future for ErrorHandlersFuture diff --git a/src/middleware/logger.rs b/src/middleware/logger.rs index bbb0e3dc4..f89b13a1c 100644 --- a/src/middleware/logger.rs +++ b/src/middleware/logger.rs @@ -13,16 +13,17 @@ use std::{ }; use actix_service::{Service, Transform}; -use actix_utils::future::{ok, Ready}; +use actix_utils::future::{ready, Ready}; use bytes::Bytes; use futures_core::ready; use log::{debug, warn}; +use pin_project_lite::pin_project; use regex::{Regex, RegexSet}; -use time::OffsetDateTime; +use time::{format_description::well_known::Rfc3339, OffsetDateTime}; use crate::{ - dev::{BodySize, MessageBody}, - http::{HeaderName, StatusCode}, + body::{BodySize, MessageBody}, + http::HeaderName, service::{ServiceRequest, ServiceResponse}, Error, HttpResponse, Result, }; @@ -180,8 +181,8 @@ where { type Response = ServiceResponse>; type Error = Error; - type InitError = (); type Transform = LoggerMiddleware; + type InitError = (); type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { @@ -195,10 +196,10 @@ where } } - ok(LoggerMiddleware { + ready(Ok(LoggerMiddleware { service, inner: self.0.clone(), - }) + })) } } @@ -246,17 +247,18 @@ where } } -#[pin_project::pin_project] -pub struct LoggerResponse -where - B: MessageBody, - S: Service, -{ - #[pin] - fut: S::Future, - time: OffsetDateTime, - format: Option, - _phantom: PhantomData, +pin_project! { + pub struct LoggerResponse + where + B: MessageBody, + S: Service, + { + #[pin] + fut: S::Future, + time: OffsetDateTime, + format: Option, + _phantom: PhantomData, + } } impl Future for LoggerResponse @@ -275,9 +277,7 @@ where }; if let Some(error) = res.response().error() { - if res.response().head().status != StatusCode::INTERNAL_SERVER_ERROR { - debug!("Error in response: {:?}", error); - } + debug!("Error in response: {:?}", error); } if let Some(ref mut format) = this.format { @@ -298,28 +298,25 @@ where } } -use pin_project::{pin_project, pinned_drop}; - -#[pin_project(PinnedDrop)] -pub struct StreamLog { - #[pin] - body: B, - format: Option, - size: usize, - time: OffsetDateTime, -} - -#[pinned_drop] -impl PinnedDrop for StreamLog { - fn drop(self: Pin<&mut Self>) { - if let Some(ref format) = self.format { - let render = |fmt: &mut fmt::Formatter<'_>| { - for unit in &format.0 { - unit.render(fmt, self.size, self.time)?; - } - Ok(()) - }; - log::info!("{}", FormatDisplay(&render)); +pin_project! { + pub struct StreamLog { + #[pin] + body: B, + format: Option, + size: usize, + time: OffsetDateTime, + } + impl PinnedDrop for StreamLog { + fn drop(this: Pin<&mut Self>) { + if let Some(ref format) = this.format { + let render = |fmt: &mut fmt::Formatter<'_>| { + for unit in &format.0 { + unit.render(fmt, this.size, this.time)?; + } + Ok(()) + }; + log::info!("{}", FormatDisplay(&render)); + } } } } @@ -341,7 +338,6 @@ where ) -> Poll>> { let this = self.project(); - // TODO: MSRV 1.51: poll_map_err match ready!(this.body.poll_next(cx)) { Some(Ok(chunk)) => { *this.size += chunk.len(); @@ -539,7 +535,7 @@ impl FormatText { }; } FormatText::UrlPath => *self = FormatText::Str(req.path().to_string()), - FormatText::RequestTime => *self = FormatText::Str(now.format("%Y-%m-%dT%H:%M:%S")), + FormatText::RequestTime => *self = FormatText::Str(now.format(&Rfc3339).unwrap()), FormatText::RequestHeader(ref name) => { let s = if let Some(val) = req.headers().get(name) { if let Ok(s) = val.to_str() { @@ -553,7 +549,7 @@ impl FormatText { *self = FormatText::Str(s.to_string()); } FormatText::RemoteAddr => { - let s = if let Some(ref peer) = req.connection_info().remote_addr() { + let s = if let Some(peer) = req.connection_info().remote_addr() { FormatText::Str((*peer).to_string()) } else { FormatText::Str("-".to_string()) @@ -768,7 +764,7 @@ mod tests { Ok(()) }; let s = format!("{}", FormatDisplay(&render)); - assert!(s.contains(&now.format("%Y-%m-%dT%H:%M:%S"))); + assert!(s.contains(&now.format(&Rfc3339).unwrap())); } #[actix_rt::test] diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 96a361fcf..d19cb64e9 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -19,3 +19,43 @@ mod compress; #[cfg(feature = "__compress")] pub use self::compress::Compress; + +#[cfg(test)] +mod tests { + use crate::{http::StatusCode, App}; + + use super::*; + + #[test] + fn common_combinations() { + // ensure there's no reason that the built-in middleware cannot compose + + let _ = App::new() + .wrap(Compat::new(Logger::default())) + .wrap(Condition::new(true, DefaultHeaders::new())) + .wrap(DefaultHeaders::new().header("X-Test2", "X-Value2")) + .wrap(ErrorHandlers::new().handler(StatusCode::FORBIDDEN, |res| { + Ok(ErrorHandlerResponse::Response(res)) + })) + .wrap(Logger::default()) + .wrap(NormalizePath::new(TrailingSlash::Trim)); + + let _ = App::new() + .wrap(NormalizePath::new(TrailingSlash::Trim)) + .wrap(Logger::default()) + .wrap(ErrorHandlers::new().handler(StatusCode::FORBIDDEN, |res| { + Ok(ErrorHandlerResponse::Response(res)) + })) + .wrap(DefaultHeaders::new().header("X-Test2", "X-Value2")) + .wrap(Condition::new(true, DefaultHeaders::new())) + .wrap(Compat::new(Logger::default())); + + #[cfg(feature = "__compress")] + { + let _ = App::new().wrap(Compress::default()).wrap(Logger::default()); + let _ = App::new().wrap(Logger::default()).wrap(Compress::default()); + let _ = App::new().wrap(Compat::new(Compress::default())); + let _ = App::new().wrap(Condition::new(true, Compat::new(Compress::default()))); + } + } +} diff --git a/src/middleware/normalize.rs b/src/middleware/normalize.rs index 219af1c6a..8ad0bb3f0 100644 --- a/src/middleware/normalize.rs +++ b/src/middleware/normalize.rs @@ -59,7 +59,7 @@ impl Default for TrailingSlash { /// /// # actix_web::rt::System::new().block_on(async { /// let app = App::new() -/// .wrap(middleware::NormalizePath::default()) +/// .wrap(middleware::NormalizePath::trim()) /// .route("/test", web::get().to(|| async { "test" })) /// .route("/unmatchable/", web::get().to(|| async { "unmatchable" })); /// @@ -85,13 +85,31 @@ impl Default for TrailingSlash { /// assert_eq!(res.status(), StatusCode::NOT_FOUND); /// # }) /// ``` -#[derive(Debug, Clone, Copy, Default)] +#[derive(Debug, Clone, Copy)] pub struct NormalizePath(TrailingSlash); +impl Default for NormalizePath { + fn default() -> Self { + log::warn!( + "`NormalizePath::default()` is deprecated. The default trailing slash behavior changed \ + in v4 from `Always` to `Trim`. Update your call to `NormalizePath::new(...)`." + ); + + Self(TrailingSlash::Trim) + } +} + impl NormalizePath { /// Create new `NormalizePath` middleware with the specified trailing slash style. pub fn new(trailing_slash_style: TrailingSlash) -> Self { - NormalizePath(trailing_slash_style) + Self(trailing_slash_style) + } + + /// Constructs a new `NormalizePath` middleware with [trim](TrailingSlash::Trim) semantics. + /// + /// Use this instead of `NormalizePath::default()` to avoid deprecation warning. + pub fn trim() -> Self { + Self::new(TrailingSlash::Trim) } } diff --git a/src/request.rs b/src/request.rs index 4b950e758..58222da47 100644 --- a/src/request.rs +++ b/src/request.rs @@ -100,7 +100,7 @@ impl HttpRequest { &self.head().headers } - /// The target path of this Request. + /// The target path of this request. #[inline] pub fn path(&self) -> &str { self.head().uri.path() @@ -108,18 +108,22 @@ impl HttpRequest { /// The query string in the URL. /// - /// E.g., id=10 + /// Example: `id=10` #[inline] pub fn query_string(&self) -> &str { self.uri().query().unwrap_or_default() } - /// Get a reference to the Path parameters. + /// Returns a reference to the URL parameters container. /// - /// Params is a container for url parameters. - /// A variable segment is specified in the form `{identifier}`, - /// where the identifier can be used later in a request handler to - /// access the matched value for that segment. + /// A url parameter is specified in the form `{identifier}`, where the identifier can be used + /// later in a request handler to access the matched value for that parameter. + /// + /// # Percent Encoding and URL Parameters + /// Because each URL parameter is able to capture multiple path segments, both `["%2F", "%25"]` + /// found in the request URI are not decoded into `["/", "%"]` in order to preserve path + /// segment boundaries. If a url parameter is expected to contain these characters, then it is + /// on the user to decode them. #[inline] pub fn match_info(&self) -> &Path { &self.inner.path @@ -161,30 +165,36 @@ impl HttpRequest { self.head().extensions_mut() } - /// Generate url for named resource + /// Generates URL for a named resource. /// + /// This substitutes in sequence all URL parameters that appear in the resource itself and in + /// parent [scopes](crate::web::scope), if any. + /// + /// It is worth noting that the characters `['/', '%']` are not escaped and therefore a single + /// URL parameter may expand into multiple path segments and `elements` can be percent-encoded + /// beforehand without worrying about double encoding. Any other character that is not valid in + /// a URL path context is escaped using percent-encoding. + /// + /// # Examples /// ``` /// # use actix_web::{web, App, HttpRequest, HttpResponse}; - /// # /// fn index(req: HttpRequest) -> HttpResponse { - /// let url = req.url_for("foo", &["1", "2", "3"]); // <- generate url for "foo" resource + /// let url = req.url_for("foo", &["1", "2", "3"]); // <- generate URL for "foo" resource /// HttpResponse::Ok().into() /// } /// - /// fn main() { - /// let app = App::new() - /// .service(web::resource("/test/{one}/{two}/{three}") - /// .name("foo") // <- set resource name, then it could be used in `url_for` - /// .route(web::get().to(|| HttpResponse::Ok())) - /// ); - /// } + /// let app = App::new() + /// .service(web::resource("/test/{one}/{two}/{three}") + /// .name("foo") // <- set resource name so it can be used in `url_for` + /// .route(web::get().to(|| HttpResponse::Ok())) + /// ); /// ``` pub fn url_for(&self, name: &str, elements: U) -> Result where U: IntoIterator, I: AsRef, { - self.resource_map().url_for(&self, name, elements) + self.resource_map().url_for(self, name, elements) } /// Generate url for named resource @@ -196,10 +206,10 @@ impl HttpRequest { self.url_for(name, &NO_PARAMS) } - #[inline] /// Get a reference to a `ResourceMap` of current application. + #[inline] pub fn resource_map(&self) -> &ResourceMap { - &self.app_state().rmap() + self.app_state().rmap() } /// Peer socket address. @@ -358,7 +368,6 @@ impl Drop for HttpRequest { /// } /// ``` impl FromRequest for HttpRequest { - type Config = (); type Error = Error; type Future = Ready>; @@ -509,9 +518,9 @@ mod tests { #[test] fn test_url_for() { let mut res = ResourceDef::new("/user/{name}.{ext}"); - *res.name_mut() = "index".to_string(); + res.set_name("index"); - let mut rmap = ResourceMap::new(ResourceDef::new("")); + let mut rmap = ResourceMap::new(ResourceDef::prefix("")); rmap.add(&mut res, None); assert!(rmap.has_resource("/user/test.html")); assert!(!rmap.has_resource("/test/unknown")); @@ -539,9 +548,9 @@ mod tests { #[test] fn test_url_for_static() { let mut rdef = ResourceDef::new("/index.html"); - *rdef.name_mut() = "index".to_string(); + rdef.set_name("index"); - let mut rmap = ResourceMap::new(ResourceDef::new("")); + let mut rmap = ResourceMap::new(ResourceDef::prefix("")); rmap.add(&mut rdef, None); assert!(rmap.has_resource("/index.html")); @@ -560,9 +569,9 @@ mod tests { #[test] fn test_match_name() { let mut rdef = ResourceDef::new("/index.html"); - *rdef.name_mut() = "index".to_string(); + rdef.set_name("index"); - let mut rmap = ResourceMap::new(ResourceDef::new("")); + let mut rmap = ResourceMap::new(ResourceDef::prefix("")); rmap.add(&mut rdef, None); assert!(rmap.has_resource("/index.html")); @@ -579,11 +588,10 @@ mod tests { fn test_url_for_external() { let mut rdef = ResourceDef::new("https://youtube.com/watch/{video_id}"); - *rdef.name_mut() = "youtube".to_string(); + rdef.set_name("youtube"); - let mut rmap = ResourceMap::new(ResourceDef::new("")); + let mut rmap = ResourceMap::new(ResourceDef::prefix("")); rmap.add(&mut rdef, None); - assert!(rmap.has_resource("https://youtube.com/watch/unknown")); let req = TestRequest::default().rmap(rmap).to_http_request(); let url = req.url_for("youtube", &["oHg5SJYRHA0"]); diff --git a/src/request_data.rs b/src/request_data.rs index 581943015..575dc1eb3 100644 --- a/src/request_data.rs +++ b/src/request_data.rs @@ -64,7 +64,6 @@ impl Deref for ReqData { } impl FromRequest for ReqData { - type Config = (); type Error = Error; type Future = Ready>; diff --git a/src/resource.rs b/src/resource.rs index 20d1ee17e..fc417bac2 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -1,32 +1,29 @@ -use std::cell::RefCell; -use std::fmt; -use std::future::Future; -use std::rc::Rc; +use std::{cell::RefCell, fmt, future::Future, rc::Rc}; use actix_http::Extensions; -use actix_router::IntoPattern; -use actix_service::boxed::{self, BoxService, BoxServiceFactory}; +use actix_router::{IntoPatterns, Patterns}; use actix_service::{ - apply, apply_fn_factory, fn_service, IntoServiceFactory, Service, ServiceFactory, + apply, apply_fn_factory, boxed, fn_service, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt, Transform, }; use futures_core::future::LocalBoxFuture; use futures_util::future::join_all; use crate::{ + body::MessageBody, data::Data, - dev::{insert_leading_slash, AppService, HttpServiceFactory, ResourceDef}, + dev::{ensure_leading_slash, AppService, ResourceDef}, guard::Guard, handler::Handler, responder::Responder, route::{Route, RouteService}, - service::{ServiceRequest, ServiceResponse}, - Error, FromRequest, HttpResponse, + service::{ + BoxedHttpService, BoxedHttpServiceFactory, HttpServiceFactory, ServiceRequest, + ServiceResponse, + }, + BoxError, Error, FromRequest, HttpResponse, }; -type HttpService = BoxService; -type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; - /// *Resource* is an entry in resources table which corresponds to requested URL. /// /// Resource in turn has at least one route. @@ -51,17 +48,17 @@ type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Err /// Default behavior could be overridden with `default_resource()` method. pub struct Resource { endpoint: T, - rdef: Vec, + rdef: Patterns, name: Option, routes: Vec, app_data: Option, guards: Vec>, - default: HttpNewService, + default: BoxedHttpServiceFactory, factory_ref: Rc>>, } impl Resource { - pub fn new(path: T) -> Resource { + pub fn new(path: T) -> Resource { let fref = Rc::new(RefCell::new(None)); Resource { @@ -242,6 +239,8 @@ where I: FromRequest + 'static, R: Future + 'static, R::Output: Responder + 'static, + ::Body: MessageBody, + <::Body as MessageBody>::Error: Into, { self.routes.push(Route::new().to(handler)); self @@ -391,13 +390,13 @@ where }; let mut rdef = if config.is_root() || !self.rdef.is_empty() { - ResourceDef::new(insert_leading_slash(self.rdef.clone())) + ResourceDef::new(ensure_leading_slash(self.rdef.clone())) } else { ResourceDef::new(self.rdef.clone()) }; if let Some(ref name) = self.name { - *rdef.name_mut() = name.clone(); + rdef.set_name(name); } *self.factory_ref.borrow_mut() = Some(ResourceFactory { @@ -422,7 +421,7 @@ where pub struct ResourceFactory { routes: Vec, - default: HttpNewService, + default: BoxedHttpServiceFactory, } impl ServiceFactory for ResourceFactory { @@ -454,7 +453,7 @@ impl ServiceFactory for ResourceFactory { pub struct ResourceService { routes: Vec, - default: HttpService, + default: BoxedHttpService, } impl Service for ResourceService { diff --git a/src/responder.rs b/src/responder.rs index c5852a501..9d8a0e8ed 100644 --- a/src/responder.rs +++ b/src/responder.rs @@ -1,19 +1,21 @@ use std::borrow::Cow; use actix_http::{ - body::Body, + body::{BoxBody, EitherBody, MessageBody}, http::{header::IntoHeaderPair, Error as HttpError, HeaderMap, StatusCode}, }; use bytes::{Bytes, BytesMut}; -use crate::{Error, HttpRequest, HttpResponse, HttpResponseBuilder}; +use crate::{BoxError, Error, HttpRequest, HttpResponse, HttpResponseBuilder}; /// Trait implemented by types that can be converted to an HTTP response. /// /// Any types that implement this trait can be used in the return type of a handler. pub trait Responder { + type Body: MessageBody + 'static; + /// Convert self to `HttpResponse`. - fn respond_to(self, req: &HttpRequest) -> HttpResponse; + fn respond_to(self, req: &HttpRequest) -> HttpResponse; /// Override a status code for a Responder. /// @@ -59,38 +61,52 @@ pub trait Responder { } impl Responder for HttpResponse { + type Body = BoxBody; + #[inline] - fn respond_to(self, _: &HttpRequest) -> HttpResponse { + fn respond_to(self, _: &HttpRequest) -> HttpResponse { self } } -impl Responder for actix_http::Response { +impl Responder for actix_http::Response { + type Body = BoxBody; + #[inline] - fn respond_to(self, _: &HttpRequest) -> HttpResponse { + fn respond_to(self, _: &HttpRequest) -> HttpResponse { HttpResponse::from(self) } } impl Responder for HttpResponseBuilder { + type Body = BoxBody; + #[inline] - fn respond_to(mut self, _: &HttpRequest) -> HttpResponse { + fn respond_to(mut self, _: &HttpRequest) -> HttpResponse { self.finish() } } impl Responder for actix_http::ResponseBuilder { + type Body = BoxBody; + #[inline] - fn respond_to(mut self, _: &HttpRequest) -> HttpResponse { - HttpResponse::from(self.finish()) + fn respond_to(mut self, req: &HttpRequest) -> HttpResponse { + self.finish().map_into_boxed_body().respond_to(req) } } -impl Responder for Option { - fn respond_to(self, req: &HttpRequest) -> HttpResponse { +impl Responder for Option +where + T: Responder, + ::Error: Into, +{ + type Body = EitherBody; + + fn respond_to(self, req: &HttpRequest) -> HttpResponse { match self { - Some(val) => val.respond_to(req), - None => HttpResponse::new(StatusCode::NOT_FOUND), + Some(val) => val.respond_to(req).map_into_left_body(), + None => HttpResponse::new(StatusCode::NOT_FOUND).map_into_right_body(), } } } @@ -98,47 +114,69 @@ impl Responder for Option { impl Responder for Result where T: Responder, + ::Error: Into, E: Into, { - fn respond_to(self, req: &HttpRequest) -> HttpResponse { + type Body = EitherBody; + + fn respond_to(self, req: &HttpRequest) -> HttpResponse { match self { - Ok(val) => val.respond_to(req), - Err(e) => HttpResponse::from_error(e.into()), + Ok(val) => val.respond_to(req).map_into_left_body(), + Err(err) => HttpResponse::from_error(err.into()).map_into_right_body(), } } } impl Responder for (T, StatusCode) { - fn respond_to(self, req: &HttpRequest) -> HttpResponse { + type Body = T::Body; + + fn respond_to(self, req: &HttpRequest) -> HttpResponse { let mut res = self.0.respond_to(req); *res.status_mut() = self.1; res } } -macro_rules! impl_responder { - ($res: ty, $ct: path) => { +macro_rules! impl_responder_by_forward_into_base_response { + ($res:ty, $body:ty) => { impl Responder for $res { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { - HttpResponse::Ok().content_type($ct).body(self) + type Body = $body; + + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + let res: actix_http::Response<_> = self.into(); + res.into() + } + } + }; + + ($res:ty) => { + impl_responder_by_forward_into_base_response!($res, $res); + }; +} + +impl_responder_by_forward_into_base_response!(&'static [u8]); +impl_responder_by_forward_into_base_response!(Bytes); +impl_responder_by_forward_into_base_response!(BytesMut); + +impl_responder_by_forward_into_base_response!(&'static str); +impl_responder_by_forward_into_base_response!(String); + +macro_rules! impl_into_string_responder { + ($res:ty) => { + impl Responder for $res { + type Body = String; + + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + let string: String = self.into(); + let res: actix_http::Response<_> = string.into(); + res.into() } } }; } -impl_responder!(&'static str, mime::TEXT_PLAIN_UTF_8); - -impl_responder!(String, mime::TEXT_PLAIN_UTF_8); - -impl_responder!(&'_ String, mime::TEXT_PLAIN_UTF_8); - -impl_responder!(Cow<'_, str>, mime::TEXT_PLAIN_UTF_8); - -impl_responder!(&'static [u8], mime::APPLICATION_OCTET_STREAM); - -impl_responder!(Bytes, mime::APPLICATION_OCTET_STREAM); - -impl_responder!(BytesMut, mime::APPLICATION_OCTET_STREAM); +impl_into_string_responder!(&'_ String); +impl_into_string_responder!(Cow<'_, str>); /// Allows overriding status code and headers for a responder. pub struct CustomResponder { @@ -204,11 +242,17 @@ impl CustomResponder { } } -impl Responder for CustomResponder { - fn respond_to(self, req: &HttpRequest) -> HttpResponse { +impl Responder for CustomResponder +where + T: Responder, + ::Error: Into, +{ + type Body = EitherBody; + + fn respond_to(self, req: &HttpRequest) -> HttpResponse { let headers = match self.headers { Ok(headers) => headers, - Err(err) => return HttpResponse::from_error(Error::from(err)), + Err(err) => return HttpResponse::from_error(err).map_into_right_body(), }; let mut res = self.responder.respond_to(req); @@ -222,7 +266,7 @@ impl Responder for CustomResponder { res.headers_mut().insert(k, v); } - res + res.map_into_left_body() } } @@ -231,11 +275,15 @@ pub(crate) mod tests { use actix_service::Service; use bytes::{Bytes, BytesMut}; + use actix_http::body::to_bytes; + use super::*; - use crate::dev::{Body, ResponseBody}; - use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; - use crate::test::{init_service, TestRequest}; - use crate::{error, web, App}; + use crate::{ + error, + http::{header::CONTENT_TYPE, HeaderValue, StatusCode}, + test::{assert_body_eq, init_service, TestRequest}, + web, App, + }; #[actix_rt::test] async fn test_option_responder() { @@ -253,133 +301,116 @@ pub(crate) mod tests { let req = TestRequest::with_uri("/some").to_request(); let resp = srv.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); - match resp.response().body() { - Body::Bytes(ref b) => { - let bytes = b.clone(); - assert_eq!(bytes, Bytes::from_static(b"some")); - } - _ => panic!(), - } - } - - pub(crate) trait BodyTest { - fn bin_ref(&self) -> &[u8]; - fn body(&self) -> &Body; - } - - impl BodyTest for Body { - fn bin_ref(&self) -> &[u8] { - match self { - Body::Bytes(ref bin) => &bin, - _ => unreachable!("bug in test impl"), - } - } - fn body(&self) -> &Body { - self - } - } - - impl BodyTest for ResponseBody { - fn bin_ref(&self) -> &[u8] { - match self { - ResponseBody::Body(ref b) => match b { - Body::Bytes(ref bin) => &bin, - _ => unreachable!("bug in test impl"), - }, - ResponseBody::Other(ref b) => match b { - Body::Bytes(ref bin) => &bin, - _ => unreachable!("bug in test impl"), - }, - } - } - fn body(&self) -> &Body { - match self { - ResponseBody::Body(ref b) => b, - ResponseBody::Other(ref b) => b, - } - } + assert_body_eq!(resp, b"some"); } #[actix_rt::test] async fn test_responder() { let req = TestRequest::default().to_http_request(); - let resp = "test".respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); + let res = "test".respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), + res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("text/plain; charset=utf-8") ); - - let resp = b"test".respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + + let res = b"test".respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("application/octet-stream") ); - - let resp = "test".to_string().respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8") + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), ); - let resp = (&"test".to_string()).respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); + let res = "test".to_string().respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), + res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("text/plain; charset=utf-8") ); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + + let res = (&"test".to_string()).respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); let s = String::from("test"); - let resp = Cow::Borrowed(s.as_str()).respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); + let res = Cow::Borrowed(s.as_str()).respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), + res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("text/plain; charset=utf-8") ); - - let resp = Cow::<'_, str>::Owned(s).respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8") + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), ); - let resp = Cow::Borrowed("test").respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); + let res = Cow::<'_, str>::Owned(s).respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), + res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("text/plain; charset=utf-8") ); - - let resp = Bytes::from_static(b"test").respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + + let res = Cow::Borrowed("test").respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + + let res = Bytes::from_static(b"test").respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("application/octet-stream") ); - - let resp = BytesMut::from(b"test".as_ref()).respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + + let res = BytesMut::from(b"test".as_ref()).respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("application/octet-stream") ); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); // InternalError - let resp = error::InternalError::new("err", StatusCode::BAD_REQUEST).respond_to(&req); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let res = error::InternalError::new("err", StatusCode::BAD_REQUEST).respond_to(&req); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); } #[actix_rt::test] @@ -389,11 +420,14 @@ pub(crate) mod tests { // Result let resp = Ok::<_, Error>("test".to_string()).respond_to(&req); assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); assert_eq!( resp.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("text/plain; charset=utf-8") ); + assert_eq!( + to_bytes(resp.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); let res = Err::(error::InternalError::new("err", StatusCode::BAD_REQUEST)) .respond_to(&req); @@ -410,7 +444,10 @@ pub(crate) mod tests { .respond_to(&req); assert_eq!(res.status(), StatusCode::BAD_REQUEST); - assert_eq!(res.body().bin_ref(), b"test"); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); let res = "test" .to_string() @@ -418,11 +455,14 @@ pub(crate) mod tests { .respond_to(&req); assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.body().bin_ref(), b"test"); assert_eq!( res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("json") ); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); } #[actix_rt::test] @@ -430,17 +470,23 @@ pub(crate) mod tests { let req = TestRequest::default().to_http_request(); let res = ("test".to_string(), StatusCode::BAD_REQUEST).respond_to(&req); assert_eq!(res.status(), StatusCode::BAD_REQUEST); - assert_eq!(res.body().bin_ref(), b"test"); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); let req = TestRequest::default().to_http_request(); let res = ("test".to_string(), StatusCode::OK) .with_header((CONTENT_TYPE, mime::APPLICATION_JSON)) .respond_to(&req); assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.body().bin_ref(), b"test"); assert_eq!( res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("application/json") ); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); } } diff --git a/src/response/builder.rs b/src/response/builder.rs index 56d30d9d0..b5bef2e99 100644 --- a/src/response/builder.rs +++ b/src/response/builder.rs @@ -1,14 +1,13 @@ use std::{ cell::{Ref, RefMut}, convert::TryInto, - error::Error as StdError, future::Future, pin::Pin, task::{Context, Poll}, }; use actix_http::{ - body::{AnyBody, BodyStream}, + body::{BodyStream, BoxBody, MessageBody}, http::{ header::{self, HeaderName, IntoHeaderPair, IntoHeaderValue}, ConnectionType, Error as HttpError, StatusCode, @@ -26,14 +25,14 @@ use cookie::{Cookie, CookieJar}; use crate::{ error::{Error, JsonPayloadError}, - HttpResponse, + BoxError, HttpResponse, }; /// An HTTP response builder. /// /// This type can be used to construct an instance of `Response` through a builder-like pattern. pub struct HttpResponseBuilder { - res: Option>, + res: Option>, err: Option, #[cfg(feature = "cookies")] cookies: Option, @@ -44,7 +43,7 @@ impl HttpResponseBuilder { /// Create response builder pub fn new(status: StatusCode) -> Self { Self { - res: Some(Response::new(status)), + res: Some(Response::with_body(status, BoxBody::new(()))), err: None, #[cfg(feature = "cookies")] cookies: None, @@ -299,7 +298,6 @@ impl HttpResponseBuilder { } /// Mutable reference to a the response's extensions - #[inline] pub fn extensions_mut(&mut self) -> RefMut<'_, Extensions> { self.res .as_mut() @@ -307,18 +305,20 @@ impl HttpResponseBuilder { .extensions_mut() } - /// Set a body and generate `Response`. + /// Set a body and build the `HttpResponse`. /// /// `HttpResponseBuilder` can not be used after this call. - #[inline] - pub fn body>(&mut self, body: B) -> HttpResponse { - match self.message_body(body.into()) { - Ok(res) => res, + pub fn body(&mut self, body: B) -> HttpResponse + where + B: MessageBody + 'static, + { + match self.message_body(body) { + Ok(res) => res.map_into_boxed_body(), Err(err) => HttpResponse::from_error(err), } } - /// Set a body and generate `Response`. + /// Set a body and build the `HttpResponse`. /// /// `HttpResponseBuilder` can not be used after this call. pub fn message_body(&mut self, body: B) -> Result, Error> { @@ -332,7 +332,7 @@ impl HttpResponseBuilder { .expect("cannot reuse response builder") .set_body(body); - #[allow(unused_mut)] + #[allow(unused_mut)] // mut is only unused when cookies are disabled let mut res = HttpResponse::from(res); #[cfg(feature = "cookies")] @@ -348,19 +348,19 @@ impl HttpResponseBuilder { Ok(res) } - /// Set a streaming body and generate `Response`. + /// Set a streaming body and build the `HttpResponse`. /// /// `HttpResponseBuilder` can not be used after this call. #[inline] pub fn streaming(&mut self, stream: S) -> HttpResponse where - S: Stream> + Unpin + 'static, - E: Into> + 'static, + S: Stream> + 'static, + E: Into + 'static, { - self.body(AnyBody::from_message(BodyStream::new(stream))) + self.body(BodyStream::new(stream)) } - /// Set a json body and generate `Response` + /// Set a JSON body and build the `HttpResponse`. /// /// `HttpResponseBuilder` can not be used after this call. pub fn json(&mut self, value: impl Serialize) -> HttpResponse { @@ -376,18 +376,18 @@ impl HttpResponseBuilder { self.insert_header((header::CONTENT_TYPE, mime::APPLICATION_JSON)); } - self.body(AnyBody::from(body)) + self.body(body) } Err(err) => HttpResponse::from_error(JsonPayloadError::Serialize(err)), } } - /// Set an empty body and generate `Response` + /// Set an empty body and build the `HttpResponse`. /// /// `HttpResponseBuilder` can not be used after this call. #[inline] pub fn finish(&mut self) -> HttpResponse { - self.body(AnyBody::Empty) + self.body(()) } /// This method construct new `HttpResponseBuilder` @@ -416,7 +416,7 @@ impl From for HttpResponse { } } -impl From for Response { +impl From for Response { fn from(mut builder: HttpResponseBuilder) -> Self { builder.finish().into() } @@ -435,12 +435,9 @@ mod tests { use actix_http::body; use super::*; - use crate::{ - dev::Body, - http::{ - header::{self, HeaderValue, CONTENT_TYPE}, - StatusCode, - }, + use crate::http::{ + header::{self, HeaderValue, CONTENT_TYPE}, + StatusCode, }; #[test] @@ -475,7 +472,7 @@ mod tests { fn test_content_type() { let resp = HttpResponseBuilder::new(StatusCode::OK) .content_type("text/plain") - .body(Body::Empty); + .body(Bytes::new()); assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain") } diff --git a/src/response/http_codes.rs b/src/response/http_codes.rs index d67ef3f92..44ddb78f9 100644 --- a/src/response/http_codes.rs +++ b/src/response/http_codes.rs @@ -87,13 +87,12 @@ impl HttpResponse { #[cfg(test)] mod tests { - use crate::dev::Body; use crate::http::StatusCode; use crate::HttpResponse; #[test] fn test_build() { - let resp = HttpResponse::Ok().body(Body::Empty); + let resp = HttpResponse::Ok().finish(); assert_eq!(resp.status(), StatusCode::OK); } } diff --git a/src/response/response.rs b/src/response/response.rs index 09515c839..97de21e42 100644 --- a/src/response/response.rs +++ b/src/response/response.rs @@ -8,7 +8,7 @@ use std::{ }; use actix_http::{ - body::{AnyBody, Body, MessageBody}, + body::{BoxBody, EitherBody, MessageBody}, http::{header::HeaderMap, StatusCode}, Extensions, Response, ResponseHead, }; @@ -25,12 +25,12 @@ use { use crate::{error::Error, HttpResponseBuilder}; /// An outgoing response. -pub struct HttpResponse { +pub struct HttpResponse { res: Response, pub(crate) error: Option, } -impl HttpResponse { +impl HttpResponse { /// Constructs a response. #[inline] pub fn new(status: StatusCode) -> Self { @@ -227,6 +227,27 @@ impl HttpResponse { } } + // TODO: docs for the body map methods below + + #[inline] + pub fn map_into_left_body(self) -> HttpResponse> { + self.map_body(|_, body| EitherBody::left(body)) + } + + #[inline] + pub fn map_into_right_body(self) -> HttpResponse> { + self.map_body(|_, body| EitherBody::right(body)) + } + + #[inline] + pub fn map_into_boxed_body(self) -> HttpResponse + where + B: MessageBody + 'static, + { + // TODO: avoid double boxing with down-casting, if it improves perf + self.map_body(|_, body| BoxBody::new(body)) + } + /// Extract response body pub fn into_body(self) -> B { self.res.into_body() @@ -270,14 +291,14 @@ impl From> for Response { } } -// Future is only implemented for Body payload type because it's the most useful for making simple -// handlers without async blocks. Making it generic over all MessageBody types requires a future -// impl on Response which would cause it's body field to be, undesirably, Option. +// Future is only implemented for BoxBody payload type because it's the most useful for making +// simple handlers without async blocks. Making it generic over all MessageBody types requires a +// future impl on Response which would cause it's body field to be, undesirably, Option. // // This impl is not particularly efficient due to the Response construction and should probably // not be invoked if performance is important. Prefer an async fn/block in such cases. -impl Future for HttpResponse { - type Output = Result, Error>; +impl Future for HttpResponse { + type Output = Result, Error>; fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { if let Some(err) = self.error.take() { @@ -293,7 +314,7 @@ impl Future for HttpResponse { #[cfg(feature = "cookies")] pub struct CookieIter<'a> { - iter: header::GetAll<'a>, + iter: header::map::GetAll<'a>, } #[cfg(feature = "cookies")] diff --git a/src/rmap.rs b/src/rmap.rs index 3c8805d57..432eaf83c 100644 --- a/src/rmap.rs +++ b/src/rmap.rs @@ -1,53 +1,86 @@ -use std::cell::RefCell; -use std::rc::{Rc, Weak}; +use std::{ + borrow::Cow, + cell::RefCell, + rc::{Rc, Weak}, +}; use actix_router::ResourceDef; use ahash::AHashMap; use url::Url; -use crate::error::UrlGenerationError; -use crate::request::HttpRequest; +use crate::{error::UrlGenerationError, request::HttpRequest}; #[derive(Clone, Debug)] pub struct ResourceMap { - root: ResourceDef, + pattern: ResourceDef, + + /// Named resources within the tree or, for external resources, + /// it points to isolated nodes outside the tree. + named: AHashMap>, + parent: RefCell>, - named: AHashMap, - patterns: Vec<(ResourceDef, Option>)>, + + /// Must be `None` for "edge" nodes. + nodes: Option>>, } impl ResourceMap { + /// Creates a _container_ node in the `ResourceMap` tree. pub fn new(root: ResourceDef) -> Self { ResourceMap { - root, - parent: RefCell::new(Weak::new()), + pattern: root, named: AHashMap::default(), - patterns: Vec::new(), + parent: RefCell::new(Weak::new()), + nodes: Some(Vec::new()), } } + /// Adds a (possibly nested) resource. + /// + /// To add a non-prefix pattern, `nested` must be `None`. + /// To add external resource, supply a pattern without a leading `/`. + /// The root pattern of `nested`, if present, should match `pattern`. pub fn add(&mut self, pattern: &mut ResourceDef, nested: Option>) { - pattern.set_id(self.patterns.len() as u16); - self.patterns.push((pattern.clone(), nested)); - if !pattern.name().is_empty() { - self.named - .insert(pattern.name().to_string(), pattern.clone()); + pattern.set_id(self.nodes.as_ref().unwrap().len() as u16); + + if let Some(new_node) = nested { + assert_eq!(&new_node.pattern, pattern, "`patern` and `nested` mismatch"); + self.named.extend(new_node.named.clone().into_iter()); + self.nodes.as_mut().unwrap().push(new_node); + } else { + let new_node = Rc::new(ResourceMap { + pattern: pattern.clone(), + named: AHashMap::default(), + parent: RefCell::new(Weak::new()), + nodes: None, + }); + + if let Some(name) = pattern.name() { + self.named.insert(name.to_owned(), Rc::clone(&new_node)); + } + + let is_external = match pattern.pattern() { + Some(p) => !p.is_empty() && !p.starts_with('/'), + None => false, + }; + + // Don't add external resources to the tree + if !is_external { + self.nodes.as_mut().unwrap().push(new_node); + } } } - pub(crate) fn finish(&self, current: Rc) { - for (_, nested) in &self.patterns { - if let Some(ref nested) = nested { - *nested.parent.borrow_mut() = Rc::downgrade(¤t); - nested.finish(nested.clone()); - } + pub(crate) fn finish(self: &Rc) { + for node in self.nodes.iter().flatten() { + node.parent.replace(Rc::downgrade(self)); + ResourceMap::finish(node); } } /// Generate url for named resource /// - /// Check [`HttpRequest::url_for()`](../struct.HttpRequest.html#method. - /// url_for) for detailed information. + /// Check [`HttpRequest::url_for`] for detailed information. pub fn url_for( &self, req: &HttpRequest, @@ -58,192 +91,108 @@ impl ResourceMap { U: IntoIterator, I: AsRef, { - let mut path = String::new(); let mut elements = elements.into_iter(); - if self.patterns_for(name, &mut path, &mut elements)?.is_some() { - if path.starts_with('/') { - let conn = req.connection_info(); - Ok(Url::parse(&format!( - "{}://{}{}", - conn.scheme(), - conn.host(), - path - ))?) - } else { - Ok(Url::parse(&path)?) - } + let path = self + .named + .get(name) + .ok_or(UrlGenerationError::ResourceNotFound)? + .root_rmap_fn(String::with_capacity(24), |mut acc, node| { + node.pattern + .resource_path_from_iter(&mut acc, &mut elements) + .then(|| acc) + }) + .ok_or(UrlGenerationError::NotEnoughElements)?; + + let (base, path): (Cow<'_, _>, _) = if path.starts_with('/') { + // build full URL from connection info parts and resource path + let conn = req.connection_info(); + let base = format!("{}://{}", conn.scheme(), conn.host()); + (Cow::Owned(base), path.as_str()) } else { - Err(UrlGenerationError::ResourceNotFound) - } + // external resource; third slash would be the root slash in the path + let third_slash_index = path + .char_indices() + .filter_map(|(i, c)| (c == '/').then(|| i)) + .nth(2) + .unwrap_or_else(|| path.len()); + + ( + Cow::Borrowed(&path[..third_slash_index]), + &path[third_slash_index..], + ) + }; + + let mut url = Url::parse(&base)?; + url.set_path(path); + Ok(url) } pub fn has_resource(&self, path: &str) -> bool { - let path = if path.is_empty() { "/" } else { path }; - - for (pattern, rmap) in &self.patterns { - if let Some(ref rmap) = rmap { - if let Some(plen) = pattern.is_prefix_match(path) { - return rmap.has_resource(&path[plen..]); - } - } else if pattern.is_match(path) || pattern.pattern() == "" && path == "/" { - return true; - } - } - false + self.find_matching_node(path).is_some() } /// Returns the name of the route that matches the given path or None if no full match - /// is possible. + /// is possible or the matching resource is not named. pub fn match_name(&self, path: &str) -> Option<&str> { - let path = if path.is_empty() { "/" } else { path }; - - for (pattern, rmap) in &self.patterns { - if let Some(ref rmap) = rmap { - if let Some(plen) = pattern.is_prefix_match(path) { - return rmap.match_name(&path[plen..]); - } - } else if pattern.is_match(path) { - return match pattern.name() { - "" => None, - s => Some(s), - }; - } - } - - None + self.find_matching_node(path)?.pattern.name() } /// Returns the full resource pattern matched against a path or None if no full match /// is possible. pub fn match_pattern(&self, path: &str) -> Option { - let path = if path.is_empty() { "/" } else { path }; - - // ensure a full match exists - if !self.has_resource(path) { - return None; - } - - Some(self.traverse_resource_pattern(path)) + self.find_matching_node(path)?.root_rmap_fn( + String::with_capacity(24), + |mut acc, node| { + acc.push_str(node.pattern.pattern()?); + Some(acc) + }, + ) } - /// Takes remaining path and tries to match it up against a resource definition within the - /// current resource map recursively, returning a concatenation of all resource prefixes and - /// patterns matched in the tree. - /// - /// Should only be used after checking the resource exists in the map so that partial match - /// patterns are not returned. - fn traverse_resource_pattern(&self, remaining: &str) -> String { - for (pattern, rmap) in &self.patterns { - if let Some(ref rmap) = rmap { - if let Some(prefix_len) = pattern.is_prefix_match(remaining) { - let prefix = pattern.pattern().to_owned(); - - return [ - prefix, - rmap.traverse_resource_pattern(&remaining[prefix_len..]), - ] - .concat(); - } - } else if pattern.is_match(remaining) { - return pattern.pattern().to_owned(); - } - } - - String::new() + fn find_matching_node(&self, path: &str) -> Option<&ResourceMap> { + self._find_matching_node(path).flatten() } - fn patterns_for( - &self, - name: &str, - path: &mut String, - elements: &mut U, - ) -> Result, UrlGenerationError> + /// Returns `None` if root pattern doesn't match; + /// `Some(None)` if root pattern matches but there is no matching child pattern. + /// Don't search sideways when `Some(none)` is returned. + fn _find_matching_node(&self, path: &str) -> Option> { + let matched_len = self.pattern.find_match(path)?; + let path = &path[matched_len..]; + + Some(match &self.nodes { + // find first sub-node to match remaining path + Some(nodes) => nodes + .iter() + .filter_map(|node| node._find_matching_node(path)) + .next() + .flatten(), + + // only terminate at edge nodes + None => Some(self), + }) + } + + /// Find `self`'s highest ancestor and then run `F`, providing `B`, in that rmap context. + fn root_rmap_fn(&self, init: B, mut f: F) -> Option where - U: Iterator, - I: AsRef, + F: FnMut(B, &ResourceMap) -> Option, { - if self.pattern_for(name, path, elements)?.is_some() { - Ok(Some(())) - } else { - self.parent_pattern_for(name, path, elements) - } + self._root_rmap_fn(init, &mut f) } - fn pattern_for( - &self, - name: &str, - path: &mut String, - elements: &mut U, - ) -> Result, UrlGenerationError> + /// Run `F`, providing `B`, if `self` is top-level resource map, else recurse to parent map. + fn _root_rmap_fn(&self, init: B, f: &mut F) -> Option where - U: Iterator, - I: AsRef, + F: FnMut(B, &ResourceMap) -> Option, { - if let Some(pattern) = self.named.get(name) { - if pattern.pattern().starts_with('/') { - self.fill_root(path, elements)?; - } - if pattern.resource_path(path, elements) { - Ok(Some(())) - } else { - Err(UrlGenerationError::NotEnoughElements) - } - } else { - for (_, rmap) in &self.patterns { - if let Some(ref rmap) = rmap { - if rmap.pattern_for(name, path, elements)?.is_some() { - return Ok(Some(())); - } - } - } - Ok(None) - } - } + let data = match self.parent.borrow().upgrade() { + Some(ref parent) => parent._root_rmap_fn(init, f)?, + None => init, + }; - fn fill_root( - &self, - path: &mut String, - elements: &mut U, - ) -> Result<(), UrlGenerationError> - where - U: Iterator, - I: AsRef, - { - if let Some(ref parent) = self.parent.borrow().upgrade() { - parent.fill_root(path, elements)?; - } - if self.root.resource_path(path, elements) { - Ok(()) - } else { - Err(UrlGenerationError::NotEnoughElements) - } - } - - fn parent_pattern_for( - &self, - name: &str, - path: &mut String, - elements: &mut U, - ) -> Result, UrlGenerationError> - where - U: Iterator, - I: AsRef, - { - if let Some(ref parent) = self.parent.borrow().upgrade() { - if let Some(pattern) = parent.named.get(name) { - self.fill_root(path, elements)?; - if pattern.resource_path(path, elements) { - Ok(Some(())) - } else { - Err(UrlGenerationError::NotEnoughElements) - } - } else { - parent.parent_pattern_for(name, path, elements) - } - } else { - Ok(None) - } + f(data, self) } } @@ -255,7 +204,7 @@ mod tests { fn extract_matched_pattern() { let mut root = ResourceMap::new(ResourceDef::root_prefix("")); - let mut user_map = ResourceMap::new(ResourceDef::root_prefix("")); + let mut user_map = ResourceMap::new(ResourceDef::root_prefix("/user/{id}")); user_map.add(&mut ResourceDef::new("/"), None); user_map.add(&mut ResourceDef::new("/profile"), None); user_map.add(&mut ResourceDef::new("/article/{id}"), None); @@ -271,9 +220,10 @@ mod tests { &mut ResourceDef::root_prefix("/user/{id}"), Some(Rc::new(user_map)), ); + root.add(&mut ResourceDef::new("/info"), None); let root = Rc::new(root); - root.finish(Rc::clone(&root)); + ResourceMap::finish(&root); // sanity check resource map setup @@ -284,7 +234,7 @@ mod tests { assert!(root.has_resource("/v2")); assert!(!root.has_resource("/v33")); - assert!(root.has_resource("/user/22")); + assert!(!root.has_resource("/user/22")); assert!(root.has_resource("/user/22/")); assert!(root.has_resource("/user/22/profile")); @@ -329,15 +279,15 @@ mod tests { let mut root = ResourceMap::new(ResourceDef::root_prefix("")); let mut rdef = ResourceDef::new("/info"); - *rdef.name_mut() = "root_info".to_owned(); + rdef.set_name("root_info"); root.add(&mut rdef, None); - let mut user_map = ResourceMap::new(ResourceDef::root_prefix("")); + let mut user_map = ResourceMap::new(ResourceDef::root_prefix("/user/{id}")); let mut rdef = ResourceDef::new("/"); user_map.add(&mut rdef, None); let mut rdef = ResourceDef::new("/post/{post_id}"); - *rdef.name_mut() = "user_post".to_owned(); + rdef.set_name("user_post"); user_map.add(&mut rdef, None); root.add( @@ -346,14 +296,14 @@ mod tests { ); let root = Rc::new(root); - root.finish(Rc::clone(&root)); + ResourceMap::finish(&root); // sanity check resource map setup assert!(root.has_resource("/info")); assert!(!root.has_resource("/bar")); - assert!(root.has_resource("/user/22")); + assert!(!root.has_resource("/user/22")); assert!(root.has_resource("/user/22/")); assert!(root.has_resource("/user/22/post/55")); @@ -373,7 +323,7 @@ mod tests { // ref: https://github.com/actix/actix-web/issues/1582 let mut root = ResourceMap::new(ResourceDef::root_prefix("")); - let mut user_map = ResourceMap::new(ResourceDef::root_prefix("")); + let mut user_map = ResourceMap::new(ResourceDef::root_prefix("/user/{id}")); user_map.add(&mut ResourceDef::new("/"), None); user_map.add(&mut ResourceDef::new("/profile"), None); user_map.add(&mut ResourceDef::new("/article/{id}"), None); @@ -389,20 +339,155 @@ mod tests { ); let root = Rc::new(root); - root.finish(Rc::clone(&root)); + ResourceMap::finish(&root); // check root has no parent assert!(root.parent.borrow().upgrade().is_none()); // check child has parent reference - assert!(root.patterns[0].1.is_some()); + assert!(root.nodes.as_ref().unwrap()[0] + .parent + .borrow() + .upgrade() + .is_some()); // check child's parent root id matches root's root id - assert_eq!( - root.patterns[0].1.as_ref().unwrap().root.id(), - root.root.id() - ); + assert!(Rc::ptr_eq( + &root.nodes.as_ref().unwrap()[0] + .parent + .borrow() + .upgrade() + .unwrap(), + &root + )); let output = format!("{:?}", root); assert!(output.starts_with("ResourceMap {")); assert!(output.ends_with(" }")); } + + #[test] + fn short_circuit() { + let mut root = ResourceMap::new(ResourceDef::prefix("")); + + let mut user_root = ResourceDef::prefix("/user"); + let mut user_map = ResourceMap::new(user_root.clone()); + user_map.add(&mut ResourceDef::new("/u1"), None); + user_map.add(&mut ResourceDef::new("/u2"), None); + + root.add(&mut ResourceDef::new("/user/u3"), None); + root.add(&mut user_root, Some(Rc::new(user_map))); + root.add(&mut ResourceDef::new("/user/u4"), None); + + let rmap = Rc::new(root); + ResourceMap::finish(&rmap); + + assert!(rmap.has_resource("/user/u1")); + assert!(rmap.has_resource("/user/u2")); + assert!(rmap.has_resource("/user/u3")); + assert!(!rmap.has_resource("/user/u4")); + } + + #[test] + fn url_for() { + let mut root = ResourceMap::new(ResourceDef::prefix("")); + + let mut user_scope_rdef = ResourceDef::prefix("/user"); + let mut user_scope_map = ResourceMap::new(user_scope_rdef.clone()); + + let mut user_rdef = ResourceDef::new("/{user_id}"); + let mut user_map = ResourceMap::new(user_rdef.clone()); + + let mut post_rdef = ResourceDef::new("/post/{sub_id}"); + post_rdef.set_name("post"); + + user_map.add(&mut post_rdef, None); + user_scope_map.add(&mut user_rdef, Some(Rc::new(user_map))); + root.add(&mut user_scope_rdef, Some(Rc::new(user_scope_map))); + + let rmap = Rc::new(root); + ResourceMap::finish(&rmap); + + let mut req = crate::test::TestRequest::default(); + req.set_server_hostname("localhost:8888"); + let req = req.to_http_request(); + + let url = rmap + .url_for(&req, "post", &["u123", "foobar"]) + .unwrap() + .to_string(); + assert_eq!(url, "http://localhost:8888/user/u123/post/foobar"); + + assert!(rmap.url_for(&req, "missing", &["u123"]).is_err()); + } + + #[test] + fn url_for_parser() { + let mut root = ResourceMap::new(ResourceDef::prefix("")); + + let mut rdef_1 = ResourceDef::new("/{var}"); + rdef_1.set_name("internal"); + + let mut rdef_2 = ResourceDef::new("http://host.dom/{var}"); + rdef_2.set_name("external.1"); + + let mut rdef_3 = ResourceDef::new("{var}"); + rdef_3.set_name("external.2"); + + root.add(&mut rdef_1, None); + root.add(&mut rdef_2, None); + root.add(&mut rdef_3, None); + let rmap = Rc::new(root); + ResourceMap::finish(&rmap); + + let mut req = crate::test::TestRequest::default(); + req.set_server_hostname("localhost:8888"); + let req = req.to_http_request(); + + const INPUT: &[&str] = &["a/../quick brown%20fox/%nan?query#frag"]; + const OUTPUT: &str = "/quick%20brown%20fox/%nan%3Fquery%23frag"; + + let url = rmap.url_for(&req, "internal", INPUT).unwrap(); + assert_eq!(url.path(), OUTPUT); + + let url = rmap.url_for(&req, "external.1", INPUT).unwrap(); + assert_eq!(url.path(), OUTPUT); + + assert!(rmap.url_for(&req, "external.2", INPUT).is_err()); + assert!(rmap.url_for(&req, "external.2", &[""]).is_err()); + } + + #[test] + fn external_resource_with_no_name() { + let mut root = ResourceMap::new(ResourceDef::prefix("")); + + let mut rdef = ResourceDef::new("https://duck.com/{query}"); + root.add(&mut rdef, None); + + let rmap = Rc::new(root); + ResourceMap::finish(&rmap); + + assert!(!rmap.has_resource("https://duck.com/abc")); + } + + #[test] + fn external_resource_with_name() { + let mut root = ResourceMap::new(ResourceDef::prefix("")); + + let mut rdef = ResourceDef::new("https://duck.com/{query}"); + rdef.set_name("duck"); + root.add(&mut rdef, None); + + let rmap = Rc::new(root); + ResourceMap::finish(&rmap); + + assert!(!rmap.has_resource("https://duck.com/abc")); + + let mut req = crate::test::TestRequest::default(); + req.set_server_hostname("localhost:8888"); + let req = req.to_http_request(); + + assert_eq!( + rmap.url_for(&req, "duck", &["abcd"]).unwrap().to_string(), + "https://duck.com/abcd" + ); + } } diff --git a/src/route.rs b/src/route.rs index d85b940bd..1eb323068 100644 --- a/src/route.rs +++ b/src/route.rs @@ -1,19 +1,18 @@ -#![allow(clippy::rc_buffer)] // inner value is mutated before being shared (`Rc::get_mut`) - -use std::{future::Future, rc::Rc}; +use std::{future::Future, mem, rc::Rc}; use actix_http::http::Method; use actix_service::{ - boxed::{self, BoxService, BoxServiceFactory}, - Service, ServiceFactory, ServiceFactoryExt, + boxed::{self, BoxService}, + fn_service, Service, ServiceFactory, ServiceFactoryExt, }; use futures_core::future::LocalBoxFuture; use crate::{ + body::MessageBody, guard::{self, Guard}, - handler::{Handler, HandlerService}, - service::{ServiceRequest, ServiceResponse}, - Error, FromRequest, HttpResponse, Responder, + handler::{handler_service, Handler}, + service::{BoxedHttpServiceFactory, ServiceRequest, ServiceResponse}, + BoxError, Error, FromRequest, HttpResponse, Responder, }; /// Resource route definition @@ -21,7 +20,7 @@ use crate::{ /// Route uses builder-like pattern for configuration. /// If handler is not explicitly set, default *404 Not Found* handler is used. pub struct Route { - service: BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>, + service: BoxedHttpServiceFactory, guards: Rc>>, } @@ -30,13 +29,15 @@ impl Route { #[allow(clippy::new_without_default)] pub fn new() -> Route { Route { - service: boxed::factory(HandlerService::new(HttpResponse::NotFound)), + service: boxed::factory(fn_service(|req: ServiceRequest| async { + Ok(req.into_response(HttpResponse::NotFound())) + })), guards: Rc::new(Vec::new()), } } pub(crate) fn take_guards(&mut self) -> Vec> { - std::mem::take(Rc::get_mut(&mut self.guards).unwrap()) + mem::take(Rc::get_mut(&mut self.guards).unwrap()) } } @@ -181,8 +182,10 @@ impl Route { T: FromRequest + 'static, R: Future + 'static, R::Output: Responder + 'static, + ::Body: MessageBody, + <::Body as MessageBody>::Error: Into, { - self.service = boxed::factory(HandlerService::new(handler)); + self.service = handler_service(handler); self } diff --git a/src/scope.rs b/src/scope.rs index aa546c422..ff013671b 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -3,9 +3,8 @@ use std::{cell::RefCell, fmt, future::Future, mem, rc::Rc}; use actix_http::Extensions; use actix_router::{ResourceDef, Router}; use actix_service::{ - apply, apply_fn_factory, - boxed::{self, BoxService, BoxServiceFactory}, - IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt, Transform, + apply, apply_fn_factory, boxed, IntoServiceFactory, Service, ServiceFactory, + ServiceFactoryExt, Transform, }; use futures_core::future::LocalBoxFuture; use futures_util::future::join_all; @@ -13,16 +12,17 @@ use futures_util::future::join_all; use crate::{ config::ServiceConfig, data::Data, - dev::{AppService, HttpServiceFactory}, + dev::AppService, guard::Guard, rmap::ResourceMap, - service::{AppServiceFactory, ServiceFactoryWrapper, ServiceRequest, ServiceResponse}, + service::{ + AppServiceFactory, BoxedHttpService, BoxedHttpServiceFactory, HttpServiceFactory, + ServiceFactoryWrapper, ServiceRequest, ServiceResponse, + }, Error, Resource, Route, }; type Guards = Vec>; -type HttpService = BoxService; -type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; /// Resources scope. /// @@ -41,9 +41,9 @@ type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Err /// fn main() { /// let app = App::new().service( /// web::scope("/{project_id}/") -/// .service(web::resource("/path1").to(|| async { HttpResponse::Ok() })) +/// .service(web::resource("/path1").to(|| async { "OK" })) /// .service(web::resource("/path2").route(web::get().to(|| HttpResponse::Ok()))) -/// .service(web::resource("/path3").route(web::head().to(|| HttpResponse::MethodNotAllowed()))) +/// .service(web::resource("/path3").route(web::head().to(HttpResponse::MethodNotAllowed))) /// ); /// } /// ``` @@ -58,7 +58,7 @@ pub struct Scope { app_data: Option, services: Vec>, guards: Vec>, - default: Option>, + default: Option>, external: Vec, factory_ref: Rc>>, } @@ -470,8 +470,14 @@ where } pub struct ScopeFactory { - services: Rc<[(ResourceDef, HttpNewService, RefCell>)]>, - default: Rc, + services: Rc< + [( + ResourceDef, + BoxedHttpServiceFactory, + RefCell>, + )], + >, + default: Rc, } impl ServiceFactory for ScopeFactory { @@ -518,8 +524,8 @@ impl ServiceFactory for ScopeFactory { } pub struct ScopeService { - router: Router>>, - default: HttpService, + router: Router>>, + default: BoxedHttpService, } impl Service for ScopeService { @@ -530,7 +536,7 @@ impl Service for ScopeService { actix_service::always_ready!(); fn call(&self, mut req: ServiceRequest) -> Self::Future { - let res = self.router.recognize_checked(&mut req, |req, guards| { + let res = self.router.recognize_fn(&mut req, |req, guards| { if let Some(ref guards) = guards { for f in guards { if !f.check(req.head()) { @@ -580,12 +586,11 @@ mod tests { use bytes::Bytes; use crate::{ - dev::Body, guard, http::{header, HeaderValue, Method, StatusCode}, middleware::DefaultHeaders, service::{ServiceRequest, ServiceResponse}, - test::{call_service, init_service, read_body, TestRequest}, + test::{assert_body_eq, call_service, init_service, read_body, TestRequest}, web, App, HttpMessage, HttpRequest, HttpResponse, }; @@ -748,20 +753,13 @@ mod tests { .await; let req = TestRequest::with_uri("/ab-project1/path1").to_request(); - let resp = srv.call(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - - match resp.response().body() { - Body::Bytes(ref b) => { - let bytes = b.clone(); - assert_eq!(bytes, Bytes::from_static(b"project: project1")); - } - _ => panic!(), - } + let res = srv.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_body_eq!(res, b"project: project1"); let req = TestRequest::with_uri("/aa-project1/path1").to_request(); - let resp = srv.call(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let res = srv.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); } #[actix_rt::test] @@ -849,16 +847,9 @@ mod tests { .await; let req = TestRequest::with_uri("/app/project_1/path1").to_request(); - let resp = srv.call(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::CREATED); - - match resp.response().body() { - Body::Bytes(ref b) => { - let bytes = b.clone(); - assert_eq!(bytes, Bytes::from_static(b"project: project_1")); - } - _ => panic!(), - } + let res = srv.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::CREATED); + assert_body_eq!(res, b"project: project_1"); } #[actix_rt::test] @@ -877,20 +868,13 @@ mod tests { .await; let req = TestRequest::with_uri("/app/test/1/path1").to_request(); - let resp = srv.call(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::CREATED); - - match resp.response().body() { - Body::Bytes(ref b) => { - let bytes = b.clone(); - assert_eq!(bytes, Bytes::from_static(b"project: test - 1")); - } - _ => panic!(), - } + let res = srv.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::CREATED); + assert_body_eq!(res, b"project: test - 1"); let req = TestRequest::with_uri("/app/test/1/path2").to_request(); - let resp = srv.call(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let res = srv.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); } #[actix_rt::test] @@ -1153,4 +1137,70 @@ mod tests { Bytes::from_static(b"http://localhost:8080/a/b/c/12345") ); } + + #[actix_rt::test] + async fn dynamic_scopes() { + let srv = init_service( + App::new().service( + web::scope("/{a}/").service( + web::scope("/{b}/") + .route("", web::get().to(|_: HttpRequest| HttpResponse::Created())) + .route( + "/", + web::get().to(|_: HttpRequest| HttpResponse::Accepted()), + ) + .route("/{c}", web::get().to(|_: HttpRequest| HttpResponse::Ok())), + ), + ), + ) + .await; + + // note the unintuitive behavior with trailing slashes on scopes with dynamic segments + let req = TestRequest::with_uri("/a//b//c").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/a//b/").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::CREATED); + + let req = TestRequest::with_uri("/a//b//").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::ACCEPTED); + + let req = TestRequest::with_uri("/a//b//c/d").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + let srv = init_service( + App::new().service( + web::scope("/{a}").service( + web::scope("/{b}") + .route("", web::get().to(|_: HttpRequest| HttpResponse::Created())) + .route( + "/", + web::get().to(|_: HttpRequest| HttpResponse::Accepted()), + ) + .route("/{c}", web::get().to(|_: HttpRequest| HttpResponse::Ok())), + ), + ), + ) + .await; + + let req = TestRequest::with_uri("/a/b/c").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/a/b").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::CREATED); + + let req = TestRequest::with_uri("/a/b/").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::ACCEPTED); + + let req = TestRequest::with_uri("/a/b/c/d").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + } } diff --git a/src/server.rs b/src/server.rs index c302f0352..c26501123 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,8 +1,6 @@ use std::{ any::Any, - cmp, - error::Error as StdError, - fmt, io, + cmp, fmt, io, marker::PhantomData, net, sync::{Arc, Mutex}, @@ -17,9 +15,9 @@ use actix_service::{ }; #[cfg(feature = "openssl")] -use actix_tls::accept::openssl::{AlpnError, SslAcceptor, SslAcceptorBuilder}; +use actix_tls::accept::openssl::reexports::{AlpnError, SslAcceptor, SslAcceptorBuilder}; #[cfg(feature = "rustls")] -use actix_tls::accept::rustls::ServerConfig as RustlsServerConfig; +use actix_tls::accept::rustls::reexports::ServerConfig as RustlsServerConfig; use crate::{config::AppConfig, Error}; @@ -77,15 +75,13 @@ where I: IntoServiceFactory, S: ServiceFactory + 'static, - // S::Future: 'static, S::Error: Into + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, >::Future: 'static, S::Service: 'static, - // S::Service: 'static, + B: MessageBody + 'static, - B::Error: Into>, { /// Create new HTTP server with application factory pub fn new(factory: F) -> Self { @@ -111,11 +107,11 @@ where /// and handlers. /// /// # Connection Types - /// - `actix_web::rt::net::TcpStream` when no TLS layer is used. - /// - `actix_tls::rustls::TlsStream` when using rustls. - /// - `actix_tls::openssl::SslStream` when using openssl. + /// - `actix_tls::accept::openssl::TlsStream` when using openssl. + /// - `actix_tls::accept::rustls::TlsStream` when using rustls. + /// - `actix_web::rt::net::TcpStream` when no encryption is used. /// - /// See `on_connect` example for additional details. + /// See the `on_connect` example for additional details. pub fn on_connect(self, f: CB) -> HttpServer where CB: Fn(&dyn Any, &mut CloneableExtensions) + Send + Sync + 'static, @@ -162,7 +158,7 @@ where /// /// By default max connections is set to a 25k. pub fn max_connections(mut self, num: usize) -> Self { - self.builder = self.builder.maxconn(num); + self.builder = self.builder.max_concurrent_connections(num); self } @@ -236,7 +232,7 @@ where self } - /// Stop actix system. + /// Stop Actix `System` after server shutdown. pub fn system_exit(mut self) -> Self { self.builder = self.builder.system_exit(); self @@ -659,8 +655,8 @@ fn create_tcp_listener(addr: net::SocketAddr, backlog: u32) -> io::Result io::Result { builder.set_alpn_select_callback(|_, protocols| { const H2: &[u8] = b"\x02h2"; diff --git a/src/service.rs b/src/service.rs index 47e7e4acc..df9e809e4 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,26 +1,35 @@ -use std::cell::{Ref, RefMut}; -use std::rc::Rc; -use std::{fmt, net}; +use std::{ + cell::{Ref, RefMut}, + fmt, net, + rc::Rc, +}; use actix_http::{ - body::{AnyBody, MessageBody}, + body::{BoxBody, EitherBody, MessageBody}, http::{HeaderMap, Method, StatusCode, Uri, Version}, Extensions, HttpMessage, Payload, PayloadStream, RequestHead, Response, ResponseHead, }; -use actix_router::{IntoPattern, Path, Resource, ResourceDef, Url}; -use actix_service::{IntoServiceFactory, ServiceFactory}; +use actix_router::{IntoPatterns, Path, Patterns, Resource, ResourceDef, Url}; +use actix_service::{ + boxed::{BoxService, BoxServiceFactory}, + IntoServiceFactory, ServiceFactory, +}; #[cfg(feature = "cookies")] use cookie::{Cookie, ParseError as CookieParseError}; use crate::{ config::{AppConfig, AppService}, - dev::insert_leading_slash, + dev::ensure_leading_slash, guard::Guard, info::ConnectionInfo, rmap::ResourceMap, Error, HttpRequest, HttpResponse, }; +pub(crate) type BoxedHttpService = BoxService, Error>; +pub(crate) type BoxedHttpServiceFactory = + BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; + pub trait HttpServiceFactory { fn register(self, config: &mut AppService); } @@ -117,7 +126,7 @@ impl ServiceRequest { /// This method returns reference to the request head #[inline] pub fn head(&self) -> &RequestHead { - &self.req.head() + self.req.head() } /// This method returns reference to the request head @@ -212,14 +221,14 @@ impl ServiceRequest { self.req.match_pattern() } - #[inline] /// Get a mutable reference to the Path parameters. + #[inline] pub fn match_info_mut(&mut self) -> &mut Path { self.req.match_info_mut() } - #[inline] /// Get a reference to a `ResourceMap` of current application. + #[inline] pub fn resource_map(&self) -> &ResourceMap { self.req.resource_map() } @@ -326,12 +335,12 @@ impl fmt::Debug for ServiceRequest { } /// A service level response wrapper. -pub struct ServiceResponse { +pub struct ServiceResponse { request: HttpRequest, response: HttpResponse, } -impl ServiceResponse { +impl ServiceResponse { /// Create service response from the error pub fn from_err>(err: E, request: HttpRequest) -> Self { let response = HttpResponse::from_error(err); @@ -393,16 +402,6 @@ impl ServiceResponse { self.response.headers_mut() } - /// Execute closure and in case of error convert it to response. - pub fn checked_expr(mut self, f: F) -> Result - where - F: FnOnce(&mut Self) -> Result<(), E>, - E: Into, - { - f(&mut self).map_err(Into::into)?; - Ok(self) - } - /// Extract response body pub fn into_body(self) -> B { self.response.into_body() @@ -411,6 +410,7 @@ impl ServiceResponse { impl ServiceResponse { /// Set a new body + #[inline] pub fn map_body(self, f: F) -> ServiceResponse where F: FnOnce(&mut ResponseHead, B) -> B2, @@ -422,6 +422,24 @@ impl ServiceResponse { request: self.request, } } + + #[inline] + pub fn map_into_left_body(self) -> ServiceResponse> { + self.map_body(|_, body| EitherBody::left(body)) + } + + #[inline] + pub fn map_into_right_body(self) -> ServiceResponse> { + self.map_body(|_, body| EitherBody::right(body)) + } + + #[inline] + pub fn map_into_boxed_body(self) -> ServiceResponse + where + B: MessageBody + 'static, + { + self.map_body(|_, body| BoxBody::new(body)) + } } impl From> for HttpResponse { @@ -459,14 +477,14 @@ where } pub struct WebService { - rdef: Vec, + rdef: Patterns, name: Option, guards: Vec>, } impl WebService { /// Create new `WebService` instance. - pub fn new(path: T) -> Self { + pub fn new(path: T) -> Self { WebService { rdef: path.patterns(), name: None, @@ -476,7 +494,7 @@ impl WebService { /// Set service name. /// - /// Name is used for url generation. + /// Name is used for URL generation. pub fn name(mut self, name: &str) -> Self { self.name = Some(name.to_string()); self @@ -528,7 +546,7 @@ impl WebService { struct WebServiceImpl { srv: T, - rdef: Vec, + rdef: Patterns, name: Option, guards: Vec>, } @@ -551,13 +569,15 @@ where }; let mut rdef = if config.is_root() || !self.rdef.is_empty() { - ResourceDef::new(insert_leading_slash(self.rdef)) + ResourceDef::new(ensure_leading_slash(self.rdef)) } else { ResourceDef::new(self.rdef) }; + if let Some(ref name) = self.name { - *rdef.name_mut() = name.clone(); + rdef.set_name(name); } + config.register_service(rdef, guards, self.srv, None) } } @@ -569,7 +589,6 @@ where /// The max number of services can be grouped together is 12. /// /// # Examples -/// /// ``` /// use actix_web::{services, web, App}; /// diff --git a/src/test.rs b/src/test.rs index 634826d19..2cd01039d 100644 --- a/src/test.rs +++ b/src/test.rs @@ -1,10 +1,9 @@ //! Various helpers for Actix applications to use during testing. -use std::{net::SocketAddr, rc::Rc}; +use std::{borrow::Cow, net::SocketAddr, rc::Rc}; pub use actix_http::test::TestBuffer; use actix_http::{ - body, http::{header::IntoHeaderPair, Method, StatusCode, Uri, Version}, test::TestRequest as HttpTestRequest, Extensions, Request, @@ -20,9 +19,10 @@ use serde::{de::DeserializeOwned, Serialize}; use crate::cookie::{Cookie, CookieJar}; use crate::{ app_service::AppInitServiceState, + body::{self, BoxBody, MessageBody}, config::AppConfig, data::Data, - dev::{Body, MessageBody, Payload}, + dev::Payload, http::header::ContentType, rmap::ResourceMap, service::{ServiceRequest, ServiceResponse}, @@ -32,14 +32,14 @@ use crate::{ /// Create service that always responds with `HttpResponse::Ok()` and no body. pub fn ok_service( -) -> impl Service, Error = Error> { +) -> impl Service, Error = Error> { default_service(StatusCode::OK) } /// Create service that always responds with given status code and no body. pub fn default_service( status_code: StatusCode, -) -> impl Service, Error = Error> { +) -> impl Service, Error = Error> { (move |req: ServiceRequest| { ok(req.into_response(HttpResponseBuilder::new(status_code).finish())) }) @@ -52,11 +52,11 @@ pub fn default_service( /// use actix_service::Service; /// use actix_web::{test, web, App, HttpResponse, http::StatusCode}; /// -/// #[actix_rt::test] +/// #[actix_web::test] /// async fn test_init_service() { /// let app = test::init_service( /// App::new() -/// .service(web::resource("/test").to(|| async { HttpResponse::Ok() })) +/// .service(web::resource("/test").to(|| async { "OK" })) /// ).await; /// /// // Create request object @@ -98,7 +98,7 @@ where /// ``` /// use actix_web::{test, web, App, HttpResponse, http::StatusCode}; /// -/// #[actix_rt::test] +/// #[actix_web::test] /// async fn test_response() { /// let app = test::init_service( /// App::new() @@ -129,7 +129,7 @@ where /// use actix_web::{test, web, App, HttpResponse, http::header}; /// use bytes::Bytes; /// -/// #[actix_rt::test] +/// #[actix_web::test] /// async fn test_index() { /// let app = test::init_service( /// App::new().service( @@ -176,7 +176,7 @@ where /// use actix_web::{test, web, App, HttpResponse, http::header}; /// use bytes::Bytes; /// -/// #[actix_rt::test] +/// #[actix_web::test] /// async fn test_index() { /// let app = test::init_service( /// App::new().service( @@ -224,7 +224,7 @@ where /// name: String, /// } /// -/// #[actix_rt::test] +/// #[actix_web::test] /// async fn test_post_person() { /// let app = test::init_service( /// App::new().service( @@ -296,7 +296,7 @@ where /// name: String /// } /// -/// #[actix_rt::test] +/// #[actix_web::test] /// async fn test_add_person() { /// let app = test::init_service( /// App::new().service( @@ -356,8 +356,8 @@ where /// } /// } /// -/// #[test] -/// fn test_index() { +/// #[actix_web::test] +/// async fn test_index() { /// let req = test::TestRequest::default().insert_header("content-type", "text/plain") /// .to_http_request(); /// @@ -470,19 +470,31 @@ impl TestRequest { self } - /// Set request path pattern parameter - pub fn param(mut self, name: &'static str, value: &'static str) -> Self { + /// Set request path pattern parameter. + /// + /// # Examples + /// ``` + /// use actix_web::test::TestRequest; + /// + /// let req = TestRequest::default().param("foo", "bar"); + /// let req = TestRequest::default().param("foo".to_owned(), "bar".to_owned()); + /// ``` + pub fn param( + mut self, + name: impl Into>, + value: impl Into>, + ) -> Self { self.path.add_static(name, value); self } - /// Set peer addr + /// Set peer addr. pub fn peer_addr(mut self, addr: SocketAddr) -> Self { self.peer_addr = Some(addr); self } - /// Set request payload + /// Set request payload. pub fn set_payload>(mut self, data: B) -> Self { self.req.set_payload(data); self @@ -620,6 +632,22 @@ impl TestRequest { } } +/// Reduces boilerplate code when testing expected response payloads. +#[cfg(test)] +macro_rules! assert_body_eq { + ($res:ident, $expected:expr) => { + assert_eq!( + ::actix_http::body::to_bytes($res.into_body()) + .await + .expect("body read should have succeeded"), + Bytes::from_static($expected), + ) + }; +} + +#[cfg(test)] +pub(crate) use assert_body_eq; + #[cfg(test)] mod tests { use std::time::SystemTime; diff --git a/src/types/either.rs b/src/types/either.rs index d3b003587..3c759736e 100644 --- a/src/types/either.rs +++ b/src/types/either.rs @@ -9,9 +9,10 @@ use std::{ use bytes::Bytes; use futures_core::ready; +use pin_project_lite::pin_project; use crate::{ - dev, + body, dev, web::{Form, Json}, Error, FromRequest, HttpRequest, HttpResponse, Responder, }; @@ -145,10 +146,12 @@ where L: Responder, R: Responder, { - fn respond_to(self, req: &HttpRequest) -> HttpResponse { + type Body = body::EitherBody; + + fn respond_to(self, req: &HttpRequest) -> HttpResponse { match self { - Either::Left(a) => a.respond_to(req), - Either::Right(b) => b.respond_to(req), + Either::Left(a) => a.respond_to(req).map_into_left_body(), + Either::Right(b) => b.respond_to(req).map_into_right_body(), } } } @@ -187,7 +190,6 @@ where { type Error = EitherExtractError; type Future = EitherExtractFut; - type Config = (); fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { EitherExtractFut { @@ -199,37 +201,40 @@ where } } -#[pin_project::pin_project] -pub struct EitherExtractFut -where - R: FromRequest, - L: FromRequest, -{ - req: HttpRequest, - #[pin] - state: EitherExtractState, +pin_project! { + pub struct EitherExtractFut + where + R: FromRequest, + L: FromRequest, + { + req: HttpRequest, + #[pin] + state: EitherExtractState, + } } -#[pin_project::pin_project(project = EitherExtractProj)] -pub enum EitherExtractState -where - L: FromRequest, - R: FromRequest, -{ - Bytes { - #[pin] - bytes: ::Future, - }, - Left { - #[pin] - left: L::Future, - fallback: Bytes, - }, - Right { - #[pin] - right: R::Future, - left_err: Option, - }, +pin_project! { + #[project = EitherExtractProj] + pub enum EitherExtractState + where + L: FromRequest, + R: FromRequest, + { + Bytes { + #[pin] + bytes: ::Future, + }, + Left { + #[pin] + left: L::Future, + fallback: Bytes, + }, + Right { + #[pin] + right: R::Future, + left_err: Option, + }, + } } impl Future for EitherExtractFut @@ -253,7 +258,7 @@ where Ok(bytes) => { let fallback = bytes.clone(); let left = - L::from_request(&this.req, &mut payload_from_bytes(bytes)); + L::from_request(this.req, &mut payload_from_bytes(bytes)); EitherExtractState::Left { left, fallback } } Err(err) => break Err(EitherExtractError::Bytes(err)), @@ -265,7 +270,7 @@ where Ok(extracted) => break Ok(Either::Left(extracted)), Err(left_err) => { let right = R::from_request( - &this.req, + this.req, &mut payload_from_bytes(mem::take(fallback)), ); EitherExtractState::Right { diff --git a/src/types/form.rs b/src/types/form.rs index c81f73554..9c09c6b73 100644 --- a/src/types/form.rs +++ b/src/types/form.rs @@ -20,8 +20,9 @@ use serde::{de::DeserializeOwned, Serialize}; #[cfg(feature = "__compress")] use crate::dev::Decompress; use crate::{ - error::UrlencodedError, extract::FromRequest, http::header::CONTENT_LENGTH, web, Error, - HttpMessage, HttpRequest, HttpResponse, Responder, + body::EitherBody, error::UrlencodedError, extract::FromRequest, + http::header::CONTENT_LENGTH, web, Error, HttpMessage, HttpRequest, HttpResponse, + Responder, }; /// URL encoded payload extractor and responder. @@ -30,9 +31,9 @@ use crate::{ /// /// # Extractor /// To extract typed data from a request body, the inner type `T` must implement the -/// [`serde::Deserialize`] trait. +/// [`DeserializeOwned`] trait. /// -/// Use [`FormConfig`] to configure extraction process. +/// Use [`FormConfig`] to configure extraction options. /// /// ``` /// use actix_web::{post, web}; @@ -126,20 +127,12 @@ impl FromRequest for Form where T: DeserializeOwned + 'static, { - type Config = FormConfig; type Error = Error; type Future = FormExtractFut; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - let (limit, err_handler) = req - .app_data::() - .or_else(|| { - req.app_data::>() - .map(|d| d.as_ref()) - }) - .map(|c| (c.limit, c.err_handler.clone())) - .unwrap_or((16384, None)); + let FormConfig { limit, err_handler } = FormConfig::from_req(req).clone(); FormExtractFut { fut: UrlEncoded::new(req, payload).limit(limit), @@ -188,12 +181,21 @@ impl fmt::Display for Form { /// See [here](#responder) for example of usage as a handler return type. impl Responder for Form { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { + type Body = EitherBody; + + fn respond_to(self, _: &HttpRequest) -> HttpResponse { match serde_urlencoded::to_string(&self.0) { - Ok(body) => HttpResponse::Ok() + Ok(body) => match HttpResponse::Ok() .content_type(mime::APPLICATION_WWW_FORM_URLENCODED) - .body(body), - Err(err) => HttpResponse::from_error(UrlencodedError::Serialize(err)), + .message_body(body) + { + Ok(res) => res.map_into_left_body(), + Err(err) => HttpResponse::from_error(err).map_into_right_body(), + }, + + Err(err) => { + HttpResponse::from_error(UrlencodedError::Serialize(err)).map_into_right_body() + } } } } @@ -241,14 +243,26 @@ impl FormConfig { self.err_handler = Some(Rc::new(f)); self } + + /// Extract payload config from app data. + /// + /// Checks both `T` and `Data`, in that order, and falls back to the default payload config. + fn from_req(req: &HttpRequest) -> &Self { + req.app_data::() + .or_else(|| req.app_data::>().map(|d| d.as_ref())) + .unwrap_or(&DEFAULT_CONFIG) + } } +/// Allow shared refs used as default. +const DEFAULT_CONFIG: FormConfig = FormConfig { + limit: 16_384, // 2^14 bytes (~16kB) + err_handler: None, +}; + impl Default for FormConfig { fn default() -> Self { - FormConfig { - limit: 16_384, // 2^14 bytes (~16kB) - err_handler: None, - } + DEFAULT_CONFIG } } @@ -404,11 +418,14 @@ mod tests { use serde::{Deserialize, Serialize}; use super::*; - use crate::http::{ - header::{HeaderValue, CONTENT_LENGTH, CONTENT_TYPE}, - StatusCode, - }; use crate::test::TestRequest; + use crate::{ + http::{ + header::{HeaderValue, CONTENT_LENGTH, CONTENT_TYPE}, + StatusCode, + }, + test::assert_body_eq, + }; #[derive(Deserialize, Serialize, Debug, PartialEq)] struct Info { @@ -516,15 +533,13 @@ mod tests { hello: "world".to_string(), counter: 123, }); - let resp = form.respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); + let res = form.respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), + res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("application/x-www-form-urlencoded") ); - - use crate::responder::tests::BodyTest; - assert_eq!(resp.body().bin_ref(), b"hello=world&counter=123"); + assert_body_eq!(res, b"hello=world&counter=123"); } #[actix_rt::test] diff --git a/src/types/header.rs b/src/types/header.rs index 9b64f445d..6ea77faf6 100644 --- a/src/types/header.rs +++ b/src/types/header.rs @@ -62,7 +62,6 @@ where { type Error = ParseError; type Future = Ready>; - type Config = (); #[inline] fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { diff --git a/src/types/json.rs b/src/types/json.rs index fc02c8854..2b4d220e2 100644 --- a/src/types/json.rs +++ b/src/types/json.rs @@ -19,6 +19,7 @@ use actix_http::Payload; #[cfg(feature = "__compress")] use crate::dev::Decompress; use crate::{ + body::EitherBody, error::{Error, JsonPayloadError}, extract::FromRequest, http::header::CONTENT_LENGTH, @@ -34,7 +35,7 @@ use crate::{ /// To extract typed data from a request body, the inner type `T` must implement the /// [`serde::Deserialize`] trait. /// -/// Use [`JsonConfig`] to configure extraction process. +/// Use [`JsonConfig`] to configure extraction options. /// /// ``` /// use actix_web::{post, web, App}; @@ -97,19 +98,13 @@ impl ops::DerefMut for Json { } } -impl fmt::Display for Json -where - T: fmt::Display, -{ +impl fmt::Display for Json { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&self.0, f) } } -impl Serialize for Json -where - T: Serialize, -{ +impl Serialize for Json { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, @@ -122,36 +117,42 @@ where /// /// If serialization failed impl Responder for Json { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { + type Body = EitherBody; + + fn respond_to(self, _: &HttpRequest) -> HttpResponse { match serde_json::to_string(&self.0) { - Ok(body) => HttpResponse::Ok() + Ok(body) => match HttpResponse::Ok() .content_type(mime::APPLICATION_JSON) - .body(body), - Err(err) => HttpResponse::from_error(JsonPayloadError::Serialize(err)), + .message_body(body) + { + Ok(res) => res.map_into_left_body(), + Err(err) => HttpResponse::from_error(err).map_into_right_body(), + }, + + Err(err) => { + HttpResponse::from_error(JsonPayloadError::Serialize(err)).map_into_right_body() + } } } } /// See [here](#extractor) for example of usage as an extractor. -impl FromRequest for Json -where - T: DeserializeOwned + 'static, -{ +impl FromRequest for Json { type Error = Error; type Future = JsonExtractFut; - type Config = JsonConfig; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { let config = JsonConfig::from_req(req); let limit = config.limit; - let ctype = config.content_type.as_deref(); + let ctype_required = config.content_type_required; + let ctype_fn = config.content_type.as_deref(); let err_handler = config.err_handler.clone(); JsonExtractFut { req: Some(req.clone()), - fut: JsonBody::new(req, payload, ctype).limit(limit), + fut: JsonBody::new(req, payload, ctype_fn, ctype_required).limit(limit), err_handler, } } @@ -166,10 +167,7 @@ pub struct JsonExtractFut { err_handler: JsonErrorHandler, } -impl Future for JsonExtractFut -where - T: DeserializeOwned + 'static, -{ +impl Future for JsonExtractFut { type Output = Result, Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -237,6 +235,7 @@ pub struct JsonConfig { limit: usize, err_handler: JsonErrorHandler, content_type: Option bool + Send + Sync>>, + content_type_required: bool, } impl JsonConfig { @@ -264,6 +263,12 @@ impl JsonConfig { self } + /// Sets whether or not the request must have a `Content-Type` header to be parsed. + pub fn content_type_required(mut self, content_type_required: bool) -> Self { + self.content_type_required = content_type_required; + self + } + /// Extract payload config from app data. Check both `T` and `Data`, in that order, and fall /// back to the default payload config. fn from_req(req: &HttpRequest) -> &Self { @@ -280,6 +285,7 @@ const DEFAULT_CONFIG: JsonConfig = JsonConfig { limit: DEFAULT_LIMIT, err_handler: None, content_type: None, + content_type_required: true, }; impl Default for JsonConfig { @@ -290,15 +296,18 @@ impl Default for JsonConfig { /// Future that resolves to some `T` when parsed from a JSON payload. /// -/// Form can be deserialized from any type `T` that implements [`serde::Deserialize`]. +/// Can deserialize any type `T` that implements [`Deserialize`][serde::Deserialize]. /// /// Returns error if: -/// - content type is not `application/json` -/// - content length is greater than [limit](JsonBody::limit()) +/// - `Content-Type` is not `application/json` when `ctype_required` (passed to [`new`][Self::new]) +/// is `true`. +/// - `Content-Length` is greater than [limit](JsonBody::limit()). +/// - The payload, when consumed, is not valid JSON. pub enum JsonBody { Error(Option), Body { limit: usize, + /// Length as reported by `Content-Length` header, if present. length: Option, #[cfg(feature = "__compress")] payload: Decompress, @@ -311,27 +320,27 @@ pub enum JsonBody { impl Unpin for JsonBody {} -impl JsonBody -where - T: DeserializeOwned + 'static, -{ +impl JsonBody { /// Create a new future to decode a JSON request payload. #[allow(clippy::borrow_interior_mutable_const)] pub fn new( req: &HttpRequest, payload: &mut Payload, - ctype: Option<&(dyn Fn(mime::Mime) -> bool + Send + Sync)>, + ctype_fn: Option<&(dyn Fn(mime::Mime) -> bool + Send + Sync)>, + ctype_required: bool, ) -> Self { // check content-type - let json = if let Ok(Some(mime)) = req.mime_type() { + let can_parse_json = if let Ok(Some(mime)) = req.mime_type() { mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) - || ctype.map_or(false, |predicate| predicate(mime)) + || ctype_fn.map_or(false, |predicate| predicate(mime)) } else { - false + // if `ctype_required` is false, assume payload is + // json even when content-type header is missing + !ctype_required }; - if !json { + if !can_parse_json { return JsonBody::Error(Some(JsonPayloadError::ContentType)); } @@ -341,7 +350,7 @@ where .and_then(|l| l.to_str().ok()) .and_then(|s| s.parse::().ok()); - // Notice the content_length is not checked against limit of json config here. + // Notice the content-length is not checked against limit of json config here. // As the internal usage always call JsonBody::limit after JsonBody::new. // And limit check to return an error variant of JsonBody happens there. @@ -395,10 +404,7 @@ where } } -impl Future for JsonBody -where - T: DeserializeOwned + 'static, -{ +impl Future for JsonBody { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -425,7 +431,7 @@ where } } None => { - let json = serde_json::from_slice::(&buf) + let json = serde_json::from_slice::(buf) .map_err(JsonPayloadError::Deserialize)?; return Poll::Ready(Ok(json)); } @@ -448,7 +454,7 @@ mod tests { header::{self, CONTENT_LENGTH, CONTENT_TYPE}, StatusCode, }, - test::{load_body, TestRequest}, + test::{assert_body_eq, load_body, TestRequest}, }; #[derive(Serialize, Deserialize, PartialEq, Debug)] @@ -476,15 +482,13 @@ mod tests { let j = Json(MyObject { name: "test".to_string(), }); - let resp = j.respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); + let res = j.respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), + res.headers().get(header::CONTENT_TYPE).unwrap(), header::HeaderValue::from_static("application/json") ); - - use crate::responder::tests::BodyTest; - assert_eq!(resp.body().bin_ref(), b"{\"name\":\"test\"}"); + assert_body_eq!(res, b"{\"name\":\"test\"}"); } #[actix_rt::test] @@ -581,7 +585,7 @@ mod tests { #[actix_rt::test] async fn test_json_body() { let (req, mut pl) = TestRequest::default().to_http_parts(); - let json = JsonBody::::new(&req, &mut pl, None).await; + let json = JsonBody::::new(&req, &mut pl, None, true).await; assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); let (req, mut pl) = TestRequest::default() @@ -590,7 +594,7 @@ mod tests { header::HeaderValue::from_static("application/text"), )) .to_http_parts(); - let json = JsonBody::::new(&req, &mut pl, None).await; + let json = JsonBody::::new(&req, &mut pl, None, true).await; assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); let (req, mut pl) = TestRequest::default() @@ -604,7 +608,7 @@ mod tests { )) .to_http_parts(); - let json = JsonBody::::new(&req, &mut pl, None) + let json = JsonBody::::new(&req, &mut pl, None, true) .limit(100) .await; assert!(json_eq( @@ -623,7 +627,7 @@ mod tests { .set_payload(Bytes::from_static(&[0u8; 1000])) .to_http_parts(); - let json = JsonBody::::new(&req, &mut pl, None) + let json = JsonBody::::new(&req, &mut pl, None, true) .limit(100) .await; @@ -644,7 +648,7 @@ mod tests { .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) .to_http_parts(); - let json = JsonBody::::new(&req, &mut pl, None).await; + let json = JsonBody::::new(&req, &mut pl, None, true).await; assert_eq!( json.ok().unwrap(), MyObject { @@ -714,6 +718,21 @@ mod tests { assert!(s.is_err()) } + #[actix_rt::test] + async fn test_json_with_no_content_type() { + let (req, mut pl) = TestRequest::default() + .insert_header(( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + )) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .app_data(JsonConfig::default().content_type_required(false)) + .to_http_parts(); + + let s = Json::::from_request(&req, &mut pl).await; + assert!(s.is_ok()) + } + #[actix_rt::test] async fn test_with_config_in_data_wrapper() { let (req, mut pl) = TestRequest::default() diff --git a/src/types/path.rs b/src/types/path.rs index f2273a59b..4b60d27c0 100644 --- a/src/types/path.rs +++ b/src/types/path.rs @@ -14,7 +14,7 @@ use crate::{ /// Extract typed data from request path segments. /// -/// Use [`PathConfig`] to configure extraction process. +/// Use [`PathConfig`] to configure extraction option. /// /// # Examples /// ``` @@ -90,20 +90,19 @@ impl fmt::Display for Path { } } -/// See [here](#usage) for example of usage as an extractor. +/// See [here](#Examples) for example of usage as an extractor. impl FromRequest for Path where T: de::DeserializeOwned, { type Error = Error; type Future = Ready>; - type Config = PathConfig; #[inline] fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { let error_handler = req - .app_data::() - .and_then(|c| c.ehandler.clone()); + .app_data::() + .and_then(|c| c.err_handler.clone()); ready( de::Deserialize::deserialize(PathDeserializer::new(req.match_info())) @@ -159,9 +158,9 @@ where /// ); /// } /// ``` -#[derive(Clone)] +#[derive(Clone, Default)] pub struct PathConfig { - ehandler: Option Error + Send + Sync>>, + err_handler: Option Error + Send + Sync>>, } impl PathConfig { @@ -170,17 +169,11 @@ impl PathConfig { where F: Fn(PathError, &HttpRequest) -> Error + Send + Sync + 'static, { - self.ehandler = Some(Arc::new(f)); + self.err_handler = Some(Arc::new(f)); self } } -impl Default for PathConfig { - fn default() -> Self { - PathConfig { ehandler: None } - } -} - #[cfg(test)] mod tests { use actix_router::ResourceDef; @@ -209,7 +202,7 @@ mod tests { let resource = ResourceDef::new("/{value}/"); let mut req = TestRequest::with_uri("/32/").to_srv_request(); - resource.match_path(req.match_info_mut()); + resource.capture_match_info(req.match_info_mut()); let (req, mut pl) = req.into_parts(); assert_eq!(*Path::::from_request(&req, &mut pl).await.unwrap(), 32); @@ -221,7 +214,7 @@ mod tests { let resource = ResourceDef::new("/{key}/{value}/"); let mut req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); - resource.match_path(req.match_info_mut()); + resource.capture_match_info(req.match_info_mut()); let (req, mut pl) = req.into_parts(); let (Path(res),) = <(Path<(String, String)>,)>::from_request(&req, &mut pl) @@ -247,7 +240,7 @@ mod tests { let mut req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); let resource = ResourceDef::new("/{key}/{value}/"); - resource.match_path(req.match_info_mut()); + resource.capture_match_info(req.match_info_mut()); let (req, mut pl) = req.into_parts(); let mut s = Path::::from_request(&req, &mut pl).await.unwrap(); @@ -270,7 +263,7 @@ mod tests { let mut req = TestRequest::with_uri("/name/32/").to_srv_request(); let resource = ResourceDef::new("/{key}/{value}/"); - resource.match_path(req.match_info_mut()); + resource.capture_match_info(req.match_info_mut()); let (req, mut pl) = req.into_parts(); let s = Path::::from_request(&req, &mut pl).await.unwrap(); diff --git a/src/types/payload.rs b/src/types/payload.rs index 188da6201..73987def5 100644 --- a/src/types/payload.rs +++ b/src/types/payload.rs @@ -43,11 +43,12 @@ use crate::{ /// Ok(format!("Request Body Bytes:\n{:?}", bytes)) /// } /// ``` -pub struct Payload(pub crate::dev::Payload); +pub struct Payload(dev::Payload); impl Payload { /// Unwrap to inner Payload type. - pub fn into_inner(self) -> crate::dev::Payload { + #[inline] + pub fn into_inner(self) -> dev::Payload { self.0 } } @@ -61,9 +62,8 @@ impl Stream for Payload { } } -/// See [here](#usage) for example of usage as an extractor. +/// See [here](#Examples) for example of usage as an extractor. impl FromRequest for Payload { - type Config = PayloadConfig; type Error = Error; type Future = Ready>; @@ -90,7 +90,6 @@ impl FromRequest for Payload { /// } /// ``` impl FromRequest for Bytes { - type Config = PayloadConfig; type Error = Error; type Future = Either>>; @@ -126,8 +125,7 @@ impl<'a> Future for BytesExtractFut { /// /// Text extractor automatically decode body according to the request's charset. /// -/// [**PayloadConfig**](PayloadConfig) allows to configure -/// extraction process. +/// Use [`PayloadConfig`] to configure extraction process. /// /// # Examples /// ``` @@ -139,7 +137,6 @@ impl<'a> Future for BytesExtractFut { /// format!("Body {}!", text) /// } impl FromRequest for String { - type Config = PayloadConfig; type Error = Error; type Future = Either>>; @@ -198,14 +195,15 @@ fn bytes_to_string(body: Bytes, encoding: &'static Encoding) -> Result) -> String { -/// dbg!("Authorization object={:?}", info.into_inner()); +/// dbg!("Authorization object = {:?}", info.into_inner()); /// "OK".to_string() /// } /// -/// // Or use `.0`, which is equivalent to `.into_inner()`. +/// // Or use destructuring, which is equivalent to `.into_inner()`. /// #[get("/debug2")] -/// async fn debug2(info: web::Query) -> String { -/// dbg!("Authorization object={:?}", info.0); +/// async fn debug2(web::Query(info): web::Query) -> String { +/// dbg!("Authorization object = {:?}", info); /// "OK".to_string() /// } /// ``` -#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct Query(pub T); impl Query { @@ -65,8 +65,10 @@ impl Query { pub fn into_inner(self) -> T { self.0 } +} - /// Deserialize `T` from a URL encoded query parameter string. +impl Query { + /// Deserialize a `T` from the URL encoded query parameter string. /// /// ``` /// # use std::collections::HashMap; @@ -76,10 +78,7 @@ impl Query { /// assert_eq!(numbers.get("two"), Some(&2)); /// assert!(numbers.get("three").is_none()); /// ``` - pub fn from_query(query_str: &str) -> Result - where - T: de::DeserializeOwned, - { + pub fn from_query(query_str: &str) -> Result { serde_urlencoded::from_str::(query_str) .map(Self) .map_err(QueryPayloadError::Deserialize) @@ -106,19 +105,15 @@ impl fmt::Display for Query { } } -/// See [here](#usage) for example of usage as an extractor. -impl FromRequest for Query -where - T: de::DeserializeOwned, -{ +/// See [here](#Examples) for example of usage as an extractor. +impl FromRequest for Query { type Error = Error; type Future = Ready>; - type Config = QueryConfig; #[inline] fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { let error_handler = req - .app_data::() + .app_data::() .and_then(|c| c.err_handler.clone()); serde_urlencoded::from_str::(req.query_string()) @@ -165,14 +160,14 @@ where /// let query_cfg = web::QueryConfig::default() /// // use custom error handler /// .error_handler(|err, req| { -/// error::InternalError::from_response(err, HttpResponse::Conflict().into()).into() +/// error::InternalError::from_response(err, HttpResponse::Conflict().finish()).into() /// }); /// /// App::new() /// .app_data(query_cfg) /// .service(index); /// ``` -#[derive(Clone)] +#[derive(Clone, Default)] pub struct QueryConfig { err_handler: Option Error + Send + Sync>>, } @@ -188,12 +183,6 @@ impl QueryConfig { } } -impl Default for QueryConfig { - fn default() -> Self { - QueryConfig { err_handler: None } - } -} - #[cfg(test)] mod tests { use actix_http::http::StatusCode; @@ -213,10 +202,10 @@ mod tests { #[actix_rt::test] async fn test_service_request_extract() { let req = TestRequest::with_uri("/name/user1/").to_srv_request(); - assert!(Query::::from_query(&req.query_string()).is_err()); + assert!(Query::::from_query(req.query_string()).is_err()); let req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); - let mut s = Query::::from_query(&req.query_string()).unwrap(); + let mut s = Query::::from_query(req.query_string()).unwrap(); assert_eq!(s.id, "test"); assert_eq!( diff --git a/src/web.rs b/src/web.rs index 40ac46275..b58adc2f8 100644 --- a/src/web.rs +++ b/src/web.rs @@ -1,46 +1,38 @@ //! Essentials helper functions and types for application registration. -use actix_http::http::Method; -use actix_router::IntoPattern; -use std::future::Future; +use std::{error::Error as StdError, future::Future}; -pub use actix_http::Response as HttpResponse; +use actix_http::http::Method; +use actix_router::IntoPatterns; pub use bytes::{Buf, BufMut, Bytes, BytesMut}; -use crate::error::BlockingError; -use crate::extract::FromRequest; -use crate::handler::Handler; -use crate::resource::Resource; -use crate::responder::Responder; -use crate::route::Route; -use crate::scope::Scope; -use crate::service::WebService; +use crate::{ + body::MessageBody, error::BlockingError, extract::FromRequest, handler::Handler, + resource::Resource, responder::Responder, route::Route, scope::Scope, service::WebService, +}; pub use crate::config::ServiceConfig; pub use crate::data::Data; pub use crate::request::HttpRequest; pub use crate::request_data::ReqData; +pub use crate::response::HttpResponse; pub use crate::types::*; -/// Create resource for a specific path. +/// Creates a new resource for a specific path. /// -/// Resources may have variable path segments. For example, a -/// resource with the path `/a/{name}/c` would match all incoming -/// requests with paths such as `/a/b/c`, `/a/1/c`, or `/a/etc/c`. +/// Resources may have dynamic path segments. For example, a resource with the path `/a/{name}/c` +/// would match all incoming requests with paths such as `/a/b/c`, `/a/1/c`, or `/a/etc/c`. /// -/// A variable segment is specified in the form `{identifier}`, -/// where the identifier can be used later in a request handler to -/// access the matched value for that segment. This is done by -/// looking up the identifier in the `Params` object returned by -/// `HttpRequest.match_info()` method. +/// A dynamic segment is specified in the form `{identifier}`, where the identifier can be used +/// later in a request handler to access the matched value for that segment. This is done by looking +/// up the identifier in the `Path` object returned by [`HttpRequest.match_info()`] method. /// /// By default, each segment matches the regular expression `[^{}/]+`. /// /// You can also specify a custom regex in the form `{identifier:regex}`: /// -/// For instance, to route `GET`-requests on any route matching -/// `/users/{userid}/{friend}` and store `userid` and `friend` in -/// the exposed `Params` object: +/// For instance, to route `GET`-requests on any route matching `/users/{userid}/{friend}` and store +/// `userid` and `friend` in the exposed `Path` object: /// /// ``` /// use actix_web::{web, App, HttpResponse}; @@ -51,14 +43,20 @@ pub use crate::types::*; /// .route(web::head().to(|| HttpResponse::MethodNotAllowed())) /// ); /// ``` -pub fn resource(path: T) -> Resource { +pub fn resource(path: T) -> Resource { Resource::new(path) } -/// Configure scope for common root path. +/// Creates scope for common path prefix. /// -/// Scopes collect multiple paths under a common path prefix. -/// Scope path can contain variable path segments as resources. +/// Scopes collect multiple paths under a common path prefix. The scope's path can contain dynamic +/// path segments. +/// +/// # Examples +/// In this example, three routes are set up (and will handle any method): +/// * `/{project_id}/path1` +/// * `/{project_id}/path2` +/// * `/{project_id}/path3` /// /// ``` /// use actix_web::{web, App, HttpResponse}; @@ -70,148 +68,50 @@ pub fn resource(path: T) -> Resource { /// .service(web::resource("/path3").to(|| HttpResponse::MethodNotAllowed())) /// ); /// ``` -/// -/// In the above example, three routes get added: -/// * /{project_id}/path1 -/// * /{project_id}/path2 -/// * /{project_id}/path3 -/// pub fn scope(path: &str) -> Scope { Scope::new(path) } -/// Create *route* without configuration. +/// Creates a new un-configured route. pub fn route() -> Route { Route::new() } -/// Create *route* with `GET` method guard. -/// -/// ``` -/// use actix_web::{web, App, HttpResponse}; -/// -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::get().to(|| HttpResponse::Ok())) -/// ); -/// ``` -/// -/// In the above example, one `GET` route gets added: -/// * /{project_id} -/// -pub fn get() -> Route { - method(Method::GET) +macro_rules! method_route { + ($method_fn:ident, $method_const:ident) => { + paste::paste! { + #[doc = " Creates a new route with `" $method_const "` method guard."] + /// + /// # Examples + #[doc = " In this example, one `" $method_const " /{project_id}` route is set up:"] + /// ``` + /// use actix_web::{web, App, HttpResponse}; + /// + /// let app = App::new().service( + /// web::resource("/{project_id}") + #[doc = " .route(web::" $method_fn "().to(|| HttpResponse::Ok()))"] + /// + /// ); + /// ``` + pub fn $method_fn() -> Route { + method(Method::$method_const) + } + } + }; } -/// Create *route* with `POST` method guard. -/// -/// ``` -/// use actix_web::{web, App, HttpResponse}; -/// -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::post().to(|| HttpResponse::Ok())) -/// ); -/// ``` -/// -/// In the above example, one `POST` route gets added: -/// * /{project_id} -/// -pub fn post() -> Route { - method(Method::POST) -} +method_route!(get, GET); +method_route!(post, POST); +method_route!(put, PUT); +method_route!(patch, PATCH); +method_route!(delete, DELETE); +method_route!(head, HEAD); +method_route!(trace, TRACE); -/// Create *route* with `PUT` method guard. +/// Creates a new route with specified method guard. /// -/// ``` -/// use actix_web::{web, App, HttpResponse}; -/// -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::put().to(|| HttpResponse::Ok())) -/// ); -/// ``` -/// -/// In the above example, one `PUT` route gets added: -/// * /{project_id} -/// -pub fn put() -> Route { - method(Method::PUT) -} - -/// Create *route* with `PATCH` method guard. -/// -/// ``` -/// use actix_web::{web, App, HttpResponse}; -/// -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::patch().to(|| HttpResponse::Ok())) -/// ); -/// ``` -/// -/// In the above example, one `PATCH` route gets added: -/// * /{project_id} -/// -pub fn patch() -> Route { - method(Method::PATCH) -} - -/// Create *route* with `DELETE` method guard. -/// -/// ``` -/// use actix_web::{web, App, HttpResponse}; -/// -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::delete().to(|| HttpResponse::Ok())) -/// ); -/// ``` -/// -/// In the above example, one `DELETE` route gets added: -/// * /{project_id} -/// -pub fn delete() -> Route { - method(Method::DELETE) -} - -/// Create *route* with `HEAD` method guard. -/// -/// ``` -/// use actix_web::{web, App, HttpResponse}; -/// -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::head().to(|| HttpResponse::Ok())) -/// ); -/// ``` -/// -/// In the above example, one `HEAD` route gets added: -/// * /{project_id} -/// -pub fn head() -> Route { - method(Method::HEAD) -} - -/// Create *route* with `TRACE` method guard. -/// -/// ``` -/// use actix_web::{web, App, HttpResponse}; -/// -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::trace().to(|| HttpResponse::Ok())) -/// ); -/// ``` -/// -/// In the above example, one `HEAD` route gets added: -/// * /{project_id} -/// -pub fn trace() -> Route { - method(Method::TRACE) -} - -/// Create *route* and add method guard. +/// # Examples +/// In this example, one `GET /{project_id}` route is set up: /// /// ``` /// use actix_web::{web, http, App, HttpResponse}; @@ -221,15 +121,11 @@ pub fn trace() -> Route { /// .route(web::method(http::Method::GET).to(|| HttpResponse::Ok())) /// ); /// ``` -/// -/// In the above example, one `GET` route gets added: -/// * /{project_id} -/// pub fn method(method: Method) -> Route { Route::new().method(method) } -/// Create a new route and add handler. +/// Creates a new any-method route with handler. /// /// ``` /// use actix_web::{web, App, HttpResponse, Responder}; @@ -249,11 +145,13 @@ where I: FromRequest + 'static, R: Future + 'static, R::Output: Responder + 'static, + ::Body: MessageBody + 'static, + <::Body as MessageBody>::Error: Into>, { Route::new().to(handler) } -/// Create raw service for a specific path. +/// Creates a raw service for a specific path. /// /// ``` /// use actix_web::{dev, web, guard, App, Error, HttpResponse}; @@ -268,12 +166,12 @@ where /// .finish(my_service) /// ); /// ``` -pub fn service(path: T) -> WebService { +pub fn service(path: T) -> WebService { WebService::new(path) } -/// Execute blocking function on a thread pool, returns future that resolves -/// to result of the function execution. +/// Executes blocking function on a thread pool, returns future that resolves to result of the +/// function execution. pub fn block(f: F) -> impl Future> where F: FnOnce() -> R + Send + 'static, diff --git a/tests/test-macro-import-conflict.rs b/tests/test-macro-import-conflict.rs new file mode 100644 index 000000000..0d23bb41d --- /dev/null +++ b/tests/test-macro-import-conflict.rs @@ -0,0 +1,15 @@ +//! Checks that test macro does not cause problems in the presence of imports named "test" that +//! could be either a module with test items or the "test with runtime" macro itself. +//! +//! Before actix/actix-net#399 was implemented, this macro was running twice. The first run output +//! `#[test]` and it got run again and since it was in scope. +//! +//! Prevented by using the fully-qualified test marker (`#[::core::prelude::v1::test]`). + +use actix_web::test; + +#[actix_web::test] +async fn test_macro_naming_conflict() { + let _req = test::TestRequest::default(); + assert_eq!(async { 1 }.await, 1); +} diff --git a/tests/test_httpserver.rs b/tests/test_httpserver.rs index 881c6ce94..887b51d41 100644 --- a/tests/test_httpserver.rs +++ b/tests/test_httpserver.rs @@ -14,57 +14,45 @@ async fn test_start() { let (tx, rx) = mpsc::channel(); thread::spawn(move || { - let sys = actix_rt::System::new(); + actix_rt::System::new() + .block_on(async { + let srv = HttpServer::new(|| { + App::new().service( + web::resource("/").route(web::to(|| HttpResponse::Ok().body("test"))), + ) + }) + .workers(1) + .backlog(1) + .max_connections(10) + .max_connection_rate(10) + .keep_alive(10) + .client_timeout(5000) + .client_shutdown(0) + .server_hostname("localhost") + .system_exit() + .disable_signals() + .bind(format!("{}", addr)) + .unwrap() + .run(); - sys.block_on(async { - let srv = HttpServer::new(|| { - App::new().service( - web::resource("/").route(web::to(|| HttpResponse::Ok().body("test"))), - ) + tx.send(srv.handle()).unwrap(); + + srv.await }) - .workers(1) - .backlog(1) - .max_connections(10) - .max_connection_rate(10) - .keep_alive(10) - .client_timeout(5000) - .client_shutdown(0) - .server_hostname("localhost") - .system_exit() - .disable_signals() - .bind(format!("{}", addr)) - .unwrap() - .run(); - - let _ = tx.send((srv, actix_rt::System::current())); - }); - - let _ = sys.run(); + .unwrap(); }); - let (srv, sys) = rx.recv().unwrap(); - #[cfg(feature = "client")] - { - use actix_http::client; + let srv = rx.recv().unwrap(); - let client = awc::Client::builder() - .connector( - client::Connector::new() - .timeout(Duration::from_millis(100)) - .finish(), - ) - .finish(); + let client = awc::Client::builder() + .connector(awc::Connector::new().timeout(Duration::from_millis(100))) + .finish(); - let host = format!("http://{}", addr); - let response = client.get(host.clone()).send().await.unwrap(); - assert!(response.status().is_success()); - } + let host = format!("http://{}", addr); + let response = client.get(host.clone()).send().await.unwrap(); + assert!(response.status().is_success()); - // stop - let _ = srv.stop(false); - - thread::sleep(Duration::from_millis(100)); - let _ = sys.stop(); + srv.stop(false).await; } #[cfg(feature = "openssl")] @@ -92,37 +80,38 @@ fn ssl_acceptor() -> openssl::ssl::SslAcceptorBuilder { #[cfg(feature = "openssl")] async fn test_start_ssl() { use actix_web::HttpRequest; + use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; let addr = actix_test::unused_addr(); let (tx, rx) = mpsc::channel(); thread::spawn(move || { - let sys = actix_rt::System::new(); - let builder = ssl_acceptor(); + actix_rt::System::new() + .block_on(async { + let builder = ssl_acceptor(); - let srv = HttpServer::new(|| { - App::new().service(web::resource("/").route(web::to(|req: HttpRequest| { - assert!(req.app_config().secure()); - HttpResponse::Ok().body("test") - }))) - }) - .workers(1) - .shutdown_timeout(1) - .system_exit() - .disable_signals() - .bind_openssl(format!("{}", addr), builder) - .unwrap(); + let srv = HttpServer::new(|| { + App::new().service(web::resource("/").route(web::to(|req: HttpRequest| { + assert!(req.app_config().secure()); + HttpResponse::Ok().body("test") + }))) + }) + .workers(1) + .shutdown_timeout(1) + .system_exit() + .disable_signals() + .bind_openssl(format!("{}", addr), builder) + .unwrap(); - sys.block_on(async { - let srv = srv.run(); - let _ = tx.send((srv, actix_rt::System::current())); - }); + let srv = srv.run(); + tx.send(srv.handle()).unwrap(); - let _ = sys.run(); + srv.await + }) + .unwrap() }); - let (srv, sys) = rx.recv().unwrap(); + let srv = rx.recv().unwrap(); - use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); builder.set_verify(SslVerifyMode::NONE); let _ = builder @@ -141,9 +130,5 @@ async fn test_start_ssl() { let response = client.get(host.clone()).send().await.unwrap(); assert!(response.status().is_success()); - // stop - let _ = srv.stop(false); - - thread::sleep(Duration::from_millis(100)); - let _ = sys.stop(); + srv.stop(false).await; } diff --git a/tests/test_server.rs b/tests/test_server.rs index afea39dd9..a850f228d 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -127,6 +127,8 @@ async fn test_body() { // read response let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -154,6 +156,8 @@ async fn test_body_gzip() { let mut dec = Vec::new(); e.read_to_end(&mut dec).unwrap(); assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -181,6 +185,8 @@ async fn test_body_gzip2() { let mut dec = Vec::new(); e.read_to_end(&mut dec).unwrap(); assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -194,13 +200,10 @@ async fn test_body_encoding_override() { .body(STR) }))) .service(web::resource("/raw").route(web::to(|| { - let body = actix_web::dev::Body::Bytes(STR.into()); let mut response = - HttpResponse::with_body(actix_web::http::StatusCode::OK, body); - + HttpResponse::with_body(actix_web::http::StatusCode::OK, STR); response.encoding(ContentEncoding::Deflate); - - response + response.map_into_boxed_body() }))) }); @@ -241,6 +244,8 @@ async fn test_body_encoding_override() { e.write_all(bytes.as_ref()).unwrap(); let dec = e.finish().unwrap(); assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -275,6 +280,8 @@ async fn test_body_gzip_large() { let mut dec = Vec::new(); e.read_to_end(&mut dec).unwrap(); assert_eq!(Bytes::from(dec), Bytes::from(data)); + + srv.stop().await; } #[actix_rt::test] @@ -314,6 +321,8 @@ async fn test_body_gzip_large_random() { e.read_to_end(&mut dec).unwrap(); assert_eq!(dec.len(), data.len()); assert_eq!(Bytes::from(dec), Bytes::from(data)); + + srv.stop().await; } #[actix_rt::test] @@ -348,6 +357,8 @@ async fn test_body_chunked_implicit() { let mut dec = Vec::new(); e.read_to_end(&mut dec).unwrap(); assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -380,6 +391,8 @@ async fn test_body_br_streaming() { let dec = e.finish().unwrap(); println!("T: {:?}", Bytes::copy_from_slice(&dec)); assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -401,6 +414,8 @@ async fn test_head_binary() { // read response let bytes = response.body().await.unwrap(); assert!(bytes.is_empty()); + + srv.stop().await; } #[actix_rt::test] @@ -420,6 +435,8 @@ async fn test_no_chunking() { // read response let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -447,6 +464,8 @@ async fn test_body_deflate() { e.write_all(bytes.as_ref()).unwrap(); let dec = e.finish().unwrap(); assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -475,6 +494,8 @@ async fn test_body_brotli() { e.write_all(bytes.as_ref()).unwrap(); let dec = e.finish().unwrap(); assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -503,6 +524,8 @@ async fn test_body_zstd() { let mut dec = Vec::new(); e.read_to_end(&mut dec).unwrap(); assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -534,6 +557,8 @@ async fn test_body_zstd_streaming() { let mut dec = Vec::new(); e.read_to_end(&mut dec).unwrap(); assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -559,6 +584,8 @@ async fn test_zstd_encoding() { // read response let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -594,6 +621,8 @@ async fn test_zstd_encoding_large() { // read response let bytes = response.body().limit(320_000).await.unwrap(); assert_eq!(bytes, Bytes::from(data)); + + srv.stop().await; } #[actix_rt::test] @@ -619,6 +648,8 @@ async fn test_encoding() { // read response let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -644,6 +675,8 @@ async fn test_gzip_encoding() { // read response let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -670,6 +703,8 @@ async fn test_gzip_encoding_large() { // read response let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from(data)); + + srv.stop().await; } #[actix_rt::test] @@ -702,6 +737,8 @@ async fn test_reading_gzip_encoding_large_random() { let bytes = response.body().await.unwrap(); assert_eq!(bytes.len(), data.len()); assert_eq!(bytes, Bytes::from(data)); + + srv.stop().await; } #[actix_rt::test] @@ -727,6 +764,8 @@ async fn test_reading_deflate_encoding() { // read response let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -753,6 +792,8 @@ async fn test_reading_deflate_encoding_large() { // read response let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from(data)); + + srv.stop().await; } #[actix_rt::test] @@ -785,6 +826,8 @@ async fn test_reading_deflate_encoding_large_random() { let bytes = response.body().await.unwrap(); assert_eq!(bytes.len(), data.len()); assert_eq!(bytes, Bytes::from(data)); + + srv.stop().await; } #[actix_rt::test] @@ -810,6 +853,8 @@ async fn test_brotli_encoding() { // read response let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -845,6 +890,8 @@ async fn test_brotli_encoding_large() { // read response let bytes = response.body().limit(320_000).await.unwrap(); assert_eq!(bytes, Bytes::from(data)); + + srv.stop().await; } #[cfg(feature = "openssl")] @@ -861,9 +908,9 @@ async fn test_brotli_encoding_large_openssl() { }); // body - let mut e = BrotliEncoder::new(Vec::new(), 3); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); + let mut enc = BrotliEncoder::new(Vec::new(), 3); + enc.write_all(data.as_ref()).unwrap(); + let enc = enc.finish().unwrap(); // client request let mut response = srv @@ -877,33 +924,39 @@ async fn test_brotli_encoding_large_openssl() { // read response let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from(data)); + + srv.stop().await; } #[cfg(feature = "rustls")] mod plus_rustls { use std::io::BufReader; - use rustls::{ - internal::pemfile::{certs, pkcs8_private_keys}, - NoClientAuth, ServerConfig as RustlsServerConfig, - }; + use rustls::{Certificate, PrivateKey, ServerConfig as RustlsServerConfig}; + use rustls_pemfile::{certs, pkcs8_private_keys}; use super::*; - fn rustls_config() -> RustlsServerConfig { + fn tls_config() -> RustlsServerConfig { let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); let cert_file = cert.serialize_pem().unwrap(); let key_file = cert.serialize_private_key_pem(); - let mut config = RustlsServerConfig::new(NoClientAuth::new()); let cert_file = &mut BufReader::new(cert_file.as_bytes()); let key_file = &mut BufReader::new(key_file.as_bytes()); - let cert_chain = certs(cert_file).unwrap(); + let cert_chain = certs(cert_file) + .unwrap() + .into_iter() + .map(Certificate) + .collect(); let mut keys = pkcs8_private_keys(key_file).unwrap(); - config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); - config + RustlsServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(cert_chain, PrivateKey(keys.remove(0))) + .unwrap() } #[actix_rt::test] @@ -914,7 +967,7 @@ mod plus_rustls { .map(char::from) .collect::(); - let srv = actix_test::start_with(actix_test::config().rustls(rustls_config()), || { + let srv = actix_test::start_with(actix_test::config().rustls(tls_config()), || { App::new().service(web::resource("/").route(web::to(|bytes: Bytes| { HttpResponse::Ok() .encoding(actix_web::http::ContentEncoding::Identity) @@ -940,6 +993,8 @@ mod plus_rustls { let bytes = response.body().await.unwrap(); assert_eq!(bytes.len(), data.len()); assert_eq!(bytes, Bytes::from(data)); + + srv.stop().await; } } @@ -994,6 +1049,8 @@ async fn test_server_cookies() { assert_eq!(cookies[0], second_cookie); assert_eq!(cookies[1], first_cookie); } + + srv.stop().await; } #[actix_rt::test] @@ -1014,6 +1071,8 @@ async fn test_slow_request() { let mut data = String::new(); let _ = stream.read_to_string(&mut data); assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); + + srv.stop().await; } #[actix_rt::test] @@ -1026,6 +1085,8 @@ async fn test_normalize() { let response = srv.get("/one/").send().await.unwrap(); assert!(response.status().is_success()); + + srv.stop().await } // allow deprecated App::data @@ -1077,3 +1138,24 @@ async fn test_data_drop() { assert_eq!(num.load(Ordering::SeqCst), 0); } + +#[actix_rt::test] +async fn test_accept_encoding_no_match() { + let srv = actix_test::start_with(actix_test::config().h1(), || { + App::new() + .wrap(Compress::default()) + .service(web::resource("/").route(web::to(move || HttpResponse::Ok().finish()))) + }); + + let response = srv + .get("/") + .append_header((ACCEPT_ENCODING, "compress, identity;q=0")) + .no_decompress() + .send() + .await + .unwrap(); + + assert_eq!(response.status().as_u16(), 406); + + srv.stop().await; +}