mirror of
https://github.com/fafhrd91/actix-web
synced 2024-11-24 00:21:08 +01:00
refactor dispatcher to avoid possible UB with DispatcherState Pin
This commit is contained in:
parent
69dab0063c
commit
c05f9475c5
@ -71,7 +71,6 @@ where
|
||||
{
|
||||
Normal(#[pin] InnerDispatcher<T, S, B, X, U>),
|
||||
Upgrade(#[pin] U::Future),
|
||||
None,
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
@ -101,7 +100,7 @@ where
|
||||
ka_expire: Instant,
|
||||
ka_timer: Option<Delay>,
|
||||
|
||||
io: T,
|
||||
io: Option<T>,
|
||||
read_buf: BytesMut,
|
||||
write_buf: BytesMut,
|
||||
codec: Codec,
|
||||
@ -148,22 +147,6 @@ where
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S, B, X, U> DispatcherState<T, S, B, X, U>
|
||||
where
|
||||
S: Service<Request = Request>,
|
||||
S::Error: Into<Error>,
|
||||
B: MessageBody,
|
||||
X: Service<Request = Request, Response = Request>,
|
||||
X::Error: Into<Error>,
|
||||
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
|
||||
U::Error: fmt::Display,
|
||||
{
|
||||
fn take(self: Pin<&mut Self>) -> Self {
|
||||
std::mem::replace(unsafe { self.get_unchecked_mut() }, Self::None)
|
||||
}
|
||||
}
|
||||
|
||||
enum PollResponse {
|
||||
Upgrade(Request),
|
||||
DoNothing,
|
||||
@ -258,7 +241,7 @@ where
|
||||
state: State::None,
|
||||
error: None,
|
||||
messages: VecDeque::new(),
|
||||
io,
|
||||
io: Some(io),
|
||||
codec,
|
||||
read_buf,
|
||||
service,
|
||||
@ -322,9 +305,10 @@ where
|
||||
let len = self.write_buf.len();
|
||||
let mut written = 0;
|
||||
#[project]
|
||||
let InnerDispatcher { mut io, write_buf, .. } = self.project();
|
||||
let InnerDispatcher { io, write_buf, .. } = self.project();
|
||||
let mut io = Pin::new(io.as_mut().unwrap());
|
||||
while written < len {
|
||||
match Pin::new(&mut io).poll_write(cx, &write_buf[written..])
|
||||
match io.as_mut().poll_write(cx, &write_buf[written..])
|
||||
{
|
||||
Poll::Ready(Ok(0)) => {
|
||||
return Err(DispatchError::Io(io::Error::new(
|
||||
@ -751,10 +735,10 @@ where
|
||||
} else {
|
||||
// flush buffer
|
||||
inner.as_mut().poll_flush(cx)?;
|
||||
if !inner.write_buf.is_empty() {
|
||||
if !inner.write_buf.is_empty() || inner.io.is_none() {
|
||||
Poll::Pending
|
||||
} else {
|
||||
match Pin::new(inner.project().io).poll_shutdown(cx) {
|
||||
match Pin::new(inner.project().io).as_pin_mut().unwrap().poll_shutdown(cx) {
|
||||
Poll::Ready(res) => {
|
||||
Poll::Ready(res.map_err(DispatchError::from))
|
||||
}
|
||||
@ -767,7 +751,7 @@ where
|
||||
let should_disconnect =
|
||||
if !inner.flags.contains(Flags::READ_DISCONNECT) {
|
||||
let mut inner_p = inner.as_mut().project();
|
||||
read_available(cx, &mut inner_p.io, &mut inner_p.read_buf)?
|
||||
read_available(cx, inner_p.io.as_mut().unwrap(), &mut inner_p.read_buf)?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@ -793,20 +777,17 @@ where
|
||||
|
||||
// switch to upgrade handler
|
||||
if let PollResponse::Upgrade(req) = result {
|
||||
if let DispatcherState::Normal(inner) = self.as_mut().project().inner.take() {
|
||||
let inner_p = inner.as_mut().project();
|
||||
let mut parts = FramedParts::with_read_buf(
|
||||
inner.io,
|
||||
inner.codec,
|
||||
inner.read_buf,
|
||||
inner_p.io.take().unwrap(),
|
||||
std::mem::take(inner_p.codec),
|
||||
std::mem::take(inner_p.read_buf),
|
||||
);
|
||||
parts.write_buf = inner.write_buf;
|
||||
parts.write_buf = std::mem::take(inner_p.write_buf);
|
||||
let framed = Framed::from_parts(parts);
|
||||
let upgrade = inner.upgrade.unwrap().call((req, framed));
|
||||
let upgrade = inner_p.upgrade.take().unwrap().call((req, framed));
|
||||
self.as_mut().project().inner.set(DispatcherState::Upgrade(upgrade));
|
||||
return self.poll(cx);
|
||||
} else {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
|
||||
// we didnt get WouldBlock from write operation,
|
||||
@ -859,7 +840,6 @@ where
|
||||
DispatchError::Upgrade
|
||||
})
|
||||
}
|
||||
DispatcherState::None => panic!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -949,9 +929,9 @@ mod tests {
|
||||
Poll::Ready(res) => assert!(res.is_err()),
|
||||
}
|
||||
|
||||
if let DispatcherState::Normal(ref inner) = h1.inner {
|
||||
if let DispatcherState::Normal(ref mut inner) = h1.inner {
|
||||
assert!(inner.flags.contains(Flags::READ_DISCONNECT));
|
||||
assert_eq!(&inner.io.write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n");
|
||||
assert_eq!(&inner.io.take().unwrap().write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n");
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
Loading…
Reference in New Issue
Block a user