1
0
mirror of https://github.com/fafhrd91/actix-web synced 2024-11-24 08:22:59 +01:00
actix-web/src/h2.rs

325 lines
10 KiB
Rust
Raw Normal View History

use std::{io, cmp, mem};
use std::rc::Rc;
2017-11-02 20:54:09 +01:00
use std::io::{Read, Write};
2017-11-04 21:24:57 +01:00
use std::time::Duration;
use std::net::SocketAddr;
use std::collections::VecDeque;
2017-11-04 21:24:57 +01:00
use actix::Arbiter;
use http::request::Parts;
2017-11-04 17:07:44 +01:00
use http2::{Reason, RecvStream};
use http2::server::{Server, Handshake, Respond};
2017-11-02 20:54:09 +01:00
use bytes::{Buf, Bytes};
use futures::{Async, Poll, Future, Stream};
2017-11-02 20:54:09 +01:00
use tokio_io::{AsyncRead, AsyncWrite};
2017-11-04 21:24:57 +01:00
use tokio_core::reactor::Timeout;
2017-11-02 20:54:09 +01:00
2017-11-25 07:15:52 +01:00
use pipeline::Pipeline;
2017-11-06 10:24:49 +01:00
use h2writer::H2Writer;
use channel::HttpHandler;
2017-11-16 07:06:28 +01:00
use error::PayloadError;
2017-11-07 01:23:58 +01:00
use encoding::PayloadType;
2017-11-25 07:15:52 +01:00
use httpcodes::HTTPNotFound;
use httprequest::HttpRequest;
2017-11-16 07:06:28 +01:00
use payload::{Payload, PayloadWriter};
2017-11-04 21:24:57 +01:00
const KEEPALIVE_PERIOD: u64 = 15; // seconds
pub(crate) struct Http2<T, H>
where T: AsyncRead + AsyncWrite + 'static, H: 'static
{
router: Rc<Vec<H>>,
addr: Option<SocketAddr>,
state: State<IoWrapper<T>>,
2017-11-04 17:07:44 +01:00
disconnected: bool,
tasks: VecDeque<Entry>,
2017-11-04 21:24:57 +01:00
keepalive_timer: Option<Timeout>,
}
enum State<T: AsyncRead + AsyncWrite> {
Handshake(Handshake<T, Bytes>),
Server(Server<T, Bytes>),
Empty,
}
impl<T, H> Http2<T, H>
where T: AsyncRead + AsyncWrite + 'static,
H: HttpHandler + 'static
{
pub fn new(stream: T, addr: Option<SocketAddr>, router: Rc<Vec<H>>, buf: Bytes) -> Self {
Http2{ router: router,
addr: addr,
2017-11-04 17:07:44 +01:00
disconnected: false,
tasks: VecDeque::new(),
state: State::Handshake(
2017-11-04 21:24:57 +01:00
Server::handshake(IoWrapper{unread: Some(buf), inner: stream})),
keepalive_timer: None,
}
}
pub fn poll(&mut self) -> Poll<(), ()> {
2017-11-04 17:07:44 +01:00
// server
if let State::Server(ref mut server) = self.state {
2017-11-04 21:24:57 +01:00
// keep-alive timer
if let Some(ref mut timeout) = self.keepalive_timer {
match timeout.poll() {
2017-11-09 04:31:25 +01:00
Ok(Async::Ready(_)) => {
trace!("Keep-alive timeout, close connection");
return Ok(Async::Ready(()))
}
2017-11-04 21:24:57 +01:00
Ok(Async::NotReady) => (),
Err(_) => unreachable!(),
}
}
2017-11-04 17:07:44 +01:00
loop {
let mut not_ready = true;
// check in-flight connections
for item in &mut self.tasks {
// read payload
item.poll_payload();
if !item.eof {
2017-11-25 07:15:52 +01:00
match item.task.poll_io(&mut item.stream) {
2017-11-04 17:07:44 +01:00
Ok(Async::Ready(ready)) => {
item.eof = true;
if ready {
item.finished = true;
}
not_ready = false;
},
Ok(Async::NotReady) => (),
2017-11-25 18:28:25 +01:00
Err(err) => {
error!("Unhandled error: {}", err);
2017-11-04 17:07:44 +01:00
item.eof = true;
item.error = true;
item.stream.reset(Reason::INTERNAL_ERROR);
}
}
} else if !item.finished {
match item.task.poll() {
Ok(Async::NotReady) => (),
Ok(Async::Ready(_)) => {
not_ready = false;
item.finished = true;
},
2017-11-25 18:28:25 +01:00
Err(err) => {
2017-11-04 17:07:44 +01:00
item.error = true;
item.finished = true;
2017-11-25 18:28:25 +01:00
error!("Unhandled error: {}", err);
2017-11-04 17:07:44 +01:00
}
}
}
}
// cleanup finished tasks
while !self.tasks.is_empty() {
if self.tasks[0].eof && self.tasks[0].finished || self.tasks[0].error {
self.tasks.pop_front();
} else {
break
}
}
// get request
if !self.disconnected {
match server.poll() {
Ok(Async::Ready(None)) => {
not_ready = false;
self.disconnected = true;
for entry in &mut self.tasks {
entry.task.disconnected()
}
},
2017-11-18 17:50:07 +01:00
Ok(Async::Ready(Some((req, resp)))) => {
2017-11-04 17:07:44 +01:00
not_ready = false;
let (parts, body) = req.into_parts();
// stop keepalive timer
2017-11-04 21:24:57 +01:00
self.keepalive_timer.take();
self.tasks.push_back(
2017-11-10 22:26:12 +01:00
Entry::new(parts, body, resp, self.addr, &self.router));
2017-11-04 17:07:44 +01:00
}
2017-11-04 21:24:57 +01:00
Ok(Async::NotReady) => {
// start keep-alive timer
2017-11-07 01:23:58 +01:00
if self.tasks.is_empty() && self.keepalive_timer.is_none() {
trace!("Start keep-alive timer");
let mut timeout = Timeout::new(
Duration::new(KEEPALIVE_PERIOD, 0),
Arbiter::handle()).unwrap();
// register timeout
let _ = timeout.poll();
self.keepalive_timer = Some(timeout);
2017-11-04 21:24:57 +01:00
}
}
Err(err) => {
trace!("Connection error: {}", err);
self.disconnected = true;
for entry in &mut self.tasks {
entry.task.disconnected()
}
self.keepalive_timer.take();
},
2017-11-04 17:07:44 +01:00
}
}
if not_ready {
if self.tasks.is_empty() && self.disconnected {
return Ok(Async::Ready(()))
} else {
return Ok(Async::NotReady)
}
}
}
}
// handshake
self.state = if let State::Handshake(ref mut handshake) = self.state {
match handshake.poll() {
Ok(Async::Ready(srv)) => {
State::Server(srv)
},
Ok(Async::NotReady) =>
return Ok(Async::NotReady),
Err(err) => {
trace!("Error handling connection: {}", err);
return Err(())
}
}
} else {
mem::replace(&mut self.state, State::Empty)
};
2017-11-04 17:07:44 +01:00
self.poll()
}
}
struct Entry {
2017-11-25 07:15:52 +01:00
task: Pipeline,
2017-11-07 01:23:58 +01:00
payload: PayloadType,
recv: RecvStream,
2017-11-04 17:07:44 +01:00
stream: H2Writer,
eof: bool,
error: bool,
finished: bool,
2017-11-04 17:07:44 +01:00
reof: bool,
capacity: usize,
}
impl Entry {
fn new<H>(parts: Parts,
recv: RecvStream,
resp: Respond<Bytes>,
addr: Option<SocketAddr>,
router: &Rc<Vec<H>>) -> Entry
where H: HttpHandler + 'static
{
let path = parts.uri.path().to_owned();
let query = parts.uri.query().unwrap_or("").to_owned();
2017-11-27 04:00:57 +01:00
// Payload and Content-Encoding
let (psender, payload) = Payload::new(false);
let mut req = HttpRequest::new(
2017-11-27 04:00:57 +01:00
parts.method, path, parts.version, parts.headers, query, payload);
2017-11-06 18:35:52 +01:00
// set remote addr
req.set_remove_addr(addr);
2017-11-25 07:15:52 +01:00
// Payload sender
let psender = PayloadType::new(req.headers(), psender);
// start request processing
let mut task = None;
for h in router.iter() {
if req.path().starts_with(h.prefix()) {
2017-11-27 04:00:57 +01:00
task = Some(h.handle(req));
break
}
}
2017-11-25 07:15:52 +01:00
Entry {task: task.unwrap_or_else(|| Pipeline::error(HTTPNotFound)),
2017-11-07 01:23:58 +01:00
payload: psender,
recv: recv,
2017-11-04 17:07:44 +01:00
stream: H2Writer::new(resp),
eof: false,
error: false,
2017-11-04 17:07:44 +01:00
finished: false,
reof: false,
capacity: 0,
}
}
fn poll_payload(&mut self) {
if !self.reof {
match self.recv.poll() {
Ok(Async::Ready(Some(chunk))) => {
2017-11-07 01:23:58 +01:00
self.payload.feed_data(chunk);
2017-11-04 17:07:44 +01:00
},
Ok(Async::Ready(None)) => {
self.reof = true;
},
Ok(Async::NotReady) => (),
Err(err) => {
2017-11-07 01:23:58 +01:00
self.payload.set_error(PayloadError::Http2(err))
2017-11-04 17:07:44 +01:00
}
}
2017-11-07 01:23:58 +01:00
let capacity = self.payload.capacity();
2017-11-04 17:07:44 +01:00
if self.capacity != capacity {
self.capacity = capacity;
if let Err(err) = self.recv.release_capacity().release_capacity(capacity) {
2017-11-07 01:23:58 +01:00
self.payload.set_error(PayloadError::Http2(err))
2017-11-04 17:07:44 +01:00
}
}
}
}
}
2017-11-02 20:54:09 +01:00
struct IoWrapper<T> {
unread: Option<Bytes>,
inner: T,
}
impl<T: Read> Read for IoWrapper<T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if let Some(mut bytes) = self.unread.take() {
let size = cmp::min(buf.len(), bytes.len());
buf[..size].copy_from_slice(&bytes[..size]);
if bytes.len() > size {
bytes.split_to(size);
2017-11-02 20:54:09 +01:00
self.unread = Some(bytes);
}
Ok(size)
} else {
self.inner.read(buf)
}
}
}
impl<T: Write> Write for IoWrapper<T> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl<T: AsyncRead + 'static> AsyncRead for IoWrapper<T> {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}
}
impl<T: AsyncWrite + 'static> AsyncWrite for IoWrapper<T> {
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.inner.shutdown()
}
fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
self.inner.write_buf(buf)
}
}