diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index 1311a098..ea8f91e0 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -130,8 +130,8 @@ where B: MessageBody, { None, - ExpectCall(Pin>), - ServiceCall(Pin>), + ExpectCall(#[pin] X::Future), + ServiceCall(#[pin] S::Future), SendPayload(#[pin] ResponseBody), } @@ -347,7 +347,7 @@ where self: Pin<&mut Self>, message: Response<()>, body: ResponseBody, - ) -> Result, DispatchError> { + ) -> Result<(), DispatchError> { let mut this = self.project(); this.codec .encode(Message::Item((message, body.size())), &mut this.write_buf) @@ -360,9 +360,10 @@ where this.flags.set(Flags::KEEPALIVE, this.codec.keepalive()); match body.size() { - BodySize::None | BodySize::Empty => Ok(State::None), - _ => Ok(State::SendPayload(body)), - } + BodySize::None | BodySize::Empty => this.state.set(State::None), + _ => this.state.set(State::SendPayload(body)), + }; + Ok(()) } fn send_continue(self: Pin<&mut Self>) { @@ -377,49 +378,52 @@ where ) -> Result { loop { let mut this = self.as_mut().project(); - let state = match this.state.project() { + // state is not changed on Poll::Pending. + // other variant and conditions always trigger a state change(or an error). + let state_change = match this.state.project() { StateProj::None => match this.messages.pop_front() { Some(DispatcherMessage::Item(req)) => { - Some(self.as_mut().handle_request(req, cx)?) + self.as_mut().handle_request(req, cx)?; + true } - Some(DispatcherMessage::Error(res)) => Some( + Some(DispatcherMessage::Error(res)) => { self.as_mut() - .send_response(res, ResponseBody::Other(Body::Empty))?, - ), + .send_response(res, ResponseBody::Other(Body::Empty))?; + true + } Some(DispatcherMessage::Upgrade(req)) => { return Ok(PollResponse::Upgrade(req)); } - None => None, + None => false, }, - StateProj::ExpectCall(fut) => match fut.as_mut().poll(cx) { + StateProj::ExpectCall(fut) => match fut.poll(cx) { Poll::Ready(Ok(req)) => { self.as_mut().send_continue(); this = self.as_mut().project(); - this.state - .set(State::ServiceCall(Box::pin(this.service.call(req)))); + this.state.set(State::ServiceCall(this.service.call(req))); continue; } Poll::Ready(Err(e)) => { let res: Response = e.into().into(); let (res, body) = res.replace_body(()); - Some(self.as_mut().send_response(res, body.into_body())?) + self.as_mut().send_response(res, body.into_body())?; + true } - Poll::Pending => None, + Poll::Pending => false, }, - StateProj::ServiceCall(fut) => match fut.as_mut().poll(cx) { + StateProj::ServiceCall(fut) => match fut.poll(cx) { Poll::Ready(Ok(res)) => { let (res, body) = res.into().replace_body(()); - let state = self.as_mut().send_response(res, body)?; - this = self.as_mut().project(); - this.state.set(state); + self.as_mut().send_response(res, body)?; continue; } Poll::Ready(Err(e)) => { let res: Response = e.into().into(); let (res, body) = res.replace_body(()); - Some(self.as_mut().send_response(res, body.into_body())?) + self.as_mut().send_response(res, body.into_body())?; + true } - Poll::Pending => None, + Poll::Pending => false, }, StateProj::SendPayload(mut stream) => { loop { @@ -454,11 +458,8 @@ where } }; - this = self.as_mut().project(); - - // set new state - if let Some(state) = state { - this.state.set(state); + // state is changed and continue when the state is not Empty + if state_change { if !self.state.is_empty() { continue; } @@ -483,39 +484,67 @@ where mut self: Pin<&mut Self>, req: Request, cx: &mut Context<'_>, - ) -> Result, DispatchError> { + ) -> Result<(), DispatchError> { // Handle `EXPECT: 100-Continue` header - let req = if req.head().expect() { - let mut task = Box::pin(self.as_mut().project().expect.call(req)); - match task.as_mut().poll(cx) { - Poll::Ready(Ok(req)) => { - self.as_mut().send_continue(); - req - } - Poll::Pending => return Ok(State::ExpectCall(task)), - Poll::Ready(Err(e)) => { - let e = e.into(); - let res: Response = e.into(); - let (res, body) = res.replace_body(()); - return self.send_response(res, body.into_body()); - } - } + if req.head().expect() { + // set dispatcher state so the future is pinned. + let task = self.as_mut().project().expect.call(req); + self.as_mut().project().state.set(State::ExpectCall(task)); } else { - req + // the same as above. + let task = self.as_mut().project().service.call(req); + self.as_mut().project().state.set(State::ServiceCall(task)); }; - // Call service - let mut task = Box::pin(self.as_mut().project().service.call(req)); - match task.as_mut().poll(cx) { - Poll::Ready(Ok(res)) => { - let (res, body) = res.into().replace_body(()); - self.send_response(res, body) - } - Poll::Pending => Ok(State::ServiceCall(task)), - Poll::Ready(Err(e)) => { - let res: Response = e.into().into(); - let (res, body) = res.replace_body(()); - self.send_response(res, body.into_body()) + // eagerly poll the future for once(or twice if expect is resolved immediately). + loop { + match self.as_mut().project().state.project() { + StateProj::ExpectCall(fut) => { + match fut.poll(cx) { + // expect is resolved. continue loop and poll the service call branch. + Poll::Ready(Ok(req)) => { + self.as_mut().send_continue(); + let task = self.as_mut().project().service.call(req); + self.as_mut().project().state.set(State::ServiceCall(task)); + continue; + } + // future is pending. return Ok(()) to notify that a new state is + // set and the outer loop should be continue. + Poll::Pending => return Ok(()), + // future is error. send response and return a result. On success + // to notify the dispatcher a new state is set and the outer loop + // should be continue. + Poll::Ready(Err(e)) => { + let e = e.into(); + let res: Response = e.into(); + let (res, body) = res.replace_body(()); + return self.send_response(res, body.into_body()); + } + } + } + StateProj::ServiceCall(fut) => { + // return no matter the service call future's result. + return match fut.poll(cx) { + // future is resolved. send response and return a result. On success + // to notify the dispatcher a new state is set and the outer loop + // should be continue. + Poll::Ready(Ok(res)) => { + let (res, body) = res.into().replace_body(()); + self.send_response(res, body) + } + // see the comment on ExpectCall state branch's Pending. + Poll::Pending => Ok(()), + // see the comment on ExpectCall state branch's Ready(Err(e)). + Poll::Ready(Err(e)) => { + let res: Response = e.into().into(); + let (res, body) = res.replace_body(()); + self.send_response(res, body.into_body()) + } + }; + } + _ => unreachable!( + "State must be set to ServiceCall or ExceptCall in handle_request" + ), } } } @@ -566,9 +595,8 @@ where // handle request early if this.state.is_empty() { - let state = self.as_mut().handle_request(req, cx)?; + self.as_mut().handle_request(req, cx)?; this = self.as_mut().project(); - this.state.set(state); } else { this.messages.push_back(DispatcherMessage::Item(req)); } @@ -1004,7 +1032,7 @@ mod tests { lazy(|cx| { let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n"); - let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf, ServiceConfig::default(), CloneableService::new(ok_service()), @@ -1015,15 +1043,17 @@ mod tests { None, ); - match Pin::new(&mut h1).poll(cx) { + futures_util::pin_mut!(h1); + + match h1.as_mut().poll(cx) { Poll::Pending => panic!(), Poll::Ready(res) => assert!(res.is_err()), } - if let DispatcherState::Normal(ref mut inner) = h1.inner { + if let DispatcherStateProj::Normal(inner) = h1.project().inner.project() { assert!(inner.flags.contains(Flags::READ_DISCONNECT)); assert_eq!( - &inner.io.take().unwrap().write_buf[..26], + &inner.project().io.take().unwrap().write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n" ); } @@ -1043,7 +1073,7 @@ mod tests { let cfg = ServiceConfig::new(KeepAlive::Disabled, 1, 1, false, None); - let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf, cfg, CloneableService::new(echo_path_service()), @@ -1054,9 +1084,11 @@ mod tests { None, ); + futures_util::pin_mut!(h1); + assert!(matches!(&h1.inner, DispatcherState::Normal(_))); - match Pin::new(&mut h1).poll(cx) { + match h1.as_mut().poll(cx) { Poll::Pending => panic!("first poll should not be pending"), Poll::Ready(res) => assert!(res.is_ok()), } @@ -1064,8 +1096,8 @@ mod tests { // polls: initial => shutdown assert_eq!(h1.poll_count, 2); - if let DispatcherState::Normal(ref mut inner) = h1.inner { - let res = &mut inner.io.take().unwrap().write_buf[..]; + if let DispatcherStateProj::Normal(inner) = h1.project().inner.project() { + let res = &mut inner.project().io.take().unwrap().write_buf[..]; stabilize_date_header(res); let exp = b"\ @@ -1096,7 +1128,7 @@ mod tests { let cfg = ServiceConfig::new(KeepAlive::Disabled, 1, 1, false, None); - let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf, cfg, CloneableService::new(echo_path_service()), @@ -1107,9 +1139,11 @@ mod tests { None, ); + futures_util::pin_mut!(h1); + assert!(matches!(&h1.inner, DispatcherState::Normal(_))); - match Pin::new(&mut h1).poll(cx) { + match h1.as_mut().poll(cx) { Poll::Pending => panic!("first poll should not be pending"), Poll::Ready(res) => assert!(res.is_err()), } @@ -1117,8 +1151,8 @@ mod tests { // polls: initial => shutdown assert_eq!(h1.poll_count, 1); - if let DispatcherState::Normal(ref mut inner) = h1.inner { - let res = &mut inner.io.take().unwrap().write_buf[..]; + if let DispatcherStateProj::Normal(inner) = h1.project().inner.project() { + let res = &mut inner.project().io.take().unwrap().write_buf[..]; stabilize_date_header(res); let exp = b"\ @@ -1144,7 +1178,7 @@ mod tests { lazy(|cx| { let mut buf = TestSeqBuffer::empty(); let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); - let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler<_>>::new( + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler<_>>::new( buf.clone(), cfg, CloneableService::new(echo_payload_service()), @@ -1164,7 +1198,9 @@ mod tests { ", ); - assert!(Pin::new(&mut h1).poll(cx).is_pending()); + futures_util::pin_mut!(h1); + + assert!(h1.as_mut().poll(cx).is_pending()); assert!(matches!(&h1.inner, DispatcherState::Normal(_))); // polls: manual @@ -1181,7 +1217,7 @@ mod tests { } buf.extend_read_buf("12345"); - assert!(Pin::new(&mut h1).poll(cx).is_ready()); + assert!(h1.as_mut().poll(cx).is_ready()); // polls: manual manual shutdown assert_eq!(h1.poll_count, 3); @@ -1214,7 +1250,7 @@ mod tests { lazy(|cx| { let mut buf = TestSeqBuffer::empty(); let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); - let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler<_>>::new( + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler<_>>::new( buf.clone(), cfg, CloneableService::new(echo_path_service()), @@ -1234,7 +1270,9 @@ mod tests { ", ); - assert!(Pin::new(&mut h1).poll(cx).is_ready()); + futures_util::pin_mut!(h1); + + assert!(h1.as_mut().poll(cx).is_ready()); assert!(matches!(&h1.inner, DispatcherState::Normal(_))); // polls: manual shutdown @@ -1272,7 +1310,7 @@ mod tests { lazy(|cx| { let mut buf = TestSeqBuffer::empty(); let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); - let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler<_>>::new( + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler<_>>::new( buf.clone(), cfg, CloneableService::new(ok_service()), @@ -1292,7 +1330,9 @@ mod tests { ", ); - assert!(Pin::new(&mut h1).poll(cx).is_ready()); + futures_util::pin_mut!(h1); + + assert!(h1.as_mut().poll(cx).is_ready()); assert!(matches!(&h1.inner, DispatcherState::Upgrade(_))); // polls: manual shutdown