1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-23 15:51:06 +01:00

Make it work with actix-web 4.0.0-beta.7

This commit is contained in:
Andrey Kutejko 2021-06-21 18:00:20 +02:00
parent 64eec6e550
commit 6682fc826f
18 changed files with 168 additions and 119 deletions

View File

@ -101,7 +101,7 @@ jobs:
- name: Clear the cargo caches - name: Clear the cargo caches
run: | run: |
cargo install cargo-cache --no-default-features --features ci-autoclean cargo install cargo-cache --version 0.6.2 --no-default-features --features ci-autoclean
cargo-cache cargo-cache
build_and_test_other: build_and_test_other:
@ -180,5 +180,5 @@ jobs:
- name: Clear the cargo caches - name: Clear the cargo caches
run: | run: |
cargo install cargo-cache --no-default-features --features ci-autoclean cargo install cargo-cache --version 0.6.2 --no-default-features --features ci-autoclean
cargo-cache cargo-cache

View File

@ -19,8 +19,8 @@ name = "actix_cors"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "4.0.0-beta.5", default-features = false } actix-web = { version = "4.0.0-beta.7", default-features = false }
actix-service = "2.0.0-beta.5" actix-service = "2"
derive_more = "0.99.5" derive_more = "0.99.5"
futures-util = { version = "0.3.7", default-features = false } futures-util = { version = "0.3.7", default-features = false }

View File

