From 3109f9be629db32d64e02381ab8e758567c5adc6 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sat, 10 Feb 2018 00:05:20 -0800 Subject: [PATCH] special handling for upgraded pipeline --- guide/src/qs_8.md | 2 +- src/server/h1writer.rs | 31 +++++++-- src/ws/client.rs | 26 +++---- src/ws/context.rs | 24 +++---- src/ws/frame.rs | 152 +++++++++++++---------------------------- 5 files changed, 91 insertions(+), 144 deletions(-) diff --git a/guide/src/qs_8.md b/guide/src/qs_8.md index 0f062edce..c2e533aaa 100644 --- a/guide/src/qs_8.md +++ b/guide/src/qs_8.md @@ -128,7 +128,7 @@ impl Handler for Ws { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { match msg { - ws::Message::Text(text) => ctx.text(&text), + ws::Message::Text(text) => ctx.text(text), _ => (), } } diff --git a/src/server/h1writer.rs b/src/server/h1writer.rs index 09f1b45d4..9bf99a624 100644 --- a/src/server/h1writer.rs +++ b/src/server/h1writer.rs @@ -96,15 +96,17 @@ impl Writer for H1Writer { fn start(&mut self, req: &mut HttpMessage, msg: &mut HttpResponse) -> io::Result { // prepare task - self.flags.insert(Flags::STARTED); self.encoder = PayloadEncoder::new(self.buffer.clone(), req, msg); if msg.keep_alive().unwrap_or_else(|| req.keep_alive()) { - self.flags.insert(Flags::KEEPALIVE); + self.flags.insert(Flags::STARTED | Flags::KEEPALIVE); + } else { + self.flags.insert(Flags::STARTED); } // Connection upgrade let version = msg.version().unwrap_or_else(|| req.version); if msg.upgrade() { + self.flags.insert(Flags::UPGRADE); msg.headers_mut().insert(CONNECTION, HeaderValue::from_static("upgrade")); } // keep-alive @@ -177,8 +179,29 @@ impl Writer for H1Writer { self.written += payload.len() as u64; if !self.flags.contains(Flags::DISCONNECTED) { if self.flags.contains(Flags::STARTED) { - // TODO: add warning, write after EOF - self.encoder.write(payload)?; + // shortcut for upgraded connection + if self.flags.contains(Flags::UPGRADE) { + if self.buffer.is_empty() { + match self.stream.write(payload.as_ref()) { + Ok(0) => { + self.disconnected(); + return Ok(WriterState::Done); + }, + Ok(n) => if payload.len() < n { + self.buffer.extend_from_slice(&payload.as_ref()[n..]) + }, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + return Ok(WriterState::Done) + } + Err(err) => return Err(err), + } + } else { + self.buffer.extend(payload); + } + } else { + // TODO: add warning, write after EOF + self.encoder.write(payload)?; + } } else { // might be response to EXCEPT self.buffer.extend_from_slice(payload.as_ref()) diff --git a/src/ws/client.rs b/src/ws/client.rs index 565e75804..8b4837c12 100644 --- a/src/ws/client.rs +++ b/src/ws/client.rs @@ -28,8 +28,8 @@ use client::{ClientRequest, ClientRequestBuilder, use client::{Connect, Connection, ClientConnector, ClientConnectorError}; use super::Message; +use super::frame::Frame; use super::proto::{CloseCode, OpCode}; -use super::frame::{Frame, FrameData}; pub type WsClientFuture = Future; @@ -444,17 +444,9 @@ impl WsClientWriter { /// Write payload #[inline] - fn write(&mut self, data: FrameData) { + fn write(&mut self, data: &Binary) { if !self.as_mut().closed { - match data { - FrameData::Complete(data) => { - let _ = self.as_mut().writer.write(&data); - }, - FrameData::Split(headers, payload) => { - let _ = self.as_mut().writer.write(&headers); - let _ = self.as_mut().writer.write(&payload); - } - } + let _ = self.as_mut().writer.write(data); } else { warn!("Trying to write to disconnected response"); } @@ -462,31 +454,31 @@ impl WsClientWriter { /// Send text frame #[inline] - pub fn text(&mut self, text: &str) { - self.write(Frame::message(Vec::from(text), OpCode::Text, true).generate(true)); + pub fn text>(&mut self, text: T) { + self.write(&Frame::message(text.into(), OpCode::Text, true, true)); } /// Send binary frame #[inline] pub fn binary>(&mut self, data: B) { - self.write(Frame::message(data, OpCode::Binary, true).generate(true)); + self.write(&Frame::message(data, OpCode::Binary, true, true)); } /// Send ping frame #[inline] pub fn ping(&mut self, message: &str) { - self.write(Frame::message(Vec::from(message), OpCode::Ping, true).generate(true)); + self.write(&Frame::message(Vec::from(message), OpCode::Ping, true, true)); } /// Send pong frame #[inline] pub fn pong(&mut self, message: &str) { - self.write(Frame::message(Vec::from(message), OpCode::Pong, true).generate(true)); + self.write(&Frame::message(Vec::from(message), OpCode::Pong, true, true)); } /// Send close frame #[inline] pub fn close(&mut self, code: CloseCode, reason: &str) { - self.write(Frame::close(code, reason).generate(true)); + self.write(&Frame::close(code, reason, true)); } } diff --git a/src/ws/context.rs b/src/ws/context.rs index a903a890b..835c1774e 100644 --- a/src/ws/context.rs +++ b/src/ws/context.rs @@ -14,7 +14,7 @@ use error::{Error, ErrorInternalServerError}; use httprequest::HttpRequest; use context::{Frame as ContextFrame, ActorHttpContext, Drain}; -use ws::frame::{Frame, FrameData}; +use ws::frame::Frame; use ws::proto::{OpCode, CloseCode}; @@ -105,21 +105,13 @@ impl WebsocketContext where A: Actor { /// Write payload #[inline] - fn write(&mut self, data: FrameData) { + fn write(&mut self, data: Binary) { if !self.disconnected { if self.stream.is_none() { self.stream = Some(SmallVec::new()); } let stream = self.stream.as_mut().unwrap(); - - match data { - FrameData::Complete(data) => - stream.push(ContextFrame::Chunk(Some(data))), - FrameData::Split(headers, payload) => { - stream.push(ContextFrame::Chunk(Some(headers))); - stream.push(ContextFrame::Chunk(Some(payload))); - } - } + stream.push(ContextFrame::Chunk(Some(data))); } else { warn!("Trying to write to disconnected response"); } @@ -140,31 +132,31 @@ impl WebsocketContext where A: Actor { /// Send text frame #[inline] pub fn text>(&mut self, text: T) { - self.write(Frame::message(text.into(), OpCode::Text, true).generate(false)); + self.write(Frame::message(text.into(), OpCode::Text, true, false)); } /// Send binary frame #[inline] pub fn binary>(&mut self, data: B) { - self.write(Frame::message(data, OpCode::Binary, true).generate(false)); + self.write(Frame::message(data, OpCode::Binary, true, false)); } /// Send ping frame #[inline] pub fn ping(&mut self, message: &str) { - self.write(Frame::message(Vec::from(message), OpCode::Ping, true).generate(false)); + self.write(Frame::message(Vec::from(message), OpCode::Ping, true, false)); } /// Send pong frame #[inline] pub fn pong(&mut self, message: &str) { - self.write(Frame::message(Vec::from(message), OpCode::Pong, true).generate(false)); + self.write(Frame::message(Vec::from(message), OpCode::Pong, true, false)); } /// Send close frame #[inline] pub fn close(&mut self, code: CloseCode, reason: &str) { - self.write(Frame::close(code, reason).generate(false)); + self.write(Frame::close(code, reason, false)); } /// Returns drain future diff --git a/src/ws/frame.rs b/src/ws/frame.rs index e6e0f5352..612fe2f0a 100644 --- a/src/ws/frame.rs +++ b/src/ws/frame.rs @@ -9,14 +9,6 @@ use body::Binary; use ws::proto::{OpCode, CloseCode}; use ws::mask::apply_mask; -#[derive(Debug, PartialEq)] -pub(crate) enum FrameData { - Complete(Binary), - Split(Binary, Binary), -} - -const MAX_LEN: usize = 122; - /// A struct representing a `WebSocket` frame. #[derive(Debug)] pub(crate) struct Frame { @@ -35,20 +27,9 @@ impl Frame { (self.finished, self.opcode, self.payload) } - /// Create a new data frame. - #[inline] - pub fn message>(data: B, code: OpCode, finished: bool) -> Frame { - Frame { - finished: finished, - opcode: code, - payload: data.into(), - .. Frame::default() - } - } - /// Create a new Close control frame. #[inline] - pub fn close(code: CloseCode, reason: &str) -> Frame { + pub fn close(code: CloseCode, reason: &str, genmask: bool) -> Binary { let raw: [u8; 2] = unsafe { let u: u16 = code.into(); mem::transmute(u.to_be()) @@ -63,10 +44,7 @@ impl Frame { .cloned()) }; - Frame { - payload: payload.into(), - .. Frame::default() - } + Frame::message(payload, OpCode::Close, true, genmask) } /// Parse the input stream into a frame. @@ -162,7 +140,7 @@ impl Frame { } OpCode::Close if length > 125 => { debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); - return Ok(Some(Frame::close(CloseCode::Protocol, "Received close frame with payload length exceeding 125."))) + return Ok(Some(Frame::default())) } _ => () } @@ -183,96 +161,61 @@ impl Frame { } /// Generate binary representation - pub fn generate(self, genmask: bool) -> FrameData { - let mut one = 0u8; - let code: u8 = self.opcode.into(); - if self.finished { - one |= 0x80; - } - if self.rsv1 { - one |= 0x40; - } - if self.rsv2 { - one |= 0x20; - } - if self.rsv3 { - one |= 0x10; - } - one |= code; - - let (two, mask_size) = if genmask { - (0x80, 4) + pub fn message>(data: B, code: OpCode, + finished: bool, genmask: bool) -> Binary + { + let payload = data.into(); + let one: u8 = if finished { + 0x80 | Into::::into(code) } else { - (0, 0) + code.into() + }; + let payload_len = payload.len(); + let (two, p_len) = if genmask { + (0x80, payload_len + 4) + } else { + (0, payload_len) }; - let payload_len = self.payload.len(); - let mut buf = if payload_len < MAX_LEN { - if genmask { - let len = payload_len + 6; - let mask: [u8; 4] = rand::random(); - let mut buf = BytesMut::with_capacity(len); - { - let buf_mut = unsafe{buf.bytes_mut()}; - buf_mut[0] = one; - buf_mut[1] = two | payload_len as u8; - buf_mut[2..6].copy_from_slice(&mask); - buf_mut[6..payload_len+6].copy_from_slice(self.payload.as_ref()); - apply_mask(&mut buf_mut[6..], &mask); - } - unsafe{buf.advance_mut(len)}; - return FrameData::Complete(buf.into()) - } else { - let len = payload_len + 2; - let mut buf = BytesMut::with_capacity(len); - { - let buf_mut = unsafe{buf.bytes_mut()}; - buf_mut[0] = one; - buf_mut[1] = two | payload_len as u8; - buf_mut[2..payload_len+2].copy_from_slice(self.payload.as_ref()); - } - unsafe{buf.advance_mut(len)}; - return FrameData::Complete(buf.into()) - } - } else if payload_len < 126 { - let mut buf = BytesMut::with_capacity(mask_size + 2); + let mut buf = if payload_len < 126 { + let mut buf = BytesMut::with_capacity(p_len + 2); + buf.put_slice(&[one, two | payload_len as u8]); + buf + } else if payload_len <= 65_535 { + let mut buf = BytesMut::with_capacity(p_len + 4); + buf.put_slice(&[one, two | 126]); { let buf_mut = unsafe{buf.bytes_mut()}; - buf_mut[0] = one; - buf_mut[1] = two | payload_len as u8; + BigEndian::write_u16(&mut buf_mut[..2], payload_len as u16); } unsafe{buf.advance_mut(2)}; buf - } else if payload_len <= 65_535 { - let mut buf = BytesMut::with_capacity(mask_size + 4); - { - let buf_mut = unsafe{buf.bytes_mut()}; - buf_mut[0] = one; - buf_mut[1] = two | 126; - BigEndian::write_u16(&mut buf_mut[2..4], payload_len as u16); - } - unsafe{buf.advance_mut(4)}; - buf } else { - let mut buf = BytesMut::with_capacity(mask_size + 10); + let mut buf = BytesMut::with_capacity(p_len + 8); + buf.put_slice(&[one, two | 127]); { let buf_mut = unsafe{buf.bytes_mut()}; - buf_mut[0] = one; - buf_mut[1] = two | 127; - BigEndian::write_u64(&mut buf_mut[2..10], payload_len as u64); + BigEndian::write_u64(&mut buf_mut[..8], payload_len as u64); } - unsafe{buf.advance_mut(10)}; + unsafe{buf.advance_mut(8)}; buf }; if genmask { - let mut payload = Vec::from(self.payload.as_ref()); let mask: [u8; 4] = rand::random(); - apply_mask(&mut payload, &mask); - buf.extend_from_slice(&mask); - FrameData::Split(buf.into(), payload.into()) + unsafe { + { + let buf_mut = buf.bytes_mut(); + buf_mut[..4].copy_from_slice(&mask); + buf_mut[4..payload_len+4].copy_from_slice(payload.as_ref()); + apply_mask(&mut buf_mut[4..], &mask); + } + buf.advance_mut(payload_len + 4); + } + buf.into() } else { - FrameData::Split(buf.into(), self.payload) + buf.put_slice(payload.as_ref()); + buf.into() } } } @@ -392,31 +335,28 @@ mod tests { #[test] fn test_ping_frame() { - let frame = Frame::message(Vec::from("data"), OpCode::Ping, true); - let res = frame.generate(false); + let frame = Frame::message(Vec::from("data"), OpCode::Ping, true, false); let mut v = vec![137u8, 4u8]; v.extend(b"data"); - assert_eq!(res, FrameData::Complete(v.into())); + assert_eq!(frame, v.into()); } #[test] fn test_pong_frame() { - let frame = Frame::message(Vec::from("data"), OpCode::Pong, true); - let res = frame.generate(false); + let frame = Frame::message(Vec::from("data"), OpCode::Pong, true, false); let mut v = vec![138u8, 4u8]; v.extend(b"data"); - assert_eq!(res, FrameData::Complete(v.into())); + assert_eq!(frame, v.into()); } #[test] fn test_close_frame() { - let frame = Frame::close(CloseCode::Normal, "data"); - let res = frame.generate(false); + let frame = Frame::close(CloseCode::Normal, "data", false); let mut v = vec![136u8, 6u8, 3u8, 232u8]; v.extend(b"data"); - assert_eq!(res, FrameData::Complete(v.into())); + assert_eq!(frame, v.into()); } }