mirror of
https://github.com/actix/actix-extras.git
synced 2024-11-23 23:51:06 +01:00
add upgrade service support to h1 dispatcher
This commit is contained in:
parent
43d325a139
commit
561f83d044
@ -1,9 +1,18 @@
|
||||
# Changes
|
||||
|
||||
## [0.1.0-alpha.5] - 2019-04-xx
|
||||
|
||||
### Added
|
||||
|
||||
* Allow to use custom service for upgrade requests
|
||||
|
||||
|
||||
## [0.1.0-alpha.4] - 2019-04-08
|
||||
|
||||
### Added
|
||||
|
||||
* Allow to use custom `Expect` handler
|
||||
|
||||
* Add minimal `std::error::Error` impl for `Error`
|
||||
|
||||
### Changed
|
||||
|
@ -115,6 +115,27 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// Provide service for custom `Connection: UPGRADE` support.
|
||||
///
|
||||
/// If service is provided then normal requests handling get halted
|
||||
/// and this service get called with original request and framed object.
|
||||
pub fn upgrade<F, U1>(self, upgrade: F) -> HttpServiceBuilder<T, S, X, U1>
|
||||
where
|
||||
F: IntoNewService<U1>,
|
||||
U1: NewService<Request = (Request, Framed<T, Codec>), Response = ()>,
|
||||
U1::Error: fmt::Display,
|
||||
U1::InitError: fmt::Debug,
|
||||
{
|
||||
HttpServiceBuilder {
|
||||
keep_alive: self.keep_alive,
|
||||
client_timeout: self.client_timeout,
|
||||
client_disconnect: self.client_disconnect,
|
||||
expect: self.expect,
|
||||
upgrade: Some(upgrade.into_new_service()),
|
||||
_t: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Finish service configuration and create *http service* for HTTP/1 protocol.
|
||||
pub fn h1<F, P, B>(self, service: F) -> H1Service<T, P, S, B, X, U>
|
||||
where
|
||||
|
@ -350,6 +350,9 @@ pub enum DispatchError {
|
||||
/// Service error
|
||||
Service(Error),
|
||||
|
||||
/// Upgrade service error
|
||||
Upgrade,
|
||||
|
||||
/// An `io::Error` that occurred while trying to read or write to a network
|
||||
/// stream.
|
||||
#[display(fmt = "IO error: {}", _0)]
|
||||
|
@ -2,7 +2,7 @@ use std::collections::VecDeque;
|
||||
use std::time::Instant;
|
||||
use std::{fmt, io};
|
||||
|
||||
use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed};
|
||||
use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts};
|
||||
use actix_service::Service;
|
||||
use actix_utils::cloneable::CloneableService;
|
||||
use bitflags::bitflags;
|
||||
@ -34,7 +34,7 @@ bitflags! {
|
||||
const SHUTDOWN = 0b0000_1000;
|
||||
const READ_DISCONNECT = 0b0001_0000;
|
||||
const WRITE_DISCONNECT = 0b0010_0000;
|
||||
const DROPPING = 0b0100_0000;
|
||||
const UPGRADE = 0b0100_0000;
|
||||
}
|
||||
}
|
||||
|
||||
@ -49,7 +49,22 @@ where
|
||||
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
|
||||
U::Error: fmt::Display,
|
||||
{
|
||||
inner: Option<InnerDispatcher<T, S, B, X, U>>,
|
||||
inner: DispatcherState<T, S, B, X, U>,
|
||||
}
|
||||
|
||||
enum DispatcherState<T, S, B, X, U>
|
||||
where
|
||||
S: Service<Request = Request>,
|
||||
S::Error: Into<Error>,
|
||||
B: MessageBody,
|
||||
X: Service<Request = Request, Response = Request>,
|
||||
X::Error: Into<Error>,
|
||||
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
|
||||
U::Error: fmt::Display,
|
||||
{
|
||||
Normal(InnerDispatcher<T, S, B, X, U>),
|
||||
Upgrade(U::Future),
|
||||
None,
|
||||
}
|
||||
|
||||
struct InnerDispatcher<T, S, B, X, U>
|
||||
@ -83,6 +98,7 @@ where
|
||||
|
||||
enum DispatcherMessage {
|
||||
Item(Request),
|
||||
Upgrade(Request),
|
||||
Error(Response<()>),
|
||||
}
|
||||
|
||||
@ -121,18 +137,24 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, B, X> fmt::Debug for State<S, B, X>
|
||||
where
|
||||
S: Service<Request = Request>,
|
||||
X: Service<Request = Request, Response = Request>,
|
||||
B: MessageBody,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
enum PollResponse {
|
||||
Upgrade(Request),
|
||||
DoNothing,
|
||||
DrainWriteBuf,
|
||||
}
|
||||
|
||||
impl PartialEq for PollResponse {
|
||||
fn eq(&self, other: &PollResponse) -> bool {
|
||||
match self {
|
||||
State::None => write!(f, "State::None"),
|
||||
State::ExpectCall(_) => write!(f, "State::ExceptCall"),
|
||||
State::ServiceCall(_) => write!(f, "State::ServiceCall"),
|
||||
State::SendPayload(_) => write!(f, "State::SendPayload"),
|
||||
PollResponse::DrainWriteBuf => match other {
|
||||
PollResponse::DrainWriteBuf => true,
|
||||
_ => false,
|
||||
},
|
||||
PollResponse::DoNothing => match other {
|
||||
PollResponse::DoNothing => true,
|
||||
_ => false,
|
||||
},
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -197,7 +219,7 @@ where
|
||||
};
|
||||
|
||||
Dispatcher {
|
||||
inner: Some(InnerDispatcher {
|
||||
inner: DispatcherState::Normal(InnerDispatcher {
|
||||
io,
|
||||
codec,
|
||||
read_buf,
|
||||
@ -230,7 +252,10 @@ where
|
||||
U::Error: fmt::Display,
|
||||
{
|
||||
fn can_read(&self) -> bool {
|
||||
if self.flags.contains(Flags::READ_DISCONNECT) {
|
||||
if self
|
||||
.flags
|
||||
.intersects(Flags::READ_DISCONNECT | Flags::UPGRADE)
|
||||
{
|
||||
false
|
||||
} else if let Some(ref info) = self.payload {
|
||||
info.need_read() == PayloadStatus::Read
|
||||
@ -315,7 +340,7 @@ where
|
||||
.extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
|
||||
}
|
||||
|
||||
fn poll_response(&mut self) -> Result<bool, DispatchError> {
|
||||
fn poll_response(&mut self) -> Result<PollResponse, DispatchError> {
|
||||
loop {
|
||||
let state = match self.state {
|
||||
State::None => match self.messages.pop_front() {
|
||||
@ -325,6 +350,9 @@ where
|
||||
Some(DispatcherMessage::Error(res)) => {
|
||||
Some(self.send_response(res, ResponseBody::Other(Body::Empty))?)
|
||||
}
|
||||
Some(DispatcherMessage::Upgrade(req)) => {
|
||||
return Ok(PollResponse::Upgrade(req));
|
||||
}
|
||||
None => None,
|
||||
},
|
||||
State::ExpectCall(ref mut fut) => match fut.poll() {
|
||||
@ -374,10 +402,10 @@ where
|
||||
)?;
|
||||
self.state = State::None;
|
||||
}
|
||||
Async::NotReady => return Ok(false),
|
||||
Async::NotReady => return Ok(PollResponse::DoNothing),
|
||||
}
|
||||
} else {
|
||||
return Ok(true);
|
||||
return Ok(PollResponse::DrainWriteBuf);
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -405,7 +433,7 @@ where
|
||||
break;
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
Ok(PollResponse::DoNothing)
|
||||
}
|
||||
|
||||
fn handle_request(&mut self, req: Request) -> Result<State<S, B, X>, DispatchError> {
|
||||
@ -461,15 +489,18 @@ where
|
||||
|
||||
match msg {
|
||||
Message::Item(mut req) => {
|
||||
match self.codec.message_type() {
|
||||
MessageType::Payload | MessageType::Stream => {
|
||||
let (ps, pl) = Payload::create(false);
|
||||
let (req1, _) =
|
||||
req.replace_payload(crate::Payload::H1(pl));
|
||||
req = req1;
|
||||
self.payload = Some(ps);
|
||||
}
|
||||
_ => (),
|
||||
let pl = self.codec.message_type();
|
||||
|
||||
if pl == MessageType::Stream && self.upgrade.is_some() {
|
||||
self.messages.push_back(DispatcherMessage::Upgrade(req));
|
||||
break;
|
||||
}
|
||||
if pl == MessageType::Payload || pl == MessageType::Stream {
|
||||
let (ps, pl) = Payload::create(false);
|
||||
let (req1, _) =
|
||||
req.replace_payload(crate::Payload::H1(pl));
|
||||
req = req1;
|
||||
self.payload = Some(ps);
|
||||
}
|
||||
|
||||
// handle request early
|
||||
@ -633,80 +664,112 @@ where
|
||||
|
||||
#[inline]
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
let inner = self.inner.as_mut().unwrap();
|
||||
inner.poll_keepalive()?;
|
||||
match self.inner {
|
||||
DispatcherState::Normal(ref mut inner) => {
|
||||
inner.poll_keepalive()?;
|
||||
|
||||
if inner.flags.contains(Flags::SHUTDOWN) {
|
||||
if inner.flags.contains(Flags::WRITE_DISCONNECT) {
|
||||
Ok(Async::Ready(()))
|
||||
} else {
|
||||
// flush buffer
|
||||
inner.poll_flush()?;
|
||||
if !inner.write_buf.is_empty() {
|
||||
Ok(Async::NotReady)
|
||||
if inner.flags.contains(Flags::SHUTDOWN) {
|
||||
if inner.flags.contains(Flags::WRITE_DISCONNECT) {
|
||||
Ok(Async::Ready(()))
|
||||
} else {
|
||||
// flush buffer
|
||||
inner.poll_flush()?;
|
||||
if !inner.write_buf.is_empty() {
|
||||
Ok(Async::NotReady)
|
||||
} else {
|
||||
match inner.io.shutdown()? {
|
||||
Async::Ready(_) => Ok(Async::Ready(())),
|
||||
Async::NotReady => Ok(Async::NotReady),
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match inner.io.shutdown()? {
|
||||
Async::Ready(_) => Ok(Async::Ready(())),
|
||||
Async::NotReady => Ok(Async::NotReady),
|
||||
// read socket into a buf
|
||||
if !inner.flags.contains(Flags::READ_DISCONNECT) {
|
||||
if let Some(true) =
|
||||
read_available(&mut inner.io, &mut inner.read_buf)?
|
||||
{
|
||||
inner.flags.insert(Flags::READ_DISCONNECT)
|
||||
}
|
||||
}
|
||||
|
||||
inner.poll_request()?;
|
||||
loop {
|
||||
if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE {
|
||||
inner.write_buf.reserve(HW_BUFFER_SIZE);
|
||||
}
|
||||
let result = inner.poll_response()?;
|
||||
let drain = result == PollResponse::DrainWriteBuf;
|
||||
|
||||
// switch to upgrade handler
|
||||
if let PollResponse::Upgrade(req) = result {
|
||||
if let DispatcherState::Normal(inner) =
|
||||
std::mem::replace(&mut self.inner, DispatcherState::None)
|
||||
{
|
||||
let mut parts = FramedParts::with_read_buf(
|
||||
inner.io,
|
||||
inner.codec,
|
||||
inner.read_buf,
|
||||
);
|
||||
parts.write_buf = inner.write_buf;
|
||||
let framed = Framed::from_parts(parts);
|
||||
self.inner = DispatcherState::Upgrade(
|
||||
inner.upgrade.unwrap().call((req, framed)),
|
||||
);
|
||||
return self.poll();
|
||||
} else {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
|
||||
// we didnt get WouldBlock from write operation,
|
||||
// so data get written to kernel completely (OSX)
|
||||
// and we have to write again otherwise response can get stuck
|
||||
if inner.poll_flush()? || !drain {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// client is gone
|
||||
if inner.flags.contains(Flags::WRITE_DISCONNECT) {
|
||||
return Ok(Async::Ready(()));
|
||||
}
|
||||
|
||||
let is_empty = inner.state.is_empty();
|
||||
|
||||
// read half is closed and we do not processing any responses
|
||||
if inner.flags.contains(Flags::READ_DISCONNECT) && is_empty {
|
||||
inner.flags.insert(Flags::SHUTDOWN);
|
||||
}
|
||||
|
||||
// keep-alive and stream errors
|
||||
if is_empty && inner.write_buf.is_empty() {
|
||||
if let Some(err) = inner.error.take() {
|
||||
Err(err)
|
||||
}
|
||||
// disconnect if keep-alive is not enabled
|
||||
else if inner.flags.contains(Flags::STARTED)
|
||||
&& !inner.flags.intersects(Flags::KEEPALIVE)
|
||||
{
|
||||
inner.flags.insert(Flags::SHUTDOWN);
|
||||
self.poll()
|
||||
}
|
||||
// disconnect if shutdown
|
||||
else if inner.flags.contains(Flags::SHUTDOWN) {
|
||||
self.poll()
|
||||
} else {
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
} else {
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// read socket into a buf
|
||||
if !inner.flags.contains(Flags::READ_DISCONNECT) {
|
||||
if let Some(true) = read_available(&mut inner.io, &mut inner.read_buf)? {
|
||||
inner.flags.insert(Flags::READ_DISCONNECT)
|
||||
}
|
||||
}
|
||||
|
||||
inner.poll_request()?;
|
||||
loop {
|
||||
if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE {
|
||||
inner.write_buf.reserve(HW_BUFFER_SIZE);
|
||||
}
|
||||
let need_write = inner.poll_response()?;
|
||||
|
||||
// we didnt get WouldBlock from write operation,
|
||||
// so data get written to kernel completely (OSX)
|
||||
// and we have to write again otherwise response can get stuck
|
||||
if inner.poll_flush()? || !need_write {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// client is gone
|
||||
if inner.flags.contains(Flags::WRITE_DISCONNECT) {
|
||||
return Ok(Async::Ready(()));
|
||||
}
|
||||
|
||||
let is_empty = inner.state.is_empty();
|
||||
|
||||
// read half is closed and we do not processing any responses
|
||||
if inner.flags.contains(Flags::READ_DISCONNECT) && is_empty {
|
||||
inner.flags.insert(Flags::SHUTDOWN);
|
||||
}
|
||||
|
||||
// keep-alive and stream errors
|
||||
if is_empty && inner.write_buf.is_empty() {
|
||||
if let Some(err) = inner.error.take() {
|
||||
Err(err)
|
||||
}
|
||||
// disconnect if keep-alive is not enabled
|
||||
else if inner.flags.contains(Flags::STARTED)
|
||||
&& !inner.flags.intersects(Flags::KEEPALIVE)
|
||||
{
|
||||
inner.flags.insert(Flags::SHUTDOWN);
|
||||
self.poll()
|
||||
}
|
||||
// disconnect if shutdown
|
||||
else if inner.flags.contains(Flags::SHUTDOWN) {
|
||||
self.poll()
|
||||
} else {
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
} else {
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
DispatcherState::Upgrade(ref mut fut) => fut.poll().map_err(|e| {
|
||||
error!("Upgrade handler error: {}", e);
|
||||
DispatchError::Upgrade
|
||||
}),
|
||||
DispatcherState::None => panic!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -31,9 +31,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
|
||||
fn test_h1_v2() {
|
||||
env_logger::init();
|
||||
let mut srv = TestServer::new(move || {
|
||||
HttpService::build()
|
||||
.finish(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
|
||||
.map(|_| ())
|
||||
HttpService::build().finish(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
|
||||
});
|
||||
let response = srv.block_on(srv.get("/").send()).unwrap();
|
||||
assert!(response.status().is_success());
|
||||
|
76
actix-http/tests/test_ws.rs
Normal file
76
actix-http/tests/test_ws.rs
Normal file
@ -0,0 +1,76 @@
|
||||
use actix_codec::{AsyncRead, AsyncWrite, Framed};
|
||||
use actix_http::{body, h1, ws, Error, HttpService, Request, Response};
|
||||
use actix_http_test::TestServer;
|
||||
use actix_utils::framed::FramedTransport;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::future::{self, ok};
|
||||
use futures::{Future, Sink, Stream};
|
||||
|
||||
fn ws_service<T: AsyncRead + AsyncWrite>(
|
||||
(req, framed): (Request, Framed<T, h1::Codec>),
|
||||
) -> impl Future<Item = (), Error = Error> {
|
||||
let res = ws::handshake(&req).unwrap().message_body(());
|
||||
|
||||
framed
|
||||
.send((res, body::BodySize::None).into())
|
||||
.map_err(|_| panic!())
|
||||
.and_then(|framed| {
|
||||
FramedTransport::new(framed.into_framed(ws::Codec::new()), service)
|
||||
.map_err(|_| panic!())
|
||||
})
|
||||
}
|
||||
|
||||
fn service(msg: ws::Frame) -> impl Future<Item = ws::Message, Error = Error> {
|
||||
let msg = match msg {
|
||||
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
|
||||
ws::Frame::Text(text) => {
|
||||
ws::Message::Text(String::from_utf8_lossy(&text.unwrap()).to_string())
|
||||
}
|
||||
ws::Frame::Binary(bin) => ws::Message::Binary(bin.unwrap().freeze()),
|
||||
ws::Frame::Close(reason) => ws::Message::Close(reason),
|
||||
_ => panic!(),
|
||||
};
|
||||
ok(msg)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple() {
|
||||
let mut srv = TestServer::new(|| {
|
||||
HttpService::build()
|
||||
.upgrade(ws_service)
|
||||
.finish(|_| future::ok::<_, ()>(Response::NotFound()))
|
||||
});
|
||||
|
||||
// client service
|
||||
let framed = srv.ws().unwrap();
|
||||
let framed = srv
|
||||
.block_on(framed.send(ws::Message::Text("text".to_string())))
|
||||
.unwrap();
|
||||
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
|
||||
assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text")))));
|
||||
|
||||
let framed = srv
|
||||
.block_on(framed.send(ws::Message::Binary("text".into())))
|
||||
.unwrap();
|
||||
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
|
||||
assert_eq!(
|
||||
item,
|
||||
Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into())))
|
||||
);
|
||||
|
||||
let framed = srv
|
||||
.block_on(framed.send(ws::Message::Ping("text".into())))
|
||||
.unwrap();
|
||||
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
|
||||
assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into())));
|
||||
|
||||
let framed = srv
|
||||
.block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))))
|
||||
.unwrap();
|
||||
|
||||
let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
|
||||
assert_eq!(
|
||||
item,
|
||||
Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into())))
|
||||
);
|
||||
}
|
Loading…
Reference in New Issue
Block a user