1
0
mirror of https://github.com/actix/actix-extras.git synced 2025-01-23 23:34:35 +01:00

add Header trait

This commit is contained in:
Nikolay Kim 2018-03-05 19:28:42 -08:00
parent 0c30057c8c
commit e182ed33b1
8 changed files with 304 additions and 107 deletions

View File

@ -0,0 +1,52 @@
use http::header;
use header::{Header, HttpDate, IntoHeaderValue};
use error::ParseError;
use httpmessage::HttpMessage;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct IfModifiedSince(pub HttpDate);
impl Header for IfModifiedSince {
fn name() -> header::HeaderName {
header::IF_MODIFIED_SINCE
}
fn parse<T: HttpMessage>(msg: &T) -> Result<Self, ParseError> {
let val = msg.headers().get(Self::name())
.ok_or(ParseError::Header)?.to_str().map_err(|_| ParseError::Header)?;
Ok(IfModifiedSince(val.parse()?))
}
}
impl IntoHeaderValue for IfModifiedSince {
type Error = header::InvalidHeaderValueBytes;
fn try_into(self) -> Result<header::HeaderValue, Self::Error> {
self.0.try_into()
}
}
#[cfg(test)]
mod tests {
use time::Tm;
use test::TestRequest;
use httpmessage::HttpMessage;
use super::HttpDate;
use super::IfModifiedSince;
fn date() -> HttpDate {
Tm {
tm_nsec: 0, tm_sec: 37, tm_min: 48, tm_hour: 8,
tm_mday: 7, tm_mon: 10, tm_year: 94,
tm_wday: 0, tm_isdst: 0, tm_yday: 0, tm_utcoff: 0}.into()
}
#[test]
fn test_if_mod_since() {
let req = TestRequest::with_hdr(IfModifiedSince(date())).finish();
let h = req.get::<IfModifiedSince>().unwrap();
assert_eq!(h.0, date());
}
}

View File

@ -0,0 +1,52 @@
use http::header;
use header::{Header, HttpDate, IntoHeaderValue};
use error::ParseError;
use httpmessage::HttpMessage;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct IfUnmodifiedSince(pub HttpDate);
impl Header for IfUnmodifiedSince {
fn name() -> header::HeaderName {
header::IF_MODIFIED_SINCE
}
fn parse<T: HttpMessage>(msg: &T) -> Result<Self, ParseError> {
let val = msg.headers().get(Self::name())
.ok_or(ParseError::Header)?.to_str().map_err(|_| ParseError::Header)?;
Ok(IfUnmodifiedSince(val.parse()?))
}
}
impl IntoHeaderValue for IfUnmodifiedSince {
type Error = header::InvalidHeaderValueBytes;
fn try_into(self) -> Result<header::HeaderValue, Self::Error> {
self.0.try_into()
}
}
#[cfg(test)]
mod tests {
use time::Tm;
use test::TestRequest;
use httpmessage::HttpMessage;
use super::HttpDate;
use super::IfUnmodifiedSince;
fn date() -> HttpDate {
Tm {
tm_nsec: 0, tm_sec: 37, tm_min: 48, tm_hour: 8,
tm_mday: 7, tm_mon: 10, tm_year: 94,
tm_wday: 0, tm_isdst: 0, tm_yday: 0, tm_utcoff: 0}.into()
}
#[test]
fn test_if_mod_since() {
let req = TestRequest::with_hdr(IfUnmodifiedSince(date())).finish();
let h = req.get::<IfUnmodifiedSince>().unwrap();
assert_eq!(h.0, date());
}
}

5
src/header/common/mod.rs Normal file
View File

@ -0,0 +1,5 @@
mod if_modified_since;
mod if_unmodified_since;
pub use self::if_modified_since::IfModifiedSince;
pub use self::if_unmodified_since::IfUnmodifiedSince;

97
src/header/httpdate.rs Normal file
View File

