From aa20e2670d44b893ae47ad01c61fc1fd3b65dcce Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 16 Nov 2018 21:09:33 -0800 Subject: [PATCH] refactor h1 dispatcher --- src/h1/client.rs | 69 +++++++-- src/h1/decoder.rs | 90 +++++------- src/h1/dispatcher.rs | 338 ++++++++++++++++++++++--------------------- 3 files changed, 267 insertions(+), 230 deletions(-) diff --git a/src/h1/client.rs b/src/h1/client.rs index 8d4051d3..f871cb33 100644 --- a/src/h1/client.rs +++ b/src/h1/client.rs @@ -125,6 +125,15 @@ impl ClientPayloadCodec { } } +fn prn_version(ver: Version) -> &'static str { + match ver { + Version::HTTP_09 => "HTTP/0.9", + Version::HTTP_10 => "HTTP/1.0", + Version::HTTP_11 => "HTTP/1.1", + Version::HTTP_2 => "HTTP/2.0", + } +} + impl ClientCodecInner { fn encode_response( &mut self, @@ -135,33 +144,63 @@ impl ClientCodecInner { // render message { // status line - writeln!( + write!( Writer(buffer), - "{} {} {:?}\r", + "{} {} {}", msg.method, msg.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"), - msg.version + prn_version(msg.version) ).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; // write headers buffer.reserve(msg.headers.len() * AVERAGE_HEADER_SIZE); - for (key, value) in &msg.headers { - let v = value.as_ref(); - let k = key.as_str().as_bytes(); - buffer.reserve(k.len() + v.len() + 4); - buffer.put_slice(k); - buffer.put_slice(b": "); - buffer.put_slice(v); - buffer.put_slice(b"\r\n"); - // Connection upgrade - if key == UPGRADE { - self.flags.insert(Flags::UPGRADE); + // content length + let mut len_is_set = true; + match btype { + BodyType::Sized(len) => { + buffer.extend_from_slice(b"\r\ncontent-length: "); + write!(buffer.writer(), "{}", len)?; + buffer.extend_from_slice(b"\r\n"); } + BodyType::Unsized => { + buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n") + } + BodyType::Zero => { + len_is_set = false; + buffer.extend_from_slice(b"\r\n") + } + BodyType::None => buffer.extend_from_slice(b"\r\n"), + } + + let mut has_date = false; + + for (key, value) in &msg.headers { + match *key { + TRANSFER_ENCODING => continue, + CONTENT_LENGTH => match btype { + BodyType::None => (), + BodyType::Zero => len_is_set = true, + _ => continue, + }, + DATE => has_date = true, + UPGRADE => self.flags.insert(Flags::UPGRADE), + _ => (), + } + + buffer.put_slice(key.as_ref()); + buffer.put_slice(b": "); + buffer.put_slice(value.as_ref()); + buffer.put_slice(b"\r\n"); + } + + // set content length + if !len_is_set { + buffer.extend_from_slice(b"content-length: 0\r\n") } // set date header - if !msg.headers.contains_key(DATE) { + if !has_date { self.config.set_date(buffer); } else { buffer.extend_from_slice(b"\r\n"); diff --git a/src/h1/decoder.rs b/src/h1/decoder.rs index fe2aa707..61cab7ad 100644 --- a/src/h1/decoder.rs +++ b/src/h1/decoder.rs @@ -42,10 +42,9 @@ impl Decoder for MessageDecoder { } pub(crate) enum PayloadLength { - None, - Chunked, + Payload(PayloadType), Upgrade, - Length(u64), + None, } pub(crate) trait MessageTypeDecoder: Sized { @@ -55,7 +54,7 @@ pub(crate) trait MessageTypeDecoder: Sized { fn decode(src: &mut BytesMut) -> Result, ParseError>; - fn process_headers( + fn set_headers( &mut self, slice: &Bytes, version: Version, @@ -140,10 +139,17 @@ pub(crate) trait MessageTypeDecoder: Sized { self.keep_alive(); } + // https://tools.ietf.org/html/rfc7230#section-3.3.3 if chunked { - Ok(PayloadLength::Chunked) + // Chunked encoding + Ok(PayloadLength::Payload(PayloadType::Payload( + PayloadDecoder::chunked(), + ))) } else if let Some(len) = content_length { - Ok(PayloadLength::Length(len)) + // Content-Length + Ok(PayloadLength::Payload(PayloadType::Payload( + PayloadDecoder::length(len), + ))) } else if has_upgrade { Ok(PayloadLength::Upgrade) } else { @@ -166,7 +172,7 @@ impl MessageTypeDecoder for Request { // performance bump for pipeline benchmarks. let mut headers: [HeaderIndex; MAX_HEADERS] = unsafe { mem::uninitialized() }; - let (len, method, uri, version, headers_len) = { + let (len, method, uri, ver, h_len) = { let mut parsed: [httparse::Header; MAX_HEADERS] = unsafe { mem::uninitialized() }; @@ -189,35 +195,24 @@ impl MessageTypeDecoder for Request { } }; - // convert headers let mut msg = Request::new(); - let len = msg.process_headers( - &src.split_to(len).freeze(), - version, - &headers[..headers_len], - )?; + // convert headers + let len = + msg.set_headers(&src.split_to(len).freeze(), ver, &headers[..h_len])?; - // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // payload decoder let decoder = match len { - PayloadLength::Chunked => { - // Chunked encoding - PayloadType::Payload(PayloadDecoder::chunked()) - } - PayloadLength::Length(len) => { - // Content-Length - PayloadType::Payload(PayloadDecoder::length(len)) - } + PayloadLength::Payload(pl) => pl, PayloadLength::Upgrade => { - // upgrade(websocket) or connect + // upgrade(websocket) PayloadType::Stream(PayloadDecoder::eof()) } PayloadLength::None => { if method == Method::CONNECT { - // upgrade(websocket) or connect PayloadType::Stream(PayloadDecoder::eof()) } else if src.len() >= MAX_BUFFER_SIZE { - error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); + trace!("MAX_BUFFER_SIZE unprocessed data reached, closing"); return Err(ParseError::TooLarge); } else { PayloadType::None @@ -230,7 +225,7 @@ impl MessageTypeDecoder for Request { inner.url.update(&uri); inner.head.uri = uri; inner.head.method = method; - inner.head.version = version; + inner.head.version = ver; } Ok(Some((msg, decoder))) @@ -251,7 +246,7 @@ impl MessageTypeDecoder for ClientResponse { // performance bump for pipeline benchmarks. let mut headers: [HeaderIndex; MAX_HEADERS] = unsafe { mem::uninitialized() }; - let (len, version, status, headers_len) = { + let (len, ver, status, h_len) = { let mut parsed: [httparse::Header; MAX_HEADERS] = unsafe { mem::uninitialized() }; @@ -276,37 +271,26 @@ impl MessageTypeDecoder for ClientResponse { let mut msg = ClientResponse::new(); // convert headers - let len = msg.process_headers( - &src.split_to(len).freeze(), - version, - &headers[..headers_len], - )?; + let len = + msg.set_headers(&src.split_to(len).freeze(), ver, &headers[..h_len])?; - // https://tools.ietf.org/html/rfc7230#section-3.3.3 - let decoder = match len { - PayloadLength::Chunked => { - // Chunked encoding - PayloadType::Payload(PayloadDecoder::chunked()) - } - PayloadLength::Length(len) => { - // Content-Length - PayloadType::Payload(PayloadDecoder::length(len)) - } - _ => { - if status == StatusCode::SWITCHING_PROTOCOLS { - // switching protocol or connect - PayloadType::Stream(PayloadDecoder::eof()) - } else if src.len() >= MAX_BUFFER_SIZE { - error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); - return Err(ParseError::TooLarge); - } else { - PayloadType::None - } + // message payload + let decoder = if let PayloadLength::Payload(pl) = len { + pl + } else { + if status == StatusCode::SWITCHING_PROTOCOLS { + // switching protocol or connect + PayloadType::Stream(PayloadDecoder::eof()) + } else if src.len() >= MAX_BUFFER_SIZE { + error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); + return Err(ParseError::TooLarge); + } else { + PayloadType::None } }; msg.head.status = status; - msg.head.version = Some(version); + msg.head.version = Some(ver); Ok(Some((msg, decoder))) } diff --git a/src/h1/dispatcher.rs b/src/h1/dispatcher.rs index 470d3df9..4a0bce72 100644 --- a/src/h1/dispatcher.rs +++ b/src/h1/dispatcher.rs @@ -1,11 +1,12 @@ use std::collections::VecDeque; use std::fmt::Debug; +use std::mem; use std::time::Instant; use actix_net::codec::Framed; use actix_net::service::Service; -use futures::{Async, AsyncSink, Future, Poll, Sink, Stream}; +use futures::{Async, Future, Poll, Sink, Stream}; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_timer::Delay; @@ -37,12 +38,19 @@ bitflags! { /// Dispatcher for HTTP/1.1 protocol pub struct Dispatcher +where + S::Error: Debug, +{ + inner: Option>, +} + +struct InnerDispatcher where S::Error: Debug, { service: S, flags: Flags, - framed: Option>, + framed: Framed, error: Option>, config: ServiceConfig, @@ -63,7 +71,6 @@ enum DispatcherMessage { enum State { None, ServiceCall(S::Future), - SendResponse(Option<(Message, Body)>), SendPayload(BodyStream), } @@ -113,20 +120,29 @@ where }; Dispatcher { - payload: None, - state: State::None, - error: None, - messages: VecDeque::new(), - framed: Some(framed), - unhandled: None, - service, - flags, - config, - ka_expire, - ka_timer, + inner: Some(InnerDispatcher { + framed, + payload: None, + state: State::None, + error: None, + messages: VecDeque::new(), + unhandled: None, + service, + flags, + config, + ka_expire, + ka_timer, + }), } } +} +impl InnerDispatcher +where + T: AsyncRead + AsyncWrite, + S: Service, + S::Error: Debug, +{ fn can_read(&self) -> bool { if self.flags.contains(Flags::DISCONNECTED) { return false; @@ -150,7 +166,7 @@ where /// Flush stream fn poll_flush(&mut self) -> Poll<(), DispatchError> { if !self.flags.contains(Flags::FLUSHED) { - match self.framed.as_mut().unwrap().poll_complete() { + match self.framed.poll_complete() { Ok(Async::NotReady) => Ok(Async::NotReady), Err(err) => { debug!("Error sending data: {}", err); @@ -170,90 +186,82 @@ where } } + fn send_response( + &mut self, + message: Response, + body: Body, + ) -> Result, DispatchError> { + self.framed + .force_send(Message::Item(message)) + .map_err(|err| { + if let Some(mut payload) = self.payload.take() { + payload.set_error(PayloadError::Incomplete(None)); + } + DispatchError::Io(err) + })?; + + self.flags + .set(Flags::KEEPALIVE, self.framed.get_codec().keepalive()); + self.flags.remove(Flags::FLUSHED); + match body { + Body::Empty => Ok(State::None), + Body::Streaming(stream) => Ok(State::SendPayload(stream)), + Body::Binary(mut bin) => { + self.flags.remove(Flags::FLUSHED); + self.framed.force_send(Message::Chunk(Some(bin.take())))?; + self.framed.force_send(Message::Chunk(None))?; + Ok(State::None) + } + } + } + fn poll_response(&mut self) -> Result<(), DispatchError> { let mut retry = self.can_read(); - - // process loop { - let state = match self.state { - State::None => if let Some(msg) = self.messages.pop_front() { - match msg { - DispatcherMessage::Item(req) => Some(self.handle_request(req)?), - DispatcherMessage::Error(res) => Some(State::SendResponse( - Some((Message::Item(res), Body::Empty)), - )), + let state = match mem::replace(&mut self.state, State::None) { + State::None => match self.messages.pop_front() { + Some(DispatcherMessage::Item(req)) => { + Some(self.handle_request(req)?) } - } else { - None + Some(DispatcherMessage::Error(res)) => { + Some(self.send_response(res, Body::Empty)?) + } + None => None, }, - // call inner service - State::ServiceCall(ref mut fut) => { + State::ServiceCall(mut fut) => { match fut.poll().map_err(DispatchError::Service)? { Async::Ready(mut res) => { - self.framed - .as_mut() - .unwrap() - .get_codec_mut() - .prepare_te(&mut res); + self.framed.get_codec_mut().prepare_te(&mut res); let body = res.replace_body(Body::Empty); - Some(State::SendResponse(Some((Message::Item(res), body)))) + Some(self.send_response(res, body)?) } - Async::NotReady => None, - } - } - // send respons - State::SendResponse(ref mut item) => { - let (msg, body) = item.take().expect("SendResponse is empty"); - let framed = self.framed.as_mut().unwrap(); - match framed.start_send(msg) { - Ok(AsyncSink::Ready) => { - self.flags - .set(Flags::KEEPALIVE, framed.get_codec().keepalive()); - self.flags.remove(Flags::FLUSHED); - match body { - Body::Empty => Some(State::None), - Body::Streaming(stream) => { - Some(State::SendPayload(stream)) - } - Body::Binary(mut bin) => { - self.flags.remove(Flags::FLUSHED); - framed - .force_send(Message::Chunk(Some(bin.take())))?; - framed.force_send(Message::Chunk(None))?; - Some(State::None) - } - } - } - Ok(AsyncSink::NotReady(msg)) => { - *item = Some((msg, body)); - return Ok(()); - } - Err(err) => { - if let Some(mut payload) = self.payload.take() { - payload.set_error(PayloadError::Incomplete(None)); - } - return Err(DispatchError::Io(err)); + Async::NotReady => { + self.state = State::ServiceCall(fut); + None } } } - // Send payload - State::SendPayload(ref mut stream) => { - let mut framed = self.framed.as_mut().unwrap(); + State::SendPayload(mut stream) => { loop { - if !framed.is_write_buf_full() { + if !self.framed.is_write_buf_full() { match stream.poll().map_err(|_| DispatchError::Unknown)? { Async::Ready(Some(item)) => { self.flags.remove(Flags::FLUSHED); - framed.force_send(Message::Chunk(Some(item)))?; + self.framed + .force_send(Message::Chunk(Some(item)))?; continue; } Async::Ready(None) => { self.flags.remove(Flags::FLUSHED); - framed.force_send(Message::Chunk(None))?; + self.framed.force_send(Message::Chunk(None))?; + } + Async::NotReady => { + self.state = State::SendPayload(stream); + return Ok(()); } - Async::NotReady => return Ok(()), } } else { + self.state = State::SendPayload(stream); return Ok(()); } break; @@ -266,7 +274,7 @@ where Some(state) => self.state = state, None => { // if read-backpressure is enabled and we consumed some data. - // we may read more dataand retry + // we may read more data and retry if !retry && self.can_read() && self.poll_request()? { retry = self.can_read(); continue; @@ -286,13 +294,9 @@ where let mut task = self.service.call(req); match task.poll().map_err(DispatchError::Service)? { Async::Ready(mut res) => { - self.framed - .as_mut() - .unwrap() - .get_codec_mut() - .prepare_te(&mut res); + self.framed.get_codec_mut().prepare_te(&mut res); let body = res.replace_body(Body::Empty); - Ok(State::SendResponse(Some((Message::Item(res), body)))) + self.send_response(res, body) } Async::NotReady => Ok(State::ServiceCall(task)), } @@ -307,20 +311,14 @@ where let mut updated = false; loop { - match self.framed.as_mut().unwrap().poll() { + match self.framed.poll() { Ok(Async::Ready(Some(msg))) => { updated = true; self.flags.insert(Flags::STARTED); match msg { Message::Item(req) => { - match self - .framed - .as_ref() - .unwrap() - .get_codec() - .message_type() - { + match self.framed.get_codec().message_type() { MessageType::Payload => { let (ps, pl) = Payload::new(false); *req.inner.payload.borrow_mut() = Some(pl); @@ -406,52 +404,58 @@ where /// keep-alive timer fn poll_keepalive(&mut self) -> Result<(), DispatchError> { - if let Some(ref mut timer) = self.ka_timer { - match timer.poll() { - Ok(Async::Ready(_)) => { - // if we get timer during shutdown, just drop connection - if self.flags.contains(Flags::SHUTDOWN) { - return Err(DispatchError::DisconnectTimeout); - } else if timer.deadline() >= self.ka_expire { - // check for any outstanding response processing - if self.state.is_empty() && self.flags.contains(Flags::FLUSHED) { - if self.flags.contains(Flags::STARTED) { - trace!("Keep-alive timeout, close connection"); - self.flags.insert(Flags::SHUTDOWN); + if self.ka_timer.is_some() { + return Ok(()); + } + match self.ka_timer.as_mut().unwrap().poll().map_err(|e| { + error!("Timer error {:?}", e); + DispatchError::Unknown + })? { + Async::Ready(_) => { + // if we get timeout during shutdown, drop connection + if self.flags.contains(Flags::SHUTDOWN) { + return Err(DispatchError::DisconnectTimeout); + } else if self.ka_timer.as_mut().unwrap().deadline() >= self.ka_expire { + // check for any outstanding response processing + if self.state.is_empty() && self.flags.contains(Flags::FLUSHED) { + if self.flags.contains(Flags::STARTED) { + trace!("Keep-alive timeout, close connection"); + self.flags.insert(Flags::SHUTDOWN); - // start shutdown timer - if let Some(deadline) = - self.config.client_disconnect_timer() - { + // start shutdown timer + if let Some(deadline) = self.config.client_disconnect_timer() + { + self.ka_timer.as_mut().map(|timer| { timer.reset(deadline); let _ = timer.poll(); - } else { - return Ok(()); - } + }); } else { - // timeout on first request (slow request) return 408 - trace!("Slow request timeout"); - self.flags.insert(Flags::STARTED | Flags::DISCONNECTED); - self.state = State::SendResponse(Some(( - Message::Item(Response::RequestTimeout().finish()), - Body::Empty, - ))); + return Ok(()); } - } else if let Some(deadline) = self.config.keep_alive_expire() { + } else { + // timeout on first request (slow request) return 408 + trace!("Slow request timeout"); + self.flags.insert(Flags::STARTED | Flags::DISCONNECTED); + self.state = self.send_response( + Response::RequestTimeout().finish(), + Body::Empty, + )?; + } + } else if let Some(deadline) = self.config.keep_alive_expire() { + self.ka_timer.as_mut().map(|timer| { timer.reset(deadline); let _ = timer.poll(); - } - } else { - timer.reset(self.ka_expire); - let _ = timer.poll(); + }); } - } - Ok(Async::NotReady) => (), - Err(e) => { - error!("Timer error {:?}", e); - return Err(DispatchError::Unknown); + } else { + let expire = self.ka_expire; + self.ka_timer.as_mut().map(|timer| { + timer.reset(expire); + let _ = timer.poll(); + }); } } + Async::NotReady => (), } Ok(()) @@ -469,43 +473,53 @@ where #[inline] fn poll(&mut self) -> Poll { - if self.flags.contains(Flags::SHUTDOWN) { - self.poll_keepalive()?; - try_ready!(self.poll_flush()); - let io = self.framed.take().unwrap().into_inner(); - Ok(Async::Ready(H1ServiceResult::Shutdown(io))) - } else { - self.poll_keepalive()?; - self.poll_request()?; - self.poll_response()?; - self.poll_flush()?; - - // keep-alive and stream errors - if self.state.is_empty() && self.flags.contains(Flags::FLUSHED) { - if let Some(err) = self.error.take() { - Err(err) - } else if self.flags.contains(Flags::DISCONNECTED) { - Ok(Async::Ready(H1ServiceResult::Disconnected)) - } - // unhandled request (upgrade or connect) - else if self.unhandled.is_some() { - let req = self.unhandled.take().unwrap(); - let framed = self.framed.take().unwrap(); - Ok(Async::Ready(H1ServiceResult::Unhandled(req, framed))) - } - // disconnect if keep-alive is not enabled - else if self.flags.contains(Flags::STARTED) && !self - .flags - .intersects(Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED) - { - let io = self.framed.take().unwrap().into_inner(); - Ok(Async::Ready(H1ServiceResult::Shutdown(io))) - } else { - Ok(Async::NotReady) - } + let shutdown = if let Some(ref mut inner) = self.inner { + if inner.flags.contains(Flags::SHUTDOWN) { + inner.poll_keepalive()?; + try_ready!(inner.poll_flush()); + true } else { - Ok(Async::NotReady) + inner.poll_keepalive()?; + inner.poll_request()?; + inner.poll_response()?; + inner.poll_flush()?; + + // keep-alive and stream errors + if inner.state.is_empty() && inner.flags.contains(Flags::FLUSHED) { + if let Some(err) = inner.error.take() { + return Err(err); + } else if inner.flags.contains(Flags::DISCONNECTED) { + return Ok(Async::Ready(H1ServiceResult::Disconnected)); + } + // unhandled request (upgrade or connect) + else if inner.unhandled.is_some() { + false + } + // disconnect if keep-alive is not enabled + else if inner.flags.contains(Flags::STARTED) && !inner + .flags + .intersects(Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED) + { + true + } else { + return Ok(Async::NotReady); + } + } else { + return Ok(Async::NotReady); + } } + } else { + unreachable!() + }; + + let mut inner = self.inner.take().unwrap(); + if shutdown { + Ok(Async::Ready(H1ServiceResult::Shutdown( + inner.framed.into_inner(), + ))) + } else { + let req = inner.unhandled.take().unwrap(); + Ok(Async::Ready(H1ServiceResult::Unhandled(req, inner.framed))) } } }