From 265b21312305386f118ad366d905b74a90a629d8 Mon Sep 17 00:00:00 2001 From: Keith Cirkel Date: Tue, 14 Jan 2025 00:51:40 +0000 Subject: [PATCH] implement contains_key, update, update_or (#459) * implement contains_key, update, update_or * docs(session): update docs for new methods * docs(session): clarify errors * test(session): fix doctest --------- Co-authored-by: Rob Ede --- actix-session/CHANGES.md | 2 + actix-session/src/session.rs | 109 ++++++++++++++++++++++++++++++++- actix-session/tests/session.rs | 39 ++++++++++++ justfile | 1 + 4 files changed, 148 insertions(+), 3 deletions(-) diff --git a/actix-session/CHANGES.md b/actix-session/CHANGES.md index 99a7b6461..37c0da8eb 100644 --- a/actix-session/CHANGES.md +++ b/actix-session/CHANGES.md @@ -2,6 +2,8 @@ ## Unreleased +- Add `Session::contains_key` method. +- Add `Session::update[_or]()` methods. - Update `redis` dependency to `0.27`. ## 0.10.1 diff --git a/actix-session/src/session.rs b/actix-session/src/session.rs index d7e9286d3..ef207271a 100644 --- a/actix-session/src/session.rs +++ b/actix-session/src/session.rs @@ -33,6 +33,9 @@ use serde::{de::DeserializeOwned, Serialize}; /// session.insert("counter", 1)?; /// } /// +/// // or use the shorthand +/// session.update_or("counter", 1, |count: i32| count + 1); +/// /// Ok("Welcome!") /// } /// # actix_web::web::to(index); @@ -97,6 +100,11 @@ impl Session { } } + /// Returns `true` if the session contains a value for the specified `key`. + pub fn contains_key(&self, key: &str) -> bool { + self.0.borrow().state.contains_key(key) + } + /// Get all raw key-value data from the session. /// /// Note that values are JSON encoded. @@ -114,7 +122,9 @@ impl Session { /// Any serializable value can be used and will be encoded as JSON in session data, hence why /// only a reference to the value is taken. /// - /// It returns an error if it fails to serialize `value` to JSON. + /// # Errors + /// + /// Returns an error if JSON serialization of `value` fails. pub fn insert( &self, key: impl Into, @@ -132,9 +142,8 @@ impl Session { .with_context(|| { format!( "Failed to serialize the provided `{}` type instance as JSON in order to \ - attach as session data to the `{}` key", + attach as session data to the `{key}` key", std::any::type_name::(), - &key ) }) .map_err(SessionInsertError)?; @@ -145,6 +154,83 @@ impl Session { Ok(()) } + /// Updates a key-value pair into the session. + /// + /// If the key exists then update it to the new value and place it back in. If the key does not + /// exist it will not be updated. + /// + /// Any serializable value can be used and will be encoded as JSON in the session data, hence + /// why only a reference to the value is taken. + /// + /// # Errors + /// + /// Returns an error if JSON serialization of the value fails. + pub fn update( + &self, + key: impl Into, + updater: F, + ) -> Result<(), SessionUpdateError> + where + F: FnOnce(T) -> T, + { + let mut inner = self.0.borrow_mut(); + let key_str = key.into(); + + if let Some(val_str) = inner.state.get(&key_str) { + let value = serde_json::from_str(val_str) + .with_context(|| { + format!( + "Failed to deserialize the JSON-encoded session data attached to key \ + `{key_str}` as a `{}` type", + std::any::type_name::() + ) + }) + .map_err(SessionUpdateError)?; + + let val = serde_json::to_string(&updater(value)) + .with_context(|| { + format!( + "Failed to serialize the provided `{}` type instance as JSON in order to \ + attach as session data to the `{key_str}` key", + std::any::type_name::(), + ) + }) + .map_err(SessionUpdateError)?; + + inner.state.insert(key_str, val); + } + + Ok(()) + } + + /// Updates a key-value pair into the session, or inserts a default value. + /// + /// If the key exists then update it to the new value and place it back in. If the key does not + /// exist the default value will be inserted instead. + /// + /// Any serializable value can be used and will be encoded as JSON in session data, hence why + /// only a reference to the value is taken. + /// + /// # Errors + /// + /// Returns error if JSON serialization of a value fails. + pub fn update_or( + &self, + key: &str, + default_value: T, + updater: F, + ) -> Result<(), SessionUpdateError> + where + F: FnOnce(T) -> T, + { + if self.contains_key(key) { + self.update(key, updater) + } else { + self.insert(key, default_value) + .map_err(|err| SessionUpdateError(err.into())) + } + } + /// Remove value from the session. /// /// If present, the JSON encoded value is returned. @@ -319,3 +405,20 @@ impl ResponseError for SessionInsertError { HttpResponse::new(self.status_code()) } } + +/// Error returned by [`Session::update`]. +#[derive(Debug, Display, From)] +#[display("{_0}")] +pub struct SessionUpdateError(anyhow::Error); + +impl StdError for SessionUpdateError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + Some(self.0.as_ref()) + } +} + +impl ResponseError for SessionUpdateError { + fn error_response(&self) -> HttpResponse { + HttpResponse::new(self.status_code()) + } +} diff --git a/actix-session/tests/session.rs b/actix-session/tests/session.rs index 53319e653..6a14430cb 100644 --- a/actix-session/tests/session.rs +++ b/actix-session/tests/session.rs @@ -69,6 +69,16 @@ async fn session_entries() { map.contains_key("test_num"); } +#[actix_web::test] +async fn session_contains_key() { + let req = test::TestRequest::default().to_srv_request(); + let session = req.get_session(); + session.insert("test_str", "val").unwrap(); + session.insert("test_str", 1).unwrap(); + assert!(session.contains_key("test_str")); + assert!(!session.contains_key("test_num")); +} + #[actix_web::test] async fn insert_session_after_renew() { let session = test::TestRequest::default().to_srv_request().get_session(); @@ -83,6 +93,35 @@ async fn insert_session_after_renew() { assert_eq!(session.status(), SessionStatus::Renewed); } +#[actix_web::test] +async fn update_session() { + let session = test::TestRequest::default().to_srv_request().get_session(); + + session.update("test_val", |c: u32| c + 1).unwrap(); + assert_eq!(session.status(), SessionStatus::Unchanged); + + session.insert("test_val", 0).unwrap(); + assert_eq!(session.status(), SessionStatus::Changed); + + session.update("test_val", |c: u32| c + 1).unwrap(); + assert_eq!(session.get("test_val").unwrap(), Some(1)); + + session.update("test_val", |c: u32| c + 1).unwrap(); + assert_eq!(session.get("test_val").unwrap(), Some(2)); +} + +#[actix_web::test] +async fn update_or_session() { + let session = test::TestRequest::default().to_srv_request().get_session(); + + session.update_or("test_val", 1, |c: u32| c + 1).unwrap(); + assert_eq!(session.status(), SessionStatus::Changed); + assert_eq!(session.get("test_val").unwrap(), Some(1)); + + session.update_or("test_val", 1, |c: u32| c + 1).unwrap(); + assert_eq!(session.get("test_val").unwrap(), Some(2)); +} + #[actix_web::test] async fn remove_session_after_renew() { let session = test::TestRequest::default().to_srv_request().get_session(); diff --git a/justfile b/justfile index a46dbc9ef..4d24efc84 100644 --- a/justfile +++ b/justfile @@ -43,6 +43,7 @@ update-readmes: [group("test")] test: cargo {{ toolchain }} nextest run --workspace --all-features + cargo {{ toolchain }} test --doc --workspace --all-features # Test workspace code and docs. [group("test")]