From 9559f6a1759c3605ef1243a510ff36ec15018586 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 3 Jan 2018 23:41:55 -0800 Subject: [PATCH] introduce IoStream trait for low level stream operations --- src/channel.rs | 186 ++++++++++++++++++++++++++++++++++++++++--------- src/h1.rs | 47 +++++++++---- src/server.rs | 20 +++--- src/worker.rs | 4 +- 4 files changed, 201 insertions(+), 56 deletions(-) diff --git a/src/channel.rs b/src/channel.rs index 963ef1065..baae32fa0 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -1,8 +1,8 @@ -use std::{ptr, mem, time}; +use std::{ptr, mem, time, io}; use std::rc::Rc; use std::net::{SocketAddr, Shutdown}; -use bytes::Bytes; +use bytes::{Bytes, Buf, BufMut}; use futures::{Future, Poll, Async}; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_core::net::TcpStream; @@ -48,8 +48,7 @@ impl IntoHttpHandler for T { } } -enum HttpProtocol - where T: AsyncRead + AsyncWrite + 'static, H: HttpHandler + 'static +enum HttpProtocol { H1(h1::Http1), H2(h2::Http2), @@ -57,22 +56,14 @@ enum HttpProtocol #[doc(hidden)] pub struct HttpChannel - where T: AsyncRead + AsyncWrite + 'static, H: HttpHandler + 'static + where T: IoStream, H: HttpHandler + 'static { proto: Option>, node: Option>>, } -impl Drop for HttpChannel - where T: AsyncRead + AsyncWrite + 'static, H: HttpHandler + 'static -{ - fn drop(&mut self) { - self.shutdown() - } -} - impl HttpChannel - where T: AsyncRead + AsyncWrite + 'static, H: HttpHandler + 'static + where T: IoStream, H: HttpHandler + 'static { pub(crate) fn new(h: Rc>, io: T, peer: Option, http2: bool) -> HttpChannel @@ -91,19 +82,12 @@ impl HttpChannel } } - fn io(&mut self) -> Option<&mut T> { - match self.proto { - Some(HttpProtocol::H1(ref mut h1)) => { - Some(h1.io()) - } - _ => None, - } - } - fn shutdown(&mut self) { match self.proto { Some(HttpProtocol::H1(ref mut h1)) => { - let _ = h1.io().shutdown(); + let io = h1.io(); + let _ = IoStream::set_linger(io, Some(time::Duration::new(0, 0))); + let _ = IoStream::shutdown(io, Shutdown::Both); } Some(HttpProtocol::H2(ref mut h2)) => { h2.shutdown() @@ -122,7 +106,7 @@ impl HttpChannel }*/ impl Future for HttpChannel - where T: AsyncRead + AsyncWrite + 'static, H: HttpHandler + 'static + where T: IoStream, H: HttpHandler + 'static { type Item = (); type Error = (); @@ -242,7 +226,7 @@ impl Node<()> { } } - pub(crate) fn traverse(&self) where H: HttpHandler + 'static { + pub(crate) fn traverse(&self) where T: IoStream, H: HttpHandler + 'static { let mut next = self.next.as_ref(); loop { if let Some(n) = next { @@ -251,13 +235,8 @@ impl Node<()> { next = n.next.as_ref(); if !n.element.is_null() { - let ch: &mut HttpChannel = mem::transmute( + let ch: &mut HttpChannel = mem::transmute( &mut *(n.element as *mut _)); - if let Some(io) = ch.io() { - let _ = TcpStream::set_linger(io, Some(time::Duration::new(0, 0))); - let _ = TcpStream::shutdown(io, Shutdown::Both); - continue; - } ch.shutdown(); } } @@ -267,3 +246,146 @@ impl Node<()> { } } } + + +pub trait IoStream: AsyncRead + AsyncWrite + 'static { + fn shutdown(&mut self, how: Shutdown) -> io::Result<()>; + + fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()>; + + fn set_linger(&mut self, dur: Option) -> io::Result<()>; +} + +impl IoStream for TcpStream { + #[inline] + fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { + TcpStream::shutdown(self, how) + } + + #[inline] + fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { + TcpStream::set_nodelay(self, nodelay) + } + + #[inline] + fn set_linger(&mut self, dur: Option) -> io::Result<()> { + TcpStream::set_linger(self, dur) + } +} + + +pub(crate) struct WrapperStream where T: AsyncRead + AsyncWrite + 'static { + io: T, +} + +impl WrapperStream where T: AsyncRead + AsyncWrite + 'static +{ + pub fn new(io: T) -> Self { + WrapperStream{io: io} + } +} + +impl IoStream for WrapperStream + where T: AsyncRead + AsyncWrite + 'static +{ + #[inline] + fn shutdown(&mut self, _: Shutdown) -> io::Result<()> { + Ok(()) + } + + #[inline] + fn set_nodelay(&mut self, _: bool) -> io::Result<()> { + Ok(()) + } + + #[inline] + fn set_linger(&mut self, _: Option) -> io::Result<()> { + Ok(()) + } +} + +impl io::Read for WrapperStream + where T: AsyncRead + AsyncWrite + 'static +{ + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.io.read(buf) + } +} + +impl io::Write for WrapperStream + where T: AsyncRead + AsyncWrite + 'static +{ + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result { + self.io.write(buf) + } + #[inline] + fn flush(&mut self) -> io::Result<()> { + self.io.flush() + } +} + +impl AsyncRead for WrapperStream + where T: AsyncRead + AsyncWrite + 'static +{ + fn read_buf(&mut self, buf: &mut B) -> Poll { + self.io.read_buf(buf) + } +} + +impl AsyncWrite for WrapperStream + where T: AsyncRead + AsyncWrite + 'static +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.io.shutdown() + } + fn write_buf(&mut self, buf: &mut B) -> Poll { + self.io.write_buf(buf) + } +} + + +#[cfg(feature="alpn")] +use tokio_openssl::SslStream; + +#[cfg(feature="alpn")] +impl IoStream for SslStream { + #[inline] + fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { + let _ = self.get_mut().shutdown(); + Ok(()) + } + + #[inline] + fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { + self.get_mut().get_mut().set_nodelay(nodelay) + } + + #[inline] + fn set_linger(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().get_mut().set_linger(dur) + } +} + +#[cfg(feature="tls")] +use tokio_tls::TlsStream; + +#[cfg(feature="tls")] +impl IoStream for TlsStream { + #[inline] + fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { + let _ = self.get_mut().shutdown(); + Ok(()) + } + + #[inline] + fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { + self.get_mut().get_mut().set_nodelay(nodelay) + } + + #[inline] + fn set_linger(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().get_mut().set_linger(dur) + } +} diff --git a/src/h1.rs b/src/h1.rs index e0358a159..e5f592dd2 100644 --- a/src/h1.rs +++ b/src/h1.rs @@ -10,12 +10,11 @@ use http::{Uri, Method, Version, HttpTryFrom, HeaderMap}; use http::header::{self, HeaderName, HeaderValue}; use bytes::{Bytes, BytesMut, BufMut}; use futures::{Future, Poll, Async}; -use tokio_io::{AsyncRead, AsyncWrite}; use tokio_core::reactor::Timeout; use pipeline::Pipeline; use encoding::PayloadType; -use channel::{HttpHandler, HttpHandlerTask}; +use channel::{HttpHandler, HttpHandlerTask, IoStream}; use h1writer::{Writer, H1Writer}; use worker::WorkerSettings; use httpcodes::HTTPNotFound; @@ -57,7 +56,7 @@ enum Item { Http2, } -pub(crate) struct Http1 { +pub(crate) struct Http1 { flags: Flags, settings: Rc>, addr: Option, @@ -74,8 +73,7 @@ struct Entry { } impl Http1 - where T: AsyncRead + AsyncWrite + 'static, - H: HttpHandler + 'static + where T: IoStream, H: HttpHandler + 'static { pub fn new(h: Rc>, stream: T, addr: Option) -> Self { let bytes = h.get_shared_bytes(); @@ -417,7 +415,7 @@ impl Reader { pub fn parse(&mut self, io: &mut T, buf: &mut BytesMut, settings: &WorkerSettings) -> Poll - where T: AsyncRead + where T: IoStream { // read payload if self.payload.is_some() { @@ -507,8 +505,8 @@ impl Reader { } } - fn read_from_io(&mut self, io: &mut T, buf: &mut BytesMut) - -> Poll + fn read_from_io(&mut self, io: &mut T, buf: &mut BytesMut) + -> Poll { unsafe { if buf.remaining_mut() < LW_BUFFER_SIZE { @@ -894,14 +892,17 @@ impl ChunkedState { #[cfg(test)] mod tests { - use std::{io, cmp}; - use bytes::{Bytes, BytesMut}; - use futures::{Async}; - use tokio_io::AsyncRead; + use std::{io, cmp, time}; + use std::net::Shutdown; + use bytes::{Bytes, BytesMut, Buf}; + use futures::Async; + use tokio_io::{AsyncRead, AsyncWrite}; use http::{Version, Method}; + use super::*; use application::HttpApplication; use worker::WorkerSettings; + use channel::IoStream; struct Buffer { buf: Bytes, @@ -940,6 +941,28 @@ mod tests { } } + impl IoStream for Buffer { + fn shutdown(&self, _: Shutdown) -> io::Result<()> { + Ok(()) + } + fn set_nodelay(&self, _: bool) -> io::Result<()> { + Ok(()) + } + fn set_linger(&self, _: Option) -> io::Result<()> { + Ok(()) + } + } + impl io::Write for Buffer { + fn write(&mut self, buf: &[u8]) -> io::Result {Ok(buf.len())} + fn flush(&mut self) -> io::Result<()> {Ok(())} + } + impl AsyncWrite for Buffer { + fn shutdown(&mut self) -> Poll<(), io::Error> { Ok(Async::Ready(())) } + fn write_buf(&mut self, _: &mut B) -> Poll { + Ok(Async::NotReady) + } + } + macro_rules! not_ready { ($e:expr) => (match $e { Ok(Async::NotReady) => (), diff --git a/src/server.rs b/src/server.rs index 1833e8ae2..d602e2769 100644 --- a/src/server.rs +++ b/src/server.rs @@ -31,7 +31,7 @@ use tokio_openssl::SslStream; use actix::actors::signal; use helpers; -use channel::{HttpChannel, HttpHandler, IntoHttpHandler}; +use channel::{HttpChannel, HttpHandler, IntoHttpHandler, IoStream, WrapperStream}; use worker::{Conn, Worker, WorkerSettings, StreamHandlerType, StopWorker}; /// Various server settings @@ -131,7 +131,7 @@ impl HttpServer HttpServer where A: 'static, - T: AsyncRead + AsyncWrite + 'static, + T: IoStream, H: HttpHandler, U: IntoIterator + 'static, V: IntoHttpHandler, @@ -450,7 +450,7 @@ impl HttpServer, net::SocketAddr, H, } } -impl HttpServer +impl HttpServer, A, H, U> where A: 'static, T: AsyncRead + AsyncWrite + 'static, H: HttpHandler, @@ -488,7 +488,7 @@ impl HttpServer // start server HttpServer::create(move |ctx| { ctx.add_stream(stream.map( - move |(t, _)| Conn{io: t, peer: None, http2: false})); + move |(t, _)| Conn{io: WrapperStream::new(t), peer: None, http2: false})); self }) } @@ -499,7 +499,7 @@ impl HttpServer /// Handle `SIGINT`, `SIGTERM`, `SIGQUIT` signals and send `SystemExit(0)` /// message to `System` actor. impl Handler for HttpServer - where T: AsyncRead + AsyncWrite + 'static, + where T: IoStream, H: HttpHandler + 'static, U: 'static, A: 'static, @@ -530,13 +530,13 @@ impl Handler for HttpServer } impl StreamHandler, io::Error> for HttpServer - where T: AsyncRead + AsyncWrite + 'static, + where T: IoStream, H: HttpHandler + 'static, U: 'static, A: 'static {} impl Handler, io::Error> for HttpServer - where T: AsyncRead + AsyncWrite + 'static, + where T: IoStream, H: HttpHandler + 'static, U: 'static, A: 'static, @@ -573,7 +573,7 @@ pub struct StopServer { } impl Handler for HttpServer - where T: AsyncRead + AsyncWrite + 'static, + where T: IoStream, H: HttpHandler + 'static, U: 'static, A: 'static, @@ -589,7 +589,7 @@ impl Handler for HttpServer } impl Handler for HttpServer - where T: AsyncRead + AsyncWrite + 'static, + where T: IoStream, H: HttpHandler + 'static, U: 'static, A: 'static, @@ -605,7 +605,7 @@ impl Handler for HttpServer } impl Handler for HttpServer - where T: AsyncRead + AsyncWrite + 'static, + where T: IoStream, H: HttpHandler + 'static, U: 'static, A: 'static, diff --git a/src/worker.rs b/src/worker.rs index c6127d2a0..d0f73f63b 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -135,7 +135,7 @@ impl Worker { slf.shutdown_timeout(ctx, tx, d); } else { info!("Force shutdown http worker, {} connections", num); - slf.settings.head().traverse::(); + slf.settings.head().traverse::(); let _ = tx.send(false); Arbiter::arbiter().send(StopArbiter(0)); } @@ -187,7 +187,7 @@ impl Handler for Worker Self::async_reply(rx.map_err(|_| ()).actfuture()) } else { info!("Force shutdown http worker, {} connections", num); - self.settings.head().traverse::(); + self.settings.head().traverse::(); Self::reply(false) } }