1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-11-28 01:52:57 +01:00
actix-web/awc/src/ws.rs
2021-12-27 16:15:33 +00:00

539 lines
16 KiB
Rust

//! Websockets client
//!
//! Type definitions required to use [`awc::Client`](super::Client) as a WebSocket client.
//!
//! # Example
//!
//! ```no_run
//! use awc::{Client, ws};
//! use futures_util::{sink::SinkExt as _, stream::StreamExt as _};
//!
//! #[actix_rt::main]
//! async fn main() {
//! let (_resp, mut connection) = Client::new()
//! .ws("ws://echo.websocket.org")
//! .connect()
//! .await
//! .unwrap();
//!
//! connection
//! .send(ws::Message::Text("Echo".into()))
//! .await
//! .unwrap();
//! let response = connection.next().await.unwrap().unwrap();
//!
//! assert_eq!(response, ws::Frame::Text("Echo".as_bytes().into()));
//! }
//! ```
use std::{convert::TryFrom, fmt, net::SocketAddr, str};
use actix_codec::Framed;
use actix_http::{ws, Payload, RequestHead};
use actix_rt::time::timeout;
use actix_service::Service as _;
pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message};
use crate::{
client::ClientConfig,
connect::{BoxedSocket, ConnectRequest},
error::{HttpError, InvalidUrl, SendRequestError, WsClientError},
http::{
header::{self, HeaderName, HeaderValue, TryIntoHeaderValue, AUTHORIZATION},
ConnectionType, Method, StatusCode, Uri, Version,
},
ClientResponse,
};
#[cfg(feature = "cookies")]
use crate::cookie::{Cookie, CookieJar};
/// WebSocket connection.
pub struct WebsocketsRequest {
pub(crate) head: RequestHead,
err: Option<HttpError>,
origin: Option<HeaderValue>,
protocols: Option<String>,
addr: Option<SocketAddr>,
max_size: usize,
server_mode: bool,
config: ClientConfig,
#[cfg(feature = "cookies")]
cookies: Option<CookieJar>,
}
impl WebsocketsRequest {
/// Create new WebSocket connection
pub(crate) fn new<U>(uri: U, config: ClientConfig) -> Self
where
Uri: TryFrom<U>,
<Uri as TryFrom<U>>::Error: Into<HttpError>,
{
let mut err = None;
#[allow(clippy::field_reassign_with_default)]
let mut head = {
let mut head = RequestHead::default();
head.method = Method::GET;
head.version = Version::HTTP_11;
head
};
match Uri::try_from(uri) {
Ok(uri) => head.uri = uri,
Err(e) => err = Some(e.into()),
}
WebsocketsRequest {
head,
err,
config,
addr: None,
origin: None,
protocols: None,
max_size: 65_536,
server_mode: false,
#[cfg(feature = "cookies")]
cookies: None,
}
}
/// Set socket address of the server.
///
/// This address is used for connection. If address is not
/// provided url's host name get resolved.
pub fn address(mut self, addr: SocketAddr) -> Self {
self.addr = Some(addr);
self
}
/// Set supported WebSocket protocols
pub fn protocols<U, V>(mut self, protos: U) -> Self
where
U: IntoIterator<Item = V>,
V: AsRef<str>,
{
let mut protos = protos
.into_iter()
.fold(String::new(), |acc, s| acc + s.as_ref() + ",");
protos.pop();
self.protocols = Some(protos);
self
}
/// Set a cookie
#[cfg(feature = "cookies")]
pub fn cookie(mut self, cookie: Cookie<'_>) -> Self {
if self.cookies.is_none() {
let mut jar = CookieJar::new();
jar.add(cookie.into_owned());
self.cookies = Some(jar)
} else {
self.cookies.as_mut().unwrap().add(cookie.into_owned());
}
self
}
/// Set request Origin
pub fn origin<V, E>(mut self, origin: V) -> Self
where
HeaderValue: TryFrom<V, Error = E>,
HttpError: From<E>,
{
match HeaderValue::try_from(origin) {
Ok(value) => self.origin = Some(value),
Err(e) => self.err = Some(e.into()),
}
self
}
/// Set max frame size
///
/// By default max size is set to 64kB
pub fn max_frame_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
/// Disable payload masking. By default ws client masks frame payload.
pub fn server_mode(mut self) -> Self {
self.server_mode = true;
self
}
/// Append a header.
///
/// Header gets appended to existing header.
/// To override header use `set_header()` method.
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
V: TryIntoHeaderValue,
{
match HeaderName::try_from(key) {
Ok(key) => match value.try_into_value() {
Ok(value) => {
self.head.headers.append(key, value);
}
Err(e) => self.err = Some(e.into()),
},
Err(e) => self.err = Some(e.into()),
}
self
}
/// Insert a header, replaces existing header.
pub fn set_header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
V: TryIntoHeaderValue,
{
match HeaderName::try_from(key) {
Ok(key) => match value.try_into_value() {
Ok(value) => {
self.head.headers.insert(key, value);
}
Err(e) => self.err = Some(e.into()),
},
Err(e) => self.err = Some(e.into()),
}
self
}
/// Insert a header only if it is not yet set.
pub fn set_header_if_none<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
V: TryIntoHeaderValue,
{
match HeaderName::try_from(key) {
Ok(key) => {
if !self.head.headers.contains_key(&key) {
match value.try_into_value() {
Ok(value) => {
self.head.headers.insert(key, value);
}
Err(e) => self.err = Some(e.into()),
}
}
}
Err(e) => self.err = Some(e.into()),
}
self
}
/// Set HTTP basic authorization header
pub fn basic_auth<U>(self, username: U, password: Option<&str>) -> Self
where
U: fmt::Display,
{
let auth = match password {
Some(password) => format!("{}:{}", username, password),
None => format!("{}:", username),
};
self.header(AUTHORIZATION, format!("Basic {}", base64::encode(&auth)))
}
/// Set HTTP bearer authentication header
pub fn bearer_auth<T>(self, token: T) -> Self
where
T: fmt::Display,
{
self.header(AUTHORIZATION, format!("Bearer {}", token))
}
/// Complete request construction and connect to a WebSocket server.
pub async fn connect(
mut self,
) -> Result<(ClientResponse, Framed<BoxedSocket, Codec>), WsClientError> {
if let Some(e) = self.err.take() {
return Err(e.into());
}
// validate uri
let uri = &self.head.uri;
if uri.host().is_none() {
return Err(InvalidUrl::MissingHost.into());
} else if uri.scheme().is_none() {
return Err(InvalidUrl::MissingScheme.into());
} else if let Some(scheme) = uri.scheme() {
match scheme.as_str() {
"http" | "ws" | "https" | "wss" => {}
_ => return Err(InvalidUrl::UnknownScheme.into()),
}
} else {
return Err(InvalidUrl::UnknownScheme.into());
}
if !self.head.headers.contains_key(header::HOST) {
self.head.headers.insert(
header::HOST,
HeaderValue::from_str(uri.host().unwrap()).unwrap(),
);
}
// set cookies
#[cfg(feature = "cookies")]
if let Some(ref mut jar) = self.cookies {
let cookie: String = jar
.delta()
// ensure only name=value is written to cookie header
.map(|c| c.stripped().encoded().to_string())
.collect::<Vec<_>>()
.join("; ");
if !cookie.is_empty() {
self.head
.headers
.insert(header::COOKIE, HeaderValue::from_str(&cookie).unwrap());
}
}
// origin
if let Some(origin) = self.origin.take() {
self.head.headers.insert(header::ORIGIN, origin);
}
self.head.set_connection_type(ConnectionType::Upgrade);
#[allow(clippy::declare_interior_mutable_const)]
const HV_WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
self.head.headers.insert(header::UPGRADE, HV_WEBSOCKET);
#[allow(clippy::declare_interior_mutable_const)]
const HV_THIRTEEN: HeaderValue = HeaderValue::from_static("13");
self.head
.headers
.insert(header::SEC_WEBSOCKET_VERSION, HV_THIRTEEN);
if let Some(protocols) = self.protocols.take() {
self.head.headers.insert(
header::SEC_WEBSOCKET_PROTOCOL,
HeaderValue::try_from(protocols.as_str()).unwrap(),
);
}
// Generate a random key for the `Sec-WebSocket-Key` header which is a base64-encoded
// (see RFC 4648 §4) value that, when decoded, is 16 bytes in length (RFC 6455 §1.3).
let sec_key: [u8; 16] = rand::random();
let key = base64::encode(&sec_key);
self.head.headers.insert(
header::SEC_WEBSOCKET_KEY,
HeaderValue::try_from(key.as_str()).unwrap(),
);
let head = self.head;
let max_size = self.max_size;
let server_mode = self.server_mode;
let req = ConnectRequest::Tunnel(head, self.addr);
let fut = self.config.connector.call(req);
// set request timeout
let res = if let Some(to) = self.config.timeout {
timeout(to, fut)
.await
.map_err(|_| SendRequestError::Timeout)??
} else {
fut.await?
};
let (head, framed) = res.into_tunnel_response();
// verify response
if head.status != StatusCode::SWITCHING_PROTOCOLS {
return Err(WsClientError::InvalidResponseStatus(head.status));
}
// check for "UPGRADE" to WebSocket header
let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) {
if let Ok(s) = hdr.to_str() {
s.to_ascii_lowercase().contains("websocket")
} else {
false
}
} else {
false
};
if !has_hdr {
log::trace!("Invalid upgrade header");
return Err(WsClientError::InvalidUpgradeHeader);
}
// Check for "CONNECTION" header
if let Some(conn) = head.headers.get(&header::CONNECTION) {
if let Ok(s) = conn.to_str() {
if !s.to_ascii_lowercase().contains("upgrade") {
log::trace!("Invalid connection header: {}", s);
return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
}
} else {
log::trace!("Invalid connection header: {:?}", conn);
return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
}
} else {
log::trace!("Missing connection header");
return Err(WsClientError::MissingConnectionHeader);
}
if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) {
let encoded = ws::hash_key(key.as_ref());
if hdr_key.as_bytes() != encoded {
log::trace!(
"Invalid challenge response: expected: {:?} received: {:?}",
&encoded,
key
);
return Err(WsClientError::InvalidChallengeResponse(
encoded,
hdr_key.clone(),
));
}
} else {
log::trace!("Missing SEC-WEBSOCKET-ACCEPT header");
return Err(WsClientError::MissingWebSocketAcceptHeader);
};
// response and ws framed
Ok((
ClientResponse::new(head, Payload::None),
framed.into_map_codec(|_| {
if server_mode {
ws::Codec::new().max_size(max_size)
} else {
ws::Codec::new().max_size(max_size).client_mode()
}
}),
))
}
}
impl fmt::Debug for WebsocketsRequest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"\nWebsocketsRequest {}:{}",
self.head.method, self.head.uri
)?;
writeln!(f, " headers:")?;
for (key, val) in self.head.headers.iter() {
writeln!(f, " {:?}: {:?}", key, val)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Client;
#[actix_rt::test]
async fn test_debug() {
let request = Client::new().ws("/").header("x-test", "111");
let repr = format!("{:?}", request);
assert!(repr.contains("WebsocketsRequest"));
assert!(repr.contains("x-test"));
}
#[actix_rt::test]
async fn test_header_override() {
let req = Client::builder()
.add_default_header((header::CONTENT_TYPE, "111"))
.finish()
.ws("/")
.set_header(header::CONTENT_TYPE, "222");
assert_eq!(
req.head
.headers
.get(header::CONTENT_TYPE)
.unwrap()
.to_str()
.unwrap(),
"222"
);
}
#[actix_rt::test]
async fn basic_auth() {
let req = Client::new()
.ws("/")
.basic_auth("username", Some("password"));
assert_eq!(
req.head
.headers
.get(header::AUTHORIZATION)
.unwrap()
.to_str()
.unwrap(),
"Basic dXNlcm5hbWU6cGFzc3dvcmQ="
);
let req = Client::new().ws("/").basic_auth("username", None);
assert_eq!(
req.head
.headers
.get(header::AUTHORIZATION)
.unwrap()
.to_str()
.unwrap(),
"Basic dXNlcm5hbWU6"
);
}
#[actix_rt::test]
async fn bearer_auth() {
let req = Client::new().ws("/").bearer_auth("someS3cr3tAutht0k3n");
assert_eq!(
req.head
.headers
.get(header::AUTHORIZATION)
.unwrap()
.to_str()
.unwrap(),
"Bearer someS3cr3tAutht0k3n"
);
let _ = req.connect();
}
#[actix_rt::test]
async fn basics() {
let req = Client::new()
.ws("http://localhost/")
.origin("test-origin")
.max_frame_size(100)
.server_mode()
.protocols(&["v1", "v2"])
.set_header_if_none(header::CONTENT_TYPE, "json")
.set_header_if_none(header::CONTENT_TYPE, "text")
.cookie(Cookie::build("cookie1", "value1").finish());
assert_eq!(
req.origin.as_ref().unwrap().to_str().unwrap(),
"test-origin"
);
assert_eq!(req.max_size, 100);
assert!(req.server_mode);
assert_eq!(req.protocols, Some("v1,v2".to_string()));
assert_eq!(
req.head.headers.get(header::CONTENT_TYPE).unwrap(),
header::HeaderValue::from_static("json")
);
let _ = req.connect().await;
assert!(Client::new().ws("/").connect().await.is_err());
assert!(Client::new().ws("http:///test").connect().await.is_err());
assert!(Client::new().ws("hmm://test.com/").connect().await.is_err());
}
}