From 234c60d473a529f353ab51c04f3c7cb9e402fb70 Mon Sep 17 00:00:00 2001 From: Jef Date: Wed, 20 Jun 2018 10:50:56 +0200 Subject: [PATCH] Fix some unsoundness This improves the sound implementation of `fn route`. Previously this function would iterate twice but we can reduce the overhead without using `unsafe`. --- src/application.rs | 41 +++++++++++++++++++++-------------------- src/multipart.rs | 8 ++++---- src/server/h1writer.rs | 31 ++++++++++++++++++++++++------- 3 files changed, 49 insertions(+), 31 deletions(-) diff --git a/src/application.rs b/src/application.rs index 93008b3d2..b28c1829f 100644 --- a/src/application.rs +++ b/src/application.rs @@ -335,33 +335,34 @@ where T: FromRequest + 'static, { { - let parts = self.parts.as_mut().expect("Use after finish"); + let parts: &mut ApplicationParts = self.parts.as_mut().expect("Use after finish"); - // get resource handler - let mut found = false; - for &mut (ref pattern, ref handler) in &mut parts.resources { - if handler.is_some() && pattern.pattern() == path { - found = true; - break; - } - } + let out = { + // get resource handler + let mut iterator = parts.resources.iter_mut(); - if !found { - let mut handler = ResourceHandler::default(); - handler.method(method).with(f); - let pattern = Resource::new(handler.get_name(), path); - parts.resources.push((pattern, Some(handler))); - } else { - for &mut (ref pattern, ref mut handler) in &mut parts.resources { - if let Some(ref mut handler) = *handler { - if pattern.pattern() == path { - handler.method(method).with(f); - break; + loop { + if let Some(&mut (ref pattern, ref mut handler)) = iterator.next() { + if let Some(ref mut handler) = *handler { + if pattern.pattern() == path { + handler.method(method).with(f); + break None; + } } + } else { + let mut handler = ResourceHandler::default(); + handler.method(method).with(f); + let pattern = Resource::new(handler.get_name(), path); + break Some((pattern, Some(handler))); } } + }; + + if let Some(out) = out { + parts.resources.push(out); } } + self } diff --git a/src/multipart.rs b/src/multipart.rs index 9c5c0380c..7c93b5657 100644 --- a/src/multipart.rs +++ b/src/multipart.rs @@ -1,5 +1,5 @@ //! Multipart requests support -use std::cell::RefCell; +use std::cell::{RefCell, UnsafeCell}; use std::marker::PhantomData; use std::rc::Rc; use std::{cmp, fmt}; @@ -590,7 +590,7 @@ where } struct PayloadRef { - payload: Rc>, + payload: Rc>>, } impl PayloadRef @@ -599,7 +599,7 @@ where { fn new(payload: PayloadHelper) -> PayloadRef { PayloadRef { - payload: Rc::new(payload), + payload: Rc::new(payload.into()), } } @@ -609,7 +609,7 @@ where { if s.current() { let payload: &mut PayloadHelper = - unsafe { &mut *(self.payload.as_ref() as *const _ as *mut _) }; + unsafe { &mut *self.payload.get() }; Some(payload) } else { None diff --git a/src/server/h1writer.rs b/src/server/h1writer.rs index d174964b9..ebb0fff32 100644 --- a/src/server/h1writer.rs +++ b/src/server/h1writer.rs @@ -73,12 +73,11 @@ impl H1Writer { self.flags.contains(Flags::KEEPALIVE) && !self.flags.contains(Flags::UPGRADE) } - fn write_data(&mut self, data: &[u8]) -> io::Result { + fn write_data(stream: &mut T, data: &[u8]) -> io::Result { let mut written = 0; while written < data.len() { - match self.stream.write(&data[written..]) { + match stream.write(&data[written..]) { Ok(0) => { - self.disconnected(); return Err(io::Error::new(io::ErrorKind::WriteZero, "")); } Ok(n) => { @@ -243,7 +242,16 @@ impl Writer for H1Writer { if self.flags.contains(Flags::UPGRADE) { if self.buffer.is_empty() { let pl: &[u8] = payload.as_ref(); - let n = self.write_data(pl)?; + let n = match Self::write_data(&mut self.stream, pl) { + Err(err) => { + if err.kind() == io::ErrorKind::WriteZero { + self.disconnected(); + } + + return Err(err); + } + Ok(val) => val, + }; if n < pl.len() { self.buffer.extend_from_slice(&pl[n..]); return Ok(WriterState::Done); @@ -284,9 +292,18 @@ impl Writer for H1Writer { #[inline] fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> { if !self.buffer.is_empty() { - let buf: &[u8] = - unsafe { &mut *(self.buffer.as_ref() as *const _ as *mut _) }; - let written = self.write_data(buf)?; + let written = { + match Self::write_data(&mut self.stream, self.buffer.as_ref()) { + Err(err) => { + if err.kind() == io::ErrorKind::WriteZero { + self.disconnected(); + } + + return Err(err); + } + Ok(val) => val, + } + }; let _ = self.buffer.split_to(written); if shutdown && !self.buffer.is_empty() || (self.buffer.len() > self.buffer_capacity)