1
0
mirror of https://github.com/actix/actix-extras.git synced 2025-07-01 12:15:08 +02:00

ws services

This commit is contained in:
Nikolay Kim
2018-10-22 09:59:20 -07:00
parent 20c693b39c
commit 9b94eaa6a8
7 changed files with 284 additions and 20 deletions

View File

@ -14,11 +14,13 @@ mod codec;
mod frame;
mod mask;
mod proto;
mod service;
mod transport;
pub use self::codec::{Codec, Frame, Message};
pub use self::frame::Parser;
pub use self::proto::{CloseCode, CloseReason, OpCode};
pub use self::service::VerifyWebSockets;
pub use self::transport::Transport;
/// Websocket protocol errors
@ -109,15 +111,20 @@ impl ResponseError for HandshakeError {
}
}
/// Prepare `WebSocket` handshake response.
///
/// This function returns handshake `Response`, ready to send to peer.
/// It does not perform any IO.
///
/// Verify `WebSocket` handshake request and create handshake reponse.
// /// `protocols` is a sequence of known protocols. On successful handshake,
// /// the returned response headers contain the first protocol in this list
// /// which the server also knows.
pub fn handshake(req: &Request) -> Result<ResponseBuilder, HandshakeError> {
verify_handshake(req)?;
Ok(handshake_response(req))
}
/// Verify `WebSocket` handshake request.
// /// `protocols` is a sequence of known protocols. On successful handshake,
// /// the returned response headers contain the first protocol in this list
// /// which the server also knows.
pub fn verify_handshake(req: &Request) -> Result<(), HandshakeError> {
// WebSocket accepts only GET
if *req.method() != Method::GET {
return Err(HandshakeError::GetMethodRequired);
@ -161,17 +168,24 @@ pub fn handshake(req: &Request) -> Result<ResponseBuilder, HandshakeError> {
if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) {
return Err(HandshakeError::BadWebsocketKey);
}
Ok(())
}
/// Create websocket's handshake response
///
/// This function returns handshake `Response`, ready to send to peer.
pub fn handshake_response(req: &Request) -> ResponseBuilder {
let key = {
let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap();
proto::hash_key(key.as_ref())
};
Ok(Response::build(StatusCode::SWITCHING_PROTOCOLS)
Response::build(StatusCode::SWITCHING_PROTOCOLS)
.connection_type(ConnectionType::Upgrade)
.header(header::UPGRADE, "websocket")
.header(header::TRANSFER_ENCODING, "chunked")
.header(header::SEC_WEBSOCKET_ACCEPT, key.as_str())
.take())
.take()
}
#[cfg(test)]
@ -185,13 +199,13 @@ mod tests {
let req = TestRequest::default().method(Method::POST).finish();
assert_eq!(
HandshakeError::GetMethodRequired,
handshake(&req).err().unwrap()
verify_handshake(&req).err().unwrap()
);
let req = TestRequest::default().finish();
assert_eq!(
HandshakeError::NoWebsocketUpgrade,
handshake(&req).err().unwrap()
verify_handshake(&req).err().unwrap()
);
let req = TestRequest::default()
@ -199,7 +213,7 @@ mod tests {
.finish();
assert_eq!(
HandshakeError::NoWebsocketUpgrade,
handshake(&req).err().unwrap()
verify_handshake(&req).err().unwrap()
);
let req = TestRequest::default()
@ -209,7 +223,7 @@ mod tests {
).finish();
assert_eq!(
HandshakeError::NoConnectionUpgrade,
handshake(&req).err().unwrap()
verify_handshake(&req).err().unwrap()
);
let req = TestRequest::default()
@ -222,7 +236,7 @@ mod tests {
).finish();
assert_eq!(
HandshakeError::NoVersionHeader,
handshake(&req).err().unwrap()
verify_handshake(&req).err().unwrap()
);
let req = TestRequest::default()
@ -238,7 +252,7 @@ mod tests {
).finish();
assert_eq!(
HandshakeError::UnsupportedVersion,
handshake(&req).err().unwrap()
verify_handshake(&req).err().unwrap()
);
let req = TestRequest::default()
@ -254,7 +268,7 @@ mod tests {
).finish();
assert_eq!(
HandshakeError::BadWebsocketKey,
handshake(&req).err().unwrap()
verify_handshake(&req).err().unwrap()
);
let req = TestRequest::default()
@ -273,7 +287,7 @@ mod tests {
).finish();
assert_eq!(
StatusCode::SWITCHING_PROTOCOLS,
handshake(&req).unwrap().finish().status()
handshake_response(&req).finish().status()
);
}

52
src/ws/service.rs Normal file
View File

@ -0,0 +1,52 @@
use std::marker::PhantomData;
use actix_net::codec::Framed;
use actix_net::service::{NewService, Service};
use futures::future::{ok, FutureResult};
use futures::{Async, IntoFuture, Poll};
use h1::Codec;
use request::Request;
use super::{verify_handshake, HandshakeError};
pub struct VerifyWebSockets<T> {
_t: PhantomData<T>,
}
impl<T> Default for VerifyWebSockets<T> {
fn default() -> Self {
VerifyWebSockets { _t: PhantomData }
}
}
impl<T> NewService for VerifyWebSockets<T> {
type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>);
type InitError = ();
type Service = VerifyWebSockets<T>;
type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self) -> Self::Future {
ok(VerifyWebSockets { _t: PhantomData })
}
}
impl<T> Service for VerifyWebSockets<T> {
type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>);
type Future = FutureResult<Self::Response, Self::Error>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Ok(Async::Ready(()))
}
fn call(&mut self, (req, framed): Self::Request) -> Self::Future {
match verify_handshake(&req) {
Err(e) => Err((e, framed)).into_future(),
Ok(_) => Ok((req, framed)).into_future(),
}
}
}