mirror of
https://github.com/actix/actix-extras.git
synced 2024-11-24 16:02:59 +01:00
d9988f3ab6
fix missing content length when no compression is used
411 lines
13 KiB
Rust
411 lines
13 KiB
Rust
#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))]
|
|
|
|
use std::cell::RefCell;
|
|
use std::fmt::Write as FmtWrite;
|
|
use std::io::{self, Write};
|
|
|
|
#[cfg(feature = "brotli")]
|
|
use brotli2::write::BrotliEncoder;
|
|
use bytes::{BufMut, BytesMut};
|
|
#[cfg(feature = "flate2")]
|
|
use flate2::write::{DeflateEncoder, GzEncoder};
|
|
#[cfg(feature = "flate2")]
|
|
use flate2::Compression;
|
|
use futures::{Async, Poll};
|
|
use http::header::{
|
|
HeaderValue, CONNECTION, CONTENT_ENCODING, CONTENT_LENGTH, DATE, TRANSFER_ENCODING,
|
|
};
|
|
use http::{HttpTryFrom, Version};
|
|
use time::{self, Duration};
|
|
use tokio_io::AsyncWrite;
|
|
|
|
use body::{Binary, Body};
|
|
use header::ContentEncoding;
|
|
use server::output::{ContentEncoder, Output, TransferEncoding};
|
|
use server::WriterState;
|
|
|
|
use client::ClientRequest;
|
|
|
|
const AVERAGE_HEADER_SIZE: usize = 30;
|
|
|
|
bitflags! {
|
|
struct Flags: u8 {
|
|
const STARTED = 0b0000_0001;
|
|
const UPGRADE = 0b0000_0010;
|
|
const KEEPALIVE = 0b0000_0100;
|
|
const DISCONNECTED = 0b0000_1000;
|
|
}
|
|
}
|
|
|
|
pub(crate) struct HttpClientWriter {
|
|
flags: Flags,
|
|
written: u64,
|
|
headers_size: u32,
|
|
buffer: Output,
|
|
buffer_capacity: usize,
|
|
}
|
|
|
|
impl HttpClientWriter {
|
|
pub fn new() -> HttpClientWriter {
|
|
HttpClientWriter {
|
|
flags: Flags::empty(),
|
|
written: 0,
|
|
headers_size: 0,
|
|
buffer_capacity: 0,
|
|
buffer: Output::Buffer(BytesMut::new()),
|
|
}
|
|
}
|
|
|
|
pub fn disconnected(&mut self) {
|
|
self.buffer.take();
|
|
}
|
|
|
|
pub fn is_completed(&self) -> bool {
|
|
self.buffer.is_empty()
|
|
}
|
|
|
|
// pub fn keepalive(&self) -> bool {
|
|
// self.flags.contains(Flags::KEEPALIVE) &&
|
|
// !self.flags.contains(Flags::UPGRADE) }
|
|
|
|
fn write_to_stream<T: AsyncWrite>(
|
|
&mut self, stream: &mut T,
|
|
) -> io::Result<WriterState> {
|
|
while !self.buffer.is_empty() {
|
|
match stream.write(self.buffer.as_ref().as_ref()) {
|
|
Ok(0) => {
|
|
self.disconnected();
|
|
return Ok(WriterState::Done);
|
|
}
|
|
Ok(n) => {
|
|
let _ = self.buffer.split_to(n);
|
|
}
|
|
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
|
if self.buffer.len() > self.buffer_capacity {
|
|
return Ok(WriterState::Pause);
|
|
} else {
|
|
return Ok(WriterState::Done);
|
|
}
|
|
}
|
|
Err(err) => return Err(err),
|
|
}
|
|
}
|
|
Ok(WriterState::Done)
|
|
}
|
|
}
|
|
|
|
pub struct Writer<'a>(pub &'a mut BytesMut);
|
|
|
|
impl<'a> io::Write for Writer<'a> {
|
|
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
|
self.0.extend_from_slice(buf);
|
|
Ok(buf.len())
|
|
}
|
|
fn flush(&mut self) -> io::Result<()> {
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl HttpClientWriter {
|
|
pub fn start(&mut self, msg: &mut ClientRequest) -> io::Result<()> {
|
|
// prepare task
|
|
self.buffer = content_encoder(self.buffer.take(), msg);
|
|
self.flags.insert(Flags::STARTED);
|
|
if msg.upgrade() {
|
|
self.flags.insert(Flags::UPGRADE);
|
|
}
|
|
|
|
// render message
|
|
{
|
|
// output buffer
|
|
let buffer = self.buffer.as_mut();
|
|
|
|
// status line
|
|
writeln!(
|
|
Writer(buffer),
|
|
"{} {} {:?}\r",
|
|
msg.method(),
|
|
msg.uri()
|
|
.path_and_query()
|
|
.map(|u| u.as_str())
|
|
.unwrap_or("/"),
|
|
msg.version()
|
|
).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
|
|
|
// write headers
|
|
if let Body::Binary(ref bytes) = *msg.body() {
|
|
buffer.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE + bytes.len());
|
|
} else {
|
|
buffer.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE);
|
|
}
|
|
|
|
for (key, value) in msg.headers() {
|
|
let v = value.as_ref();
|
|
let k = key.as_str().as_bytes();
|
|
buffer.reserve(k.len() + v.len() + 4);
|
|
buffer.put_slice(k);
|
|
buffer.put_slice(b": ");
|
|
buffer.put_slice(v);
|
|
buffer.put_slice(b"\r\n");
|
|
}
|
|
|
|
// set date header
|
|
if !msg.headers().contains_key(DATE) {
|
|
buffer.extend_from_slice(b"date: ");
|
|
set_date(buffer);
|
|
buffer.extend_from_slice(b"\r\n\r\n");
|
|
} else {
|
|
buffer.extend_from_slice(b"\r\n");
|
|
}
|
|
}
|
|
self.headers_size = self.buffer.len() as u32;
|
|
|
|
if msg.body().is_binary() {
|
|
if let Body::Binary(bytes) = msg.replace_body(Body::Empty) {
|
|
self.written += bytes.len() as u64;
|
|
self.buffer.write(bytes.as_ref())?;
|
|
}
|
|
} else {
|
|
self.buffer_capacity = msg.write_buffer_capacity();
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub fn write(&mut self, payload: &[u8]) -> io::Result<WriterState> {
|
|
self.written += payload.len() as u64;
|
|
if !self.flags.contains(Flags::DISCONNECTED) {
|
|
self.buffer.write(payload)?;
|
|
}
|
|
|
|
if self.buffer.len() > self.buffer_capacity {
|
|
Ok(WriterState::Pause)
|
|
} else {
|
|
Ok(WriterState::Done)
|
|
}
|
|
}
|
|
|
|
pub fn write_eof(&mut self) -> io::Result<()> {
|
|
if self.buffer.write_eof()? {
|
|
Ok(())
|
|
} else {
|
|
Err(io::Error::new(
|
|
io::ErrorKind::Other,
|
|
"Last payload item, but eof is not reached",
|
|
))
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn poll_completed<T: AsyncWrite>(
|
|
&mut self, stream: &mut T, shutdown: bool,
|
|
) -> Poll<(), io::Error> {
|
|
match self.write_to_stream(stream) {
|
|
Ok(WriterState::Done) => {
|
|
if shutdown {
|
|
stream.shutdown()
|
|
} else {
|
|
Ok(Async::Ready(()))
|
|
}
|
|
}
|
|
Ok(WriterState::Pause) => Ok(Async::NotReady),
|
|
Err(err) => Err(err),
|
|
}
|
|
}
|
|
}
|
|
|
|
fn content_encoder(buf: BytesMut, req: &mut ClientRequest) -> Output {
|
|
let version = req.version();
|
|
let mut body = req.replace_body(Body::Empty);
|
|
let mut encoding = req.content_encoding();
|
|
|
|
let transfer = match body {
|
|
Body::Empty => {
|
|
req.headers_mut().remove(CONTENT_LENGTH);
|
|
return Output::Empty(buf);
|
|
}
|
|
Body::Binary(ref mut bytes) => {
|
|
#[cfg(any(feature = "flate2", feature = "brotli"))]
|
|
{
|
|
if encoding.is_compression() {
|
|
let mut tmp = BytesMut::new();
|
|
let mut transfer = TransferEncoding::eof(tmp);
|
|
let mut enc = match encoding {
|
|
#[cfg(feature = "flate2")]
|
|
ContentEncoding::Deflate => ContentEncoder::Deflate(
|
|
DeflateEncoder::new(transfer, Compression::default()),
|
|
),
|
|
#[cfg(feature = "flate2")]
|
|
ContentEncoding::Gzip => ContentEncoder::Gzip(GzEncoder::new(
|
|
transfer,
|
|
Compression::default(),
|
|
)),
|
|
#[cfg(feature = "brotli")]
|
|
ContentEncoding::Br => {
|
|
ContentEncoder::Br(BrotliEncoder::new(transfer, 5))
|
|
}
|
|
ContentEncoding::Auto | ContentEncoding::Identity => {
|
|
unreachable!()
|
|
}
|
|
};
|
|
// TODO return error!
|
|
let _ = enc.write(bytes.as_ref());
|
|
let _ = enc.write_eof();
|
|
*bytes = Binary::from(enc.buf_mut().take());
|
|
|
|
req.headers_mut().insert(
|
|
CONTENT_ENCODING,
|
|
HeaderValue::from_static(encoding.as_str()),
|
|
);
|
|
encoding = ContentEncoding::Identity;
|
|
}
|
|
let mut b = BytesMut::new();
|
|
let _ = write!(b, "{}", bytes.len());
|
|
req.headers_mut()
|
|
.insert(CONTENT_LENGTH, HeaderValue::try_from(b.freeze()).unwrap());
|
|
TransferEncoding::eof(buf)
|
|
}
|
|
#[cfg(not(any(feature = "flate2", feature = "brotli")))]
|
|
{
|
|
let mut b = BytesMut::new();
|
|
let _ = write!(b, "{}", bytes.len());
|
|
req.headers_mut()
|
|
.insert(CONTENT_LENGTH, HeaderValue::try_from(b.freeze()).unwrap());
|
|
TransferEncoding::eof(buf)
|
|
}
|
|
}
|
|
Body::Streaming(_) | Body::Actor(_) => {
|
|
if req.upgrade() {
|
|
if version == Version::HTTP_2 {
|
|
error!("Connection upgrade is forbidden for HTTP/2");
|
|
} else {
|
|
req.headers_mut()
|
|
.insert(CONNECTION, HeaderValue::from_static("upgrade"));
|
|
}
|
|
if encoding != ContentEncoding::Identity {
|
|
encoding = ContentEncoding::Identity;
|
|
req.headers_mut().remove(CONTENT_ENCODING);
|
|
}
|
|
TransferEncoding::eof(buf)
|
|
} else {
|
|
streaming_encoding(buf, version, req)
|
|
}
|
|
}
|
|
};
|
|
|
|
if encoding.is_compression() {
|
|
req.headers_mut().insert(
|
|
CONTENT_ENCODING,
|
|
HeaderValue::from_static(encoding.as_str()),
|
|
);
|
|
}
|
|
|
|
req.replace_body(body);
|
|
let enc = match encoding {
|
|
#[cfg(feature = "flate2")]
|
|
ContentEncoding::Deflate => ContentEncoder::Deflate(DeflateEncoder::new(
|
|
transfer,
|
|
Compression::default(),
|
|
)),
|
|
#[cfg(feature = "flate2")]
|
|
ContentEncoding::Gzip => {
|
|
ContentEncoder::Gzip(GzEncoder::new(transfer, Compression::default()))
|
|
}
|
|
#[cfg(feature = "brotli")]
|
|
ContentEncoding::Br => ContentEncoder::Br(BrotliEncoder::new(transfer, 5)),
|
|
ContentEncoding::Identity | ContentEncoding::Auto => return Output::TE(transfer),
|
|
};
|
|
Output::Encoder(enc)
|
|
}
|
|
|
|
fn streaming_encoding(
|
|
buf: BytesMut, version: Version, req: &mut ClientRequest,
|
|
) -> TransferEncoding {
|
|
if req.chunked() {
|
|
// Enable transfer encoding
|
|
req.headers_mut().remove(CONTENT_LENGTH);
|
|
if version == Version::HTTP_2 {
|
|
req.headers_mut().remove(TRANSFER_ENCODING);
|
|
TransferEncoding::eof(buf)
|
|
} else {
|
|
req.headers_mut()
|
|
.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
|
|
TransferEncoding::chunked(buf)
|
|
}
|
|
} else {
|
|
// if Content-Length is specified, then use it as length hint
|
|
let (len, chunked) = if let Some(len) = req.headers().get(CONTENT_LENGTH) {
|
|
// Content-Length
|
|
if let Ok(s) = len.to_str() {
|
|
if let Ok(len) = s.parse::<u64>() {
|
|
(Some(len), false)
|
|
} else {
|
|
error!("illegal Content-Length: {:?}", len);
|
|
(None, false)
|
|
}
|
|
} else {
|
|
error!("illegal Content-Length: {:?}", len);
|
|
(None, false)
|
|
}
|
|
} else {
|
|
(None, true)
|
|
};
|
|
|
|
if !chunked {
|
|
if let Some(len) = len {
|
|
TransferEncoding::length(len, buf)
|
|
} else {
|
|
TransferEncoding::eof(buf)
|
|
}
|
|
} else {
|
|
// Enable transfer encoding
|
|
match version {
|
|
Version::HTTP_11 => {
|
|
req.headers_mut()
|
|
.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
|
|
TransferEncoding::chunked(buf)
|
|
}
|
|
_ => {
|
|
req.headers_mut().remove(TRANSFER_ENCODING);
|
|
TransferEncoding::eof(buf)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// "Sun, 06 Nov 1994 08:49:37 GMT".len()
|
|
pub const DATE_VALUE_LENGTH: usize = 29;
|
|
|
|
fn set_date(dst: &mut BytesMut) {
|
|
CACHED.with(|cache| {
|
|
let mut cache = cache.borrow_mut();
|
|
let now = time::get_time();
|
|
if now > cache.next_update {
|
|
cache.update(now);
|
|
}
|
|
dst.extend_from_slice(cache.buffer());
|
|
})
|
|
}
|
|
|
|
struct CachedDate {
|
|
bytes: [u8; DATE_VALUE_LENGTH],
|
|
next_update: time::Timespec,
|
|
}
|
|
|
|
thread_local!(static CACHED: RefCell<CachedDate> = RefCell::new(CachedDate {
|
|
bytes: [0; DATE_VALUE_LENGTH],
|
|
next_update: time::Timespec::new(0, 0),
|
|
}));
|
|
|
|
impl CachedDate {
|
|
fn buffer(&self) -> &[u8] {
|
|
&self.bytes[..]
|
|
}
|
|
|
|
fn update(&mut self, now: time::Timespec) {
|
|
write!(&mut self.bytes[..], "{}", time::at_utc(now).rfc822()).unwrap();
|
|
self.next_update = now + Duration::seconds(1);
|
|
self.next_update.nsec = 0;
|
|
}
|
|
}
|