diff --git a/.gitignore b/.gitignore index ec2e10ce9..1f89fc17f 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ guide/build/ *.sock *~ .DS_Store + +Server.toml diff --git a/Cargo.toml b/Cargo.toml index 9a953a3ea..a7e099ede 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "actix-protobuf", "actix-redis", "actix-session", + "actix-settings", "actix-web-httpauth", ] @@ -17,6 +18,7 @@ actix-limitation = { path = "./actix-limitation" } actix-protobuf = { path = "./actix-protobuf" } actix-redis = { path = "./actix-redis" } actix-session = { path = "./actix-session" } +actix-settings = { path = "./actix-settings" } actix-web-httpauth = { path = "./actix-web-httpauth" } # uncomment to quickly test against local actix-web repo diff --git a/README.md b/README.md index 819a96519..e9662fd3b 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ | [actix-protobuf] | [![crates.io](https://img.shields.io/crates/v/actix-protobuf?label=latest)](https://crates.io/crates/actix-protobuf) [![dependency status](https://deps.rs/crate/actix-protobuf/0.8.0/status.svg)](https://deps.rs/crate/actix-protobuf/0.8.0) | Protobuf payload extractor. | | [actix-redis] | [![crates.io](https://img.shields.io/crates/v/actix-redis?label=latest)](https://crates.io/crates/actix-redis) [![dependency status](https://deps.rs/crate/actix-redis/0.12.0/status.svg)](https://deps.rs/crate/actix-redis/0.12.0) | Actor-based Redis client. | | [actix-session] | [![crates.io](https://img.shields.io/crates/v/actix-session?label=latest)](https://crates.io/crates/actix-session) [![dependency status](https://deps.rs/crate/actix-session/0.7.1/status.svg)](https://deps.rs/crate/actix-session/0.7.1) | Session management. | +| [actix-settings] | [![crates.io](https://img.shields.io/crates/v/actix-settings?label=latest)](https://crates.io/crates/actix-settings) [![dependency status](https://deps.rs/crate/actix-settings/0.5.2/status.svg)](https://deps.rs/crate/actix-settings/0.5.2) | Easily manage Actix Web's settings from a TOML file and environment variables. | | [actix-web-httpauth] | [![crates.io](https://img.shields.io/crates/v/actix-web-httpauth?label=latest)](https://crates.io/crates/actix-web-httpauth) [![dependency status](https://deps.rs/crate/actix-web-httpauth/0.8.0/status.svg)](https://deps.rs/crate/actix-web-httpauth/0.8.0) | HTTP authentication schemes. | --- @@ -55,6 +56,7 @@ To add a crate to this list, submit a pull request. [actix-protobuf]: ./actix-protobuf [actix-redis]: ./actix-redis [actix-session]: ./actix-session +[actix-settings]: ./actix-settings [actix-web-httpauth]: ./actix-web-httpauth [actix-web-lab]: https://crates.io/crates/actix-web-lab [actix-multipart-extract]: https://crates.io/crates/actix-multipart-extract diff --git a/actix-settings/CHANGES.md b/actix-settings/CHANGES.md new file mode 100644 index 000000000..d9a002aab --- /dev/null +++ b/actix-settings/CHANGES.md @@ -0,0 +1,9 @@ +# Changes + +## Unreleased - 2022-xx-xx +- Update Actix Web dependencies to v4 ecosystem. +- Rename `actix.ssl` settings object to `actix.tls`. +- `NoSettings` is now marked `#[non_exhaustive]`. + +## 0.5.2 - 2022-07-31 +- Adopted into @actix org from . diff --git a/actix-settings/Cargo.toml b/actix-settings/Cargo.toml new file mode 100644 index 000000000..b6f8783d8 --- /dev/null +++ b/actix-settings/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "actix-settings" +version = "0.5.2" +authors = [ + "Joey Ezechiels ", + "Rob Ede ", +] +edition = "2018" +description = "Easily manage Actix Web's settings from a TOML file and environment variables" +license = "MIT OR Apache-2.0" + +[dependencies] +actix-http = "3" +actix-service = "2" +actix-web = "4" + +ioe = "0.5" +once_cell = "1.13" +regex = "1.5.5" +serde = { version = "1", features = ["derive"] } +toml = "0.5" + +[dev-dependencies] +env_logger = "0.9" diff --git a/actix-settings/LICENSE-APACHE b/actix-settings/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/actix-settings/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/actix-settings/LICENSE-MIT b/actix-settings/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/actix-settings/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/actix-settings/README.md b/actix-settings/README.md new file mode 100644 index 000000000..513f573f9 --- /dev/null +++ b/actix-settings/README.md @@ -0,0 +1,31 @@ +# actix-settings + +> Easily manage Actix Web's settings from a TOML file and environment variables. + +[![crates.io](https://img.shields.io/crates/v/actix-settings?label=latest)](https://crates.io/crates/actix-settings) +[![Documentation](https://docs.rs/actix-settings/badge.svg?version=0.5.2)](https://docs.rs/actix-settings/0.5.2) +![Apache 2.0 or MIT licensed](https://img.shields.io/crates/l/actix-settings) +[![Dependency Status](https://deps.rs/crate/actix-settings/0.5.2/status.svg)](https://deps.rs/crate/actix-settings/0.5.2) + +## Documentation & Resources + +- [API Documentation](https://docs.rs/actix-settings) +- [Usage Example][usage] +- Minimum Supported Rust Version (MSRV): 1.57 + +### Custom Settings + +There is a way to extend the available settings. This can be used to combine the settings provided by Actix Web and those provided by application server built using `actix`. + +Have a look at [the usage example][usage] to see how. + +## WIP + +Configuration options for TLS set up are not yet implemented. + +## Special Thanks + +This crate was made possible by support from Accept B.V and [@jjpe]. + +[usage]: https://github.com/actix/actix-extras/blob/master/actix-settings/examples/actix.rs +[@jjpe]: https://github.com/jjpe diff --git a/actix-settings/examples/actix.rs b/actix-settings/examples/actix.rs new file mode 100644 index 000000000..46388082d --- /dev/null +++ b/actix-settings/examples/actix.rs @@ -0,0 +1,82 @@ +use actix_settings::{ApplySettings as _, Mode, Settings}; +use actix_web::{ + get, + middleware::{Compress, Condition, Logger}, + web, App, HttpServer, Responder, +}; + +#[get("/")] +async fn index(settings: web::Data) -> impl Responder { + format!( + r#"{{ + "mode": "{}", + "hosts": ["{}"] +}}"#, + match settings.actix.mode { + Mode::Development => "development", + Mode::Production => "production", + }, + settings + .actix + .hosts + .iter() + .map(|addr| { format!("{}:{}", addr.host, addr.port) }) + .collect::>() + .join(", "), + ) + .customize() + .insert_header(("content-type", "application/json")) +} + +#[actix_web::main] +async fn main() -> std::io::Result<()> { + let mut settings = Settings::parse_toml("./examples/Server.toml") + .expect("Failed to parse `Settings` from Server.toml"); + + // If the environment variable `$APPLICATION__HOSTS` is set, + // have its value override the `settings.actix.hosts` setting: + Settings::override_field_with_env_var(&mut settings.actix.hosts, "APPLICATION__HOSTS")?; + + init_logger(&settings); + + HttpServer::new({ + // clone settings into each worker thread + let settings = settings.clone(); + + move || { + App::new() + // Include this `.wrap()` call for compression settings to take effect: + .wrap(Condition::new( + settings.actix.enable_compression, + Compress::default(), + )) + .wrap(Logger::default()) + // make `Settings` available to handlers + .app_data(web::Data::new(settings.clone())) + .service(index) + } + }) + // apply the `Settings` to Actix Web's `HttpServer` + .apply_settings(&settings) + .run() + .await +} + +/// Initialize the logging infrastructure +fn init_logger(settings: &Settings) { + if !settings.actix.enable_log { + return; + } + + std::env::set_var( + "RUST_LOG", + match settings.actix.mode { + Mode::Development => "actix_web=debug", + Mode::Production => "actix_web=info", + }, + ); + + std::env::set_var("RUST_BACKTRACE", "1"); + + env_logger::init(); +} diff --git a/actix-settings/src/actix.rs b/actix-settings/src/actix.rs new file mode 100644 index 000000000..128a52123 --- /dev/null +++ b/actix-settings/src/actix.rs @@ -0,0 +1,512 @@ +use std::{fmt, path::PathBuf}; + +use once_cell::sync::Lazy; +use regex::Regex; +use serde::{de, Deserialize}; + +use crate::{core::Parse, error::AtError}; + +/// Settings types for Actix Web. +#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Hash)] +#[serde(rename_all = "kebab-case")] +pub struct ActixSettings { + pub hosts: Vec
, + pub mode: Mode, + pub enable_compression: bool, + pub enable_log: bool, + pub num_workers: NumWorkers, + pub backlog: Backlog, + pub max_connections: MaxConnections, + pub max_connection_rate: MaxConnectionRate, + pub keep_alive: KeepAlive, + pub client_timeout: Timeout, + pub client_shutdown: Timeout, + pub shutdown_timeout: Timeout, + pub tls: Tls, +} + +#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Hash)] +pub struct Address { + pub host: String, + pub port: u16, +} + +pub(crate) static ADDR_REGEX: Lazy = Lazy::new(|| { + Regex::new( + r#"(?x) + \[ # opening square bracket + (\s)* # optional whitespace + "(?P[^"]+)" # host name (string) + , # separating comma + (\s)* # optional whitespace + (?P\d+) # port number (integer) + (\s)* # optional whitespace + \] # closing square bracket + "#, + ) + .expect("Failed to compile regex: ADDR_REGEX") +}); + +pub(crate) static ADDR_LIST_REGEX: Lazy = Lazy::new(|| { + Regex::new( + r#"(?x) + \[ # opening square bracket (list) + (\s)* # optional whitespace + (?P( + \[".*", (\s)* \d+\] # element + (,)? # element separator + (\s)* # optional whitespace + )*) + (\s)* # optional whitespace + \] # closing square bracket (list) + "#, + ) + .expect("Failed to compile regex: ADDRS_REGEX") +}); + +impl Parse for Address { + fn parse(string: &str) -> Result { + let mut items = string + .trim() + .trim_start_matches('[') + .trim_end_matches(']') + .split(','); + + let parse_error = || AtError::ParseAddressError(string.to_string()); + + if !ADDR_REGEX.is_match(string) { + return Err(parse_error()); + } + + Ok(Self { + host: items.next().ok_or_else(parse_error)?.trim().to_string(), + port: items.next().ok_or_else(parse_error)?.trim().parse()?, + }) + } +} + +impl Parse for Vec
{ + fn parse(string: &str) -> Result { + let parse_error = || AtError::ParseAddressError(string.to_string()); + + if !ADDR_LIST_REGEX.is_match(string) { + return Err(parse_error()); + } + + let mut addrs = vec![]; + + for list_caps in ADDR_LIST_REGEX.captures_iter(string) { + let elements = &list_caps["elements"].trim(); + for elt_caps in ADDR_REGEX.captures_iter(elements) { + addrs.push(Address { + host: elt_caps["host"].to_string(), + port: elt_caps["port"].parse()?, + }); + } + } + + Ok(addrs) + } +} + +#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Hash)] +pub enum Mode { + #[serde(rename = "development")] + Development, + + #[serde(rename = "production")] + Production, +} + +impl Parse for Mode { + fn parse(string: &str) -> std::result::Result { + match string { + "development" => Ok(Self::Development), + "production" => Ok(Self::Production), + _ => Err(InvalidValue! { + expected: "\"development\" | \"production\".", + got: string, + }), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum NumWorkers { + Default, + Manual(usize), +} + +impl Parse for NumWorkers { + fn parse(string: &str) -> std::result::Result { + match string { + "default" => Ok(NumWorkers::Default), + string => match string.parse::() { + Ok(val) => Ok(NumWorkers::Manual(val)), + Err(_) => Err(InvalidValue! { + expected: "a positive integer", + got: string, + }), + }, + } + } +} + +impl<'de> serde::Deserialize<'de> for NumWorkers { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct NumWorkersVisitor; + + impl<'de> de::Visitor<'de> for NumWorkersVisitor { + type Value = NumWorkers; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + let msg = "Either \"default\" or a string containing an integer > 0"; + formatter.write_str(msg) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match NumWorkers::parse(value) { + Ok(num_workers) => Ok(num_workers), + Err(AtError::InvalidValue { expected, got, .. }) => Err( + de::Error::invalid_value(de::Unexpected::Str(&got), &expected), + ), + Err(_) => unreachable!(), + } + } + } + + deserializer.deserialize_string(NumWorkersVisitor) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Backlog { + Default, + Manual(usize), +} + +impl Parse for Backlog { + fn parse(string: &str) -> std::result::Result { + match string { + "default" => Ok(Backlog::Default), + string => match string.parse::() { + Ok(val) => Ok(Backlog::Manual(val)), + Err(_) => Err(InvalidValue! { + expected: "an integer > 0", + got: string, + }), + }, + } + } +} + +impl<'de> serde::Deserialize<'de> for Backlog { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct BacklogVisitor; + + impl<'de> de::Visitor<'de> for BacklogVisitor { + type Value = Backlog; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + let msg = "Either \"default\" or a string containing an integer > 0"; + formatter.write_str(msg) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match Backlog::parse(value) { + Ok(backlog) => Ok(backlog), + Err(AtError::InvalidValue { expected, got, .. }) => Err( + de::Error::invalid_value(de::Unexpected::Str(&got), &expected), + ), + Err(_) => unreachable!(), + } + } + } + + deserializer.deserialize_string(BacklogVisitor) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum MaxConnections { + Default, + Manual(usize), +} + +impl Parse for MaxConnections { + fn parse(string: &str) -> std::result::Result { + match string { + "default" => Ok(MaxConnections::Default), + string => match string.parse::() { + Ok(val) => Ok(MaxConnections::Manual(val)), + Err(_) => Err(InvalidValue! { + expected: "an integer > 0", + got: string, + }), + }, + } + } +} + +impl<'de> serde::Deserialize<'de> for MaxConnections { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct MaxConnectionsVisitor; + + impl<'de> de::Visitor<'de> for MaxConnectionsVisitor { + type Value = MaxConnections; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + let msg = "Either \"default\" or a string containing an integer > 0"; + formatter.write_str(msg) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match MaxConnections::parse(value) { + Ok(max_connections) => Ok(max_connections), + Err(AtError::InvalidValue { expected, got, .. }) => Err( + de::Error::invalid_value(de::Unexpected::Str(&got), &expected), + ), + Err(_) => unreachable!(), + } + } + } + + deserializer.deserialize_string(MaxConnectionsVisitor) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum MaxConnectionRate { + Default, + Manual(usize), +} + +impl Parse for MaxConnectionRate { + fn parse(string: &str) -> std::result::Result { + match string { + "default" => Ok(MaxConnectionRate::Default), + string => match string.parse::() { + Ok(val) => Ok(MaxConnectionRate::Manual(val)), + Err(_) => Err(InvalidValue! { + expected: "an integer > 0", + got: string, + }), + }, + } + } +} + +impl<'de> serde::Deserialize<'de> for MaxConnectionRate { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct MaxConnectionRateVisitor; + + impl<'de> de::Visitor<'de> for MaxConnectionRateVisitor { + type Value = MaxConnectionRate; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + let msg = "Either \"default\" or a string containing an integer > 0"; + formatter.write_str(msg) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match MaxConnectionRate::parse(value) { + Ok(max_connection_rate) => Ok(max_connection_rate), + Err(AtError::InvalidValue { expected, got, .. }) => Err( + de::Error::invalid_value(de::Unexpected::Str(&got), &expected), + ), + Err(_) => unreachable!(), + } + } + } + + deserializer.deserialize_string(MaxConnectionRateVisitor) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum KeepAlive { + Default, + Disabled, + Os, + Seconds(usize), +} + +impl Parse for KeepAlive { + fn parse(string: &str) -> std::result::Result { + pub(crate) static FMT: Lazy = + Lazy::new(|| Regex::new(r"^\d+ seconds$").expect("Failed to compile regex: FMT")); + + pub(crate) static DIGITS: Lazy = + Lazy::new(|| Regex::new(r"^\d+").expect("Failed to compile regex: FMT")); + + macro_rules! invalid_value { + ($got:expr) => { + Err(InvalidValue! { + expected: "a string of the format \"N seconds\" where N is an integer > 0", + got: $got, + }) + }; + } + + let digits_in = |m: regex::Match| &string[m.start()..m.end()]; + match string { + "default" => Ok(KeepAlive::Default), + "disabled" => Ok(KeepAlive::Disabled), + "OS" | "os" => Ok(KeepAlive::Os), + string if !FMT.is_match(string) => invalid_value!(string), + string => match DIGITS.find(string) { + None => invalid_value!(string), + Some(mat) => match digits_in(mat).parse() { + Ok(val) => Ok(KeepAlive::Seconds(val)), + Err(_) => invalid_value!(string), + }, + }, + } + } +} + +impl<'de> serde::Deserialize<'de> for KeepAlive { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct KeepAliveVisitor; + + impl<'de> de::Visitor<'de> for KeepAliveVisitor { + type Value = KeepAlive; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + let msg = "Either \"default\", \"disabled\", \"os\", or a string of the format \"N seconds\" where N is an integer > 0"; + formatter.write_str(msg) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match KeepAlive::parse(value) { + Ok(keep_alive) => Ok(keep_alive), + Err(AtError::InvalidValue { expected, got, .. }) => Err( + de::Error::invalid_value(de::Unexpected::Str(&got), &expected), + ), + Err(_) => unreachable!(), + } + } + } + + deserializer.deserialize_string(KeepAliveVisitor) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Timeout { + Default, + Milliseconds(usize), + Seconds(usize), +} + +impl Parse for Timeout { + fn parse(string: &str) -> std::result::Result { + pub static FMT: Lazy = Lazy::new(|| { + Regex::new(r"^\d+ (milliseconds|seconds)$").expect("Failed to compile regex: FMT") + }); + + pub static DIGITS: Lazy = + Lazy::new(|| Regex::new(r"^\d+").expect("Failed to compile regex: DIGITS")); + + pub static UNIT: Lazy = Lazy::new(|| { + Regex::new(r"(milliseconds|seconds)$").expect("Failed to compile regex: UNIT") + }); + + macro_rules! invalid_value { + ($got:expr) => { + Err(InvalidValue! { + expected: "a string of the format \"N seconds\" or \"N milliseconds\" where N is an integer > 0", + got: $got, + }) + } + } + match string { + "default" => Ok(Timeout::Default), + string if !FMT.is_match(string) => invalid_value!(string), + string => match (DIGITS.find(string), UNIT.find(string)) { + (None, _) => invalid_value!(string), + (_, None) => invalid_value!(string), + (Some(dmatch), Some(umatch)) => { + let digits = &string[dmatch.start()..dmatch.end()]; + let unit = &string[umatch.start()..umatch.end()]; + match (digits.parse(), unit) { + (Ok(v), "milliseconds") => Ok(Timeout::Milliseconds(v)), + (Ok(v), "seconds") => Ok(Timeout::Seconds(v)), + _ => invalid_value!(string), + } + } + }, + } + } +} + +impl<'de> serde::Deserialize<'de> for Timeout { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct TimeoutVisitor; + + impl<'de> de::Visitor<'de> for TimeoutVisitor { + type Value = Timeout; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + let msg = "Either \"default\", \"disabled\", \"os\", or a string of the format \"N seconds\" where N is an integer > 0"; + formatter.write_str(msg) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match Timeout::parse(value) { + Ok(num_workers) => Ok(num_workers), + Err(AtError::InvalidValue { expected, got, .. }) => Err( + de::Error::invalid_value(de::Unexpected::Str(&got), &expected), + ), + Err(_) => unreachable!(), + } + } + } + + deserializer.deserialize_string(TimeoutVisitor) + } +} + +#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Hash)] +#[serde(rename_all = "kebab-case")] +pub struct Tls { + pub enabled: bool, + pub certificate: PathBuf, + pub private_key: PathBuf, +} diff --git a/actix-settings/src/core.rs b/actix-settings/src/core.rs new file mode 100644 index 000000000..36df38c9e --- /dev/null +++ b/actix-settings/src/core.rs @@ -0,0 +1,38 @@ +use std::{path::PathBuf, str::FromStr}; + +use crate::error::AtError; + +pub trait Parse: Sized { + fn parse(string: &str) -> Result; +} + +impl Parse for bool { + fn parse(string: &str) -> Result { + Self::from_str(string).map_err(AtError::from) + } +} + +macro_rules! impl_parse_for_int_type { + ($($int_type:ty),+ $(,)?) => { + $( + impl Parse for $int_type { + fn parse(string: &str) -> Result { + Self::from_str(string).map_err(AtError::from) + } + } + )+ + } +} +impl_parse_for_int_type![i8, i16, i32, i64, i128, u8, u16, u32, u64, u128]; + +impl Parse for String { + fn parse(string: &str) -> Result { + Ok(string.to_string()) + } +} + +impl Parse for PathBuf { + fn parse(string: &str) -> Result { + Ok(PathBuf::from(string)) + } +} diff --git a/actix-settings/src/defaults.toml b/actix-settings/src/defaults.toml new file mode 100644 index 000000000..a73be1edb --- /dev/null +++ b/actix-settings/src/defaults.toml @@ -0,0 +1,72 @@ +[actix] +# For more info, see: https://docs.rs/actix-web/4/actix_web/struct.HttpServer.html. + +hosts = [ + ["0.0.0.0", 9000] # This should work for both development and deployment... + # # ... but other entries are possible, as well. +] +mode = "development" # Either "development" or "production". +enable-compression = true # Toggle compression middleware. +enable-log = true # Toggle logging middleware. + +# The number of workers that the server should start. +# By default the number of available logical cpu cores is used. +# Takes a string value: Either "default", or an integer N > 0 e.g. "6". +num-workers = "default" + +# The maximum number of pending connections. This refers to the number of clients +# that can be waiting to be served. Exceeding this number results in the client +# getting an error when attempting to connect. It should only affect servers under +# significant load. Generally set in the 64-2048 range. The default value is 2048. +# Takes a string value: Either "default", or an integer N > 0 e.g. "6". +backlog = "default" + +# Sets the maximum per-worker number of concurrent connections. All socket listeners +# will stop accepting connections when this limit is reached for each worker. +# By default max connections is set to a 25k. +# Takes a string value: Either "default", or an integer N > 0 e.g. "6". +max-connections = "default" + +# Sets the maximum per-worker concurrent connection establish process. All listeners +# will stop accepting connections when this limit is reached. It can be used to limit +# the global TLS CPU usage. By default max connections is set to a 256. +# Takes a string value: Either "default", or an integer N > 0 e.g. "6". +max-connection-rate = "default" + +# Set server keep-alive setting. By default keep alive is set to 5 seconds. +# Takes a string value: Either "default", "disabled", "os", +# or a string of the format "N seconds" where N is an integer > 0 e.g. "6 seconds". +keep-alive = "default" + +# Set server client timeout in milliseconds for first request. Defines a timeout +# for reading client request header. If a client does not transmit the entire set of +# headers within this time, the request is terminated with the 408 (Request Time-out) +# error. To disable timeout, set the value to 0. +# By default client timeout is set to 5000 milliseconds. +# Takes a string value: Either "default", or a string of the format "N milliseconds" +# where N is an integer > 0 e.g. "6 milliseconds". +client-timeout = "default" + +# Set server connection shutdown timeout in milliseconds. Defines a timeout for +# shutdown connection. If a shutdown procedure does not complete within this time, +# the request is dropped. To disable timeout set value to 0. +# By default client timeout is set to 5000 milliseconds. +# Takes a string value: Either "default", or a string of the format "N milliseconds" +# where N is an integer > 0 e.g. "6 milliseconds". +client-shutdown = "default" + +# Timeout for graceful workers shutdown. After receiving a stop signal, workers have +# this much time to finish serving requests. Workers still alive after the timeout +# are force dropped. By default shutdown timeout sets to 30 seconds. +# Takes a string value: Either "default", or a string of the format "N seconds" +# where N is an integer > 0 e.g. "6 seconds". +shutdown-timeout = "default" + +[actix.tls] # TLS is disabled by default because the certs don't exist +enabled = false +certificate = "path/to/cert/cert.pem" +private-key = "path/to/cert/key.pem" + +# The `application` table be used to express application-specific settings. +# See the `README.md` file for more details on how to use this. +[application] diff --git a/actix-settings/src/error.rs b/actix-settings/src/error.rs new file mode 100644 index 000000000..2e097a0a2 --- /dev/null +++ b/actix-settings/src/error.rs @@ -0,0 +1,123 @@ +use std::{env::VarError, io, num::ParseIntError, path::PathBuf, str::ParseBoolError}; + +use toml::de::Error as TomlError; + +pub type AtResult = std::result::Result; + +#[derive(Clone, Debug)] +pub enum AtError { + EnvVarError(VarError), + FileExists(PathBuf), + InvalidValue { + expected: &'static str, + got: String, + file: &'static str, + line: u32, + column: u32, + }, + IoError(ioe::IoError), + ParseBoolError(ParseBoolError), + ParseIntError(ParseIntError), + ParseAddressError(String), + TomlError(TomlError), +} + +macro_rules! InvalidValue { + (expected: $expected:expr, got: $got:expr,) => { + crate::AtError::InvalidValue { + expected: $expected, + got: $got.to_string(), + file: file!(), + line: line!(), + column: column!(), + } + }; +} + +impl From for AtError { + fn from(err: io::Error) -> Self { + Self::IoError(ioe::IoError::from(err)) + } +} + +impl From for AtError { + fn from(err: ioe::IoError) -> Self { + Self::IoError(err) + } +} + +impl From for AtError { + fn from(err: ParseBoolError) -> Self { + Self::ParseBoolError(err) + } +} + +impl From for AtError { + fn from(err: ParseIntError) -> Self { + Self::ParseIntError(err) + } +} + +impl From for AtError { + fn from(err: TomlError) -> Self { + Self::TomlError(err) + } +} + +impl From for AtError { + fn from(err: VarError) -> Self { + Self::EnvVarError(err) + } +} + +impl From for io::Error { + fn from(err: AtError) -> Self { + match err { + AtError::EnvVarError(var_error) => { + let msg = format!("Env var error: {}", var_error); + io::Error::new(io::ErrorKind::InvalidInput, msg) + } + + AtError::FileExists(path_buf) => { + let msg = format!("File exists: {}", path_buf.display()); + io::Error::new(io::ErrorKind::AlreadyExists, msg) + } + + AtError::InvalidValue { + expected, + ref got, + file, + line, + column, + } => { + let msg = format!( + "Expected {}, got {} (@ {}:{}:{})", + expected, got, file, line, column + ); + io::Error::new(io::ErrorKind::InvalidInput, msg) + } + + AtError::IoError(io_error) => io_error.into(), + + AtError::ParseBoolError(parse_bool_error) => { + let msg = format!("Failed to parse boolean: {}", parse_bool_error); + io::Error::new(io::ErrorKind::InvalidInput, msg) + } + + AtError::ParseIntError(parse_int_error) => { + let msg = format!("Failed to parse integer: {}", parse_int_error); + io::Error::new(io::ErrorKind::InvalidInput, msg) + } + + AtError::ParseAddressError(string) => { + let msg = format!("Failed to parse address: {}", string); + io::Error::new(io::ErrorKind::InvalidInput, msg) + } + + AtError::TomlError(toml_error) => { + let msg = format!("TOML error: {}", toml_error); + io::Error::new(io::ErrorKind::InvalidInput, msg) + } + } + } +} diff --git a/actix-settings/src/lib.rs b/actix-settings/src/lib.rs new file mode 100644 index 000000000..3432af703 --- /dev/null +++ b/actix-settings/src/lib.rs @@ -0,0 +1,705 @@ +/// A library to process Server.toml files +use std::{ + env, fmt, + fs::File, + io::{Read as _, Write as _}, + path::Path, + time::Duration, +}; + +use actix_http::{Request, Response}; +use actix_service::IntoServiceFactory; +use actix_web::{ + body::MessageBody, + dev::{AppConfig, ServiceFactory}, + http::KeepAlive as ActixKeepAlive, + Error as WebError, HttpServer, +}; +use serde::{de, Deserialize}; + +#[macro_use] +mod error; +mod actix; +mod core; + +pub use self::actix::*; +pub use self::core::Parse; +pub use self::error::{AtError, AtResult}; + +#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Hash)] +#[serde(bound = "A: Deserialize<'de>")] +pub struct BasicSettings { + pub actix: ActixSettings, + pub application: A, +} + +pub type Settings = BasicSettings; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize)] +#[non_exhaustive] +pub struct NoSettings {/* NOTE: turning this into a unit struct will cause deserialization failures. */} + +impl BasicSettings +where + A: de::DeserializeOwned, +{ + /// NOTE **DO NOT** mess with the ordering of the tables in this template. + /// Especially the `[application]` table needs to be last in order + /// for some tests to keep working. + pub(crate) const DEFAULT_TOML_TEMPLATE: &'static str = include_str!("./defaults.toml"); + + /// Parse an instance of `Self` from a `TOML` file located at `filepath`. + /// If the file doesn't exist, it is generated from the default `TOML` + /// template, after which the newly generated file is read in and parsed. + pub fn parse_toml

