1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-06-25 06:39:22 +02:00

ws verifyciation takes RequestHead; add SendError utility service

This commit is contained in:
Nikolay Kim
2019-04-11 14:00:32 -07:00
parent 6420a2fe1f
commit d115b3b3ed
10 changed files with 204 additions and 119 deletions

View File

@ -34,23 +34,35 @@ where
B: MessageBody,
{
type Item = Framed<T, Codec>;
type Error = Error;
type Error = (Error, Framed<T, Codec>);
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
let mut body_ready = self.body.is_some();
let framed = self.framed.as_mut().unwrap();
// send body
if self.res.is_none() && self.body.is_some() {
while body_ready && self.body.is_some() && !framed.is_write_buf_full() {
match self.body.as_mut().unwrap().poll_next()? {
while body_ready
&& self.body.is_some()
&& !self.framed.as_ref().unwrap().is_write_buf_full()
{
match self
.body
.as_mut()
.unwrap()
.poll_next()
.map_err(|e| (e, self.framed.take().unwrap()))?
{
Async::Ready(item) => {
// body is done
if item.is_none() {
let _ = self.body.take();
}
framed.force_send(Message::Chunk(item))?;
self.framed
.as_mut()
.unwrap()
.force_send(Message::Chunk(item))
.map_err(|e| (e.into(), self.framed.take().unwrap()))?;
}
Async::NotReady => body_ready = false,
}
@ -58,8 +70,14 @@ where
}
// flush write buffer
if !framed.is_write_buf_empty() {
match framed.poll_complete()? {
if !self.framed.as_ref().unwrap().is_write_buf_empty() {
match self
.framed
.as_mut()
.unwrap()
.poll_complete()
.map_err(|e| (e.into(), self.framed.take().unwrap()))?
{
Async::Ready(_) => {
if body_ready {
continue;
@ -73,7 +91,11 @@ where
// send response
if let Some(res) = self.res.take() {
framed.force_send(res)?;
self.framed
.as_mut()
.unwrap()
.force_send(res)
.map_err(|e| (e.into(), self.framed.take().unwrap()))?;
continue;
}

View File

@ -9,21 +9,18 @@ use derive_more::{Display, From};
use http::{header, Method, StatusCode};
use crate::error::ResponseError;
use crate::httpmessage::HttpMessage;
use crate::request::Request;
use crate::message::RequestHead;
use crate::response::{Response, ResponseBuilder};
mod codec;
mod frame;
mod mask;
mod proto;
mod service;
mod transport;
pub use self::codec::{Codec, Frame, Message};
pub use self::frame::Parser;
pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode};
pub use self::service::VerifyWebSockets;
pub use self::transport::Transport;
/// Websocket protocol errors
@ -112,7 +109,7 @@ impl ResponseError for HandshakeError {
// /// `protocols` is a sequence of known protocols. On successful handshake,
// /// the returned response headers contain the first protocol in this list
// /// which the server also knows.
pub fn handshake(req: &Request) -> Result<ResponseBuilder, HandshakeError> {
pub fn handshake(req: &RequestHead) -> Result<ResponseBuilder, HandshakeError> {
verify_handshake(req)?;
Ok(handshake_response(req))
}
@ -121,9 +118,9 @@ pub fn handshake(req: &Request) -> Result<ResponseBuilder, HandshakeError> {
// /// `protocols` is a sequence of known protocols. On successful handshake,
// /// the returned response headers contain the first protocol in this list
// /// which the server also knows.
pub fn verify_handshake(req: &Request) -> Result<(), HandshakeError> {
pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> {
// WebSocket accepts only GET
if *req.method() != Method::GET {
if req.method != Method::GET {
return Err(HandshakeError::GetMethodRequired);
}
@ -171,7 +168,7 @@ pub fn verify_handshake(req: &Request) -> Result<(), HandshakeError> {
/// Create websocket's handshake response
///
/// This function returns handshake `Response`, ready to send to peer.
pub fn handshake_response(req: &Request) -> ResponseBuilder {
pub fn handshake_response(req: &RequestHead) -> ResponseBuilder {
let key = {
let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap();
proto::hash_key(key.as_ref())
@ -195,13 +192,13 @@ mod tests {
let req = TestRequest::default().method(Method::POST).finish();
assert_eq!(
HandshakeError::GetMethodRequired,
verify_handshake(&req).err().unwrap()
verify_handshake(req.head()).err().unwrap()
);
let req = TestRequest::default().finish();
assert_eq!(
HandshakeError::NoWebsocketUpgrade,
verify_handshake(&req).err().unwrap()
verify_handshake(req.head()).err().unwrap()
);
let req = TestRequest::default()
@ -209,7 +206,7 @@ mod tests {
.finish();
assert_eq!(
HandshakeError::NoWebsocketUpgrade,
verify_handshake(&req).err().unwrap()
verify_handshake(req.head()).err().unwrap()
);
let req = TestRequest::default()
@ -220,7 +217,7 @@ mod tests {
.finish();
assert_eq!(
HandshakeError::NoConnectionUpgrade,
verify_handshake(&req).err().unwrap()
verify_handshake(req.head()).err().unwrap()
);
let req = TestRequest::default()
@ -235,7 +232,7 @@ mod tests {
.finish();
assert_eq!(
HandshakeError::NoVersionHeader,
verify_handshake(&req).err().unwrap()
verify_handshake(req.head()).err().unwrap()
);
let req = TestRequest::default()
@ -254,7 +251,7 @@ mod tests {
.finish();
assert_eq!(
HandshakeError::UnsupportedVersion,
verify_handshake(&req).err().unwrap()
verify_handshake(req.head()).err().unwrap()
);
let req = TestRequest::default()
@ -273,7 +270,7 @@ mod tests {
.finish();
assert_eq!(
HandshakeError::BadWebsocketKey,
verify_handshake(&req).err().unwrap()
verify_handshake(req.head()).err().unwrap()
);
let req = TestRequest::default()
@ -296,7 +293,7 @@ mod tests {
.finish();
assert_eq!(
StatusCode::SWITCHING_PROTOCOLS,
handshake_response(&req).finish().status()
handshake_response(req.head()).finish().status()
);
}

View File

@ -1,52 +0,0 @@
use std::marker::PhantomData;
use actix_codec::Framed;
use actix_service::{NewService, Service};
use futures::future::{ok, FutureResult};
use futures::{Async, IntoFuture, Poll};
use crate::h1::Codec;
use crate::request::Request;
use super::{verify_handshake, HandshakeError};
pub struct VerifyWebSockets<T> {
_t: PhantomData<T>,
}
impl<T> Default for VerifyWebSockets<T> {
fn default() -> Self {
VerifyWebSockets { _t: PhantomData }
}
}
impl<T> NewService for VerifyWebSockets<T> {
type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>);
type InitError = ();
type Service = VerifyWebSockets<T>;
type Future = FutureResult<Self::Service, Self::InitError>;
fn new_service(&self, _: &()) -> Self::Future {
ok(VerifyWebSockets { _t: PhantomData })
}
}
impl<T> Service for VerifyWebSockets<T> {
type Request = (Request, Framed<T, Codec>);
type Response = (Request, Framed<T, Codec>);
type Error = (HandshakeError, Framed<T, Codec>);
type Future = FutureResult<Self::Response, Self::Error>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
Ok(Async::Ready(()))
}
fn call(&mut self, (req, framed): (Request, Framed<T, Codec>)) -> Self::Future {
match verify_handshake(&req) {
Err(e) => Err((e, framed)).into_future(),
Ok(_) => Ok((req, framed)).into_future(),
}
}
}