1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-24 07:53:00 +01:00

add upgrade service support to h1 dispatcher

This commit is contained in:
Nikolay Kim 2019-04-08 17:49:27 -07:00
parent 43d325a139
commit 561f83d044
6 changed files with 271 additions and 101 deletions

View File

@ -1,9 +1,18 @@
# Changes
## [0.1.0-alpha.5] - 2019-04-xx
### Added
* Allow to use custom service for upgrade requests
## [0.1.0-alpha.4] - 2019-04-08
### Added
* Allow to use custom `Expect` handler
* Add minimal `std::error::Error` impl for `Error`
### Changed

View File

@ -115,6 +115,27 @@ where
}
}
/// Provide service for custom `Connection: UPGRADE` support.
///
/// If service is provided then normal requests handling get halted
/// and this service get called with original request and framed object.
pub fn upgrade<F, U1>(self, upgrade: F) -> HttpServiceBuilder<T, S, X, U1>
where
F: IntoNewService<U1>,
U1: NewService<Request = (Request, Framed<T, Codec>), Response = ()>,
U1::Error: fmt::Display,
U1::InitError: fmt::Debug,
{
HttpServiceBuilder {
keep_alive: self.keep_alive,
client_timeout: self.client_timeout,
client_disconnect: self.client_disconnect,
expect: self.expect,
upgrade: Some(upgrade.into_new_service()),
_t: PhantomData,
}
}
/// Finish service configuration and create *http service* for HTTP/1 protocol.
pub fn h1<F, P, B>(self, service: F) -> H1Service<T, P, S, B, X, U>
where

View File

