From ebc59cf7b941a9004e1d37a97dbd407fd74db125 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 21 Jun 2018 11:20:21 +0600 Subject: [PATCH] add unsafe checks #331 --- src/with.rs | 80 +++++++++++++++++++++++++++++------------- tests/test_handlers.rs | 22 ++++++++++++ 2 files changed, 78 insertions(+), 24 deletions(-) diff --git a/src/with.rs b/src/with.rs index 4cb1546a7..423e558a2 100644 --- a/src/with.rs +++ b/src/with.rs @@ -111,8 +111,18 @@ where T: FromRequest, S: 'static, { - hnd: Rc>, + hnd: Rc>, cfg: ExtractorConfig, +} + +pub struct WithHnd +where + F: Fn(T) -> R, + T: FromRequest, + S: 'static, +{ + hnd: Rc>, + _t: PhantomData, _s: PhantomData, } @@ -125,8 +135,11 @@ where pub fn new(f: F, cfg: ExtractorConfig) -> Self { With { cfg, - hnd: Rc::new(UnsafeCell::new(f)), - _s: PhantomData, + hnd: Rc::new(WithHnd { + hnd: Rc::new(UnsafeCell::new(f)), + _t: PhantomData, + _s: PhantomData, + }), } } } @@ -166,7 +179,7 @@ where S: 'static, { started: bool, - hnd: Rc>, + hnd: Rc>, cfg: ExtractorConfig, req: HttpRequest, fut1: Option>>, @@ -206,20 +219,28 @@ where } }; - let hnd: &mut F = unsafe { &mut *self.hnd.get() }; - let item = match (*hnd)(item).respond_to(&self.req) { - Ok(item) => item.into(), - Err(e) => return Err(e.into()), - }; - - match item.into() { - AsyncResultItem::Err(err) => Err(err), - AsyncResultItem::Ok(resp) => Ok(Async::Ready(resp)), - AsyncResultItem::Future(fut) => { - self.fut2 = Some(fut); - self.poll() + let fut = { + // clone handler, inicrease ref counter + let h = self.hnd.as_ref().hnd.clone(); + // Enforce invariants before entering unsafe code. + // Only two references could exists With struct owns one, and line above + if Rc::weak_count(&h) != 0 && Rc::strong_count(&h) != 2 { + panic!("Multiple copies of handler are in use") } - } + let hnd: &mut F = unsafe { &mut *h.as_ref().get() }; + let item = match (*hnd)(item).respond_to(&self.req) { + Ok(item) => item.into(), + Err(e) => return Err(e.into()), + }; + + match item.into() { + AsyncResultItem::Err(err) => return Err(err), + AsyncResultItem::Ok(resp) => return Ok(Async::Ready(resp)), + AsyncResultItem::Future(fut) => fut, + } + }; + self.fut2 = Some(fut); + self.poll() } } @@ -232,9 +253,8 @@ where T: FromRequest, S: 'static, { - hnd: Rc>, + hnd: Rc>, cfg: ExtractorConfig, - _s: PhantomData, } impl WithAsync @@ -249,8 +269,11 @@ where pub fn new(f: F, cfg: ExtractorConfig) -> Self { WithAsync { cfg, - hnd: Rc::new(UnsafeCell::new(f)), - _s: PhantomData, + hnd: Rc::new(WithHnd { + hnd: Rc::new(UnsafeCell::new(f)), + _s: PhantomData, + _t: PhantomData, + }), } } } @@ -295,7 +318,7 @@ where S: 'static, { started: bool, - hnd: Rc>, + hnd: Rc>, cfg: ExtractorConfig, req: HttpRequest, fut1: Option>>, @@ -356,8 +379,17 @@ where } }; - let hnd: &mut F = unsafe { &mut *self.hnd.get() }; - self.fut2 = Some((*hnd)(item)); + self.fut2 = { + // clone handler, inicrease ref counter + let h = self.hnd.as_ref().hnd.clone(); + // Enforce invariants before entering unsafe code. + // Only two references could exists With struct owns one, and line above + if Rc::weak_count(&h) != 0 && Rc::strong_count(&h) != 2 { + panic!("Multiple copies of handler are in use") + } + let hnd: &mut F = unsafe { &mut *h.as_ref().get() }; + Some((*hnd)(item)) + }; self.poll() } } diff --git a/tests/test_handlers.rs b/tests/test_handlers.rs index 116112e27..57309f833 100644 --- a/tests/test_handlers.rs +++ b/tests/test_handlers.rs @@ -42,6 +42,28 @@ fn test_path_extractor() { assert_eq!(bytes, Bytes::from_static(b"Welcome test!")); } +#[test] +fn test_async_handler() { + let mut srv = test::TestServer::new(|app| { + app.resource("/{username}/index.html", |r| { + r.route().with(|p: Path| { + Delay::new(Instant::now() + Duration::from_millis(10)) + .and_then(move |_| Ok(format!("Welcome {}!", p.username))) + .responder() + }) + }); + }); + + // client request + let request = srv.get().uri(srv.url("/test/index.html")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.execute(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from_static(b"Welcome test!")); +} + #[test] fn test_query_extractor() { let mut srv = test::TestServer::new(|app| {