1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-25 00:12:59 +01:00
actix-extras/src/server/h2.rs

440 lines
15 KiB
Rust
Raw Normal View History

2018-04-14 01:02:01 +02:00
use std::collections::VecDeque;
2017-11-02 20:54:09 +01:00
use std::io::{Read, Write};
use std::net::SocketAddr;
2018-04-14 01:02:01 +02:00
use std::rc::Rc;
2018-05-25 06:03:16 +02:00
use std::time::{Duration, Instant};
2018-04-14 01:02:01 +02:00
use std::{cmp, io, mem};
2017-11-02 20:54:09 +01:00
use bytes::{Buf, Bytes};
2018-04-14 01:02:01 +02:00
use futures::{Async, Future, Poll, Stream};
use http2::server::{self, Connection, Handshake, SendResponse};
use http2::{Reason, RecvStream};
use modhttp::request::Parts;
use tokio_io::{AsyncRead, AsyncWrite};
2018-05-25 06:03:16 +02:00
use tokio_timer::Delay;
2017-11-02 20:54:09 +01:00
2018-06-18 01:45:54 +02:00
use error::{Error, PayloadError};
2018-06-25 06:58:04 +02:00
use http::{StatusCode, Version};
2018-04-14 01:02:01 +02:00
use payload::{Payload, PayloadStatus, PayloadWriter};
2018-04-17 21:55:13 +02:00
use uri::Url;
2018-06-25 06:58:04 +02:00
use super::error::ServerError;
2018-04-14 01:02:01 +02:00
use super::h2writer::H2Writer;
2018-06-24 06:42:20 +02:00
use super::input::PayloadType;
2018-01-12 03:35:05 +01:00
use super::settings::WorkerSettings;
use super::{HttpHandler, HttpHandlerTask, Writer};
2018-01-12 03:35:05 +01:00
bitflags! {
struct Flags: u8 {
const DISCONNECTED = 0b0000_0010;
}
}
/// HTTP/2 Transport
2018-04-14 01:02:01 +02:00
pub(crate) struct Http2<T, H>
where
T: AsyncRead + AsyncWrite + 'static,
2018-06-18 01:45:54 +02:00
H: HttpHandler + 'static,
{
flags: Flags,
2017-12-14 06:38:47 +01:00
settings: Rc<WorkerSettings<H>>,
addr: Option<SocketAddr>,
state: State<IoWrapper<T>>,
2018-03-18 19:05:44 +01:00
tasks: VecDeque<Entry<H>>,
2018-05-25 06:03:16 +02:00
keepalive_timer: Option<Delay>,
}
enum State<T: AsyncRead + AsyncWrite> {
Handshake(Handshake<T, Bytes>),
2018-01-12 01:22:27 +01:00
Connection(Connection<T, Bytes>),
Empty,
}
impl<T, H> Http2<T, H>
2018-04-14 01:02:01 +02:00
where
T: AsyncRead + AsyncWrite + 'static,
H: HttpHandler + 'static,
{
2018-04-14 01:02:01 +02:00
pub fn new(
2018-04-29 07:55:47 +02:00
settings: Rc<WorkerSettings<H>>, io: T, addr: Option<SocketAddr>, buf: Bytes,
2018-04-14 01:02:01 +02:00
) -> Self {
Http2 {
flags: Flags::empty(),
tasks: VecDeque::new(),
state: State::Handshake(server::handshake(IoWrapper {
unread: Some(buf),
inner: io,
})),
keepalive_timer: None,
addr,
settings,
2017-11-04 21:24:57 +01:00
}
}
2018-01-04 07:43:44 +01:00
pub(crate) fn shutdown(&mut self) {
self.state = State::Empty;
self.tasks.clear();
self.keepalive_timer.take();
}
2017-12-29 01:25:47 +01:00
pub fn settings(&self) -> &WorkerSettings<H> {
self.settings.as_ref()
}
pub fn poll(&mut self) -> Poll<(), ()> {
2017-11-04 17:07:44 +01:00
// server
2018-01-12 01:22:27 +01:00
if let State::Connection(ref mut conn) = self.state {
2017-11-04 21:24:57 +01:00
// keep-alive timer
if let Some(ref mut timeout) = self.keepalive_timer {
match timeout.poll() {
2017-11-09 04:31:25 +01:00
Ok(Async::Ready(_)) => {
trace!("Keep-alive timeout, close connection");
2018-04-14 01:02:01 +02:00
return Ok(Async::Ready(()));
2017-11-09 04:31:25 +01:00
}
2017-11-04 21:24:57 +01:00
Ok(Async::NotReady) => (),
Err(_) => unreachable!(),
}
}
2017-11-04 17:07:44 +01:00
loop {
let mut not_ready = true;
// check in-flight connections
for item in &mut self.tasks {
// read payload
item.poll_payload();
if !item.flags.contains(EntryFlags::EOF) {
let retry = item.payload.need_read() == PayloadStatus::Read;
2018-02-27 05:07:22 +01:00
loop {
match item.task.poll_io(&mut item.stream) {
Ok(Async::Ready(ready)) => {
if ready {
item.flags.insert(
2018-04-14 01:02:01 +02:00
EntryFlags::EOF | EntryFlags::FINISHED,
);
} else {
item.flags.insert(EntryFlags::EOF);
2018-02-27 05:07:22 +01:00
}
not_ready = false;
2018-04-14 01:02:01 +02:00
}
2018-02-27 05:07:22 +01:00
Ok(Async::NotReady) => {
if item.payload.need_read() == PayloadStatus::Read
&& !retry
{
2018-04-14 01:02:01 +02:00
continue;
2018-02-27 05:07:22 +01:00
}
2018-04-14 01:02:01 +02:00
}
2018-02-27 05:07:22 +01:00
Err(err) => {
error!("Unhandled error: {}", err);
item.flags.insert(
2018-05-17 21:20:20 +02:00
EntryFlags::EOF
| EntryFlags::ERROR
2018-04-14 01:02:01 +02:00
| EntryFlags::WRITE_DONE,
);
2018-02-27 05:07:22 +01:00
item.stream.reset(Reason::INTERNAL_ERROR);
2017-11-04 17:07:44 +01:00
}
}
2018-04-14 01:02:01 +02:00
break;
2017-11-04 17:07:44 +01:00
}
} else if !item.flags.contains(EntryFlags::FINISHED) {
2018-06-18 01:45:54 +02:00
match item.task.poll_completed() {
2017-11-04 17:07:44 +01:00
Ok(Async::NotReady) => (),
Ok(Async::Ready(_)) => {
not_ready = false;
item.flags.insert(EntryFlags::FINISHED);
2018-04-14 01:02:01 +02:00
}
2017-11-25 18:28:25 +01:00
Err(err) => {
item.flags.insert(
2018-05-17 21:20:20 +02:00
EntryFlags::ERROR
| EntryFlags::WRITE_DONE
2018-04-14 01:02:01 +02:00
| EntryFlags::FINISHED,
);
2017-11-25 18:28:25 +01:00
error!("Unhandled error: {}", err);
2017-11-04 17:07:44 +01:00
}
}
}
if !item.flags.contains(EntryFlags::WRITE_DONE) {
match item.stream.poll_completed(false) {
Ok(Async::NotReady) => (),
Ok(Async::Ready(_)) => {
not_ready = false;
item.flags.insert(EntryFlags::WRITE_DONE);
}
Err(_err) => {
item.flags.insert(EntryFlags::ERROR);
}
}
}
2017-11-04 17:07:44 +01:00
}
// cleanup finished tasks
while !self.tasks.is_empty() {
2018-04-14 01:02:01 +02:00
if self.tasks[0].flags.contains(EntryFlags::EOF)
&& self.tasks[0].flags.contains(EntryFlags::WRITE_DONE)
|| self.tasks[0].flags.contains(EntryFlags::ERROR)
{
2017-11-04 17:07:44 +01:00
self.tasks.pop_front();
} else {
2018-04-14 01:02:01 +02:00
break;
2017-11-04 17:07:44 +01:00
}
}
// get request
if !self.flags.contains(Flags::DISCONNECTED) {
2018-01-12 01:22:27 +01:00
match conn.poll() {
2017-11-04 17:07:44 +01:00
Ok(Async::Ready(None)) => {
not_ready = false;
self.flags.insert(Flags::DISCONNECTED);
2017-11-04 17:07:44 +01:00
for entry in &mut self.tasks {
entry.task.disconnected()
}
2018-04-14 01:02:01 +02:00
}
2017-11-18 17:50:07 +01:00
Ok(Async::Ready(Some((req, resp)))) => {
2017-11-04 17:07:44 +01:00
not_ready = false;
let (parts, body) = req.into_parts();
// stop keepalive timer
2017-11-04 21:24:57 +01:00
self.keepalive_timer.take();
2018-04-14 01:02:01 +02:00
self.tasks.push_back(Entry::new(
parts,
body,
resp,
self.addr,
&self.settings,
));
2017-11-04 17:07:44 +01:00
}
2017-11-04 21:24:57 +01:00
Ok(Async::NotReady) => {
// start keep-alive timer
2017-12-14 06:38:47 +01:00
if self.tasks.is_empty() {
if self.settings.keep_alive_enabled() {
let keep_alive = self.settings.keep_alive();
2017-12-14 06:38:47 +01:00
if keep_alive > 0 && self.keepalive_timer.is_none() {
trace!("Start keep-alive timer");
2018-05-25 06:03:16 +02:00
let mut timeout = Delay::new(
Instant::now()
+ Duration::new(keep_alive, 0),
);
2017-12-14 06:38:47 +01:00
// register timeout
let _ = timeout.poll();
self.keepalive_timer = Some(timeout);
}
} else {
// keep-alive disable, drop connection
2018-04-14 01:02:01 +02:00
return conn.poll_close().map_err(|e| {
error!("Error during connection close: {}", e)
});
2017-12-14 06:38:47 +01:00
}
} else {
// keep-alive unset, rely on operating system
2018-04-14 01:02:01 +02:00
return Ok(Async::NotReady);
2017-11-04 21:24:57 +01:00
}
}
Err(err) => {
trace!("Connection error: {}", err);
self.flags.insert(Flags::DISCONNECTED);
2017-11-04 21:24:57 +01:00
for entry in &mut self.tasks {
entry.task.disconnected()
}
self.keepalive_timer.take();
2018-04-14 01:02:01 +02:00
}
2017-11-04 17:07:44 +01:00
}
}
if not_ready {
2018-04-14 01:02:01 +02:00
if self.tasks.is_empty() && self.flags.contains(Flags::DISCONNECTED)
{
2018-05-17 21:20:20 +02:00
return conn
.poll_close()
2018-04-14 01:02:01 +02:00
.map_err(|e| error!("Error during connection close: {}", e));
2017-11-04 17:07:44 +01:00
} else {
2018-04-14 01:02:01 +02:00
return Ok(Async::NotReady);
2017-11-04 17:07:44 +01:00
}
}
}
}
// handshake
self.state = if let State::Handshake(ref mut handshake) = self.state {
match handshake.poll() {
2018-04-14 01:02:01 +02:00
Ok(Async::Ready(conn)) => State::Connection(conn),
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => {
trace!("Error handling connection: {}", err);
2018-04-14 01:02:01 +02:00
return Err(());
}
}
} else {
mem::replace(&mut self.state, State::Empty)
};
2017-11-04 17:07:44 +01:00
self.poll()
}
}
bitflags! {
struct EntryFlags: u8 {
const EOF = 0b0000_0001;
const REOF = 0b0000_0010;
const ERROR = 0b0000_0100;
const FINISHED = 0b0000_1000;
const WRITE_DONE = 0b0001_0000;
}
}
2018-06-18 01:45:54 +02:00
enum EntryPipe<H: HttpHandler> {
Task(H::Task),
Error(Box<HttpHandlerTask>),
}
impl<H: HttpHandler> EntryPipe<H> {
fn disconnected(&mut self) {
match *self {
EntryPipe::Task(ref mut task) => task.disconnected(),
EntryPipe::Error(ref mut task) => task.disconnected(),
}
}
fn poll_io(&mut self, io: &mut Writer) -> Poll<bool, Error> {
match *self {
EntryPipe::Task(ref mut task) => task.poll_io(io),
EntryPipe::Error(ref mut task) => task.poll_io(io),
}
}
fn poll_completed(&mut self) -> Poll<(), Error> {
match *self {
EntryPipe::Task(ref mut task) => task.poll_completed(),
EntryPipe::Error(ref mut task) => task.poll_completed(),
}
}
}
struct Entry<H: HttpHandler + 'static> {
task: EntryPipe<H>,
2017-11-07 01:23:58 +01:00
payload: PayloadType,
recv: RecvStream,
2018-03-18 19:05:44 +01:00
stream: H2Writer<H>,
flags: EntryFlags,
}
2018-06-18 01:45:54 +02:00
impl<H: HttpHandler + 'static> Entry<H> {
2018-04-14 01:02:01 +02:00
fn new(
parts: Parts, recv: RecvStream, resp: SendResponse<Bytes>,
addr: Option<SocketAddr>, settings: &Rc<WorkerSettings<H>>,
) -> Entry<H>
where
H: HttpHandler + 'static,
{
2017-11-27 04:00:57 +01:00
// Payload and Content-Encoding
let (psender, payload) = Payload::new(false);
2018-06-25 06:58:04 +02:00
let mut msg = settings.get_request_context();
msg.inner.url = Url::new(parts.uri);
msg.inner.method = parts.method;
msg.inner.version = parts.version;
msg.inner.headers = parts.headers;
*msg.inner.payload.borrow_mut() = Some(payload);
msg.inner.addr = addr;
2017-11-25 07:15:52 +01:00
// Payload sender
2018-06-25 06:58:04 +02:00
let psender = PayloadType::new(msg.headers(), psender);
2017-11-25 07:15:52 +01:00
// start request processing
let mut task = None;
2017-12-26 18:00:45 +01:00
for h in settings.handlers().iter_mut() {
2018-06-25 06:58:04 +02:00
msg = match h.handle(msg) {
2017-11-29 22:53:52 +01:00
Ok(t) => {
task = Some(t);
2018-04-14 01:02:01 +02:00
break;
}
2018-06-25 06:58:04 +02:00
Err(msg) => msg,
}
}
2018-04-14 01:02:01 +02:00
Entry {
2018-06-18 01:45:54 +02:00
task: task.map(EntryPipe::Task).unwrap_or_else(|| {
2018-06-25 06:58:04 +02:00
EntryPipe::Error(ServerError::err(
Version::HTTP_2,
StatusCode::NOT_FOUND,
))
2018-06-18 01:45:54 +02:00
}),
2018-04-14 01:02:01 +02:00
payload: psender,
2018-06-24 06:30:58 +02:00
stream: H2Writer::new(resp, Rc::clone(settings)),
2018-04-14 01:02:01 +02:00
flags: EntryFlags::empty(),
recv,
2017-11-04 17:07:44 +01:00
}
}
fn poll_payload(&mut self) {
2018-05-09 00:44:50 +02:00
while !self.flags.contains(EntryFlags::REOF)
&& self.payload.need_read() == PayloadStatus::Read
{
2017-11-04 17:07:44 +01:00
match self.recv.poll() {
Ok(Async::Ready(Some(chunk))) => {
2018-05-09 00:44:50 +02:00
let l = chunk.len();
2017-11-07 01:23:58 +01:00
self.payload.feed_data(chunk);
2018-05-09 00:44:50 +02:00
if let Err(err) = self.recv.release_capacity().release_capacity(l) {
self.payload.set_error(PayloadError::Http2(err));
break;
}
2018-04-14 01:02:01 +02:00
}
2017-11-04 17:07:44 +01:00
Ok(Async::Ready(None)) => {
self.flags.insert(EntryFlags::REOF);
2018-05-09 00:44:50 +02:00
self.payload.feed_eof();
}
Ok(Async::NotReady) => break,
Err(err) => {
self.payload.set_error(PayloadError::Http2(err));
break;
2017-11-04 17:07:44 +01:00
}
}
}
}
}
2017-11-02 20:54:09 +01:00
struct IoWrapper<T> {
unread: Option<Bytes>,
inner: T,
}
impl<T: Read> Read for IoWrapper<T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if let Some(mut bytes) = self.unread.take() {
let size = cmp::min(buf.len(), bytes.len());
buf[..size].copy_from_slice(&bytes[..size]);
if bytes.len() > size {
bytes.split_to(size);
2017-11-02 20:54:09 +01:00
self.unread = Some(bytes);
}
Ok(size)
} else {
self.inner.read(buf)
}
}
}
impl<T: Write> Write for IoWrapper<T> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl<T: AsyncRead + 'static> AsyncRead for IoWrapper<T> {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}
}
impl<T: AsyncWrite + 'static> AsyncWrite for IoWrapper<T> {
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.inner.shutdown()
}
fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
self.inner.write_buf(buf)
}
}