1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-01-18 22:01:50 +01:00

Remove ConnectionLifetime trait. Simplify Acquired handling (#2072)

This commit is contained in:
fakeshadow 2021-03-15 19:56:23 -07:00 committed by GitHub
parent d93314a683
commit 69dd1a9bd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 76 additions and 114 deletions

View File

@ -94,14 +94,6 @@ pub trait Connection {
>; >;
} }
pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static {
/// Close connection
fn close(self: Pin<&mut Self>);
/// Release connection to the connection pool
fn release(self: Pin<&mut Self>);
}
#[doc(hidden)] #[doc(hidden)]
/// HTTP client connection /// HTTP client connection
pub struct IoConnection<T> pub struct IoConnection<T>
@ -110,7 +102,7 @@ where
{ {
io: Option<ConnectionType<T>>, io: Option<ConnectionType<T>>,
created: time::Instant, created: time::Instant,
pool: Option<Acquired<T>>, pool: Acquired<T>,
} }
impl<T> fmt::Debug for IoConnection<T> impl<T> fmt::Debug for IoConnection<T>
@ -130,7 +122,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> IoConnection<T> {
pub(crate) fn new( pub(crate) fn new(
io: ConnectionType<T>, io: ConnectionType<T>,
created: time::Instant, created: time::Instant,
pool: Option<Acquired<T>>, pool: Acquired<T>,
) -> Self { ) -> Self {
IoConnection { IoConnection {
pool, pool,
@ -139,13 +131,9 @@ impl<T: AsyncRead + AsyncWrite + Unpin> IoConnection<T> {
} }
} }
pub(crate) fn into_inner(self) -> (ConnectionType<T>, time::Instant) {
(self.io.unwrap(), self.created)
}
#[cfg(test)] #[cfg(test)]
pub(crate) fn into_parts(self) -> (ConnectionType<T>, time::Instant, Acquired<T>) { pub(crate) fn into_parts(self) -> (ConnectionType<T>, time::Instant, Acquired<T>) {
(self.io.unwrap(), self.created, self.pool.unwrap()) (self.io.unwrap(), self.created, self.pool)
} }
async fn send_request<B: MessageBody + 'static, H: Into<RequestHeadType>>( async fn send_request<B: MessageBody + 'static, H: Into<RequestHeadType>>(
@ -173,13 +161,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> IoConnection<T> {
match self.io.take().unwrap() { match self.io.take().unwrap() {
ConnectionType::H1(io) => h1proto::open_tunnel(io, head.into()).await, ConnectionType::H1(io) => h1proto::open_tunnel(io, head.into()).await,
ConnectionType::H2(io) => { ConnectionType::H2(io) => {
if let Some(mut pool) = self.pool.take() { self.pool.release(ConnectionType::H2(io), self.created);
pool.release(IoConnection::new(
ConnectionType::H2(io),
self.created,
None,
));
}
Err(SendRequestError::TunnelNotSupported) Err(SendRequestError::TunnelNotSupported)
} }
} }

View File

@ -7,7 +7,7 @@ use actix_codec::{AsyncRead, AsyncWrite, Framed, ReadBuf};
use bytes::buf::BufMut; use bytes::buf::BufMut;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures_core::Stream; use futures_core::Stream;
use futures_util::{future::poll_fn, SinkExt, StreamExt}; use futures_util::{future::poll_fn, SinkExt};
use crate::error::PayloadError; use crate::error::PayloadError;
use crate::h1; use crate::h1;
@ -19,7 +19,7 @@ use crate::http::{
use crate::message::{RequestHeadType, ResponseHead}; use crate::message::{RequestHeadType, ResponseHead};
use crate::payload::{Payload, PayloadStream}; use crate::payload::{Payload, PayloadStream};
use super::connection::{ConnectionLifetime, ConnectionType, IoConnection}; use super::connection::ConnectionType;
use super::error::{ConnectError, SendRequestError}; use super::error::{ConnectError, SendRequestError};
use super::pool::Acquired; use super::pool::Acquired;
use crate::body::{BodySize, MessageBody}; use crate::body::{BodySize, MessageBody};
@ -29,7 +29,7 @@ pub(crate) async fn send_request<T, B>(
mut head: RequestHeadType, mut head: RequestHeadType,
body: B, body: B,
created: time::Instant, created: time::Instant,
pool: Option<Acquired<T>>, acquired: Acquired<T>,
) -> Result<(ResponseHead, Payload), SendRequestError> ) -> Result<(ResponseHead, Payload), SendRequestError>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
@ -42,9 +42,9 @@ where
if let Some(host) = head.as_ref().uri.host() { if let Some(host) = head.as_ref().uri.host() {
let mut wrt = BytesMut::with_capacity(host.len() + 5).writer(); let mut wrt = BytesMut::with_capacity(host.len() + 5).writer();
let _ = match head.as_ref().uri.port_u16() { match head.as_ref().uri.port_u16() {
None | Some(80) | Some(443) => write!(wrt, "{}", host), None | Some(80) | Some(443) => write!(wrt, "{}", host)?,
Some(port) => write!(wrt, "{}:{}", host, port), Some(port) => write!(wrt, "{}:{}", host, port)?,
}; };
match wrt.get_mut().split().freeze().try_into_value() { match wrt.get_mut().split().freeze().try_into_value() {
@ -64,7 +64,7 @@ where
let io = H1Connection { let io = H1Connection {
created, created,
pool, acquired,
io: Some(io), io: Some(io),
}; };
@ -77,10 +77,8 @@ where
let is_expect = if head.as_ref().headers.contains_key(EXPECT) { let is_expect = if head.as_ref().headers.contains_key(EXPECT) {
match body.size() { match body.size() {
BodySize::None | BodySize::Empty | BodySize::Sized(0) => { BodySize::None | BodySize::Empty | BodySize::Sized(0) => {
let pin_framed = Pin::new(&mut framed); let keep_alive = framed.codec_ref().keepalive();
framed.io_mut().on_release(keep_alive);
let force_close = !pin_framed.codec_ref().keepalive();
release_connection(pin_framed, force_close);
// TODO: use a new variant or a new type better describing error violate // TODO: use a new variant or a new type better describing error violate
// `Requirements for clients` session of above RFC // `Requirements for clients` session of above RFC
@ -128,8 +126,9 @@ where
match pin_framed.codec_ref().message_type() { match pin_framed.codec_ref().message_type() {
h1::MessageType::None => { h1::MessageType::None => {
let force_close = !pin_framed.codec_ref().keepalive(); let keep_alive = pin_framed.codec_ref().keepalive();
release_connection(pin_framed, force_close); pin_framed.io_mut().on_release(keep_alive);
Ok((head, Payload::None)) Ok((head, Payload::None))
} }
_ => { _ => {
@ -151,12 +150,11 @@ where
framed.send((head, BodySize::None).into()).await?; framed.send((head, BodySize::None).into()).await?;
// read response // read response
if let (Some(result), framed) = framed.into_future().await { let head = poll_fn(|cx| Pin::new(&mut framed).poll_next(cx))
let head = result.map_err(SendRequestError::from)?; .await
.ok_or(ConnectError::Disconnected)??;
Ok((head, framed)) Ok((head, framed))
} else {
Err(SendRequestError::from(ConnectError::Disconnected))
}
} }
/// send request body to the peer /// send request body to the peer
@ -165,7 +163,7 @@ pub(crate) async fn send_body<T, B>(
mut framed: Pin<&mut Framed<T, h1::ClientCodec>>, mut framed: Pin<&mut Framed<T, h1::ClientCodec>>,
) -> Result<(), SendRequestError> ) -> Result<(), SendRequestError>
where where
T: ConnectionLifetime + Unpin, T: AsyncRead + AsyncWrite + Unpin + 'static,
B: MessageBody, B: MessageBody,
{ {
actix_rt::pin!(body); actix_rt::pin!(body);
@ -200,7 +198,7 @@ where
} }
} }
SinkExt::flush(Pin::into_inner(framed)).await?; SinkExt::flush(framed.get_mut()).await?;
Ok(()) Ok(())
} }
@ -208,41 +206,37 @@ where
/// HTTP client connection /// HTTP client connection
pub struct H1Connection<T> pub struct H1Connection<T>
where where
T: AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
/// T should be `Unpin` /// T should be `Unpin`
io: Option<T>, io: Option<T>,
created: time::Instant, created: time::Instant,
pool: Option<Acquired<T>>, acquired: Acquired<T>,
} }
impl<T> ConnectionLifetime for H1Connection<T> impl<T> H1Connection<T>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
/// Close connection fn on_release(&mut self, keep_alive: bool) {
fn close(mut self: Pin<&mut Self>) { if keep_alive {
if let Some(mut pool) = self.pool.take() { self.release();
if let Some(io) = self.io.take() { } else {
pool.close(IoConnection::new( self.close();
ConnectionType::H1(io),
self.created,
None,
));
} }
} }
/// Close connection
fn close(&mut self) {
if let Some(io) = self.io.take() {
self.acquired.close(ConnectionType::H1(io));
}
} }
/// Release this connection to the connection pool /// Release this connection to the connection pool
fn release(mut self: Pin<&mut Self>) { fn release(&mut self) {
if let Some(mut pool) = self.pool.take() {
if let Some(io) = self.io.take() { if let Some(io) = self.io.take() {
pool.release(IoConnection::new( self.acquired.release(ConnectionType::H1(io), self.created);
ConnectionType::H1(io),
self.created,
None,
));
}
} }
} }
} }
@ -282,13 +276,19 @@ impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncWrite for H1Connection<T>
} }
#[pin_project::pin_project] #[pin_project::pin_project]
pub(crate) struct PlStream<Io> { pub(crate) struct PlStream<Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
#[pin] #[pin]
framed: Option<Framed<Io, h1::ClientPayloadCodec>>, framed: Option<Framed<H1Connection<Io>, h1::ClientPayloadCodec>>,
} }
impl<Io: ConnectionLifetime> PlStream<Io> { impl<Io> PlStream<Io>
fn new(framed: Framed<Io, h1::ClientCodec>) -> Self { where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
fn new(framed: Framed<H1Connection<Io>, h1::ClientCodec>) -> Self {
let framed = framed.into_map_codec(|codec| codec.into_payload_codec()); let framed = framed.into_map_codec(|codec| codec.into_payload_codec());
PlStream { PlStream {
@ -297,24 +297,26 @@ impl<Io: ConnectionLifetime> PlStream<Io> {
} }
} }
impl<Io: ConnectionLifetime> Stream for PlStream<Io> { impl<Io> Stream for PlStream<Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
type Item = Result<Bytes, PayloadError>; type Item = Result<Bytes, PayloadError>;
fn poll_next( fn poll_next(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> { ) -> Poll<Option<Self::Item>> {
let mut this = self.project(); let mut framed = self.project().framed.as_pin_mut().unwrap();
match this.framed.as_mut().as_pin_mut().unwrap().next_item(cx)? { match framed.as_mut().next_item(cx)? {
Poll::Pending => Poll::Pending, Poll::Pending => Poll::Pending,
Poll::Ready(Some(chunk)) => { Poll::Ready(Some(chunk)) => {
if let Some(chunk) = chunk { if let Some(chunk) = chunk {
Poll::Ready(Some(Ok(chunk))) Poll::Ready(Some(Ok(chunk)))
} else { } else {
let framed = this.framed.as_mut().as_pin_mut().unwrap(); let keep_alive = framed.codec_ref().keepalive();
let force_close = !framed.codec_ref().keepalive(); framed.io_mut().on_release(keep_alive);
release_connection(framed, force_close);
Poll::Ready(None) Poll::Ready(None)
} }
} }
@ -322,14 +324,3 @@ impl<Io: ConnectionLifetime> Stream for PlStream<Io> {
} }
} }
} }
fn release_connection<T, U>(framed: Pin<&mut Framed<T, U>>, force_close: bool)
where
T: ConnectionLifetime,
{
if !force_close && framed.is_read_buf_empty() && framed.is_write_buf_empty() {
framed.io_pin().release()
} else {
framed.io_pin().close()
}
}

View File

@ -17,7 +17,7 @@ use crate::message::{RequestHeadType, ResponseHead};
use crate::payload::Payload; use crate::payload::Payload;
use super::config::ConnectorConfig; use super::config::ConnectorConfig;
use super::connection::{ConnectionType, IoConnection}; use super::connection::ConnectionType;
use super::error::SendRequestError; use super::error::SendRequestError;
use super::pool::Acquired; use super::pool::Acquired;
use crate::client::connection::H2Connection; use crate::client::connection::H2Connection;
@ -27,7 +27,7 @@ pub(crate) async fn send_request<T, B>(
head: RequestHeadType, head: RequestHeadType,
body: B, body: B,
created: time::Instant, created: time::Instant,
pool: Option<Acquired<T>>, acquired: Acquired<T>,
) -> Result<(ResponseHead, Payload), SendRequestError> ) -> Result<(ResponseHead, Payload), SendRequestError>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
@ -103,13 +103,13 @@ where
let res = poll_fn(|cx| io.poll_ready(cx)).await; let res = poll_fn(|cx| io.poll_ready(cx)).await;
if let Err(e) = res { if let Err(e) = res {
release(io, pool, created, e.is_io()); release(io, acquired, created, e.is_io());
return Err(SendRequestError::from(e)); return Err(SendRequestError::from(e));
} }
let resp = match io.send_request(req, eof) { let resp = match io.send_request(req, eof) {
Ok((fut, send)) => { Ok((fut, send)) => {
release(io, pool, created, false); release(io, acquired, created, false);
if !eof { if !eof {
send_body(body, send).await?; send_body(body, send).await?;
@ -117,7 +117,7 @@ where
fut.await.map_err(SendRequestError::from)? fut.await.map_err(SendRequestError::from)?
} }
Err(e) => { Err(e) => {
release(io, pool, created, e.is_io()); release(io, acquired, created, e.is_io());
return Err(e.into()); return Err(e.into());
} }
}; };
@ -181,16 +181,14 @@ async fn send_body<B: MessageBody>(
/// release SendRequest object /// release SendRequest object
fn release<T: AsyncRead + AsyncWrite + Unpin + 'static>( fn release<T: AsyncRead + AsyncWrite + Unpin + 'static>(
io: H2Connection, io: H2Connection,
pool: Option<Acquired<T>>, acquired: Acquired<T>,
created: time::Instant, created: time::Instant,
close: bool, close: bool,
) { ) {
if let Some(mut pool) = pool {
if close { if close {
pool.close(IoConnection::new(ConnectionType::H2(io), created, None)); acquired.close(ConnectionType::H2(io));
} else { } else {
pool.release(IoConnection::new(ConnectionType::H2(io), created, None)); acquired.release(ConnectionType::H2(io), created);
}
} }
} }

View File

@ -217,7 +217,7 @@ where
// construct acquired. It's used to put Io type back to pool/ close the Io type. // construct acquired. It's used to put Io type back to pool/ close the Io type.
// permit is carried with the whole lifecycle of Acquired. // permit is carried with the whole lifecycle of Acquired.
let acquired = Some(Acquired { key, inner, permit }); let acquired = Acquired { key, inner, permit };
// match the connection and spawn new one if did not get anything. // match the connection and spawn new one if did not get anything.
match conn { match conn {
@ -235,7 +235,7 @@ where
acquired, acquired,
)) ))
} else { } else {
let config = &acquired.as_ref().unwrap().inner.config; let config = &acquired.inner.config;
let (sender, connection) = handshake(io, config).await?; let (sender, connection) = handshake(io, config).await?;
Ok(IoConnection::new( Ok(IoConnection::new(
ConnectionType::H2(H2Connection::new(sender, connection)), ConnectionType::H2(H2Connection::new(sender, connection)),
@ -346,14 +346,12 @@ where
Io: AsyncRead + AsyncWrite + Unpin + 'static, Io: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
/// Close the IO. /// Close the IO.
pub(crate) fn close(&mut self, conn: IoConnection<Io>) { pub(crate) fn close(&self, conn: ConnectionType<Io>) {
let (conn, _) = conn.into_inner();
self.inner.close(conn); self.inner.close(conn);
} }
/// Release IO back into pool. /// Release IO back into pool.
pub(crate) fn release(&mut self, conn: IoConnection<Io>) { pub(crate) fn release(&self, conn: ConnectionType<Io>, created: Instant) {
let (io, created) = conn.into_inner();
let Acquired { key, inner, .. } = self; let Acquired { key, inner, .. } = self;
inner inner
@ -362,12 +360,12 @@ where
.entry(key.clone()) .entry(key.clone())
.or_insert_with(VecDeque::new) .or_insert_with(VecDeque::new)
.push_back(PooledConnection { .push_back(PooledConnection {
conn: io, conn,
created, created,
used: Instant::now(), used: Instant::now(),
}); });
let _ = &mut self.permit; let _ = &self.permit;
} }
} }
@ -447,8 +445,8 @@ mod test {
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
let (conn, created, mut acquired) = conn.into_parts(); let (conn, created, acquired) = conn.into_parts();
acquired.release(IoConnection::new(conn, created, None)); acquired.release(conn, created);
} }
#[actix_rt::test] #[actix_rt::test]

View File

@ -1,5 +1,3 @@
use std::task::Poll;
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use futures_util::future::{ready, Ready}; use futures_util::future::{ready, Ready};

View File

@ -1,5 +1,3 @@
use std::task::Poll;
use actix_codec::Framed; use actix_codec::Framed;
use actix_service::{Service, ServiceFactory}; use actix_service::{Service, ServiceFactory};
use futures_core::future::LocalBoxFuture; use futures_core::future::LocalBoxFuture;

View File

@ -1,6 +1,5 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::rc::Rc; use std::rc::Rc;
use std::task::Poll;
use actix_http::{Extensions, Request, Response}; use actix_http::{Extensions, Request, Response};
use actix_router::{Path, ResourceDef, Router, Url}; use actix_router::{Path, ResourceDef, Router, Url};

View File

@ -2,7 +2,6 @@ use std::cell::RefCell;
use std::fmt; use std::fmt;
use std::future::Future; use std::future::Future;
use std::rc::Rc; use std::rc::Rc;
use std::task::Poll;
use actix_http::{Error, Extensions, Response}; use actix_http::{Error, Extensions, Response};
use actix_router::IntoPattern; use actix_router::IntoPattern;

View File

@ -2,7 +2,6 @@ use std::cell::RefCell;
use std::fmt; use std::fmt;
use std::future::Future; use std::future::Future;
use std::rc::Rc; use std::rc::Rc;
use std::task::Poll;
use actix_http::Extensions; use actix_http::Extensions;
use actix_router::{ResourceDef, Router}; use actix_router::{ResourceDef, Router};