diff --git a/src/test.rs b/src/test.rs index faad063f..fe3422b5 100644 --- a/src/test.rs +++ b/src/test.rs @@ -182,7 +182,7 @@ impl TestServer { /// Connect to websocket server pub fn ws(&mut self) -> Result<(WsClientReader, WsClientWriter), WsClientError> { let url = self.url("/"); - self.system.run_until_complete(WsClient::new(url).connect().unwrap()) + self.system.run_until_complete(WsClient::new(url).connect()) } /// Create `GET` request diff --git a/src/ws/client.rs b/src/ws/client.rs index 2d385239..2023fdd2 100644 --- a/src/ws/client.rs +++ b/src/ws/client.rs @@ -31,9 +31,6 @@ use super::Message; use super::frame::Frame; use super::proto::{CloseCode, OpCode}; -pub type WsClientFuture = - Future; - /// Websocket client error #[derive(Fail, Debug)] @@ -140,7 +137,7 @@ impl WsClient { } /// Set cookie for handshake request - pub fn cookie<'c>(mut self, cookie: Cookie<'c>) -> Self { + pub fn cookie(mut self, cookie: Cookie) -> Self { self.request.cookie(cookie); self } @@ -165,49 +162,46 @@ impl WsClient { } /// Connect to websocket server and do ws handshake - pub fn connect(&mut self) -> Result, WsClientError> { + pub fn connect(&mut self) -> WsHandshake { if let Some(e) = self.err.take() { - return Err(e) + WsHandshake::new(None, Some(e), &self.conn) } - if let Some(e) = self.http_err.take() { - return Err(e.into()) - } - - // origin - if let Some(origin) = self.origin.take() { - self.request.set_header(header::ORIGIN, origin); - } - - self.request.upgrade(); - self.request.set_header(header::UPGRADE, "websocket"); - self.request.set_header(header::CONNECTION, "upgrade"); - self.request.set_header("SEC-WEBSOCKET-VERSION", "13"); - - if let Some(protocols) = self.protocols.take() { - self.request.set_header("SEC-WEBSOCKET-PROTOCOL", protocols.as_str()); - } - let request = self.request.finish()?; - - if request.uri().host().is_none() { - return Err(WsClientError::InvalidUrl) - } - if let Some(scheme) = request.uri().scheme_part() { - if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" { - return Err(WsClientError::InvalidUrl); - } + else if let Some(e) = self.http_err.take() { + WsHandshake::new(None, Some(e.into()), &self.conn) } else { - return Err(WsClientError::InvalidUrl); - } + // origin + if let Some(origin) = self.origin.take() { + self.request.set_header(header::ORIGIN, origin); + } - // get connection and start handshake - Ok(Box::new( - self.conn.send(Connect(request.uri().clone())) - .map_err(|_| WsClientError::Disconnected) - .and_then(|res| match res { - Ok(stream) => Either::A(WsHandshake::new(stream, request)), - Err(err) => Either::B(FutErr(err.into())), - }) - )) + self.request.upgrade(); + self.request.set_header(header::UPGRADE, "websocket"); + self.request.set_header(header::CONNECTION, "upgrade"); + self.request.set_header("SEC-WEBSOCKET-VERSION", "13"); + + if let Some(protocols) = self.protocols.take() { + self.request.set_header("SEC-WEBSOCKET-PROTOCOL", protocols.as_str()); + } + let request = match self.request.finish() { + Ok(req) => req, + Err(err) => return WsHandshake::new(None, Some(err.into()), &self.conn), + }; + + if request.uri().host().is_none() { + return WsHandshake::new(None, Some(WsClientError::InvalidUrl), &self.conn) + } + if let Some(scheme) = request.uri().scheme_part() { + if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" { + return WsHandshake::new( + None, Some(WsClientError::InvalidUrl), &self.conn) + } + } else { + return WsHandshake::new(None, Some(WsClientError::InvalidUrl), &self.conn) + } + + // start handshake + WsHandshake::new(Some(request), None, &self.conn) + } } } @@ -220,39 +214,53 @@ struct WsInner { error_sent: bool, } -struct WsHandshake { +pub struct WsHandshake { inner: Option, - request: ClientRequest, + request: Option, sent: bool, key: String, + error: Option, + stream: Option, Error=WsClientError>>>, } impl WsHandshake { - fn new(conn: Connection, mut request: ClientRequest) -> WsHandshake { + fn new(request: Option, + err: Option, + conn: &Addr) -> WsHandshake + { // Generate a random key for the `Sec-WebSocket-Key` header. // a base64-encoded (see Section 4 of [RFC4648]) value that, // when decoded, is 16 bytes in length (RFC 6455) let sec_key: [u8; 16] = rand::random(); let key = base64::encode(&sec_key); - request.headers_mut().insert( - HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(), - HeaderValue::try_from(key.as_str()).unwrap()); + if let Some(mut request) = request { + let stream = Box::new( + conn.send(Connect(request.uri().clone())) + .map(|res| res.map_err(|e| e.into())) + .map_err(|_| WsClientError::Disconnected)); - let inner = WsInner { - conn: conn, - writer: HttpClientWriter::new(SharedBytes::default()), - parser: HttpResponseParser::default(), - parser_buf: BytesMut::new(), - closed: false, - error_sent: false, - }; + request.headers_mut().insert( + HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(), + HeaderValue::try_from(key.as_str()).unwrap()); - WsHandshake { - key: key, - inner: Some(inner), - request: request, - sent: false, + WsHandshake { + key: key, + inner: None, + request: Some(request), + sent: false, + error: err, + stream: Some(stream), + } + } else { + WsHandshake { + key: key, + inner: None, + request: None, + sent: false, + error: err, + stream: None, + } } } } @@ -262,11 +270,36 @@ impl Future for WsHandshake { type Error = WsClientError; fn poll(&mut self) -> Poll { + if let Some(err) = self.error.take() { + return Err(err) + } + + if self.stream.is_some() { + match self.stream.as_mut().unwrap().poll()? { + Async::Ready(result) => match result { + Ok(conn) => { + let inner = WsInner { + conn: conn, + writer: HttpClientWriter::new(SharedBytes::default()), + parser: HttpResponseParser::default(), + parser_buf: BytesMut::new(), + closed: false, + error_sent: false, + }; + self.stream.take(); + self.inner = Some(inner); + } + Err(err) => return Err(err), + }, + Async::NotReady => return Ok(Async::NotReady) + } + } + let mut inner = self.inner.take().unwrap(); if !self.sent { self.sent = true; - inner.writer.start(&mut self.request)?; + inner.writer.start(self.request.as_mut().unwrap())?; } if let Err(err) = inner.writer.poll_completed(&mut inner.conn, false) { return Err(err.into()) diff --git a/src/ws/mod.rs b/src/ws/mod.rs index d9bf0f10..91258cab 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -65,7 +65,7 @@ use self::frame::Frame; use self::proto::{hash_key, OpCode}; pub use self::proto::CloseCode; pub use self::context::WebsocketContext; -pub use self::client::{WsClient, WsClientError, WsClientReader, WsClientWriter, WsClientFuture}; +pub use self::client::{WsClient, WsClientError, WsClientReader, WsClientWriter, WsHandshake}; const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT"; const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY";