From 32cefb84551067a8c00ebd5cab6607d8b49c0206 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sat, 4 Nov 2017 09:07:44 -0700 Subject: [PATCH] implement h2 writer --- Cargo.toml | 5 +- README.md | 6 +- examples/tls/src/main.rs | 4 +- src/h1writer.rs | 3 +- src/h2.rs | 160 ++++++++++++++++----- src/h2writer.rs | 296 +++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + src/payload.rs | 59 ++++++-- src/server.rs | 3 +- src/task.rs | 6 +- 10 files changed, 486 insertions(+), 57 deletions(-) create mode 100644 src/h2writer.rs diff --git a/Cargo.toml b/Cargo.toml index 37586ca2c..58dd2a3fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-web" -version = "0.2.1" +version = "0.3.0" authors = ["Nikolay Kim "] description = "Actix web framework" readme = "README.md" @@ -48,8 +48,7 @@ futures = "0.1" tokio-io = "0.1" tokio-core = "0.1" -h2 = { path = '../h2' } -# h2 = { git = 'https://github.com/carllerche/h2', optional = true } +h2 = { git = 'https://github.com/carllerche/h2' } # tls native-tls = { version="0.1", optional = true } diff --git a/README.md b/README.md index 7424cfc69..543c30f9f 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Actix web is licensed under the [Apache-2.0 license](http://opensource.org/licen ## Features - * HTTP 1.1 and 1.0 support + * HTTP/1 and HTTP/2 support * Streaming and pipelining support * Keep-alive and slow requests support * [WebSockets support](https://actix.github.io/actix-web/actix_web/ws/index.html) @@ -27,7 +27,7 @@ To use `actix-web`, add this to your `Cargo.toml`: ```toml [dependencies] -actix-web = "0.2" +actix-web = { git = "https://github.com/actix/actix-web" } ``` ## Example @@ -37,7 +37,7 @@ actix-web = "0.2" * [Mulitpart streams](https://github.com/actix/actix-web/tree/master/examples/multipart) * [Simple websocket session](https://github.com/actix/actix-web/tree/master/examples/websocket.rs) * [Tcp/Websocket chat](https://github.com/actix/actix-web/tree/master/examples/websocket-chat) -* [SockJS Server](https://github.com/fafhrd91/actix-sockjs) +* [SockJS Server](https://github.com/actix/actix-sockjs) ```rust diff --git a/examples/tls/src/main.rs b/examples/tls/src/main.rs index 9a2c12258..5e2d37544 100644 --- a/examples/tls/src/main.rs +++ b/examples/tls/src/main.rs @@ -17,7 +17,9 @@ fn index(req: &mut HttpRequest, _payload: Payload, state: &()) -> HttpResponse { } fn main() { - ::std::env::set_var("RUST_LOG", "actix_web=info"); + if ::std::env::var("RUST_LOG").is_err() { + ::std::env::set_var("RUST_LOG", "actix_web=info"); + } let _ = env_logger::init(); let sys = actix::System::new("ws-example"); diff --git a/src/h1writer.rs b/src/h1writer.rs index 98f2aa4fa..0aa7b94d3 100644 --- a/src/h1writer.rs +++ b/src/h1writer.rs @@ -16,6 +16,7 @@ const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific const MAX_WRITE_BUFFER_SIZE: usize = 65_536; // max buffer size 64k +#[derive(Debug)] pub(crate) enum WriterState { Done, Pause, @@ -255,7 +256,7 @@ impl Writer for H1Writer { /// Encoders to handle different Transfer-Encodings. #[derive(Debug, Clone)] -struct Encoder { +pub(crate) struct Encoder { kind: Kind, } diff --git a/src/h2.rs b/src/h2.rs index ecf429c0a..9c89c18e5 100644 --- a/src/h2.rs +++ b/src/h2.rs @@ -5,7 +5,7 @@ use std::cell::UnsafeCell; use std::collections::VecDeque; use http::request::Parts; -use http2::{RecvStream}; +use http2::{Reason, RecvStream}; use http2::server::{Server, Handshake, Respond}; use bytes::{Buf, Bytes}; use futures::{Async, Poll, Future, Stream}; @@ -16,6 +16,7 @@ use server::HttpHandler; use httpcodes::HTTPNotFound; use httprequest::HttpRequest; use payload::{Payload, PayloadError, PayloadSender}; +use h2writer::H2Writer; pub(crate) struct Http2 @@ -25,7 +26,7 @@ pub(crate) struct Http2 #[allow(dead_code)] addr: A, state: State>, - error: bool, + disconnected: bool, tasks: VecDeque, } @@ -43,13 +44,101 @@ impl Http2 pub fn new(stream: T, addr: A, router: Rc>, buf: Bytes) -> Self { Http2{ router: router, addr: addr, - error: false, + disconnected: false, tasks: VecDeque::new(), state: State::Handshake( Server::handshake(IoWrapper{unread: Some(buf), inner: stream})) } } pub fn poll(&mut self) -> Poll<(), ()> { + // server + if let State::Server(ref mut server) = self.state { + loop { + let mut not_ready = true; + + // check in-flight connections + for item in &mut self.tasks { + // read payload + item.poll_payload(); + + if !item.eof { + let req = unsafe {item.req.get().as_mut().unwrap()}; + match item.task.poll_io(&mut item.stream, req) { + Ok(Async::Ready(ready)) => { + item.eof = true; + if ready { + item.finished = true; + } + not_ready = false; + }, + Ok(Async::NotReady) => (), + Err(_) => { + item.eof = true; + item.error = true; + item.stream.reset(Reason::INTERNAL_ERROR); + } + } + } else if !item.finished { + match item.task.poll() { + Ok(Async::NotReady) => (), + Ok(Async::Ready(_)) => { + not_ready = false; + item.finished = true; + }, + Err(_) => { + item.error = true; + item.finished = true; + } + } + } + } + + // cleanup finished tasks + while !self.tasks.is_empty() { + if self.tasks[0].eof && self.tasks[0].finished || self.tasks[0].error { + self.tasks.pop_front(); + } else { + break + } + } + + // get request + if !self.disconnected { + match server.poll() { + Ok(Async::NotReady) => { + // Ok(Async::NotReady); + () + } + Err(err) => { + trace!("Connection error: {}", err); + self.disconnected = true; + }, + Ok(Async::Ready(None)) => { + not_ready = false; + self.disconnected = true; + for entry in &mut self.tasks { + entry.task.disconnected() + } + }, + Ok(Async::Ready(Some((req, resp)))) => { + not_ready = false; + let (parts, body) = req.into_parts(); + self.tasks.push_back( + Entry::new(parts, body, resp, &self.router)); + } + } + } + + if not_ready { + if self.tasks.is_empty() && self.disconnected { + return Ok(Async::Ready(())) + } else { + return Ok(Async::NotReady) + } + } + } + } + // handshake self.state = if let State::Handshake(ref mut handshake) = self.state { match handshake.poll() { @@ -67,32 +156,7 @@ impl Http2 mem::replace(&mut self.state, State::Empty) }; - // get request - let poll = if let State::Server(ref mut server) = self.state { - server.poll() - } else { - unreachable!("Http2::poll() state was not advanced completely!") - }; - - match poll { - Ok(Async::NotReady) => { - // Ok(Async::NotReady); - () - } - Err(err) => { - trace!("Connection error: {}", err); - self.error = true; - }, - Ok(Async::Ready(None)) => { - - }, - Ok(Async::Ready(Some((req, resp)))) => { - let (parts, body) = req.into_parts(); - let entry = Entry::new(parts, body, resp, &self.router); - } - } - - Ok(Async::Ready(())) + self.poll() } } @@ -101,10 +165,12 @@ struct Entry { req: UnsafeCell, payload: PayloadSender, recv: RecvStream, - respond: Respond, + stream: H2Writer, eof: bool, error: bool, finished: bool, + reof: bool, + capacity: usize, } impl Entry { @@ -117,7 +183,6 @@ impl Entry { let path = parts.uri.path().to_owned(); let query = parts.uri.query().unwrap_or("").to_owned(); - println!("PARTS: {:?}", parts); let mut req = HttpRequest::new( parts.method, path, parts.version, parts.headers, query); let (psender, payload) = Payload::new(false); @@ -130,16 +195,43 @@ impl Entry { break } } - println!("REQ: {:?}", req); Entry {task: task.unwrap_or_else(|| Task::reply(HTTPNotFound)), req: UnsafeCell::new(req), payload: psender, recv: recv, - respond: resp, + stream: H2Writer::new(resp), eof: false, error: false, - finished: false} + finished: false, + reof: false, + capacity: 0, + } + } + + fn poll_payload(&mut self) { + if !self.reof { + match self.recv.poll() { + Ok(Async::Ready(Some(chunk))) => { + self.payload.feed_data(chunk); + }, + Ok(Async::Ready(None)) => { + self.reof = true; + }, + Ok(Async::NotReady) => (), + Err(err) => { + self.payload.set_error(PayloadError::Http2(err)) + } + } + + let capacity = self.payload.capacity(); + if self.capacity != capacity { + self.capacity = capacity; + if let Err(err) = self.recv.release_capacity().release_capacity(capacity) { + self.payload.set_error(PayloadError::Http2(err)) + } + } + } } } diff --git a/src/h2writer.rs b/src/h2writer.rs new file mode 100644 index 000000000..ae3c9bc82 --- /dev/null +++ b/src/h2writer.rs @@ -0,0 +1,296 @@ +use std::{io, cmp}; +use bytes::{Bytes, BytesMut}; +use futures::{Async, Poll}; +use http2::{Reason, SendStream}; +use http2::server::Respond; +use http::{Version, HttpTryFrom, Response}; +use http::header::{HeaderValue, CONNECTION, CONTENT_TYPE, + CONTENT_LENGTH, TRANSFER_ENCODING, DATE}; + +use date; +use body::Body; +use httprequest::HttpRequest; +use httpresponse::HttpResponse; +use h1writer::{Writer, WriterState}; + +const CHUNK_SIZE: usize = 16_384; +const MAX_WRITE_BUFFER_SIZE: usize = 65_536; // max buffer size 64k + + +pub(crate) struct H2Writer { + respond: Respond, + stream: Option>, + buffer: BytesMut, + started: bool, + encoder: Encoder, + disconnected: bool, + eof: bool, +} + +impl H2Writer { + + pub fn new(respond: Respond) -> H2Writer { + H2Writer { + respond: respond, + stream: None, + buffer: BytesMut::new(), + started: false, + encoder: Encoder::length(0), + disconnected: false, + eof: true, + } + } + + pub fn reset(&mut self, reason: Reason) { + if let Some(mut stream) = self.stream.take() { + stream.send_reset(reason) + } + } + + fn write_to_stream(&mut self) -> Result { + if !self.started { + return Ok(WriterState::Done) + } + + if let Some(ref mut stream) = self.stream { + if self.buffer.is_empty() { + if self.eof { + let _ = stream.send_data(Bytes::new(), true); + } + return Ok(WriterState::Done) + } + + loop { + match stream.poll_capacity() { + Ok(Async::NotReady) => { + if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + return Ok(WriterState::Pause) + } else { + return Ok(WriterState::Done) + } + } + Ok(Async::Ready(None)) => { + return Ok(WriterState::Done) + } + Ok(Async::Ready(Some(cap))) => { + let len = self.buffer.len(); + let bytes = self.buffer.split_to(cmp::min(cap, len)); + let eof = self.buffer.is_empty() && self.eof; + + if let Err(_) = stream.send_data(bytes.freeze(), eof) { + return Err(io::Error::new(io::ErrorKind::Other, "")) + } else { + if !self.buffer.is_empty() { + let cap = cmp::min(self.buffer.len(), CHUNK_SIZE); + stream.reserve_capacity(cap); + } else { + return Ok(WriterState::Done) + } + } + } + Err(_) => { + return Err(io::Error::new(io::ErrorKind::Other, "")) + } + } + } + } + return Ok(WriterState::Done) + } +} + +impl Writer for H2Writer { + + fn start(&mut self, _: &mut HttpRequest, msg: &mut HttpResponse) + -> Result + { + trace!("Prepare message status={:?}", msg); + + // prepare response + self.started = true; + let body = msg.replace_body(Body::Empty); + + // http2 specific + msg.headers.remove(CONNECTION); + msg.headers.remove(TRANSFER_ENCODING); + + match body { + Body::Empty => { + if msg.chunked() { + error!("Chunked transfer is enabled but body is set to Empty"); + } + msg.headers.insert(CONTENT_LENGTH, HeaderValue::from_static("0")); + self.encoder = Encoder::length(0); + }, + Body::Length(n) => { + if msg.chunked() { + error!("Chunked transfer is enabled but body with specific length is specified"); + } + self.eof = false; + msg.headers.insert( + CONTENT_LENGTH, + HeaderValue::from_str(format!("{}", n).as_str()).unwrap()); + self.encoder = Encoder::length(n); + }, + Body::Binary(ref bytes) => { + self.eof = false; + msg.headers.insert( + CONTENT_LENGTH, + HeaderValue::from_str(format!("{}", bytes.len()).as_str()).unwrap()); + self.encoder = Encoder::length(0); + } + _ => { + msg.headers.remove(CONTENT_LENGTH); + self.eof = false; + self.encoder = Encoder::eof(); + } + } + + // using http::h1::date is quite a lot faster than generating + // a unique Date header each time like req/s goes up about 10% + if !msg.headers.contains_key(DATE) { + let mut bytes = BytesMut::with_capacity(29); + date::extend(&mut bytes); + msg.headers.insert(DATE, HeaderValue::try_from(bytes.freeze()).unwrap()); + } + + // default content-type + if !msg.headers.contains_key(CONTENT_TYPE) { + msg.headers.insert( + CONTENT_TYPE, HeaderValue::from_static("application/octet-stream")); + } + + let mut resp = Response::new(()); + *resp.status_mut() = msg.status; + *resp.version_mut() = Version::HTTP_2; + for (key, value) in msg.headers().iter() { + resp.headers_mut().insert(key, value.clone()); + } + + match self.respond.send_response(resp, self.eof) { + Ok(stream) => { + self.stream = Some(stream); + } + Err(_) => { + return Err(io::Error::new(io::ErrorKind::Other, "err")) + } + } + + if let Body::Binary(ref bytes) = body { + self.eof = true; + self.buffer.extend_from_slice(bytes.as_ref()); + if let Some(ref mut stream) = self.stream { + stream.reserve_capacity(cmp::min(self.buffer.len(), CHUNK_SIZE)); + } + return Ok(WriterState::Done) + } + msg.replace_body(body); + + Ok(WriterState::Done) + } + + fn write(&mut self, payload: &[u8]) -> Result { + if !self.disconnected { + if self.started { + // TODO: add warning, write after EOF + self.encoder.encode(&mut self.buffer, payload); + } else { + // might be response for EXCEPT + self.buffer.extend_from_slice(payload) + } + } + + if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + return Ok(WriterState::Pause) + } else { + return Ok(WriterState::Done) + } + } + + fn write_eof(&mut self) -> Result { + self.eof = true; + if !self.encoder.encode_eof(&mut self.buffer) { + Err(io::Error::new(io::ErrorKind::Other, + "Last payload item, but eof is not reached")) + } else { + if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + return Ok(WriterState::Pause) + } else { + return Ok(WriterState::Done) + } + } + } + + fn poll_complete(&mut self) -> Poll<(), io::Error> { + match self.write_to_stream() { + Ok(WriterState::Done) => Ok(Async::Ready(())), + Ok(WriterState::Pause) => Ok(Async::NotReady), + Err(err) => Err(err) + } + } +} + + +/// Encoders to handle different Transfer-Encodings. +#[derive(Debug, Clone)] +pub(crate) struct Encoder { + kind: Kind, +} + +#[derive(Debug, PartialEq, Clone)] +enum Kind { + /// An Encoder for when Content-Length is set. + /// + /// Enforces that the body is not longer than the Content-Length header. + Length(u64), + /// An Encoder for when Content-Length is not known. + /// + /// Appliction decides when to stop writing. + Eof, +} + +impl Encoder { + + pub fn eof() -> Encoder { + Encoder { + kind: Kind::Eof, + } + } + + pub fn length(len: u64) -> Encoder { + Encoder { + kind: Kind::Length(len), + } + } + + /// Encode message. Return `EOF` state of encoder + pub fn encode(&mut self, dst: &mut BytesMut, msg: &[u8]) -> bool { + match self.kind { + Kind::Eof => { + dst.extend(msg); + msg.is_empty() + }, + Kind::Length(ref mut remaining) => { + if msg.is_empty() { + return *remaining == 0 + } + let max = cmp::min(*remaining, msg.len() as u64); + trace!("sized write = {}", max); + dst.extend(msg[..max as usize].as_ref()); + + *remaining -= max as u64; + trace!("encoded {} bytes, remaining = {}", max, remaining); + *remaining == 0 + }, + } + } + + /// Encode eof. Return `EOF` state of encoder + pub fn encode_eof(&mut self, _dst: &mut BytesMut) -> bool { + match self.kind { + Kind::Eof => true, + Kind::Length(ref mut remaining) => { + return *remaining == 0 + }, + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 9dc538124..d42004a4b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,6 +47,7 @@ mod wsproto; mod h1; mod h2; mod h1writer; +mod h2writer; pub mod ws; pub mod dev; diff --git a/src/payload.rs b/src/payload.rs index ae4c881a6..863984940 100644 --- a/src/payload.rs +++ b/src/payload.rs @@ -5,12 +5,13 @@ use std::collections::VecDeque; use std::error::Error; use std::io::Error as IoError; use bytes::{Bytes, BytesMut}; +use http2::Error as Http2Error; use futures::{Async, Poll, Stream}; use futures::task::{Task, current as current_task}; use actix::ResponseType; -const MAX_PAYLOAD_SIZE: usize = 65_536; // max buffer size 64k +const DEFAULT_BUFFER_SIZE: usize = 65_536; // max buffer size 64k /// Just Bytes object pub struct PayloadItem(pub Bytes); @@ -27,6 +28,8 @@ pub enum PayloadError { Incomplete, /// Parse error ParseError(IoError), + /// Http2 error + Http2(Http2Error), } impl fmt::Display for PayloadError { @@ -43,6 +46,7 @@ impl Error for PayloadError { match *self { PayloadError::Incomplete => "A payload reached EOF, but is not complete.", PayloadError::ParseError(ref e) => e.description(), + PayloadError::Http2(ref e) => e.description(), } } @@ -130,6 +134,16 @@ impl Payload { pub fn unread_data(&mut self, data: Bytes) { self.inner.borrow_mut().unread_data(data); } + + /// Get size of payload buffer + pub fn buffer_size(&self) -> usize { + self.inner.borrow().buffer_size() + } + + /// Set size of payload buffer + pub fn set_buffer_size(&self, size: usize) { + self.inner.borrow_mut().set_buffer_size(size) + } } @@ -147,33 +161,33 @@ pub(crate) struct PayloadSender { } impl PayloadSender { - pub(crate) fn set_error(&mut self, err: PayloadError) { + pub fn set_error(&mut self, err: PayloadError) { if let Some(shared) = self.inner.upgrade() { shared.borrow_mut().set_error(err) } } - pub(crate) fn feed_eof(&mut self) { + pub fn feed_eof(&mut self) { if let Some(shared) = self.inner.upgrade() { shared.borrow_mut().feed_eof() } } - pub(crate) fn feed_data(&mut self, data: Bytes) { + pub fn feed_data(&mut self, data: Bytes) { if let Some(shared) = self.inner.upgrade() { shared.borrow_mut().feed_data(data) } } - pub(crate) fn maybe_paused(&self) -> bool { + pub fn maybe_paused(&self) -> bool { match self.inner.upgrade() { Some(shared) => { let inner = shared.borrow(); - if inner.paused() && inner.len() < MAX_PAYLOAD_SIZE { + if inner.paused() && inner.len() < inner.buffer_size() { drop(inner); shared.borrow_mut().resume(); false - } else if !inner.paused() && inner.len() > MAX_PAYLOAD_SIZE { + } else if !inner.paused() && inner.len() > inner.buffer_size() { drop(inner); shared.borrow_mut().pause(); true @@ -184,6 +198,14 @@ impl PayloadSender { None => false, } } + + pub fn capacity(&self) -> usize { + if let Some(shared) = self.inner.upgrade() { + shared.borrow().capacity() + } else { + 0 + } + } } #[derive(Debug)] @@ -194,6 +216,7 @@ struct Inner { err: Option, task: Option, items: VecDeque, + buf_size: usize, } impl Inner { @@ -206,6 +229,7 @@ impl Inner { err: None, task: None, items: VecDeque::new(), + buf_size: DEFAULT_BUFFER_SIZE, } } @@ -347,7 +371,6 @@ impl Inner { self.readuntil(b"\n") } - #[doc(hidden)] pub fn readall(&mut self) -> Option { let len = self.items.iter().fold(0, |cur, item| cur + item.len()); if len > 0 { @@ -363,10 +386,26 @@ impl Inner { } } - pub fn unread_data(&mut self, data: Bytes) { + fn unread_data(&mut self, data: Bytes) { self.len += data.len(); self.items.push_front(data) } + + fn capacity(&self) -> usize { + if self.len > self.buf_size { + 0 + } else { + self.buf_size - self.len + } + } + + fn buffer_size(&self) -> usize { + self.buf_size + } + + fn set_buffer_size(&mut self, size: usize) { + self.buf_size = size + } } #[cfg(test)] @@ -569,7 +608,7 @@ mod tests { assert!(!payload.paused()); assert!(!sender.maybe_paused()); - for _ in 0..MAX_PAYLOAD_SIZE+1 { + for _ in 0..DEFAULT_BUFFER_SIZE+1 { sender.feed_data(Bytes::from("1")); } assert!(sender.maybe_paused()); diff --git a/src/server.rs b/src/server.rs index 55f85b6d3..c7cd6613f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -146,7 +146,6 @@ impl HttpServer, net::SocketAddr, H> { let acc = acceptor.clone(); ctx.add_stream(tcp.incoming().and_then(move |(stream, addr)| { - println!("SSL"); TlsAcceptorExt::accept_async(acc.as_ref(), stream) .map(move |t| { IoStream(t, addr) @@ -183,7 +182,7 @@ impl Handler, io::Error> for HttpServer H: HttpHandler + 'static, { fn error(&mut self, err: io::Error, _: &mut Context) { - println!("Error handling request: {}", err) + debug!("Error handling request: {}", err) } fn handle(&mut self, msg: IoStream, _: &mut Context) diff --git a/src/task.rs b/src/task.rs index 073cde62b..0a8a46065 100644 --- a/src/task.rs +++ b/src/task.rs @@ -247,7 +247,7 @@ impl Task { return Ok(Async::NotReady) } Err(err) => { - trace!("Error sending data: {}", err); + debug!("Error sending data: {}", err); return Err(()) } } @@ -285,7 +285,7 @@ impl Task { match frame { Frame::Message(ref msg) => { if self.iostate != TaskIOState::ReadingMessage { - error!("Non expected frame {:?}", frame); + error!("Unexpected frame {:?}", frame); return Err(()) } let upgrade = msg.upgrade(); @@ -299,7 +299,7 @@ impl Task { if chunk.is_none() { self.iostate = TaskIOState::Done; } else if self.iostate != TaskIOState::ReadingPayload { - error!("Non expected frame {:?}", self.iostate); + error!("Unexpected frame {:?}", self.iostate); return Err(()) } },