1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-11-27 17:52:56 +01:00

feat: support add connection header thought header apis

This commit is contained in:
Jonatan Lemes 2024-05-19 18:13:18 -03:00
parent 48aaf41638
commit a17945fb6e

View File

@ -6,6 +6,7 @@ use std::{
slice::from_raw_parts_mut, slice::from_raw_parts_mut,
}; };
use ahash::AHashMap;
use bytes::{BufMut, BytesMut}; use bytes::{BufMut, BytesMut};
use crate::{ use crate::{
@ -109,28 +110,21 @@ pub(crate) trait MessageType: Sized {
BodySize::None => dst.put_slice(b"\r\n"), BodySize::None => dst.put_slice(b"\r\n"),
} }
// Connection let headers = match self.extra_headers() {
match conn_type { Some(extra_headers) => self
ConnectionType::Upgrade => dst.put_slice(b"connection: upgrade\r\n"), .headers()
ConnectionType::KeepAlive if version < Version::HTTP_11 => { .inner
if camel_case { .iter()
dst.put_slice(b"Connection: keep-alive\r\n") .filter(|(name, _)| !extra_headers.contains_key(*name))
} else { .chain(extra_headers.inner.iter())
dst.put_slice(b"connection: keep-alive\r\n") .collect::<AHashMap<_, _>>(),
} None => self.headers().inner.iter().collect::<AHashMap<_, _>>(),
} };
ConnectionType::Close if version >= Version::HTTP_11 => {
if camel_case { // write connection header
dst.put_slice(b"Connection: close\r\n") self.write_connection_header(&headers, conn_type, version, dst);
} else {
dst.put_slice(b"connection: close\r\n")
}
}
_ => {}
}
// write headers // write headers
let mut has_date = false; let mut has_date = false;
let mut buf = dst.chunk_mut().as_mut_ptr(); let mut buf = dst.chunk_mut().as_mut_ptr();
@ -141,7 +135,7 @@ pub(crate) trait MessageType: Sized {
// container's knowledge, this is used to sync the containers cursor after data is written // container's knowledge, this is used to sync the containers cursor after data is written
let mut pos = 0; let mut pos = 0;
self.write_headers(|key, value| { self.write_headers(&headers, |key, value| {
match *key { match *key {
CONNECTION => return, CONNECTION => return,
TRANSFER_ENCODING | CONTENT_LENGTH if skip_len => return, TRANSFER_ENCODING | CONTENT_LENGTH if skip_len => return,
@ -221,22 +215,54 @@ pub(crate) trait MessageType: Sized {
Ok(()) Ok(())
} }
fn write_headers<F>(&mut self, mut f: F) fn write_connection_header<B: BufMut>(
&self,
headers: &AHashMap<&HeaderName, &Value>,
conn_type: ConnectionType,
version: Version,
buf: &mut B,
) {
let camel_case = self.camel_case();
if let Some(header_value) = headers.get(&CONNECTION) {
if camel_case {
buf.put_slice(b"Connection: ");
} else {
buf.put_slice(b"connection: ");
}
for val in header_value.iter() {
buf.put_slice(val.as_ref());
}
buf.put_slice(b"\r\n");
return;
}
// Connection
match conn_type {
ConnectionType::Upgrade => buf.put_slice(b"connection: upgrade\r\n"),
ConnectionType::KeepAlive if version < Version::HTTP_11 => {
if camel_case {
buf.put_slice(b"Connection: keep-alive\r\n")
} else {
buf.put_slice(b"connection: keep-alive\r\n")
}
}
ConnectionType::Close if version >= Version::HTTP_11 => {
if camel_case {
buf.put_slice(b"Connection: close\r\n")
} else {
buf.put_slice(b"connection: close\r\n")
}
}
_ => {}
}
}
fn write_headers<F>(&self, headers: &AHashMap<&HeaderName, &Value>, mut f: F)
where where
F: FnMut(&HeaderName, &Value), F: FnMut(&HeaderName, &Value),
{ {
match self.extra_headers() { headers.iter().for_each(|(key, value)| f(key, value));
Some(headers) => {
// merging headers from head and extra headers.
self.headers()
.inner
.iter()
.filter(|(name, _)| !headers.contains_key(*name))
.chain(headers.inner.iter())
.for_each(|(k, v)| f(k, v))
}
None => self.headers().inner.iter().for_each(|(k, v)| f(k, v)),
}
} }
} }
@ -668,4 +694,61 @@ mod tests {
assert!(!data.contains("content-length: 0\r\n")); assert!(!data.contains("content-length: 0\r\n"));
assert!(!data.contains("transfer-encoding: chunked\r\n")); assert!(!data.contains("transfer-encoding: chunked\r\n"));
} }
#[actix_rt::test]
async fn test_close_connection_header_even_keep_alive_was_provided() {
let mut bytes = BytesMut::with_capacity(2048);
let mut res = Response::with_body(StatusCode::OK, ());
res.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("close"));
let _ = res.encode_headers(
&mut bytes,
Version::HTTP_11,
BodySize::Stream,
ConnectionType::KeepAlive,
&ServiceConfig::default(),
);
let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
assert!(data.contains("connection: close\r\n"));
}
#[actix_rt::test]
async fn test_keep_alive_connection_header_when_provided() {
let mut bytes = BytesMut::with_capacity(2048);
let mut res = Response::with_body(StatusCode::OK, ());
res.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("keep-alive"));
let _ = res.encode_headers(
&mut bytes,
Version::HTTP_11,
BodySize::Stream,
ConnectionType::KeepAlive,
&ServiceConfig::default(),
);
let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
assert!(data.contains("connection: keep-alive\r\n"));
}
#[actix_rt::test]
async fn test_keep_alive_connection_header_even_close_was_provided() {
let mut bytes = BytesMut::with_capacity(2048);
let mut res = Response::with_body(StatusCode::OK, ());
res.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("keep-alive"));
let _ = res.encode_headers(
&mut bytes,
Version::HTTP_11,
BodySize::Stream,
ConnectionType::Close,
&ServiceConfig::default(),
);
let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap();
assert!(data.contains("connection: keep-alive\r\n"));
}
} }