From 561f83d044510ba693ffc5a292dc59568d7d9289 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 8 Apr 2019 17:49:27 -0700 Subject: [PATCH] add upgrade service support to h1 dispatcher --- actix-http/CHANGES.md | 9 ++ actix-http/src/builder.rs | 21 +++ actix-http/src/error.rs | 3 + actix-http/src/h1/dispatcher.rs | 259 ++++++++++++++++++++------------ actix-http/tests/test_client.rs | 4 +- actix-http/tests/test_ws.rs | 76 ++++++++++ 6 files changed, 271 insertions(+), 101 deletions(-) create mode 100644 actix-http/tests/test_ws.rs diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index 4cc18b479..cca4560c5 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -1,9 +1,18 @@ # Changes +## [0.1.0-alpha.5] - 2019-04-xx + +### Added + +* Allow to use custom service for upgrade requests + + ## [0.1.0-alpha.4] - 2019-04-08 ### Added +* Allow to use custom `Expect` handler + * Add minimal `std::error::Error` impl for `Error` ### Changed diff --git a/actix-http/src/builder.rs b/actix-http/src/builder.rs index 7b07d30e3..56f144bd8 100644 --- a/actix-http/src/builder.rs +++ b/actix-http/src/builder.rs @@ -115,6 +115,27 @@ where } } + /// Provide service for custom `Connection: UPGRADE` support. + /// + /// If service is provided then normal requests handling get halted + /// and this service get called with original request and framed object. + pub fn upgrade(self, upgrade: F) -> HttpServiceBuilder + where + F: IntoNewService, + U1: NewService), Response = ()>, + U1::Error: fmt::Display, + U1::InitError: fmt::Debug, + { + HttpServiceBuilder { + keep_alive: self.keep_alive, + client_timeout: self.client_timeout, + client_disconnect: self.client_disconnect, + expect: self.expect, + upgrade: Some(upgrade.into_new_service()), + _t: PhantomData, + } + } + /// Finish service configuration and create *http service* for HTTP/1 protocol. pub fn h1(self, service: F) -> H1Service where diff --git a/actix-http/src/error.rs b/actix-http/src/error.rs index fc37d3243..6573c8ce6 100644 --- a/actix-http/src/error.rs +++ b/actix-http/src/error.rs @@ -350,6 +350,9 @@ pub enum DispatchError { /// Service error Service(Error), + /// Upgrade service error + Upgrade, + /// An `io::Error` that occurred while trying to read or write to a network /// stream. #[display(fmt = "IO error: {}", _0)] diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index eccf2412b..9014047d7 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -2,7 +2,7 @@ use std::collections::VecDeque; use std::time::Instant; use std::{fmt, io}; -use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed}; +use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts}; use actix_service::Service; use actix_utils::cloneable::CloneableService; use bitflags::bitflags; @@ -34,7 +34,7 @@ bitflags! { const SHUTDOWN = 0b0000_1000; const READ_DISCONNECT = 0b0001_0000; const WRITE_DISCONNECT = 0b0010_0000; - const DROPPING = 0b0100_0000; + const UPGRADE = 0b0100_0000; } } @@ -49,7 +49,22 @@ where U: Service), Response = ()>, U::Error: fmt::Display, { - inner: Option>, + inner: DispatcherState, +} + +enum DispatcherState +where + S: Service, + S::Error: Into, + B: MessageBody, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + Normal(InnerDispatcher), + Upgrade(U::Future), + None, } struct InnerDispatcher @@ -83,6 +98,7 @@ where enum DispatcherMessage { Item(Request), + Upgrade(Request), Error(Response<()>), } @@ -121,18 +137,24 @@ where } } -impl fmt::Debug for State -where - S: Service, - X: Service, - B: MessageBody, -{ - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +enum PollResponse { + Upgrade(Request), + DoNothing, + DrainWriteBuf, +} + +impl PartialEq for PollResponse { + fn eq(&self, other: &PollResponse) -> bool { match self { - State::None => write!(f, "State::None"), - State::ExpectCall(_) => write!(f, "State::ExceptCall"), - State::ServiceCall(_) => write!(f, "State::ServiceCall"), - State::SendPayload(_) => write!(f, "State::SendPayload"), + PollResponse::DrainWriteBuf => match other { + PollResponse::DrainWriteBuf => true, + _ => false, + }, + PollResponse::DoNothing => match other { + PollResponse::DoNothing => true, + _ => false, + }, + _ => false, } } } @@ -197,7 +219,7 @@ where }; Dispatcher { - inner: Some(InnerDispatcher { + inner: DispatcherState::Normal(InnerDispatcher { io, codec, read_buf, @@ -230,7 +252,10 @@ where U::Error: fmt::Display, { fn can_read(&self) -> bool { - if self.flags.contains(Flags::READ_DISCONNECT) { + if self + .flags + .intersects(Flags::READ_DISCONNECT | Flags::UPGRADE) + { false } else if let Some(ref info) = self.payload { info.need_read() == PayloadStatus::Read @@ -315,7 +340,7 @@ where .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); } - fn poll_response(&mut self) -> Result { + fn poll_response(&mut self) -> Result { loop { let state = match self.state { State::None => match self.messages.pop_front() { @@ -325,6 +350,9 @@ where Some(DispatcherMessage::Error(res)) => { Some(self.send_response(res, ResponseBody::Other(Body::Empty))?) } + Some(DispatcherMessage::Upgrade(req)) => { + return Ok(PollResponse::Upgrade(req)); + } None => None, }, State::ExpectCall(ref mut fut) => match fut.poll() { @@ -374,10 +402,10 @@ where )?; self.state = State::None; } - Async::NotReady => return Ok(false), + Async::NotReady => return Ok(PollResponse::DoNothing), } } else { - return Ok(true); + return Ok(PollResponse::DrainWriteBuf); } break; } @@ -405,7 +433,7 @@ where break; } - Ok(false) + Ok(PollResponse::DoNothing) } fn handle_request(&mut self, req: Request) -> Result, DispatchError> { @@ -461,15 +489,18 @@ where match msg { Message::Item(mut req) => { - match self.codec.message_type() { - MessageType::Payload | MessageType::Stream => { - let (ps, pl) = Payload::create(false); - let (req1, _) = - req.replace_payload(crate::Payload::H1(pl)); - req = req1; - self.payload = Some(ps); - } - _ => (), + let pl = self.codec.message_type(); + + if pl == MessageType::Stream && self.upgrade.is_some() { + self.messages.push_back(DispatcherMessage::Upgrade(req)); + break; + } + if pl == MessageType::Payload || pl == MessageType::Stream { + let (ps, pl) = Payload::create(false); + let (req1, _) = + req.replace_payload(crate::Payload::H1(pl)); + req = req1; + self.payload = Some(ps); } // handle request early @@ -633,80 +664,112 @@ where #[inline] fn poll(&mut self) -> Poll { - let inner = self.inner.as_mut().unwrap(); - inner.poll_keepalive()?; + match self.inner { + DispatcherState::Normal(ref mut inner) => { + inner.poll_keepalive()?; - if inner.flags.contains(Flags::SHUTDOWN) { - if inner.flags.contains(Flags::WRITE_DISCONNECT) { - Ok(Async::Ready(())) - } else { - // flush buffer - inner.poll_flush()?; - if !inner.write_buf.is_empty() { - Ok(Async::NotReady) + if inner.flags.contains(Flags::SHUTDOWN) { + if inner.flags.contains(Flags::WRITE_DISCONNECT) { + Ok(Async::Ready(())) + } else { + // flush buffer + inner.poll_flush()?; + if !inner.write_buf.is_empty() { + Ok(Async::NotReady) + } else { + match inner.io.shutdown()? { + Async::Ready(_) => Ok(Async::Ready(())), + Async::NotReady => Ok(Async::NotReady), + } + } + } } else { - match inner.io.shutdown()? { - Async::Ready(_) => Ok(Async::Ready(())), - Async::NotReady => Ok(Async::NotReady), + // read socket into a buf + if !inner.flags.contains(Flags::READ_DISCONNECT) { + if let Some(true) = + read_available(&mut inner.io, &mut inner.read_buf)? + { + inner.flags.insert(Flags::READ_DISCONNECT) + } + } + + inner.poll_request()?; + loop { + if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE { + inner.write_buf.reserve(HW_BUFFER_SIZE); + } + let result = inner.poll_response()?; + let drain = result == PollResponse::DrainWriteBuf; + + // switch to upgrade handler + if let PollResponse::Upgrade(req) = result { + if let DispatcherState::Normal(inner) = + std::mem::replace(&mut self.inner, DispatcherState::None) + { + let mut parts = FramedParts::with_read_buf( + inner.io, + inner.codec, + inner.read_buf, + ); + parts.write_buf = inner.write_buf; + let framed = Framed::from_parts(parts); + self.inner = DispatcherState::Upgrade( + inner.upgrade.unwrap().call((req, framed)), + ); + return self.poll(); + } else { + panic!() + } + } + + // we didnt get WouldBlock from write operation, + // so data get written to kernel completely (OSX) + // and we have to write again otherwise response can get stuck + if inner.poll_flush()? || !drain { + break; + } + } + + // client is gone + if inner.flags.contains(Flags::WRITE_DISCONNECT) { + return Ok(Async::Ready(())); + } + + let is_empty = inner.state.is_empty(); + + // read half is closed and we do not processing any responses + if inner.flags.contains(Flags::READ_DISCONNECT) && is_empty { + inner.flags.insert(Flags::SHUTDOWN); + } + + // keep-alive and stream errors + if is_empty && inner.write_buf.is_empty() { + if let Some(err) = inner.error.take() { + Err(err) + } + // disconnect if keep-alive is not enabled + else if inner.flags.contains(Flags::STARTED) + && !inner.flags.intersects(Flags::KEEPALIVE) + { + inner.flags.insert(Flags::SHUTDOWN); + self.poll() + } + // disconnect if shutdown + else if inner.flags.contains(Flags::SHUTDOWN) { + self.poll() + } else { + Ok(Async::NotReady) + } + } else { + Ok(Async::NotReady) } } } - } else { - // read socket into a buf - if !inner.flags.contains(Flags::READ_DISCONNECT) { - if let Some(true) = read_available(&mut inner.io, &mut inner.read_buf)? { - inner.flags.insert(Flags::READ_DISCONNECT) - } - } - - inner.poll_request()?; - loop { - if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE { - inner.write_buf.reserve(HW_BUFFER_SIZE); - } - let need_write = inner.poll_response()?; - - // we didnt get WouldBlock from write operation, - // so data get written to kernel completely (OSX) - // and we have to write again otherwise response can get stuck - if inner.poll_flush()? || !need_write { - break; - } - } - - // client is gone - if inner.flags.contains(Flags::WRITE_DISCONNECT) { - return Ok(Async::Ready(())); - } - - let is_empty = inner.state.is_empty(); - - // read half is closed and we do not processing any responses - if inner.flags.contains(Flags::READ_DISCONNECT) && is_empty { - inner.flags.insert(Flags::SHUTDOWN); - } - - // keep-alive and stream errors - if is_empty && inner.write_buf.is_empty() { - if let Some(err) = inner.error.take() { - Err(err) - } - // disconnect if keep-alive is not enabled - else if inner.flags.contains(Flags::STARTED) - && !inner.flags.intersects(Flags::KEEPALIVE) - { - inner.flags.insert(Flags::SHUTDOWN); - self.poll() - } - // disconnect if shutdown - else if inner.flags.contains(Flags::SHUTDOWN) { - self.poll() - } else { - Ok(Async::NotReady) - } - } else { - Ok(Async::NotReady) - } + DispatcherState::Upgrade(ref mut fut) => fut.poll().map_err(|e| { + error!("Upgrade handler error: {}", e); + DispatchError::Upgrade + }), + DispatcherState::None => panic!(), } } } diff --git a/actix-http/tests/test_client.rs b/actix-http/tests/test_client.rs index cfe0999fd..6d382478f 100644 --- a/actix-http/tests/test_client.rs +++ b/actix-http/tests/test_client.rs @@ -31,9 +31,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ fn test_h1_v2() { env_logger::init(); let mut srv = TestServer::new(move || { - HttpService::build() - .finish(|_| future::ok::<_, ()>(Response::Ok().body(STR))) - .map(|_| ()) + HttpService::build().finish(|_| future::ok::<_, ()>(Response::Ok().body(STR))) }); let response = srv.block_on(srv.get("/").send()).unwrap(); assert!(response.status().is_success()); diff --git a/actix-http/tests/test_ws.rs b/actix-http/tests/test_ws.rs new file mode 100644 index 000000000..b6be748bd --- /dev/null +++ b/actix-http/tests/test_ws.rs @@ -0,0 +1,76 @@ +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_http::{body, h1, ws, Error, HttpService, Request, Response}; +use actix_http_test::TestServer; +use actix_utils::framed::FramedTransport; +use bytes::{Bytes, BytesMut}; +use futures::future::{self, ok}; +use futures::{Future, Sink, Stream}; + +fn ws_service( + (req, framed): (Request, Framed), +) -> impl Future { + let res = ws::handshake(&req).unwrap().message_body(()); + + framed + .send((res, body::BodySize::None).into()) + .map_err(|_| panic!()) + .and_then(|framed| { + FramedTransport::new(framed.into_framed(ws::Codec::new()), service) + .map_err(|_| panic!()) + }) +} + +fn service(msg: ws::Frame) -> impl Future { + let msg = match msg { + ws::Frame::Ping(msg) => ws::Message::Pong(msg), + ws::Frame::Text(text) => { + ws::Message::Text(String::from_utf8_lossy(&text.unwrap()).to_string()) + } + ws::Frame::Binary(bin) => ws::Message::Binary(bin.unwrap().freeze()), + ws::Frame::Close(reason) => ws::Message::Close(reason), + _ => panic!(), + }; + ok(msg) +} + +#[test] +fn test_simple() { + let mut srv = TestServer::new(|| { + HttpService::build() + .upgrade(ws_service) + .finish(|_| future::ok::<_, ()>(Response::NotFound())) + }); + + // client service + let framed = srv.ws().unwrap(); + let framed = srv + .block_on(framed.send(ws::Message::Text("text".to_string()))) + .unwrap(); + let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text"))))); + + let framed = srv + .block_on(framed.send(ws::Message::Binary("text".into()))) + .unwrap(); + let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!( + item, + Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))) + ); + + let framed = srv + .block_on(framed.send(ws::Message::Ping("text".into()))) + .unwrap(); + let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into()))); + + let framed = srv + .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))) + .unwrap(); + + let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!( + item, + Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) + ); +}