diff --git a/src/client/request.rs b/src/client/request.rs index 392ef6b9..92960d7f 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -27,6 +27,14 @@ pub struct ClientRequest { encoding: ContentEncoding, response_decompress: bool, buffer_capacity: Option<(usize, usize)>, + conn: ConnectionType, + +} + +enum ConnectionType { + Default, + Connector(Addr), + Connection(Connection), } impl Default for ClientRequest { @@ -43,6 +51,7 @@ impl Default for ClientRequest { encoding: ContentEncoding::Auto, response_decompress: true, buffer_capacity: None, + conn: ConnectionType::Default, } } } @@ -190,18 +199,14 @@ impl ClientRequest { } /// Send request - pub fn send(self) -> SendRequest { - SendRequest::new(self) - } - - /// Send request using custom connector - pub fn with_connector(self, conn: Addr) -> SendRequest { - SendRequest::with_connector(self, conn) - } - - /// Send request using existing Connection - pub fn with_connection(self, conn: Connection) -> SendRequest { - SendRequest::with_connection(self, conn) + /// + /// This method returns future that resolves to a ClientResponse + pub fn send(mut self) -> SendRequest { + match mem::replace(&mut self.conn, ConnectionType::Default) { + ConnectionType::Default => SendRequest::new(self), + ConnectionType::Connector(conn) => SendRequest::with_connector(self, conn), + ConnectionType::Connection(conn) => SendRequest::with_connection(self, conn), + } } } @@ -451,6 +456,22 @@ impl ClientRequestBuilder { self } + /// Send request using custom connector + pub fn with_connector(&mut self, conn: Addr) -> &mut Self { + if let Some(parts) = parts(&mut self.request, &self.err) { + parts.conn = ConnectionType::Connector(conn); + } + self + } + + /// Send request using existing Connection + pub fn with_connection(&mut self, conn: Connection) -> &mut Self { + if let Some(parts) = parts(&mut self.request, &self.err) { + parts.conn = ConnectionType::Connection(conn); + } + self + } + /// This method calls provided closure with builder reference if value is true. pub fn if_true(&mut self, value: bool, f: F) -> &mut Self where F: FnOnce(&mut ClientRequestBuilder) diff --git a/src/ws/client.rs b/src/ws/client.rs index 15239146..bd5d8f8b 100644 --- a/src/ws/client.rs +++ b/src/ws/client.rs @@ -194,6 +194,7 @@ impl WsClient { self.request.set_header(header::UPGRADE, "websocket"); self.request.set_header(header::CONNECTION, "upgrade"); self.request.set_header("SEC-WEBSOCKET-VERSION", "13"); + self.request.with_connector(self.conn.clone()); if let Some(protocols) = self.protocols.take() { self.request.set_header("SEC-WEBSOCKET-PROTOCOL", protocols.as_str()); @@ -215,7 +216,7 @@ impl WsClient { } // start handshake - WsClientHandshake::new(request, &self.conn, self.max_size) + WsClientHandshake::new(request, self.max_size) } } } @@ -235,8 +236,7 @@ pub struct WsClientHandshake { } impl WsClientHandshake { - fn new(mut request: ClientRequest, - conn: &Addr, max_size: usize) -> WsClientHandshake + fn new(mut request: ClientRequest, max_size: usize) -> WsClientHandshake { // Generate a random key for the `Sec-WebSocket-Key` header. // a base64-encoded (see Section 4 of [RFC4648]) value that, @@ -256,7 +256,7 @@ impl WsClientHandshake { WsClientHandshake { key, max_size, - request: Some(request.with_connector(conn.clone())), + request: Some(request.send()), tx: Some(tx), error: None, }