1
0
mirror of https://github.com/actix/actix-extras.git synced 2025-01-22 23:05:56 +01:00

refactor h1 dispatcher

This commit is contained in:
Nikolay Kim 2019-04-06 00:16:04 -07:00
parent fbedaec661
commit 3872d3ba5a
8 changed files with 275 additions and 168 deletions

View File

@ -9,7 +9,7 @@ use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCOD
use http::{Method, StatusCode, Version}; use http::{Method, StatusCode, Version};
use super::decoder::{PayloadDecoder, PayloadItem, PayloadType}; use super::decoder::{PayloadDecoder, PayloadItem, PayloadType};
use super::{decoder, encoder, reserve_readbuf}; use super::{decoder, encoder};
use super::{Message, MessageType}; use super::{Message, MessageType};
use crate::body::BodySize; use crate::body::BodySize;
use crate::config::ServiceConfig; use crate::config::ServiceConfig;
@ -31,7 +31,7 @@ const AVERAGE_HEADER_SIZE: usize = 30;
/// HTTP/1 Codec /// HTTP/1 Codec
pub struct Codec { pub struct Codec {
config: ServiceConfig, pub(crate) config: ServiceConfig,
decoder: decoder::MessageDecoder<Request>, decoder: decoder::MessageDecoder<Request>,
payload: Option<PayloadDecoder>, payload: Option<PayloadDecoder>,
version: Version, version: Version,
@ -78,16 +78,25 @@ impl Codec {
} }
} }
#[inline]
/// Check if request is upgrade /// Check if request is upgrade
pub fn upgrade(&self) -> bool { pub fn upgrade(&self) -> bool {
self.ctype == ConnectionType::Upgrade self.ctype == ConnectionType::Upgrade
} }
#[inline]
/// Check if last response is keep-alive /// Check if last response is keep-alive
pub fn keepalive(&self) -> bool { pub fn keepalive(&self) -> bool {
self.ctype == ConnectionType::KeepAlive self.ctype == ConnectionType::KeepAlive
} }
#[inline]
/// Check if keep-alive enabled on server level
pub fn keepalive_enabled(&self) -> bool {
self.flags.contains(Flags::KEEPALIVE_ENABLED)
}
#[inline]
/// Check last request's message type /// Check last request's message type
pub fn message_type(&self) -> MessageType { pub fn message_type(&self) -> MessageType {
if self.flags.contains(Flags::STREAM) { if self.flags.contains(Flags::STREAM) {
@ -107,10 +116,7 @@ impl Decoder for Codec {
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if self.payload.is_some() { if self.payload.is_some() {
Ok(match self.payload.as_mut().unwrap().decode(src)? { Ok(match self.payload.as_mut().unwrap().decode(src)? {
Some(PayloadItem::Chunk(chunk)) => { Some(PayloadItem::Chunk(chunk)) => Some(Message::Chunk(Some(chunk))),
reserve_readbuf(src);
Some(Message::Chunk(Some(chunk)))
}
Some(PayloadItem::Eof) => { Some(PayloadItem::Eof) => {
self.payload.take(); self.payload.take();
Some(Message::Chunk(None)) Some(Message::Chunk(None))
@ -135,7 +141,6 @@ impl Decoder for Codec {
self.flags.insert(Flags::STREAM); self.flags.insert(Flags::STREAM);
} }
} }
reserve_readbuf(src);
Ok(Some(Message::Item(req))) Ok(Some(Message::Item(req)))
} else { } else {
Ok(None) Ok(None)

View File

@ -84,7 +84,9 @@ pub(crate) trait MessageType: Sized {
header::CONTENT_LENGTH => { header::CONTENT_LENGTH => {
if let Ok(s) = value.to_str() { if let Ok(s) = value.to_str() {
if let Ok(len) = s.parse::<u64>() { if let Ok(len) = s.parse::<u64>() {
content_length = Some(len); if len != 0 {
content_length = Some(len);
}
} else { } else {
debug!("illegal Content-Length: {:?}", s); debug!("illegal Content-Length: {:?}", s);
return Err(ParseError::Header); return Err(ParseError::Header);

View File

@ -1,20 +1,20 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::mem;
use std::time::Instant; use std::time::Instant;
use std::{fmt, io};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder};
use actix_service::Service; use actix_service::Service;
use actix_utils::cloneable::CloneableService; use actix_utils::cloneable::CloneableService;
use bitflags::bitflags; use bitflags::bitflags;
use futures::{Async, Future, Poll, Sink, Stream}; use bytes::{BufMut, BytesMut};
use log::{debug, error, trace}; use futures::{Async, Future, Poll};
use log::{error, trace};
use tokio_timer::Delay; use tokio_timer::Delay;
use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::body::{Body, BodySize, MessageBody, ResponseBody};
use crate::config::ServiceConfig; use crate::config::ServiceConfig;
use crate::error::{DispatchError, Error}; use crate::error::{DispatchError, Error};
use crate::error::{ParseError, PayloadError}; use crate::error::{ParseError, PayloadError};
use crate::http::StatusCode;
use crate::request::Request; use crate::request::Request;
use crate::response::Response; use crate::response::Response;
@ -22,17 +22,19 @@ use super::codec::Codec;
use super::payload::{Payload, PayloadSender, PayloadStatus, PayloadWriter}; use super::payload::{Payload, PayloadSender, PayloadStatus, PayloadWriter};
use super::{Message, MessageType}; use super::{Message, MessageType};
const LW_BUFFER_SIZE: usize = 4096;
const HW_BUFFER_SIZE: usize = 32_768;
const MAX_PIPELINED_MESSAGES: usize = 16; const MAX_PIPELINED_MESSAGES: usize = 16;
bitflags! { bitflags! {
pub struct Flags: u8 { pub struct Flags: u8 {
const STARTED = 0b0000_0001; const STARTED = 0b0000_0001;
const KEEPALIVE_ENABLED = 0b0000_0010; const KEEPALIVE = 0b0000_0010;
const KEEPALIVE = 0b0000_0100; const POLLED = 0b0000_0100;
const POLLED = 0b0000_1000; const SHUTDOWN = 0b0000_1000;
const SHUTDOWN = 0b0010_0000; const READ_DISCONNECT = 0b0001_0000;
const DISCONNECTED = 0b0100_0000; const WRITE_DISCONNECT = 0b0010_0000;
const DROPPING = 0b1000_0000; const DROPPING = 0b0100_0000;
} }
} }
@ -59,9 +61,7 @@ where
service: CloneableService<S>, service: CloneableService<S>,
expect: CloneableService<X>, expect: CloneableService<X>,
flags: Flags, flags: Flags,
framed: Framed<T, Codec>,
error: Option<DispatchError>, error: Option<DispatchError>,
config: ServiceConfig,
state: State<S, B, X>, state: State<S, B, X>,
payload: Option<PayloadSender>, payload: Option<PayloadSender>,
@ -69,6 +69,11 @@ where
ka_expire: Instant, ka_expire: Instant,
ka_timer: Option<Delay>, ka_timer: Option<Delay>,
io: T,
read_buf: BytesMut,
write_buf: BytesMut,
codec: Codec,
} }
enum DispatcherMessage { enum DispatcherMessage {
@ -101,6 +106,30 @@ where
false false
} }
} }
fn is_call(&self) -> bool {
if let State::ServiceCall(_) = self {
true
} else {
false
}
}
}
impl<S, B, X> fmt::Debug for State<S, B, X>
where
S: Service<Request = Request>,
X: Service<Request = Request, Response = Request>,
B: MessageBody,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
State::None => write!(f, "State::None"),
State::ExpectCall(_) => write!(f, "State::ExceptCall"),
State::ServiceCall(_) => write!(f, "State::ServiceCall"),
State::SendPayload(_) => write!(f, "State::SendPayload"),
}
}
} }
impl<T, S, B, X> Dispatcher<T, S, B, X> impl<T, S, B, X> Dispatcher<T, S, B, X>
@ -121,8 +150,10 @@ where
expect: CloneableService<X>, expect: CloneableService<X>,
) -> Self { ) -> Self {
Dispatcher::with_timeout( Dispatcher::with_timeout(
Framed::new(stream, Codec::new(config.clone())), stream,
Codec::new(config.clone()),
config, config,
BytesMut::with_capacity(HW_BUFFER_SIZE),
None, None,
service, service,
expect, expect,
@ -131,15 +162,17 @@ where
/// Create http/1 dispatcher with slow request timeout. /// Create http/1 dispatcher with slow request timeout.
pub fn with_timeout( pub fn with_timeout(
framed: Framed<T, Codec>, io: T,
codec: Codec,
config: ServiceConfig, config: ServiceConfig,
read_buf: BytesMut,
timeout: Option<Delay>, timeout: Option<Delay>,
service: CloneableService<S>, service: CloneableService<S>,
expect: CloneableService<X>, expect: CloneableService<X>,
) -> Self { ) -> Self {
let keepalive = config.keep_alive_enabled(); let keepalive = config.keep_alive_enabled();
let flags = if keepalive { let flags = if keepalive {
Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED Flags::KEEPALIVE
} else { } else {
Flags::empty() Flags::empty()
}; };
@ -155,7 +188,10 @@ where
Dispatcher { Dispatcher {
inner: Some(InnerDispatcher { inner: Some(InnerDispatcher {
framed, io,
codec,
read_buf,
write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
payload: None, payload: None,
state: State::None, state: State::None,
error: None, error: None,
@ -163,7 +199,6 @@ where
service, service,
expect, expect,
flags, flags,
config,
ka_expire, ka_expire,
ka_timer, ka_timer,
}), }),
@ -182,11 +217,9 @@ where
X::Error: Into<Error>, X::Error: Into<Error>,
{ {
fn can_read(&self) -> bool { fn can_read(&self) -> bool {
if self.flags.contains(Flags::DISCONNECTED) { if self.flags.contains(Flags::READ_DISCONNECT) {
return false; return false;
} } else if let Some(ref info) = self.payload {
if let Some(ref info) = self.payload {
info.need_read() == PayloadStatus::Read info.need_read() == PayloadStatus::Read
} else { } else {
true true
@ -195,32 +228,52 @@ where
// if checked is set to true, delay disconnect until all tasks have finished. // if checked is set to true, delay disconnect until all tasks have finished.
fn client_disconnected(&mut self) { fn client_disconnected(&mut self) {
self.flags.insert(Flags::DISCONNECTED); self.flags
.insert(Flags::READ_DISCONNECT | Flags::WRITE_DISCONNECT);
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = self.payload.take() {
payload.set_error(PayloadError::Incomplete(None)); payload.set_error(PayloadError::Incomplete(None));
} }
} }
/// Flush stream /// Flush stream
fn poll_flush(&mut self) -> Poll<bool, DispatchError> { ///
if !self.framed.is_write_buf_empty() { /// true - got whouldblock
match self.framed.poll_complete() { /// false - didnt get whouldblock
Ok(Async::NotReady) => Ok(Async::NotReady), fn poll_flush(&mut self) -> Result<bool, DispatchError> {
Err(err) => { if self.write_buf.is_empty() {
debug!("Error sending data: {}", err); return Ok(false);
Err(err.into())
}
Ok(Async::Ready(_)) => {
// if payload is not consumed we can not use connection
if self.payload.is_some() && self.state.is_empty() {
return Err(DispatchError::PayloadIsNotConsumed);
}
Ok(Async::Ready(true))
}
}
} else {
Ok(Async::Ready(false))
} }
let len = self.write_buf.len();
let mut written = 0;
while written < len {
match self.io.write(&self.write_buf[written..]) {
Ok(0) => {
return Err(DispatchError::Io(io::Error::new(
io::ErrorKind::WriteZero,
"",
)));
}
Ok(n) => {
written += n;
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if written > 0 {
let _ = self.write_buf.split_to(written);
}
return Ok(true);
}
Err(err) => return Err(DispatchError::Io(err)),
}
}
if written > 0 {
if written == self.write_buf.len() {
unsafe { self.write_buf.set_len(0) }
} else {
let _ = self.write_buf.split_to(written);
}
}
Ok(false)
} }
fn send_response( fn send_response(
@ -228,8 +281,8 @@ where
message: Response<()>, message: Response<()>,
body: ResponseBody<B>, body: ResponseBody<B>,
) -> Result<State<S, B, X>, DispatchError> { ) -> Result<State<S, B, X>, DispatchError> {
self.framed self.codec
.force_send(Message::Item((message, body.length()))) .encode(Message::Item((message, body.length())), &mut self.write_buf)
.map_err(|err| { .map_err(|err| {
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = self.payload.take() {
payload.set_error(PayloadError::Incomplete(None)); payload.set_error(PayloadError::Incomplete(None));
@ -237,113 +290,109 @@ where
DispatchError::Io(err) DispatchError::Io(err)
})?; })?;
self.flags self.flags.set(Flags::KEEPALIVE, self.codec.keepalive());
.set(Flags::KEEPALIVE, self.framed.get_codec().keepalive());
match body.length() { match body.length() {
BodySize::None | BodySize::Empty => Ok(State::None), BodySize::None | BodySize::Empty => Ok(State::None),
_ => Ok(State::SendPayload(body)), _ => Ok(State::SendPayload(body)),
} }
} }
fn send_continue(&mut self) -> Result<(), DispatchError> { fn send_continue(&mut self) {
self.framed self.write_buf
.force_send(Message::Item(( .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
Response::empty(StatusCode::CONTINUE),
BodySize::Empty,
)))
.map_err(|err| DispatchError::Io(err))
} }
fn poll_response(&mut self) -> Result<(), DispatchError> { fn poll_response(&mut self) -> Result<bool, DispatchError> {
let mut retry = self.can_read();
loop { loop {
let state = match mem::replace(&mut self.state, State::None) { let state = match self.state {
State::None => match self.messages.pop_front() { State::None => match self.messages.pop_front() {
Some(DispatcherMessage::Item(req)) => { Some(DispatcherMessage::Item(req)) => {
Some(self.handle_request(req)?) Some(self.handle_request(req)?)
} }
Some(DispatcherMessage::Error(res)) => { Some(DispatcherMessage::Error(res)) => {
self.send_response(res, ResponseBody::Other(Body::Empty))?; Some(self.send_response(res, ResponseBody::Other(Body::Empty))?)
None
} }
None => None, None => None,
}, },
State::ExpectCall(mut fut) => match fut.poll() { State::ExpectCall(ref mut fut) => match fut.poll() {
Ok(Async::Ready(req)) => { Ok(Async::Ready(req)) => {
self.send_continue()?; self.send_continue();
Some(State::ServiceCall(self.service.call(req))) self.state = State::ServiceCall(self.service.call(req));
} continue;
Ok(Async::NotReady) => {
self.state = State::ExpectCall(fut);
None
} }
Ok(Async::NotReady) => None,
Err(e) => { Err(e) => {
let e = e.into(); let res: Response = e.into().into();
let res: Response = e.into();
let (res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
Some(self.send_response(res, body.into_body())?) Some(self.send_response(res, body.into_body())?)
} }
}, },
State::ServiceCall(mut fut) => match fut.poll() { State::ServiceCall(ref mut fut) => match fut.poll() {
Ok(Async::Ready(res)) => { Ok(Async::Ready(res)) => {
let (res, body) = res.into().replace_body(()); let (res, body) = res.into().replace_body(());
Some(self.send_response(res, body)?) self.state = self.send_response(res, body)?;
continue;
} }
Ok(Async::NotReady) => { Ok(Async::NotReady) => None,
self.state = State::ServiceCall(fut); Err(e) => {
None let res: Response = e.into().into();
}
Err(_e) => {
let res: Response = Response::InternalServerError().finish();
let (res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
Some(self.send_response(res, body.into_body())?) Some(self.send_response(res, body.into_body())?)
} }
}, },
State::SendPayload(mut stream) => { State::SendPayload(ref mut stream) => {
loop { loop {
if !self.framed.is_write_buf_full() { if self.write_buf.len() < HW_BUFFER_SIZE {
match stream match stream
.poll_next() .poll_next()
.map_err(|_| DispatchError::Unknown)? .map_err(|_| DispatchError::Unknown)?
{ {
Async::Ready(Some(item)) => { Async::Ready(Some(item)) => {
self.framed self.codec.encode(
.force_send(Message::Chunk(Some(item)))?; Message::Chunk(Some(item)),
&mut self.write_buf,
)?;
continue; continue;
} }
Async::Ready(None) => { Async::Ready(None) => {
self.framed.force_send(Message::Chunk(None))?; self.codec.encode(
} Message::Chunk(None),
Async::NotReady => { &mut self.write_buf,
self.state = State::SendPayload(stream); )?;
return Ok(()); self.state = State::None;
} }
Async::NotReady => return Ok(false),
} }
} else { } else {
self.state = State::SendPayload(stream); return Ok(true);
return Ok(());
} }
break; break;
} }
None continue;
} }
}; };
match state { // set new state
Some(state) => self.state = state, if let Some(state) = state {
None => { self.state = state;
// if read-backpressure is enabled and we consumed some data. if !self.state.is_empty() {
// we may read more data and retry continue;
if !retry && self.can_read() && self.poll_request()? { }
retry = self.can_read(); } else {
// if read-backpressure is enabled and we consumed some data.
// we may read more data and retry
if self.state.is_call() {
if self.poll_request()? {
continue; continue;
} }
break; } else if !self.messages.is_empty() {
continue;
} }
} }
break;
} }
Ok(()) Ok(false)
} }
fn handle_request(&mut self, req: Request) -> Result<State<S, B, X>, DispatchError> { fn handle_request(&mut self, req: Request) -> Result<State<S, B, X>, DispatchError> {
@ -352,7 +401,7 @@ where
let mut task = self.expect.call(req); let mut task = self.expect.call(req);
match task.poll() { match task.poll() {
Ok(Async::Ready(req)) => { Ok(Async::Ready(req)) => {
self.send_continue()?; self.send_continue();
req req
} }
Ok(Async::NotReady) => return Ok(State::ExpectCall(task)), Ok(Async::NotReady) => return Ok(State::ExpectCall(task)),
@ -375,8 +424,8 @@ where
self.send_response(res, body) self.send_response(res, body)
} }
Ok(Async::NotReady) => Ok(State::ServiceCall(task)), Ok(Async::NotReady) => Ok(State::ServiceCall(task)),
Err(_e) => { Err(e) => {
let res: Response = Response::InternalServerError().finish(); let res: Response = e.into().into();
let (res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
self.send_response(res, body.into_body()) self.send_response(res, body.into_body())
} }
@ -386,20 +435,20 @@ where
/// Process one incoming requests /// Process one incoming requests
pub(self) fn poll_request(&mut self) -> Result<bool, DispatchError> { pub(self) fn poll_request(&mut self) -> Result<bool, DispatchError> {
// limit a mount of non processed requests // limit a mount of non processed requests
if self.messages.len() >= MAX_PIPELINED_MESSAGES { if self.messages.len() >= MAX_PIPELINED_MESSAGES || !self.can_read() {
return Ok(false); return Ok(false);
} }
let mut updated = false; let mut updated = false;
loop { loop {
match self.framed.poll() { match self.codec.decode(&mut self.read_buf) {
Ok(Async::Ready(Some(msg))) => { Ok(Some(msg)) => {
updated = true; updated = true;
self.flags.insert(Flags::STARTED); self.flags.insert(Flags::STARTED);
match msg { match msg {
Message::Item(mut req) => { Message::Item(mut req) => {
match self.framed.get_codec().message_type() { match self.codec.message_type() {
MessageType::Payload | MessageType::Stream => { MessageType::Payload | MessageType::Stream => {
let (ps, pl) = Payload::create(false); let (ps, pl) = Payload::create(false);
let (req1, _) = let (req1, _) =
@ -424,7 +473,7 @@ where
error!( error!(
"Internal server error: unexpected payload chunk" "Internal server error: unexpected payload chunk"
); );
self.flags.insert(Flags::DISCONNECTED); self.flags.insert(Flags::READ_DISCONNECT);
self.messages.push_back(DispatcherMessage::Error( self.messages.push_back(DispatcherMessage::Error(
Response::InternalServerError().finish().drop_body(), Response::InternalServerError().finish().drop_body(),
)); ));
@ -437,7 +486,7 @@ where
payload.feed_eof(); payload.feed_eof();
} else { } else {
error!("Internal server error: unexpected eof"); error!("Internal server error: unexpected eof");
self.flags.insert(Flags::DISCONNECTED); self.flags.insert(Flags::READ_DISCONNECT);
self.messages.push_back(DispatcherMessage::Error( self.messages.push_back(DispatcherMessage::Error(
Response::InternalServerError().finish().drop_body(), Response::InternalServerError().finish().drop_body(),
)); ));
@ -447,11 +496,7 @@ where
} }
} }
} }
Ok(Async::Ready(None)) => { Ok(None) => break,
self.client_disconnected();
break;
}
Ok(Async::NotReady) => break,
Err(ParseError::Io(e)) => { Err(ParseError::Io(e)) => {
self.client_disconnected(); self.client_disconnected();
self.error = Some(DispatchError::Io(e)); self.error = Some(DispatchError::Io(e));
@ -466,15 +511,15 @@ where
self.messages.push_back(DispatcherMessage::Error( self.messages.push_back(DispatcherMessage::Error(
Response::BadRequest().finish().drop_body(), Response::BadRequest().finish().drop_body(),
)); ));
self.flags.insert(Flags::DISCONNECTED); self.flags.insert(Flags::READ_DISCONNECT);
self.error = Some(e.into()); self.error = Some(e.into());
break; break;
} }
} }
} }
if self.ka_timer.is_some() && updated { if updated && self.ka_timer.is_some() {
if let Some(expire) = self.config.keep_alive_expire() { if let Some(expire) = self.codec.config.keep_alive_expire() {
self.ka_expire = expire; self.ka_expire = expire;
} }
} }
@ -486,10 +531,10 @@ where
if self.ka_timer.is_none() { if self.ka_timer.is_none() {
// shutdown timeout // shutdown timeout
if self.flags.contains(Flags::SHUTDOWN) { if self.flags.contains(Flags::SHUTDOWN) {
if let Some(interval) = self.config.client_disconnect_timer() { if let Some(interval) = self.codec.config.client_disconnect_timer() {
self.ka_timer = Some(Delay::new(interval)); self.ka_timer = Some(Delay::new(interval));
} else { } else {
self.flags.insert(Flags::DISCONNECTED); self.flags.insert(Flags::READ_DISCONNECT);
return Ok(()); return Ok(());
} }
} else { } else {
@ -507,13 +552,14 @@ where
return Err(DispatchError::DisconnectTimeout); return Err(DispatchError::DisconnectTimeout);
} else if self.ka_timer.as_mut().unwrap().deadline() >= self.ka_expire { } else if self.ka_timer.as_mut().unwrap().deadline() >= self.ka_expire {
// check for any outstanding tasks // check for any outstanding tasks
if self.state.is_empty() && self.framed.is_write_buf_empty() { if self.state.is_empty() && self.write_buf.is_empty() {
if self.flags.contains(Flags::STARTED) { if self.flags.contains(Flags::STARTED) {
trace!("Keep-alive timeout, close connection"); trace!("Keep-alive timeout, close connection");
self.flags.insert(Flags::SHUTDOWN); self.flags.insert(Flags::SHUTDOWN);
// start shutdown timer // start shutdown timer
if let Some(deadline) = self.config.client_disconnect_timer() if let Some(deadline) =
self.codec.config.client_disconnect_timer()
{ {
if let Some(timer) = self.ka_timer.as_mut() { if let Some(timer) = self.ka_timer.as_mut() {
timer.reset(deadline); timer.reset(deadline);
@ -521,7 +567,7 @@ where
} }
} else { } else {
// no shutdown timeout, drop socket // no shutdown timeout, drop socket
self.flags.insert(Flags::DISCONNECTED); self.flags.insert(Flags::WRITE_DISCONNECT);
return Ok(()); return Ok(());
} }
} else { } else {
@ -538,7 +584,8 @@ where
self.flags.insert(Flags::STARTED | Flags::SHUTDOWN); self.flags.insert(Flags::STARTED | Flags::SHUTDOWN);
self.state = State::None; self.state = State::None;
} }
} else if let Some(deadline) = self.config.keep_alive_expire() { } else if let Some(deadline) = self.codec.config.keep_alive_expire()
{
if let Some(timer) = self.ka_timer.as_mut() { if let Some(timer) = self.ka_timer.as_mut() {
timer.reset(deadline); timer.reset(deadline);
let _ = timer.poll(); let _ = timer.poll();
@ -572,34 +619,60 @@ where
#[inline] #[inline]
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let inner = self.inner.as_mut().unwrap(); let inner = self.inner.as_mut().unwrap();
inner.poll_keepalive()?;
if inner.flags.contains(Flags::SHUTDOWN) { if inner.flags.contains(Flags::SHUTDOWN) {
inner.poll_keepalive()?; if inner.flags.contains(Flags::WRITE_DISCONNECT) {
if inner.flags.contains(Flags::DISCONNECTED) {
Ok(Async::Ready(())) Ok(Async::Ready(()))
} else { } else {
// try_ready!(inner.poll_flush()); // flush buffer
match inner.framed.get_mut().shutdown()? { inner.poll_flush()?;
Async::Ready(_) => Ok(Async::Ready(())), if !inner.write_buf.is_empty() {
Async::NotReady => Ok(Async::NotReady), Ok(Async::NotReady)
} else {
match inner.io.shutdown()? {
Async::Ready(_) => Ok(Async::Ready(())),
Async::NotReady => Ok(Async::NotReady),
}
} }
} }
} else { } else {
inner.poll_keepalive()?; // read socket into a buf
if !inner.flags.contains(Flags::READ_DISCONNECT) {
if let Some(true) = read_available(&mut inner.io, &mut inner.read_buf)? {
inner.flags.insert(Flags::READ_DISCONNECT)
}
}
inner.poll_request()?; inner.poll_request()?;
loop { loop {
inner.poll_response()?; if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE {
if let Async::Ready(false) = inner.poll_flush()? { inner.write_buf.reserve(HW_BUFFER_SIZE);
}
let need_write = inner.poll_response()?;
// we didnt get WouldBlock from write operation,
// so data get written to kernel completely (OSX)
// and we have to write again otherwise response can get stuck
if inner.poll_flush()? || !need_write {
break; break;
} }
} }
if inner.flags.contains(Flags::DISCONNECTED) { // client is gone
if inner.flags.contains(Flags::WRITE_DISCONNECT) {
return Ok(Async::Ready(())); return Ok(Async::Ready(()));
} }
let is_empty = inner.state.is_empty();
// read half is closed and we do not processing any responses
if inner.flags.contains(Flags::READ_DISCONNECT) && is_empty {
inner.flags.insert(Flags::SHUTDOWN);
}
// keep-alive and stream errors // keep-alive and stream errors
if inner.state.is_empty() && inner.framed.is_write_buf_empty() { if is_empty && inner.write_buf.is_empty() {
if let Some(err) = inner.error.take() { if let Some(err) = inner.error.take() {
Err(err) Err(err)
} }
@ -623,13 +696,52 @@ where
} }
} }
fn read_available<T>(io: &mut T, buf: &mut BytesMut) -> Result<Option<bool>, io::Error>
where
T: io::Read,
{
let mut read_some = false;
loop {
if buf.remaining_mut() < LW_BUFFER_SIZE {
buf.reserve(HW_BUFFER_SIZE);
}
let read = unsafe { io.read(buf.bytes_mut()) };
match read {
Ok(n) => {
if n == 0 {
return Ok(Some(true));
} else {
read_some = true;
unsafe {
buf.advance_mut(n);
}
}
}
Err(e) => {
return if e.kind() == io::ErrorKind::WouldBlock {
if read_some {
Ok(Some(false))
} else {
Ok(None)
}
} else if e.kind() == io::ErrorKind::ConnectionReset && read_some {
Ok(Some(true))
} else {
Err(e)
};
}
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::{cmp, io}; use std::{cmp, io};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_service::IntoService; use actix_service::IntoService;
use bytes::{Buf, Bytes}; use bytes::{Buf, Bytes, BytesMut};
use futures::future::{lazy, ok}; use futures::future::{lazy, ok};
use super::*; use super::*;
@ -638,6 +750,7 @@ mod tests {
struct Buffer { struct Buffer {
buf: Bytes, buf: Bytes,
write_buf: BytesMut,
err: Option<io::Error>, err: Option<io::Error>,
} }
@ -645,6 +758,7 @@ mod tests {
fn new(data: &'static str) -> Buffer { fn new(data: &'static str) -> Buffer {
Buffer { Buffer {
buf: Bytes::from(data), buf: Bytes::from(data),
write_buf: BytesMut::new(),
err: None, err: None,
} }
} }
@ -670,6 +784,7 @@ mod tests {
impl io::Write for Buffer { impl io::Write for Buffer {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_buf.extend(buf);
Ok(buf.len()) Ok(buf.len())
} }
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
@ -699,15 +814,17 @@ mod tests {
), ),
CloneableService::new(ExpectHandler), CloneableService::new(ExpectHandler),
); );
assert!(h1.poll().is_ok()); assert!(h1.poll().is_err());
assert!(h1.poll().is_ok());
assert!(h1 assert!(h1
.inner .inner
.as_ref() .as_ref()
.unwrap() .unwrap()
.flags .flags
.contains(Flags::DISCONNECTED)); .contains(Flags::READ_DISCONNECT));
// assert_eq!(h1.tasks.len(), 1); assert_eq!(
&h1.inner.as_ref().unwrap().io.write_buf[..26],
b"HTTP/1.1 400 Bad Request\r\n"
);
ok::<_, ()>(()) ok::<_, ()>(())
})); }));
} }

View File

@ -41,8 +41,6 @@ impl<T: MessageType> Default for MessageEncoder<T> {
pub(crate) trait MessageType: Sized { pub(crate) trait MessageType: Sized {
fn status(&self) -> Option<StatusCode>; fn status(&self) -> Option<StatusCode>;
// fn connection_type(&self) -> Option<ConnectionType>;
fn headers(&self) -> &HeaderMap; fn headers(&self) -> &HeaderMap;
fn chunked(&self) -> bool; fn chunked(&self) -> bool;
@ -171,10 +169,6 @@ impl MessageType for Response<()> {
self.head().chunked() self.head().chunked()
} }
//fn connection_type(&self) -> Option<ConnectionType> {
// self.head().ctype
//}
fn headers(&self) -> &HeaderMap { fn headers(&self) -> &HeaderMap {
&self.head().headers &self.head().headers
} }

View File

@ -51,18 +51,6 @@ impl Response<Body> {
} }
} }
#[inline]
pub(crate) fn empty(status: StatusCode) -> Response<()> {
let mut head: Message<ResponseHead> = Message::new();
head.status = status;
Response {
head,
body: ResponseBody::Body(()),
error: None,
}
}
/// Constructs an error response /// Constructs an error response
#[inline] #[inline]
pub fn from_error(error: Error) -> Response { pub fn from_error(error: Error) -> Response {

View File

@ -1,7 +1,7 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use std::{fmt, io}; use std::{fmt, io};
use actix_codec::{AsyncRead, AsyncWrite, Framed, FramedParts}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_server_config::{Io as ServerIo, Protocol, ServerConfig as SrvConfig}; use actix_server_config::{Io as ServerIo, Protocol, ServerConfig as SrvConfig};
use actix_service::{IntoNewService, NewService, Service}; use actix_service::{IntoNewService, NewService, Service};
use actix_utils::cloneable::CloneableService; use actix_utils::cloneable::CloneableService;
@ -375,13 +375,14 @@ where
self.state = self.state =
State::Handshake(Some((server::handshake(io), cfg, srv))); State::Handshake(Some((server::handshake(io), cfg, srv)));
} else { } else {
let framed = Framed::from_parts(FramedParts::with_read_buf( self.state = State::H1(h1::Dispatcher::with_timeout(
io, io,
h1::Codec::new(cfg.clone()), h1::Codec::new(cfg.clone()),
cfg,
buf, buf,
)); None,
self.state = State::H1(h1::Dispatcher::with_timeout( srv,
framed, cfg, None, srv, expect, expect,
)) ))
} }
self.poll() self.poll()

View File

@ -911,11 +911,11 @@ fn test_h1_service_error() {
}); });
let response = srv.block_on(srv.get("/").send()).unwrap(); let response = srv.block_on(srv.get("/").send()).unwrap();
assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
// read response // read response
let bytes = srv.load_body(response).unwrap(); let bytes = srv.load_body(response).unwrap();
assert!(bytes.is_empty()); assert_eq!(bytes, Bytes::from_static(b"error"));
} }
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]

View File

@ -770,7 +770,7 @@ fn test_reading_deflate_encoding_large_random_ssl() {
awc::Connector::new() awc::Connector::new()
.timeout(std::time::Duration::from_millis(500)) .timeout(std::time::Duration::from_millis(500))
.ssl(builder.build()) .ssl(builder.build())
.service(), .finish(),
) )
.finish() .finish()
}); });