mirror of
https://github.com/actix/actix-extras.git
synced 2024-11-23 23:51:06 +01:00
refactor task impl, extract stream writer to separate struct
This commit is contained in:
parent
f010672885
commit
4add742aba
@ -28,8 +28,6 @@ default = []
|
||||
# tls
|
||||
tls = ["native-tls", "tokio-tls"]
|
||||
|
||||
# http2 = ["h2"]
|
||||
|
||||
[dependencies]
|
||||
log = "0.3"
|
||||
time = "0.1"
|
||||
|
4
cov.sh
4
cov.sh
@ -1,4 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
for file in target/debug/actix_web-*[^\.d]; do mkdir -p "target/cov/$(basename $file)"; /usr/local/bin/kcov --exclude-pattern=/.cargo,/usr/lib --verify "target/cov/$(basename $file)" "$file"; done &&
|
||||
for file in target/debug/test_*[^\.d]; do mkdir -p "target/cov/$(basename $file)"; /usr/local/bin/kcov --exclude-pattern=/.cargo,/usr/lib --verify "target/cov/$(basename $file)" "$file"; done
|
@ -47,6 +47,7 @@ impl<A> ActorContext for HttpContext<A> where A: Actor<Context=Self> + Route
|
||||
{
|
||||
/// Stop actor execution
|
||||
fn stop(&mut self) {
|
||||
self.stream.push_back(Frame::Payload(None));
|
||||
self.items.stop();
|
||||
self.address.close();
|
||||
if self.state == ActorState::Running {
|
||||
@ -141,7 +142,6 @@ impl<A> HttpContext<A> where A: Actor<Context=Self> + Route {
|
||||
|
||||
/// Indicate end of streamimng payload. Also this method calls `Self::close`.
|
||||
pub fn write_eof(&mut self) {
|
||||
self.stream.push_back(Frame::Payload(None));
|
||||
self.stop();
|
||||
}
|
||||
|
||||
|
465
src/h1.rs
465
src/h1.rs
@ -1,22 +1,282 @@
|
||||
use std::{self, io, ptr};
|
||||
use std::rc::Rc;
|
||||
use std::cell::UnsafeCell;
|
||||
use std::time::Duration;
|
||||
use std::collections::VecDeque;
|
||||
|
||||
use actix::Arbiter;
|
||||
use httparse;
|
||||
use http::{Method, Version, HttpTryFrom, HeaderMap};
|
||||
use http::header::{self, HeaderName, HeaderValue};
|
||||
use bytes::{Bytes, BytesMut, BufMut};
|
||||
use futures::{Async, Poll};
|
||||
use tokio_io::AsyncRead;
|
||||
use futures::{Future, Poll, Async};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use tokio_core::reactor::Timeout;
|
||||
use percent_encoding;
|
||||
|
||||
use task::Task;
|
||||
use server::HttpHandler;
|
||||
use error::ParseError;
|
||||
use httpcodes::HTTPNotFound;
|
||||
use httprequest::HttpRequest;
|
||||
use payload::{Payload, PayloadError, PayloadSender};
|
||||
use h1writer::H1Writer;
|
||||
|
||||
const MAX_HEADERS: usize = 100;
|
||||
const KEEPALIVE_PERIOD: u64 = 15; // seconds
|
||||
const INIT_BUFFER_SIZE: usize = 8192;
|
||||
const MAX_BUFFER_SIZE: usize = 131_072;
|
||||
const MAX_HEADERS: usize = 100;
|
||||
const MAX_PIPELINED_MESSAGES: usize = 16;
|
||||
const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0";
|
||||
|
||||
pub(crate) enum Http1Result {
|
||||
Done,
|
||||
Upgrade,
|
||||
}
|
||||
|
||||
pub(crate) struct Http1<T: AsyncWrite + 'static, A: 'static, H: 'static> {
|
||||
router: Rc<Vec<H>>,
|
||||
#[allow(dead_code)]
|
||||
addr: A,
|
||||
stream: H1Writer<T>,
|
||||
reader: Reader,
|
||||
read_buf: BytesMut,
|
||||
error: bool,
|
||||
tasks: VecDeque<Entry>,
|
||||
keepalive: bool,
|
||||
keepalive_timer: Option<Timeout>,
|
||||
h2: bool,
|
||||
}
|
||||
|
||||
struct Entry {
|
||||
task: Task,
|
||||
req: UnsafeCell<HttpRequest>,
|
||||
eof: bool,
|
||||
error: bool,
|
||||
finished: bool,
|
||||
}
|
||||
|
||||
impl<T, A, H> Http1<T, A, H>
|
||||
where T: AsyncRead + AsyncWrite + 'static,
|
||||
A: 'static,
|
||||
H: HttpHandler + 'static
|
||||
{
|
||||
pub fn new(stream: T, addr: A, router: Rc<Vec<H>>) -> Self {
|
||||
Http1{ router: router,
|
||||
addr: addr,
|
||||
stream: H1Writer::new(stream),
|
||||
reader: Reader::new(),
|
||||
read_buf: BytesMut::new(),
|
||||
error: false,
|
||||
tasks: VecDeque::new(),
|
||||
keepalive: true,
|
||||
keepalive_timer: None,
|
||||
h2: false }
|
||||
}
|
||||
|
||||
pub fn into_inner(mut self) -> (T, A, Rc<Vec<H>>, Bytes) {
|
||||
(self.stream.into_inner(), self.addr, self.router, self.read_buf.freeze())
|
||||
}
|
||||
|
||||
pub fn poll(&mut self) -> Poll<Http1Result, ()> {
|
||||
// keep-alive timer
|
||||
if let Some(ref mut timeout) = self.keepalive_timer {
|
||||
match timeout.poll() {
|
||||
Ok(Async::Ready(_)) =>
|
||||
return Ok(Async::Ready(Http1Result::Done)),
|
||||
Ok(Async::NotReady) => (),
|
||||
Err(_) => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
let mut not_ready = true;
|
||||
|
||||
// check in-flight messages
|
||||
let mut io = false;
|
||||
let mut idx = 0;
|
||||
while idx < self.tasks.len() {
|
||||
let item = &mut self.tasks[idx];
|
||||
|
||||
if !io && !item.eof {
|
||||
if item.error {
|
||||
return Err(())
|
||||
}
|
||||
|
||||
// this is anoying
|
||||
let req = unsafe {item.req.get().as_mut().unwrap()};
|
||||
match item.task.poll_io(&mut self.stream, req)
|
||||
{
|
||||
Ok(Async::Ready(ready)) => {
|
||||
not_ready = false;
|
||||
|
||||
// overide keep-alive state
|
||||
if self.keepalive {
|
||||
self.keepalive = self.stream.keepalive();
|
||||
}
|
||||
self.stream = H1Writer::new(self.stream.into_inner());
|
||||
|
||||
item.eof = true;
|
||||
if ready {
|
||||
item.finished = true;
|
||||
}
|
||||
},
|
||||
Ok(Async::NotReady) => {
|
||||
// no more IO for this iteration
|
||||
io = true;
|
||||
},
|
||||
Err(_) => {
|
||||
// it is not possible to recover from error
|
||||
// during task handling, so just drop connection
|
||||
return Err(())
|
||||
}
|
||||
}
|
||||
} else if !item.finished {
|
||||
match item.task.poll() {
|
||||
Ok(Async::NotReady) => (),
|
||||
Ok(Async::Ready(_)) => {
|
||||
not_ready = false;
|
||||
item.finished = true;
|
||||
},
|
||||
Err(_) =>
|
||||
item.error = true,
|
||||
}
|
||||
}
|
||||
idx += 1;
|
||||
}
|
||||
|
||||
// cleanup finished tasks
|
||||
while !self.tasks.is_empty() {
|
||||
if self.tasks[0].eof && self.tasks[0].finished {
|
||||
self.tasks.pop_front();
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// no keep-alive
|
||||
if !self.keepalive && self.tasks.is_empty() {
|
||||
if self.h2 {
|
||||
return Ok(Async::Ready(Http1Result::Upgrade))
|
||||
} else {
|
||||
return Ok(Async::Ready(Http1Result::Done))
|
||||
}
|
||||
}
|
||||
|
||||
// read incoming data
|
||||
if !self.error && !self.h2 && self.tasks.len() < MAX_PIPELINED_MESSAGES {
|
||||
match self.reader.parse(self.stream.get_mut(), &mut self.read_buf) {
|
||||
Ok(Async::Ready(Item::Http1(mut req, payload))) => {
|
||||
not_ready = false;
|
||||
|
||||
// stop keepalive timer
|
||||
self.keepalive_timer.take();
|
||||
|
||||
// start request processing
|
||||
let mut task = None;
|
||||
for h in self.router.iter() {
|
||||
if req.path().starts_with(h.prefix()) {
|
||||
task = Some(h.handle(&mut req, payload));
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
self.tasks.push_back(
|
||||
Entry {task: task.unwrap_or_else(|| Task::reply(HTTPNotFound)),
|
||||
req: UnsafeCell::new(req),
|
||||
eof: false,
|
||||
error: false,
|
||||
finished: false});
|
||||
}
|
||||
Ok(Async::Ready(Item::Http2)) => {
|
||||
self.h2 = true;
|
||||
}
|
||||
Err(ReaderError::Disconnect) => {
|
||||
not_ready = false;
|
||||
self.error = true;
|
||||
self.stream.disconnected();
|
||||
for entry in &mut self.tasks {
|
||||
entry.task.disconnected()
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
// notify all tasks
|
||||
not_ready = false;
|
||||
self.stream.disconnected();
|
||||
for entry in &mut self.tasks {
|
||||
entry.task.disconnected()
|
||||
}
|
||||
|
||||
// kill keepalive
|
||||
self.keepalive = false;
|
||||
self.keepalive_timer.take();
|
||||
|
||||
// on parse error, stop reading stream but
|
||||
// tasks need to be completed
|
||||
self.error = true;
|
||||
|
||||
if self.tasks.is_empty() {
|
||||
if let ReaderError::Error(err) = err {
|
||||
self.tasks.push_back(
|
||||
Entry {task: Task::reply(err),
|
||||
req: UnsafeCell::new(HttpRequest::for_error()),
|
||||
eof: false,
|
||||
error: false,
|
||||
finished: false});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Async::NotReady) => {
|
||||
// start keep-alive timer, this is also slow request timeout
|
||||
if self.tasks.is_empty() {
|
||||
if self.keepalive {
|
||||
if self.keepalive_timer.is_none() {
|
||||
trace!("Start keep-alive timer");
|
||||
let mut timeout = Timeout::new(
|
||||
Duration::new(KEEPALIVE_PERIOD, 0),
|
||||
Arbiter::handle()).unwrap();
|
||||
// register timeout
|
||||
let _ = timeout.poll();
|
||||
self.keepalive_timer = Some(timeout);
|
||||
}
|
||||
} else {
|
||||
// keep-alive disable, drop connection
|
||||
return Ok(Async::Ready(Http1Result::Done))
|
||||
}
|
||||
}
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check for parse error
|
||||
if self.tasks.is_empty() {
|
||||
if self.error || self.keepalive_timer.is_none() {
|
||||
return Ok(Async::Ready(Http1Result::Done))
|
||||
}
|
||||
else if self.h2 {
|
||||
return Ok(Async::Ready(Http1Result::Upgrade))
|
||||
}
|
||||
}
|
||||
|
||||
if not_ready {
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Item {
|
||||
Http1(HttpRequest, Payload),
|
||||
Http2,
|
||||
}
|
||||
|
||||
struct Reader {
|
||||
h1: bool,
|
||||
payload: Option<PayloadInfo>,
|
||||
}
|
||||
|
||||
enum Decoding {
|
||||
Paused,
|
||||
Ready,
|
||||
@ -28,19 +288,9 @@ struct PayloadInfo {
|
||||
decoder: Decoder,
|
||||
}
|
||||
|
||||
pub(crate) struct Reader {
|
||||
read_buf: BytesMut,
|
||||
payload: Option<PayloadInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum ReaderItem {
|
||||
Http1(HttpRequest, Payload),
|
||||
Http2,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum ReaderError {
|
||||
enum ReaderError {
|
||||
Disconnect,
|
||||
Payload,
|
||||
Error(ParseError),
|
||||
}
|
||||
@ -55,19 +305,19 @@ enum Message {
|
||||
impl Reader {
|
||||
pub fn new() -> Reader {
|
||||
Reader {
|
||||
read_buf: BytesMut::new(),
|
||||
h1: false,
|
||||
payload: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn decode(&mut self) -> std::result::Result<Decoding, ReaderError>
|
||||
fn decode(&mut self, buf: &mut BytesMut) -> std::result::Result<Decoding, ReaderError>
|
||||
{
|
||||
if let Some(ref mut payload) = self.payload {
|
||||
if payload.tx.maybe_paused() {
|
||||
return Ok(Decoding::Paused)
|
||||
}
|
||||
loop {
|
||||
match payload.decoder.decode(&mut self.read_buf) {
|
||||
match payload.decoder.decode(buf) {
|
||||
Ok(Async::Ready(Some(bytes))) => {
|
||||
payload.tx.feed_data(bytes)
|
||||
},
|
||||
@ -87,18 +337,18 @@ impl Reader {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse<T>(&mut self, io: &mut T) -> Poll<(HttpRequest, Payload), ReaderError>
|
||||
pub fn parse<T>(&mut self, io: &mut T, buf: &mut BytesMut) -> Poll<Item, ReaderError>
|
||||
where T: AsyncRead
|
||||
{
|
||||
loop {
|
||||
match self.decode()? {
|
||||
match self.decode(buf)? {
|
||||
Decoding::Paused => return Ok(Async::NotReady),
|
||||
Decoding::Ready => {
|
||||
self.payload = None;
|
||||
break
|
||||
},
|
||||
Decoding::NotReady => {
|
||||
match self.read_from_io(io) {
|
||||
match self.read_from_io(io, buf) {
|
||||
Ok(Async::Ready(0)) => {
|
||||
if let Some(ref mut payload) = self.payload {
|
||||
payload.tx.set_error(PayloadError::Incomplete);
|
||||
@ -123,7 +373,7 @@ impl Reader {
|
||||
}
|
||||
|
||||
loop {
|
||||
match Reader::parse_message(&mut self.read_buf).map_err(ReaderError::Error)? {
|
||||
match Reader::parse_message(buf).map_err(ReaderError::Error)? {
|
||||
Message::Http1(msg, decoder) => {
|
||||
let payload = if let Some(decoder) = decoder {
|
||||
let (tx, rx) = Payload::new(false);
|
||||
@ -134,7 +384,7 @@ impl Reader {
|
||||
self.payload = Some(payload);
|
||||
|
||||
loop {
|
||||
match self.decode()? {
|
||||
match self.decode(buf)? {
|
||||
Decoding::Paused =>
|
||||
break,
|
||||
Decoding::Ready => {
|
||||
@ -142,7 +392,7 @@ impl Reader {
|
||||
break
|
||||
},
|
||||
Decoding::NotReady => {
|
||||
match self.read_from_io(io) {
|
||||
match self.read_from_io(io, buf) {
|
||||
Ok(Async::Ready(0)) => {
|
||||
trace!("parse eof");
|
||||
if let Some(ref mut payload) = self.payload {
|
||||
@ -171,21 +421,26 @@ impl Reader {
|
||||
let (_, rx) = Payload::new(true);
|
||||
rx
|
||||
};
|
||||
return Ok(Async::Ready((msg, payload)));
|
||||
self.h1 = true;
|
||||
return Ok(Async::Ready(Item::Http1(msg, payload)));
|
||||
},
|
||||
Message::Http2 => {
|
||||
if self.h1 {
|
||||
return Err(ReaderError::Error(ParseError::Version))
|
||||
}
|
||||
return Ok(Async::Ready(Item::Http2));
|
||||
},
|
||||
Message::NotReady => {
|
||||
if self.read_buf.capacity() >= MAX_BUFFER_SIZE {
|
||||
if buf.capacity() >= MAX_BUFFER_SIZE {
|
||||
debug!("MAX_BUFFER_SIZE reached, closing");
|
||||
return Err(ReaderError::Error(ParseError::TooLarge));
|
||||
}
|
||||
},
|
||||
}
|
||||
match self.read_from_io(io) {
|
||||
match self.read_from_io(io, buf) {
|
||||
Ok(Async::Ready(0)) => {
|
||||
trace!("Eof during parse");
|
||||
return Err(ReaderError::Error(ParseError::Incomplete));
|
||||
debug!("Ignored premature client disconnection");
|
||||
return Err(ReaderError::Disconnect);
|
||||
},
|
||||
Ok(Async::Ready(_)) => (),
|
||||
Ok(Async::NotReady) =>
|
||||
@ -196,17 +451,19 @@ impl Reader {
|
||||
}
|
||||
}
|
||||
|
||||
fn read_from_io<T: AsyncRead>(&mut self, io: &mut T) -> Poll<usize, io::Error> {
|
||||
if self.read_buf.remaining_mut() < INIT_BUFFER_SIZE {
|
||||
self.read_buf.reserve(INIT_BUFFER_SIZE);
|
||||
fn read_from_io<T: AsyncRead>(&mut self, io: &mut T, buf: &mut BytesMut)
|
||||
-> Poll<usize, io::Error>
|
||||
{
|
||||
if buf.remaining_mut() < INIT_BUFFER_SIZE {
|
||||
buf.reserve(INIT_BUFFER_SIZE);
|
||||
unsafe { // Zero out unused memory
|
||||
let buf = self.read_buf.bytes_mut();
|
||||
let len = buf.len();
|
||||
ptr::write_bytes(buf.as_mut_ptr(), 0, len);
|
||||
let b = buf.bytes_mut();
|
||||
let len = b.len();
|
||||
ptr::write_bytes(b.as_mut_ptr(), 0, len);
|
||||
}
|
||||
}
|
||||
unsafe {
|
||||
let n = match io.read(self.read_buf.bytes_mut()) {
|
||||
let n = match io.read(buf.bytes_mut()) {
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
if e.kind() == io::ErrorKind::WouldBlock {
|
||||
@ -215,18 +472,17 @@ impl Reader {
|
||||
return Err(e)
|
||||
}
|
||||
};
|
||||
self.read_buf.advance_mut(n);
|
||||
buf.advance_mut(n);
|
||||
Ok(Async::Ready(n))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_message(buf: &mut BytesMut) -> Result<Message, ParseError>
|
||||
{
|
||||
println!("BUF: {:?}", buf);
|
||||
if buf.is_empty() || buf.len() < 14 {
|
||||
if buf.is_empty() {
|
||||
return Ok(Message::NotReady);
|
||||
}
|
||||
if &buf[..14] == &HTTP2_PREFACE[..] {
|
||||
if buf.len() >= 14 && &buf[..14] == &HTTP2_PREFACE[..] {
|
||||
return Ok(Message::Http2)
|
||||
}
|
||||
|
||||
@ -368,7 +624,7 @@ fn record_header_indices(bytes: &[u8],
|
||||
/// If a message body does not include a Transfer-Encoding, it *should*
|
||||
/// include a Content-Length header.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Decoder {
|
||||
struct Decoder {
|
||||
kind: Kind,
|
||||
}
|
||||
|
||||
@ -424,7 +680,7 @@ enum ChunkedState {
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
pub fn is_eof(&self) -> bool {
|
||||
/*pub fn is_eof(&self) -> bool {
|
||||
trace!("is_eof? {:?}", self);
|
||||
match self.kind {
|
||||
Kind::Length(0) |
|
||||
@ -432,7 +688,7 @@ impl Decoder {
|
||||
Kind::Eof(true) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}*/
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
@ -633,7 +889,7 @@ mod tests {
|
||||
use futures::{Async};
|
||||
use tokio_io::AsyncRead;
|
||||
use http::{Version, Method};
|
||||
use super::{Reader, ReaderError};
|
||||
use super::*;
|
||||
|
||||
struct Buffer {
|
||||
buf: Bytes,
|
||||
@ -682,8 +938,8 @@ mod tests {
|
||||
|
||||
macro_rules! parse_ready {
|
||||
($e:expr) => (
|
||||
match Reader::new().parse($e) {
|
||||
Ok(Async::Ready((req, payload))) => (req, payload),
|
||||
match Reader::new().parse($e, &mut BytesMut::new()) {
|
||||
Ok(Async::Ready(Item::Http1(req, payload))) => (req, payload),
|
||||
Ok(_) => panic!("Eof during parsing http request"),
|
||||
Err(err) => panic!("Error during parsing http request: {:?}", err),
|
||||
}
|
||||
@ -693,7 +949,7 @@ mod tests {
|
||||
macro_rules! reader_parse_ready {
|
||||
($e:expr) => (
|
||||
match $e {
|
||||
Ok(Async::Ready((req, payload))) => (req, payload),
|
||||
Ok(Async::Ready(Item::Http1(req, payload))) => (req, payload),
|
||||
Ok(_) => panic!("Eof during parsing http request"),
|
||||
Err(err) => panic!("Error during parsing http request: {:?}", err),
|
||||
}
|
||||
@ -701,22 +957,28 @@ mod tests {
|
||||
}
|
||||
|
||||
macro_rules! expect_parse_err {
|
||||
($e:expr) => (match Reader::new().parse($e) {
|
||||
Err(err) => match err {
|
||||
ReaderError::Error(_) => (),
|
||||
_ => panic!("Parse error expected"),
|
||||
},
|
||||
_ => panic!("Error expected"),
|
||||
})
|
||||
($e:expr) => ({
|
||||
let mut buf = BytesMut::new();
|
||||
match Reader::new().parse($e, &mut buf) {
|
||||
Err(err) => match err {
|
||||
ReaderError::Error(_) => (),
|
||||
_ => panic!("Parse error expected"),
|
||||
},
|
||||
val => {
|
||||
panic!("Error expected")
|
||||
}
|
||||
}}
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse() {
|
||||
let mut buf = Buffer::new("GET /test HTTP/1.1\r\n\r\n");
|
||||
let mut readbuf = BytesMut::new();
|
||||
|
||||
let mut reader = Reader::new();
|
||||
match reader.parse(&mut buf) {
|
||||
Ok(Async::Ready((req, payload))) => {
|
||||
match reader.parse(&mut buf, &mut readbuf) {
|
||||
Ok(Async::Ready(Item::Http1(req, payload))) => {
|
||||
assert_eq!(req.version(), Version::HTTP_11);
|
||||
assert_eq!(*req.method(), Method::GET);
|
||||
assert_eq!(req.path(), "/test");
|
||||
@ -729,16 +991,17 @@ mod tests {
|
||||
#[test]
|
||||
fn test_parse_partial() {
|
||||
let mut buf = Buffer::new("PUT /test HTTP/1");
|
||||
let mut readbuf = BytesMut::new();
|
||||
|
||||
let mut reader = Reader::new();
|
||||
match reader.parse(&mut buf) {
|
||||
match reader.parse(&mut buf, &mut readbuf) {
|
||||
Ok(Async::NotReady) => (),
|
||||
_ => panic!("Error"),
|
||||
}
|
||||
|
||||
buf.feed_data(".1\r\n\r\n");
|
||||
match reader.parse(&mut buf) {
|
||||
Ok(Async::Ready((req, payload))) => {
|
||||
match reader.parse(&mut buf, &mut readbuf) {
|
||||
Ok(Async::Ready(Item::Http1(req, payload))) => {
|
||||
assert_eq!(req.version(), Version::HTTP_11);
|
||||
assert_eq!(*req.method(), Method::PUT);
|
||||
assert_eq!(req.path(), "/test");
|
||||
@ -751,10 +1014,11 @@ mod tests {
|
||||
#[test]
|
||||
fn test_parse_post() {
|
||||
let mut buf = Buffer::new("POST /test2 HTTP/1.0\r\n\r\n");
|
||||
let mut readbuf = BytesMut::new();
|
||||
|
||||
let mut reader = Reader::new();
|
||||
match reader.parse(&mut buf) {
|
||||
Ok(Async::Ready((req, payload))) => {
|
||||
match reader.parse(&mut buf, &mut readbuf) {
|
||||
Ok(Async::Ready(Item::Http1(req, payload))) => {
|
||||
assert_eq!(req.version(), Version::HTTP_10);
|
||||
assert_eq!(*req.method(), Method::POST);
|
||||
assert_eq!(req.path(), "/test2");
|
||||
@ -767,10 +1031,11 @@ mod tests {
|
||||
#[test]
|
||||
fn test_parse_body() {
|
||||
let mut buf = Buffer::new("GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody");
|
||||
let mut readbuf = BytesMut::new();
|
||||
|
||||
let mut reader = Reader::new();
|
||||
match reader.parse(&mut buf) {
|
||||
Ok(Async::Ready((req, mut payload))) => {
|
||||
match reader.parse(&mut buf, &mut readbuf) {
|
||||
Ok(Async::Ready(Item::Http1(req, mut payload))) => {
|
||||
assert_eq!(req.version(), Version::HTTP_11);
|
||||
assert_eq!(*req.method(), Method::GET);
|
||||
assert_eq!(req.path(), "/test");
|
||||
@ -784,10 +1049,11 @@ mod tests {
|
||||
fn test_parse_body_crlf() {
|
||||
let mut buf = Buffer::new(
|
||||
"\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody");
|
||||
let mut readbuf = BytesMut::new();
|
||||
|
||||
let mut reader = Reader::new();
|
||||
match reader.parse(&mut buf) {
|
||||
Ok(Async::Ready((req, mut payload))) => {
|
||||
match reader.parse(&mut buf, &mut readbuf) {
|
||||
Ok(Async::Ready(Item::Http1(req, mut payload))) => {
|
||||
assert_eq!(req.version(), Version::HTTP_11);
|
||||
assert_eq!(*req.method(), Method::GET);
|
||||
assert_eq!(req.path(), "/test");
|
||||
@ -800,13 +1066,14 @@ mod tests {
|
||||
#[test]
|
||||
fn test_parse_partial_eof() {
|
||||
let mut buf = Buffer::new("GET /test HTTP/1.1\r\n");
|
||||
let mut readbuf = BytesMut::new();
|
||||
|
||||
let mut reader = Reader::new();
|
||||
not_ready!{ reader.parse(&mut buf) }
|
||||
not_ready!{ reader.parse(&mut buf, &mut readbuf) }
|
||||
|
||||
buf.feed_data("\r\n");
|
||||
match reader.parse(&mut buf) {
|
||||
Ok(Async::Ready((req, payload))) => {
|
||||
match reader.parse(&mut buf, &mut readbuf) {
|
||||
Ok(Async::Ready(Item::Http1(req, payload))) => {
|
||||
assert_eq!(req.version(), Version::HTTP_11);
|
||||
assert_eq!(*req.method(), Method::GET);
|
||||
assert_eq!(req.path(), "/test");
|
||||
@ -819,19 +1086,20 @@ mod tests {
|
||||
#[test]
|
||||
fn test_headers_split_field() {
|
||||
let mut buf = Buffer::new("GET /test HTTP/1.1\r\n");
|
||||
let mut readbuf = BytesMut::new();
|
||||
|
||||
let mut reader = Reader::new();
|
||||
not_ready!{ reader.parse(&mut buf) }
|
||||
not_ready!{ reader.parse(&mut buf, &mut readbuf) }
|
||||
|
||||
buf.feed_data("t");
|
||||
not_ready!{ reader.parse(&mut buf) }
|
||||
not_ready!{ reader.parse(&mut buf, &mut readbuf) }
|
||||
|
||||
buf.feed_data("es");
|
||||
not_ready!{ reader.parse(&mut buf) }
|
||||
not_ready!{ reader.parse(&mut buf, &mut readbuf) }
|
||||
|
||||
buf.feed_data("t: value\r\n\r\n");
|
||||
match reader.parse(&mut buf) {
|
||||
Ok(Async::Ready((req, payload))) => {
|
||||
match reader.parse(&mut buf, &mut readbuf) {
|
||||
Ok(Async::Ready(Item::Http1(req, payload))) => {
|
||||
assert_eq!(req.version(), Version::HTTP_11);
|
||||
assert_eq!(*req.method(), Method::GET);
|
||||
assert_eq!(req.path(), "/test");
|
||||
@ -848,10 +1116,11 @@ mod tests {
|
||||
"GET /test HTTP/1.1\r\n\
|
||||
Set-Cookie: c1=cookie1\r\n\
|
||||
Set-Cookie: c2=cookie2\r\n\r\n");
|
||||
let mut readbuf = BytesMut::new();
|
||||
|
||||
let mut reader = Reader::new();
|
||||
match reader.parse(&mut buf) {
|
||||
Ok(Async::Ready((req, _))) => {
|
||||
match reader.parse(&mut buf, &mut readbuf) {
|
||||
Ok(Async::Ready(Item::Http1(req, _))) => {
|
||||
let val: Vec<_> = req.headers().get_all("Set-Cookie")
|
||||
.iter().map(|v| v.to_str().unwrap().to_owned()).collect();
|
||||
assert_eq!(val[0], "c1=cookie1");
|
||||
@ -1081,14 +1350,15 @@ mod tests {
|
||||
let mut buf = Buffer::new(
|
||||
"GET /test HTTP/1.1\r\n\
|
||||
transfer-encoding: chunked\r\n\r\n");
|
||||
let mut readbuf = BytesMut::new();
|
||||
|
||||
let mut reader = Reader::new();
|
||||
let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf));
|
||||
let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf));
|
||||
assert!(req.chunked().unwrap());
|
||||
assert!(!payload.eof());
|
||||
|
||||
buf.feed_data("4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n");
|
||||
not_ready!(reader.parse(&mut buf));
|
||||
not_ready!(reader.parse(&mut buf, &mut readbuf));
|
||||
assert!(!payload.eof());
|
||||
assert_eq!(payload.readall().unwrap().as_ref(), b"dataline");
|
||||
assert!(payload.eof());
|
||||
@ -1099,10 +1369,11 @@ mod tests {
|
||||
let mut buf = Buffer::new(
|
||||
"GET /test HTTP/1.1\r\n\
|
||||
transfer-encoding: chunked\r\n\r\n");
|
||||
let mut readbuf = BytesMut::new();
|
||||
|
||||
let mut reader = Reader::new();
|
||||
|
||||
let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf));
|
||||
let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf));
|
||||
assert!(req.chunked().unwrap());
|
||||
assert!(!payload.eof());
|
||||
|
||||
@ -1111,7 +1382,7 @@ mod tests {
|
||||
POST /test2 HTTP/1.1\r\n\
|
||||
transfer-encoding: chunked\r\n\r\n");
|
||||
|
||||
let (req2, payload2) = reader_parse_ready!(reader.parse(&mut buf));
|
||||
let (req2, payload2) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf));
|
||||
assert_eq!(*req2.method(), Method::POST);
|
||||
assert!(req2.chunked().unwrap());
|
||||
assert!(!payload2.eof());
|
||||
@ -1125,37 +1396,38 @@ mod tests {
|
||||
let mut buf = Buffer::new(
|
||||
"GET /test HTTP/1.1\r\n\
|
||||
transfer-encoding: chunked\r\n\r\n");
|
||||
let mut readbuf = BytesMut::new();
|
||||
|
||||
let mut reader = Reader::new();
|
||||
let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf));
|
||||
let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf));
|
||||
assert!(req.chunked().unwrap());
|
||||
assert!(!payload.eof());
|
||||
|
||||
buf.feed_data("4\r\ndata\r");
|
||||
not_ready!(reader.parse(&mut buf));
|
||||
not_ready!(reader.parse(&mut buf, &mut readbuf));
|
||||
|
||||
buf.feed_data("\n4");
|
||||
not_ready!(reader.parse(&mut buf));
|
||||
not_ready!(reader.parse(&mut buf, &mut readbuf));
|
||||
|
||||
buf.feed_data("\r");
|
||||
not_ready!(reader.parse(&mut buf));
|
||||
not_ready!(reader.parse(&mut buf, &mut readbuf));
|
||||
buf.feed_data("\n");
|
||||
not_ready!(reader.parse(&mut buf));
|
||||
not_ready!(reader.parse(&mut buf, &mut readbuf));
|
||||
|
||||
buf.feed_data("li");
|
||||
not_ready!(reader.parse(&mut buf));
|
||||
not_ready!(reader.parse(&mut buf, &mut readbuf));
|
||||
|
||||
buf.feed_data("ne\r\n0\r\n");
|
||||
not_ready!(reader.parse(&mut buf));
|
||||
not_ready!(reader.parse(&mut buf, &mut readbuf));
|
||||
|
||||
//buf.feed_data("test: test\r\n");
|
||||
//not_ready!(reader.parse(&mut buf));
|
||||
//not_ready!(reader.parse(&mut buf, &mut readbuf));
|
||||
|
||||
assert_eq!(payload.readall().unwrap().as_ref(), b"dataline");
|
||||
assert!(!payload.eof());
|
||||
|
||||
buf.feed_data("\r\n");
|
||||
not_ready!(reader.parse(&mut buf));
|
||||
not_ready!(reader.parse(&mut buf, &mut readbuf));
|
||||
assert!(payload.eof());
|
||||
}
|
||||
|
||||
@ -1164,14 +1436,15 @@ mod tests {
|
||||
let mut buf = Buffer::new(
|
||||
"GET /test HTTP/1.1\r\n\
|
||||
transfer-encoding: chunked\r\n\r\n");
|
||||
let mut readbuf = BytesMut::new();
|
||||
|
||||
let mut reader = Reader::new();
|
||||
let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf));
|
||||
let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf));
|
||||
assert!(req.chunked().unwrap());
|
||||
assert!(!payload.eof());
|
||||
|
||||
buf.feed_data("4;test\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); // test: test\r\n\r\n")
|
||||
not_ready!(reader.parse(&mut buf));
|
||||
not_ready!(reader.parse(&mut buf, &mut readbuf));
|
||||
assert!(!payload.eof());
|
||||
assert_eq!(payload.readall().unwrap().as_ref(), b"dataline");
|
||||
assert!(payload.eof());
|
||||
@ -1193,4 +1466,16 @@ mod tests {
|
||||
Err(err) => panic!("{:?}", err),
|
||||
}
|
||||
}*/
|
||||
|
||||
#[test]
|
||||
fn test_http2_prefix() {
|
||||
let mut buf = Buffer::new("PRI * HTTP/2.0\r\n\r\n");
|
||||
let mut readbuf = BytesMut::new();
|
||||
|
||||
let mut reader = Reader::new();
|
||||
match reader.parse(&mut buf, &mut readbuf) {
|
||||
Ok(Async::Ready(Item::Http2)) => (),
|
||||
Ok(_) | Err(_) => panic!("Error during parsing http request"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
351
src/h1writer.rs
Normal file
351
src/h1writer.rs
Normal file
@ -0,0 +1,351 @@
|
||||
use std::{cmp, io};
|
||||
use std::fmt::Write;
|
||||
use bytes::BytesMut;
|
||||
use futures::{Async, Poll};
|
||||
use tokio_io::AsyncWrite;
|
||||
use http::{Version, StatusCode};
|
||||
use http::header::{HeaderValue,
|
||||
CONNECTION, CONTENT_TYPE, CONTENT_LENGTH, TRANSFER_ENCODING, DATE};
|
||||
|
||||
use date;
|
||||
use body::Body;
|
||||
use httprequest::HttpRequest;
|
||||
use httpresponse::HttpResponse;
|
||||
|
||||
const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific
|
||||
const MAX_WRITE_BUFFER_SIZE: usize = 65_536; // max buffer size 64k
|
||||
|
||||
|
||||
pub(crate) enum WriterState {
|
||||
Done,
|
||||
Pause,
|
||||
}
|
||||
|
||||
/// Send stream
|
||||
pub(crate) trait Writer {
|
||||
fn start(&mut self, req: &mut HttpRequest, resp: &mut HttpResponse)
|
||||
-> Result<WriterState, io::Error>;
|
||||
|
||||
fn write(&mut self, payload: &[u8]) -> Result<WriterState, io::Error>;
|
||||
|
||||
fn write_eof(&mut self) -> Result<WriterState, io::Error>;
|
||||
|
||||
fn poll_complete(&mut self) -> Poll<(), io::Error>;
|
||||
}
|
||||
|
||||
|
||||
pub(crate) struct H1Writer<T: AsyncWrite> {
|
||||
stream: Option<T>,
|
||||
buffer: BytesMut,
|
||||
started: bool,
|
||||
encoder: Encoder,
|
||||
upgrade: bool,
|
||||
keepalive: bool,
|
||||
disconnected: bool,
|
||||
}
|
||||
|
||||
impl<T: AsyncWrite> H1Writer<T> {
|
||||
|
||||
pub fn new(stream: T) -> H1Writer<T> {
|
||||
H1Writer {
|
||||
stream: Some(stream),
|
||||
buffer: BytesMut::new(),
|
||||
started: false,
|
||||
encoder: Encoder::length(0),
|
||||
upgrade: false,
|
||||
keepalive: false,
|
||||
disconnected: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self) -> &mut T {
|
||||
self.stream.as_mut().unwrap()
|
||||
}
|
||||
|
||||
pub fn into_inner(&mut self) -> T {
|
||||
self.stream.take().unwrap()
|
||||
}
|
||||
|
||||
pub fn disconnected(&mut self) {
|
||||
let len = self.buffer.len();
|
||||
self.buffer.split_to(len);
|
||||
}
|
||||
|
||||
pub fn keepalive(&self) -> bool {
|
||||
self.keepalive && !self.upgrade
|
||||
}
|
||||
|
||||
fn write_to_stream(&mut self) -> Result<WriterState, io::Error> {
|
||||
if let Some(ref mut stream) = self.stream {
|
||||
while !self.buffer.is_empty() {
|
||||
match stream.write(self.buffer.as_ref()) {
|
||||
Ok(n) => {
|
||||
self.buffer.split_to(n);
|
||||
},
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||
if self.buffer.len() > MAX_WRITE_BUFFER_SIZE {
|
||||
return Ok(WriterState::Pause)
|
||||
} else {
|
||||
return Ok(WriterState::Done)
|
||||
}
|
||||
}
|
||||
Err(err) =>
|
||||
return Err(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
return Ok(WriterState::Done)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncWrite> Writer for H1Writer<T> {
|
||||
|
||||
fn start(&mut self, req: &mut HttpRequest, msg: &mut HttpResponse)
|
||||
-> Result<WriterState, io::Error>
|
||||
{
|
||||
trace!("Prepare message status={:?}", msg.status);
|
||||
|
||||
// prepare task
|
||||
let mut extra = 0;
|
||||
let body = msg.replace_body(Body::Empty);
|
||||
let version = msg.version().unwrap_or_else(|| req.version());
|
||||
self.started = true;
|
||||
self.keepalive = msg.keep_alive().unwrap_or_else(|| req.keep_alive());
|
||||
|
||||
match body {
|
||||
Body::Empty => {
|
||||
if msg.chunked() {
|
||||
error!("Chunked transfer is enabled but body is set to Empty");
|
||||
}
|
||||
msg.headers.insert(CONTENT_LENGTH, HeaderValue::from_static("0"));
|
||||
msg.headers.remove(TRANSFER_ENCODING);
|
||||
self.encoder = Encoder::length(0);
|
||||
},
|
||||
Body::Length(n) => {
|
||||
if msg.chunked() {
|
||||
error!("Chunked transfer is enabled but body with specific length is specified");
|
||||
}
|
||||
msg.headers.insert(
|
||||
CONTENT_LENGTH,
|
||||
HeaderValue::from_str(format!("{}", n).as_str()).unwrap());
|
||||
msg.headers.remove(TRANSFER_ENCODING);
|
||||
self.encoder = Encoder::length(n);
|
||||
},
|
||||
Body::Binary(ref bytes) => {
|
||||
extra = bytes.len();
|
||||
msg.headers.insert(
|
||||
CONTENT_LENGTH,
|
||||
HeaderValue::from_str(format!("{}", bytes.len()).as_str()).unwrap());
|
||||
msg.headers.remove(TRANSFER_ENCODING);
|
||||
self.encoder = Encoder::length(0);
|
||||
}
|
||||
Body::Streaming => {
|
||||
if msg.chunked() {
|
||||
if version < Version::HTTP_11 {
|
||||
error!("Chunked transfer encoding is forbidden for {:?}", version);
|
||||
}
|
||||
msg.headers.remove(CONTENT_LENGTH);
|
||||
msg.headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
|
||||
self.encoder = Encoder::chunked();
|
||||
} else {
|
||||
self.encoder = Encoder::eof();
|
||||
}
|
||||
}
|
||||
Body::Upgrade => {
|
||||
msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
|
||||
self.encoder = Encoder::eof();
|
||||
}
|
||||
}
|
||||
|
||||
// Connection upgrade
|
||||
if msg.upgrade() {
|
||||
msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
|
||||
}
|
||||
// keep-alive
|
||||
else if self.keepalive {
|
||||
if version < Version::HTTP_11 {
|
||||
msg.headers.insert(CONNECTION, HeaderValue::from_static("keep-alive"));
|
||||
}
|
||||
} else if version >= Version::HTTP_11 {
|
||||
msg.headers.insert(CONNECTION, HeaderValue::from_static("close"));
|
||||
}
|
||||
|
||||
// render message
|
||||
let init_cap = 100 + msg.headers.len() * AVERAGE_HEADER_SIZE + extra;
|
||||
self.buffer.reserve(init_cap);
|
||||
|
||||
if version == Version::HTTP_11 && msg.status == StatusCode::OK {
|
||||
self.buffer.extend(b"HTTP/1.1 200 OK\r\n");
|
||||
} else {
|
||||
let _ = write!(self.buffer, "{:?} {}\r\n", version, msg.status);
|
||||
}
|
||||
for (key, value) in &msg.headers {
|
||||
let t: &[u8] = key.as_ref();
|
||||
self.buffer.extend(t);
|
||||
self.buffer.extend(b": ");
|
||||
self.buffer.extend(value.as_ref());
|
||||
self.buffer.extend(b"\r\n");
|
||||
}
|
||||
|
||||
// using http::h1::date is quite a lot faster than generating
|
||||
// a unique Date header each time like req/s goes up about 10%
|
||||
if !msg.headers.contains_key(DATE) {
|
||||
self.buffer.reserve(date::DATE_VALUE_LENGTH + 8);
|
||||
self.buffer.extend(b"Date: ");
|
||||
date::extend(&mut self.buffer);
|
||||
self.buffer.extend(b"\r\n");
|
||||
}
|
||||
|
||||
// default content-type
|
||||
if !msg.headers.contains_key(CONTENT_TYPE) {
|
||||
self.buffer.extend(b"ContentType: application/octet-stream\r\n".as_ref());
|
||||
}
|
||||
|
||||
self.buffer.extend(b"\r\n");
|
||||
|
||||
if let Body::Binary(ref bytes) = body {
|
||||
self.buffer.extend_from_slice(bytes.as_ref());
|
||||
return Ok(WriterState::Done)
|
||||
}
|
||||
msg.replace_body(body);
|
||||
|
||||
Ok(WriterState::Done)
|
||||
}
|
||||
|
||||
fn write(&mut self, payload: &[u8]) -> Result<WriterState, io::Error> {
|
||||
if !self.disconnected {
|
||||
if self.started {
|
||||
// TODO: add warning, write after EOF
|
||||
self.encoder.encode(&mut self.buffer, payload);
|
||||
} else {
|
||||
// might be response for EXCEPT
|
||||
self.buffer.extend_from_slice(payload)
|
||||
}
|
||||
}
|
||||
|
||||
if self.buffer.len() > MAX_WRITE_BUFFER_SIZE {
|
||||
return Ok(WriterState::Pause)
|
||||
} else {
|
||||
return Ok(WriterState::Done)
|
||||
}
|
||||
}
|
||||
|
||||
fn write_eof(&mut self) -> Result<WriterState, io::Error> {
|
||||
if !self.encoder.encode_eof(&mut self.buffer) {
|
||||
//debug!("last payload item, but it is not EOF ");
|
||||
Err(io::Error::new(io::ErrorKind::Other,
|
||||
"Last payload item, but eof is not reached"))
|
||||
} else {
|
||||
if self.buffer.len() > MAX_WRITE_BUFFER_SIZE {
|
||||
return Ok(WriterState::Pause)
|
||||
} else {
|
||||
return Ok(WriterState::Done)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_complete(&mut self) -> Poll<(), io::Error> {
|
||||
match self.write_to_stream() {
|
||||
Ok(WriterState::Done) => Ok(Async::Ready(())),
|
||||
Ok(WriterState::Pause) => Ok(Async::NotReady),
|
||||
Err(err) => Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Encoders to handle different Transfer-Encodings.
|
||||
#[derive(Debug, Clone)]
|
||||
struct Encoder {
|
||||
kind: Kind,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
enum Kind {
|
||||
/// An Encoder for when Transfer-Encoding includes `chunked`.
|
||||
Chunked(bool),
|
||||
/// An Encoder for when Content-Length is set.
|
||||
///
|
||||
/// Enforces that the body is not longer than the Content-Length header.
|
||||
Length(u64),
|
||||
/// An Encoder for when Content-Length is not known.
|
||||
///
|
||||
/// Appliction decides when to stop writing.
|
||||
Eof,
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
|
||||
pub fn eof() -> Encoder {
|
||||
Encoder {
|
||||
kind: Kind::Eof,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn chunked() -> Encoder {
|
||||
Encoder {
|
||||
kind: Kind::Chunked(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn length(len: u64) -> Encoder {
|
||||
Encoder {
|
||||
kind: Kind::Length(len),
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode message. Return `EOF` state of encoder
|
||||
pub fn encode(&mut self, dst: &mut BytesMut, msg: &[u8]) -> bool {
|
||||
match self.kind {
|
||||
Kind::Eof => {
|
||||
dst.extend(msg);
|
||||
msg.is_empty()
|
||||
},
|
||||
Kind::Chunked(ref mut eof) => {
|
||||
if *eof {
|
||||
return true;
|
||||
}
|
||||
|
||||
if msg.is_empty() {
|
||||
*eof = true;
|
||||
dst.extend(b"0\r\n\r\n");
|
||||
} else {
|
||||
write!(dst, "{:X}\r\n", msg.len()).unwrap();
|
||||
dst.extend(msg);
|
||||
dst.extend(b"\r\n");
|
||||
}
|
||||
*eof
|
||||
},
|
||||
Kind::Length(ref mut remaining) => {
|
||||
if msg.is_empty() {
|
||||
return *remaining == 0
|
||||
}
|
||||
let max = cmp::min(*remaining, msg.len() as u64);
|
||||
trace!("sized write = {}", max);
|
||||
dst.extend(msg[..max as usize].as_ref());
|
||||
|
||||
*remaining -= max as u64;
|
||||
trace!("encoded {} bytes, remaining = {}", max, remaining);
|
||||
*remaining == 0
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode eof. Return `EOF` state of encoder
|
||||
pub fn encode_eof(&mut self, dst: &mut BytesMut) -> bool {
|
||||
match self.kind {
|
||||
Kind::Eof => true,
|
||||
Kind::Chunked(ref mut eof) => {
|
||||
if *eof {
|
||||
return true;
|
||||
}
|
||||
|
||||
*eof = true;
|
||||
dst.extend(b"0\r\n\r\n");
|
||||
true
|
||||
},
|
||||
Kind::Length(ref mut remaining) => {
|
||||
return *remaining == 0
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
148
src/h2.rs
148
src/h2.rs
@ -1,9 +1,147 @@
|
||||
use std::{io, cmp};
|
||||
use std::{io, cmp, mem};
|
||||
use std::rc::Rc;
|
||||
use std::io::{Read, Write};
|
||||
use std::cell::UnsafeCell;
|
||||
use std::collections::VecDeque;
|
||||
|
||||
use http::request::Parts;
|
||||
use http2::{RecvStream};
|
||||
use http2::server::{Server, Handshake, Respond};
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures::Poll;
|
||||
use futures::{Async, Poll, Future, Stream};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use task::Task;
|
||||
use server::HttpHandler;
|
||||
use httpcodes::HTTPNotFound;
|
||||
use httprequest::HttpRequest;
|
||||
use payload::{Payload, PayloadError, PayloadSender};
|
||||
|
||||
|
||||
pub(crate) struct Http2<T, A, H>
|
||||
where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static
|
||||
{
|
||||
router: Rc<Vec<H>>,
|
||||
#[allow(dead_code)]
|
||||
addr: A,
|
||||
state: State<IoWrapper<T>>,
|
||||
error: bool,
|
||||
tasks: VecDeque<Entry>,
|
||||
}
|
||||
|
||||
enum State<T: AsyncRead + AsyncWrite> {
|
||||
Handshake(Handshake<T, Bytes>),
|
||||
Server(Server<T, Bytes>),
|
||||
Empty,
|
||||
}
|
||||
|
||||
impl<T, A, H> Http2<T, A, H>
|
||||
where T: AsyncRead + AsyncWrite + 'static,
|
||||
A: 'static,
|
||||
H: HttpHandler + 'static
|
||||
{
|
||||
pub fn new(stream: T, addr: A, router: Rc<Vec<H>>, buf: Bytes) -> Self {
|
||||
Http2{ router: router,
|
||||
addr: addr,
|
||||
error: false,
|
||||
tasks: VecDeque::new(),
|
||||
state: State::Handshake(
|
||||
Server::handshake(IoWrapper{unread: Some(buf), inner: stream})) }
|
||||
}
|
||||
|
||||
pub fn poll(&mut self) -> Poll<(), ()> {
|
||||
// handshake
|
||||
self.state = if let State::Handshake(ref mut handshake) = self.state {
|
||||
match handshake.poll() {
|
||||
Ok(Async::Ready(srv)) => {
|
||||
State::Server(srv)
|
||||
},
|
||||
Ok(Async::NotReady) =>
|
||||
return Ok(Async::NotReady),
|
||||
Err(err) => {
|
||||
trace!("Error handling connection: {}", err);
|
||||
return Err(())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mem::replace(&mut self.state, State::Empty)
|
||||
};
|
||||
|
||||
// get request
|
||||
let poll = if let State::Server(ref mut server) = self.state {
|
||||
server.poll()
|
||||
} else {
|
||||
unreachable!("Http2::poll() state was not advanced completely!")
|
||||
};
|
||||
|
||||
match poll {
|
||||
Ok(Async::NotReady) => {
|
||||
// Ok(Async::NotReady);
|
||||
()
|
||||
}
|
||||
Err(err) => {
|
||||
trace!("Connection error: {}", err);
|
||||
self.error = true;
|
||||
},
|
||||
Ok(Async::Ready(None)) => {
|
||||
|
||||
},
|
||||
Ok(Async::Ready(Some((req, resp)))) => {
|
||||
let (parts, body) = req.into_parts();
|
||||
let entry = Entry::new(parts, body, resp, &self.router);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Async::Ready(()))
|
||||
}
|
||||
}
|
||||
|
||||
struct Entry {
|
||||
task: Task,
|
||||
req: UnsafeCell<HttpRequest>,
|
||||
payload: PayloadSender,
|
||||
recv: RecvStream,
|
||||
respond: Respond<Bytes>,
|
||||
eof: bool,
|
||||
error: bool,
|
||||
finished: bool,
|
||||
}
|
||||
|
||||
impl Entry {
|
||||
fn new<H>(parts: Parts,
|
||||
recv: RecvStream,
|
||||
resp: Respond<Bytes>,
|
||||
router: &Rc<Vec<H>>) -> Entry
|
||||
where H: HttpHandler + 'static
|
||||
{
|
||||
let path = parts.uri.path().to_owned();
|
||||
let query = parts.uri.query().unwrap_or("").to_owned();
|
||||
|
||||
println!("PARTS: {:?}", parts);
|
||||
let mut req = HttpRequest::new(
|
||||
parts.method, path, parts.version, parts.headers, query);
|
||||
let (psender, payload) = Payload::new(false);
|
||||
|
||||
// start request processing
|
||||
let mut task = None;
|
||||
for h in router.iter() {
|
||||
if req.path().starts_with(h.prefix()) {
|
||||
task = Some(h.handle(&mut req, payload));
|
||||
break
|
||||
}
|
||||
}
|
||||
println!("REQ: {:?}", req);
|
||||
|
||||
Entry {task: task.unwrap_or_else(|| Task::reply(HTTPNotFound)),
|
||||
req: UnsafeCell::new(req),
|
||||
payload: psender,
|
||||
recv: recv,
|
||||
respond: resp,
|
||||
eof: false,
|
||||
error: false,
|
||||
finished: false}
|
||||
}
|
||||
}
|
||||
|
||||
struct IoWrapper<T> {
|
||||
unread: Option<Bytes>,
|
||||
@ -14,9 +152,9 @@ impl<T: Read> Read for IoWrapper<T> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
if let Some(mut bytes) = self.unread.take() {
|
||||
let size = cmp::min(buf.len(), bytes.len());
|
||||
buf.copy_from_slice(&bytes[..size]);
|
||||
bytes.split_to(size);
|
||||
if !bytes.is_empty() {
|
||||
buf[..size].copy_from_slice(&bytes[..size]);
|
||||
if bytes.len() > size {
|
||||
bytes.split_to(size);
|
||||
self.unread = Some(bytes);
|
||||
}
|
||||
Ok(size)
|
||||
|
@ -20,6 +20,7 @@ extern crate mime_guess;
|
||||
extern crate url;
|
||||
extern crate percent_encoding;
|
||||
extern crate actix;
|
||||
extern crate h2 as http2;
|
||||
|
||||
#[cfg(feature="tls")]
|
||||
extern crate native_tls;
|
||||
@ -45,6 +46,7 @@ mod wsframe;
|
||||
mod wsproto;
|
||||
mod h1;
|
||||
mod h2;
|
||||
mod h1writer;
|
||||
|
||||
pub mod ws;
|
||||
pub mod dev;
|
||||
|
250
src/server.rs
250
src/server.rs
@ -1,13 +1,9 @@
|
||||
use std::{io, net};
|
||||
use std::{io, net, mem};
|
||||
use std::rc::Rc;
|
||||
use std::cell::UnsafeCell;
|
||||
use std::time::Duration;
|
||||
use std::marker::PhantomData;
|
||||
use std::collections::VecDeque;
|
||||
|
||||
use actix::dev::*;
|
||||
use futures::{Future, Poll, Async, Stream};
|
||||
use tokio_core::reactor::Timeout;
|
||||
use tokio_core::net::{TcpListener, TcpStream};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
@ -17,9 +13,9 @@ use native_tls::TlsAcceptor;
|
||||
use tokio_tls::{TlsStream, TlsAcceptorExt};
|
||||
|
||||
use h1;
|
||||
use h2;
|
||||
use task::Task;
|
||||
use payload::Payload;
|
||||
use httpcodes::HTTPNotFound;
|
||||
use httprequest::HttpRequest;
|
||||
|
||||
/// Low level http request handler
|
||||
@ -153,11 +149,10 @@ impl<H: HttpHandler> HttpServer<TlsStream<TcpStream>, net::SocketAddr, H> {
|
||||
println!("SSL");
|
||||
TlsAcceptorExt::accept_async(acc.as_ref(), stream)
|
||||
.map(move |t| {
|
||||
println!("connected {:?} {:?}", t, addr);
|
||||
IoStream(t, addr)
|
||||
})
|
||||
.map_err(|err| {
|
||||
println!("ERR: {:?}", err);
|
||||
trace!("Error during handling tls connection: {}", err);
|
||||
io::Error::new(io::ErrorKind::Other, err)
|
||||
})
|
||||
}));
|
||||
@ -195,42 +190,25 @@ impl<T, A, H> Handler<IoStream<T, A>, io::Error> for HttpServer<T, A, H>
|
||||
-> Response<Self, IoStream<T, A>>
|
||||
{
|
||||
Arbiter::handle().spawn(
|
||||
HttpChannel{router: Rc::clone(&self.h),
|
||||
addr: msg.1,
|
||||
stream: msg.0,
|
||||
reader: h1::Reader::new(),
|
||||
error: false,
|
||||
items: VecDeque::new(),
|
||||
inactive: VecDeque::new(),
|
||||
keepalive: true,
|
||||
keepalive_timer: None,
|
||||
HttpChannel{
|
||||
proto: Protocol::H1(h1::Http1::new(msg.0, msg.1, Rc::clone(&self.h)))
|
||||
});
|
||||
Self::empty()
|
||||
}
|
||||
}
|
||||
|
||||
struct Entry {
|
||||
task: Task,
|
||||
req: UnsafeCell<HttpRequest>,
|
||||
eof: bool,
|
||||
error: bool,
|
||||
finished: bool,
|
||||
enum Protocol<T, A, H>
|
||||
where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static
|
||||
{
|
||||
H1(h1::Http1<T, A, H>),
|
||||
H2(h2::Http2<T, A, H>),
|
||||
None,
|
||||
}
|
||||
|
||||
const KEEPALIVE_PERIOD: u64 = 15; // seconds
|
||||
const MAX_PIPELINED_MESSAGES: usize = 16;
|
||||
|
||||
pub struct HttpChannel<T: 'static, A: 'static, H: 'static> {
|
||||
router: Rc<Vec<H>>,
|
||||
#[allow(dead_code)]
|
||||
addr: A,
|
||||
stream: T,
|
||||
reader: h1::Reader,
|
||||
error: bool,
|
||||
items: VecDeque<Entry>,
|
||||
inactive: VecDeque<Entry>,
|
||||
keepalive: bool,
|
||||
keepalive_timer: Option<Timeout>,
|
||||
pub struct HttpChannel<T, A, H>
|
||||
where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static
|
||||
{
|
||||
proto: Protocol<T, A, H>,
|
||||
}
|
||||
|
||||
/*impl<T: 'static, A: 'static, H: 'static> Drop for HttpChannel<T, A, H> {
|
||||
@ -240,193 +218,45 @@ pub struct HttpChannel<T: 'static, A: 'static, H: 'static> {
|
||||
}*/
|
||||
|
||||
impl<T, A, H> Actor for HttpChannel<T, A, H>
|
||||
where T: AsyncRead + AsyncWrite + 'static,
|
||||
A: 'static,
|
||||
H: HttpHandler + 'static
|
||||
where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: HttpHandler + 'static
|
||||
{
|
||||
type Context = Context<Self>;
|
||||
}
|
||||
|
||||
impl<T, A, H> Future for HttpChannel<T, A, H>
|
||||
where T: AsyncRead + AsyncWrite + 'static,
|
||||
A: 'static,
|
||||
H: HttpHandler + 'static
|
||||
where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: HttpHandler + 'static
|
||||
{
|
||||
type Item = ();
|
||||
type Error = ();
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
// keep-alive timer
|
||||
if let Some(ref mut timeout) = self.keepalive_timer {
|
||||
match timeout.poll() {
|
||||
Ok(Async::Ready(_)) =>
|
||||
return Ok(Async::Ready(())),
|
||||
Ok(Async::NotReady) => (),
|
||||
Err(_) => unreachable!(),
|
||||
match self.proto {
|
||||
Protocol::H1(ref mut h1) => {
|
||||
match h1.poll() {
|
||||
Ok(Async::Ready(h1::Http1Result::Done)) =>
|
||||
return Ok(Async::Ready(())),
|
||||
Ok(Async::Ready(h1::Http1Result::Upgrade)) => (),
|
||||
Ok(Async::NotReady) =>
|
||||
return Ok(Async::NotReady),
|
||||
Err(_) =>
|
||||
return Err(()),
|
||||
}
|
||||
}
|
||||
Protocol::H2(ref mut h2) =>
|
||||
return h2.poll(),
|
||||
Protocol::None =>
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
loop {
|
||||
let mut not_ready = true;
|
||||
|
||||
// check in-flight messages
|
||||
let mut idx = 0;
|
||||
while idx < self.items.len() {
|
||||
if idx == 0 {
|
||||
if self.items[idx].error {
|
||||
return Err(())
|
||||
}
|
||||
|
||||
// this is anoying
|
||||
let req = unsafe {self.items[idx].req.get().as_mut().unwrap()};
|
||||
match self.items[idx].task.poll_io(&mut self.stream, req)
|
||||
{
|
||||
Ok(Async::Ready(ready)) => {
|
||||
not_ready = false;
|
||||
let mut item = self.items.pop_front().unwrap();
|
||||
|
||||
// overide keep-alive state
|
||||
if self.keepalive {
|
||||
self.keepalive = item.task.keepalive();
|
||||
}
|
||||
if !ready {
|
||||
item.eof = true;
|
||||
self.inactive.push_back(item);
|
||||
}
|
||||
|
||||
// no keep-alive
|
||||
if ready && !self.keepalive &&
|
||||
self.items.is_empty() && self.inactive.is_empty()
|
||||
{
|
||||
return Ok(Async::Ready(()))
|
||||
}
|
||||
continue
|
||||
},
|
||||
Ok(Async::NotReady) => (),
|
||||
Err(_) => {
|
||||
// it is not possible to recover from error
|
||||
// during task handling, so just drop connection
|
||||
return Err(())
|
||||
}
|
||||
}
|
||||
} else if !self.items[idx].finished && !self.items[idx].error {
|
||||
match self.items[idx].task.poll() {
|
||||
Ok(Async::NotReady) => (),
|
||||
Ok(Async::Ready(_)) => {
|
||||
not_ready = false;
|
||||
self.items[idx].finished = true;
|
||||
},
|
||||
Err(_) =>
|
||||
self.items[idx].error = true,
|
||||
}
|
||||
}
|
||||
idx += 1;
|
||||
}
|
||||
|
||||
// check inactive tasks
|
||||
let mut idx = 0;
|
||||
while idx < self.inactive.len() {
|
||||
if idx == 0 && self.inactive[idx].error && self.inactive[idx].finished {
|
||||
let _ = self.inactive.pop_front();
|
||||
continue
|
||||
}
|
||||
|
||||
if !self.inactive[idx].finished && !self.inactive[idx].error {
|
||||
match self.inactive[idx].task.poll() {
|
||||
Ok(Async::NotReady) => (),
|
||||
Ok(Async::Ready(_)) => {
|
||||
not_ready = false;
|
||||
self.inactive[idx].finished = true
|
||||
}
|
||||
Err(_) =>
|
||||
self.inactive[idx].error = true,
|
||||
}
|
||||
}
|
||||
idx += 1;
|
||||
}
|
||||
|
||||
// read incoming data
|
||||
if !self.error && self.items.len() < MAX_PIPELINED_MESSAGES {
|
||||
match self.reader.parse(&mut self.stream) {
|
||||
Ok(Async::Ready((mut req, payload))) => {
|
||||
not_ready = false;
|
||||
|
||||
// stop keepalive timer
|
||||
self.keepalive_timer.take();
|
||||
|
||||
// start request processing
|
||||
let mut task = None;
|
||||
for h in self.router.iter() {
|
||||
if req.path().starts_with(h.prefix()) {
|
||||
task = Some(h.handle(&mut req, payload));
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
self.items.push_back(
|
||||
Entry {task: task.unwrap_or_else(|| Task::reply(HTTPNotFound)),
|
||||
req: UnsafeCell::new(req),
|
||||
eof: false,
|
||||
error: false,
|
||||
finished: false});
|
||||
}
|
||||
Err(err) => {
|
||||
// notify all tasks
|
||||
not_ready = false;
|
||||
for entry in &mut self.items {
|
||||
entry.task.disconnected()
|
||||
}
|
||||
|
||||
// kill keepalive
|
||||
self.keepalive = false;
|
||||
self.keepalive_timer.take();
|
||||
|
||||
// on parse error, stop reading stream but
|
||||
// tasks need to be completed
|
||||
self.error = true;
|
||||
|
||||
if self.items.is_empty() {
|
||||
if let h1::ReaderError::Error(err) = err {
|
||||
self.items.push_back(
|
||||
Entry {task: Task::reply(err),
|
||||
req: UnsafeCell::new(HttpRequest::for_error()),
|
||||
eof: false,
|
||||
error: false,
|
||||
finished: false});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Async::NotReady) => {
|
||||
// start keep-alive timer, this is also slow request timeout
|
||||
if self.items.is_empty() && self.inactive.is_empty() {
|
||||
if self.keepalive {
|
||||
if self.keepalive_timer.is_none() {
|
||||
trace!("Start keep-alive timer");
|
||||
let mut timeout = Timeout::new(
|
||||
Duration::new(KEEPALIVE_PERIOD, 0),
|
||||
Arbiter::handle()).unwrap();
|
||||
// register timeout
|
||||
let _ = timeout.poll();
|
||||
self.keepalive_timer = Some(timeout);
|
||||
}
|
||||
} else {
|
||||
// keep-alive disable, drop connection
|
||||
return Ok(Async::Ready(()))
|
||||
}
|
||||
}
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check for parse error
|
||||
if self.items.is_empty() && self.inactive.is_empty() && self.error {
|
||||
return Ok(Async::Ready(()))
|
||||
}
|
||||
|
||||
if not_ready {
|
||||
return Ok(Async::NotReady)
|
||||
// upgrade to h2
|
||||
let proto = mem::replace(&mut self.proto, Protocol::None);
|
||||
match proto {
|
||||
Protocol::H1(h1) => {
|
||||
let (stream, addr, router, buf) = h1.into_inner();
|
||||
self.proto = Protocol::H2(h2::Http2::new(stream, addr, router, buf));
|
||||
return self.poll()
|
||||
}
|
||||
_ => unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
375
src/task.rs
375
src/task.rs
@ -1,27 +1,18 @@
|
||||
use std::{mem, cmp, io};
|
||||
use std::{mem, io};
|
||||
use std::rc::Rc;
|
||||
use std::fmt::Write;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::VecDeque;
|
||||
|
||||
use http::{StatusCode, Version};
|
||||
use http::header::{HeaderValue,
|
||||
CONNECTION, CONTENT_TYPE, CONTENT_LENGTH, TRANSFER_ENCODING, DATE};
|
||||
use bytes::BytesMut;
|
||||
use futures::{Async, Future, Poll, Stream};
|
||||
use futures::task::{Task as FutureTask, current as current_task};
|
||||
use tokio_io::AsyncWrite;
|
||||
|
||||
use date;
|
||||
use body::Body;
|
||||
use h1writer::{Writer, WriterState};
|
||||
use route::Frame;
|
||||
use application::Middleware;
|
||||
use httprequest::HttpRequest;
|
||||
use httpresponse::HttpResponse;
|
||||
|
||||
type FrameStream = Stream<Item=Frame, Error=io::Error>;
|
||||
const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific
|
||||
const MAX_WRITE_BUFFER_SIZE: usize = 65_536; // max buffer size 64k
|
||||
|
||||
#[derive(PartialEq, Debug)]
|
||||
enum TaskRunningState {
|
||||
@ -34,6 +25,16 @@ impl TaskRunningState {
|
||||
fn is_done(&self) -> bool {
|
||||
*self == TaskRunningState::Done
|
||||
}
|
||||
fn pause(&mut self) {
|
||||
if *self != TaskRunningState::Done {
|
||||
*self = TaskRunningState::Paused
|
||||
}
|
||||
}
|
||||
fn resume(&mut self) {
|
||||
if *self != TaskRunningState::Done {
|
||||
*self = TaskRunningState::Running
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug)]
|
||||
@ -100,17 +101,12 @@ impl Future for DrainFut {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub struct Task {
|
||||
state: TaskRunningState,
|
||||
iostate: TaskIOState,
|
||||
frames: VecDeque<Frame>,
|
||||
stream: TaskStream,
|
||||
encoder: Encoder,
|
||||
buffer: BytesMut,
|
||||
drain: Vec<Rc<RefCell<DrainFut>>>,
|
||||
upgrade: bool,
|
||||
keepalive: bool,
|
||||
prepared: Option<HttpResponse>,
|
||||
disconnected: bool,
|
||||
middlewares: Option<Rc<Vec<Box<Middleware>>>>,
|
||||
@ -129,10 +125,6 @@ impl Task {
|
||||
frames: frames,
|
||||
drain: Vec::new(),
|
||||
stream: TaskStream::None,
|
||||
encoder: Encoder::length(0),
|
||||
buffer: BytesMut::new(),
|
||||
upgrade: false,
|
||||
keepalive: false,
|
||||
prepared: None,
|
||||
disconnected: false,
|
||||
middlewares: None,
|
||||
@ -147,11 +139,7 @@ impl Task {
|
||||
iostate: TaskIOState::ReadingMessage,
|
||||
frames: VecDeque::new(),
|
||||
stream: TaskStream::Stream(Box::new(stream)),
|
||||
encoder: Encoder::length(0),
|
||||
buffer: BytesMut::new(),
|
||||
drain: Vec::new(),
|
||||
upgrade: false,
|
||||
keepalive: false,
|
||||
prepared: None,
|
||||
disconnected: false,
|
||||
middlewares: None,
|
||||
@ -165,158 +153,26 @@ impl Task {
|
||||
iostate: TaskIOState::ReadingMessage,
|
||||
frames: VecDeque::new(),
|
||||
stream: TaskStream::Context(Box::new(ctx)),
|
||||
encoder: Encoder::length(0),
|
||||
buffer: BytesMut::new(),
|
||||
drain: Vec::new(),
|
||||
upgrade: false,
|
||||
keepalive: false,
|
||||
prepared: None,
|
||||
disconnected: false,
|
||||
middlewares: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn keepalive(&self) -> bool {
|
||||
self.keepalive && !self.upgrade
|
||||
}
|
||||
|
||||
pub(crate) fn set_middlewares(&mut self, middlewares: Rc<Vec<Box<Middleware>>>) {
|
||||
self.middlewares = Some(middlewares);
|
||||
}
|
||||
|
||||
pub(crate) fn disconnected(&mut self) {
|
||||
let len = self.buffer.len();
|
||||
self.buffer.split_to(len);
|
||||
self.disconnected = true;
|
||||
if let TaskStream::Context(ref mut ctx) = self.stream {
|
||||
ctx.disconnected();
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare(&mut self, req: &mut HttpRequest, msg: HttpResponse)
|
||||
{
|
||||
trace!("Prepare message status={:?}", msg.status);
|
||||
|
||||
// run middlewares
|
||||
let mut msg = if let Some(middlewares) = self.middlewares.take() {
|
||||
let mut msg = msg;
|
||||
for middleware in middlewares.iter() {
|
||||
msg = middleware.response(req, msg);
|
||||
}
|
||||
self.middlewares = Some(middlewares);
|
||||
msg
|
||||
} else {
|
||||
msg
|
||||
};
|
||||
|
||||
// prepare task
|
||||
let mut extra = 0;
|
||||
let body = msg.replace_body(Body::Empty);
|
||||
let version = msg.version().unwrap_or_else(|| req.version());
|
||||
self.keepalive = msg.keep_alive().unwrap_or_else(|| req.keep_alive());
|
||||
|
||||
match body {
|
||||
Body::Empty => {
|
||||
if msg.chunked() {
|
||||
error!("Chunked transfer is enabled but body is set to Empty");
|
||||
}
|
||||
msg.headers.insert(CONTENT_LENGTH, HeaderValue::from_static("0"));
|
||||
msg.headers.remove(TRANSFER_ENCODING);
|
||||
self.encoder = Encoder::length(0);
|
||||
},
|
||||
Body::Length(n) => {
|
||||
if msg.chunked() {
|
||||
error!("Chunked transfer is enabled but body with specific length is specified");
|
||||
}
|
||||
msg.headers.insert(
|
||||
CONTENT_LENGTH,
|
||||
HeaderValue::from_str(format!("{}", n).as_str()).unwrap());
|
||||
msg.headers.remove(TRANSFER_ENCODING);
|
||||
self.encoder = Encoder::length(n);
|
||||
},
|
||||
Body::Binary(ref bytes) => {
|
||||
extra = bytes.len();
|
||||
msg.headers.insert(
|
||||
CONTENT_LENGTH,
|
||||
HeaderValue::from_str(format!("{}", bytes.len()).as_str()).unwrap());
|
||||
msg.headers.remove(TRANSFER_ENCODING);
|
||||
self.encoder = Encoder::length(0);
|
||||
}
|
||||
Body::Streaming => {
|
||||
if msg.chunked() {
|
||||
if version < Version::HTTP_11 {
|
||||
error!("Chunked transfer encoding is forbidden for {:?}", version);
|
||||
}
|
||||
msg.headers.remove(CONTENT_LENGTH);
|
||||
msg.headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
|
||||
self.encoder = Encoder::chunked();
|
||||
} else {
|
||||
self.encoder = Encoder::eof();
|
||||
}
|
||||
}
|
||||
Body::Upgrade => {
|
||||
msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
|
||||
self.encoder = Encoder::eof();
|
||||
}
|
||||
}
|
||||
|
||||
// Connection upgrade
|
||||
if msg.upgrade() {
|
||||
msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
|
||||
}
|
||||
// keep-alive
|
||||
else if self.keepalive {
|
||||
if version < Version::HTTP_11 {
|
||||
msg.headers.insert(CONNECTION, HeaderValue::from_static("keep-alive"));
|
||||
}
|
||||
} else if version >= Version::HTTP_11 {
|
||||
msg.headers.insert(CONNECTION, HeaderValue::from_static("close"));
|
||||
}
|
||||
|
||||
// render message
|
||||
let init_cap = 100 + msg.headers.len() * AVERAGE_HEADER_SIZE + extra;
|
||||
self.buffer.reserve(init_cap);
|
||||
|
||||
if version == Version::HTTP_11 && msg.status == StatusCode::OK {
|
||||
self.buffer.extend(b"HTTP/1.1 200 OK\r\n");
|
||||
} else {
|
||||
let _ = write!(self.buffer, "{:?} {}\r\n", version, msg.status);
|
||||
}
|
||||
for (key, value) in &msg.headers {
|
||||
let t: &[u8] = key.as_ref();
|
||||
self.buffer.extend(t);
|
||||
self.buffer.extend(b": ");
|
||||
self.buffer.extend(value.as_ref());
|
||||
self.buffer.extend(b"\r\n");
|
||||
}
|
||||
|
||||
// using http::h1::date is quite a lot faster than generating
|
||||
// a unique Date header each time like req/s goes up about 10%
|
||||
if !msg.headers.contains_key(DATE) {
|
||||
self.buffer.reserve(date::DATE_VALUE_LENGTH + 8);
|
||||
self.buffer.extend(b"Date: ");
|
||||
date::extend(&mut self.buffer);
|
||||
self.buffer.extend(b"\r\n");
|
||||
}
|
||||
|
||||
// default content-type
|
||||
if !msg.headers.contains_key(CONTENT_TYPE) {
|
||||
self.buffer.extend(b"ContentType: application/octet-stream\r\n".as_ref());
|
||||
}
|
||||
|
||||
self.buffer.extend(b"\r\n");
|
||||
|
||||
if let Body::Binary(ref bytes) = body {
|
||||
self.buffer.extend_from_slice(bytes.as_ref());
|
||||
self.prepared = Some(msg);
|
||||
return
|
||||
}
|
||||
msg.replace_body(body);
|
||||
self.prepared = Some(msg);
|
||||
}
|
||||
|
||||
pub(crate) fn poll_io<T>(&mut self, io: &mut T, req: &mut HttpRequest) -> Poll<bool, ()>
|
||||
where T: AsyncWrite
|
||||
where T: Writer
|
||||
{
|
||||
trace!("POLL-IO frames:{:?}", self.frames.len());
|
||||
// response is completed
|
||||
@ -328,87 +184,76 @@ impl Task {
|
||||
match self.poll() {
|
||||
Ok(Async::Ready(_)) => {
|
||||
self.state = TaskRunningState::Done;
|
||||
}
|
||||
},
|
||||
Ok(Async::NotReady) => (),
|
||||
Err(_) => return Err(())
|
||||
}
|
||||
}
|
||||
|
||||
// use exiting frames
|
||||
while let Some(frame) = self.frames.pop_front() {
|
||||
trace!("IO Frame: {:?}", frame);
|
||||
match frame {
|
||||
Frame::Message(response) => {
|
||||
if !self.disconnected {
|
||||
self.prepare(req, response);
|
||||
if self.state != TaskRunningState::Paused {
|
||||
while let Some(frame) = self.frames.pop_front() {
|
||||
trace!("IO Frame: {:?}", frame);
|
||||
let res = match frame {
|
||||
Frame::Message(mut response) => {
|
||||
trace!("Prepare message status={:?}", response.status);
|
||||
|
||||
// run middlewares
|
||||
let mut response =
|
||||
if let Some(middlewares) = self.middlewares.take() {
|
||||
let mut response = response;
|
||||
for middleware in middlewares.iter() {
|
||||
response = middleware.response(req, response);
|
||||
}
|
||||
self.middlewares = Some(middlewares);
|
||||
response
|
||||
} else {
|
||||
response
|
||||
};
|
||||
|
||||
let result = io.start(req, &mut response);
|
||||
self.prepared = Some(response);
|
||||
result
|
||||
}
|
||||
}
|
||||
Frame::Payload(Some(chunk)) => {
|
||||
if !self.disconnected {
|
||||
if self.prepared.is_some() {
|
||||
// TODO: add warning, write after EOF
|
||||
self.encoder.encode(&mut self.buffer, chunk.as_ref());
|
||||
} else {
|
||||
// might be response for EXCEPT
|
||||
self.buffer.extend_from_slice(chunk.as_ref())
|
||||
}
|
||||
Frame::Payload(Some(chunk)) => {
|
||||
io.write(chunk.as_ref())
|
||||
},
|
||||
Frame::Payload(None) => {
|
||||
self.iostate = TaskIOState::Done;
|
||||
io.write_eof()
|
||||
},
|
||||
Frame::Drain(fut) => {
|
||||
self.drain.push(fut);
|
||||
break
|
||||
}
|
||||
},
|
||||
Frame::Payload(None) => {
|
||||
if !self.disconnected &&
|
||||
!self.encoder.encode(&mut self.buffer, [].as_ref())
|
||||
{
|
||||
// TODO: add error "not eof""
|
||||
debug!("last payload item, but it is not EOF ");
|
||||
return Err(())
|
||||
};
|
||||
|
||||
match res {
|
||||
Ok(WriterState::Pause) => {
|
||||
self.state.pause();
|
||||
break
|
||||
}
|
||||
break
|
||||
},
|
||||
Frame::Drain(fut) => {
|
||||
self.drain.push(fut);
|
||||
break
|
||||
Ok(WriterState::Done) => self.state.resume(),
|
||||
Err(_) => return Err(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// write bytes to TcpStream
|
||||
if !self.disconnected {
|
||||
while !self.buffer.is_empty() {
|
||||
match io.write(self.buffer.as_ref()) {
|
||||
Ok(n) => {
|
||||
self.buffer.split_to(n);
|
||||
},
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||
break
|
||||
}
|
||||
Err(_) => return Err(()),
|
||||
}
|
||||
// flush io
|
||||
match io.poll_complete() {
|
||||
Ok(Async::Ready(())) => self.state.resume(),
|
||||
Ok(Async::NotReady) => {
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
}
|
||||
|
||||
// should pause task
|
||||
if self.state != TaskRunningState::Done {
|
||||
if self.buffer.len() > MAX_WRITE_BUFFER_SIZE {
|
||||
self.state = TaskRunningState::Paused;
|
||||
} else if self.state == TaskRunningState::Paused {
|
||||
self.state = TaskRunningState::Running;
|
||||
Err(err) => {
|
||||
trace!("Error sending data: {}", err);
|
||||
return Err(())
|
||||
}
|
||||
} else {
|
||||
// at this point we wont get any more Frames
|
||||
self.iostate = TaskIOState::Done;
|
||||
}
|
||||
|
||||
// drain
|
||||
if self.buffer.is_empty() && !self.drain.is_empty() {
|
||||
match io.flush() {
|
||||
Ok(_) => (),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
Err(_) => return Err(()),
|
||||
}
|
||||
|
||||
if !self.drain.is_empty() {
|
||||
for fut in &mut self.drain {
|
||||
fut.borrow_mut().set()
|
||||
}
|
||||
@ -416,7 +261,7 @@ impl Task {
|
||||
}
|
||||
|
||||
// response is completed
|
||||
if (self.buffer.is_empty() || self.disconnected) && self.iostate.is_done() {
|
||||
if self.iostate.is_done() {
|
||||
// run middlewares
|
||||
if let Some(ref mut resp) = self.prepared {
|
||||
if let Some(middlewares) = self.middlewares.take() {
|
||||
@ -443,8 +288,8 @@ impl Task {
|
||||
error!("Non expected frame {:?}", frame);
|
||||
return Err(())
|
||||
}
|
||||
self.upgrade = msg.upgrade();
|
||||
if self.upgrade || msg.body().has_body() {
|
||||
let upgrade = msg.upgrade();
|
||||
if upgrade || msg.body().has_body() {
|
||||
self.iostate = TaskIOState::ReadingPayload;
|
||||
} else {
|
||||
self.iostate = TaskIOState::Done;
|
||||
@ -489,89 +334,3 @@ impl Future for Task {
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Encoders to handle different Transfer-Encodings.
|
||||
#[derive(Debug, Clone)]
|
||||
struct Encoder {
|
||||
kind: Kind,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
enum Kind {
|
||||
/// An Encoder for when Transfer-Encoding includes `chunked`.
|
||||
Chunked(bool),
|
||||
/// An Encoder for when Content-Length is set.
|
||||
///
|
||||
/// Enforces that the body is not longer than the Content-Length header.
|
||||
Length(u64),
|
||||
/// An Encoder for when Content-Length is not known.
|
||||
///
|
||||
/// Appliction decides when to stop writing.
|
||||
Eof,
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
|
||||
pub fn eof() -> Encoder {
|
||||
Encoder {
|
||||
kind: Kind::Eof,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn chunked() -> Encoder {
|
||||
Encoder {
|
||||
kind: Kind::Chunked(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn length(len: u64) -> Encoder {
|
||||
Encoder {
|
||||
kind: Kind::Length(len),
|
||||
}
|
||||
}
|
||||
|
||||
/*pub fn is_eof(&self) -> bool {
|
||||
match self.kind {
|
||||
Kind::Eof | Kind::Length(0) => true,
|
||||
Kind::Chunked(eof) => eof,
|
||||
_ => false,
|
||||
}
|
||||
}*/
|
||||
|
||||
/// Encode message. Return `EOF` state of encoder
|
||||
pub fn encode(&mut self, dst: &mut BytesMut, msg: &[u8]) -> bool {
|
||||
match self.kind {
|
||||
Kind::Eof => {
|
||||
dst.extend(msg);
|
||||
msg.is_empty()
|
||||
},
|
||||
Kind::Chunked(ref mut eof) => {
|
||||
if *eof {
|
||||
return true;
|
||||
}
|
||||
|
||||
if msg.is_empty() {
|
||||
*eof = true;
|
||||
dst.extend(b"0\r\n\r\n");
|
||||
} else {
|
||||
write!(dst, "{:X}\r\n", msg.len()).unwrap();
|
||||
dst.extend(msg);
|
||||
dst.extend(b"\r\n");
|
||||
}
|
||||
*eof
|
||||
},
|
||||
Kind::Length(ref mut remaining) => {
|
||||
if msg.is_empty() {
|
||||
return *remaining == 0
|
||||
}
|
||||
let max = cmp::min(*remaining, msg.len() as u64);
|
||||
trace!("sized write = {}", max);
|
||||
dst.extend(msg[..max as usize].as_ref());
|
||||
|
||||
*remaining -= max as u64;
|
||||
trace!("encoded {} bytes, remaining = {}", max, remaining);
|
||||
*remaining == 0
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user