1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-24 07:53:00 +01:00

add upgrade service support to h1 dispatcher

This commit is contained in:
Nikolay Kim 2019-04-08 17:49:27 -07:00
parent 43d325a139
commit 561f83d044
6 changed files with 271 additions and 101 deletions

View File

@ -1,9 +1,18 @@
# Changes # Changes
## [0.1.0-alpha.5] - 2019-04-xx
### Added
* Allow to use custom service for upgrade requests
## [0.1.0-alpha.4] - 2019-04-08 ## [0.1.0-alpha.4] - 2019-04-08
### Added ### Added
* Allow to use custom `Expect` handler
* Add minimal `std::error::Error` impl for `Error` * Add minimal `std::error::Error` impl for `Error`
### Changed ### Changed

View File

@ -115,6 +115,27 @@ where
} }
} }
/// Provide service for custom `Connection: UPGRADE` support.
///
/// If service is provided then normal requests handling get halted
/// and this service get called with original request and framed object.
pub fn upgrade<F, U1>(self, upgrade: F) -> HttpServiceBuilder<T, S, X, U1>
where
F: IntoNewService<U1>,
U1: NewService<Request = (Request, Framed<T, Codec>), Response = ()>,
U1::Error: fmt::Display,
U1::InitError: fmt::Debug,
{
HttpServiceBuilder {
keep_alive: self.keep_alive,
client_timeout: self.client_timeout,
client_disconnect: self.client_disconnect,
expect: self.expect,
upgrade: Some(upgrade.into_new_service()),
_t: PhantomData,
}
}
/// Finish service configuration and create *http service* for HTTP/1 protocol. /// Finish service configuration and create *http service* for HTTP/1 protocol.
pub fn h1<F, P, B>(self, service: F) -> H1Service<T, P, S, B, X, U> pub fn h1<F, P, B>(self, service: F) -> H1Service<T, P, S, B, X, U>
where where

View File

