use std::marker::PhantomData; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_http::body::BodySize; use actix_http::error::ResponseError; use actix_http::h1::{Codec, Message}; use actix_http::ws::{verify_handshake, HandshakeError}; use actix_http::{Request, Response}; use actix_service::{NewService, Service}; use futures::future::{ok, Either, FutureResult}; use futures::{Async, Future, IntoFuture, Poll, Sink}; /// Service that verifies incoming request if it is valid websocket /// upgrade request. In case of error returns `HandshakeError` pub struct VerifyWebSockets { _t: PhantomData<(T, C)>, } impl Default for VerifyWebSockets { fn default() -> Self { VerifyWebSockets { _t: PhantomData } } } impl NewService for VerifyWebSockets { type Config = C; type Request = (Request, Framed); type Response = (Request, Framed); type Error = (HandshakeError, Framed); type InitError = (); type Service = VerifyWebSockets; type Future = FutureResult; fn new_service(&self, _: &C) -> Self::Future { ok(VerifyWebSockets { _t: PhantomData }) } } impl Service for VerifyWebSockets { type Request = (Request, Framed); type Response = (Request, Framed); type Error = (HandshakeError, Framed); type Future = FutureResult; fn poll_ready(&mut self) -> Poll<(), Self::Error> { Ok(Async::Ready(())) } fn call(&mut self, (req, framed): (Request, Framed)) -> Self::Future { match verify_handshake(req.head()) { Err(e) => Err((e, framed)).into_future(), Ok(_) => Ok((req, framed)).into_future(), } } } /// Send http/1 error response pub struct SendError(PhantomData<(T, R, E, C)>); impl Default for SendError where T: AsyncRead + AsyncWrite, E: ResponseError, { fn default() -> Self { SendError(PhantomData) } } impl NewService for SendError where T: AsyncRead + AsyncWrite + 'static, R: 'static, E: ResponseError + 'static, { type Config = C; type Request = Result)>; type Response = R; type Error = (E, Framed); type InitError = (); type Service = SendError; type Future = FutureResult; fn new_service(&self, _: &C) -> Self::Future { ok(SendError(PhantomData)) } } impl Service for SendError where T: AsyncRead + AsyncWrite + 'static, R: 'static, E: ResponseError + 'static, { type Request = Result)>; type Response = R; type Error = (E, Framed); type Future = Either)>, SendErrorFut>; fn poll_ready(&mut self) -> Poll<(), Self::Error> { Ok(Async::Ready(())) } fn call(&mut self, req: Result)>) -> Self::Future { match req { Ok(r) => Either::A(ok(r)), Err((e, framed)) => { let res = e.error_response().drop_body(); Either::B(SendErrorFut { framed: Some(framed), res: Some((res, BodySize::Empty).into()), err: Some(e), _t: PhantomData, }) } } } } pub struct SendErrorFut { res: Option, BodySize)>>, framed: Option>, err: Option, _t: PhantomData, } impl Future for SendErrorFut where E: ResponseError, T: AsyncRead + AsyncWrite, { type Item = R; type Error = (E, Framed); fn poll(&mut self) -> Poll { if let Some(res) = self.res.take() { if self.framed.as_mut().unwrap().force_send(res).is_err() { return Err((self.err.take().unwrap(), self.framed.take().unwrap())); } } match self.framed.as_mut().unwrap().poll_complete() { Ok(Async::Ready(_)) => { Err((self.err.take().unwrap(), self.framed.take().unwrap())) } Ok(Async::NotReady) => Ok(Async::NotReady), Err(_) => Err((self.err.take().unwrap(), self.framed.take().unwrap())), } } }