@ -0,0 +1,97 @@
use std::fmt::{self, Display};
use std::io::Write;
use std::str::FromStr;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use time;
use bytes::{BytesMut, BufMut};
use http::header::{HeaderValue, InvalidHeaderValueBytes};
use error::ParseError;
use super::IntoHeaderValue;
/// A timestamp with HTTP formatting and parsing
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct HttpDate(time::Tm);
impl FromStr for HttpDate {
type Err = ParseError;
fn from_str(s: &str) -> Result<HttpDate, ParseError> {
match time::strptime(s, "%a, %d %b %Y %T %Z").or_else(|_| {
time::strptime(s, "%A, %d-%b-%y %T %Z")
}).or_else(|_| {
time::strptime(s, "%c")
}) {
Ok(t) => Ok(HttpDate(t)),
Err(_) => Err(ParseError::Header),
}
}
}
impl Display for HttpDate {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.0.to_utc().rfc822(), f)
}
}
impl From<time::Tm> for HttpDate {
fn from(tm: time::Tm) -> HttpDate {
HttpDate(tm)
}
}
impl From<SystemTime> for HttpDate {
fn from(sys: SystemTime) -> HttpDate {
let tmspec = match sys.duration_since(UNIX_EPOCH) {
Ok(dur) => {
time::Timespec::new(dur.as_secs() as i64, dur.subsec_nanos() as i32)
},
Err(err) => {
let neg = err.duration();
time::Timespec::new(-(neg.as_secs() as i64), -(neg.subsec_nanos() as i32))
},
};
HttpDate(time::at_utc(tmspec))
}
}
impl IntoHeaderValue for HttpDate {
type Error = InvalidHeaderValueBytes;
fn try_into(self) -> Result<HeaderValue, Self::Error> {
let mut wrt = BytesMut::with_capacity(29).writer();
write!(wrt, "{}", self.0.rfc822()).unwrap();
unsafe{Ok(HeaderValue::from_shared_unchecked(wrt.get_mut().take().freeze()))}
}
}
impl From<HttpDate> for SystemTime {
fn from(date: HttpDate) -> SystemTime {
let spec = date.0.to_timespec();
if spec.sec >= 0 {
UNIX_EPOCH + Duration::new(spec.sec as u64, spec.nsec as u32)
} else {
UNIX_EPOCH - Duration::new(spec.sec as u64, spec.nsec as u32)
}
}
}
#[cfg(test)]
mod tests {
use time::Tm;
use super::HttpDate;
const NOV_07: HttpDate = HttpDate(Tm {
tm_nsec: 0, tm_sec: 37, tm_min: 48, tm_hour: 8, tm_mday: 7, tm_mon: 10, tm_year: 94,
tm_wday: 0, tm_isdst: 0, tm_yday: 0, tm_utcoff: 0});
#[test]
fn test_date() {
assert_eq!("Sun, 07 Nov 1994 08:48:37 GMT".parse::<HttpDate>().unwrap(), NOV_07);
assert_eq!("Sunday, 07-Nov-94 08:48:37 GMT".parse::<HttpDate>().unwrap(), NOV_07);
assert_eq!("Sun Nov 7 08:48:37 1994".parse::<HttpDate>().unwrap(), NOV_07);
assert!("this-is-no-date".parse::<HttpDate>().is_err());
}
}

View File

@ -1,20 +1,37 @@
use std::fmt::{self, Display};
use std::io::Write;
use std::str::FromStr;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
//! Various http headers
// A lot of code is inspired by hyper
use time;
use bytes::{Bytes, BytesMut, BufMut};
use bytes::Bytes;
use http::{Error as HttpError};
use http::header::{HeaderValue, InvalidHeaderValue, InvalidHeaderValueBytes};
use http::header::{InvalidHeaderValue, InvalidHeaderValueBytes};
pub use httpresponse::ConnectionType;
pub use cookie::{Cookie, CookieBuilder};
pub use http_range::HttpRange;
pub use http::header::{HeaderName, HeaderValue};
use error::ParseError;
use httpmessage::HttpMessage;
pub use httpresponse::ConnectionType;
mod common;
mod httpdate;
pub use self::common::*;
pub use self::httpdate::HttpDate;
#[doc(hidden)]
/// A trait for any object that will represent a header field and value.
pub trait Header where Self: IntoHeaderValue {
/// Returns the name of the header field
fn name() -> HeaderName;
/// Parse a header
fn parse<T: HttpMessage>(msg: &T) -> Result<Self, ParseError>;
}
#[doc(hidden)]
/// A trait for any object that can be Converted to a `HeaderValue`
pub trait IntoHeaderValue: Sized {
/// The type returned in the event of a conversion error.
type Error: Into<HttpError>;
@ -116,82 +133,3 @@ impl<'a> From<&'a str> for ContentEncoding {
}
}
}
/// A timestamp with HTTP formatting and parsing
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct Date(time::Tm);
impl FromStr for Date {
type Err = ParseError;
fn from_str(s: &str) -> Result<Date, ParseError> {
match time::strptime(s, "%a, %d %b %Y %T %Z").or_else(|_| {
time::strptime(s, "%A, %d-%b-%y %T %Z")
}).or_else(|_| {
time::strptime(s, "%c")
}) {
Ok(t) => Ok(Date(t)),
Err(_) => Err(ParseError::Header),
}
}
}
impl Display for Date {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.0.to_utc().rfc822(), f)
}
}
impl From<SystemTime> for Date {
fn from(sys: SystemTime) -> Date {
let tmspec = match sys.duration_since(UNIX_EPOCH) {
Ok(dur) => {
time::Timespec::new(dur.as_secs() as i64, dur.subsec_nanos() as i32)
},
Err(err) => {
let neg = err.duration();
time::Timespec::new(-(neg.as_secs() as i64), -(neg.subsec_nanos() as i32))
},
};
Date(time::at_utc(tmspec))
}
}
impl IntoHeaderValue for Date {
type Error = InvalidHeaderValueBytes;
fn try_into(self) -> Result<HeaderValue, Self::Error> {
let mut wrt = BytesMut::with_capacity(29).writer();
write!(wrt, "{}", self.0.rfc822()).unwrap();
HeaderValue::from_shared(wrt.get_mut().take().freeze())
}
}
impl From<Date> for SystemTime {
fn from(date: Date) -> SystemTime {
let spec = date.0.to_timespec();
if spec.sec >= 0 {
UNIX_EPOCH + Duration::new(spec.sec as u64, spec.nsec as u32)
} else {
UNIX_EPOCH - Duration::new(spec.sec as u64, spec.nsec as u32)
}
}
}
#[cfg(test)]
mod tests {
use time::Tm;
use super::Date;
const NOV_07: HttpDate = HttpDate(Tm {
tm_nsec: 0, tm_sec: 37, tm_min: 48, tm_hour: 8, tm_mday: 7, tm_mon: 10, tm_year: 94,
tm_wday: 0, tm_isdst: 0, tm_yday: 0, tm_utcoff: 0});
#[test]
fn test_date() {
assert_eq!("Sun, 07 Nov 1994 08:48:37 GMT".parse::<Date>().unwrap(), NOV_07);
assert_eq!("Sunday, 07-Nov-94 08:48:37 GMT".parse::<Date>().unwrap(), NOV_07);
assert_eq!("Sun Nov 7 08:48:37 1994".parse::<Date>().unwrap(), NOV_07);
assert!("this-is-no-date".parse::<Date>().is_err());
}
}

