1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-30 18:34:36 +01:00

ws verifyciation takes RequestHead; add SendError utility service

This commit is contained in:
Nikolay Kim 2019-04-11 14:00:32 -07:00
parent 6420a2fe1f
commit d115b3b3ed
10 changed files with 204 additions and 119 deletions

View File

@ -2,6 +2,7 @@ mod app;
mod helpers; mod helpers;
mod request; mod request;
mod route; mod route;
mod service;
mod state; mod state;
// re-export for convinience // re-export for convinience
@ -10,4 +11,5 @@ pub use actix_http::{http, Error, HttpMessage, Response, ResponseError};
pub use self::app::{FramedApp, FramedAppService}; pub use self::app::{FramedApp, FramedAppService};
pub use self::request::FramedRequest; pub use self::request::FramedRequest;
pub use self::route::FramedRoute; pub use self::route::FramedRoute;
pub use self::service::{SendError, VerifyWebSockets};
pub use self::state::State; pub use self::state::State;

112
actix-framed/src/service.rs Normal file
View File

@ -0,0 +1,112 @@
use std::marker::PhantomData;
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::error::{Error, ResponseError};
use actix_http::ws::{verify_handshake, HandshakeError};
use actix_http::{h1, Request};
use actix_service::{NewService, Service};
use futures::future::{ok, Either, FutureResult};
use futures::{Async, Future, IntoFuture, Poll};
/// Service that verifies incoming request if it is valid websocket
/// upgrade request. In case of error returns `HandshakeError`
pub struct VerifyWebSockets<T> {
_t: PhantomData<T>,
}
impl<T> Default for VerifyWebSockets<T> {
fn default() -> Self {
VerifyWebSockets { _t: PhantomData }
}
}
impl<T> NewService for VerifyWebSockets<T> {
type Request = (Request, Framed<T, h1::Codec>);
type Response = (Request, Framed<T, h1::Codec>);
type Error = (HandshakeError, Framed<T, h1::Codec>);
type InitError = ();
type Service = VerifyWebSockets<T>;
type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self, _: &()) -> Self::Future {
ok(VerifyWebSockets { _t: PhantomData })
}
}
impl<T> Service for VerifyWebSockets<T> {
type Request = (Request, Framed<T, h1::Codec>);
type Response = (Request, Framed<T, h1::Codec>);
type Error = (HandshakeError, Framed<T, h1::Codec>);
type Future = FutureResult<Self::Response, Self::Error>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Ok(Async::Ready(()))
}
fn call(&mut self, (req, framed): (Request, Framed<T, h1::Codec>)) -> Self::Future {
match verify_handshake(req.head()) {
Err(e) => Err((e, framed)).into_future(),
Ok(_) => Ok((req, framed)).into_future(),
}
}
}
/// Send http/1 error response
pub struct SendError<T, R, E>(PhantomData<(T, R, E)>);
impl<T, R, E> Default for SendError<T, R, E>
where
T: AsyncRead + AsyncWrite,
E: ResponseError,
{
fn default() -> Self {
SendError(PhantomData)
}
}
impl<T, R, E> NewService for SendError<T, R, E>
where
T: AsyncRead + AsyncWrite + 'static,
R: 'static,
E: ResponseError + 'static,
{
type Request = Result<R, (E, Framed<T, h1::Codec>)>;
type Response = R;
type Error = Error;
type InitError = ();
type Service = SendError<T, R, E>;
type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self, _: &()) -> Self::Future {
ok(SendError(PhantomData))
}
}
impl<T, R, E> Service for SendError<T, R, E>
where
T: AsyncRead + AsyncWrite + 'static,
R: 'static,
E: ResponseError + 'static,
{
type Request = Result<R, (E, Framed<T, h1::Codec>)>;
type Response = R;
type Error = Error;
type Future = Either<FutureResult<R, Error>, Box<Future<Item = R, Error = Error>>>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Ok(Async::Ready(()))
}
fn call(&mut self, req: Result<R, (E, Framed<T, h1::Codec>)>) -> Self::Future {
match req {
Ok(r) => Either::A(ok(r)),
Err((e, framed)) => {
let res = e.render_response();
let e = Error::from(e);
Either::B(Box::new(
h1::SendResponse::new(framed, res).then(move |_| Err(e)),
))
}
}
}
}

View File

