//! Websockets client use std::cell::RefCell; use std::io::Write; use std::rc::Rc; use std::{fmt, str}; use actix_codec::Framed; use actix_http::{ws, Payload, RequestHead}; use bytes::{BufMut, BytesMut}; #[cfg(feature = "cookies")] use cookie::{Cookie, CookieJar}; use futures::future::{err, Either, Future}; use crate::connect::{BoxedSocket, Connect}; use crate::error::{InvalidUrl, WsClientError}; use crate::http::header::{ self, HeaderName, HeaderValue, IntoHeaderValue, AUTHORIZATION, }; use crate::http::{ ConnectionType, Error as HttpError, HttpTryFrom, Method, StatusCode, Uri, Version, }; use crate::response::ClientResponse; /// `WebSocket` connection pub struct WebsocketsRequest { head: RequestHead, err: Option, origin: Option, protocols: Option, max_size: usize, server_mode: bool, default_headers: bool, #[cfg(feature = "cookies")] cookies: Option, connector: Rc>, } impl WebsocketsRequest { /// Create new websocket connection pub(crate) fn new(uri: U, connector: Rc>) -> Self where Uri: HttpTryFrom, { let mut err = None; let mut head = RequestHead::default(); head.method = Method::GET; head.version = Version::HTTP_11; match Uri::try_from(uri) { Ok(uri) => head.uri = uri, Err(e) => err = Some(e.into()), } WebsocketsRequest { head, err, connector, origin: None, protocols: None, max_size: 65_536, server_mode: false, #[cfg(feature = "cookies")] cookies: None, default_headers: true, } } /// Set supported websocket protocols pub fn protocols(mut self, protos: U) -> Self where U: IntoIterator + 'static, V: AsRef, { let mut protos = protos .into_iter() .fold(String::new(), |acc, s| acc + s.as_ref() + ","); protos.pop(); self.protocols = Some(protos); self } #[cfg(feature = "cookies")] /// Set a cookie pub fn cookie<'c>(mut self, cookie: Cookie<'c>) -> Self { if self.cookies.is_none() { let mut jar = CookieJar::new(); jar.add(cookie.into_owned()); self.cookies = Some(jar) } else { self.cookies.as_mut().unwrap().add(cookie.into_owned()); } self } /// Set request Origin pub fn origin(mut self, origin: V) -> Self where HeaderValue: HttpTryFrom, { match HeaderValue::try_from(origin) { Ok(value) => self.origin = Some(value), Err(e) => self.err = Some(e.into()), } self } /// Set max frame size /// /// By default max size is set to 64kb pub fn max_frame_size(mut self, size: usize) -> Self { self.max_size = size; self } /// Disable payload masking. By default ws client masks frame payload. pub fn server_mode(mut self) -> Self { self.server_mode = true; self } /// Do not add default request headers. /// By default `Date` and `User-Agent` headers are set. pub fn no_default_headers(mut self) -> Self { self.default_headers = false; self } /// Append a header. /// /// Header gets appended to existing header. /// To override header use `set_header()` method. pub fn header(mut self, key: K, value: V) -> Self where HeaderName: HttpTryFrom, V: IntoHeaderValue, { match HeaderName::try_from(key) { Ok(key) => match value.try_into() { Ok(value) => { self.head.headers.append(key, value); } Err(e) => self.err = Some(e.into()), }, Err(e) => self.err = Some(e.into()), } self } /// Insert a header, replaces existing header. pub fn set_header(mut self, key: K, value: V) -> Self where HeaderName: HttpTryFrom, V: IntoHeaderValue, { match HeaderName::try_from(key) { Ok(key) => match value.try_into() { Ok(value) => { self.head.headers.insert(key, value); } Err(e) => self.err = Some(e.into()), }, Err(e) => self.err = Some(e.into()), } self } /// Insert a header only if it is not yet set. pub fn set_header_if_none(mut self, key: K, value: V) -> Self where HeaderName: HttpTryFrom, V: IntoHeaderValue, { match HeaderName::try_from(key) { Ok(key) => { if !self.head.headers.contains_key(&key) { match value.try_into() { Ok(value) => { self.head.headers.insert(key, value); } Err(e) => self.err = Some(e.into()), } } } Err(e) => self.err = Some(e.into()), } self } /// Set HTTP basic authorization header pub fn basic_auth(self, username: U, password: Option

) -> Self where U: fmt::Display, P: fmt::Display, { let auth = match password { Some(password) => format!("{}:{}", username, password), None => format!("{}", username), }; self.header(AUTHORIZATION, format!("Basic {}", base64::encode(&auth))) } /// Set HTTP bearer authentication header pub fn bearer_auth(self, token: T) -> Self where T: fmt::Display, { self.header(AUTHORIZATION, format!("Bearer {}", token)) } /// Complete request construction and connect. pub fn connect( mut self, ) -> impl Future< Item = (ClientResponse, Framed), Error = WsClientError, > { if let Some(e) = self.err.take() { return Either::A(err(e.into())); } // validate uri let uri = &self.head.uri; if uri.host().is_none() { return Either::A(err(InvalidUrl::MissingHost.into())); } else if uri.scheme_part().is_none() { return Either::A(err(InvalidUrl::MissingScheme.into())); } else if let Some(scheme) = uri.scheme_part() { match scheme.as_str() { "http" | "ws" | "https" | "wss" => (), _ => return Either::A(err(InvalidUrl::UnknownScheme.into())), } } else { return Either::A(err(InvalidUrl::UnknownScheme.into())); } // set default headers let mut slf = if self.default_headers { // set request host header if let Some(host) = self.head.uri.host() { if !self.head.headers.contains_key(header::HOST) { let mut wrt = BytesMut::with_capacity(host.len() + 5).writer(); let _ = match self.head.uri.port_u16() { None | Some(80) | Some(443) => write!(wrt, "{}", host), Some(port) => write!(wrt, "{}:{}", host, port), }; match wrt.get_mut().take().freeze().try_into() { Ok(value) => { self.head.headers.insert(header::HOST, value); } Err(e) => return Either::A(err(HttpError::from(e).into())), } } } // user agent self.set_header_if_none( header::USER_AGENT, concat!("awc/", env!("CARGO_PKG_VERSION")), ) } else { self }; #[allow(unused_mut)] let mut head = slf.head; #[cfg(feature = "cookies")] { use percent_encoding::{percent_encode, USERINFO_ENCODE_SET}; use std::fmt::Write; // set cookies if let Some(ref mut jar) = slf.cookies { let mut cookie = String::new(); for c in jar.delta() { let name = percent_encode(c.name().as_bytes(), USERINFO_ENCODE_SET); let value = percent_encode(c.value().as_bytes(), USERINFO_ENCODE_SET); let _ = write!(&mut cookie, "; {}={}", name, value); } head.headers.insert( header::COOKIE, HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(), ); } } // origin if let Some(origin) = slf.origin.take() { head.headers.insert(header::ORIGIN, origin); } head.set_connection_type(ConnectionType::Upgrade); head.headers .insert(header::UPGRADE, HeaderValue::from_static("websocket")); head.headers.insert( header::SEC_WEBSOCKET_VERSION, HeaderValue::from_static("13"), ); if let Some(protocols) = slf.protocols.take() { head.headers.insert( header::SEC_WEBSOCKET_PROTOCOL, HeaderValue::try_from(protocols.as_str()).unwrap(), ); } // Generate a random key for the `Sec-WebSocket-Key` header. // a base64-encoded (see Section 4 of [RFC4648]) value that, // when decoded, is 16 bytes in length (RFC 6455) let sec_key: [u8; 16] = rand::random(); let key = base64::encode(&sec_key); head.headers.insert( header::SEC_WEBSOCKET_KEY, HeaderValue::try_from(key.as_str()).unwrap(), ); let max_size = slf.max_size; let server_mode = slf.server_mode; let fut = slf .connector .borrow_mut() .open_tunnel(head) .from_err() .and_then(move |(head, framed)| { // verify response if head.status != StatusCode::SWITCHING_PROTOCOLS { return Err(WsClientError::InvalidResponseStatus(head.status)); } // Check for "UPGRADE" to websocket header let has_hdr = if let Some(hdr) = head.headers.get(header::UPGRADE) { if let Ok(s) = hdr.to_str() { s.to_ascii_lowercase().contains("websocket") } else { false } } else { false }; if !has_hdr { log::trace!("Invalid upgrade header"); return Err(WsClientError::InvalidUpgradeHeader); } // Check for "CONNECTION" header if let Some(conn) = head.headers.get(header::CONNECTION) { if let Ok(s) = conn.to_str() { if !s.to_ascii_lowercase().contains("upgrade") { log::trace!("Invalid connection header: {}", s); return Err(WsClientError::InvalidConnectionHeader( conn.clone(), )); } } else { log::trace!("Invalid connection header: {:?}", conn); return Err(WsClientError::InvalidConnectionHeader(conn.clone())); } } else { log::trace!("Missing connection header"); return Err(WsClientError::MissingConnectionHeader); } if let Some(hdr_key) = head.headers.get(header::SEC_WEBSOCKET_ACCEPT) { let encoded = ws::hash_key(key.as_ref()); if hdr_key.as_bytes() != encoded.as_bytes() { log::trace!( "Invalid challenge response: expected: {} received: {:?}", encoded, key ); return Err(WsClientError::InvalidChallengeResponse( encoded, hdr_key.clone(), )); } } else { log::trace!("Missing SEC-WEBSOCKET-ACCEPT header"); return Err(WsClientError::MissingWebSocketAcceptHeader); }; // response and ws framed Ok(( ClientResponse::new(head, Payload::None), framed.map_codec(|_| { if server_mode { ws::Codec::new().max_size(max_size) } else { ws::Codec::new().max_size(max_size).client_mode() } }), )) }); Either::B(fut) } }