From 2e83c5924d0de509d793b2b62c4056745e841cbe Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 12 Dec 2017 21:32:58 -0800 Subject: [PATCH] cleanup and optimize some code --- examples/basic.rs | 8 +++++--- src/encoding.rs | 43 ++++++++++++++++++++++++++--------------- src/error.rs | 13 ++++++++++++- src/h1.rs | 47 ++++++++++++++++++++++++++++++++++----------- src/h2.rs | 2 +- src/handler.rs | 4 ++++ src/httprequest.rs | 46 ++++++++++++++++++++++++++------------------ src/httpresponse.rs | 32 ++++++++++++++++++++---------- src/ws.rs | 21 +++++++++++++------- 9 files changed, 148 insertions(+), 68 deletions(-) diff --git a/examples/basic.rs b/examples/basic.rs index 304eabd27..22dfaba37 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -13,9 +13,11 @@ use futures::future::{FutureResult, result}; /// simple handler fn index(mut req: HttpRequest) -> Result { println!("{:?}", req); - if let Ok(ch) = req.payload_mut().readany() { - if let futures::Async::Ready(Some(d)) = ch { - println!("{}", String::from_utf8_lossy(d.0.as_ref())); + if let Some(payload) = req.payload_mut() { + if let Ok(ch) = payload.readany() { + if let futures::Async::Ready(Some(d)) = ch { + println!("{}", String::from_utf8_lossy(d.0.as_ref())); + } } } diff --git a/src/encoding.rs b/src/encoding.rs index 1c8c88272..be44990b7 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -402,14 +402,14 @@ impl PayloadEncoder { resp.headers_mut().insert( CONTENT_LENGTH, - HeaderValue::from_str(format!("{}", b.len()).as_str()).unwrap()); + HeaderValue::from_str(&b.len().to_string()).unwrap()); *bytes = Binary::from(b); encoding = ContentEncoding::Identity; TransferEncoding::eof() } else { resp.headers_mut().insert( CONTENT_LENGTH, - HeaderValue::from_str(format!("{}", bytes.len()).as_str()).unwrap()); + HeaderValue::from_str(&bytes.len().to_string()).unwrap()); resp.headers_mut().remove(TRANSFER_ENCODING); TransferEncoding::length(bytes.len() as u64) } @@ -478,22 +478,27 @@ impl PayloadEncoder { impl PayloadEncoder { + #[inline] pub fn len(&self) -> usize { self.0.get_ref().len() } + #[inline] pub fn get_mut(&mut self) -> &mut BytesMut { self.0.get_mut() } + #[inline] pub fn is_eof(&self) -> bool { self.0.is_eof() } + #[inline] pub fn write(&mut self, payload: &[u8]) -> Result<(), io::Error> { self.0.write(payload) } + #[inline] pub fn write_eof(&mut self) -> Result<(), io::Error> { self.0.write_eof() } @@ -508,6 +513,7 @@ enum ContentEncoder { impl ContentEncoder { + #[inline] pub fn is_eof(&self) -> bool { match *self { ContentEncoder::Br(ref encoder) => @@ -521,6 +527,7 @@ impl ContentEncoder { } } + #[inline] pub fn get_ref(&self) -> &BytesMut { match *self { ContentEncoder::Br(ref encoder) => @@ -534,6 +541,7 @@ impl ContentEncoder { } } + #[inline] pub fn get_mut(&mut self) -> &mut BytesMut { match *self { ContentEncoder::Br(ref mut encoder) => @@ -547,6 +555,7 @@ impl ContentEncoder { } } + #[inline] pub fn write_eof(&mut self) -> Result<(), io::Error> { let encoder = mem::replace(self, ContentEncoder::Identity(TransferEncoding::eof())); @@ -555,7 +564,6 @@ impl ContentEncoder { match encoder.finish() { Ok(mut writer) => { writer.encode_eof(); - *self = ContentEncoder::Identity(writer); Ok(()) }, Err(err) => Err(err), @@ -565,7 +573,6 @@ impl ContentEncoder { match encoder.finish() { Ok(mut writer) => { writer.encode_eof(); - *self = ContentEncoder::Identity(writer); Ok(()) }, Err(err) => Err(err), @@ -575,7 +582,6 @@ impl ContentEncoder { match encoder.finish() { Ok(mut writer) => { writer.encode_eof(); - *self = ContentEncoder::Identity(writer); Ok(()) }, Err(err) => Err(err), @@ -583,19 +589,18 @@ impl ContentEncoder { }, ContentEncoder::Identity(mut writer) => { writer.encode_eof(); - *self = ContentEncoder::Identity(writer); Ok(()) } } } + #[inline] pub fn write(&mut self, data: &[u8]) -> Result<(), io::Error> { match *self { ContentEncoder::Br(ref mut encoder) => { match encoder.write(data) { - Ok(_) => { - encoder.flush() - }, + Ok(_) => + encoder.flush(), Err(err) => { trace!("Error decoding br encoding: {}", err); Err(err) @@ -604,20 +609,18 @@ impl ContentEncoder { }, ContentEncoder::Gzip(ref mut encoder) => { match encoder.write(data) { - Ok(_) => { - encoder.flush() - }, + Ok(_) => + encoder.flush(), Err(err) => { - trace!("Error decoding br encoding: {}", err); + trace!("Error decoding gzip encoding: {}", err); Err(err) }, } } ContentEncoder::Deflate(ref mut encoder) => { match encoder.write(data) { - Ok(_) => { - encoder.flush() - }, + Ok(_) => + encoder.flush(), Err(err) => { trace!("Error decoding deflate encoding: {}", err); Err(err) @@ -655,6 +658,7 @@ enum TransferEncodingKind { impl TransferEncoding { + #[inline] pub fn eof() -> TransferEncoding { TransferEncoding { kind: TransferEncodingKind::Eof, @@ -662,6 +666,7 @@ impl TransferEncoding { } } + #[inline] pub fn chunked() -> TransferEncoding { TransferEncoding { kind: TransferEncodingKind::Chunked(false), @@ -669,6 +674,7 @@ impl TransferEncoding { } } + #[inline] pub fn length(len: u64) -> TransferEncoding { TransferEncoding { kind: TransferEncodingKind::Length(len), @@ -676,6 +682,7 @@ impl TransferEncoding { } } + #[inline] pub fn is_eof(&self) -> bool { match self.kind { TransferEncodingKind::Eof => true, @@ -687,6 +694,7 @@ impl TransferEncoding { } /// Encode message. Return `EOF` state of encoder + #[inline] pub fn encode(&mut self, msg: &[u8]) -> bool { match self.kind { TransferEncodingKind::Eof => { @@ -724,6 +732,7 @@ impl TransferEncoding { } /// Encode eof. Return `EOF` state of encoder + #[inline] pub fn encode_eof(&mut self) { match self.kind { TransferEncodingKind::Eof | TransferEncodingKind::Length(_) => (), @@ -739,11 +748,13 @@ impl TransferEncoding { impl io::Write for TransferEncoding { + #[inline] fn write(&mut self, buf: &[u8]) -> io::Result { self.encode(buf); Ok(buf.len()) } + #[inline] fn flush(&mut self) -> io::Result<()> { Ok(()) } diff --git a/src/error.rs b/src/error.rs index e79b12939..37c0d0e5c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -262,6 +262,9 @@ pub enum MultipartError { /// Multipart boundary is not found #[fail(display="Multipart boundary is not found")] Boundary, + /// Request does not contain payload + #[fail(display="Request does not contain payload")] + NoPayload, /// Error during field parsing #[fail(display="{}", _0)] Parse(#[cause] ParseError), @@ -329,6 +332,9 @@ pub enum WsHandshakeError { /// Websocket key is not set or wrong #[fail(display="Unknown websocket key")] BadWebsocketKey, + /// Request does not contain payload + #[fail(display="Request does not contain payload")] + NoPayload, } impl ResponseError for WsHandshakeError { @@ -351,7 +357,9 @@ impl ResponseError for WsHandshakeError { WsHandshakeError::UnsupportedVersion => HTTPBadRequest.with_reason("Unsupported version"), WsHandshakeError::BadWebsocketKey => - HTTPBadRequest.with_reason("Handshake error") + HTTPBadRequest.with_reason("Handshake error"), + WsHandshakeError::NoPayload => + HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR, Body::Empty), } } } @@ -371,6 +379,9 @@ pub enum UrlencodedError { /// Content type error #[fail(display="Content type error")] ContentType, + /// Request does not contain payload + #[fail(display="Request does not contain payload")] + NoPayload, } /// Return `BadRequest` for `UrlencodedError` diff --git a/src/h1.rs b/src/h1.rs index 5a46a3bb4..10b5cff2d 100644 --- a/src/h1.rs +++ b/src/h1.rs @@ -545,28 +545,25 @@ impl Reader { } } - let (mut psender, payload) = Payload::new(false); - let msg = HttpRequest::new(method, uri, version, headers, payload); - - let decoder = if msg.upgrade() { + let decoder = if upgrade(&method, &headers) { Decoder::eof() } else { - let has_len = msg.headers().contains_key(header::CONTENT_LENGTH); + let has_len = headers.contains_key(header::CONTENT_LENGTH); // Chunked encoding - if msg.chunked()? { + if chunked(&headers)? { if has_len { return Err(ParseError::Header) } Decoder::chunked() } else { if !has_len { - psender.feed_eof(); + let msg = HttpRequest::new(method, uri, version, headers, None); return Ok(Message::Http1(msg, None)) } // Content-Length - let len = msg.headers().get(header::CONTENT_LENGTH).unwrap(); + let len = headers.get(header::CONTENT_LENGTH).unwrap(); if let Ok(s) = len.to_str() { if let Ok(len) = s.parse::() { Decoder::length(len) @@ -581,11 +578,13 @@ impl Reader { } }; - let payload = PayloadInfo { - tx: PayloadType::new(msg.headers(), psender), + let (psender, payload) = Payload::new(false); + let info = PayloadInfo { + tx: PayloadType::new(&headers, psender), decoder: decoder, }; - Ok(Message::Http1(msg, Some(payload))) + let msg = HttpRequest::new(method, uri, version, headers, Some(payload)); + Ok(Message::Http1(msg, Some(info))) } } @@ -610,6 +609,32 @@ fn record_header_indices(bytes: &[u8], } } +/// Check if request is UPGRADE +fn upgrade(method: &Method, headers: &HeaderMap) -> bool { + if let Some(conn) = headers.get(header::CONNECTION) { + if let Ok(s) = conn.to_str() { + s.to_lowercase().contains("upgrade") + } else { + *method == Method::CONNECT + } + } else { + *method == Method::CONNECT + } +} + +/// Check if request has chunked transfer encoding +fn chunked(headers: &HeaderMap) -> Result { + if let Some(encodings) = headers.get(header::TRANSFER_ENCODING) { + if let Ok(s) = encodings.to_str() { + Ok(s.to_lowercase().contains("chunked")) + } else { + Err(ParseError::Header) + } + } else { + Ok(false) + } +} + /// Decoders to handle different Transfer-Encodings. /// /// If a message body does not include a Transfer-Encoding, it *should* diff --git a/src/h2.rs b/src/h2.rs index d3f357aef..625681623 100644 --- a/src/h2.rs +++ b/src/h2.rs @@ -237,7 +237,7 @@ impl Entry { let (psender, payload) = Payload::new(false); let mut req = HttpRequest::new( - parts.method, parts.uri, parts.version, parts.headers, payload); + parts.method, parts.uri, parts.version, parts.headers, Some(payload)); // set remote addr req.set_peer_addr(addr); diff --git a/src/handler.rs b/src/handler.rs index 241eabf3c..5c9ba94a8 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -58,6 +58,7 @@ pub(crate) enum ReplyItem { impl Reply { /// Create actor response + #[inline] pub fn actor(ctx: HttpContext) -> Reply where A: Actor>, S: 'static { @@ -65,6 +66,7 @@ impl Reply { } /// Create async response + #[inline] pub fn async(fut: F) -> Reply where F: Future + 'static { @@ -72,10 +74,12 @@ impl Reply { } /// Send response + #[inline] pub fn response>(response: R) -> Reply { Reply(ReplyItem::Message(response.into())) } + #[inline] pub(crate) fn into(self) -> ReplyItem { self.0 } diff --git a/src/httprequest.rs b/src/httprequest.rs index 3752cdf36..266a4e945 100644 --- a/src/httprequest.rs +++ b/src/httprequest.rs @@ -28,7 +28,7 @@ pub struct HttpMessage { pub params: Params<'static>, pub cookies: Option>>, pub addr: Option, - pub payload: Payload, + pub payload: Option, pub info: Option>, } @@ -43,7 +43,7 @@ impl Default for HttpMessage { params: Params::default(), cookies: None, addr: None, - payload: Payload::empty(), + payload: None, extensions: Extensions::new(), info: None, } @@ -72,13 +72,13 @@ impl HttpMessage { } /// An HTTP Request -pub struct HttpRequest(Rc, Rc, Option>); +pub struct HttpRequest(Rc, Option>, Option>); impl HttpRequest<()> { /// Construct a new Request. #[inline] pub fn new(method: Method, uri: Uri, - version: Version, headers: HeaderMap, payload: Payload) -> HttpRequest + version: Version, headers: HeaderMap, payload: Option) -> HttpRequest { HttpRequest( Rc::new(HttpMessage { @@ -93,7 +93,7 @@ impl HttpRequest<()> { extensions: Extensions::new(), info: None, }), - Rc::new(()), + None, None, ) } @@ -118,14 +118,14 @@ impl HttpRequest<()> { extensions: Extensions::new(), info: None, }), - Rc::new(()), + None, None, ) } /// Construct new http request with state. pub fn with_state(self, state: Rc, router: Router) -> HttpRequest { - HttpRequest(self.0, state, Some(router)) + HttpRequest(self.0, Some(state), Some(router)) } } @@ -133,7 +133,7 @@ impl HttpRequest { /// Construct new http request without state. pub fn clone_without_state(&self) -> HttpRequest { - HttpRequest(Rc::clone(&self.0), Rc::new(()), None) + HttpRequest(Rc::clone(&self.0), None, None) } // get mutable reference for inner message @@ -153,7 +153,7 @@ impl HttpRequest { /// Shared application state #[inline] pub fn state(&self) -> &S { - &self.1 + self.1.as_ref().unwrap() } /// Protocol extensions. @@ -377,20 +377,20 @@ impl HttpRequest { /// Returns reference to the associated http payload. #[inline] - pub fn payload(&self) -> &Payload { - &self.0.payload + pub fn payload(&self) -> Option<&Payload> { + self.0.payload.as_ref() } /// Returns mutable reference to the associated http payload. #[inline] - pub fn payload_mut(&mut self) -> &mut Payload { - &mut self.as_mut().payload + pub fn payload_mut(&mut self) -> Option<&mut Payload> { + self.as_mut().payload.as_mut() } /// Return payload #[inline] - pub fn take_payload(&mut self) -> Payload { - mem::replace(&mut self.as_mut().payload, Payload::empty()) + pub fn take_payload(&mut self) -> Option { + self.as_mut().payload.take() } /// Return stream to process BODY as multipart. @@ -398,7 +398,11 @@ impl HttpRequest { /// Content-type: multipart/form-data; pub fn multipart(&mut self) -> Result { let boundary = Multipart::boundary(&self.0.headers)?; - Ok(Multipart::new(boundary, self.take_payload())) + if let Some(payload) = self.take_payload() { + Ok(Multipart::new(boundary, payload)) + } else { + Err(MultipartError::NoPayload) + } } /// Parse `application/x-www-form-urlencoded` encoded body. @@ -441,7 +445,11 @@ impl HttpRequest { }; if t { - Ok(UrlEncoded{pl: self.take_payload(), body: BytesMut::new()}) + if let Some(payload) = self.take_payload() { + Ok(UrlEncoded{pl: payload, body: BytesMut::new()}) + } else { + Err(UrlencodedError::NoPayload) + } } else { Err(UrlencodedError::ContentType) } @@ -452,13 +460,13 @@ impl Default for HttpRequest<()> { /// Construct default request fn default() -> HttpRequest { - HttpRequest(Rc::new(HttpMessage::default()), Rc::new(()), None) + HttpRequest(Rc::new(HttpMessage::default()), None, None) } } impl Clone for HttpRequest { fn clone(&self) -> HttpRequest { - HttpRequest(Rc::clone(&self.0), Rc::clone(&self.1), None) + HttpRequest(Rc::clone(&self.0), self.1.clone(), None) } } diff --git a/src/httpresponse.rs b/src/httpresponse.rs index e877a761a..ba4cb7b23 100644 --- a/src/httpresponse.rs +++ b/src/httpresponse.rs @@ -222,20 +222,20 @@ struct Parts { chunked: bool, encoding: ContentEncoding, connection_type: Option, - cookies: CookieJar, + cookies: Option, } impl Parts { fn new(status: StatusCode) -> Self { Parts { version: None, - headers: HeaderMap::new(), + headers: HeaderMap::with_capacity(8), status: status, reason: None, chunked: false, encoding: ContentEncoding::Auto, connection_type: None, - cookies: CookieJar::new(), + cookies: None, } } } @@ -359,7 +359,13 @@ impl HttpResponseBuilder { /// Set a cookie pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self { if let Some(parts) = parts(&mut self.parts, &self.err) { - parts.cookies.add(cookie.into_owned()); + if parts.cookies.is_none() { + let mut jar = CookieJar::new(); + jar.add(cookie.into_owned()); + parts.cookies = Some(jar) + } else { + parts.cookies.as_mut().unwrap().add(cookie.into_owned()); + } } self } @@ -367,9 +373,13 @@ impl HttpResponseBuilder { /// Remote cookie, cookie has to be cookie from `HttpRequest::cookies()` method. pub fn del_cookie<'a>(&mut self, cookie: &Cookie<'a>) -> &mut Self { if let Some(parts) = parts(&mut self.parts, &self.err) { + if parts.cookies.is_none() { + parts.cookies = Some(CookieJar::new()) + } + let mut jar = parts.cookies.as_mut().unwrap(); let cookie = cookie.clone().into_owned(); - parts.cookies.add_original(cookie.clone()); - parts.cookies.remove(cookie); + jar.add_original(cookie.clone()); + jar.remove(cookie); } self } @@ -391,10 +401,12 @@ impl HttpResponseBuilder { if let Some(e) = self.err.take() { return Err(e) } - for cookie in parts.cookies.delta() { - parts.headers.append( - header::SET_COOKIE, - HeaderValue::from_str(&cookie.to_string())?); + if let Some(jar) = parts.cookies { + for cookie in jar.delta() { + parts.headers.append( + header::SET_COOKIE, + HeaderValue::from_str(&cookie.to_string())?); + } } Ok(HttpResponse { version: parts.version, diff --git a/src/ws.rs b/src/ws.rs index 2578cdc0d..cd23526fc 100644 --- a/src/ws.rs +++ b/src/ws.rs @@ -96,11 +96,15 @@ pub fn start(mut req: HttpRequest, actor: A) -> Result { let resp = handshake(&req)?; - let stream = WsStream::new(&mut req); - let mut ctx = HttpContext::new(req, actor); - ctx.start(resp); - ctx.add_stream(stream); - Ok(ctx.into()) + if let Some(payload) = req.take_payload() { + let stream = WsStream::new(payload); + let mut ctx = HttpContext::new(req, actor); + ctx.start(resp); + ctx.add_stream(stream); + Ok(ctx.into()) + } else { + Err(WsHandshakeError::NoPayload.into()) + } } /// Prepare `WebSocket` handshake response. @@ -178,8 +182,11 @@ pub struct WsStream { } impl WsStream { - pub fn new(req: &mut HttpRequest) -> WsStream { - WsStream { rx: req.take_payload(), buf: BytesMut::new(), closed: false, error_sent: false } + pub fn new(payload: Payload) -> WsStream { + WsStream { rx: payload, + buf: BytesMut::new(), + closed: false, + error_sent: false } } }