@ -1,6 +1,10 @@
use std::{collections::HashSet, convert::TryInto, iter::FromIterator, rc::Rc}; use std::{
collections::HashSet, convert::TryInto, error::Error as StdError,
iter::FromIterator, rc::Rc,
};
use actix_web::{ use actix_web::{
body::MessageBody,
dev::{RequestHead, Service, ServiceRequest, ServiceResponse, Transform}, dev::{RequestHead, Service, ServiceRequest, ServiceResponse, Transform},
error::{Error, Result}, error::{Error, Result},
http::{self, header::HeaderName, Error as HttpError, HeaderValue, Method, Uri}, http::{self, header::HeaderName, Error as HttpError, HeaderValue, Method, Uri},
@ -487,9 +491,10 @@ impl<S, B> Transform<S, ServiceRequest> for Cors
where where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>, S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static, S::Future: 'static,
B: 'static, B: MessageBody + 'static,
B::Error: StdError,
{ {
type Response = ServiceResponse<B>; type Response = ServiceResponse;
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Transform = CorsMiddleware<S>; type Transform = CorsMiddleware<S>;

View File

@ -1,6 +1,7 @@
use std::{convert::TryInto, rc::Rc}; use std::{convert::TryInto, error::Error as StdError, rc::Rc};
use actix_web::{ use actix_web::{
body::{Body, MessageBody},
dev::{Service, ServiceRequest, ServiceResponse}, dev::{Service, ServiceRequest, ServiceResponse},
error::{Error, Result}, error::{Error, Result},
http::{ http::{
@ -25,8 +26,13 @@ pub struct CorsMiddleware<S> {
pub(crate) inner: Rc<Inner>, pub(crate) inner: Rc<Inner>,
} }
impl<S> CorsMiddleware<S> { impl<S, B> CorsMiddleware<S>
fn handle_preflight<B>(inner: &Inner, req: ServiceRequest) -> ServiceResponse<B> { where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
B: MessageBody + 'static,
B::Error: StdError,
{
fn handle_preflight(inner: &Inner, req: ServiceRequest) -> ServiceResponse {
if let Err(err) = inner if let Err(err) = inner
.validate_origin(req.head()) .validate_origin(req.head())
.and_then(|_| inner.validate_allowed_method(req.head())) .and_then(|_| inner.validate_allowed_method(req.head()))
@ -69,11 +75,10 @@ impl<S> CorsMiddleware<S> {
} }
let res = res.finish(); let res = res.finish();
let res = res.into_body();
req.into_response(res) req.into_response(res)
} }
fn augment_response<B>( fn augment_response(
inner: &Inner, inner: &Inner,
mut res: ServiceResponse<B>, mut res: ServiceResponse<B>,
) -> ServiceResponse<B> { ) -> ServiceResponse<B> {
@ -112,20 +117,21 @@ impl<S> CorsMiddleware<S> {
} }
} }
type CorsMiddlewareServiceFuture<B> = Either< type CorsMiddlewareServiceFuture = Either<
Ready<Result<ServiceResponse<B>, Error>>, Ready<Result<ServiceResponse, Error>>,
LocalBoxFuture<'static, Result<ServiceResponse<B>, Error>>, LocalBoxFuture<'static, Result<ServiceResponse, Error>>,
>; >;
impl<S, B> Service<ServiceRequest> for CorsMiddleware<S> impl<S, B> Service<ServiceRequest> for CorsMiddleware<S>
where where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>, S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static, S::Future: 'static,
B: 'static, B: MessageBody + 'static,
B::Error: StdError,
{ {
type Response = ServiceResponse<B>; type Response = ServiceResponse;
type Error = Error; type Error = Error;
type Future = CorsMiddlewareServiceFuture<B>; type Future = CorsMiddlewareServiceFuture;
actix_service::forward_ready!(service); actix_service::forward_ready!(service);
@ -157,6 +163,7 @@ where
} else { } else {
res res
} }
.map(|res| res.map_body(|_, body| Body::from_message(body)))
} }
.boxed_local(); .boxed_local();

View File

@ -15,13 +15,13 @@ name = "actix_identity"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-service = "2.0.0-beta.5" actix-service = "2"
actix-web = { version = "4.0.0-beta.5", default-features = false, features = ["cookies", "secure-cookies"] } actix-web = { version = "4.0.0-beta.7", default-features = false, features = ["cookies", "secure-cookies"] }
futures-util = { version = "0.3.7", default-features = false } futures-util = { version = "0.3.7", default-features = false }
serde = "1.0" serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
time = "0.2.23" time = "0.2.23"
[dev-dependencies] [dev-dependencies]
actix-http = "3.0.0-beta.4" actix-http = "3.0.0-beta.7"
actix-rt = "2" actix-rt = "2"

View File

@ -7,7 +7,7 @@ use time::Duration;
use actix_web::{ use actix_web::{
cookie::{Cookie, CookieJar, Key, SameSite}, cookie::{Cookie, CookieJar, Key, SameSite},
dev::{ServiceRequest, ServiceResponse}, dev::{ServiceRequest, ServiceResponse},
error::{Error, Result}, error::{Error, ErrorInternalServerError, Result},
http::header::{self, HeaderValue}, http::header::{self, HeaderValue},
HttpMessage, HttpMessage,
}; };
@ -69,16 +69,18 @@ impl CookieIdentityInner {
value: Option<CookieValue>, value: Option<CookieValue>,
) -> Result<()> { ) -> Result<()> {
let add_cookie = value.is_some(); let add_cookie = value.is_some();
let val = value.map(|val| { let val = value
if !self.legacy_supported() { .map(|val| {
serde_json::to_string(&val) if !self.legacy_supported() {
} else { serde_json::to_string(&val)
Ok(val.identity) } else {
} Ok(val.identity)
}); }
})
.transpose()
.map_err(ErrorInternalServerError)?;
let mut cookie = let mut cookie = Cookie::new(self.name.clone(), val.unwrap_or_default());
Cookie::new(self.name.clone(), val.unwrap_or_else(|| Ok(String::new()))?);
cookie.set_path(self.path.clone()); cookie.set_path(self.path.clone());
cookie.set_secure(self.secure); cookie.set_secure(self.secure);
cookie.set_http_only(true); cookie.set_http_only(true);
@ -108,10 +110,10 @@ impl CookieIdentityInner {
}; };
if add_cookie { if add_cookie {
jar.private(&key).add(cookie); jar.private_mut(&key).add(cookie);
} else { } else {
jar.add_original(cookie.clone()); jar.add_original(cookie.clone());
jar.private(&key).remove(cookie); jar.private_mut(&key).remove(cookie);
} }
for cookie in jar.delta() { for cookie in jar.delta() {
@ -391,7 +393,7 @@ mod tests {
.copied() .copied()
.collect(); .collect();
jar.private(&Key::derive_from(&key)).add(Cookie::new( jar.private_mut(&Key::derive_from(&key)).add(Cookie::new(
COOKIE_NAME, COOKIE_NAME,
serde_json::to_string(&CookieValue { serde_json::to_string(&CookieValue {
identity: identity.to_string(), identity: identity.to_string(),
@ -575,7 +577,7 @@ mod tests {
fn legacy_login_cookie(identity: &'static str) -> Cookie<'static> { fn legacy_login_cookie(identity: &'static str) -> Cookie<'static> {
let mut jar = CookieJar::new(); let mut jar = CookieJar::new();
jar.private(&Key::derive_from(&COOKIE_KEY_MASTER)) jar.private_mut(&Key::derive_from(&COOKIE_KEY_MASTER))
.add(Cookie::new(COOKIE_NAME, identity)); .add(Cookie::new(COOKIE_NAME, identity));
jar.get(COOKIE_NAME).unwrap().clone() jar.get(COOKIE_NAME).unwrap().clone()
} }

View File

@ -1,6 +1,7 @@
use std::rc::Rc; use std::{error::Error as StdError, rc::Rc};
use actix_web::{ use actix_web::{
body::{Body, MessageBody},
dev::{Service, ServiceRequest, ServiceResponse, Transform}, dev::{Service, ServiceRequest, ServiceResponse, Transform},
Error, HttpMessage, Result, Error, HttpMessage, Result,
}; };
@ -41,9 +42,10 @@ where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static, S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static, S::Future: 'static,
T: IdentityPolicy, T: IdentityPolicy,
B: 'static, B: MessageBody + 'static,
B::Error: StdError,
{ {
type Response = ServiceResponse<B>; type Response = ServiceResponse;
type Error = Error; type Error = Error;
type InitError = (); type InitError = ();
type Transform = IdentityServiceMiddleware<S, T>; type Transform = IdentityServiceMiddleware<S, T>;
@ -73,12 +75,13 @@ impl<S, T> Clone for IdentityServiceMiddleware<S, T> {
impl<S, T, B> Service<ServiceRequest> for IdentityServiceMiddleware<S, T> impl<S, T, B> Service<ServiceRequest> for IdentityServiceMiddleware<S, T>
where where
B: 'static, B: MessageBody + 'static,
B::Error: StdError,
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static, S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static, S::Future: 'static,
T: IdentityPolicy, T: IdentityPolicy,
{ {
type Response = ServiceResponse<B>; type Response = ServiceResponse;
type Error = Error; type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
@ -100,11 +103,13 @@ where
if let Some(id) = id { if let Some(id) = id {
match backend.to_response(id.id, id.changed, &mut res).await { match backend.to_response(id.id, id.changed, &mut res).await {
Ok(_) => Ok(res), Ok(_) => {
Ok(res.map_body(|_, body| Body::from_message(body)))
}
Err(e) => Ok(res.error_response(e)), Err(e) => Ok(res.error_response(e)),
} }
} else { } else {
Ok(res) Ok(res.map_body(|_, body| Body::from_message(body)))
} }
} }
Err(err) => Ok(req.error_response(err)), Err(err) => Ok(req.error_response(err)),

View File

@ -20,7 +20,7 @@ path = "src/lib.rs"
[dependencies] [dependencies]
actix-rt = "2" actix-rt = "2"
actix-web = { version = "4.0.0-beta.5", default_features = false } actix-web = { version = "4.0.0-beta.7", default_features = false }
derive_more = "0.99.5" derive_more = "0.99.5"
futures-util = { version = "0.3.7", default-features = false } futures-util = { version = "0.3.7", default-features = false }
prost = "0.7" prost = "0.7"

View File

@ -8,7 +8,7 @@ authors = [
] ]
[dependencies] [dependencies]
actix-web = "4.0.0-beta.5" actix-web = "4.0.0-beta.7"
actix-protobuf = { path = "../../" } actix-protobuf = { path = "../../" }
env_logger = "0.8" env_logger = "0.8"

View File

@ -12,11 +12,13 @@ use prost::DecodeError as ProtoBufDecodeError;
use prost::EncodeError as ProtoBufEncodeError; use prost::EncodeError as ProtoBufEncodeError;
use prost::Message; use prost::Message;
use actix_web::dev::{HttpResponseBuilder, Payload}; use actix_web::dev::Payload;
use actix_web::error::{Error, PayloadError, ResponseError}; use actix_web::error::{Error, ErrorBadRequest, ErrorPayloadTooLarge, PayloadError};
use actix_web::http::header::{CONTENT_LENGTH, CONTENT_TYPE}; use actix_web::http::header::{CONTENT_LENGTH, CONTENT_TYPE};
use actix_web::web::BytesMut; use actix_web::web::BytesMut;
use actix_web::{FromRequest, HttpMessage, HttpRequest, HttpResponse, Responder}; use actix_web::{
FromRequest, HttpMessage, HttpRequest, HttpResponse, HttpResponseBuilder, Responder,
};
use futures_util::future::{FutureExt, LocalBoxFuture}; use futures_util::future::{FutureExt, LocalBoxFuture};
use futures_util::StreamExt; use futures_util::StreamExt;
@ -39,11 +41,11 @@ pub enum ProtoBufPayloadError {
Payload(PayloadError), Payload(PayloadError),
} }
impl ResponseError for ProtoBufPayloadError { impl Into<Error> for ProtoBufPayloadError {
fn error_response(&self) -> HttpResponse { fn into(self) -> Error {
match *self { match self {
ProtoBufPayloadError::Overflow => HttpResponse::PayloadTooLarge().into(), ProtoBufPayloadError::Overflow => ErrorPayloadTooLarge(self).into(),
_ => HttpResponse::BadRequest().into(), _ => ErrorBadRequest(self).into(),
} }
} }
} }
@ -143,9 +145,7 @@ impl<T: Message + Default> Responder for ProtoBuf<T> {
Ok(()) => HttpResponse::Ok() Ok(()) => HttpResponse::Ok()
.content_type("application/protobuf") .content_type("application/protobuf")
.body(buf), .body(buf),
Err(err) => HttpResponse::from_error(Error::from( Err(err) => HttpResponse::from_error(ProtoBufPayloadError::Serialize(err)),
ProtoBufPayloadError::Serialize(err),
)),
} }
} }
} }
@ -255,7 +255,7 @@ impl ProtoBufResponseBuilder for HttpResponseBuilder {
let mut body = Vec::new(); let mut body = Vec::new();
value value
.encode(&mut body) .encode(&mut body)
.map_err(ProtoBufPayloadError::Serialize)?; .map_err(|err| Into::<Error>::into(ProtoBufPayloadError::Serialize(err)))?;
Ok(self.body(body)) Ok(self.body(body))
} }
} }

View File

@ -31,31 +31,31 @@ web = [
] ]
[dependencies] [dependencies]
actix = { version = "0.11.0", default-features = false } actix = { version = "0.12.0", default-features = false }
actix-rt = { version = "2.1", default-features = false } actix-rt = { version = "2.1", default-features = false }
actix-service = "2.0.0-beta.5" actix-service = "2"
actix-tls = { version = "3.0.0-beta.5", default-features = false, features = ["connect"] } actix-tls = { version = "3.0.0-beta.5", default-features = false, features = ["connect"] }
log = "0.4.6" log = "0.4.6"
backoff = "0.2.1" backoff = "0.3.0"
derive_more = "0.99.2" derive_more = "0.99.2"
futures-core = { version = "0.3.7", default-features = false } futures-core = { version = "0.3.7", default-features = false }
redis2 = { package = "redis", version = "0.19.0", features = ["tokio-comp", "tokio-native-tls-comp"] } redis2 = { package = "redis", version = "0.20.0", features = ["tokio-comp", "tokio-native-tls-comp"] }
redis-async = { version = "0.8", default-features = false, features = ["tokio10"] } redis-async = { version = "0.9.2", default-features = false, features = ["tokio10"] }
time = "0.2.23" time = "0.2.23"
tokio = { version = "1", features = ["sync"] } tokio = { version = "1", features = ["sync"] }
tokio-util = "0.6.1" tokio-util = "0.6.1"
# actix-session # actix-session
actix-web = { version = "4.0.0-beta.5", default_features = false, optional = true } actix-web = { version = "4.0.0-beta.7", default_features = false, optional = true }
actix-session = { version = "0.5.0-beta.1", optional = true } actix-session = { version = "0.5.0-beta.1", optional = true }
rand = { version = "0.8.0", optional = true } rand = { version = "0.8.0", optional = true }
serde = { version = "1.0.101", optional = true } serde = { version = "1.0.101", optional = true }
serde_json = { version = "1.0.40", optional = true } serde_json = { version = "1.0.40", optional = true }
[dev-dependencies] [dev-dependencies]
actix-test = "0.1.0-beta.1" actix-test = "0.1.0-beta.3"
actix-http = "3.0.0-beta.5" actix-http = "3.0.0-beta.7"
actix-rt = "2.1" actix-rt = "2.1"
env_logger = "0.7" env_logger = "0.8"
serde_derive = "1.0" serde_derive = "1.0"

View File

@ -1,7 +1,8 @@
use actix_redis::RedisSession; use actix_redis::RedisSession;
use actix_session::Session; use actix_session::Session;
use actix_web::{ use actix_web::{
cookie, middleware, web, App, Error, HttpResponse, HttpServer, Responder, cookie, error::InternalError, middleware, web, App, Error, HttpResponse, HttpServer,
Responder,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -19,10 +20,10 @@ struct User {
} }
impl User { impl User {
fn authenticate(credentials: Credentials) -> Result<Self, HttpResponse> { fn authenticate(credentials: Credentials) -> Result<Self, Error> {
// TODO: figure out why I keep getting hacked // TODO: figure out why I keep getting hacked
if &credentials.password != "hunter2" { if &credentials.password != "hunter2" {
return Err(HttpResponse::Unauthorized().json("Unauthorized")); return Err(unauthorized());
} }
Ok(User { Ok(User {
@ -33,29 +34,32 @@ impl User {
} }
} }
pub fn validate_session(session: &Session) -> Result<i64, HttpResponse> { fn unauthorized() -> Error {
let user_id: Option<i64> = session.get("user_id").unwrap_or(None); InternalError::from_response(
"Unauthorized",
HttpResponse::Unauthorized().json("Unauthorized").into(),
)
.into()
}
match user_id { pub fn validate_session(session: &Session) -> Result<i64, Error> {
Some(id) => { let user_id: i64 = session
// keep the user's session alive .get("user_id")
session.renew(); .unwrap_or(None)
Ok(id) .ok_or_else(unauthorized)?;
} // keep the user's session alive
None => Err(HttpResponse::Unauthorized().json("Unauthorized")), session.renew();
} Ok(user_id)
} }
async fn login( async fn login(
credentials: web::Json<Credentials>, credentials: web::Json<Credentials>,
session: Session, session: Session,
) -> Result<impl Responder, HttpResponse> { ) -> Result<impl Responder, Error> {
let credentials = credentials.into_inner(); let credentials = credentials.into_inner();
match User::authenticate(credentials) { let user = User::authenticate(credentials)?;
Ok(user) => session.insert("user_id", user.id).unwrap(), session.insert("user_id", user.id).unwrap();
Err(_) => return Err(HttpResponse::Unauthorized().json("Unauthorized")),
};
Ok("Welcome!") Ok("Welcome!")
} }

View File

@ -6,7 +6,7 @@ use actix_session::{Session, SessionStatus};
use actix_web::cookie::{Cookie, CookieJar, Key, SameSite}; use actix_web::cookie::{Cookie, CookieJar, Key, SameSite};
use actix_web::dev::{ServiceRequest, ServiceResponse}; use actix_web::dev::{ServiceRequest, ServiceResponse};
use actix_web::http::header::{self, HeaderValue}; use actix_web::http::header::{self, HeaderValue};
use actix_web::{error, Error, HttpMessage}; use actix_web::{error, Error};
use futures_core::future::LocalBoxFuture; use futures_core::future::LocalBoxFuture;
use rand::{distributions::Alphanumeric, rngs::OsRng, Rng}; use rand::{distributions::Alphanumeric, rngs::OsRng, Rng};
use redis_async::resp::RespValue; use redis_async::resp::RespValue;
@ -311,7 +311,7 @@ impl Inner {
// set cookie // set cookie
let mut jar = CookieJar::new(); let mut jar = CookieJar::new();
jar.signed(&self.key).add(cookie); jar.signed_mut(&self.key).add(cookie);
(value, Some(jar)) (value, Some(jar))
}; };
@ -320,10 +320,8 @@ impl Inner {
let state: HashMap<_, _> = state.collect(); let state: HashMap<_, _> = state.collect();
let body = match serde_json::to_string(&state) { let body =
Err(e) => return Err(e.into()), serde_json::to_string(&state).map_err(error::ErrorInternalServerError)?;
Ok(body) => body,
};
let cmd = Command(resp_array!["SET", cache_key, body, "EX", &self.ttl]); let cmd = Command(resp_array!["SET", cache_key, body, "EX", &self.ttl]);
@ -444,9 +442,9 @@ mod test {
let id: Option<String> = session.get("user_id")?; let id: Option<String> = session.get("user_id")?;
if let Some(x) = id { if let Some(x) = id {
session.purge(); session.purge();
Ok(format!("Logged out: {}", x).into()) Ok(HttpResponse::Ok().body(format!("Logged out: {}", x)))
} else { } else {
Ok("Could not log out anonymous user".into()) Ok(HttpResponse::Ok().body("Could not log out anonymous user"))
} }
} }
@ -648,7 +646,11 @@ mod test {
.unwrap(); .unwrap();
assert_ne!( assert_ne!(
OffsetDateTime::now_utc().year(), OffsetDateTime::now_utc().year(),
cookie_4.expires().map(|t| t.year()).unwrap() cookie_4
.expires()
.and_then(|e| e.datetime())
.map(|t| t.year())
.unwrap()
); );
// Step 10: GET index, including session cookie #2 in request // Step 10: GET index, including session cookie #2 in request

View File

@ -20,7 +20,7 @@ default = ["cookie-session"]
cookie-session = ["actix-web/secure-cookies"] cookie-session = ["actix-web/secure-cookies"]
[dependencies] [dependencies]
actix-web = { version = "4.0.0-beta.5", default_features = false, features = ["cookies"] } actix-web = { version = "4.0.0-beta.7", default_features = false, features = ["cookies"] }
actix-service = "2.0.0-beta.5" actix-service = "2.0.0-beta.5"
derive_more = "0.99.5" derive_more = "0.99.5"

View File

@ -1,12 +1,13 @@
//! Cookie based sessions. See docs for [`CookieSession`]. //! Cookie based sessions. See docs for [`CookieSession`].
use std::{collections::HashMap, rc::Rc}; use std::{collections::HashMap, error::Error as StdError, rc::Rc};
use actix_service::{Service, Transform}; use actix_service::{Service, Transform};
use actix_web::body::{Body, MessageBody};
use actix_web::cookie::{Cookie, CookieJar, Key, SameSite}; use actix_web::cookie::{Cookie, CookieJar, Key, SameSite};
use actix_web::dev::{ServiceRequest, ServiceResponse}; use actix_web::dev::{ServiceRequest, ServiceResponse};
use actix_web::http::{header::SET_COOKIE, HeaderValue}; use actix_web::http::{header::SET_COOKIE, HeaderValue};
use actix_web::{Error, HttpMessage, ResponseError}; use actix_web::{Error, ResponseError};
use derive_more::Display; use derive_more::Display;
use futures_util::future::{ok, LocalBoxFuture, Ready}; use futures_util::future::{ok, LocalBoxFuture, Ready};
use serde_json::error::Error as JsonError; use serde_json::error::Error as JsonError;
@ -106,8 +107,8 @@ impl CookieSessionInner {
let mut jar = CookieJar::new(); let mut jar = CookieJar::new();
match self.security { match self.security {
CookieSecurity::Signed => jar.signed(&self.key).add(cookie), CookieSecurity::Signed => jar.signed_mut(&self.key).add(cookie),
CookieSecurity::Private => jar.private(&self.key).add(cookie), CookieSecurity::Private => jar.private_mut(&self.key).add(cookie),
} }
for cookie in jar.delta() { for cookie in jar.delta() {
@ -292,13 +293,15 @@ impl CookieSession {
} }
} }
impl<S, B: 'static> Transform<S, ServiceRequest> for CookieSession impl<S, B> Transform<S, ServiceRequest> for CookieSession
where where
S: Service<ServiceRequest, Response = ServiceResponse<B>>, S: Service<ServiceRequest, Response = ServiceResponse<B>>,
S::Future: 'static, S::Future: 'static,
S::Error: 'static, S::Error: StdError + 'static,
B: MessageBody + 'static,
B::Error: StdError,
{ {
type Response = ServiceResponse<B>; type Response = ServiceResponse;
type Error = S::Error; type Error = S::Error;
type InitError = (); type InitError = ();
type Transform = CookieSessionMiddleware<S>; type Transform = CookieSessionMiddleware<S>;
@ -318,13 +321,15 @@ pub struct CookieSessionMiddleware<S> {
inner: Rc<CookieSessionInner>, inner: Rc<CookieSessionInner>,
} }
impl<S, B: 'static> Service<ServiceRequest> for CookieSessionMiddleware<S> impl<S, B> Service<ServiceRequest> for CookieSessionMiddleware<S>
where where
S: Service<ServiceRequest, Response = ServiceResponse<B>>, S: Service<ServiceRequest, Response = ServiceResponse<B>>,
S::Future: 'static, S::Future: 'static,
S::Error: 'static, S::Error: 'static,
B: MessageBody + 'static,
B::Error: StdError,
{ {
type Response = ServiceResponse<B>; type Response = ServiceResponse;
type Error = S::Error; type Error = S::Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
@ -346,32 +351,35 @@ where
Box::pin(async move { Box::pin(async move {
let mut res = fut.await?; let mut res = fut.await?;
let res = match Session::get_changes(&mut res) { let result = match Session::get_changes(&mut res) {
(SessionStatus::Changed, state) | (SessionStatus::Renewed, state) => { (SessionStatus::Changed, state) | (SessionStatus::Renewed, state) => {
res.checked_expr(|res| inner.set_cookie(res, state)) inner.set_cookie(&mut res, state)
} }
(SessionStatus::Unchanged, state) if prolong_expiration => { (SessionStatus::Unchanged, state) if prolong_expiration => {
res.checked_expr(|res| inner.set_cookie(res, state)) inner.set_cookie(&mut res, state)
} }
// set a new session cookie upon first request (new client) // set a new session cookie upon first request (new client)
(SessionStatus::Unchanged, _) => { (SessionStatus::Unchanged, _) => {
if is_new { if is_new {
let state: HashMap<String, String> = HashMap::new(); let state: HashMap<String, String> = HashMap::new();
res.checked_expr(|res| inner.set_cookie(res, state.into_iter())) inner.set_cookie(&mut res, state.into_iter())
} else { } else {
res Ok(())
} }
} }
(SessionStatus::Purged, _) => { (SessionStatus::Purged, _) => {
let _ = inner.remove_cookie(&mut res); let _ = inner.remove_cookie(&mut res);
res Ok(())
} }
}; };
Ok(res) match result {
Ok(_) => Ok(res.map_body(|_, body| Body::from_message(body))),
Err(error) => Ok(res.error_response(error)),
}
}) })
} }
} }
@ -533,7 +541,9 @@ mod tests {
.find(|c| c.name() == "actix-session") .find(|c| c.name() == "actix-session")
.expect("Cookie is set") .expect("Cookie is set")
.expires() .expires()
.expect("Expiration is set"); .expect("Expiration is set")
.datetime()
.expect("Expiration is a datetime");
actix_rt::time::sleep(std::time::Duration::from_secs(1)).await; actix_rt::time::sleep(std::time::Duration::from_secs(1)).await;
@ -545,7 +555,9 @@ mod tests {
.find(|c| c.name() == "actix-session") .find(|c| c.name() == "actix-session")
.expect("Cookie is set") .expect("Cookie is set")
.expires() .expires()
.expect("Expiration is set"); .expect("Expiration is set")
.datetime()
.expect("Expiration is a datetime");
assert!(expires_2 - expires_1 >= Duration::seconds(1)); assert!(expires_2 - expires_1 >= Duration::seconds(1));
} }

View File

@ -51,6 +51,7 @@ use std::{
use actix_web::{ use actix_web::{
dev::{Extensions, Payload, RequestHead, ServiceRequest, ServiceResponse}, dev::{Extensions, Payload, RequestHead, ServiceRequest, ServiceResponse},
error::ErrorInternalServerError,
Error, FromRequest, HttpMessage, HttpRequest, Error, FromRequest, HttpMessage, HttpRequest,
}; };
use futures_util::future::{ok, Ready}; use futures_util::future::{ok, Ready};
@ -148,7 +149,9 @@ impl Session {
/// Get a `value` from the session. /// Get a `value` from the session.
pub fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error> { pub fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error> {
if let Some(s) = self.0.borrow().state.get(key) { if let Some(s) = self.0.borrow().state.get(key) {
Ok(Some(serde_json::from_str(s)?)) Ok(Some(
serde_json::from_str(s).map_err(ErrorInternalServerError)?,
))
} else { } else {
Ok(None) Ok(None)
} }
@ -174,7 +177,7 @@ impl Session {
if inner.status != SessionStatus::Purged { if inner.status != SessionStatus::Purged {
inner.status = SessionStatus::Changed; inner.status = SessionStatus::Changed;
let val = serde_json::to_string(&value)?; let val = serde_json::to_string(&value).map_err(ErrorInternalServerError)?;
inner.state.insert(key.into(), val); inner.state.insert(key.into(), val);
} }

View File

@ -20,7 +20,7 @@ name = "actix_web_httpauth"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
actix-web = { version = "4.0.0-beta.5", default_features = false } actix-web = { version = "4.0.0-beta.7", default_features = false }
actix-service = "2.0.0-beta.5" actix-service = "2.0.0-beta.5"
base64 = "0.13" base64 = "0.13"
futures-util = { version = "0.3.7", default-features = false } futures-util = { version = "0.3.7", default-features = false }

View File

@ -1,8 +1,12 @@
//! HTTP Authentication middleware. //! HTTP Authentication middleware.
use std::{future::Future, marker::PhantomData, pin::Pin, rc::Rc, sync::Arc}; use std::{
error::Error as StdError, future::Future, marker::PhantomData, pin::Pin, rc::Rc,
sync::Arc,
};
use actix_web::{ use actix_web::{
body::{Body, MessageBody},
dev::{Service, ServiceRequest, ServiceResponse, Transform}, dev::{Service, ServiceRequest, ServiceResponse, Transform},
Error, Error,
}; };
@ -120,8 +124,10 @@ where
F: Fn(ServiceRequest, T) -> O + 'static, F: Fn(ServiceRequest, T) -> O + 'static,
O: Future<Output = Result<ServiceRequest, Error>> + 'static, O: Future<Output = Result<ServiceRequest, Error>> + 'static,
T: AuthExtractor + 'static, T: AuthExtractor + 'static,
B: MessageBody + 'static,
B::Error: StdError,
{ {
type Response = ServiceResponse<B>; type Response = ServiceResponse;
type Error = Error; type Error = Error;
type Transform = AuthenticationMiddleware<S, F, T>; type Transform = AuthenticationMiddleware<S, F, T>;
type InitError = (); type InitError = ();
@ -153,10 +159,12 @@ where
F: Fn(ServiceRequest, T) -> O + 'static, F: Fn(ServiceRequest, T) -> O + 'static,
O: Future<Output = Result<ServiceRequest, Error>> + 'static, O: Future<Output = Result<ServiceRequest, Error>> + 'static,
T: AuthExtractor + 'static, T: AuthExtractor + 'static,
B: MessageBody + 'static,
B::Error: StdError,
{ {
type Response = ServiceResponse<B>; type Response = ServiceResponse;
type Error = S::Error; type Error = S::Error;
type Future = LocalBoxFuture<'static, Result<ServiceResponse<B>, Error>>; type Future = LocalBoxFuture<'static, Result<ServiceResponse, Error>>;
actix_service::forward_ready!(service); actix_service::forward_ready!(service);
@ -177,7 +185,8 @@ where
// middleware to do their thing (eg. cors adding headers) // middleware to do their thing (eg. cors adding headers)
let req = process_fn(req, credentials).await?; let req = process_fn(req, credentials).await?;
service.call(req).await let res = service.call(req).await?;
Ok(res.map_body(|_, body| Body::from_message(body)))
} }
.boxed_local() .boxed_local()
} }