diff --git a/CHANGES.md b/CHANGES.md index f0f7cd7e..a8c4e1e5 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,10 @@ # Changes +## 0.6.5 (2018-05-15) + +* Fix error handling during request decoding #222 + + ## 0.6.4 (2018-05-11) * Fix segfault in ServerSettings::get_response_builder() diff --git a/src/server/h1.rs b/src/server/h1.rs index b6de5cc5..933ce0a8 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -403,8 +403,12 @@ where #[cfg(test)] mod tests { - use bytes::{Bytes, BytesMut}; + use std::net::Shutdown; + use std::{cmp, time}; + + use bytes::{Buf, Bytes, BytesMut}; use http::{Method, Version}; + use tokio_io::{AsyncRead, AsyncWrite}; use super::*; use application::HttpApplication; @@ -468,6 +472,101 @@ mod tests { }}; } + struct Buffer { + buf: Bytes, + err: Option, + } + + impl Buffer { + fn new(data: &'static str) -> Buffer { + Buffer { + buf: Bytes::from(data), + err: None, + } + } + fn feed_data(&mut self, data: &'static str) { + let mut b = BytesMut::from(self.buf.as_ref()); + b.extend(data.as_bytes()); + self.buf = b.take().freeze(); + } + } + + impl AsyncRead for Buffer {} + impl io::Read for Buffer { + fn read(&mut self, dst: &mut [u8]) -> Result { + if self.buf.is_empty() { + if self.err.is_some() { + Err(self.err.take().unwrap()) + } else { + Err(io::Error::new(io::ErrorKind::WouldBlock, "")) + } + } else { + let size = cmp::min(self.buf.len(), dst.len()); + let b = self.buf.split_to(size); + dst[..size].copy_from_slice(&b); + Ok(size) + } + } + } + + impl IoStream for Buffer { + fn shutdown(&mut self, _: Shutdown) -> io::Result<()> { + Ok(()) + } + fn set_nodelay(&mut self, _: bool) -> io::Result<()> { + Ok(()) + } + fn set_linger(&mut self, _: Option) -> io::Result<()> { + Ok(()) + } + } + impl io::Write for Buffer { + fn write(&mut self, buf: &[u8]) -> io::Result { + Ok(buf.len()) + } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } + } + impl AsyncWrite for Buffer { + fn shutdown(&mut self) -> Poll<(), io::Error> { + Ok(Async::Ready(())) + } + fn write_buf(&mut self, _: &mut B) -> Poll { + Ok(Async::NotReady) + } + } + + #[test] + fn test_req_parse() { + let buf = Buffer::new("GET /test HTTP/1.1\r\n\r\n"); + let readbuf = BytesMut::new(); + let settings = Rc::new(WorkerSettings::::new( + Vec::new(), + KeepAlive::Os, + )); + + let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf); + h1.poll_io(); + h1.parse(); + assert_eq!(h1.tasks.len(), 1); + } + + #[test] + fn test_req_parse_err() { + let buf = Buffer::new("GET /test HTTP/1\r\n\r\n"); + let readbuf = BytesMut::new(); + let settings = Rc::new(WorkerSettings::::new( + Vec::new(), + KeepAlive::Os, + )); + + let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf); + h1.poll_io(); + h1.parse(); + assert!(h1.flags.contains(Flags::ERROR)); + } + #[test] fn test_parse() { let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n\r\n");