1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-08-19 04:15:38 +02:00

expose on_connect v2 (#1754)

Co-authored-by: Mikail Bagishov <bagishov.mikail@yandex.ru>
This commit is contained in:
Rob Ede
2020-10-30 02:03:26 +00:00
committed by GitHub
parent 4519db36b2
commit 9963a5ef54
16 changed files with 372 additions and 70 deletions

View File

@@ -2,7 +2,7 @@ use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{net, rc};
use std::{net, rc::Rc};
use actix_codec::{AsyncRead, AsyncWrite};
use actix_rt::net::TcpStream;
@@ -23,6 +23,7 @@ use crate::error::{DispatchError, Error};
use crate::helpers::DataFactory;
use crate::request::Request;
use crate::response::Response;
use crate::{ConnectCallback, Extensions};
use super::dispatcher::Dispatcher;
@@ -30,7 +31,8 @@ use super::dispatcher::Dispatcher;
pub struct H2Service<T, S, B> {
srv: S,
cfg: ServiceConfig,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
_t: PhantomData<(T, B)>,
}
@@ -50,19 +52,27 @@ where
H2Service {
cfg,
on_connect: None,
on_connect_ext: None,
srv: service.into_factory(),
_t: PhantomData,
}
}
/// Set on connect callback.
pub(crate) fn on_connect(
mut self,
f: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
f: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
) -> Self {
self.on_connect = f;
self
}
/// Set on connect callback.
pub(crate) fn on_connect_ext(mut self, f: Option<Rc<ConnectCallback<T>>>) -> Self {
self.on_connect_ext = f;
self
}
}
impl<S, B> H2Service<TcpStream, S, B>
@@ -203,6 +213,7 @@ where
fut: self.srv.new_service(()),
cfg: Some(self.cfg.clone()),
on_connect: self.on_connect.clone(),
on_connect_ext: self.on_connect_ext.clone(),
_t: PhantomData,
}
}
@@ -214,7 +225,8 @@ pub struct H2ServiceResponse<T, S: ServiceFactory, B> {
#[pin]
fut: S::Future,
cfg: Option<ServiceConfig>,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
_t: PhantomData<(T, B)>,
}
@@ -237,6 +249,7 @@ where
H2ServiceHandler::new(
this.cfg.take().unwrap(),
this.on_connect.clone(),
this.on_connect_ext.clone(),
service,
)
}))
@@ -247,7 +260,8 @@ where
pub struct H2ServiceHandler<T, S: Service, B> {
srv: CloneableService<S>,
cfg: ServiceConfig,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
_t: PhantomData<(T, B)>,
}
@@ -261,12 +275,14 @@ where
{
fn new(
cfg: ServiceConfig,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
srv: S,
) -> H2ServiceHandler<T, S, B> {
H2ServiceHandler {
cfg,
on_connect,
on_connect_ext,
srv: CloneableService::new(srv),
_t: PhantomData,
}
@@ -296,18 +312,21 @@ where
}
fn call(&mut self, (io, addr): Self::Request) -> Self::Future {
let on_connect = if let Some(ref on_connect) = self.on_connect {
Some(on_connect(&io))
} else {
None
};
let deprecated_on_connect = self.on_connect.as_ref().map(|handler| handler(&io));
let mut connect_extensions = Extensions::new();
if let Some(ref handler) = self.on_connect_ext {
// run on_connect_ext callback, populating connect extensions
handler(&io, &mut connect_extensions);
}
H2ServiceHandlerResponse {
state: State::Handshake(
Some(self.srv.clone()),
Some(self.cfg.clone()),
addr,
on_connect,
deprecated_on_connect,
Some(connect_extensions),
server::handshake(io),
),
}
@@ -325,6 +344,7 @@ where
Option<ServiceConfig>,
Option<net::SocketAddr>,
Option<Box<dyn DataFactory>>,
Option<Extensions>,
Handshake<T, Bytes>,
),
}
@@ -360,6 +380,7 @@ where
ref mut config,
ref peer_addr,
ref mut on_connect,
ref mut on_connect_data,
ref mut handshake,
) => match Pin::new(handshake).poll(cx) {
Poll::Ready(Ok(conn)) => {
@@ -367,6 +388,7 @@ where
srv.take().unwrap(),
conn,
on_connect.take(),
on_connect_data.take().unwrap(),
config.take().unwrap(),
None,
*peer_addr,