diff --git a/src/payload.rs b/src/payload.rs index bfa4dc81..512d56f1 100644 --- a/src/payload.rs +++ b/src/payload.rs @@ -297,6 +297,21 @@ impl PayloadHelper where S: Stream { } } + #[inline] + pub fn get_chunk(&mut self) -> Poll, PayloadError> { + if self.items.is_empty() { + match self.poll_stream()? { + Async::Ready(true) => (), + Async::Ready(false) => return Ok(Async::Ready(None)), + Async::NotReady => return Ok(Async::NotReady), + } + } + match self.items.front().map(|c| c.as_ref()) { + Some(chunk) => Ok(Async::Ready(Some(chunk))), + None => Ok(Async::NotReady), + } + } + #[inline] pub fn readexactly(&mut self, size: usize) -> Poll, PayloadError> { if size <= self.len { diff --git a/src/ws/frame.rs b/src/ws/frame.rs index 1d758298..52a20e50 100644 --- a/src/ws/frame.rs +++ b/src/ws/frame.rs @@ -1,4 +1,4 @@ -use std::{fmt, mem}; +use std::{fmt, mem, ptr}; use std::iter::FromIterator; use bytes::{Bytes, BytesMut, BufMut}; use byteorder::{ByteOrder, BigEndian, NetworkEndian}; @@ -17,9 +17,6 @@ use ws::mask::apply_mask; #[derive(Debug)] pub(crate) struct Frame { finished: bool, - rsv1: bool, - rsv2: bool, - rsv3: bool, opcode: OpCode, payload: Binary, } @@ -51,9 +48,9 @@ impl Frame { Frame::message(payload, OpCode::Close, true, genmask) } - /// Parse the input stream into a frame. - pub fn parse(pl: &mut PayloadHelper, server: bool, max_size: usize) - -> Poll, ProtocolError> + fn read_copy_md( + pl: &mut PayloadHelper, server: bool, max_size: usize + ) -> Poll)>, ProtocolError> where S: Stream { let mut idx = 2; @@ -74,12 +71,14 @@ impl Frame { return Err(ProtocolError::MaskedFrame) } - let rsv1 = first & 0x40 != 0; - let rsv2 = first & 0x20 != 0; - let rsv3 = first & 0x10 != 0; + // Op code let opcode = OpCode::from(first & 0x0F); - let len = second & 0x7F; + if let OpCode::Bad = opcode { + return Err(ProtocolError::InvalidOpcode(first & 0x0F)) + } + + let len = second & 0x7F; let length = if len == 126 { let buf = match pl.copy(4)? { Async::Ready(Some(buf)) => buf, @@ -114,14 +113,106 @@ impl Frame { Async::NotReady => return Ok(Async::NotReady), }; - let mut mask_bytes = [0u8; 4]; - mask_bytes.copy_from_slice(&buf[idx..idx+4]); + let mask: &[u8] = &buf[idx..idx+4]; + let mask_u32: u32 = unsafe {ptr::read_unaligned(mask.as_ptr() as *const u32)}; idx += 4; - Some(mask_bytes) + Some(mask_u32) } else { None }; + Ok(Async::Ready(Some((idx, finished, opcode, length, mask)))) + } + + fn read_chunk_md(chunk: &[u8], server: bool, max_size: usize) + -> Poll<(usize, bool, OpCode, usize, Option), ProtocolError> + { + let chunk_len = chunk.len(); + + let mut idx = 2; + if chunk_len < 2 { + return Ok(Async::NotReady) + } + + let first = chunk[0]; + let second = chunk[1]; + let finished = first & 0x80 != 0; + + // check masking + let masked = second & 0x80 != 0; + if !masked && server { + return Err(ProtocolError::UnmaskedFrame) + } else if masked && !server { + return Err(ProtocolError::MaskedFrame) + } + + // Op code + let opcode = OpCode::from(first & 0x0F); + + if let OpCode::Bad = opcode { + return Err(ProtocolError::InvalidOpcode(first & 0x0F)) + } + + let len = second & 0x7F; + let length = if len == 126 { + if chunk_len < 4 { + return Ok(Async::NotReady) + } + let len = NetworkEndian::read_uint(&chunk[idx..], 2) as usize; + idx += 2; + len + } else if len == 127 { + if chunk_len < 10 { + return Ok(Async::NotReady) + } + let len = NetworkEndian::read_uint(&chunk[idx..], 8) as usize; + idx += 8; + len + } else { + len as usize + }; + + // check for max allowed size + if length > max_size { + return Err(ProtocolError::Overflow) + } + + let mask = if server { + if chunk_len < idx + 4 { + return Ok(Async::NotReady) + } + + let mask: &[u8] = &chunk[idx..idx+4]; + let mask_u32: u32 = unsafe {ptr::read_unaligned(mask.as_ptr() as *const u32)}; + idx += 4; + Some(mask_u32) + } else { + None + }; + + Ok(Async::Ready((idx, finished, opcode, length, mask))) + } + + /// Parse the input stream into a frame. + pub fn parse(pl: &mut PayloadHelper, server: bool, max_size: usize) + -> Poll, ProtocolError> + where S: Stream + { + let result = match pl.get_chunk()? { + Async::NotReady => return Ok(Async::NotReady), + Async::Ready(Some(chunk)) => Frame::read_chunk_md(chunk, server, max_size)?, + Async::Ready(None) => return Ok(Async::Ready(None)), + }; + + let (idx, finished, opcode, length, mask) = match result { + Async::NotReady => match Frame::read_copy_md(pl, server, max_size)? { + Async::NotReady => return Ok(Async::NotReady), + Async::Ready(Some(item)) => item, + Async::Ready(None) => return Ok(Async::Ready(None)), + }, + Async::Ready(item) => item, + }; + match pl.can_read(idx + length)? { Async::Ready(Some(true)) => (), Async::Ready(None) => return Ok(Async::Ready(None)), @@ -134,7 +225,7 @@ impl Frame { // get body if length == 0 { return Ok(Async::Ready(Some(Frame { - finished, rsv1, rsv2, rsv3, opcode, payload: Binary::from("") }))); + finished, opcode, payload: Binary::from("") }))); } let data = match pl.readexactly(length)? { @@ -143,11 +234,6 @@ impl Frame { Async::NotReady => panic!(), }; - // Disallow bad opcode - if let OpCode::Bad = opcode { - return Err(ProtocolError::InvalidOpcode(first & 0x0F)) - } - // control frames must have length <= 125 match opcode { OpCode::Ping | OpCode::Pong if length > 125 => { @@ -161,14 +247,14 @@ impl Frame { } // unmask - if let Some(ref mask) = mask { + if let Some(mask) = mask { #[allow(mutable_transmutes)] let p: &mut [u8] = unsafe{let ptr: &[u8] = &data; mem::transmute(ptr)}; apply_mask(p, mask); } Ok(Async::Ready(Some(Frame { - finished, rsv1, rsv2, rsv3, opcode, payload: data.into() }))) + finished, opcode, payload: data.into() }))) } /// Generate binary representation @@ -213,13 +299,13 @@ impl Frame { }; if genmask { - let mask: [u8; 4] = rand::random(); + let mask = rand::random::(); unsafe { { let buf_mut = buf.bytes_mut(); - buf_mut[..4].copy_from_slice(&mask); + *(buf_mut as *mut _ as *mut u32) = mask; buf_mut[4..payload_len+4].copy_from_slice(payload.as_ref()); - apply_mask(&mut buf_mut[4..], &mask); + apply_mask(&mut buf_mut[4..], mask); } buf.advance_mut(payload_len + 4); } @@ -235,9 +321,6 @@ impl Default for Frame { fn default() -> Frame { Frame { finished: true, - rsv1: false, - rsv2: false, - rsv3: false, opcode: OpCode::Close, payload: Binary::from(&b""[..]), } @@ -250,15 +333,11 @@ impl fmt::Display for Frame { " final: {} - reserved: {} {} {} opcode: {} payload length: {} payload: 0x{} ", self.finished, - self.rsv1, - self.rsv2, - self.rsv3, self.opcode, self.payload.len(), self.payload.as_ref().iter().map( @@ -296,7 +375,6 @@ mod tests { let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); let frame = extract(Frame::parse(&mut buf, false, 1024)); - println!("FRAME: {}", frame); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload.as_ref(), &b"1"[..]); diff --git a/src/ws/mask.rs b/src/ws/mask.rs index 2e5a2960..e29eefd9 100644 --- a/src/ws/mask.rs +++ b/src/ws/mask.rs @@ -2,11 +2,10 @@ use std::cmp::min; use std::mem::uninitialized; use std::ptr::copy_nonoverlapping; -use std::ptr; /// Mask/unmask a frame. #[inline] -pub fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) { +pub fn apply_mask(buf: &mut [u8], mask: u32) { apply_mask_fast32(buf, mask) } @@ -21,9 +20,7 @@ fn apply_mask_fallback(buf: &mut [u8], mask: &[u8; 4]) { /// Faster version of `apply_mask()` which operates on 8-byte blocks. #[inline] -fn apply_mask_fast32(buf: &mut [u8], mask: &[u8; 4]) { - let mask_u32: u32 = unsafe {ptr::read_unaligned(mask.as_ptr() as *const u32)}; - +fn apply_mask_fast32(buf: &mut [u8], mask_u32: u32) { let mut ptr = buf.as_mut_ptr(); let mut len = buf.len(); @@ -35,12 +32,14 @@ fn apply_mask_fast32(buf: &mut [u8], mask: &[u8; 4]) { ptr = ptr.offset(head as isize); } len -= head; - let mask_u32 = if cfg!(target_endian = "big") { + //let mask_u32 = + if cfg!(target_endian = "big") { mask_u32.rotate_left(8 * head as u32) } else { mask_u32.rotate_right(8 * head as u32) - }; + }//; + /* let head = min(len, (4 - (ptr as usize & 3)) & 3); if head > 0 { unsafe { @@ -55,7 +54,7 @@ fn apply_mask_fast32(buf: &mut [u8], mask: &[u8; 4]) { } } else { mask_u32 - } + }*/ } else { mask_u32 }; @@ -106,6 +105,7 @@ unsafe fn xor_mem(ptr: *mut u8, mask: u32, len: usize) { #[cfg(test)] mod tests { + use std::ptr; use super::{apply_mask_fallback, apply_mask_fast32}; #[test] @@ -113,6 +113,8 @@ mod tests { let mask = [ 0x6d, 0xb6, 0xb2, 0x80, ]; + let mask_u32: u32 = unsafe {ptr::read_unaligned(mask.as_ptr() as *const u32)}; + let unmasked = vec![ 0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9, 0x12, 0x03, @@ -124,7 +126,7 @@ mod tests { apply_mask_fallback(&mut masked, &mask); let mut masked_fast = unmasked.clone(); - apply_mask_fast32(&mut masked_fast, &mask); + apply_mask_fast32(&mut masked_fast, mask_u32); assert_eq!(masked, masked_fast); } @@ -135,7 +137,7 @@ mod tests { apply_mask_fallback(&mut masked[1..], &mask); let mut masked_fast = unmasked.clone(); - apply_mask_fast32(&mut masked_fast[1..], &mask); + apply_mask_fast32(&mut masked_fast[1..], mask_u32); assert_eq!(masked, masked_fast); }