use std::collections::VecDeque; use std::marker::PhantomData; use std::time::Instant; use std::{fmt, mem}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_service::Service; use actix_utils::cloneable::CloneableService; use bitflags::bitflags; use bytes::{Bytes, BytesMut}; use futures::{try_ready, Async, Future, Poll, Sink, Stream}; use h2::server::{Connection, SendResponse}; use h2::{RecvStream, SendStream}; use http::header::{ HeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, }; use http::HttpTryFrom; use log::{debug, error, trace}; use tokio_timer::Delay; use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::config::ServiceConfig; use crate::error::{DispatchError, Error, ParseError, PayloadError, ResponseError}; use crate::message::ResponseHead; use crate::payload::Payload; use crate::request::Request; use crate::response::Response; const CHUNK_SIZE: usize = 16_384; /// Dispatcher for HTTP/2 protocol pub struct Dispatcher< T: AsyncRead + AsyncWrite, S: Service, B: MessageBody, > { service: CloneableService, connection: Connection, config: ServiceConfig, ka_expire: Instant, ka_timer: Option, _t: PhantomData, } impl Dispatcher where T: AsyncRead + AsyncWrite, S: Service, S::Error: Into, S::Future: 'static, S::Response: Into>, B: MessageBody + 'static, { pub fn new( service: CloneableService, connection: Connection, config: ServiceConfig, timeout: Option, ) -> Self { // let keepalive = config.keep_alive_enabled(); // let flags = if keepalive { // Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED // } else { // Flags::empty() // }; // keep-alive timer let (ka_expire, ka_timer) = if let Some(delay) = timeout { (delay.deadline(), Some(delay)) } else if let Some(delay) = config.keep_alive_timer() { (delay.deadline(), Some(delay)) } else { (config.now(), None) }; Dispatcher { service, config, ka_expire, ka_timer, connection, _t: PhantomData, } } } impl Future for Dispatcher where T: AsyncRead + AsyncWrite, S: Service, S::Error: Into, S::Future: 'static, S::Response: Into>, B: MessageBody + 'static, { type Item = (); type Error = DispatchError; #[inline] fn poll(&mut self) -> Poll { loop { match self.connection.poll()? { Async::Ready(None) => return Ok(Async::Ready(())), Async::Ready(Some((req, res))) => { // update keep-alive expire if self.ka_timer.is_some() { if let Some(expire) = self.config.keep_alive_expire() { self.ka_expire = expire; } } let (parts, body) = req.into_parts(); let mut req = Request::with_payload(body.into()); let head = &mut req.head_mut(); head.uri = parts.uri; head.method = parts.method; head.version = parts.version; head.headers = parts.headers.into(); tokio_current_thread::spawn(ServiceResponse:: { state: ServiceResponseState::ServiceCall( self.service.call(req), Some(res), ), config: self.config.clone(), buffer: None, }) } Async::NotReady => return Ok(Async::NotReady), } } } } struct ServiceResponse { state: ServiceResponseState, config: ServiceConfig, buffer: Option, } enum ServiceResponseState { ServiceCall(F, Option>), SendPayload(SendStream, ResponseBody), } impl ServiceResponse where F: Future, F::Error: Into, F::Item: Into>, B: MessageBody + 'static, { fn prepare_response( &self, head: &ResponseHead, length: &mut BodySize, ) -> http::Response<()> { let mut has_date = false; let mut skip_len = length != &BodySize::Stream; let mut res = http::Response::new(()); *res.status_mut() = head.status; *res.version_mut() = http::Version::HTTP_2; // Content length match head.status { http::StatusCode::NO_CONTENT | http::StatusCode::CONTINUE | http::StatusCode::PROCESSING => *length = BodySize::None, http::StatusCode::SWITCHING_PROTOCOLS => { skip_len = true; *length = BodySize::Stream; } _ => (), } let _ = match length { BodySize::None | BodySize::Stream => None, BodySize::Empty => res .headers_mut() .insert(CONTENT_LENGTH, HeaderValue::from_static("0")), BodySize::Sized(len) => res.headers_mut().insert( CONTENT_LENGTH, HeaderValue::try_from(format!("{}", len)).unwrap(), ), BodySize::Sized64(len) => res.headers_mut().insert( CONTENT_LENGTH, HeaderValue::try_from(format!("{}", len)).unwrap(), ), }; // copy headers for (key, value) in head.headers.iter() { match *key { CONNECTION | TRANSFER_ENCODING => continue, // http2 specific CONTENT_LENGTH if skip_len => continue, DATE => has_date = true, _ => (), } res.headers_mut().append(key, value.clone()); } // set date header if !has_date { let mut bytes = BytesMut::with_capacity(29); self.config.set_date_header(&mut bytes); res.headers_mut() .insert(DATE, HeaderValue::try_from(bytes.freeze()).unwrap()); } res } } impl Future for ServiceResponse where F: Future, F::Error: Into, F::Item: Into>, B: MessageBody + 'static, { type Item = (); type Error = (); fn poll(&mut self) -> Poll { match self.state { ServiceResponseState::ServiceCall(ref mut call, ref mut send) => { match call.poll() { Ok(Async::Ready(res)) => { let (res, body) = res.into().replace_body(()); let mut send = send.take().unwrap(); let mut length = body.length(); let h2_res = self.prepare_response(res.head(), &mut length); let stream = send .send_response(h2_res, length.is_eof()) .map_err(|e| { trace!("Error sending h2 response: {:?}", e); })?; if length.is_eof() { Ok(Async::Ready(())) } else { self.state = ServiceResponseState::SendPayload(stream, body); self.poll() } } Ok(Async::NotReady) => Ok(Async::NotReady), Err(_e) => { let res: Response = Response::InternalServerError().finish(); let (res, body) = res.replace_body(()); let mut send = send.take().unwrap(); let mut length = body.length(); let h2_res = self.prepare_response(res.head(), &mut length); let stream = send .send_response(h2_res, length.is_eof()) .map_err(|e| { trace!("Error sending h2 response: {:?}", e); })?; if length.is_eof() { Ok(Async::Ready(())) } else { self.state = ServiceResponseState::SendPayload( stream, body.into_body(), ); self.poll() } } } } ServiceResponseState::SendPayload(ref mut stream, ref mut body) => loop { loop { if let Some(ref mut buffer) = self.buffer { match stream.poll_capacity().map_err(|e| warn!("{:?}", e))? { Async::NotReady => return Ok(Async::NotReady), Async::Ready(None) => return Ok(Async::Ready(())), Async::Ready(Some(cap)) => { let len = buffer.len(); let bytes = buffer.split_to(std::cmp::min(cap, len)); if let Err(e) = stream.send_data(bytes, false) { warn!("{:?}", e); return Err(()); } else if !buffer.is_empty() { let cap = std::cmp::min(buffer.len(), CHUNK_SIZE); stream.reserve_capacity(cap); } else { self.buffer.take(); } } } } else { match body.poll_next() { Ok(Async::NotReady) => { return Ok(Async::NotReady); } Ok(Async::Ready(None)) => { if let Err(e) = stream.send_data(Bytes::new(), true) { warn!("{:?}", e); return Err(()); } else { return Ok(Async::Ready(())); } } Ok(Async::Ready(Some(chunk))) => { stream.reserve_capacity(std::cmp::min( chunk.len(), CHUNK_SIZE, )); self.buffer = Some(chunk); } Err(e) => { error!("Response payload stream error: {:?}", e); return Err(()); } } } } }, } } }