From 9963a5ef54119f3a4b791cdd76ae607397334d4d Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Fri, 30 Oct 2020 02:03:26 +0000 Subject: [PATCH] expose on_connect v2 (#1754) Co-authored-by: Mikail Bagishov --- CHANGES.md | 2 + Cargo.toml | 14 ++-- actix-http/CHANGES.md | 7 ++ actix-http/src/builder.rs | 41 +++++++++-- actix-http/src/extensions.rs | 30 +++++++- actix-http/src/h1/dispatcher.rs | 18 ++++- actix-http/src/h1/service.rs | 35 ++++++++-- actix-http/src/h2/dispatcher.rs | 8 +++ actix-http/src/h2/service.rs | 46 ++++++++---- actix-http/src/helpers.rs | 1 + actix-http/src/lib.rs | 4 +- actix-http/src/service.rs | 65 ++++++++++++----- actix-http/tests/test_openssl.rs | 2 + actix-http/tests/test_server.rs | 2 + examples/on_connect.rs | 51 ++++++++++++++ src/server.rs | 116 ++++++++++++++++++++++++++----- 16 files changed, 372 insertions(+), 70 deletions(-) create mode 100644 examples/on_connect.rs diff --git a/CHANGES.md b/CHANGES.md index fb4fde4f1..15d44b75c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -6,6 +6,7 @@ * Add request-local data extractor `web::ReqData`. [#1748] * Add ability to register closure for request middleware logging. [#1749] * Add `app_data` to `ServiceConfig`. [#1757] +* Expose `on_connect` for access to the connection stream before request is handled. [#1754] ### Changed * Print non-configured `Data` type when attempting extraction. [#1743] @@ -16,6 +17,7 @@ [#1743]: https://github.com/actix/actix-web/pull/1743 [#1748]: https://github.com/actix/actix-web/pull/1748 [#1750]: https://github.com/actix/actix-web/pull/1750 +[#1754]: https://github.com/actix/actix-web/pull/1754 [#1749]: https://github.com/actix/actix-web/pull/1749 diff --git a/Cargo.toml b/Cargo.toml index 5d64cfd91..4fafc61c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,6 +64,14 @@ required-features = ["compress"] name = "test_server" required-features = ["compress"] +[[example]] +name = "on_connect" +required-features = [] + +[[example]] +name = "client" +required-features = ["rustls"] + [dependencies] actix-codec = "0.3.0" actix-service = "1.0.6" @@ -105,7 +113,7 @@ tinyvec = { version = "1", features = ["alloc"] } actix = "0.10.0" actix-http = { version = "2.0.0", features = ["actors"] } rand = "0.7" -env_logger = "0.7" +env_logger = "0.8" serde_derive = "1.0" brotli2 = "0.3.2" flate2 = "1.0.13" @@ -125,10 +133,6 @@ actix-files = { path = "actix-files" } actix-multipart = { path = "actix-multipart" } awc = { path = "awc" } -[[example]] -name = "client" -required-features = ["rustls"] - [[bench]] name = "server" harness = false diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index 990c9c071..0afb63a6d 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -1,9 +1,16 @@ # Changes ## Unreleased - 2020-xx-xx +### Added +* Added more flexible `on_connect_ext` methods for on-connect handling. [#1754] + +### Changed * Upgrade `base64` to `0.13`. * Upgrade `pin-project` to `1.0`. +[#1754]: https://github.com/actix/actix-web/pull/1754 + + ## 2.0.0 - 2020-09-11 * No significant changes from `2.0.0-beta.4`. diff --git a/actix-http/src/builder.rs b/actix-http/src/builder.rs index 271abd43f..b28c69761 100644 --- a/actix-http/src/builder.rs +++ b/actix-http/src/builder.rs @@ -14,10 +14,11 @@ use crate::helpers::{Data, DataFactory}; use crate::request::Request; use crate::response::Response; use crate::service::HttpService; +use crate::{ConnectCallback, Extensions}; -/// A http service builder +/// A HTTP service builder /// -/// This type can be used to construct an instance of `http service` through a +/// This type can be used to construct an instance of [`HttpService`] through a /// builder-like pattern. pub struct HttpServiceBuilder> { keep_alive: KeepAlive, @@ -27,7 +28,9 @@ pub struct HttpServiceBuilder> { local_addr: Option, expect: X, upgrade: Option, + // DEPRECATED: in favor of on_connect_ext on_connect: Option Box>>, + on_connect_ext: Option>>, _t: PhantomData<(T, S)>, } @@ -49,6 +52,7 @@ where expect: ExpectHandler, upgrade: None, on_connect: None, + on_connect_ext: None, _t: PhantomData, } } @@ -138,6 +142,7 @@ where expect: expect.into_factory(), upgrade: self.upgrade, on_connect: self.on_connect, + on_connect_ext: self.on_connect_ext, _t: PhantomData, } } @@ -167,14 +172,16 @@ where expect: self.expect, upgrade: Some(upgrade.into_factory()), on_connect: self.on_connect, + on_connect_ext: self.on_connect_ext, _t: PhantomData, } } /// Set on-connect callback. /// - /// It get called once per connection and result of the call - /// get stored to the request's extensions. + /// Called once per connection. Return value of the call is stored in request extensions. + /// + /// *SOFT DEPRECATED*: Prefer the `on_connect_ext` style callback. pub fn on_connect(mut self, f: F) -> Self where F: Fn(&T) -> I + 'static, @@ -184,7 +191,20 @@ where self } - /// Finish service configuration and create *http service* for HTTP/1 protocol. + /// Sets the callback to be run on connection establishment. + /// + /// Has mutable access to a data container that will be merged into request extensions. + /// This enables transport layer data (like client certificates) to be accessed in middleware + /// and handlers. + pub fn on_connect_ext(mut self, f: F) -> Self + where + F: Fn(&T, &mut Extensions) + 'static, + { + self.on_connect_ext = Some(Rc::new(f)); + self + } + + /// Finish service configuration and create a HTTP Service for HTTP/1 protocol. pub fn h1(self, service: F) -> H1Service where B: MessageBody, @@ -200,13 +220,15 @@ where self.secure, self.local_addr, ); + H1Service::with_config(cfg, service.into_factory()) .expect(self.expect) .upgrade(self.upgrade) .on_connect(self.on_connect) + .on_connect_ext(self.on_connect_ext) } - /// Finish service configuration and create *http service* for HTTP/2 protocol. + /// Finish service configuration and create a HTTP service for HTTP/2 protocol. pub fn h2(self, service: F) -> H2Service where B: MessageBody + 'static, @@ -223,7 +245,10 @@ where self.secure, self.local_addr, ); - H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect) + + H2Service::with_config(cfg, service.into_factory()) + .on_connect(self.on_connect) + .on_connect_ext(self.on_connect_ext) } /// Finish service configuration and create `HttpService` instance. @@ -243,9 +268,11 @@ where self.secure, self.local_addr, ); + HttpService::with_config(cfg, service.into_factory()) .expect(self.expect) .upgrade(self.upgrade) .on_connect(self.on_connect) + .on_connect_ext(self.on_connect_ext) } } diff --git a/actix-http/src/extensions.rs b/actix-http/src/extensions.rs index 09f1b711f..7dda74731 100644 --- a/actix-http/src/extensions.rs +++ b/actix-http/src/extensions.rs @@ -1,5 +1,5 @@ use std::any::{Any, TypeId}; -use std::fmt; +use std::{fmt, mem}; use fxhash::FxHashMap; @@ -66,6 +66,11 @@ impl Extensions { pub fn extend(&mut self, other: Extensions) { self.map.extend(other.map); } + + /// Sets (or overrides) items from `other` into this map. + pub(crate) fn drain_from(&mut self, other: &mut Self) { + self.map.extend(mem::take(&mut other.map)); + } } impl fmt::Debug for Extensions { @@ -213,4 +218,27 @@ mod tests { assert_eq!(extensions.get(), Some(&20u8)); assert_eq!(extensions.get_mut(), Some(&mut 20u8)); } + + #[test] + fn test_drain_from() { + let mut ext = Extensions::new(); + ext.insert(2isize); + + let mut more_ext = Extensions::new(); + + more_ext.insert(5isize); + more_ext.insert(5usize); + + assert_eq!(ext.get::(), Some(&2isize)); + assert_eq!(ext.get::(), None); + assert_eq!(more_ext.get::(), Some(&5isize)); + assert_eq!(more_ext.get::(), Some(&5usize)); + + ext.drain_from(&mut more_ext); + + assert_eq!(ext.get::(), Some(&5isize)); + assert_eq!(ext.get::(), Some(&5usize)); + assert_eq!(more_ext.get::(), None); + assert_eq!(more_ext.get::(), None); + } } diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index 7c4de9707..ace4144e3 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -12,7 +12,6 @@ use bytes::{Buf, BytesMut}; use log::{error, trace}; use pin_project::pin_project; -use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::cloneable::CloneableService; use crate::config::ServiceConfig; use crate::error::{DispatchError, Error}; @@ -21,6 +20,10 @@ use crate::helpers::DataFactory; use crate::httpmessage::HttpMessage; use crate::request::Request; use crate::response::Response; +use crate::{ + body::{Body, BodySize, MessageBody, ResponseBody}, + Extensions, +}; use super::codec::Codec; use super::payload::{Payload, PayloadSender, PayloadStatus}; @@ -88,6 +91,7 @@ where expect: CloneableService, upgrade: Option>, on_connect: Option>, + on_connect_data: Extensions, flags: Flags, peer_addr: Option, error: Option, @@ -167,7 +171,7 @@ where U: Service), Response = ()>, U::Error: fmt::Display, { - /// Create http/1 dispatcher. + /// Create HTTP/1 dispatcher. pub(crate) fn new( stream: T, config: ServiceConfig, @@ -175,6 +179,7 @@ where expect: CloneableService, upgrade: Option>, on_connect: Option>, + on_connect_data: Extensions, peer_addr: Option, ) -> Self { Dispatcher::with_timeout( @@ -187,6 +192,7 @@ where expect, upgrade, on_connect, + on_connect_data, peer_addr, ) } @@ -202,6 +208,7 @@ where expect: CloneableService, upgrade: Option>, on_connect: Option>, + on_connect_data: Extensions, peer_addr: Option, ) -> Self { let keepalive = config.keep_alive_enabled(); @@ -234,6 +241,7 @@ where expect, upgrade, on_connect, + on_connect_data, flags, peer_addr, ka_expire, @@ -526,11 +534,15 @@ where let pl = this.codec.message_type(); req.head_mut().peer_addr = *this.peer_addr; + // DEPRECATED // set on_connect data if let Some(ref on_connect) = this.on_connect { on_connect.set(&mut req.extensions_mut()); } + // merge on_connect_ext data into request extensions + req.extensions_mut().drain_from(this.on_connect_data); + if pl == MessageType::Stream && this.upgrade.is_some() { this.messages.push_back(DispatcherMessage::Upgrade(req)); break; @@ -927,8 +939,10 @@ mod tests { CloneableService::new(ExpectHandler), None, None, + Extensions::new(), None, ); + match Pin::new(&mut h1).poll(cx) { Poll::Pending => panic!(), Poll::Ready(res) => assert!(res.is_err()), diff --git a/actix-http/src/h1/service.rs b/actix-http/src/h1/service.rs index 6aafd4089..5008791c0 100644 --- a/actix-http/src/h1/service.rs +++ b/actix-http/src/h1/service.rs @@ -18,6 +18,7 @@ use crate::error::{DispatchError, Error, ParseError}; use crate::helpers::DataFactory; use crate::request::Request; use crate::response::Response; +use crate::{ConnectCallback, Extensions}; use super::codec::Codec; use super::dispatcher::Dispatcher; @@ -30,6 +31,7 @@ pub struct H1Service> { expect: X, upgrade: Option, on_connect: Option Box>>, + on_connect_ext: Option>>, _t: PhantomData<(T, B)>, } @@ -52,6 +54,7 @@ where expect: ExpectHandler, upgrade: None, on_connect: None, + on_connect_ext: None, _t: PhantomData, } } @@ -213,6 +216,7 @@ where srv: self.srv, upgrade: self.upgrade, on_connect: self.on_connect, + on_connect_ext: self.on_connect_ext, _t: PhantomData, } } @@ -229,6 +233,7 @@ where srv: self.srv, expect: self.expect, on_connect: self.on_connect, + on_connect_ext: self.on_connect_ext, _t: PhantomData, } } @@ -241,6 +246,12 @@ where self.on_connect = f; self } + + /// Set on connect callback. + pub(crate) fn on_connect_ext(mut self, f: Option>>) -> Self { + self.on_connect_ext = f; + self + } } impl ServiceFactory for H1Service @@ -274,6 +285,7 @@ where expect: None, upgrade: None, on_connect: self.on_connect.clone(), + on_connect_ext: self.on_connect_ext.clone(), cfg: Some(self.cfg.clone()), _t: PhantomData, } @@ -303,6 +315,7 @@ where expect: Option, upgrade: Option, on_connect: Option Box>>, + on_connect_ext: Option>>, cfg: Option, _t: PhantomData<(T, B)>, } @@ -352,23 +365,26 @@ where Poll::Ready(result.map(|service| { let this = self.as_mut().project(); + H1ServiceHandler::new( this.cfg.take().unwrap(), service, this.expect.take().unwrap(), this.upgrade.take(), this.on_connect.clone(), + this.on_connect_ext.clone(), ) })) } } -/// `Service` implementation for HTTP1 transport +/// `Service` implementation for HTTP/1 transport pub struct H1ServiceHandler { srv: CloneableService, expect: CloneableService, upgrade: Option>, on_connect: Option Box>>, + on_connect_ext: Option>>, cfg: ServiceConfig, _t: PhantomData<(T, B)>, } @@ -390,6 +406,7 @@ where expect: X, upgrade: Option, on_connect: Option Box>>, + on_connect_ext: Option>>, ) -> H1ServiceHandler { H1ServiceHandler { srv: CloneableService::new(srv), @@ -397,6 +414,7 @@ where upgrade: upgrade.map(CloneableService::new), cfg, on_connect, + on_connect_ext, _t: PhantomData, } } @@ -462,11 +480,13 @@ where } fn call(&mut self, (io, addr): Self::Request) -> Self::Future { - let on_connect = if let Some(ref on_connect) = self.on_connect { - Some(on_connect(&io)) - } else { - None - }; + let deprecated_on_connect = self.on_connect.as_ref().map(|handler| handler(&io)); + + let mut connect_extensions = Extensions::new(); + if let Some(ref handler) = self.on_connect_ext { + // run on_connect_ext callback, populating connect extensions + handler(&io, &mut connect_extensions); + } Dispatcher::new( io, @@ -474,7 +494,8 @@ where self.srv.clone(), self.expect.clone(), self.upgrade.clone(), - on_connect, + deprecated_on_connect, + connect_extensions, addr, ) } diff --git a/actix-http/src/h2/dispatcher.rs b/actix-http/src/h2/dispatcher.rs index daa651f4d..594339121 100644 --- a/actix-http/src/h2/dispatcher.rs +++ b/actix-http/src/h2/dispatcher.rs @@ -24,6 +24,7 @@ use crate::message::ResponseHead; use crate::payload::Payload; use crate::request::Request; use crate::response::Response; +use crate::Extensions; const CHUNK_SIZE: usize = 16_384; @@ -36,6 +37,7 @@ where service: CloneableService, connection: Connection, on_connect: Option>, + on_connect_data: Extensions, config: ServiceConfig, peer_addr: Option, ka_expire: Instant, @@ -56,6 +58,7 @@ where service: CloneableService, connection: Connection, on_connect: Option>, + on_connect_data: Extensions, config: ServiceConfig, timeout: Option, peer_addr: Option, @@ -82,6 +85,7 @@ where peer_addr, connection, on_connect, + on_connect_data, ka_expire, ka_timer, _t: PhantomData, @@ -130,11 +134,15 @@ where head.headers = parts.headers.into(); head.peer_addr = this.peer_addr; + // DEPRECATED // set on_connect data if let Some(ref on_connect) = this.on_connect { on_connect.set(&mut req.extensions_mut()); } + // merge on_connect_ext data into request extensions + req.extensions_mut().drain_from(&mut this.on_connect_data); + actix_rt::spawn(ServiceResponse::< S::Future, S::Response, diff --git a/actix-http/src/h2/service.rs b/actix-http/src/h2/service.rs index 6b5620e02..3103127b4 100644 --- a/actix-http/src/h2/service.rs +++ b/actix-http/src/h2/service.rs @@ -2,7 +2,7 @@ use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; -use std::{net, rc}; +use std::{net, rc::Rc}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_rt::net::TcpStream; @@ -23,6 +23,7 @@ use crate::error::{DispatchError, Error}; use crate::helpers::DataFactory; use crate::request::Request; use crate::response::Response; +use crate::{ConnectCallback, Extensions}; use super::dispatcher::Dispatcher; @@ -30,7 +31,8 @@ use super::dispatcher::Dispatcher; pub struct H2Service { srv: S, cfg: ServiceConfig, - on_connect: Option Box>>, + on_connect: Option Box>>, + on_connect_ext: Option>>, _t: PhantomData<(T, B)>, } @@ -50,19 +52,27 @@ where H2Service { cfg, on_connect: None, + on_connect_ext: None, srv: service.into_factory(), _t: PhantomData, } } /// Set on connect callback. + pub(crate) fn on_connect( mut self, - f: Option Box>>, + f: Option Box>>, ) -> Self { self.on_connect = f; self } + + /// Set on connect callback. + pub(crate) fn on_connect_ext(mut self, f: Option>>) -> Self { + self.on_connect_ext = f; + self + } } impl H2Service @@ -203,6 +213,7 @@ where fut: self.srv.new_service(()), cfg: Some(self.cfg.clone()), on_connect: self.on_connect.clone(), + on_connect_ext: self.on_connect_ext.clone(), _t: PhantomData, } } @@ -214,7 +225,8 @@ pub struct H2ServiceResponse { #[pin] fut: S::Future, cfg: Option, - on_connect: Option Box>>, + on_connect: Option Box>>, + on_connect_ext: Option>>, _t: PhantomData<(T, B)>, } @@ -237,6 +249,7 @@ where H2ServiceHandler::new( this.cfg.take().unwrap(), this.on_connect.clone(), + this.on_connect_ext.clone(), service, ) })) @@ -247,7 +260,8 @@ where pub struct H2ServiceHandler { srv: CloneableService, cfg: ServiceConfig, - on_connect: Option Box>>, + on_connect: Option Box>>, + on_connect_ext: Option>>, _t: PhantomData<(T, B)>, } @@ -261,12 +275,14 @@ where { fn new( cfg: ServiceConfig, - on_connect: Option Box>>, + on_connect: Option Box>>, + on_connect_ext: Option>>, srv: S, ) -> H2ServiceHandler { H2ServiceHandler { cfg, on_connect, + on_connect_ext, srv: CloneableService::new(srv), _t: PhantomData, } @@ -296,18 +312,21 @@ where } fn call(&mut self, (io, addr): Self::Request) -> Self::Future { - let on_connect = if let Some(ref on_connect) = self.on_connect { - Some(on_connect(&io)) - } else { - None - }; + let deprecated_on_connect = self.on_connect.as_ref().map(|handler| handler(&io)); + + let mut connect_extensions = Extensions::new(); + if let Some(ref handler) = self.on_connect_ext { + // run on_connect_ext callback, populating connect extensions + handler(&io, &mut connect_extensions); + } H2ServiceHandlerResponse { state: State::Handshake( Some(self.srv.clone()), Some(self.cfg.clone()), addr, - on_connect, + deprecated_on_connect, + Some(connect_extensions), server::handshake(io), ), } @@ -325,6 +344,7 @@ where Option, Option, Option>, + Option, Handshake, ), } @@ -360,6 +380,7 @@ where ref mut config, ref peer_addr, ref mut on_connect, + ref mut on_connect_data, ref mut handshake, ) => match Pin::new(handshake).poll(cx) { Poll::Ready(Ok(conn)) => { @@ -367,6 +388,7 @@ where srv.take().unwrap(), conn, on_connect.take(), + on_connect_data.take().unwrap(), config.take().unwrap(), None, *peer_addr, diff --git a/actix-http/src/helpers.rs b/actix-http/src/helpers.rs index bbf358b66..ac0e0f118 100644 --- a/actix-http/src/helpers.rs +++ b/actix-http/src/helpers.rs @@ -50,6 +50,7 @@ impl<'a> io::Write for Writer<'a> { self.0.extend_from_slice(buf); Ok(buf.len()) } + fn flush(&mut self) -> io::Result<()> { Ok(()) } diff --git a/actix-http/src/lib.rs b/actix-http/src/lib.rs index fab91be2b..e57a3727e 100644 --- a/actix-http/src/lib.rs +++ b/actix-http/src/lib.rs @@ -1,4 +1,4 @@ -//! Basic http primitives for actix-net framework. +//! Basic HTTP primitives for the Actix ecosystem. #![deny(rust_2018_idioms)] #![allow( @@ -78,3 +78,5 @@ pub enum Protocol { Http1, Http2, } + +type ConnectCallback = dyn Fn(&IO, &mut Extensions); diff --git a/actix-http/src/service.rs b/actix-http/src/service.rs index 9ee579702..75745209c 100644 --- a/actix-http/src/service.rs +++ b/actix-http/src/service.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; -use std::{fmt, net, rc}; +use std::{fmt, net, rc::Rc}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_rt::net::TcpStream; @@ -20,15 +20,17 @@ use crate::error::{DispatchError, Error}; use crate::helpers::DataFactory; use crate::request::Request; use crate::response::Response; -use crate::{h1, h2::Dispatcher, Protocol}; +use crate::{h1, h2::Dispatcher, ConnectCallback, Extensions, Protocol}; -/// `ServiceFactory` HTTP1.1/HTTP2 transport implementation +/// A `ServiceFactory` for HTTP/1.1 or HTTP/2 protocol. pub struct HttpService> { srv: S, cfg: ServiceConfig, expect: X, upgrade: Option, - on_connect: Option Box>>, + // DEPRECATED: in favor of on_connect_ext + on_connect: Option Box>>, + on_connect_ext: Option>>, _t: PhantomData<(T, B)>, } @@ -66,6 +68,7 @@ where expect: h1::ExpectHandler, upgrade: None, on_connect: None, + on_connect_ext: None, _t: PhantomData, } } @@ -81,6 +84,7 @@ where expect: h1::ExpectHandler, upgrade: None, on_connect: None, + on_connect_ext: None, _t: PhantomData, } } @@ -113,6 +117,7 @@ where srv: self.srv, upgrade: self.upgrade, on_connect: self.on_connect, + on_connect_ext: self.on_connect_ext, _t: PhantomData, } } @@ -138,6 +143,7 @@ where srv: self.srv, expect: self.expect, on_connect: self.on_connect, + on_connect_ext: self.on_connect_ext, _t: PhantomData, } } @@ -145,11 +151,17 @@ where /// Set on connect callback. pub(crate) fn on_connect( mut self, - f: Option Box>>, + f: Option Box>>, ) -> Self { self.on_connect = f; self } + + /// Set connect callback with mutable access to request data container. + pub(crate) fn on_connect_ext(mut self, f: Option>>) -> Self { + self.on_connect_ext = f; + self + } } impl HttpService @@ -355,6 +367,7 @@ where expect: None, upgrade: None, on_connect: self.on_connect.clone(), + on_connect_ext: self.on_connect_ext.clone(), cfg: self.cfg.clone(), _t: PhantomData, } @@ -378,7 +391,8 @@ pub struct HttpServiceResponse< fut_upg: Option, expect: Option, upgrade: Option, - on_connect: Option Box>>, + on_connect: Option Box>>, + on_connect_ext: Option>>, cfg: ServiceConfig, _t: PhantomData<(T, B)>, } @@ -429,6 +443,7 @@ where .fut .poll(cx) .map_err(|e| log::error!("Init http service error: {:?}", e))); + Poll::Ready(result.map(|service| { let this = self.as_mut().project(); HttpServiceHandler::new( @@ -437,6 +452,7 @@ where this.expect.take().unwrap(), this.upgrade.take(), this.on_connect.clone(), + this.on_connect_ext.clone(), ) })) } @@ -448,7 +464,8 @@ pub struct HttpServiceHandler { expect: CloneableService, upgrade: Option>, cfg: ServiceConfig, - on_connect: Option Box>>, + on_connect: Option Box>>, + on_connect_ext: Option>>, _t: PhantomData<(T, B, X)>, } @@ -469,11 +486,13 @@ where srv: S, expect: X, upgrade: Option, - on_connect: Option Box>>, + on_connect: Option Box>>, + on_connect_ext: Option>>, ) -> HttpServiceHandler { HttpServiceHandler { cfg, on_connect, + on_connect_ext, srv: CloneableService::new(srv), expect: CloneableService::new(expect), upgrade: upgrade.map(CloneableService::new), @@ -543,11 +562,12 @@ where } fn call(&mut self, (io, proto, peer_addr): Self::Request) -> Self::Future { - let on_connect = if let Some(ref on_connect) = self.on_connect { - Some(on_connect(&io)) - } else { - None - }; + let mut connect_extensions = Extensions::new(); + + let deprecated_on_connect = self.on_connect.as_ref().map(|handler| handler(&io)); + if let Some(ref handler) = self.on_connect_ext { + handler(&io, &mut connect_extensions); + } match proto { Protocol::Http2 => HttpServiceHandlerResponse { @@ -555,10 +575,12 @@ where server::handshake(io), self.cfg.clone(), self.srv.clone(), - on_connect, + deprecated_on_connect, + connect_extensions, peer_addr, ))), }, + Protocol::Http1 => HttpServiceHandlerResponse { state: State::H1(h1::Dispatcher::new( io, @@ -566,7 +588,8 @@ where self.srv.clone(), self.expect.clone(), self.upgrade.clone(), - on_connect, + deprecated_on_connect, + connect_extensions, peer_addr, )), }, @@ -595,6 +618,7 @@ where ServiceConfig, CloneableService, Option>, + Extensions, Option, )>, ), @@ -670,9 +694,16 @@ where } else { panic!() }; - let (_, cfg, srv, on_connect, peer_addr) = data.take().unwrap(); + let (_, cfg, srv, on_connect, on_connect_data, peer_addr) = + data.take().unwrap(); self.set(State::H2(Dispatcher::new( - srv, conn, on_connect, cfg, None, peer_addr, + srv, + conn, + on_connect, + on_connect_data, + cfg, + None, + peer_addr, ))); self.poll(cx) } diff --git a/actix-http/tests/test_openssl.rs b/actix-http/tests/test_openssl.rs index 795deacdc..05f01d240 100644 --- a/actix-http/tests/test_openssl.rs +++ b/actix-http/tests/test_openssl.rs @@ -411,8 +411,10 @@ async fn test_h2_on_connect() { let srv = test_server(move || { HttpService::build() .on_connect(|_| 10usize) + .on_connect_ext(|_, data| data.insert(20isize)) .h2(|req: Request| { assert!(req.extensions().contains::()); + assert!(req.extensions().contains::()); ok::<_, ()>(Response::Ok().finish()) }) .openssl(ssl_acceptor()) diff --git a/actix-http/tests/test_server.rs b/actix-http/tests/test_server.rs index 0375b6f66..de6368fda 100644 --- a/actix-http/tests/test_server.rs +++ b/actix-http/tests/test_server.rs @@ -663,8 +663,10 @@ async fn test_h1_on_connect() { let srv = test_server(|| { HttpService::build() .on_connect(|_| 10usize) + .on_connect_ext(|_, data| data.insert(20isize)) .h1(|req: Request| { assert!(req.extensions().contains::()); + assert!(req.extensions().contains::()); future::ok::<_, ()>(Response::Ok().finish()) }) .tcp() diff --git a/examples/on_connect.rs b/examples/on_connect.rs new file mode 100644 index 000000000..bdad7e67e --- /dev/null +++ b/examples/on_connect.rs @@ -0,0 +1,51 @@ +//! This example shows how to use `actix_web::HttpServer::on_connect` to access a lower-level socket +//! properties and pass them to a handler through request-local data. +//! +//! For an example of extracting a client TLS certificate, see: +//! + +use std::{any::Any, env, io, net::SocketAddr}; + +use actix_web::{dev::Extensions, rt::net::TcpStream, web, App, HttpServer}; + +#[derive(Debug, Clone)] +struct ConnectionInfo { + bind: SocketAddr, + peer: SocketAddr, + ttl: Option, +} + +async fn route_whoami(conn_info: web::ReqData) -> String { + format!( + "Here is some info about your connection:\n\n{:#?}", + conn_info + ) +} + +fn get_conn_info(connection: &dyn Any, data: &mut Extensions) { + if let Some(sock) = connection.downcast_ref::() { + data.insert(ConnectionInfo { + bind: sock.local_addr().unwrap(), + peer: sock.peer_addr().unwrap(), + ttl: sock.ttl().ok(), + }); + } else { + unreachable!("connection should only be plaintext since no TLS is set up"); + } +} + +#[actix_web::main] +async fn main() -> io::Result<()> { + if env::var("RUST_LOG").is_err() { + env::set_var("RUST_LOG", "info"); + } + + env_logger::init(); + + HttpServer::new(|| App::new().default_service(web::to(route_whoami))) + .on_connect(get_conn_info) + .bind(("127.0.0.1", 8080))? + .workers(1) + .run() + .await +} diff --git a/src/server.rs b/src/server.rs index 2b86f7416..3badb6e8d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,8 +1,14 @@ -use std::marker::PhantomData; -use std::sync::{Arc, Mutex}; -use std::{fmt, io, net}; +use std::{ + any::Any, + fmt, io, + marker::PhantomData, + net, + sync::{Arc, Mutex}, +}; -use actix_http::{body::MessageBody, Error, HttpService, KeepAlive, Request, Response}; +use actix_http::{ + body::MessageBody, Error, Extensions, HttpService, KeepAlive, Request, Response, +}; use actix_server::{Server, ServerBuilder}; use actix_service::{map_config, IntoServiceFactory, Service, ServiceFactory}; @@ -64,6 +70,7 @@ where backlog: i32, sockets: Vec, builder: ServerBuilder, + on_connect_fn: Option>, _t: PhantomData<(S, B)>, } @@ -91,6 +98,32 @@ where backlog: 1024, sockets: Vec::new(), builder: ServerBuilder::default(), + on_connect_fn: None, + _t: PhantomData, + } + } + + /// Sets function that will be called once before each connection is handled. + /// It will receive a `&std::any::Any`, which contains underlying connection type and an + /// [Extensions] container so that request-local data can be passed to middleware and handlers. + /// + /// For example: + /// - `actix_tls::openssl::SslStream` when using openssl. + /// - `actix_tls::rustls::TlsStream` when using rustls. + /// - `actix_web::rt::net::TcpStream` when no encryption is used. + /// + /// See `on_connect` example for additional details. + pub fn on_connect(self, f: CB) -> HttpServer + where + CB: Fn(&dyn Any, &mut Extensions) + Send + Sync + 'static, + { + HttpServer { + factory: self.factory, + config: self.config, + backlog: self.backlog, + sockets: self.sockets, + builder: self.builder, + on_connect_fn: Some(Arc::new(f)), _t: PhantomData, } } @@ -240,6 +273,7 @@ where addr, scheme: "http", }); + let on_connect_fn = self.on_connect_fn.clone(); self.builder = self.builder.listen( format!("actix-web-service-{}", addr), @@ -252,11 +286,20 @@ where c.host.clone().unwrap_or_else(|| format!("{}", addr)), ); - HttpService::build() + let svc = HttpService::build() .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) - .local_addr(addr) - .finish(map_config(factory(), move |_| cfg.clone())) + .local_addr(addr); + + let svc = if let Some(handler) = on_connect_fn.clone() { + svc.on_connect_ext(move |io: &_, ext: _| { + (handler)(io as &dyn Any, ext) + }) + } else { + svc + }; + + svc.finish(map_config(factory(), move |_| cfg.clone())) .tcp() }, )?; @@ -289,6 +332,8 @@ where scheme: "https", }); + let on_connect_fn = self.on_connect_fn.clone(); + self.builder = self.builder.listen( format!("actix-web-service-{}", addr), lst, @@ -299,11 +344,21 @@ where addr, c.host.clone().unwrap_or_else(|| format!("{}", addr)), ); - HttpService::build() + + let svc = HttpService::build() .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) - .client_disconnect(c.client_shutdown) - .finish(map_config(factory(), move |_| cfg.clone())) + .client_disconnect(c.client_shutdown); + + let svc = if let Some(handler) = on_connect_fn.clone() { + svc.on_connect_ext(move |io: &_, ext: _| { + (&*handler)(io as &dyn Any, ext) + }) + } else { + svc + }; + + svc.finish(map_config(factory(), move |_| cfg.clone())) .openssl(acceptor.clone()) }, )?; @@ -336,6 +391,8 @@ where scheme: "https", }); + let on_connect_fn = self.on_connect_fn.clone(); + self.builder = self.builder.listen( format!("actix-web-service-{}", addr), lst, @@ -346,11 +403,21 @@ where addr, c.host.clone().unwrap_or_else(|| format!("{}", addr)), ); - HttpService::build() + + let svc = HttpService::build() .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) - .client_disconnect(c.client_shutdown) - .finish(map_config(factory(), move |_| cfg.clone())) + .client_disconnect(c.client_shutdown); + + let svc = if let Some(handler) = on_connect_fn.clone() { + svc.on_connect_ext(move |io: &_, ext: _| { + (handler)(io as &dyn Any, ext) + }) + } else { + svc + }; + + svc.finish(map_config(factory(), move |_| cfg.clone())) .rustls(config.clone()) }, )?; @@ -441,7 +508,7 @@ where } #[cfg(unix)] - /// Start listening for unix domain connections on existing listener. + /// Start listening for unix domain (UDS) connections on existing listener. pub fn listen_uds( mut self, lst: std::os::unix::net::UnixListener, @@ -460,6 +527,7 @@ where }); let addr = format!("actix-web-service-{:?}", lst.local_addr()?); + let on_connect_fn = self.on_connect_fn.clone(); self.builder = self.builder.listen_uds(addr, lst, move || { let c = cfg.lock().unwrap(); @@ -468,11 +536,23 @@ where socket_addr, c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)), ); + pipeline_factory(|io: UnixStream| ok((io, Protocol::Http1, None))).and_then( - HttpService::build() - .keep_alive(c.keep_alive) - .client_timeout(c.client_timeout) - .finish(map_config(factory(), move |_| config.clone())), + { + let svc = HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout); + + let svc = if let Some(handler) = on_connect_fn.clone() { + svc.on_connect_ext(move |io: &_, ext: _| { + (&*handler)(io as &dyn Any, ext) + }) + } else { + svc + }; + + svc.finish(map_config(factory(), move |_| config.clone())) + }, ) })?; Ok(self)