@ -12,7 +12,7 @@ fn ws_service<T: AsyncRead + AsyncWrite>(
req: FramedRequest<T>, req: FramedRequest<T>,
) -> impl Future<Item = (), Error = Error> { ) -> impl Future<Item = (), Error = Error> {
let (req, framed, _) = req.into_parts(); let (req, framed, _) = req.into_parts();
let res = ws::handshake(&req).unwrap().message_body(()); let res = ws::handshake(req.head()).unwrap().message_body(());
framed framed
.send((res, body::BodySize::None).into()) .send((res, body::BodySize::None).into())

View File

@ -12,6 +12,8 @@
* MessageBody::length() renamed to MessageBody::size() for consistency * MessageBody::length() renamed to MessageBody::size() for consistency
* ws handshake verification functions take RequestHead instead of Request
## [0.1.0-alpha.4] - 2019-04-08 ## [0.1.0-alpha.4] - 2019-04-08

View File

@ -46,7 +46,7 @@ secure-cookies = ["ring"]
[dependencies] [dependencies]
actix-service = "0.3.6" actix-service = "0.3.6"
actix-codec = "0.1.2" actix-codec = "0.1.2"
actix-connect = "0.1.2" actix-connect = "0.1.3"
actix-utils = "0.3.5" actix-utils = "0.3.5"
actix-server-config = "0.1.0" actix-server-config = "0.1.0"
actix-threadpool = "0.1.0" actix-threadpool = "0.1.0"

View File

@ -34,23 +34,35 @@ where
B: MessageBody, B: MessageBody,
{ {
type Item = Framed<T, Codec>; type Item = Framed<T, Codec>;
type Error = Error; type Error = (Error, Framed<T, Codec>);
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop { loop {
let mut body_ready = self.body.is_some(); let mut body_ready = self.body.is_some();
let framed = self.framed.as_mut().unwrap();
// send body // send body
if self.res.is_none() && self.body.is_some() { if self.res.is_none() && self.body.is_some() {
while body_ready && self.body.is_some() && !framed.is_write_buf_full() { while body_ready
match self.body.as_mut().unwrap().poll_next()? { && self.body.is_some()
&& !self.framed.as_ref().unwrap().is_write_buf_full()
{
match self
.body
.as_mut()
.unwrap()
.poll_next()
.map_err(|e| (e, self.framed.take().unwrap()))?
{
Async::Ready(item) => { Async::Ready(item) => {
// body is done // body is done
if item.is_none() { if item.is_none() {
let _ = self.body.take(); let _ = self.body.take();
} }
framed.force_send(Message::Chunk(item))?; self.framed
.as_mut()
.unwrap()
.force_send(Message::Chunk(item))
.map_err(|e| (e.into(), self.framed.take().unwrap()))?;
} }
Async::NotReady => body_ready = false, Async::NotReady => body_ready = false,
} }
@ -58,8 +70,14 @@ where
} }
// flush write buffer // flush write buffer
if !framed.is_write_buf_empty() { if !self.framed.as_ref().unwrap().is_write_buf_empty() {
match framed.poll_complete()? { match self
.framed
.as_mut()
.unwrap()
.poll_complete()
.map_err(|e| (e.into(), self.framed.take().unwrap()))?
{
Async::Ready(_) => { Async::Ready(_) => {
if body_ready { if body_ready {
continue; continue;
@ -73,7 +91,11 @@ where
// send response // send response
if let Some(res) = self.res.take() { if let Some(res) = self.res.take() {
framed.force_send(res)?; self.framed
.as_mut()
.unwrap()
.force_send(res)
.map_err(|e| (e.into(), self.framed.take().unwrap()))?;
continue; continue;
} }

View File

@ -9,21 +9,18 @@ use derive_more::{Display, From};
use http::{header, Method, StatusCode}; use http::{header, Method, StatusCode};
use crate::error::ResponseError; use crate::error::ResponseError;
use crate::httpmessage::HttpMessage; use crate::message::RequestHead;
use crate::request::Request;
use crate::response::{Response, ResponseBuilder}; use crate::response::{Response, ResponseBuilder};
mod codec; mod codec;
mod frame; mod frame;
mod mask; mod mask;
mod proto; mod proto;
mod service;
mod transport; mod transport;
pub use self::codec::{Codec, Frame, Message}; pub use self::codec::{Codec, Frame, Message};
pub use self::frame::Parser; pub use self::frame::Parser;
pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode}; pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode};
pub use self::service::VerifyWebSockets;
pub use self::transport::Transport; pub use self::transport::Transport;
/// Websocket protocol errors /// Websocket protocol errors
@ -112,7 +109,7 @@ impl ResponseError for HandshakeError {
// /// `protocols` is a sequence of known protocols. On successful handshake, // /// `protocols` is a sequence of known protocols. On successful handshake,
// /// the returned response headers contain the first protocol in this list // /// the returned response headers contain the first protocol in this list
// /// which the server also knows. // /// which the server also knows.
pub fn handshake(req: &Request) -> Result<ResponseBuilder, HandshakeError> { pub fn handshake(req: &RequestHead) -> Result<ResponseBuilder, HandshakeError> {
verify_handshake(req)?; verify_handshake(req)?;
Ok(handshake_response(req)) Ok(handshake_response(req))
} }
@ -121,9 +118,9 @@ pub fn handshake(req: &Request) -> Result<ResponseBuilder, HandshakeError> {
// /// `protocols` is a sequence of known protocols. On successful handshake, // /// `protocols` is a sequence of known protocols. On successful handshake,
// /// the returned response headers contain the first protocol in this list // /// the returned response headers contain the first protocol in this list
// /// which the server also knows. // /// which the server also knows.
pub fn verify_handshake(req: &Request) -> Result<(), HandshakeError> { pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> {
// WebSocket accepts only GET // WebSocket accepts only GET
if *req.method() != Method::GET { if req.method != Method::GET {
return Err(HandshakeError::GetMethodRequired); return Err(HandshakeError::GetMethodRequired);
} }
@ -171,7 +168,7 @@ pub fn verify_handshake(req: &Request) -> Result<(), HandshakeError> {
/// Create websocket's handshake response /// Create websocket's handshake response
/// ///
/// This function returns handshake `Response`, ready to send to peer. /// This function returns handshake `Response`, ready to send to peer.
pub fn handshake_response(req: &Request) -> ResponseBuilder { pub fn handshake_response(req: &RequestHead) -> ResponseBuilder {
let key = { let key = {
let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap(); let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap();
proto::hash_key(key.as_ref()) proto::hash_key(key.as_ref())
@ -195,13 +192,13 @@ mod tests {
let req = TestRequest::default().method(Method::POST).finish(); let req = TestRequest::default().method(Method::POST).finish();
assert_eq!( assert_eq!(
HandshakeError::GetMethodRequired, HandshakeError::GetMethodRequired,
verify_handshake(&req).err().unwrap() verify_handshake(req.head()).err().unwrap()
); );
let req = TestRequest::default().finish(); let req = TestRequest::default().finish();
assert_eq!( assert_eq!(
HandshakeError::NoWebsocketUpgrade, HandshakeError::NoWebsocketUpgrade,
verify_handshake(&req).err().unwrap() verify_handshake(req.head()).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -209,7 +206,7 @@ mod tests {
.finish(); .finish();
assert_eq!( assert_eq!(
HandshakeError::NoWebsocketUpgrade, HandshakeError::NoWebsocketUpgrade,
verify_handshake(&req).err().unwrap() verify_handshake(req.head()).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -220,7 +217,7 @@ mod tests {
.finish(); .finish();
assert_eq!( assert_eq!(
HandshakeError::NoConnectionUpgrade, HandshakeError::NoConnectionUpgrade,
verify_handshake(&req).err().unwrap() verify_handshake(req.head()).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -235,7 +232,7 @@ mod tests {
.finish(); .finish();
assert_eq!( assert_eq!(
HandshakeError::NoVersionHeader, HandshakeError::NoVersionHeader,
verify_handshake(&req).err().unwrap() verify_handshake(req.head()).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -254,7 +251,7 @@ mod tests {
.finish(); .finish();
assert_eq!( assert_eq!(
HandshakeError::UnsupportedVersion, HandshakeError::UnsupportedVersion,
verify_handshake(&req).err().unwrap() verify_handshake(req.head()).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -273,7 +270,7 @@ mod tests {
.finish(); .finish();
assert_eq!( assert_eq!(
HandshakeError::BadWebsocketKey, HandshakeError::BadWebsocketKey,
verify_handshake(&req).err().unwrap() verify_handshake(req.head()).err().unwrap()
); );
let req = TestRequest::default() let req = TestRequest::default()
@ -296,7 +293,7 @@ mod tests {
.finish(); .finish();
assert_eq!( assert_eq!(
StatusCode::SWITCHING_PROTOCOLS, StatusCode::SWITCHING_PROTOCOLS,
handshake_response(&req).finish().status() handshake_response(req.head()).finish().status()
); );
} }

View File

@ -1,52 +0,0 @@
use std::marker::PhantomData;
use actix_codec::Framed;
use actix_service::{NewService, Service};
use futures::future::{ok, FutureResult};
use futures::{Async, IntoFuture, Poll};
use crate::h1::Codec;
use crate::request::Request;
use super::{verify_handshake, HandshakeError};
pub struct VerifyWebSockets<T> {
_t: PhantomData<T>,
}
impl<T> Default for VerifyWebSockets<T> {
fn default() -> Self {
VerifyWebSockets { _t: PhantomData }
}
}
impl<T> NewService for VerifyWebSockets<T> {
type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>);
type InitError = ();
type Service = VerifyWebSockets<T>;
type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self, _: &()) -> Self::Future {
ok(VerifyWebSockets { _t: PhantomData })
}
}
impl<T> Service for VerifyWebSockets<T> {
type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>);
type Future = FutureResult<Self::Response, Self::Error>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Ok(Async::Ready(()))
}
fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future {
match verify_handshake(&req) {
Err(e) => Err((e, framed)).into_future(),
Ok(_) => Ok((req, framed)).into_future(),
}
}
}

View File

@ -9,7 +9,7 @@ use futures::{Future, Sink, Stream};
fn ws_service<T: AsyncRead + AsyncWrite>( fn ws_service<T: AsyncRead + AsyncWrite>(
(req, framed): (Request, Framed<T, h1::Codec>), (req, framed): (Request, Framed<T, h1::Codec>),
) -> impl Future<Item = (), Error = Error> { ) -> impl Future<Item = (), Error = Error> {
let res = ws::handshake(&req).unwrap().message_body(()); let res = ws::handshake(req.head()).unwrap().message_body(());
framed framed
.send((res, body::BodySize::None).into()) .send((res, body::BodySize::None).into())

View File

@ -11,7 +11,7 @@ use futures::future::{ok, Either};
use futures::{Future, Sink, Stream}; use futures::{Future, Sink, Stream};
use tokio_tcp::TcpStream; use tokio_tcp::TcpStream;
use actix_http::{body::BodySize, h1, ws, ResponseError, ServiceConfig}; use actix_http::{body::BodySize, h1, ws, Request, ResponseError, ServiceConfig};
fn ws_service(req: ws::Frame) -> impl Future<Item = ws::Message, Error = io::Error> { fn ws_service(req: ws::Frame) -> impl Future<Item = ws::Message, Error = io::Error> {
match req { match req {
@ -40,47 +40,49 @@ fn test_simple() {
fn_service(|io: Io<TcpStream>| Ok(io.into_parts().0)) fn_service(|io: Io<TcpStream>| Ok(io.into_parts().0))
.and_then(IntoFramed::new(|| h1::Codec::new(ServiceConfig::default()))) .and_then(IntoFramed::new(|| h1::Codec::new(ServiceConfig::default())))
.and_then(TakeItem::new().map_err(|_| ())) .and_then(TakeItem::new().map_err(|_| ()))
.and_then(|(req, framed): (_, Framed<_, _>)| { .and_then(
// validate request |(req, framed): (Option<h1::Message<Request>>, Framed<_, _>)| {
if let Some(h1::Message::Item(req)) = req { // validate request
match ws::verify_handshake(&req) { if let Some(h1::Message::Item(req)) = req {
Err(e) => { match ws::verify_handshake(req.head()) {
// validation failed Err(e) => {
let res = e.error_response(); // validation failed
Either::A( let res = e.error_response();
framed Either::A(
.send(h1::Message::Item(( framed
res.drop_body(), .send(h1::Message::Item((
BodySize::Empty, res.drop_body(),
))) BodySize::Empty,
.map_err(|_| ()) )))
.map(|_| ()), .map_err(|_| ())
) .map(|_| ()),
} )
Ok(_) => { }
let res = ws::handshake_response(&req).finish(); Ok(_) => {
Either::B( let res = ws::handshake_response(req.head()).finish();
// send handshake response Either::B(
framed // send handshake response
.send(h1::Message::Item(( framed
res.drop_body(), .send(h1::Message::Item((
BodySize::None, res.drop_body(),
))) BodySize::None,
.map_err(|_| ()) )))
.and_then(|framed| { .map_err(|_| ())
// start websocket service .and_then(|framed| {
let framed = // start websocket service
framed.into_framed(ws::Codec::new()); let framed =
ws::Transport::with(framed, ws_service) framed.into_framed(ws::Codec::new());
.map_err(|_| ()) ws::Transport::with(framed, ws_service)
}), .map_err(|_| ())
) }),
)
}
} }
} else {
panic!()
} }
} else { },
panic!() )
}
})
}); });
// client service // client service