From b0ca6220f07fafd9862542cb703cefc1753c16e9 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sat, 6 Oct 2018 22:36:57 -0700 Subject: [PATCH] refactor te encoding --- src/error.rs | 6 +- src/h1/codec.rs | 32 +++-------- src/h1/dispatcher.rs | 132 +++++++++++++++++++++---------------------- src/h1/service.rs | 13 +++-- 4 files changed, 83 insertions(+), 100 deletions(-) diff --git a/src/error.rs b/src/error.rs index 26b3ca56..277814d2 100644 --- a/src/error.rs +++ b/src/error.rs @@ -374,7 +374,7 @@ impl ResponseError for cookie::ParseError { #[derive(Debug)] /// A set of errors that can occur during dispatching http requests -pub enum DispatchError { +pub enum DispatchError { /// Service error // #[fail(display = "Application specific error: {}", _0)] Service(E), @@ -413,13 +413,13 @@ pub enum DispatchError { Unknown, } -impl From for DispatchError { +impl From for DispatchError { fn from(err: ParseError) -> Self { DispatchError::Parse(err) } } -impl From for DispatchError { +impl From for DispatchError { fn from(err: io::Error) -> Self { DispatchError::Io(err) } diff --git a/src/h1/codec.rs b/src/h1/codec.rs index 8ab8f252..247b0f01 100644 --- a/src/h1/codec.rs +++ b/src/h1/codec.rs @@ -54,7 +54,6 @@ pub struct Codec { // encoder part flags: Flags, - written: u64, headers_size: u32, te: ResponseEncoder, } @@ -82,31 +81,30 @@ impl Codec { version: Version::HTTP_11, flags, - written: 0, headers_size: 0, te: ResponseEncoder::default(), } } - fn written(&self) -> u64 { - self.written - } - + /// Check if request is upgrade pub fn upgrade(&self) -> bool { self.flags.contains(Flags::UPGRADE) } + /// Check if last response is keep-alive pub fn keepalive(&self) -> bool { self.flags.contains(Flags::KEEPALIVE) } + /// prepare transfer encoding + pub fn prepare_te(&mut self, res: &mut Response) { + self.te + .update(res, self.flags.contains(Flags::HEAD), self.version); + } + fn encode_response( &mut self, mut msg: Response, buffer: &mut BytesMut, ) -> io::Result<()> { - // prepare transfer encoding - self.te - .update(&mut msg, self.flags.contains(Flags::HEAD), self.version); - let ka = self.flags.contains(Flags::KEEPALIVE_ENABLED) && msg .keep_alive() .unwrap_or_else(|| self.flags.contains(Flags::KEEPALIVE)); @@ -131,12 +129,11 @@ impl Codec { msg.headers_mut() .insert(CONNECTION, HeaderValue::from_static("close")); } - let body = msg.replace_body(Body::Empty); // render message { let reason = msg.reason().as_bytes(); - if let Body::Binary(ref bytes) = body { + if let Body::Binary(ref bytes) = msg.body() { buffer.reserve( 256 + msg.headers().len() * AVERAGE_HEADER_SIZE + bytes.len() @@ -229,16 +226,6 @@ impl Codec { self.headers_size = buffer.len() as u32; } - if let Body::Binary(bytes) = body { - self.written = bytes.len() as u64; - // buffer.write(bytes.as_ref())?; - buffer.extend_from_slice(bytes.as_ref()); - } else { - // capacity, makes sense only for streaming or actor - // self.buffer_capacity = msg.write_buffer_capacity(); - - msg.replace_body(body); - } Ok(()) } } @@ -282,7 +269,6 @@ impl Encoder for Codec { ) -> Result<(), Self::Error> { match item { OutMessage::Response(res) => { - self.written = 0; self.encode_response(res, dst)?; } OutMessage::Payload(bytes) => { diff --git a/src/h1/dispatcher.rs b/src/h1/dispatcher.rs index a39967a2..7b6d31fe 100644 --- a/src/h1/dispatcher.rs +++ b/src/h1/dispatcher.rs @@ -1,15 +1,15 @@ use std::collections::VecDeque; -use std::fmt::{Debug, Display}; use std::time::Instant; use actix_net::codec::Framed; use actix_net::service::Service; use futures::{Async, AsyncSink, Future, Poll, Sink, Stream}; +use log::Level::Debug; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_timer::Delay; -use error::{ParseError, PayloadError}; +use error::{Error, ParseError, PayloadError}; use payload::{Payload, PayloadSender, PayloadStatus, PayloadWriter}; use body::Body; @@ -38,7 +38,7 @@ bitflags! { /// Dispatcher for HTTP/1.1 protocol pub struct Dispatcher where - S::Error: Debug + Display, + S::Error: Into, { service: S, flags: Flags, @@ -81,7 +81,7 @@ impl Dispatcher where T: AsyncRead + AsyncWrite, S: Service, - S::Error: Debug + Display, + S::Error: Into, { /// Create http/1 dispatcher. pub fn new(stream: T, config: ServiceConfig, service: S) -> Self { @@ -177,52 +177,34 @@ where State::None => loop { break if let Some(msg) = self.messages.pop_front() { match msg { - Message::Item(msg) => { - let mut task = self.service.call(msg); - match task.poll() { - Ok(Async::Ready(res)) => { - if res.body().is_streaming() { - unimplemented!() - } else { - Some(Ok(State::SendResponse(Some( - OutMessage::Response(res), - )))) - } - } - Ok(Async::NotReady) => { - Some(Ok(State::Response(task))) - } - Err(err) => Some(Err(DispatchError::Service(err))), - } - } - Message::Error(res) => Some(Ok(State::SendResponse(Some( + Message::Item(req) => Some(self.handle_request(req)), + Message::Error(res) => Some(State::SendResponse(Some( OutMessage::Response(res), - )))), + ))), } } else { None }; }, State::Payload(ref mut _body) => unimplemented!(), - State::Response(ref mut fut) => { - match fut.poll() { - Ok(Async::Ready(res)) => { - if res.body().is_streaming() { - unimplemented!() - } else { - Some(Ok(State::SendResponse(Some( - OutMessage::Response(res), - )))) - } - } - Ok(Async::NotReady) => None, - Err(err) => { - // it is not possible to recover from error - // during pipe handling, so just drop connection - Some(Err(DispatchError::Service(err))) + State::Response(ref mut fut) => match fut.poll() { + Ok(Async::Ready(mut res)) => { + self.framed.get_codec_mut().prepare_te(&mut res); + if res.body().is_streaming() { + unimplemented!() + } else { + Some(State::SendResponse(Some(OutMessage::Response(res)))) } } - } + Ok(Async::NotReady) => None, + Err(err) => { + let err = err.into(); + if log_enabled!(Debug) { + debug!("{:?}", err); + } + Some(State::SendResponse(Some(OutMessage::Response(err.into())))) + } + }, State::SendResponse(ref mut item) => { let msg = item.take().expect("SendResponse is empty"); match self.framed.start_send(msg) { @@ -232,13 +214,19 @@ where self.framed.get_codec().keepalive(), ); self.flags.remove(Flags::FLUSHED); - Some(Ok(State::None)) + Some(State::None) } Ok(AsyncSink::NotReady(msg)) => { *item = Some(msg); return Ok(()); } - Err(err) => Some(Err(DispatchError::Io(err))), + Err(err) => { + self.flags.insert(Flags::READ_DISCONNECTED); + if let Some(mut payload) = self.payload.take() { + payload.set_error(PayloadError::Incomplete); + } + return Err(DispatchError::Io(err)); + } } } State::SendResponseWithPayload(ref mut item) => { @@ -251,23 +239,25 @@ where self.framed.get_codec().keepalive(), ); self.flags.remove(Flags::FLUSHED); - Some(Ok(State::Payload(body))) + Some(State::Payload(body)) } Ok(AsyncSink::NotReady(msg)) => { *item = Some((msg, body)); return Ok(()); } - Err(err) => Some(Err(DispatchError::Io(err))), + Err(err) => { + self.flags.insert(Flags::READ_DISCONNECTED); + if let Some(mut payload) = self.payload.take() { + payload.set_error(PayloadError::Incomplete); + } + return Err(DispatchError::Io(err)); + } } } }; match state { - Some(Ok(state)) => self.state = state, - Some(Err(err)) => { - self.client_disconnected(); - return Err(err); - } + Some(state) => self.state = state, None => { // if read-backpressure is enabled and we consumed some data. // we may read more dataand retry @@ -283,6 +273,28 @@ where Ok(()) } + fn handle_request(&mut self, req: Request) -> State { + let mut task = self.service.call(req); + match task.poll() { + Ok(Async::Ready(mut res)) => { + self.framed.get_codec_mut().prepare_te(&mut res); + if res.body().is_streaming() { + unimplemented!() + } else { + State::SendResponse(Some(OutMessage::Response(res))) + } + } + Ok(Async::NotReady) => State::Response(task), + Err(err) => { + let err = err.into(); + if log_enabled!(Debug) { + debug!("{:?}", err); + } + State::SendResponse(Some(OutMessage::Response(err.into()))) + } + } + } + fn one_message(&mut self, msg: InMessage) -> Result<(), DispatchError> { self.flags.insert(Flags::STARTED); @@ -290,23 +302,7 @@ where InMessage::Message(msg) => { // handle request early if self.state.is_empty() { - let mut task = self.service.call(msg); - match task.poll() { - Ok(Async::Ready(res)) => { - if res.body().is_streaming() { - unimplemented!() - } else { - self.state = - State::SendResponse(Some(OutMessage::Response(res))); - } - } - Ok(Async::NotReady) => self.state = State::Response(task), - Err(err) => { - error!("Unhandled application error: {}", err); - self.client_disconnected(); - return Err(DispatchError::Service(err)); - } - } + self.state = self.handle_request(msg); } else { self.messages.push_back(Message::Item(msg)); } @@ -449,7 +445,7 @@ impl Future for Dispatcher where T: AsyncRead + AsyncWrite, S: Service, - S::Error: Debug + Display, + S::Error: Into, { type Item = (); type Error = DispatchError; diff --git a/src/h1/service.rs b/src/h1/service.rs index 02535fc8..3ac073ad 100644 --- a/src/h1/service.rs +++ b/src/h1/service.rs @@ -1,4 +1,3 @@ -use std::fmt::{Debug, Display}; use std::marker::PhantomData; use actix_net::codec::Framed; @@ -7,7 +6,7 @@ use futures::{future, Async, Future, Poll, Stream}; use tokio_io::{AsyncRead, AsyncWrite}; use config::ServiceConfig; -use error::{DispatchError, ParseError}; +use error::{DispatchError, Error, ParseError}; use request::Request; use response::Response; @@ -24,6 +23,8 @@ pub struct H1Service { impl H1Service where S: NewService, + S::Service: Clone, + S::Error: Into, { /// Create new `HttpService` instance. pub fn new>(cfg: ServiceConfig, service: F) -> Self { @@ -40,7 +41,7 @@ where T: AsyncRead + AsyncWrite, S: NewService + Clone, S::Service: Clone, - S::Error: Debug + Display, + S::Error: Into, { type Request = T; type Response = (); @@ -69,7 +70,7 @@ where T: AsyncRead + AsyncWrite, S: NewService, S::Service: Clone, - S::Error: Debug + Display, + S::Error: Into, { type Item = H1ServiceHandler; type Error = S::InitError; @@ -93,7 +94,7 @@ pub struct H1ServiceHandler { impl H1ServiceHandler where S: Service + Clone, - S::Error: Debug + Display, + S::Error: Into, { fn new(cfg: ServiceConfig, srv: S) -> H1ServiceHandler { H1ServiceHandler { @@ -108,7 +109,7 @@ impl Service for H1ServiceHandler where T: AsyncRead + AsyncWrite, S: Service + Clone, - S::Error: Debug + Display, + S::Error: Into, { type Request = T; type Response = ();