1
0
mirror of https://github.com/fafhrd91/actix-net synced 2024-12-03 19:42:13 +01:00

add Tlv trait

This commit is contained in:
Rob Ede 2022-10-31 19:11:04 +00:00
parent b9877392ab
commit 69d9afd39e
No known key found for this signature in database
GPG Key ID: 97C636207D3EF933
6 changed files with 122 additions and 121 deletions

View File

@ -21,6 +21,7 @@ crc32fast = "1"
futures-core = { version = "0.3.17", default-features = false, features = ["std"] }
futures-util = { version = "0.3.17", default-features = false, features = ["std"] }
itoa = "1"
smallvec = "1"
tokio = { version = "1.13.1", features = ["sync", "io-util"] }
tracing = { version = "0.1.30", default-features = false, features = ["log"] }

View File

@ -11,7 +11,7 @@ use std::{
},
};
use actix_proxy_protocol::{v1, v2, AddressFamily, Command, Tlv, TransportProtocol};
use actix_proxy_protocol::{tlv, v1, v2, AddressFamily, Command, TransportProtocol};
use actix_rt::net::TcpStream;
use actix_server::Server;
use actix_service::{fn_service, ServiceFactoryExt as _};
@ -31,6 +31,7 @@ TLV = type-length-value
TO DO:
handle UNKNOWN transport
v2 UNSPEC mode
AF_UNIX socket
*/
fn extend_with_ip_bytes(buf: &mut Vec<u8>, ip: IpAddr) {
@ -71,18 +72,10 @@ async fn wrap_with_proxy_protocol_v2(mut stream: TcpStream) -> io::Result<()> {
UPSTREAM.to_string()
);
let mut proxy_header = v2::Header::new(
Command::Proxy,
TransportProtocol::Stream,
AddressFamily::Inet,
SocketAddr::from(([127, 0, 0, 1], 8082)),
*UPSTREAM,
vec![
Tlv::new(0x05, [0x34, 0x32, 0x36, 0x39]), // UNIQUE_ID
Tlv::new(0x04, "NOOP m9"), // NOOP
],
);
let mut proxy_header = v2::Header::new_tcp_ipv4_proxy(([127, 0, 0, 1], 8082), *UPSTREAM);
proxy_header.add_tlv(0x05, [0x34, 0x32, 0x36, 0x39]); // UNIQUE_ID
proxy_header.add_tlv(0x04, "NOOP m9"); // NOOP
proxy_header.add_crc23c_checksum_tlv();
proxy_header.write_to_tokio(&mut upstream).await?;

View File

@ -16,6 +16,7 @@ use std::{
use arrayvec::{ArrayString, ArrayVec};
use tokio::io::{AsyncWrite, AsyncWriteExt as _};
pub mod tlv;
pub mod v1;
pub mod v2;
@ -165,80 +166,3 @@ enum ProxyProtocolHeader {
V1(v1::Header),
V2(v2::Header),
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct Pp2Crc32c {
checksum: u32,
}
impl Pp2Crc32c {
fn to_tlv(&self) -> Tlv {
Tlv {
r#type: 0x03,
value: self.checksum.to_be_bytes().to_vec(),
}
}
}
#[derive(Debug, Clone)]
pub struct Tlv {
r#type: u8,
value: Vec<u8>,
}
impl Tlv {
pub fn new(r#type: u8, value: impl Into<Vec<u8>>) -> Self {
let value = value.into();
assert!(
!value.is_empty(),
"TLV values must have length greater than 1",
);
Self { r#type, value }
}
fn len(&self) -> u16 {
// 1b type + 2b len
// + [len]b value
1 + 2 + (self.value.len() as u16)
}
pub fn write_to(&self, wrt: &mut impl io::Write) -> io::Result<()> {
wrt.write_all(&[self.r#type])?;
wrt.write_all(&(self.value.len() as u16).to_be_bytes())?;
wrt.write_all(&self.value)?;
Ok(())
}
pub fn as_crc32c(&self) -> Option<Pp2Crc32c> {
if self.r#type != 0x03 {
return None;
}
let checksum_bytes = <[u8; 4]>::try_from(self.value.as_slice()).ok()?;
Some(Pp2Crc32c {
checksum: u32::from_be_bytes(checksum_bytes),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tlv_as_crc32c() {
let noop = Tlv::new(0x04, vec![0x00]);
assert_eq!(noop.as_crc32c(), None);
let noop = Tlv::new(0x03, vec![0x08, 0x70, 0x17, 0x7b]);
assert_eq!(
noop.as_crc32c(),
Some(Pp2Crc32c {
checksum: 141563771
})
);
}
}

View File

@ -0,0 +1,58 @@
use std::convert::TryFrom;
pub trait Tlv: Sized {
const TYPE: u8;
fn try_from_parts(typ: u8, value: &[u8]) -> Option<Self>;
fn value_bytes(&self) -> Vec<u8>;
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct Crc32c {
pub(crate) checksum: u32,
}
impl Tlv for Crc32c {
const TYPE: u8 = 0x03;
fn try_from_parts(typ: u8, value: &[u8]) -> Option<Self> {
if typ != Self::TYPE {
return None;
}
let checksum_bytes = <[u8; 4]>::try_from(value).ok()?;
Some(Self {
checksum: u32::from_be_bytes(checksum_bytes),
})
}
fn value_bytes(&self) -> Vec<u8> {
self.checksum.to_be_bytes().to_vec()
}
}
#[cfg(test)]
mod tests {
use super::*;
// #[test]
// #[should_panic]
// fn tlv_zero_len() {
// Tlv::new(0x00, vec![]);
// }
#[test]
fn tlv_as_crc32c() {
// noop
assert_eq!(Crc32c::try_from_parts(0x04, &[0x00]), None);
assert_eq!(
Crc32c::try_from_parts(0x03, &[0x08, 0x70, 0x17, 0x7b]),
Some(Crc32c {
checksum: 141563771
})
);
}
}

View File

@ -3,9 +3,13 @@ use std::{
net::{IpAddr, SocketAddr},
};
use smallvec::{smallvec, SmallVec, ToSmallVec as _};
use tokio::io::{AsyncWrite, AsyncWriteExt as _};
use crate::{AddressFamily, Command, Pp2Crc32c, Tlv, TransportProtocol, Version};
use crate::{
tlv::{Crc32c, Tlv},
AddressFamily, Command, TransportProtocol, Version,
};
pub(crate) const SIGNATURE: [u8; 12] = [
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
@ -18,26 +22,39 @@ pub struct Header {
address_family: AddressFamily,
src: SocketAddr,
dst: SocketAddr,
tlvs: Vec<Tlv>,
tlvs: SmallVec<[(u8, SmallVec<[u8; 16]>); 4]>,
}
impl Header {
pub const fn new(
pub fn new(
command: Command,
transport_protocol: TransportProtocol,
address_family: AddressFamily,
src: SocketAddr,
dst: SocketAddr,
tlvs: Vec<Tlv>,
src: impl Into<SocketAddr>,
dst: impl Into<SocketAddr>,
) -> Self {
Self {
command,
transport_protocol,
address_family,
src: src.into(),
dst: dst.into(),
tlvs: SmallVec::new(),
}
}
pub fn new_tcp_ipv4_proxy(src: impl Into<SocketAddr>, dst: impl Into<SocketAddr>) -> Self {
Self::new(
Command::Proxy,
TransportProtocol::Stream,
AddressFamily::Inet,
src,
dst,
tlvs,
}
)
}
pub fn add_tlv(&mut self, typ: u8, value: impl AsRef<[u8]>) {
self.tlvs.push((typ, SmallVec::from_slice(value.as_ref())));
}
fn v2_len(&self) -> u16 {
@ -47,7 +64,12 @@ impl Header {
16 + 2 // 16b IPv6 + 2b port number
};
(addr_len * 2) + self.tlvs.iter().map(|tlv| tlv.len()).sum::<u16>()
(addr_len * 2)
+ self
.tlvs
.iter()
.map(|(_, value)| 1 + 2 + value.len() as u16)
.sum::<u16>()
}
pub fn write_to(&self, wrt: &mut impl io::Write) -> io::Result<()> {
@ -81,8 +103,10 @@ impl Header {
wrt.write_all(&self.dst.port().to_be_bytes())?;
// TLVs
for tlv in &self.tlvs {
tlv.write_to(wrt)?;
for (typ, value) in &self.tlvs {
wrt.write_all(&[*typ])?;
wrt.write_all(&(value.len() as u16).to_be_bytes())?;
wrt.write_all(&value)?;
}
Ok(())
@ -100,8 +124,14 @@ impl Header {
buf
}
pub fn has_tlv<T: Tlv>(&self) -> bool {
self.tlvs.iter().any(|&(typ, _)| typ == T::TYPE)
}
/// If this is not called last thing it will be wrong.
pub fn add_crc23c_checksum_tlv(&mut self) {
if self.tlvs.iter().any(|tlv| tlv.as_crc32c().is_some()) {
// don't add a checksum if it is already set
if self.has_tlv::<Crc32c>() {
return;
}
@ -114,7 +144,7 @@ impl Header {
// the bits unchanged.
// add zeroed checksum field to TLVs
let crc = Pp2Crc32c::default().to_tlv();
let crc = (Crc32c::TYPE, Crc32c::default().value_bytes().to_smallvec());
self.tlvs.push(crc);
// write PROXY header to buffer
@ -123,16 +153,18 @@ impl Header {
// calculate CRC on buffer and update CRC TLV
let crc_calc = crc32fast::hash(&buf);
self.tlvs.last_mut().unwrap().value = crc_calc.to_be_bytes().to_vec();
self.tlvs.last_mut().unwrap().1 = crc_calc.to_be_bytes().to_smallvec();
tracing::debug!("checksum is {}", crc_calc);
}
pub fn validate_crc32c_tlv(&self) -> Option<bool> {
dbg!(&self.tlvs);
// exit early if no crc32c TLV is present
let crc_sent = self.tlvs.iter().filter_map(|tlv| tlv.as_crc32c()).next()?;
// extract crc32c TLV or exit early if none is present
let crc_sent = self
.tlvs
.iter()
.filter_map(|(typ, value)| Crc32c::try_from_parts(*typ, &value))
.next()?;
// If the checksum is provided as part of the PROXY header and the checksum
// functionality is supported by the receiver, the receiver MUST:
@ -145,9 +177,9 @@ impl Header {
// The default procedure for handling an invalid TCP connection is to abort it.
let mut this = self.clone();
for tlv in this.tlvs.iter_mut() {
if tlv.as_crc32c().is_some() {
tlv.value.fill(0);
for (typ, value) in this.tlvs.iter_mut() {
if Crc32c::try_from_parts(*typ, &value).is_some() {
value.fill(0);
}
}
@ -168,12 +200,6 @@ mod tests {
use super::*;
#[test]
#[should_panic]
fn tlv_zero_len() {
Tlv::new(0x00, vec![]);
}
#[test]
fn write_v2_no_tlvs() {
let mut exp = Vec::new();
@ -189,7 +215,6 @@ mod tests {
AddressFamily::Inet,
SocketAddr::from(([127, 0, 0, 1], 1234)),
SocketAddr::from(([127, 0, 0, 2], 80)),
vec![],
);
assert_eq!(header.v2_len(), 12);
@ -207,7 +232,7 @@ mod tests {
exp.extend_from_slice(&[0x00, 80, 0xff, 0xff]); // 45-49
exp.extend_from_slice(&[0x04, 0x00, 0x01, 0x00]); // 50-53 NOOP TLV
let header = Header::new(
let mut header = Header::new(
Command::Local,
TransportProtocol::Stream,
AddressFamily::Inet,
@ -216,9 +241,10 @@ mod tests {
Ipv6Addr::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
65535,
)),
vec![Tlv::new(0x04, [0])],
);
header.add_tlv(0x04, [0]);
assert_eq!(header.v2_len(), 36 + 4);
assert_eq!(header.to_vec(), exp);
}
@ -248,7 +274,6 @@ mod tests {
AddressFamily::Inet,
SocketAddr::from(([127, 0, 0, 1], 80)),
SocketAddr::from(([127, 0, 0, 1], 80)),
vec![],
);
assert!(
@ -266,7 +291,7 @@ mod tests {
assert_eq!(header.validate_crc32c_tlv().unwrap(), true);
// mangle crc32c TLV and assert that validate now fails
*header.tlvs.last_mut().unwrap().value.last_mut().unwrap() = 0x00;
*header.tlvs.last_mut().unwrap().1.last_mut().unwrap() = 0x00;
assert_eq!(header.validate_crc32c_tlv().unwrap(), false);
}
}