From a855c8b2c9ced44d6c57e6246707f7a42743eb53 Mon Sep 17 00:00:00 2001
From: Nikolay Kim <fafhrd91@gmail.com>
Date: Sat, 24 Feb 2018 08:14:21 +0300
Subject: [PATCH] better ergonomics for WsClient::client()

---
 src/test.rs      |   2 +-
 src/ws/client.rs | 159 ++++++++++++++++++++++++++++-------------------
 src/ws/mod.rs    |   2 +-
 3 files changed, 98 insertions(+), 65 deletions(-)

diff --git a/src/test.rs b/src/test.rs
index faad063f2..fe3422b54 100644
--- a/src/test.rs
+++ b/src/test.rs
@@ -182,7 +182,7 @@ impl TestServer {
     /// Connect to websocket server
     pub fn ws(&mut self) -> Result<(WsClientReader, WsClientWriter), WsClientError> {
         let url = self.url("/");
-        self.system.run_until_complete(WsClient::new(url).connect().unwrap())
+        self.system.run_until_complete(WsClient::new(url).connect())
     }
 
     /// Create `GET` request
diff --git a/src/ws/client.rs b/src/ws/client.rs
index 2d385239d..2023fdd21 100644
--- a/src/ws/client.rs
+++ b/src/ws/client.rs
@@ -31,9 +31,6 @@ use super::Message;
 use super::frame::Frame;
 use super::proto::{CloseCode, OpCode};
 
-pub type WsClientFuture =
-    Future<Item=(WsClientReader, WsClientWriter), Error=WsClientError>;
-
 
 /// Websocket client error
 #[derive(Fail, Debug)]
@@ -140,7 +137,7 @@ impl WsClient {
     }
 
     /// Set cookie for handshake request
-    pub fn cookie<'c>(mut self, cookie: Cookie<'c>) -> Self {
+    pub fn cookie(mut self, cookie: Cookie) -> Self {
         self.request.cookie(cookie);
         self
     }
@@ -165,49 +162,46 @@ impl WsClient {
     }
 
     /// Connect to websocket server and do ws handshake
-    pub fn connect(&mut self) -> Result<Box<WsClientFuture>, WsClientError> {
+    pub fn connect(&mut self) -> WsHandshake {
         if let Some(e) = self.err.take() {
-            return Err(e)
+            WsHandshake::new(None, Some(e), &self.conn)
         }
-        if let Some(e) = self.http_err.take() {
-            return Err(e.into())
-        }
-
-        // origin
-        if let Some(origin) = self.origin.take() {
-            self.request.set_header(header::ORIGIN, origin);
-        }
-
-        self.request.upgrade();
-        self.request.set_header(header::UPGRADE, "websocket");
-        self.request.set_header(header::CONNECTION, "upgrade");
-        self.request.set_header("SEC-WEBSOCKET-VERSION", "13");
-
-        if let Some(protocols) = self.protocols.take() {
-            self.request.set_header("SEC-WEBSOCKET-PROTOCOL", protocols.as_str());
-        }
-        let request = self.request.finish()?;
-
-        if request.uri().host().is_none() {
-            return Err(WsClientError::InvalidUrl)
-        }
-        if let Some(scheme) = request.uri().scheme_part() {
-            if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" {
-                return Err(WsClientError::InvalidUrl);
-            }
+        else if let Some(e) = self.http_err.take() {
+            WsHandshake::new(None, Some(e.into()), &self.conn)
         } else {
-            return Err(WsClientError::InvalidUrl);
-        }
+            // origin
+            if let Some(origin) = self.origin.take() {
+                self.request.set_header(header::ORIGIN, origin);
+            }
 
-        // get connection and start handshake
-        Ok(Box::new(
-            self.conn.send(Connect(request.uri().clone()))
-                .map_err(|_| WsClientError::Disconnected)
-                .and_then(|res| match res {
-                    Ok(stream) => Either::A(WsHandshake::new(stream, request)),
-                    Err(err) => Either::B(FutErr(err.into())),
-                })
-        ))
+            self.request.upgrade();
+            self.request.set_header(header::UPGRADE, "websocket");
+            self.request.set_header(header::CONNECTION, "upgrade");
+            self.request.set_header("SEC-WEBSOCKET-VERSION", "13");
+
+            if let Some(protocols) = self.protocols.take() {
+                self.request.set_header("SEC-WEBSOCKET-PROTOCOL", protocols.as_str());
+            }
+            let request = match self.request.finish() {
+                Ok(req) => req,
+                Err(err) => return WsHandshake::new(None, Some(err.into()), &self.conn),
+            };
+
+            if request.uri().host().is_none() {
+                return WsHandshake::new(None, Some(WsClientError::InvalidUrl), &self.conn)
+            }
+            if let Some(scheme) = request.uri().scheme_part() {
+                if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" {
+                    return WsHandshake::new(
+                        None, Some(WsClientError::InvalidUrl), &self.conn)
+                }
+            } else {
+                return WsHandshake::new(None, Some(WsClientError::InvalidUrl), &self.conn)
+            }
+
+            // start handshake
+            WsHandshake::new(Some(request), None, &self.conn)
+        }
     }
 }
 
