From af1e0bac0804d6979bc20b6d754eac19a2338814 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sun, 29 Oct 2017 06:05:31 -0700 Subject: [PATCH] add HttpContext::drain() --- Cargo.toml | 4 +- src/application.rs | 4 +- src/context.rs | 36 +++++++++++- src/lib.rs | 3 +- src/resource.rs | 10 +++- src/route.rs | 5 +- src/task.rs | 144 +++++++++++++++++++++++++++++++++------------ 7 files changed, 156 insertions(+), 50 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 18c62c0ec..4b163aed3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,6 @@ mime = "0.3" mime_guess = "1.8" cookie = { version="0.10", features=["percent-encode"] } regex = "0.2" -slab = "0.4" sha1 = "0.2" url = "1.5" percent-encoding = "1.0" @@ -46,9 +45,8 @@ percent-encoding = "1.0" # tokio bytes = "0.4" futures = "0.1" -tokio-core = "0.1" tokio-io = "0.1" -tokio-proto = "0.1" +tokio-core = "0.1" # h2 = { git = 'https://github.com/carllerche/h2', optional = true } [dependencies.actix] diff --git a/src/application.rs b/src/application.rs index da6cee971..2bfa6dba0 100644 --- a/src/application.rs +++ b/src/application.rs @@ -91,7 +91,7 @@ impl Application<()> { parts: Some(ApplicationBuilderParts { state: (), prefix: prefix.to_string(), - default: Resource::default(), + default: Resource::default_not_found(), handlers: HashMap::new(), resources: HashMap::new(), middlewares: Vec::new(), @@ -110,7 +110,7 @@ impl Application where S: 'static { parts: Some(ApplicationBuilderParts { state: state, prefix: prefix.to_string(), - default: Resource::default(), + default: Resource::default_not_found(), handlers: HashMap::new(), resources: HashMap::new(), middlewares: Vec::new(), diff --git a/src/context.rs b/src/context.rs index 3606db593..ae4f6607a 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,7 +1,9 @@ use std; use std::rc::Rc; +use std::cell::RefCell; use std::collections::VecDeque; -use futures::{Async, Stream, Poll}; +use std::marker::PhantomData; +use futures::{Async, Future, Stream, Poll}; use futures::sync::oneshot::Sender; use actix::{Actor, ActorState, ActorContext, AsyncContext, @@ -10,7 +12,7 @@ use actix::fut::ActorFuture; use actix::dev::{AsyncContextApi, ActorAddressCell, ActorItemsCell, ActorWaitCell, SpawnHandle, Envelope, ToEnvelope, RemoteEnvelope}; -use task::IoContext; +use task::{IoContext, DrainFut}; use body::BinaryBody; use route::{Route, Frame}; use httpresponse::HttpResponse; @@ -137,6 +139,14 @@ impl HttpContext where A: Actor + Route { self.stream.push_back(Frame::Payload(None)) } + /// Returns drain future + pub fn drain(&mut self) -> Drain { + let fut = Rc::new(RefCell::new(DrainFut::new())); + self.stream.push_back(Frame::Drain(fut.clone())); + self.modified = true; + Drain{ a: PhantomData, inner: fut } + } + /// Check if connection still open pub fn connected(&self) -> bool { !self.disconnected @@ -199,6 +209,10 @@ impl Stream for HttpContext where A: Actor + Route // check wait futures if self.wait.poll(act, ctx) { + // get frame + if let Some(frame) = self.stream.pop_front() { + return Ok(Async::Ready(Some(frame))) + } return Ok(Async::NotReady) } @@ -269,3 +283,21 @@ impl ToEnvelope for HttpContext RemoteEnvelope::new(msg, tx).into() } } + + +pub struct Drain { + a: PhantomData, + inner: Rc> +} + +impl ActorFuture for Drain + where A: Actor +{ + type Item = (); + type Error = (); + type Actor = A; + + fn poll(&mut self, _: &mut A, _: &mut ::Context) -> Poll<(), ()> { + self.inner.borrow_mut().poll() + } +} diff --git a/src/lib.rs b/src/lib.rs index fe0f1f8ec..960cc56e4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,9 +8,8 @@ extern crate sha1; extern crate regex; #[macro_use] extern crate futures; -extern crate tokio_core; extern crate tokio_io; -extern crate tokio_proto; +extern crate tokio_core; extern crate cookie; extern crate http; diff --git a/src/resource.rs b/src/resource.rs index a030d5ec9..5d55665bd 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -13,7 +13,7 @@ use payload::Payload; use context::HttpContext; use httprequest::HttpRequest; use httpresponse::HttpResponse; -use httpcodes::HTTPMethodNotAllowed; +use httpcodes::{HTTPNotFound, HTTPMethodNotAllowed}; /// Http resource /// @@ -51,6 +51,14 @@ impl Default for Resource { impl Resource where S: 'static { + pub(crate) fn default_not_found() -> Self { + Resource { + name: String::new(), + state: PhantomData, + routes: HashMap::new(), + default: Box::new(HTTPNotFound)} + } + /// Set resource name pub fn set_name(&mut self, name: T) { self.name = name.to_string(); diff --git a/src/route.rs b/src/route.rs index f6f4e7bcb..e242a5e8f 100644 --- a/src/route.rs +++ b/src/route.rs @@ -1,12 +1,13 @@ use std::io; use std::rc::Rc; +use std::cell::RefCell; use std::marker::PhantomData; use actix::Actor; use http::{header, Version}; use futures::Stream; -use task::Task; +use task::{Task, DrainFut}; use body::BinaryBody; use context::HttpContext; use resource::Reply; @@ -21,9 +22,9 @@ use httpcodes::HTTPExpectationFailed; pub enum Frame { Message(HttpResponse), Payload(Option), + Drain(Rc>), } - /// Trait defines object that could be regestered as resource route #[allow(unused_variables)] pub trait RouteHandler: 'static { diff --git a/src/task.rs b/src/task.rs index f5f155e6d..992121c9e 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,6 +1,7 @@ use std::{mem, cmp, io}; use std::rc::Rc; use std::fmt::Write; +use std::cell::RefCell; use std::collections::VecDeque; use http::{StatusCode, Version}; @@ -8,6 +9,7 @@ use http::header::{HeaderValue, CONNECTION, CONTENT_TYPE, CONTENT_LENGTH, TRANSFER_ENCODING, DATE}; use bytes::BytesMut; use futures::{Async, Future, Poll, Stream}; +use futures::task::{Task as FutureTask, current as current_task}; use tokio_io::{AsyncRead, AsyncWrite}; use date; @@ -57,6 +59,45 @@ pub(crate) trait IoContext: Stream + 'static { fn disconnected(&mut self); } +#[doc(hidden)] +#[derive(Debug)] +pub struct DrainFut { + drained: bool, + task: Option, +} + +impl DrainFut { + + pub fn new() -> DrainFut { + DrainFut { + drained: false, + task: None, + } + } + + fn set(&mut self) { + self.drained = true; + if let Some(task) = self.task.take() { + task.notify() + } + } +} + +impl Future for DrainFut { + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll<(), ()> { + if self.drained { + Ok(Async::Ready(())) + } else { + self.task = Some(current_task()); + Ok(Async::NotReady) + } + } +} + + pub struct Task { state: TaskRunningState, iostate: TaskIOState, @@ -64,6 +105,7 @@ pub struct Task { stream: TaskStream, encoder: Encoder, buffer: BytesMut, + drain: Vec>>, upgrade: bool, keepalive: bool, prepared: Option, @@ -82,6 +124,7 @@ impl Task { state: TaskRunningState::Running, iostate: TaskIOState::Done, frames: frames, + drain: Vec::new(), stream: TaskStream::None, encoder: Encoder::length(0), buffer: BytesMut::new(), @@ -103,6 +146,7 @@ impl Task { stream: TaskStream::Stream(Box::new(stream)), encoder: Encoder::length(0), buffer: BytesMut::new(), + drain: Vec::new(), upgrade: false, keepalive: false, prepared: None, @@ -120,6 +164,7 @@ impl Task { stream: TaskStream::Context(Box::new(ctx)), encoder: Encoder::length(0), buffer: BytesMut::new(), + drain: Vec::new(), upgrade: false, keepalive: false, prepared: None, @@ -275,47 +320,53 @@ impl Task { if self.frames.is_empty() && self.iostate.is_done() { return Ok(Async::Ready(self.state.is_done())); } else { - // poll stream - if self.state == TaskRunningState::Running { - match self.poll() { - Ok(Async::Ready(_)) => { - self.state = TaskRunningState::Done; - } - Ok(Async::NotReady) => (), - Err(_) => return Err(()) - } - } - - // use exiting frames - while let Some(frame) = self.frames.pop_front() { - trace!("IO Frame: {:?}", frame); - match frame { - Frame::Message(response) => { - if !self.disconnected { - self.prepare(req, response); + if self.drain.is_empty() { + // poll stream + if self.state == TaskRunningState::Running { + match self.poll() { + Ok(Async::Ready(_)) => { + self.state = TaskRunningState::Done; } + Ok(Async::NotReady) => (), + Err(_) => return Err(()) } - Frame::Payload(Some(chunk)) => { - if !self.disconnected { - if self.prepared.is_some() { - // TODO: add warning, write after EOF - self.encoder.encode(&mut self.buffer, chunk.as_ref()); - } else { - // might be response for EXCEPT - self.buffer.extend_from_slice(chunk.as_ref()) + } + + // use exiting frames + while let Some(frame) = self.frames.pop_front() { + trace!("IO Frame: {:?}", frame); + match frame { + Frame::Message(response) => { + if !self.disconnected { + self.prepare(req, response); } } - }, - Frame::Payload(None) => { - if !self.disconnected && - !self.encoder.encode(&mut self.buffer, [].as_ref()) - { - // TODO: add error "not eof"" - debug!("last payload item, but it is not EOF "); - return Err(()) + Frame::Payload(Some(chunk)) => { + if !self.disconnected { + if self.prepared.is_some() { + // TODO: add warning, write after EOF + self.encoder.encode(&mut self.buffer, chunk.as_ref()); + } else { + // might be response for EXCEPT + self.buffer.extend_from_slice(chunk.as_ref()) + } + } + }, + Frame::Payload(None) => { + if !self.disconnected && + !self.encoder.encode(&mut self.buffer, [].as_ref()) + { + // TODO: add error "not eof"" + debug!("last payload item, but it is not EOF "); + return Err(()) + } + break + }, + Frame::Drain(fut) => { + self.drain.push(fut); + break } - break - }, + } } } } @@ -347,6 +398,23 @@ impl Task { self.iostate = TaskIOState::Done; } + // drain + if self.buffer.is_empty() && !self.drain.is_empty() { + match io.flush() { + Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + return Ok(Async::NotReady) + } + Err(_) => return Err(()), + } + + for fut in &mut self.drain { + fut.borrow_mut().set() + } + self.drain.clear(); + // return self.poll_io(io, req); + } + // response is completed if (self.buffer.is_empty() || self.disconnected) && self.iostate.is_done() { // run middlewares @@ -357,7 +425,6 @@ impl Task { } } } - Ok(Async::Ready(self.state.is_done())) } else { Ok(Async::NotReady) @@ -391,6 +458,7 @@ impl Task { return Err(()) } }, + _ => (), } self.frames.push_back(frame) }, @@ -399,7 +467,7 @@ impl Task { Ok(Async::NotReady) => return Ok(Async::NotReady), Err(_) => - return Err(()) + return Err(()), } } }