1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-11-24 00:21:08 +01:00

change behavior of default upgrade handler (#2071)

This commit is contained in:
fakeshadow 2021-03-13 14:20:18 -08:00 committed by GitHub
parent 22dcc31193
commit 515d0e3fb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 57 additions and 39 deletions

View File

@ -951,14 +951,15 @@ mod tests {
use std::str; use std::str;
use actix_service::fn_service; use actix_service::fn_service;
use futures_util::future::{lazy, ready}; use futures_util::future::{lazy, ready, Ready};
use super::*; use super::*;
use crate::test::TestBuffer;
use crate::{error::Error, KeepAlive};
use crate::{ use crate::{
error::Error,
h1::{ExpectHandler, UpgradeHandler}, h1::{ExpectHandler, UpgradeHandler},
test::TestSeqBuffer, http::Method,
test::{TestBuffer, TestSeqBuffer},
HttpMessage, KeepAlive,
}; };
fn find_slice(haystack: &[u8], needle: &[u8], from: usize) -> Option<usize> { fn find_slice(haystack: &[u8], needle: &[u8], from: usize) -> Option<usize> {
@ -1282,14 +1283,30 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_upgrade() { async fn test_upgrade() {
struct TestUpgrade;
impl<T> Service<(Request, Framed<T, Codec>)> for TestUpgrade {
type Response = ();
type Error = Error;
type Future = Ready<Result<Self::Response, Self::Error>>;
actix_service::always_ready!();
fn call(&self, (req, _framed): (Request, Framed<T, Codec>)) -> Self::Future {
assert_eq!(req.method(), Method::GET);
assert!(req.upgrade());
assert_eq!(req.headers().get("upgrade").unwrap(), "websocket");
ready(Ok(()))
}
}
lazy(|cx| { lazy(|cx| {
let mut buf = TestSeqBuffer::empty(); let mut buf = TestSeqBuffer::empty();
let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None);
let services = let services = HttpFlow::new(ok_service(), ExpectHandler, Some(TestUpgrade));
HttpFlow::new(ok_service(), ExpectHandler, Some(UpgradeHandler));
let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( let h1 = Dispatcher::<_, _, _, _, TestUpgrade>::new(
buf.clone(), buf.clone(),
cfg, cfg,
services, services,

View File

@ -2,7 +2,7 @@ use std::task::Poll;
use actix_codec::Framed; use actix_codec::Framed;
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use futures_util::future::{ready, Ready}; use futures_core::future::LocalBoxFuture;
use crate::error::Error; use crate::error::Error;
use crate::h1::Codec; use crate::h1::Codec;
@ -16,7 +16,7 @@ impl<T> ServiceFactory<(Request, Framed<T, Codec>)> for UpgradeHandler {
type Config = (); type Config = ();
type Service = UpgradeHandler; type Service = UpgradeHandler;
type InitError = Error; type InitError = Error;
type Future = Ready<Result<Self::Service, Self::InitError>>; type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
unimplemented!() unimplemented!()
@ -26,11 +26,11 @@ impl<T> ServiceFactory<(Request, Framed<T, Codec>)> for UpgradeHandler {
impl<T> Service<(Request, Framed<T, Codec>)> for UpgradeHandler { impl<T> Service<(Request, Framed<T, Codec>)> for UpgradeHandler {
type Response = (); type Response = ();
type Error = Error; type Error = Error;
type Future = Ready<Result<Self::Response, Self::Error>>; type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
actix_service::always_ready!(); actix_service::always_ready!();
fn call(&self, _: (Request, Framed<T, Codec>)) -> Self::Future { fn call(&self, _: (Request, Framed<T, Codec>)) -> Self::Future {
ready(Ok(())) unimplemented!()
} }
} }

View File

@ -12,13 +12,6 @@ use actix_http::{
use actix_server::{Server, ServerBuilder}; use actix_server::{Server, ServerBuilder};
use actix_service::{map_config, IntoServiceFactory, Service, ServiceFactory}; use actix_service::{map_config, IntoServiceFactory, Service, ServiceFactory};
#[cfg(unix)]
use actix_http::Protocol;
#[cfg(unix)]
use actix_service::pipeline_factory;
#[cfg(unix)]
use futures_util::future::ok;
#[cfg(feature = "openssl")] #[cfg(feature = "openssl")]
use actix_tls::accept::openssl::{AlpnError, SslAcceptor, SslAcceptorBuilder}; use actix_tls::accept::openssl::{AlpnError, SslAcceptor, SslAcceptorBuilder};
#[cfg(feature = "rustls")] #[cfg(feature = "rustls")]
@ -489,7 +482,9 @@ where
#[cfg(unix)] #[cfg(unix)]
/// Start listening for unix domain (UDS) connections on existing listener. /// Start listening for unix domain (UDS) connections on existing listener.
pub fn listen_uds(mut self, lst: std::os::unix::net::UnixListener) -> io::Result<Self> { pub fn listen_uds(mut self, lst: std::os::unix::net::UnixListener) -> io::Result<Self> {
use actix_http::Protocol;
use actix_rt::net::UnixStream; use actix_rt::net::UnixStream;
use actix_service::pipeline_factory;
let cfg = self.config.clone(); let cfg = self.config.clone();
let factory = self.factory.clone(); let factory = self.factory.clone();
@ -511,13 +506,16 @@ where
c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)), c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)),
); );
pipeline_factory(|io: UnixStream| ok((io, Protocol::Http1, None))).and_then({ pipeline_factory(|io: UnixStream| async { Ok((io, Protocol::Http1, None)) })
.and_then({
let svc = HttpService::build() let svc = HttpService::build()
.keep_alive(c.keep_alive) .keep_alive(c.keep_alive)
.client_timeout(c.client_timeout); .client_timeout(c.client_timeout);
let svc = if let Some(handler) = on_connect_fn.clone() { let svc = if let Some(handler) = on_connect_fn.clone() {
svc.on_connect_ext(move |io: &_, ext: _| (&*handler)(io as &dyn Any, ext)) svc.on_connect_ext(move |io: &_, ext: _| {
(&*handler)(io as &dyn Any, ext)
})
} else { } else {
svc svc
}; };
@ -534,7 +532,9 @@ where
where where
A: AsRef<std::path::Path>, A: AsRef<std::path::Path>,
{ {
use actix_http::Protocol;
use actix_rt::net::UnixStream; use actix_rt::net::UnixStream;
use actix_service::pipeline_factory;
let cfg = self.config.clone(); let cfg = self.config.clone();
let factory = self.factory.clone(); let factory = self.factory.clone();
@ -555,7 +555,8 @@ where
socket_addr, socket_addr,
c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)), c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)),
); );
pipeline_factory(|io: UnixStream| ok((io, Protocol::Http1, None))).and_then( pipeline_factory(|io: UnixStream| async { Ok((io, Protocol::Http1, None)) })
.and_then(
HttpService::build() HttpService::build()
.keep_alive(c.keep_alive) .keep_alive(c.keep_alive)
.client_timeout(c.client_timeout) .client_timeout(c.client_timeout)