From f369d9af0eecb88b38ce48c1f9579523d10fce81 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 10 Nov 2017 13:08:15 -0800 Subject: [PATCH] make remote addr available to http request --- src/channel.rs | 30 ++++++++++++++++-------------- src/h1.rs | 15 +++++++++------ src/h2.rs | 25 ++++++++++++++++--------- src/httprequest.rs | 19 +++++++++++++++++++ src/server.rs | 36 +++++++++++++++++++----------------- 5 files changed, 79 insertions(+), 46 deletions(-) diff --git a/src/channel.rs b/src/channel.rs index 3f5ece3d1..2403b46c2 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -1,4 +1,5 @@ use std::rc::Rc; +use std::net::SocketAddr; use actix::dev::*; use bytes::Bytes; @@ -19,23 +20,24 @@ pub trait HttpHandler: 'static { fn handle(&self, req: &mut HttpRequest, payload: Payload) -> Task; } -enum HttpProtocol - where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static +enum HttpProtocol + where T: AsyncRead + AsyncWrite + 'static, H: 'static { - H1(h1::Http1), - H2(h2::Http2), + H1(h1::Http1), + H2(h2::Http2), } -pub struct HttpChannel - where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static +pub struct HttpChannel + where T: AsyncRead + AsyncWrite + 'static, H: 'static { - proto: Option>, + proto: Option>, } -impl HttpChannel - where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: HttpHandler + 'static +impl HttpChannel + where T: AsyncRead + AsyncWrite + 'static, H: HttpHandler + 'static { - pub fn new(stream: T, addr: A, router: Rc>, http2: bool) -> HttpChannel { + pub fn new(stream: T, addr: Option, router: Rc>, http2: bool) + -> HttpChannel { if http2 { HttpChannel { proto: Some(HttpProtocol::H2( @@ -54,14 +56,14 @@ impl HttpChannel } }*/ -impl Actor for HttpChannel - where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: HttpHandler + 'static +impl Actor for HttpChannel + where T: AsyncRead + AsyncWrite + 'static, H: HttpHandler + 'static { type Context = Context; } -impl Future for HttpChannel - where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: HttpHandler + 'static +impl Future for HttpChannel + where T: AsyncRead + AsyncWrite + 'static, H: HttpHandler + 'static { type Item = (); type Error = (); diff --git a/src/h1.rs b/src/h1.rs index 97f9aa086..ca13aa67d 100644 --- a/src/h1.rs +++ b/src/h1.rs @@ -1,6 +1,7 @@ use std::{self, io, ptr}; use std::rc::Rc; use std::cell::UnsafeCell; +use std::net::SocketAddr; use std::time::Duration; use std::collections::VecDeque; @@ -35,10 +36,10 @@ pub(crate) enum Http1Result { Switch, } -pub(crate) struct Http1 { +pub(crate) struct Http1 { router: Rc>, #[allow(dead_code)] - addr: A, + addr: Option, stream: H1Writer, reader: Reader, read_buf: BytesMut, @@ -57,12 +58,11 @@ struct Entry { finished: bool, } -impl Http1 +impl Http1 where T: AsyncRead + AsyncWrite + 'static, - A: 'static, H: HttpHandler + 'static { - pub fn new(stream: T, addr: A, router: Rc>) -> Self { + pub fn new(stream: T, addr: Option, router: Rc>) -> Self { Http1{ router: router, addr: addr, stream: H1Writer::new(stream), @@ -75,7 +75,7 @@ impl Http1 h2: false } } - pub fn into_inner(mut self) -> (T, A, Rc>, Bytes) { + pub fn into_inner(mut self) -> (T, Option, Rc>, Bytes) { (self.stream.unwrap(), self.addr, self.router, self.read_buf.freeze()) } @@ -172,6 +172,9 @@ impl Http1 Ok(Async::Ready(Item::Http1(mut req, payload))) => { not_ready = false; + // set remote addr + req.set_remove_addr(self.addr.clone()); + // stop keepalive timer self.keepalive_timer.take(); diff --git a/src/h2.rs b/src/h2.rs index 22ae0fe73..039a057b6 100644 --- a/src/h2.rs +++ b/src/h2.rs @@ -3,6 +3,7 @@ use std::rc::Rc; use std::io::{Read, Write}; use std::cell::UnsafeCell; use std::time::Duration; +use std::net::SocketAddr; use std::collections::VecDeque; use actix::Arbiter; @@ -24,12 +25,12 @@ use payload::{Payload, PayloadError, PayloadWriter}; const KEEPALIVE_PERIOD: u64 = 15; // seconds -pub(crate) struct Http2 - where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static +pub(crate) struct Http2 + where T: AsyncRead + AsyncWrite + 'static, H: 'static { router: Rc>, #[allow(dead_code)] - addr: A, + addr: Option, state: State>, disconnected: bool, tasks: VecDeque, @@ -42,12 +43,11 @@ enum State { Empty, } -impl Http2 +impl Http2 where T: AsyncRead + AsyncWrite + 'static, - A: 'static, H: HttpHandler + 'static { - pub fn new(stream: T, addr: A, router: Rc>, buf: Bytes) -> Self { + pub fn new(stream: T, addr: Option, router: Rc>, buf: Bytes) -> Self { Http2{ router: router, addr: addr, disconnected: false, @@ -132,12 +132,15 @@ impl Http2 entry.task.disconnected() } }, - Ok(Async::Ready(Some((req, resp)))) => { + Ok(Async::Ready(Some((mut req, resp)))) => { not_ready = false; let (parts, body) = req.into_parts(); - self.tasks.push_back( - Entry::new(parts, body, resp, &self.router)); + + // stop keepalive timer self.keepalive_timer.take(); + + self.tasks.push_back( + Entry::new(parts, body, resp, self.addr.clone(), &self.router)); } Ok(Async::NotReady) => { // start keep-alive timer @@ -210,6 +213,7 @@ impl Entry { fn new(parts: Parts, recv: RecvStream, resp: Respond, + addr: Option, router: &Rc>) -> Entry where H: HttpHandler + 'static { @@ -219,6 +223,9 @@ impl Entry { let mut req = HttpRequest::new( parts.method, path, parts.version, parts.headers, query); + // set remote addr + req.set_remove_addr(addr); + // Payload and Content-Encoding let (psender, payload) = Payload::new(false); diff --git a/src/httprequest.rs b/src/httprequest.rs index 5fdf6f040..e75131e5b 100644 --- a/src/httprequest.rs +++ b/src/httprequest.rs @@ -1,5 +1,6 @@ //! HTTP Request message related code. use std::{str, fmt}; +use std::net::SocketAddr; use std::collections::HashMap; use bytes::BytesMut; use futures::{Async, Future, Stream, Poll}; @@ -25,6 +26,7 @@ pub struct HttpRequest { cookies: Vec>, cookies_loaded: bool, extensions: Extensions, + addr: Option, } impl HttpRequest { @@ -43,6 +45,7 @@ impl HttpRequest { cookies: Vec::new(), cookies_loaded: false, extensions: Extensions::new(), + addr: None, } } @@ -57,6 +60,7 @@ impl HttpRequest { cookies: Vec::new(), cookies_loaded: false, extensions: Extensions::new(), + addr: None, } } @@ -86,6 +90,21 @@ impl HttpRequest { &self.path } + /// Remote IP of client initiated HTTP request. + /// + /// The IP is resolved through the following headers, in this order: + /// + /// - Forwarded + /// - X-Forwarded-For + /// - peername of opened socket + pub fn remote(&self) -> Option<&SocketAddr> { + self.addr.as_ref() + } + + pub(crate) fn set_remove_addr(&mut self, addr: Option) { + self.addr = addr + } + /// Return a new iterator that yields pairs of `Cow` for query parameters #[inline] pub fn query(&self) -> HashMap { diff --git a/src/server.rs b/src/server.rs index 45b50848a..d3fd36147 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,5 +1,6 @@ use std::{io, net}; use std::rc::Rc; +use std::net::SocketAddr; use std::marker::PhantomData; use actix::dev::*; @@ -7,6 +8,8 @@ use futures::Stream; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_core::net::{TcpListener, TcpStream}; +#[cfg(feature="tls")] +use futures::Future; #[cfg(feature="tls")] use native_tls::TlsAcceptor; #[cfg(feature="tls")] @@ -64,7 +67,7 @@ impl HttpServer S: Stream + 'static { Ok(HttpServer::create(move |ctx| { - ctx.add_stream(stream.map(|(t, a)| IoStream(t, a, false))); + ctx.add_stream(stream.map(|(t, _)| IoStream(t, None, false))); self })) } @@ -109,7 +112,7 @@ impl HttpServer { Ok(HttpServer::create(move |ctx| { for (addr, tcp) in addrs { info!("Starting http server on {}", addr); - ctx.add_stream(tcp.incoming().map(|(t, a)| IoStream(t, a, false))); + ctx.add_stream(tcp.incoming().map(|(t, a)| IoStream(t, Some(a), false))); } self })) @@ -146,7 +149,7 @@ impl HttpServer, net::SocketAddr, H> { ctx.add_stream(tcp.incoming().and_then(move |(stream, addr)| { TlsAcceptorExt::accept_async(acc.as_ref(), stream) .map(move |t| { - IoStream(t, addr) + IoStream(t, Some(addr), false) }) .map_err(|err| { trace!("Error during handling tls connection: {}", err); @@ -200,7 +203,7 @@ impl HttpServer, net::SocketAddr, H> { } else { false }; - IoStream(stream, addr, http2) + IoStream(stream, Some(addr), http2) }) .map_err(|err| { trace!("Error during handling tls connection: {}", err); @@ -213,32 +216,31 @@ impl HttpServer, net::SocketAddr, H> { } } -struct IoStream(T, A, bool); +struct IoStream(T, Option, bool); -impl ResponseType for IoStream - where T: AsyncRead + AsyncWrite + 'static, - A: 'static +impl ResponseType for IoStream + where T: AsyncRead + AsyncWrite + 'static { type Item = (); type Error = (); } -impl StreamHandler, io::Error> for HttpServer +impl StreamHandler, io::Error> for HttpServer where T: AsyncRead + AsyncWrite + 'static, - A: 'static, - H: HttpHandler + 'static {} - -impl Handler, io::Error> for HttpServer - where T: AsyncRead + AsyncWrite + 'static, - A: 'static, H: HttpHandler + 'static, + A: 'static {} + +impl Handler, io::Error> for HttpServer + where T: AsyncRead + AsyncWrite + 'static, + H: HttpHandler + 'static, + A: 'static, { fn error(&mut self, err: io::Error, _: &mut Context) { debug!("Error handling request: {}", err) } - fn handle(&mut self, msg: IoStream, _: &mut Context) - -> Response> + fn handle(&mut self, msg: IoStream, _: &mut Context) + -> Response> { Arbiter::handle().spawn( HttpChannel::new(msg.0, msg.1, Rc::clone(&self.h), msg.2));