1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-25 00:12:59 +01:00
actix-extras/src/client/writer.rs

391 lines
12 KiB
Rust
Raw Normal View History

2018-02-26 23:33:56 +01:00
#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))]
2018-02-20 08:18:18 +01:00
use std::cell::RefCell;
use std::fmt::Write as FmtWrite;
2018-04-14 01:02:01 +02:00
use std::io::{self, Write};
2018-02-20 08:18:18 +01:00
2018-04-14 01:02:01 +02:00
#[cfg(feature = "brotli")]
use brotli2::write::BrotliEncoder;
use bytes::{BufMut, BytesMut};
2018-04-24 21:24:04 +02:00
#[cfg(feature = "flate2")]
2018-04-14 01:02:01 +02:00
use flate2::write::{DeflateEncoder, GzEncoder};
2018-04-24 21:24:04 +02:00
#[cfg(feature = "flate2")]
use flate2::Compression;
2018-01-28 07:03:03 +01:00
use futures::{Async, Poll};
2018-04-14 01:02:01 +02:00
use http::header::{HeaderValue, CONNECTION, CONTENT_ENCODING, CONTENT_LENGTH, DATE,
TRANSFER_ENCODING};
use http::{HttpTryFrom, Version};
use time::{self, Duration};
2018-01-28 07:03:03 +01:00
use tokio_io::AsyncWrite;
2018-04-14 01:02:01 +02:00
use body::{Binary, Body};
2018-03-14 01:21:22 +01:00
use header::ContentEncoding;
2018-02-19 12:11:11 +01:00
use server::encoding::{ContentEncoder, TransferEncoding};
2018-04-14 01:02:01 +02:00
use server::shared::SharedBytes;
2018-04-24 21:24:04 +02:00
use server::WriterState;
2018-01-28 07:03:03 +01:00
use client::ClientRequest;
2018-01-28 07:03:03 +01:00
const AVERAGE_HEADER_SIZE: usize = 30;
2018-01-28 07:03:03 +01:00
bitflags! {
struct Flags: u8 {
const STARTED = 0b0000_0001;
const UPGRADE = 0b0000_0010;
const KEEPALIVE = 0b0000_0100;
const DISCONNECTED = 0b0000_1000;
}
}
pub(crate) struct HttpClientWriter {
2018-01-28 07:03:03 +01:00
flags: Flags,
written: u64,
headers_size: u32,
buffer: SharedBytes,
2018-03-09 19:09:13 +01:00
buffer_capacity: usize,
2018-02-19 12:11:11 +01:00
encoder: ContentEncoder,
2018-01-28 07:03:03 +01:00
}
impl HttpClientWriter {
2018-02-26 23:33:56 +01:00
pub fn new(buffer: SharedBytes) -> HttpClientWriter {
let encoder = ContentEncoder::Identity(TransferEncoding::eof(buffer.clone()));
HttpClientWriter {
2018-01-28 07:03:03 +01:00
flags: Flags::empty(),
written: 0,
headers_size: 0,
2018-03-09 19:09:13 +01:00
buffer_capacity: 0,
2018-02-26 23:33:56 +01:00
buffer,
encoder,
2018-01-28 07:03:03 +01:00
}
}
pub fn disconnected(&mut self) {
self.buffer.take();
}
2018-04-16 18:30:59 +02:00
pub fn is_completed(&self) -> bool {
self.buffer.is_empty()
}
// pub fn keepalive(&self) -> bool {
2018-04-14 01:02:01 +02:00
// self.flags.contains(Flags::KEEPALIVE) &&
// !self.flags.contains(Flags::UPGRADE) }
2018-01-28 07:03:03 +01:00
2018-04-14 01:02:01 +02:00
fn write_to_stream<T: AsyncWrite>(
2018-04-24 21:24:04 +02:00
&mut self, stream: &mut T,
2018-04-14 01:02:01 +02:00
) -> io::Result<WriterState> {
2018-01-28 07:03:03 +01:00
while !self.buffer.is_empty() {
match stream.write(self.buffer.as_ref()) {
Ok(0) => {
self.disconnected();
return Ok(WriterState::Done);
2018-04-14 01:02:01 +02:00
}
2018-01-28 07:03:03 +01:00
Ok(n) => {
let _ = self.buffer.split_to(n);
2018-04-14 01:02:01 +02:00
}
2018-01-28 07:03:03 +01:00
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
2018-03-09 19:09:13 +01:00
if self.buffer.len() > self.buffer_capacity {
2018-04-14 01:02:01 +02:00
return Ok(WriterState::Pause);
2018-01-28 07:03:03 +01:00
} else {
2018-04-14 01:02:01 +02:00
return Ok(WriterState::Done);
2018-01-28 07:03:03 +01:00
}
}
Err(err) => return Err(err),
}
}
Ok(WriterState::Done)
}
}
impl HttpClientWriter {
2018-02-19 12:11:11 +01:00
pub fn start(&mut self, msg: &mut ClientRequest) -> io::Result<()> {
2018-01-28 07:03:03 +01:00
// prepare task
self.flags.insert(Flags::STARTED);
2018-02-19 12:11:11 +01:00
self.encoder = content_encoder(self.buffer.clone(), msg);
2018-01-28 07:03:03 +01:00
2018-03-22 15:44:16 +01:00
if msg.upgrade() {
self.flags.insert(Flags::UPGRADE);
}
2018-01-28 07:03:03 +01:00
// render message
{
2018-03-22 15:44:16 +01:00
// status line
2018-04-14 01:02:01 +02:00
writeln!(
self.buffer,
"{} {} {:?}\r",
msg.method(),
2018-04-29 07:55:47 +02:00
msg.uri().path_and_query().map(|u| u.as_str()).unwrap_or("/"),
2018-04-14 01:02:01 +02:00
msg.version()
)?;
2018-03-22 15:44:16 +01:00
// write headers
2018-02-20 08:18:18 +01:00
let mut buffer = self.buffer.get_mut();
2018-02-19 12:11:11 +01:00
if let Body::Binary(ref bytes) = *msg.body() {
2018-03-22 15:44:16 +01:00
buffer.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE + bytes.len());
2018-02-19 12:11:11 +01:00
} else {
2018-03-22 15:44:16 +01:00
buffer.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE);
2018-02-19 12:11:11 +01:00
}
2018-01-28 07:03:03 +01:00
for (key, value) in msg.headers() {
2018-01-28 07:03:03 +01:00
let v = value.as_ref();
let k = key.as_str().as_bytes();
buffer.reserve(k.len() + v.len() + 4);
buffer.put_slice(k);
buffer.put_slice(b": ");
buffer.put_slice(v);
buffer.put_slice(b"\r\n");
}
2018-02-20 08:18:18 +01:00
// set date header
if !msg.headers().contains_key(DATE) {
buffer.extend_from_slice(b"date: ");
set_date(&mut buffer);
buffer.extend_from_slice(b"\r\n\r\n");
} else {
2018-01-28 07:03:03 +01:00
buffer.extend_from_slice(b"\r\n");
2018-02-20 08:18:18 +01:00
}
2018-01-28 07:03:03 +01:00
self.headers_size = buffer.len() as u32;
2018-02-19 12:11:11 +01:00
if msg.body().is_binary() {
if let Body::Binary(bytes) = msg.replace_body(Body::Empty) {
self.written += bytes.len() as u64;
self.encoder.write(bytes)?;
}
2018-03-09 19:09:13 +01:00
} else {
self.buffer_capacity = msg.write_buffer_capacity();
2018-02-19 12:11:11 +01:00
}
2018-01-28 07:03:03 +01:00
}
2018-02-19 12:11:11 +01:00
Ok(())
2018-01-28 07:03:03 +01:00
}
2018-02-20 07:48:27 +01:00
pub fn write(&mut self, payload: Binary) -> io::Result<WriterState> {
2018-01-28 07:03:03 +01:00
self.written += payload.len() as u64;
if !self.flags.contains(Flags::DISCONNECTED) {
2018-02-20 07:48:27 +01:00
if self.flags.contains(Flags::UPGRADE) {
self.buffer.extend(payload);
} else {
self.encoder.write(payload)?;
}
2018-01-28 07:03:03 +01:00
}
2018-03-09 19:09:13 +01:00
if self.buffer.len() > self.buffer_capacity {
2018-01-28 07:03:03 +01:00
Ok(WriterState::Pause)
} else {
Ok(WriterState::Done)
}
}
2018-02-20 07:48:27 +01:00
pub fn write_eof(&mut self) -> io::Result<()> {
self.encoder.write_eof()?;
2018-02-20 08:18:18 +01:00
if self.encoder.is_eof() {
Ok(())
} else {
2018-04-14 01:02:01 +02:00
Err(io::Error::new(
io::ErrorKind::Other,
"Last payload item, but eof is not reached",
))
2018-01-28 07:03:03 +01:00
}
}
#[inline]
2018-04-14 01:02:01 +02:00
pub fn poll_completed<T: AsyncWrite>(
2018-04-24 21:24:04 +02:00
&mut self, stream: &mut T, shutdown: bool,
2018-04-14 01:02:01 +02:00
) -> Poll<(), io::Error> {
2018-01-28 07:03:03 +01:00
match self.write_to_stream(stream) {
Ok(WriterState::Done) => {
if shutdown {
stream.shutdown()
} else {
Ok(Async::Ready(()))
}
2018-04-14 01:02:01 +02:00
}
2018-01-28 07:03:03 +01:00
Ok(WriterState::Pause) => Ok(Async::NotReady),
2018-04-14 01:02:01 +02:00
Err(err) => Err(err),
2018-01-28 07:03:03 +01:00
}
}
}
2018-02-19 12:11:11 +01:00
fn content_encoder(buf: SharedBytes, req: &mut ClientRequest) -> ContentEncoder {
let version = req.version();
let mut body = req.replace_body(Body::Empty);
let mut encoding = req.content_encoding();
let transfer = match body {
Body::Empty => {
req.headers_mut().remove(CONTENT_LENGTH);
TransferEncoding::length(0, buf)
2018-04-14 01:02:01 +02:00
}
2018-02-19 12:11:11 +01:00
Body::Binary(ref mut bytes) => {
if encoding.is_compression() {
let tmp = SharedBytes::default();
let transfer = TransferEncoding::eof(tmp.clone());
let mut enc = match encoding {
2018-04-24 21:24:04 +02:00
#[cfg(feature = "flate2")]
2018-02-19 12:11:11 +01:00
ContentEncoding::Deflate => ContentEncoder::Deflate(
2018-04-14 01:02:01 +02:00
DeflateEncoder::new(transfer, Compression::default()),
),
2018-04-24 21:24:04 +02:00
#[cfg(feature = "flate2")]
2018-04-14 01:02:01 +02:00
ContentEncoding::Gzip => ContentEncoder::Gzip(GzEncoder::new(
transfer,
Compression::default(),
)),
#[cfg(feature = "brotli")]
ContentEncoding::Br => {
ContentEncoder::Br(BrotliEncoder::new(transfer, 5))
}
2018-02-19 12:11:11 +01:00
ContentEncoding::Identity => ContentEncoder::Identity(transfer),
2018-04-14 01:02:01 +02:00
ContentEncoding::Auto => unreachable!(),
2018-02-19 12:11:11 +01:00
};
// TODO return error!
let _ = enc.write(bytes.clone());
let _ = enc.write_eof();
*bytes = Binary::from(tmp.take());
2018-02-19 22:18:18 +01:00
req.headers_mut().insert(
2018-04-14 01:02:01 +02:00
CONTENT_ENCODING,
HeaderValue::from_static(encoding.as_str()),
);
2018-02-19 12:11:11 +01:00
encoding = ContentEncoding::Identity;
}
let mut b = BytesMut::new();
let _ = write!(b, "{}", bytes.len());
2018-04-29 07:55:47 +02:00
req.headers_mut()
.insert(CONTENT_LENGTH, HeaderValue::try_from(b.freeze()).unwrap());
2018-02-19 12:11:11 +01:00
TransferEncoding::eof(buf)
2018-04-14 01:02:01 +02:00
}
2018-02-19 12:11:11 +01:00
Body::Streaming(_) | Body::Actor(_) => {
if req.upgrade() {
if version == Version::HTTP_2 {
error!("Connection upgrade is forbidden for HTTP/2");
} else {
2018-04-14 01:02:01 +02:00
req.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("upgrade"));
2018-02-19 12:11:11 +01:00
}
if encoding != ContentEncoding::Identity {
encoding = ContentEncoding::Identity;
req.headers_mut().remove(CONTENT_ENCODING);
}
TransferEncoding::eof(buf)
} else {
streaming_encoding(buf, version, req)
}
}
};
2018-02-19 22:18:18 +01:00
if encoding.is_compression() {
2018-04-29 07:55:47 +02:00
req.headers_mut()
.insert(CONTENT_ENCODING, HeaderValue::from_static(encoding.as_str()));
2018-02-19 22:18:18 +01:00
}
2018-02-19 12:11:11 +01:00
req.replace_body(body);
match encoding {
2018-04-24 21:24:04 +02:00
#[cfg(feature = "flate2")]
2018-04-14 01:02:01 +02:00
ContentEncoding::Deflate => ContentEncoder::Deflate(DeflateEncoder::new(
transfer,
Compression::default(),
)),
2018-04-24 21:24:04 +02:00
#[cfg(feature = "flate2")]
2018-04-14 01:02:01 +02:00
ContentEncoding::Gzip => {
ContentEncoder::Gzip(GzEncoder::new(transfer, Compression::default()))
}
#[cfg(feature = "brotli")]
ContentEncoding::Br => ContentEncoder::Br(BrotliEncoder::new(transfer, 5)),
ContentEncoding::Identity | ContentEncoding::Auto => {
ContentEncoder::Identity(transfer)
}
2018-02-19 12:11:11 +01:00
}
}
2018-04-14 01:02:01 +02:00
fn streaming_encoding(
2018-04-24 21:24:04 +02:00
buf: SharedBytes, version: Version, req: &mut ClientRequest,
2018-04-14 01:02:01 +02:00
) -> TransferEncoding {
2018-02-19 12:11:11 +01:00
if req.chunked() {
// Enable transfer encoding
req.headers_mut().remove(CONTENT_LENGTH);
if version == Version::HTTP_2 {
req.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof(buf)
} else {
2018-04-14 01:02:01 +02:00
req.headers_mut()
.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
2018-02-19 12:11:11 +01:00
TransferEncoding::chunked(buf)
}
} else {
// if Content-Length is specified, then use it as length hint
2018-04-14 01:02:01 +02:00
let (len, chunked) = if let Some(len) = req.headers().get(CONTENT_LENGTH) {
// Content-Length
if let Ok(s) = len.to_str() {
if let Ok(len) = s.parse::<u64>() {
(Some(len), false)
2018-02-19 12:11:11 +01:00
} else {
error!("illegal Content-Length: {:?}", len);
(None, false)
}
} else {
2018-04-14 01:02:01 +02:00
error!("illegal Content-Length: {:?}", len);
(None, false)
}
} else {
(None, true)
};
2018-02-19 12:11:11 +01:00
if !chunked {
if let Some(len) = len {
TransferEncoding::length(len, buf)
} else {
TransferEncoding::eof(buf)
}
} else {
// Enable transfer encoding
match version {
Version::HTTP_11 => {
2018-04-14 01:02:01 +02:00
req.headers_mut()
.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
2018-02-19 12:11:11 +01:00
TransferEncoding::chunked(buf)
2018-04-14 01:02:01 +02:00
}
2018-02-19 12:11:11 +01:00
_ => {
req.headers_mut().remove(TRANSFER_ENCODING);
TransferEncoding::eof(buf)
}
}
}
}
}
2018-02-20 08:18:18 +01:00
// "Sun, 06 Nov 1994 08:49:37 GMT".len()
pub const DATE_VALUE_LENGTH: usize = 29;
fn set_date(dst: &mut BytesMut) {
CACHED.with(|cache| {
let mut cache = cache.borrow_mut();
let now = time::get_time();
if now > cache.next_update {
cache.update(now);
}
dst.extend_from_slice(cache.buffer());
})
}
struct CachedDate {
bytes: [u8; DATE_VALUE_LENGTH],
next_update: time::Timespec,
}
thread_local!(static CACHED: RefCell<CachedDate> = RefCell::new(CachedDate {
bytes: [0; DATE_VALUE_LENGTH],
next_update: time::Timespec::new(0, 0),
}));
impl CachedDate {
fn buffer(&self) -> &[u8] {
&self.bytes[..]
}
fn update(&mut self, now: time::Timespec) {
write!(&mut self.bytes[..], "{}", time::at_utc(now).rfc822()).unwrap();
self.next_update = now + Duration::seconds(1);
self.next_update.nsec = 0;
}
}