1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-24 07:53:00 +01:00

refactor payload handling

This commit is contained in:
Nikolay Kim 2019-02-13 13:52:11 -08:00
parent 8d4ce0c956
commit 118606262b
14 changed files with 186 additions and 125 deletions

View File

@ -18,7 +18,7 @@ fn main() {
.client_timeout(1000) .client_timeout(1000)
.client_disconnect(1000) .client_disconnect(1000)
.server_hostname("localhost") .server_hostname("localhost")
.finish(|req: Request| { .finish(|mut req: Request| {
req.body().limit(512).and_then(|bytes: Bytes| { req.body().limit(512).and_then(|bytes: Bytes| {
info!("request body: {:?}", bytes); info!("request body: {:?}", bytes);
let mut res = Response::Ok(); let mut res = Response::Ok();

View File

@ -8,7 +8,7 @@ use futures::Future;
use log::info; use log::info;
use std::env; use std::env;
fn handle_request(req: Request) -> impl Future<Item = Response, Error = Error> { fn handle_request(mut req: Request) -> impl Future<Item = Response, Error = Error> {
req.body().limit(512).from_err().and_then(|bytes: Bytes| { req.body().limit(512).from_err().and_then(|bytes: Bytes| {
info!("request body: {:?}", bytes); info!("request body: {:?}", bytes);
let mut res = Response::Ok(); let mut res = Response::Ok();

View File

@ -13,7 +13,6 @@ use crate::body::{BodyLength, MessageBody};
use crate::error::PayloadError; use crate::error::PayloadError;
use crate::h1; use crate::h1;
use crate::message::RequestHead; use crate::message::RequestHead;
use crate::payload::PayloadStream;
pub(crate) fn send_request<T, B>( pub(crate) fn send_request<T, B>(
io: T, io: T,
@ -205,7 +204,9 @@ pub(crate) struct Payload<Io> {
} }
impl<Io: ConnectionLifetime> Payload<Io> { impl<Io: ConnectionLifetime> Payload<Io> {
pub fn stream(framed: Framed<Io, h1::ClientCodec>) -> PayloadStream { pub fn stream(
framed: Framed<Io, h1::ClientCodec>,
) -> Box<Stream<Item = Bytes, Error = PayloadError>> {
Box::new(Payload { Box::new(Payload {
framed: Some(framed.map_codec(|codec| codec.into_payload_codec())), framed: Some(framed.map_codec(|codec| codec.into_payload_codec())),
}) })

View File

@ -1,4 +1,3 @@
use std::cell::RefCell;
use std::time; use std::time;
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
@ -10,7 +9,7 @@ use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCOD
use http::{request::Request, HttpTryFrom, Version}; use http::{request::Request, HttpTryFrom, Version};
use crate::body::{BodyLength, MessageBody}; use crate::body::{BodyLength, MessageBody};
use crate::message::{RequestHead, ResponseHead}; use crate::message::{Message, RequestHead, ResponseHead};
use super::connection::{ConnectionType, IoConnection}; use super::connection::{ConnectionType, IoConnection};
use super::error::SendRequestError; use super::error::SendRequestError;
@ -103,14 +102,14 @@ where
.and_then(|resp| { .and_then(|resp| {
let (parts, body) = resp.into_parts(); let (parts, body) = resp.into_parts();
let mut head = ResponseHead::default(); let mut head: Message<ResponseHead> = Message::new();
head.version = parts.version; head.version = parts.version;
head.status = parts.status; head.status = parts.status;
head.headers = parts.headers; head.headers = parts.headers;
Ok(ClientResponse { Ok(ClientResponse {
head, head,
payload: RefCell::new(body.into()), payload: body.into(),
}) })
}) })
.from_err() .from_err()

View File

@ -1,5 +1,4 @@
use std::cell::RefCell; use std::fmt;
use std::{fmt, mem};
use bytes::Bytes; use bytes::Bytes;
use futures::{Poll, Stream}; use futures::{Poll, Stream};
@ -7,13 +6,13 @@ use http::{HeaderMap, StatusCode, Version};
use crate::error::PayloadError; use crate::error::PayloadError;
use crate::httpmessage::HttpMessage; use crate::httpmessage::HttpMessage;
use crate::message::{Head, ResponseHead}; use crate::message::{Head, Message, ResponseHead};
use crate::payload::{Payload, PayloadStream}; use crate::payload::{Payload, PayloadStream};
/// Client Response /// Client Response
pub struct ClientResponse { pub struct ClientResponse {
pub(crate) head: ResponseHead, pub(crate) head: Message<ResponseHead>,
pub(crate) payload: RefCell<Payload>, pub(crate) payload: Payload,
} }
impl HttpMessage for ClientResponse { impl HttpMessage for ClientResponse {
@ -23,9 +22,8 @@ impl HttpMessage for ClientResponse {
&self.head.headers &self.head.headers
} }
#[inline] fn take_payload(&mut self) -> Payload {
fn payload(&self) -> Payload<Self::Stream> { std::mem::replace(&mut self.payload, Payload::None)
mem::replace(&mut *self.payload.borrow_mut(), Payload::None)
} }
} }
@ -33,8 +31,8 @@ impl ClientResponse {
/// Create new Request instance /// Create new Request instance
pub fn new() -> ClientResponse { pub fn new() -> ClientResponse {
ClientResponse { ClientResponse {
head: ResponseHead::default(), head: Message::new(),
payload: RefCell::new(Payload::None), payload: Payload::None,
} }
} }
@ -80,7 +78,7 @@ impl ClientResponse {
/// Set response payload /// Set response payload
pub fn set_payload(&mut self, payload: Payload) { pub fn set_payload(&mut self, payload: Payload) {
*self.payload.get_mut() = payload; self.payload = payload;
} }
} }
@ -89,7 +87,7 @@ impl Stream for ClientResponse {
type Error = PayloadError; type Error = PayloadError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
self.payload.get_mut().poll() self.payload.poll()
} }
} }

View File

@ -316,7 +316,9 @@ where
match self.framed.get_codec().message_type() { match self.framed.get_codec().message_type() {
MessageType::Payload => { MessageType::Payload => {
let (ps, pl) = Payload::create(false); let (ps, pl) = Payload::create(false);
req = req.set_payload(crate::Payload::H1(pl)); let (req1, _) =
req.replace_payload(crate::Payload::H1(pl));
req = req1;
self.payload = Some(ps); self.payload = Some(ps);
} }
MessageType::Stream => { MessageType::Stream => {

View File

@ -12,7 +12,6 @@ use log::error;
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::config::{KeepAlive, ServiceConfig}; use crate::config::{KeepAlive, ServiceConfig};
use crate::error::{DispatchError, ParseError}; use crate::error::{DispatchError, ParseError};
use crate::payload::PayloadStream;
use crate::request::Request; use crate::request::Request;
use crate::response::Response; use crate::response::Response;
@ -29,7 +28,7 @@ pub struct H1Service<T, S, B> {
impl<T, S, B> H1Service<T, S, B> impl<T, S, B> H1Service<T, S, B>
where where
S: NewService<Request = Request<PayloadStream>>, S: NewService<Request = Request>,
S::Error: Debug, S::Error: Debug,
S::Response: Into<Response<B>>, S::Response: Into<Response<B>>,
S::Service: 'static, S::Service: 'static,

View File

@ -1,4 +1,4 @@
use std::{mem, str}; use std::str;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use encoding::all::UTF_8; use encoding::all::UTF_8;
@ -21,13 +21,13 @@ use crate::payload::Payload;
/// Trait that implements general purpose operations on http messages /// Trait that implements general purpose operations on http messages
pub trait HttpMessage: Sized { pub trait HttpMessage: Sized {
/// Type of message payload stream /// Type of message payload stream
type Stream: Stream<Item = Bytes, Error = PayloadError> + Sized; type Stream;
/// Read the message headers. /// Read the message headers.
fn headers(&self) -> &HeaderMap; fn headers(&self) -> &HeaderMap;
/// Message payload stream /// Message payload stream
fn payload(&self) -> Payload<Self::Stream>; fn take_payload(&mut self) -> Payload<Self::Stream>;
#[doc(hidden)] #[doc(hidden)]
/// Get a header /// Get a header
@ -130,7 +130,10 @@ pub trait HttpMessage: Sized {
/// } /// }
/// # fn main() {} /// # fn main() {}
/// ``` /// ```
fn body(&self) -> MessageBody<Self> { fn body(&mut self) -> MessageBody<Self>
where
Self::Stream: Stream<Item = Bytes, Error = PayloadError> + Sized,
{
MessageBody::new(self) MessageBody::new(self)
} }
@ -164,7 +167,10 @@ pub trait HttpMessage: Sized {
/// } /// }
/// # fn main() {} /// # fn main() {}
/// ``` /// ```
fn urlencoded<T: DeserializeOwned>(&self) -> UrlEncoded<Self, T> { fn urlencoded<T: DeserializeOwned>(&mut self) -> UrlEncoded<Self, T>
where
Self::Stream: Stream<Item = Bytes, Error = PayloadError>,
{
UrlEncoded::new(self) UrlEncoded::new(self)
} }
@ -200,12 +206,18 @@ pub trait HttpMessage: Sized {
/// } /// }
/// # fn main() {} /// # fn main() {}
/// ``` /// ```
fn json<T: DeserializeOwned>(&self) -> JsonBody<Self, T> { fn json<T: DeserializeOwned + 'static>(&mut self) -> JsonBody<Self, T>
where
Self::Stream: Stream<Item = Bytes, Error = PayloadError> + 'static,
{
JsonBody::new(self) JsonBody::new(self)
} }
/// Return stream of lines. /// Return stream of lines.
fn readlines(&self) -> Readlines<Self> { fn readlines(&mut self) -> Readlines<Self>
where
Self::Stream: Stream<Item = Bytes, Error = PayloadError> + 'static,
{
Readlines::new(self) Readlines::new(self)
} }
} }
@ -220,16 +232,20 @@ pub struct Readlines<T: HttpMessage> {
err: Option<ReadlinesError>, err: Option<ReadlinesError>,
} }
impl<T: HttpMessage> Readlines<T> { impl<T> Readlines<T>
where
T: HttpMessage,
T::Stream: Stream<Item = Bytes, Error = PayloadError>,
{
/// Create a new stream to read request line by line. /// Create a new stream to read request line by line.
fn new(req: &T) -> Self { fn new(req: &mut T) -> Self {
let encoding = match req.encoding() { let encoding = match req.encoding() {
Ok(enc) => enc, Ok(enc) => enc,
Err(err) => return Self::err(err.into()), Err(err) => return Self::err(err.into()),
}; };
Readlines { Readlines {
stream: req.payload(), stream: req.take_payload(),
buff: BytesMut::with_capacity(262_144), buff: BytesMut::with_capacity(262_144),
limit: 262_144, limit: 262_144,
checked_buff: true, checked_buff: true,
@ -256,7 +272,11 @@ impl<T: HttpMessage> Readlines<T> {
} }
} }
impl<T: HttpMessage + 'static> Stream for Readlines<T> { impl<T> Stream for Readlines<T>
where
T: HttpMessage,
T::Stream: Stream<Item = Bytes, Error = PayloadError>,
{
type Item = String; type Item = String;
type Error = ReadlinesError; type Error = ReadlinesError;
@ -362,9 +382,13 @@ pub struct MessageBody<T: HttpMessage> {
fut: Option<Box<Future<Item = Bytes, Error = PayloadError>>>, fut: Option<Box<Future<Item = Bytes, Error = PayloadError>>>,
} }
impl<T: HttpMessage> MessageBody<T> { impl<T> MessageBody<T>
where
T: HttpMessage,
T::Stream: Stream<Item = Bytes, Error = PayloadError>,
{
/// Create `MessageBody` for request. /// Create `MessageBody` for request.
pub fn new(req: &T) -> MessageBody<T> { pub fn new(req: &mut T) -> MessageBody<T> {
let mut len = None; let mut len = None;
if let Some(l) = req.headers().get(header::CONTENT_LENGTH) { if let Some(l) = req.headers().get(header::CONTENT_LENGTH) {
if let Ok(s) = l.to_str() { if let Ok(s) = l.to_str() {
@ -379,9 +403,9 @@ impl<T: HttpMessage> MessageBody<T> {
} }
MessageBody { MessageBody {
stream: req.take_payload(),
limit: 262_144, limit: 262_144,
length: len, length: len,
stream: req.payload(),
fut: None, fut: None,
err: None, err: None,
} }
@ -406,7 +430,8 @@ impl<T: HttpMessage> MessageBody<T> {
impl<T> Future for MessageBody<T> impl<T> Future for MessageBody<T>
where where
T: HttpMessage + 'static, T: HttpMessage,
T::Stream: Stream<Item = Bytes, Error = PayloadError> + 'static,
{ {
type Item = Bytes; type Item = Bytes;
type Error = PayloadError; type Error = PayloadError;
@ -429,7 +454,7 @@ where
// future // future
let limit = self.limit; let limit = self.limit;
self.fut = Some(Box::new( self.fut = Some(Box::new(
mem::replace(&mut self.stream, Payload::None) std::mem::replace(&mut self.stream, Payload::None)
.from_err() .from_err()
.fold(BytesMut::with_capacity(8192), move |mut body, chunk| { .fold(BytesMut::with_capacity(8192), move |mut body, chunk| {
if (body.len() + chunk.len()) > limit { if (body.len() + chunk.len()) > limit {
@ -455,9 +480,13 @@ pub struct UrlEncoded<T: HttpMessage, U> {
fut: Option<Box<Future<Item = U, Error = UrlencodedError>>>, fut: Option<Box<Future<Item = U, Error = UrlencodedError>>>,
} }
impl<T: HttpMessage, U> UrlEncoded<T, U> { impl<T, U> UrlEncoded<T, U>
where
T: HttpMessage,
T::Stream: Stream<Item = Bytes, Error = PayloadError>,
{
/// Create a new future to URL encode a request /// Create a new future to URL encode a request
pub fn new(req: &T) -> UrlEncoded<T, U> { pub fn new(req: &mut T) -> UrlEncoded<T, U> {
// check content type // check content type
if req.content_type().to_lowercase() != "application/x-www-form-urlencoded" { if req.content_type().to_lowercase() != "application/x-www-form-urlencoded" {
return Self::err(UrlencodedError::ContentType); return Self::err(UrlencodedError::ContentType);
@ -482,7 +511,7 @@ impl<T: HttpMessage, U> UrlEncoded<T, U> {
UrlEncoded { UrlEncoded {
encoding, encoding,
stream: req.payload(), stream: req.take_payload(),
limit: 262_144, limit: 262_144,
length: len, length: len,
fut: None, fut: None,
@ -510,7 +539,8 @@ impl<T: HttpMessage, U> UrlEncoded<T, U> {
impl<T, U> Future for UrlEncoded<T, U> impl<T, U> Future for UrlEncoded<T, U>
where where
T: HttpMessage + 'static, T: HttpMessage,
T::Stream: Stream<Item = Bytes, Error = PayloadError> + 'static,
U: DeserializeOwned + 'static, U: DeserializeOwned + 'static,
{ {
type Item = U; type Item = U;
@ -535,7 +565,7 @@ where
// future // future
let encoding = self.encoding; let encoding = self.encoding;
let fut = mem::replace(&mut self.stream, Payload::None) let fut = std::mem::replace(&mut self.stream, Payload::None)
.from_err() .from_err()
.fold(BytesMut::with_capacity(8192), move |mut body, chunk| { .fold(BytesMut::with_capacity(8192), move |mut body, chunk| {
if (body.len() + chunk.len()) > limit { if (body.len() + chunk.len()) > limit {
@ -691,7 +721,7 @@ mod tests {
#[test] #[test]
fn test_urlencoded_error() { fn test_urlencoded_error() {
let req = TestRequest::with_header( let mut req = TestRequest::with_header(
header::CONTENT_TYPE, header::CONTENT_TYPE,
"application/x-www-form-urlencoded", "application/x-www-form-urlencoded",
) )
@ -702,7 +732,7 @@ mod tests {
UrlencodedError::UnknownLength UrlencodedError::UnknownLength
); );
let req = TestRequest::with_header( let mut req = TestRequest::with_header(
header::CONTENT_TYPE, header::CONTENT_TYPE,
"application/x-www-form-urlencoded", "application/x-www-form-urlencoded",
) )
@ -713,7 +743,7 @@ mod tests {
UrlencodedError::Overflow UrlencodedError::Overflow
); );
let req = TestRequest::with_header(header::CONTENT_TYPE, "text/plain") let mut req = TestRequest::with_header(header::CONTENT_TYPE, "text/plain")
.header(header::CONTENT_LENGTH, "10") .header(header::CONTENT_LENGTH, "10")
.finish(); .finish();
assert_eq!( assert_eq!(
@ -724,7 +754,7 @@ mod tests {
#[test] #[test]
fn test_urlencoded() { fn test_urlencoded() {
let req = TestRequest::with_header( let mut req = TestRequest::with_header(
header::CONTENT_TYPE, header::CONTENT_TYPE,
"application/x-www-form-urlencoded", "application/x-www-form-urlencoded",
) )
@ -740,7 +770,7 @@ mod tests {
}) })
); );
let req = TestRequest::with_header( let mut req = TestRequest::with_header(
header::CONTENT_TYPE, header::CONTENT_TYPE,
"application/x-www-form-urlencoded; charset=utf-8", "application/x-www-form-urlencoded; charset=utf-8",
) )
@ -759,19 +789,20 @@ mod tests {
#[test] #[test]
fn test_message_body() { fn test_message_body() {
let req = TestRequest::with_header(header::CONTENT_LENGTH, "xxxx").finish(); let mut req = TestRequest::with_header(header::CONTENT_LENGTH, "xxxx").finish();
match req.body().poll().err().unwrap() { match req.body().poll().err().unwrap() {
PayloadError::UnknownLength => (), PayloadError::UnknownLength => (),
_ => unreachable!("error"), _ => unreachable!("error"),
} }
let req = TestRequest::with_header(header::CONTENT_LENGTH, "1000000").finish(); let mut req =
TestRequest::with_header(header::CONTENT_LENGTH, "1000000").finish();
match req.body().poll().err().unwrap() { match req.body().poll().err().unwrap() {
PayloadError::Overflow => (), PayloadError::Overflow => (),
_ => unreachable!("error"), _ => unreachable!("error"),
} }
let req = TestRequest::default() let mut req = TestRequest::default()
.set_payload(Bytes::from_static(b"test")) .set_payload(Bytes::from_static(b"test"))
.finish(); .finish();
match req.body().poll().ok().unwrap() { match req.body().poll().ok().unwrap() {
@ -779,7 +810,7 @@ mod tests {
_ => unreachable!("error"), _ => unreachable!("error"),
} }
let req = TestRequest::default() let mut req = TestRequest::default()
.set_payload(Bytes::from_static(b"11111111111111")) .set_payload(Bytes::from_static(b"11111111111111"))
.finish(); .finish();
match req.body().limit(5).poll().err().unwrap() { match req.body().limit(5).poll().err().unwrap() {
@ -790,14 +821,14 @@ mod tests {
#[test] #[test]
fn test_readlines() { fn test_readlines() {
let req = TestRequest::default() let mut req = TestRequest::default()
.set_payload(Bytes::from_static( .set_payload(Bytes::from_static(
b"Lorem Ipsum is simply dummy text of the printing and typesetting\n\ b"Lorem Ipsum is simply dummy text of the printing and typesetting\n\
industry. Lorem Ipsum has been the industry's standard dummy\n\ industry. Lorem Ipsum has been the industry's standard dummy\n\
Contrary to popular belief, Lorem Ipsum is not simply random text.", Contrary to popular belief, Lorem Ipsum is not simply random text.",
)) ))
.finish(); .finish();
let mut r = Readlines::new(&req); let mut r = Readlines::new(&mut req);
match r.poll().ok().unwrap() { match r.poll().ok().unwrap() {
Async::Ready(Some(s)) => assert_eq!( Async::Ready(Some(s)) => assert_eq!(
s, s,

View File

@ -2,11 +2,12 @@ use bytes::BytesMut;
use futures::{Future, Poll, Stream}; use futures::{Future, Poll, Stream};
use http::header::CONTENT_LENGTH; use http::header::CONTENT_LENGTH;
use bytes::Bytes;
use mime; use mime;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde_json; use serde_json;
use crate::error::JsonPayloadError; use crate::error::{JsonPayloadError, PayloadError};
use crate::httpmessage::HttpMessage; use crate::httpmessage::HttpMessage;
use crate::payload::Payload; use crate::payload::Payload;
@ -41,7 +42,7 @@ use crate::payload::Payload;
/// } /// }
/// # fn main() {} /// # fn main() {}
/// ``` /// ```
pub struct JsonBody<T: HttpMessage, U: DeserializeOwned> { pub struct JsonBody<T: HttpMessage, U> {
limit: usize, limit: usize,
length: Option<usize>, length: Option<usize>,
stream: Payload<T::Stream>, stream: Payload<T::Stream>,
@ -49,9 +50,14 @@ pub struct JsonBody<T: HttpMessage, U: DeserializeOwned> {
fut: Option<Box<Future<Item = U, Error = JsonPayloadError>>>, fut: Option<Box<Future<Item = U, Error = JsonPayloadError>>>,
} }
impl<T: HttpMessage, U: DeserializeOwned> JsonBody<T, U> { impl<T, U> JsonBody<T, U>
where
T: HttpMessage,
T::Stream: Stream<Item = Bytes, Error = PayloadError> + 'static,
U: DeserializeOwned + 'static,
{
/// Create `JsonBody` for request. /// Create `JsonBody` for request.
pub fn new(req: &T) -> Self { pub fn new(req: &mut T) -> Self {
// check content-type // check content-type
let json = if let Ok(Some(mime)) = req.mime_type() { let json = if let Ok(Some(mime)) = req.mime_type() {
mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
@ -80,7 +86,7 @@ impl<T: HttpMessage, U: DeserializeOwned> JsonBody<T, U> {
JsonBody { JsonBody {
limit: 262_144, limit: 262_144,
length: len, length: len,
stream: req.payload(), stream: req.take_payload(),
fut: None, fut: None,
err: None, err: None,
} }
@ -93,7 +99,12 @@ impl<T: HttpMessage, U: DeserializeOwned> JsonBody<T, U> {
} }
} }
impl<T: HttpMessage + 'static, U: DeserializeOwned + 'static> Future for JsonBody<T, U> { impl<T, U> Future for JsonBody<T, U>
where
T: HttpMessage,
T::Stream: Stream<Item = Bytes, Error = PayloadError> + 'static,
U: DeserializeOwned + 'static,
{
type Item = U; type Item = U;
type Error = JsonPayloadError; type Error = JsonPayloadError;
@ -162,11 +173,11 @@ mod tests {
#[test] #[test]
fn test_json_body() { fn test_json_body() {
let req = TestRequest::default().finish(); let mut req = TestRequest::default().finish();
let mut json = req.json::<MyObject>(); let mut json = req.json::<MyObject>();
assert_eq!(json.poll().err().unwrap(), JsonPayloadError::ContentType); assert_eq!(json.poll().err().unwrap(), JsonPayloadError::ContentType);
let req = TestRequest::default() let mut req = TestRequest::default()
.header( .header(
header::CONTENT_TYPE, header::CONTENT_TYPE,
header::HeaderValue::from_static("application/text"), header::HeaderValue::from_static("application/text"),
@ -175,7 +186,7 @@ mod tests {
let mut json = req.json::<MyObject>(); let mut json = req.json::<MyObject>();
assert_eq!(json.poll().err().unwrap(), JsonPayloadError::ContentType); assert_eq!(json.poll().err().unwrap(), JsonPayloadError::ContentType);
let req = TestRequest::default() let mut req = TestRequest::default()
.header( .header(
header::CONTENT_TYPE, header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"), header::HeaderValue::from_static("application/json"),
@ -188,7 +199,7 @@ mod tests {
let mut json = req.json::<MyObject>().limit(100); let mut json = req.json::<MyObject>().limit(100);
assert_eq!(json.poll().err().unwrap(), JsonPayloadError::Overflow); assert_eq!(json.poll().err().unwrap(), JsonPayloadError::Overflow);
let req = TestRequest::default() let mut req = TestRequest::default()
.header( .header(
header::CONTENT_TYPE, header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"), header::HeaderValue::from_static("application/json"),

View File

@ -2,9 +2,8 @@ use std::cell::{Ref, RefCell, RefMut};
use std::collections::VecDeque; use std::collections::VecDeque;
use std::rc::Rc; use std::rc::Rc;
use http::{HeaderMap, Method, StatusCode, Uri, Version};
use crate::extensions::Extensions; use crate::extensions::Extensions;
use crate::http::{HeaderMap, Method, StatusCode, Uri, Version};
/// Represents various types of connection /// Represents various types of connection
#[derive(Copy, Clone, PartialEq, Debug)] #[derive(Copy, Clone, PartialEq, Debug)]
@ -21,6 +20,12 @@ pub enum ConnectionType {
pub trait Head: Default + 'static { pub trait Head: Default + 'static {
fn clear(&mut self); fn clear(&mut self);
/// Read the message headers.
fn headers(&self) -> &HeaderMap;
/// Mutable reference to the message headers.
fn headers_mut(&mut self) -> &mut HeaderMap;
/// Connection type /// Connection type
fn connection_type(&self) -> ConnectionType; fn connection_type(&self) -> ConnectionType;
@ -68,6 +73,14 @@ impl Head for RequestHead {
self.extensions.borrow_mut().clear(); self.extensions.borrow_mut().clear();
} }
fn headers(&self) -> &HeaderMap {
&self.headers
}
fn headers_mut(&mut self) -> &mut HeaderMap {
&mut self.headers
}
fn set_connection_type(&mut self, ctype: ConnectionType) { fn set_connection_type(&mut self, ctype: ConnectionType) {
self.ctype = Some(ctype) self.ctype = Some(ctype)
} }
@ -129,6 +142,14 @@ impl Head for ResponseHead {
self.headers.clear(); self.headers.clear();
} }
fn headers(&self) -> &HeaderMap {
&self.headers
}
fn headers_mut(&mut self) -> &mut HeaderMap {
&mut self.headers
}
fn set_connection_type(&mut self, ctype: ConnectionType) { fn set_connection_type(&mut self, ctype: ConnectionType) {
self.ctype = Some(ctype) self.ctype = Some(ctype)
} }

View File

@ -15,21 +15,21 @@ pub enum Payload<S = PayloadStream> {
Stream(S), Stream(S),
} }
impl<S> From<RecvStream> for Payload<S> {
fn from(v: RecvStream) -> Self {
Payload::H2(crate::h2::Payload::new(v))
}
}
impl<S> From<crate::h1::Payload> for Payload<S> { impl<S> From<crate::h1::Payload> for Payload<S> {
fn from(pl: crate::h1::Payload) -> Self { fn from(v: crate::h1::Payload) -> Self {
Payload::H1(pl) Payload::H1(v)
} }
} }
impl<S> From<crate::h2::Payload> for Payload<S> { impl<S> From<crate::h2::Payload> for Payload<S> {
fn from(pl: crate::h2::Payload) -> Self { fn from(v: crate::h2::Payload) -> Self {
Payload::H2(pl) Payload::H2(v)
}
}
impl<S> From<RecvStream> for Payload<S> {
fn from(v: RecvStream) -> Self {
Payload::H2(crate::h2::Payload::new(v))
} }
} }

View File

@ -1,11 +1,8 @@
use std::cell::{Ref, RefCell, RefMut}; use std::cell::{Ref, RefMut};
use std::{fmt, mem}; use std::fmt;
use bytes::Bytes;
use futures::Stream;
use http::{header, HeaderMap, Method, Uri, Version}; use http::{header, HeaderMap, Method, Uri, Version};
use crate::error::PayloadError;
use crate::extensions::Extensions; use crate::extensions::Extensions;
use crate::httpmessage::HttpMessage; use crate::httpmessage::HttpMessage;
use crate::message::{Message, RequestHead}; use crate::message::{Message, RequestHead};
@ -13,31 +10,27 @@ use crate::payload::{Payload, PayloadStream};
/// Request /// Request
pub struct Request<P = PayloadStream> { pub struct Request<P = PayloadStream> {
pub(crate) payload: RefCell<Payload<P>>, pub(crate) payload: Payload<P>,
pub(crate) inner: Message<RequestHead>, pub(crate) head: Message<RequestHead>,
} }
impl<P> HttpMessage for Request<P> impl<P> HttpMessage for Request<P> {
where
P: Stream<Item = Bytes, Error = PayloadError>,
{
type Stream = P; type Stream = P;
fn headers(&self) -> &HeaderMap { fn headers(&self) -> &HeaderMap {
&self.head().headers &self.head().headers
} }
#[inline] fn take_payload(&mut self) -> Payload<P> {
fn payload(&self) -> Payload<Self::Stream> { std::mem::replace(&mut self.payload, Payload::None)
mem::replace(&mut *self.payload.borrow_mut(), Payload::None)
} }
} }
impl<P> From<Message<RequestHead>> for Request<P> { impl From<Message<RequestHead>> for Request<PayloadStream> {
fn from(msg: Message<RequestHead>) -> Self { fn from(head: Message<RequestHead>) -> Self {
Request { Request {
payload: RefCell::new(Payload::None), head,
inner: msg, payload: Payload::None,
} }
} }
} }
@ -46,8 +39,8 @@ impl Request<PayloadStream> {
/// Create new Request instance /// Create new Request instance
pub fn new() -> Request<PayloadStream> { pub fn new() -> Request<PayloadStream> {
Request { Request {
payload: RefCell::new(Payload::None), head: Message::new(),
inner: Message::new(), payload: Payload::None,
} }
} }
} }
@ -56,38 +49,44 @@ impl<P> Request<P> {
/// Create new Request instance /// Create new Request instance
pub fn with_payload(payload: Payload<P>) -> Request<P> { pub fn with_payload(payload: Payload<P>) -> Request<P> {
Request { Request {
payload: RefCell::new(payload), payload,
inner: Message::new(), head: Message::new(),
} }
} }
/// Create new Request instance /// Create new Request instance
pub fn set_payload<I, P1>(self, payload: I) -> Request<P1> pub fn replace_payload<P1>(self, payload: Payload<P1>) -> (Request<P1>, Payload<P>) {
where let pl = self.payload;
I: Into<Payload<P1>>, (
{ Request {
Request { payload,
payload: RefCell::new(payload.into()), head: self.head,
inner: self.inner, },
} pl,
)
}
/// Get request's payload
pub fn take_payload(&mut self) -> Payload<P> {
std::mem::replace(&mut self.payload, Payload::None)
} }
/// Split request into request head and payload /// Split request into request head and payload
pub fn into_parts(self) -> (Message<RequestHead>, Payload<P>) { pub fn into_parts(self) -> (Message<RequestHead>, Payload<P>) {
(self.inner, self.payload.into_inner()) (self.head, self.payload)
} }
#[inline] #[inline]
/// Http message part of the request /// Http message part of the request
pub fn head(&self) -> &RequestHead { pub fn head(&self) -> &RequestHead {
&*self.inner &*self.head
} }
#[inline] #[inline]
#[doc(hidden)] #[doc(hidden)]
/// Mutable reference to a http message part of the request /// Mutable reference to a http message part of the request
pub fn head_mut(&mut self) -> &mut RequestHead { pub fn head_mut(&mut self) -> &mut RequestHead {
&mut *self.inner &mut *self.head
} }
/// Request's uri. /// Request's uri.
@ -135,13 +134,13 @@ impl<P> Request<P> {
/// Request extensions /// Request extensions
#[inline] #[inline]
pub fn extensions(&self) -> Ref<Extensions> { pub fn extensions(&self) -> Ref<Extensions> {
self.inner.extensions() self.head.extensions()
} }
/// Mutable reference to a the request's extensions /// Mutable reference to a the request's extensions
#[inline] #[inline]
pub fn extensions_mut(&self) -> RefMut<Extensions> { pub fn extensions_mut(&self) -> RefMut<Extensions> {
self.inner.extensions_mut() self.head.extensions_mut()
} }
/// Check if request requires connection upgrade /// Check if request requires connection upgrade
@ -155,7 +154,7 @@ impl<P> Request<P> {
} }
} }
impl<Payload> fmt::Debug for Request<Payload> { impl<P> fmt::Debug for Request<P> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!( writeln!(
f, f,

View File

@ -47,7 +47,7 @@ fn test_h1_v2() {
assert!(repr.contains("ClientRequest")); assert!(repr.contains("ClientRequest"));
assert!(repr.contains("x-test")); assert!(repr.contains("x-test"));
let response = srv.block_on(request.send(&mut connector)).unwrap(); let mut response = srv.block_on(request.send(&mut connector)).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
// read response // read response
@ -55,7 +55,7 @@ fn test_h1_v2() {
assert_eq!(bytes, Bytes::from_static(STR.as_ref())); assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
let request = srv.post().finish().unwrap(); let request = srv.post().finish().unwrap();
let response = srv.block_on(request.send(&mut connector)).unwrap(); let mut response = srv.block_on(request.send(&mut connector)).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
// read response // read response

View File

@ -89,7 +89,7 @@ fn test_h2_body() -> std::io::Result<()> {
.map_err(|e| println!("Openssl error: {}", e)) .map_err(|e| println!("Openssl error: {}", e))
.and_then( .and_then(
h2::H2Service::build() h2::H2Service::build()
.finish(|req: Request<_>| { .finish(|mut req: Request<_>| {
req.body() req.body()
.limit(1024 * 1024) .limit(1024 * 1024)
.and_then(|body| Ok(Response::Ok().body(body))) .and_then(|body| Ok(Response::Ok().body(body)))
@ -101,7 +101,7 @@ fn test_h2_body() -> std::io::Result<()> {
let req = client::ClientRequest::get(srv.surl("/")) let req = client::ClientRequest::get(srv.surl("/"))
.body(data.clone()) .body(data.clone())
.unwrap(); .unwrap();
let response = srv.send_request(req).unwrap(); let mut response = srv.send_request(req).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
let body = srv.block_on(response.body().limit(1024 * 1024)).unwrap(); let body = srv.block_on(response.body().limit(1024 * 1024)).unwrap();
@ -350,7 +350,7 @@ fn test_headers() {
let req = srv.get().finish().unwrap(); let req = srv.get().finish().unwrap();
let response = srv.block_on(req.send(&mut connector)).unwrap(); let mut response = srv.block_on(req.send(&mut connector)).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
// read response // read response
@ -387,7 +387,7 @@ fn test_body() {
}); });
let req = srv.get().finish().unwrap(); let req = srv.get().finish().unwrap();
let response = srv.send_request(req).unwrap(); let mut response = srv.send_request(req).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
// read response // read response
@ -402,7 +402,7 @@ fn test_head_empty() {
}); });
let req = client::ClientRequest::head(srv.url("/")).finish().unwrap(); let req = client::ClientRequest::head(srv.url("/")).finish().unwrap();
let response = srv.send_request(req).unwrap(); let mut response = srv.send_request(req).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
{ {
@ -428,7 +428,7 @@ fn test_head_binary() {
}); });
let req = client::ClientRequest::head(srv.url("/")).finish().unwrap(); let req = client::ClientRequest::head(srv.url("/")).finish().unwrap();
let response = srv.send_request(req).unwrap(); let mut response = srv.send_request(req).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
{ {
@ -477,7 +477,7 @@ fn test_body_length() {
}); });
let req = srv.get().finish().unwrap(); let req = srv.get().finish().unwrap();
let response = srv.send_request(req).unwrap(); let mut response = srv.send_request(req).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
// read response // read response
@ -496,7 +496,7 @@ fn test_body_chunked_explicit() {
}); });
let req = srv.get().finish().unwrap(); let req = srv.get().finish().unwrap();
let response = srv.send_request(req).unwrap(); let mut response = srv.send_request(req).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
// read response // read response
@ -517,7 +517,7 @@ fn test_body_chunked_implicit() {
}); });
let req = srv.get().finish().unwrap(); let req = srv.get().finish().unwrap();
let response = srv.send_request(req).unwrap(); let mut response = srv.send_request(req).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
// read response // read response
@ -540,7 +540,7 @@ fn test_response_http_error_handling() {
}); });
let req = srv.get().finish().unwrap(); let req = srv.get().finish().unwrap();
let response = srv.send_request(req).unwrap(); let mut response = srv.send_request(req).unwrap();
assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
// read response // read response