1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-28 09:42:40 +01:00

do not handle upgrade and connect requests

This commit is contained in:
Nikolay Kim 2018-10-13 23:57:31 -07:00
parent b960b5827c
commit d39c018c93
7 changed files with 141 additions and 57 deletions

View File

@ -4,7 +4,7 @@ use std::io::{self, Write};
use bytes::{BufMut, Bytes, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};
use tokio_codec::{Decoder, Encoder}; use tokio_codec::{Decoder, Encoder};
use super::decoder::{PayloadDecoder, PayloadItem, RequestDecoder}; use super::decoder::{PayloadDecoder, PayloadItem, RequestDecoder, RequestPayloadType};
use super::encoder::{ResponseEncoder, ResponseLength}; use super::encoder::{ResponseEncoder, ResponseLength};
use body::{Binary, Body}; use body::{Binary, Body};
use config::ServiceConfig; use config::ServiceConfig;
@ -17,10 +17,11 @@ use response::Response;
bitflags! { bitflags! {
struct Flags: u8 { struct Flags: u8 {
const HEAD = 0b0000_0001; const HEAD = 0b0000_0001;
const UPGRADE = 0b0000_0010; const UPGRADE = 0b0000_0010;
const KEEPALIVE = 0b0000_0100; const KEEPALIVE = 0b0000_0100;
const KEEPALIVE_ENABLED = 0b0001_0000; const KEEPALIVE_ENABLED = 0b0000_1000;
const UNHANDLED = 0b0001_0000;
} }
} }
@ -39,11 +40,19 @@ pub enum OutMessage {
#[derive(Debug)] #[derive(Debug)]
pub enum InMessage { pub enum InMessage {
/// Request /// Request
Message { req: Request, payload: bool }, Message(Request, InMessageType),
/// Payload chunk /// Payload chunk
Chunk(Option<Bytes>), Chunk(Option<Bytes>),
} }
/// Incoming request type
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InMessageType {
None,
Payload,
Unhandled,
}
/// HTTP/1 Codec /// HTTP/1 Codec
pub struct Codec { pub struct Codec {
config: ServiceConfig, config: ServiceConfig,
@ -246,6 +255,8 @@ impl Decoder for Codec {
Some(PayloadItem::Eof) => Some(InMessage::Chunk(None)), Some(PayloadItem::Eof) => Some(InMessage::Chunk(None)),
None => None, None => None,
}) })
} else if self.flags.contains(Flags::UNHANDLED) {
Ok(None)
} else if let Some((req, payload)) = self.decoder.decode(src)? { } else if let Some((req, payload)) = self.decoder.decode(src)? {
self.flags self.flags
.set(Flags::HEAD, req.inner.method == Method::HEAD); .set(Flags::HEAD, req.inner.method == Method::HEAD);
@ -253,11 +264,21 @@ impl Decoder for Codec {
if self.flags.contains(Flags::KEEPALIVE_ENABLED) { if self.flags.contains(Flags::KEEPALIVE_ENABLED) {
self.flags.set(Flags::KEEPALIVE, req.keep_alive()); self.flags.set(Flags::KEEPALIVE, req.keep_alive());
} }
self.payload = payload; let payload = match payload {
Ok(Some(InMessage::Message { RequestPayloadType::None => {
req, self.payload = None;
payload: self.payload.is_some(), InMessageType::None
})) }
RequestPayloadType::Payload(pl) => {
self.payload = Some(pl);
InMessageType::Payload
}
RequestPayloadType::Unhandled => {
self.payload = None;
InMessageType::Unhandled
}
};
Ok(Some(InMessage::Message(req, payload)))
} else { } else {
Ok(None) Ok(None)
} }

View File

