1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-01-19 06:04:40 +01:00

fix bug where upgrade future is not reset properly (#1880)

This commit is contained in:
fakeshadow 2021-01-07 08:57:34 +08:00 committed by GitHub
parent 85753130d9
commit 6d710629af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 55 deletions

View File

@ -3,8 +3,7 @@ use std::task::{Context, Poll};
use std::{fmt, mem}; use std::{fmt, mem};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures_core::Stream; use futures_core::{ready, Stream};
use futures_util::ready;
use pin_project::pin_project; use pin_project::pin_project;
use crate::error::Error; use crate::error::Error;

View File

@ -338,7 +338,7 @@ where
.map_err(|e| log::error!("Init http service error: {:?}", e)))?; .map_err(|e| log::error!("Init http service error: {:?}", e)))?;
this = self.as_mut().project(); this = self.as_mut().project();
*this.upgrade = Some(upgrade); *this.upgrade = Some(upgrade);
this.fut_ex.set(None); this.fut_upg.set(None);
} }
let result = ready!(this let result = ready!(this

View File

@ -9,7 +9,6 @@ use actix_rt::net::TcpStream;
use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory}; use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory};
use bytes::Bytes; use bytes::Bytes;
use futures_core::{ready, Future}; use futures_core::{ready, Future};
use futures_util::future::ok;
use h2::server::{self, Handshake}; use h2::server::{self, Handshake};
use pin_project::pin_project; use pin_project::pin_project;
@ -175,9 +174,9 @@ where
Error = DispatchError, Error = DispatchError,
InitError = (), InitError = (),
> { > {
pipeline_factory(|io: TcpStream| { pipeline_factory(|io: TcpStream| async {
let peer_addr = io.peer_addr().ok(); let peer_addr = io.peer_addr().ok();
ok((io, Protocol::Http1, peer_addr)) Ok((io, Protocol::Http1, peer_addr))
}) })
.and_then(self) .and_then(self)
} }
@ -227,7 +226,7 @@ mod openssl {
.map_err(TlsError::Tls) .map_err(TlsError::Tls)
.map_init_err(|_| panic!()), .map_init_err(|_| panic!()),
) )
.and_then(|io: SslStream<TcpStream>| { .and_then(|io: SslStream<TcpStream>| async {
let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() { let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() {
if protos.windows(2).any(|window| window == b"h2") { if protos.windows(2).any(|window| window == b"h2") {
Protocol::Http2 Protocol::Http2
@ -238,7 +237,7 @@ mod openssl {
Protocol::Http1 Protocol::Http1
}; };
let peer_addr = io.get_ref().peer_addr().ok(); let peer_addr = io.get_ref().peer_addr().ok();
ok((io, proto, peer_addr)) Ok((io, proto, peer_addr))
}) })
.and_then(self.map_err(TlsError::Service)) .and_then(self.map_err(TlsError::Service))
} }
@ -295,7 +294,7 @@ mod rustls {
.map_err(TlsError::Tls) .map_err(TlsError::Tls)
.map_init_err(|_| panic!()), .map_init_err(|_| panic!()),
) )
.and_then(|io: TlsStream<TcpStream>| { .and_then(|io: TlsStream<TcpStream>| async {
let proto = if let Some(protos) = io.get_ref().1.get_alpn_protocol() { let proto = if let Some(protos) = io.get_ref().1.get_alpn_protocol() {
if protos.windows(2).any(|window| window == b"h2") { if protos.windows(2).any(|window| window == b"h2") {
Protocol::Http2 Protocol::Http2
@ -306,7 +305,7 @@ mod rustls {
Protocol::Http1 Protocol::Http1
}; };
let peer_addr = io.get_ref().0.peer_addr().ok(); let peer_addr = io.get_ref().0.peer_addr().ok();
ok((io, proto, peer_addr)) Ok((io, proto, peer_addr))
}) })
.and_then(self.map_err(TlsError::Service)) .and_then(self.map_err(TlsError::Service))
} }
@ -413,7 +412,7 @@ where
.map_err(|e| log::error!("Init http service error: {:?}", e)))?; .map_err(|e| log::error!("Init http service error: {:?}", e)))?;
this = self.as_mut().project(); this = self.as_mut().project();
*this.upgrade = Some(upgrade); *this.upgrade = Some(upgrade);
this.fut_ex.set(None); this.fut_upg.set(None);
} }
let result = ready!(this let result = ready!(this
@ -645,53 +644,30 @@ where
{ {
type Output = Result<(), DispatchError>; type Output = Result<(), DispatchError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().state.poll(cx) match self.as_mut().project().state.project() {
}
}
impl<T, S, B, X, U> State<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>,
S::Error: Into<Error> + 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
X: Service<Request, Response = Request>,
X::Error: Into<Error>,
U: Service<(Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display,
{
fn poll(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), DispatchError>> {
match self.as_mut().project() {
StateProj::H1(disp) => disp.poll(cx), StateProj::H1(disp) => disp.poll(cx),
StateProj::H2(disp) => disp.poll(cx), StateProj::H2(disp) => disp.poll(cx),
StateProj::H2Handshake(ref mut data) => { StateProj::H2Handshake(data) => {
let conn = if let Some(ref mut item) = data { match ready!(Pin::new(&mut data.as_mut().unwrap().0).poll(cx)) {
match Pin::new(&mut item.0).poll(cx) { Ok(conn) => {
Poll::Ready(Ok(conn)) => conn, let (_, cfg, srv, on_connect_data, peer_addr) =
Poll::Ready(Err(err)) => { data.take().unwrap();
trace!("H2 handshake error: {}", err); self.as_mut().project().state.set(State::H2(Dispatcher::new(
return Poll::Ready(Err(err.into())); srv,
} conn,
Poll::Pending => return Poll::Pending, on_connect_data,
cfg,
None,
peer_addr,
)));
self.poll(cx)
} }
} else { Err(err) => {
panic!() trace!("H2 handshake error: {}", err);
}; Poll::Ready(Err(err.into()))
let (_, cfg, srv, on_connect_data, peer_addr) = data.take().unwrap(); }
self.set(State::H2(Dispatcher::new( }
srv,
conn,
on_connect_data,
cfg,
None,
peer_addr,
)));
self.poll(cx)
} }
} }
} }