(filepath: P) -> AtResult + where + P: AsRef, + { + let filepath = filepath.as_ref(); + + if !filepath.exists() { + Self::write_toml_file(filepath)?; + } + + let mut f = File::open(filepath)?; + let mut contents = String::with_capacity(f.metadata()?.len() as usize); + f.read_to_string(&mut contents)?; + + Ok(toml::from_str::(&contents)?) + } + + /// Parse an instance of `Self` straight from the default `TOML` template. + pub fn from_default_template() -> AtResult { + Self::from_template(Self::DEFAULT_TOML_TEMPLATE) + } + + /// Parse an instance of `Self` straight from the default `TOML` template. + pub fn from_template(template: &str) -> AtResult { + Ok(toml::from_str(template)?) + } + + /// Write the default `TOML` template to a new file, to be located + /// at `filepath`. Return a `Error::FileExists(_)` error if a + /// file already exists at that location. + pub fn write_toml_file

(filepath: P) -> AtResult<()> + where + P: AsRef, + { + let filepath = filepath.as_ref(); + let contents = Self::DEFAULT_TOML_TEMPLATE.trim(); + + if filepath.exists() { + return Err(AtError::FileExists(filepath.to_path_buf())); + } + + let mut file = File::create(filepath)?; + file.write_all(contents.as_bytes())?; + file.flush()?; + + Ok(()) + } + + pub fn override_field(field: &mut F, value: V) -> AtResult<()> + where + F: Parse, + V: AsRef, + { + *field = F::parse(value.as_ref())?; + Ok(()) + } + + pub fn override_field_with_env_var(field: &mut F, var_name: N) -> AtResult<()> + where + F: Parse, + N: AsRef, + { + match env::var(var_name.as_ref()) { + Err(env::VarError::NotPresent) => Ok((/*NOP*/)), + Err(var_error) => Err(AtError::from(var_error)), + Ok(value) => Self::override_field(field, value), + } + } +} + +pub trait ApplySettings { + /// Apply a [`BasicSettings`] value to `self`. + /// + /// [`BasicSettings`]: ./struct.BasicSettings.html + #[must_use] + fn apply_settings(self, settings: &BasicSettings) -> Self + where + A: de::DeserializeOwned; +} + +impl ApplySettings for HttpServer +where + F: Fn() -> I + Send + Clone + 'static, + I: IntoServiceFactory, + S: ServiceFactory + 'static, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + S::Response: Into> + 'static, + S::Future: 'static, + B: MessageBody + 'static, +{ + fn apply_settings(mut self, settings: &BasicSettings) -> Self + where + A: de::DeserializeOwned, + { + if settings.actix.tls.enabled { + // for Address { host, port } in &settings.actix.hosts { + // self = self.bind(format!("{}:{}", host, port)) + // .unwrap(/*TODO*/); + // } + todo!("[ApplySettings] TLS support has not been implemented yet."); + } else { + for Address { host, port } in &settings.actix.hosts { + self = self.bind(format!("{}:{}", host, port)) + .unwrap(/*TODO*/); + } + } + + self = match settings.actix.num_workers { + NumWorkers::Default => self, + NumWorkers::Manual(n) => self.workers(n), + }; + + self = match settings.actix.backlog { + Backlog::Default => self, + Backlog::Manual(n) => self.backlog(n as u32), + }; + + self = match settings.actix.max_connections { + MaxConnections::Default => self, + MaxConnections::Manual(n) => self.max_connections(n), + }; + + self = match settings.actix.max_connection_rate { + MaxConnectionRate::Default => self, + MaxConnectionRate::Manual(n) => self.max_connection_rate(n), + }; + + self = match settings.actix.keep_alive { + KeepAlive::Default => self, + KeepAlive::Disabled => self.keep_alive(ActixKeepAlive::Disabled), + KeepAlive::Os => self.keep_alive(ActixKeepAlive::Os), + KeepAlive::Seconds(n) => self.keep_alive(Duration::from_secs(n as u64)), + }; + + self = match settings.actix.client_timeout { + Timeout::Default => self, + Timeout::Milliseconds(n) => { + self.client_disconnect_timeout(Duration::from_millis(n as u64)) + } + Timeout::Seconds(n) => self.client_disconnect_timeout(Duration::from_secs(n as u64)), + }; + + self = match settings.actix.client_shutdown { + Timeout::Default => self, + Timeout::Milliseconds(n) => { + self.client_disconnect_timeout(Duration::from_millis(n as u64)) + } + Timeout::Seconds(n) => self.client_disconnect_timeout(Duration::from_secs(n as u64)), + }; + + self = match settings.actix.shutdown_timeout { + Timeout::Default => self, + Timeout::Milliseconds(_) => self.shutdown_timeout(1), + Timeout::Seconds(n) => self.shutdown_timeout(n as u64), + }; + + self + } +} + +#[cfg(test)] +mod tests { + #![allow(non_snake_case)] + + use std::path::Path; + + use actix_web::{App, HttpServer}; + use serde::Deserialize; + + use crate::actix::*; // used for value construction in assertions + use crate::{ApplySettings, AtResult, BasicSettings, Settings}; + + #[test] + fn apply_settings() -> AtResult<()> { + let settings = Settings::parse_toml("Server.toml")?; + let _ = HttpServer::new(App::new).apply_settings(&settings); + Ok(()) + } + + #[test] + fn override_field__hosts() { + let mut settings = Settings::from_default_template().unwrap(); + + assert_eq!( + settings.actix.hosts, + vec![Address { + host: "0.0.0.0".into(), + port: 9000 + },] + ); + + Settings::override_field( + &mut settings.actix.hosts, + r#"[ + ["0.0.0.0", 1234], + ["localhost", 2345] + ]"#, + ) + .unwrap(); + + assert_eq!( + settings.actix.hosts, + vec![ + Address { + host: "0.0.0.0".into(), + port: 1234 + }, + Address { + host: "localhost".into(), + port: 2345 + }, + ] + ); + } + + #[test] + fn override_field_with_env_var__hosts() { + let mut settings = Settings::from_default_template().unwrap(); + + assert_eq!( + settings.actix.hosts, + vec![Address { + host: "0.0.0.0".into(), + port: 9000 + },] + ); + + std::env::set_var( + "OVERRIDE__HOSTS", + r#"[ + ["0.0.0.0", 1234], + ["localhost", 2345] + ]"#, + ); + + Settings::override_field_with_env_var(&mut settings.actix.hosts, "OVERRIDE__HOSTS") + .unwrap(); + + assert_eq!( + settings.actix.hosts, + vec![ + Address { + host: "0.0.0.0".into(), + port: 1234 + }, + Address { + host: "localhost".into(), + port: 2345 + }, + ] + ); + } + + #[test] + fn override_field__mode() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!(settings.actix.mode, Mode::Development); + Settings::override_field(&mut settings.actix.mode, "production")?; + assert_eq!(settings.actix.mode, Mode::Production); + Ok(()) + } + + #[test] + fn override_field_with_env_var__mode() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!(settings.actix.mode, Mode::Development); + std::env::set_var("OVERRIDE__MODE", "production"); + Settings::override_field_with_env_var(&mut settings.actix.mode, "OVERRIDE__MODE")?; + assert_eq!(settings.actix.mode, Mode::Production); + Ok(()) + } + + #[test] + fn override_field__enable_compression() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert!(settings.actix.enable_compression); + Settings::override_field(&mut settings.actix.enable_compression, "false")?; + assert!(!settings.actix.enable_compression); + Ok(()) + } + + #[test] + fn override_field_with_env_var__enable_compression() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert!(settings.actix.enable_compression); + std::env::set_var("OVERRIDE__ENABLE_COMPRESSION", "false"); + Settings::override_field_with_env_var( + &mut settings.actix.enable_compression, + "OVERRIDE__ENABLE_COMPRESSION", + )?; + assert!(!settings.actix.enable_compression); + Ok(()) + } + + #[test] + fn override_field__enable_log() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert!(settings.actix.enable_log); + Settings::override_field(&mut settings.actix.enable_log, "false")?; + assert!(!settings.actix.enable_log); + Ok(()) + } + + #[test] + fn override_field_with_env_var__enable_log() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert!(settings.actix.enable_log); + std::env::set_var("OVERRIDE__ENABLE_LOG", "false"); + Settings::override_field_with_env_var( + &mut settings.actix.enable_log, + "OVERRIDE__ENABLE_LOG", + )?; + assert!(!settings.actix.enable_log); + Ok(()) + } + + #[test] + fn override_field__num_workers() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!(settings.actix.num_workers, NumWorkers::Default); + Settings::override_field(&mut settings.actix.num_workers, "42")?; + assert_eq!(settings.actix.num_workers, NumWorkers::Manual(42)); + Ok(()) + } + + #[test] + fn override_field_with_env_var__num_workers() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!(settings.actix.num_workers, NumWorkers::Default); + std::env::set_var("OVERRIDE__NUM_WORKERS", "42"); + Settings::override_field_with_env_var( + &mut settings.actix.num_workers, + "OVERRIDE__NUM_WORKERS", + )?; + assert_eq!(settings.actix.num_workers, NumWorkers::Manual(42)); + Ok(()) + } + + #[test] + fn override_field__backlog() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!(settings.actix.backlog, Backlog::Default); + Settings::override_field(&mut settings.actix.backlog, "42")?; + assert_eq!(settings.actix.backlog, Backlog::Manual(42)); + Ok(()) + } + + #[test] + fn override_field_with_env_var__backlog() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!(settings.actix.backlog, Backlog::Default); + std::env::set_var("OVERRIDE__BACKLOG", "42"); + Settings::override_field_with_env_var(&mut settings.actix.backlog, "OVERRIDE__BACKLOG")?; + assert_eq!(settings.actix.backlog, Backlog::Manual(42)); + Ok(()) + } + + #[test] + fn override_field__max_connections() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!(settings.actix.max_connections, MaxConnections::Default); + Settings::override_field(&mut settings.actix.max_connections, "42")?; + assert_eq!(settings.actix.max_connections, MaxConnections::Manual(42)); + Ok(()) + } + + #[test] + fn override_field_with_env_var__max_connections() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!(settings.actix.max_connections, MaxConnections::Default); + std::env::set_var("OVERRIDE__MAX_CONNECTIONS", "42"); + Settings::override_field_with_env_var( + &mut settings.actix.max_connections, + "OVERRIDE__MAX_CONNECTIONS", + )?; + assert_eq!(settings.actix.max_connections, MaxConnections::Manual(42)); + Ok(()) + } + + #[test] + fn override_field__max_connection_rate() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!( + settings.actix.max_connection_rate, + MaxConnectionRate::Default + ); + Settings::override_field(&mut settings.actix.max_connection_rate, "42")?; + assert_eq!( + settings.actix.max_connection_rate, + MaxConnectionRate::Manual(42) + ); + Ok(()) + } + + #[test] + fn override_field_with_env_var__max_connection_rate() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!( + settings.actix.max_connection_rate, + MaxConnectionRate::Default + ); + std::env::set_var("OVERRIDE__MAX_CONNECTION_RATE", "42"); + Settings::override_field_with_env_var( + &mut settings.actix.max_connection_rate, + "OVERRIDE__MAX_CONNECTION_RATE", + )?; + assert_eq!( + settings.actix.max_connection_rate, + MaxConnectionRate::Manual(42) + ); + Ok(()) + } + + #[test] + fn override_field__keep_alive() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!(settings.actix.keep_alive, KeepAlive::Default); + Settings::override_field(&mut settings.actix.keep_alive, "42 seconds")?; + assert_eq!(settings.actix.keep_alive, KeepAlive::Seconds(42)); + Ok(()) + } + + #[test] + fn override_field_with_env_var__keep_alive() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!(settings.actix.keep_alive, KeepAlive::Default); + std::env::set_var("OVERRIDE__KEEP_ALIVE", "42 seconds"); + Settings::override_field_with_env_var( + &mut settings.actix.keep_alive, + "OVERRIDE__KEEP_ALIVE", + )?; + assert_eq!(settings.actix.keep_alive, KeepAlive::Seconds(42)); + Ok(()) + } + + #[test] + fn override_field__client_timeout() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!(settings.actix.client_timeout, Timeout::Default); + Settings::override_field(&mut settings.actix.client_timeout, "42 seconds")?; + assert_eq!(settings.actix.client_timeout, Timeout::Seconds(42)); + Ok(()) + } + + #[test] + fn override_field_with_env_var__client_timeout() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!(settings.actix.client_timeout, Timeout::Default); + std::env::set_var("OVERRIDE__CLIENT_TIMEOUT", "42 seconds"); + Settings::override_field_with_env_var( + &mut settings.actix.client_timeout, + "OVERRIDE__CLIENT_TIMEOUT", + )?; + assert_eq!(settings.actix.client_timeout, Timeout::Seconds(42)); + Ok(()) + } + + #[test] + fn override_field__client_shutdown() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!(settings.actix.client_shutdown, Timeout::Default); + Settings::override_field(&mut settings.actix.client_shutdown, "42 seconds")?; + assert_eq!(settings.actix.client_shutdown, Timeout::Seconds(42)); + Ok(()) + } + + #[test] + fn override_field_with_env_var__client_shutdown() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!(settings.actix.client_shutdown, Timeout::Default); + std::env::set_var("OVERRIDE__CLIENT_SHUTDOWN", "42 seconds"); + Settings::override_field_with_env_var( + &mut settings.actix.client_shutdown, + "OVERRIDE__CLIENT_SHUTDOWN", + )?; + assert_eq!(settings.actix.client_shutdown, Timeout::Seconds(42)); + Ok(()) + } + + #[test] + fn override_field__shutdown_timeout() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!(settings.actix.shutdown_timeout, Timeout::Default); + Settings::override_field(&mut settings.actix.shutdown_timeout, "42 seconds")?; + assert_eq!(settings.actix.shutdown_timeout, Timeout::Seconds(42)); + Ok(()) + } + + #[test] + fn override_field_with_env_var__shutdown_timeout() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!(settings.actix.shutdown_timeout, Timeout::Default); + std::env::set_var("OVERRIDE__SHUTDOWN_TIMEOUT", "42 seconds"); + Settings::override_field_with_env_var( + &mut settings.actix.shutdown_timeout, + "OVERRIDE__SHUTDOWN_TIMEOUT", + )?; + assert_eq!(settings.actix.shutdown_timeout, Timeout::Seconds(42)); + Ok(()) + } + + #[test] + fn override_field__tls__enabled() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert!(!settings.actix.tls.enabled); + Settings::override_field(&mut settings.actix.tls.enabled, "true")?; + assert!(settings.actix.tls.enabled); + Ok(()) + } + + #[test] + fn override_field_with_env_var__tls__enabled() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert!(!settings.actix.tls.enabled); + std::env::set_var("OVERRIDE__TLS_ENABLED", "true"); + Settings::override_field_with_env_var( + &mut settings.actix.tls.enabled, + "OVERRIDE__TLS_ENABLED", + )?; + assert!(settings.actix.tls.enabled); + Ok(()) + } + + #[test] + fn override_field__tls__certificate() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!( + settings.actix.tls.certificate, + Path::new("path/to/cert/cert.pem") + ); + Settings::override_field( + &mut settings.actix.tls.certificate, + "/overridden/path/to/cert/cert.pem", + )?; + assert_eq!( + settings.actix.tls.certificate, + Path::new("/overridden/path/to/cert/cert.pem") + ); + Ok(()) + } + + #[test] + fn override_field_with_env_var__tls__certificate() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!( + settings.actix.tls.certificate, + Path::new("path/to/cert/cert.pem") + ); + std::env::set_var( + "OVERRIDE__TLS_CERTIFICATE", + "/overridden/path/to/cert/cert.pem", + ); + Settings::override_field_with_env_var( + &mut settings.actix.tls.certificate, + "OVERRIDE__TLS_CERTIFICATE", + )?; + assert_eq!( + settings.actix.tls.certificate, + Path::new("/overridden/path/to/cert/cert.pem") + ); + Ok(()) + } + + #[test] + fn override_field__tls__private_key() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!( + settings.actix.tls.private_key, + Path::new("path/to/cert/key.pem") + ); + Settings::override_field( + &mut settings.actix.tls.private_key, + "/overridden/path/to/cert/key.pem", + )?; + assert_eq!( + settings.actix.tls.private_key, + Path::new("/overridden/path/to/cert/key.pem") + ); + Ok(()) + } + + #[test] + fn override_field_with_env_var__tls__private_key() -> AtResult<()> { + let mut settings = Settings::from_default_template()?; + assert_eq!( + settings.actix.tls.private_key, + Path::new("path/to/cert/key.pem") + ); + std::env::set_var( + "OVERRIDE__TLS_PRIVATE_KEY", + "/overridden/path/to/cert/key.pem", + ); + Settings::override_field_with_env_var( + &mut settings.actix.tls.private_key, + "OVERRIDE__TLS_PRIVATE_KEY", + )?; + assert_eq!( + settings.actix.tls.private_key, + Path::new("/overridden/path/to/cert/key.pem") + ); + Ok(()) + } + + #[test] + fn override_extended_field_with_custom_type() -> AtResult<()> { + #[derive(Debug, Clone, Deserialize, PartialEq, Eq)] + struct NestedSetting { + foo: String, + bar: bool, + } + #[derive(Debug, Clone, Deserialize, PartialEq, Eq)] + struct AppSettings { + #[serde(rename = "example-name")] + example_name: String, + #[serde(rename = "nested-field")] + nested_field: NestedSetting, + } + type CustomSettings = BasicSettings; + let mut settings = CustomSettings::from_template( + &(CustomSettings::DEFAULT_TOML_TEMPLATE.to_string() + // NOTE: Add these entries to the `[application]` table: + + "\nexample-name = \"example value\"" + + "\nnested-field = { foo = \"foo\", bar = false }"), + )?; + assert_eq!( + settings.application, + AppSettings { + example_name: "example value".into(), + nested_field: NestedSetting { + foo: "foo".into(), + bar: false, + }, + } + ); + CustomSettings::override_field( + &mut settings.application.example_name, + "/overridden/path/to/cert/key.pem", + )?; + assert_eq!( + settings.application, + AppSettings { + example_name: "/overridden/path/to/cert/key.pem".into(), + nested_field: NestedSetting { + foo: "foo".into(), + bar: false, + }, + } + ); + Ok(()) + } +}