From a63d3f9a7a0e5f4982404b66802c73eb9e6c65fd Mon Sep 17 00:00:00 2001
From: Nikolay Kim <fafhrd91@gmail.com>
Date: Sun, 9 Sep 2018 14:14:53 -0700
Subject: [PATCH] cleanup ServerFactory trait

---
 src/server/http.rs   | 130 ++++++++++++++++++++++++-------------------
 tests/test_server.rs |   1 +
 2 files changed, 75 insertions(+), 56 deletions(-)

diff --git a/src/server/http.rs b/src/server/http.rs
index 41161ed3f..5cdeb5642 100644
--- a/src/server/http.rs
+++ b/src/server/http.rs
@@ -351,28 +351,36 @@ where
         Ok(self)
     }
 
-    // /// Start listening for incoming connections with supplied acceptor.
-    // #[doc(hidden)]
-    // #[cfg_attr(feature = "cargo-clippy", allow(needless_pass_by_value))]
-    // pub fn bind_with<S, A>(mut self, addr: S, acceptor: A) -> io::Result<Self>
-    // where
-    //     S: net::ToSocketAddrs,
-    //     A: AcceptorService<TcpStream> + Send + 'static,
-    // {
-    //     let sockets = self.bind2(addr)?;
+    /// Start listening for incoming connections with supplied acceptor.
+    #[doc(hidden)]
+    #[cfg_attr(feature = "cargo-clippy", allow(needless_pass_by_value))]
+    pub fn bind_with<S, A>(mut self, addr: S, acceptor: A) -> io::Result<Self>
+    where
+        S: net::ToSocketAddrs,
+        A: AcceptorServiceFactory,
+    {
+        let sockets = self.bind2(addr)?;
 
-    //     for lst in sockets {
-    //         let token = Token(self.handlers.len());
-    //         let addr = lst.local_addr().unwrap();
-    //         self.handlers.push(Box::new(StreamHandler::new(
-    //             lst.local_addr().unwrap(),
-    //             acceptor.clone(),
-    //         )));
-    //         self.sockets.push(Socket { lst, addr, token })
-    //     }
+        for lst in sockets {
+            let addr = lst.local_addr().unwrap();
+            self.sockets.push(Socket {
+                lst,
+                addr,
+                scheme: "https",
+                handler: Box::new(HttpServiceBuilder::new(
+                    acceptor.clone(),
+                    DefaultPipelineFactory::new(
+                        self.factory.clone(),
+                        self.host.clone(),
+                        addr,
+                        self.keep_alive,
+                    ),
+                )),
+            });
+        }
 
-    //     Ok(self)
-    // }
+        Ok(self)
+    }
 
     fn bind2<S: net::ToSocketAddrs>(
         &self, addr: S,
@@ -416,25 +424,50 @@ where
     //     self.bind_with(addr, NativeTlsAcceptor::new(acceptor))
     // }
 
-    // #[cfg(feature = "alpn")]
-    // /// Start listening for incoming tls connections.
-    // ///
-    // /// This method sets alpn protocols to "h2" and "http/1.1"
-    // pub fn bind_ssl<S>(self, addr: S, builder: SslAcceptorBuilder) -> io::Result<Self>
-    // where
-    //     S: net::ToSocketAddrs,
-    // {
-    //     use super::{OpensslAcceptor, ServerFlags};
+    #[cfg(any(feature = "alpn", feature = "ssl"))]
+    /// Start listening for incoming tls connections.
+    ///
+    /// This method sets alpn protocols to "h2" and "http/1.1"
+    pub fn bind_ssl<S>(
+        mut self, addr: S, builder: SslAcceptorBuilder,
+    ) -> io::Result<Self>
+    where
+        S: net::ToSocketAddrs,
+    {
+        use super::{openssl_acceptor_with_flags, ServerFlags};
 
-    //     // alpn support
-    //     let flags = if !self.no_http2 {
-    //         ServerFlags::HTTP1
-    //     } else {
-    //         ServerFlags::HTTP1 | ServerFlags::HTTP2
-    //     };
+        let sockets = self.bind2(addr)?;
 
-    //     self.bind_with(addr, OpensslAcceptor::with_flags(builder, flags)?)
-    // }
+        // alpn support
+        let flags = if !self.no_http2 {
+            ServerFlags::HTTP1
+        } else {
+            ServerFlags::HTTP1 | ServerFlags::HTTP2
+        };
+
+        let acceptor = openssl_acceptor_with_flags(builder, flags)?;
+
+        for lst in sockets {
+            let addr = lst.local_addr().unwrap();
+            let accpt = acceptor.clone();
+            self.sockets.push(Socket {
+                lst,
+                addr,
+                scheme: "https",
+                handler: Box::new(HttpServiceBuilder::new(
+                    move || ssl::OpensslAcceptor::new(accpt.clone()).map_err(|_| ()),
+                    DefaultPipelineFactory::new(
+                        self.factory.clone(),
+                        self.host.clone(),
+                        addr,
+                        self.keep_alive,
+                    ),
+                )),
+            });
+        }
+
+        Ok(self)
+    }
 
     // #[cfg(feature = "rust-tls")]
     // /// Start listening for incoming tls connections.
@@ -500,13 +533,7 @@ impl<H: IntoHttpHandler, F: Fn() -> Vec<H> + Send + Clone> HttpServer<H, F> {
         let sockets = mem::replace(&mut self.sockets, Vec::new());
 
         for socket in sockets {
-            let Socket {
-                lst,
-                handler,
-                addr: _,
-                scheme: _,
-            } = socket;
-            srv = handler.register(srv, lst, self.host.clone(), self.keep_alive);
+            srv = socket.handler.register(srv, socket.lst);
         }
         srv.start()
     }
@@ -700,10 +727,7 @@ trait ServiceFactory<H>
 where
     H: IntoHttpHandler,
 {
-    fn register(
-        &self, server: Server, lst: net::TcpListener, host: Option<String>,
-        keep_alive: KeepAlive,
-    ) -> Server;
+    fn register(&self, server: Server, lst: net::TcpListener) -> Server;
 }
 
 struct SimpleFactory<H, F, P>
@@ -737,10 +761,7 @@ where
     F: Fn() -> Vec<H> + Send + Clone + 'static,
     P: HttpPipelineFactory<Io = TcpStream>,
 {
-    fn register(
-        &self, server: Server, lst: net::TcpListener, _host: Option<String>,
-        _keep_alive: KeepAlive,
-    ) -> Server {
+    fn register(&self, server: Server, lst: net::TcpListener) -> Server {
         let pipeline = self.pipeline.clone();
         server.listen(lst, move || pipeline.create())
     }
@@ -814,10 +835,7 @@ where
     P: HttpPipelineFactory<Io = A::Io>,
     H: IntoHttpHandler,
 {
-    fn register(
-        &self, server: Server, lst: net::TcpListener, _host: Option<String>,
-        _keep_alive: KeepAlive,
-    ) -> Server {
+    fn register(&self, server: Server, lst: net::TcpListener) -> Server {
         server.listen(lst, self.finish())
     }
 }
diff --git a/tests/test_server.rs b/tests/test_server.rs
index 30ee13fb3..41f4bcf39 100644
--- a/tests/test_server.rs
+++ b/tests/test_server.rs
@@ -9,6 +9,7 @@ extern crate h2;
 extern crate http as modhttp;
 extern crate rand;
 extern crate tokio;
+extern crate tokio_current_thread;
 extern crate tokio_reactor;
 extern crate tokio_tcp;
 extern crate tokio_current_thread as current_thread;