mirror of
https://github.com/fafhrd91/actix-web
synced 2024-11-24 00:21:08 +01:00
refactor dispatcher to avoid possible UB with DispatcherState Pin
This commit is contained in:
parent
69dab0063c
commit
c05f9475c5
@ -71,7 +71,6 @@ where
|
|||||||
{
|
{
|
||||||
Normal(#[pin] InnerDispatcher<T, S, B, X, U>),
|
Normal(#[pin] InnerDispatcher<T, S, B, X, U>),
|
||||||
Upgrade(#[pin] U::Future),
|
Upgrade(#[pin] U::Future),
|
||||||
None,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pin_project]
|
#[pin_project]
|
||||||
@ -101,7 +100,7 @@ where
|
|||||||
ka_expire: Instant,
|
ka_expire: Instant,
|
||||||
ka_timer: Option<Delay>,
|
ka_timer: Option<Delay>,
|
||||||
|
|
||||||
io: T,
|
io: Option<T>,
|
||||||
read_buf: BytesMut,
|
read_buf: BytesMut,
|
||||||
write_buf: BytesMut,
|
write_buf: BytesMut,
|
||||||
codec: Codec,
|
codec: Codec,
|
||||||
@ -148,22 +147,6 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T, S, B, X, U> DispatcherState<T, S, B, X, U>
|
|
||||||
where
|
|
||||||
S: Service<Request = Request>,
|
|
||||||
S::Error: Into<Error>,
|
|
||||||
B: MessageBody,
|
|
||||||
X: Service<Request = Request, Response = Request>,
|
|
||||||
X::Error: Into<Error>,
|
|
||||||
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
|
|
||||||
U::Error: fmt::Display,
|
|
||||||
{
|
|
||||||
fn take(self: Pin<&mut Self>) -> Self {
|
|
||||||
std::mem::replace(unsafe { self.get_unchecked_mut() }, Self::None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
enum PollResponse {
|
enum PollResponse {
|
||||||
Upgrade(Request),
|
Upgrade(Request),
|
||||||
DoNothing,
|
DoNothing,
|
||||||
@ -258,7 +241,7 @@ where
|
|||||||
state: State::None,
|
state: State::None,
|
||||||
error: None,
|
error: None,
|
||||||
messages: VecDeque::new(),
|
messages: VecDeque::new(),
|
||||||
io,
|
io: Some(io),
|
||||||
codec,
|
codec,
|
||||||
read_buf,
|
read_buf,
|
||||||
service,
|
service,
|
||||||
@ -322,9 +305,10 @@ where
|
|||||||
let len = self.write_buf.len();
|
let len = self.write_buf.len();
|
||||||
let mut written = 0;
|
let mut written = 0;
|
||||||
#[project]
|
#[project]
|
||||||
let InnerDispatcher { mut io, write_buf, .. } = self.project();
|
let InnerDispatcher { io, write_buf, .. } = self.project();
|
||||||
|
let mut io = Pin::new(io.as_mut().unwrap());
|
||||||
while written < len {
|
while written < len {
|
||||||
match Pin::new(&mut io).poll_write(cx, &write_buf[written..])
|
match io.as_mut().poll_write(cx, &write_buf[written..])
|
||||||
{
|
{
|
||||||
Poll::Ready(Ok(0)) => {
|
Poll::Ready(Ok(0)) => {
|
||||||
return Err(DispatchError::Io(io::Error::new(
|
return Err(DispatchError::Io(io::Error::new(
|
||||||
@ -751,10 +735,10 @@ where
|
|||||||
} else {
|
} else {
|
||||||
// flush buffer
|
// flush buffer
|
||||||
inner.as_mut().poll_flush(cx)?;
|
inner.as_mut().poll_flush(cx)?;
|
||||||
if !inner.write_buf.is_empty() {
|
if !inner.write_buf.is_empty() || inner.io.is_none() {
|
||||||
Poll::Pending
|
Poll::Pending
|
||||||
} else {
|
} else {
|
||||||
match Pin::new(inner.project().io).poll_shutdown(cx) {
|
match Pin::new(inner.project().io).as_pin_mut().unwrap().poll_shutdown(cx) {
|
||||||
Poll::Ready(res) => {
|
Poll::Ready(res) => {
|
||||||
Poll::Ready(res.map_err(DispatchError::from))
|
Poll::Ready(res.map_err(DispatchError::from))
|
||||||
}
|
}
|
||||||
@ -767,7 +751,7 @@ where
|
|||||||
let should_disconnect =
|
let should_disconnect =
|
||||||
if !inner.flags.contains(Flags::READ_DISCONNECT) {
|
if !inner.flags.contains(Flags::READ_DISCONNECT) {
|
||||||
let mut inner_p = inner.as_mut().project();
|
let mut inner_p = inner.as_mut().project();
|
||||||
read_available(cx, &mut inner_p.io, &mut inner_p.read_buf)?
|
read_available(cx, inner_p.io.as_mut().unwrap(), &mut inner_p.read_buf)?
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
@ -793,20 +777,17 @@ where
|
|||||||
|
|
||||||
// switch to upgrade handler
|
// switch to upgrade handler
|
||||||
if let PollResponse::Upgrade(req) = result {
|
if let PollResponse::Upgrade(req) = result {
|
||||||
if let DispatcherState::Normal(inner) = self.as_mut().project().inner.take() {
|
let inner_p = inner.as_mut().project();
|
||||||
let mut parts = FramedParts::with_read_buf(
|
let mut parts = FramedParts::with_read_buf(
|
||||||
inner.io,
|
inner_p.io.take().unwrap(),
|
||||||
inner.codec,
|
std::mem::take(inner_p.codec),
|
||||||
inner.read_buf,
|
std::mem::take(inner_p.read_buf),
|
||||||
);
|
);
|
||||||
parts.write_buf = inner.write_buf;
|
parts.write_buf = std::mem::take(inner_p.write_buf);
|
||||||
let framed = Framed::from_parts(parts);
|
let framed = Framed::from_parts(parts);
|
||||||
let upgrade = inner.upgrade.unwrap().call((req, framed));
|
let upgrade = inner_p.upgrade.take().unwrap().call((req, framed));
|
||||||
self.as_mut().project().inner.set(DispatcherState::Upgrade(upgrade));
|
self.as_mut().project().inner.set(DispatcherState::Upgrade(upgrade));
|
||||||
return self.poll(cx);
|
return self.poll(cx);
|
||||||
} else {
|
|
||||||
panic!()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// we didnt get WouldBlock from write operation,
|
// we didnt get WouldBlock from write operation,
|
||||||
@ -859,7 +840,6 @@ where
|
|||||||
DispatchError::Upgrade
|
DispatchError::Upgrade
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
DispatcherState::None => panic!(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -949,9 +929,9 @@ mod tests {
|
|||||||
Poll::Ready(res) => assert!(res.is_err()),
|
Poll::Ready(res) => assert!(res.is_err()),
|
||||||
}
|
}
|
||||||
|
|
||||||
if let DispatcherState::Normal(ref inner) = h1.inner {
|
if let DispatcherState::Normal(ref mut inner) = h1.inner {
|
||||||
assert!(inner.flags.contains(Flags::READ_DISCONNECT));
|
assert!(inner.flags.contains(Flags::READ_DISCONNECT));
|
||||||
assert_eq!(&inner.io.write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n");
|
assert_eq!(&inner.io.take().unwrap().write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n");
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
Loading…
Reference in New Issue
Block a user