@ -350,6 +350,9 @@ pub enum DispatchError {
/// Service error /// Service error
Service(Error), Service(Error),
/// Upgrade service error
Upgrade,
/// An `io::Error` that occurred while trying to read or write to a network /// An `io::Error` that occurred while trying to read or write to a network
/// stream. /// stream.
#[display(fmt = "IO error: {}", _0)] #[display(fmt = "IO error: {}", _0)]

View File

@ -2,7 +2,7 @@ use std::collections::VecDeque;
use std::time::Instant; use std::time::Instant;
use std::{fmt, io}; use std::{fmt, io};
use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts};
use actix_service::Service; use actix_service::Service;
use actix_utils::cloneable::CloneableService; use actix_utils::cloneable::CloneableService;
use bitflags::bitflags; use bitflags::bitflags;
@ -34,7 +34,7 @@ bitflags! {
const SHUTDOWN = 0b0000_1000; const SHUTDOWN = 0b0000_1000;
const READ_DISCONNECT = 0b0001_0000; const READ_DISCONNECT = 0b0001_0000;
const WRITE_DISCONNECT = 0b0010_0000; const WRITE_DISCONNECT = 0b0010_0000;
const DROPPING = 0b0100_0000; const UPGRADE = 0b0100_0000;
} }
} }
@ -49,7 +49,22 @@ where
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
inner: Option<InnerDispatcher<T, S, B, X, U>>, inner: DispatcherState<T, S, B, X, U>,
}
enum DispatcherState<T, S, B, X, U>
where
S: Service<Request = Request>,
S::Error: Into<Error>,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
{
Normal(InnerDispatcher<T, S, B, X, U>),
Upgrade(U::Future),
None,
} }
struct InnerDispatcher<T, S, B, X, U> struct InnerDispatcher<T, S, B, X, U>
@ -83,6 +98,7 @@ where
enum DispatcherMessage { enum DispatcherMessage {
Item(Request), Item(Request),
Upgrade(Request),
Error(Response<()>), Error(Response<()>),
} }
@ -121,18 +137,24 @@ where
} }
} }
impl<S, B, X> fmt::Debug for State<S, B, X> enum PollResponse {
where Upgrade(Request),
S: Service<Request = Request>, DoNothing,
X: Service<Request = Request, Response = Request>, DrainWriteBuf,
B: MessageBody, }
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { impl PartialEq for PollResponse {
fn eq(&self, other: &PollResponse) -> bool {
match self { match self {
State::None => write!(f, "State::None"), PollResponse::DrainWriteBuf => match other {
State::ExpectCall(_) => write!(f, "State::ExceptCall"), PollResponse::DrainWriteBuf => true,
State::ServiceCall(_) => write!(f, "State::ServiceCall"), _ => false,
State::SendPayload(_) => write!(f, "State::SendPayload"), },
PollResponse::DoNothing => match other {
PollResponse::DoNothing => true,
_ => false,
},
_ => false,
} }
} }
} }
@ -197,7 +219,7 @@ where
}; };
Dispatcher { Dispatcher {
inner: Some(InnerDispatcher { inner: DispatcherState::Normal(InnerDispatcher {
io, io,
codec, codec,
read_buf, read_buf,
@ -230,7 +252,10 @@ where
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
fn can_read(&self) -> bool { fn can_read(&self) -> bool {
if self.flags.contains(Flags::READ_DISCONNECT) { if self
.flags
.intersects(Flags::READ_DISCONNECT | Flags::UPGRADE)
{
false false
} else if let Some(ref info) = self.payload { } else if let Some(ref info) = self.payload {
info.need_read() == PayloadStatus::Read info.need_read() == PayloadStatus::Read
@ -315,7 +340,7 @@ where
.extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
} }
fn poll_response(&mut self) -> Result<bool, DispatchError> { fn poll_response(&mut self) -> Result<PollResponse, DispatchError> {
loop { loop {
let state = match self.state { let state = match self.state {
State::None => match self.messages.pop_front() { State::None => match self.messages.pop_front() {
@ -325,6 +350,9 @@ where
Some(DispatcherMessage::Error(res)) => { Some(DispatcherMessage::Error(res)) => {
Some(self.send_response(res, ResponseBody::Other(Body::Empty))?) Some(self.send_response(res, ResponseBody::Other(Body::Empty))?)
} }
Some(DispatcherMessage::Upgrade(req)) => {
return Ok(PollResponse::Upgrade(req));
}
None => None, None => None,
}, },
State::ExpectCall(ref mut fut) => match fut.poll() { State::ExpectCall(ref mut fut) => match fut.poll() {
@ -374,10 +402,10 @@ where
)?; )?;
self.state = State::None; self.state = State::None;
} }
Async::NotReady => return Ok(false), Async::NotReady => return Ok(PollResponse::DoNothing),
} }
} else { } else {
return Ok(true); return Ok(PollResponse::DrainWriteBuf);
} }
break; break;
} }
@ -405,7 +433,7 @@ where
break; break;
} }
Ok(false) Ok(PollResponse::DoNothing)
} }
fn handle_request(&mut self, req: Request) -> Result<State<S, B, X>, DispatchError> { fn handle_request(&mut self, req: Request) -> Result<State<S, B, X>, DispatchError> {
@ -461,16 +489,19 @@ where
match msg { match msg {
Message::Item(mut req) => { Message::Item(mut req) => {
match self.codec.message_type() { let pl = self.codec.message_type();
MessageType::Payload | MessageType::Stream => {
if pl == MessageType::Stream && self.upgrade.is_some() {
self.messages.push_back(DispatcherMessage::Upgrade(req));
break;
}
if pl == MessageType::Payload || pl == MessageType::Stream {
let (ps, pl) = Payload::create(false); let (ps, pl) = Payload::create(false);
let (req1, _) = let (req1, _) =
req.replace_payload(crate::Payload::H1(pl)); req.replace_payload(crate::Payload::H1(pl));
req = req1; req = req1;
self.payload = Some(ps); self.payload = Some(ps);
} }
_ => (),
}
// handle request early // handle request early
if self.state.is_empty() { if self.state.is_empty() {
@ -633,7 +664,8 @@ where
#[inline] #[inline]
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let inner = self.inner.as_mut().unwrap(); match self.inner {
DispatcherState::Normal(ref mut inner) => {
inner.poll_keepalive()?; inner.poll_keepalive()?;
if inner.flags.contains(Flags::SHUTDOWN) { if inner.flags.contains(Flags::SHUTDOWN) {
@ -654,7 +686,9 @@ where
} else { } else {
// read socket into a buf // read socket into a buf
if !inner.flags.contains(Flags::READ_DISCONNECT) { if !inner.flags.contains(Flags::READ_DISCONNECT) {
if let Some(true) = read_available(&mut inner.io, &mut inner.read_buf)? { if let Some(true) =
read_available(&mut inner.io, &mut inner.read_buf)?
{
inner.flags.insert(Flags::READ_DISCONNECT) inner.flags.insert(Flags::READ_DISCONNECT)
} }
} }
@ -664,12 +698,34 @@ where
if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE { if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE {
inner.write_buf.reserve(HW_BUFFER_SIZE); inner.write_buf.reserve(HW_BUFFER_SIZE);
} }
let need_write = inner.poll_response()?; let result = inner.poll_response()?;
let drain = result == PollResponse::DrainWriteBuf;
// switch to upgrade handler
if let PollResponse::Upgrade(req) = result {
if let DispatcherState::Normal(inner) =
std::mem::replace(&mut self.inner, DispatcherState::None)
{
let mut parts = FramedParts::with_read_buf(
inner.io,
inner.codec,
inner.read_buf,
);
parts.write_buf = inner.write_buf;
let framed = Framed::from_parts(parts);
self.inner = DispatcherState::Upgrade(
inner.upgrade.unwrap().call((req, framed)),
);
return self.poll();
} else {
panic!()
}
}
// we didnt get WouldBlock from write operation, // we didnt get WouldBlock from write operation,
// so data get written to kernel completely (OSX) // so data get written to kernel completely (OSX)
// and we have to write again otherwise response can get stuck // and we have to write again otherwise response can get stuck
if inner.poll_flush()? || !need_write { if inner.poll_flush()? || !drain {
break; break;
} }
} }
@ -709,6 +765,13 @@ where
} }
} }
} }
DispatcherState::Upgrade(ref mut fut) => fut.poll().map_err(|e| {
error!("Upgrade handler error: {}", e);
DispatchError::Upgrade
}),
DispatcherState::None => panic!(),
}
}
} }
fn read_available<T>(io: &mut T, buf: &mut BytesMut) -> Result<Option<bool>, io::Error> fn read_available<T>(io: &mut T, buf: &mut BytesMut) -> Result<Option<bool>, io::Error>

View File

@ -31,9 +31,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
fn test_h1_v2() { fn test_h1_v2() {
env_logger::init(); env_logger::init();
let mut srv = TestServer::new(move || { let mut srv = TestServer::new(move || {
HttpService::build() HttpService::build().finish(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
.finish(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
.map(|_| ())
}); });
let response = srv.block_on(srv.get("/").send()).unwrap(); let response = srv.block_on(srv.get("/").send()).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());

View File

@ -0,0 +1,76 @@
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::{body, h1, ws, Error, HttpService, Request, Response};
use actix_http_test::TestServer;
use actix_utils::framed::FramedTransport;
use bytes::{Bytes, BytesMut};
use futures::future::{self, ok};
use futures::{Future, Sink, Stream};
fn ws_service<T: AsyncRead + AsyncWrite>(
(req, framed): (Request, Framed<T, h1::Codec>),
) -> impl Future<Item = (), Error = Error> {
let res = ws::handshake(&req).unwrap().message_body(());
framed
.send((res, body::BodySize::None).into())
.map_err(|_| panic!())
.and_then(|framed| {
FramedTransport::new(framed.into_framed(ws::Codec::new()), service)
.map_err(|_| panic!())
})
}
fn service(msg: ws::Frame) -> impl Future<Item = ws::Message, Error = Error> {
let msg = match msg {
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
ws::Frame::Text(text) => {
ws::Message::Text(String::from_utf8_lossy(&text.unwrap()).to_string())
}
ws::Frame::Binary(bin) => ws::Message::Binary(bin.unwrap().freeze()),
ws::Frame::Close(reason) => ws::Message::Close(reason),
_ => panic!(),
};
ok(msg)
}
#[test]
fn test_simple() {
let mut srv = TestServer::new(|| {
HttpService::build()
.upgrade(ws_service)
.finish(|_| future::ok::<_, ()>(Response::NotFound()))
});
// client service
let framed = srv.ws().unwrap();
let framed = srv
.block_on(framed.send(ws::Message::Text("text".to_string())))
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text")))));
let framed = srv
.block_on(framed.send(ws::Message::Binary("text".into())))
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(
item,
Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into())))
);
let framed = srv
.block_on(framed.send(ws::Message::Ping("text".into())))
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into())));
let framed = srv
.block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))))
.unwrap();
let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(
item,
Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into())))
);
}