1
0
mirror of https://github.com/fafhrd91/actix-web synced 2025-07-01 08:45:10 +02:00

migrate actix-web to std::future

This commit is contained in:
Nikolay Kim
2019-11-20 23:33:22 +06:00
parent d081e57316
commit 3127dd4db6
46 changed files with 4134 additions and 3720 deletions

View File

@ -1,12 +1,16 @@
//! Form extractor
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use std::{fmt, ops};
use actix_http::{Error, HttpMessage, Payload, Response};
use bytes::BytesMut;
use encoding_rs::{Encoding, UTF_8};
use futures::{Future, Poll, Stream};
use futures::future::{err, ok, FutureExt, LocalBoxFuture, Ready};
use futures::{Stream, StreamExt};
use serde::de::DeserializeOwned;
use serde::Serialize;
@ -110,7 +114,7 @@ where
{
type Config = FormConfig;
type Error = Error;
type Future = Box<dyn Future<Item = Self, Error = Error>>;
type Future = LocalBoxFuture<'static, Result<Self, Error>>;
#[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
@ -120,18 +124,19 @@ where
.map(|c| (c.limit, c.ehandler.clone()))
.unwrap_or((16384, None));
Box::new(
UrlEncoded::new(req, payload)
.limit(limit)
.map_err(move |e| {
UrlEncoded::new(req, payload)
.limit(limit)
.map(move |res| match res {
Err(e) => {
if let Some(err) = err {
(*err)(e, &req2)
Err((*err)(e, &req2))
} else {
e.into()
Err(e.into())
}
})
.map(Form),
)
}
Ok(item) => Ok(Form(item)),
})
.boxed_local()
}
}
@ -149,15 +154,15 @@ impl<T: fmt::Display> fmt::Display for Form<T> {
impl<T: Serialize> Responder for Form<T> {
type Error = Error;
type Future = Result<Response, Error>;
type Future = Ready<Result<Response, Error>>;
fn respond_to(self, _: &HttpRequest) -> Self::Future {
let body = match serde_urlencoded::to_string(&self.0) {
Ok(body) => body,
Err(e) => return Err(e.into()),
Err(e) => return err(e.into()),
};
Ok(Response::build(StatusCode::OK)
ok(Response::build(StatusCode::OK)
.set(ContentType::form_url_encoded())
.body(body))
}
@ -240,7 +245,7 @@ pub struct UrlEncoded<U> {
length: Option<usize>,
encoding: &'static Encoding,
err: Option<UrlencodedError>,
fut: Option<Box<dyn Future<Item = U, Error = UrlencodedError>>>,
fut: Option<LocalBoxFuture<'static, Result<U, UrlencodedError>>>,
}
impl<U> UrlEncoded<U> {
@ -301,45 +306,45 @@ impl<U> Future for UrlEncoded<U>
where
U: DeserializeOwned + 'static,
{
type Item = U;
type Error = UrlencodedError;
type Output = Result<U, UrlencodedError>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let Some(ref mut fut) = self.fut {
return fut.poll();
return Pin::new(fut).poll(cx);
}
if let Some(err) = self.err.take() {
return Err(err);
return Poll::Ready(Err(err));
}
// payload size
let limit = self.limit;
if let Some(len) = self.length.take() {
if len > limit {
return Err(UrlencodedError::Overflow { size: len, limit });
return Poll::Ready(Err(UrlencodedError::Overflow { size: len, limit }));
}
}
// future
let encoding = self.encoding;
let fut = self
.stream
.take()
.unwrap()
.from_err()
.fold(BytesMut::with_capacity(8192), move |mut body, chunk| {
if (body.len() + chunk.len()) > limit {
Err(UrlencodedError::Overflow {
size: body.len() + chunk.len(),
limit,
})
} else {
body.extend_from_slice(&chunk);
Ok(body)
let mut stream = self.stream.take().unwrap();
self.fut = Some(
async move {
let mut body = BytesMut::with_capacity(8192);
while let Some(item) = stream.next().await {
let chunk = item?;
if (body.len() + chunk.len()) > limit {
return Err(UrlencodedError::Overflow {
size: body.len() + chunk.len(),
limit,
});
} else {
body.extend_from_slice(&chunk);
}
}
})
.and_then(move |body| {
if encoding == UTF_8 {
serde_urlencoded::from_bytes::<U>(&body)
.map_err(|_| UrlencodedError::Parse)
@ -351,9 +356,10 @@ where
serde_urlencoded::from_str::<U>(&body)
.map_err(|_| UrlencodedError::Parse)
}
});
self.fut = Some(Box::new(fut));
self.poll()
}
.boxed_local(),
);
self.poll(cx)
}
}
@ -374,20 +380,24 @@ mod tests {
#[test]
fn test_form() {
let (req, mut pl) =
TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.header(CONTENT_LENGTH, "11")
.set_payload(Bytes::from_static(b"hello=world&counter=123"))
.to_http_parts();
block_on(async {
let (req, mut pl) = TestRequest::with_header(
CONTENT_TYPE,
"application/x-www-form-urlencoded",
)
.header(CONTENT_LENGTH, "11")
.set_payload(Bytes::from_static(b"hello=world&counter=123"))
.to_http_parts();
let Form(s) = block_on(Form::<Info>::from_request(&req, &mut pl)).unwrap();
assert_eq!(
s,
Info {
hello: "world".into(),
counter: 123
}
);
let Form(s) = Form::<Info>::from_request(&req, &mut pl).await.unwrap();
assert_eq!(
s,
Info {
hello: "world".into(),
counter: 123
}
);
})
}
fn eq(err: UrlencodedError, other: UrlencodedError) -> bool {
@ -410,81 +420,93 @@ mod tests {
#[test]
fn test_urlencoded_error() {
let (req, mut pl) =
TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.header(CONTENT_LENGTH, "xxxx")
.to_http_parts();
let info = block_on(UrlEncoded::<Info>::new(&req, &mut pl));
assert!(eq(info.err().unwrap(), UrlencodedError::UnknownLength));
let (req, mut pl) =
TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.header(CONTENT_LENGTH, "1000000")
.to_http_parts();
let info = block_on(UrlEncoded::<Info>::new(&req, &mut pl));
assert!(eq(
info.err().unwrap(),
UrlencodedError::Overflow { size: 0, limit: 0 }
));
let (req, mut pl) = TestRequest::with_header(CONTENT_TYPE, "text/plain")
.header(CONTENT_LENGTH, "10")
block_on(async {
let (req, mut pl) = TestRequest::with_header(
CONTENT_TYPE,
"application/x-www-form-urlencoded",
)
.header(CONTENT_LENGTH, "xxxx")
.to_http_parts();
let info = block_on(UrlEncoded::<Info>::new(&req, &mut pl));
assert!(eq(info.err().unwrap(), UrlencodedError::ContentType));
let info = UrlEncoded::<Info>::new(&req, &mut pl).await;
assert!(eq(info.err().unwrap(), UrlencodedError::UnknownLength));
let (req, mut pl) = TestRequest::with_header(
CONTENT_TYPE,
"application/x-www-form-urlencoded",
)
.header(CONTENT_LENGTH, "1000000")
.to_http_parts();
let info = UrlEncoded::<Info>::new(&req, &mut pl).await;
assert!(eq(
info.err().unwrap(),
UrlencodedError::Overflow { size: 0, limit: 0 }
));
let (req, mut pl) = TestRequest::with_header(CONTENT_TYPE, "text/plain")
.header(CONTENT_LENGTH, "10")
.to_http_parts();
let info = UrlEncoded::<Info>::new(&req, &mut pl).await;
assert!(eq(info.err().unwrap(), UrlencodedError::ContentType));
})
}
#[test]
fn test_urlencoded() {
let (req, mut pl) =
TestRequest::with_header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.header(CONTENT_LENGTH, "11")
.set_payload(Bytes::from_static(b"hello=world&counter=123"))
.to_http_parts();
block_on(async {
let (req, mut pl) = TestRequest::with_header(
CONTENT_TYPE,
"application/x-www-form-urlencoded",
)
.header(CONTENT_LENGTH, "11")
.set_payload(Bytes::from_static(b"hello=world&counter=123"))
.to_http_parts();
let info = block_on(UrlEncoded::<Info>::new(&req, &mut pl)).unwrap();
assert_eq!(
info,
Info {
hello: "world".to_owned(),
counter: 123
}
);
let info = UrlEncoded::<Info>::new(&req, &mut pl).await.unwrap();
assert_eq!(
info,
Info {
hello: "world".to_owned(),
counter: 123
}
);
let (req, mut pl) = TestRequest::with_header(
CONTENT_TYPE,
"application/x-www-form-urlencoded; charset=utf-8",
)
.header(CONTENT_LENGTH, "11")
.set_payload(Bytes::from_static(b"hello=world&counter=123"))
.to_http_parts();
let (req, mut pl) = TestRequest::with_header(
CONTENT_TYPE,
"application/x-www-form-urlencoded; charset=utf-8",
)
.header(CONTENT_LENGTH, "11")
.set_payload(Bytes::from_static(b"hello=world&counter=123"))
.to_http_parts();
let info = block_on(UrlEncoded::<Info>::new(&req, &mut pl)).unwrap();
assert_eq!(
info,
Info {
hello: "world".to_owned(),
counter: 123
}
);
let info = UrlEncoded::<Info>::new(&req, &mut pl).await.unwrap();
assert_eq!(
info,
Info {
hello: "world".to_owned(),
counter: 123
}
);
})
}
#[test]
fn test_responder() {
let req = TestRequest::default().to_http_request();
block_on(async {
let req = TestRequest::default().to_http_request();
let form = Form(Info {
hello: "world".to_string(),
counter: 123,
});
let resp = form.respond_to(&req).unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/x-www-form-urlencoded")
);
let form = Form(Info {
hello: "world".to_string(),
counter: 123,
});
let resp = form.respond_to(&req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(CONTENT_TYPE).unwrap(),
HeaderValue::from_static("application/x-www-form-urlencoded")
);
use crate::responder::tests::BodyTest;
assert_eq!(resp.body().bin_ref(), b"hello=world&counter=123");
use crate::responder::tests::BodyTest;
assert_eq!(resp.body().bin_ref(), b"hello=world&counter=123");
})
}
}

View File

@ -1,10 +1,14 @@
//! Json extractor/responder
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{fmt, ops};
use bytes::BytesMut;
use futures::{Future, Poll, Stream};
use futures::future::{err, ok, FutureExt, LocalBoxFuture, Ready};
use futures::{Stream, StreamExt};
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json;
@ -118,15 +122,15 @@ where
impl<T: Serialize> Responder for Json<T> {
type Error = Error;
type Future = Result<Response, Error>;
type Future = Ready<Result<Response, Error>>;
fn respond_to(self, _: &HttpRequest) -> Self::Future {
let body = match serde_json::to_string(&self.0) {
Ok(body) => body,
Err(e) => return Err(e.into()),
Err(e) => return err(e.into()),
};
Ok(Response::build(StatusCode::OK)
ok(Response::build(StatusCode::OK)
.content_type("application/json")
.body(body))
}
@ -169,7 +173,7 @@ where
T: DeserializeOwned + 'static,
{
type Error = Error;
type Future = Box<dyn Future<Item = Self, Error = Error>>;
type Future = LocalBoxFuture<'static, Result<Self, Error>>;
type Config = JsonConfig;
#[inline]
@ -180,23 +184,24 @@ where
.map(|c| (c.limit, c.ehandler.clone(), c.content_type.clone()))
.unwrap_or((32768, None, None));
Box::new(
JsonBody::new(req, payload, ctype)
.limit(limit)
.map_err(move |e| {
JsonBody::new(req, payload, ctype)
.limit(limit)
.map(move |res| match res {
Err(e) => {
log::debug!(
"Failed to deserialize Json from payload. \
Request path: {}",
req2.path()
);
if let Some(err) = err {
(*err)(e, &req2)
Err((*err)(e, &req2))
} else {
e.into()
Err(e.into())
}
})
.map(Json),
)
}
Ok(data) => Ok(Json(data)),
})
.boxed_local()
}
}
@ -290,7 +295,7 @@ pub struct JsonBody<U> {
length: Option<usize>,
stream: Option<Decompress<Payload>>,
err: Option<JsonPayloadError>,
fut: Option<Box<dyn Future<Item = U, Error = JsonPayloadError>>>,
fut: Option<LocalBoxFuture<'static, Result<U, JsonPayloadError>>>,
}
impl<U> JsonBody<U>
@ -349,41 +354,43 @@ impl<U> Future for JsonBody<U>
where
U: DeserializeOwned + 'static,
{
type Item = U;
type Error = JsonPayloadError;
type Output = Result<U, JsonPayloadError>;
fn poll(&mut self) -> Poll<U, JsonPayloadError> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let Some(ref mut fut) = self.fut {
return fut.poll();
return Pin::new(fut).poll(cx);
}
if let Some(err) = self.err.take() {
return Err(err);
return Poll::Ready(Err(err));
}
let limit = self.limit;
if let Some(len) = self.length.take() {
if len > limit {
return Err(JsonPayloadError::Overflow);
return Poll::Ready(Err(JsonPayloadError::Overflow));
}
}
let mut stream = self.stream.take().unwrap();
let fut = self
.stream
.take()
.unwrap()
.from_err()
.fold(BytesMut::with_capacity(8192), move |mut body, chunk| {
if (body.len() + chunk.len()) > limit {
Err(JsonPayloadError::Overflow)
} else {
body.extend_from_slice(&chunk);
Ok(body)
self.fut = Some(
async move {
let mut body = BytesMut::with_capacity(8192);
while let Some(item) = stream.next().await {
let chunk = item?;
if (body.len() + chunk.len()) > limit {
return Err(JsonPayloadError::Overflow);
} else {
body.extend_from_slice(&chunk);
}
}
})
.and_then(|body| Ok(serde_json::from_slice::<U>(&body)?));
self.fut = Some(Box::new(fut));
self.poll()
Ok(serde_json::from_slice::<U>(&body)?)
}
.boxed_local(),
);
self.poll(cx)
}
}
@ -395,7 +402,7 @@ mod tests {
use super::*;
use crate::error::InternalError;
use crate::http::header;
use crate::test::{block_on, TestRequest};
use crate::test::{block_on, load_stream, TestRequest};
use crate::HttpResponse;
#[derive(Serialize, Deserialize, PartialEq, Debug)]
@ -419,218 +426,234 @@ mod tests {
#[test]
fn test_responder() {
let req = TestRequest::default().to_http_request();
block_on(async {
let req = TestRequest::default().to_http_request();
let j = Json(MyObject {
name: "test".to_string(),
});
let resp = j.respond_to(&req).unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
header::HeaderValue::from_static("application/json")
);
let j = Json(MyObject {
name: "test".to_string(),
});
let resp = j.respond_to(&req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
header::HeaderValue::from_static("application/json")
);
use crate::responder::tests::BodyTest;
assert_eq!(resp.body().bin_ref(), b"{\"name\":\"test\"}");
use crate::responder::tests::BodyTest;
assert_eq!(resp.body().bin_ref(), b"{\"name\":\"test\"}");
})
}
#[test]
fn test_custom_error_responder() {
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.data(JsonConfig::default().limit(10).error_handler(|err, _| {
let msg = MyObject {
name: "invalid request".to_string(),
};
let resp = HttpResponse::BadRequest()
.body(serde_json::to_string(&msg).unwrap());
InternalError::from_response(err, resp).into()
}))
.to_http_parts();
block_on(async {
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.data(JsonConfig::default().limit(10).error_handler(|err, _| {
let msg = MyObject {
name: "invalid request".to_string(),
};
let resp = HttpResponse::BadRequest()
.body(serde_json::to_string(&msg).unwrap());
InternalError::from_response(err, resp).into()
}))
.to_http_parts();
let s = block_on(Json::<MyObject>::from_request(&req, &mut pl));
let mut resp = Response::from_error(s.err().unwrap().into());
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let s = Json::<MyObject>::from_request(&req, &mut pl).await;
let mut resp = Response::from_error(s.err().unwrap().into());
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = block_on(resp.take_body().concat2()).unwrap();
let msg: MyObject = serde_json::from_slice(&body).unwrap();
assert_eq!(msg.name, "invalid request");
let body = load_stream(resp.take_body()).await.unwrap();
let msg: MyObject = serde_json::from_slice(&body).unwrap();
assert_eq!(msg.name, "invalid request");
})
}
#[test]
fn test_extract() {
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.to_http_parts();
block_on(async {
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.to_http_parts();
let s = block_on(Json::<MyObject>::from_request(&req, &mut pl)).unwrap();
assert_eq!(s.name, "test");
assert_eq!(
s.into_inner(),
MyObject {
name: "test".to_string()
}
);
let s = Json::<MyObject>::from_request(&req, &mut pl).await.unwrap();
assert_eq!(s.name, "test");
assert_eq!(
s.into_inner(),
MyObject {
name: "test".to_string()
}
);
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.data(JsonConfig::default().limit(10))
.to_http_parts();
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.data(JsonConfig::default().limit(10))
.to_http_parts();
let s = block_on(Json::<MyObject>::from_request(&req, &mut pl));
assert!(format!("{}", s.err().unwrap())
.contains("Json payload size is bigger than allowed"));
let s = Json::<MyObject>::from_request(&req, &mut pl).await;
assert!(format!("{}", s.err().unwrap())
.contains("Json payload size is bigger than allowed"));
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.data(
JsonConfig::default()
.limit(10)
.error_handler(|_, _| JsonPayloadError::ContentType.into()),
)
.to_http_parts();
let s = block_on(Json::<MyObject>::from_request(&req, &mut pl));
assert!(format!("{}", s.err().unwrap()).contains("Content type error"));
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.data(
JsonConfig::default()
.limit(10)
.error_handler(|_, _| JsonPayloadError::ContentType.into()),
)
.to_http_parts();
let s = Json::<MyObject>::from_request(&req, &mut pl).await;
assert!(format!("{}", s.err().unwrap()).contains("Content type error"));
})
}
#[test]
fn test_json_body() {
let (req, mut pl) = TestRequest::default().to_http_parts();
let json = block_on(JsonBody::<MyObject>::new(&req, &mut pl, None));
assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
block_on(async {
let (req, mut pl) = TestRequest::default().to_http_parts();
let json = JsonBody::<MyObject>::new(&req, &mut pl, None).await;
assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
let (req, mut pl) = TestRequest::default()
.header(
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/text"),
)
.to_http_parts();
let json = JsonBody::<MyObject>::new(&req, &mut pl, None).await;
assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("10000"),
)
.to_http_parts();
let json = JsonBody::<MyObject>::new(&req, &mut pl, None)
.limit(100)
.await;
assert!(json_eq(json.err().unwrap(), JsonPayloadError::Overflow));
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.to_http_parts();
let json = JsonBody::<MyObject>::new(&req, &mut pl, None).await;
assert_eq!(
json.ok().unwrap(),
MyObject {
name: "test".to_owned()
}
);
})
}
#[test]
fn test_with_json_and_bad_content_type() {
block_on(async {
let (req, mut pl) = TestRequest::with_header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/text"),
)
.to_http_parts();
let json = block_on(JsonBody::<MyObject>::new(&req, &mut pl, None));
assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("10000"),
)
.to_http_parts();
let json = block_on(JsonBody::<MyObject>::new(&req, &mut pl, None).limit(100));
assert!(json_eq(json.err().unwrap(), JsonPayloadError::Overflow));
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
header::HeaderValue::from_static("text/plain"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.data(JsonConfig::default().limit(4096))
.to_http_parts();
let json = block_on(JsonBody::<MyObject>::new(&req, &mut pl, None));
assert_eq!(
json.ok().unwrap(),
MyObject {
name: "test".to_owned()
}
);
}
#[test]
fn test_with_json_and_bad_content_type() {
let (req, mut pl) = TestRequest::with_header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("text/plain"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.data(JsonConfig::default().limit(4096))
.to_http_parts();
let s = block_on(Json::<MyObject>::from_request(&req, &mut pl));
assert!(s.is_err())
let s = Json::<MyObject>::from_request(&req, &mut pl).await;
assert!(s.is_err())
})
}
#[test]
fn test_with_json_and_good_custom_content_type() {
let (req, mut pl) = TestRequest::with_header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("text/plain"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.data(JsonConfig::default().content_type(|mime: mime::Mime| {
mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN
}))
.to_http_parts();
block_on(async {
let (req, mut pl) = TestRequest::with_header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("text/plain"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.data(JsonConfig::default().content_type(|mime: mime::Mime| {
mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN
}))
.to_http_parts();
let s = block_on(Json::<MyObject>::from_request(&req, &mut pl));
assert!(s.is_ok())
let s = Json::<MyObject>::from_request(&req, &mut pl).await;
assert!(s.is_ok())
})
}
#[test]
fn test_with_json_and_bad_custom_content_type() {
let (req, mut pl) = TestRequest::with_header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("text/html"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.data(JsonConfig::default().content_type(|mime: mime::Mime| {
mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN
}))
.to_http_parts();
block_on(async {
let (req, mut pl) = TestRequest::with_header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("text/html"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.data(JsonConfig::default().content_type(|mime: mime::Mime| {
mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN
}))
.to_http_parts();
let s = block_on(Json::<MyObject>::from_request(&req, &mut pl));
assert!(s.is_err())
let s = Json::<MyObject>::from_request(&req, &mut pl).await;
assert!(s.is_err())
})
}
}

View File

@ -5,6 +5,7 @@ use std::{fmt, ops};
use actix_http::error::{Error, ErrorNotFound};
use actix_router::PathDeserializer;
use futures::future::{ready, Ready};
use serde::de;
use crate::dev::Payload;
@ -159,7 +160,7 @@ where
T: de::DeserializeOwned,
{
type Error = Error;
type Future = Result<Self, Error>;
type Future = Ready<Result<Self, Error>>;
type Config = PathConfig;
#[inline]
@ -169,21 +170,23 @@ where
.map(|c| c.ehandler.clone())
.unwrap_or(None);
de::Deserialize::deserialize(PathDeserializer::new(req.match_info()))
.map(|inner| Path { inner })
.map_err(move |e| {
log::debug!(
"Failed during Path extractor deserialization. \
Request path: {:?}",
req.path()
);
if let Some(error_handler) = error_handler {
let e = PathError::Deserialize(e);
(error_handler)(e, req)
} else {
ErrorNotFound(e)
}
})
ready(
de::Deserialize::deserialize(PathDeserializer::new(req.match_info()))
.map(|inner| Path { inner })
.map_err(move |e| {
log::debug!(
"Failed during Path extractor deserialization. \
Request path: {:?}",
req.path()
);
if let Some(error_handler) = error_handler {
let e = PathError::Deserialize(e);
(error_handler)(e, req)
} else {
ErrorNotFound(e)
}
}),
)
}
}
@ -268,100 +271,116 @@ mod tests {
#[test]
fn test_extract_path_single() {
let resource = ResourceDef::new("/{value}/");
block_on(async {
let resource = ResourceDef::new("/{value}/");
let mut req = TestRequest::with_uri("/32/").to_srv_request();
resource.match_path(req.match_info_mut());
let mut req = TestRequest::with_uri("/32/").to_srv_request();
resource.match_path(req.match_info_mut());
let (req, mut pl) = req.into_parts();
assert_eq!(*Path::<i8>::from_request(&req, &mut pl).unwrap(), 32);
assert!(Path::<MyStruct>::from_request(&req, &mut pl).is_err());
let (req, mut pl) = req.into_parts();
assert_eq!(*Path::<i8>::from_request(&req, &mut pl).await.unwrap(), 32);
assert!(Path::<MyStruct>::from_request(&req, &mut pl).await.is_err());
})
}
#[test]
fn test_tuple_extract() {
let resource = ResourceDef::new("/{key}/{value}/");
block_on(async {
let resource = ResourceDef::new("/{key}/{value}/");
let mut req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request();
resource.match_path(req.match_info_mut());
let mut req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request();
resource.match_path(req.match_info_mut());
let (req, mut pl) = req.into_parts();
let res =
block_on(<(Path<(String, String)>,)>::from_request(&req, &mut pl)).unwrap();
assert_eq!((res.0).0, "name");
assert_eq!((res.0).1, "user1");
let (req, mut pl) = req.into_parts();
let res = <(Path<(String, String)>,)>::from_request(&req, &mut pl)
.await
.unwrap();
assert_eq!((res.0).0, "name");
assert_eq!((res.0).1, "user1");
let res = block_on(
<(Path<(String, String)>, Path<(String, String)>)>::from_request(
let res = <(Path<(String, String)>, Path<(String, String)>)>::from_request(
&req, &mut pl,
),
)
.unwrap();
assert_eq!((res.0).0, "name");
assert_eq!((res.0).1, "user1");
assert_eq!((res.1).0, "name");
assert_eq!((res.1).1, "user1");
)
.await
.unwrap();
assert_eq!((res.0).0, "name");
assert_eq!((res.0).1, "user1");
assert_eq!((res.1).0, "name");
assert_eq!((res.1).1, "user1");
let () = <()>::from_request(&req, &mut pl).unwrap();
let () = <()>::from_request(&req, &mut pl).await.unwrap();
})
}
#[test]
fn test_request_extract() {
let mut req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request();
block_on(async {
let mut req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request();
let resource = ResourceDef::new("/{key}/{value}/");
resource.match_path(req.match_info_mut());
let resource = ResourceDef::new("/{key}/{value}/");
resource.match_path(req.match_info_mut());
let (req, mut pl) = req.into_parts();
let mut s = Path::<MyStruct>::from_request(&req, &mut pl).unwrap();
assert_eq!(s.key, "name");
assert_eq!(s.value, "user1");
s.value = "user2".to_string();
assert_eq!(s.value, "user2");
assert_eq!(
format!("{}, {:?}", s, s),
"MyStruct(name, user2), MyStruct { key: \"name\", value: \"user2\" }"
);
let s = s.into_inner();
assert_eq!(s.value, "user2");
let (req, mut pl) = req.into_parts();
let mut s = Path::<MyStruct>::from_request(&req, &mut pl).await.unwrap();
assert_eq!(s.key, "name");
assert_eq!(s.value, "user1");
s.value = "user2".to_string();
assert_eq!(s.value, "user2");
assert_eq!(
format!("{}, {:?}", s, s),
"MyStruct(name, user2), MyStruct { key: \"name\", value: \"user2\" }"
);
let s = s.into_inner();
assert_eq!(s.value, "user2");
let s = Path::<(String, String)>::from_request(&req, &mut pl).unwrap();
assert_eq!(s.0, "name");
assert_eq!(s.1, "user1");
let s = Path::<(String, String)>::from_request(&req, &mut pl)
.await
.unwrap();
assert_eq!(s.0, "name");
assert_eq!(s.1, "user1");
let mut req = TestRequest::with_uri("/name/32/").to_srv_request();
let resource = ResourceDef::new("/{key}/{value}/");
resource.match_path(req.match_info_mut());
let mut req = TestRequest::with_uri("/name/32/").to_srv_request();
let resource = ResourceDef::new("/{key}/{value}/");
resource.match_path(req.match_info_mut());
let (req, mut pl) = req.into_parts();
let s = Path::<Test2>::from_request(&req, &mut pl).unwrap();
assert_eq!(s.as_ref().key, "name");
assert_eq!(s.value, 32);
let (req, mut pl) = req.into_parts();
let s = Path::<Test2>::from_request(&req, &mut pl).await.unwrap();
assert_eq!(s.as_ref().key, "name");
assert_eq!(s.value, 32);
let s = Path::<(String, u8)>::from_request(&req, &mut pl).unwrap();
assert_eq!(s.0, "name");
assert_eq!(s.1, 32);
let s = Path::<(String, u8)>::from_request(&req, &mut pl)
.await
.unwrap();
assert_eq!(s.0, "name");
assert_eq!(s.1, 32);
let res = Path::<Vec<String>>::from_request(&req, &mut pl).unwrap();
assert_eq!(res[0], "name".to_owned());
assert_eq!(res[1], "32".to_owned());
let res = Path::<Vec<String>>::from_request(&req, &mut pl)
.await
.unwrap();
assert_eq!(res[0], "name".to_owned());
assert_eq!(res[1], "32".to_owned());
})
}
#[test]
fn test_custom_err_handler() {
let (req, mut pl) = TestRequest::with_uri("/name/user1/")
.data(PathConfig::default().error_handler(|err, _| {
error::InternalError::from_response(
err,
HttpResponse::Conflict().finish(),
)
.into()
}))
.to_http_parts();
block_on(async {
let (req, mut pl) = TestRequest::with_uri("/name/user1/")
.data(PathConfig::default().error_handler(|err, _| {
error::InternalError::from_response(
err,
HttpResponse::Conflict().finish(),
)
.into()
}))
.to_http_parts();
let s = block_on(Path::<(usize,)>::from_request(&req, &mut pl)).unwrap_err();
let res: HttpResponse = s.into();
let s = Path::<(usize,)>::from_request(&req, &mut pl)
.await
.unwrap_err();
let res: HttpResponse = s.into();
assert_eq!(res.status(), http::StatusCode::CONFLICT);
assert_eq!(res.status(), http::StatusCode::CONFLICT);
})
}
}

View File

@ -1,12 +1,15 @@
//! Payload/Bytes/String extractors
use std::future::Future;
use std::pin::Pin;
use std::str;
use std::task::{Context, Poll};
use actix_http::error::{Error, ErrorBadRequest, PayloadError};
use actix_http::HttpMessage;
use bytes::{Bytes, BytesMut};
use encoding_rs::UTF_8;
use futures::future::{err, Either, FutureResult};
use futures::{Future, Poll, Stream};
use futures::future::{err, ok, Either, FutureExt, LocalBoxFuture, Ready};
use futures::{Stream, StreamExt};
use mime::Mime;
use crate::dev;
@ -19,21 +22,19 @@ use crate::request::HttpRequest;
/// ## Example
///
/// ```rust
/// use futures::{Future, Stream};
/// use futures::{Future, Stream, StreamExt};
/// use actix_web::{web, error, App, Error, HttpResponse};
///
/// /// extract binary data from request
/// fn index(body: web::Payload) -> impl Future<Item = HttpResponse, Error = Error>
/// async fn index(mut body: web::Payload) -> Result<HttpResponse, Error>
/// {
/// body.map_err(Error::from)
/// .fold(web::BytesMut::new(), move |mut body, chunk| {
/// body.extend_from_slice(&chunk);
/// Ok::<_, Error>(body)
/// })
/// .and_then(|body| {
/// format!("Body {:?}!", body);
/// Ok(HttpResponse::Ok().finish())
/// })
/// let mut bytes = web::BytesMut::new();
/// while let Some(item) = body.next().await {
/// bytes.extend_from_slice(&item?);
/// }
///
/// format!("Body {:?}!", bytes);
/// Ok(HttpResponse::Ok().finish())
/// }
///
/// fn main() {
@ -53,12 +54,14 @@ impl Payload {
}
impl Stream for Payload {
type Item = Bytes;
type Error = PayloadError;
type Item = Result<Bytes, PayloadError>;
#[inline]
fn poll(&mut self) -> Poll<Option<Self::Item>, PayloadError> {
self.0.poll()
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.0).poll_next(cx)
}
}
@ -67,21 +70,19 @@ impl Stream for Payload {
/// ## Example
///
/// ```rust
/// use futures::{Future, Stream};
/// use futures::{Future, Stream, StreamExt};
/// use actix_web::{web, error, App, Error, HttpResponse};
///
/// /// extract binary data from request
/// fn index(body: web::Payload) -> impl Future<Item = HttpResponse, Error = Error>
/// async fn index(mut body: web::Payload) -> Result<HttpResponse, Error>
/// {
/// body.map_err(Error::from)
/// .fold(web::BytesMut::new(), move |mut body, chunk| {
/// body.extend_from_slice(&chunk);
/// Ok::<_, Error>(body)
/// })
/// .and_then(|body| {
/// format!("Body {:?}!", body);
/// Ok(HttpResponse::Ok().finish())
/// })
/// let mut bytes = web::BytesMut::new();
/// while let Some(item) = body.next().await {
/// bytes.extend_from_slice(&item?);
/// }
///
/// format!("Body {:?}!", bytes);
/// Ok(HttpResponse::Ok().finish())
/// }
///
/// fn main() {
@ -94,11 +95,11 @@ impl Stream for Payload {
impl FromRequest for Payload {
type Config = PayloadConfig;
type Error = Error;
type Future = Result<Payload, Error>;
type Future = Ready<Result<Payload, Error>>;
#[inline]
fn from_request(_: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
Ok(Payload(payload.take()))
ok(Payload(payload.take()))
}
}
@ -130,8 +131,10 @@ impl FromRequest for Payload {
impl FromRequest for Bytes {
type Config = PayloadConfig;
type Error = Error;
type Future =
Either<Box<dyn Future<Item = Bytes, Error = Error>>, FutureResult<Bytes, Error>>;
type Future = Either<
LocalBoxFuture<'static, Result<Bytes, Error>>,
Ready<Result<Bytes, Error>>,
>;
#[inline]
fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
@ -144,13 +147,12 @@ impl FromRequest for Bytes {
};
if let Err(e) = cfg.check_mimetype(req) {
return Either::B(err(e));
return Either::Right(err(e));
}
let limit = cfg.limit;
Either::A(Box::new(
HttpMessageBody::new(req, payload).limit(limit).from_err(),
))
let fut = HttpMessageBody::new(req, payload).limit(limit);
Either::Left(async move { Ok(fut.await?) }.boxed_local())
}
}
@ -185,8 +187,8 @@ impl FromRequest for String {
type Config = PayloadConfig;
type Error = Error;
type Future = Either<
Box<dyn Future<Item = String, Error = Error>>,
FutureResult<String, Error>,
LocalBoxFuture<'static, Result<String, Error>>,
Ready<Result<String, Error>>,
>;
#[inline]
@ -201,33 +203,34 @@ impl FromRequest for String {
// check content-type
if let Err(e) = cfg.check_mimetype(req) {
return Either::B(err(e));
return Either::Right(err(e));
}
// check charset
let encoding = match req.encoding() {
Ok(enc) => enc,
Err(e) => return Either::B(err(e.into())),
Err(e) => return Either::Right(err(e.into())),
};
let limit = cfg.limit;
let fut = HttpMessageBody::new(req, payload).limit(limit);
Either::A(Box::new(
HttpMessageBody::new(req, payload)
.limit(limit)
.from_err()
.and_then(move |body| {
if encoding == UTF_8 {
Ok(str::from_utf8(body.as_ref())
.map_err(|_| ErrorBadRequest("Can not decode body"))?
.to_owned())
} else {
Ok(encoding
.decode_without_bom_handling_and_without_replacement(&body)
.map(|s| s.into_owned())
.ok_or_else(|| ErrorBadRequest("Can not decode body"))?)
}
}),
))
Either::Left(
async move {
let body = fut.await?;
if encoding == UTF_8 {
Ok(str::from_utf8(body.as_ref())
.map_err(|_| ErrorBadRequest("Can not decode body"))?
.to_owned())
} else {
Ok(encoding
.decode_without_bom_handling_and_without_replacement(&body)
.map(|s| s.into_owned())
.ok_or_else(|| ErrorBadRequest("Can not decode body"))?)
}
}
.boxed_local(),
)
}
}
/// Payload configuration for request's payload.
@ -300,7 +303,7 @@ pub struct HttpMessageBody {
length: Option<usize>,
stream: Option<dev::Decompress<dev::Payload>>,
err: Option<PayloadError>,
fut: Option<Box<dyn Future<Item = Bytes, Error = PayloadError>>>,
fut: Option<LocalBoxFuture<'static, Result<Bytes, PayloadError>>>,
}
impl HttpMessageBody {
@ -346,42 +349,43 @@ impl HttpMessageBody {
}
impl Future for HttpMessageBody {
type Item = Bytes;
type Error = PayloadError;
type Output = Result<Bytes, PayloadError>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let Some(ref mut fut) = self.fut {
return fut.poll();
return Pin::new(fut).poll(cx);
}
if let Some(err) = self.err.take() {
return Err(err);
return Poll::Ready(Err(err));
}
if let Some(len) = self.length.take() {
if len > self.limit {
return Err(PayloadError::Overflow);
return Poll::Ready(Err(PayloadError::Overflow));
}
}
// future
let limit = self.limit;
self.fut = Some(Box::new(
self.stream
.take()
.unwrap()
.from_err()
.fold(BytesMut::with_capacity(8192), move |mut body, chunk| {
if (body.len() + chunk.len()) > limit {
Err(PayloadError::Overflow)
let mut stream = self.stream.take().unwrap();
self.fut = Some(
async move {
let mut body = BytesMut::with_capacity(8192);
while let Some(item) = stream.next().await {
let chunk = item?;
if body.len() + chunk.len() > limit {
return Err(PayloadError::Overflow);
} else {
body.extend_from_slice(&chunk);
Ok(body)
}
})
.map(|body| body.freeze()),
));
self.poll()
}
Ok(body.freeze())
}
.boxed_local(),
);
self.poll(cx)
}
}

View File

@ -4,6 +4,7 @@ use std::sync::Arc;
use std::{fmt, ops};
use actix_http::error::Error;
use futures::future::{err, ok, Ready};
use serde::de;
use serde_urlencoded;
@ -132,7 +133,7 @@ where
T: de::DeserializeOwned,
{
type Error = Error;
type Future = Result<Self, Error>;
type Future = Ready<Result<Self, Error>>;
type Config = QueryConfig;
#[inline]
@ -143,7 +144,7 @@ where
.unwrap_or(None);
serde_urlencoded::from_str::<T>(req.query_string())
.map(|val| Ok(Query(val)))
.map(|val| ok(Query(val)))
.unwrap_or_else(move |e| {
let e = QueryPayloadError::Deserialize(e);
@ -159,7 +160,7 @@ where
e.into()
};
Err(e)
err(e)
})
}
}
@ -227,7 +228,7 @@ mod tests {
use super::*;
use crate::error::InternalError;
use crate::test::TestRequest;
use crate::test::{block_on, TestRequest};
use crate::HttpResponse;
#[derive(Deserialize, Debug, Display)]
@ -253,42 +254,46 @@ mod tests {
#[test]
fn test_request_extract() {
let req = TestRequest::with_uri("/name/user1/").to_srv_request();
let (req, mut pl) = req.into_parts();
assert!(Query::<Id>::from_request(&req, &mut pl).is_err());
block_on(async {
let req = TestRequest::with_uri("/name/user1/").to_srv_request();
let (req, mut pl) = req.into_parts();
assert!(Query::<Id>::from_request(&req, &mut pl).await.is_err());
let req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request();
let (req, mut pl) = req.into_parts();
let req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request();
let (req, mut pl) = req.into_parts();
let mut s = Query::<Id>::from_request(&req, &mut pl).unwrap();
assert_eq!(s.id, "test");
assert_eq!(format!("{}, {:?}", s, s), "test, Id { id: \"test\" }");
let mut s = Query::<Id>::from_request(&req, &mut pl).await.unwrap();
assert_eq!(s.id, "test");
assert_eq!(format!("{}, {:?}", s, s), "test, Id { id: \"test\" }");
s.id = "test1".to_string();
let s = s.into_inner();
assert_eq!(s.id, "test1");
s.id = "test1".to_string();
let s = s.into_inner();
assert_eq!(s.id, "test1");
})
}
#[test]
fn test_custom_error_responder() {
let req = TestRequest::with_uri("/name/user1/")
.data(QueryConfig::default().error_handler(|e, _| {
let resp = HttpResponse::UnprocessableEntity().finish();
InternalError::from_response(e, resp).into()
}))
.to_srv_request();
block_on(async {
let req = TestRequest::with_uri("/name/user1/")
.data(QueryConfig::default().error_handler(|e, _| {
let resp = HttpResponse::UnprocessableEntity().finish();
InternalError::from_response(e, resp).into()
}))
.to_srv_request();
let (req, mut pl) = req.into_parts();
let query = Query::<Id>::from_request(&req, &mut pl);
let (req, mut pl) = req.into_parts();
let query = Query::<Id>::from_request(&req, &mut pl).await;
assert!(query.is_err());
assert_eq!(
query
.unwrap_err()
.as_response_error()
.error_response()
.status(),
StatusCode::UNPROCESSABLE_ENTITY
);
assert!(query.is_err());
assert_eq!(
query
.unwrap_err()
.as_response_error()
.error_response()
.status(),
StatusCode::UNPROCESSABLE_ENTITY
);
})
}
}

View File

@ -1,9 +1,13 @@
use std::borrow::Cow;
use std::future::Future;
use std::pin::Pin;
use std::str;
use std::task::{Context, Poll};
use bytes::{Bytes, BytesMut};
use encoding_rs::{Encoding, UTF_8};
use futures::{Async, Poll, Stream};
use futures::Stream;
use pin_project::pin_project;
use crate::dev::Payload;
use crate::error::{PayloadError, ReadlinesError};
@ -22,7 +26,7 @@ pub struct Readlines<T: HttpMessage> {
impl<T> Readlines<T>
where
T: HttpMessage,
T::Stream: Stream<Item = Bytes, Error = PayloadError>,
T::Stream: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{
/// Create a new stream to read request line by line.
pub fn new(req: &mut T) -> Self {
@ -62,20 +66,21 @@ where
impl<T> Stream for Readlines<T>
where
T: HttpMessage,
T::Stream: Stream<Item = Bytes, Error = PayloadError>,
T::Stream: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{
type Item = String;
type Error = ReadlinesError;
type Item = Result<String, ReadlinesError>;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
if let Some(err) = self.err.take() {
return Err(err);
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if let Some(err) = this.err.take() {
return Poll::Ready(Some(Err(err)));
}
// check if there is a newline in the buffer
if !self.checked_buff {
if !this.checked_buff {
let mut found: Option<usize> = None;
for (ind, b) in self.buff.iter().enumerate() {
for (ind, b) in this.buff.iter().enumerate() {
if *b == b'\n' {
found = Some(ind);
break;
@ -83,28 +88,28 @@ where
}
if let Some(ind) = found {
// check if line is longer than limit
if ind + 1 > self.limit {
return Err(ReadlinesError::LimitOverflow);
if ind + 1 > this.limit {
return Poll::Ready(Some(Err(ReadlinesError::LimitOverflow)));
}
let line = if self.encoding == UTF_8 {
str::from_utf8(&self.buff.split_to(ind + 1))
let line = if this.encoding == UTF_8 {
str::from_utf8(&this.buff.split_to(ind + 1))
.map_err(|_| ReadlinesError::EncodingError)?
.to_owned()
} else {
self.encoding
this.encoding
.decode_without_bom_handling_and_without_replacement(
&self.buff.split_to(ind + 1),
&this.buff.split_to(ind + 1),
)
.map(Cow::into_owned)
.ok_or(ReadlinesError::EncodingError)?
};
return Ok(Async::Ready(Some(line)));
return Poll::Ready(Some(Ok(line)));
}
self.checked_buff = true;
this.checked_buff = true;
}
// poll req for more bytes
match self.stream.poll() {
Ok(Async::Ready(Some(mut bytes))) => {
match Pin::new(&mut this.stream).poll_next(cx) {
Poll::Ready(Some(Ok(mut bytes))) => {
// check if there is a newline in bytes
let mut found: Option<usize> = None;
for (ind, b) in bytes.iter().enumerate() {
@ -115,15 +120,15 @@ where
}
if let Some(ind) = found {
// check if line is longer than limit
if ind + 1 > self.limit {
return Err(ReadlinesError::LimitOverflow);
if ind + 1 > this.limit {
return Poll::Ready(Some(Err(ReadlinesError::LimitOverflow)));
}
let line = if self.encoding == UTF_8 {
let line = if this.encoding == UTF_8 {
str::from_utf8(&bytes.split_to(ind + 1))
.map_err(|_| ReadlinesError::EncodingError)?
.to_owned()
} else {
self.encoding
this.encoding
.decode_without_bom_handling_and_without_replacement(
&bytes.split_to(ind + 1),
)
@ -131,83 +136,72 @@ where
.ok_or(ReadlinesError::EncodingError)?
};
// extend buffer with rest of the bytes;
self.buff.extend_from_slice(&bytes);
self.checked_buff = false;
return Ok(Async::Ready(Some(line)));
this.buff.extend_from_slice(&bytes);
this.checked_buff = false;
return Poll::Ready(Some(Ok(line)));
}
self.buff.extend_from_slice(&bytes);
Ok(Async::NotReady)
this.buff.extend_from_slice(&bytes);
Poll::Pending
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Ok(Async::Ready(None)) => {
if self.buff.is_empty() {
return Ok(Async::Ready(None));
Poll::Pending => Poll::Pending,
Poll::Ready(None) => {
if this.buff.is_empty() {
return Poll::Ready(None);
}
if self.buff.len() > self.limit {
return Err(ReadlinesError::LimitOverflow);
if this.buff.len() > this.limit {
return Poll::Ready(Some(Err(ReadlinesError::LimitOverflow)));
}
let line = if self.encoding == UTF_8 {
str::from_utf8(&self.buff)
let line = if this.encoding == UTF_8 {
str::from_utf8(&this.buff)
.map_err(|_| ReadlinesError::EncodingError)?
.to_owned()
} else {
self.encoding
.decode_without_bom_handling_and_without_replacement(&self.buff)
this.encoding
.decode_without_bom_handling_and_without_replacement(&this.buff)
.map(Cow::into_owned)
.ok_or(ReadlinesError::EncodingError)?
};
self.buff.clear();
Ok(Async::Ready(Some(line)))
this.buff.clear();
Poll::Ready(Some(Ok(line)))
}
Err(e) => Err(ReadlinesError::from(e)),
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(ReadlinesError::from(e)))),
}
}
}
#[cfg(test)]
mod tests {
use futures::stream::StreamExt;
use super::*;
use crate::test::{block_on, TestRequest};
#[test]
fn test_readlines() {
let mut req = TestRequest::default()
block_on(async {
let mut req = TestRequest::default()
.set_payload(Bytes::from_static(
b"Lorem Ipsum is simply dummy text of the printing and typesetting\n\
industry. Lorem Ipsum has been the industry's standard dummy\n\
Contrary to popular belief, Lorem Ipsum is not simply random text.",
))
.to_request();
let stream = match block_on(Readlines::new(&mut req).into_future()) {
Ok((Some(s), stream)) => {
assert_eq!(
s,
"Lorem Ipsum is simply dummy text of the printing and typesetting\n"
);
stream
}
_ => unreachable!("error"),
};
let stream = match block_on(stream.into_future()) {
Ok((Some(s), stream)) => {
assert_eq!(
s,
"industry. Lorem Ipsum has been the industry's standard dummy\n"
);
stream
}
_ => unreachable!("error"),
};
let mut stream = Readlines::new(&mut req);
assert_eq!(
stream.next().await.unwrap().unwrap(),
"Lorem Ipsum is simply dummy text of the printing and typesetting\n"
);
match block_on(stream.into_future()) {
Ok((Some(s), _)) => {
assert_eq!(
s,
"Contrary to popular belief, Lorem Ipsum is not simply random text."
);
}
_ => unreachable!("error"),
}
assert_eq!(
stream.next().await.unwrap().unwrap(),
"industry. Lorem Ipsum has been the industry's standard dummy\n"
);
assert_eq!(
stream.next().await.unwrap().unwrap(),
"Contrary to popular belief, Lorem Ipsum is not simply random text."
);
})
}
}