From f678842e463347bddf16df85d4d07e7e9c243610 Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Sun, 31 Jul 2022 15:10:22 +0100 Subject: [PATCH] modululize -settings --- actix-settings/src/actix.rs | 514 +----------------- actix-settings/src/actix/address.rs | 89 +++ actix-settings/src/actix/backlog.rs | 59 ++ actix-settings/src/actix/keep_alive.rs | 82 +++ .../src/actix/max_connection_rate.rs | 59 ++ actix-settings/src/actix/max_connections.rs | 59 ++ actix-settings/src/actix/mode.rs | 25 + actix-settings/src/actix/num_workers.rs | 59 ++ actix-settings/src/actix/timeout.rs | 88 +++ actix-settings/src/actix/tls.rs | 11 + actix-settings/src/lib.rs | 18 +- 11 files changed, 561 insertions(+), 502 deletions(-) create mode 100644 actix-settings/src/actix/address.rs create mode 100644 actix-settings/src/actix/backlog.rs create mode 100644 actix-settings/src/actix/keep_alive.rs create mode 100644 actix-settings/src/actix/max_connection_rate.rs create mode 100644 actix-settings/src/actix/max_connections.rs create mode 100644 actix-settings/src/actix/mode.rs create mode 100644 actix-settings/src/actix/num_workers.rs create mode 100644 actix-settings/src/actix/timeout.rs create mode 100644 actix-settings/src/actix/tls.rs diff --git a/actix-settings/src/actix.rs b/actix-settings/src/actix.rs index 128a52123..77f83d1ac 100644 --- a/actix-settings/src/actix.rs +++ b/actix-settings/src/actix.rs @@ -1,17 +1,31 @@ -use std::{fmt, path::PathBuf}; +use serde::Deserialize; -use once_cell::sync::Lazy; -use regex::Regex; -use serde::{de, Deserialize}; +mod address; +mod backlog; +mod keep_alive; +mod max_connection_rate; +mod max_connections; +mod mode; +mod num_workers; +mod timeout; +mod tls; -use crate::{core::Parse, error::AtError}; +pub use self::address::Address; +pub use self::backlog::Backlog; +pub use self::keep_alive::KeepAlive; +pub use self::max_connection_rate::MaxConnectionRate; +pub use self::max_connections::MaxConnections; +pub use self::mode::Mode; +pub use self::num_workers::NumWorkers; +pub use self::timeout::Timeout; +pub use self::tls::Tls; /// Settings types for Actix Web. -#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize)] #[serde(rename_all = "kebab-case")] pub struct ActixSettings { pub hosts: Vec
, - pub mode: Mode, + pub mode: mode::Mode, pub enable_compression: bool, pub enable_log: bool, pub num_workers: NumWorkers, @@ -24,489 +38,3 @@ pub struct ActixSettings { pub shutdown_timeout: Timeout, pub tls: Tls, } - -#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Hash)] -pub struct Address { - pub host: String, - pub port: u16, -} - -pub(crate) static ADDR_REGEX: Lazy = Lazy::new(|| { - Regex::new( - r#"(?x) - \[ # opening square bracket - (\s)* # optional whitespace - "(?P[^"]+)" # host name (string) - , # separating comma - (\s)* # optional whitespace - (?P\d+) # port number (integer) - (\s)* # optional whitespace - \] # closing square bracket - "#, - ) - .expect("Failed to compile regex: ADDR_REGEX") -}); - -pub(crate) static ADDR_LIST_REGEX: Lazy = Lazy::new(|| { - Regex::new( - r#"(?x) - \[ # opening square bracket (list) - (\s)* # optional whitespace - (?P( - \[".*", (\s)* \d+\] # element - (,)? # element separator - (\s)* # optional whitespace - )*) - (\s)* # optional whitespace - \] # closing square bracket (list) - "#, - ) - .expect("Failed to compile regex: ADDRS_REGEX") -}); - -impl Parse for Address { - fn parse(string: &str) -> Result { - let mut items = string - .trim() - .trim_start_matches('[') - .trim_end_matches(']') - .split(','); - - let parse_error = || AtError::ParseAddressError(string.to_string()); - - if !ADDR_REGEX.is_match(string) { - return Err(parse_error()); - } - - Ok(Self { - host: items.next().ok_or_else(parse_error)?.trim().to_string(), - port: items.next().ok_or_else(parse_error)?.trim().parse()?, - }) - } -} - -impl Parse for Vec
{ - fn parse(string: &str) -> Result { - let parse_error = || AtError::ParseAddressError(string.to_string()); - - if !ADDR_LIST_REGEX.is_match(string) { - return Err(parse_error()); - } - - let mut addrs = vec![]; - - for list_caps in ADDR_LIST_REGEX.captures_iter(string) { - let elements = &list_caps["elements"].trim(); - for elt_caps in ADDR_REGEX.captures_iter(elements) { - addrs.push(Address { - host: elt_caps["host"].to_string(), - port: elt_caps["port"].parse()?, - }); - } - } - - Ok(addrs) - } -} - -#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Hash)] -pub enum Mode { - #[serde(rename = "development")] - Development, - - #[serde(rename = "production")] - Production, -} - -impl Parse for Mode { - fn parse(string: &str) -> std::result::Result { - match string { - "development" => Ok(Self::Development), - "production" => Ok(Self::Production), - _ => Err(InvalidValue! { - expected: "\"development\" | \"production\".", - got: string, - }), - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum NumWorkers { - Default, - Manual(usize), -} - -impl Parse for NumWorkers { - fn parse(string: &str) -> std::result::Result { - match string { - "default" => Ok(NumWorkers::Default), - string => match string.parse::() { - Ok(val) => Ok(NumWorkers::Manual(val)), - Err(_) => Err(InvalidValue! { - expected: "a positive integer", - got: string, - }), - }, - } - } -} - -impl<'de> serde::Deserialize<'de> for NumWorkers { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - struct NumWorkersVisitor; - - impl<'de> de::Visitor<'de> for NumWorkersVisitor { - type Value = NumWorkers; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - let msg = "Either \"default\" or a string containing an integer > 0"; - formatter.write_str(msg) - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - match NumWorkers::parse(value) { - Ok(num_workers) => Ok(num_workers), - Err(AtError::InvalidValue { expected, got, .. }) => Err( - de::Error::invalid_value(de::Unexpected::Str(&got), &expected), - ), - Err(_) => unreachable!(), - } - } - } - - deserializer.deserialize_string(NumWorkersVisitor) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum Backlog { - Default, - Manual(usize), -} - -impl Parse for Backlog { - fn parse(string: &str) -> std::result::Result { - match string { - "default" => Ok(Backlog::Default), - string => match string.parse::() { - Ok(val) => Ok(Backlog::Manual(val)), - Err(_) => Err(InvalidValue! { - expected: "an integer > 0", - got: string, - }), - }, - } - } -} - -impl<'de> serde::Deserialize<'de> for Backlog { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - struct BacklogVisitor; - - impl<'de> de::Visitor<'de> for BacklogVisitor { - type Value = Backlog; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - let msg = "Either \"default\" or a string containing an integer > 0"; - formatter.write_str(msg) - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - match Backlog::parse(value) { - Ok(backlog) => Ok(backlog), - Err(AtError::InvalidValue { expected, got, .. }) => Err( - de::Error::invalid_value(de::Unexpected::Str(&got), &expected), - ), - Err(_) => unreachable!(), - } - } - } - - deserializer.deserialize_string(BacklogVisitor) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum MaxConnections { - Default, - Manual(usize), -} - -impl Parse for MaxConnections { - fn parse(string: &str) -> std::result::Result { - match string { - "default" => Ok(MaxConnections::Default), - string => match string.parse::() { - Ok(val) => Ok(MaxConnections::Manual(val)), - Err(_) => Err(InvalidValue! { - expected: "an integer > 0", - got: string, - }), - }, - } - } -} - -impl<'de> serde::Deserialize<'de> for MaxConnections { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - struct MaxConnectionsVisitor; - - impl<'de> de::Visitor<'de> for MaxConnectionsVisitor { - type Value = MaxConnections; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - let msg = "Either \"default\" or a string containing an integer > 0"; - formatter.write_str(msg) - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - match MaxConnections::parse(value) { - Ok(max_connections) => Ok(max_connections), - Err(AtError::InvalidValue { expected, got, .. }) => Err( - de::Error::invalid_value(de::Unexpected::Str(&got), &expected), - ), - Err(_) => unreachable!(), - } - } - } - - deserializer.deserialize_string(MaxConnectionsVisitor) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum MaxConnectionRate { - Default, - Manual(usize), -} - -impl Parse for MaxConnectionRate { - fn parse(string: &str) -> std::result::Result { - match string { - "default" => Ok(MaxConnectionRate::Default), - string => match string.parse::() { - Ok(val) => Ok(MaxConnectionRate::Manual(val)), - Err(_) => Err(InvalidValue! { - expected: "an integer > 0", - got: string, - }), - }, - } - } -} - -impl<'de> serde::Deserialize<'de> for MaxConnectionRate { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - struct MaxConnectionRateVisitor; - - impl<'de> de::Visitor<'de> for MaxConnectionRateVisitor { - type Value = MaxConnectionRate; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - let msg = "Either \"default\" or a string containing an integer > 0"; - formatter.write_str(msg) - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - match MaxConnectionRate::parse(value) { - Ok(max_connection_rate) => Ok(max_connection_rate), - Err(AtError::InvalidValue { expected, got, .. }) => Err( - de::Error::invalid_value(de::Unexpected::Str(&got), &expected), - ), - Err(_) => unreachable!(), - } - } - } - - deserializer.deserialize_string(MaxConnectionRateVisitor) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum KeepAlive { - Default, - Disabled, - Os, - Seconds(usize), -} - -impl Parse for KeepAlive { - fn parse(string: &str) -> std::result::Result { - pub(crate) static FMT: Lazy = - Lazy::new(|| Regex::new(r"^\d+ seconds$").expect("Failed to compile regex: FMT")); - - pub(crate) static DIGITS: Lazy = - Lazy::new(|| Regex::new(r"^\d+").expect("Failed to compile regex: FMT")); - - macro_rules! invalid_value { - ($got:expr) => { - Err(InvalidValue! { - expected: "a string of the format \"N seconds\" where N is an integer > 0", - got: $got, - }) - }; - } - - let digits_in = |m: regex::Match| &string[m.start()..m.end()]; - match string { - "default" => Ok(KeepAlive::Default), - "disabled" => Ok(KeepAlive::Disabled), - "OS" | "os" => Ok(KeepAlive::Os), - string if !FMT.is_match(string) => invalid_value!(string), - string => match DIGITS.find(string) { - None => invalid_value!(string), - Some(mat) => match digits_in(mat).parse() { - Ok(val) => Ok(KeepAlive::Seconds(val)), - Err(_) => invalid_value!(string), - }, - }, - } - } -} - -impl<'de> serde::Deserialize<'de> for KeepAlive { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - struct KeepAliveVisitor; - - impl<'de> de::Visitor<'de> for KeepAliveVisitor { - type Value = KeepAlive; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - let msg = "Either \"default\", \"disabled\", \"os\", or a string of the format \"N seconds\" where N is an integer > 0"; - formatter.write_str(msg) - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - match KeepAlive::parse(value) { - Ok(keep_alive) => Ok(keep_alive), - Err(AtError::InvalidValue { expected, got, .. }) => Err( - de::Error::invalid_value(de::Unexpected::Str(&got), &expected), - ), - Err(_) => unreachable!(), - } - } - } - - deserializer.deserialize_string(KeepAliveVisitor) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum Timeout { - Default, - Milliseconds(usize), - Seconds(usize), -} - -impl Parse for Timeout { - fn parse(string: &str) -> std::result::Result { - pub static FMT: Lazy = Lazy::new(|| { - Regex::new(r"^\d+ (milliseconds|seconds)$").expect("Failed to compile regex: FMT") - }); - - pub static DIGITS: Lazy = - Lazy::new(|| Regex::new(r"^\d+").expect("Failed to compile regex: DIGITS")); - - pub static UNIT: Lazy = Lazy::new(|| { - Regex::new(r"(milliseconds|seconds)$").expect("Failed to compile regex: UNIT") - }); - - macro_rules! invalid_value { - ($got:expr) => { - Err(InvalidValue! { - expected: "a string of the format \"N seconds\" or \"N milliseconds\" where N is an integer > 0", - got: $got, - }) - } - } - match string { - "default" => Ok(Timeout::Default), - string if !FMT.is_match(string) => invalid_value!(string), - string => match (DIGITS.find(string), UNIT.find(string)) { - (None, _) => invalid_value!(string), - (_, None) => invalid_value!(string), - (Some(dmatch), Some(umatch)) => { - let digits = &string[dmatch.start()..dmatch.end()]; - let unit = &string[umatch.start()..umatch.end()]; - match (digits.parse(), unit) { - (Ok(v), "milliseconds") => Ok(Timeout::Milliseconds(v)), - (Ok(v), "seconds") => Ok(Timeout::Seconds(v)), - _ => invalid_value!(string), - } - } - }, - } - } -} - -impl<'de> serde::Deserialize<'de> for Timeout { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - struct TimeoutVisitor; - - impl<'de> de::Visitor<'de> for TimeoutVisitor { - type Value = Timeout; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - let msg = "Either \"default\", \"disabled\", \"os\", or a string of the format \"N seconds\" where N is an integer > 0"; - formatter.write_str(msg) - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - match Timeout::parse(value) { - Ok(num_workers) => Ok(num_workers), - Err(AtError::InvalidValue { expected, got, .. }) => Err( - de::Error::invalid_value(de::Unexpected::Str(&got), &expected), - ), - Err(_) => unreachable!(), - } - } - } - - deserializer.deserialize_string(TimeoutVisitor) - } -} - -#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Hash)] -#[serde(rename_all = "kebab-case")] -pub struct Tls { - pub enabled: bool, - pub certificate: PathBuf, - pub private_key: PathBuf, -} diff --git a/actix-settings/src/actix/address.rs b/actix-settings/src/actix/address.rs new file mode 100644 index 000000000..e6af25116 --- /dev/null +++ b/actix-settings/src/actix/address.rs @@ -0,0 +1,89 @@ +use once_cell::sync::Lazy; +use regex::Regex; +use serde::Deserialize; + +use crate::{core::Parse, error::AtError}; + +static ADDR_REGEX: Lazy = Lazy::new(|| { + Regex::new( + r#"(?x) + \[ # opening square bracket + (\s)* # optional whitespace + "(?P[^"]+)" # host name (string) + , # separating comma + (\s)* # optional whitespace + (?P\d+) # port number (integer) + (\s)* # optional whitespace + \] # closing square bracket + "#, + ) + .expect("Failed to compile regex: ADDR_REGEX") +}); + +static ADDR_LIST_REGEX: Lazy = Lazy::new(|| { + Regex::new( + r#"(?x) + \[ # opening square bracket (list) + (\s)* # optional whitespace + (?P( + \[".*", (\s)* \d+\] # element + (,)? # element separator + (\s)* # optional whitespace + )*) + (\s)* # optional whitespace + \] # closing square bracket (list) + "#, + ) + .expect("Failed to compile regex: ADDRS_REGEX") +}); + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize)] +pub struct Address { + pub host: String, + pub port: u16, +} + +impl Parse for Address { + fn parse(string: &str) -> Result { + let mut items = string + .trim() + .trim_start_matches('[') + .trim_end_matches(']') + .split(','); + + let parse_error = || AtError::ParseAddressError(string.to_string()); + + if !ADDR_REGEX.is_match(string) { + return Err(parse_error()); + } + + Ok(Self { + host: items.next().ok_or_else(parse_error)?.trim().to_string(), + port: items.next().ok_or_else(parse_error)?.trim().parse()?, + }) + } +} + +impl Parse for Vec
{ + fn parse(string: &str) -> Result { + let parse_error = || AtError::ParseAddressError(string.to_string()); + + if !ADDR_LIST_REGEX.is_match(string) { + return Err(parse_error()); + } + + let mut addrs = vec![]; + + for list_caps in ADDR_LIST_REGEX.captures_iter(string) { + let elements = &list_caps["elements"].trim(); + for elt_caps in ADDR_REGEX.captures_iter(elements) { + addrs.push(Address { + host: elt_caps["host"].to_string(), + port: elt_caps["port"].parse()?, + }); + } + } + + Ok(addrs) + } +} diff --git a/actix-settings/src/actix/backlog.rs b/actix-settings/src/actix/backlog.rs new file mode 100644 index 000000000..882e01ae9 --- /dev/null +++ b/actix-settings/src/actix/backlog.rs @@ -0,0 +1,59 @@ +use std::fmt; + +use serde::de; + +use crate::{core::Parse, error::AtError}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Backlog { + Default, + Manual(usize), +} + +impl Parse for Backlog { + fn parse(string: &str) -> std::result::Result { + match string { + "default" => Ok(Backlog::Default), + string => match string.parse::() { + Ok(val) => Ok(Backlog::Manual(val)), + Err(_) => Err(InvalidValue! { + expected: "an integer > 0", + got: string, + }), + }, + } + } +} + +impl<'de> de::Deserialize<'de> for Backlog { + fn deserialize(deserializer: D) -> Result + where + D: de::Deserializer<'de>, + { + struct BacklogVisitor; + + impl<'de> de::Visitor<'de> for BacklogVisitor { + type Value = Backlog; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + let msg = "Either \"default\" or a string containing an integer > 0"; + formatter.write_str(msg) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match Backlog::parse(value) { + Ok(backlog) => Ok(backlog), + Err(AtError::InvalidValue { expected, got, .. }) => Err( + de::Error::invalid_value(de::Unexpected::Str(&got), &expected), + ), + Err(_) => unreachable!(), + } + } + } + + deserializer.deserialize_string(BacklogVisitor) + } +} diff --git a/actix-settings/src/actix/keep_alive.rs b/actix-settings/src/actix/keep_alive.rs new file mode 100644 index 000000000..82aa7174f --- /dev/null +++ b/actix-settings/src/actix/keep_alive.rs @@ -0,0 +1,82 @@ +use std::fmt; + +use once_cell::sync::Lazy; +use regex::Regex; +use serde::{de, Deserialize}; + +use crate::{core::Parse, error::AtError}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum KeepAlive { + Default, + Disabled, + Os, + Seconds(usize), +} + +impl Parse for KeepAlive { + fn parse(string: &str) -> std::result::Result { + pub(crate) static FMT: Lazy = + Lazy::new(|| Regex::new(r"^\d+ seconds$").expect("Failed to compile regex: FMT")); + + pub(crate) static DIGITS: Lazy = + Lazy::new(|| Regex::new(r"^\d+").expect("Failed to compile regex: FMT")); + + macro_rules! invalid_value { + ($got:expr) => { + Err(InvalidValue! { + expected: "a string of the format \"N seconds\" where N is an integer > 0", + got: $got, + }) + }; + } + + let digits_in = |m: regex::Match| &string[m.start()..m.end()]; + match string { + "default" => Ok(KeepAlive::Default), + "disabled" => Ok(KeepAlive::Disabled), + "OS" | "os" => Ok(KeepAlive::Os), + string if !FMT.is_match(string) => invalid_value!(string), + string => match DIGITS.find(string) { + None => invalid_value!(string), + Some(mat) => match digits_in(mat).parse() { + Ok(val) => Ok(KeepAlive::Seconds(val)), + Err(_) => invalid_value!(string), + }, + }, + } + } +} + +impl<'de> serde::Deserialize<'de> for KeepAlive { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct KeepAliveVisitor; + + impl<'de> de::Visitor<'de> for KeepAliveVisitor { + type Value = KeepAlive; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + let msg = "Either \"default\", \"disabled\", \"os\", or a string of the format \"N seconds\" where N is an integer > 0"; + formatter.write_str(msg) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match KeepAlive::parse(value) { + Ok(keep_alive) => Ok(keep_alive), + Err(AtError::InvalidValue { expected, got, .. }) => Err( + de::Error::invalid_value(de::Unexpected::Str(&got), &expected), + ), + Err(_) => unreachable!(), + } + } + } + + deserializer.deserialize_string(KeepAliveVisitor) + } +} diff --git a/actix-settings/src/actix/max_connection_rate.rs b/actix-settings/src/actix/max_connection_rate.rs new file mode 100644 index 000000000..0ecfb21d4 --- /dev/null +++ b/actix-settings/src/actix/max_connection_rate.rs @@ -0,0 +1,59 @@ +use std::fmt; + +use serde::{de, Deserialize}; + +use crate::{core::Parse, error::AtError}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum MaxConnectionRate { + Default, + Manual(usize), +} + +impl Parse for MaxConnectionRate { + fn parse(string: &str) -> std::result::Result { + match string { + "default" => Ok(MaxConnectionRate::Default), + string => match string.parse::() { + Ok(val) => Ok(MaxConnectionRate::Manual(val)), + Err(_) => Err(InvalidValue! { + expected: "an integer > 0", + got: string, + }), + }, + } + } +} + +impl<'de> serde::Deserialize<'de> for MaxConnectionRate { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct MaxConnectionRateVisitor; + + impl<'de> de::Visitor<'de> for MaxConnectionRateVisitor { + type Value = MaxConnectionRate; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + let msg = "Either \"default\" or a string containing an integer > 0"; + formatter.write_str(msg) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match MaxConnectionRate::parse(value) { + Ok(max_connection_rate) => Ok(max_connection_rate), + Err(AtError::InvalidValue { expected, got, .. }) => Err( + de::Error::invalid_value(de::Unexpected::Str(&got), &expected), + ), + Err(_) => unreachable!(), + } + } + } + + deserializer.deserialize_string(MaxConnectionRateVisitor) + } +} diff --git a/actix-settings/src/actix/max_connections.rs b/actix-settings/src/actix/max_connections.rs new file mode 100644 index 000000000..55a467d59 --- /dev/null +++ b/actix-settings/src/actix/max_connections.rs @@ -0,0 +1,59 @@ +use std::fmt; + +use serde::{de, Deserialize}; + +use crate::{core::Parse, error::AtError}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum MaxConnections { + Default, + Manual(usize), +} + +impl Parse for MaxConnections { + fn parse(string: &str) -> std::result::Result { + match string { + "default" => Ok(MaxConnections::Default), + string => match string.parse::() { + Ok(val) => Ok(MaxConnections::Manual(val)), + Err(_) => Err(InvalidValue! { + expected: "an integer > 0", + got: string, + }), + }, + } + } +} + +impl<'de> serde::Deserialize<'de> for MaxConnections { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct MaxConnectionsVisitor; + + impl<'de> de::Visitor<'de> for MaxConnectionsVisitor { + type Value = MaxConnections; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + let msg = "Either \"default\" or a string containing an integer > 0"; + formatter.write_str(msg) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match MaxConnections::parse(value) { + Ok(max_connections) => Ok(max_connections), + Err(AtError::InvalidValue { expected, got, .. }) => Err( + de::Error::invalid_value(de::Unexpected::Str(&got), &expected), + ), + Err(_) => unreachable!(), + } + } + } + + deserializer.deserialize_string(MaxConnectionsVisitor) + } +} diff --git a/actix-settings/src/actix/mode.rs b/actix-settings/src/actix/mode.rs new file mode 100644 index 000000000..7689ca6c2 --- /dev/null +++ b/actix-settings/src/actix/mode.rs @@ -0,0 +1,25 @@ +use crate::{core::Parse, error::AtError}; + +use serde::Deserialize; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize)] +pub enum Mode { + #[serde(rename = "development")] + Development, + + #[serde(rename = "production")] + Production, +} + +impl Parse for Mode { + fn parse(string: &str) -> std::result::Result { + match string { + "development" => Ok(Self::Development), + "production" => Ok(Self::Production), + _ => Err(InvalidValue! { + expected: "\"development\" | \"production\".", + got: string, + }), + } + } +} diff --git a/actix-settings/src/actix/num_workers.rs b/actix-settings/src/actix/num_workers.rs new file mode 100644 index 000000000..513931c0f --- /dev/null +++ b/actix-settings/src/actix/num_workers.rs @@ -0,0 +1,59 @@ +use std::fmt; + +use serde::de; + +use crate::{core::Parse, error::AtError}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum NumWorkers { + Default, + Manual(usize), +} + +impl Parse for NumWorkers { + fn parse(string: &str) -> std::result::Result { + match string { + "default" => Ok(NumWorkers::Default), + string => match string.parse::() { + Ok(val) => Ok(NumWorkers::Manual(val)), + Err(_) => Err(InvalidValue! { + expected: "a positive integer", + got: string, + }), + }, + } + } +} + +impl<'de> de::Deserialize<'de> for NumWorkers { + fn deserialize(deserializer: D) -> Result + where + D: de::Deserializer<'de>, + { + struct NumWorkersVisitor; + + impl<'de> de::Visitor<'de> for NumWorkersVisitor { + type Value = NumWorkers; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + let msg = "Either \"default\" or a string containing an integer > 0"; + formatter.write_str(msg) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match NumWorkers::parse(value) { + Ok(num_workers) => Ok(num_workers), + Err(AtError::InvalidValue { expected, got, .. }) => Err( + de::Error::invalid_value(de::Unexpected::Str(&got), &expected), + ), + Err(_) => unreachable!(), + } + } + } + + deserializer.deserialize_string(NumWorkersVisitor) + } +} diff --git a/actix-settings/src/actix/timeout.rs b/actix-settings/src/actix/timeout.rs new file mode 100644 index 000000000..1022d3767 --- /dev/null +++ b/actix-settings/src/actix/timeout.rs @@ -0,0 +1,88 @@ +use std::fmt; + +use once_cell::sync::Lazy; +use regex::Regex; +use serde::de; + +use crate::{core::Parse, error::AtError}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Timeout { + Default, + Milliseconds(usize), + Seconds(usize), +} + +impl Parse for Timeout { + fn parse(string: &str) -> std::result::Result { + pub static FMT: Lazy = Lazy::new(|| { + Regex::new(r"^\d+ (milliseconds|seconds)$").expect("Failed to compile regex: FMT") + }); + + pub static DIGITS: Lazy = + Lazy::new(|| Regex::new(r"^\d+").expect("Failed to compile regex: DIGITS")); + + pub static UNIT: Lazy = Lazy::new(|| { + Regex::new(r"(milliseconds|seconds)$").expect("Failed to compile regex: UNIT") + }); + + macro_rules! invalid_value { + ($got:expr) => { + Err(InvalidValue! { + expected: "a string of the format \"N seconds\" or \"N milliseconds\" where N is an integer > 0", + got: $got, + }) + } + } + match string { + "default" => Ok(Timeout::Default), + string if !FMT.is_match(string) => invalid_value!(string), + string => match (DIGITS.find(string), UNIT.find(string)) { + (None, _) => invalid_value!(string), + (_, None) => invalid_value!(string), + (Some(dmatch), Some(umatch)) => { + let digits = &string[dmatch.start()..dmatch.end()]; + let unit = &string[umatch.start()..umatch.end()]; + match (digits.parse(), unit) { + (Ok(v), "milliseconds") => Ok(Timeout::Milliseconds(v)), + (Ok(v), "seconds") => Ok(Timeout::Seconds(v)), + _ => invalid_value!(string), + } + } + }, + } + } +} + +impl<'de> serde::Deserialize<'de> for Timeout { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct TimeoutVisitor; + + impl<'de> de::Visitor<'de> for TimeoutVisitor { + type Value = Timeout; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + let msg = "Either \"default\", \"disabled\", \"os\", or a string of the format \"N seconds\" where N is an integer > 0"; + formatter.write_str(msg) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match Timeout::parse(value) { + Ok(num_workers) => Ok(num_workers), + Err(AtError::InvalidValue { expected, got, .. }) => Err( + de::Error::invalid_value(de::Unexpected::Str(&got), &expected), + ), + Err(_) => unreachable!(), + } + } + } + + deserializer.deserialize_string(TimeoutVisitor) + } +} diff --git a/actix-settings/src/actix/tls.rs b/actix-settings/src/actix/tls.rs new file mode 100644 index 000000000..043913b70 --- /dev/null +++ b/actix-settings/src/actix/tls.rs @@ -0,0 +1,11 @@ +use std::path::PathBuf; + +use serde::Deserialize; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub struct Tls { + pub enabled: bool, + pub certificate: PathBuf, + pub private_key: PathBuf, +} diff --git a/actix-settings/src/lib.rs b/actix-settings/src/lib.rs index 3432af703..b8be4b561 100644 --- a/actix-settings/src/lib.rs +++ b/actix-settings/src/lib.rs @@ -1,4 +1,5 @@ -/// A library to process Server.toml files +//! Easily manage Actix Web's settings from a TOML file and environment variables. + use std::{ env, fmt, fs::File, @@ -22,7 +23,10 @@ mod error; mod actix; mod core; -pub use self::actix::*; +pub use self::actix::{ + ActixSettings, Address, Backlog, KeepAlive, MaxConnectionRate, MaxConnections, Mode, + NumWorkers, Timeout, Tls, +}; pub use self::core::Parse; pub use self::error::{AtError, AtResult}; @@ -214,15 +218,11 @@ where #[cfg(test)] mod tests { - #![allow(non_snake_case)] + use super::*; use std::path::Path; use actix_web::{App, HttpServer}; - use serde::Deserialize; - - use crate::actix::*; // used for value construction in assertions - use crate::{ApplySettings, AtResult, BasicSettings, Settings}; #[test] fn apply_settings() -> AtResult<()> { @@ -657,12 +657,12 @@ mod tests { #[test] fn override_extended_field_with_custom_type() -> AtResult<()> { - #[derive(Debug, Clone, Deserialize, PartialEq, Eq)] + #[derive(Debug, Clone, PartialEq, Eq, Deserialize)] struct NestedSetting { foo: String, bar: bool, } - #[derive(Debug, Clone, Deserialize, PartialEq, Eq)] + #[derive(Debug, Clone, PartialEq, Eq, Deserialize)] struct AppSettings { #[serde(rename = "example-name")] example_name: String,