use std::{ io, net::{IpAddr, SocketAddr}, }; use smallvec::{smallvec, SmallVec, ToSmallVec as _}; use tokio::io::{AsyncWrite, AsyncWriteExt as _}; use crate::{ tlv::{Crc32c, Tlv}, AddressFamily, Command, TransportProtocol, Version, }; pub(crate) const SIGNATURE: [u8; 12] = [ 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, ]; #[derive(Debug, Clone)] pub struct Header { command: Command, transport_protocol: TransportProtocol, address_family: AddressFamily, src: SocketAddr, dst: SocketAddr, tlvs: SmallVec<[(u8, SmallVec<[u8; 16]>); 4]>, } impl Header { pub fn new( command: Command, transport_protocol: TransportProtocol, address_family: AddressFamily, src: impl Into, dst: impl Into, ) -> Self { Self { command, transport_protocol, address_family, src: src.into(), dst: dst.into(), tlvs: SmallVec::new(), } } pub fn new_tcp_ipv4_proxy(src: impl Into, dst: impl Into) -> Self { Self::new( Command::Proxy, TransportProtocol::Stream, AddressFamily::Inet, src, dst, ) } pub fn add_tlv(&mut self, typ: u8, value: impl AsRef<[u8]>) { self.tlvs.push((typ, SmallVec::from_slice(value.as_ref()))); } fn v2_len(&self) -> u16 { let addr_len = if self.src.is_ipv4() { 4 + 2 // 4b IPv4 + 2b port number } else { 16 + 2 // 16b IPv6 + 2b port number }; (addr_len * 2) + self .tlvs .iter() .map(|(_, value)| 1 + 2 + value.len() as u16) .sum::() } pub fn write_to(&self, wrt: &mut impl io::Write) -> io::Result<()> { // PROXY v2 signature wrt.write_all(&SIGNATURE)?; // version | command wrt.write_all(&[Version::V2.v2_hi() | self.command.v2_lo()])?; // address family | transport protocol wrt.write_all(&[self.address_family.v2_hi() | self.transport_protocol.v2_lo()])?; // rest-of-header length wrt.write_all(&self.v2_len().to_be_bytes())?; tracing::debug!("proxy rest-of-header len: {}", self.v2_len()); fn write_ip_bytes_to(wrt: &mut impl io::Write, ip: IpAddr) -> io::Result<()> { match ip { IpAddr::V4(ip) => wrt.write_all(&ip.octets()), IpAddr::V6(ip) => wrt.write_all(&ip.octets()), } } // L3 (IP) address write_ip_bytes_to(wrt, self.src.ip())?; write_ip_bytes_to(wrt, self.dst.ip())?; // L4 ports wrt.write_all(&self.src.port().to_be_bytes())?; wrt.write_all(&self.dst.port().to_be_bytes())?; // TLVs for (typ, value) in &self.tlvs { wrt.write_all(&[*typ])?; wrt.write_all(&(value.len() as u16).to_be_bytes())?; wrt.write_all(&value)?; } Ok(()) } pub async fn write_to_tokio(&self, wrt: &mut (impl AsyncWrite + Unpin)) -> io::Result<()> { let buf = self.to_vec(); wrt.write_all(&buf).await } fn to_vec(&self) -> Vec { // TODO: figure out cap let mut buf = Vec::with_capacity(64); self.write_to(&mut buf).unwrap(); buf } pub fn has_tlv(&self) -> bool { self.tlvs.iter().any(|&(typ, _)| typ == T::TYPE) } /// If this is not called last thing it will be wrong. pub fn add_crc23c_checksum_tlv(&mut self) { // don't add a checksum if it is already set if self.has_tlv::() { return; } // When the checksum is supported by the sender after constructing the header // the sender MUST: // - initialize the checksum field to '0's. // - calculate the CRC32c checksum of the PROXY header as described in RFC4960, // Appendix B [8]. // - put the resultant value into the checksum field, and leave the rest of // the bits unchanged. // add zeroed checksum field to TLVs let crc = (Crc32c::TYPE, Crc32c::default().value_bytes().to_smallvec()); self.tlvs.push(crc); // write PROXY header to buffer let mut buf = Vec::new(); self.write_to(&mut buf).unwrap(); // calculate CRC on buffer and update CRC TLV let crc_calc = crc32fast::hash(&buf); self.tlvs.last_mut().unwrap().1 = crc_calc.to_be_bytes().to_smallvec(); tracing::debug!("checksum is {}", crc_calc); } pub fn validate_crc32c_tlv(&self) -> Option { // extract crc32c TLV or exit early if none is present let crc_sent = self .tlvs .iter() .filter_map(|(typ, value)| Crc32c::try_from_parts(*typ, &value)) .next()?; // If the checksum is provided as part of the PROXY header and the checksum // functionality is supported by the receiver, the receiver MUST: // - store the received CRC32c checksum value aside. // - replace the 32 bits of the checksum field in the received PROXY header with // all '0's and calculate a CRC32c checksum value of the whole PROXY header. // - verify that the calculated CRC32c checksum is the same as the received // CRC32c checksum. If it is not, the receiver MUST treat the TCP connection // providing the header as invalid. // The default procedure for handling an invalid TCP connection is to abort it. let mut this = self.clone(); for (typ, value) in this.tlvs.iter_mut() { if Crc32c::try_from_parts(*typ, &value).is_some() { value.fill(0); } } let mut buf = Vec::new(); this.write_to(&mut buf).unwrap(); let mut crc_calc = crc32fast::hash(&buf); Some(crc_sent.checksum == crc_calc) } } #[cfg(test)] mod tests { use std::net::Ipv6Addr; use const_str::hex; use pretty_assertions::assert_eq; use super::*; #[test] fn write_v2_no_tlvs() { let mut exp = Vec::new(); exp.extend_from_slice(&SIGNATURE); // 0-11 exp.extend_from_slice(&[0x21, 0x11]); // 12-13 exp.extend_from_slice(&[0x00, 0x0C]); // 14-15 exp.extend_from_slice(&[127, 0, 0, 1, 127, 0, 0, 2]); // 16-23 exp.extend_from_slice(&[0x04, 0xd2, 0x00, 80]); // 24-27 let header = Header::new( Command::Proxy, TransportProtocol::Stream, AddressFamily::Inet, SocketAddr::from(([127, 0, 0, 1], 1234)), SocketAddr::from(([127, 0, 0, 2], 80)), ); assert_eq!(header.v2_len(), 12); assert_eq!(header.to_vec(), exp); } #[test] fn write_v2_ipv6_tlv_noop() { let mut exp = Vec::new(); exp.extend_from_slice(&SIGNATURE); // 0-11 exp.extend_from_slice(&[0x20, 0x11]); // 12-13 exp.extend_from_slice(&[0x00, 0x28]); // 14-15 exp.extend_from_slice(&hex!("00000000000000000000000000000001")); // 16-31 exp.extend_from_slice(&hex!("000102030405060708090A0B0C0D0E0F")); // 32-45 exp.extend_from_slice(&[0x00, 80, 0xff, 0xff]); // 45-49 exp.extend_from_slice(&[0x04, 0x00, 0x01, 0x00]); // 50-53 NOOP TLV let mut header = Header::new( Command::Local, TransportProtocol::Stream, AddressFamily::Inet, SocketAddr::from((Ipv6Addr::LOCALHOST, 80)), SocketAddr::from(( Ipv6Addr::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), 65535, )), ); header.add_tlv(0x04, [0]); assert_eq!(header.v2_len(), 36 + 4); assert_eq!(header.to_vec(), exp); } #[test] fn write_v2_tlv_c2c() { let mut exp = Vec::new(); exp.extend_from_slice(&SIGNATURE); // 0-11 exp.extend_from_slice(&[0x21, 0x11]); // 12-13 exp.extend_from_slice(&[0x00, 0x13]); // 14-15 exp.extend_from_slice(&[127, 0, 0, 1, 127, 0, 0, 1]); // 16-23 exp.extend_from_slice(&[0x00, 80, 0x00, 80]); // 24-27 exp.extend_from_slice(&[0x03, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00]); // 28-35 TLV crc32c assert_eq!( crc32fast::hash(&exp), // correct checksum calculated manually u32::from_be_bytes([0x08, 0x70, 0x17, 0x7b]), ); // re-assign actual checksum to last 4 bytes of expected byte array exp[31..35].copy_from_slice(&[0x08, 0x70, 0x17, 0x7b]); let mut header = Header::new( Command::Proxy, TransportProtocol::Stream, AddressFamily::Inet, SocketAddr::from(([127, 0, 0, 1], 80)), SocketAddr::from(([127, 0, 0, 1], 80)), ); assert!( header.validate_crc32c_tlv().is_none(), "header doesn't have CRC TLV added yet" ); // add crc32c TLV to header header.add_crc23c_checksum_tlv(); assert_eq!(header.v2_len(), 12 + 7); assert_eq!(header.to_vec(), exp); // struct can self-validate checksum assert_eq!(header.validate_crc32c_tlv().unwrap(), true); // mangle crc32c TLV and assert that validate now fails *header.tlvs.last_mut().unwrap().1.last_mut().unwrap() = 0x00; assert_eq!(header.validate_crc32c_tlv().unwrap(), false); } }