1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-11-24 00:21:08 +01:00

refactor dispatcher to avoid possible UB with DispatcherState Pin

This commit is contained in:
Maksym Vorobiov 2020-02-10 13:17:38 +02:00 committed by Yuki Okushi
parent 69dab0063c
commit c05f9475c5

View File

@ -71,7 +71,6 @@ where
{ {
Normal(#[pin] InnerDispatcher<T, S, B, X, U>), Normal(#[pin] InnerDispatcher<T, S, B, X, U>),
Upgrade(#[pin] U::Future), Upgrade(#[pin] U::Future),
None,
} }
#[pin_project] #[pin_project]
@ -101,7 +100,7 @@ where
ka_expire: Instant, ka_expire: Instant,
ka_timer: Option<Delay>, ka_timer: Option<Delay>,
io: T, io: Option<T>,
read_buf: BytesMut, read_buf: BytesMut,
write_buf: BytesMut, write_buf: BytesMut,
codec: Codec, codec: Codec,
@ -148,22 +147,6 @@ where
} }
} }
} }
impl<T, S, B, X, U> DispatcherState<T, S, B, X, U>
where
S: Service<Request = Request>,
S::Error: Into<Error>,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
{
fn take(self: Pin<&mut Self>) -> Self {
std::mem::replace(unsafe { self.get_unchecked_mut() }, Self::None)
}
}
enum PollResponse { enum PollResponse {
Upgrade(Request), Upgrade(Request),
DoNothing, DoNothing,
@ -258,7 +241,7 @@ where
state: State::None, state: State::None,
error: None, error: None,
messages: VecDeque::new(), messages: VecDeque::new(),
io, io: Some(io),
codec, codec,
read_buf, read_buf,
service, service,
@ -322,9 +305,10 @@ where
let len = self.write_buf.len(); let len = self.write_buf.len();
let mut written = 0; let mut written = 0;
#[project] #[project]
let InnerDispatcher { mut io, write_buf, .. } = self.project(); let InnerDispatcher { io, write_buf, .. } = self.project();
let mut io = Pin::new(io.as_mut().unwrap());
while written < len { while written < len {
match Pin::new(&mut io).poll_write(cx, &write_buf[written..]) match io.as_mut().poll_write(cx, &write_buf[written..])
{ {
Poll::Ready(Ok(0)) => { Poll::Ready(Ok(0)) => {
return Err(DispatchError::Io(io::Error::new( return Err(DispatchError::Io(io::Error::new(
@ -751,10 +735,10 @@ where
} else { } else {
// flush buffer // flush buffer
inner.as_mut().poll_flush(cx)?; inner.as_mut().poll_flush(cx)?;
if !inner.write_buf.is_empty() { if !inner.write_buf.is_empty() || inner.io.is_none() {
Poll::Pending Poll::Pending
} else { } else {
match Pin::new(inner.project().io).poll_shutdown(cx) { match Pin::new(inner.project().io).as_pin_mut().unwrap().poll_shutdown(cx) {
Poll::Ready(res) => { Poll::Ready(res) => {
Poll::Ready(res.map_err(DispatchError::from)) Poll::Ready(res.map_err(DispatchError::from))
} }
@ -767,7 +751,7 @@ where
let should_disconnect = let should_disconnect =
if !inner.flags.contains(Flags::READ_DISCONNECT) { if !inner.flags.contains(Flags::READ_DISCONNECT) {
let mut inner_p = inner.as_mut().project(); let mut inner_p = inner.as_mut().project();
read_available(cx, &mut inner_p.io, &mut inner_p.read_buf)? read_available(cx, inner_p.io.as_mut().unwrap(), &mut inner_p.read_buf)?
} else { } else {
None None
}; };
@ -793,20 +777,17 @@ where
// switch to upgrade handler // switch to upgrade handler
if let PollResponse::Upgrade(req) = result { if let PollResponse::Upgrade(req) = result {
if let DispatcherState::Normal(inner) = self.as_mut().project().inner.take() { let inner_p = inner.as_mut().project();
let mut parts = FramedParts::with_read_buf( let mut parts = FramedParts::with_read_buf(
inner.io, inner_p.io.take().unwrap(),
inner.codec, std::mem::take(inner_p.codec),
inner.read_buf, std::mem::take(inner_p.read_buf),
); );
parts.write_buf = inner.write_buf; parts.write_buf = std::mem::take(inner_p.write_buf);
let framed = Framed::from_parts(parts); let framed = Framed::from_parts(parts);
let upgrade = inner.upgrade.unwrap().call((req, framed)); let upgrade = inner_p.upgrade.take().unwrap().call((req, framed));
self.as_mut().project().inner.set(DispatcherState::Upgrade(upgrade)); self.as_mut().project().inner.set(DispatcherState::Upgrade(upgrade));
return self.poll(cx); return self.poll(cx);
} else {
panic!()
}
} }
// we didnt get WouldBlock from write operation, // we didnt get WouldBlock from write operation,
@ -859,7 +840,6 @@ where
DispatchError::Upgrade DispatchError::Upgrade
}) })
} }
DispatcherState::None => panic!(),
} }
} }
} }
@ -949,9 +929,9 @@ mod tests {
Poll::Ready(res) => assert!(res.is_err()), Poll::Ready(res) => assert!(res.is_err()),
} }
if let DispatcherState::Normal(ref inner) = h1.inner { if let DispatcherState::Normal(ref mut inner) = h1.inner {
assert!(inner.flags.contains(Flags::READ_DISCONNECT)); assert!(inner.flags.contains(Flags::READ_DISCONNECT));
assert_eq!(&inner.io.write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n"); assert_eq!(&inner.io.take().unwrap().write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n");
} }
}) })
.await; .await;