1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-28 17:52:40 +01:00
actix-extras/actix-framed/src/service.rs
2019-12-02 21:37:13 +06:00

157 lines
4.7 KiB
Rust

use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::body::BodySize;
use actix_http::error::ResponseError;
use actix_http::h1::{Codec, Message};
use actix_http::ws::{verify_handshake, HandshakeError};
use actix_http::{Request, Response};
use actix_service::{Service, ServiceFactory};
use futures::future::{err, ok, Either, Ready};
use futures::Future;
/// Service that verifies incoming request if it is valid websocket
/// upgrade request. In case of error returns `HandshakeError`
pub struct VerifyWebSockets<T, C> {
_t: PhantomData<(T, C)>,
}
impl<T, C> Default for VerifyWebSockets<T, C> {
fn default() -> Self {
VerifyWebSockets { _t: PhantomData }
}
}
impl<T, C> ServiceFactory for VerifyWebSockets<T, C> {
type Config = C;
type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>);
type InitError = ();
type Service = VerifyWebSockets<T, C>;
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: C) -> Self::Future {
ok(VerifyWebSockets { _t: PhantomData })
}
}
impl<T, C> Service for VerifyWebSockets<T, C> {
type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>);
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future {
match verify_handshake(req.head()) {
Err(e) => err((e, framed)),
Ok(_) => ok((req, framed)),
}
}
}
/// Send http/1 error response
pub struct SendError<T, R, E, C>(PhantomData<(T, R, E, C)>);
impl<T, R, E, C> Default for SendError<T, R, E, C>
where
T: AsyncRead + AsyncWrite,
E: ResponseError,
{
fn default() -> Self {
SendError(PhantomData)
}
}
impl<T, R, E, C> ServiceFactory for SendError<T, R, E, C>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
R: 'static,
E: ResponseError + 'static,
{
type Config = C;
type Request = Result<R, (E, Framed<T, Codec>)>;
type Response = R;
type Error = (E, Framed<T, Codec>);
type InitError = ();
type Service = SendError<T, R, E, C>;
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: C) -> Self::Future {
ok(SendError(PhantomData))
}
}
impl<T, R, E, C> Service for SendError<T, R, E, C>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
R: 'static,
E: ResponseError + 'static,
{
type Request = Result<R, (E, Framed<T, Codec>)>;
type Response = R;
type Error = (E, Framed<T, Codec>);
type Future = Either<Ready<Result<R, (E, Framed<T, Codec>)>>, SendErrorFut<T, R, E>>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Result<R, (E, Framed<T, Codec>)>) -> Self::Future {
match req {
Ok(r) => Either::Left(ok(r)),
Err((e, framed)) => {
let res = e.error_response().drop_body();
Either::Right(SendErrorFut {
framed: Some(framed),
res: Some((res, BodySize::Empty).into()),
err: Some(e),
_t: PhantomData,
})
}
}
}
}
#[pin_project::pin_project]
pub struct SendErrorFut<T, R, E> {
res: Option<Message<(Response<()>, BodySize)>>,
framed: Option<Framed<T, Codec>>,
err: Option<E>,
_t: PhantomData<R>,
}
impl<T, R, E> Future for SendErrorFut<T, R, E>
where
E: ResponseError,
T: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<R, (E, Framed<T, Codec>)>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let Some(res) = self.res.take() {
if self.framed.as_mut().unwrap().write(res).is_err() {
return Poll::Ready(Err((
self.err.take().unwrap(),
self.framed.take().unwrap(),
)));
}
}
match self.framed.as_mut().unwrap().flush(cx) {
Poll::Ready(Ok(_)) => {
Poll::Ready(Err((self.err.take().unwrap(), self.framed.take().unwrap())))
}
Poll::Ready(Err(_)) => {
Poll::Ready(Err((self.err.take().unwrap(), self.framed.take().unwrap())))
}
Poll::Pending => Poll::Pending,
}
}
}