@ -350,6 +350,9 @@ pub enum DispatchError {
/// Service error
Service(Error),
/// Upgrade service error
Upgrade,
/// An `io::Error` that occurred while trying to read or write to a network
/// stream.
#[display(fmt = "IO error: {}", _0)]

View File

@ -2,7 +2,7 @@ use std::collections::VecDeque;
use std::time::Instant;
use std::{fmt, io};
use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed};
use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts};
use actix_service::Service;
use actix_utils::cloneable::CloneableService;
use bitflags::bitflags;
@ -34,7 +34,7 @@ bitflags! {
const SHUTDOWN = 0b0000_1000;
const READ_DISCONNECT = 0b0001_0000;
const WRITE_DISCONNECT = 0b0010_0000;
const DROPPING = 0b0100_0000;
const UPGRADE = 0b0100_0000;
}
}
@ -49,7 +49,22 @@ where
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
{
inner: Option<InnerDispatcher<T, S, B, X, U>>,
inner: DispatcherState<T, S, B, X, U>,
}
enum DispatcherState<T, S, B, X, U>
where
S: Service<Request = Request>,
S::Error: Into<Error>,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
{
Normal(InnerDispatcher<T, S, B, X, U>),
Upgrade(U::Future),
None,
}
struct InnerDispatcher<T, S, B, X, U>
@ -83,6 +98,7 @@ where
enum DispatcherMessage {
Item(Request),
Upgrade(Request),
Error(Response<()>),
}
@ -121,18 +137,24 @@ where
}
}
impl<S, B, X> fmt::Debug for State<S, B, X>
where
S: Service<Request = Request>,
X: Service<Request = Request, Response = Request>,
B: MessageBody,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
enum PollResponse {
Upgrade(Request),
DoNothing,
DrainWriteBuf,
}
impl PartialEq for PollResponse {
fn eq(&self, other: &PollResponse) -> bool {
match self {
State::None => write!(f, "State::None"),
State::ExpectCall(_) => write!(f, "State::ExceptCall"),
State::ServiceCall(_) => write!(f, "State::ServiceCall"),
State::SendPayload(_) => write!(f, "State::SendPayload"),
PollResponse::DrainWriteBuf => match other {
PollResponse::DrainWriteBuf => true,
_ => false,
},
PollResponse::DoNothing => match other {
PollResponse::DoNothing => true,
_ => false,
},
_ => false,
}
}
}
@ -197,7 +219,7 @@ where
};
Dispatcher {
inner: Some(InnerDispatcher {
inner: DispatcherState::Normal(InnerDispatcher {
io,
codec,
read_buf,
@ -230,7 +252,10 @@ where
U::Error: fmt::Display,
{
fn can_read(&self) -> bool {
if self.flags.contains(Flags::READ_DISCONNECT) {
if self
.flags
.intersects(Flags::READ_DISCONNECT | Flags::UPGRADE)
{
false
} else if let Some(ref info) = self.payload {
info.need_read() == PayloadStatus::Read
@ -315,7 +340,7 @@ where
.extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
}
fn poll_response(&mut self) -> Result<bool, DispatchError> {
fn poll_response(&mut self) -> Result<PollResponse, DispatchError> {
loop {
let state = match self.state {
State::None => match self.messages.pop_front() {
@ -325,6 +350,9 @@ where
Some(DispatcherMessage::Error(res)) => {
Some(self.send_response(res, ResponseBody::Other(Body::Empty))?)
}
Some(DispatcherMessage::Upgrade(req)) => {
return Ok(PollResponse::Upgrade(req));
}
None => None,
},
State::ExpectCall(ref mut fut) => match fut.poll() {
@ -374,10 +402,10 @@ where
)?;
self.state = State::None;
}
Async::NotReady => return Ok(false),
Async::NotReady => return Ok(PollResponse::DoNothing),
}
} else {
return Ok(true);
return Ok(PollResponse::DrainWriteBuf);
}
break;
}
@ -405,7 +433,7 @@ where
break;
}
Ok(false)
Ok(PollResponse::DoNothing)
}
fn handle_request(&mut self, req: Request) -> Result<State<S, B, X>, DispatchError> {
@ -461,16 +489,19 @@ where
match msg {
Message::Item(mut req) => {
match self.codec.message_type() {
MessageType::Payload | MessageType::Stream => {
let pl = self.codec.message_type();
if pl == MessageType::Stream && self.upgrade.is_some() {
self.messages.push_back(DispatcherMessage::Upgrade(req));
break;
}
if pl == MessageType::Payload || pl == MessageType::Stream {
let (ps, pl) = Payload::create(false);
let (req1, _) =
req.replace_payload(crate::Payload::H1(pl));
req = req1;
self.payload = Some(ps);
}
_ => (),
}
// handle request early
if self.state.is_empty() {
@ -633,7 +664,8 @@ where
#[inline]
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let inner = self.inner.as_mut().unwrap();
match self.inner {
DispatcherState::Normal(ref mut inner) => {
inner.poll_keepalive()?;
if inner.flags.contains(Flags::SHUTDOWN) {
@ -654,7 +686,9 @@ where
} else {
// read socket into a buf
if !inner.flags.contains(Flags::READ_DISCONNECT) {
if let Some(true) = read_available(&mut inner.io, &mut inner.read_buf)? {
if let Some(true) =
read_available(&mut inner.io, &mut inner.read_buf)?
{
inner.flags.insert(Flags::READ_DISCONNECT)
}
}
@ -664,12 +698,34 @@ where
if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE {
inner.write_buf.reserve(HW_BUFFER_SIZE);
}
let need_write = inner.poll_response()?;
let result = inner.poll_response()?;
let drain = result == PollResponse::DrainWriteBuf;
// switch to upgrade handler
if let PollResponse::Upgrade(req) = result {
if let DispatcherState::Normal(inner) =
std::mem::replace(&mut self.inner, DispatcherState::None)
{
let mut parts = FramedParts::with_read_buf(
inner.io,
inner.codec,
inner.read_buf,
);
parts.write_buf = inner.write_buf;
let framed = Framed::from_parts(parts);
self.inner = DispatcherState::Upgrade(
inner.upgrade.unwrap().call((req, framed)),
);
return self.poll();
} else {
panic!()
}
}
// we didnt get WouldBlock from write operation,
// so data get written to kernel completely (OSX)
// and we have to write again otherwise response can get stuck
if inner.poll_flush()? || !need_write {
if inner.poll_flush()? || !drain {
break;
}
}
@ -709,6 +765,13 @@ where
}
}
}
DispatcherState::Upgrade(ref mut fut) => fut.poll().map_err(|e| {
error!("Upgrade handler error: {}", e);
DispatchError::Upgrade
}),
DispatcherState::None => panic!(),
}
}
}
fn read_available<T>(io: &mut T, buf: &mut BytesMut) -> Result<Option<bool>, io::Error>

View File

@ -31,9 +31,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
fn test_h1_v2() {
env_logger::init();
let mut srv = TestServer::new(move || {
HttpService::build()
.finish(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
.map(|_| ())
HttpService::build().finish(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
});
let response = srv.block_on(srv.get("/").send()).unwrap();
assert!(response.status().is_success());

View File

@ -0,0 +1,76 @@
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::{body, h1, ws, Error, HttpService, Request, Response};
use actix_http_test::TestServer;
use actix_utils::framed::FramedTransport;
use bytes::{Bytes, BytesMut};
use futures::future::{self, ok};
use futures::{Future, Sink, Stream};
fn ws_service<T: AsyncRead + AsyncWrite>(
(req, framed): (Request, Framed<T, h1::Codec>),
) -> impl Future<Item = (), Error = Error> {
let res = ws::handshake(&req).unwrap().message_body(());
framed
.send((res, body::BodySize::None).into())
.map_err(|_| panic!())
.and_then(|framed| {
FramedTransport::new(framed.into_framed(ws::Codec::new()), service)
.map_err(|_| panic!())
})
}
fn service(msg: ws::Frame) -> impl Future<Item = ws::Message, Error = Error> {
let msg = match msg {
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
ws::Frame::Text(text) => {
ws::Message::Text(String::from_utf8_lossy(&text.unwrap()).to_string())
}
ws::Frame::Binary(bin) => ws::Message::Binary(bin.unwrap().freeze()),
ws::Frame::Close(reason) => ws::Message::Close(reason),
_ => panic!(),
};
ok(msg)
}
#[test]
fn test_simple() {
let mut srv = TestServer::new(|| {
HttpService::build()
.upgrade(ws_service)
.finish(|_| future::ok::<_, ()>(Response::NotFound()))
});
// client service
let framed = srv.ws().unwrap();
let framed = srv
.block_on(framed.send(ws::Message::Text("text".to_string())))
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text")))));
let framed = srv
.block_on(framed.send(ws::Message::Binary("text".into())))
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(
item,
Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into())))
);
let framed = srv
.block_on(framed.send(ws::Message::Ping("text".into())))
.unwrap();
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into())));
let framed = srv
.block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))))
.unwrap();
let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(
item,
Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into())))
);
}