1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-27 17:22:57 +01:00

refactor task impl, extract stream writer to separate struct

This commit is contained in:
Nikolay Kim 2017-11-03 11:29:44 -07:00
parent f010672885
commit 4add742aba
9 changed files with 979 additions and 620 deletions

View File

@ -28,8 +28,6 @@ default = []
# tls
tls = ["native-tls", "tokio-tls"]
# http2 = ["h2"]
[dependencies]
log = "0.3"
time = "0.1"

4
cov.sh
View File

@ -1,4 +0,0 @@
#!/bin/bash
for file in target/debug/actix_web-*[^\.d]; do mkdir -p "target/cov/$(basename $file)"; /usr/local/bin/kcov --exclude-pattern=/.cargo,/usr/lib --verify "target/cov/$(basename $file)" "$file"; done &&
for file in target/debug/test_*[^\.d]; do mkdir -p "target/cov/$(basename $file)"; /usr/local/bin/kcov --exclude-pattern=/.cargo,/usr/lib --verify "target/cov/$(basename $file)" "$file"; done

View File

@ -47,6 +47,7 @@ impl<A> ActorContext for HttpContext<A> where A: Actor<Context=Self> + Route
{
/// Stop actor execution
fn stop(&mut self) {
self.stream.push_back(Frame::Payload(None));
self.items.stop();
self.address.close();
if self.state == ActorState::Running {
@ -141,7 +142,6 @@ impl<A> HttpContext<A> where A: Actor<Context=Self> + Route {
/// Indicate end of streamimng payload. Also this method calls `Self::close`.
pub fn write_eof(&mut self) {
self.stream.push_back(Frame::Payload(None));
self.stop();
}

465
src/h1.rs
View File

@ -1,22 +1,282 @@
use std::{self, io, ptr};
use std::rc::Rc;
use std::cell::UnsafeCell;
use std::time::Duration;
use std::collections::VecDeque;
use actix::Arbiter;
use httparse;
use http::{Method, Version, HttpTryFrom, HeaderMap};
use http::header::{self, HeaderName, HeaderValue};
use bytes::{Bytes, BytesMut, BufMut};
use futures::{Async, Poll};
use tokio_io::AsyncRead;
use futures::{Future, Poll, Async};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_core::reactor::Timeout;
use percent_encoding;
use task::Task;
use server::HttpHandler;
use error::ParseError;
use httpcodes::HTTPNotFound;
use httprequest::HttpRequest;
use payload::{Payload, PayloadError, PayloadSender};
use h1writer::H1Writer;
const MAX_HEADERS: usize = 100;
const KEEPALIVE_PERIOD: u64 = 15; // seconds
const INIT_BUFFER_SIZE: usize = 8192;
const MAX_BUFFER_SIZE: usize = 131_072;
const MAX_HEADERS: usize = 100;
const MAX_PIPELINED_MESSAGES: usize = 16;
const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0";
pub(crate) enum Http1Result {
Done,
Upgrade,
}
pub(crate) struct Http1<T: AsyncWrite + 'static, A: 'static, H: 'static> {
router: Rc<Vec<H>>,
#[allow(dead_code)]
addr: A,
stream: H1Writer<T>,
reader: Reader,
read_buf: BytesMut,
error: bool,
tasks: VecDeque<Entry>,
keepalive: bool,
keepalive_timer: Option<Timeout>,
h2: bool,
}
struct Entry {
task: Task,
req: UnsafeCell<HttpRequest>,
eof: bool,
error: bool,
finished: bool,
}
impl<T, A, H> Http1<T, A, H>
where T: AsyncRead + AsyncWrite + 'static,
A: 'static,
H: HttpHandler + 'static
{
pub fn new(stream: T, addr: A, router: Rc<Vec<H>>) -> Self {
Http1{ router: router,
addr: addr,
stream: H1Writer::new(stream),
reader: Reader::new(),
read_buf: BytesMut::new(),
error: false,
tasks: VecDeque::new(),
keepalive: true,
keepalive_timer: None,
h2: false }
}
pub fn into_inner(mut self) -> (T, A, Rc<Vec<H>>, Bytes) {
(self.stream.into_inner(), self.addr, self.router, self.read_buf.freeze())
}
pub fn poll(&mut self) -> Poll<Http1Result, ()> {
// keep-alive timer
if let Some(ref mut timeout) = self.keepalive_timer {
match timeout.poll() {
Ok(Async::Ready(_)) =>
return Ok(Async::Ready(Http1Result::Done)),
Ok(Async::NotReady) => (),
Err(_) => unreachable!(),
}
}
loop {
let mut not_ready = true;
// check in-flight messages
let mut io = false;
let mut idx = 0;
while idx < self.tasks.len() {
let item = &mut self.tasks[idx];
if !io && !item.eof {
if item.error {
return Err(())
}
// this is anoying
let req = unsafe {item.req.get().as_mut().unwrap()};
match item.task.poll_io(&mut self.stream, req)
{
Ok(Async::Ready(ready)) => {
not_ready = false;
// overide keep-alive state
if self.keepalive {
self.keepalive = self.stream.keepalive();
}
self.stream = H1Writer::new(self.stream.into_inner());
item.eof = true;
if ready {
item.finished = true;
}
},
Ok(Async::NotReady) => {
// no more IO for this iteration
io = true;
},
Err(_) => {
// it is not possible to recover from error
// during task handling, so just drop connection
return Err(())
}
}
} else if !item.finished {
match item.task.poll() {
Ok(Async::NotReady) => (),
Ok(Async::Ready(_)) => {
not_ready = false;
item.finished = true;
},
Err(_) =>
item.error = true,
}
}
idx += 1;
}
// cleanup finished tasks
while !self.tasks.is_empty() {
if self.tasks[0].eof && self.tasks[0].finished {
self.tasks.pop_front();
} else {
break
}
}
// no keep-alive
if !self.keepalive && self.tasks.is_empty() {
if self.h2 {
return Ok(Async::Ready(Http1Result::Upgrade))
} else {
return Ok(Async::Ready(Http1Result::Done))
}
}
// read incoming data
if !self.error && !self.h2 && self.tasks.len() < MAX_PIPELINED_MESSAGES {
match self.reader.parse(self.stream.get_mut(), &mut self.read_buf) {
Ok(Async::Ready(Item::Http1(mut req, payload))) => {
not_ready = false;
// stop keepalive timer
self.keepalive_timer.take();
// start request processing
let mut task = None;
for h in self.router.iter() {
if req.path().starts_with(h.prefix()) {
task = Some(h.handle(&mut req, payload));
break
}
}
self.tasks.push_back(
Entry {task: task.unwrap_or_else(|| Task::reply(HTTPNotFound)),
req: UnsafeCell::new(req),
eof: false,
error: false,
finished: false});
}
Ok(Async::Ready(Item::Http2)) => {
self.h2 = true;
}
Err(ReaderError::Disconnect) => {
not_ready = false;
self.error = true;
self.stream.disconnected();
for entry in &mut self.tasks {
entry.task.disconnected()
}
},
Err(err) => {
// notify all tasks
not_ready = false;
self.stream.disconnected();
for entry in &mut self.tasks {
entry.task.disconnected()
}
// kill keepalive
self.keepalive = false;
self.keepalive_timer.take();
// on parse error, stop reading stream but
// tasks need to be completed
self.error = true;
if self.tasks.is_empty() {
if let ReaderError::Error(err) = err {
self.tasks.push_back(
Entry {task: Task::reply(err),
req: UnsafeCell::new(HttpRequest::for_error()),
eof: false,
error: false,
finished: false});
}
}
}
Ok(Async::NotReady) => {
// start keep-alive timer, this is also slow request timeout
if self.tasks.is_empty() {
if self.keepalive {
if self.keepalive_timer.is_none() {
trace!("Start keep-alive timer");
let mut timeout = Timeout::new(
Duration::new(KEEPALIVE_PERIOD, 0),
Arbiter::handle()).unwrap();
// register timeout
let _ = timeout.poll();
self.keepalive_timer = Some(timeout);
}
} else {
// keep-alive disable, drop connection
return Ok(Async::Ready(Http1Result::Done))
}
}
return Ok(Async::NotReady)
}
}
}
// check for parse error
if self.tasks.is_empty() {
if self.error || self.keepalive_timer.is_none() {
return Ok(Async::Ready(Http1Result::Done))
}
else if self.h2 {
return Ok(Async::Ready(Http1Result::Upgrade))
}
}
if not_ready {
return Ok(Async::NotReady)
}
}
}
}
#[derive(Debug)]
enum Item {
Http1(HttpRequest, Payload),
Http2,
}
struct Reader {
h1: bool,
payload: Option<PayloadInfo>,
}
enum Decoding {
Paused,
Ready,
@ -28,19 +288,9 @@ struct PayloadInfo {
decoder: Decoder,
}
pub(crate) struct Reader {
read_buf: BytesMut,
payload: Option<PayloadInfo>,
}
#[derive(Debug)]
pub(crate) enum ReaderItem {
Http1(HttpRequest, Payload),
Http2,
}
#[derive(Debug)]
pub(crate) enum ReaderError {
enum ReaderError {
Disconnect,
Payload,
Error(ParseError),
}
@ -55,19 +305,19 @@ enum Message {
impl Reader {
pub fn new() -> Reader {
Reader {
read_buf: BytesMut::new(),
h1: false,
payload: None,
}
}
fn decode(&mut self) -> std::result::Result<Decoding, ReaderError>
fn decode(&mut self, buf: &mut BytesMut) -> std::result::Result<Decoding, ReaderError>
{
if let Some(ref mut payload) = self.payload {
if payload.tx.maybe_paused() {
return Ok(Decoding::Paused)
}
loop {
match payload.decoder.decode(&mut self.read_buf) {
match payload.decoder.decode(buf) {
Ok(Async::Ready(Some(bytes))) => {
payload.tx.feed_data(bytes)
},
@ -87,18 +337,18 @@ impl Reader {
}
}
pub fn parse<T>(&mut self, io: &mut T) -> Poll<(HttpRequest, Payload), ReaderError>
pub fn parse<T>(&mut self, io: &mut T, buf: &mut BytesMut) -> Poll<Item, ReaderError>
where T: AsyncRead
{
loop {
match self.decode()? {
match self.decode(buf)? {
Decoding::Paused => return Ok(Async::NotReady),
Decoding::Ready => {
self.payload = None;
break
},
Decoding::NotReady => {
match self.read_from_io(io) {
match self.read_from_io(io, buf) {
Ok(Async::Ready(0)) => {
if let Some(ref mut payload) = self.payload {
payload.tx.set_error(PayloadError::Incomplete);
@ -123,7 +373,7 @@ impl Reader {
}
loop {
match Reader::parse_message(&mut self.read_buf).map_err(ReaderError::Error)? {
match Reader::parse_message(buf).map_err(ReaderError::Error)? {
Message::Http1(msg, decoder) => {
let payload = if let Some(decoder) = decoder {
let (tx, rx) = Payload::new(false);
@ -134,7 +384,7 @@ impl Reader {
self.payload = Some(payload);
loop {
match self.decode()? {
match self.decode(buf)? {
Decoding::Paused =>
break,
Decoding::Ready => {
@ -142,7 +392,7 @@ impl Reader {
break
},
Decoding::NotReady => {
match self.read_from_io(io) {
match self.read_from_io(io, buf) {
Ok(Async::Ready(0)) => {
trace!("parse eof");
if let Some(ref mut payload) = self.payload {
@ -171,21 +421,26 @@ impl Reader {
let (_, rx) = Payload::new(true);
rx
};
return Ok(Async::Ready((msg, payload)));
self.h1 = true;
return Ok(Async::Ready(Item::Http1(msg, payload)));
},
Message::Http2 => {
if self.h1 {
return Err(ReaderError::Error(ParseError::Version))
}
return Ok(Async::Ready(Item::Http2));
},
Message::NotReady => {
if self.read_buf.capacity() >= MAX_BUFFER_SIZE {
if buf.capacity() >= MAX_BUFFER_SIZE {
debug!("MAX_BUFFER_SIZE reached, closing");
return Err(ReaderError::Error(ParseError::TooLarge));
}
},
}
match self.read_from_io(io) {
match self.read_from_io(io, buf) {
Ok(Async::Ready(0)) => {
trace!("Eof during parse");
return Err(ReaderError::Error(ParseError::Incomplete));
debug!("Ignored premature client disconnection");
return Err(ReaderError::Disconnect);
},
Ok(Async::Ready(_)) => (),
Ok(Async::NotReady) =>
@ -196,17 +451,19 @@ impl Reader {
}
}
fn read_from_io<T: AsyncRead>(&mut self, io: &mut T) -> Poll<usize, io::Error> {
if self.read_buf.remaining_mut() < INIT_BUFFER_SIZE {
self.read_buf.reserve(INIT_BUFFER_SIZE);
fn read_from_io<T: AsyncRead>(&mut self, io: &mut T, buf: &mut BytesMut)
-> Poll<usize, io::Error>
{
if buf.remaining_mut() < INIT_BUFFER_SIZE {
buf.reserve(INIT_BUFFER_SIZE);
unsafe { // Zero out unused memory
let buf = self.read_buf.bytes_mut();
let len = buf.len();
ptr::write_bytes(buf.as_mut_ptr(), 0, len);
let b = buf.bytes_mut();
let len = b.len();
ptr::write_bytes(b.as_mut_ptr(), 0, len);
}
}
unsafe {
let n = match io.read(self.read_buf.bytes_mut()) {
let n = match io.read(buf.bytes_mut()) {
Ok(n) => n,
Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock {
@ -215,18 +472,17 @@ impl Reader {
return Err(e)
}
};
self.read_buf.advance_mut(n);
buf.advance_mut(n);
Ok(Async::Ready(n))
}
}
fn parse_message(buf: &mut BytesMut) -> Result<Message, ParseError>
{
println!("BUF: {:?}", buf);
if buf.is_empty() || buf.len() < 14 {
if buf.is_empty() {
return Ok(Message::NotReady);
}
if &buf[..14] == &HTTP2_PREFACE[..] {
if buf.len() >= 14 && &buf[..14] == &HTTP2_PREFACE[..] {
return Ok(Message::Http2)
}
@ -368,7 +624,7 @@ fn record_header_indices(bytes: &[u8],
/// If a message body does not include a Transfer-Encoding, it *should*
/// include a Content-Length header.
#[derive(Debug, Clone, PartialEq)]
pub struct Decoder {
struct Decoder {
kind: Kind,
}
@ -424,7 +680,7 @@ enum ChunkedState {
}
impl Decoder {
pub fn is_eof(&self) -> bool {
/*pub fn is_eof(&self) -> bool {
trace!("is_eof? {:?}", self);
match self.kind {
Kind::Length(0) |
@ -432,7 +688,7 @@ impl Decoder {
Kind::Eof(true) => true,
_ => false,
}
}
}*/
}
impl Decoder {
@ -633,7 +889,7 @@ mod tests {
use futures::{Async};
use tokio_io::AsyncRead;
use http::{Version, Method};
use super::{Reader, ReaderError};
use super::*;
struct Buffer {
buf: Bytes,
@ -682,8 +938,8 @@ mod tests {
macro_rules! parse_ready {
($e:expr) => (
match Reader::new().parse($e) {
Ok(Async::Ready((req, payload))) => (req, payload),
match Reader::new().parse($e, &mut BytesMut::new()) {
Ok(Async::Ready(Item::Http1(req, payload))) => (req, payload),
Ok(_) => panic!("Eof during parsing http request"),
Err(err) => panic!("Error during parsing http request: {:?}", err),
}
@ -693,7 +949,7 @@ mod tests {
macro_rules! reader_parse_ready {
($e:expr) => (
match $e {
Ok(Async::Ready((req, payload))) => (req, payload),
Ok(Async::Ready(Item::Http1(req, payload))) => (req, payload),
Ok(_) => panic!("Eof during parsing http request"),
Err(err) => panic!("Error during parsing http request: {:?}", err),
}
@ -701,22 +957,28 @@ mod tests {
}
macro_rules! expect_parse_err {
($e:expr) => (match Reader::new().parse($e) {
Err(err) => match err {
ReaderError::Error(_) => (),
_ => panic!("Parse error expected"),
},
_ => panic!("Error expected"),
})
($e:expr) => ({
let mut buf = BytesMut::new();
match Reader::new().parse($e, &mut buf) {
Err(err) => match err {
ReaderError::Error(_) => (),
_ => panic!("Parse error expected"),
},
val => {
panic!("Error expected")
}
}}
)
}
#[test]
fn test_parse() {
let mut buf = Buffer::new("GET /test HTTP/1.1\r\n\r\n");
let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
match reader.parse(&mut buf) {
Ok(Async::Ready((req, payload))) => {
match reader.parse(&mut buf, &mut readbuf) {
Ok(Async::Ready(Item::Http1(req, payload))) => {
assert_eq!(req.version(), Version::HTTP_11);
assert_eq!(*req.method(), Method::GET);
assert_eq!(req.path(), "/test");
@ -729,16 +991,17 @@ mod tests {
#[test]
fn test_parse_partial() {
let mut buf = Buffer::new("PUT /test HTTP/1");
let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
match reader.parse(&mut buf) {
match reader.parse(&mut buf, &mut readbuf) {
Ok(Async::NotReady) => (),
_ => panic!("Error"),
}
buf.feed_data(".1\r\n\r\n");
match reader.parse(&mut buf) {
Ok(Async::Ready((req, payload))) => {
match reader.parse(&mut buf, &mut readbuf) {
Ok(Async::Ready(Item::Http1(req, payload))) => {
assert_eq!(req.version(), Version::HTTP_11);
assert_eq!(*req.method(), Method::PUT);
assert_eq!(req.path(), "/test");
@ -751,10 +1014,11 @@ mod tests {
#[test]
fn test_parse_post() {
let mut buf = Buffer::new("POST /test2 HTTP/1.0\r\n\r\n");
let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
match reader.parse(&mut buf) {
Ok(Async::Ready((req, payload))) => {
match reader.parse(&mut buf, &mut readbuf) {
Ok(Async::Ready(Item::Http1(req, payload))) => {
assert_eq!(req.version(), Version::HTTP_10);
assert_eq!(*req.method(), Method::POST);
assert_eq!(req.path(), "/test2");
@ -767,10 +1031,11 @@ mod tests {
#[test]
fn test_parse_body() {
let mut buf = Buffer::new("GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody");
let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
match reader.parse(&mut buf) {
Ok(Async::Ready((req, mut payload))) => {
match reader.parse(&mut buf, &mut readbuf) {
Ok(Async::Ready(Item::Http1(req, mut payload))) => {
assert_eq!(req.version(), Version::HTTP_11);
assert_eq!(*req.method(), Method::GET);
assert_eq!(req.path(), "/test");
@ -784,10 +1049,11 @@ mod tests {
fn test_parse_body_crlf() {
let mut buf = Buffer::new(
"\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody");
let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
match reader.parse(&mut buf) {
Ok(Async::Ready((req, mut payload))) => {
match reader.parse(&mut buf, &mut readbuf) {
Ok(Async::Ready(Item::Http1(req, mut payload))) => {
assert_eq!(req.version(), Version::HTTP_11);
assert_eq!(*req.method(), Method::GET);
assert_eq!(req.path(), "/test");
@ -800,13 +1066,14 @@ mod tests {
#[test]
fn test_parse_partial_eof() {
let mut buf = Buffer::new("GET /test HTTP/1.1\r\n");
let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
not_ready!{ reader.parse(&mut buf) }
not_ready!{ reader.parse(&mut buf, &mut readbuf) }
buf.feed_data("\r\n");
match reader.parse(&mut buf) {
Ok(Async::Ready((req, payload))) => {
match reader.parse(&mut buf, &mut readbuf) {
Ok(Async::Ready(Item::Http1(req, payload))) => {
assert_eq!(req.version(), Version::HTTP_11);
assert_eq!(*req.method(), Method::GET);
assert_eq!(req.path(), "/test");
@ -819,19 +1086,20 @@ mod tests {
#[test]
fn test_headers_split_field() {
let mut buf = Buffer::new("GET /test HTTP/1.1\r\n");
let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
not_ready!{ reader.parse(&mut buf) }
not_ready!{ reader.parse(&mut buf, &mut readbuf) }
buf.feed_data("t");
not_ready!{ reader.parse(&mut buf) }
not_ready!{ reader.parse(&mut buf, &mut readbuf) }
buf.feed_data("es");
not_ready!{ reader.parse(&mut buf) }
not_ready!{ reader.parse(&mut buf, &mut readbuf) }
buf.feed_data("t: value\r\n\r\n");
match reader.parse(&mut buf) {
Ok(Async::Ready((req, payload))) => {
match reader.parse(&mut buf, &mut readbuf) {
Ok(Async::Ready(Item::Http1(req, payload))) => {
assert_eq!(req.version(), Version::HTTP_11);
assert_eq!(*req.method(), Method::GET);
assert_eq!(req.path(), "/test");
@ -848,10 +1116,11 @@ mod tests {
"GET /test HTTP/1.1\r\n\
Set-Cookie: c1=cookie1\r\n\
Set-Cookie: c2=cookie2\r\n\r\n");
let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
match reader.parse(&mut buf) {
Ok(Async::Ready((req, _))) => {
match reader.parse(&mut buf, &mut readbuf) {
Ok(Async::Ready(Item::Http1(req, _))) => {
let val: Vec<_> = req.headers().get_all("Set-Cookie")
.iter().map(|v| v.to_str().unwrap().to_owned()).collect();
assert_eq!(val[0], "c1=cookie1");
@ -1081,14 +1350,15 @@ mod tests {
let mut buf = Buffer::new(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n");
let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf));
let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf));
assert!(req.chunked().unwrap());
assert!(!payload.eof());
buf.feed_data("4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n");
not_ready!(reader.parse(&mut buf));
not_ready!(reader.parse(&mut buf, &mut readbuf));
assert!(!payload.eof());
assert_eq!(payload.readall().unwrap().as_ref(), b"dataline");
assert!(payload.eof());
@ -1099,10 +1369,11 @@ mod tests {
let mut buf = Buffer::new(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n");
let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf));
let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf));
assert!(req.chunked().unwrap());
assert!(!payload.eof());
@ -1111,7 +1382,7 @@ mod tests {
POST /test2 HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n");
let (req2, payload2) = reader_parse_ready!(reader.parse(&mut buf));
let (req2, payload2) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf));
assert_eq!(*req2.method(), Method::POST);
assert!(req2.chunked().unwrap());
assert!(!payload2.eof());
@ -1125,37 +1396,38 @@ mod tests {
let mut buf = Buffer::new(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n");
let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf));
let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf));
assert!(req.chunked().unwrap());
assert!(!payload.eof());
buf.feed_data("4\r\ndata\r");
not_ready!(reader.parse(&mut buf));
not_ready!(reader.parse(&mut buf, &mut readbuf));
buf.feed_data("\n4");
not_ready!(reader.parse(&mut buf));
not_ready!(reader.parse(&mut buf, &mut readbuf));
buf.feed_data("\r");
not_ready!(reader.parse(&mut buf));
not_ready!(reader.parse(&mut buf, &mut readbuf));
buf.feed_data("\n");
not_ready!(reader.parse(&mut buf));
not_ready!(reader.parse(&mut buf, &mut readbuf));
buf.feed_data("li");
not_ready!(reader.parse(&mut buf));
not_ready!(reader.parse(&mut buf, &mut readbuf));
buf.feed_data("ne\r\n0\r\n");
not_ready!(reader.parse(&mut buf));
not_ready!(reader.parse(&mut buf, &mut readbuf));
//buf.feed_data("test: test\r\n");
//not_ready!(reader.parse(&mut buf));
//not_ready!(reader.parse(&mut buf, &mut readbuf));
assert_eq!(payload.readall().unwrap().as_ref(), b"dataline");
assert!(!payload.eof());
buf.feed_data("\r\n");
not_ready!(reader.parse(&mut buf));
not_ready!(reader.parse(&mut buf, &mut readbuf));
assert!(payload.eof());
}
@ -1164,14 +1436,15 @@ mod tests {
let mut buf = Buffer::new(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n");
let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf));
let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf));
assert!(req.chunked().unwrap());
assert!(!payload.eof());
buf.feed_data("4;test\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); // test: test\r\n\r\n")
not_ready!(reader.parse(&mut buf));
not_ready!(reader.parse(&mut buf, &mut readbuf));
assert!(!payload.eof());
assert_eq!(payload.readall().unwrap().as_ref(), b"dataline");
assert!(payload.eof());
@ -1193,4 +1466,16 @@ mod tests {
Err(err) => panic!("{:?}", err),
}
}*/
#[test]
fn test_http2_prefix() {
let mut buf = Buffer::new("PRI * HTTP/2.0\r\n\r\n");
let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
match reader.parse(&mut buf, &mut readbuf) {
Ok(Async::Ready(Item::Http2)) => (),
Ok(_) | Err(_) => panic!("Error during parsing http request"),
}
}
}

351
src/h1writer.rs Normal file
View File

@ -0,0 +1,351 @@
use std::{cmp, io};
use std::fmt::Write;
use bytes::BytesMut;
use futures::{Async, Poll};
use tokio_io::AsyncWrite;
use http::{Version, StatusCode};
use http::header::{HeaderValue,
CONNECTION, CONTENT_TYPE, CONTENT_LENGTH, TRANSFER_ENCODING, DATE};
use date;
use body::Body;
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific
const MAX_WRITE_BUFFER_SIZE: usize = 65_536; // max buffer size 64k
pub(crate) enum WriterState {
Done,
Pause,
}
/// Send stream
pub(crate) trait Writer {
fn start(&mut self, req: &mut HttpRequest, resp: &mut HttpResponse)
-> Result<WriterState, io::Error>;
fn write(&mut self, payload: &[u8]) -> Result<WriterState, io::Error>;
fn write_eof(&mut self) -> Result<WriterState, io::Error>;
fn poll_complete(&mut self) -> Poll<(), io::Error>;
}
pub(crate) struct H1Writer<T: AsyncWrite> {
stream: Option<T>,
buffer: BytesMut,
started: bool,
encoder: Encoder,
upgrade: bool,
keepalive: bool,
disconnected: bool,
}
impl<T: AsyncWrite> H1Writer<T> {
pub fn new(stream: T) -> H1Writer<T> {
H1Writer {
stream: Some(stream),
buffer: BytesMut::new(),
started: false,
encoder: Encoder::length(0),
upgrade: false,
keepalive: false,
disconnected: false,
}
}
pub fn get_mut(&mut self) -> &mut T {
self.stream.as_mut().unwrap()
}
pub fn into_inner(&mut self) -> T {
self.stream.take().unwrap()
}
pub fn disconnected(&mut self) {
let len = self.buffer.len();
self.buffer.split_to(len);
}
pub fn keepalive(&self) -> bool {
self.keepalive && !self.upgrade
}
fn write_to_stream(&mut self) -> Result<WriterState, io::Error> {
if let Some(ref mut stream) = self.stream {
while !self.buffer.is_empty() {
match stream.write(self.buffer.as_ref()) {
Ok(n) => {
self.buffer.split_to(n);
},
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if self.buffer.len() > MAX_WRITE_BUFFER_SIZE {
return Ok(WriterState::Pause)
} else {
return Ok(WriterState::Done)
}
}
Err(err) =>
return Err(err),
}
}
}
return Ok(WriterState::Done)
}
}
impl<T: AsyncWrite> Writer for H1Writer<T> {
fn start(&mut self, req: &mut HttpRequest, msg: &mut HttpResponse)
-> Result<WriterState, io::Error>
{
trace!("Prepare message status={:?}", msg.status);
// prepare task
let mut extra = 0;
let body = msg.replace_body(Body::Empty);
let version = msg.version().unwrap_or_else(|| req.version());
self.started = true;
self.keepalive = msg.keep_alive().unwrap_or_else(|| req.keep_alive());
match body {
Body::Empty => {
if msg.chunked() {
error!("Chunked transfer is enabled but body is set to Empty");
}
msg.headers.insert(CONTENT_LENGTH, HeaderValue::from_static("0"));
msg.headers.remove(TRANSFER_ENCODING);
self.encoder = Encoder::length(0);
},
Body::Length(n) => {
if msg.chunked() {
error!("Chunked transfer is enabled but body with specific length is specified");
}
msg.headers.insert(
CONTENT_LENGTH,
HeaderValue::from_str(format!("{}", n).as_str()).unwrap());
msg.headers.remove(TRANSFER_ENCODING);
self.encoder = Encoder::length(n);
},
Body::Binary(ref bytes) => {
extra = bytes.len();
msg.headers.insert(
CONTENT_LENGTH,
HeaderValue::from_str(format!("{}", bytes.len()).as_str()).unwrap());
msg.headers.remove(TRANSFER_ENCODING);
self.encoder = Encoder::length(0);
}
Body::Streaming => {
if msg.chunked() {
if version < Version::HTTP_11 {
error!("Chunked transfer encoding is forbidden for {:?}", version);
}
msg.headers.remove(CONTENT_LENGTH);
msg.headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
self.encoder = Encoder::chunked();
} else {
self.encoder = Encoder::eof();
}
}
Body::Upgrade => {
msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
self.encoder = Encoder::eof();
}
}
// Connection upgrade
if msg.upgrade() {
msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
}
// keep-alive
else if self.keepalive {
if version < Version::HTTP_11 {
msg.headers.insert(CONNECTION, HeaderValue::from_static("keep-alive"));
}
} else if version >= Version::HTTP_11 {
msg.headers.insert(CONNECTION, HeaderValue::from_static("close"));
}
// render message
let init_cap = 100 + msg.headers.len() * AVERAGE_HEADER_SIZE + extra;
self.buffer.reserve(init_cap);
if version == Version::HTTP_11 && msg.status == StatusCode::OK {
self.buffer.extend(b"HTTP/1.1 200 OK\r\n");
} else {
let _ = write!(self.buffer, "{:?} {}\r\n", version, msg.status);
}
for (key, value) in &msg.headers {
let t: &[u8] = key.as_ref();
self.buffer.extend(t);
self.buffer.extend(b": ");
self.buffer.extend(value.as_ref());
self.buffer.extend(b"\r\n");
}
// using http::h1::date is quite a lot faster than generating
// a unique Date header each time like req/s goes up about 10%
if !msg.headers.contains_key(DATE) {
self.buffer.reserve(date::DATE_VALUE_LENGTH + 8);
self.buffer.extend(b"Date: ");
date::extend(&mut self.buffer);
self.buffer.extend(b"\r\n");
}
// default content-type
if !msg.headers.contains_key(CONTENT_TYPE) {
self.buffer.extend(b"ContentType: application/octet-stream\r\n".as_ref());
}
self.buffer.extend(b"\r\n");
if let Body::Binary(ref bytes) = body {
self.buffer.extend_from_slice(bytes.as_ref());
return Ok(WriterState::Done)
}
msg.replace_body(body);
Ok(WriterState::Done)
}
fn write(&mut self, payload: &[u8]) -> Result<WriterState, io::Error> {
if !self.disconnected {
if self.started {
// TODO: add warning, write after EOF
self.encoder.encode(&mut self.buffer, payload);
} else {
// might be response for EXCEPT
self.buffer.extend_from_slice(payload)
}
}
if self.buffer.len() > MAX_WRITE_BUFFER_SIZE {
return Ok(WriterState::Pause)
} else {
return Ok(WriterState::Done)
}
}
fn write_eof(&mut self) -> Result<WriterState, io::Error> {
if !self.encoder.encode_eof(&mut self.buffer) {
//debug!("last payload item, but it is not EOF ");
Err(io::Error::new(io::ErrorKind::Other,
"Last payload item, but eof is not reached"))
} else {
if self.buffer.len() > MAX_WRITE_BUFFER_SIZE {
return Ok(WriterState::Pause)
} else {
return Ok(WriterState::Done)
}
}
}
fn poll_complete(&mut self) -> Poll<(), io::Error> {
match self.write_to_stream() {
Ok(WriterState::Done) => Ok(Async::Ready(())),
Ok(WriterState::Pause) => Ok(Async::NotReady),
Err(err) => Err(err)
}
}
}
/// Encoders to handle different Transfer-Encodings.
#[derive(Debug, Clone)]
struct Encoder {
kind: Kind,
}
#[derive(Debug, PartialEq, Clone)]
enum Kind {
/// An Encoder for when Transfer-Encoding includes `chunked`.
Chunked(bool),
/// An Encoder for when Content-Length is set.
///
/// Enforces that the body is not longer than the Content-Length header.
Length(u64),
/// An Encoder for when Content-Length is not known.
///
/// Appliction decides when to stop writing.
Eof,
}
impl Encoder {
pub fn eof() -> Encoder {
Encoder {
kind: Kind::Eof,
}
}
pub fn chunked() -> Encoder {
Encoder {
kind: Kind::Chunked(false),
}
}
pub fn length(len: u64) -> Encoder {
Encoder {
kind: Kind::Length(len),
}
}
/// Encode message. Return `EOF` state of encoder
pub fn encode(&mut self, dst: &mut BytesMut, msg: &[u8]) -> bool {
match self.kind {
Kind::Eof => {
dst.extend(msg);
msg.is_empty()
},
Kind::Chunked(ref mut eof) => {
if *eof {
return true;
}
if msg.is_empty() {
*eof = true;
dst.extend(b"0\r\n\r\n");
} else {
write!(dst, "{:X}\r\n", msg.len()).unwrap();
dst.extend(msg);
dst.extend(b"\r\n");
}
*eof
},
Kind::Length(ref mut remaining) => {
if msg.is_empty() {
return *remaining == 0
}
let max = cmp::min(*remaining, msg.len() as u64);
trace!("sized write = {}", max);
dst.extend(msg[..max as usize].as_ref());
*remaining -= max as u64;
trace!("encoded {} bytes, remaining = {}", max, remaining);
*remaining == 0
},
}
}
/// Encode eof. Return `EOF` state of encoder
pub fn encode_eof(&mut self, dst: &mut BytesMut) -> bool {
match self.kind {
Kind::Eof => true,
Kind::Chunked(ref mut eof) => {
if *eof {
return true;
}
*eof = true;
dst.extend(b"0\r\n\r\n");
true
},
Kind::Length(ref mut remaining) => {
return *remaining == 0
},
}
}
}

148
src/h2.rs
View File

@ -1,9 +1,147 @@
use std::{io, cmp};
use std::{io, cmp, mem};
use std::rc::Rc;
use std::io::{Read, Write};
use std::cell::UnsafeCell;
use std::collections::VecDeque;
use http::request::Parts;
use http2::{RecvStream};
use http2::server::{Server, Handshake, Respond};
use bytes::{Buf, Bytes};
use futures::Poll;
use futures::{Async, Poll, Future, Stream};
use tokio_io::{AsyncRead, AsyncWrite};
use task::Task;
use server::HttpHandler;
use httpcodes::HTTPNotFound;
use httprequest::HttpRequest;
use payload::{Payload, PayloadError, PayloadSender};
pub(crate) struct Http2<T, A, H>
where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static
{
router: Rc<Vec<H>>,
#[allow(dead_code)]
addr: A,
state: State<IoWrapper<T>>,
error: bool,
tasks: VecDeque<Entry>,
}
enum State<T: AsyncRead + AsyncWrite> {
Handshake(Handshake<T, Bytes>),
Server(Server<T, Bytes>),
Empty,
}
impl<T, A, H> Http2<T, A, H>
where T: AsyncRead + AsyncWrite + 'static,
A: 'static,
H: HttpHandler + 'static
{
pub fn new(stream: T, addr: A, router: Rc<Vec<H>>, buf: Bytes) -> Self {
Http2{ router: router,
addr: addr,
error: false,
tasks: VecDeque::new(),
state: State::Handshake(
Server::handshake(IoWrapper{unread: Some(buf), inner: stream})) }
}
pub fn poll(&mut self) -> Poll<(), ()> {
// handshake
self.state = if let State::Handshake(ref mut handshake) = self.state {
match handshake.poll() {
Ok(Async::Ready(srv)) => {
State::Server(srv)
},
Ok(Async::NotReady) =>
return Ok(Async::NotReady),
Err(err) => {
trace!("Error handling connection: {}", err);
return Err(())
}
}
} else {
mem::replace(&mut self.state, State::Empty)
};
// get request
let poll = if let State::Server(ref mut server) = self.state {
server.poll()
} else {
unreachable!("Http2::poll() state was not advanced completely!")
};
match poll {
Ok(Async::NotReady) => {
// Ok(Async::NotReady);
()
}
Err(err) => {
trace!("Connection error: {}", err);
self.error = true;
},
Ok(Async::Ready(None)) => {
},
Ok(Async::Ready(Some((req, resp)))) => {
let (parts, body) = req.into_parts();
let entry = Entry::new(parts, body, resp, &self.router);
}
}
Ok(Async::Ready(()))
}
}
struct Entry {
task: Task,
req: UnsafeCell<HttpRequest>,
payload: PayloadSender,
recv: RecvStream,
respond: Respond<Bytes>,
eof: bool,
error: bool,
finished: bool,
}
impl Entry {
fn new<H>(parts: Parts,
recv: RecvStream,
resp: Respond<Bytes>,
router: &Rc<Vec<H>>) -> Entry
where H: HttpHandler + 'static
{
let path = parts.uri.path().to_owned();
let query = parts.uri.query().unwrap_or("").to_owned();
println!("PARTS: {:?}", parts);
let mut req = HttpRequest::new(
parts.method, path, parts.version, parts.headers, query);
let (psender, payload) = Payload::new(false);
// start request processing
let mut task = None;
for h in router.iter() {
if req.path().starts_with(h.prefix()) {
task = Some(h.handle(&mut req, payload));
break
}
}
println!("REQ: {:?}", req);
Entry {task: task.unwrap_or_else(|| Task::reply(HTTPNotFound)),
req: UnsafeCell::new(req),
payload: psender,
recv: recv,
respond: resp,
eof: false,
error: false,
finished: false}
}
}
struct IoWrapper<T> {
unread: Option<Bytes>,
@ -14,9 +152,9 @@ impl<T: Read> Read for IoWrapper<T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if let Some(mut bytes) = self.unread.take() {
let size = cmp::min(buf.len(), bytes.len());
buf.copy_from_slice(&bytes[..size]);
bytes.split_to(size);
if !bytes.is_empty() {
buf[..size].copy_from_slice(&bytes[..size]);
if bytes.len() > size {
bytes.split_to(size);
self.unread = Some(bytes);
}
Ok(size)

View File

@ -20,6 +20,7 @@ extern crate mime_guess;
extern crate url;
extern crate percent_encoding;
extern crate actix;
extern crate h2 as http2;
#[cfg(feature="tls")]
extern crate native_tls;
@ -45,6 +46,7 @@ mod wsframe;
mod wsproto;
mod h1;
mod h2;
mod h1writer;
pub mod ws;
pub mod dev;

View File

@ -1,13 +1,9 @@
use std::{io, net};
use std::{io, net, mem};
use std::rc::Rc;
use std::cell::UnsafeCell;
use std::time::Duration;
use std::marker::PhantomData;
use std::collections::VecDeque;
use actix::dev::*;
use futures::{Future, Poll, Async, Stream};
use tokio_core::reactor::Timeout;
use tokio_core::net::{TcpListener, TcpStream};
use tokio_io::{AsyncRead, AsyncWrite};
@ -17,9 +13,9 @@ use native_tls::TlsAcceptor;
use tokio_tls::{TlsStream, TlsAcceptorExt};
use h1;
use h2;
use task::Task;
use payload::Payload;
use httpcodes::HTTPNotFound;
use httprequest::HttpRequest;
/// Low level http request handler
@ -153,11 +149,10 @@ impl<H: HttpHandler> HttpServer<TlsStream<TcpStream>, net::SocketAddr, H> {
println!("SSL");
TlsAcceptorExt::accept_async(acc.as_ref(), stream)
.map(move |t| {
println!("connected {:?} {:?}", t, addr);
IoStream(t, addr)
})
.map_err(|err| {
println!("ERR: {:?}", err);
trace!("Error during handling tls connection: {}", err);
io::Error::new(io::ErrorKind::Other, err)
})
}));
@ -195,42 +190,25 @@ impl<T, A, H> Handler<IoStream<T, A>, io::Error> for HttpServer<T, A, H>
-> Response<Self, IoStream<T, A>>
{
Arbiter::handle().spawn(
HttpChannel{router: Rc::clone(&self.h),
addr: msg.1,
stream: msg.0,
reader: h1::Reader::new(),
error: false,
items: VecDeque::new(),
inactive: VecDeque::new(),
keepalive: true,
keepalive_timer: None,
HttpChannel{
proto: Protocol::H1(h1::Http1::new(msg.0, msg.1, Rc::clone(&self.h)))
});
Self::empty()
}
}
struct Entry {
task: Task,
req: UnsafeCell<HttpRequest>,
eof: bool,
error: bool,
finished: bool,
enum Protocol<T, A, H>
where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static
{
H1(h1::Http1<T, A, H>),
H2(h2::Http2<T, A, H>),
None,
}
const KEEPALIVE_PERIOD: u64 = 15; // seconds
const MAX_PIPELINED_MESSAGES: usize = 16;
pub struct HttpChannel<T: 'static, A: 'static, H: 'static> {
router: Rc<Vec<H>>,
#[allow(dead_code)]
addr: A,
stream: T,
reader: h1::Reader,
error: bool,
items: VecDeque<Entry>,
inactive: VecDeque<Entry>,
keepalive: bool,
keepalive_timer: Option<Timeout>,
pub struct HttpChannel<T, A, H>
where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static
{
proto: Protocol<T, A, H>,
}
/*impl<T: 'static, A: 'static, H: 'static> Drop for HttpChannel<T, A, H> {
@ -240,193 +218,45 @@ pub struct HttpChannel<T: 'static, A: 'static, H: 'static> {
}*/
impl<T, A, H> Actor for HttpChannel<T, A, H>
where T: AsyncRead + AsyncWrite + 'static,
A: 'static,
H: HttpHandler + 'static
where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: HttpHandler + 'static
{
type Context = Context<Self>;
}
impl<T, A, H> Future for HttpChannel<T, A, H>
where T: AsyncRead + AsyncWrite + 'static,
A: 'static,
H: HttpHandler + 'static
where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: HttpHandler + 'static
{
type Item = ();
type Error = ();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
// keep-alive timer
if let Some(ref mut timeout) = self.keepalive_timer {
match timeout.poll() {
Ok(Async::Ready(_)) =>
return Ok(Async::Ready(())),
Ok(Async::NotReady) => (),
Err(_) => unreachable!(),
match self.proto {
Protocol::H1(ref mut h1) => {
match h1.poll() {
Ok(Async::Ready(h1::Http1Result::Done)) =>
return Ok(Async::Ready(())),
Ok(Async::Ready(h1::Http1Result::Upgrade)) => (),
Ok(Async::NotReady) =>
return Ok(Async::NotReady),
Err(_) =>
return Err(()),
}
}
Protocol::H2(ref mut h2) =>
return h2.poll(),
Protocol::None =>
unreachable!()
}
loop {
let mut not_ready = true;
// check in-flight messages
let mut idx = 0;
while idx < self.items.len() {
if idx == 0 {
if self.items[idx].error {
return Err(())
}
// this is anoying
let req = unsafe {self.items[idx].req.get().as_mut().unwrap()};
match self.items[idx].task.poll_io(&mut self.stream, req)
{
Ok(Async::Ready(ready)) => {
not_ready = false;
let mut item = self.items.pop_front().unwrap();
// overide keep-alive state
if self.keepalive {
self.keepalive = item.task.keepalive();
}
if !ready {
item.eof = true;
self.inactive.push_back(item);
}
// no keep-alive
if ready && !self.keepalive &&
self.items.is_empty() && self.inactive.is_empty()
{
return Ok(Async::Ready(()))
}
continue
},
Ok(Async::NotReady) => (),
Err(_) => {
// it is not possible to recover from error
// during task handling, so just drop connection
return Err(())
}
}
} else if !self.items[idx].finished && !self.items[idx].error {
match self.items[idx].task.poll() {
Ok(Async::NotReady) => (),
Ok(Async::Ready(_)) => {
not_ready = false;
self.items[idx].finished = true;
},
Err(_) =>
self.items[idx].error = true,
}
}
idx += 1;
}
// check inactive tasks
let mut idx = 0;
while idx < self.inactive.len() {
if idx == 0 && self.inactive[idx].error && self.inactive[idx].finished {
let _ = self.inactive.pop_front();
continue
}
if !self.inactive[idx].finished && !self.inactive[idx].error {
match self.inactive[idx].task.poll() {
Ok(Async::NotReady) => (),
Ok(Async::Ready(_)) => {
not_ready = false;
self.inactive[idx].finished = true
}
Err(_) =>
self.inactive[idx].error = true,
}
}
idx += 1;
}
// read incoming data
if !self.error && self.items.len() < MAX_PIPELINED_MESSAGES {
match self.reader.parse(&mut self.stream) {
Ok(Async::Ready((mut req, payload))) => {
not_ready = false;
// stop keepalive timer
self.keepalive_timer.take();
// start request processing
let mut task = None;
for h in self.router.iter() {
if req.path().starts_with(h.prefix()) {
task = Some(h.handle(&mut req, payload));
break
}
}
self.items.push_back(
Entry {task: task.unwrap_or_else(|| Task::reply(HTTPNotFound)),
req: UnsafeCell::new(req),
eof: false,
error: false,
finished: false});
}
Err(err) => {
// notify all tasks
not_ready = false;
for entry in &mut self.items {
entry.task.disconnected()
}
// kill keepalive
self.keepalive = false;
self.keepalive_timer.take();
// on parse error, stop reading stream but
// tasks need to be completed
self.error = true;
if self.items.is_empty() {
if let h1::ReaderError::Error(err) = err {
self.items.push_back(
Entry {task: Task::reply(err),
req: UnsafeCell::new(HttpRequest::for_error()),
eof: false,
error: false,
finished: false});
}
}
}
Ok(Async::NotReady) => {
// start keep-alive timer, this is also slow request timeout
if self.items.is_empty() && self.inactive.is_empty() {
if self.keepalive {
if self.keepalive_timer.is_none() {
trace!("Start keep-alive timer");
let mut timeout = Timeout::new(
Duration::new(KEEPALIVE_PERIOD, 0),
Arbiter::handle()).unwrap();
// register timeout
let _ = timeout.poll();
self.keepalive_timer = Some(timeout);
}
} else {
// keep-alive disable, drop connection
return Ok(Async::Ready(()))
}
}
return Ok(Async::NotReady)
}
}
}
// check for parse error
if self.items.is_empty() && self.inactive.is_empty() && self.error {
return Ok(Async::Ready(()))
}
if not_ready {
return Ok(Async::NotReady)
// upgrade to h2
let proto = mem::replace(&mut self.proto, Protocol::None);
match proto {
Protocol::H1(h1) => {
let (stream, addr, router, buf) = h1.into_inner();
self.proto = Protocol::H2(h2::Http2::new(stream, addr, router, buf));
return self.poll()
}
_ => unreachable!()
}
}
}

View File

@ -1,27 +1,18 @@
use std::{mem, cmp, io};
use std::{mem, io};
use std::rc::Rc;
use std::fmt::Write;
use std::cell::RefCell;
use std::collections::VecDeque;
use http::{StatusCode, Version};
use http::header::{HeaderValue,
CONNECTION, CONTENT_TYPE, CONTENT_LENGTH, TRANSFER_ENCODING, DATE};
use bytes::BytesMut;
use futures::{Async, Future, Poll, Stream};
use futures::task::{Task as FutureTask, current as current_task};
use tokio_io::AsyncWrite;
use date;
use body::Body;
use h1writer::{Writer, WriterState};
use route::Frame;
use application::Middleware;
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
type FrameStream = Stream<Item=Frame, Error=io::Error>;
const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific
const MAX_WRITE_BUFFER_SIZE: usize = 65_536; // max buffer size 64k
#[derive(PartialEq, Debug)]
enum TaskRunningState {
@ -34,6 +25,16 @@ impl TaskRunningState {
fn is_done(&self) -> bool {
*self == TaskRunningState::Done
}
fn pause(&mut self) {
if *self != TaskRunningState::Done {
*self = TaskRunningState::Paused
}
}
fn resume(&mut self) {
if *self != TaskRunningState::Done {
*self = TaskRunningState::Running
}
}
}
#[derive(PartialEq, Debug)]
@ -100,17 +101,12 @@ impl Future for DrainFut {
}
}
pub struct Task {
state: TaskRunningState,
iostate: TaskIOState,
frames: VecDeque<Frame>,
stream: TaskStream,
encoder: Encoder,
buffer: BytesMut,
drain: Vec<Rc<RefCell<DrainFut>>>,
upgrade: bool,
keepalive: bool,
prepared: Option<HttpResponse>,
disconnected: bool,
middlewares: Option<Rc<Vec<Box<Middleware>>>>,
@ -129,10 +125,6 @@ impl Task {
frames: frames,
drain: Vec::new(),
stream: TaskStream::None,
encoder: Encoder::length(0),
buffer: BytesMut::new(),
upgrade: false,
keepalive: false,
prepared: None,
disconnected: false,
middlewares: None,
@ -147,11 +139,7 @@ impl Task {
iostate: TaskIOState::ReadingMessage,
frames: VecDeque::new(),
stream: TaskStream::Stream(Box::new(stream)),
encoder: Encoder::length(0),
buffer: BytesMut::new(),
drain: Vec::new(),
upgrade: false,
keepalive: false,
prepared: None,
disconnected: false,
middlewares: None,
@ -165,158 +153,26 @@ impl Task {
iostate: TaskIOState::ReadingMessage,
frames: VecDeque::new(),
stream: TaskStream::Context(Box::new(ctx)),
encoder: Encoder::length(0),
buffer: BytesMut::new(),
drain: Vec::new(),
upgrade: false,
keepalive: false,
prepared: None,
disconnected: false,
middlewares: None,
}
}
pub(crate) fn keepalive(&self) -> bool {
self.keepalive && !self.upgrade
}
pub(crate) fn set_middlewares(&mut self, middlewares: Rc<Vec<Box<Middleware>>>) {
self.middlewares = Some(middlewares);
}
pub(crate) fn disconnected(&mut self) {
let len = self.buffer.len();
self.buffer.split_to(len);
self.disconnected = true;
if let TaskStream::Context(ref mut ctx) = self.stream {
ctx.disconnected();
}
}
fn prepare(&mut self, req: &mut HttpRequest, msg: HttpResponse)
{
trace!("Prepare message status={:?}", msg.status);
// run middlewares
let mut msg = if let Some(middlewares) = self.middlewares.take() {
let mut msg = msg;
for middleware in middlewares.iter() {
msg = middleware.response(req, msg);
}
self.middlewares = Some(middlewares);
msg
} else {
msg
};
// prepare task
let mut extra = 0;
let body = msg.replace_body(Body::Empty);
let version = msg.version().unwrap_or_else(|| req.version());
self.keepalive = msg.keep_alive().unwrap_or_else(|| req.keep_alive());
match body {
Body::Empty => {
if msg.chunked() {
error!("Chunked transfer is enabled but body is set to Empty");
}
msg.headers.insert(CONTENT_LENGTH, HeaderValue::from_static("0"));
msg.headers.remove(TRANSFER_ENCODING);
self.encoder = Encoder::length(0);
},
Body::Length(n) => {
if msg.chunked() {
error!("Chunked transfer is enabled but body with specific length is specified");
}
msg.headers.insert(
CONTENT_LENGTH,
HeaderValue::from_str(format!("{}", n).as_str()).unwrap());
msg.headers.remove(TRANSFER_ENCODING);
self.encoder = Encoder::length(n);
},
Body::Binary(ref bytes) => {
extra = bytes.len();
msg.headers.insert(
CONTENT_LENGTH,
HeaderValue::from_str(format!("{}", bytes.len()).as_str()).unwrap());
msg.headers.remove(TRANSFER_ENCODING);
self.encoder = Encoder::length(0);
}
Body::Streaming => {
if msg.chunked() {
if version < Version::HTTP_11 {
error!("Chunked transfer encoding is forbidden for {:?}", version);
}
msg.headers.remove(CONTENT_LENGTH);
msg.headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
self.encoder = Encoder::chunked();
} else {
self.encoder = Encoder::eof();
}
}
Body::Upgrade => {
msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
self.encoder = Encoder::eof();
}
}
// Connection upgrade
if msg.upgrade() {
msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
}
// keep-alive
else if self.keepalive {
if version < Version::HTTP_11 {
msg.headers.insert(CONNECTION, HeaderValue::from_static("keep-alive"));
}
} else if version >= Version::HTTP_11 {
msg.headers.insert(CONNECTION, HeaderValue::from_static("close"));
}
// render message
let init_cap = 100 + msg.headers.len() * AVERAGE_HEADER_SIZE + extra;
self.buffer.reserve(init_cap);
if version == Version::HTTP_11 && msg.status == StatusCode::OK {
self.buffer.extend(b"HTTP/1.1 200 OK\r\n");
} else {
let _ = write!(self.buffer, "{:?} {}\r\n", version, msg.status);
}
for (key, value) in &msg.headers {
let t: &[u8] = key.as_ref();
self.buffer.extend(t);
self.buffer.extend(b": ");
self.buffer.extend(value.as_ref());
self.buffer.extend(b"\r\n");
}
// using http::h1::date is quite a lot faster than generating
// a unique Date header each time like req/s goes up about 10%
if !msg.headers.contains_key(DATE) {
self.buffer.reserve(date::DATE_VALUE_LENGTH + 8);
self.buffer.extend(b"Date: ");
date::extend(&mut self.buffer);
self.buffer.extend(b"\r\n");
}
// default content-type
if !msg.headers.contains_key(CONTENT_TYPE) {
self.buffer.extend(b"ContentType: application/octet-stream\r\n".as_ref());
}
self.buffer.extend(b"\r\n");
if let Body::Binary(ref bytes) = body {
self.buffer.extend_from_slice(bytes.as_ref());
self.prepared = Some(msg);
return
}
msg.replace_body(body);
self.prepared = Some(msg);
}
pub(crate) fn poll_io<T>(&mut self, io: &mut T, req: &mut HttpRequest) -> Poll<bool, ()>
where T: AsyncWrite
where T: Writer
{
trace!("POLL-IO frames:{:?}", self.frames.len());
// response is completed
@ -328,87 +184,76 @@ impl Task {
match self.poll() {
Ok(Async::Ready(_)) => {
self.state = TaskRunningState::Done;
}
},
Ok(Async::NotReady) => (),
Err(_) => return Err(())
}
}
// use exiting frames
while let Some(frame) = self.frames.pop_front() {
trace!("IO Frame: {:?}", frame);
match frame {
Frame::Message(response) => {
if !self.disconnected {
self.prepare(req, response);
if self.state != TaskRunningState::Paused {
while let Some(frame) = self.frames.pop_front() {
trace!("IO Frame: {:?}", frame);
let res = match frame {
Frame::Message(mut response) => {
trace!("Prepare message status={:?}", response.status);
// run middlewares
let mut response =
if let Some(middlewares) = self.middlewares.take() {
let mut response = response;
for middleware in middlewares.iter() {
response = middleware.response(req, response);
}
self.middlewares = Some(middlewares);
response
} else {
response
};
let result = io.start(req, &mut response);
self.prepared = Some(response);
result
}
}
Frame::Payload(Some(chunk)) => {
if !self.disconnected {
if self.prepared.is_some() {
// TODO: add warning, write after EOF
self.encoder.encode(&mut self.buffer, chunk.as_ref());
} else {
// might be response for EXCEPT
self.buffer.extend_from_slice(chunk.as_ref())
}
Frame::Payload(Some(chunk)) => {
io.write(chunk.as_ref())
},
Frame::Payload(None) => {
self.iostate = TaskIOState::Done;
io.write_eof()
},
Frame::Drain(fut) => {
self.drain.push(fut);
break
}
},
Frame::Payload(None) => {
if !self.disconnected &&
!self.encoder.encode(&mut self.buffer, [].as_ref())
{
// TODO: add error "not eof""
debug!("last payload item, but it is not EOF ");
return Err(())
};
match res {
Ok(WriterState::Pause) => {
self.state.pause();
break
}
break
},
Frame::Drain(fut) => {
self.drain.push(fut);
break
Ok(WriterState::Done) => self.state.resume(),
Err(_) => return Err(())
}
}
}
}
// write bytes to TcpStream
if !self.disconnected {
while !self.buffer.is_empty() {
match io.write(self.buffer.as_ref()) {
Ok(n) => {
self.buffer.split_to(n);
},
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
break
}
Err(_) => return Err(()),
}
// flush io
match io.poll_complete() {
Ok(Async::Ready(())) => self.state.resume(),
Ok(Async::NotReady) => {
return Ok(Async::NotReady)
}
}
// should pause task
if self.state != TaskRunningState::Done {
if self.buffer.len() > MAX_WRITE_BUFFER_SIZE {
self.state = TaskRunningState::Paused;
} else if self.state == TaskRunningState::Paused {
self.state = TaskRunningState::Running;
Err(err) => {
trace!("Error sending data: {}", err);
return Err(())
}
} else {
// at this point we wont get any more Frames
self.iostate = TaskIOState::Done;
}
// drain
if self.buffer.is_empty() && !self.drain.is_empty() {
match io.flush() {
Ok(_) => (),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
return Ok(Async::NotReady)
}
Err(_) => return Err(()),
}
if !self.drain.is_empty() {
for fut in &mut self.drain {
fut.borrow_mut().set()
}
@ -416,7 +261,7 @@ impl Task {
}
// response is completed
if (self.buffer.is_empty() || self.disconnected) && self.iostate.is_done() {
if self.iostate.is_done() {
// run middlewares
if let Some(ref mut resp) = self.prepared {
if let Some(middlewares) = self.middlewares.take() {
@ -443,8 +288,8 @@ impl Task {
error!("Non expected frame {:?}", frame);
return Err(())
}
self.upgrade = msg.upgrade();
if self.upgrade || msg.body().has_body() {
let upgrade = msg.upgrade();
if upgrade || msg.body().has_body() {
self.iostate = TaskIOState::ReadingPayload;
} else {
self.iostate = TaskIOState::Done;
@ -489,89 +334,3 @@ impl Future for Task {
result
}
}
/// Encoders to handle different Transfer-Encodings.
#[derive(Debug, Clone)]
struct Encoder {
kind: Kind,
}
#[derive(Debug, PartialEq, Clone)]
enum Kind {
/// An Encoder for when Transfer-Encoding includes `chunked`.
Chunked(bool),
/// An Encoder for when Content-Length is set.
///
/// Enforces that the body is not longer than the Content-Length header.
Length(u64),
/// An Encoder for when Content-Length is not known.
///
/// Appliction decides when to stop writing.
Eof,
}
impl Encoder {
pub fn eof() -> Encoder {
Encoder {
kind: Kind::Eof,
}
}
pub fn chunked() -> Encoder {
Encoder {
kind: Kind::Chunked(false),
}
}
pub fn length(len: u64) -> Encoder {
Encoder {
kind: Kind::Length(len),
}
}
/*pub fn is_eof(&self) -> bool {
match self.kind {
Kind::Eof | Kind::Length(0) => true,
Kind::Chunked(eof) => eof,
_ => false,
}
}*/
/// Encode message. Return `EOF` state of encoder
pub fn encode(&mut self, dst: &mut BytesMut, msg: &[u8]) -> bool {
match self.kind {
Kind::Eof => {
dst.extend(msg);
msg.is_empty()
},
Kind::Chunked(ref mut eof) => {
if *eof {
return true;
}
if msg.is_empty() {
*eof = true;
dst.extend(b"0\r\n\r\n");
} else {
write!(dst, "{:X}\r\n", msg.len()).unwrap();
dst.extend(msg);
dst.extend(b"\r\n");
}
*eof
},
Kind::Length(ref mut remaining) => {
if msg.is_empty() {
return *remaining == 0
}
let max = cmp::min(*remaining, msg.len() as u64);
trace!("sized write = {}", max);
dst.extend(msg[..max as usize].as_ref());
*remaining -= max as u64;
trace!("encoded {} bytes, remaining = {}", max, remaining);
*remaining == 0
},
}
}
}