@ -16,6 +16,13 @@ const MAX_HEADERS: usize = 96;
pub struct RequestDecoder(&'static RequestPool); pub struct RequestDecoder(&'static RequestPool);
/// Incoming request type
pub enum RequestPayloadType {
None,
Payload(PayloadDecoder),
Unhandled,
}
impl RequestDecoder { impl RequestDecoder {
pub(crate) fn with_pool(pool: &'static RequestPool) -> RequestDecoder { pub(crate) fn with_pool(pool: &'static RequestPool) -> RequestDecoder {
RequestDecoder(pool) RequestDecoder(pool)
@ -29,7 +36,7 @@ impl Default for RequestDecoder {
} }
impl Decoder for RequestDecoder { impl Decoder for RequestDecoder {
type Item = (Request, Option<PayloadDecoder>); type Item = (Request, RequestPayloadType);
type Error = ParseError; type Error = ParseError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
@ -149,18 +156,18 @@ impl Decoder for RequestDecoder {
// https://tools.ietf.org/html/rfc7230#section-3.3.3 // https://tools.ietf.org/html/rfc7230#section-3.3.3
let decoder = if chunked { let decoder = if chunked {
// Chunked encoding // Chunked encoding
Some(PayloadDecoder::chunked()) RequestPayloadType::Payload(PayloadDecoder::chunked())
} else if let Some(len) = content_length { } else if let Some(len) = content_length {
// Content-Length // Content-Length
Some(PayloadDecoder::length(len)) RequestPayloadType::Payload(PayloadDecoder::length(len))
} else if has_upgrade || msg.inner.method == Method::CONNECT { } else if has_upgrade || msg.inner.method == Method::CONNECT {
// upgrade(websocket) or connect // upgrade(websocket) or connect
Some(PayloadDecoder::eof()) RequestPayloadType::Unhandled
} else if src.len() >= MAX_BUFFER_SIZE { } else if src.len() >= MAX_BUFFER_SIZE {
error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); error!("MAX_BUFFER_SIZE unprocessed data reached, closing");
return Err(ParseError::TooLarge); return Err(ParseError::TooLarge);
} else { } else {
None RequestPayloadType::None
}; };
Ok(Some((msg, decoder))) Ok(Some((msg, decoder)))
@ -481,20 +488,36 @@ mod tests {
use super::*; use super::*;
use error::ParseError; use error::ParseError;
use h1::InMessage; use h1::{InMessage, InMessageType};
use httpmessage::HttpMessage; use httpmessage::HttpMessage;
use request::Request; use request::Request;
impl RequestPayloadType {
fn unwrap(self) -> PayloadDecoder {
match self {
RequestPayloadType::Payload(pl) => pl,
_ => panic!(),
}
}
fn is_unhandled(&self) -> bool {
match self {
RequestPayloadType::Unhandled => true,
_ => false,
}
}
}
impl InMessage { impl InMessage {
fn message(self) -> Request { fn message(self) -> Request {
match self { match self {
InMessage::Message { req, payload: _ } => req, InMessage::Message(req, _) => req,
_ => panic!("error"), _ => panic!("error"),
} }
} }
fn is_payload(&self) -> bool { fn is_payload(&self) -> bool {
match *self { match *self {
InMessage::Message { req: _, payload } => payload, InMessage::Message(_, payload) => payload == InMessageType::Payload,
_ => panic!("error"), _ => panic!("error"),
} }
} }
@ -919,13 +942,9 @@ mod tests {
); );
let mut reader = RequestDecoder::default(); let mut reader = RequestDecoder::default();
let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); let (req, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
assert!(!req.keep_alive()); assert!(!req.keep_alive());
assert!(req.upgrade()); assert!(req.upgrade());
assert_eq!( assert!(pl.is_unhandled());
pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(),
b"some raw data"
);
} }
#[test] #[test]

View File

@ -18,7 +18,8 @@ use error::DispatchError;
use request::Request; use request::Request;
use response::Response; use response::Response;
use super::codec::{Codec, InMessage, OutMessage}; use super::codec::{Codec, InMessage, InMessageType, OutMessage};
use super::H1ServiceResult;
const MAX_PIPELINED_MESSAGES: usize = 16; const MAX_PIPELINED_MESSAGES: usize = 16;
@ -41,13 +42,14 @@ where
{ {
service: S, service: S,
flags: Flags, flags: Flags,
framed: Framed<T, Codec>, framed: Option<Framed<T, Codec>>,
error: Option<DispatchError<S::Error>>, error: Option<DispatchError<S::Error>>,
config: ServiceConfig, config: ServiceConfig,
state: State<S>, state: State<S>,
payload: Option<PayloadSender>, payload: Option<PayloadSender>,
messages: VecDeque<Message>, messages: VecDeque<Message>,
unhandled: Option<Request>,
ka_expire: Instant, ka_expire: Instant,
ka_timer: Option<Delay>, ka_timer: Option<Delay>,
@ -112,9 +114,10 @@ where
state: State::None, state: State::None,
error: None, error: None,
messages: VecDeque::new(), messages: VecDeque::new(),
framed: Some(framed),
unhandled: None,
service, service,
flags, flags,
framed,
config, config,
ka_expire, ka_expire,
ka_timer, ka_timer,
@ -144,7 +147,7 @@ where
/// Flush stream /// Flush stream
fn poll_flush(&mut self) -> Poll<(), DispatchError<S::Error>> { fn poll_flush(&mut self) -> Poll<(), DispatchError<S::Error>> {
if !self.flags.contains(Flags::FLUSHED) { if !self.flags.contains(Flags::FLUSHED) {
match self.framed.poll_complete() { match self.framed.as_mut().unwrap().poll_complete() {
Ok(Async::NotReady) => Ok(Async::NotReady), Ok(Async::NotReady) => Ok(Async::NotReady),
Err(err) => { Err(err) => {
debug!("Error sending data: {}", err); debug!("Error sending data: {}", err);
@ -187,7 +190,11 @@ where
State::ServiceCall(ref mut fut) => { State::ServiceCall(ref mut fut) => {
match fut.poll().map_err(DispatchError::Service)? { match fut.poll().map_err(DispatchError::Service)? {
Async::Ready(mut res) => { Async::Ready(mut res) => {
self.framed.get_codec_mut().prepare_te(&mut res); self.framed
.as_mut()
.unwrap()
.get_codec_mut()
.prepare_te(&mut res);
let body = res.replace_body(Body::Empty); let body = res.replace_body(Body::Empty);
Some(State::SendResponse(Some(( Some(State::SendResponse(Some((
OutMessage::Response(res), OutMessage::Response(res),
@ -200,11 +207,11 @@ where
// send respons // send respons
State::SendResponse(ref mut item) => { State::SendResponse(ref mut item) => {
let (msg, body) = item.take().expect("SendResponse is empty"); let (msg, body) = item.take().expect("SendResponse is empty");
match self.framed.start_send(msg) { match self.framed.as_mut().unwrap().start_send(msg) {
Ok(AsyncSink::Ready) => { Ok(AsyncSink::Ready) => {
self.flags.set( self.flags.set(
Flags::KEEPALIVE, Flags::KEEPALIVE,
self.framed.get_codec().keepalive(), self.framed.as_mut().unwrap().get_codec().keepalive(),
); );
self.flags.remove(Flags::FLUSHED); self.flags.remove(Flags::FLUSHED);
match body { match body {
@ -233,7 +240,7 @@ where
// Send payload // Send payload
State::SendPayload(ref mut stream, ref mut bin) => { State::SendPayload(ref mut stream, ref mut bin) => {
if let Some(item) = bin.take() { if let Some(item) = bin.take() {
match self.framed.start_send(item) { match self.framed.as_mut().unwrap().start_send(item) {
Ok(AsyncSink::Ready) => { Ok(AsyncSink::Ready) => {
self.flags.remove(Flags::FLUSHED); self.flags.remove(Flags::FLUSHED);
} }
@ -248,6 +255,8 @@ where
match stream.poll() { match stream.poll() {
Ok(Async::Ready(Some(item))) => match self Ok(Async::Ready(Some(item))) => match self
.framed .framed
.as_mut()
.unwrap()
.start_send(OutMessage::Chunk(Some(item.into()))) .start_send(OutMessage::Chunk(Some(item.into())))
{ {
Ok(AsyncSink::Ready) => { Ok(AsyncSink::Ready) => {
@ -297,7 +306,11 @@ where
let mut task = self.service.call(req); let mut task = self.service.call(req);
match task.poll().map_err(DispatchError::Service)? { match task.poll().map_err(DispatchError::Service)? {
Async::Ready(mut res) => { Async::Ready(mut res) => {
self.framed.get_codec_mut().prepare_te(&mut res); self.framed
.as_mut()
.unwrap()
.get_codec_mut()
.prepare_te(&mut res);
let body = res.replace_body(Body::Empty); let body = res.replace_body(Body::Empty);
Ok(State::SendResponse(Some((OutMessage::Response(res), body)))) Ok(State::SendResponse(Some((OutMessage::Response(res), body))))
} }
@ -314,17 +327,24 @@ where
let mut updated = false; let mut updated = false;
'outer: loop { 'outer: loop {
match self.framed.poll() { match self.framed.as_mut().unwrap().poll() {
Ok(Async::Ready(Some(msg))) => { Ok(Async::Ready(Some(msg))) => {
updated = true; updated = true;
self.flags.insert(Flags::STARTED); self.flags.insert(Flags::STARTED);
match msg { match msg {
InMessage::Message { req, payload } => { InMessage::Message(req, payload) => {
if payload { match payload {
let (ps, pl) = Payload::new(false); InMessageType::Payload => {
*req.inner.payload.borrow_mut() = Some(pl); let (ps, pl) = Payload::new(false);
self.payload = Some(ps); *req.inner.payload.borrow_mut() = Some(pl);
self.payload = Some(ps);
}
InMessageType::Unhandled => {
self.unhandled = Some(req);
return Ok(updated);
}
_ => (),
} }
// handle request early // handle request early
@ -454,15 +474,16 @@ where
S: Service<Request = Request, Response = Response>, S: Service<Request = Request, Response = Response>,
S::Error: Debug, S::Error: Debug,
{ {
type Item = (); type Item = H1ServiceResult<T>;
type Error = DispatchError<S::Error>; type Error = DispatchError<S::Error>;
#[inline] #[inline]
fn poll(&mut self) -> Poll<(), Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if self.flags.contains(Flags::SHUTDOWN) { if self.flags.contains(Flags::SHUTDOWN) {
self.poll_keepalive()?; self.poll_keepalive()?;
try_ready!(self.poll_flush()); try_ready!(self.poll_flush());
Ok(AsyncWrite::shutdown(self.framed.get_mut())?) let io = self.framed.take().unwrap().into_inner();
Ok(Async::Ready(H1ServiceResult::Shutdown(io)))
} else { } else {
self.poll_keepalive()?; self.poll_keepalive()?;
self.poll_request()?; self.poll_request()?;
@ -474,15 +495,21 @@ where
if let Some(err) = self.error.take() { if let Some(err) = self.error.take() {
Err(err) Err(err)
} else if self.flags.contains(Flags::DISCONNECTED) { } else if self.flags.contains(Flags::DISCONNECTED) {
Ok(Async::Ready(())) Ok(Async::Ready(H1ServiceResult::Disconnected))
}
// unhandled request (upgrade or connect)
else if self.unhandled.is_some() {
let req = self.unhandled.take().unwrap();
let framed = self.framed.take().unwrap();
Ok(Async::Ready(H1ServiceResult::Unhandled(req, framed)))
} }
// disconnect if keep-alive is not enabled // disconnect if keep-alive is not enabled
else if self.flags.contains(Flags::STARTED) && !self else if self.flags.contains(Flags::STARTED) && !self
.flags .flags
.intersects(Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED) .intersects(Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED)
{ {
self.flags.insert(Flags::SHUTDOWN); let io = self.framed.take().unwrap().into_inner();
self.poll() Ok(Async::Ready(H1ServiceResult::Shutdown(io)))
} else { } else {
Ok(Async::NotReady) Ok(Async::NotReady)
} }

View File

@ -1,11 +1,22 @@
//! HTTP/1 implementation //! HTTP/1 implementation
use actix_net::codec::Framed;
mod codec; mod codec;
mod decoder; mod decoder;
mod dispatcher; mod dispatcher;
mod encoder; mod encoder;
mod service; mod service;
pub use self::codec::{Codec, InMessage, OutMessage}; pub use self::codec::{Codec, InMessage, InMessageType, OutMessage};
pub use self::decoder::{PayloadDecoder, RequestDecoder}; pub use self::decoder::{PayloadDecoder, RequestDecoder};
pub use self::dispatcher::Dispatcher; pub use self::dispatcher::Dispatcher;
pub use self::service::{H1Service, H1ServiceHandler}; pub use self::service::{H1Service, H1ServiceHandler};
use request::Request;
/// H1 service response type
pub enum H1ServiceResult<T> {
Disconnected,
Shutdown(T),
Unhandled(Request, Framed<T, Codec>),
}

View File

@ -12,6 +12,7 @@ use request::Request;
use response::Response; use response::Response;
use super::dispatcher::Dispatcher; use super::dispatcher::Dispatcher;
use super::H1ServiceResult;
/// `NewService` implementation for HTTP1 transport /// `NewService` implementation for HTTP1 transport
pub struct H1Service<T, S> { pub struct H1Service<T, S> {
@ -51,7 +52,7 @@ where
S::Error: Debug, S::Error: Debug,
{ {
type Request = T; type Request = T;
type Response = (); type Response = H1ServiceResult<T>;
type Error = DispatchError<S::Error>; type Error = DispatchError<S::Error>;
type InitError = S::InitError; type InitError = S::InitError;
type Service = H1ServiceHandler<T, S::Service>; type Service = H1ServiceHandler<T, S::Service>;
@ -243,7 +244,7 @@ where
S::Error: Debug, S::Error: Debug,
{ {
type Request = T; type Request = T;
type Response = (); type Response = H1ServiceResult<T>;
type Error = DispatchError<S::Error>; type Error = DispatchError<S::Error>;
type Future = Dispatcher<T, S>; type Future = Dispatcher<T, S>;

View File

@ -9,6 +9,7 @@ use std::{io::Read, io::Write, net, thread, time};
use actix::System; use actix::System;
use actix_net::server::Server; use actix_net::server::Server;
use actix_net::service::NewServiceExt;
use actix_web::{client, test, HttpMessage}; use actix_web::{client, test, HttpMessage};
use bytes::Bytes; use bytes::Bytes;
use futures::future::{self, ok}; use futures::future::{self, ok};
@ -29,6 +30,7 @@ fn test_h1_v2() {
.server_hostname("localhost") .server_hostname("localhost")
.server_address(addr) .server_address(addr)
.finish(|_| future::ok::<_, ()>(Response::Ok().finish())) .finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
.map(|_| ())
}).unwrap() }).unwrap()
.run(); .run();
}); });
@ -53,6 +55,7 @@ fn test_slow_request() {
h1::H1Service::build() h1::H1Service::build()
.client_timeout(100) .client_timeout(100)
.finish(|_| future::ok::<_, ()>(Response::Ok().finish())) .finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
.map(|_| ())
}).unwrap() }).unwrap()
.run(); .run();
}); });
@ -72,6 +75,7 @@ fn test_malformed_request() {
Server::new() Server::new()
.bind("test", addr, move || { .bind("test", addr, move || {
h1::H1Service::new(|_| future::ok::<_, ()>(Response::Ok().finish())) h1::H1Service::new(|_| future::ok::<_, ()>(Response::Ok().finish()))
.map(|_| ())
}).unwrap() }).unwrap()
.run(); .run();
}); });
@ -106,7 +110,7 @@ fn test_content_length() {
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
]; ];
future::ok::<_, ()>(Response::new(statuses[indx])) future::ok::<_, ()>(Response::new(statuses[indx]))
}) }).map(|_| ())
}).unwrap() }).unwrap()
.run(); .run();
}); });
@ -172,7 +176,7 @@ fn test_headers() {
); );
} }
future::ok::<_, ()>(builder.body(data.clone())) future::ok::<_, ()>(builder.body(data.clone()))
}) }).map(|_| ())
}) })
.unwrap() .unwrap()
.run() .run()
@ -221,6 +225,7 @@ fn test_body() {
Server::new() Server::new()
.bind("test", addr, move || { .bind("test", addr, move || {
h1::H1Service::new(|_| future::ok::<_, ()>(Response::Ok().body(STR))) h1::H1Service::new(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
.map(|_| ())
}).unwrap() }).unwrap()
.run(); .run();
}); });
@ -246,7 +251,7 @@ fn test_head_empty() {
.bind("test", addr, move || { .bind("test", addr, move || {
h1::H1Service::new(|_| { h1::H1Service::new(|_| {
ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).finish()) ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).finish())
}) }).map(|_| ())
}).unwrap() }).unwrap()
.run() .run()
}); });
@ -282,7 +287,7 @@ fn test_head_binary() {
ok::<_, ()>( ok::<_, ()>(
Response::Ok().content_length(STR.len() as u64).body(STR), Response::Ok().content_length(STR.len() as u64).body(STR),
) )
}) }).map(|_| ())
}).unwrap() }).unwrap()
.run() .run()
}); });
@ -314,7 +319,7 @@ fn test_head_binary2() {
thread::spawn(move || { thread::spawn(move || {
Server::new() Server::new()
.bind("test", addr, move || { .bind("test", addr, move || {
h1::H1Service::new(|_| ok::<_, ()>(Response::Ok().body(STR))) h1::H1Service::new(|_| ok::<_, ()>(Response::Ok().body(STR))).map(|_| ())
}).unwrap() }).unwrap()
.run() .run()
}); });
@ -349,7 +354,7 @@ fn test_body_length() {
.content_length(STR.len() as u64) .content_length(STR.len() as u64)
.body(Body::Streaming(Box::new(body))), .body(Body::Streaming(Box::new(body))),
) )
}) }).map(|_| ())
}).unwrap() }).unwrap()
.run() .run()
}); });
@ -380,7 +385,7 @@ fn test_body_chunked_explicit() {
.chunked() .chunked()
.body(Body::Streaming(Box::new(body))), .body(Body::Streaming(Box::new(body))),
) )
}) }).map(|_| ())
}).unwrap() }).unwrap()
.run() .run()
}); });
@ -409,7 +414,7 @@ fn test_body_chunked_implicit() {
h1::H1Service::new(|_| { h1::H1Service::new(|_| {
let body = once(Ok(Bytes::from_static(STR.as_ref()))); let body = once(Ok(Bytes::from_static(STR.as_ref())));
ok::<_, ()>(Response::Ok().body(Body::Streaming(Box::new(body)))) ok::<_, ()>(Response::Ok().body(Body::Streaming(Box::new(body))))
}) }).map(|_| ())
}).unwrap() }).unwrap()
.run() .run()
}); });

View File

@ -51,7 +51,7 @@ fn test_simple() {
.and_then(TakeItem::new().map_err(|_| ())) .and_then(TakeItem::new().map_err(|_| ()))
.and_then(|(req, framed): (_, Framed<_, _>)| { .and_then(|(req, framed): (_, Framed<_, _>)| {
// validate request // validate request
if let Some(h1::InMessage::Message { req, payload: _ }) = req { if let Some(h1::InMessage::Message(req, _)) = req {
match ws::handshake(&req) { match ws::handshake(&req) {
Err(e) => { Err(e) => {
// validation failed // validation failed