diff --git a/examples/tls/src/main.rs b/examples/tls/src/main.rs index 3368c4ebc..886049d19 100644 --- a/examples/tls/src/main.rs +++ b/examples/tls/src/main.rs @@ -42,8 +42,8 @@ fn main() { .header("LOCATION", "/index.html") .body(Body::Empty) }))) - .serve_tls::<_, ()>("127.0.0.1:8080", pkcs12).unwrap(); + .serve_tls::<_, ()>("127.0.0.1:8443", pkcs12).unwrap(); - println!("Started http server: 127.0.0.1:8080"); + println!("Started http server: 127.0.0.1:8443"); let _ = sys.run(); } diff --git a/src/httprequest.rs b/src/httprequest.rs index 6a167d7dd..75aa88036 100644 --- a/src/httprequest.rs +++ b/src/httprequest.rs @@ -9,12 +9,14 @@ use url::form_urlencoded; use http::{header, Uri, Method, Version, HeaderMap, Extensions}; use {Cookie, HttpRange}; +use info::ConnectionInfo; use recognizer::Params; use payload::Payload; use multipart::Multipart; use error::{ParseError, PayloadError, MultipartError, CookieParseError, HttpRangeError, UrlencodedError}; + struct HttpMessage { version: Version, method: Method, @@ -27,6 +29,7 @@ struct HttpMessage { cookies_loaded: bool, addr: Option, payload: Payload, + info: Option>, } impl Default for HttpMessage { @@ -44,6 +47,7 @@ impl Default for HttpMessage { addr: None, payload: Payload::empty(), extensions: Extensions::new(), + info: None, } } } @@ -70,6 +74,7 @@ impl HttpRequest<()> { addr: None, payload: payload, extensions: Extensions::new(), + info: None, }), Rc::new(()) ) @@ -106,6 +111,15 @@ impl HttpRequest { &mut self.as_mut().extensions } + pub(crate) fn set_prefix(&mut self, idx: usize) { + self.as_mut().prefix = idx; + } + + #[doc(hidden)] + pub fn prefix_len(&self) -> usize { + self.0.prefix + } + /// Read the Request Uri. #[inline] pub fn uri(&self) -> &Uri { &self.0.uri } @@ -132,13 +146,15 @@ impl HttpRequest { self.0.uri.path() } - pub(crate) fn set_prefix(&mut self, idx: usize) { - self.as_mut().prefix = idx; - } - - #[doc(hidden)] - pub fn prefix_len(&self) -> usize { - self.0.prefix + /// Load *ConnectionInfo* for currect request. + #[inline] + pub fn load_connection_info(&mut self) -> &ConnectionInfo { + if self.0.info.is_none() { + let info: ConnectionInfo<'static> = unsafe{ + mem::transmute(ConnectionInfo::new(self))}; + self.as_mut().info = Some(info); + } + self.0.info.as_ref().unwrap() } /// Remote IP of client initiated HTTP request. diff --git a/src/info.rs b/src/info.rs new file mode 100644 index 000000000..76420ee6c --- /dev/null +++ b/src/info.rs @@ -0,0 +1,148 @@ +use std::str::FromStr; +use http::header::{self, HeaderName}; +use httprequest::HttpRequest; + +const X_FORWARDED_HOST: &str = "X-FORWARDED-HOST"; +const X_FORWARDED_PROTO: &str = "X-FORWARDED-PROTO"; + + +/// `HttpRequest` connection information +/// +/// While it is possible to create `ConnectionInfo` directly, +/// consider using `HttpRequest::load_connection_info()` which cache result. +pub struct ConnectionInfo<'a> { + scheme: &'a str, + host: &'a str, + remote: String, + forwarded_for: Vec<&'a str>, + forwarded_by: Vec<&'a str>, +} + +impl<'a> ConnectionInfo<'a> { + + /// Create *ConnectionInfo* instance for a request. + pub fn new(req: &'a HttpRequest) -> ConnectionInfo<'a> { + let mut host = None; + let mut scheme = None; + let mut forwarded_for = Vec::new(); + let mut forwarded_by = Vec::new(); + + // load forwarded header + for hdr in req.headers().get_all(header::FORWARDED) { + if let Ok(val) = hdr.to_str() { + for pair in val.split(';') { + for el in pair.split(',') { + let mut items = el.splitn(1, '='); + if let Some(name) = items.next() { + if let Some(val) = items.next() { + match &name.to_lowercase() as &str { + "for" => forwarded_for.push(val.trim()), + "by" => forwarded_by.push(val.trim()), + "proto" => if scheme.is_none() { + scheme = Some(val.trim()); + }, + "host" => if host.is_none() { + host = Some(val.trim()); + }, + _ => (), + } + } + } + } + } + } + } + + // scheme + if scheme.is_none() { + if let Some(h) = req.headers().get( + HeaderName::from_str(X_FORWARDED_PROTO).unwrap()) { + if let Ok(h) = h.to_str() { + scheme = h.split(',').next().map(|v| v.trim()); + } + } + if scheme.is_none() { + if let Some(a) = req.uri().scheme_part() { + scheme = Some(a.as_str()) + } + } + } + + // host + if host.is_none() { + if let Some(h) = req.headers().get(HeaderName::from_str(X_FORWARDED_HOST).unwrap()) { + if let Ok(h) = h.to_str() { + host = h.split(',').next().map(|v| v.trim()); + } + } + if host.is_none() { + if let Some(h) = req.headers().get(header::HOST) { + if let Ok(h) = h.to_str() { + host = Some(h); + } + } + if host.is_none() { + if let Some(a) = req.uri().authority_part() { + host = Some(a.as_str()) + } + } + } + } + + ConnectionInfo { + scheme: scheme.unwrap_or("http"), + host: host.unwrap_or("localhost"), + remote: String::new(), + forwarded_for: forwarded_for, + forwarded_by: forwarded_by, + } + } + + /// Scheme of the request. + /// + /// Scheme is resolved through the following headers, in this order: + /// + /// - Forwarded + /// - X-Forwarded-Proto + /// - Uri + #[inline] + pub fn scheme(&self) -> &str { + self.scheme + } + + /// Hostname of the request. + /// + /// Hostname is resolved through the following headers, in this order: + /// + /// - Forwarded + /// - X-Forwarded-Host + /// - Host + /// - Uri + pub fn host(&self) -> &str { + self.host + } + + /// Remote IP of client initiated HTTP request. + /// + /// The IP is resolved through the following headers, in this order: + /// + /// - Forwarded + /// - X-Forwarded-For + /// - peername of opened socket + #[inline] + pub fn remote(&self) -> &str { + &self.remote + } + + /// List of the nodes making the request to the proxy. + #[inline] + pub fn forwarded_for(&self) -> &Vec<&str> { + &self.forwarded_for + } + + /// List of the user-agent facing interface of the proxies + #[inline] + pub fn forwarded_by(&self) -> &Vec<&str> { + &self.forwarded_by + } +} diff --git a/src/lib.rs b/src/lib.rs index d365c87c7..90f2cdba8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,6 +55,7 @@ mod encoding; mod httprequest; mod httpresponse; mod payload; +mod info; mod route; mod resource; mod recognizer; @@ -81,6 +82,7 @@ pub use error::{Error, Result}; pub use encoding::ContentEncoding; pub use body::{Body, Binary}; pub use application::Application; +pub use info::ConnectionInfo; pub use httprequest::{HttpRequest, UrlEncoded}; pub use httpresponse::HttpResponse; pub use payload::{Payload, PayloadItem}; diff --git a/src/recognizer.rs b/src/recognizer.rs index be9667195..9776f3950 100644 --- a/src/recognizer.rs +++ b/src/recognizer.rs @@ -209,28 +209,17 @@ pub struct RouteRecognizer { patterns: HashMap, } -impl Default for RouteRecognizer { - - fn default() -> Self { - RouteRecognizer { - prefix: 0, - re: RegexSet::new([""].iter()).unwrap(), - routes: Vec::new(), - patterns: HashMap::new(), - } - } -} - impl RouteRecognizer { - pub fn new, U>(prefix: P, routes: U) -> Self - where U: IntoIterator, T)> + pub fn new, U, K>(prefix: P, routes: U) -> Self + where U: IntoIterator, T)>, + K: Into, { let mut paths = Vec::new(); let mut handlers = Vec::new(); let mut patterns = HashMap::new(); for item in routes { - let (pat, elements) = parse(&item.0); + let (pat, elements) = parse(&item.0.into()); let pattern = Pattern::new(&pat, elements); if let Some(ref name) = item.1 { let _ = patterns.insert(name.clone(), pattern.clone()); @@ -399,8 +388,6 @@ mod tests { #[test] fn test_recognizer() { - let mut rec = RouteRecognizer::::default(); - let routes = vec![ ("/name", None, 1), ("/name/{val}", None, 2), @@ -408,7 +395,7 @@ mod tests { ("/v{val}/{val2}/index.html", None, 4), ("/v/{tail:.*}", None, 5), ]; - rec.set_routes(routes); + let mut rec = RouteRecognizer::new("/", routes); let (params, val) = rec.recognize("/name").unwrap(); assert_eq!(*val, 1); diff --git a/src/ws.rs b/src/ws.rs index 21b630be7..cbde61e59 100644 --- a/src/ws.rs +++ b/src/ws.rs @@ -66,13 +66,9 @@ use wsframe; use wsproto::*; pub use wsproto::CloseCode; -#[doc(hidden)] const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT"; -#[doc(hidden)] const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY"; -#[doc(hidden)] const SEC_WEBSOCKET_VERSION: &str = "SEC-WEBSOCKET-VERSION"; -// #[doc(hidden)] // const SEC_WEBSOCKET_PROTOCOL: &'static str = "SEC-WEBSOCKET-PROTOCOL";