1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-06-25 22:49:21 +02:00

expose on_connect v2 (#1754)

Co-authored-by: Mikail Bagishov <bagishov.mikail@yandex.ru>
This commit is contained in:
Rob Ede
2020-10-30 02:03:26 +00:00
committed by GitHub
parent 4519db36b2
commit 9963a5ef54
16 changed files with 372 additions and 70 deletions

View File

@ -12,7 +12,6 @@ use bytes::{Buf, BytesMut};
use log::{error, trace};
use pin_project::pin_project;
use crate::body::{Body, BodySize, MessageBody, ResponseBody};
use crate::cloneable::CloneableService;
use crate::config::ServiceConfig;
use crate::error::{DispatchError, Error};
@ -21,6 +20,10 @@ use crate::helpers::DataFactory;
use crate::httpmessage::HttpMessage;
use crate::request::Request;
use crate::response::Response;
use crate::{
body::{Body, BodySize, MessageBody, ResponseBody},
Extensions,
};
use super::codec::Codec;
use super::payload::{Payload, PayloadSender, PayloadStatus};
@ -88,6 +91,7 @@ where
expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>,
on_connect: Option<Box<dyn DataFactory>>,
on_connect_data: Extensions,
flags: Flags,
peer_addr: Option<net::SocketAddr>,
error: Option<DispatchError>,
@ -167,7 +171,7 @@ where
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
{
/// Create http/1 dispatcher.
/// Create HTTP/1 dispatcher.
pub(crate) fn new(
stream: T,
config: ServiceConfig,
@ -175,6 +179,7 @@ where
expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>,
on_connect: Option<Box<dyn DataFactory>>,
on_connect_data: Extensions,
peer_addr: Option<net::SocketAddr>,
) -> Self {
Dispatcher::with_timeout(
@ -187,6 +192,7 @@ where
expect,
upgrade,
on_connect,
on_connect_data,
peer_addr,
)
}
@ -202,6 +208,7 @@ where
expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>,
on_connect: Option<Box<dyn DataFactory>>,
on_connect_data: Extensions,
peer_addr: Option<net::SocketAddr>,
) -> Self {
let keepalive = config.keep_alive_enabled();
@ -234,6 +241,7 @@ where
expect,
upgrade,
on_connect,
on_connect_data,
flags,
peer_addr,
ka_expire,
@ -526,11 +534,15 @@ where
let pl = this.codec.message_type();
req.head_mut().peer_addr = *this.peer_addr;
// DEPRECATED
// set on_connect data
if let Some(ref on_connect) = this.on_connect {
on_connect.set(&mut req.extensions_mut());
}
// merge on_connect_ext data into request extensions
req.extensions_mut().drain_from(this.on_connect_data);
if pl == MessageType::Stream && this.upgrade.is_some() {
this.messages.push_back(DispatcherMessage::Upgrade(req));
break;
@ -927,8 +939,10 @@ mod tests {
CloneableService::new(ExpectHandler),
None,
None,
Extensions::new(),
None,
);
match Pin::new(&mut h1).poll(cx) {
Poll::Pending => panic!(),
Poll::Ready(res) => assert!(res.is_err()),

View File

@ -18,6 +18,7 @@ use crate::error::{DispatchError, Error, ParseError};
use crate::helpers::DataFactory;
use crate::request::Request;
use crate::response::Response;
use crate::{ConnectCallback, Extensions};
use super::codec::Codec;
use super::dispatcher::Dispatcher;
@ -30,6 +31,7 @@ pub struct H1Service<T, S, B, X = ExpectHandler, U = UpgradeHandler<T>> {
expect: X,
upgrade: Option<U>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
_t: PhantomData<(T, B)>,
}
@ -52,6 +54,7 @@ where
expect: ExpectHandler,
upgrade: None,
on_connect: None,
on_connect_ext: None,
_t: PhantomData,
}
}
@ -213,6 +216,7 @@ where
srv: self.srv,
upgrade: self.upgrade,
on_connect: self.on_connect,
on_connect_ext: self.on_connect_ext,
_t: PhantomData,
}
}
@ -229,6 +233,7 @@ where
srv: self.srv,
expect: self.expect,
on_connect: self.on_connect,
on_connect_ext: self.on_connect_ext,
_t: PhantomData,
}
}
@ -241,6 +246,12 @@ where
self.on_connect = f;
self
}
/// Set on connect callback.
pub(crate) fn on_connect_ext(mut self, f: Option<Rc<ConnectCallback<T>>>) -> Self {
self.on_connect_ext = f;
self
}
}
impl<T, S, B, X, U> ServiceFactory for H1Service<T, S, B, X, U>
@ -274,6 +285,7 @@ where
expect: None,
upgrade: None,
on_connect: self.on_connect.clone(),
on_connect_ext: self.on_connect_ext.clone(),
cfg: Some(self.cfg.clone()),
_t: PhantomData,
}
@ -303,6 +315,7 @@ where
expect: Option<X::Service>,
upgrade: Option<U::Service>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
cfg: Option<ServiceConfig>,
_t: PhantomData<(T, B)>,
}
@ -352,23 +365,26 @@ where
Poll::Ready(result.map(|service| {
let this = self.as_mut().project();
H1ServiceHandler::new(
this.cfg.take().unwrap(),
service,
this.expect.take().unwrap(),
this.upgrade.take(),
this.on_connect.clone(),
this.on_connect_ext.clone(),
)
}))
}
}
/// `Service` implementation for HTTP1 transport
/// `Service` implementation for HTTP/1 transport
pub struct H1ServiceHandler<T, S: Service, B, X: Service, U: Service> {
srv: CloneableService<S>,
expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
cfg: ServiceConfig,
_t: PhantomData<(T, B)>,
}
@ -390,6 +406,7 @@ where
expect: X,
upgrade: Option<U>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
) -> H1ServiceHandler<T, S, B, X, U> {
H1ServiceHandler {
srv: CloneableService::new(srv),
@ -397,6 +414,7 @@ where
upgrade: upgrade.map(CloneableService::new),
cfg,
on_connect,
on_connect_ext,
_t: PhantomData,
}
}
@ -462,11 +480,13 @@ where
}
fn call(&mut self, (io, addr): Self::Request) -> Self::Future {
let on_connect = if let Some(ref on_connect) = self.on_connect {
Some(on_connect(&io))
} else {
None
};
let deprecated_on_connect = self.on_connect.as_ref().map(|handler| handler(&io));
let mut connect_extensions = Extensions::new();
if let Some(ref handler) = self.on_connect_ext {
// run on_connect_ext callback, populating connect extensions
handler(&io, &mut connect_extensions);
}
Dispatcher::new(
io,
@ -474,7 +494,8 @@ where
self.srv.clone(),
self.expect.clone(),
self.upgrade.clone(),
on_connect,
deprecated_on_connect,
connect_extensions,
addr,
)
}