1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-30 18:34:36 +01:00

ws services

This commit is contained in:
Nikolay Kim 2018-10-22 09:59:20 -07:00
parent 20c693b39c
commit 9b94eaa6a8
7 changed files with 284 additions and 20 deletions

View File

@ -632,6 +632,13 @@ where
} }
} }
/// Convert Response to a Error
impl From<Response> for Error {
fn from(res: Response) -> Error {
InternalError::from_response("", res).into()
}
}
/// Helper function that creates wrapper of any error and generate *BAD /// Helper function that creates wrapper of any error and generate *BAD
/// REQUEST* response. /// REQUEST* response.
#[allow(non_snake_case)] #[allow(non_snake_case)]

View File

@ -267,7 +267,10 @@ pub struct OneRequest<T> {
_t: PhantomData<T>, _t: PhantomData<T>,
} }
impl<T> OneRequest<T> { impl<T> OneRequest<T>
where
T: AsyncRead + AsyncWrite,
{
/// Create new `H1SimpleService` instance. /// Create new `H1SimpleService` instance.
pub fn new() -> Self { pub fn new() -> Self {
OneRequest { OneRequest {

View File

@ -110,6 +110,7 @@ mod json;
mod payload; mod payload;
mod request; mod request;
mod response; mod response;
mod service;
pub mod uri; pub mod uri;
pub mod error; pub mod error;
@ -123,6 +124,7 @@ pub use extensions::Extensions;
pub use httpmessage::HttpMessage; pub use httpmessage::HttpMessage;
pub use request::Request; pub use request::Request;
pub use response::Response; pub use response::Response;
pub use service::{SendError, SendResponse};
pub use self::config::{KeepAlive, ServiceConfig, ServiceConfigBuilder}; pub use self::config::{KeepAlive, ServiceConfig, ServiceConfigBuilder};

185
src/service.rs Normal file
View File

@ -0,0 +1,185 @@
use std::io;
use std::marker::PhantomData;
use actix_net::codec::Framed;
use actix_net::service::{NewService, Service};
use futures::future::{ok, Either, FutureResult};
use futures::{Async, AsyncSink, Future, Poll, Sink};
use tokio_io::AsyncWrite;
use error::ResponseError;
use h1::{Codec, OutMessage};
use response::Response;
pub struct SendError<T, R, E>(PhantomData<(T, R, E)>);
impl<T, R, E> Default for SendError<T, R, E>
where
T: AsyncWrite,
E: ResponseError,
{
fn default() -> Self {
SendError(PhantomData)
}
}
impl<T, R, E> NewService for SendError<T, R, E>
where
T: AsyncWrite,
E: ResponseError,
{
type Request = Result<R, (E, Framed<T, Codec>)>;
type Response = R;
type Error = (E, Framed<T, Codec>);
type InitError = ();
type Service = SendError<T, R, E>;
type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self) -> Self::Future {
ok(SendError(PhantomData))
}
}
impl<T, R, E> Service for SendError<T, R, E>
where
T: AsyncWrite,
E: ResponseError,
{
type Request = Result<R, (E, Framed<T, Codec>)>;
type Response = R;
type Error = (E, Framed<T, Codec>);
type Future = Either<FutureResult<R, (E, Framed<T, Codec>)>, SendErrorFut<T, R, E>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Ok(Async::Ready(()))
}
fn call(&mut self, req: Self::Request) -> Self::Future {
match req {
Ok(r) => Either::A(ok(r)),
Err((e, framed)) => Either::B(SendErrorFut {
framed: Some(framed),
res: Some(OutMessage::Response(e.error_response())),
err: Some(e),
_t: PhantomData,
}),
}
}
}
pub struct SendErrorFut<T, R, E> {
res: Option<OutMessage>,
framed: Option<Framed<T, Codec>>,
err: Option<E>,
_t: PhantomData<R>,
}
impl<T, R, E> Future for SendErrorFut<T, R, E>
where
E: ResponseError,
T: AsyncWrite,
{
type Item = R;
type Error = (E, Framed<T, Codec>);
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(res) = self.res.take() {
match self.framed.as_mut().unwrap().start_send(res) {
Ok(AsyncSink::Ready) => (),
Ok(AsyncSink::NotReady(res)) => {
self.res = Some(res);
return Ok(Async::NotReady);
}
Err(_) => {
return Err((self.err.take().unwrap(), self.framed.take().unwrap()))
}
}
}
match self.framed.as_mut().unwrap().poll_complete() {
Ok(Async::Ready(_)) => {
return Err((self.err.take().unwrap(), self.framed.take().unwrap()))
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(_) => {
return Err((self.err.take().unwrap(), self.framed.take().unwrap()))
}
}
}
}
pub struct SendResponse<T>(PhantomData<(T,)>);
impl<T> Default for SendResponse<T>
where
T: AsyncWrite,
{
fn default() -> Self {
SendResponse(PhantomData)
}
}
impl<T> NewService for SendResponse<T>
where
T: AsyncWrite,
{
type Request = (Response, Framed<T, Codec>);
type Response = Framed<T, Codec>;
type Error = io::Error;
type InitError = ();
type Service = SendResponse<T>;
type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self) -> Self::Future {
ok(SendResponse(PhantomData))
}
}
impl<T> Service for SendResponse<T>
where
T: AsyncWrite,
{
type Request = (Response, Framed<T, Codec>);
type Response = Framed<T, Codec>;
type Error = io::Error;
type Future = SendResponseFut<T>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Ok(Async::Ready(()))
}
fn call(&mut self, (res, framed): Self::Request) -> Self::Future {
SendResponseFut {
res: Some(OutMessage::Response(res)),
framed: Some(framed),
}
}
}
pub struct SendResponseFut<T> {
res: Option<OutMessage>,
framed: Option<Framed<T, Codec>>,
}
impl<T> Future for SendResponseFut<T>
where
T: AsyncWrite,
{
type Item = Framed<T, Codec>;
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(res) = self.res.take() {
match self.framed.as_mut().unwrap().start_send(res)? {
AsyncSink::Ready => (),
AsyncSink::NotReady(res) => {
self.res = Some(res);
return Ok(Async::NotReady);
}
}
}
match self.framed.as_mut().unwrap().poll_complete()? {
Async::Ready(_) => Ok(Async::Ready(self.framed.take().unwrap())),
Async::NotReady => Ok(Async::NotReady),
}
}
}

View File

@ -14,11 +14,13 @@ mod codec;
mod frame; mod frame;
mod mask; mod mask;
mod proto; mod proto;
mod service;
mod transport; mod transport;
pub use self::codec::{Codec, Frame, Message}; pub use self::codec::{Codec, Frame, Message};
pub use self::frame::Parser; pub use self::frame::Parser;
pub use self::proto::{CloseCode, CloseReason, OpCode}; pub use self::proto::{CloseCode, CloseReason, OpCode};
pub use self::service::VerifyWebSockets;
pub use self::transport::Transport; pub use self::transport::Transport;
/// Websocket protocol errors /// Websocket protocol errors
@ -109,15 +111,20 @@ impl ResponseError for HandshakeError {
} }
} }
/// Prepare `WebSocket` handshake response. /// Verify `WebSocket` handshake request and create handshake reponse.
///
/// This function returns handshake `Response`, ready to send to peer.
/// It does not perform any IO.
///
// /// `protocols` is a sequence of known protocols. On successful handshake, // /// `protocols` is a sequence of known protocols. On successful handshake,
// /// the returned response headers contain the first protocol in this list // /// the returned response headers contain the first protocol in this list
// /// which the server also knows. // /// which the server also knows.
pub fn handshake(req: &Request) -> Result<ResponseBuilder, HandshakeError> { pub fn handshake(req: &Request) -> Result<ResponseBuilder, HandshakeError> {
verify_handshake(req)?;
Ok(handshake_response(req))
}
/// Verify `WebSocket` handshake request.
// /// `protocols` is a sequence of known protocols. On successful handshake,
// /// the returned response headers contain the first protocol in this list
// /// which the server also knows.
pub fn verify_handshake(req: &Request) -> Result<(), HandshakeError> {
// WebSocket accepts only GET // WebSocket accepts only GET
if *req.method() != Method::GET { if *req.method() != Method::GET {
return Err(HandshakeError::GetMethodRequired); return Err(HandshakeError::GetMethodRequired);
@ -161,17 +168,24 @@ pub fn handshake(req: &Request) -> Result<ResponseBuilder, HandshakeError> {
if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) { if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) {
return Err(HandshakeError::BadWebsocketKey); return Err(HandshakeError::BadWebsocketKey);
} }
Ok(())
}
/// Create websocket's handshake response
///
/// This function returns handshake `Response`, ready to send to peer.
pub fn handshake_response(req: &Request) -> ResponseBuilder {
let key = { let key = {
let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap(); let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap();
proto::hash_key(key.as_ref()) proto::hash_key(key.as_ref())
}; };
Ok(Response::build(StatusCode::SWITCHING_PROTOCOLS) Response::build(StatusCode::SWITCHING_PROTOCOLS)
.connection_type(ConnectionType::Upgrade) .connection_type(ConnectionType::Upgrade)
.header(header::UPGRADE, "websocket") .header(header::UPGRADE, "websocket")
.header(header::TRANSFER_ENCODING, "chunked") .header(header::TRANSFER_ENCODING, "chunked")
.header(header::SEC_WEBSOCKET_ACCEPT, key.as_str()) .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str())
.take()) .take()
} }
#[cfg(test)] #[cfg(test)]
@ -185,13 +199,13 @@ mod tests {
let req = TestRequest::default().method(Method::POST).finish(); let req = TestRequest::default().method(Method::POST).finish();
assert_eq!( assert_eq!(
HandshakeError::GetMethodRequired, HandshakeError::GetMethodRequired,
handshake(&req).err().unwrap() verify_handshake(&req).err().unwrap()
); );
let req = TestRequest::default().finish(); let req = TestRequest::default().finish();
assert_eq!( assert_eq!(
HandshakeError::NoWebsocketUpgrade, HandshakeError::NoWebsocketUpgrade,
handshake(&req).err().unwrap() verify_handshake(&req).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -199,7 +213,7 @@ mod tests {
.finish(); .finish();
assert_eq!( assert_eq!(
HandshakeError::NoWebsocketUpgrade, HandshakeError::NoWebsocketUpgrade,
handshake(&req).err().unwrap() verify_handshake(&req).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -209,7 +223,7 @@ mod tests {
).finish(); ).finish();
assert_eq!( assert_eq!(
HandshakeError::NoConnectionUpgrade, HandshakeError::NoConnectionUpgrade,
handshake(&req).err().unwrap() verify_handshake(&req).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -222,7 +236,7 @@ mod tests {
).finish(); ).finish();
assert_eq!( assert_eq!(
HandshakeError::NoVersionHeader, HandshakeError::NoVersionHeader,
handshake(&req).err().unwrap() verify_handshake(&req).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -238,7 +252,7 @@ mod tests {
).finish(); ).finish();
assert_eq!( assert_eq!(
HandshakeError::UnsupportedVersion, HandshakeError::UnsupportedVersion,
handshake(&req).err().unwrap() verify_handshake(&req).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -254,7 +268,7 @@ mod tests {
).finish(); ).finish();
assert_eq!( assert_eq!(
HandshakeError::BadWebsocketKey, HandshakeError::BadWebsocketKey,
handshake(&req).err().unwrap() verify_handshake(&req).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -273,7 +287,7 @@ mod tests {
).finish(); ).finish();
assert_eq!( assert_eq!(
StatusCode::SWITCHING_PROTOCOLS, StatusCode::SWITCHING_PROTOCOLS,
handshake(&req).unwrap().finish().status() handshake_response(&req).finish().status()
); );
} }

52
src/ws/service.rs Normal file
View File

@ -0,0 +1,52 @@
use std::marker::PhantomData;
use actix_net::codec::Framed;
use actix_net::service::{NewService, Service};
use futures::future::{ok, FutureResult};
use futures::{Async, IntoFuture, Poll};
use h1::Codec;
use request::Request;
use super::{verify_handshake, HandshakeError};
pub struct VerifyWebSockets<T> {
_t: PhantomData<T>,
}
impl<T> Default for VerifyWebSockets<T> {
fn default() -> Self {
VerifyWebSockets { _t: PhantomData }
}
}
impl<T> NewService for VerifyWebSockets<T> {
type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>);
type InitError = ();
type Service = VerifyWebSockets<T>;
type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self) -> Self::Future {
ok(VerifyWebSockets { _t: PhantomData })
}
}
impl<T> Service for VerifyWebSockets<T> {
type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>);
type Future = FutureResult<Self::Response, Self::Error>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Ok(Async::Ready(()))
}
fn call(&mut self, (req, framed): Self::Request) -> Self::Future {
match verify_handshake(&req) {
Err(e) => Err((e, framed)).into_future(),
Ok(_) => Ok((req, framed)).into_future(),
}
}
}

View File

@ -52,7 +52,7 @@ fn test_simple() {
.and_then(|(req, framed): (_, Framed<_, _>)| { .and_then(|(req, framed): (_, Framed<_, _>)| {
// validate request // validate request
if let Some(h1::InMessage::Message(req, _)) = req { if let Some(h1::InMessage::Message(req, _)) = req {
match ws::handshake(&req) { match ws::verify_handshake(&req) {
Err(e) => { Err(e) => {
// validation failed // validation failed
let resp = e.error_response(); let resp = e.error_response();
@ -63,11 +63,12 @@ fn test_simple() {
.map(|_| ()), .map(|_| ()),
) )
} }
Ok(mut resp) => Either::B( Ok(_) => Either::B(
// send response // send response
framed framed
.send(h1::OutMessage::Response(resp.finish())) .send(h1::OutMessage::Response(
.map_err(|_| ()) ws::handshake_response(&req).finish(),
)).map_err(|_| ())
.and_then(|framed| { .and_then(|framed| {
// start websocket service // start websocket service
let framed = let framed =