1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-01-19 06:04:40 +01:00

fix bug where upgrade future is not reset properly (#1880)

This commit is contained in:
fakeshadow 2021-01-07 08:57:34 +08:00 committed by GitHub
parent 85753130d9
commit 6d710629af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 55 deletions

View File

@ -3,8 +3,7 @@ use std::task::{Context, Poll};
use std::{fmt, mem};
use bytes::{Bytes, BytesMut};
use futures_core::Stream;
use futures_util::ready;
use futures_core::{ready, Stream};
use pin_project::pin_project;
use crate::error::Error;

View File

@ -338,7 +338,7 @@ where
.map_err(|e| log::error!("Init http service error: {:?}", e)))?;
this = self.as_mut().project();
*this.upgrade = Some(upgrade);
this.fut_ex.set(None);
this.fut_upg.set(None);
}
let result = ready!(this

View File

@ -9,7 +9,6 @@ use actix_rt::net::TcpStream;
use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory};
use bytes::Bytes;
use futures_core::{ready, Future};
use futures_util::future::ok;
use h2::server::{self, Handshake};
use pin_project::pin_project;
@ -175,9 +174,9 @@ where
Error = DispatchError,
InitError = (),
> {
pipeline_factory(|io: TcpStream| {
pipeline_factory(|io: TcpStream| async {
let peer_addr = io.peer_addr().ok();
ok((io, Protocol::Http1, peer_addr))
Ok((io, Protocol::Http1, peer_addr))
})
.and_then(self)
}
@ -227,7 +226,7 @@ mod openssl {
.map_err(TlsError::Tls)
.map_init_err(|_| panic!()),
)
.and_then(|io: SslStream<TcpStream>| {
.and_then(|io: SslStream<TcpStream>| async {
let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() {
if protos.windows(2).any(|window| window == b"h2") {
Protocol::Http2
@ -238,7 +237,7 @@ mod openssl {
Protocol::Http1
};
let peer_addr = io.get_ref().peer_addr().ok();
ok((io, proto, peer_addr))
Ok((io, proto, peer_addr))
})
.and_then(self.map_err(TlsError::Service))
}
@ -295,7 +294,7 @@ mod rustls {
.map_err(TlsError::Tls)
.map_init_err(|_| panic!()),
)
.and_then(|io: TlsStream<TcpStream>| {
.and_then(|io: TlsStream<TcpStream>| async {
let proto = if let Some(protos) = io.get_ref().1.get_alpn_protocol() {
if protos.windows(2).any(|window| window == b"h2") {
Protocol::Http2
@ -306,7 +305,7 @@ mod rustls {
Protocol::Http1
};
let peer_addr = io.get_ref().0.peer_addr().ok();
ok((io, proto, peer_addr))
Ok((io, proto, peer_addr))
})
.and_then(self.map_err(TlsError::Service))
}
@ -413,7 +412,7 @@ where
.map_err(|e| log::error!("Init http service error: {:?}", e)))?;
this = self.as_mut().project();
*this.upgrade = Some(upgrade);
this.fut_ex.set(None);
this.fut_upg.set(None);
}
let result = ready!(this
@ -645,45 +644,16 @@ where
{
type Output = Result<(), DispatchError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().state.poll(cx)
}
}
impl<T, S, B, X, U> State<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>,
S::Error: Into<Error> + 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
X: Service<Request, Response = Request>,
X::Error: Into<Error>,
U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display,
{
fn poll(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), DispatchError>> {
match self.as_mut().project() {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().project().state.project() {
StateProj::H1(disp) => disp.poll(cx),
StateProj::H2(disp) => disp.poll(cx),
StateProj::H2Handshake(ref mut data) => {
let conn = if let Some(ref mut item) = data {
match Pin::new(&mut item.0).poll(cx) {
Poll::Ready(Ok(conn)) => conn,
Poll::Ready(Err(err)) => {
trace!("H2 handshake error: {}", err);
return Poll::Ready(Err(err.into()));
}
Poll::Pending => return Poll::Pending,
}
} else {
panic!()
};
let (_, cfg, srv, on_connect_data, peer_addr) = data.take().unwrap();
self.set(State::H2(Dispatcher::new(
StateProj::H2Handshake(data) => {
match ready!(Pin::new(&mut data.as_mut().unwrap().0).poll(cx)) {
Ok(conn) => {
let (_, cfg, srv, on_connect_data, peer_addr) =
data.take().unwrap();
self.as_mut().project().state.set(State::H2(Dispatcher::new(
srv,
conn,
on_connect_data,
@ -693,6 +663,12 @@ where
)));
self.poll(cx)
}
Err(err) => {
trace!("H2 handshake error: {}", err);
Poll::Ready(Err(err.into()))
}
}
}
}
}
}