mirror of
https://github.com/actix/actix-extras.git
synced 2024-11-30 18:34:36 +01:00
Fix poll_ready call for WebSockets upgrade (#1219)
* Fix poll_ready call for WebSockets upgrade * Poll upgrade service from H1ServiceHandler too
This commit is contained in:
parent
29ac6463e1
commit
3b860ebdc7
@ -1,5 +1,11 @@
|
|||||||
# Changes
|
# Changes
|
||||||
|
|
||||||
|
## [1.0.xx] - 2019-12-xx
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
* Poll upgrade service's readiness from HTTP service handlers
|
||||||
|
|
||||||
## [1.0.0] - 2019-12-13
|
## [1.0.0] - 2019-12-13
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
@ -66,7 +66,6 @@ where
|
|||||||
U::Error: fmt::Display,
|
U::Error: fmt::Display,
|
||||||
{
|
{
|
||||||
Normal(InnerDispatcher<T, S, B, X, U>),
|
Normal(InnerDispatcher<T, S, B, X, U>),
|
||||||
UpgradeReadiness(InnerDispatcher<T, S, B, X, U>, Request),
|
|
||||||
Upgrade(U::Future),
|
Upgrade(U::Future),
|
||||||
None,
|
None,
|
||||||
}
|
}
|
||||||
@ -764,8 +763,16 @@ where
|
|||||||
if let DispatcherState::Normal(inner) =
|
if let DispatcherState::Normal(inner) =
|
||||||
std::mem::replace(&mut self.inner, DispatcherState::None)
|
std::mem::replace(&mut self.inner, DispatcherState::None)
|
||||||
{
|
{
|
||||||
self.inner =
|
let mut parts = FramedParts::with_read_buf(
|
||||||
DispatcherState::UpgradeReadiness(inner, req);
|
inner.io,
|
||||||
|
inner.codec,
|
||||||
|
inner.read_buf,
|
||||||
|
);
|
||||||
|
parts.write_buf = inner.write_buf;
|
||||||
|
let framed = Framed::from_parts(parts);
|
||||||
|
self.inner = DispatcherState::Upgrade(
|
||||||
|
inner.upgrade.unwrap().call((req, framed)),
|
||||||
|
);
|
||||||
return self.poll(cx);
|
return self.poll(cx);
|
||||||
} else {
|
} else {
|
||||||
panic!()
|
panic!()
|
||||||
@ -815,35 +822,6 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
DispatcherState::UpgradeReadiness(ref mut inner, _) => {
|
|
||||||
let upgrade = inner.upgrade.as_mut().unwrap();
|
|
||||||
match upgrade.poll_ready(cx) {
|
|
||||||
Poll::Ready(Ok(_)) => {
|
|
||||||
if let DispatcherState::UpgradeReadiness(inner, req) =
|
|
||||||
std::mem::replace(&mut self.inner, DispatcherState::None)
|
|
||||||
{
|
|
||||||
let mut parts = FramedParts::with_read_buf(
|
|
||||||
inner.io,
|
|
||||||
inner.codec,
|
|
||||||
inner.read_buf,
|
|
||||||
);
|
|
||||||
parts.write_buf = inner.write_buf;
|
|
||||||
let framed = Framed::from_parts(parts);
|
|
||||||
self.inner = DispatcherState::Upgrade(
|
|
||||||
inner.upgrade.unwrap().call((req, framed)),
|
|
||||||
);
|
|
||||||
self.poll(cx)
|
|
||||||
} else {
|
|
||||||
panic!()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Poll::Pending => Poll::Pending,
|
|
||||||
Poll::Ready(Err(e)) => {
|
|
||||||
error!("Upgrade handler readiness check error: {}", e);
|
|
||||||
Poll::Ready(Err(DispatchError::Upgrade))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
DispatcherState::Upgrade(ref mut fut) => {
|
DispatcherState::Upgrade(ref mut fut) => {
|
||||||
unsafe { Pin::new_unchecked(fut) }.poll(cx).map_err(|e| {
|
unsafe { Pin::new_unchecked(fut) }.poll(cx).map_err(|e| {
|
||||||
error!("Upgrade handler error: {}", e);
|
error!("Upgrade handler error: {}", e);
|
||||||
|
@ -72,7 +72,7 @@ where
|
|||||||
Request = (Request, Framed<TcpStream, Codec>),
|
Request = (Request, Framed<TcpStream, Codec>),
|
||||||
Response = (),
|
Response = (),
|
||||||
>,
|
>,
|
||||||
U::Error: fmt::Display,
|
U::Error: fmt::Display + Into<Error>,
|
||||||
U::InitError: fmt::Debug,
|
U::InitError: fmt::Debug,
|
||||||
{
|
{
|
||||||
/// Create simple tcp stream service
|
/// Create simple tcp stream service
|
||||||
@ -115,7 +115,7 @@ mod openssl {
|
|||||||
Request = (Request, Framed<SslStream<TcpStream>, Codec>),
|
Request = (Request, Framed<SslStream<TcpStream>, Codec>),
|
||||||
Response = (),
|
Response = (),
|
||||||
>,
|
>,
|
||||||
U::Error: fmt::Display,
|
U::Error: fmt::Display + Into<Error>,
|
||||||
U::InitError: fmt::Debug,
|
U::InitError: fmt::Debug,
|
||||||
{
|
{
|
||||||
/// Create openssl based service
|
/// Create openssl based service
|
||||||
@ -255,7 +255,7 @@ where
|
|||||||
X::Error: Into<Error>,
|
X::Error: Into<Error>,
|
||||||
X::InitError: fmt::Debug,
|
X::InitError: fmt::Debug,
|
||||||
U: ServiceFactory<Config = (), Request = (Request, Framed<T, Codec>), Response = ()>,
|
U: ServiceFactory<Config = (), Request = (Request, Framed<T, Codec>), Response = ()>,
|
||||||
U::Error: fmt::Display,
|
U::Error: fmt::Display + Into<Error>,
|
||||||
U::InitError: fmt::Debug,
|
U::InitError: fmt::Debug,
|
||||||
{
|
{
|
||||||
type Config = ();
|
type Config = ();
|
||||||
@ -412,7 +412,7 @@ where
|
|||||||
X: Service<Request = Request, Response = Request>,
|
X: Service<Request = Request, Response = Request>,
|
||||||
X::Error: Into<Error>,
|
X::Error: Into<Error>,
|
||||||
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
|
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
|
||||||
U::Error: fmt::Display,
|
U::Error: fmt::Display + Into<Error>,
|
||||||
{
|
{
|
||||||
type Request = (T, Option<net::SocketAddr>);
|
type Request = (T, Option<net::SocketAddr>);
|
||||||
type Response = ();
|
type Response = ();
|
||||||
@ -440,6 +440,19 @@ where
|
|||||||
})?
|
})?
|
||||||
.is_ready()
|
.is_ready()
|
||||||
&& ready;
|
&& ready;
|
||||||
|
|
||||||
|
let ready = if let Some(ref mut upg) = self.upgrade {
|
||||||
|
upg.poll_ready(cx)
|
||||||
|
.map_err(|e| {
|
||||||
|
let e = e.into();
|
||||||
|
log::error!("Http service readiness error: {:?}", e);
|
||||||
|
DispatchError::Service(e)
|
||||||
|
})?
|
||||||
|
.is_ready()
|
||||||
|
&& ready
|
||||||
|
} else {
|
||||||
|
ready
|
||||||
|
};
|
||||||
|
|
||||||
if ready {
|
if ready {
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
|
@ -169,7 +169,7 @@ where
|
|||||||
Request = (Request, Framed<TcpStream, h1::Codec>),
|
Request = (Request, Framed<TcpStream, h1::Codec>),
|
||||||
Response = (),
|
Response = (),
|
||||||
>,
|
>,
|
||||||
U::Error: fmt::Display,
|
U::Error: fmt::Display + Into<Error>,
|
||||||
U::InitError: fmt::Debug,
|
U::InitError: fmt::Debug,
|
||||||
<U::Service as Service>::Future: 'static,
|
<U::Service as Service>::Future: 'static,
|
||||||
{
|
{
|
||||||
@ -214,7 +214,7 @@ mod openssl {
|
|||||||
Request = (Request, Framed<SslStream<TcpStream>, h1::Codec>),
|
Request = (Request, Framed<SslStream<TcpStream>, h1::Codec>),
|
||||||
Response = (),
|
Response = (),
|
||||||
>,
|
>,
|
||||||
U::Error: fmt::Display,
|
U::Error: fmt::Display + Into<Error>,
|
||||||
U::InitError: fmt::Debug,
|
U::InitError: fmt::Debug,
|
||||||
<U::Service as Service>::Future: 'static,
|
<U::Service as Service>::Future: 'static,
|
||||||
{
|
{
|
||||||
@ -335,7 +335,7 @@ where
|
|||||||
Request = (Request, Framed<T, h1::Codec>),
|
Request = (Request, Framed<T, h1::Codec>),
|
||||||
Response = (),
|
Response = (),
|
||||||
>,
|
>,
|
||||||
U::Error: fmt::Display,
|
U::Error: fmt::Display + Into<Error>,
|
||||||
U::InitError: fmt::Debug,
|
U::InitError: fmt::Debug,
|
||||||
<U::Service as Service>::Future: 'static,
|
<U::Service as Service>::Future: 'static,
|
||||||
{
|
{
|
||||||
@ -493,7 +493,7 @@ where
|
|||||||
X: Service<Request = Request, Response = Request>,
|
X: Service<Request = Request, Response = Request>,
|
||||||
X::Error: Into<Error>,
|
X::Error: Into<Error>,
|
||||||
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
|
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
|
||||||
U::Error: fmt::Display,
|
U::Error: fmt::Display + Into<Error>,
|
||||||
{
|
{
|
||||||
type Request = (T, Protocol, Option<net::SocketAddr>);
|
type Request = (T, Protocol, Option<net::SocketAddr>);
|
||||||
type Response = ();
|
type Response = ();
|
||||||
@ -522,6 +522,19 @@ where
|
|||||||
.is_ready()
|
.is_ready()
|
||||||
&& ready;
|
&& ready;
|
||||||
|
|
||||||
|
let ready = if let Some(ref mut upg) = self.upgrade {
|
||||||
|
upg.poll_ready(cx)
|
||||||
|
.map_err(|e| {
|
||||||
|
let e = e.into();
|
||||||
|
log::error!("Http service readiness error: {:?}", e);
|
||||||
|
DispatchError::Service(e)
|
||||||
|
})?
|
||||||
|
.is_ready()
|
||||||
|
&& ready
|
||||||
|
} else {
|
||||||
|
ready
|
||||||
|
};
|
||||||
|
|
||||||
if ready {
|
if ready {
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
} else {
|
} else {
|
||||||
|
@ -1,24 +1,70 @@
|
|||||||
|
use std::cell::Cell;
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use actix_codec::{AsyncRead, AsyncWrite, Framed};
|
use actix_codec::{AsyncRead, AsyncWrite, Framed};
|
||||||
use actix_http::{body, h1, ws, Error, HttpService, Request, Response};
|
use actix_http::{body, h1, ws, Error, HttpService, Request, Response};
|
||||||
use actix_http_test::test_server;
|
use actix_http_test::test_server;
|
||||||
|
use actix_service::{fn_factory, Service};
|
||||||
use actix_utils::framed::Dispatcher;
|
use actix_utils::framed::Dispatcher;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use futures::future;
|
use futures::future;
|
||||||
use futures::{SinkExt, StreamExt};
|
use futures::task::{Context, Poll};
|
||||||
|
use futures::{Future, SinkExt, StreamExt};
|
||||||
|
|
||||||
async fn ws_service<T: AsyncRead + AsyncWrite + Unpin>(
|
struct WsService<T>(Arc<Mutex<(PhantomData<T>, Cell<bool>)>>);
|
||||||
(req, mut framed): (Request, Framed<T, h1::Codec>),
|
|
||||||
) -> Result<(), Error> {
|
|
||||||
let res = ws::handshake(req.head()).unwrap().message_body(());
|
|
||||||
|
|
||||||
framed
|
impl<T> WsService<T> {
|
||||||
.send((res, body::BodySize::None).into())
|
fn new() -> Self {
|
||||||
.await
|
WsService(Arc::new(Mutex::new((PhantomData, Cell::new(false)))))
|
||||||
.unwrap();
|
}
|
||||||
|
|
||||||
Dispatcher::new(framed.into_framed(ws::Codec::new()), service)
|
fn set_polled(&mut self) {
|
||||||
.await
|
*self.0.lock().unwrap().1.get_mut() = true;
|
||||||
.map_err(|_| panic!())
|
}
|
||||||
|
|
||||||
|
fn was_polled(&self) -> bool {
|
||||||
|
self.0.lock().unwrap().1.get()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Clone for WsService<T> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
WsService(self.0.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Service for WsService<T>
|
||||||
|
where
|
||||||
|
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||||
|
{
|
||||||
|
type Request = (Request, Framed<T, h1::Codec>);
|
||||||
|
type Response = ();
|
||||||
|
type Error = Error;
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<(), Error>>>>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, _ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.set_polled();
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, (req, mut framed): Self::Request) -> Self::Future {
|
||||||
|
let fut = async move {
|
||||||
|
let res = ws::handshake(req.head()).unwrap().message_body(());
|
||||||
|
|
||||||
|
framed
|
||||||
|
.send((res, body::BodySize::None).into())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Dispatcher::new(framed.into_framed(ws::Codec::new()), service)
|
||||||
|
.await
|
||||||
|
.map_err(|_| panic!())
|
||||||
|
};
|
||||||
|
|
||||||
|
Box::pin(fut)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn service(msg: ws::Frame) -> Result<ws::Message, Error> {
|
async fn service(msg: ws::Frame) -> Result<ws::Message, Error> {
|
||||||
@ -37,11 +83,16 @@ async fn service(msg: ws::Frame) -> Result<ws::Message, Error> {
|
|||||||
|
|
||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
async fn test_simple() {
|
async fn test_simple() {
|
||||||
let mut srv = test_server(|| {
|
let ws_service = WsService::new();
|
||||||
HttpService::build()
|
let mut srv = test_server({
|
||||||
.upgrade(actix_service::fn_service(ws_service))
|
let ws_service = ws_service.clone();
|
||||||
.finish(|_| future::ok::<_, ()>(Response::NotFound()))
|
move || {
|
||||||
.tcp()
|
let ws_service = ws_service.clone();
|
||||||
|
HttpService::build()
|
||||||
|
.upgrade(fn_factory(move || future::ok::<_, ()>(ws_service.clone())))
|
||||||
|
.finish(|_| future::ok::<_, ()>(Response::NotFound()))
|
||||||
|
.tcp()
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// client service
|
// client service
|
||||||
@ -138,4 +189,6 @@ async fn test_simple() {
|
|||||||
item.unwrap().unwrap(),
|
item.unwrap().unwrap(),
|
||||||
ws::Frame::Close(Some(ws::CloseCode::Normal.into()))
|
ws::Frame::Close(Some(ws::CloseCode::Normal.into()))
|
||||||
);
|
);
|
||||||
|
|
||||||
|
assert!(ws_service.was_polled());
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user