1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-24 16:02:59 +01:00

do not handle upgrade and connect requests

This commit is contained in:
Nikolay Kim 2018-10-13 23:57:31 -07:00
parent b960b5827c
commit d39c018c93
7 changed files with 141 additions and 57 deletions

View File

@ -4,7 +4,7 @@ use std::io::{self, Write};
use bytes::{BufMut, Bytes, BytesMut};
use tokio_codec::{Decoder, Encoder};
use super::decoder::{PayloadDecoder, PayloadItem, RequestDecoder};
use super::decoder::{PayloadDecoder, PayloadItem, RequestDecoder, RequestPayloadType};
use super::encoder::{ResponseEncoder, ResponseLength};
use body::{Binary, Body};
use config::ServiceConfig;
@ -20,7 +20,8 @@ bitflags! {
const HEAD = 0b0000_0001;
const UPGRADE = 0b0000_0010;
const KEEPALIVE = 0b0000_0100;
const KEEPALIVE_ENABLED = 0b0001_0000;
const KEEPALIVE_ENABLED = 0b0000_1000;
const UNHANDLED = 0b0001_0000;
}
}
@ -39,11 +40,19 @@ pub enum OutMessage {
#[derive(Debug)]
pub enum InMessage {
/// Request
Message { req: Request, payload: bool },
Message(Request, InMessageType),
/// Payload chunk
Chunk(Option<Bytes>),
}
/// Incoming request type
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InMessageType {
None,
Payload,
Unhandled,
}
/// HTTP/1 Codec
pub struct Codec {
config: ServiceConfig,
@ -246,6 +255,8 @@ impl Decoder for Codec {
Some(PayloadItem::Eof) => Some(InMessage::Chunk(None)),
None => None,
})
} else if self.flags.contains(Flags::UNHANDLED) {
Ok(None)
} else if let Some((req, payload)) = self.decoder.decode(src)? {
self.flags
.set(Flags::HEAD, req.inner.method == Method::HEAD);
@ -253,11 +264,21 @@ impl Decoder for Codec {
if self.flags.contains(Flags::KEEPALIVE_ENABLED) {
self.flags.set(Flags::KEEPALIVE, req.keep_alive());
}
self.payload = payload;
Ok(Some(InMessage::Message {
req,
payload: self.payload.is_some(),
}))
let payload = match payload {
RequestPayloadType::None => {
self.payload = None;
InMessageType::None
}
RequestPayloadType::Payload(pl) => {
self.payload = Some(pl);
InMessageType::Payload
}
RequestPayloadType::Unhandled => {
self.payload = None;
InMessageType::Unhandled
}
};
Ok(Some(InMessage::Message(req, payload)))
} else {
Ok(None)
}

View File

@ -16,6 +16,13 @@ const MAX_HEADERS: usize = 96;
pub struct RequestDecoder(&'static RequestPool);
/// Incoming request type
pub enum RequestPayloadType {
None,
Payload(PayloadDecoder),
Unhandled,
}
impl RequestDecoder {
pub(crate) fn with_pool(pool: &'static RequestPool) -> RequestDecoder {
RequestDecoder(pool)
@ -29,7 +36,7 @@ impl Default for RequestDecoder {
}
impl Decoder for RequestDecoder {
type Item = (Request, Option<PayloadDecoder>);
type Item = (Request, RequestPayloadType);
type Error = ParseError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
@ -149,18 +156,18 @@ impl Decoder for RequestDecoder {
// https://tools.ietf.org/html/rfc7230#section-3.3.3
let decoder = if chunked {
// Chunked encoding
Some(PayloadDecoder::chunked())
RequestPayloadType::Payload(PayloadDecoder::chunked())
} else if let Some(len) = content_length {
// Content-Length
Some(PayloadDecoder::length(len))
RequestPayloadType::Payload(PayloadDecoder::length(len))
} else if has_upgrade || msg.inner.method == Method::CONNECT {
// upgrade(websocket) or connect
Some(PayloadDecoder::eof())
RequestPayloadType::Unhandled
} else if src.len() >= MAX_BUFFER_SIZE {
error!("MAX_BUFFER_SIZE unprocessed data reached, closing");
return Err(ParseError::TooLarge);
} else {
None
RequestPayloadType::None
};
Ok(Some((msg, decoder)))
@ -481,20 +488,36 @@ mod tests {
use super::*;
use error::ParseError;
use h1::InMessage;
use h1::{InMessage, InMessageType};
use httpmessage::HttpMessage;
use request::Request;
impl RequestPayloadType {
fn unwrap(self) -> PayloadDecoder {
match self {
RequestPayloadType::Payload(pl) => pl,
_ => panic!(),
}
}
fn is_unhandled(&self) -> bool {
match self {
RequestPayloadType::Unhandled => true,
_ => false,
}
}
}
impl InMessage {
fn message(self) -> Request {
match self {
InMessage::Message { req, payload: _ } => req,
InMessage::Message(req, _) => req,
_ => panic!("error"),
}
}
fn is_payload(&self) -> bool {
match *self {
InMessage::Message { req: _, payload } => payload,
InMessage::Message(_, payload) => payload == InMessageType::Payload,
_ => panic!("error"),
}
}
@ -919,13 +942,9 @@ mod tests {
);
let mut reader = RequestDecoder::default();
let (req, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
assert!(!req.keep_alive());
assert!(req.upgrade());
assert_eq!(
pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(),
b"some raw data"
);
assert!(pl.is_unhandled());
}
#[test]

View File

@ -18,7 +18,8 @@ use error::DispatchError;
use request::Request;
use response::Response;
use super::codec::{Codec, InMessage, OutMessage};
use super::codec::{Codec, InMessage, InMessageType, OutMessage};
use super::H1ServiceResult;
const MAX_PIPELINED_MESSAGES: usize = 16;
@ -41,13 +42,14 @@ where
{
service: S,
flags: Flags,
framed: Framed<T, Codec>,
framed: Option<Framed<T, Codec>>,
error: Option<DispatchError<S::Error>>,
config: ServiceConfig,
state: State<S>,
payload: Option<PayloadSender>,
messages: VecDeque<Message>,
unhandled: Option<Request>,
ka_expire: Instant,
ka_timer: Option<Delay>,
@ -112,9 +114,10 @@ where
state: State::None,
error: None,
messages: VecDeque::new(),
framed: Some(framed),
unhandled: None,
service,
flags,
framed,
config,
ka_expire,
ka_timer,
@ -144,7 +147,7 @@ where
/// Flush stream
fn poll_flush(&mut self) -> Poll<(), DispatchError<S::Error>> {
if !self.flags.contains(Flags::FLUSHED) {
match self.framed.poll_complete() {
match self.framed.as_mut().unwrap().poll_complete() {
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(err) => {
debug!("Error sending data: {}", err);
@ -187,7 +190,11 @@ where
State::ServiceCall(ref mut fut) => {
match fut.poll().map_err(DispatchError::Service)? {
Async::Ready(mut res) => {
self.framed.get_codec_mut().prepare_te(&mut res);
self.framed
.as_mut()
.unwrap()
.get_codec_mut()
.prepare_te(&mut res);
let body = res.replace_body(Body::Empty);
Some(State::SendResponse(Some((
OutMessage::Response(res),
@ -200,11 +207,11 @@ where
// send respons
State::SendResponse(ref mut item) => {
let (msg, body) = item.take().expect("SendResponse is empty");
match self.framed.start_send(msg) {
match self.framed.as_mut().unwrap().start_send(msg) {
Ok(AsyncSink::Ready) => {
self.flags.set(
Flags::KEEPALIVE,
self.framed.get_codec().keepalive(),
self.framed.as_mut().unwrap().get_codec().keepalive(),
);
self.flags.remove(Flags::FLUSHED);
match body {
@ -233,7 +240,7 @@ where
// Send payload
State::SendPayload(ref mut stream, ref mut bin) => {
if let Some(item) = bin.take() {
match self.framed.start_send(item) {
match self.framed.as_mut().unwrap().start_send(item) {
Ok(AsyncSink::Ready) => {
self.flags.remove(Flags::FLUSHED);
}
@ -248,6 +255,8 @@ where
match stream.poll() {
Ok(Async::Ready(Some(item))) => match self
.framed
.as_mut()
.unwrap()
.start_send(OutMessage::Chunk(Some(item.into())))
{
Ok(AsyncSink::Ready) => {
@ -297,7 +306,11 @@ where
let mut task = self.service.call(req);
match task.poll().map_err(DispatchError::Service)? {
Async::Ready(mut res) => {
self.framed.get_codec_mut().prepare_te(&mut res);
self.framed
.as_mut()
.unwrap()
.get_codec_mut()
.prepare_te(&mut res);
let body = res.replace_body(Body::Empty);
Ok(State::SendResponse(Some((OutMessage::Response(res), body))))
}
@ -314,18 +327,25 @@ where
let mut updated = false;
'outer: loop {
match self.framed.poll() {
match self.framed.as_mut().unwrap().poll() {
Ok(Async::Ready(Some(msg))) => {
updated = true;
self.flags.insert(Flags::STARTED);
match msg {
InMessage::Message { req, payload } => {
if payload {
InMessage::Message(req, payload) => {
match payload {
InMessageType::Payload => {
let (ps, pl) = Payload::new(false);
*req.inner.payload.borrow_mut() = Some(pl);
self.payload = Some(ps);
}
InMessageType::Unhandled => {
self.unhandled = Some(req);
return Ok(updated);
}
_ => (),
}
// handle request early
if self.state.is_empty() {
@ -454,15 +474,16 @@ where
S: Service<Request = Request, Response = Response>,
S::Error: Debug,
{
type Item = ();
type Item = H1ServiceResult<T>;
type Error = DispatchError<S::Error>;
#[inline]
fn poll(&mut self) -> Poll<(), Self::Error> {
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if self.flags.contains(Flags::SHUTDOWN) {
self.poll_keepalive()?;
try_ready!(self.poll_flush());
Ok(AsyncWrite::shutdown(self.framed.get_mut())?)
let io = self.framed.take().unwrap().into_inner();
Ok(Async::Ready(H1ServiceResult::Shutdown(io)))
} else {
self.poll_keepalive()?;
self.poll_request()?;
@ -474,15 +495,21 @@ where
if let Some(err) = self.error.take() {
Err(err)
} else if self.flags.contains(Flags::DISCONNECTED) {
Ok(Async::Ready(()))
Ok(Async::Ready(H1ServiceResult::Disconnected))
}
// unhandled request (upgrade or connect)
else if self.unhandled.is_some() {
let req = self.unhandled.take().unwrap();
let framed = self.framed.take().unwrap();
Ok(Async::Ready(H1ServiceResult::Unhandled(req, framed)))
}
// disconnect if keep-alive is not enabled
else if self.flags.contains(Flags::STARTED) && !self
.flags
.intersects(Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED)
{
self.flags.insert(Flags::SHUTDOWN);
self.poll()
let io = self.framed.take().unwrap().into_inner();
Ok(Async::Ready(H1ServiceResult::Shutdown(io)))
} else {
Ok(Async::NotReady)
}

View File

@ -1,11 +1,22 @@
//! HTTP/1 implementation
use actix_net::codec::Framed;
mod codec;
mod decoder;
mod dispatcher;
mod encoder;
mod service;
pub use self::codec::{Codec, InMessage, OutMessage};
pub use self::codec::{Codec, InMessage, InMessageType, OutMessage};
pub use self::decoder::{PayloadDecoder, RequestDecoder};
pub use self::dispatcher::Dispatcher;
pub use self::service::{H1Service, H1ServiceHandler};
use request::Request;
/// H1 service response type
pub enum H1ServiceResult<T> {
Disconnected,
Shutdown(T),
Unhandled(Request, Framed<T, Codec>),
}

View File

@ -12,6 +12,7 @@ use request::Request;
use response::Response;
use super::dispatcher::Dispatcher;
use super::H1ServiceResult;
/// `NewService` implementation for HTTP1 transport
pub struct H1Service<T, S> {
@ -51,7 +52,7 @@ where
S::Error: Debug,
{
type Request = T;
type Response = ();
type Response = H1ServiceResult<T>;
type Error = DispatchError<S::Error>;
type InitError = S::InitError;
type Service = H1ServiceHandler<T, S::Service>;
@ -243,7 +244,7 @@ where
S::Error: Debug,
{
type Request = T;
type Response = ();
type Response = H1ServiceResult<T>;
type Error = DispatchError<S::Error>;
type Future = Dispatcher<T, S>;

View File

@ -9,6 +9,7 @@ use std::{io::Read, io::Write, net, thread, time};
use actix::System;
use actix_net::server::Server;
use actix_net::service::NewServiceExt;
use actix_web::{client, test, HttpMessage};
use bytes::Bytes;
use futures::future::{self, ok};
@ -29,6 +30,7 @@ fn test_h1_v2() {
.server_hostname("localhost")
.server_address(addr)
.finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
.map(|_| ())
}).unwrap()
.run();
});
@ -53,6 +55,7 @@ fn test_slow_request() {
h1::H1Service::build()
.client_timeout(100)
.finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
.map(|_| ())
}).unwrap()
.run();
});
@ -72,6 +75,7 @@ fn test_malformed_request() {
Server::new()
.bind("test", addr, move || {
h1::H1Service::new(|_| future::ok::<_, ()>(Response::Ok().finish()))
.map(|_| ())
}).unwrap()
.run();
});
@ -106,7 +110,7 @@ fn test_content_length() {
StatusCode::NOT_FOUND,
];
future::ok::<_, ()>(Response::new(statuses[indx]))
})
}).map(|_| ())
}).unwrap()
.run();
});
@ -172,7 +176,7 @@ fn test_headers() {
);
}
future::ok::<_, ()>(builder.body(data.clone()))
})
}).map(|_| ())
})
.unwrap()
.run()
@ -221,6 +225,7 @@ fn test_body() {
Server::new()
.bind("test", addr, move || {
h1::H1Service::new(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
.map(|_| ())
}).unwrap()
.run();
});
@ -246,7 +251,7 @@ fn test_head_empty() {
.bind("test", addr, move || {
h1::H1Service::new(|_| {
ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).finish())
})
}).map(|_| ())
}).unwrap()
.run()
});
@ -282,7 +287,7 @@ fn test_head_binary() {
ok::<_, ()>(
Response::Ok().content_length(STR.len() as u64).body(STR),
)
})
}).map(|_| ())
}).unwrap()
.run()
});
@ -314,7 +319,7 @@ fn test_head_binary2() {
thread::spawn(move || {
Server::new()
.bind("test", addr, move || {
h1::H1Service::new(|_| ok::<_, ()>(Response::Ok().body(STR)))
h1::H1Service::new(|_| ok::<_, ()>(Response::Ok().body(STR))).map(|_| ())
}).unwrap()
.run()
});
@ -349,7 +354,7 @@ fn test_body_length() {
.content_length(STR.len() as u64)
.body(Body::Streaming(Box::new(body))),
)
})
}).map(|_| ())
}).unwrap()
.run()
});
@ -380,7 +385,7 @@ fn test_body_chunked_explicit() {
.chunked()
.body(Body::Streaming(Box::new(body))),
)
})
}).map(|_| ())
}).unwrap()
.run()
});
@ -409,7 +414,7 @@ fn test_body_chunked_implicit() {
h1::H1Service::new(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(Response::Ok().body(Body::Streaming(Box::new(body))))
})
}).map(|_| ())
}).unwrap()
.run()
});

View File

@ -51,7 +51,7 @@ fn test_simple() {
.and_then(TakeItem::new().map_err(|_| ()))
.and_then(|(req, framed): (_, Framed<_, _>)| {
// validate request
if let Some(h1::InMessage::Message { req, payload: _ }) = req {
if let Some(h1::InMessage::Message(req, _)) = req {
match ws::handshake(&req) {
Err(e) => {
// validation failed