diff --git a/actix-proxy-protocol/Cargo.toml b/actix-proxy-protocol/Cargo.toml index ebdc4da9..f9996979 100755 --- a/actix-proxy-protocol/Cargo.toml +++ b/actix-proxy-protocol/Cargo.toml @@ -21,6 +21,7 @@ crc32fast = "1" futures-core = { version = "0.3.17", default-features = false, features = ["std"] } futures-util = { version = "0.3.17", default-features = false, features = ["std"] } itoa = "1" +smallvec = "1" tokio = { version = "1.13.1", features = ["sync", "io-util"] } tracing = { version = "0.1.30", default-features = false, features = ["log"] } diff --git a/actix-proxy-protocol/examples/proxy-server.rs b/actix-proxy-protocol/examples/proxy-server.rs index f6cee9cf..05c80b68 100644 --- a/actix-proxy-protocol/examples/proxy-server.rs +++ b/actix-proxy-protocol/examples/proxy-server.rs @@ -11,7 +11,7 @@ use std::{ }, }; -use actix_proxy_protocol::{v1, v2, AddressFamily, Command, Tlv, TransportProtocol}; +use actix_proxy_protocol::{tlv, v1, v2, AddressFamily, Command, TransportProtocol}; use actix_rt::net::TcpStream; use actix_server::Server; use actix_service::{fn_service, ServiceFactoryExt as _}; @@ -31,6 +31,7 @@ TLV = type-length-value TO DO: handle UNKNOWN transport v2 UNSPEC mode +AF_UNIX socket */ fn extend_with_ip_bytes(buf: &mut Vec, ip: IpAddr) { @@ -71,18 +72,10 @@ async fn wrap_with_proxy_protocol_v2(mut stream: TcpStream) -> io::Result<()> { UPSTREAM.to_string() ); - let mut proxy_header = v2::Header::new( - Command::Proxy, - TransportProtocol::Stream, - AddressFamily::Inet, - SocketAddr::from(([127, 0, 0, 1], 8082)), - *UPSTREAM, - vec![ - Tlv::new(0x05, [0x34, 0x32, 0x36, 0x39]), // UNIQUE_ID - Tlv::new(0x04, "NOOP m9"), // NOOP - ], - ); + let mut proxy_header = v2::Header::new_tcp_ipv4_proxy(([127, 0, 0, 1], 8082), *UPSTREAM); + proxy_header.add_tlv(0x05, [0x34, 0x32, 0x36, 0x39]); // UNIQUE_ID + proxy_header.add_tlv(0x04, "NOOP m9"); // NOOP proxy_header.add_crc23c_checksum_tlv(); proxy_header.write_to_tokio(&mut upstream).await?; diff --git a/actix-proxy-protocol/src/lib.rs b/actix-proxy-protocol/src/lib.rs index ed787cda..e79aade7 100644 --- a/actix-proxy-protocol/src/lib.rs +++ b/actix-proxy-protocol/src/lib.rs @@ -16,6 +16,7 @@ use std::{ use arrayvec::{ArrayString, ArrayVec}; use tokio::io::{AsyncWrite, AsyncWriteExt as _}; +pub mod tlv; pub mod v1; pub mod v2; @@ -165,80 +166,3 @@ enum ProxyProtocolHeader { V1(v1::Header), V2(v2::Header), } - -#[derive(Debug, Clone, Default, PartialEq, Eq)] -pub struct Pp2Crc32c { - checksum: u32, -} - -impl Pp2Crc32c { - fn to_tlv(&self) -> Tlv { - Tlv { - r#type: 0x03, - value: self.checksum.to_be_bytes().to_vec(), - } - } -} - -#[derive(Debug, Clone)] -pub struct Tlv { - r#type: u8, - value: Vec, -} - -impl Tlv { - pub fn new(r#type: u8, value: impl Into>) -> Self { - let value = value.into(); - - assert!( - !value.is_empty(), - "TLV values must have length greater than 1", - ); - - Self { r#type, value } - } - - fn len(&self) -> u16 { - // 1b type + 2b len - // + [len]b value - 1 + 2 + (self.value.len() as u16) - } - - pub fn write_to(&self, wrt: &mut impl io::Write) -> io::Result<()> { - wrt.write_all(&[self.r#type])?; - wrt.write_all(&(self.value.len() as u16).to_be_bytes())?; - wrt.write_all(&self.value)?; - Ok(()) - } - - pub fn as_crc32c(&self) -> Option { - if self.r#type != 0x03 { - return None; - } - - let checksum_bytes = <[u8; 4]>::try_from(self.value.as_slice()).ok()?; - - Some(Pp2Crc32c { - checksum: u32::from_be_bytes(checksum_bytes), - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn tlv_as_crc32c() { - let noop = Tlv::new(0x04, vec![0x00]); - assert_eq!(noop.as_crc32c(), None); - - let noop = Tlv::new(0x03, vec![0x08, 0x70, 0x17, 0x7b]); - assert_eq!( - noop.as_crc32c(), - Some(Pp2Crc32c { - checksum: 141563771 - }) - ); - } -} diff --git a/actix-proxy-protocol/src/tlv.rs b/actix-proxy-protocol/src/tlv.rs new file mode 100644 index 00000000..bf7a04ae --- /dev/null +++ b/actix-proxy-protocol/src/tlv.rs @@ -0,0 +1,58 @@ +use std::convert::TryFrom; + +pub trait Tlv: Sized { + const TYPE: u8; + + fn try_from_parts(typ: u8, value: &[u8]) -> Option; + + fn value_bytes(&self) -> Vec; +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct Crc32c { + pub(crate) checksum: u32, +} + +impl Tlv for Crc32c { + const TYPE: u8 = 0x03; + + fn try_from_parts(typ: u8, value: &[u8]) -> Option { + if typ != Self::TYPE { + return None; + } + + let checksum_bytes = <[u8; 4]>::try_from(value).ok()?; + + Some(Self { + checksum: u32::from_be_bytes(checksum_bytes), + }) + } + + fn value_bytes(&self) -> Vec { + self.checksum.to_be_bytes().to_vec() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // #[test] + // #[should_panic] + // fn tlv_zero_len() { + // Tlv::new(0x00, vec![]); + // } + + #[test] + fn tlv_as_crc32c() { + // noop + assert_eq!(Crc32c::try_from_parts(0x04, &[0x00]), None); + + assert_eq!( + Crc32c::try_from_parts(0x03, &[0x08, 0x70, 0x17, 0x7b]), + Some(Crc32c { + checksum: 141563771 + }) + ); + } +} diff --git a/actix-proxy-protocol/src/v1/mod.rs b/actix-proxy-protocol/src/v1.rs similarity index 100% rename from actix-proxy-protocol/src/v1/mod.rs rename to actix-proxy-protocol/src/v1.rs diff --git a/actix-proxy-protocol/src/v2/mod.rs b/actix-proxy-protocol/src/v2.rs similarity index 78% rename from actix-proxy-protocol/src/v2/mod.rs rename to actix-proxy-protocol/src/v2.rs index 8aebf7a0..efabdd12 100644 --- a/actix-proxy-protocol/src/v2/mod.rs +++ b/actix-proxy-protocol/src/v2.rs @@ -3,9 +3,13 @@ use std::{ net::{IpAddr, SocketAddr}, }; +use smallvec::{smallvec, SmallVec, ToSmallVec as _}; use tokio::io::{AsyncWrite, AsyncWriteExt as _}; -use crate::{AddressFamily, Command, Pp2Crc32c, Tlv, TransportProtocol, Version}; +use crate::{ + tlv::{Crc32c, Tlv}, + AddressFamily, Command, TransportProtocol, Version, +}; pub(crate) const SIGNATURE: [u8; 12] = [ 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, @@ -18,26 +22,39 @@ pub struct Header { address_family: AddressFamily, src: SocketAddr, dst: SocketAddr, - tlvs: Vec, + tlvs: SmallVec<[(u8, SmallVec<[u8; 16]>); 4]>, } impl Header { - pub const fn new( + pub fn new( command: Command, transport_protocol: TransportProtocol, address_family: AddressFamily, - src: SocketAddr, - dst: SocketAddr, - tlvs: Vec, + src: impl Into, + dst: impl Into, ) -> Self { Self { command, transport_protocol, address_family, + src: src.into(), + dst: dst.into(), + tlvs: SmallVec::new(), + } + } + + pub fn new_tcp_ipv4_proxy(src: impl Into, dst: impl Into) -> Self { + Self::new( + Command::Proxy, + TransportProtocol::Stream, + AddressFamily::Inet, src, dst, - tlvs, - } + ) + } + + pub fn add_tlv(&mut self, typ: u8, value: impl AsRef<[u8]>) { + self.tlvs.push((typ, SmallVec::from_slice(value.as_ref()))); } fn v2_len(&self) -> u16 { @@ -47,7 +64,12 @@ impl Header { 16 + 2 // 16b IPv6 + 2b port number }; - (addr_len * 2) + self.tlvs.iter().map(|tlv| tlv.len()).sum::() + (addr_len * 2) + + self + .tlvs + .iter() + .map(|(_, value)| 1 + 2 + value.len() as u16) + .sum::() } pub fn write_to(&self, wrt: &mut impl io::Write) -> io::Result<()> { @@ -81,8 +103,10 @@ impl Header { wrt.write_all(&self.dst.port().to_be_bytes())?; // TLVs - for tlv in &self.tlvs { - tlv.write_to(wrt)?; + for (typ, value) in &self.tlvs { + wrt.write_all(&[*typ])?; + wrt.write_all(&(value.len() as u16).to_be_bytes())?; + wrt.write_all(&value)?; } Ok(()) @@ -100,8 +124,14 @@ impl Header { buf } + pub fn has_tlv(&self) -> bool { + self.tlvs.iter().any(|&(typ, _)| typ == T::TYPE) + } + + /// If this is not called last thing it will be wrong. pub fn add_crc23c_checksum_tlv(&mut self) { - if self.tlvs.iter().any(|tlv| tlv.as_crc32c().is_some()) { + // don't add a checksum if it is already set + if self.has_tlv::() { return; } @@ -114,7 +144,7 @@ impl Header { // the bits unchanged. // add zeroed checksum field to TLVs - let crc = Pp2Crc32c::default().to_tlv(); + let crc = (Crc32c::TYPE, Crc32c::default().value_bytes().to_smallvec()); self.tlvs.push(crc); // write PROXY header to buffer @@ -123,16 +153,18 @@ impl Header { // calculate CRC on buffer and update CRC TLV let crc_calc = crc32fast::hash(&buf); - self.tlvs.last_mut().unwrap().value = crc_calc.to_be_bytes().to_vec(); + self.tlvs.last_mut().unwrap().1 = crc_calc.to_be_bytes().to_smallvec(); tracing::debug!("checksum is {}", crc_calc); } pub fn validate_crc32c_tlv(&self) -> Option { - dbg!(&self.tlvs); - - // exit early if no crc32c TLV is present - let crc_sent = self.tlvs.iter().filter_map(|tlv| tlv.as_crc32c()).next()?; + // extract crc32c TLV or exit early if none is present + let crc_sent = self + .tlvs + .iter() + .filter_map(|(typ, value)| Crc32c::try_from_parts(*typ, &value)) + .next()?; // If the checksum is provided as part of the PROXY header and the checksum // functionality is supported by the receiver, the receiver MUST: @@ -145,9 +177,9 @@ impl Header { // The default procedure for handling an invalid TCP connection is to abort it. let mut this = self.clone(); - for tlv in this.tlvs.iter_mut() { - if tlv.as_crc32c().is_some() { - tlv.value.fill(0); + for (typ, value) in this.tlvs.iter_mut() { + if Crc32c::try_from_parts(*typ, &value).is_some() { + value.fill(0); } } @@ -168,12 +200,6 @@ mod tests { use super::*; - #[test] - #[should_panic] - fn tlv_zero_len() { - Tlv::new(0x00, vec![]); - } - #[test] fn write_v2_no_tlvs() { let mut exp = Vec::new(); @@ -189,7 +215,6 @@ mod tests { AddressFamily::Inet, SocketAddr::from(([127, 0, 0, 1], 1234)), SocketAddr::from(([127, 0, 0, 2], 80)), - vec![], ); assert_eq!(header.v2_len(), 12); @@ -207,7 +232,7 @@ mod tests { exp.extend_from_slice(&[0x00, 80, 0xff, 0xff]); // 45-49 exp.extend_from_slice(&[0x04, 0x00, 0x01, 0x00]); // 50-53 NOOP TLV - let header = Header::new( + let mut header = Header::new( Command::Local, TransportProtocol::Stream, AddressFamily::Inet, @@ -216,9 +241,10 @@ mod tests { Ipv6Addr::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), 65535, )), - vec![Tlv::new(0x04, [0])], ); + header.add_tlv(0x04, [0]); + assert_eq!(header.v2_len(), 36 + 4); assert_eq!(header.to_vec(), exp); } @@ -248,7 +274,6 @@ mod tests { AddressFamily::Inet, SocketAddr::from(([127, 0, 0, 1], 80)), SocketAddr::from(([127, 0, 0, 1], 80)), - vec![], ); assert!( @@ -266,7 +291,7 @@ mod tests { assert_eq!(header.validate_crc32c_tlv().unwrap(), true); // mangle crc32c TLV and assert that validate now fails - *header.tlvs.last_mut().unwrap().value.last_mut().unwrap() = 0x00; + *header.tlvs.last_mut().unwrap().1.last_mut().unwrap() = 0x00; assert_eq!(header.validate_crc32c_tlv().unwrap(), false); } }