1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
//! This is code from [Tungstenite project](https://github.com/snapview/tungstenite-rs)
#![cfg_attr(feature = "cargo-clippy", allow(clippy::cast_ptr_alignment))]
use std::ptr::copy_nonoverlapping;
use std::slice;

// Holds a slice guaranteed to be shorter than 8 bytes
struct ShortSlice<'a>(&'a mut [u8]);

impl<'a> ShortSlice<'a> {
    unsafe fn new(slice: &'a mut [u8]) -> Self {
        // Sanity check for debug builds
        debug_assert!(slice.len() < 8);
        ShortSlice(slice)
    }
    fn len(&self) -> usize {
        self.0.len()
    }
}

/// Faster version of `apply_mask()` which operates on 8-byte blocks.
#[inline]
#[cfg_attr(feature = "cargo-clippy", allow(clippy::cast_lossless))]
pub(crate) fn apply_mask(buf: &mut [u8], mask_u32: u32) {
    // Extend the mask to 64 bits
    let mut mask_u64 = ((mask_u32 as u64) << 32) | (mask_u32 as u64);
    // Split the buffer into three segments
    let (head, mid, tail) = align_buf(buf);

    // Initial unaligned segment
    let head_len = head.len();
    if head_len > 0 {
        xor_short(head, mask_u64);
        if cfg!(target_endian = "big") {
            mask_u64 = mask_u64.rotate_left(8 * head_len as u32);
        } else {
            mask_u64 = mask_u64.rotate_right(8 * head_len as u32);
        }
    }
    // Aligned segment
    for v in mid {
        *v ^= mask_u64;
    }
    // Final unaligned segment
    if tail.len() > 0 {
        xor_short(tail, mask_u64);
    }
}

#[inline]
// TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so
// inefficient, it could be done better. The compiler does not understand that
// a `ShortSlice` must be smaller than a u64.
#[cfg_attr(
    feature = "cargo-clippy",
    allow(clippy::needless_pass_by_value)
)]
fn xor_short(buf: ShortSlice, mask: u64) {
    // Unsafe: we know that a `ShortSlice` fits in a u64
    unsafe {
        let (ptr, len) = (buf.0.as_mut_ptr(), buf.0.len());
        let mut b: u64 = 0;
        #[allow(trivial_casts)]
        copy_nonoverlapping(ptr, &mut b as *mut _ as *mut u8, len);
        b ^= mask;
        #[allow(trivial_casts)]
        copy_nonoverlapping(&b as *const _ as *const u8, ptr, len);
    }
}

#[inline]
// Unsafe: caller must ensure the buffer has the correct size and alignment
unsafe fn cast_slice(buf: &mut [u8]) -> &mut [u64] {
    // Assert correct size and alignment in debug builds
    debug_assert!(buf.len().trailing_zeros() >= 3);
    debug_assert!((buf.as_ptr() as usize).trailing_zeros() >= 3);

    slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut u64, buf.len() >> 3)
}

#[inline]
// Splits a slice into three parts: an unaligned short head and tail, plus an aligned
// u64 mid section.
fn align_buf(buf: &mut [u8]) -> (ShortSlice, &mut [u64], ShortSlice) {
    let start_ptr = buf.as_ptr() as usize;
    let end_ptr = start_ptr + buf.len();

    // Round *up* to next aligned boundary for start
    let start_aligned = (start_ptr + 7) & !0x7;
    // Round *down* to last aligned boundary for end
    let end_aligned = end_ptr & !0x7;

    if end_aligned >= start_aligned {
        // We have our three segments (head, mid, tail)
        let (tmp, tail) = buf.split_at_mut(end_aligned - start_ptr);
        let (head, mid) = tmp.split_at_mut(start_aligned - start_ptr);

        // Unsafe: we know the middle section is correctly aligned, and the outer
        // sections are smaller than 8 bytes
        unsafe { (ShortSlice::new(head), cast_slice(mid), ShortSlice(tail)) }
    } else {
        // We didn't cross even one aligned boundary!

        // Unsafe: The outer sections are smaller than 8 bytes
        unsafe { (ShortSlice::new(buf), &mut [], ShortSlice::new(&mut [])) }
    }
}

#[cfg(test)]
mod tests {
    use super::apply_mask;
    use byteorder::{ByteOrder, LittleEndian};

    /// A safe unoptimized mask application.
    fn apply_mask_fallback(buf: &mut [u8], mask: &[u8; 4]) {
        for (i, byte) in buf.iter_mut().enumerate() {
            *byte ^= mask[i & 3];
        }
    }

    #[test]
    fn test_apply_mask() {
        let mask = [0x6d, 0xb6, 0xb2, 0x80];
        let mask_u32: u32 = LittleEndian::read_u32(&mask);

        let unmasked = vec![
            0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17,
            0x74, 0xf9, 0x12, 0x03,
        ];

        // Check masking with proper alignment.
        {
            let mut masked = unmasked.clone();
            apply_mask_fallback(&mut masked, &mask);

            let mut masked_fast = unmasked.clone();
            apply_mask(&mut masked_fast, mask_u32);

            assert_eq!(masked, masked_fast);
        }

        // Check masking without alignment.
        {
            let mut masked = unmasked.clone();
            apply_mask_fallback(&mut masked[1..], &mask);

            let mut masked_fast = unmasked.clone();
            apply_mask(&mut masked_fast[1..], mask_u32);

            assert_eq!(masked, masked_fast);
        }
    }
}