1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-06-30 00:14:58 +02:00

Merge remote-tracking branch 'origin/master' into on-connect-fix

This commit is contained in:
Rob Ede
2021-12-05 04:46:19 +00:00
233 changed files with 14067 additions and 5964 deletions

View File

@ -2,12 +2,123 @@
## Unreleased - 2021-xx-xx
### Added
* Add timeout for canceling HTTP/2 server side connection handshake. Default to 5 seconds. [#2483]
* HTTP/2 handshake timeout can be configured with `ServiceConfig::client_timeout`. [#2483]
* `Response::map_into_boxed_body`. [#2468]
* `body::EitherBody` enum. [#2468]
* `body::None` struct. [#2468]
* Impl `MessageBody` for `bytestring::ByteString`. [#2468]
* `impl Clone for ws::HandshakeError`. [#2468]
* `#[must_use]` for `ws::Codec` to prevent subtle bugs. [#1920]
* `impl Default ` for `ws::Codec`. [#1920]
* `header::QualityItem::{max, min}`. [#2486]
* `header::Quality::{MAX, MIN}`. [#2486]
* `impl Display` for `header::Quality`. [#2486]
* `CloneableExtensions` object for use in `on_connect` handlers. [#2327]
### Changed
* Rename `body::BoxBody::{from_body => new}`. [#2468]
* Body type for `Responses` returned from `Response::{new, ok, etc...}` is now `BoxBody`. [#2468]
* The `Error` associated type on `MessageBody` type now requires `impl Error` (or similar). [#2468]
* Error types using in service builders now require `Into<Response<BoxBody>>`. [#2468]
* `From` implementations on error types now return a `Response<BoxBody>`. [#2468]
* `ResponseBuilder::body(B)` now returns `Response<EitherBody<B>>`. [#2468]
* `ResponseBuilder::finish()` now returns `Response<EitherBody<()>>`. [#2468]
* `on_connect_ext` methods now receive a `CloneableExtensions` object. [#2327]
### Removed
* `ResponseBuilder::streaming`. [#2468]
* `impl Future` for `ResponseBuilder`. [#2468]
* Remove unnecessary `MessageBody` bound on types passed to `body::AnyBody::new`. [#2468]
* Move `body::AnyBody` to `awc`. Replaced with `EitherBody` and `BoxBody`. [#2468]
* `impl Copy` for `ws::Codec`. [#1920]
* `header::qitem` helper. Replaced with `header::QualityItem::max` [#2486]
* `impl TryFrom<u16>` for `header::Quality` [#2486]
[#2327]: https://github.com/actix/actix-web/pull/2327
[#2483]: https://github.com/actix/actix-web/pull/2483
[#2468]: https://github.com/actix/actix-web/pull/2468
[#1920]: https://github.com/actix/actix-web/pull/1920
[#2486]: https://github.com/actix/actix-web/pull/2486
## 3.0.0-beta.14 - 2021-11-30
### Changed
* Guarantee ordering of `header::GetAll` iterator to be same as insertion order. [#2467]
* Expose `header::map` module. [#2467]
* Implement `ExactSizeIterator` and `FusedIterator` for all `HeaderMap` iterators. [#2470]
* Update `actix-tls` to `3.0.0-rc.1`. [#2474]
[#2467]: https://github.com/actix/actix-web/pull/2467
[#2470]: https://github.com/actix/actix-web/pull/2470
[#2474]: https://github.com/actix/actix-web/pull/2474
## 3.0.0-beta.13 - 2021-11-22
### Added
* `body::AnyBody::empty` for quickly creating an empty body. [#2446]
* `body::AnyBody::none` for quickly creating a "none" body. [#2456]
* `impl Clone` for `body::AnyBody<S> where S: Clone`. [#2448]
* `body::AnyBody::into_boxed` for quickly converting to a type-erased, boxed body type. [#2448]
### Changed
* Rename `body::AnyBody::{Message => Body}`. [#2446]
* Rename `body::AnyBody::{from_message => new_boxed}`. [#2448]
* Rename `body::AnyBody::{from_slice => copy_from_slice}`. [#2448]
* Rename `body::{BoxAnyBody => BoxBody}`. [#2448]
* Change representation of `AnyBody` to include a type parameter in `Body` variant. Defaults to `BoxBody`. [#2448]
* `Encoder::response` now returns `AnyBody<Encoder<B>>`. [#2448]
### Removed
* `body::AnyBody::Empty`; an empty body can now only be represented as a zero-length `Bytes` variant. [#2446]
* `body::BodySize::Empty`; an empty body can now only be represented as a `Sized(0)` variant. [#2446]
* `EncoderError::Boxed`; it is no longer required. [#2446]
* `body::ResponseBody`; is function is replaced by the new `body::AnyBody` enum. [#2446]
[#2446]: https://github.com/actix/actix-web/pull/2446
[#2448]: https://github.com/actix/actix-web/pull/2448
[#2456]: https://github.com/actix/actix-web/pull/2456
## 3.0.0-beta.12 - 2021-11-15
### Changed
* Update `actix-server` to `2.0.0-beta.9`. [#2442]
### Removed
* `client` module. [#2425]
* `trust-dns` feature. [#2425]
[#2425]: https://github.com/actix/actix-web/pull/2425
[#2442]: https://github.com/actix/actix-web/pull/2442
## 3.0.0-beta.11 - 2021-10-20
### Changed
* Updated rustls to v0.20. [#2414]
* Minimum supported Rust version (MSRV) is now 1.52.
[#2414]: https://github.com/actix/actix-web/pull/2414
## 3.0.0-beta.10 - 2021-09-09
### Changed
* `ContentEncoding` is now marked `#[non_exhaustive]`. [#2377]
* Minimum supported Rust version (MSRV) is now 1.51.
### Fixed
* Remove slice creation pointing to potential uninitialized data on h1 encoder. [#2364]
* Remove `Into<Error>` bound on `Encoder` body types. [#2375]
* Fix quality parse error in Accept-Encoding header. [#2344]
[#2364]: https://github.com/actix/actix-web/pull/2364
[#2375]: https://github.com/actix/actix-web/pull/2375
[#2344]: https://github.com/actix/actix-web/pull/2344
[#2377]: https://github.com/actix/actix-web/pull/2377
## 3.0.0-beta.9 - 2021-08-09
### Fixed
* Potential HTTP request smuggling vulnerabilities. [RUSTSEC-2021-0081](https://github.com/rustsec/advisory-db/pull/977)
## 3.0.0-beta.8 - 2021-06-26
@ -217,6 +328,11 @@
[#1878]: https://github.com/actix/actix-web/pull/1878
## 2.2.1 - 2021-08-09
### Fixed
* Potential HTTP request smuggling vulnerabilities. [RUSTSEC-2021-0081](https://github.com/rustsec/advisory-db/pull/977)
## 2.2.0 - 2020-11-25
### Added
* HttpResponse builders for 1xx status codes. [#1768]

View File

@ -1,14 +1,17 @@
[package]
name = "actix-http"
version = "3.0.0-beta.8"
version = "3.0.0-beta.14"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "HTTP primitives for the Actix ecosystem"
keywords = ["actix", "http", "framework", "async", "futures"]
homepage = "https://actix.rs"
repository = "https://github.com/actix/actix-web"
categories = ["network-programming", "asynchronous",
"web-programming::http-server",
"web-programming::websocket"]
repository = "https://github.com/actix/actix-web.git"
categories = [
"network-programming",
"asynchronous",
"web-programming::http-server",
"web-programming::websocket",
]
license = "MIT OR Apache-2.0"
edition = "2018"
@ -24,29 +27,25 @@ path = "src/lib.rs"
default = []
# openssl
openssl = ["actix-tls/openssl"]
openssl = ["actix-tls/accept", "actix-tls/openssl"]
# rustls support
rustls = ["actix-tls/rustls"]
rustls = ["actix-tls/accept", "actix-tls/rustls"]
# enable compression support
compress-brotli = ["brotli2", "__compress"]
compress-gzip = ["flate2", "__compress"]
compress-zstd = ["zstd", "__compress"]
# trust-dns as client dns resolver
trust-dns = ["trust-dns-resolver"]
# Internal (PRIVATE!) features used to aid testing and cheking feature status.
# Don't rely on these whatsoever. They may disappear at anytime.
__compress = []
[dependencies]
actix-service = "2.0.0"
actix-codec = "0.4.0"
actix-codec = "0.4.1"
actix-utils = "3.0.0"
actix-rt = "2.2"
actix-tls = { version = "3.0.0-beta.5", features = ["accept", "connect"] }
ahash = "0.7"
base64 = "0.13"
@ -58,45 +57,45 @@ encoding_rs = "0.8"
futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] }
futures-util = { version = "0.3.7", default-features = false, features = ["alloc", "sink"] }
h2 = "0.3.1"
http = "0.2.2"
httparse = "1.3"
http = "0.2.5"
httparse = "1.5.1"
httpdate = "1.0.1"
itoa = "0.4"
language-tags = "0.3"
local-channel = "0.1"
once_cell = "1.5"
log = "0.4"
mime = "0.3"
percent-encoding = "2.1"
pin-project = "1.0.0"
pin-project-lite = "0.2"
rand = "0.8"
regex = "1.3"
serde = "1.0"
sha-1 = "0.9"
smallvec = "1.6"
time = { version = "0.2.23", default-features = false, features = ["std"] }
tokio = { version = "1.2", features = ["sync"] }
smallvec = "1.6.1"
# tls
actix-tls = { version = "3.0.0-rc.1", default-features = false, optional = true }
# compression
brotli2 = { version="0.3.2", optional = true }
flate2 = { version = "1.0.13", optional = true }
zstd = { version = "0.7", optional = true }
trust-dns-resolver = { version = "0.20.0", optional = true }
zstd = { version = "0.9", optional = true }
[dev-dependencies]
actix-server = "2.0.0-beta.3"
actix-http-test = { version = "3.0.0-beta.4", features = ["openssl"] }
actix-tls = { version = "3.0.0-beta.5", features = ["openssl"] }
actix-server = "2.0.0-beta.9"
actix-http-test = { version = "3.0.0-beta.7", features = ["openssl"] }
actix-tls = { version = "3.0.0-rc.1", features = ["openssl"] }
async-stream = "0.3"
criterion = { version = "0.3", features = ["html_reports"] }
env_logger = "0.8"
env_logger = "0.9"
rcgen = "0.8"
regex = "1.3"
rustls-pemfile = "0.2"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tls-openssl = { version = "0.10", package = "openssl" }
tls-rustls = { version = "0.19", package = "rustls" }
webpki = { version = "0.21.0" }
static_assertions = "1"
tls-openssl = { package = "openssl", version = "0.10.9" }
tls-rustls = { package = "rustls", version = "0.20.0" }
tokio = { version = "1.2", features = ["net", "rt"] }
[[example]]
name = "ws"
@ -113,3 +112,7 @@ harness = false
[[bench]]
name = "uninit-headers"
harness = false
[[bench]]
name = "quality-value"
harness = false

View File

@ -3,18 +3,18 @@
> HTTP primitives for the Actix ecosystem.
[![crates.io](https://img.shields.io/crates/v/actix-http?label=latest)](https://crates.io/crates/actix-http)
[![Documentation](https://docs.rs/actix-http/badge.svg?version=3.0.0-beta.8)](https://docs.rs/actix-http/3.0.0-beta.8)
[![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html)
[![Documentation](https://docs.rs/actix-http/badge.svg?version=3.0.0-beta.14)](https://docs.rs/actix-http/3.0.0-beta.14)
[![Version](https://img.shields.io/badge/rustc-1.52+-ab6000.svg)](https://blog.rust-lang.org/2021/05/06/Rust-1.52.0.html)
![MIT or Apache 2.0 licensed](https://img.shields.io/crates/l/actix-http.svg)
<br />
[![dependency status](https://deps.rs/crate/actix-http/3.0.0-beta.8/status.svg)](https://deps.rs/crate/actix-http/3.0.0-beta.8)
[![dependency status](https://deps.rs/crate/actix-http/3.0.0-beta.14/status.svg)](https://deps.rs/crate/actix-http/3.0.0-beta.14)
[![Download](https://img.shields.io/crates/d/actix-http.svg)](https://crates.io/crates/actix-http)
[![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x)
## Documentation & Resources
- [API Documentation](https://docs.rs/actix-http)
- Minimum Supported Rust Version (MSRV): 1.46.0
- Minimum Supported Rust Version (MSRV): 1.52
## Example

View File

@ -0,0 +1,90 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
const CODES: &[u16] = &[0, 1000, 201, 800, 550];
fn bench_quality_display_impls(c: &mut Criterion) {
let mut group = c.benchmark_group("quality value display impls");
for i in CODES.iter() {
group.bench_with_input(BenchmarkId::new("New (fast?)", i), i, |b, &i| {
b.iter(|| _new::Quality(i).to_string())
});
group.bench_with_input(BenchmarkId::new("Naive", i), i, |b, &i| {
b.iter(|| _naive::Quality(i).to_string())
});
}
group.finish();
}
criterion_group!(benches, bench_quality_display_impls);
criterion_main!(benches);
mod _new {
use std::fmt;
pub struct Quality(pub(crate) u16);
impl fmt::Display for Quality {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
0 => f.write_str("0"),
1000 => f.write_str("1"),
// some number in the range 1999
x => {
f.write_str("0.")?;
// this implementation avoids string allocation otherwise required
// for `.trim_end_matches('0')`
if x < 10 {
f.write_str("00")?;
// 0 is handled so it's not possible to have a trailing 0, we can just return
itoa::fmt(f, x)
} else if x < 100 {
f.write_str("0")?;
if x % 10 == 0 {
// trailing 0, divide by 10 and write
itoa::fmt(f, x / 10)
} else {
itoa::fmt(f, x)
}
} else {
// x is in range 101999
if x % 100 == 0 {
// two trailing 0s, divide by 100 and write
itoa::fmt(f, x / 100)
} else if x % 10 == 0 {
// one trailing 0, divide by 10 and write
itoa::fmt(f, x / 10)
} else {
itoa::fmt(f, x)
}
}
}
}
}
}
}
mod _naive {
use std::fmt;
pub struct Quality(pub(crate) u16);
impl fmt::Display for Quality {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
0 => f.write_str("0"),
1000 => f.write_str("1"),
x => {
write!(f, "{}", format!("{:03}", x).trim_end_matches('0'))
}
}
}
}
}

View File

@ -18,7 +18,8 @@ fn bench_write_camel_case(c: &mut Criterion) {
group.bench_with_input(BenchmarkId::new("New", i), bts, |b, bts| {
b.iter(|| {
let mut buf = black_box([0; 24]);
_new::write_camel_case(black_box(bts), &mut buf)
let len = black_box(bts.len());
_new::write_camel_case(black_box(bts), buf.as_mut_ptr(), len)
});
});
}
@ -30,9 +31,12 @@ criterion_group!(benches, bench_write_camel_case);
criterion_main!(benches);
mod _new {
pub fn write_camel_case(value: &[u8], buffer: &mut [u8]) {
pub fn write_camel_case(value: &[u8], buf: *mut u8, len: usize) {
// first copy entire (potentially wrong) slice to output
buffer[..value.len()].copy_from_slice(value);
let buffer = unsafe {
std::ptr::copy_nonoverlapping(value.as_ptr(), buf, len);
std::slice::from_raw_parts_mut(buf, len)
};
let mut iter = value.iter();

View File

@ -1,12 +1,14 @@
use std::io;
use actix_http::{body::Body, http::HeaderValue, http::StatusCode};
use actix_http::{Error, HttpService, Request, Response};
use actix_http::{
body::MessageBody, http::HeaderValue, http::StatusCode, Error, HttpService, Request,
Response,
};
use actix_server::Server;
use bytes::BytesMut;
use futures_util::StreamExt as _;
async fn handle_request(mut req: Request) -> Result<Response<Body>, Error> {
async fn handle_request(mut req: Request) -> Result<Response<impl MessageBody>, Error> {
let mut body = BytesMut::new();
while let Some(item) = req.payload().next().await {
body.extend_from_slice(&item?)

View File

@ -85,22 +85,31 @@ impl Stream for Heartbeat {
fn tls_config() -> rustls::ServerConfig {
use std::io::BufReader;
use rustls::{
internal::pemfile::{certs, pkcs8_private_keys},
NoClientAuth, ServerConfig,
};
use rustls::{Certificate, PrivateKey};
use rustls_pemfile::{certs, pkcs8_private_keys};
let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap();
let cert_file = cert.serialize_pem().unwrap();
let key_file = cert.serialize_private_key_pem();
let mut config = ServerConfig::new(NoClientAuth::new());
let cert_file = &mut BufReader::new(cert_file.as_bytes());
let key_file = &mut BufReader::new(key_file.as_bytes());
let cert_chain = certs(cert_file).unwrap();
let cert_chain = certs(cert_file)
.unwrap()
.into_iter()
.map(Certificate)
.collect();
let mut keys = pkcs8_private_keys(key_file).unwrap();
config.set_single_cert(cert_chain, keys.remove(0)).unwrap();
let mut config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert_chain, PrivateKey(keys.remove(0)))
.unwrap();
config.alpn_protocols.push(b"http/1.1".to_vec());
config.alpn_protocols.push(b"h2".to_vec());
config
}

View File

@ -1,233 +0,0 @@
use std::{
borrow::Cow,
error::Error as StdError,
fmt, mem,
pin::Pin,
task::{Context, Poll},
};
use bytes::{Bytes, BytesMut};
use futures_core::{ready, Stream};
use crate::error::Error;
use super::{BodySize, BodyStream, MessageBody, MessageBodyMapErr, SizedStream};
pub type Body = AnyBody;
/// Represents various types of HTTP message body.
pub enum AnyBody {
/// Empty response. `Content-Length` header is not set.
None,
/// Zero sized response body. `Content-Length` header is set to `0`.
Empty,
/// Specific response body.
Bytes(Bytes),
/// Generic message body.
Message(BoxAnyBody),
}
impl AnyBody {
/// Create body from slice (copy)
pub fn from_slice(s: &[u8]) -> Self {
Self::Bytes(Bytes::copy_from_slice(s))
}
/// Create body from generic message body.
pub fn from_message<B>(body: B) -> Self
where
B: MessageBody + 'static,
B::Error: Into<Box<dyn StdError + 'static>>,
{
Self::Message(BoxAnyBody::from_body(body))
}
}
impl MessageBody for AnyBody {
type Error = Error;
fn size(&self) -> BodySize {
match self {
AnyBody::None => BodySize::None,
AnyBody::Empty => BodySize::Empty,
AnyBody::Bytes(ref bin) => BodySize::Sized(bin.len() as u64),
AnyBody::Message(ref body) => body.size(),
}
}
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
match self.get_mut() {
AnyBody::None => Poll::Ready(None),
AnyBody::Empty => Poll::Ready(None),
AnyBody::Bytes(ref mut bin) => {
let len = bin.len();
if len == 0 {
Poll::Ready(None)
} else {
Poll::Ready(Some(Ok(mem::take(bin))))
}
}
// TODO: MSRV 1.51: poll_map_err
AnyBody::Message(body) => match ready!(body.as_pin_mut().poll_next(cx)) {
Some(Err(err)) => {
Poll::Ready(Some(Err(Error::new_body().with_cause(err))))
}
Some(Ok(val)) => Poll::Ready(Some(Ok(val))),
None => Poll::Ready(None),
},
}
}
}
impl PartialEq for AnyBody {
fn eq(&self, other: &Body) -> bool {
match *self {
AnyBody::None => matches!(*other, AnyBody::None),
AnyBody::Empty => matches!(*other, AnyBody::Empty),
AnyBody::Bytes(ref b) => match *other {
AnyBody::Bytes(ref b2) => b == b2,
_ => false,
},
AnyBody::Message(_) => false,
}
}
}
impl fmt::Debug for AnyBody {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
AnyBody::None => write!(f, "AnyBody::None"),
AnyBody::Empty => write!(f, "AnyBody::Empty"),
AnyBody::Bytes(ref b) => write!(f, "AnyBody::Bytes({:?})", b),
AnyBody::Message(_) => write!(f, "AnyBody::Message(_)"),
}
}
}
impl From<&'static str> for AnyBody {
fn from(s: &'static str) -> Body {
AnyBody::Bytes(Bytes::from_static(s.as_ref()))
}
}
impl From<&'static [u8]> for AnyBody {
fn from(s: &'static [u8]) -> Body {
AnyBody::Bytes(Bytes::from_static(s))
}
}
impl From<Vec<u8>> for AnyBody {
fn from(vec: Vec<u8>) -> Body {
AnyBody::Bytes(Bytes::from(vec))
}
}
impl From<String> for AnyBody {
fn from(s: String) -> Body {
s.into_bytes().into()
}
}
impl From<&'_ String> for AnyBody {
fn from(s: &String) -> Body {
AnyBody::Bytes(Bytes::copy_from_slice(AsRef::<[u8]>::as_ref(&s)))
}
}
impl From<Cow<'_, str>> for AnyBody {
fn from(s: Cow<'_, str>) -> Body {
match s {
Cow::Owned(s) => AnyBody::from(s),
Cow::Borrowed(s) => {
AnyBody::Bytes(Bytes::copy_from_slice(AsRef::<[u8]>::as_ref(s)))
}
}
}
}
impl From<Bytes> for AnyBody {
fn from(s: Bytes) -> Body {
AnyBody::Bytes(s)
}
}
impl From<BytesMut> for AnyBody {
fn from(s: BytesMut) -> Body {
AnyBody::Bytes(s.freeze())
}
}
impl<S, E> From<SizedStream<S>> for AnyBody
where
S: Stream<Item = Result<Bytes, E>> + 'static,
E: Into<Box<dyn StdError>> + 'static,
{
fn from(s: SizedStream<S>) -> Body {
AnyBody::from_message(s)
}
}
impl<S, E> From<BodyStream<S>> for AnyBody
where
S: Stream<Item = Result<Bytes, E>> + 'static,
E: Into<Box<dyn StdError>> + 'static,
{
fn from(s: BodyStream<S>) -> Body {
AnyBody::from_message(s)
}
}
/// A boxed message body with boxed errors.
pub struct BoxAnyBody(Pin<Box<dyn MessageBody<Error = Box<dyn StdError + 'static>>>>);
impl BoxAnyBody {
/// Boxes a `MessageBody` and any errors it generates.
pub fn from_body<B>(body: B) -> Self
where
B: MessageBody + 'static,
B::Error: Into<Box<dyn StdError + 'static>>,
{
let body = MessageBodyMapErr::new(body, Into::into);
Self(Box::pin(body))
}
/// Returns a mutable pinned reference to the inner message body type.
pub fn as_pin_mut(
&mut self,
) -> Pin<&mut (dyn MessageBody<Error = Box<dyn StdError + 'static>>)> {
self.0.as_mut()
}
}
impl fmt::Debug for BoxAnyBody {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("BoxAnyBody(dyn MessageBody)")
}
}
impl MessageBody for BoxAnyBody {
type Error = Error;
fn size(&self) -> BodySize {
self.0.size()
}
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
// TODO: MSRV 1.51: poll_map_err
match ready!(self.0.as_mut().poll_next(cx)) {
Some(Err(err)) => Poll::Ready(Some(Err(Error::new_body().with_cause(err)))),
Some(Ok(val)) => Poll::Ready(Some(Ok(val))),
None => Poll::Ready(None),
}
}
}

View File

@ -20,6 +20,8 @@ pin_project! {
}
}
// TODO: from_infallible method
impl<S, E> BodyStream<S>
where
S: Stream<Item = Result<Bytes, E>>,
@ -75,10 +77,23 @@ mod tests {
use derive_more::{Display, Error};
use futures_core::ready;
use futures_util::{stream, FutureExt as _};
use pin_project_lite::pin_project;
use static_assertions::{assert_impl_all, assert_not_impl_all};
use super::*;
use crate::body::to_bytes;
assert_impl_all!(BodyStream<stream::Empty<Result<Bytes, crate::Error>>>: MessageBody);
assert_impl_all!(BodyStream<stream::Empty<Result<Bytes, &'static str>>>: MessageBody);
assert_impl_all!(BodyStream<stream::Repeat<Result<Bytes, &'static str>>>: MessageBody);
assert_impl_all!(BodyStream<stream::Empty<Result<Bytes, Infallible>>>: MessageBody);
assert_impl_all!(BodyStream<stream::Repeat<Result<Bytes, Infallible>>>: MessageBody);
assert_not_impl_all!(BodyStream<stream::Empty<Bytes>>: MessageBody);
assert_not_impl_all!(BodyStream<stream::Repeat<Bytes>>: MessageBody);
// crate::Error is not Clone
assert_not_impl_all!(BodyStream<stream::Repeat<Result<Bytes, crate::Error>>>: MessageBody);
#[actix_rt::test]
async fn skips_empty_chunks() {
let body = BodyStream::new(stream::iter(
@ -124,18 +139,44 @@ mod tests {
assert!(matches!(to_bytes(body).await, Err(StreamErr)));
}
#[actix_rt::test]
async fn stream_string_error() {
// `&'static str` does not impl `Error`
// but it does impl `Into<Box<dyn Error>>`
let body = BodyStream::new(stream::once(async { Err("stringy error") }));
assert!(matches!(to_bytes(body).await, Err("stringy error")));
}
#[actix_rt::test]
async fn stream_boxed_error() {
// `Box<dyn Error>` does not impl `Error`
// but it does impl `Into<Box<dyn Error>>`
let body = BodyStream::new(stream::once(async {
Err(Box::<dyn StdError>::from("stringy error"))
}));
assert_eq!(
to_bytes(body).await.unwrap_err().to_string(),
"stringy error"
);
}
#[actix_rt::test]
async fn stream_delayed_error() {
let body =
BodyStream::new(stream::iter(vec![Ok(Bytes::from("1")), Err(StreamErr)]));
assert!(matches!(to_bytes(body).await, Err(StreamErr)));
#[pin_project::pin_project(project = TimeDelayStreamProj)]
#[derive(Debug)]
enum TimeDelayStream {
Start,
Sleep(Pin<Box<Sleep>>),
Done,
pin_project! {
#[derive(Debug)]
#[project = TimeDelayStreamProj]
enum TimeDelayStream {
Start,
Sleep { delay: Pin<Box<Sleep>> },
Done,
}
}
impl Stream for TimeDelayStream {
@ -148,12 +189,14 @@ mod tests {
match self.as_mut().get_mut() {
TimeDelayStream::Start => {
let sleep = sleep(Duration::from_millis(1));
self.as_mut().set(TimeDelayStream::Sleep(Box::pin(sleep)));
self.as_mut().set(TimeDelayStream::Sleep {
delay: Box::pin(sleep),
});
cx.waker().wake_by_ref();
Poll::Pending
}
TimeDelayStream::Sleep(ref mut delay) => {
TimeDelayStream::Sleep { ref mut delay } => {
ready!(delay.poll_unpin(cx));
self.set(TimeDelayStream::Done);
cx.waker().wake_by_ref();

View File

@ -0,0 +1,80 @@
use std::{
error::Error as StdError,
fmt,
pin::Pin,
task::{Context, Poll},
};
use bytes::Bytes;
use super::{BodySize, MessageBody, MessageBodyMapErr};
use crate::Error;
/// A boxed message body with boxed errors.
pub struct BoxBody(Pin<Box<dyn MessageBody<Error = Box<dyn StdError>>>>);
impl BoxBody {
/// Boxes a `MessageBody` and any errors it generates.
pub fn new<B>(body: B) -> Self
where
B: MessageBody + 'static,
{
let body = MessageBodyMapErr::new(body, Into::into);
Self(Box::pin(body))
}
/// Returns a mutable pinned reference to the inner message body type.
pub fn as_pin_mut(
&mut self,
) -> Pin<&mut (dyn MessageBody<Error = Box<dyn StdError>>)> {
self.0.as_mut()
}
}
impl fmt::Debug for BoxBody {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("BoxBody(dyn MessageBody)")
}
}
impl MessageBody for BoxBody {
type Error = Error;
fn size(&self) -> BodySize {
self.0.size()
}
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
self.0
.as_mut()
.poll_next(cx)
.map_err(|err| Error::new_body().with_cause(err))
}
}
#[cfg(test)]
mod tests {
use static_assertions::{assert_impl_all, assert_not_impl_all};
use super::*;
use crate::body::to_bytes;
assert_impl_all!(BoxBody: MessageBody, fmt::Debug, Unpin);
assert_not_impl_all!(BoxBody: Send, Sync, Unpin);
#[actix_rt::test]
async fn nested_boxed_body() {
let body = Bytes::from_static(&[1, 2, 3]);
let boxed_body = BoxBody::new(BoxBody::new(body));
assert_eq!(
to_bytes(boxed_body).await.unwrap(),
Bytes::from(vec![1, 2, 3]),
);
}
}

View File

@ -0,0 +1,83 @@
use std::{
pin::Pin,
task::{Context, Poll},
};
use bytes::Bytes;
use pin_project_lite::pin_project;
use super::{BodySize, BoxBody, MessageBody};
use crate::Error;
pin_project! {
#[project = EitherBodyProj]
#[derive(Debug, Clone)]
pub enum EitherBody<L, R = BoxBody> {
/// A body of type `L`.
Left { #[pin] body: L },
/// A body of type `R`.
Right { #[pin] body: R },
}
}
impl<L> EitherBody<L, BoxBody> {
/// Creates new `EitherBody` using left variant and boxed right variant.
pub fn new(body: L) -> Self {
Self::Left { body }
}
}
impl<L, R> EitherBody<L, R> {
/// Creates new `EitherBody` using left variant.
pub fn left(body: L) -> Self {
Self::Left { body }
}
/// Creates new `EitherBody` using right variant.
pub fn right(body: R) -> Self {
Self::Right { body }
}
}
impl<L, R> MessageBody for EitherBody<L, R>
where
L: MessageBody + 'static,
R: MessageBody + 'static,
{
type Error = Error;
fn size(&self) -> BodySize {
match self {
EitherBody::Left { body } => body.size(),
EitherBody::Right { body } => body.size(),
}
}
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
match self.project() {
EitherBodyProj::Left { body } => body
.poll_next(cx)
.map_err(|err| Error::new_body().with_cause(err)),
EitherBodyProj::Right { body } => body
.poll_next(cx)
.map_err(|err| Error::new_body().with_cause(err)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn type_parameter_inference() {
let _body: EitherBody<(), _> = EitherBody::new(());
let _body: EitherBody<_, ()> = EitherBody::left(());
let _body: EitherBody<(), _> = EitherBody::right(());
}
}

View File

@ -2,6 +2,7 @@
use std::{
convert::Infallible,
error::Error as StdError,
mem,
pin::Pin,
task::{Context, Poll},
@ -11,13 +12,14 @@ use bytes::{Bytes, BytesMut};
use futures_core::ready;
use pin_project_lite::pin_project;
use crate::error::Error;
use super::BodySize;
/// An interface for response bodies.
/// An interface types that can converted to bytes and used as response bodies.
// TODO: examples
pub trait MessageBody {
type Error;
// TODO: consider this bound to only fmt::Display since the error type is not really used
// and there is an impl for Into<Box<StdError>> on String
type Error: Into<Box<dyn StdError>>;
/// Body size hint.
fn size(&self) -> BodySize;
@ -29,154 +31,218 @@ pub trait MessageBody {
) -> Poll<Option<Result<Bytes, Self::Error>>>;
}
impl MessageBody for () {
type Error = Infallible;
mod foreign_impls {
use super::*;
fn size(&self) -> BodySize {
BodySize::Empty
}
impl MessageBody for Infallible {
type Error = Infallible;
fn poll_next(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
Poll::Ready(None)
}
}
#[inline]
fn size(&self) -> BodySize {
match *self {}
}
impl<B> MessageBody for Box<B>
where
B: MessageBody + Unpin,
B::Error: Into<Error>,
{
type Error = B::Error;
fn size(&self) -> BodySize {
self.as_ref().size()
}
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
Pin::new(self.get_mut().as_mut()).poll_next(cx)
}
}
impl<B> MessageBody for Pin<Box<B>>
where
B: MessageBody,
B::Error: Into<Error>,
{
type Error = B::Error;
fn size(&self) -> BodySize {
self.as_ref().size()
}
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
self.as_mut().poll_next(cx)
}
}
impl MessageBody for Bytes {
type Error = Infallible;
fn size(&self) -> BodySize {
BodySize::Sized(self.len() as u64)
}
fn poll_next(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
if self.is_empty() {
Poll::Ready(None)
} else {
Poll::Ready(Some(Ok(mem::take(self.get_mut()))))
#[inline]
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
match *self {}
}
}
}
impl MessageBody for BytesMut {
type Error = Infallible;
impl MessageBody for () {
type Error = Infallible;
fn size(&self) -> BodySize {
BodySize::Sized(self.len() as u64)
}
#[inline]
fn size(&self) -> BodySize {
BodySize::Sized(0)
}
fn poll_next(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
if self.is_empty() {
#[inline]
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
Poll::Ready(None)
} else {
Poll::Ready(Some(Ok(mem::take(self.get_mut()).freeze())))
}
}
}
impl MessageBody for &'static str {
type Error = Infallible;
impl<B> MessageBody for Box<B>
where
B: MessageBody + Unpin,
{
type Error = B::Error;
fn size(&self) -> BodySize {
BodySize::Sized(self.len() as u64)
}
#[inline]
fn size(&self) -> BodySize {
self.as_ref().size()
}
fn poll_next(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
if self.is_empty() {
Poll::Ready(None)
} else {
Poll::Ready(Some(Ok(Bytes::from_static(
mem::take(self.get_mut()).as_ref(),
))))
#[inline]
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
Pin::new(self.get_mut().as_mut()).poll_next(cx)
}
}
}
impl MessageBody for Vec<u8> {
type Error = Infallible;
impl<B> MessageBody for Pin<Box<B>>
where
B: MessageBody,
{
type Error = B::Error;
fn size(&self) -> BodySize {
BodySize::Sized(self.len() as u64)
}
#[inline]
fn size(&self) -> BodySize {
self.as_ref().size()
}
fn poll_next(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
if self.is_empty() {
Poll::Ready(None)
} else {
Poll::Ready(Some(Ok(Bytes::from(mem::take(self.get_mut())))))
#[inline]
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
self.as_mut().poll_next(cx)
}
}
}
impl MessageBody for String {
type Error = Infallible;
impl MessageBody for &'static [u8] {
type Error = Infallible;
fn size(&self) -> BodySize {
BodySize::Sized(self.len() as u64)
fn size(&self) -> BodySize {
BodySize::Sized(self.len() as u64)
}
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
if self.is_empty() {
Poll::Ready(None)
} else {
let bytes = mem::take(self.get_mut());
let bytes = Bytes::from_static(bytes);
Poll::Ready(Some(Ok(bytes)))
}
}
}
fn poll_next(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
if self.is_empty() {
Poll::Ready(None)
} else {
Poll::Ready(Some(Ok(Bytes::from(
mem::take(self.get_mut()).into_bytes(),
))))
impl MessageBody for Bytes {
type Error = Infallible;
fn size(&self) -> BodySize {
BodySize::Sized(self.len() as u64)
}
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
if self.is_empty() {
Poll::Ready(None)
} else {
let bytes = mem::take(self.get_mut());
Poll::Ready(Some(Ok(bytes)))
}
}
}
impl MessageBody for BytesMut {
type Error = Infallible;
fn size(&self) -> BodySize {
BodySize::Sized(self.len() as u64)
}
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
if self.is_empty() {
Poll::Ready(None)
} else {
let bytes = mem::take(self.get_mut()).freeze();
Poll::Ready(Some(Ok(bytes)))
}
}
}
impl MessageBody for Vec<u8> {
type Error = Infallible;
fn size(&self) -> BodySize {
BodySize::Sized(self.len() as u64)
}
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
if self.is_empty() {
Poll::Ready(None)
} else {
let bytes = mem::take(self.get_mut());
Poll::Ready(Some(Ok(Bytes::from(bytes))))
}
}
}
impl MessageBody for &'static str {
type Error = Infallible;
fn size(&self) -> BodySize {
BodySize::Sized(self.len() as u64)
}
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
if self.is_empty() {
Poll::Ready(None)
} else {
let string = mem::take(self.get_mut());
let bytes = Bytes::from_static(string.as_bytes());
Poll::Ready(Some(Ok(bytes)))
}
}
}
impl MessageBody for String {
type Error = Infallible;
fn size(&self) -> BodySize {
BodySize::Sized(self.len() as u64)
}
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
if self.is_empty() {
Poll::Ready(None)
} else {
let string = mem::take(self.get_mut());
Poll::Ready(Some(Ok(Bytes::from(string))))
}
}
}
impl MessageBody for bytestring::ByteString {
type Error = Infallible;
fn size(&self) -> BodySize {
BodySize::Sized(self.len() as u64)
}
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
let string = mem::take(self.get_mut());
Poll::Ready(Some(Ok(string.into_bytes())))
}
}
}
@ -206,6 +272,7 @@ impl<B, F, E> MessageBody for MessageBodyMapErr<B, F>
where
B: MessageBody,
F: FnOnce(B::Error) -> E,
E: Into<Box<dyn StdError>>,
{
type Error = E;
@ -230,3 +297,129 @@ where
}
}
}
#[cfg(test)]
mod tests {
use actix_rt::pin;
use actix_utils::future::poll_fn;
use bytes::{Bytes, BytesMut};
use super::*;
macro_rules! assert_poll_next {
($pin:expr, $exp:expr) => {
assert_eq!(
poll_fn(|cx| $pin.as_mut().poll_next(cx))
.await
.unwrap() // unwrap option
.unwrap(), // unwrap result
$exp
);
};
}
macro_rules! assert_poll_next_none {
($pin:expr) => {
assert!(poll_fn(|cx| $pin.as_mut().poll_next(cx)).await.is_none());
};
}
#[actix_rt::test]
async fn boxing_equivalence() {
assert_eq!(().size(), BodySize::Sized(0));
assert_eq!(().size(), Box::new(()).size());
assert_eq!(().size(), Box::pin(()).size());
let pl = Box::new(());
pin!(pl);
assert_poll_next_none!(pl);
let mut pl = Box::pin(());
assert_poll_next_none!(pl);
}
#[actix_rt::test]
async fn test_unit() {
let pl = ();
assert_eq!(pl.size(), BodySize::Sized(0));
pin!(pl);
assert_poll_next_none!(pl);
}
#[actix_rt::test]
async fn test_static_str() {
assert_eq!("".size(), BodySize::Sized(0));
assert_eq!("test".size(), BodySize::Sized(4));
let pl = "test";
pin!(pl);
assert_poll_next!(pl, Bytes::from("test"));
}
#[actix_rt::test]
async fn test_static_bytes() {
assert_eq!(b"".as_ref().size(), BodySize::Sized(0));
assert_eq!(b"test".as_ref().size(), BodySize::Sized(4));
let pl = b"test".as_ref();
pin!(pl);
assert_poll_next!(pl, Bytes::from("test"));
}
#[actix_rt::test]
async fn test_vec() {
assert_eq!(vec![0; 0].size(), BodySize::Sized(0));
assert_eq!(Vec::from("test").size(), BodySize::Sized(4));
let pl = Vec::from("test");
pin!(pl);
assert_poll_next!(pl, Bytes::from("test"));
}
#[actix_rt::test]
async fn test_bytes() {
assert_eq!(Bytes::new().size(), BodySize::Sized(0));
assert_eq!(Bytes::from_static(b"test").size(), BodySize::Sized(4));
let pl = Bytes::from_static(b"test");
pin!(pl);
assert_poll_next!(pl, Bytes::from("test"));
}
#[actix_rt::test]
async fn test_bytes_mut() {
assert_eq!(BytesMut::new().size(), BodySize::Sized(0));
assert_eq!(BytesMut::from(b"test".as_ref()).size(), BodySize::Sized(4));
let pl = BytesMut::from("test");
pin!(pl);
assert_poll_next!(pl, Bytes::from("test"));
}
#[actix_rt::test]
async fn test_string() {
assert_eq!(String::new().size(), BodySize::Sized(0));
assert_eq!("test".to_owned().size(), BodySize::Sized(4));
let pl = "test".to_owned();
pin!(pl);
assert_poll_next!(pl, Bytes::from("test"));
}
// down-casting used to be done with a method on MessageBody trait
// test is kept to demonstrate equivalence of Any trait
#[actix_rt::test]
async fn test_body_casting() {
let mut body = String::from("hello cast");
// let mut resp_body: &mut dyn MessageBody<Error = Error> = &mut body;
let resp_body: &mut dyn std::any::Any = &mut body;
let body = resp_body.downcast_ref::<String>().unwrap();
assert_eq!(body, "hello cast");
let body = &mut resp_body.downcast_mut::<String>().unwrap();
body.push('!');
let body = resp_body.downcast_ref::<String>().unwrap();
assert_eq!(body, "hello cast!");
let not_body = resp_body.downcast_ref::<()>();
assert!(not_body.is_none());
}
}

View File

@ -1,263 +1,20 @@
//! Traits and structures to aid consuming and writing HTTP payloads.
use std::task::Poll;
use actix_rt::pin;
use actix_utils::future::poll_fn;
use bytes::{Bytes, BytesMut};
use futures_core::ready;
#[allow(clippy::module_inception)]
mod body;
mod body_stream;
mod boxed;
mod either;
mod message_body;
mod response_body;
mod none;
mod size;
mod sized_stream;
mod utils;
pub use self::body::{AnyBody, Body, BoxAnyBody};
pub use self::body_stream::BodyStream;
pub use self::boxed::BoxBody;
pub use self::either::EitherBody;
pub use self::message_body::MessageBody;
pub(crate) use self::message_body::MessageBodyMapErr;
pub use self::response_body::ResponseBody;
pub use self::none::None;
pub use self::size::BodySize;
pub use self::sized_stream::SizedStream;
/// Collects the body produced by a `MessageBody` implementation into `Bytes`.
///
/// Any errors produced by the body stream are returned immediately.
///
/// # Examples
/// ```
/// use actix_http::body::{Body, to_bytes};
/// use bytes::Bytes;
///
/// # async fn test_to_bytes() {
/// let body = Body::Empty;
/// let bytes = to_bytes(body).await.unwrap();
/// assert!(bytes.is_empty());
///
/// let body = Body::Bytes(Bytes::from_static(b"123"));
/// let bytes = to_bytes(body).await.unwrap();
/// assert_eq!(bytes, b"123"[..]);
/// # }
/// ```
pub async fn to_bytes<B: MessageBody>(body: B) -> Result<Bytes, B::Error> {
let cap = match body.size() {
BodySize::None | BodySize::Empty | BodySize::Sized(0) => return Ok(Bytes::new()),
BodySize::Sized(size) => size as usize,
BodySize::Stream => 32_768,
};
let mut buf = BytesMut::with_capacity(cap);
pin!(body);
poll_fn(|cx| loop {
let body = body.as_mut();
match ready!(body.poll_next(cx)) {
Some(Ok(bytes)) => buf.extend_from_slice(&*bytes),
None => return Poll::Ready(Ok(())),
Some(Err(err)) => return Poll::Ready(Err(err)),
}
})
.await?;
Ok(buf.freeze())
}
#[cfg(test)]
mod tests {
use std::pin::Pin;
use actix_rt::pin;
use actix_utils::future::poll_fn;
use bytes::{Bytes, BytesMut};
use super::*;
impl Body {
pub(crate) fn get_ref(&self) -> &[u8] {
match *self {
Body::Bytes(ref bin) => &bin,
_ => panic!(),
}
}
}
#[actix_rt::test]
async fn test_static_str() {
assert_eq!(Body::from("").size(), BodySize::Sized(0));
assert_eq!(Body::from("test").size(), BodySize::Sized(4));
assert_eq!(Body::from("test").get_ref(), b"test");
assert_eq!("test".size(), BodySize::Sized(4));
assert_eq!(
poll_fn(|cx| Pin::new(&mut "test").poll_next(cx))
.await
.unwrap()
.ok(),
Some(Bytes::from("test"))
);
}
#[actix_rt::test]
async fn test_static_bytes() {
assert_eq!(Body::from(b"test".as_ref()).size(), BodySize::Sized(4));
assert_eq!(Body::from(b"test".as_ref()).get_ref(), b"test");
assert_eq!(
Body::from_slice(b"test".as_ref()).size(),
BodySize::Sized(4)
);
assert_eq!(Body::from_slice(b"test".as_ref()).get_ref(), b"test");
let sb = Bytes::from(&b"test"[..]);
pin!(sb);
assert_eq!(sb.size(), BodySize::Sized(4));
assert_eq!(
poll_fn(|cx| sb.as_mut().poll_next(cx)).await.unwrap().ok(),
Some(Bytes::from("test"))
);
}
#[actix_rt::test]
async fn test_vec() {
assert_eq!(Body::from(Vec::from("test")).size(), BodySize::Sized(4));
assert_eq!(Body::from(Vec::from("test")).get_ref(), b"test");
let test_vec = Vec::from("test");
pin!(test_vec);
assert_eq!(test_vec.size(), BodySize::Sized(4));
assert_eq!(
poll_fn(|cx| test_vec.as_mut().poll_next(cx))
.await
.unwrap()
.ok(),
Some(Bytes::from("test"))
);
}
#[actix_rt::test]
async fn test_bytes() {
let b = Bytes::from("test");
assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4));
assert_eq!(Body::from(b.clone()).get_ref(), b"test");
pin!(b);
assert_eq!(b.size(), BodySize::Sized(4));
assert_eq!(
poll_fn(|cx| b.as_mut().poll_next(cx)).await.unwrap().ok(),
Some(Bytes::from("test"))
);
}
#[actix_rt::test]
async fn test_bytes_mut() {
let b = BytesMut::from("test");
assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4));
assert_eq!(Body::from(b.clone()).get_ref(), b"test");
pin!(b);
assert_eq!(b.size(), BodySize::Sized(4));
assert_eq!(
poll_fn(|cx| b.as_mut().poll_next(cx)).await.unwrap().ok(),
Some(Bytes::from("test"))
);
}
#[actix_rt::test]
async fn test_string() {
let b = "test".to_owned();
assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4));
assert_eq!(Body::from(b.clone()).get_ref(), b"test");
assert_eq!(Body::from(&b).size(), BodySize::Sized(4));
assert_eq!(Body::from(&b).get_ref(), b"test");
pin!(b);
assert_eq!(b.size(), BodySize::Sized(4));
assert_eq!(
poll_fn(|cx| b.as_mut().poll_next(cx)).await.unwrap().ok(),
Some(Bytes::from("test"))
);
}
#[actix_rt::test]
async fn test_unit() {
assert_eq!(().size(), BodySize::Empty);
assert!(poll_fn(|cx| Pin::new(&mut ()).poll_next(cx))
.await
.is_none());
}
#[actix_rt::test]
async fn test_box_and_pin() {
let val = Box::new(());
pin!(val);
assert_eq!(val.size(), BodySize::Empty);
assert!(poll_fn(|cx| val.as_mut().poll_next(cx)).await.is_none());
let mut val = Box::pin(());
assert_eq!(val.size(), BodySize::Empty);
assert!(poll_fn(|cx| val.as_mut().poll_next(cx)).await.is_none());
}
#[actix_rt::test]
async fn test_body_eq() {
assert!(
Body::Bytes(Bytes::from_static(b"1"))
== Body::Bytes(Bytes::from_static(b"1"))
);
assert!(Body::Bytes(Bytes::from_static(b"1")) != Body::None);
}
#[actix_rt::test]
async fn test_body_debug() {
assert!(format!("{:?}", Body::None).contains("Body::None"));
assert!(format!("{:?}", Body::Empty).contains("Body::Empty"));
assert!(format!("{:?}", Body::Bytes(Bytes::from_static(b"1"))).contains('1'));
}
#[actix_rt::test]
async fn test_serde_json() {
use serde_json::{json, Value};
assert_eq!(
Body::from(serde_json::to_vec(&Value::String("test".to_owned())).unwrap())
.size(),
BodySize::Sized(6)
);
assert_eq!(
Body::from(serde_json::to_vec(&json!({"test-key":"test-value"})).unwrap())
.size(),
BodySize::Sized(25)
);
}
// down-casting used to be done with a method on MessageBody trait
// test is kept to demonstrate equivalence of Any trait
#[actix_rt::test]
async fn test_body_casting() {
let mut body = String::from("hello cast");
// let mut resp_body: &mut dyn MessageBody<Error = Error> = &mut body;
let resp_body: &mut dyn std::any::Any = &mut body;
let body = resp_body.downcast_ref::<String>().unwrap();
assert_eq!(body, "hello cast");
let body = &mut resp_body.downcast_mut::<String>().unwrap();
body.push('!');
let body = resp_body.downcast_ref::<String>().unwrap();
assert_eq!(body, "hello cast!");
let not_body = resp_body.downcast_ref::<()>();
assert!(not_body.is_none());
}
#[actix_rt::test]
async fn test_to_bytes() {
let body = Body::Empty;
let bytes = to_bytes(body).await.unwrap();
assert!(bytes.is_empty());
let body = Body::Bytes(Bytes::from_static(b"123"));
let bytes = to_bytes(body).await.unwrap();
assert_eq!(bytes, b"123"[..]);
}
}
pub use self::utils::to_bytes;

View File

@ -0,0 +1,43 @@
use std::{
convert::Infallible,
pin::Pin,
task::{Context, Poll},
};
use bytes::Bytes;
use super::{BodySize, MessageBody};
/// Body type for responses that forbid payloads.
///
/// Distinct from an empty response which would contain a Content-Length header.
///
/// For an "empty" body, use `()` or `Bytes::new()`.
#[derive(Debug, Clone, Copy, Default)]
#[non_exhaustive]
pub struct None;
impl None {
/// Constructs new "none" body.
#[inline]
pub fn new() -> Self {
None
}
}
impl MessageBody for None {
type Error = Infallible;
#[inline]
fn size(&self) -> BodySize {
BodySize::None
}
#[inline]
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
Poll::Ready(Option::None)
}
}

View File

@ -1,89 +0,0 @@
use std::{
mem,
pin::Pin,
task::{Context, Poll},
};
use bytes::Bytes;
use futures_core::{ready, Stream};
use pin_project::pin_project;
use crate::error::Error;
use super::{Body, BodySize, MessageBody};
#[pin_project(project = ResponseBodyProj)]
pub enum ResponseBody<B> {
Body(#[pin] B),
Other(Body),
}
impl ResponseBody<Body> {
pub fn into_body<B>(self) -> ResponseBody<B> {
match self {
ResponseBody::Body(b) => ResponseBody::Other(b),
ResponseBody::Other(b) => ResponseBody::Other(b),
}
}
}
impl<B> ResponseBody<B> {
pub fn take_body(&mut self) -> ResponseBody<B> {
mem::replace(self, ResponseBody::Other(Body::None))
}
}
impl<B: MessageBody> ResponseBody<B> {
pub fn as_ref(&self) -> Option<&B> {
if let ResponseBody::Body(ref b) = self {
Some(b)
} else {
None
}
}
}
impl<B> MessageBody for ResponseBody<B>
where
B: MessageBody,
B::Error: Into<Error>,
{
type Error = Error;
fn size(&self) -> BodySize {
match self {
ResponseBody::Body(ref body) => body.size(),
ResponseBody::Other(ref body) => body.size(),
}
}
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
Stream::poll_next(self, cx)
}
}
impl<B> Stream for ResponseBody<B>
where
B: MessageBody,
B::Error: Into<Error>,
{
type Item = Result<Bytes, Error>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.project() {
// TODO: MSRV 1.51: poll_map_err
ResponseBodyProj::Body(body) => match ready!(body.poll_next(cx)) {
Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
Some(Ok(val)) => Poll::Ready(Some(Ok(val))),
None => Poll::Ready(None),
},
ResponseBodyProj::Other(body) => Pin::new(body).poll_next(cx),
}
}
}

View File

@ -6,14 +6,9 @@ pub enum BodySize {
/// Will skip writing Content-Length header.
None,
/// Zero size body.
///
/// Will write `Content-Length: 0` header.
Empty,
/// Known size body.
///
/// Will write `Content-Length: N` header. `Sized(0)` is treated the same as `Empty`.
/// Will write `Content-Length: N` header.
Sized(u64),
/// Unknown size body.
@ -23,18 +18,19 @@ pub enum BodySize {
}
impl BodySize {
/// Returns true if size hint indicates no or empty body.
/// Returns true if size hint indicates omitted or empty body.
///
/// Streams will return false because it cannot be known without reading the stream.
///
/// ```
/// # use actix_http::body::BodySize;
/// assert!(BodySize::None.is_eof());
/// assert!(BodySize::Empty.is_eof());
/// assert!(BodySize::Sized(0).is_eof());
///
/// assert!(!BodySize::Sized(64).is_eof());
/// assert!(!BodySize::Stream.is_eof());
/// ```
pub fn is_eof(&self) -> bool {
matches!(self, BodySize::None | BodySize::Empty | BodySize::Sized(0))
matches!(self, BodySize::None | BodySize::Sized(0))
}
}

View File

@ -32,6 +32,8 @@ where
}
}
// TODO: from_infallible method
impl<S, E> MessageBody for SizedStream<S>
where
S: Stream<Item = Result<Bytes, E>>,
@ -72,10 +74,22 @@ mod tests {
use actix_rt::pin;
use actix_utils::future::poll_fn;
use futures_util::stream;
use static_assertions::{assert_impl_all, assert_not_impl_all};
use super::*;
use crate::body::to_bytes;
assert_impl_all!(SizedStream<stream::Empty<Result<Bytes, crate::Error>>>: MessageBody);
assert_impl_all!(SizedStream<stream::Empty<Result<Bytes, &'static str>>>: MessageBody);
assert_impl_all!(SizedStream<stream::Repeat<Result<Bytes, &'static str>>>: MessageBody);
assert_impl_all!(SizedStream<stream::Empty<Result<Bytes, Infallible>>>: MessageBody);
assert_impl_all!(SizedStream<stream::Repeat<Result<Bytes, Infallible>>>: MessageBody);
assert_not_impl_all!(SizedStream<stream::Empty<Bytes>>: MessageBody);
assert_not_impl_all!(SizedStream<stream::Repeat<Bytes>>: MessageBody);
// crate::Error is not Clone
assert_not_impl_all!(SizedStream<stream::Repeat<Result<Bytes, crate::Error>>>: MessageBody);
#[actix_rt::test]
async fn skips_empty_chunks() {
let body = SizedStream::new(
@ -119,4 +133,37 @@ mod tests {
assert_eq!(to_bytes(body).await.ok(), Some(Bytes::from("12")));
}
#[actix_rt::test]
async fn stream_string_error() {
// `&'static str` does not impl `Error`
// but it does impl `Into<Box<dyn Error>>`
let body = SizedStream::new(0, stream::once(async { Err("stringy error") }));
assert_eq!(to_bytes(body).await, Ok(Bytes::new()));
let body = SizedStream::new(1, stream::once(async { Err("stringy error") }));
assert!(matches!(to_bytes(body).await, Err("stringy error")));
}
#[actix_rt::test]
async fn stream_boxed_error() {
// `Box<dyn Error>` does not impl `Error`
// but it does impl `Into<Box<dyn Error>>`
let body = SizedStream::new(
0,
stream::once(async { Err(Box::<dyn StdError>::from("stringy error")) }),
);
assert_eq!(to_bytes(body).await.unwrap(), Bytes::new());
let body = SizedStream::new(
1,
stream::once(async { Err(Box::<dyn StdError>::from("stringy error")) }),
);
assert_eq!(
to_bytes(body).await.unwrap_err().to_string(),
"stringy error"
);
}
}

View File

@ -0,0 +1,78 @@
use std::task::Poll;
use actix_rt::pin;
use actix_utils::future::poll_fn;
use bytes::{Bytes, BytesMut};
use futures_core::ready;
use super::{BodySize, MessageBody};
/// Collects the body produced by a `MessageBody` implementation into `Bytes`.
///
/// Any errors produced by the body stream are returned immediately.
///
/// # Examples
/// ```
/// use actix_http::body::{self, to_bytes};
/// use bytes::Bytes;
///
/// # async fn test_to_bytes() {
/// let body = body::None::new();
/// let bytes = to_bytes(body).await.unwrap();
/// assert!(bytes.is_empty());
///
/// let body = Bytes::from_static(b"123");
/// let bytes = to_bytes(body).await.unwrap();
/// assert_eq!(bytes, b"123"[..]);
/// # }
/// ```
pub async fn to_bytes<B: MessageBody>(body: B) -> Result<Bytes, B::Error> {
let cap = match body.size() {
BodySize::None | BodySize::Sized(0) => return Ok(Bytes::new()),
BodySize::Sized(size) => size as usize,
// good enough first guess for chunk size
BodySize::Stream => 32_768,
};
let mut buf = BytesMut::with_capacity(cap);
pin!(body);
poll_fn(|cx| loop {
let body = body.as_mut();
match ready!(body.poll_next(cx)) {
Some(Ok(bytes)) => buf.extend_from_slice(&*bytes),
None => return Poll::Ready(Ok(())),
Some(Err(err)) => return Poll::Ready(Err(err)),
}
})
.await?;
Ok(buf.freeze())
}
#[cfg(test)]
mod test {
use futures_util::{stream, StreamExt as _};
use super::*;
use crate::{body::BodyStream, Error};
#[actix_rt::test]
async fn test_to_bytes() {
let bytes = to_bytes(()).await.unwrap();
assert!(bytes.is_empty());
let body = Bytes::from_static(b"123");
let bytes = to_bytes(body).await.unwrap();
assert_eq!(bytes, b"123"[..]);
let stream =
stream::iter(vec![Bytes::from_static(b"123"), Bytes::from_static(b"abc")])
.map(Ok::<_, Error>);
let body = BodyStream::new(stream);
let bytes = to_bytes(body).await.unwrap();
assert_eq!(bytes, b"123abc"[..]);
}
}

View File

@ -1,10 +1,10 @@
use std::{error::Error as StdError, fmt, marker::PhantomData, net, rc::Rc};
use std::{fmt, marker::PhantomData, net, rc::Rc};
use actix_codec::Framed;
use actix_service::{IntoServiceFactory, Service, ServiceFactory};
use crate::{
body::{AnyBody, MessageBody},
body::{BoxBody, MessageBody},
config::{KeepAlive, ServiceConfig},
extensions::CloneableExtensions,
h1::{self, ExpectHandler, H1Service, UpgradeHandler},
@ -32,7 +32,7 @@ pub struct HttpServiceBuilder<T, S, X = ExpectHandler, U = UpgradeHandler> {
impl<T, S> HttpServiceBuilder<T, S, ExpectHandler, UpgradeHandler>
where
S: ServiceFactory<Request, Config = ()>,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::InitError: fmt::Debug,
<S::Service as Service<Request>>::Future: 'static,
{
@ -55,11 +55,11 @@ where
impl<T, S, X, U> HttpServiceBuilder<T, S, X, U>
where
S: ServiceFactory<Request, Config = ()>,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::InitError: fmt::Debug,
<S::Service as Service<Request>>::Future: 'static,
X: ServiceFactory<Request, Config = (), Response = Request>,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
X::InitError: fmt::Debug,
U: ServiceFactory<(Request, Framed<T, h1::Codec>), Config = (), Response = ()>,
U::Error: fmt::Display,
@ -121,7 +121,7 @@ where
where
F: IntoServiceFactory<X1, Request>,
X1: ServiceFactory<Request, Config = (), Response = Request>,
X1::Error: Into<Response<AnyBody>>,
X1::Error: Into<Response<BoxBody>>,
X1::InitError: fmt::Debug,
{
HttpServiceBuilder {
@ -179,7 +179,7 @@ where
where
B: MessageBody,
F: IntoServiceFactory<S, Request>,
S::Error: Into<Response<AnyBody>>,
S::Error: Into<Response<BoxBody>>,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
{
@ -201,12 +201,11 @@ where
pub fn h2<F, B>(self, service: F) -> H2Service<T, S, B>
where
F: IntoServiceFactory<S, Request>,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
B::Error: Into<Box<dyn StdError>>,
{
let cfg = ServiceConfig::new(
self.keep_alive,
@ -224,12 +223,11 @@ where
pub fn finish<F, B>(self, service: F) -> HttpService<T, S, B, X, U>
where
F: IntoServiceFactory<S, Request>,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
B::Error: Into<Box<dyn StdError>>,
{
let cfg = ServiceConfig::new(
self.keep_alive,

View File

@ -1,43 +0,0 @@
use std::net::IpAddr;
use std::time::Duration;
const DEFAULT_H2_CONN_WINDOW: u32 = 1024 * 1024 * 2; // 2MB
const DEFAULT_H2_STREAM_WINDOW: u32 = 1024 * 1024; // 1MB
/// Connector configuration
#[derive(Clone)]
pub(crate) struct ConnectorConfig {
pub(crate) timeout: Duration,
pub(crate) handshake_timeout: Duration,
pub(crate) conn_lifetime: Duration,
pub(crate) conn_keep_alive: Duration,
pub(crate) disconnect_timeout: Option<Duration>,
pub(crate) limit: usize,
pub(crate) conn_window_size: u32,
pub(crate) stream_window_size: u32,
pub(crate) local_address: Option<IpAddr>,
}
impl Default for ConnectorConfig {
fn default() -> Self {
Self {
timeout: Duration::from_secs(5),
handshake_timeout: Duration::from_secs(5),
conn_lifetime: Duration::from_secs(75),
conn_keep_alive: Duration::from_secs(15),
disconnect_timeout: Some(Duration::from_millis(3000)),
limit: 100,
conn_window_size: DEFAULT_H2_CONN_WINDOW,
stream_window_size: DEFAULT_H2_STREAM_WINDOW,
local_address: None,
}
}
}
impl ConnectorConfig {
pub(crate) fn no_disconnect_timeout(&self) -> Self {
let mut res = self.clone();
res.disconnect_timeout = None;
res
}
}

View File

@ -1,475 +0,0 @@
use std::{
io,
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
time,
};
use actix_codec::{AsyncRead, AsyncWrite, Framed, ReadBuf};
use actix_rt::task::JoinHandle;
use bytes::Bytes;
use futures_core::future::LocalBoxFuture;
use h2::client::SendRequest;
use crate::h1::ClientCodec;
use crate::message::{RequestHeadType, ResponseHead};
use crate::payload::Payload;
use crate::{body::MessageBody, Error};
use super::error::SendRequestError;
use super::pool::Acquired;
use super::{h1proto, h2proto};
/// Trait alias for types impl [tokio::io::AsyncRead] and [tokio::io::AsyncWrite].
pub trait ConnectionIo: AsyncRead + AsyncWrite + Unpin + 'static {}
impl<T: AsyncRead + AsyncWrite + Unpin + 'static> ConnectionIo for T {}
/// HTTP client connection
pub struct H1Connection<Io: ConnectionIo> {
io: Option<Io>,
created: time::Instant,
acquired: Acquired<Io>,
}
impl<Io: ConnectionIo> H1Connection<Io> {
/// close or release the connection to pool based on flag input
pub(super) fn on_release(&mut self, keep_alive: bool) {
if keep_alive {
self.release();
} else {
self.close();
}
}
/// Close connection
fn close(&mut self) {
let io = self.io.take().unwrap();
self.acquired.close(ConnectionInnerType::H1(io));
}
/// Release this connection to the connection pool
fn release(&mut self) {
let io = self.io.take().unwrap();
self.acquired
.release(ConnectionInnerType::H1(io), self.created);
}
fn io_pin_mut(self: Pin<&mut Self>) -> Pin<&mut Io> {
Pin::new(self.get_mut().io.as_mut().unwrap())
}
}
impl<Io: ConnectionIo> AsyncRead for H1Connection<Io> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.io_pin_mut().poll_read(cx, buf)
}
}
impl<Io: ConnectionIo> AsyncWrite for H1Connection<Io> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.io_pin_mut().poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.io_pin_mut().poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
self.io_pin_mut().poll_shutdown(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
self.io_pin_mut().poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.io.as_ref().unwrap().is_write_vectored()
}
}
/// HTTP2 client connection
pub struct H2Connection<Io: ConnectionIo> {
io: Option<H2ConnectionInner>,
created: time::Instant,
acquired: Acquired<Io>,
}
impl<Io: ConnectionIo> Deref for H2Connection<Io> {
type Target = SendRequest<Bytes>;
fn deref(&self) -> &Self::Target {
&self.io.as_ref().unwrap().sender
}
}
impl<Io: ConnectionIo> DerefMut for H2Connection<Io> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.io.as_mut().unwrap().sender
}
}
impl<Io: ConnectionIo> H2Connection<Io> {
/// close or release the connection to pool based on flag input
pub(super) fn on_release(&mut self, close: bool) {
if close {
self.close();
} else {
self.release();
}
}
/// Close connection
fn close(&mut self) {
let io = self.io.take().unwrap();
self.acquired.close(ConnectionInnerType::H2(io));
}
/// Release this connection to the connection pool
fn release(&mut self) {
let io = self.io.take().unwrap();
self.acquired
.release(ConnectionInnerType::H2(io), self.created);
}
}
/// `H2ConnectionInner` has two parts: `SendRequest` and `Connection`.
///
/// `Connection` is spawned as an async task on runtime and `H2ConnectionInner` holds a handle
/// for this task. Therefore, it can wake up and quit the task when SendRequest is dropped.
pub(super) struct H2ConnectionInner {
handle: JoinHandle<()>,
sender: SendRequest<Bytes>,
}
impl H2ConnectionInner {
pub(super) fn new<Io: ConnectionIo>(
sender: SendRequest<Bytes>,
connection: h2::client::Connection<Io>,
) -> Self {
let handle = actix_rt::spawn(async move {
let _ = connection.await;
});
Self { handle, sender }
}
}
/// Cancel spawned connection task on drop.
impl Drop for H2ConnectionInner {
fn drop(&mut self) {
if self
.sender
.send_request(http::Request::new(()), true)
.is_err()
{
self.handle.abort();
}
}
}
#[allow(dead_code)]
/// Unified connection type cover Http1 Plain/Tls and Http2 protocols
pub enum Connection<A, B = Box<dyn ConnectionIo>>
where
A: ConnectionIo,
B: ConnectionIo,
{
Tcp(ConnectionType<A>),
Tls(ConnectionType<B>),
}
/// Unified connection type cover Http1/2 protocols
pub enum ConnectionType<Io: ConnectionIo> {
H1(H1Connection<Io>),
H2(H2Connection<Io>),
}
/// Helper type for storing connection types in pool.
pub(super) enum ConnectionInnerType<Io> {
H1(Io),
H2(H2ConnectionInner),
}
impl<Io: ConnectionIo> ConnectionType<Io> {
pub(super) fn from_pool(
inner: ConnectionInnerType<Io>,
created: time::Instant,
acquired: Acquired<Io>,
) -> Self {
match inner {
ConnectionInnerType::H1(io) => Self::from_h1(io, created, acquired),
ConnectionInnerType::H2(io) => Self::from_h2(io, created, acquired),
}
}
pub(super) fn from_h1(
io: Io,
created: time::Instant,
acquired: Acquired<Io>,
) -> Self {
Self::H1(H1Connection {
io: Some(io),
created,
acquired,
})
}
pub(super) fn from_h2(
io: H2ConnectionInner,
created: time::Instant,
acquired: Acquired<Io>,
) -> Self {
Self::H2(H2Connection {
io: Some(io),
created,
acquired,
})
}
}
impl<A, B> Connection<A, B>
where
A: ConnectionIo,
B: ConnectionIo,
{
/// Send a request through connection.
pub fn send_request<RB, H>(
self,
head: H,
body: RB,
) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>
where
H: Into<RequestHeadType> + 'static,
RB: MessageBody + 'static,
RB::Error: Into<Error>,
{
Box::pin(async move {
match self {
Connection::Tcp(ConnectionType::H1(conn)) => {
h1proto::send_request(conn, head.into(), body).await
}
Connection::Tls(ConnectionType::H1(conn)) => {
h1proto::send_request(conn, head.into(), body).await
}
Connection::Tls(ConnectionType::H2(conn)) => {
h2proto::send_request(conn, head.into(), body).await
}
_ => unreachable!(
"Plain Tcp connection can be used only in Http1 protocol"
),
}
})
}
/// Send request, returns Response and Framed tunnel.
pub fn open_tunnel<H: Into<RequestHeadType> + 'static>(
self,
head: H,
) -> LocalBoxFuture<
'static,
Result<(ResponseHead, Framed<Connection<A, B>, ClientCodec>), SendRequestError>,
> {
Box::pin(async move {
match self {
Connection::Tcp(ConnectionType::H1(ref _conn)) => {
let (head, framed) = h1proto::open_tunnel(self, head.into()).await?;
Ok((head, framed))
}
Connection::Tls(ConnectionType::H1(ref _conn)) => {
let (head, framed) = h1proto::open_tunnel(self, head.into()).await?;
Ok((head, framed))
}
Connection::Tls(ConnectionType::H2(mut conn)) => {
conn.release();
Err(SendRequestError::TunnelNotSupported)
}
Connection::Tcp(ConnectionType::H2(_)) => {
unreachable!(
"Plain Tcp connection can be used only in Http1 protocol"
)
}
}
})
}
}
impl<A, B> AsyncRead for Connection<A, B>
where
A: ConnectionIo,
B: ConnectionIo,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Connection::Tcp(ConnectionType::H1(conn)) => {
Pin::new(conn).poll_read(cx, buf)
}
Connection::Tls(ConnectionType::H1(conn)) => {
Pin::new(conn).poll_read(cx, buf)
}
_ => unreachable!("H2Connection can not impl AsyncRead trait"),
}
}
}
const H2_UNREACHABLE_WRITE: &str = "H2Connection can not impl AsyncWrite trait";
impl<A, B> AsyncWrite for Connection<A, B>
where
A: ConnectionIo,
B: ConnectionIo,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Connection::Tcp(ConnectionType::H1(conn)) => {
Pin::new(conn).poll_write(cx, buf)
}
Connection::Tls(ConnectionType::H1(conn)) => {
Pin::new(conn).poll_write(cx, buf)
}
_ => unreachable!(H2_UNREACHABLE_WRITE),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Connection::Tcp(ConnectionType::H1(conn)) => Pin::new(conn).poll_flush(cx),
Connection::Tls(ConnectionType::H1(conn)) => Pin::new(conn).poll_flush(cx),
_ => unreachable!(H2_UNREACHABLE_WRITE),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Connection::Tcp(ConnectionType::H1(conn)) => {
Pin::new(conn).poll_shutdown(cx)
}
Connection::Tls(ConnectionType::H1(conn)) => {
Pin::new(conn).poll_shutdown(cx)
}
_ => unreachable!(H2_UNREACHABLE_WRITE),
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Connection::Tcp(ConnectionType::H1(conn)) => {
Pin::new(conn).poll_write_vectored(cx, bufs)
}
Connection::Tls(ConnectionType::H1(conn)) => {
Pin::new(conn).poll_write_vectored(cx, bufs)
}
_ => unreachable!(H2_UNREACHABLE_WRITE),
}
}
fn is_write_vectored(&self) -> bool {
match *self {
Connection::Tcp(ConnectionType::H1(ref conn)) => conn.is_write_vectored(),
Connection::Tls(ConnectionType::H1(ref conn)) => conn.is_write_vectored(),
_ => unreachable!(H2_UNREACHABLE_WRITE),
}
}
}
#[cfg(test)]
mod test {
use std::{
future::Future,
net,
pin::Pin,
task::{Context, Poll},
time::{Duration, Instant},
};
use actix_rt::{
net::TcpStream,
time::{interval, Interval},
};
use super::*;
#[actix_rt::test]
async fn test_h2_connection_drop() {
let addr = "127.0.0.1:0".parse::<net::SocketAddr>().unwrap();
let listener = net::TcpListener::bind(addr).unwrap();
let local = listener.local_addr().unwrap();
std::thread::spawn(move || while listener.accept().is_ok() {});
let tcp = TcpStream::connect(local).await.unwrap();
let (sender, connection) = h2::client::handshake(tcp).await.unwrap();
let conn = H2ConnectionInner::new(sender.clone(), connection);
assert!(sender.clone().ready().await.is_ok());
assert!(h2::client::SendRequest::clone(&conn.sender)
.ready()
.await
.is_ok());
drop(conn);
struct DropCheck {
sender: h2::client::SendRequest<Bytes>,
interval: Interval,
start_from: Instant,
}
impl Future for DropCheck {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match futures_core::ready!(this.sender.poll_ready(cx)) {
Ok(()) => {
if this.start_from.elapsed() > Duration::from_secs(10) {
panic!("connection should be gone and can not be ready");
} else {
let _ = this.interval.poll_tick(cx);
Poll::Pending
}
}
Err(_) => Poll::Ready(()),
}
}
}
DropCheck {
sender,
interval: interval(Duration::from_millis(100)),
start_from: Instant::now(),
}
.await;
}
}

View File

@ -1,764 +0,0 @@
use std::{
fmt,
future::Future,
net::IpAddr,
pin::Pin,
rc::Rc,
task::{Context, Poll},
time::Duration,
};
use actix_rt::{
net::{ActixStream, TcpStream},
time::{sleep, Sleep},
};
use actix_service::Service;
use actix_tls::connect::{
new_connector, Connect as TcpConnect, ConnectError as TcpConnectError,
Connection as TcpConnection, Resolver,
};
use futures_core::{future::LocalBoxFuture, ready};
use http::Uri;
use pin_project::pin_project;
use super::config::ConnectorConfig;
use super::connection::{Connection, ConnectionIo};
use super::error::ConnectError;
use super::pool::ConnectionPool;
use super::Connect;
use super::Protocol;
#[cfg(feature = "openssl")]
use actix_tls::connect::ssl::openssl::SslConnector as OpensslConnector;
#[cfg(feature = "rustls")]
use actix_tls::connect::ssl::rustls::ClientConfig;
enum SslConnector {
#[allow(dead_code)]
None,
#[cfg(feature = "openssl")]
Openssl(OpensslConnector),
#[cfg(feature = "rustls")]
Rustls(std::sync::Arc<ClientConfig>),
}
/// Manages HTTP client network connectivity.
///
/// The `Connector` type uses a builder-like combinator pattern for service
/// construction that finishes by calling the `.finish()` method.
///
/// ```ignore
/// use std::time::Duration;
/// use actix_http::client::Connector;
///
/// let connector = Connector::new()
/// .timeout(Duration::from_secs(5))
/// .finish();
/// ```
pub struct Connector<T> {
connector: T,
config: ConnectorConfig,
#[allow(dead_code)]
ssl: SslConnector,
}
impl Connector<()> {
#[allow(clippy::new_ret_no_self, clippy::let_unit_value)]
pub fn new() -> Connector<
impl Service<
TcpConnect<Uri>,
Response = TcpConnection<Uri, TcpStream>,
Error = actix_tls::connect::ConnectError,
> + Clone,
> {
Connector {
ssl: Self::build_ssl(vec![b"h2".to_vec(), b"http/1.1".to_vec()]),
connector: new_connector(resolver::resolver()),
config: ConnectorConfig::default(),
}
}
// Build Ssl connector with openssl, based on supplied alpn protocols
#[cfg(feature = "openssl")]
fn build_ssl(protocols: Vec<Vec<u8>>) -> SslConnector {
use actix_tls::connect::ssl::openssl::SslMethod;
use bytes::{BufMut, BytesMut};
let mut alpn = BytesMut::with_capacity(20);
for proto in &protocols {
alpn.put_u8(proto.len() as u8);
alpn.put(proto.as_slice());
}
let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap();
let _ = ssl
.set_alpn_protos(&alpn)
.map_err(|e| error!("Can not set alpn protocol: {:?}", e));
SslConnector::Openssl(ssl.build())
}
// Build Ssl connector with rustls, based on supplied alpn protocols
#[cfg(all(not(feature = "openssl"), feature = "rustls"))]
fn build_ssl(protocols: Vec<Vec<u8>>) -> SslConnector {
let mut config = ClientConfig::new();
config.set_protocols(&protocols);
config.root_store.add_server_trust_anchors(
&actix_tls::connect::ssl::rustls::TLS_SERVER_ROOTS,
);
SslConnector::Rustls(std::sync::Arc::new(config))
}
// ssl turned off, provides empty ssl connector
#[cfg(not(any(feature = "openssl", feature = "rustls")))]
fn build_ssl(_: Vec<Vec<u8>>) -> SslConnector {
SslConnector::None
}
}
impl<S> Connector<S> {
/// Use custom connector.
pub fn connector<S1, Io1>(self, connector: S1) -> Connector<S1>
where
Io1: ActixStream + fmt::Debug + 'static,
S1: Service<
TcpConnect<Uri>,
Response = TcpConnection<Uri, Io1>,
Error = TcpConnectError,
> + Clone,
{
Connector {
connector,
config: self.config,
ssl: self.ssl,
}
}
}
impl<S, Io> Connector<S>
where
// Note:
// Input Io type is bound to ActixStream trait but internally in client module they
// are bound to ConnectionIo trait alias. And latter is the trait exposed to public
// in the form of Box<dyn ConnectionIo> type.
//
// This remap is to hide ActixStream's trait methods. They are not meant to be called
// from user code.
Io: ActixStream + fmt::Debug + 'static,
S: Service<
TcpConnect<Uri>,
Response = TcpConnection<Uri, Io>,
Error = TcpConnectError,
> + Clone
+ 'static,
{
/// Tcp connection timeout, i.e. max time to connect to remote host including dns name
/// resolution. Set to 5 second by default.
pub fn timeout(mut self, timeout: Duration) -> Self {
self.config.timeout = timeout;
self
}
/// Tls handshake timeout, i.e. max time to do tls handshake with remote host after tcp
/// connection established. Set to 5 second by default.
pub fn handshake_timeout(mut self, timeout: Duration) -> Self {
self.config.handshake_timeout = timeout;
self
}
#[cfg(feature = "openssl")]
/// Use custom `SslConnector` instance.
pub fn ssl(mut self, connector: OpensslConnector) -> Self {
self.ssl = SslConnector::Openssl(connector);
self
}
#[cfg(feature = "rustls")]
/// Use custom `SslConnector` instance.
pub fn rustls(mut self, connector: std::sync::Arc<ClientConfig>) -> Self {
self.ssl = SslConnector::Rustls(connector);
self
}
/// Maximum supported HTTP major version.
///
/// Supported versions are HTTP/1.1 and HTTP/2.
pub fn max_http_version(mut self, val: http::Version) -> Self {
let versions = match val {
http::Version::HTTP_11 => vec![b"http/1.1".to_vec()],
http::Version::HTTP_2 => vec![b"h2".to_vec(), b"http/1.1".to_vec()],
_ => {
unimplemented!("actix-http:client: supported versions http/1.1, http/2")
}
};
self.ssl = Connector::build_ssl(versions);
self
}
/// Indicates the initial window size (in octets) for
/// HTTP2 stream-level flow control for received data.
///
/// The default value is 65,535 and is good for APIs, but not for big objects.
pub fn initial_window_size(mut self, size: u32) -> Self {
self.config.stream_window_size = size;
self
}
/// Indicates the initial window size (in octets) for
/// HTTP2 connection-level flow control for received data.
///
/// The default value is 65,535 and is good for APIs, but not for big objects.
pub fn initial_connection_window_size(mut self, size: u32) -> Self {
self.config.conn_window_size = size;
self
}
/// Set total number of simultaneous connections per type of scheme.
///
/// If limit is 0, the connector has no limit.
/// The default limit size is 100.
pub fn limit(mut self, limit: usize) -> Self {
self.config.limit = limit;
self
}
/// Set keep-alive period for opened connection.
///
/// Keep-alive period is the period between connection usage. If
/// the delay between repeated usages of the same connection
/// exceeds this period, the connection is closed.
/// Default keep-alive period is 15 seconds.
pub fn conn_keep_alive(mut self, dur: Duration) -> Self {
self.config.conn_keep_alive = dur;
self
}
/// Set max lifetime period for connection.
///
/// Connection lifetime is max lifetime of any opened connection
/// until it is closed regardless of keep-alive period.
/// Default lifetime period is 75 seconds.
pub fn conn_lifetime(mut self, dur: Duration) -> Self {
self.config.conn_lifetime = dur;
self
}
/// Set server connection disconnect timeout in milliseconds.
///
/// Defines a timeout for disconnect connection. If a disconnect procedure does not complete
/// within this time, the socket get dropped. This timeout affects only secure connections.
///
/// To disable timeout set value to 0.
///
/// By default disconnect timeout is set to 3000 milliseconds.
pub fn disconnect_timeout(mut self, dur: Duration) -> Self {
self.config.disconnect_timeout = Some(dur);
self
}
/// Set local IP Address the connector would use for establishing connection.
pub fn local_address(mut self, addr: IpAddr) -> Self {
self.config.local_address = Some(addr);
self
}
/// Finish configuration process and create connector service.
/// The Connector builder always concludes by calling `finish()` last in
/// its combinator chain.
pub fn finish(self) -> ConnectorService<S, Io> {
let local_address = self.config.local_address;
let timeout = self.config.timeout;
let tcp_service_inner =
TcpConnectorInnerService::new(self.connector, timeout, local_address);
#[allow(clippy::redundant_clone)]
let tcp_service = TcpConnectorService {
service: tcp_service_inner.clone(),
};
let tls_service = match self.ssl {
SslConnector::None => None,
#[cfg(feature = "openssl")]
SslConnector::Openssl(tls) => {
const H2: &[u8] = b"h2";
use actix_tls::connect::ssl::openssl::{OpensslConnector, SslStream};
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, SslStream<Io>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let sock = self.into_parts().0;
let h2 = sock
.ssl()
.selected_alpn_protocol()
.map_or(false, |protos| protos.windows(2).any(|w| w == H2));
if h2 {
(Box::new(sock), Protocol::Http2)
} else {
(Box::new(sock), Protocol::Http1)
}
}
}
let handshake_timeout = self.config.handshake_timeout;
let tls_service = TlsConnectorService {
tcp_service: tcp_service_inner,
tls_service: OpensslConnector::service(tls),
timeout: handshake_timeout,
};
Some(actix_service::boxed::rc_service(tls_service))
}
#[cfg(feature = "rustls")]
SslConnector::Rustls(tls) => {
const H2: &[u8] = b"h2";
use actix_tls::connect::ssl::rustls::{
RustlsConnector, Session, TlsStream,
};
impl<Io: ConnectionIo> IntoConnectionIo for TcpConnection<Uri, TlsStream<Io>> {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol) {
let sock = self.into_parts().0;
let h2 = sock
.get_ref()
.1
.get_alpn_protocol()
.map_or(false, |protos| protos.windows(2).any(|w| w == H2));
if h2 {
(Box::new(sock), Protocol::Http2)
} else {
(Box::new(sock), Protocol::Http1)
}
}
}
let handshake_timeout = self.config.handshake_timeout;
let tls_service = TlsConnectorService {
tcp_service: tcp_service_inner,
tls_service: RustlsConnector::service(tls),
timeout: handshake_timeout,
};
Some(actix_service::boxed::rc_service(tls_service))
}
};
let tcp_config = self.config.no_disconnect_timeout();
let tcp_pool = ConnectionPool::new(tcp_service, tcp_config);
let tls_config = self.config;
let tls_pool = tls_service
.map(move |tls_service| ConnectionPool::new(tls_service, tls_config));
ConnectorServicePriv { tcp_pool, tls_pool }
}
}
/// tcp service for map `TcpConnection<Uri, Io>` type to `(Io, Protocol)`
#[derive(Clone)]
pub struct TcpConnectorService<S: Clone> {
service: S,
}
impl<S, Io> Service<Connect> for TcpConnectorService<S>
where
S: Service<Connect, Response = TcpConnection<Uri, Io>, Error = ConnectError>
+ Clone
+ 'static,
{
type Response = (Io, Protocol);
type Error = ConnectError;
type Future = TcpConnectorFuture<S::Future>;
actix_service::forward_ready!(service);
fn call(&self, req: Connect) -> Self::Future {
TcpConnectorFuture {
fut: self.service.call(req),
}
}
}
#[pin_project]
pub struct TcpConnectorFuture<Fut> {
#[pin]
fut: Fut,
}
impl<Fut, Io> Future for TcpConnectorFuture<Fut>
where
Fut: Future<Output = Result<TcpConnection<Uri, Io>, ConnectError>>,
{
type Output = Result<(Io, Protocol), ConnectError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project()
.fut
.poll(cx)
.map_ok(|res| (res.into_parts().0, Protocol::Http1))
}
}
/// service for establish tcp connection and do client tls handshake.
/// operation is canceled when timeout limit reached.
struct TlsConnectorService<S, St> {
/// tcp connection is canceled on `TcpConnectorInnerService`'s timeout setting.
tcp_service: S,
/// tls connection is canceled on `TlsConnectorService`'s timeout setting.
tls_service: St,
timeout: Duration,
}
impl<S, St, Io> Service<Connect> for TlsConnectorService<S, St>
where
S: Service<Connect, Response = TcpConnection<Uri, Io>, Error = ConnectError>
+ Clone
+ 'static,
St: Service<TcpConnection<Uri, Io>, Error = std::io::Error> + Clone + 'static,
Io: ConnectionIo,
St::Response: IntoConnectionIo,
{
type Response = (Box<dyn ConnectionIo>, Protocol);
type Error = ConnectError;
type Future = TlsConnectorFuture<St, S::Future, St::Future>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.tcp_service.poll_ready(cx))?;
ready!(self.tls_service.poll_ready(cx))?;
Poll::Ready(Ok(()))
}
fn call(&self, req: Connect) -> Self::Future {
let fut = self.tcp_service.call(req);
let tls_service = self.tls_service.clone();
let timeout = self.timeout;
TlsConnectorFuture::TcpConnect {
fut,
tls_service: Some(tls_service),
timeout,
}
}
}
#[pin_project(project = TlsConnectorProj)]
#[allow(clippy::large_enum_variant)]
enum TlsConnectorFuture<S, Fut1, Fut2> {
TcpConnect {
#[pin]
fut: Fut1,
tls_service: Option<S>,
timeout: Duration,
},
TlsConnect {
#[pin]
fut: Fut2,
#[pin]
timeout: Sleep,
},
}
/// helper trait for generic over different TlsStream types between tls crates.
trait IntoConnectionIo {
fn into_connection_io(self) -> (Box<dyn ConnectionIo>, Protocol);
}
impl<S, Io, Fut1, Fut2, Res> Future for TlsConnectorFuture<S, Fut1, Fut2>
where
S: Service<
TcpConnection<Uri, Io>,
Response = Res,
Error = std::io::Error,
Future = Fut2,
>,
S::Response: IntoConnectionIo,
Fut1: Future<Output = Result<TcpConnection<Uri, Io>, ConnectError>>,
Fut2: Future<Output = Result<S::Response, S::Error>>,
Io: ConnectionIo,
{
type Output = Result<(Box<dyn ConnectionIo>, Protocol), ConnectError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().project() {
TlsConnectorProj::TcpConnect {
fut,
tls_service,
timeout,
} => {
let res = ready!(fut.poll(cx))?;
let fut = tls_service
.take()
.expect("TlsConnectorFuture polled after complete")
.call(res);
let timeout = sleep(*timeout);
self.set(TlsConnectorFuture::TlsConnect { fut, timeout });
self.poll(cx)
}
TlsConnectorProj::TlsConnect { fut, timeout } => match fut.poll(cx)? {
Poll::Ready(res) => Poll::Ready(Ok(res.into_connection_io())),
Poll::Pending => timeout.poll(cx).map(|_| Err(ConnectError::Timeout)),
},
}
}
}
/// service for establish tcp connection.
/// operation is canceled when timeout limit reached.
#[derive(Clone)]
pub struct TcpConnectorInnerService<S: Clone> {
service: S,
timeout: Duration,
local_address: Option<std::net::IpAddr>,
}
impl<S: Clone> TcpConnectorInnerService<S> {
fn new(
service: S,
timeout: Duration,
local_address: Option<std::net::IpAddr>,
) -> Self {
Self {
service,
timeout,
local_address,
}
}
}
impl<S, Io> Service<Connect> for TcpConnectorInnerService<S>
where
S: Service<
TcpConnect<Uri>,
Response = TcpConnection<Uri, Io>,
Error = TcpConnectError,
> + Clone
+ 'static,
{
type Response = S::Response;
type Error = ConnectError;
type Future = TcpConnectorInnerFuture<S::Future>;
actix_service::forward_ready!(service);
fn call(&self, req: Connect) -> Self::Future {
let mut req = TcpConnect::new(req.uri).set_addr(req.addr);
if let Some(local_addr) = self.local_address {
req = req.set_local_addr(local_addr);
}
TcpConnectorInnerFuture {
fut: self.service.call(req),
timeout: sleep(self.timeout),
}
}
}
#[pin_project]
pub struct TcpConnectorInnerFuture<Fut> {
#[pin]
fut: Fut,
#[pin]
timeout: Sleep,
}
impl<Fut, Io> Future for TcpConnectorInnerFuture<Fut>
where
Fut: Future<Output = Result<TcpConnection<Uri, Io>, TcpConnectError>>,
{
type Output = Result<TcpConnection<Uri, Io>, ConnectError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.fut.poll(cx) {
Poll::Ready(res) => Poll::Ready(res.map_err(ConnectError::from)),
Poll::Pending => this.timeout.poll(cx).map(|_| Err(ConnectError::Timeout)),
}
}
}
/// Connector service for pooled Plain/Tls Tcp connections.
pub type ConnectorService<S, Io> = ConnectorServicePriv<
TcpConnectorService<TcpConnectorInnerService<S>>,
Rc<
dyn Service<
Connect,
Response = (Box<dyn ConnectionIo>, Protocol),
Error = ConnectError,
Future = LocalBoxFuture<
'static,
Result<(Box<dyn ConnectionIo>, Protocol), ConnectError>,
>,
>,
>,
Io,
Box<dyn ConnectionIo>,
>;
pub struct ConnectorServicePriv<S1, S2, Io1, Io2>
where
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError>,
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError>,
Io1: ConnectionIo,
Io2: ConnectionIo,
{
tcp_pool: ConnectionPool<S1, Io1>,
tls_pool: Option<ConnectionPool<S2, Io2>>,
}
impl<S1, S2, Io1, Io2> Service<Connect> for ConnectorServicePriv<S1, S2, Io1, Io2>
where
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static,
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static,
Io1: ConnectionIo,
Io2: ConnectionIo,
{
type Response = Connection<Io1, Io2>;
type Error = ConnectError;
type Future = ConnectorServiceFuture<S1, S2, Io1, Io2>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.tcp_pool.poll_ready(cx))?;
if let Some(ref tls_pool) = self.tls_pool {
ready!(tls_pool.poll_ready(cx))?;
}
Poll::Ready(Ok(()))
}
fn call(&self, req: Connect) -> Self::Future {
match req.uri.scheme_str() {
Some("https") | Some("wss") => match self.tls_pool {
None => ConnectorServiceFuture::SslIsNotSupported,
Some(ref pool) => ConnectorServiceFuture::Tls(pool.call(req)),
},
_ => ConnectorServiceFuture::Tcp(self.tcp_pool.call(req)),
}
}
}
#[pin_project(project = ConnectorServiceProj)]
pub enum ConnectorServiceFuture<S1, S2, Io1, Io2>
where
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static,
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static,
Io1: ConnectionIo,
Io2: ConnectionIo,
{
Tcp(#[pin] <ConnectionPool<S1, Io1> as Service<Connect>>::Future),
Tls(#[pin] <ConnectionPool<S2, Io2> as Service<Connect>>::Future),
SslIsNotSupported,
}
impl<S1, S2, Io1, Io2> Future for ConnectorServiceFuture<S1, S2, Io1, Io2>
where
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError>
+ Clone
+ 'static,
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError>
+ Clone
+ 'static,
Io1: ConnectionIo,
Io2: ConnectionIo,
{
type Output = Result<Connection<Io1, Io2>, ConnectError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project() {
ConnectorServiceProj::Tcp(fut) => fut.poll(cx).map_ok(Connection::Tcp),
ConnectorServiceProj::Tls(fut) => fut.poll(cx).map_ok(Connection::Tls),
ConnectorServiceProj::SslIsNotSupported => {
Poll::Ready(Err(ConnectError::SslIsNotSupported))
}
}
}
}
#[cfg(not(feature = "trust-dns"))]
mod resolver {
use super::*;
pub(super) fn resolver() -> Resolver {
Resolver::Default
}
}
#[cfg(feature = "trust-dns")]
mod resolver {
use std::{cell::RefCell, net::SocketAddr};
use actix_tls::connect::Resolve;
use futures_core::future::LocalBoxFuture;
use trust_dns_resolver::{
config::{ResolverConfig, ResolverOpts},
system_conf::read_system_conf,
TokioAsyncResolver,
};
use super::*;
pub(super) fn resolver() -> Resolver {
// new type for impl Resolve trait for TokioAsyncResolver.
struct TrustDnsResolver(TokioAsyncResolver);
impl Resolve for TrustDnsResolver {
fn lookup<'a>(
&'a self,
host: &'a str,
port: u16,
) -> LocalBoxFuture<'a, Result<Vec<SocketAddr>, Box<dyn std::error::Error>>>
{
Box::pin(async move {
let res = self
.0
.lookup_ip(host)
.await?
.iter()
.map(|ip| SocketAddr::new(ip, port))
.collect();
Ok(res)
})
}
}
// dns struct is cached in thread local.
// so new client constructor can reuse the existing dns resolver.
thread_local! {
static TRUST_DNS_RESOLVER: RefCell<Option<Resolver>> = RefCell::new(None);
}
// get from thread local or construct a new trust-dns resolver.
TRUST_DNS_RESOLVER.with(|local| {
let resolver = local.borrow().as_ref().map(Clone::clone);
match resolver {
Some(resolver) => resolver,
None => {
let (cfg, opts) = match read_system_conf() {
Ok((cfg, opts)) => (cfg, opts),
Err(e) => {
log::error!("TRust-DNS can not load system config: {}", e);
(ResolverConfig::default(), ResolverOpts::default())
}
};
let resolver = TokioAsyncResolver::tokio(cfg, opts).unwrap();
// box trust dns resolver and put it in thread local.
let resolver = Resolver::new_custom(TrustDnsResolver(resolver));
*local.borrow_mut() = Some(resolver.clone());
resolver
}
}
})
}
}

View File

@ -1,156 +0,0 @@
use std::{error::Error as StdError, fmt, io};
use derive_more::{Display, From};
#[cfg(feature = "openssl")]
use actix_tls::accept::openssl::SslError;
use crate::error::{Error, ParseError};
use crate::http::Error as HttpError;
/// A set of errors that can occur while connecting to an HTTP host
#[derive(Debug, Display, From)]
#[non_exhaustive]
pub enum ConnectError {
/// SSL feature is not enabled
#[display(fmt = "SSL is not supported")]
SslIsNotSupported,
/// SSL error
#[cfg(feature = "openssl")]
#[display(fmt = "{}", _0)]
SslError(SslError),
/// Failed to resolve the hostname
#[display(fmt = "Failed resolving hostname: {}", _0)]
Resolver(Box<dyn std::error::Error>),
/// No dns records
#[display(fmt = "No DNS records found for the input")]
NoRecords,
/// Http2 error
#[display(fmt = "{}", _0)]
H2(h2::Error),
/// Connecting took too long
#[display(fmt = "Timeout while establishing connection")]
Timeout,
/// Connector has been disconnected
#[display(fmt = "Internal error: connector has been disconnected")]
Disconnected,
/// Unresolved host name
#[display(fmt = "Connector received `Connect` method with unresolved host")]
Unresolved,
/// Connection io error
#[display(fmt = "{}", _0)]
Io(io::Error),
}
impl std::error::Error for ConnectError {}
impl From<actix_tls::connect::ConnectError> for ConnectError {
fn from(err: actix_tls::connect::ConnectError) -> ConnectError {
match err {
actix_tls::connect::ConnectError::Resolver(e) => ConnectError::Resolver(e),
actix_tls::connect::ConnectError::NoRecords => ConnectError::NoRecords,
actix_tls::connect::ConnectError::InvalidInput => panic!(),
actix_tls::connect::ConnectError::Unresolved => ConnectError::Unresolved,
actix_tls::connect::ConnectError::Io(e) => ConnectError::Io(e),
}
}
}
#[derive(Debug, Display, From)]
#[non_exhaustive]
pub enum InvalidUrl {
#[display(fmt = "Missing URL scheme")]
MissingScheme,
#[display(fmt = "Unknown URL scheme")]
UnknownScheme,
#[display(fmt = "Missing host name")]
MissingHost,
#[display(fmt = "URL parse error: {}", _0)]
HttpError(http::Error),
}
impl std::error::Error for InvalidUrl {}
/// A set of errors that can occur during request sending and response reading
#[derive(Debug, Display, From)]
#[non_exhaustive]
pub enum SendRequestError {
/// Invalid URL
#[display(fmt = "Invalid URL: {}", _0)]
Url(InvalidUrl),
/// Failed to connect to host
#[display(fmt = "Failed to connect to host: {}", _0)]
Connect(ConnectError),
/// Error sending request
Send(io::Error),
/// Error parsing response
Response(ParseError),
/// Http error
#[display(fmt = "{}", _0)]
Http(HttpError),
/// Http2 error
#[display(fmt = "{}", _0)]
H2(h2::Error),
/// Response took too long
#[display(fmt = "Timeout while waiting for response")]
Timeout,
/// Tunnels are not supported for HTTP/2 connection
#[display(fmt = "Tunnels are not supported for http2 connection")]
TunnelNotSupported,
/// Error sending request body
Body(Error),
/// Other errors that can occur after submitting a request.
#[display(fmt = "{:?}: {}", _1, _0)]
Custom(Box<dyn StdError>, Box<dyn fmt::Debug>),
}
impl std::error::Error for SendRequestError {}
/// A set of errors that can occur during freezing a request
#[derive(Debug, Display, From)]
#[non_exhaustive]
pub enum FreezeRequestError {
/// Invalid URL
#[display(fmt = "Invalid URL: {}", _0)]
Url(InvalidUrl),
/// HTTP error
#[display(fmt = "{}", _0)]
Http(HttpError),
/// Other errors that can occur after submitting a request.
#[display(fmt = "{:?}: {}", _1, _0)]
Custom(Box<dyn StdError>, Box<dyn fmt::Debug>),
}
impl std::error::Error for FreezeRequestError {}
impl From<FreezeRequestError> for SendRequestError {
fn from(err: FreezeRequestError) -> Self {
match err {
FreezeRequestError::Url(err) => err.into(),
FreezeRequestError::Http(err) => err.into(),
FreezeRequestError::Custom(err, msg) => SendRequestError::Custom(err, msg),
}
}
}

View File

@ -1,230 +0,0 @@
use std::{
io::Write,
pin::Pin,
task::{Context, Poll},
};
use actix_codec::Framed;
use actix_utils::future::poll_fn;
use bytes::buf::BufMut;
use bytes::{Bytes, BytesMut};
use futures_core::{ready, Stream};
use futures_util::SinkExt as _;
use crate::h1;
use crate::http::{
header::{HeaderMap, IntoHeaderValue, EXPECT, HOST},
StatusCode,
};
use crate::message::{RequestHeadType, ResponseHead};
use crate::payload::Payload;
use crate::{error::PayloadError, Error};
use super::connection::{ConnectionIo, H1Connection};
use super::error::{ConnectError, SendRequestError};
use crate::body::{BodySize, MessageBody};
pub(crate) async fn send_request<Io, B>(
io: H1Connection<Io>,
mut head: RequestHeadType,
body: B,
) -> Result<(ResponseHead, Payload), SendRequestError>
where
Io: ConnectionIo,
B: MessageBody,
B::Error: Into<Error>,
{
// set request host header
if !head.as_ref().headers.contains_key(HOST)
&& !head.extra_headers().iter().any(|h| h.contains_key(HOST))
{
if let Some(host) = head.as_ref().uri.host() {
let mut wrt = BytesMut::with_capacity(host.len() + 5).writer();
match head.as_ref().uri.port_u16() {
None | Some(80) | Some(443) => write!(wrt, "{}", host)?,
Some(port) => write!(wrt, "{}:{}", host, port)?,
};
match wrt.get_mut().split().freeze().try_into_value() {
Ok(value) => match head {
RequestHeadType::Owned(ref mut head) => {
head.headers.insert(HOST, value);
}
RequestHeadType::Rc(_, ref mut extra_headers) => {
let headers = extra_headers.get_or_insert(HeaderMap::new());
headers.insert(HOST, value);
}
},
Err(e) => log::error!("Can not set HOST header {}", e),
}
}
}
// create Framed and prepare sending request
let mut framed = Framed::new(io, h1::ClientCodec::default());
// Check EXPECT header and enable expect handle flag accordingly.
//
// RFC: https://tools.ietf.org/html/rfc7231#section-5.1.1
let is_expect = if head.as_ref().headers.contains_key(EXPECT) {
match body.size() {
BodySize::None | BodySize::Empty | BodySize::Sized(0) => {
let keep_alive = framed.codec_ref().keepalive();
framed.io_mut().on_release(keep_alive);
// TODO: use a new variant or a new type better describing error violate
// `Requirements for clients` session of above RFC
return Err(SendRequestError::Connect(ConnectError::Disconnected));
}
_ => true,
}
} else {
false
};
framed.send((head, body.size()).into()).await?;
let mut pin_framed = Pin::new(&mut framed);
// special handle for EXPECT request.
let (do_send, mut res_head) = if is_expect {
let head = poll_fn(|cx| pin_framed.as_mut().poll_next(cx))
.await
.ok_or(ConnectError::Disconnected)??;
// return response head in case status code is not continue
// and current head would be used as final response head.
(head.status == StatusCode::CONTINUE, Some(head))
} else {
(true, None)
};
if do_send {
// send request body
match body.size() {
BodySize::None | BodySize::Empty | BodySize::Sized(0) => {}
_ => send_body(body, pin_framed.as_mut()).await?,
};
// read response and init read body
let head = poll_fn(|cx| pin_framed.as_mut().poll_next(cx))
.await
.ok_or(ConnectError::Disconnected)??;
res_head = Some(head);
}
let head = res_head.unwrap();
match pin_framed.codec_ref().message_type() {
h1::MessageType::None => {
let keep_alive = pin_framed.codec_ref().keepalive();
pin_framed.io_mut().on_release(keep_alive);
Ok((head, Payload::None))
}
_ => Ok((head, Payload::Stream(Box::pin(PlStream::new(framed))))),
}
}
pub(crate) async fn open_tunnel<Io>(
io: Io,
head: RequestHeadType,
) -> Result<(ResponseHead, Framed<Io, h1::ClientCodec>), SendRequestError>
where
Io: ConnectionIo,
{
// create Framed and send request.
let mut framed = Framed::new(io, h1::ClientCodec::default());
framed.send((head, BodySize::None).into()).await?;
// read response head.
let head = poll_fn(|cx| Pin::new(&mut framed).poll_next(cx))
.await
.ok_or(ConnectError::Disconnected)??;
Ok((head, framed))
}
/// send request body to the peer
pub(crate) async fn send_body<Io, B>(
body: B,
mut framed: Pin<&mut Framed<Io, h1::ClientCodec>>,
) -> Result<(), SendRequestError>
where
Io: ConnectionIo,
B: MessageBody,
B::Error: Into<Error>,
{
actix_rt::pin!(body);
let mut eof = false;
while !eof {
while !eof && !framed.as_ref().is_write_buf_full() {
match poll_fn(|cx| body.as_mut().poll_next(cx)).await {
Some(Ok(chunk)) => {
framed.as_mut().write(h1::Message::Chunk(Some(chunk)))?;
}
Some(Err(err)) => return Err(err.into().into()),
None => {
eof = true;
framed.as_mut().write(h1::Message::Chunk(None))?;
}
}
}
if !framed.as_ref().is_write_buf_empty() {
poll_fn(|cx| match framed.as_mut().flush(cx) {
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => {
if !framed.as_ref().is_write_buf_full() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
})
.await?;
}
}
framed.get_mut().flush().await?;
Ok(())
}
#[pin_project::pin_project]
pub(crate) struct PlStream<Io: ConnectionIo> {
#[pin]
framed: Framed<H1Connection<Io>, h1::ClientPayloadCodec>,
}
impl<Io: ConnectionIo> PlStream<Io> {
fn new(framed: Framed<H1Connection<Io>, h1::ClientCodec>) -> Self {
let framed = framed.into_map_codec(|codec| codec.into_payload_codec());
PlStream { framed }
}
}
impl<Io: ConnectionIo> Stream for PlStream<Io> {
type Item = Result<Bytes, PayloadError>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let mut this = self.project();
match ready!(this.framed.as_mut().next_item(cx)?) {
Some(Some(chunk)) => Poll::Ready(Some(Ok(chunk))),
Some(None) => {
let keep_alive = this.framed.codec_ref().keepalive();
this.framed.io_mut().on_release(keep_alive);
Poll::Ready(None)
}
None => Poll::Ready(None),
}
}
}

View File

@ -1,195 +0,0 @@
use std::future::Future;
use actix_utils::future::poll_fn;
use bytes::Bytes;
use h2::{
client::{Builder, Connection, SendRequest},
SendStream,
};
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING};
use http::{request::Request, Method, Version};
use crate::{
body::{BodySize, MessageBody},
header::HeaderMap,
message::{RequestHeadType, ResponseHead},
payload::Payload,
Error,
};
use super::{
config::ConnectorConfig,
connection::{ConnectionIo, H2Connection},
error::SendRequestError,
};
pub(crate) async fn send_request<Io, B>(
mut io: H2Connection<Io>,
head: RequestHeadType,
body: B,
) -> Result<(ResponseHead, Payload), SendRequestError>
where
Io: ConnectionIo,
B: MessageBody,
B::Error: Into<Error>,
{
trace!("Sending client request: {:?} {:?}", head, body.size());
let head_req = head.as_ref().method == Method::HEAD;
let length = body.size();
let eof = matches!(
length,
BodySize::None | BodySize::Empty | BodySize::Sized(0)
);
let mut req = Request::new(());
*req.uri_mut() = head.as_ref().uri.clone();
*req.method_mut() = head.as_ref().method.clone();
*req.version_mut() = Version::HTTP_2;
let mut skip_len = true;
// let mut has_date = false;
// Content length
let _ = match length {
BodySize::None => None,
BodySize::Stream => {
skip_len = false;
None
}
BodySize::Empty => req
.headers_mut()
.insert(CONTENT_LENGTH, HeaderValue::from_static("0")),
BodySize::Sized(len) => {
let mut buf = itoa::Buffer::new();
req.headers_mut().insert(
CONTENT_LENGTH,
HeaderValue::from_str(buf.format(len)).unwrap(),
)
}
};
// Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate.
let (head, extra_headers) = match head {
RequestHeadType::Owned(head) => (RequestHeadType::Owned(head), HeaderMap::new()),
RequestHeadType::Rc(head, extra_headers) => (
RequestHeadType::Rc(head, None),
extra_headers.unwrap_or_else(HeaderMap::new),
),
};
// merging headers from head and extra headers.
let headers = head
.as_ref()
.headers
.iter()
.filter(|(name, _)| !extra_headers.contains_key(*name))
.chain(extra_headers.iter());
// copy headers
for (key, value) in headers {
match *key {
// TODO: consider skipping other headers according to:
// https://tools.ietf.org/html/rfc7540#section-8.1.2.2
// omit HTTP/1.x only headers
CONNECTION | TRANSFER_ENCODING => continue,
CONTENT_LENGTH if skip_len => continue,
// DATE => has_date = true,
_ => {}
}
req.headers_mut().append(key, value.clone());
}
let res = poll_fn(|cx| io.poll_ready(cx)).await;
if let Err(e) = res {
io.on_release(e.is_io());
return Err(SendRequestError::from(e));
}
let resp = match io.send_request(req, eof) {
Ok((fut, send)) => {
io.on_release(false);
if !eof {
send_body(body, send).await?;
}
fut.await.map_err(SendRequestError::from)?
}
Err(e) => {
io.on_release(e.is_io());
return Err(e.into());
}
};
let (parts, body) = resp.into_parts();
let payload = if head_req { Payload::None } else { body.into() };
let mut head = ResponseHead::new(parts.status);
head.version = parts.version;
head.headers = parts.headers.into();
Ok((head, payload))
}
async fn send_body<B>(
body: B,
mut send: SendStream<Bytes>,
) -> Result<(), SendRequestError>
where
B: MessageBody,
B::Error: Into<Error>,
{
let mut buf = None;
actix_rt::pin!(body);
loop {
if buf.is_none() {
match poll_fn(|cx| body.as_mut().poll_next(cx)).await {
Some(Ok(b)) => {
send.reserve_capacity(b.len());
buf = Some(b);
}
Some(Err(e)) => return Err(e.into().into()),
None => {
if let Err(e) = send.send_data(Bytes::new(), true) {
return Err(e.into());
}
send.reserve_capacity(0);
return Ok(());
}
}
}
match poll_fn(|cx| send.poll_capacity(cx)).await {
None => return Ok(()),
Some(Ok(cap)) => {
let b = buf.as_mut().unwrap();
let len = b.len();
let bytes = b.split_to(std::cmp::min(cap, len));
if let Err(e) = send.send_data(bytes, false) {
return Err(e.into());
}
if !b.is_empty() {
send.reserve_capacity(b.len());
} else {
buf = None;
}
continue;
}
Some(Err(e)) => return Err(e.into()),
}
}
}
pub(crate) fn handshake<Io: ConnectionIo>(
io: Io,
config: &ConnectorConfig,
) -> impl Future<Output = Result<(SendRequest<Bytes>, Connection<Io, Bytes>), h2::Error>>
{
let mut builder = Builder::new();
builder
.initial_window_size(config.stream_window_size)
.initial_connection_window_size(config.conn_window_size)
.enable_push(false);
builder.handshake(io)
}

View File

@ -1,26 +0,0 @@
//! HTTP client.
use http::Uri;
mod config;
mod connection;
mod connector;
mod error;
mod h1proto;
mod h2proto;
mod pool;
pub use actix_tls::connect::{
Connect as TcpConnect, ConnectError as TcpConnectError, Connection as TcpConnection,
};
pub use self::connection::{Connection, ConnectionIo};
pub use self::connector::{Connector, ConnectorService};
pub use self::error::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError};
pub use crate::Protocol;
#[derive(Clone)]
pub struct Connect {
pub uri: Uri,
pub addr: Option<std::net::SocketAddr>,
}

View File

@ -1,669 +0,0 @@
//! Client connection pooling keyed on the authority part of the connection URI.
use std::{
cell::RefCell,
collections::VecDeque,
future::Future,
io,
ops::Deref,
pin::Pin,
rc::Rc,
sync::Arc,
task::{Context, Poll},
time::{Duration, Instant},
};
use actix_codec::{AsyncRead, AsyncWrite, ReadBuf};
use actix_rt::time::{sleep, Sleep};
use actix_service::Service;
use ahash::AHashMap;
use futures_core::future::LocalBoxFuture;
use http::uri::Authority;
use pin_project::pin_project;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use super::config::ConnectorConfig;
use super::connection::{
ConnectionInnerType, ConnectionIo, ConnectionType, H2ConnectionInner,
};
use super::error::ConnectError;
use super::h2proto::handshake;
use super::Connect;
use super::Protocol;
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
pub struct Key {
authority: Authority,
}
impl From<Authority> for Key {
fn from(authority: Authority) -> Key {
Key { authority }
}
}
#[doc(hidden)]
/// Connections pool for reuse Io type for certain [`http::uri::Authority`] as key.
pub struct ConnectionPool<S, Io>
where
Io: AsyncWrite + Unpin + 'static,
{
connector: S,
inner: ConnectionPoolInner<Io>,
}
/// wrapper type for check the ref count of Rc.
pub struct ConnectionPoolInner<Io>(Rc<ConnectionPoolInnerPriv<Io>>)
where
Io: AsyncWrite + Unpin + 'static;
impl<Io> ConnectionPoolInner<Io>
where
Io: AsyncWrite + Unpin + 'static,
{
fn new(config: ConnectorConfig) -> Self {
let permits = Arc::new(Semaphore::new(config.limit));
let available = RefCell::new(AHashMap::default());
Self(Rc::new(ConnectionPoolInnerPriv {
config,
available,
permits,
}))
}
/// spawn a async for graceful shutdown h1 Io type with a timeout.
fn close(&self, conn: ConnectionInnerType<Io>) {
if let Some(timeout) = self.config.disconnect_timeout {
if let ConnectionInnerType::H1(io) = conn {
actix_rt::spawn(CloseConnection::new(io, timeout));
}
}
}
}
impl<Io> Clone for ConnectionPoolInner<Io>
where
Io: AsyncWrite + Unpin + 'static,
{
fn clone(&self) -> Self {
Self(Rc::clone(&self.0))
}
}
impl<Io> Deref for ConnectionPoolInner<Io>
where
Io: AsyncWrite + Unpin + 'static,
{
type Target = ConnectionPoolInnerPriv<Io>;
fn deref(&self) -> &Self::Target {
&*self.0
}
}
impl<Io> Drop for ConnectionPoolInner<Io>
where
Io: AsyncWrite + Unpin + 'static,
{
fn drop(&mut self) {
// When strong count is one it means the pool is dropped
// remove and drop all Io types.
if Rc::strong_count(&self.0) == 1 {
self.permits.close();
std::mem::take(&mut *self.available.borrow_mut())
.into_iter()
.for_each(|(_, conns)| {
conns.into_iter().for_each(|pooled| self.close(pooled.conn))
});
}
}
}
pub struct ConnectionPoolInnerPriv<Io>
where
Io: AsyncWrite + Unpin + 'static,
{
config: ConnectorConfig,
available: RefCell<AHashMap<Key, VecDeque<PooledConnection<Io>>>>,
permits: Arc<Semaphore>,
}
impl<S, Io> ConnectionPool<S, Io>
where
Io: AsyncWrite + Unpin + 'static,
{
/// Construct a new connection pool.
///
/// [`super::config::ConnectorConfig`]'s `limit` is used as the max permits allowed for
/// in-flight connections.
///
/// The pool can only have equal to `limit` amount of requests spawning/using Io type
/// concurrently.
///
/// Any requests beyond limit would be wait in fifo order and get notified in async manner
/// by [`tokio::sync::Semaphore`]
pub(crate) fn new(connector: S, config: ConnectorConfig) -> Self {
let inner = ConnectionPoolInner::new(config);
Self { connector, inner }
}
}
impl<S, Io> Service<Connect> for ConnectionPool<S, Io>
where
S: Service<Connect, Response = (Io, Protocol), Error = ConnectError>
+ Clone
+ 'static,
Io: ConnectionIo,
{
type Response = ConnectionType<Io>;
type Error = ConnectError;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
actix_service::forward_ready!(connector);
fn call(&self, req: Connect) -> Self::Future {
let connector = self.connector.clone();
let inner = self.inner.clone();
Box::pin(async move {
let key = if let Some(authority) = req.uri.authority() {
authority.clone().into()
} else {
return Err(ConnectError::Unresolved);
};
// acquire an owned permit and carry it with connection
let permit = inner.permits.clone().acquire_owned().await.map_err(|_| {
ConnectError::Io(io::Error::new(
io::ErrorKind::Other,
"failed to acquire semaphore on client connection pool",
))
})?;
let conn = {
let mut conn = None;
// check if there is idle connection for given key.
let mut map = inner.available.borrow_mut();
if let Some(conns) = map.get_mut(&key) {
let now = Instant::now();
while let Some(mut c) = conns.pop_front() {
let config = &inner.config;
let idle_dur = now - c.used;
let age = now - c.created;
let conn_ineligible = idle_dur > config.conn_keep_alive
|| age > config.conn_lifetime;
if conn_ineligible {
// drop connections that are too old
inner.close(c.conn);
} else {
// check if the connection is still usable
if let ConnectionInnerType::H1(ref mut io) = c.conn {
let check = ConnectionCheckFuture { io };
match check.await {
ConnectionState::Tainted => {
inner.close(c.conn);
continue;
}
ConnectionState::Skip => continue,
ConnectionState::Live => conn = Some(c),
}
} else {
conn = Some(c);
}
break;
}
}
};
conn
};
// construct acquired. It's used to put Io type back to pool/ close the Io type.
// permit is carried with the whole lifecycle of Acquired.
let acquired = Acquired { key, inner, permit };
// match the connection and spawn new one if did not get anything.
match conn {
Some(conn) => {
Ok(ConnectionType::from_pool(conn.conn, conn.created, acquired))
}
None => {
let (io, proto) = connector.call(req).await?;
// TODO: remove when http3 is added in support.
assert!(proto != Protocol::Http3);
if proto == Protocol::Http1 {
Ok(ConnectionType::from_h1(io, Instant::now(), acquired))
} else {
let config = &acquired.inner.config;
let (sender, connection) = handshake(io, config).await?;
let inner = H2ConnectionInner::new(sender, connection);
Ok(ConnectionType::from_h2(inner, Instant::now(), acquired))
}
}
}
})
}
}
/// Type for check the connection and determine if it's usable.
struct ConnectionCheckFuture<'a, Io> {
io: &'a mut Io,
}
enum ConnectionState {
/// IO is pending and a new request would wake it.
Live,
/// IO unexpectedly has unread data and should be dropped.
Tainted,
/// IO should be skipped but not dropped.
Skip,
}
impl<Io> Future for ConnectionCheckFuture<'_, Io>
where
Io: AsyncRead + Unpin,
{
type Output = ConnectionState;
// this future is only used to get access to Context.
// It should never return Poll::Pending.
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let mut buf = [0; 2];
let mut read_buf = ReadBuf::new(&mut buf);
let state = match Pin::new(&mut this.io).poll_read(cx, &mut read_buf) {
Poll::Ready(Ok(())) if !read_buf.filled().is_empty() => {
ConnectionState::Tainted
}
Poll::Pending => ConnectionState::Live,
_ => ConnectionState::Skip,
};
Poll::Ready(state)
}
}
struct PooledConnection<Io> {
conn: ConnectionInnerType<Io>,
used: Instant,
created: Instant,
}
#[pin_project]
struct CloseConnection<Io> {
io: Io,
#[pin]
timeout: Sleep,
}
impl<Io> CloseConnection<Io>
where
Io: AsyncWrite + Unpin,
{
fn new(io: Io, timeout: Duration) -> Self {
CloseConnection {
io,
timeout: sleep(timeout),
}
}
}
impl<Io> Future for CloseConnection<Io>
where
Io: AsyncWrite + Unpin,
{
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let this = self.project();
match this.timeout.poll(cx) {
Poll::Ready(_) => Poll::Ready(()),
Poll::Pending => Pin::new(this.io).poll_shutdown(cx).map(|_| ()),
}
}
}
pub struct Acquired<Io>
where
Io: AsyncWrite + Unpin + 'static,
{
/// authority key for identify connection.
key: Key,
/// handle to connection pool.
inner: ConnectionPoolInner<Io>,
/// permit for limit concurrent in-flight connection for a Client object.
permit: OwnedSemaphorePermit,
}
impl<Io: ConnectionIo> Acquired<Io> {
/// Close the IO.
pub(super) fn close(&self, conn: ConnectionInnerType<Io>) {
self.inner.close(conn);
}
/// Release IO back into pool.
pub(super) fn release(&self, conn: ConnectionInnerType<Io>, created: Instant) {
let Acquired { key, inner, .. } = self;
inner
.available
.borrow_mut()
.entry(key.clone())
.or_insert_with(VecDeque::new)
.push_back(PooledConnection {
conn,
created,
used: Instant::now(),
});
let _ = &self.permit;
}
}
#[cfg(test)]
mod test {
use std::{cell::Cell, io};
use http::Uri;
use super::*;
use crate::client::connection::ConnectionType;
/// A stream type that always returns pending on async read.
///
/// Mocks an idle TCP stream that is ready to be used for client connections.
struct TestStream(Rc<Cell<usize>>);
impl Drop for TestStream {
fn drop(&mut self) {
self.0.set(self.0.get() - 1);
}
}
impl AsyncRead for TestStream {
fn poll_read(
self: Pin<&mut Self>,
_: &mut Context<'_>,
_: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Poll::Pending
}
}
impl AsyncWrite for TestStream {
fn poll_write(
self: Pin<&mut Self>,
_: &mut Context<'_>,
_: &[u8],
) -> Poll<io::Result<usize>> {
unimplemented!()
}
fn poll_flush(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<io::Result<()>> {
unimplemented!()
}
fn poll_shutdown(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[derive(Clone)]
struct TestPoolConnector {
generated: Rc<Cell<usize>>,
}
impl Service<Connect> for TestPoolConnector {
type Response = (TestStream, Protocol);
type Error = ConnectError;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
actix_service::always_ready!();
fn call(&self, _: Connect) -> Self::Future {
self.generated.set(self.generated.get() + 1);
let generated = self.generated.clone();
Box::pin(async { Ok((TestStream(generated), Protocol::Http1)) })
}
}
fn release<T>(conn: ConnectionType<T>)
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
match conn {
ConnectionType::H1(mut conn) => conn.on_release(true),
ConnectionType::H2(mut conn) => conn.on_release(false),
}
}
#[actix_rt::test]
async fn test_pool_limit() {
let connector = TestPoolConnector {
generated: Rc::new(Cell::new(0)),
};
let config = ConnectorConfig {
limit: 1,
..Default::default()
};
let pool = super::ConnectionPool::new(connector, config);
let req = Connect {
uri: Uri::from_static("http://localhost"),
addr: None,
};
let conn = pool.call(req.clone()).await.unwrap();
let waiting = Rc::new(Cell::new(true));
let waiting_clone = waiting.clone();
actix_rt::spawn(async move {
actix_rt::time::sleep(Duration::from_millis(100)).await;
waiting_clone.set(false);
drop(conn);
});
assert!(waiting.get());
let now = Instant::now();
let conn = pool.call(req).await.unwrap();
release(conn);
assert!(!waiting.get());
assert!(now.elapsed() >= Duration::from_millis(100));
}
#[actix_rt::test]
async fn test_pool_keep_alive() {
let generated = Rc::new(Cell::new(0));
let generated_clone = generated.clone();
let connector = TestPoolConnector { generated };
let config = ConnectorConfig {
conn_keep_alive: Duration::from_secs(1),
..Default::default()
};
let pool = super::ConnectionPool::new(connector, config);
let req = Connect {
uri: Uri::from_static("http://localhost"),
addr: None,
};
let conn = pool.call(req.clone()).await.unwrap();
assert_eq!(1, generated_clone.get());
release(conn);
let conn = pool.call(req.clone()).await.unwrap();
assert_eq!(1, generated_clone.get());
release(conn);
actix_rt::time::sleep(Duration::from_millis(1500)).await;
actix_rt::task::yield_now().await;
let conn = pool.call(req).await.unwrap();
// Note: spawned recycle connection is not ran yet.
// This is tokio current thread runtime specific behavior.
assert_eq!(2, generated_clone.get());
// yield task so the old connection is properly dropped.
actix_rt::task::yield_now().await;
assert_eq!(1, generated_clone.get());
release(conn);
}
#[actix_rt::test]
async fn test_pool_lifetime() {
let generated = Rc::new(Cell::new(0));
let generated_clone = generated.clone();
let connector = TestPoolConnector { generated };
let config = ConnectorConfig {
conn_lifetime: Duration::from_secs(1),
..Default::default()
};
let pool = super::ConnectionPool::new(connector, config);
let req = Connect {
uri: Uri::from_static("http://localhost"),
addr: None,
};
let conn = pool.call(req.clone()).await.unwrap();
assert_eq!(1, generated_clone.get());
release(conn);
let conn = pool.call(req.clone()).await.unwrap();
assert_eq!(1, generated_clone.get());
release(conn);
actix_rt::time::sleep(Duration::from_millis(1500)).await;
actix_rt::task::yield_now().await;
let conn = pool.call(req).await.unwrap();
// Note: spawned recycle connection is not ran yet.
// This is tokio current thread runtime specific behavior.
assert_eq!(2, generated_clone.get());
// yield task so the old connection is properly dropped.
actix_rt::task::yield_now().await;
assert_eq!(1, generated_clone.get());
release(conn);
}
#[actix_rt::test]
async fn test_pool_authority_key() {
let generated = Rc::new(Cell::new(0));
let generated_clone = generated.clone();
let connector = TestPoolConnector { generated };
let config = ConnectorConfig::default();
let pool = super::ConnectionPool::new(connector, config);
let req = Connect {
uri: Uri::from_static("https://crates.io"),
addr: None,
};
let conn = pool.call(req.clone()).await.unwrap();
assert_eq!(1, generated_clone.get());
release(conn);
let conn = pool.call(req).await.unwrap();
assert_eq!(1, generated_clone.get());
release(conn);
let req = Connect {
uri: Uri::from_static("https://google.com"),
addr: None,
};
let conn = pool.call(req.clone()).await.unwrap();
assert_eq!(2, generated_clone.get());
release(conn);
let conn = pool.call(req).await.unwrap();
assert_eq!(2, generated_clone.get());
release(conn);
}
#[actix_rt::test]
async fn test_pool_drop() {
let generated = Rc::new(Cell::new(0));
let generated_clone = generated.clone();
let connector = TestPoolConnector { generated };
let config = ConnectorConfig::default();
let pool = Rc::new(super::ConnectionPool::new(connector, config));
let req = Connect {
uri: Uri::from_static("https://crates.io"),
addr: None,
};
let conn = pool.call(req.clone()).await.unwrap();
assert_eq!(1, generated_clone.get());
release(conn);
let req = Connect {
uri: Uri::from_static("https://google.com"),
addr: None,
};
let conn = pool.call(req.clone()).await.unwrap();
assert_eq!(2, generated_clone.get());
release(conn);
let clone1 = pool.clone();
let clone2 = clone1.clone();
drop(clone2);
for _ in 0..2 {
actix_rt::task::yield_now().await;
}
assert_eq!(2, generated_clone.get());
drop(clone1);
for _ in 0..2 {
actix_rt::task::yield_now().await;
}
assert_eq!(2, generated_clone.get());
drop(pool);
for _ in 0..2 {
actix_rt::task::yield_now().await;
}
assert_eq!(0, generated_clone.get());
}
}

View File

@ -1,26 +1,29 @@
use std::cell::Cell;
use std::fmt::Write;
use std::rc::Rc;
use std::time::Duration;
use std::{fmt, net};
use std::{
cell::Cell,
fmt::{self, Write},
net,
rc::Rc,
time::{Duration, SystemTime},
};
use actix_rt::{
task::JoinHandle,
time::{interval, sleep_until, Instant, Sleep},
};
use bytes::BytesMut;
use time::OffsetDateTime;
/// "Sun, 06 Nov 1994 08:49:37 GMT".len()
const DATE_VALUE_LENGTH: usize = 29;
pub(crate) const DATE_VALUE_LENGTH: usize = 29;
#[derive(Debug, PartialEq, Clone, Copy)]
/// Server keep-alive setting
pub enum KeepAlive {
/// Keep alive in seconds
Timeout(usize),
/// Rely on OS to shutdown tcp connection
Os,
/// Disabled
Disabled,
}
@ -206,12 +209,7 @@ impl Date {
fn update(&mut self) {
self.pos = 0;
write!(
self,
"{}",
OffsetDateTime::now_utc().format("%a, %d %b %Y %H:%M:%S GMT")
)
.unwrap();
write!(self, "{}", httpdate::fmt_http_date(SystemTime::now())).unwrap();
}
}
@ -269,11 +267,11 @@ impl DateService {
}
// TODO: move to a util module for testing all spawn handle drop style tasks.
#[cfg(test)]
/// Test Module for checking the drop state of certain async tasks that are spawned
/// with `actix_rt::spawn`
///
/// The target task must explicitly generate `NotifyOnDrop` when spawn the task
#[cfg(test)]
mod notify_on_drop {
use std::cell::RefCell;
@ -283,9 +281,8 @@ mod notify_on_drop {
/// Check if the spawned task is dropped.
///
/// # Panic:
///
/// When there was no `NotifyOnDrop` instance on current thread
/// # Panics
/// Panics when there was no `NotifyOnDrop` instance on current thread.
pub(crate) fn is_dropped() -> bool {
NOTIFY_DROPPED.with(|bool| {
bool.borrow()

View File

@ -80,7 +80,7 @@ where
let encoding = headers
.get(&CONTENT_ENCODING)
.and_then(|val| val.to_str().ok())
.map(ContentEncoding::from)
.and_then(|x| x.parse().ok())
.unwrap_or(ContentEncoding::Identity);
Self::new(stream, encoding)

View File

@ -12,7 +12,7 @@ use actix_rt::task::{spawn_blocking, JoinHandle};
use bytes::Bytes;
use derive_more::Display;
use futures_core::ready;
use pin_project::pin_project;
use pin_project_lite::pin_project;
#[cfg(feature = "compress-brotli")]
use brotli2::write::BrotliEncoder;
@ -23,99 +23,103 @@ use flate2::write::{GzEncoder, ZlibEncoder};
#[cfg(feature = "compress-zstd")]
use zstd::stream::write::Encoder as ZstdEncoder;
use super::Writer;
use crate::{
body::{Body, BodySize, BoxAnyBody, MessageBody, ResponseBody},
body::{BodySize, MessageBody},
error::BlockingError,
http::{
header::{ContentEncoding, CONTENT_ENCODING},
HeaderValue, StatusCode,
},
Error, ResponseHead,
ResponseHead,
};
use super::Writer;
use crate::error::BlockingError;
const MAX_CHUNK_SIZE_ENCODE_IN_PLACE: usize = 1024;
#[pin_project]
pub struct Encoder<B> {
eof: bool,
#[pin]
body: EncoderBody<B>,
encoder: Option<ContentEncoder>,
fut: Option<JoinHandle<Result<ContentEncoder, io::Error>>>,
pin_project! {
pub struct Encoder<B> {
#[pin]
body: EncoderBody<B>,
encoder: Option<ContentEncoder>,
fut: Option<JoinHandle<Result<ContentEncoder, io::Error>>>,
eof: bool,
}
}
impl<B: MessageBody> Encoder<B> {
fn none() -> Self {
Encoder {
body: EncoderBody::None,
encoder: None,
fut: None,
eof: true,
}
}
pub fn response(
encoding: ContentEncoding,
head: &mut ResponseHead,
body: ResponseBody<B>,
) -> ResponseBody<Encoder<B>> {
body: B,
) -> Self {
let can_encode = !(head.headers().contains_key(&CONTENT_ENCODING)
|| head.status == StatusCode::SWITCHING_PROTOCOLS
|| head.status == StatusCode::NO_CONTENT
|| encoding == ContentEncoding::Identity
|| encoding == ContentEncoding::Auto);
let body = match body {
ResponseBody::Other(b) => match b {
Body::None => return ResponseBody::Other(Body::None),
Body::Empty => return ResponseBody::Other(Body::Empty),
Body::Bytes(buf) => {
if can_encode {
EncoderBody::Bytes(buf)
} else {
return ResponseBody::Other(Body::Bytes(buf));
}
}
Body::Message(stream) => EncoderBody::BoxedStream(stream),
},
ResponseBody::Body(stream) => EncoderBody::Stream(stream),
};
match body.size() {
// no need to compress an empty body
BodySize::None => return Self::none(),
// we cannot assume that Sized is not a stream
BodySize::Sized(_) | BodySize::Stream => {}
}
// TODO potentially some optimisation for single-chunk responses here by trying to read the
// payload eagerly, stopping after 2 polls if the first is a chunk and the second is None
if can_encode {
// Modify response body only if encoder is not None
// Modify response body only if encoder is set
if let Some(enc) = ContentEncoder::encoder(encoding) {
update_head(encoding, head);
head.no_chunking(false);
return ResponseBody::Body(Encoder {
body,
eof: false,
fut: None,
return Encoder {
body: EncoderBody::Stream { body },
encoder: Some(enc),
});
fut: None,
eof: false,
};
}
}
ResponseBody::Body(Encoder {
body,
eof: false,
fut: None,
Encoder {
body: EncoderBody::Stream { body },
encoder: None,
})
fut: None,
eof: false,
}
}
}
#[pin_project(project = EncoderBodyProj)]
enum EncoderBody<B> {
Bytes(Bytes),
Stream(#[pin] B),
BoxedStream(BoxAnyBody),
pin_project! {
#[project = EncoderBodyProj]
enum EncoderBody<B> {
None,
Stream { #[pin] body: B },
}
}
impl<B> MessageBody for EncoderBody<B>
where
B: MessageBody,
B::Error: Into<Error>,
{
type Error = EncoderError<B::Error>;
type Error = EncoderError;
fn size(&self) -> BodySize {
match self {
EncoderBody::Bytes(ref b) => b.size(),
EncoderBody::Stream(ref b) => b.size(),
EncoderBody::BoxedStream(ref b) => b.size(),
EncoderBody::None => BodySize::None,
EncoderBody::Stream { body } => body.size(),
}
}
@ -124,26 +128,11 @@ where
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
match self.project() {
EncoderBodyProj::Bytes(b) => {
if b.is_empty() {
Poll::Ready(None)
} else {
Poll::Ready(Some(Ok(std::mem::take(b))))
}
}
// TODO: MSRV 1.51: poll_map_err
EncoderBodyProj::Stream(b) => match ready!(b.poll_next(cx)) {
Some(Err(err)) => Poll::Ready(Some(Err(EncoderError::Body(err)))),
Some(Ok(val)) => Poll::Ready(Some(Ok(val))),
None => Poll::Ready(None),
},
EncoderBodyProj::BoxedStream(ref mut b) => {
match ready!(b.as_pin_mut().poll_next(cx)) {
Some(Err(err)) => Poll::Ready(Some(Err(EncoderError::Boxed(err)))),
Some(Ok(val)) => Poll::Ready(Some(Ok(val))),
None => Poll::Ready(None),
}
}
EncoderBodyProj::None => Poll::Ready(None),
EncoderBodyProj::Stream { body } => body
.poll_next(cx)
.map_err(|err| EncoderError::Body(err.into())),
}
}
}
@ -151,9 +140,8 @@ where
impl<B> MessageBody for Encoder<B>
where
B: MessageBody,
B::Error: Into<Error>,
{
type Error = EncoderError<B::Error>;
type Error = EncoderError;
fn size(&self) -> BodySize {
if self.encoder.is_none() {
@ -216,6 +204,7 @@ where
None => {
if let Some(encoder) = this.encoder.take() {
let chunk = encoder.finish().map_err(EncoderError::Io)?;
if chunk.is_empty() {
return Poll::Ready(None);
} else {
@ -241,12 +230,15 @@ fn update_head(encoding: ContentEncoding, head: &mut ResponseHead) {
enum ContentEncoder {
#[cfg(feature = "compress-gzip")]
Deflate(ZlibEncoder<Writer>),
#[cfg(feature = "compress-gzip")]
Gzip(GzEncoder<Writer>),
#[cfg(feature = "compress-brotli")]
Br(BrotliEncoder<Writer>),
// We need explicit 'static lifetime here because ZstdEncoder need lifetime
// argument, and we use `spawn_blocking` in `Encoder::poll_next` that require `FnOnce() -> R + Send + 'static`
// Wwe need explicit 'static lifetime here because ZstdEncoder needs a lifetime argument and we
// use `spawn_blocking` in `Encoder::poll_next` that requires `FnOnce() -> R + Send + 'static`.
#[cfg(feature = "compress-zstd")]
Zstd(ZstdEncoder<'static, Writer>),
}
@ -259,20 +251,24 @@ impl ContentEncoder {
Writer::new(),
flate2::Compression::fast(),
))),
#[cfg(feature = "compress-gzip")]
ContentEncoding::Gzip => Some(ContentEncoder::Gzip(GzEncoder::new(
Writer::new(),
flate2::Compression::fast(),
))),
#[cfg(feature = "compress-brotli")]
ContentEncoding::Br => {
Some(ContentEncoder::Br(BrotliEncoder::new(Writer::new(), 3)))
}
#[cfg(feature = "compress-zstd")]
ContentEncoding::Zstd => {
let encoder = ZstdEncoder::new(Writer::new(), 3).ok()?;
Some(ContentEncoder::Zstd(encoder))
}
_ => None,
}
}
@ -282,10 +278,13 @@ impl ContentEncoder {
match *self {
#[cfg(feature = "compress-brotli")]
ContentEncoder::Br(ref mut encoder) => encoder.get_mut().take(),
#[cfg(feature = "compress-gzip")]
ContentEncoder::Deflate(ref mut encoder) => encoder.get_mut().take(),
#[cfg(feature = "compress-gzip")]
ContentEncoder::Gzip(ref mut encoder) => encoder.get_mut().take(),
#[cfg(feature = "compress-zstd")]
ContentEncoder::Zstd(ref mut encoder) => encoder.get_mut().take(),
}
@ -298,16 +297,19 @@ impl ContentEncoder {
Ok(writer) => Ok(writer.buf.freeze()),
Err(err) => Err(err),
},
#[cfg(feature = "compress-gzip")]
ContentEncoder::Gzip(encoder) => match encoder.finish() {
Ok(writer) => Ok(writer.buf.freeze()),
Err(err) => Err(err),
},
#[cfg(feature = "compress-gzip")]
ContentEncoder::Deflate(encoder) => match encoder.finish() {
Ok(writer) => Ok(writer.buf.freeze()),
Err(err) => Err(err),
},
#[cfg(feature = "compress-zstd")]
ContentEncoder::Zstd(encoder) => match encoder.finish() {
Ok(writer) => Ok(writer.buf.freeze()),
@ -326,6 +328,7 @@ impl ContentEncoder {
Err(err)
}
},
#[cfg(feature = "compress-gzip")]
ContentEncoder::Gzip(ref mut encoder) => match encoder.write_all(data) {
Ok(_) => Ok(()),
@ -334,6 +337,7 @@ impl ContentEncoder {
Err(err)
}
},
#[cfg(feature = "compress-gzip")]
ContentEncoder::Deflate(ref mut encoder) => match encoder.write_all(data) {
Ok(_) => Ok(()),
@ -342,6 +346,7 @@ impl ContentEncoder {
Err(err)
}
},
#[cfg(feature = "compress-zstd")]
ContentEncoder::Zstd(ref mut encoder) => match encoder.write_all(data) {
Ok(_) => Ok(()),
@ -356,12 +361,9 @@ impl ContentEncoder {
#[derive(Debug, Display)]
#[non_exhaustive]
pub enum EncoderError<E> {
pub enum EncoderError {
#[display(fmt = "body")]
Body(E),
#[display(fmt = "boxed")]
Boxed(Box<dyn StdError>),
Body(Box<dyn StdError>),
#[display(fmt = "blocking")]
Blocking(BlockingError),
@ -370,19 +372,18 @@ pub enum EncoderError<E> {
Io(io::Error),
}
impl<E: StdError + 'static> StdError for EncoderError<E> {
impl StdError for EncoderError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self {
EncoderError::Body(err) => Some(err),
EncoderError::Boxed(err) => Some(&**err),
EncoderError::Body(err) => Some(&**err),
EncoderError::Blocking(err) => Some(err),
EncoderError::Io(err) => Some(err),
}
}
}
impl<E: StdError + 'static> From<EncoderError<E>> for crate::Error {
fn from(err: EncoderError<E>) -> Self {
impl From<EncoderError> for crate::Error {
fn from(err: EncoderError) -> Self {
crate::Error::new_encoder().with_cause(err)
}
}

View File

@ -10,6 +10,9 @@ mod encoder;
pub use self::decoder::Decoder;
pub use self::encoder::Encoder;
/// Special-purpose writer for streaming (de-)compression.
///
/// Pre-allocates 8KiB of capacity.
pub(self) struct Writer {
buf: BytesMut,
}

View File

@ -5,10 +5,7 @@ use std::{error::Error as StdError, fmt, io, str::Utf8Error, string::FromUtf8Err
use derive_more::{Display, Error, From};
use http::{uri::InvalidUri, StatusCode};
use crate::{
body::{AnyBody, Body},
ws, Response,
};
use crate::{body::BoxBody, ws, Response};
pub use http::Error as HttpError;
@ -29,6 +26,11 @@ impl Error {
}
}
pub(crate) fn with_cause(mut self, cause: impl Into<Box<dyn StdError>>) -> Self {
self.inner.cause = Some(cause.into());
self
}
pub(crate) fn new_http() -> Self {
Self::new(Kind::Http)
}
@ -49,14 +51,12 @@ impl Error {
Self::new(Kind::SendResponse)
}
// TODO: remove allow
#[allow(dead_code)]
#[allow(unused)] // reserved for future use (TODO: remove allow when being used)
pub(crate) fn new_io() -> Self {
Self::new(Kind::Io)
}
// used in encoder behind feature flag so ignore unused warning
#[allow(unused)]
#[allow(unused)] // used in encoder behind feature flag so ignore unused warning
pub(crate) fn new_encoder() -> Self {
Self::new(Kind::Encoder)
}
@ -64,26 +64,22 @@ impl Error {
pub(crate) fn new_ws() -> Self {
Self::new(Kind::Ws)
}
pub(crate) fn with_cause(mut self, cause: impl Into<Box<dyn StdError>>) -> Self {
self.inner.cause = Some(cause.into());
self
}
}
impl From<Error> for Response<AnyBody> {
impl From<Error> for Response<BoxBody> {
fn from(err: Error) -> Self {
// TODO: more appropriate error status codes, usage assessment needed
let status_code = match err.inner.kind {
Kind::Parse => StatusCode::BAD_REQUEST,
_ => StatusCode::INTERNAL_SERVER_ERROR,
};
Response::new(status_code).set_body(Body::from(err.to_string()))
Response::new(status_code).set_body(BoxBody::new(err.to_string()))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)]
pub enum Kind {
pub(crate) enum Kind {
#[display(fmt = "error processing HTTP")]
Http,
@ -137,12 +133,6 @@ impl From<std::convert::Infallible> for Error {
}
}
impl From<ws::ProtocolError> for Error {
fn from(err: ws::ProtocolError) -> Self {
Self::new_ws().with_cause(err)
}
}
impl From<HttpError> for Error {
fn from(err: HttpError) -> Self {
Self::new_http().with_cause(err)
@ -155,6 +145,12 @@ impl From<ws::HandshakeError> for Error {
}
}
impl From<ws::ProtocolError> for Error {
fn from(err: ws::ProtocolError) -> Self {
Self::new_ws().with_cause(err)
}
}
/// A set of errors that can occur during parsing HTTP streams.
#[derive(Debug, Display, Error)]
#[non_exhaustive]
@ -196,7 +192,7 @@ pub enum ParseError {
#[display(fmt = "IO error: {}", _0)]
Io(io::Error),
/// Parsing a field as string failed
/// Parsing a field as string failed.
#[display(fmt = "UTF8 error: {}", _0)]
Utf8(Utf8Error),
}
@ -245,7 +241,7 @@ impl From<ParseError> for Error {
}
}
impl From<ParseError> for Response<AnyBody> {
impl From<ParseError> for Response<BoxBody> {
fn from(err: ParseError) -> Self {
Error::from(err).into()
}
@ -342,7 +338,7 @@ pub enum DispatchError {
/// Service error
// FIXME: display and error type
#[display(fmt = "Service Error")]
Service(#[error(not(source))] Response<AnyBody>),
Service(#[error(not(source))] Response<BoxBody>),
/// Body error
// FIXME: display and error type
@ -426,11 +422,11 @@ mod tests {
#[test]
fn test_into_response() {
let resp: Response<AnyBody> = ParseError::Incomplete.into();
let resp: Response<BoxBody> = ParseError::Incomplete.into();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let err: HttpError = StatusCode::from_u16(10000).err().unwrap().into();
let resp: Response<AnyBody> = Error::new_http().with_cause(err).into();
let resp: Response<BoxBody> = Error::new_http().with_cause(err).into();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
@ -455,7 +451,7 @@ mod tests {
fn test_error_http_response() {
let orig = io::Error::new(io::ErrorKind::Other, "other");
let err = Error::new_io().with_cause(orig);
let resp: Response<AnyBody> = err.into();
let resp: Response<BoxBody> = err.into();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
}

View File

@ -0,0 +1,432 @@
use std::{io, task::Poll};
use bytes::{Buf as _, Bytes, BytesMut};
macro_rules! byte (
($rdr:ident) => ({
if $rdr.len() > 0 {
let b = $rdr[0];
$rdr.advance(1);
b
} else {
return Poll::Pending
}
})
);
#[derive(Debug, PartialEq, Clone)]
pub(super) enum ChunkedState {
Size,
SizeLws,
Extension,
SizeLf,
Body,
BodyCr,
BodyLf,
EndCr,
EndLf,
End,
}
impl ChunkedState {
pub(super) fn step(
&self,
body: &mut BytesMut,
size: &mut u64,
buf: &mut Option<Bytes>,
) -> Poll<Result<ChunkedState, io::Error>> {
use self::ChunkedState::*;
match *self {
Size => ChunkedState::read_size(body, size),
SizeLws => ChunkedState::read_size_lws(body),
Extension => ChunkedState::read_extension(body),
SizeLf => ChunkedState::read_size_lf(body, *size),
Body => ChunkedState::read_body(body, size, buf),
BodyCr => ChunkedState::read_body_cr(body),
BodyLf => ChunkedState::read_body_lf(body),
EndCr => ChunkedState::read_end_cr(body),
EndLf => ChunkedState::read_end_lf(body),
End => Poll::Ready(Ok(ChunkedState::End)),
}
}
fn read_size(
rdr: &mut BytesMut,
size: &mut u64,
) -> Poll<Result<ChunkedState, io::Error>> {
let radix = 16;
let rem = match byte!(rdr) {
b @ b'0'..=b'9' => b - b'0',
b @ b'a'..=b'f' => b + 10 - b'a',
b @ b'A'..=b'F' => b + 10 - b'A',
b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)),
b';' => return Poll::Ready(Ok(ChunkedState::Extension)),
b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)),
_ => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk size line: Invalid Size",
)));
}
};
match size.checked_mul(radix) {
Some(n) => {
*size = n as u64;
*size += rem as u64;
Poll::Ready(Ok(ChunkedState::Size))
}
None => {
log::debug!("chunk size would overflow u64");
Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk size line: Size is too big",
)))
}
}
}
fn read_size_lws(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
// LWS can follow the chunk size, but no more digits can come
b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)),
b';' => Poll::Ready(Ok(ChunkedState::Extension)),
b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk size linear white space",
))),
}
}
fn read_extension(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
// strictly 0x20 (space) should be disallowed but we don't parse quoted strings here
0x00..=0x08 | 0x0a..=0x1f | 0x7f => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid character in chunk extension",
))),
_ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions
}
}
fn read_size_lf(
rdr: &mut BytesMut,
size: u64,
) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\n' if size > 0 => Poll::Ready(Ok(ChunkedState::Body)),
b'\n' if size == 0 => Poll::Ready(Ok(ChunkedState::EndCr)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk size LF",
))),
}
}
fn read_body(
rdr: &mut BytesMut,
rem: &mut u64,
buf: &mut Option<Bytes>,
) -> Poll<Result<ChunkedState, io::Error>> {
log::trace!("Chunked read, remaining={:?}", rem);
let len = rdr.len() as u64;
if len == 0 {
Poll::Ready(Ok(ChunkedState::Body))
} else {
let slice;
if *rem > len {
slice = rdr.split().freeze();
*rem -= len;
} else {
slice = rdr.split_to(*rem as usize).freeze();
*rem = 0;
}
*buf = Some(slice);
if *rem > 0 {
Poll::Ready(Ok(ChunkedState::Body))
} else {
Poll::Ready(Ok(ChunkedState::BodyCr))
}
}
}
fn read_body_cr(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk body CR",
))),
}
}
fn read_body_lf(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\n' => Poll::Ready(Ok(ChunkedState::Size)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk body LF",
))),
}
}
fn read_end_cr(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk end CR",
))),
}
}
fn read_end_lf(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\n' => Poll::Ready(Ok(ChunkedState::End)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk end LF",
))),
}
}
}
#[cfg(test)]
mod tests {
use actix_codec::Decoder as _;
use bytes::{Bytes, BytesMut};
use http::Method;
use crate::{
error::ParseError,
h1::decoder::{MessageDecoder, PayloadItem},
HttpMessage as _, Request,
};
macro_rules! parse_ready {
($e:expr) => {{
match MessageDecoder::<Request>::default().decode($e) {
Ok(Some((msg, _))) => msg,
Ok(_) => unreachable!("Eof during parsing http request"),
Err(err) => unreachable!("Error during parsing http request: {:?}", err),
}
}};
}
macro_rules! expect_parse_err {
($e:expr) => {{
match MessageDecoder::<Request>::default().decode($e) {
Err(err) => match err {
ParseError::Io(_) => unreachable!("Parse error expected"),
_ => {}
},
_ => unreachable!("Error expected"),
}
}};
}
#[test]
fn test_parse_chunked_payload_chunk_extension() {
let mut buf = BytesMut::from(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\
\r\n",
);
let mut reader = MessageDecoder::<Request>::default();
let (msg, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
assert!(msg.chunked().unwrap());
buf.extend(b"4;test\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); // test: test\r\n\r\n")
let chunk = pl.decode(&mut buf).unwrap().unwrap().chunk();
assert_eq!(chunk, Bytes::from_static(b"data"));
let chunk = pl.decode(&mut buf).unwrap().unwrap().chunk();
assert_eq!(chunk, Bytes::from_static(b"line"));
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert!(msg.eof());
}
#[test]
fn test_request_chunked() {
let mut buf = BytesMut::from(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n",
);
let req = parse_ready!(&mut buf);
if let Ok(val) = req.chunked() {
assert!(val);
} else {
unreachable!("Error");
}
// intentional typo in "chunked"
let mut buf = BytesMut::from(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chnked\r\n\r\n",
);
expect_parse_err!(&mut buf);
}
#[test]
fn test_http_request_chunked_payload() {
let mut buf = BytesMut::from(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n",
);
let mut reader = MessageDecoder::<Request>::default();
let (req, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
assert!(req.chunked().unwrap());
buf.extend(b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n");
assert_eq!(
pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(),
b"data"
);
assert_eq!(
pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(),
b"line"
);
assert!(pl.decode(&mut buf).unwrap().unwrap().eof());
}
#[test]
fn test_http_request_chunked_payload_and_next_message() {
let mut buf = BytesMut::from(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n",
);
let mut reader = MessageDecoder::<Request>::default();
let (req, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
assert!(req.chunked().unwrap());
buf.extend(
b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n\
POST /test2 HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n"
.iter(),
);
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(msg.chunk().as_ref(), b"data");
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(msg.chunk().as_ref(), b"line");
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert!(msg.eof());
let (req, _) = reader.decode(&mut buf).unwrap().unwrap();
assert!(req.chunked().unwrap());
assert_eq!(*req.method(), Method::POST);
assert!(req.chunked().unwrap());
}
#[test]
fn test_http_request_chunked_payload_chunks() {
let mut buf = BytesMut::from(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n",
);
let mut reader = MessageDecoder::<Request>::default();
let (req, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
assert!(req.chunked().unwrap());
buf.extend(b"4\r\n1111\r\n");
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(msg.chunk().as_ref(), b"1111");
buf.extend(b"4\r\ndata\r");
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(msg.chunk().as_ref(), b"data");
buf.extend(b"\n4");
assert!(pl.decode(&mut buf).unwrap().is_none());
buf.extend(b"\r");
assert!(pl.decode(&mut buf).unwrap().is_none());
buf.extend(b"\n");
assert!(pl.decode(&mut buf).unwrap().is_none());
buf.extend(b"li");
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(msg.chunk().as_ref(), b"li");
//trailers
//buf.feed_data("test: test\r\n");
//not_ready!(reader.parse(&mut buf, &mut readbuf));
buf.extend(b"ne\r\n0\r\n");
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(msg.chunk().as_ref(), b"ne");
assert!(pl.decode(&mut buf).unwrap().is_none());
buf.extend(b"\r\n");
assert!(pl.decode(&mut buf).unwrap().unwrap().eof());
}
#[test]
fn chunk_extension_quoted() {
let mut buf = BytesMut::from(
"GET /test HTTP/1.1\r\n\
Host: localhost:8080\r\n\
Transfer-Encoding: chunked\r\n\
\r\n\
2;hello=b;one=\"1 2 3\"\r\n\
xx",
);
let mut reader = MessageDecoder::<Request>::default();
let (_msg, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
let chunk = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(chunk, PayloadItem::Chunk(Bytes::from_static(b"xx")));
}
#[test]
fn hrs_chunk_extension_invalid() {
let mut buf = BytesMut::from(
"GET / HTTP/1.1\r\n\
Host: localhost:8080\r\n\
Transfer-Encoding: chunked\r\n\
\r\n\
2;x\nx\r\n\
4c\r\n\
0\r\n",
);
let mut reader = MessageDecoder::<Request>::default();
let (_msg, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
let err = pl.decode(&mut buf).unwrap_err();
assert!(err
.to_string()
.contains("Invalid character in chunk extension"));
}
#[test]
fn hrs_chunk_size_overflow() {
let mut buf = BytesMut::from(
"GET / HTTP/1.1\r\n\
Host: example.com\r\n\
Transfer-Encoding: chunked\r\n\
\r\n\
f0000000000000003\r\n\
abc\r\n\
0\r\n",
);
let mut reader = MessageDecoder::<Request>::default();
let (_msg, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
let err = pl.decode(&mut buf).unwrap_err();
assert!(err
.to_string()
.contains("Invalid chunk size line: Size is too big"));
}
}

View File

@ -120,7 +120,7 @@ impl Decoder for ClientCodec {
debug_assert!(!self.inner.payload.is_some(), "Payload decoder is set");
if let Some((req, payload)) = self.inner.decoder.decode(src)? {
if let Some(ctype) = req.ctype() {
if let Some(ctype) = req.conn_type() {
// do not use peer's keep-alive
self.inner.ctype = if ctype == ConnectionType::KeepAlive {
self.inner.ctype

View File

@ -29,7 +29,7 @@ pub struct Codec {
decoder: decoder::MessageDecoder<Request>,
payload: Option<PayloadDecoder>,
version: Version,
ctype: ConnectionType,
conn_type: ConnectionType,
// encoder part
flags: Flags,
@ -65,7 +65,7 @@ impl Codec {
decoder: decoder::MessageDecoder::default(),
payload: None,
version: Version::HTTP_11,
ctype: ConnectionType::Close,
conn_type: ConnectionType::Close,
encoder: encoder::MessageEncoder::default(),
}
}
@ -73,13 +73,13 @@ impl Codec {
/// Check if request is upgrade.
#[inline]
pub fn upgrade(&self) -> bool {
self.ctype == ConnectionType::Upgrade
self.conn_type == ConnectionType::Upgrade
}
/// Check if last response is keep-alive.
#[inline]
pub fn keepalive(&self) -> bool {
self.ctype == ConnectionType::KeepAlive
self.conn_type == ConnectionType::KeepAlive
}
/// Check if keep-alive enabled on server level.
@ -124,11 +124,11 @@ impl Decoder for Codec {
let head = req.head();
self.flags.set(Flags::HEAD, head.method == Method::HEAD);
self.version = head.version;
self.ctype = head.connection_type();
if self.ctype == ConnectionType::KeepAlive
self.conn_type = head.connection_type();
if self.conn_type == ConnectionType::KeepAlive
&& !self.flags.contains(Flags::KEEPALIVE_ENABLED)
{
self.ctype = ConnectionType::Close
self.conn_type = ConnectionType::Close
}
match payload {
PayloadType::None => self.payload = None,
@ -159,14 +159,14 @@ impl Encoder<Message<(Response<()>, BodySize)>> for Codec {
res.head_mut().version = self.version;
// connection status
self.ctype = if let Some(ct) = res.head().ctype() {
self.conn_type = if let Some(ct) = res.head().conn_type() {
if ct == ConnectionType::KeepAlive {
self.ctype
self.conn_type
} else {
ct
}
} else {
self.ctype
self.conn_type
};
// encode message
@ -177,10 +177,9 @@ impl Encoder<Message<(Response<()>, BodySize)>> for Codec {
self.flags.contains(Flags::STREAM),
self.version,
length,
self.ctype,
self.conn_type,
&self.config,
)?;
// self.headers_size = (dst.len() - len) as u32;
}
Message::Chunk(Some(bytes)) => {
self.encoder.encode_chunk(bytes.as_ref(), dst)?;
@ -189,6 +188,7 @@ impl Encoder<Message<(Response<()>, BodySize)>> for Codec {
self.encoder.encode_eof(dst)?;
}
}
Ok(())
}
}

View File

@ -1,18 +1,18 @@
use std::convert::TryFrom;
use std::io;
use std::marker::PhantomData;
use std::task::Poll;
use std::{convert::TryFrom, io, marker::PhantomData, mem::MaybeUninit, task::Poll};
use actix_codec::Decoder;
use bytes::{Buf, Bytes, BytesMut};
use bytes::{Bytes, BytesMut};
use http::header::{HeaderName, HeaderValue};
use http::{header, Method, StatusCode, Uri, Version};
use log::{debug, error, trace};
use crate::error::ParseError;
use crate::header::HeaderMap;
use crate::message::{ConnectionType, ResponseHead};
use crate::request::Request;
use super::chunked::ChunkedState;
use crate::{
error::ParseError,
header::HeaderMap,
message::{ConnectionType, ResponseHead},
request::Request,
};
pub(crate) const MAX_BUFFER_SIZE: usize = 131_072;
const MAX_HEADERS: usize = 96;
@ -67,6 +67,7 @@ pub(crate) trait MessageType: Sized {
let mut has_upgrade_websocket = false;
let mut expect = false;
let mut chunked = false;
let mut seen_te = false;
let mut content_length = None;
{
@ -85,8 +86,17 @@ pub(crate) trait MessageType: Sized {
};
match name {
header::CONTENT_LENGTH => {
if let Ok(s) = value.to_str() {
header::CONTENT_LENGTH if content_length.is_some() => {
debug!("multiple Content-Length");
return Err(ParseError::Header);
}
header::CONTENT_LENGTH => match value.to_str() {
Ok(s) if s.trim().starts_with('+') => {
debug!("illegal Content-Length: {:?}", s);
return Err(ParseError::Header);
}
Ok(s) => {
if let Ok(len) = s.parse::<u64>() {
if len != 0 {
content_length = Some(len);
@ -95,15 +105,31 @@ pub(crate) trait MessageType: Sized {
debug!("illegal Content-Length: {:?}", s);
return Err(ParseError::Header);
}
} else {
}
Err(_) => {
debug!("illegal Content-Length: {:?}", value);
return Err(ParseError::Header);
}
}
},
// transfer-encoding
header::TRANSFER_ENCODING if seen_te => {
debug!("multiple Transfer-Encoding not allowed");
return Err(ParseError::Header);
}
header::TRANSFER_ENCODING => {
seen_te = true;
if let Ok(s) = value.to_str().map(str::trim) {
chunked = s.eq_ignore_ascii_case("chunked");
if s.eq_ignore_ascii_case("chunked") {
chunked = true;
} else if s.eq_ignore_ascii_case("identity") {
// allow silently since multiple TE headers are already checked
} else {
debug!("illegal Transfer-Encoding: {:?}", s);
return Err(ParseError::Header);
}
} else {
return Err(ParseError::Header);
}
@ -148,7 +174,7 @@ pub(crate) trait MessageType: Sized {
self.set_expect()
}
// https://tools.ietf.org/html/rfc7230#section-3.3.3
// https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3
if chunked {
// Chunked encoding
Ok(PayloadLength::Payload(PayloadType::Payload(
@ -186,10 +212,17 @@ impl MessageType for Request {
let mut headers: [HeaderIndex; MAX_HEADERS] = EMPTY_HEADER_INDEX_ARRAY;
let (len, method, uri, ver, h_len) = {
let mut parsed: [httparse::Header<'_>; MAX_HEADERS] = EMPTY_HEADER_ARRAY;
// SAFETY:
// Create an uninitialized array of `MaybeUninit`. The `assume_init` is
// safe because the type we are claiming to have initialized here is a
// bunch of `MaybeUninit`s, which do not require initialization.
let mut parsed = unsafe {
MaybeUninit::<[MaybeUninit<httparse::Header<'_>>; MAX_HEADERS]>::uninit()
.assume_init()
};
let mut req = httparse::Request::new(&mut parsed);
match req.parse(src)? {
let mut req = httparse::Request::new(&mut []);
match req.parse_with_uninit_headers(src, &mut parsed)? {
httparse::Status::Complete(len) => {
let method = Method::from_bytes(req.method.unwrap().as_bytes())
.map_err(|_| ParseError::Method)?;
@ -408,20 +441,6 @@ enum Kind {
Eof,
}
#[derive(Debug, PartialEq, Clone)]
enum ChunkedState {
Size,
SizeLws,
Extension,
SizeLf,
Body,
BodyCr,
BodyLf,
EndCr,
EndLf,
End,
}
impl Decoder for PayloadDecoder {
type Item = PayloadItem;
type Error = io::Error;
@ -451,19 +470,23 @@ impl Decoder for PayloadDecoder {
Kind::Chunked(ref mut state, ref mut size) => {
loop {
let mut buf = None;
// advances the chunked state
*state = match state.step(src, size, &mut buf) {
Poll::Pending => return Ok(None),
Poll::Ready(Ok(state)) => state,
Poll::Ready(Err(e)) => return Err(e),
};
if *state == ChunkedState::End {
trace!("End of chunked stream");
return Ok(Some(PayloadItem::Eof));
}
if let Some(buf) = buf {
return Ok(Some(PayloadItem::Chunk(buf)));
}
if src.is_empty() {
return Ok(None);
}
@ -480,201 +503,40 @@ impl Decoder for PayloadDecoder {
}
}
macro_rules! byte (
($rdr:ident) => ({
if $rdr.len() > 0 {
let b = $rdr[0];
$rdr.advance(1);
b
} else {
return Poll::Pending
}
})
);
impl ChunkedState {
fn step(
&self,
body: &mut BytesMut,
size: &mut u64,
buf: &mut Option<Bytes>,
) -> Poll<Result<ChunkedState, io::Error>> {
use self::ChunkedState::*;
match *self {
Size => ChunkedState::read_size(body, size),
SizeLws => ChunkedState::read_size_lws(body),
Extension => ChunkedState::read_extension(body),
SizeLf => ChunkedState::read_size_lf(body, size),
Body => ChunkedState::read_body(body, size, buf),
BodyCr => ChunkedState::read_body_cr(body),
BodyLf => ChunkedState::read_body_lf(body),
EndCr => ChunkedState::read_end_cr(body),
EndLf => ChunkedState::read_end_lf(body),
End => Poll::Ready(Ok(ChunkedState::End)),
}
}
fn read_size(
rdr: &mut BytesMut,
size: &mut u64,
) -> Poll<Result<ChunkedState, io::Error>> {
let radix = 16;
match byte!(rdr) {
b @ b'0'..=b'9' => {
*size *= radix;
*size += u64::from(b - b'0');
}
b @ b'a'..=b'f' => {
*size *= radix;
*size += u64::from(b + 10 - b'a');
}
b @ b'A'..=b'F' => {
*size *= radix;
*size += u64::from(b + 10 - b'A');
}
b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)),
b';' => return Poll::Ready(Ok(ChunkedState::Extension)),
b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)),
_ => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk size line: Invalid Size",
)));
}
}
Poll::Ready(Ok(ChunkedState::Size))
}
fn read_size_lws(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
trace!("read_size_lws");
match byte!(rdr) {
// LWS can follow the chunk size, but no more digits can come
b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)),
b';' => Poll::Ready(Ok(ChunkedState::Extension)),
b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk size linear white space",
))),
}
}
fn read_extension(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
_ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions
}
}
fn read_size_lf(
rdr: &mut BytesMut,
size: &mut u64,
) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\n' if *size > 0 => Poll::Ready(Ok(ChunkedState::Body)),
b'\n' if *size == 0 => Poll::Ready(Ok(ChunkedState::EndCr)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk size LF",
))),
}
}
fn read_body(
rdr: &mut BytesMut,
rem: &mut u64,
buf: &mut Option<Bytes>,
) -> Poll<Result<ChunkedState, io::Error>> {
trace!("Chunked read, remaining={:?}", rem);
let len = rdr.len() as u64;
if len == 0 {
Poll::Ready(Ok(ChunkedState::Body))
} else {
let slice;
if *rem > len {
slice = rdr.split().freeze();
*rem -= len;
} else {
slice = rdr.split_to(*rem as usize).freeze();
*rem = 0;
}
*buf = Some(slice);
if *rem > 0 {
Poll::Ready(Ok(ChunkedState::Body))
} else {
Poll::Ready(Ok(ChunkedState::BodyCr))
}
}
}
fn read_body_cr(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk body CR",
))),
}
}
fn read_body_lf(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\n' => Poll::Ready(Ok(ChunkedState::Size)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk body LF",
))),
}
}
fn read_end_cr(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk end CR",
))),
}
}
fn read_end_lf(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr) {
b'\n' => Poll::Ready(Ok(ChunkedState::End)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk end LF",
))),
}
}
}
#[cfg(test)]
mod tests {
use bytes::{Bytes, BytesMut};
use http::{Method, Version};
use super::*;
use crate::error::ParseError;
use crate::http::header::{HeaderName, SET_COOKIE};
use crate::HttpMessage;
use crate::{
error::ParseError,
http::header::{HeaderName, SET_COOKIE},
HttpMessage as _,
};
impl PayloadType {
fn unwrap(self) -> PayloadDecoder {
pub(crate) fn unwrap(self) -> PayloadDecoder {
match self {
PayloadType::Payload(pl) => pl,
_ => panic!(),
}
}
fn is_unhandled(&self) -> bool {
pub(crate) fn is_unhandled(&self) -> bool {
matches!(self, PayloadType::Stream(_))
}
}
impl PayloadItem {
fn chunk(self) -> Bytes {
pub(crate) fn chunk(self) -> Bytes {
match self {
PayloadItem::Chunk(chunk) => chunk,
_ => panic!("error"),
}
}
fn eof(&self) -> bool {
pub(crate) fn eof(&self) -> bool {
matches!(*self, PayloadItem::Eof)
}
}
@ -967,34 +829,6 @@ mod tests {
assert!(req.upgrade());
}
#[test]
fn test_request_chunked() {
let mut buf = BytesMut::from(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n",
);
let req = parse_ready!(&mut buf);
if let Ok(val) = req.chunked() {
assert!(val);
} else {
unreachable!("Error");
}
// intentional typo in "chunked"
let mut buf = BytesMut::from(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chnked\r\n\r\n",
);
let req = parse_ready!(&mut buf);
if let Ok(val) = req.chunked() {
assert!(!val);
} else {
unreachable!("Error");
}
}
#[test]
fn test_headers_content_length_err_1() {
let mut buf = BytesMut::from(
@ -1112,126 +946,6 @@ mod tests {
expect_parse_err!(&mut buf);
}
#[test]
fn test_http_request_chunked_payload() {
let mut buf = BytesMut::from(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n",
);
let mut reader = MessageDecoder::<Request>::default();
let (req, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
assert!(req.chunked().unwrap());
buf.extend(b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n");
assert_eq!(
pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(),
b"data"
);
assert_eq!(
pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(),
b"line"
);
assert!(pl.decode(&mut buf).unwrap().unwrap().eof());
}
#[test]
fn test_http_request_chunked_payload_and_next_message() {
let mut buf = BytesMut::from(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n",
);
let mut reader = MessageDecoder::<Request>::default();
let (req, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
assert!(req.chunked().unwrap());
buf.extend(
b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n\
POST /test2 HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n"
.iter(),
);
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(msg.chunk().as_ref(), b"data");
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(msg.chunk().as_ref(), b"line");
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert!(msg.eof());
let (req, _) = reader.decode(&mut buf).unwrap().unwrap();
assert!(req.chunked().unwrap());
assert_eq!(*req.method(), Method::POST);
assert!(req.chunked().unwrap());
}
#[test]
fn test_http_request_chunked_payload_chunks() {
let mut buf = BytesMut::from(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n",
);
let mut reader = MessageDecoder::<Request>::default();
let (req, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
assert!(req.chunked().unwrap());
buf.extend(b"4\r\n1111\r\n");
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(msg.chunk().as_ref(), b"1111");
buf.extend(b"4\r\ndata\r");
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(msg.chunk().as_ref(), b"data");
buf.extend(b"\n4");
assert!(pl.decode(&mut buf).unwrap().is_none());
buf.extend(b"\r");
assert!(pl.decode(&mut buf).unwrap().is_none());
buf.extend(b"\n");
assert!(pl.decode(&mut buf).unwrap().is_none());
buf.extend(b"li");
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(msg.chunk().as_ref(), b"li");
//trailers
//buf.feed_data("test: test\r\n");
//not_ready!(reader.parse(&mut buf, &mut readbuf));
buf.extend(b"ne\r\n0\r\n");
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(msg.chunk().as_ref(), b"ne");
assert!(pl.decode(&mut buf).unwrap().is_none());
buf.extend(b"\r\n");
assert!(pl.decode(&mut buf).unwrap().unwrap().eof());
}
#[test]
fn test_parse_chunked_payload_chunk_extension() {
let mut buf = BytesMut::from(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\
\r\n",
);
let mut reader = MessageDecoder::<Request>::default();
let (msg, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
assert!(msg.chunked().unwrap());
buf.extend(b"4;test\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); // test: test\r\n\r\n")
let chunk = pl.decode(&mut buf).unwrap().unwrap().chunk();
assert_eq!(chunk, Bytes::from_static(b"data"));
let chunk = pl.decode(&mut buf).unwrap().unwrap().chunk();
assert_eq!(chunk, Bytes::from_static(b"line"));
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert!(msg.eof());
}
#[test]
fn test_response_http10_read_until_eof() {
let mut buf = BytesMut::from("HTTP/1.0 200 Ok\r\n\r\ntest data");
@ -1243,4 +957,84 @@ mod tests {
let chunk = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(chunk, PayloadItem::Chunk(Bytes::from_static(b"test data")));
}
#[test]
fn hrs_multiple_content_length() {
let mut buf = BytesMut::from(
"GET / HTTP/1.1\r\n\
Host: example.com\r\n\
Content-Length: 4\r\n\
Content-Length: 2\r\n\
\r\n\
abcd",
);
expect_parse_err!(&mut buf);
}
#[test]
fn hrs_content_length_plus() {
let mut buf = BytesMut::from(
"GET / HTTP/1.1\r\n\
Host: example.com\r\n\
Content-Length: +3\r\n\
\r\n\
000",
);
expect_parse_err!(&mut buf);
}
#[test]
fn hrs_unknown_transfer_encoding() {
let mut buf = BytesMut::from(
"GET / HTTP/1.1\r\n\
Host: example.com\r\n\
Transfer-Encoding: JUNK\r\n\
Transfer-Encoding: chunked\r\n\
\r\n\
5\r\n\
hello\r\n\
0",
);
expect_parse_err!(&mut buf);
}
#[test]
fn hrs_multiple_transfer_encoding() {
let mut buf = BytesMut::from(
"GET / HTTP/1.1\r\n\
Host: example.com\r\n\
Content-Length: 51\r\n\
Transfer-Encoding: identity\r\n\
Transfer-Encoding: chunked\r\n\
\r\n\
0\r\n\
\r\n\
GET /forbidden HTTP/1.1\r\n\
Host: example.com\r\n\r\n",
);
expect_parse_err!(&mut buf);
}
#[test]
fn transfer_encoding_agrees() {
let mut buf = BytesMut::from(
"GET /test HTTP/1.1\r\n\
Host: example.com\r\n\
Content-Length: 3\r\n\
Transfer-Encoding: identity\r\n\
\r\n\
0\r\n",
);
let mut reader = MessageDecoder::<Request>::default();
let (_msg, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
let chunk = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(chunk, PayloadItem::Chunk(Bytes::from_static(b"0\r\n")));
}
}

View File

@ -1,6 +1,5 @@
use std::{
collections::VecDeque,
error::Error as StdError,
fmt,
future::Future,
io, mem, net,
@ -19,7 +18,7 @@ use log::{error, trace};
use pin_project::pin_project;
use crate::{
body::{AnyBody, BodySize, MessageBody},
body::{BodySize, BoxBody, MessageBody},
config::ServiceConfig,
error::{DispatchError, ParseError, PayloadError},
service::HttpFlow,
@ -51,13 +50,12 @@ bitflags! {
pub struct Dispatcher<T, S, B, X, U>
where
S: Service<Request>,
S::Error: Into<Response<AnyBody>>,
S::Error: Into<Response<BoxBody>>,
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
X: Service<Request, Response = Request>,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
U: Service<(Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
@ -73,13 +71,12 @@ where
enum DispatcherState<T, S, B, X, U>
where
S: Service<Request>,
S::Error: Into<Response<AnyBody>>,
S::Error: Into<Response<BoxBody>>,
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
X: Service<Request, Response = Request>,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
U: Service<(Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
@ -92,13 +89,12 @@ where
struct InnerDispatcher<T, S, B, X, U>
where
S: Service<Request>,
S::Error: Into<Response<AnyBody>>,
S::Error: Into<Response<BoxBody>>,
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
X: Service<Request, Response = Request>,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
U: Service<(Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
@ -137,13 +133,12 @@ where
X: Service<Request, Response = Request>,
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
{
None,
ExpectCall(#[pin] X::Future),
ServiceCall(#[pin] S::Future),
SendPayload(#[pin] B),
SendErrorPayload(#[pin] AnyBody),
SendErrorPayload(#[pin] BoxBody),
}
impl<S, B, X> State<S, B, X>
@ -153,7 +148,6 @@ where
X: Service<Request, Response = Request>,
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
{
fn is_empty(&self) -> bool {
matches!(self, State::None)
@ -171,14 +165,13 @@ where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>,
S::Error: Into<Response<AnyBody>>,
S::Error: Into<Response<BoxBody>>,
S::Response: Into<Response<B>>,
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
X: Service<Request, Response = Request>,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
U: Service<(Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
@ -232,14 +225,13 @@ where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>,
S::Error: Into<Response<AnyBody>>,
S::Error: Into<Response<BoxBody>>,
S::Response: Into<Response<B>>,
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
X: Service<Request, Response = Request>,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
U: Service<(Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
@ -303,9 +295,9 @@ where
body: &impl MessageBody,
) -> Result<BodySize, DispatchError> {
let size = body.size();
let mut this = self.project();
let this = self.project();
this.codec
.encode(Message::Item((message, size)), &mut this.write_buf)
.encode(Message::Item((message, size)), this.write_buf)
.map_err(|err| {
if let Some(mut payload) = this.payload.take() {
payload.set_error(PayloadError::Incomplete(None));
@ -325,7 +317,7 @@ where
) -> Result<(), DispatchError> {
let size = self.as_mut().send_response_inner(message, &body)?;
let state = match size {
BodySize::None | BodySize::Empty => State::None,
BodySize::None | BodySize::Sized(0) => State::None,
_ => State::SendPayload(body),
};
self.project().state.set(state);
@ -335,11 +327,11 @@ where
fn send_error_response(
mut self: Pin<&mut Self>,
message: Response<()>,
body: AnyBody,
body: BoxBody,
) -> Result<(), DispatchError> {
let size = self.as_mut().send_response_inner(message, &body)?;
let state = match size {
BodySize::None | BodySize::Empty => State::None,
BodySize::None | BodySize::Sized(0) => State::None,
_ => State::SendErrorPayload(body),
};
self.project().state.set(state);
@ -380,7 +372,7 @@ where
// send_response would update InnerDispatcher state to SendPayload or
// None(If response body is empty).
// continue loop to poll it.
self.as_mut().send_error_response(res, AnyBody::Empty)?;
self.as_mut().send_error_response(res, BoxBody::new(()))?;
}
// return with upgrade request and poll it exclusively.
@ -400,7 +392,7 @@ where
// send service call error as response
Poll::Ready(Err(err)) => {
let res: Response<AnyBody> = err.into();
let res: Response<BoxBody> = err.into();
let (res, body) = res.replace_body(());
self.as_mut().send_error_response(res, body)?;
}
@ -425,13 +417,13 @@ where
Poll::Ready(Some(Ok(item))) => {
this.codec.encode(
Message::Chunk(Some(item)),
&mut this.write_buf,
this.write_buf,
)?;
}
Poll::Ready(None) => {
this.codec
.encode(Message::Chunk(None), &mut this.write_buf)?;
.encode(Message::Chunk(None), this.write_buf)?;
// payload stream finished.
// set state to None and handle next message
this.state.set(State::None);
@ -460,13 +452,13 @@ where
Poll::Ready(Some(Ok(item))) => {
this.codec.encode(
Message::Chunk(Some(item)),
&mut this.write_buf,
this.write_buf,
)?;
}
Poll::Ready(None) => {
this.codec
.encode(Message::Chunk(None), &mut this.write_buf)?;
.encode(Message::Chunk(None), this.write_buf)?;
// payload stream finished.
// set state to None and handle next message
this.state.set(State::None);
@ -497,7 +489,7 @@ where
// send expect error as response
Poll::Ready(Err(err)) => {
let res: Response<AnyBody> = err.into();
let res: Response<BoxBody> = err.into();
let (res, body) = res.replace_body(());
self.as_mut().send_error_response(res, body)?;
}
@ -546,7 +538,7 @@ where
// to notify the dispatcher a new state is set and the outer loop
// should be continue.
Poll::Ready(Err(err)) => {
let res: Response<AnyBody> = err.into();
let res: Response<BoxBody> = err.into();
let (res, body) = res.replace_body(());
return self.send_error_response(res, body);
}
@ -566,7 +558,7 @@ where
Poll::Pending => Ok(()),
// see the comment on ExpectCall state branch's Ready(Err(err)).
Poll::Ready(Err(err)) => {
let res: Response<AnyBody> = err.into();
let res: Response<BoxBody> = err.into();
let (res, body) = res.replace_body(());
self.send_error_response(res, body)
}
@ -592,7 +584,7 @@ where
let mut updated = false;
let mut this = self.as_mut().project();
loop {
match this.codec.decode(&mut this.read_buf) {
match this.codec.decode(this.read_buf) {
Ok(Some(msg)) => {
updated = true;
this.flags.insert(Flags::STARTED);
@ -772,7 +764,7 @@ where
trace!("Slow request timeout");
let _ = self.as_mut().send_error_response(
Response::with_body(StatusCode::REQUEST_TIMEOUT, ()),
AnyBody::Empty,
BoxBody::new(()),
);
this = self.project();
this.flags.insert(Flags::STARTED | Flags::SHUTDOWN);
@ -909,14 +901,13 @@ where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>,
S::Error: Into<Response<AnyBody>>,
S::Error: Into<Response<BoxBody>>,
S::Response: Into<Response<B>>,
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
X: Service<Request, Response = Request>,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
U: Service<(Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
@ -1060,24 +1051,26 @@ mod tests {
fn stabilize_date_header(payload: &mut [u8]) {
let mut from = 0;
while let Some(pos) = find_slice(&payload, b"date", from) {
while let Some(pos) = find_slice(payload, b"date", from) {
payload[(from + pos)..(from + pos + 35)]
.copy_from_slice(b"date: Thu, 01 Jan 1970 12:34:56 UTC");
from += 35;
}
}
fn ok_service() -> impl Service<Request, Response = Response<AnyBody>, Error = Error>
fn ok_service(
) -> impl Service<Request, Response = Response<impl MessageBody>, Error = Error>
{
fn_service(|_req: Request| ready(Ok::<_, Error>(Response::ok())))
}
fn echo_path_service(
) -> impl Service<Request, Response = Response<AnyBody>, Error = Error> {
) -> impl Service<Request, Response = Response<impl MessageBody>, Error = Error>
{
fn_service(|req: Request| {
let path = req.path().as_bytes();
ready(Ok::<_, Error>(
Response::ok().set_body(AnyBody::from_slice(path)),
Response::ok().set_body(Bytes::copy_from_slice(path)),
))
})
}

View File

@ -20,6 +20,7 @@ const AVERAGE_HEADER_SIZE: usize = 30;
#[derive(Debug)]
pub(crate) struct MessageEncoder<T: MessageType> {
#[allow(dead_code)]
pub length: BodySize,
pub te: TransferEncoding,
_phantom: PhantomData<T>,
@ -55,7 +56,7 @@ pub(crate) trait MessageType: Sized {
dst: &mut BytesMut,
version: Version,
mut length: BodySize,
ctype: ConnectionType,
conn_type: ConnectionType,
config: &ServiceConfig,
) -> io::Result<()> {
let chunked = self.chunked();
@ -70,17 +71,28 @@ pub(crate) trait MessageType: Sized {
| StatusCode::PROCESSING
| StatusCode::NO_CONTENT => {
// skip content-length and transfer-encoding headers
// See https://tools.ietf.org/html/rfc7230#section-3.3.1
// and https://tools.ietf.org/html/rfc7230#section-3.3.2
// see https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.1
// and https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.2
skip_len = true;
length = BodySize::None
}
StatusCode::NOT_MODIFIED => {
// 304 responses should never have a body but should retain a manually set
// content-length header
// see https://datatracker.ietf.org/doc/html/rfc7232#section-4.1
skip_len = false;
length = BodySize::None;
}
_ => {}
}
}
match length {
BodySize::Stream => {
if chunked {
skip_len = true;
if camel_case {
dst.put_slice(b"\r\nTransfer-Encoding: chunked\r\n")
} else {
@ -91,19 +103,16 @@ pub(crate) trait MessageType: Sized {
dst.put_slice(b"\r\n");
}
}
BodySize::Empty => {
if camel_case {
dst.put_slice(b"\r\nContent-Length: 0\r\n");
} else {
dst.put_slice(b"\r\ncontent-length: 0\r\n");
}
BodySize::Sized(0) if camel_case => {
dst.put_slice(b"\r\nContent-Length: 0\r\n")
}
BodySize::Sized(0) => dst.put_slice(b"\r\ncontent-length: 0\r\n"),
BodySize::Sized(len) => helpers::write_content_length(len, dst),
BodySize::None => dst.put_slice(b"\r\n"),
}
// Connection
match ctype {
match conn_type {
ConnectionType::Upgrade => dst.put_slice(b"connection: upgrade\r\n"),
ConnectionType::KeepAlive if version < Version::HTTP_11 => {
if camel_case {
@ -174,7 +183,7 @@ pub(crate) trait MessageType: Sized {
unsafe {
if camel_case {
// use Camel-Case headers
write_camel_case(k, from_raw_parts_mut(buf, k_len));
write_camel_case(k, buf, k_len);
} else {
write_data(k, buf, k_len);
}
@ -328,13 +337,13 @@ impl<T: MessageType> MessageEncoder<T> {
stream: bool,
version: Version,
length: BodySize,
ctype: ConnectionType,
conn_type: ConnectionType,
config: &ServiceConfig,
) -> io::Result<()> {
// transfer encoding
if !head {
self.te = match length {
BodySize::Empty => TransferEncoding::empty(),
BodySize::Sized(0) => TransferEncoding::empty(),
BodySize::Sized(len) => TransferEncoding::length(len),
BodySize::Stream => {
if message.chunked() && !stream {
@ -350,7 +359,7 @@ impl<T: MessageType> MessageEncoder<T> {
}
message.encode_status(dst)?;
message.encode_headers(dst, version, length, ctype, config)
message.encode_headers(dst, version, length, conn_type, config)
}
}
@ -364,10 +373,12 @@ pub(crate) struct TransferEncoding {
enum TransferEncodingKind {
/// An Encoder for when Transfer-Encoding includes `chunked`.
Chunked(bool),
/// An Encoder for when Content-Length is set.
///
/// Enforces that the body is not longer than the Content-Length header.
Length(u64),
/// An Encoder for when Content-Length is not known.
///
/// Application decides when to stop writing.
@ -472,15 +483,22 @@ impl TransferEncoding {
}
/// # Safety
/// Callers must ensure that the given length matches given value length.
/// Callers must ensure that the given `len` matches the given `value` length and that `buf` is
/// valid for writes of at least `len` bytes.
unsafe fn write_data(value: &[u8], buf: *mut u8, len: usize) {
debug_assert_eq!(value.len(), len);
copy_nonoverlapping(value.as_ptr(), buf, len);
}
fn write_camel_case(value: &[u8], buffer: &mut [u8]) {
/// # Safety
/// Callers must ensure that the given `len` matches the given `value` length and that `buf` is
/// valid for writes of at least `len` bytes.
unsafe fn write_camel_case(value: &[u8], buf: *mut u8, len: usize) {
// first copy entire (potentially wrong) slice to output
buffer[..value.len()].copy_from_slice(value);
write_data(value, buf, len);
// SAFETY: We just initialized the buffer with `value`
let buffer = from_raw_parts_mut(buf, len);
let mut iter = value.iter();
@ -544,7 +562,7 @@ mod tests {
let _ = head.encode_headers(
&mut bytes,
Version::HTTP_11,
BodySize::Empty,
BodySize::Sized(0),
ConnectionType::Close,
&ServiceConfig::default(),
);
@ -615,7 +633,7 @@ mod tests {
let _ = head.encode_headers(
&mut bytes,
Version::HTTP_11,
BodySize::Empty,
BodySize::Sized(0),
ConnectionType::Close,
&ServiceConfig::default(),
);

View File

@ -1,6 +1,8 @@
//! HTTP/1 protocol implementation.
use bytes::{Bytes, BytesMut};
mod chunked;
mod client;
mod codec;
mod decoder;

View File

@ -1,5 +1,4 @@
use std::{
error::Error as StdError,
fmt,
marker::PhantomData,
net,
@ -16,7 +15,7 @@ use actix_utils::future::ready;
use futures_core::future::LocalBoxFuture;
use crate::{
body::{AnyBody, MessageBody},
body::{BoxBody, MessageBody},
config::ServiceConfig,
error::DispatchError,
service::HttpServiceHandler,
@ -38,7 +37,7 @@ pub struct H1Service<T, S, B, X = ExpectHandler, U = UpgradeHandler> {
impl<T, S, B> H1Service<T, S, B>
where
S: ServiceFactory<Request, Config = ()>,
S::Error: Into<Response<AnyBody>>,
S::Error: Into<Response<BoxBody>>,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
B: MessageBody,
@ -63,21 +62,20 @@ impl<S, B, X, U> H1Service<TcpStream, S, B, X, U>
where
S: ServiceFactory<Request, Config = ()>,
S::Future: 'static,
S::Error: Into<Response<AnyBody>>,
S::Error: Into<Response<BoxBody>>,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
X: ServiceFactory<Request, Config = (), Response = Request>,
X::Future: 'static,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
X::InitError: fmt::Debug,
U: ServiceFactory<(Request, Framed<TcpStream, Codec>), Config = (), Response = ()>,
U::Future: 'static,
U::Error: fmt::Display + Into<Response<AnyBody>>,
U::Error: fmt::Display + Into<Response<BoxBody>>,
U::InitError: fmt::Debug,
{
/// Create simple tcp stream service
@ -102,9 +100,11 @@ where
mod openssl {
use super::*;
use actix_service::ServiceFactoryExt;
use actix_tls::accept::{
openssl::{Acceptor, SslAcceptor, SslError, TlsStream},
openssl::{
reexports::{Error as SslError, SslAcceptor},
Acceptor, TlsStream,
},
TlsError,
};
@ -112,16 +112,15 @@ mod openssl {
where
S: ServiceFactory<Request, Config = ()>,
S::Future: 'static,
S::Error: Into<Response<AnyBody>>,
S::Error: Into<Response<BoxBody>>,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
X: ServiceFactory<Request, Config = (), Response = Request>,
X::Future: 'static,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
X::InitError: fmt::Debug,
U: ServiceFactory<
@ -130,10 +129,10 @@ mod openssl {
Response = (),
>,
U::Future: 'static,
U::Error: fmt::Display + Into<Response<AnyBody>>,
U::Error: fmt::Display + Into<Response<BoxBody>>,
U::InitError: fmt::Debug,
{
/// Create openssl based service
/// Create OpenSSL based service.
pub fn openssl(
self,
acceptor: SslAcceptor,
@ -145,11 +144,13 @@ mod openssl {
InitError = (),
> {
Acceptor::new(acceptor)
.map_err(TlsError::Tls)
.map_init_err(|_| panic!())
.and_then(|io: TlsStream<TcpStream>| {
.map_init_err(|_| {
unreachable!("TLS acceptor service factory does not error on init")
})
.map_err(TlsError::into_service_error)
.map(|io: TlsStream<TcpStream>| {
let peer_addr = io.get_ref().peer_addr().ok();
ready(Ok((io, peer_addr)))
(io, peer_addr)
})
.and_then(self.map_err(TlsError::Service))
}
@ -158,30 +159,30 @@ mod openssl {
#[cfg(feature = "rustls")]
mod rustls {
use super::*;
use std::io;
use actix_service::ServiceFactoryExt;
use actix_service::ServiceFactoryExt as _;
use actix_tls::accept::{
rustls::{Acceptor, ServerConfig, TlsStream},
rustls::{reexports::ServerConfig, Acceptor, TlsStream},
TlsError,
};
use super::*;
impl<S, B, X, U> H1Service<TlsStream<TcpStream>, S, B, X, U>
where
S: ServiceFactory<Request, Config = ()>,
S::Future: 'static,
S::Error: Into<Response<AnyBody>>,
S::Error: Into<Response<BoxBody>>,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
X: ServiceFactory<Request, Config = (), Response = Request>,
X::Future: 'static,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
X::InitError: fmt::Debug,
U: ServiceFactory<
@ -190,10 +191,10 @@ mod rustls {
Response = (),
>,
U::Future: 'static,
U::Error: fmt::Display + Into<Response<AnyBody>>,
U::Error: fmt::Display + Into<Response<BoxBody>>,
U::InitError: fmt::Debug,
{
/// Create rustls based service
/// Create Rustls based service.
pub fn rustls(
self,
config: ServerConfig,
@ -205,11 +206,13 @@ mod rustls {
InitError = (),
> {
Acceptor::new(config)
.map_err(TlsError::Tls)
.map_init_err(|_| panic!())
.and_then(|io: TlsStream<TcpStream>| {
.map_init_err(|_| {
unreachable!("TLS acceptor service factory does not error on init")
})
.map_err(TlsError::into_service_error)
.map(|io: TlsStream<TcpStream>| {
let peer_addr = io.get_ref().0.peer_addr().ok();
ready(Ok((io, peer_addr)))
(io, peer_addr)
})
.and_then(self.map_err(TlsError::Service))
}
@ -219,7 +222,7 @@ mod rustls {
impl<T, S, B, X, U> H1Service<T, S, B, X, U>
where
S: ServiceFactory<Request, Config = ()>,
S::Error: Into<Response<AnyBody>>,
S::Error: Into<Response<BoxBody>>,
S::Response: Into<Response<B>>,
S::InitError: fmt::Debug,
B: MessageBody,
@ -227,7 +230,7 @@ where
pub fn expect<X1>(self, expect: X1) -> H1Service<T, S, B, X1, U>
where
X1: ServiceFactory<Request, Response = Request>,
X1::Error: Into<Response<AnyBody>>,
X1::Error: Into<Response<BoxBody>>,
X1::InitError: fmt::Debug,
{
H1Service {
@ -270,21 +273,20 @@ where
S: ServiceFactory<Request, Config = ()>,
S::Future: 'static,
S::Error: Into<Response<AnyBody>>,
S::Error: Into<Response<BoxBody>>,
S::Response: Into<Response<B>>,
S::InitError: fmt::Debug,
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
X: ServiceFactory<Request, Config = (), Response = Request>,
X::Future: 'static,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
X::InitError: fmt::Debug,
U: ServiceFactory<(Request, Framed<T, Codec>), Config = (), Response = ()>,
U::Future: 'static,
U::Error: fmt::Display + Into<Response<AnyBody>>,
U::Error: fmt::Display + Into<Response<BoxBody>>,
U::InitError: fmt::Debug,
{
type Response = ();
@ -340,17 +342,16 @@ where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>,
S::Error: Into<Response<AnyBody>>,
S::Error: Into<Response<BoxBody>>,
S::Response: Into<Response<B>>,
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
X: Service<Request, Response = Request>,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
U: Service<(Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display + Into<Response<AnyBody>>,
U::Error: fmt::Display + Into<Response<BoxBody>>,
{
type Response = ();
type Error = DispatchError;

View File

@ -1,22 +1,30 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use pin_project_lite::pin_project;
use crate::body::{BodySize, MessageBody};
use crate::error::Error;
use crate::h1::{Codec, Message};
use crate::response::Response;
use crate::{
body::{BodySize, MessageBody},
error::Error,
h1::{Codec, Message},
response::Response,
};
/// Send HTTP/1 response
#[pin_project::pin_project]
pub struct SendResponse<T, B> {
res: Option<Message<(Response<()>, BodySize)>>,
#[pin]
body: Option<B>,
#[pin]
framed: Option<Framed<T, Codec>>,
pin_project! {
/// Send HTTP/1 response
pub struct SendResponse<T, B> {
res: Option<Message<(Response<()>, BodySize)>>,
#[pin]
body: Option<B>,
#[pin]
framed: Option<Framed<T, Codec>>,
}
}
impl<T, B> SendResponse<T, B>
@ -63,7 +71,6 @@ where
.is_write_buf_full()
{
let next =
// TODO: MSRV 1.51: poll_map_err
match this.body.as_mut().as_pin_mut().unwrap().poll_next(cx) {
Poll::Ready(Some(Ok(item))) => Poll::Ready(Some(item)),
Poll::Ready(Some(Err(err))) => {

View File

@ -10,17 +10,21 @@ use std::{
};
use actix_codec::{AsyncRead, AsyncWrite};
use actix_rt::time::{sleep, Sleep};
use actix_service::Service;
use actix_utils::future::poll_fn;
use bytes::{Bytes, BytesMut};
use futures_core::ready;
use h2::server::{Connection, SendResponse};
use h2::{
server::{Connection, SendResponse},
Ping, PingPong,
};
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
use log::{error, trace};
use pin_project_lite::pin_project;
use crate::{
body::{AnyBody, BodySize, MessageBody},
body::{BodySize, BoxBody, MessageBody},
config::ServiceConfig,
service::HttpFlow,
OnConnectData, Payload, Request, Response, ResponseHead,
@ -36,40 +40,63 @@ pin_project! {
on_connect_data: OnConnectData,
config: ServiceConfig,
peer_addr: Option<net::SocketAddr>,
_phantom: PhantomData<B>,
ping_pong: Option<H2PingPong>,
_phantom: PhantomData<B>
}
}
impl<T, S, B, X, U> Dispatcher<T, S, B, X, U> {
impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin,
{
pub(crate) fn new(
flow: Rc<HttpFlow<S, X, U>>,
connection: Connection<T, Bytes>,
mut conn: Connection<T, Bytes>,
on_connect_data: OnConnectData,
config: ServiceConfig,
peer_addr: Option<net::SocketAddr>,
timer: Option<Pin<Box<Sleep>>>,
) -> Self {
let ping_pong = config.keep_alive().map(|dur| H2PingPong {
timer: timer
.map(|mut timer| {
// reset timer if it's received from new function.
timer.as_mut().reset(config.now() + dur);
timer
})
.unwrap_or_else(|| Box::pin(sleep(dur))),
on_flight: false,
ping_pong: conn.ping_pong().unwrap(),
});
Self {
flow,
config,
peer_addr,
connection,
connection: conn,
on_connect_data,
ping_pong,
_phantom: PhantomData,
}
}
}
struct H2PingPong {
timer: Pin<Box<Sleep>>,
on_flight: bool,
ping_pong: PingPong,
}
impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>,
S::Error: Into<Response<AnyBody>>,
S::Error: Into<Response<BoxBody>>,
S::Future: 'static,
S::Response: Into<Response<B>>,
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
{
type Output = Result<(), crate::error::DispatchError>;
@ -77,54 +104,92 @@ where
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
while let Some((req, tx)) =
ready!(Pin::new(&mut this.connection).poll_accept(cx)?)
{
let (parts, body) = req.into_parts();
let pl = crate::h2::Payload::new(body);
let pl = Payload::<crate::payload::PayloadStream>::H2(pl);
let mut req = Request::with_payload(pl);
loop {
match Pin::new(&mut this.connection).poll_accept(cx)? {
Poll::Ready(Some((req, tx))) => {
let (parts, body) = req.into_parts();
let pl = crate::h2::Payload::new(body);
let pl = Payload::<crate::payload::PayloadStream>::H2(pl);
let mut req = Request::with_payload(pl);
let head = req.head_mut();
head.uri = parts.uri;
head.method = parts.method;
head.version = parts.version;
head.headers = parts.headers.into();
head.peer_addr = this.peer_addr;
let head = req.head_mut();
head.uri = parts.uri;
head.method = parts.method;
head.version = parts.version;
head.headers = parts.headers.into();
head.peer_addr = this.peer_addr;
// merge on_connect_ext data into request extensions
this.on_connect_data.merge_into(&mut req);
// merge on_connect_ext data into request extensions
this.on_connect_data.merge_into(&mut req);
let fut = this.flow.service.call(req);
let config = this.config.clone();
let fut = this.flow.service.call(req);
let config = this.config.clone();
// multiplex request handling with spawn task
actix_rt::spawn(async move {
// resolve service call and send response.
let res = match fut.await {
Ok(res) => handle_response(res.into(), tx, config).await,
Err(err) => {
let res: Response<AnyBody> = err.into();
handle_response(res, tx, config).await
}
};
// multiplex request handling with spawn task
actix_rt::spawn(async move {
// resolve service call and send response.
let res = match fut.await {
Ok(res) => handle_response(res.into(), tx, config).await,
Err(err) => {
let res: Response<BoxBody> = err.into();
handle_response(res, tx, config).await
}
};
// log error.
if let Err(err) = res {
match err {
DispatchError::SendResponse(err) => {
trace!("Error sending HTTP/2 response: {:?}", err)
// log error.
if let Err(err) = res {
match err {
DispatchError::SendResponse(err) => {
trace!("Error sending HTTP/2 response: {:?}", err)
}
DispatchError::SendData(err) => warn!("{:?}", err),
DispatchError::ResponseBody(err) => {
error!("Response payload stream error: {:?}", err)
}
}
}
DispatchError::SendData(err) => warn!("{:?}", err),
DispatchError::ResponseBody(err) => {
error!("Response payload stream error: {:?}", err)
}
}
});
}
});
}
Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Pending => match this.ping_pong.as_mut() {
Some(ping_pong) => loop {
if ping_pong.on_flight {
// When have on flight ping pong. poll pong and and keep alive timer.
// on success pong received update keep alive timer to determine the next timing of
// ping pong.
match ping_pong.ping_pong.poll_pong(cx)? {
Poll::Ready(_) => {
ping_pong.on_flight = false;
Poll::Ready(Ok(()))
let dead_line =
this.config.keep_alive_expire().unwrap();
ping_pong.timer.as_mut().reset(dead_line);
}
Poll::Pending => {
return ping_pong
.timer
.as_mut()
.poll(cx)
.map(|_| Ok(()))
}
}
} else {
// When there is no on flight ping pong. keep alive timer is used to wait for next
// timing of ping pong. Therefore at this point it serves as an interval instead.
ready!(ping_pong.timer.as_mut().poll(cx));
ping_pong.ping_pong.send_ping(Ping::opaque())?;
let dead_line = this.config.keep_alive_expire().unwrap();
ping_pong.timer.as_mut().reset(dead_line);
ping_pong.on_flight = true;
}
},
None => return Poll::Pending,
},
}
}
}
}
@ -141,7 +206,6 @@ async fn handle_response<B>(
) -> Result<(), DispatchError>
where
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
{
let (res, body) = res.replace_body(());
@ -226,9 +290,11 @@ fn prepare_response(
let _ = match size {
BodySize::None | BodySize::Stream => None,
BodySize::Empty => res
BodySize::Sized(0) => res
.headers_mut()
.insert(CONTENT_LENGTH, HeaderValue::from_static("0")),
BodySize::Sized(len) => {
let mut buf = itoa::Buffer::new();
@ -243,7 +309,7 @@ fn prepare_response(
for (key, value) in head.headers.iter() {
match *key {
// TODO: consider skipping other headers according to:
// https://tools.ietf.org/html/rfc7540#section-8.1.2.2
// https://datatracker.ietf.org/doc/html/rfc7540#section-8.1.2.2
// omit HTTP/1.x only headers
CONNECTION | TRANSFER_ENCODING => continue,
CONTENT_LENGTH if skip_len => continue,

View File

@ -1,20 +1,30 @@
//! HTTP/2 protocol.
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use actix_codec::{AsyncRead, AsyncWrite};
use actix_rt::time::Sleep;
use bytes::Bytes;
use futures_core::{ready, Stream};
use h2::RecvStream;
use h2::{
server::{handshake, Connection, Handshake},
RecvStream,
};
mod dispatcher;
mod service;
pub use self::dispatcher::Dispatcher;
pub use self::service::H2Service;
use crate::error::PayloadError;
use crate::{
config::ServiceConfig,
error::{DispatchError, PayloadError},
};
/// HTTP/2 peer stream.
pub struct Payload {
@ -50,3 +60,44 @@ impl Stream for Payload {
}
}
}
pub(crate) fn handshake_with_timeout<T>(
io: T,
config: &ServiceConfig,
) -> HandshakeWithTimeout<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
HandshakeWithTimeout {
handshake: handshake(io),
timer: config.client_timer().map(Box::pin),
}
}
pub(crate) struct HandshakeWithTimeout<T: AsyncRead + AsyncWrite + Unpin> {
handshake: Handshake<T>,
timer: Option<Pin<Box<Sleep>>>,
}
impl<T> Future for HandshakeWithTimeout<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<(Connection<T, Bytes>, Option<Pin<Box<Sleep>>>), DispatchError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match Pin::new(&mut this.handshake).poll(cx)? {
// return the timer on success handshake. It can be re-used for h2 ping-pong.
Poll::Ready(conn) => Poll::Ready(Ok((conn, this.timer.take()))),
Poll::Pending => match this.timer.as_mut() {
Some(timer) => {
ready!(timer.as_mut().poll(cx));
Poll::Ready(Err(DispatchError::SlowRequestTimeout))
}
None => Poll::Pending,
},
}
}
}

View File

@ -1,5 +1,4 @@
use std::{
error::Error as StdError,
future::Future,
marker::PhantomData,
net,
@ -15,20 +14,18 @@ use actix_service::{
ServiceFactoryExt as _,
};
use actix_utils::future::ready;
use bytes::Bytes;
use futures_core::{future::LocalBoxFuture, ready};
use h2::server::{handshake as h2_handshake, Handshake as H2Handshake};
use log::error;
use crate::{
body::{AnyBody, MessageBody},
body::{BoxBody, MessageBody},
config::ServiceConfig,
error::DispatchError,
service::HttpFlow,
ConnectCallback, OnConnectData, Request, Response,
};
use super::dispatcher::Dispatcher;
use super::{dispatcher::Dispatcher, handshake_with_timeout, HandshakeWithTimeout};
/// `ServiceFactory` implementation for HTTP/2 transport
pub struct H2Service<T, S, B> {
@ -41,12 +38,11 @@ pub struct H2Service<T, S, B> {
impl<T, S, B> H2Service<T, S, B>
where
S: ServiceFactory<Request, Config = ()>,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service<Request>>::Future: 'static,
B: MessageBody + 'static,
B::Error: Into<Box<dyn StdError>>,
{
/// Create new `H2Service` instance with config.
pub(crate) fn with_config<F: IntoServiceFactory<S, Request>>(
@ -72,12 +68,11 @@ impl<S, B> H2Service<TcpStream, S, B>
where
S: ServiceFactory<Request, Config = ()>,
S::Future: 'static,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service<Request>>::Future: 'static,
B: MessageBody + 'static,
B::Error: Into<Box<dyn StdError>>,
{
/// Create plain TCP based service
pub fn tcp(
@ -101,9 +96,14 @@ where
#[cfg(feature = "openssl")]
mod openssl {
use actix_service::{fn_factory, fn_service, ServiceFactoryExt};
use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, TlsStream};
use actix_tls::accept::TlsError;
use actix_service::ServiceFactoryExt as _;
use actix_tls::accept::{
openssl::{
reexports::{Error as SslError, SslAcceptor},
Acceptor, TlsStream,
},
TlsError,
};
use super::*;
@ -111,14 +111,13 @@ mod openssl {
where
S: ServiceFactory<Request, Config = ()>,
S::Future: 'static,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service<Request>>::Future: 'static,
B: MessageBody + 'static,
B::Error: Into<Box<dyn StdError>>,
{
/// Create OpenSSL based service
/// Create OpenSSL based service.
pub fn openssl(
self,
acceptor: SslAcceptor,
@ -130,16 +129,14 @@ mod openssl {
InitError = S::InitError,
> {
Acceptor::new(acceptor)
.map_err(TlsError::Tls)
.map_init_err(|_| panic!())
.and_then(fn_factory(|| {
ready(Ok::<_, S::InitError>(fn_service(
|io: TlsStream<TcpStream>| {
let peer_addr = io.get_ref().peer_addr().ok();
ready(Ok((io, peer_addr)))
},
)))
}))
.map_init_err(|_| {
unreachable!("TLS acceptor service factory does not error on init")
})
.map_err(TlsError::into_service_error)
.map(|io: TlsStream<TcpStream>| {
let peer_addr = io.get_ref().peer_addr().ok();
(io, peer_addr)
})
.and_then(self.map_err(TlsError::Service))
}
}
@ -147,24 +144,27 @@ mod openssl {
#[cfg(feature = "rustls")]
mod rustls {
use super::*;
use actix_service::ServiceFactoryExt;
use actix_tls::accept::rustls::{Acceptor, ServerConfig, TlsStream};
use actix_tls::accept::TlsError;
use std::io;
use actix_service::ServiceFactoryExt as _;
use actix_tls::accept::{
rustls::{reexports::ServerConfig, Acceptor, TlsStream},
TlsError,
};
use super::*;
impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
where
S: ServiceFactory<Request, Config = ()>,
S::Future: 'static,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service<Request>>::Future: 'static,
B: MessageBody + 'static,
B::Error: Into<Box<dyn StdError>>,
{
/// Create Rustls based service
/// Create Rustls based service.
pub fn rustls(
self,
mut config: ServerConfig,
@ -177,19 +177,17 @@ mod rustls {
> {
let mut protos = vec![b"h2".to_vec()];
protos.extend_from_slice(&config.alpn_protocols);
config.set_protocols(&protos);
config.alpn_protocols = protos;
Acceptor::new(config)
.map_err(TlsError::Tls)
.map_init_err(|_| panic!())
.and_then(fn_factory(|| {
ready(Ok::<_, S::InitError>(fn_service(
|io: TlsStream<TcpStream>| {
let peer_addr = io.get_ref().0.peer_addr().ok();
ready(Ok((io, peer_addr)))
},
)))
}))
.map_init_err(|_| {
unreachable!("TLS acceptor service factory does not error on init")
})
.map_err(TlsError::into_service_error)
.map(|io: TlsStream<TcpStream>| {
let peer_addr = io.get_ref().0.peer_addr().ok();
(io, peer_addr)
})
.and_then(self.map_err(TlsError::Service))
}
}
@ -201,12 +199,11 @@ where
S: ServiceFactory<Request, Config = ()>,
S::Future: 'static,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service<Request>>::Future: 'static,
B: MessageBody + 'static,
B::Error: Into<Box<dyn StdError>>,
{
type Response = ();
type Error = DispatchError;
@ -241,7 +238,7 @@ where
impl<T, S, B> H2ServiceHandler<T, S, B>
where
S: Service<Request>,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::Future: 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
@ -264,11 +261,10 @@ impl<T, S, B> Service<(T, Option<net::SocketAddr>)> for H2ServiceHandler<T, S, B
where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::Future: 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
B::Error: Into<Box<dyn StdError>>,
{
type Response = ();
type Error = DispatchError;
@ -292,7 +288,7 @@ where
Some(self.cfg.clone()),
addr,
on_connect_data,
h2_handshake(io),
handshake_with_timeout(io, &self.cfg),
),
}
}
@ -309,7 +305,7 @@ where
Option<ServiceConfig>,
Option<net::SocketAddr>,
OnConnectData,
H2Handshake<T, Bytes>,
HandshakeWithTimeout<T>,
),
}
@ -317,7 +313,7 @@ pub struct H2ServiceHandlerResponse<T, S, B>
where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::Future: 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
@ -329,11 +325,10 @@ impl<T, S, B> Future for H2ServiceHandlerResponse<T, S, B>
where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::Future: 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
{
type Output = Result<(), DispatchError>;
@ -347,7 +342,7 @@ where
ref mut on_connect_data,
ref mut handshake,
) => match ready!(Pin::new(handshake).poll(cx)) {
Ok(conn) => {
Ok((conn, timer)) => {
let on_connect_data = std::mem::take(on_connect_data);
self.state = State::Incoming(Dispatcher::new(
srv.take().unwrap(),
@ -355,12 +350,13 @@ where
on_connect_data,
config.take().unwrap(),
*peer_addr,
timer,
));
self.poll(cx)
}
Err(err) => {
trace!("H2 handshake error: {}", err);
Poll::Ready(Err(err.into()))
Poll::Ready(Err(err))
}
},
}

View File

@ -1,11 +1,12 @@
//! Helper trait for types that can be effectively borrowed as a [HeaderValue].
//!
//! [HeaderValue]: crate::http::HeaderValue
//! Sealed [`AsHeaderName`] trait and implementations.
use std::{borrow::Cow, str::FromStr};
use std::{borrow::Cow, str::FromStr as _};
use http::header::{HeaderName, InvalidHeaderName};
/// Sealed trait implemented for types that can be effectively borrowed as a [`HeaderValue`].
///
/// [`HeaderValue`]: crate::http::HeaderValue
pub trait AsHeaderName: Sealed {}
pub struct Seal;

View File

@ -1,4 +1,6 @@
use std::convert::TryFrom;
//! [`IntoHeaderPair`] trait and implementations.
use std::convert::TryFrom as _;
use http::{
header::{HeaderName, InvalidHeaderName, InvalidHeaderValue},
@ -7,7 +9,10 @@ use http::{
use super::{Header, IntoHeaderValue};
/// Transforms structures into header K/V pairs for inserting into `HeaderMap`s.
/// An interface for types that can be converted into a [`HeaderName`]/[`HeaderValue`] pair for
/// insertion into a [`HeaderMap`].
///
/// [`HeaderMap`]: crate::http::HeaderMap
pub trait IntoHeaderPair: Sized {
type Error: Into<HttpError>;

View File

@ -1,10 +1,12 @@
use std::convert::TryFrom;
//! [`IntoHeaderValue`] trait and implementations.
use std::convert::TryFrom as _;
use bytes::Bytes;
use http::{header::InvalidHeaderValue, Error as HttpError, HeaderValue};
use mime::Mime;
/// A trait for any object that can be Converted to a `HeaderValue`
/// An interface for types that can be converted into a [`HeaderValue`].
pub trait IntoHeaderValue: Sized {
/// The type returned in the event of a conversion error.
type Error: Into<HttpError>;

View File

@ -1,6 +1,6 @@
//! A multi-value [`HeaderMap`] and its iterators.
use std::{borrow::Cow, collections::hash_map, ops};
use std::{borrow::Cow, collections::hash_map, iter, ops};
use ahash::AHashMap;
use http::header::{HeaderName, HeaderValue};
@ -288,7 +288,7 @@ impl HeaderMap {
/// Returns an iterator over all values associated with a header name.
///
/// The returned iterator does not incur any allocations and will yield no items if there are no
/// values associated with the key. Iteration order is **not** guaranteed to be the same as
/// values associated with the key. Iteration order is guaranteed to be the same as
/// insertion order.
///
/// # Examples
@ -355,6 +355,19 @@ impl HeaderMap {
///
/// assert_eq!(map.len(), 1);
/// ```
///
/// A convenience method is provided on the returned iterator to check if the insertion replaced
/// any values.
/// ```
/// # use actix_http::http::{header, HeaderMap, HeaderValue};
/// let mut map = HeaderMap::new();
///
/// let removed = map.insert(header::ACCEPT, HeaderValue::from_static("text/plain"));
/// assert!(removed.is_empty());
///
/// let removed = map.insert(header::ACCEPT, HeaderValue::from_static("text/html"));
/// assert!(!removed.is_empty());
/// ```
pub fn insert(&mut self, key: HeaderName, val: HeaderValue) -> Removed {
let value = self.inner.insert(key, Value::one(val));
Removed::new(value)
@ -393,6 +406,9 @@ impl HeaderMap {
/// Removes all headers for a particular header name from the map.
///
/// Providing an invalid header names (as a string argument) will have no effect and return
/// without error.
///
/// # Examples
/// ```
/// # use actix_http::http::{header, HeaderMap, HeaderValue};
@ -409,6 +425,21 @@ impl HeaderMap {
/// assert!(removed.next().is_none());
///
/// assert!(map.is_empty());
/// ```
///
/// A convenience method is provided on the returned iterator to check if the `remove` call
/// actually removed any values.
/// ```
/// # use actix_http::http::{header, HeaderMap, HeaderValue};
/// let mut map = HeaderMap::new();
///
/// let removed = map.remove("accept");
/// assert!(removed.is_empty());
///
/// map.insert(header::ACCEPT, HeaderValue::from_static("text/html"));
/// let removed = map.remove("accept");
/// assert!(!removed.is_empty());
/// ```
pub fn remove(&mut self, key: impl AsHeaderName) -> Removed {
let value = match key.try_as_name(super::as_name::Seal) {
Ok(Cow::Borrowed(name)) => self.inner.remove(name),
@ -550,7 +581,8 @@ impl HeaderMap {
}
}
/// Note that this implementation will clone a [HeaderName] for each value.
/// Note that this implementation will clone a [HeaderName] for each value. Consider using
/// [`drain`](Self::drain) to control header name cloning.
impl IntoIterator for HeaderMap {
type Item = (HeaderName, HeaderValue);
type IntoIter = IntoIter;
@ -571,7 +603,7 @@ impl<'a> IntoIterator for &'a HeaderMap {
}
}
/// Iterator for all values with the same header name.
/// Iterator over borrowed values with the same associated name.
///
/// See [`HeaderMap::get_all`].
#[derive(Debug)]
@ -613,18 +645,36 @@ impl<'a> Iterator for GetAll<'a> {
}
}
/// Iterator for owned [`HeaderValue`]s with the same associated [`HeaderName`] returned from methods
/// on [`HeaderMap`] that remove or replace items.
impl ExactSizeIterator for GetAll<'_> {}
impl iter::FusedIterator for GetAll<'_> {}
/// Iterator over removed, owned values with the same associated name.
///
/// Returned from methods that remove or replace items. See [`HeaderMap::insert`]
/// and [`HeaderMap::remove`].
#[derive(Debug)]
pub struct Removed {
inner: Option<smallvec::IntoIter<[HeaderValue; 4]>>,
}
impl<'a> Removed {
impl Removed {
fn new(value: Option<Value>) -> Self {
let inner = value.map(|value| value.inner.into_iter());
Self { inner }
}
/// Returns true if iterator contains no elements, without consuming it.
///
/// If called immediately after [`HeaderMap::insert`] or [`HeaderMap::remove`], it will indicate
/// wether any items were actually replaced or removed, respectively.
pub fn is_empty(&self) -> bool {
match self.inner {
// size hint lower bound of smallvec is the correct length
Some(ref iter) => iter.size_hint().0 == 0,
None => true,
}
}
}
impl Iterator for Removed {
@ -644,7 +694,11 @@ impl Iterator for Removed {
}
}
/// Iterator over all [`HeaderName`]s in the map.
impl ExactSizeIterator for Removed {}
impl iter::FusedIterator for Removed {}
/// Iterator over all names in the map.
#[derive(Debug)]
pub struct Keys<'a>(hash_map::Keys<'a, HeaderName, Value>);
@ -662,6 +716,11 @@ impl<'a> Iterator for Keys<'a> {
}
}
impl ExactSizeIterator for Keys<'_> {}
impl iter::FusedIterator for Keys<'_> {}
/// Iterator over borrowed name-value pairs.
#[derive(Debug)]
pub struct Iter<'a> {
inner: hash_map::Iter<'a, HeaderName, Value>,
@ -684,7 +743,7 @@ impl<'a> Iterator for Iter<'a> {
fn next(&mut self) -> Option<Self::Item> {
// handle in-progress multi value lists first
if let Some((ref name, ref mut vals)) = self.multi_inner {
if let Some((name, ref mut vals)) = self.multi_inner {
match vals.get(self.multi_idx) {
Some(val) => {
self.multi_idx += 1;
@ -713,6 +772,10 @@ impl<'a> Iterator for Iter<'a> {
}
}
impl ExactSizeIterator for Iter<'_> {}
impl iter::FusedIterator for Iter<'_> {}
/// Iterator over drained name-value pairs.
///
/// Iterator items are `(Option<HeaderName>, HeaderValue)` to avoid cloning.
@ -764,6 +827,10 @@ impl<'a> Iterator for Drain<'a> {
}
}
impl ExactSizeIterator for Drain<'_> {}
impl iter::FusedIterator for Drain<'_> {}
/// Iterator over owned name-value pairs.
///
/// Implementation necessarily clones header names for each value.
@ -814,12 +881,27 @@ impl Iterator for IntoIter {
}
}
impl ExactSizeIterator for IntoIter {}
impl iter::FusedIterator for IntoIter {}
#[cfg(test)]
mod tests {
use std::iter::FusedIterator;
use http::header;
use static_assertions::assert_impl_all;
use super::*;
assert_impl_all!(HeaderMap: IntoIterator);
assert_impl_all!(Keys<'_>: Iterator, ExactSizeIterator, FusedIterator);
assert_impl_all!(GetAll<'_>: Iterator, ExactSizeIterator, FusedIterator);
assert_impl_all!(Removed: Iterator, ExactSizeIterator, FusedIterator);
assert_impl_all!(Iter<'_>: Iterator, ExactSizeIterator, FusedIterator);
assert_impl_all!(IntoIter: Iterator, ExactSizeIterator, FusedIterator);
assert_impl_all!(Drain<'_>: Iterator, ExactSizeIterator, FusedIterator);
#[test]
fn create() {
let map = HeaderMap::new();
@ -945,6 +1027,56 @@ mod tests {
assert_eq!(vals.next(), removed.next().as_ref());
}
#[test]
fn get_all_iteration_order_matches_insertion_order() {
let mut map = HeaderMap::new();
let mut vals = map.get_all(header::COOKIE);
assert!(vals.next().is_none());
map.append(header::COOKIE, HeaderValue::from_static("1"));
let mut vals = map.get_all(header::COOKIE);
assert_eq!(vals.next().unwrap().as_bytes(), b"1");
assert!(vals.next().is_none());
map.append(header::COOKIE, HeaderValue::from_static("2"));
let mut vals = map.get_all(header::COOKIE);
assert_eq!(vals.next().unwrap().as_bytes(), b"1");
assert_eq!(vals.next().unwrap().as_bytes(), b"2");
assert!(vals.next().is_none());
map.append(header::COOKIE, HeaderValue::from_static("3"));
map.append(header::COOKIE, HeaderValue::from_static("4"));
map.append(header::COOKIE, HeaderValue::from_static("5"));
let mut vals = map.get_all(header::COOKIE);
assert_eq!(vals.next().unwrap().as_bytes(), b"1");
assert_eq!(vals.next().unwrap().as_bytes(), b"2");
assert_eq!(vals.next().unwrap().as_bytes(), b"3");
assert_eq!(vals.next().unwrap().as_bytes(), b"4");
assert_eq!(vals.next().unwrap().as_bytes(), b"5");
assert!(vals.next().is_none());
let _ = map.insert(header::COOKIE, HeaderValue::from_static("6"));
let mut vals = map.get_all(header::COOKIE);
assert_eq!(vals.next().unwrap().as_bytes(), b"6");
assert!(vals.next().is_none());
let _ = map.insert(header::COOKIE, HeaderValue::from_static("7"));
let _ = map.insert(header::COOKIE, HeaderValue::from_static("8"));
let mut vals = map.get_all(header::COOKIE);
assert_eq!(vals.next().unwrap().as_bytes(), b"8");
assert!(vals.next().is_none());
map.append(header::COOKIE, HeaderValue::from_static("9"));
let mut vals = map.get_all(header::COOKIE);
assert_eq!(vals.next().unwrap().as_bytes(), b"8");
assert_eq!(vals.next().unwrap().as_bytes(), b"9");
assert!(vals.next().is_none());
// check for fused-ness
assert!(vals.next().is_none());
}
fn owned_pair<'a>(
(name, val): (&'a HeaderName, &'a HeaderValue),
) -> (HeaderName, HeaderValue) {

View File

@ -29,35 +29,34 @@ pub use http::header::{
X_FRAME_OPTIONS, X_XSS_PROTECTION,
};
use crate::error::ParseError;
use crate::HttpMessage;
use crate::{error::ParseError, HttpMessage};
mod as_name;
mod into_pair;
mod into_value;
mod utils;
pub(crate) mod map;
pub mod map;
mod shared;
#[doc(hidden)]
pub use self::shared::*;
mod utils;
pub use self::as_name::AsHeaderName;
pub use self::into_pair::IntoHeaderPair;
pub use self::into_value::IntoHeaderValue;
#[doc(hidden)]
pub use self::map::GetAll;
pub use self::map::HeaderMap;
pub use self::utils::*;
pub use self::shared::{
parse_extended_value, q, Charset, ContentEncoding, ExtendedValue, HttpDate,
LanguageTag, Quality, QualityItem,
};
pub use self::utils::{
fmt_comma_delimited, from_comma_delimited, from_one_raw_str, http_percent_encode,
};
/// A trait for any object that already represents a valid header field and value.
/// An interface for types that already represent a valid header.
pub trait Header: IntoHeaderValue {
/// Returns the name of the header field
fn name() -> HeaderName;
/// Parse a header
fn parse<T: HttpMessage>(msg: &T) -> Result<Self, ParseError>;
fn parse<M: HttpMessage>(msg: &M) -> Result<Self, ParseError>;
}
/// Convert `http::HeaderMap` to our `HeaderMap`.
@ -68,7 +67,7 @@ impl From<http::HeaderMap> for HeaderMap {
}
/// This encode set is used for HTTP header values and is defined at
/// https://tools.ietf.org/html/rfc5987#section-3.2.
/// <https://datatracker.ietf.org/doc/html/rfc5987#section-3.2>.
pub(crate) const HTTP_VALUE: &AsciiSet = &CONTROLS
.add(b' ')
.add(b'"')

View File

@ -1,14 +1,13 @@
use std::fmt::{self, Display};
use std::str::FromStr;
use std::{fmt, str};
use self::Charset::*;
/// A Mime charset.
/// A MIME character set.
///
/// The string representation is normalized to upper case.
///
/// See <http://www.iana.org/assignments/character-sets/character-sets.xhtml>.
#[derive(Clone, Debug, PartialEq)]
#[derive(Debug, Clone, PartialEq, Eq)]
#[allow(non_camel_case_types)]
pub enum Charset {
/// US ASCII
@ -88,20 +87,20 @@ impl Charset {
Iso_8859_8_E => "ISO-8859-8-E",
Iso_8859_8_I => "ISO-8859-8-I",
Gb2312 => "GB2312",
Big5 => "big5",
Big5 => "Big5",
Koi8_R => "KOI8-R",
Ext(ref s) => s,
}
}
}
impl Display for Charset {
impl fmt::Display for Charset {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.label())
}
}
impl FromStr for Charset {
impl str::FromStr for Charset {
type Err = crate::Error;
fn from_str(s: &str) -> Result<Charset, crate::Error> {
@ -128,7 +127,7 @@ impl FromStr for Charset {
"ISO-8859-8-E" => Iso_8859_8_E,
"ISO-8859-8-I" => Iso_8859_8_I,
"GB2312" => Gb2312,
"big5" => Big5,
"BIG5" => Big5,
"KOI8-R" => Koi8_R,
s => Ext(s.to_owned()),
})

View File

@ -1,5 +1,6 @@
use std::{convert::Infallible, str::FromStr};
use std::{convert::TryFrom, str::FromStr};
use derive_more::{Display, Error};
use http::header::InvalidHeaderValue;
use crate::{
@ -8,8 +9,19 @@ use crate::{
HttpMessage,
};
/// Error returned when a content encoding is unknown.
#[derive(Debug, Display, Error)]
#[display(fmt = "unsupported content encoding")]
pub struct ContentEncodingParseError;
/// Represents a supported content encoding.
#[derive(Copy, Clone, PartialEq, Debug)]
///
/// Includes a commonly-used subset of media types appropriate for use as HTTP content encodings.
/// See [IANA HTTP Content Coding Registry].
///
/// [IANA HTTP Content Coding Registry]: https://www.iana.org/assignments/http-parameters/http-parameters.xhtml
#[derive(Debug, Clone, Copy, PartialEq)]
#[non_exhaustive]
pub enum ContentEncoding {
/// Automatically select encoding based on encoding negotiation.
Auto,
@ -23,7 +35,7 @@ pub enum ContentEncoding {
/// Gzip algorithm.
Gzip,
// Zstd algorithm.
/// Zstd algorithm.
Zstd,
/// Indicates the identity function (i.e. no compression, nor modification).
@ -37,7 +49,7 @@ impl ContentEncoding {
matches!(self, ContentEncoding::Identity | ContentEncoding::Auto)
}
/// Convert content encoding to string
/// Convert content encoding to string.
#[inline]
pub fn as_str(self) -> &'static str {
match self {
@ -48,18 +60,6 @@ impl ContentEncoding {
ContentEncoding::Identity | ContentEncoding::Auto => "identity",
}
}
/// Default Q-factor (quality) value.
#[inline]
pub fn quality(self) -> f64 {
match self {
ContentEncoding::Br => 1.1,
ContentEncoding::Gzip => 1.0,
ContentEncoding::Deflate => 0.9,
ContentEncoding::Identity | ContentEncoding::Auto => 0.1,
ContentEncoding::Zstd => 0.0,
}
}
}
impl Default for ContentEncoding {
@ -69,31 +69,33 @@ impl Default for ContentEncoding {
}
impl FromStr for ContentEncoding {
type Err = Infallible;
type Err = ContentEncodingParseError;
fn from_str(val: &str) -> Result<Self, Self::Err> {
Ok(Self::from(val))
}
}
impl From<&str> for ContentEncoding {
fn from(val: &str) -> ContentEncoding {
let val = val.trim();
if val.eq_ignore_ascii_case("br") {
ContentEncoding::Br
Ok(ContentEncoding::Br)
} else if val.eq_ignore_ascii_case("gzip") {
ContentEncoding::Gzip
Ok(ContentEncoding::Gzip)
} else if val.eq_ignore_ascii_case("deflate") {
ContentEncoding::Deflate
Ok(ContentEncoding::Deflate)
} else if val.eq_ignore_ascii_case("zstd") {
ContentEncoding::Zstd
Ok(ContentEncoding::Zstd)
} else {
ContentEncoding::default()
Err(ContentEncodingParseError)
}
}
}
impl TryFrom<&str> for ContentEncoding {
type Error = ContentEncodingParseError;
fn try_from(val: &str) -> Result<Self, Self::Error> {
val.parse()
}
}
impl IntoHeaderValue for ContentEncoding {
type Error = InvalidHeaderValue;

View File

@ -1,17 +1,17 @@
//! Originally taken from `hyper::header::parsing`.
use std::{fmt, str::FromStr};
use language_tags::LanguageTag;
use crate::header::{Charset, HTTP_VALUE};
// From hyper v0.11.27 src/header/parsing.rs
/// The value part of an extended parameter consisting of three parts:
/// - The REQUIRED character set name (`charset`).
/// - The OPTIONAL language information (`language_tag`).
/// - A character sequence representing the actual value (`value`), separated by single quotes.
///
/// It is defined in [RFC 5987](https://tools.ietf.org/html/rfc5987#section-3.2).
/// It is defined in [RFC 5987 §3.2](https://datatracker.ietf.org/doc/html/rfc5987#section-3.2).
#[derive(Clone, Debug, PartialEq)]
pub struct ExtendedValue {
/// The character set that is used to encode the `value` to a string.
@ -24,17 +24,17 @@ pub struct ExtendedValue {
pub value: Vec<u8>,
}
/// Parses extended header parameter values (`ext-value`), as defined in
/// [RFC 5987](https://tools.ietf.org/html/rfc5987#section-3.2).
/// Parses extended header parameter values (`ext-value`), as defined
/// in [RFC 5987 §3.2](https://datatracker.ietf.org/doc/html/rfc5987#section-3.2).
///
/// Extended values are denoted by parameter names that end with `*`.
///
/// ## ABNF
///
/// ```text
/// ```plain
/// ext-value = charset "'" [ language ] "'" value-chars
/// ; like RFC 2231's <extended-initial-value>
/// ; (see [RFC2231], Section 7)
/// ; (see [RFC 2231 §7])
///
/// charset = "UTF-8" / "ISO-8859-1" / mime-charset
///
@ -43,22 +43,26 @@ pub struct ExtendedValue {
/// / "!" / "#" / "$" / "%" / "&"
/// / "+" / "-" / "^" / "_" / "`"
/// / "{" / "}" / "~"
/// ; as <mime-charset> in Section 2.3 of [RFC2978]
/// ; as <mime-charset> in [RFC 2978 §2.3]
/// ; except that the single quote is not included
/// ; SHOULD be registered in the IANA charset registry
///
/// language = <Language-Tag, defined in [RFC5646], Section 2.1>
/// language = <Language-Tag, defined in [RFC 5646 §2.1]>
///
/// value-chars = *( pct-encoded / attr-char )
///
/// pct-encoded = "%" HEXDIG HEXDIG
/// ; see [RFC3986], Section 2.1
/// ; see [RFC 3986 §2.1]
///
/// attr-char = ALPHA / DIGIT
/// / "!" / "#" / "$" / "&" / "+" / "-" / "."
/// / "^" / "_" / "`" / "|" / "~"
/// ; token except ( "*" / "'" / "%" )
/// ```
///
/// [RFC 2231 §7]: https://datatracker.ietf.org/doc/html/rfc2231#section-7
/// [RFC 2978 §2.3]: https://datatracker.ietf.org/doc/html/rfc2978#section-2.3
/// [RFC 3986 §2.1]: https://datatracker.ietf.org/doc/html/rfc5646#section-2.1
pub fn parse_extended_value(
val: &str,
) -> Result<ExtendedValue, crate::error::ParseError> {

View File

@ -0,0 +1,82 @@
use std::{fmt, io::Write, str::FromStr, time::SystemTime};
use bytes::BytesMut;
use http::header::{HeaderValue, InvalidHeaderValue};
use crate::{
config::DATE_VALUE_LENGTH, error::ParseError, header::IntoHeaderValue,
helpers::MutWriter,
};
/// A timestamp with HTTP-style formatting and parsing.
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct HttpDate(SystemTime);
impl FromStr for HttpDate {
type Err = ParseError;
fn from_str(s: &str) -> Result<HttpDate, ParseError> {
match httpdate::parse_http_date(s) {
Ok(sys_time) => Ok(HttpDate(sys_time)),
Err(_) => Err(ParseError::Header),
}
}
}
impl fmt::Display for HttpDate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let date_str = httpdate::fmt_http_date(self.0);
f.write_str(&date_str)
}
}
impl IntoHeaderValue for HttpDate {
type Error = InvalidHeaderValue;
fn try_into_value(self) -> Result<HeaderValue, Self::Error> {
let mut buf = BytesMut::with_capacity(DATE_VALUE_LENGTH);
let mut wrt = MutWriter(&mut buf);
// unwrap: date output is known to be well formed and of known length
write!(wrt, "{}", httpdate::fmt_http_date(self.0)).unwrap();
HeaderValue::from_maybe_shared(buf.split().freeze())
}
}
impl From<SystemTime> for HttpDate {
fn from(sys_time: SystemTime) -> HttpDate {
HttpDate(sys_time)
}
}
impl From<HttpDate> for SystemTime {
fn from(HttpDate(sys_time): HttpDate) -> SystemTime {
sys_time
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
#[test]
fn date_header() {
macro_rules! assert_parsed_date {
($case:expr, $exp:expr) => {
assert_eq!($case.parse::<HttpDate>().unwrap(), $exp);
};
}
// 784198117 = SystemTime::from(datetime!(1994-11-07 08:48:37).assume_utc()).duration_since(SystemTime::UNIX_EPOCH));
let nov_07 = HttpDate(SystemTime::UNIX_EPOCH + Duration::from_secs(784198117));
assert_parsed_date!("Mon, 07 Nov 1994 08:48:37 GMT", nov_07);
assert_parsed_date!("Monday, 07-Nov-94 08:48:37 GMT", nov_07);
assert_parsed_date!("Mon Nov 7 08:48:37 1994", nov_07);
assert!("this-is-no-date".parse::<HttpDate>().is_err());
}
}

View File

@ -1,97 +0,0 @@
use std::{
fmt,
io::Write,
str::FromStr,
time::{SystemTime, UNIX_EPOCH},
};
use bytes::buf::BufMut;
use bytes::BytesMut;
use http::header::{HeaderValue, InvalidHeaderValue};
use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
use crate::error::ParseError;
use crate::header::IntoHeaderValue;
use crate::time_parser;
/// A timestamp with HTTP formatting and parsing.
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct HttpDate(OffsetDateTime);
impl FromStr for HttpDate {
type Err = ParseError;
fn from_str(s: &str) -> Result<HttpDate, ParseError> {
match time_parser::parse_http_date(s) {
Some(t) => Ok(HttpDate(t.assume_utc())),
None => Err(ParseError::Header),
}
}
}
impl fmt::Display for HttpDate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.0.format("%a, %d %b %Y %H:%M:%S GMT"), f)
}
}
impl From<SystemTime> for HttpDate {
fn from(sys: SystemTime) -> HttpDate {
HttpDate(PrimitiveDateTime::from(sys).assume_utc())
}
}
impl IntoHeaderValue for HttpDate {
type Error = InvalidHeaderValue;
fn try_into_value(self) -> Result<HeaderValue, Self::Error> {
let mut wrt = BytesMut::with_capacity(29).writer();
write!(
wrt,
"{}",
self.0
.to_offset(UtcOffset::UTC)
.format("%a, %d %b %Y %H:%M:%S GMT")
)
.unwrap();
HeaderValue::from_maybe_shared(wrt.get_mut().split().freeze())
}
}
impl From<HttpDate> for SystemTime {
fn from(date: HttpDate) -> SystemTime {
let dt = date.0;
let epoch = OffsetDateTime::unix_epoch();
UNIX_EPOCH + (dt - epoch)
}
}
#[cfg(test)]
mod tests {
use super::HttpDate;
use time::{date, time, PrimitiveDateTime};
#[test]
fn test_date() {
let nov_07 = HttpDate(
PrimitiveDateTime::new(date!(1994 - 11 - 07), time!(8:48:37)).assume_utc(),
);
assert_eq!(
"Sun, 07 Nov 1994 08:48:37 GMT".parse::<HttpDate>().unwrap(),
nov_07
);
assert_eq!(
"Sunday, 07-Nov-94 08:48:37 GMT"
.parse::<HttpDate>()
.unwrap(),
nov_07
);
assert_eq!(
"Sun Nov 7 08:48:37 1994".parse::<HttpDate>().unwrap(),
nov_07
);
assert!("this-is-no-date".parse::<HttpDate>().is_err());
}
}

View File

@ -3,12 +3,14 @@
mod charset;
mod content_encoding;
mod extended;
mod httpdate;
mod http_date;
mod quality;
mod quality_item;
pub use self::charset::Charset;
pub use self::content_encoding::ContentEncoding;
pub use self::extended::{parse_extended_value, ExtendedValue};
pub use self::httpdate::HttpDate;
pub use self::quality_item::{q, qitem, Quality, QualityItem};
pub use self::http_date::HttpDate;
pub use self::quality::{q, Quality};
pub use self::quality_item::QualityItem;
pub use language_tags::LanguageTag;

View File

@ -0,0 +1,208 @@
use std::{
convert::{TryFrom, TryInto},
fmt,
};
use derive_more::{Display, Error};
const MAX_QUALITY_INT: u16 = 1000;
const MAX_QUALITY_FLOAT: f32 = 1.0;
/// Represents a quality used in q-factor values.
///
/// The default value is equivalent to `q=1.0` (the [max](Self::MAX) value).
///
/// # Implementation notes
/// The quality value is defined as a number between 0.0 and 1.0 with three decimal places.
/// This means there are 1001 possible values. Since floating point numbers are not exact and the
/// smallest floating point data type (`f32`) consumes four bytes, we use an `u16` value to store
/// the quality internally.
///
/// [RFC 7231 §5.3.1] gives more information on quality values in HTTP header fields.
///
/// # Examples
/// ```
/// use actix_http::header::{Quality, q};
/// assert_eq!(q(1.0), Quality::MAX);
///
/// assert_eq!(q(0.42).to_string(), "0.42");
/// assert_eq!(q(1.0).to_string(), "1");
/// assert_eq!(Quality::MIN.to_string(), "0");
/// ```
///
/// [RFC 7231 §5.3.1]: https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.1
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct Quality(pub(super) u16);
impl Quality {
/// The maximum quality value, equivalent to `q=1.0`.
pub const MAX: Quality = Quality(MAX_QUALITY_INT);
/// The minimum quality value, equivalent to `q=0.0`.
pub const MIN: Quality = Quality(0);
/// Converts a float in the range 0.01.0 to a `Quality`.
///
/// Intentionally private. External uses should rely on the `TryFrom` impl.
///
/// # Panics
/// Panics in debug mode when value is not in the range 0.0 <= n <= 1.0.
fn from_f32(value: f32) -> Self {
// Check that `value` is within range should be done before calling this method.
// Just in case, this debug_assert should catch if we were forgetful.
debug_assert!(
(0.0f32..=1.0f32).contains(&value),
"q value must be between 0.0 and 1.0"
);
Quality((value * MAX_QUALITY_INT as f32) as u16)
}
}
/// The default value is [`Quality::MAX`].
impl Default for Quality {
fn default() -> Quality {
Quality::MAX
}
}
impl fmt::Display for Quality {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
0 => f.write_str("0"),
MAX_QUALITY_INT => f.write_str("1"),
// some number in the range 1999
x => {
f.write_str("0.")?;
// This implementation avoids string allocation for removing trailing zeroes.
// In benchmarks it is twice as fast as approach using something like
// `format!("{}").trim_end_matches('0')` for non-fast-path quality values.
if x < 10 {
// x in is range 19
f.write_str("00")?;
// 0 is already handled so it's not possible to have a trailing 0 in this range
// we can just write the integer
itoa::fmt(f, x)
} else if x < 100 {
// x in is range 1099
f.write_str("0")?;
if x % 10 == 0 {
// trailing 0, divide by 10 and write
itoa::fmt(f, x / 10)
} else {
itoa::fmt(f, x)
}
} else {
// x is in range 100999
if x % 100 == 0 {
// two trailing 0s, divide by 100 and write
itoa::fmt(f, x / 100)
} else if x % 10 == 0 {
// one trailing 0, divide by 10 and write
itoa::fmt(f, x / 10)
} else {
itoa::fmt(f, x)
}
}
}
}
}
}
#[derive(Debug, Clone, Display, Error)]
#[display(fmt = "quality out of bounds")]
#[non_exhaustive]
pub struct QualityOutOfBounds;
impl TryFrom<f32> for Quality {
type Error = QualityOutOfBounds;
#[inline]
fn try_from(value: f32) -> Result<Self, Self::Error> {
if (0.0..=MAX_QUALITY_FLOAT).contains(&value) {
Ok(Quality::from_f32(value))
} else {
Err(QualityOutOfBounds)
}
}
}
/// Convenience function to create a [`Quality`] from an `f32` (0.01.0).
///
/// Not recommended for use with user input. Rely on the `TryFrom` impls where possible.
///
/// # Panics
/// Panics if value is out of range.
///
/// # Examples
/// ```
/// # use actix_http::header::{q, Quality};
/// let q1 = q(1.0);
/// assert_eq!(q1, Quality::MAX);
///
/// let q2 = q(0.0);
/// assert_eq!(q2, Quality::MIN);
///
/// let q3 = q(0.42);
/// ```
///
/// An out-of-range `f32` quality will panic.
/// ```should_panic
/// # use actix_http::header::q;
/// let _q2 = q(1.42);
/// ```
#[inline]
pub fn q<T>(quality: T) -> Quality
where
T: TryInto<Quality>,
T::Error: fmt::Debug,
{
quality.try_into().expect("quality value was out of bounds")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn q_helper() {
assert_eq!(q(0.5), Quality(500));
}
#[test]
fn display_output() {
assert_eq!(q(0.0).to_string(), "0");
assert_eq!(q(1.0).to_string(), "1");
assert_eq!(q(0.001).to_string(), "0.001");
assert_eq!(q(0.5).to_string(), "0.5");
assert_eq!(q(0.22).to_string(), "0.22");
assert_eq!(q(0.123).to_string(), "0.123");
assert_eq!(q(0.999).to_string(), "0.999");
for x in 0..=1000 {
// if trailing zeroes are handled correctly, we would not expect the serialized length
// to ever exceed "0." + 3 decimal places = 5 in length
assert!(q(x as f32 / 1000.0).to_string().len() <= 5);
}
}
#[test]
#[should_panic]
fn negative_quality() {
q(-1.0);
}
#[test]
#[should_panic]
fn quality_out_of_bounds() {
q(2.0);
}
}

View File

@ -1,101 +1,65 @@
use std::{
cmp,
convert::{TryFrom, TryInto},
fmt, str,
};
use std::{cmp, convert::TryFrom as _, fmt, str};
use derive_more::{Display, Error};
use crate::error::ParseError;
const MAX_QUALITY: u16 = 1000;
const MAX_FLOAT_QUALITY: f32 = 1.0;
use super::Quality;
/// Represents a quality used in quality values.
/// Represents an item with a quality value as defined
/// in [RFC 7231 §5.3.1](https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.1).
///
/// Can be created with the [`q`] function.
/// # Parsing and Formatting
/// This wrapper be used to parse header value items that have a q-factor annotation as well as
/// serialize items with a their q-factor.
///
/// # Implementation notes
/// # Ordering
/// Since this context of use for this type is header value items, ordering is defined for
/// `QualityItem`s but _only_ considers the item's quality. Order of appearance should be used as
/// the secondary sorting parameter; i.e., a stable sort over the quality values will produce a
/// correctly sorted sequence.
///
/// The quality value is defined as a number between 0 and 1 with three decimal
/// places. This means there are 1001 possible values. Since floating point
/// numbers are not exact and the smallest floating point data type (`f32`)
/// consumes four bytes, hyper uses an `u16` value to store the
/// quality internally. For performance reasons you may set quality directly to
/// a value between 0 and 1000 e.g. `Quality(532)` matches the quality
/// `q=0.532`.
/// # Examples
/// ```
/// # use actix_http::header::{QualityItem, q};
/// let q_item: QualityItem<String> = "hello;q=0.3".parse().unwrap();
/// assert_eq!(&q_item.item, "hello");
/// assert_eq!(q_item.quality, q(0.3));
///
/// [RFC7231 Section 5.3.1](https://tools.ietf.org/html/rfc7231#section-5.3.1)
/// gives more information on quality values in HTTP header fields.
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct Quality(u16);
impl Quality {
/// # Panics
/// Panics in debug mode when value is not in the range 0.0 <= n <= 1.0.
fn from_f32(value: f32) -> Self {
// Check that `value` is within range should be done before calling this method.
// Just in case, this debug_assert should catch if we were forgetful.
debug_assert!(
(0.0f32..=1.0f32).contains(&value),
"q value must be between 0.0 and 1.0"
);
Quality((value * MAX_QUALITY as f32) as u16)
}
}
impl Default for Quality {
fn default() -> Quality {
Quality(MAX_QUALITY)
}
}
#[derive(Debug, Clone, Display, Error)]
pub struct QualityOutOfBounds;
impl TryFrom<u16> for Quality {
type Error = QualityOutOfBounds;
fn try_from(value: u16) -> Result<Self, Self::Error> {
if (0..=MAX_QUALITY).contains(&value) {
Ok(Quality(value))
} else {
Err(QualityOutOfBounds)
}
}
}
impl TryFrom<f32> for Quality {
type Error = QualityOutOfBounds;
fn try_from(value: f32) -> Result<Self, Self::Error> {
if (0.0..=MAX_FLOAT_QUALITY).contains(&value) {
Ok(Quality::from_f32(value))
} else {
Err(QualityOutOfBounds)
}
}
}
/// Represents an item with a quality value as defined in
/// [RFC7231](https://tools.ietf.org/html/rfc7231#section-5.3.1).
#[derive(Clone, PartialEq, Debug)]
/// // note that format is normalized compared to parsed item
/// assert_eq!(q_item.to_string(), "hello; q=0.3");
///
/// // item with q=0.3 is greater than item with q=0.1
/// let q_item_fallback: QualityItem<String> = "abc;q=0.1".parse().unwrap();
/// assert!(q_item > q_item_fallback);
/// ```
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct QualityItem<T> {
/// The actual contents of the field.
/// The wrapped contents of the field.
pub item: T,
/// The quality (client or server preference) for the value.
pub quality: Quality,
}
impl<T> QualityItem<T> {
/// Creates a new `QualityItem` from an item and a quality.
/// The item can be of any type.
/// The quality should be a value in the range [0, 1].
pub fn new(item: T, quality: Quality) -> QualityItem<T> {
/// Constructs a new `QualityItem` from an item and a quality value.
///
/// The item can be of any type. The quality should be a value in the range [0, 1].
pub fn new(item: T, quality: Quality) -> Self {
QualityItem { item, quality }
}
/// Constructs a new `QualityItem` from an item, using the maximum q-value.
pub fn max(item: T) -> Self {
Self::new(item, Quality::MAX)
}
/// Constructs a new `QualityItem` from an item, using the minimum q-value.
pub fn min(item: T) -> Self {
Self::new(item, Quality::MIN)
}
}
impl<T: PartialEq> cmp::PartialOrd for QualityItem<T> {
impl<T: PartialEq> PartialOrd for QualityItem<T> {
fn partial_cmp(&self, other: &QualityItem<T>) -> Option<cmp::Ordering> {
self.quality.partial_cmp(&other.quality)
}
@ -105,97 +69,77 @@ impl<T: fmt::Display> fmt::Display for QualityItem<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.item, f)?;
match self.quality.0 {
MAX_QUALITY => Ok(()),
0 => f.write_str("; q=0"),
x => write!(f, "; q=0.{}", format!("{:03}", x).trim_end_matches('0')),
match self.quality {
// q-factor value is implied for max value
Quality::MAX => Ok(()),
Quality::MIN => f.write_str("; q=0"),
q => write!(f, "; q={}", q),
}
}
}
impl<T: str::FromStr> str::FromStr for QualityItem<T> {
type Err = crate::error::ParseError;
type Err = ParseError;
fn from_str(qitem_str: &str) -> Result<QualityItem<T>, crate::error::ParseError> {
if !qitem_str.is_ascii() {
return Err(crate::error::ParseError::Header);
fn from_str(q_item_str: &str) -> Result<Self, Self::Err> {
if !q_item_str.is_ascii() {
return Err(ParseError::Header);
}
// Set defaults used if parsing fails.
let mut raw_item = qitem_str;
let mut quality = 1f32;
// set defaults used if quality-item parsing fails, i.e., item has no q attribute
let mut raw_item = q_item_str;
let mut quality = Quality::MAX;
let parts: Vec<_> = qitem_str.rsplitn(2, ';').map(str::trim).collect();
let parts = q_item_str
.rsplit_once(';')
.map(|(item, q_attr)| (item.trim(), q_attr.trim()));
if parts.len() == 2 {
if let Some((val, q_attr)) = parts {
// example for item with q-factor:
//
// gzip; q=0.65
// ^^^^^^ parts[0]
// ^^ start
// ^^^^ q_val
// ^^^^ parts[1]
// gzip;q=0.65
// ^^^^ val
// ^^^^^^ q_attr
// ^^ q
// ^^^^ q_val
if parts[0].len() < 2 {
if q_attr.len() < 2 {
// Can't possibly be an attribute since an attribute needs at least a name followed
// by an equals sign. And bare identifiers are forbidden.
return Err(crate::error::ParseError::Header);
return Err(ParseError::Header);
}
let start = &parts[0][0..2];
let q = &q_attr[0..2];
if start == "q=" || start == "Q=" {
let q_val = &parts[0][2..];
if q == "q=" || q == "Q=" {
let q_val = &q_attr[2..];
if q_val.len() > 5 {
// longer than 5 indicates an over-precise q-factor
return Err(crate::error::ParseError::Header);
return Err(ParseError::Header);
}
let q_value = q_val
.parse::<f32>()
.map_err(|_| crate::error::ParseError::Header)?;
let q_value = q_val.parse::<f32>().map_err(|_| ParseError::Header)?;
let q_value =
Quality::try_from(q_value).map_err(|_| ParseError::Header)?;
if (0f32..=1f32).contains(&q_value) {
quality = q_value;
raw_item = parts[1];
} else {
return Err(crate::error::ParseError::Header);
}
quality = q_value;
raw_item = val;
}
}
let item = raw_item
.parse::<T>()
.map_err(|_| crate::error::ParseError::Header)?;
let item = raw_item.parse::<T>().map_err(|_| ParseError::Header)?;
// we already checked above that the quality is within range
Ok(QualityItem::new(item, Quality::from_f32(quality)))
Ok(QualityItem::new(item, quality))
}
}
/// Convenience function to wrap a value in a `QualityItem`
/// Sets `q` to the default 1.0
pub fn qitem<T>(item: T) -> QualityItem<T> {
QualityItem::new(item, Quality::default())
}
/// Convenience function to create a `Quality` from a float or integer.
///
/// Implemented for `u16` and `f32`. Panics if value is out of range.
pub fn q<T>(val: T) -> Quality
where
T: TryInto<Quality>,
T::Error: fmt::Debug,
{
// TODO: on next breaking change, handle unwrap differently
val.try_into().unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
// copy of encoding from actix-web headers
#[allow(clippy::enum_variant_names)] // allow Encoding prefix on EncodingExt
#[derive(Clone, PartialEq, Debug)]
pub enum Encoding {
Chunked,
@ -244,7 +188,7 @@ mod tests {
#[test]
fn test_quality_item_fmt_q_1() {
use Encoding::*;
let x = qitem(Chunked);
let x = QualityItem::max(Chunked);
assert_eq!(format!("{}", x), "chunked");
}
#[test]
@ -343,25 +287,8 @@ mod tests {
fn test_quality_item_ordering() {
let x: QualityItem<Encoding> = "gzip; q=0.5".parse().ok().unwrap();
let y: QualityItem<Encoding> = "gzip; q=0.273".parse().ok().unwrap();
let comparision_result: bool = x.gt(&y);
assert!(comparision_result)
}
#[test]
fn test_quality() {
assert_eq!(q(0.5), Quality(500));
}
#[test]
#[should_panic]
fn test_quality_invalid() {
q(-1.0);
}
#[test]
#[should_panic]
fn test_quality_invalid2() {
q(2.0);
let comparison_result: bool = x.gt(&y);
assert!(comparison_result)
}
#[test]

View File

@ -1,3 +1,5 @@
//! Header parsing utilities.
use std::{fmt, str::FromStr};
use super::HeaderValue;
@ -10,9 +12,12 @@ where
I: Iterator<Item = &'a HeaderValue> + 'a,
T: FromStr,
{
let mut result = Vec::new();
let size_guess = all.size_hint().1.unwrap_or(2);
let mut result = Vec::with_capacity(size_guess);
for h in all {
let s = h.to_str().map_err(|_| ParseError::Header)?;
result.extend(
s.split(',')
.filter_map(|x| match x.trim() {
@ -22,6 +27,7 @@ where
.filter_map(|x| x.trim().parse().ok()),
)
}
Ok(result)
}
@ -30,10 +36,12 @@ where
pub fn from_one_raw_str<T: FromStr>(val: Option<&HeaderValue>) -> Result<T, ParseError> {
if let Some(line) = val {
let line = line.to_str().map_err(|_| ParseError::Header)?;
if !line.is_empty() {
return T::from_str(line).or(Err(ParseError::Header));
}
}
Err(ParseError::Header)
}
@ -44,19 +52,53 @@ where
T: fmt::Display,
{
let mut iter = parts.iter();
if let Some(part) = iter.next() {
fmt::Display::fmt(part, f)?;
}
for part in iter {
f.write_str(", ")?;
fmt::Display::fmt(part, f)?;
}
Ok(())
}
/// Percent encode a sequence of bytes with a character set defined in
/// <https://tools.ietf.org/html/rfc5987#section-3.2>
/// Percent encode a sequence of bytes with a character set defined in [RFC 5987 §3.2].
///
/// [RFC 5987 §3.2]: https://datatracker.ietf.org/doc/html/rfc5987#section-3.2
#[inline]
pub fn http_percent_encode(f: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result {
let encoded = percent_encoding::percent_encode(bytes, HTTP_VALUE);
fmt::Display::fmt(&encoded, f)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn comma_delimited_parsing() {
let headers = vec![];
let res: Vec<usize> = from_comma_delimited(headers.iter()).unwrap();
assert_eq!(res, vec![0; 0]);
let headers = vec![
HeaderValue::from_static("1, 2"),
HeaderValue::from_static("3,4"),
];
let res: Vec<usize> = from_comma_delimited(headers.iter()).unwrap();
assert_eq!(res, vec![1, 2, 3, 4]);
let headers = vec![
HeaderValue::from_static(""),
HeaderValue::from_static(","),
HeaderValue::from_static(" "),
HeaderValue::from_static("1 ,"),
HeaderValue::from_static(""),
];
let res: Vec<usize> = from_comma_delimited(headers.iter()).unwrap();
assert_eq!(res, vec![1]);
}
}

View File

@ -14,7 +14,7 @@
//! [rustls]: https://crates.io/crates/rustls
//! [trust-dns]: https://crates.io/crates/trust-dns
#![deny(rust_2018_idioms, nonstandard_style)]
#![deny(rust_2018_idioms, nonstandard_style, clippy::uninit_assumed_init)]
#![allow(
clippy::type_complexity,
clippy::too_many_arguments,
@ -29,7 +29,6 @@ extern crate log;
pub mod body;
mod builder;
pub mod client;
mod config;
#[cfg(feature = "__compress")]
@ -44,7 +43,6 @@ mod request;
mod response;
mod response_builder;
mod service;
mod time_parser;
pub mod error;
pub mod h1;
@ -104,14 +102,9 @@ type ConnectCallback<IO> = dyn Fn(&IO, &mut CloneableExtensions);
///
/// # Implementation Details
/// Uses Option to reduce necessary allocations when merging with request extensions.
#[derive(Default)]
pub(crate) struct OnConnectData(Option<CloneableExtensions>);
impl Default for OnConnectData {
fn default() -> Self {
Self(None)
}
}
impl OnConnectData {
/// Construct by calling the on-connect callback with the underlying transport I/O.
pub(crate) fn from_io<T>(

View File

@ -46,8 +46,8 @@ pub trait Head: Default + 'static {
#[derive(Debug)]
pub struct RequestHead {
pub uri: Uri,
pub method: Method,
pub uri: Uri,
pub version: Version,
pub headers: HeaderMap,
pub extensions: RefCell<Extensions>,
@ -58,13 +58,13 @@ pub struct RequestHead {
impl Default for RequestHead {
fn default() -> RequestHead {
RequestHead {
uri: Uri::default(),
method: Method::default(),
uri: Uri::default(),
version: Version::HTTP_11,
headers: HeaderMap::with_capacity(16),
flags: Flags::empty(),
peer_addr: None,
extensions: RefCell::new(Extensions::new()),
peer_addr: None,
flags: Flags::empty(),
}
}
}
@ -192,6 +192,7 @@ impl RequestHead {
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
pub enum RequestHeadType {
Owned(RequestHead),
Rc(Rc<RequestHead>, Option<HeaderMap>),
@ -209,7 +210,7 @@ impl RequestHeadType {
impl AsRef<RequestHead> for RequestHeadType {
fn as_ref(&self) -> &RequestHead {
match self {
RequestHeadType::Owned(head) => &head,
RequestHeadType::Owned(head) => head,
RequestHeadType::Rc(head, _) => head.as_ref(),
}
}
@ -317,7 +318,7 @@ impl ResponseHead {
}
#[inline]
pub(crate) fn ctype(&self) -> Option<ConnectionType> {
pub(crate) fn conn_type(&self) -> Option<ConnectionType> {
if self.flags.contains(Flags::CLOSE) {
Some(ConnectionType::Close)
} else if self.flags.contains(Flags::KEEP_ALIVE) {
@ -363,7 +364,7 @@ impl<T: Head> std::ops::Deref for Message<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.head.as_ref()
self.head.as_ref()
}
}

View File

@ -6,14 +6,15 @@ use std::{
};
use bytes::{Bytes, BytesMut};
use bytestring::ByteString;
use crate::{
body::{AnyBody, MessageBody},
error::Error,
body::{BoxBody, MessageBody},
extensions::Extensions,
header::{self, IntoHeaderValue},
http::{HeaderMap, StatusCode},
message::{BoxedResponseHead, ResponseHead},
ResponseBuilder,
Error, ResponseBuilder,
};
/// An HTTP response.
@ -22,13 +23,13 @@ pub struct Response<B> {
pub(crate) body: B,
}
impl Response<AnyBody> {
impl Response<BoxBody> {
/// Constructs a new response with default body.
#[inline]
pub fn new(status: StatusCode) -> Self {
Response {
head: BoxedResponseHead::new(status),
body: AnyBody::Empty,
body: BoxBody::new(()),
}
}
@ -189,6 +190,14 @@ impl<B> Response<B> {
}
}
#[inline]
pub fn map_into_boxed_body(self) -> Response<BoxBody>
where
B: MessageBody + 'static,
{
self.map_body(|_, body| BoxBody::new(body))
}
/// Returns body, consuming this response.
pub fn into_body(self) -> B {
self.body
@ -223,81 +232,99 @@ impl<B: Default> Default for Response<B> {
}
}
impl<I: Into<Response<AnyBody>>, E: Into<Error>> From<Result<I, E>>
for Response<AnyBody>
impl<I: Into<Response<BoxBody>>, E: Into<Error>> From<Result<I, E>>
for Response<BoxBody>
{
fn from(res: Result<I, E>) -> Self {
match res {
Ok(val) => val.into(),
Err(err) => err.into().into(),
Err(err) => Response::from(err.into()),
}
}
}
impl From<ResponseBuilder> for Response<AnyBody> {
impl From<ResponseBuilder> for Response<BoxBody> {
fn from(mut builder: ResponseBuilder) -> Self {
builder.finish()
builder.finish().map_into_boxed_body()
}
}
impl From<std::convert::Infallible> for Response<AnyBody> {
impl From<std::convert::Infallible> for Response<BoxBody> {
fn from(val: std::convert::Infallible) -> Self {
match val {}
}
}
impl From<&'static str> for Response<AnyBody> {
impl From<&'static str> for Response<&'static str> {
fn from(val: &'static str) -> Self {
Response::build(StatusCode::OK)
.content_type(mime::TEXT_PLAIN_UTF_8)
.body(val)
let mut res = Response::with_body(StatusCode::OK, val);
let mime = mime::TEXT_PLAIN_UTF_8.try_into_value().unwrap();
res.headers_mut().insert(header::CONTENT_TYPE, mime);
res
}
}
impl From<&'static [u8]> for Response<AnyBody> {
impl From<&'static [u8]> for Response<&'static [u8]> {
fn from(val: &'static [u8]) -> Self {
Response::build(StatusCode::OK)
.content_type(mime::APPLICATION_OCTET_STREAM)
.body(val)
let mut res = Response::with_body(StatusCode::OK, val);
let mime = mime::APPLICATION_OCTET_STREAM.try_into_value().unwrap();
res.headers_mut().insert(header::CONTENT_TYPE, mime);
res
}
}
impl From<String> for Response<AnyBody> {
impl From<String> for Response<String> {
fn from(val: String) -> Self {
Response::build(StatusCode::OK)
.content_type(mime::TEXT_PLAIN_UTF_8)
.body(val)
let mut res = Response::with_body(StatusCode::OK, val);
let mime = mime::TEXT_PLAIN_UTF_8.try_into_value().unwrap();
res.headers_mut().insert(header::CONTENT_TYPE, mime);
res
}
}
impl<'a> From<&'a String> for Response<AnyBody> {
fn from(val: &'a String) -> Self {
Response::build(StatusCode::OK)
.content_type(mime::TEXT_PLAIN_UTF_8)
.body(val)
impl From<&String> for Response<String> {
fn from(val: &String) -> Self {
let mut res = Response::with_body(StatusCode::OK, val.clone());
let mime = mime::TEXT_PLAIN_UTF_8.try_into_value().unwrap();
res.headers_mut().insert(header::CONTENT_TYPE, mime);
res
}
}
impl From<Bytes> for Response<AnyBody> {
impl From<Bytes> for Response<Bytes> {
fn from(val: Bytes) -> Self {
Response::build(StatusCode::OK)
.content_type(mime::APPLICATION_OCTET_STREAM)
.body(val)
let mut res = Response::with_body(StatusCode::OK, val);
let mime = mime::APPLICATION_OCTET_STREAM.try_into_value().unwrap();
res.headers_mut().insert(header::CONTENT_TYPE, mime);
res
}
}
impl From<BytesMut> for Response<AnyBody> {
impl From<BytesMut> for Response<BytesMut> {
fn from(val: BytesMut) -> Self {
Response::build(StatusCode::OK)
.content_type(mime::APPLICATION_OCTET_STREAM)
.body(val)
let mut res = Response::with_body(StatusCode::OK, val);
let mime = mime::APPLICATION_OCTET_STREAM.try_into_value().unwrap();
res.headers_mut().insert(header::CONTENT_TYPE, mime);
res
}
}
impl From<ByteString> for Response<ByteString> {
fn from(val: ByteString) -> Self {
let mut res = Response::with_body(StatusCode::OK, val);
let mime = mime::TEXT_PLAIN_UTF_8.try_into_value().unwrap();
res.headers_mut().insert(header::CONTENT_TYPE, mime);
res
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::http::header::{HeaderValue, CONTENT_TYPE, COOKIE};
use crate::{
body::to_bytes,
http::header::{HeaderValue, CONTENT_TYPE, COOKIE},
};
#[test]
fn test_debug() {
@ -309,73 +336,73 @@ mod tests {
assert!(dbg.contains("Response"));
}
#[test]
fn test_into_response() {
let resp: Response<AnyBody> = "test".into();
assert_eq!(resp.status(), StatusCode::OK);
#[actix_rt::test]
async fn test_into_response() {
let res = Response::from("test");
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
res.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().get_ref(), b"test");
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]);
let resp: Response<AnyBody> = b"test".as_ref().into();
assert_eq!(resp.status(), StatusCode::OK);
let res = Response::from(b"test".as_ref());
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
res.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().get_ref(), b"test");
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]);
let resp: Response<AnyBody> = "test".to_owned().into();
assert_eq!(resp.status(), StatusCode::OK);
let res = Response::from("test".to_owned());
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
res.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().get_ref(), b"test");
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]);
let resp: Response<AnyBody> = (&"test".to_owned()).into();
assert_eq!(resp.status(), StatusCode::OK);
let res = Response::from("test".to_owned());
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
res.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("text/plain; charset=utf-8")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().get_ref(), b"test");
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]);
let b = Bytes::from_static(b"test");
let resp: Response<AnyBody> = b.into();
assert_eq!(resp.status(), StatusCode::OK);
let res = Response::from(b);
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
res.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().get_ref(), b"test");
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]);
let b = Bytes::from_static(b"test");
let resp: Response<AnyBody> = b.into();
assert_eq!(resp.status(), StatusCode::OK);
let res = Response::from(b);
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
res.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().get_ref(), b"test");
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]);
let b = BytesMut::from("test");
let resp: Response<AnyBody> = b.into();
assert_eq!(resp.status(), StatusCode::OK);
let res = Response::from(b);
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
res.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/octet-stream")
);
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().get_ref(), b"test");
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]);
}
}

View File

@ -2,19 +2,11 @@
use std::{
cell::{Ref, RefMut},
error::Error as StdError,
fmt,
future::Future,
pin::Pin,
str,
task::{Context, Poll},
fmt, str,
};
use bytes::Bytes;
use futures_core::Stream;
use crate::{
body::{AnyBody, BodyStream},
body::{EitherBody, MessageBody},
error::{Error, HttpError},
header::{self, IntoHeaderPair, IntoHeaderValue},
message::{BoxedResponseHead, ConnectionType, ResponseHead},
@ -235,10 +227,14 @@ impl ResponseBuilder {
/// Generate response with a wrapped body.
///
/// This `ResponseBuilder` will be left in a useless state.
#[inline]
pub fn body<B: Into<AnyBody>>(&mut self, body: B) -> Response<AnyBody> {
self.message_body(body.into())
.unwrap_or_else(Response::from)
pub fn body<B>(&mut self, body: B) -> Response<EitherBody<B>>
where
B: MessageBody + 'static,
{
match self.message_body(body) {
Ok(res) => res.map_body(|_, body| EitherBody::left(body)),
Err(err) => Response::from(err).map_body(|_, body| EitherBody::right(body)),
}
}
/// Generate response with a body.
@ -253,24 +249,12 @@ impl ResponseBuilder {
Ok(Response { head, body })
}
/// Generate response with a streaming body.
///
/// This `ResponseBuilder` will be left in a useless state.
#[inline]
pub fn streaming<S, E>(&mut self, stream: S) -> Response<AnyBody>
where
S: Stream<Item = Result<Bytes, E>> + 'static,
E: Into<Box<dyn StdError>> + 'static,
{
self.body(AnyBody::from_message(BodyStream::new(stream)))
}
/// Generate response with an empty body.
///
/// This `ResponseBuilder` will be left in a useless state.
#[inline]
pub fn finish(&mut self) -> Response<AnyBody> {
self.body(AnyBody::Empty)
pub fn finish(&mut self) -> Response<EitherBody<()>> {
self.body(())
}
/// Create an owned `ResponseBuilder`, leaving the original in a useless state.
@ -327,14 +311,6 @@ impl<'a> From<&'a ResponseHead> for ResponseBuilder {
}
}
impl Future for ResponseBuilder {
type Output = Result<Response<AnyBody>, Error>;
fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
Poll::Ready(Ok(self.finish()))
}
}
impl fmt::Debug for ResponseBuilder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let head = self.head.as_ref().unwrap();
@ -356,8 +332,9 @@ impl fmt::Debug for ResponseBuilder {
#[cfg(test)]
mod tests {
use bytes::Bytes;
use super::*;
use crate::body::Body;
use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE};
#[test]
@ -383,20 +360,28 @@ mod tests {
#[test]
fn test_force_close() {
let resp = Response::build(StatusCode::OK).force_close().finish();
assert!(!resp.keep_alive())
assert!(!resp.keep_alive());
}
#[test]
fn test_content_type() {
let resp = Response::build(StatusCode::OK)
.content_type("text/plain")
.body(Body::Empty);
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain")
.body(Bytes::new());
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain");
let resp = Response::build(StatusCode::OK)
.content_type(mime::APPLICATION_JAVASCRIPT_UTF_8)
.body(Bytes::new());
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
"application/javascript; charset=utf-8"
);
}
#[test]
fn test_into_builder() {
let mut resp: Response<Body> = "test".into();
let mut resp: Response<_> = "test".into();
assert_eq!(resp.status(), StatusCode::OK);
resp.headers_mut().insert(

View File

@ -1,5 +1,4 @@
use std::{
error::Error as StdError,
fmt,
future::Future,
marker::PhantomData,
@ -9,18 +8,16 @@ use std::{
task::{Context, Poll},
};
use ::h2::server::{handshake as h2_handshake, Handshake as H2Handshake};
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_rt::net::TcpStream;
use actix_service::{
fn_service, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt as _,
};
use bytes::Bytes;
use futures_core::{future::LocalBoxFuture, ready};
use pin_project::pin_project;
use pin_project_lite::pin_project;
use crate::{
body::{AnyBody, MessageBody},
body::{BoxBody, MessageBody},
builder::HttpServiceBuilder,
config::{KeepAlive, ServiceConfig},
error::DispatchError,
@ -40,7 +37,7 @@ pub struct HttpService<T, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler> {
impl<T, S, B> HttpService<T, S, B>
where
S: ServiceFactory<Request, Config = ()>,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service<Request>>::Future: 'static,
@ -55,12 +52,11 @@ where
impl<T, S, B> HttpService<T, S, B>
where
S: ServiceFactory<Request, Config = ()>,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service<Request>>::Future: 'static,
B: MessageBody + 'static,
B::Error: Into<Box<dyn StdError>>,
{
/// Create new `HttpService` instance.
pub fn new<F: IntoServiceFactory<S, Request>>(service: F) -> Self {
@ -95,7 +91,7 @@ where
impl<T, S, B, X, U> HttpService<T, S, B, X, U>
where
S: ServiceFactory<Request, Config = ()>,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service<Request>>::Future: 'static,
@ -109,7 +105,7 @@ where
pub fn expect<X1>(self, expect: X1) -> HttpService<T, S, B, X1, U>
where
X1: ServiceFactory<Request, Config = (), Response = Request>,
X1::Error: Into<Response<AnyBody>>,
X1::Error: Into<Response<BoxBody>>,
X1::InitError: fmt::Debug,
{
HttpService {
@ -153,17 +149,16 @@ impl<S, B, X, U> HttpService<TcpStream, S, B, X, U>
where
S: ServiceFactory<Request, Config = ()>,
S::Future: 'static,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service<Request>>::Future: 'static,
B: MessageBody + 'static,
B::Error: Into<Box<dyn StdError>>,
X: ServiceFactory<Request, Config = (), Response = Request>,
X::Future: 'static,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
X::InitError: fmt::Debug,
U: ServiceFactory<
@ -172,7 +167,7 @@ where
Response = (),
>,
U::Future: 'static,
U::Error: fmt::Display + Into<Response<AnyBody>>,
U::Error: fmt::Display + Into<Response<BoxBody>>,
U::InitError: fmt::Debug,
{
/// Create simple tcp stream service
@ -195,9 +190,14 @@ where
#[cfg(feature = "openssl")]
mod openssl {
use actix_service::ServiceFactoryExt;
use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, TlsStream};
use actix_tls::accept::TlsError;
use actix_service::ServiceFactoryExt as _;
use actix_tls::accept::{
openssl::{
reexports::{Error as SslError, SslAcceptor},
Acceptor, TlsStream,
},
TlsError,
};
use super::*;
@ -205,17 +205,16 @@ mod openssl {
where
S: ServiceFactory<Request, Config = ()>,
S::Future: 'static,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service<Request>>::Future: 'static,
B: MessageBody + 'static,
B::Error: Into<Box<dyn StdError>>,
X: ServiceFactory<Request, Config = (), Response = Request>,
X::Future: 'static,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
X::InitError: fmt::Debug,
U: ServiceFactory<
@ -224,10 +223,10 @@ mod openssl {
Response = (),
>,
U::Future: 'static,
U::Error: fmt::Display + Into<Response<AnyBody>>,
U::Error: fmt::Display + Into<Response<BoxBody>>,
U::InitError: fmt::Debug,
{
/// Create openssl based service
/// Create OpenSSL based service.
pub fn openssl(
self,
acceptor: SslAcceptor,
@ -239,9 +238,11 @@ mod openssl {
InitError = (),
> {
Acceptor::new(acceptor)
.map_err(TlsError::Tls)
.map_init_err(|_| panic!())
.and_then(|io: TlsStream<TcpStream>| async {
.map_init_err(|_| {
unreachable!("TLS acceptor service factory does not error on init")
})
.map_err(TlsError::into_service_error)
.map(|io: TlsStream<TcpStream>| {
let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() {
if protos.windows(2).any(|window| window == b"h2") {
Protocol::Http2
@ -251,8 +252,9 @@ mod openssl {
} else {
Protocol::Http1
};
let peer_addr = io.get_ref().peer_addr().ok();
Ok((io, proto, peer_addr))
(io, proto, peer_addr)
})
.and_then(self.map_err(TlsError::Service))
}
@ -263,27 +265,28 @@ mod openssl {
mod rustls {
use std::io;
use actix_tls::accept::rustls::{Acceptor, ServerConfig, Session, TlsStream};
use actix_tls::accept::TlsError;
use actix_service::ServiceFactoryExt as _;
use actix_tls::accept::{
rustls::{reexports::ServerConfig, Acceptor, TlsStream},
TlsError,
};
use super::*;
use actix_service::ServiceFactoryExt;
impl<S, B, X, U> HttpService<TlsStream<TcpStream>, S, B, X, U>
where
S: ServiceFactory<Request, Config = ()>,
S::Future: 'static,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service<Request>>::Future: 'static,
B: MessageBody + 'static,
B::Error: Into<Box<dyn StdError>>,
X: ServiceFactory<Request, Config = (), Response = Request>,
X::Future: 'static,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
X::InitError: fmt::Debug,
U: ServiceFactory<
@ -292,10 +295,10 @@ mod rustls {
Response = (),
>,
U::Future: 'static,
U::Error: fmt::Display + Into<Response<AnyBody>>,
U::Error: fmt::Display + Into<Response<BoxBody>>,
U::InitError: fmt::Debug,
{
/// Create rustls based service
/// Create Rustls based service.
pub fn rustls(
self,
mut config: ServerConfig,
@ -308,14 +311,15 @@ mod rustls {
> {
let mut protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
protos.extend_from_slice(&config.alpn_protocols);
config.set_protocols(&protos);
config.alpn_protocols = protos;
Acceptor::new(config)
.map_err(TlsError::Tls)
.map_init_err(|_| panic!())
.map_init_err(|_| {
unreachable!("TLS acceptor service factory does not error on init")
})
.map_err(TlsError::into_service_error)
.and_then(|io: TlsStream<TcpStream>| async {
let proto = if let Some(protos) = io.get_ref().1.get_alpn_protocol()
{
let proto = if let Some(protos) = io.get_ref().1.alpn_protocol() {
if protos.windows(2).any(|window| window == b"h2") {
Protocol::Http2
} else {
@ -339,22 +343,21 @@ where
S: ServiceFactory<Request, Config = ()>,
S::Future: 'static,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service<Request>>::Future: 'static,
B: MessageBody + 'static,
B::Error: Into<Box<dyn StdError>>,
X: ServiceFactory<Request, Config = (), Response = Request>,
X::Future: 'static,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
X::InitError: fmt::Debug,
U: ServiceFactory<(Request, Framed<T, h1::Codec>), Config = (), Response = ()>,
U::Future: 'static,
U::Error: fmt::Display + Into<Response<AnyBody>>,
U::Error: fmt::Display + Into<Response<BoxBody>>,
U::InitError: fmt::Debug,
{
type Response = ();
@ -417,11 +420,11 @@ where
impl<T, S, B, X, U> HttpServiceHandler<T, S, B, X, U>
where
S: Service<Request>,
S::Error: Into<Response<AnyBody>>,
S::Error: Into<Response<BoxBody>>,
X: Service<Request>,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
U: Service<(Request, Framed<T, h1::Codec>)>,
U::Error: Into<Response<AnyBody>>,
U::Error: Into<Response<BoxBody>>,
{
pub(super) fn new(
cfg: ServiceConfig,
@ -441,7 +444,7 @@ where
pub(super) fn _poll_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(), Response<AnyBody>>> {
) -> Poll<Result<(), Response<BoxBody>>> {
ready!(self.flow.expect.poll_ready(cx).map_err(Into::into))?;
ready!(self.flow.service.poll_ready(cx).map_err(Into::into))?;
@ -477,18 +480,17 @@ where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::Future: 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
B::Error: Into<Box<dyn StdError>>,
X: Service<Request, Response = Request>,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display + Into<Response<AnyBody>>,
U::Error: fmt::Display + Into<Response<BoxBody>>,
{
type Response = ();
type Error = DispatchError;
@ -510,23 +512,27 @@ where
match proto {
Protocol::Http2 => HttpServiceHandlerResponse {
state: State::H2Handshake(Some((
h2_handshake(io),
self.cfg.clone(),
self.flow.clone(),
on_connect_data,
peer_addr,
))),
state: State::H2Handshake {
handshake: Some((
h2::handshake_with_timeout(io, &self.cfg),
self.cfg.clone(),
self.flow.clone(),
on_connect_data,
peer_addr,
)),
},
},
Protocol::Http1 => HttpServiceHandlerResponse {
state: State::H1(h1::Dispatcher::new(
io,
self.cfg.clone(),
self.flow.clone(),
on_connect_data,
peer_addr,
)),
state: State::H1 {
dispatcher: h1::Dispatcher::new(
io,
self.cfg.clone(),
self.flow.clone(),
on_connect_data,
peer_addr,
),
},
},
proto => unimplemented!("Unsupported HTTP version: {:?}.", proto),
@ -534,58 +540,65 @@ where
}
}
#[pin_project(project = StateProj)]
enum State<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin,
pin_project! {
#[project = StateProj]
enum State<T, S, B, X, U>
where
T: AsyncRead,
T: AsyncWrite,
T: Unpin,
S: Service<Request>,
S::Future: 'static,
S::Error: Into<Response<AnyBody>>,
S: Service<Request>,
S::Future: 'static,
S::Error: Into<Response<BoxBody>>,
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
B: MessageBody,
X: Service<Request, Response = Request>,
X::Error: Into<Response<AnyBody>>,
X: Service<Request, Response = Request>,
X::Error: Into<Response<BoxBody>>,
U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display,
{
H1(#[pin] h1::Dispatcher<T, S, B, X, U>),
H2(#[pin] h2::Dispatcher<T, S, B, X, U>),
H2Handshake(
Option<(
H2Handshake<T, Bytes>,
ServiceConfig,
Rc<HttpFlow<S, X, U>>,
OnConnectData,
Option<net::SocketAddr>,
)>,
),
U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display,
{
H1 { #[pin] dispatcher: h1::Dispatcher<T, S, B, X, U> },
H2 { #[pin] dispatcher: h2::Dispatcher<T, S, B, X, U> },
H2Handshake {
handshake: Option<(
h2::HandshakeWithTimeout<T>,
ServiceConfig,
Rc<HttpFlow<S, X, U>>,
OnConnectData,
Option<net::SocketAddr>,
)>,
},
}
}
#[pin_project]
pub struct HttpServiceHandlerResponse<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin,
pin_project! {
pub struct HttpServiceHandlerResponse<T, S, B, X, U>
where
T: AsyncRead,
T: AsyncWrite,
T: Unpin,
S: Service<Request>,
S::Error: Into<Response<AnyBody>> + 'static,
S::Future: 'static,
S::Response: Into<Response<B>> + 'static,
S: Service<Request>,
S::Error: Into<Response<BoxBody>>,
S::Error: 'static,
S::Future: 'static,
S::Response: Into<Response<B>>,
S::Response: 'static,
B: MessageBody,
B::Error: Into<Box<dyn StdError>>,
B: MessageBody,
X: Service<Request, Response = Request>,
X::Error: Into<Response<AnyBody>>,
X: Service<Request, Response = Request>,
X::Error: Into<Response<BoxBody>>,
U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display,
{
#[pin]
state: State<T, S, B, X, U>,
U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display,
{
#[pin]
state: State<T, S, B, X, U>,
}
}
impl<T, S, B, X, U> Future for HttpServiceHandlerResponse<T, S, B, X, U>
@ -593,15 +606,14 @@ where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>,
S::Error: Into<Response<AnyBody>> + 'static,
S::Error: Into<Response<BoxBody>> + 'static,
S::Future: 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
B::Error: Into<Box<dyn StdError>>,
X: Service<Request, Response = Request>,
X::Error: Into<Response<AnyBody>>,
X::Error: Into<Response<BoxBody>>,
U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display,
@ -610,27 +622,29 @@ where
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().project().state.project() {
StateProj::H1(disp) => disp.poll(cx),
StateProj::H2(disp) => disp.poll(cx),
StateProj::H2Handshake(data) => {
StateProj::H1 { dispatcher } => dispatcher.poll(cx),
StateProj::H2 { dispatcher } => dispatcher.poll(cx),
StateProj::H2Handshake { handshake: data } => {
match ready!(Pin::new(&mut data.as_mut().unwrap().0).poll(cx)) {
Ok(conn) => {
let (_, cfg, srv, on_connect_data, peer_addr) =
Ok((conn, timer)) => {
let (_, config, flow, on_connect_data, peer_addr) =
data.take().unwrap();
self.as_mut().project().state.set(State::H2(
h2::Dispatcher::new(
srv,
self.as_mut().project().state.set(State::H2 {
dispatcher: h2::Dispatcher::new(
flow,
conn,
on_connect_data,
cfg,
config,
peer_addr,
timer,
),
));
});
self.poll(cx)
}
Err(err) => {
trace!("H2 handshake error: {}", err);
Poll::Ready(Err(err.into()))
Poll::Ready(Err(err))
}
}
}

View File

@ -1,72 +0,0 @@
use time::{Date, OffsetDateTime, PrimitiveDateTime};
/// Attempt to parse a `time` string as one of either RFC 1123, RFC 850, or asctime.
pub(crate) fn parse_http_date(time: &str) -> Option<PrimitiveDateTime> {
try_parse_rfc_1123(time)
.or_else(|| try_parse_rfc_850(time))
.or_else(|| try_parse_asctime(time))
}
/// Attempt to parse a `time` string as a RFC 1123 formatted date time string.
///
/// Eg: `Fri, 12 Feb 2021 00:14:29 GMT`
fn try_parse_rfc_1123(time: &str) -> Option<PrimitiveDateTime> {
time::parse(time, "%a, %d %b %Y %H:%M:%S").ok()
}
/// Attempt to parse a `time` string as a RFC 850 formatted date time string.
///
/// Eg: `Wednesday, 11-Jan-21 13:37:41 UTC`
fn try_parse_rfc_850(time: &str) -> Option<PrimitiveDateTime> {
let dt = PrimitiveDateTime::parse(time, "%A, %d-%b-%y %H:%M:%S").ok()?;
// If the `time` string contains a two-digit year, then as per RFC 2616 § 19.3,
// we consider the year as part of this century if it's within the next 50 years,
// otherwise we consider as part of the previous century.
let now = OffsetDateTime::now_utc();
let century_start_year = (now.year() / 100) * 100;
let mut expanded_year = century_start_year + dt.year();
if expanded_year > now.year() + 50 {
expanded_year -= 100;
}
let date = Date::try_from_ymd(expanded_year, dt.month(), dt.day()).ok()?;
Some(PrimitiveDateTime::new(date, dt.time()))
}
/// Attempt to parse a `time` string using ANSI C's `asctime` format.
///
/// Eg: `Wed Feb 13 15:46:11 2013`
fn try_parse_asctime(time: &str) -> Option<PrimitiveDateTime> {
time::parse(time, "%a %b %_d %H:%M:%S %Y").ok()
}
#[cfg(test)]
mod tests {
use time::{date, time};
use super::*;
#[test]
fn test_rfc_850_year_shift() {
let date = try_parse_rfc_850("Friday, 19-Nov-82 16:14:55 EST").unwrap();
assert_eq!(date, date!(1982 - 11 - 19).with_time(time!(16:14:55)));
let date = try_parse_rfc_850("Wednesday, 11-Jan-62 13:37:41 EST").unwrap();
assert_eq!(date, date!(2062 - 01 - 11).with_time(time!(13:37:41)));
let date = try_parse_rfc_850("Wednesday, 11-Jan-21 13:37:41 EST").unwrap();
assert_eq!(date, date!(2021 - 01 - 11).with_time(time!(13:37:41)));
let date = try_parse_rfc_850("Wednesday, 11-Jan-23 13:37:41 EST").unwrap();
assert_eq!(date, date!(2023 - 01 - 11).with_time(time!(13:37:41)));
let date = try_parse_rfc_850("Wednesday, 11-Jan-99 13:37:41 EST").unwrap();
assert_eq!(date, date!(1999 - 01 - 11).with_time(time!(13:37:41)));
let date = try_parse_rfc_850("Wednesday, 11-Jan-00 13:37:41 EST").unwrap();
assert_eq!(date, date!(2000 - 01 - 11).with_time(time!(13:37:41)));
}
}

View File

@ -63,8 +63,8 @@ pub enum Item {
Last(Bytes),
}
#[derive(Debug, Copy, Clone)]
/// WebSocket protocol codec.
#[derive(Debug, Clone)]
pub struct Codec {
flags: Flags,
max_size: usize,
@ -89,7 +89,8 @@ impl Codec {
/// Set max frame size.
///
/// By default max size is set to 64kB.
/// By default max size is set to 64KiB.
#[must_use = "This returns the a new Codec, without modifying the original."]
pub fn max_size(mut self, size: usize) -> Self {
self.max_size = size;
self
@ -98,12 +99,19 @@ impl Codec {
/// Set decoder to client mode.
///
/// By default decoder works in server mode.
#[must_use = "This returns the a new Codec, without modifying the original."]
pub fn client_mode(mut self) -> Self {
self.flags.remove(Flags::SERVER);
self
}
}
impl Default for Codec {
fn default() -> Self {
Self::new()
}
}
impl Encoder<Message> for Codec {
type Error = ProtocolError;

View File

@ -4,17 +4,21 @@ use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_service::{IntoService, Service};
use pin_project_lite::pin_project;
use super::{Codec, Frame, Message};
#[pin_project::pin_project]
pub struct Dispatcher<S, T>
where
S: Service<Frame, Response = Message> + 'static,
T: AsyncRead + AsyncWrite,
{
#[pin]
inner: inner::Dispatcher<S, T, Codec, Message>,
pin_project! {
pub struct Dispatcher<S, T>
where
S: Service<Frame, Response = Message>,
S: 'static,
T: AsyncRead,
T: AsyncWrite,
{
#[pin]
inner: inner::Dispatcher<S, T, Codec, Message>,
}
}
impl<S, T> Dispatcher<S, T>
@ -72,7 +76,7 @@ mod inner {
use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed};
use crate::{body::AnyBody, Response};
use crate::{body::BoxBody, Response};
/// Framed transport errors
pub enum DispatcherError<E, U, I>
@ -136,7 +140,7 @@ mod inner {
}
}
impl<E, U, I> From<DispatcherError<E, U, I>> for Response<AnyBody>
impl<E, U, I> From<DispatcherError<E, U, I>> for Response<BoxBody>
where
E: fmt::Debug + fmt::Display,
U: Encoder<I> + Decoder,
@ -144,7 +148,7 @@ mod inner {
<U as Decoder>::Error: fmt::Debug,
{
fn from(err: DispatcherError<E, U, I>) -> Self {
Response::internal_server_error().set_body(AnyBody::from(err.to_string()))
Response::internal_server_error().set_body(BoxBody::new(err.to_string()))
}
}

View File

@ -25,8 +25,8 @@ pub fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) {
//
// un aligned prefix and suffix would be mask/unmask per byte.
// proper aligned middle slice goes into fast path and operates on 4-byte blocks.
let (mut prefix, words, mut suffix) = unsafe { buf.align_to_mut::<u32>() };
apply_mask_fallback(&mut prefix, mask);
let (prefix, words, suffix) = unsafe { buf.align_to_mut::<u32>() };
apply_mask_fallback(prefix, mask);
let head = prefix.len() & 3;
let mask_u32 = if head > 0 {
if cfg!(target_endian = "big") {
@ -40,7 +40,7 @@ pub fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) {
for word in words.iter_mut() {
*word ^= mask_u32;
}
apply_mask_fallback(&mut suffix, mask_u32.to_ne_bytes());
apply_mask_fallback(suffix, mask_u32.to_ne_bytes());
}
#[cfg(test)]

View File

@ -8,9 +8,9 @@ use std::io;
use derive_more::{Display, Error, From};
use http::{header, Method, StatusCode};
use crate::body::BoxBody;
use crate::{
body::AnyBody, header::HeaderValue, message::RequestHead, response::Response,
ResponseBuilder,
header::HeaderValue, message::RequestHead, response::Response, ResponseBuilder,
};
mod codec;
@ -69,7 +69,7 @@ pub enum ProtocolError {
}
/// WebSocket handshake errors
#[derive(Debug, PartialEq, Display, Error)]
#[derive(Debug, Clone, Copy, PartialEq, Display, Error)]
pub enum HandshakeError {
/// Only get method is allowed.
#[display(fmt = "Method not allowed.")]
@ -96,8 +96,8 @@ pub enum HandshakeError {
BadWebsocketKey,
}
impl From<&HandshakeError> for Response<AnyBody> {
fn from(err: &HandshakeError) -> Self {
impl From<HandshakeError> for Response<BoxBody> {
fn from(err: HandshakeError) -> Self {
match err {
HandshakeError::GetMethodRequired => {
let mut res = Response::new(StatusCode::METHOD_NOT_ALLOWED);
@ -139,9 +139,9 @@ impl From<&HandshakeError> for Response<AnyBody> {
}
}
impl From<HandshakeError> for Response<AnyBody> {
fn from(err: HandshakeError) -> Self {
(&err).into()
impl From<&HandshakeError> for Response<BoxBody> {
fn from(err: &HandshakeError) -> Self {
(*err).into()
}
}
@ -210,7 +210,6 @@ pub fn handshake_response(req: &RequestHead) -> ResponseBuilder {
Response::build(StatusCode::SWITCHING_PROTOCOLS)
.upgrade("websocket")
.insert_header((header::TRANSFER_ENCODING, "chunked"))
.insert_header((
header::SEC_WEBSOCKET_ACCEPT,
// key is known to be header value safe ascii
@ -221,9 +220,10 @@ pub fn handshake_response(req: &RequestHead) -> ResponseBuilder {
#[cfg(test)]
mod tests {
use crate::{header, Method};
use super::*;
use crate::{body::AnyBody, test::TestRequest};
use http::{header, Method};
use crate::test::TestRequest;
#[test]
fn test_handshake() {
@ -337,17 +337,17 @@ mod tests {
#[test]
fn test_ws_error_http_response() {
let resp: Response<AnyBody> = HandshakeError::GetMethodRequired.into();
let resp: Response<BoxBody> = HandshakeError::GetMethodRequired.into();
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
let resp: Response<AnyBody> = HandshakeError::NoWebsocketUpgrade.into();
let resp: Response<BoxBody> = HandshakeError::NoWebsocketUpgrade.into();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: Response<AnyBody> = HandshakeError::NoConnectionUpgrade.into();
let resp: Response<BoxBody> = HandshakeError::NoConnectionUpgrade.into();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: Response<AnyBody> = HandshakeError::NoVersionHeader.into();
let resp: Response<BoxBody> = HandshakeError::NoVersionHeader.into();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: Response<AnyBody> = HandshakeError::UnsupportedVersion.into();
let resp: Response<BoxBody> = HandshakeError::UnsupportedVersion.into();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let resp: Response<AnyBody> = HandshakeError::BadWebsocketKey.into();
let resp: Response<BoxBody> = HandshakeError::BadWebsocketKey.into();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
}

View File

@ -3,7 +3,9 @@ use std::{
fmt,
};
/// Operation codes as part of RFC6455.
/// Operation codes defined in [RFC 6455 §11.8].
///
/// [RFC 6455]: https://datatracker.ietf.org/doc/html/rfc6455#section-11.8
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
pub enum OpCode {
/// Indicates a continuation frame of a fragmented message.
@ -105,7 +107,7 @@ pub enum CloseCode {
Abnormal,
/// Indicates that an endpoint is terminating the connection because it has received data within
/// a message that was not consistent with the type of the message (e.g., non-UTF-8 \[RFC3629\]
/// a message that was not consistent with the type of the message (e.g., non-UTF-8 \[RFC 3629\]
/// data within a text message).
Invalid,
@ -220,7 +222,8 @@ impl<T: Into<String>> From<(CloseCode, T)> for CloseReason {
}
}
/// The WebSocket GUID as stated in the spec. See https://tools.ietf.org/html/rfc6455#section-1.3.
/// The WebSocket GUID as stated in the spec.
/// See <https://datatracker.ietf.org/doc/html/rfc6455#section-1.3>.
static WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
/// Hashes the `Sec-WebSocket-Key` header according to the WebSocket spec.

View File

@ -1,7 +1,7 @@
use std::convert::Infallible;
use actix_http::{
body::AnyBody, http, http::StatusCode, HttpMessage, HttpService, Request, Response,
body::BoxBody, http, http::StatusCode, HttpMessage, HttpService, Request, Response,
};
use actix_http_test::test_server;
use actix_service::ServiceFactoryExt;
@ -99,7 +99,7 @@ async fn test_with_query_parameter() {
#[display(fmt = "expect failed")]
struct ExpectFailed;
impl From<ExpectFailed> for Response<AnyBody> {
impl From<ExpectFailed> for Response<BoxBody> {
fn from(_: ExpectFailed) -> Self {
Response::new(StatusCode::EXPECTATION_FAILED)
}

View File

@ -0,0 +1,153 @@
use std::io;
use actix_http::{error::Error, HttpService, Response};
use actix_server::Server;
use tokio::io::AsyncWriteExt;
#[actix_rt::test]
async fn h2_ping_pong() -> io::Result<()> {
let (tx, rx) = std::sync::mpsc::sync_channel(1);
let lst = std::net::TcpListener::bind("127.0.0.1:0")?;
let addr = lst.local_addr().unwrap();
let join = std::thread::spawn(move || {
actix_rt::System::new().block_on(async move {
let srv = Server::build()
.disable_signals()
.workers(1)
.listen("h2_ping_pong", lst, || {
HttpService::build()
.keep_alive(3)
.h2(|_| async { Ok::<_, Error>(Response::ok()) })
.tcp()
})?
.run();
tx.send(srv.handle()).unwrap();
srv.await
})
});
let handle = rx.recv().unwrap();
let (sync_tx, rx) = std::sync::mpsc::sync_channel(1);
// use a separate thread for h2 client so it can be blocked.
std::thread::spawn(move || {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async move {
let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut tx, conn) = h2::client::handshake(stream).await.unwrap();
tokio::spawn(async move { conn.await.unwrap() });
let (res, _) = tx.send_request(::http::Request::new(()), true).unwrap();
let res = res.await.unwrap();
assert_eq!(res.status().as_u16(), 200);
sync_tx.send(()).unwrap();
// intentionally block the client thread so it can not answer ping pong.
std::thread::sleep(std::time::Duration::from_secs(1000));
})
});
rx.recv().unwrap();
let now = std::time::Instant::now();
// stop server gracefully. this step would take up to 30 seconds.
handle.stop(true).await;
// join server thread. only when connection are all gone this step would finish.
join.join().unwrap()?;
// check the time used for join server thread so it's known that the server shutdown
// is from keep alive and not server graceful shutdown timeout.
assert!(now.elapsed() < std::time::Duration::from_secs(30));
Ok(())
}
#[actix_rt::test]
async fn h2_handshake_timeout() -> io::Result<()> {
let (tx, rx) = std::sync::mpsc::sync_channel(1);
let lst = std::net::TcpListener::bind("127.0.0.1:0")?;
let addr = lst.local_addr().unwrap();
let join = std::thread::spawn(move || {
actix_rt::System::new().block_on(async move {
let srv = Server::build()
.disable_signals()
.workers(1)
.listen("h2_ping_pong", lst, || {
HttpService::build()
.keep_alive(30)
// set first request timeout to 5 seconds.
// this is the timeout used for http2 handshake.
.client_timeout(5000)
.h2(|_| async { Ok::<_, Error>(Response::ok()) })
.tcp()
})?
.run();
tx.send(srv.handle()).unwrap();
srv.await
})
});
let handle = rx.recv().unwrap();
let (sync_tx, rx) = std::sync::mpsc::sync_channel(1);
// use a separate thread for tcp client so it can be blocked.
std::thread::spawn(move || {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async move {
let mut stream = tokio::net::TcpStream::connect(addr).await.unwrap();
// do not send the last new line intentionally.
// This should hang the server handshake
let malicious_buf = b"PRI * HTTP/2.0\r\n\r\nSM\r\n";
stream.write_all(malicious_buf).await.unwrap();
stream.flush().await.unwrap();
sync_tx.send(()).unwrap();
// intentionally block the client thread so it sit idle and not do handshake.
std::thread::sleep(std::time::Duration::from_secs(1000));
drop(stream)
})
});
rx.recv().unwrap();
let now = std::time::Instant::now();
// stop server gracefully. this step would take up to 30 seconds.
handle.stop(true).await;
// join server thread. only when connection are all gone this step would finish.
join.join().unwrap()?;
// check the time used for join server thread so it's known that the server shutdown
// is from handshake timeout and not server graceful shutdown timeout.
assert!(now.elapsed() < std::time::Duration::from_secs(30));
Ok(())
}

View File

@ -5,10 +5,10 @@ extern crate tls_openssl as openssl;
use std::{convert::Infallible, io};
use actix_http::{
body::{AnyBody, Body, SizedStream},
body::{BodyStream, BoxBody, SizedStream},
error::PayloadError,
http::{
header::{self, HeaderName, HeaderValue},
header::{self, HeaderValue},
Method, StatusCode, Version,
},
Error, HttpMessage, HttpService, Request, Response,
@ -143,38 +143,25 @@ async fn test_h2_content_length() {
})
.await;
let header = HeaderName::from_static("content-length");
let value = HeaderValue::from_static("0");
static VALUE: HeaderValue = HeaderValue::from_static("0");
{
for &i in &[0] {
let req = srv
.request(Method::HEAD, srv.surl(&format!("/{}", i)))
.send();
let _response = req.await.expect_err("should timeout on recv 1xx frame");
// assert_eq!(response.headers().get(&header), None);
let req = srv.request(Method::HEAD, srv.surl("/0")).send();
req.await.expect_err("should timeout on recv 1xx frame");
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let _response = req.await.expect_err("should timeout on recv 1xx frame");
// assert_eq!(response.headers().get(&header), None);
}
let req = srv.request(Method::GET, srv.surl("/0")).send();
req.await.expect_err("should timeout on recv 1xx frame");
for &i in &[1] {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);
}
let req = srv.request(Method::GET, srv.surl("/1")).send();
let response = req.await.unwrap();
assert!(response.headers().get("content-length").is_none());
for &i in &[2, 3] {
let req = srv
.request(Method::GET, srv.surl(&format!("/{}", i)))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), Some(&value));
assert_eq!(response.headers().get("content-length"), Some(&VALUE));
}
}
}
@ -361,7 +348,7 @@ async fn test_h2_body_chunked_explicit() {
ok::<_, Infallible>(
Response::build(StatusCode::OK)
.insert_header((header::TRANSFER_ENCODING, "chunked"))
.streaming(body),
.body(BodyStream::new(body)),
)
})
.openssl(tls_config())
@ -412,9 +399,11 @@ async fn test_h2_response_http_error_handling() {
#[display(fmt = "error")]
struct BadRequest;
impl From<BadRequest> for Response<AnyBody> {
impl From<BadRequest> for Response<BoxBody> {
fn from(err: BadRequest) -> Self {
Response::build(StatusCode::BAD_REQUEST).body(err.to_string())
Response::build(StatusCode::BAD_REQUEST)
.body(err.to_string())
.map_into_boxed_body()
}
}
@ -422,7 +411,7 @@ impl From<BadRequest> for Response<AnyBody> {
async fn test_h2_service_error() {
let mut srv = test_server(move || {
HttpService::build()
.h2(|_| err::<Response<Body>, _>(BadRequest))
.h2(|_| err::<Response<BoxBody>, _>(BadRequest))
.openssl(tls_config())
.map_err(|_| ())
})

View File

@ -3,14 +3,14 @@
extern crate tls_rustls as rustls;
use std::{
convert::Infallible,
convert::{Infallible, TryFrom},
io::{self, BufReader, Write},
net::{SocketAddr, TcpStream as StdTcpStream},
sync::Arc,
};
use actix_http::{
body::{AnyBody, Body, SizedStream},
body::{BodyStream, BoxBody, SizedStream},
error::PayloadError,
http::{
header::{self, HeaderName, HeaderValue},
@ -20,16 +20,14 @@ use actix_http::{
};
use actix_http_test::test_server;
use actix_service::{fn_factory_with_config, fn_service};
use actix_tls::connect::rustls::webpki_roots_cert_store;
use actix_utils::future::{err, ok};
use bytes::{Bytes, BytesMut};
use derive_more::{Display, Error};
use futures_core::Stream;
use futures_util::stream::{once, StreamExt as _};
use rustls::{
internal::pemfile::{certs, pkcs8_private_keys},
NoClientAuth, ServerConfig as RustlsServerConfig, Session,
};
use webpki::DNSNameRef;
use rustls::{Certificate, PrivateKey, ServerConfig as RustlsServerConfig, ServerName};
use rustls_pemfile::{certs, pkcs8_private_keys};
async fn load_body<S>(mut stream: S) -> Result<BytesMut, PayloadError>
where
@ -47,13 +45,24 @@ fn tls_config() -> RustlsServerConfig {
let cert_file = cert.serialize_pem().unwrap();
let key_file = cert.serialize_private_key_pem();
let mut config = RustlsServerConfig::new(NoClientAuth::new());
let cert_file = &mut BufReader::new(cert_file.as_bytes());
let key_file = &mut BufReader::new(key_file.as_bytes());
let cert_chain = certs(cert_file).unwrap();
let cert_chain = certs(cert_file)
.unwrap()
.into_iter()
.map(Certificate)
.collect();
let mut keys = pkcs8_private_keys(key_file).unwrap();
config.set_single_cert(cert_chain, keys.remove(0)).unwrap();
let mut config = RustlsServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert_chain, PrivateKey(keys.remove(0)))
.unwrap();
config.alpn_protocols.push(HTTP1_1_ALPN_PROTOCOL.to_vec());
config.alpn_protocols.push(H2_ALPN_PROTOCOL.to_vec());
config
}
@ -62,19 +71,28 @@ pub fn get_negotiated_alpn_protocol(
addr: SocketAddr,
client_alpn_protocol: &[u8],
) -> Option<Vec<u8>> {
let mut config = rustls::ClientConfig::new();
let mut config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(webpki_roots_cert_store())
.with_no_client_auth();
config.alpn_protocols.push(client_alpn_protocol.to_vec());
let mut sess = rustls::ClientSession::new(
&Arc::new(config),
DNSNameRef::try_from_ascii_str("localhost").unwrap(),
);
let mut sess = rustls::ClientConnection::new(
Arc::new(config),
ServerName::try_from("localhost").unwrap(),
)
.unwrap();
let mut sock = StdTcpStream::connect(addr).unwrap();
let mut stream = rustls::Stream::new(&mut sess, &mut sock);
// The handshake will fails because the client will not be able to verify the server
// certificate, but it doesn't matter here as we are just interested in the negotiated ALPN
// protocol
let _ = stream.flush();
sess.get_alpn_protocol().map(|proto| proto.to_vec())
sess.alpn_protocol().map(|proto| proto.to_vec())
}
#[actix_rt::test]
@ -398,7 +416,7 @@ async fn test_h2_body_chunked_explicit() {
ok::<_, Infallible>(
Response::build(StatusCode::OK)
.insert_header((header::TRANSFER_ENCODING, "chunked"))
.streaming(body),
.body(BodyStream::new(body)),
)
})
.rustls(tls_config())
@ -449,9 +467,9 @@ async fn test_h2_response_http_error_handling() {
#[display(fmt = "error")]
struct BadRequest;
impl From<BadRequest> for Response<AnyBody> {
impl From<BadRequest> for Response<BoxBody> {
fn from(_: BadRequest) -> Self {
Response::bad_request().set_body(AnyBody::from("error"))
Response::bad_request().set_body(BoxBody::new("error"))
}
}
@ -459,7 +477,7 @@ impl From<BadRequest> for Response<AnyBody> {
async fn test_h2_service_error() {
let mut srv = test_server(move || {
HttpService::build()
.h2(|_| err::<Response<Body>, _>(BadRequest))
.h2(|_| err::<Response<BoxBody>, _>(BadRequest))
.rustls(tls_config())
})
.await;
@ -476,7 +494,7 @@ async fn test_h2_service_error() {
async fn test_h1_service_error() {
let mut srv = test_server(move || {
HttpService::build()
.h1(|_| err::<Response<Body>, _>(BadRequest))
.h1(|_| err::<Response<BoxBody>, _>(BadRequest))
.rustls(tls_config())
})
.await;

View File

@ -6,7 +6,7 @@ use std::{
};
use actix_http::{
body::{AnyBody, Body, SizedStream},
body::{self, BodyStream, BoxBody, SizedStream},
header, http, Error, HttpMessage, HttpService, KeepAlive, Request, Response,
StatusCode,
};
@ -24,7 +24,7 @@ use regex::Regex;
#[actix_rt::test]
async fn test_h1() {
let srv = test_server(|| {
let mut srv = test_server(|| {
HttpService::build()
.keep_alive(KeepAlive::Disabled)
.client_timeout(1000)
@ -39,11 +39,13 @@ async fn test_h1() {
let response = srv.get("/").send().await.unwrap();
assert!(response.status().is_success());
srv.stop().await;
}
#[actix_rt::test]
async fn test_h1_2() {
let srv = test_server(|| {
let mut srv = test_server(|| {
HttpService::build()
.keep_alive(KeepAlive::Disabled)
.client_timeout(1000)
@ -59,13 +61,15 @@ async fn test_h1_2() {
let response = srv.get("/").send().await.unwrap();
assert!(response.status().is_success());
srv.stop().await;
}
#[derive(Debug, Display, Error)]
#[display(fmt = "expect failed")]
struct ExpectFailed;
impl From<ExpectFailed> for Response<AnyBody> {
impl From<ExpectFailed> for Response<BoxBody> {
fn from(_: ExpectFailed) -> Self {
Response::new(StatusCode::EXPECTATION_FAILED)
}
@ -73,7 +77,7 @@ impl From<ExpectFailed> for Response<AnyBody> {
#[actix_rt::test]
async fn test_expect_continue() {
let srv = test_server(|| {
let mut srv = test_server(|| {
HttpService::build()
.expect(fn_service(|req: Request| {
if req.head().uri.query() == Some("yes=") {
@ -98,11 +102,13 @@ async fn test_expect_continue() {
let mut data = String::new();
let _ = stream.read_to_string(&mut data);
assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n"));
srv.stop().await;
}
#[actix_rt::test]
async fn test_expect_continue_h1() {
let srv = test_server(|| {
let mut srv = test_server(|| {
HttpService::build()
.expect(fn_service(|req: Request| {
sleep(Duration::from_millis(20)).then(move |_| {
@ -129,6 +135,8 @@ async fn test_expect_continue_h1() {
let mut data = String::new();
let _ = stream.read_to_string(&mut data);
assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n"));
srv.stop().await;
}
#[actix_rt::test]
@ -136,7 +144,7 @@ async fn test_chunked_payload() {
let chunk_sizes = vec![32768, 32, 32768];
let total_size: usize = chunk_sizes.iter().sum();
let srv = test_server(|| {
let mut srv = test_server(|| {
HttpService::build()
.h1(fn_service(|mut request: Request| {
request
@ -183,15 +191,18 @@ async fn test_chunked_payload() {
Some(caps) => caps.get(1).unwrap().as_str().parse().unwrap(),
None => panic!("Failed to find size in HTTP Response: {}", data),
};
size
};
assert_eq!(returned_size, total_size);
srv.stop().await;
}
#[actix_rt::test]
async fn test_slow_request() {
let srv = test_server(|| {
let mut srv = test_server(|| {
HttpService::build()
.client_timeout(100)
.finish(|_| ok::<_, Infallible>(Response::ok()))
@ -204,11 +215,13 @@ async fn test_slow_request() {
let mut data = String::new();
let _ = stream.read_to_string(&mut data);
assert!(data.starts_with("HTTP/1.1 408 Request Timeout"));
srv.stop().await;
}
#[actix_rt::test]
async fn test_http1_malformed_request() {
let srv = test_server(|| {
let mut srv = test_server(|| {
HttpService::build()
.h1(|_| ok::<_, Infallible>(Response::ok()))
.tcp()
@ -220,11 +233,13 @@ async fn test_http1_malformed_request() {
let mut data = String::new();
let _ = stream.read_to_string(&mut data);
assert!(data.starts_with("HTTP/1.1 400 Bad Request"));
srv.stop().await;
}
#[actix_rt::test]
async fn test_http1_keepalive() {
let srv = test_server(|| {
let mut srv = test_server(|| {
HttpService::build()
.h1(|_| ok::<_, Infallible>(Response::ok()))
.tcp()
@ -241,11 +256,13 @@ async fn test_http1_keepalive() {
let mut data = vec![0; 1024];
let _ = stream.read(&mut data);
assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n");
srv.stop().await;
}
#[actix_rt::test]
async fn test_http1_keepalive_timeout() {
let srv = test_server(|| {
let mut srv = test_server(|| {
HttpService::build()
.keep_alive(1)
.h1(|_| ok::<_, Infallible>(Response::ok()))
@ -263,11 +280,13 @@ async fn test_http1_keepalive_timeout() {
let mut data = vec![0; 1024];
let res = stream.read(&mut data).unwrap();
assert_eq!(res, 0);
srv.stop().await;
}
#[actix_rt::test]
async fn test_http1_keepalive_close() {
let srv = test_server(|| {
let mut srv = test_server(|| {
HttpService::build()
.h1(|_| ok::<_, Infallible>(Response::ok()))
.tcp()
@ -284,11 +303,13 @@ async fn test_http1_keepalive_close() {
let mut data = vec![0; 1024];
let res = stream.read(&mut data).unwrap();
assert_eq!(res, 0);
srv.stop().await;
}
#[actix_rt::test]
async fn test_http10_keepalive_default_close() {
let srv = test_server(|| {
let mut srv = test_server(|| {
HttpService::build()
.h1(|_| ok::<_, Infallible>(Response::ok()))
.tcp()
@ -304,11 +325,13 @@ async fn test_http10_keepalive_default_close() {
let mut data = vec![0; 1024];
let res = stream.read(&mut data).unwrap();
assert_eq!(res, 0);
srv.stop().await;
}
#[actix_rt::test]
async fn test_http10_keepalive() {
let srv = test_server(|| {
let mut srv = test_server(|| {
HttpService::build()
.h1(|_| ok::<_, Infallible>(Response::ok()))
.tcp()
@ -331,11 +354,13 @@ async fn test_http10_keepalive() {
let mut data = vec![0; 1024];
let res = stream.read(&mut data).unwrap();
assert_eq!(res, 0);
srv.stop().await;
}
#[actix_rt::test]
async fn test_http1_keepalive_disabled() {
let srv = test_server(|| {
let mut srv = test_server(|| {
HttpService::build()
.keep_alive(KeepAlive::Disabled)
.h1(|_| ok::<_, Infallible>(Response::ok()))
@ -352,6 +377,8 @@ async fn test_http1_keepalive_disabled() {
let mut data = vec![0; 1024];
let res = stream.read(&mut data).unwrap();
assert_eq!(res, 0);
srv.stop().await;
}
#[actix_rt::test]
@ -361,7 +388,7 @@ async fn test_content_length() {
StatusCode,
};
let srv = test_server(|| {
let mut srv = test_server(|| {
HttpService::build()
.h1(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
@ -399,6 +426,8 @@ async fn test_content_length() {
assert_eq!(response.headers().get(&header), Some(&value));
}
}
srv.stop().await;
}
#[actix_rt::test]
@ -438,6 +467,8 @@ async fn test_h1_headers() {
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from(data2));
srv.stop().await;
}
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
@ -477,6 +508,8 @@ async fn test_h1_body() {
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
srv.stop().await;
}
#[actix_rt::test]
@ -502,6 +535,8 @@ async fn test_h1_head_empty() {
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
srv.stop().await;
}
#[actix_rt::test]
@ -527,11 +562,13 @@ async fn test_h1_head_binary() {
// read response
let bytes = srv.load_body(response).await.unwrap();
assert!(bytes.is_empty());
srv.stop().await;
}
#[actix_rt::test]
async fn test_h1_head_binary2() {
let srv = test_server(|| {
let mut srv = test_server(|| {
HttpService::build()
.h1(|_| ok::<_, Infallible>(Response::ok().set_body(STR)))
.tcp()
@ -548,6 +585,8 @@ async fn test_h1_head_binary2() {
.unwrap();
assert_eq!(format!("{}", STR.len()), len.to_str().unwrap());
}
srv.stop().await;
}
#[actix_rt::test]
@ -570,6 +609,8 @@ async fn test_h1_body_length() {
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
srv.stop().await;
}
#[actix_rt::test]
@ -581,7 +622,7 @@ async fn test_h1_body_chunked_explicit() {
ok::<_, Infallible>(
Response::build(StatusCode::OK)
.insert_header((header::TRANSFER_ENCODING, "chunked"))
.streaming(body),
.body(BodyStream::new(body)),
)
})
.tcp()
@ -605,6 +646,8 @@ async fn test_h1_body_chunked_explicit() {
// decode
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
srv.stop().await;
}
#[actix_rt::test]
@ -613,7 +656,9 @@ async fn test_h1_body_chunked_implicit() {
HttpService::build()
.h1(|_| {
let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref())));
ok::<_, Infallible>(Response::build(StatusCode::OK).streaming(body))
ok::<_, Infallible>(
Response::build(StatusCode::OK).body(BodyStream::new(body)),
)
})
.tcp()
})
@ -634,6 +679,8 @@ async fn test_h1_body_chunked_implicit() {
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
srv.stop().await;
}
#[actix_rt::test]
@ -661,15 +708,17 @@ async fn test_h1_response_http_error_handling() {
bytes,
Bytes::from_static(b"error processing HTTP: failed to parse header value")
);
srv.stop().await;
}
#[derive(Debug, Display, Error)]
#[display(fmt = "error")]
struct BadRequest;
impl From<BadRequest> for Response<AnyBody> {
impl From<BadRequest> for Response<BoxBody> {
fn from(_: BadRequest) -> Self {
Response::bad_request().set_body(AnyBody::from("error"))
Response::bad_request().set_body(BoxBody::new("error"))
}
}
@ -677,7 +726,7 @@ impl From<BadRequest> for Response<AnyBody> {
async fn test_h1_service_error() {
let mut srv = test_server(|| {
HttpService::build()
.h1(|_| err::<Response<Body>, _>(BadRequest))
.h1(|_| err::<Response<()>, _>(BadRequest))
.tcp()
})
.await;
@ -688,11 +737,13 @@ async fn test_h1_service_error() {
// read response
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
srv.stop().await;
}
#[actix_rt::test]
async fn test_h1_on_connect() {
let srv = test_server(|| {
let mut srv = test_server(|| {
HttpService::build()
.on_connect_ext(|_, data| {
data.insert(20isize);
@ -707,4 +758,92 @@ async fn test_h1_on_connect() {
let response = srv.get("/").send().await.unwrap();
assert!(response.status().is_success());
srv.stop().await;
}
/// Tests compliance with 304 Not Modified spec in RFC 7232 §4.1.
/// https://datatracker.ietf.org/doc/html/rfc7232#section-4.1
#[actix_rt::test]
async fn test_not_modified_spec_h1() {
// TODO: this test needing a few seconds to complete reveals some weirdness with either the
// dispatcher or the client, though similar hangs occur on other tests in this file, only
// succeeding, it seems, because of the keepalive timer
static CL: header::HeaderName = header::CONTENT_LENGTH;
let mut srv = test_server(|| {
HttpService::build()
.h1(|req: Request| {
let res: Response<BoxBody> = match req.path() {
// with no content-length
"/none" => {
Response::with_body(StatusCode::NOT_MODIFIED, body::None::new())
.map_into_boxed_body()
}
// with no content-length
"/body" => Response::with_body(StatusCode::NOT_MODIFIED, "1234")
.map_into_boxed_body(),
// with manual content-length header and specific None body
"/cl-none" => {
let mut res = Response::with_body(
StatusCode::NOT_MODIFIED,
body::None::new(),
);
res.headers_mut()
.insert(CL.clone(), header::HeaderValue::from_static("24"));
res.map_into_boxed_body()
}
// with manual content-length header and ignore-able body
"/cl-body" => {
let mut res =
Response::with_body(StatusCode::NOT_MODIFIED, "1234");
res.headers_mut()
.insert(CL.clone(), header::HeaderValue::from_static("4"));
res.map_into_boxed_body()
}
_ => panic!("unknown route"),
};
ok::<_, Infallible>(res)
})
.tcp()
})
.await;
let res = srv.get("/none").send().await.unwrap();
assert_eq!(res.status(), http::StatusCode::NOT_MODIFIED);
assert_eq!(res.headers().get(&CL), None);
assert!(srv.load_body(res).await.unwrap().is_empty());
let res = srv.get("/body").send().await.unwrap();
assert_eq!(res.status(), http::StatusCode::NOT_MODIFIED);
assert_eq!(res.headers().get(&CL), None);
assert!(srv.load_body(res).await.unwrap().is_empty());
let res = srv.get("/cl-none").send().await.unwrap();
assert_eq!(res.status(), http::StatusCode::NOT_MODIFIED);
assert_eq!(
res.headers().get(&CL),
Some(&header::HeaderValue::from_static("24")),
);
assert!(srv.load_body(res).await.unwrap().is_empty());
let res = srv.get("/cl-body").send().await.unwrap();
assert_eq!(res.status(), http::StatusCode::NOT_MODIFIED);
assert_eq!(
res.headers().get(&CL),
Some(&header::HeaderValue::from_static("4")),
);
// server does not prevent payload from being sent but clients may choose not to read it
// TODO: this is probably a bug, especially since CL header can differ in length from the body
assert!(!srv.load_body(res).await.unwrap().is_empty());
// TODO: add stream response tests
srv.stop().await;
}

View File

@ -6,7 +6,7 @@ use std::{
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::{
body::{AnyBody, BodySize},
body::{BodySize, BoxBody},
h1,
ws::{self, CloseCode, Frame, Item, Message},
Error, HttpService, Request, Response,
@ -50,14 +50,14 @@ enum WsServiceError {
Dispatcher,
}
impl From<WsServiceError> for Response<AnyBody> {
impl From<WsServiceError> for Response<BoxBody> {
fn from(err: WsServiceError) -> Self {
match err {
WsServiceError::Http(err) => err.into(),
WsServiceError::Ws(err) => err.into(),
WsServiceError::Io(_err) => unreachable!(),
WsServiceError::Dispatcher => Response::internal_server_error()
.set_body(AnyBody::from(format!("{}", err))),
.set_body(BoxBody::new(format!("{}", err))),
}
}
}