From 515d0e3fb45d2936fad00daf0ab709ab2cc62103 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Sat, 13 Mar 2021 14:20:18 -0800 Subject: [PATCH] change behavior of default upgrade handler (#2071) --- actix-http/src/h1/dispatcher.rs | 39 ++++++++++++++++++-------- actix-http/src/h1/upgrade.rs | 8 +++--- src/server.rs | 49 +++++++++++++++++---------------- 3 files changed, 57 insertions(+), 39 deletions(-) diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index 6df579c0a..e5989e5ee 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -662,7 +662,7 @@ where // got timeout during shutdown, drop connection if this.flags.contains(Flags::SHUTDOWN) { return Err(DispatchError::DisconnectTimeout); - // exceed deadline. check for any outstanding tasks + // exceed deadline. check for any outstanding tasks } else if timer.deadline() >= *this.ka_expire { // have no task at hand. if this.state.is_empty() && this.write_buf.is_empty() { @@ -695,15 +695,15 @@ where this.flags.insert(Flags::STARTED | Flags::SHUTDOWN); this.state.set(State::None); } - // still have unfinished task. try to reset and register keep-alive. + // still have unfinished task. try to reset and register keep-alive. } else if let Some(deadline) = this.codec.config().keep_alive_expire() { timer.as_mut().reset(deadline); let _ = timer.poll(cx); } - // timer resolved but still have not met the keep-alive expire deadline. - // reset and register for later wakeup. + // timer resolved but still have not met the keep-alive expire deadline. + // reset and register for later wakeup. } else { timer.as_mut().reset(*this.ka_expire); let _ = timer.poll(cx); @@ -951,14 +951,15 @@ mod tests { use std::str; use actix_service::fn_service; - use futures_util::future::{lazy, ready}; + use futures_util::future::{lazy, ready, Ready}; use super::*; - use crate::test::TestBuffer; - use crate::{error::Error, KeepAlive}; use crate::{ + error::Error, h1::{ExpectHandler, UpgradeHandler}, - test::TestSeqBuffer, + http::Method, + test::{TestBuffer, TestSeqBuffer}, + HttpMessage, KeepAlive, }; fn find_slice(haystack: &[u8], needle: &[u8], from: usize) -> Option { @@ -1282,14 +1283,30 @@ mod tests { #[actix_rt::test] async fn test_upgrade() { + struct TestUpgrade; + + impl Service<(Request, Framed)> for TestUpgrade { + type Response = (); + type Error = Error; + type Future = Ready>; + + actix_service::always_ready!(); + + fn call(&self, (req, _framed): (Request, Framed)) -> Self::Future { + assert_eq!(req.method(), Method::GET); + assert!(req.upgrade()); + assert_eq!(req.headers().get("upgrade").unwrap(), "websocket"); + ready(Ok(())) + } + } + lazy(|cx| { let mut buf = TestSeqBuffer::empty(); let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); - let services = - HttpFlow::new(ok_service(), ExpectHandler, Some(UpgradeHandler)); + let services = HttpFlow::new(ok_service(), ExpectHandler, Some(TestUpgrade)); - let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( + let h1 = Dispatcher::<_, _, _, _, TestUpgrade>::new( buf.clone(), cfg, services, diff --git a/actix-http/src/h1/upgrade.rs b/actix-http/src/h1/upgrade.rs index 5e24d84e3..fbfbc83c1 100644 --- a/actix-http/src/h1/upgrade.rs +++ b/actix-http/src/h1/upgrade.rs @@ -2,7 +2,7 @@ use std::task::Poll; use actix_codec::Framed; use actix_service::{Service, ServiceFactory}; -use futures_util::future::{ready, Ready}; +use futures_core::future::LocalBoxFuture; use crate::error::Error; use crate::h1::Codec; @@ -16,7 +16,7 @@ impl ServiceFactory<(Request, Framed)> for UpgradeHandler { type Config = (); type Service = UpgradeHandler; type InitError = Error; - type Future = Ready>; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { unimplemented!() @@ -26,11 +26,11 @@ impl ServiceFactory<(Request, Framed)> for UpgradeHandler { impl Service<(Request, Framed)> for UpgradeHandler { type Response = (); type Error = Error; - type Future = Ready>; + type Future = LocalBoxFuture<'static, Result>; actix_service::always_ready!(); fn call(&self, _: (Request, Framed)) -> Self::Future { - ready(Ok(())) + unimplemented!() } } diff --git a/src/server.rs b/src/server.rs index d69d6570d..99355593a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -12,13 +12,6 @@ use actix_http::{ use actix_server::{Server, ServerBuilder}; use actix_service::{map_config, IntoServiceFactory, Service, ServiceFactory}; -#[cfg(unix)] -use actix_http::Protocol; -#[cfg(unix)] -use actix_service::pipeline_factory; -#[cfg(unix)] -use futures_util::future::ok; - #[cfg(feature = "openssl")] use actix_tls::accept::openssl::{AlpnError, SslAcceptor, SslAcceptorBuilder}; #[cfg(feature = "rustls")] @@ -489,7 +482,9 @@ where #[cfg(unix)] /// Start listening for unix domain (UDS) connections on existing listener. pub fn listen_uds(mut self, lst: std::os::unix::net::UnixListener) -> io::Result { + use actix_http::Protocol; use actix_rt::net::UnixStream; + use actix_service::pipeline_factory; let cfg = self.config.clone(); let factory = self.factory.clone(); @@ -511,19 +506,22 @@ where c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)), ); - pipeline_factory(|io: UnixStream| ok((io, Protocol::Http1, None))).and_then({ - let svc = HttpService::build() - .keep_alive(c.keep_alive) - .client_timeout(c.client_timeout); + pipeline_factory(|io: UnixStream| async { Ok((io, Protocol::Http1, None)) }) + .and_then({ + let svc = HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout); - let svc = if let Some(handler) = on_connect_fn.clone() { - svc.on_connect_ext(move |io: &_, ext: _| (&*handler)(io as &dyn Any, ext)) - } else { - svc - }; + let svc = if let Some(handler) = on_connect_fn.clone() { + svc.on_connect_ext(move |io: &_, ext: _| { + (&*handler)(io as &dyn Any, ext) + }) + } else { + svc + }; - svc.finish(map_config(factory(), move |_| config.clone())) - }) + svc.finish(map_config(factory(), move |_| config.clone())) + }) })?; Ok(self) } @@ -534,7 +532,9 @@ where where A: AsRef, { + use actix_http::Protocol; use actix_rt::net::UnixStream; + use actix_service::pipeline_factory; let cfg = self.config.clone(); let factory = self.factory.clone(); @@ -555,12 +555,13 @@ where socket_addr, c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)), ); - pipeline_factory(|io: UnixStream| ok((io, Protocol::Http1, None))).and_then( - HttpService::build() - .keep_alive(c.keep_alive) - .client_timeout(c.client_timeout) - .finish(map_config(factory(), move |_| config.clone())), - ) + pipeline_factory(|io: UnixStream| async { Ok((io, Protocol::Http1, None)) }) + .and_then( + HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout) + .finish(map_config(factory(), move |_| config.clone())), + ) }, )?; Ok(self)