@@ -220,39 +214,53 @@ struct WsInner {
     error_sent: bool,
 }
 
-struct WsHandshake {
+pub struct WsHandshake {
     inner: Option<WsInner>,
-    request: ClientRequest,
+    request: Option<ClientRequest>,
     sent: bool,
     key: String,
+    error: Option<WsClientError>,
+    stream: Option<Box<Future<Item=Result<Connection, WsClientError>, Error=WsClientError>>>,
 }
 
 impl WsHandshake {
-    fn new(conn: Connection, mut request: ClientRequest) -> WsHandshake {
+    fn new(request: Option<ClientRequest>,
+           err: Option<WsClientError>,
+           conn: &Addr<Unsync, ClientConnector>) -> WsHandshake
+    {
         // Generate a random key for the `Sec-WebSocket-Key` header.
         // a base64-encoded (see Section 4 of [RFC4648]) value that,
         // when decoded, is 16 bytes in length (RFC 6455)
         let sec_key: [u8; 16] = rand::random();
         let key = base64::encode(&sec_key);
 
-        request.headers_mut().insert(
-            HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(),
-            HeaderValue::try_from(key.as_str()).unwrap());
+        if let Some(mut request) = request {
+            let stream = Box::new(
+                conn.send(Connect(request.uri().clone()))
+                    .map(|res| res.map_err(|e| e.into()))
+                    .map_err(|_| WsClientError::Disconnected));
 
-        let inner = WsInner {
-            conn: conn,
-            writer: HttpClientWriter::new(SharedBytes::default()),
-            parser: HttpResponseParser::default(),
-            parser_buf: BytesMut::new(),
-            closed: false,
-            error_sent: false,
-        };
+            request.headers_mut().insert(
+                HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(),
+                HeaderValue::try_from(key.as_str()).unwrap());
 
-        WsHandshake {
-            key: key,
-            inner: Some(inner),
-            request: request,
-            sent: false,
+            WsHandshake {
+                key: key,
+                inner: None,
+                request: Some(request),
+                sent: false,
+                error: err,
+                stream: Some(stream),
+            }
+        } else {
+            WsHandshake {
+                key: key,
+                inner: None,
+                request: None,
+                sent: false,
+                error: err,
+                stream: None,
+            }
         }
     }
 }
@@ -262,11 +270,36 @@ impl Future for WsHandshake {
     type Error = WsClientError;
 
     fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
+        if let Some(err) = self.error.take() {
+            return Err(err)
+        }
+
+        if self.stream.is_some() {
+            match self.stream.as_mut().unwrap().poll()? {
+                Async::Ready(result) => match result {
+                    Ok(conn) => {
+                        let inner = WsInner {
+                            conn: conn,
+                            writer: HttpClientWriter::new(SharedBytes::default()),
+                            parser: HttpResponseParser::default(),
+                            parser_buf: BytesMut::new(),
+                            closed: false,
+                            error_sent: false,
+                        };
+                        self.stream.take();
+                        self.inner = Some(inner);
+                    }
+                    Err(err) => return Err(err),
+                },
+                Async::NotReady => return Ok(Async::NotReady)
+            }
+        }
+
         let mut inner = self.inner.take().unwrap();
 
         if !self.sent {
             self.sent = true;
-            inner.writer.start(&mut self.request)?;
+            inner.writer.start(self.request.as_mut().unwrap())?;
         }
         if let Err(err) = inner.writer.poll_completed(&mut inner.conn, false) {
             return Err(err.into())
diff --git a/src/ws/mod.rs b/src/ws/mod.rs
index d9bf0f103..91258cabe 100644
--- a/src/ws/mod.rs
+++ b/src/ws/mod.rs
@@ -65,7 +65,7 @@ use self::frame::Frame;
 use self::proto::{hash_key, OpCode};
 pub use self::proto::CloseCode;
 pub use self::context::WebsocketContext;
-pub use self::client::{WsClient, WsClientError, WsClientReader, WsClientWriter, WsClientFuture};
+pub use self::client::{WsClient, WsClientError, WsClientReader, WsClientWriter, WsHandshake};
 
 const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT";
 const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY";