From d39c018c9384963030d0adda0e29a0735d7d5179 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sat, 13 Oct 2018 23:57:31 -0700 Subject: [PATCH] do not handle upgrade and connect requests --- src/h1/codec.rs | 43 ++++++++++++++++++++------- src/h1/decoder.rs | 45 ++++++++++++++++++++--------- src/h1/dispatcher.rs | 69 ++++++++++++++++++++++++++++++-------------- src/h1/mod.rs | 13 ++++++++- src/h1/service.rs | 5 ++-- tests/test_server.rs | 21 +++++++++----- tests/test_ws.rs | 2 +- 7 files changed, 141 insertions(+), 57 deletions(-) diff --git a/src/h1/codec.rs b/src/h1/codec.rs index 04cf395b6..a91f5cb34 100644 --- a/src/h1/codec.rs +++ b/src/h1/codec.rs @@ -4,7 +4,7 @@ use std::io::{self, Write}; use bytes::{BufMut, Bytes, BytesMut}; use tokio_codec::{Decoder, Encoder}; -use super::decoder::{PayloadDecoder, PayloadItem, RequestDecoder}; +use super::decoder::{PayloadDecoder, PayloadItem, RequestDecoder, RequestPayloadType}; use super::encoder::{ResponseEncoder, ResponseLength}; use body::{Binary, Body}; use config::ServiceConfig; @@ -17,10 +17,11 @@ use response::Response; bitflags! { struct Flags: u8 { - const HEAD = 0b0000_0001; - const UPGRADE = 0b0000_0010; - const KEEPALIVE = 0b0000_0100; - const KEEPALIVE_ENABLED = 0b0001_0000; + const HEAD = 0b0000_0001; + const UPGRADE = 0b0000_0010; + const KEEPALIVE = 0b0000_0100; + const KEEPALIVE_ENABLED = 0b0000_1000; + const UNHANDLED = 0b0001_0000; } } @@ -39,11 +40,19 @@ pub enum OutMessage { #[derive(Debug)] pub enum InMessage { /// Request - Message { req: Request, payload: bool }, + Message(Request, InMessageType), /// Payload chunk Chunk(Option), } +/// Incoming request type +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum InMessageType { + None, + Payload, + Unhandled, +} + /// HTTP/1 Codec pub struct Codec { config: ServiceConfig, @@ -246,6 +255,8 @@ impl Decoder for Codec { Some(PayloadItem::Eof) => Some(InMessage::Chunk(None)), None => None, }) + } else if self.flags.contains(Flags::UNHANDLED) { + Ok(None) } else if let Some((req, payload)) = self.decoder.decode(src)? { self.flags .set(Flags::HEAD, req.inner.method == Method::HEAD); @@ -253,11 +264,21 @@ impl Decoder for Codec { if self.flags.contains(Flags::KEEPALIVE_ENABLED) { self.flags.set(Flags::KEEPALIVE, req.keep_alive()); } - self.payload = payload; - Ok(Some(InMessage::Message { - req, - payload: self.payload.is_some(), - })) + let payload = match payload { + RequestPayloadType::None => { + self.payload = None; + InMessageType::None + } + RequestPayloadType::Payload(pl) => { + self.payload = Some(pl); + InMessageType::Payload + } + RequestPayloadType::Unhandled => { + self.payload = None; + InMessageType::Unhandled + } + }; + Ok(Some(InMessage::Message(req, payload))) } else { Ok(None) } diff --git a/src/h1/decoder.rs b/src/h1/decoder.rs index 5fe8b19c8..c2a8d0e99 100644 --- a/src/h1/decoder.rs +++ b/src/h1/decoder.rs @@ -16,6 +16,13 @@ const MAX_HEADERS: usize = 96; pub struct RequestDecoder(&'static RequestPool); +/// Incoming request type +pub enum RequestPayloadType { + None, + Payload(PayloadDecoder), + Unhandled, +} + impl RequestDecoder { pub(crate) fn with_pool(pool: &'static RequestPool) -> RequestDecoder { RequestDecoder(pool) @@ -29,7 +36,7 @@ impl Default for RequestDecoder { } impl Decoder for RequestDecoder { - type Item = (Request, Option); + type Item = (Request, RequestPayloadType); type Error = ParseError; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { @@ -149,18 +156,18 @@ impl Decoder for RequestDecoder { // https://tools.ietf.org/html/rfc7230#section-3.3.3 let decoder = if chunked { // Chunked encoding - Some(PayloadDecoder::chunked()) + RequestPayloadType::Payload(PayloadDecoder::chunked()) } else if let Some(len) = content_length { // Content-Length - Some(PayloadDecoder::length(len)) + RequestPayloadType::Payload(PayloadDecoder::length(len)) } else if has_upgrade || msg.inner.method == Method::CONNECT { // upgrade(websocket) or connect - Some(PayloadDecoder::eof()) + RequestPayloadType::Unhandled } else if src.len() >= MAX_BUFFER_SIZE { error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); return Err(ParseError::TooLarge); } else { - None + RequestPayloadType::None }; Ok(Some((msg, decoder))) @@ -481,20 +488,36 @@ mod tests { use super::*; use error::ParseError; - use h1::InMessage; + use h1::{InMessage, InMessageType}; use httpmessage::HttpMessage; use request::Request; + impl RequestPayloadType { + fn unwrap(self) -> PayloadDecoder { + match self { + RequestPayloadType::Payload(pl) => pl, + _ => panic!(), + } + } + + fn is_unhandled(&self) -> bool { + match self { + RequestPayloadType::Unhandled => true, + _ => false, + } + } + } + impl InMessage { fn message(self) -> Request { match self { - InMessage::Message { req, payload: _ } => req, + InMessage::Message(req, _) => req, _ => panic!("error"), } } fn is_payload(&self) -> bool { match *self { - InMessage::Message { req: _, payload } => payload, + InMessage::Message(_, payload) => payload == InMessageType::Payload, _ => panic!("error"), } } @@ -919,13 +942,9 @@ mod tests { ); let mut reader = RequestDecoder::default(); let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); - let mut pl = pl.unwrap(); assert!(!req.keep_alive()); assert!(req.upgrade()); - assert_eq!( - pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), - b"some raw data" - ); + assert!(pl.is_unhandled()); } #[test] diff --git a/src/h1/dispatcher.rs b/src/h1/dispatcher.rs index 8b7c29331..8ae2ae8ce 100644 --- a/src/h1/dispatcher.rs +++ b/src/h1/dispatcher.rs @@ -18,7 +18,8 @@ use error::DispatchError; use request::Request; use response::Response; -use super::codec::{Codec, InMessage, OutMessage}; +use super::codec::{Codec, InMessage, InMessageType, OutMessage}; +use super::H1ServiceResult; const MAX_PIPELINED_MESSAGES: usize = 16; @@ -41,13 +42,14 @@ where { service: S, flags: Flags, - framed: Framed, + framed: Option>, error: Option>, config: ServiceConfig, state: State, payload: Option, messages: VecDeque, + unhandled: Option, ka_expire: Instant, ka_timer: Option, @@ -112,9 +114,10 @@ where state: State::None, error: None, messages: VecDeque::new(), + framed: Some(framed), + unhandled: None, service, flags, - framed, config, ka_expire, ka_timer, @@ -144,7 +147,7 @@ where /// Flush stream fn poll_flush(&mut self) -> Poll<(), DispatchError> { if !self.flags.contains(Flags::FLUSHED) { - match self.framed.poll_complete() { + match self.framed.as_mut().unwrap().poll_complete() { Ok(Async::NotReady) => Ok(Async::NotReady), Err(err) => { debug!("Error sending data: {}", err); @@ -187,7 +190,11 @@ where State::ServiceCall(ref mut fut) => { match fut.poll().map_err(DispatchError::Service)? { Async::Ready(mut res) => { - self.framed.get_codec_mut().prepare_te(&mut res); + self.framed + .as_mut() + .unwrap() + .get_codec_mut() + .prepare_te(&mut res); let body = res.replace_body(Body::Empty); Some(State::SendResponse(Some(( OutMessage::Response(res), @@ -200,11 +207,11 @@ where // send respons State::SendResponse(ref mut item) => { let (msg, body) = item.take().expect("SendResponse is empty"); - match self.framed.start_send(msg) { + match self.framed.as_mut().unwrap().start_send(msg) { Ok(AsyncSink::Ready) => { self.flags.set( Flags::KEEPALIVE, - self.framed.get_codec().keepalive(), + self.framed.as_mut().unwrap().get_codec().keepalive(), ); self.flags.remove(Flags::FLUSHED); match body { @@ -233,7 +240,7 @@ where // Send payload State::SendPayload(ref mut stream, ref mut bin) => { if let Some(item) = bin.take() { - match self.framed.start_send(item) { + match self.framed.as_mut().unwrap().start_send(item) { Ok(AsyncSink::Ready) => { self.flags.remove(Flags::FLUSHED); } @@ -248,6 +255,8 @@ where match stream.poll() { Ok(Async::Ready(Some(item))) => match self .framed + .as_mut() + .unwrap() .start_send(OutMessage::Chunk(Some(item.into()))) { Ok(AsyncSink::Ready) => { @@ -297,7 +306,11 @@ where let mut task = self.service.call(req); match task.poll().map_err(DispatchError::Service)? { Async::Ready(mut res) => { - self.framed.get_codec_mut().prepare_te(&mut res); + self.framed + .as_mut() + .unwrap() + .get_codec_mut() + .prepare_te(&mut res); let body = res.replace_body(Body::Empty); Ok(State::SendResponse(Some((OutMessage::Response(res), body)))) } @@ -314,17 +327,24 @@ where let mut updated = false; 'outer: loop { - match self.framed.poll() { + match self.framed.as_mut().unwrap().poll() { Ok(Async::Ready(Some(msg))) => { updated = true; self.flags.insert(Flags::STARTED); match msg { - InMessage::Message { req, payload } => { - if payload { - let (ps, pl) = Payload::new(false); - *req.inner.payload.borrow_mut() = Some(pl); - self.payload = Some(ps); + InMessage::Message(req, payload) => { + match payload { + InMessageType::Payload => { + let (ps, pl) = Payload::new(false); + *req.inner.payload.borrow_mut() = Some(pl); + self.payload = Some(ps); + } + InMessageType::Unhandled => { + self.unhandled = Some(req); + return Ok(updated); + } + _ => (), } // handle request early @@ -454,15 +474,16 @@ where S: Service, S::Error: Debug, { - type Item = (); + type Item = H1ServiceResult; type Error = DispatchError; #[inline] - fn poll(&mut self) -> Poll<(), Self::Error> { + fn poll(&mut self) -> Poll { if self.flags.contains(Flags::SHUTDOWN) { self.poll_keepalive()?; try_ready!(self.poll_flush()); - Ok(AsyncWrite::shutdown(self.framed.get_mut())?) + let io = self.framed.take().unwrap().into_inner(); + Ok(Async::Ready(H1ServiceResult::Shutdown(io))) } else { self.poll_keepalive()?; self.poll_request()?; @@ -474,15 +495,21 @@ where if let Some(err) = self.error.take() { Err(err) } else if self.flags.contains(Flags::DISCONNECTED) { - Ok(Async::Ready(())) + Ok(Async::Ready(H1ServiceResult::Disconnected)) + } + // unhandled request (upgrade or connect) + else if self.unhandled.is_some() { + let req = self.unhandled.take().unwrap(); + let framed = self.framed.take().unwrap(); + Ok(Async::Ready(H1ServiceResult::Unhandled(req, framed))) } // disconnect if keep-alive is not enabled else if self.flags.contains(Flags::STARTED) && !self .flags .intersects(Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED) { - self.flags.insert(Flags::SHUTDOWN); - self.poll() + let io = self.framed.take().unwrap().into_inner(); + Ok(Async::Ready(H1ServiceResult::Shutdown(io))) } else { Ok(Async::NotReady) } diff --git a/src/h1/mod.rs b/src/h1/mod.rs index 634136a47..4e196ad54 100644 --- a/src/h1/mod.rs +++ b/src/h1/mod.rs @@ -1,11 +1,22 @@ //! HTTP/1 implementation +use actix_net::codec::Framed; + mod codec; mod decoder; mod dispatcher; mod encoder; mod service; -pub use self::codec::{Codec, InMessage, OutMessage}; +pub use self::codec::{Codec, InMessage, InMessageType, OutMessage}; pub use self::decoder::{PayloadDecoder, RequestDecoder}; pub use self::dispatcher::Dispatcher; pub use self::service::{H1Service, H1ServiceHandler}; + +use request::Request; + +/// H1 service response type +pub enum H1ServiceResult { + Disconnected, + Shutdown(T), + Unhandled(Request, Framed), +} diff --git a/src/h1/service.rs b/src/h1/service.rs index aa7a51733..6e9d7d651 100644 --- a/src/h1/service.rs +++ b/src/h1/service.rs @@ -12,6 +12,7 @@ use request::Request; use response::Response; use super::dispatcher::Dispatcher; +use super::H1ServiceResult; /// `NewService` implementation for HTTP1 transport pub struct H1Service { @@ -51,7 +52,7 @@ where S::Error: Debug, { type Request = T; - type Response = (); + type Response = H1ServiceResult; type Error = DispatchError; type InitError = S::InitError; type Service = H1ServiceHandler; @@ -243,7 +244,7 @@ where S::Error: Debug, { type Request = T; - type Response = (); + type Response = H1ServiceResult; type Error = DispatchError; type Future = Dispatcher; diff --git a/tests/test_server.rs b/tests/test_server.rs index f382eafbd..c8de0290d 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -9,6 +9,7 @@ use std::{io::Read, io::Write, net, thread, time}; use actix::System; use actix_net::server::Server; +use actix_net::service::NewServiceExt; use actix_web::{client, test, HttpMessage}; use bytes::Bytes; use futures::future::{self, ok}; @@ -29,6 +30,7 @@ fn test_h1_v2() { .server_hostname("localhost") .server_address(addr) .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) + .map(|_| ()) }).unwrap() .run(); }); @@ -53,6 +55,7 @@ fn test_slow_request() { h1::H1Service::build() .client_timeout(100) .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) + .map(|_| ()) }).unwrap() .run(); }); @@ -72,6 +75,7 @@ fn test_malformed_request() { Server::new() .bind("test", addr, move || { h1::H1Service::new(|_| future::ok::<_, ()>(Response::Ok().finish())) + .map(|_| ()) }).unwrap() .run(); }); @@ -106,7 +110,7 @@ fn test_content_length() { StatusCode::NOT_FOUND, ]; future::ok::<_, ()>(Response::new(statuses[indx])) - }) + }).map(|_| ()) }).unwrap() .run(); }); @@ -172,7 +176,7 @@ fn test_headers() { ); } future::ok::<_, ()>(builder.body(data.clone())) - }) + }).map(|_| ()) }) .unwrap() .run() @@ -221,6 +225,7 @@ fn test_body() { Server::new() .bind("test", addr, move || { h1::H1Service::new(|_| future::ok::<_, ()>(Response::Ok().body(STR))) + .map(|_| ()) }).unwrap() .run(); }); @@ -246,7 +251,7 @@ fn test_head_empty() { .bind("test", addr, move || { h1::H1Service::new(|_| { ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).finish()) - }) + }).map(|_| ()) }).unwrap() .run() }); @@ -282,7 +287,7 @@ fn test_head_binary() { ok::<_, ()>( Response::Ok().content_length(STR.len() as u64).body(STR), ) - }) + }).map(|_| ()) }).unwrap() .run() }); @@ -314,7 +319,7 @@ fn test_head_binary2() { thread::spawn(move || { Server::new() .bind("test", addr, move || { - h1::H1Service::new(|_| ok::<_, ()>(Response::Ok().body(STR))) + h1::H1Service::new(|_| ok::<_, ()>(Response::Ok().body(STR))).map(|_| ()) }).unwrap() .run() }); @@ -349,7 +354,7 @@ fn test_body_length() { .content_length(STR.len() as u64) .body(Body::Streaming(Box::new(body))), ) - }) + }).map(|_| ()) }).unwrap() .run() }); @@ -380,7 +385,7 @@ fn test_body_chunked_explicit() { .chunked() .body(Body::Streaming(Box::new(body))), ) - }) + }).map(|_| ()) }).unwrap() .run() }); @@ -409,7 +414,7 @@ fn test_body_chunked_implicit() { h1::H1Service::new(|_| { let body = once(Ok(Bytes::from_static(STR.as_ref()))); ok::<_, ()>(Response::Ok().body(Body::Streaming(Box::new(body)))) - }) + }).map(|_| ()) }).unwrap() .run() }); diff --git a/tests/test_ws.rs b/tests/test_ws.rs index 91e212efd..f475cd22d 100644 --- a/tests/test_ws.rs +++ b/tests/test_ws.rs @@ -51,7 +51,7 @@ fn test_simple() { .and_then(TakeItem::new().map_err(|_| ())) .and_then(|(req, framed): (_, Framed<_, _>)| { // validate request - if let Some(h1::InMessage::Message { req, payload: _ }) = req { + if let Some(h1::InMessage::Message(req, _)) = req { match ws::handshake(&req) { Err(e) => { // validation failed