View File

@ -12,6 +12,7 @@ use encoding::label::encoding_from_whatwg_label;
use http::{header, HeaderMap};
use json::JsonBody;
use header::Header;
use multipart::Multipart;
use error::{ParseError, ContentTypeError,
HttpRangeError, PayloadError, UrlencodedError};
@ -23,6 +24,12 @@ pub trait HttpMessage {
/// Read the message headers.
fn headers(&self) -> &HeaderMap;
#[doc(hidden)]
/// Get a header
fn get<H: Header>(&self) -> Result<H, ParseError> where Self: Sized {
H::parse(self)
}
/// Read the request content type. If request does not contain
/// *Content-Type* header, empty str get returned.
fn content_type(&self) -> &str {

View File

@ -14,7 +14,7 @@ use serde::Serialize;
use body::Body;
use error::Error;
use handler::Responder;
use header::{IntoHeaderValue, ContentEncoding};
use header::{Header, IntoHeaderValue, ContentEncoding};
use httprequest::HttpRequest;
/// Represents various types of connection
@ -241,6 +241,34 @@ impl HttpResponseBuilder {
self
}
/// Set a header.
///
/// ```rust
/// # extern crate actix_web;
/// # use actix_web::*;
/// # use actix_web::httpcodes::*;
/// #
/// use actix_web::header;
///
/// fn index(req: HttpRequest) -> Result<HttpResponse> {
/// Ok(HttpOk.build()
/// .set(header::IfModifiedSince("Sun, 07 Nov 1994 08:48:37 GMT".parse()?))
/// .finish()?)
/// }
/// fn main() {}
/// ```
#[doc(hidden)]
pub fn set<H: Header>(&mut self, hdr: H) -> &mut Self
{
if let Some(parts) = parts(&mut self.response, &self.err) {
match hdr.try_into() {
Ok(value) => { parts.headers.append(H::name(), value); }
Err(e) => self.err = Some(e.into()),
}
}
self
}
/// Set a header.
///
/// ```rust
@ -733,8 +761,9 @@ mod tests {
use std::str::FromStr;
use time::Duration;
use http::{Method, Uri};
use http::header::{COOKIE, CONTENT_TYPE};
use body::Binary;
use {headers, httpcodes};
use {header, httpcodes};
#[test]
fn test_debug() {
@ -746,7 +775,7 @@ mod tests {
#[test]
fn test_response_cookies() {
let mut headers = HeaderMap::new();
headers.insert(header::COOKIE,
headers.insert(COOKIE,
header::HeaderValue::from_static("cookie1=value1; cookie2=value2"));
let req = HttpRequest::new(
@ -755,7 +784,7 @@ mod tests {
let resp = httpcodes::HttpOk
.build()
.cookie(headers::Cookie::build("name", "value")
.cookie(header::Cookie::build("name", "value")
.domain("www.rust-lang.org")
.path("/test")
.http_only(true)
@ -803,7 +832,7 @@ mod tests {
fn test_content_type() {
let resp = HttpResponse::build(StatusCode::OK)
.content_type("text/plain").body(Body::Empty).unwrap();
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(), "text/plain")
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain")
}
#[test]
@ -820,7 +849,7 @@ mod tests {
fn test_json() {
let resp = HttpResponse::build(StatusCode::OK)
.json(vec!["v1", "v2", "v3"]).unwrap();
let ct = resp.headers().get(header::CONTENT_TYPE).unwrap();
let ct = resp.headers().get(CONTENT_TYPE).unwrap();
assert_eq!(ct, header::HeaderValue::from_static("application/json"));
assert_eq!(*resp.body(), Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]")));
}
@ -828,9 +857,9 @@ mod tests {
#[test]
fn test_json_ct() {
let resp = HttpResponse::build(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/json")
.header(CONTENT_TYPE, "text/json")
.json(vec!["v1", "v2", "v3"]).unwrap();
let ct = resp.headers().get(header::CONTENT_TYPE).unwrap();
let ct = resp.headers().get(CONTENT_TYPE).unwrap();
assert_eq!(ct, header::HeaderValue::from_static("text/json"));
assert_eq!(*resp.body(), Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]")));
}
@ -850,56 +879,56 @@ mod tests {
let resp: HttpResponse = "test".into();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(),
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
header::HeaderValue::from_static("text/plain; charset=utf-8"));
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from("test"));
let resp: HttpResponse = "test".respond_to(req.clone()).ok().unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(),
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
header::HeaderValue::from_static("text/plain; charset=utf-8"));
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from("test"));
let resp: HttpResponse = b"test".as_ref().into();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(),
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
header::HeaderValue::from_static("application/octet-stream"));
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(b"test".as_ref()));
let resp: HttpResponse = b"test".as_ref().respond_to(req.clone()).ok().unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(),
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
header::HeaderValue::from_static("application/octet-stream"));
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(b"test".as_ref()));
let resp: HttpResponse = "test".to_owned().into();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(),
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
header::HeaderValue::from_static("text/plain; charset=utf-8"));
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from("test".to_owned()));
let resp: HttpResponse = "test".to_owned().respond_to(req.clone()).ok().unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(),
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
header::HeaderValue::from_static("text/plain; charset=utf-8"));
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from("test".to_owned()));
let resp: HttpResponse = (&"test".to_owned()).into();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(),
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
header::HeaderValue::from_static("text/plain; charset=utf-8"));
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(&"test".to_owned()));
let resp: HttpResponse = (&"test".to_owned()).respond_to(req.clone()).ok().unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(),
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
header::HeaderValue::from_static("text/plain; charset=utf-8"));
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(&"test".to_owned()));
@ -907,7 +936,7 @@ mod tests {
let b = Bytes::from_static(b"test");
let resp: HttpResponse = b.into();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(),
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
header::HeaderValue::from_static("application/octet-stream"));
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(Bytes::from_static(b"test")));
@ -915,7 +944,7 @@ mod tests {
let b = Bytes::from_static(b"test");
let resp: HttpResponse = b.respond_to(req.clone()).ok().unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(),
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
header::HeaderValue::from_static("application/octet-stream"));
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(Bytes::from_static(b"test")));
@ -923,7 +952,7 @@ mod tests {
let b = BytesMut::from("test");
let resp: HttpResponse = b.into();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(),
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
header::HeaderValue::from_static("application/octet-stream"));
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(BytesMut::from("test")));
@ -931,7 +960,7 @@ mod tests {
let b = BytesMut::from("test");
let resp: HttpResponse = b.respond_to(req.clone()).ok().unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().get(header::CONTENT_TYPE).unwrap(),
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(),
header::HeaderValue::from_static("application/octet-stream"));
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.body().binary().unwrap(), &Binary::from(BytesMut::from("test")));

View File

@ -17,6 +17,7 @@ use net2::TcpBuilder;
use ws;
use body::Binary;
use error::Error;
use header::Header;
use handler::{Handler, Responder, ReplyItem};
use middleware::Middleware;
use application::{Application, HttpApplication};
@ -332,6 +333,12 @@ impl TestRequest<()> {
TestRequest::default().uri(path)
}
/// Create TestRequest and set header
pub fn with_hdr<H: Header>(hdr: H) -> TestRequest<()>
{
TestRequest::default().set(hdr)
}
/// Create TestRequest and set header
pub fn with_header<K, V>(key: K, value: V) -> TestRequest<()>
where HeaderName: HttpTryFrom<K>,
@ -375,6 +382,16 @@ impl<S> TestRequest<S> {
self
}
/// Set a header
pub fn set<H: Header>(mut self, hdr: H) -> Self
{
if let Ok(value) = hdr.try_into() {
self.headers.append(H::name(), value);
return self
}
panic!("Can not set header");
}
/// Set a header
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where HeaderName: HttpTryFrom<K>,