1
0
mirror of https://github.com/actix/actix-extras.git synced 2024-11-24 16:02:59 +01:00

special handling for upgraded pipeline

This commit is contained in:
Nikolay Kim 2018-02-10 00:05:20 -08:00
parent 2d049e4a9f
commit 3109f9be62
5 changed files with 91 additions and 144 deletions

View File

@ -128,7 +128,7 @@ impl Handler<ws::Message> for Ws {
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
match msg { match msg {
ws::Message::Text(text) => ctx.text(&text), ws::Message::Text(text) => ctx.text(text),
_ => (), _ => (),
} }
} }

View File

@ -96,15 +96,17 @@ impl<T: AsyncWrite> Writer for H1Writer<T> {
fn start(&mut self, req: &mut HttpMessage, msg: &mut HttpResponse) -> io::Result<WriterState> { fn start(&mut self, req: &mut HttpMessage, msg: &mut HttpResponse) -> io::Result<WriterState> {
// prepare task // prepare task
self.flags.insert(Flags::STARTED);
self.encoder = PayloadEncoder::new(self.buffer.clone(), req, msg); self.encoder = PayloadEncoder::new(self.buffer.clone(), req, msg);
if msg.keep_alive().unwrap_or_else(|| req.keep_alive()) { if msg.keep_alive().unwrap_or_else(|| req.keep_alive()) {
self.flags.insert(Flags::KEEPALIVE); self.flags.insert(Flags::STARTED | Flags::KEEPALIVE);
} else {
self.flags.insert(Flags::STARTED);
} }
// Connection upgrade // Connection upgrade
let version = msg.version().unwrap_or_else(|| req.version); let version = msg.version().unwrap_or_else(|| req.version);
if msg.upgrade() { if msg.upgrade() {
self.flags.insert(Flags::UPGRADE);
msg.headers_mut().insert(CONNECTION, HeaderValue::from_static("upgrade")); msg.headers_mut().insert(CONNECTION, HeaderValue::from_static("upgrade"));
} }
// keep-alive // keep-alive
@ -177,8 +179,29 @@ impl<T: AsyncWrite> Writer for H1Writer<T> {
self.written += payload.len() as u64; self.written += payload.len() as u64;
if !self.flags.contains(Flags::DISCONNECTED) { if !self.flags.contains(Flags::DISCONNECTED) {
if self.flags.contains(Flags::STARTED) { if self.flags.contains(Flags::STARTED) {
// shortcut for upgraded connection
if self.flags.contains(Flags::UPGRADE) {
if self.buffer.is_empty() {
match self.stream.write(payload.as_ref()) {
Ok(0) => {
self.disconnected();
return Ok(WriterState::Done);
},
Ok(n) => if payload.len() < n {
self.buffer.extend_from_slice(&payload.as_ref()[n..])
},
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
return Ok(WriterState::Done)
}
Err(err) => return Err(err),
}
} else {
self.buffer.extend(payload);
}
} else {
// TODO: add warning, write after EOF // TODO: add warning, write after EOF
self.encoder.write(payload)?; self.encoder.write(payload)?;
}
} else { } else {
// might be response to EXCEPT // might be response to EXCEPT
self.buffer.extend_from_slice(payload.as_ref()) self.buffer.extend_from_slice(payload.as_ref())

View File

@ -28,8 +28,8 @@ use client::{ClientRequest, ClientRequestBuilder,
use client::{Connect, Connection, ClientConnector, ClientConnectorError}; use client::{Connect, Connection, ClientConnector, ClientConnectorError};
use super::Message; use super::Message;
use super::frame::Frame;
use super::proto::{CloseCode, OpCode}; use super::proto::{CloseCode, OpCode};
use super::frame::{Frame, FrameData};
pub type WsClientFuture = pub type WsClientFuture =
Future<Item=(WsClientReader, WsClientWriter), Error=WsClientError>; Future<Item=(WsClientReader, WsClientWriter), Error=WsClientError>;
@ -444,17 +444,9 @@ impl WsClientWriter {
/// Write payload /// Write payload
#[inline] #[inline]
fn write(&mut self, data: FrameData) { fn write(&mut self, data: &Binary) {
if !self.as_mut().closed { if !self.as_mut().closed {
match data { let _ = self.as_mut().writer.write(data);
FrameData::Complete(data) => {
let _ = self.as_mut().writer.write(&data);
},
FrameData::Split(headers, payload) => {
let _ = self.as_mut().writer.write(&headers);
let _ = self.as_mut().writer.write(&payload);
}
}
} else { } else {
warn!("Trying to write to disconnected response"); warn!("Trying to write to disconnected response");
} }
@ -462,31 +454,31 @@ impl WsClientWriter {
/// Send text frame /// Send text frame
#[inline] #[inline]
pub fn text(&mut self, text: &str) { pub fn text<T: Into<String>>(&mut self, text: T) {
self.write(Frame::message(Vec::from(text), OpCode::Text, true).generate(true)); self.write(&Frame::message(text.into(), OpCode::Text, true, true));
} }
/// Send binary frame /// Send binary frame
#[inline] #[inline]
pub fn binary<B: Into<Binary>>(&mut self, data: B) { pub fn binary<B: Into<Binary>>(&mut self, data: B) {
self.write(Frame::message(data, OpCode::Binary, true).generate(true)); self.write(&Frame::message(data, OpCode::Binary, true, true));
} }
/// Send ping frame /// Send ping frame
#[inline] #[inline]
pub fn ping(&mut self, message: &str) { pub fn ping(&mut self, message: &str) {
self.write(Frame::message(Vec::from(message), OpCode::Ping, true).generate(true)); self.write(&Frame::message(Vec::from(message), OpCode::Ping, true, true));
} }
/// Send pong frame /// Send pong frame
#[inline] #[inline]
pub fn pong(&mut self, message: &str) { pub fn pong(&mut self, message: &str) {
self.write(Frame::message(Vec::from(message), OpCode::Pong, true).generate(true)); self.write(&Frame::message(Vec::from(message), OpCode::Pong, true, true));
} }
/// Send close frame /// Send close frame
#[inline] #[inline]
pub fn close(&mut self, code: CloseCode, reason: &str) { pub fn close(&mut self, code: CloseCode, reason: &str) {
self.write(Frame::close(code, reason).generate(true)); self.write(&Frame::close(code, reason, true));
} }
} }

View File

@ -14,7 +14,7 @@ use error::{Error, ErrorInternalServerError};
use httprequest::HttpRequest; use httprequest::HttpRequest;
use context::{Frame as ContextFrame, ActorHttpContext, Drain}; use context::{Frame as ContextFrame, ActorHttpContext, Drain};
use ws::frame::{Frame, FrameData}; use ws::frame::Frame;
use ws::proto::{OpCode, CloseCode}; use ws::proto::{OpCode, CloseCode};
@ -105,21 +105,13 @@ impl<A, S> WebsocketContext<A, S> where A: Actor<Context=Self> {
/// Write payload /// Write payload
#[inline] #[inline]
fn write(&mut self, data: FrameData) { fn write(&mut self, data: Binary) {
if !self.disconnected { if !self.disconnected {
if self.stream.is_none() { if self.stream.is_none() {
self.stream = Some(SmallVec::new()); self.stream = Some(SmallVec::new());
} }
let stream = self.stream.as_mut().unwrap(); let stream = self.stream.as_mut().unwrap();
stream.push(ContextFrame::Chunk(Some(data)));
match data {
FrameData::Complete(data) =>
stream.push(ContextFrame::Chunk(Some(data))),
FrameData::Split(headers, payload) => {
stream.push(ContextFrame::Chunk(Some(headers)));
stream.push(ContextFrame::Chunk(Some(payload)));
}
}
} else { } else {
warn!("Trying to write to disconnected response"); warn!("Trying to write to disconnected response");
} }
@ -140,31 +132,31 @@ impl<A, S> WebsocketContext<A, S> where A: Actor<Context=Self> {
/// Send text frame /// Send text frame
#[inline] #[inline]
pub fn text<T: Into<String>>(&mut self, text: T) { pub fn text<T: Into<String>>(&mut self, text: T) {
self.write(Frame::message(text.into(), OpCode::Text, true).generate(false)); self.write(Frame::message(text.into(), OpCode::Text, true, false));
} }
/// Send binary frame /// Send binary frame
#[inline] #[inline]
pub fn binary<B: Into<Binary>>(&mut self, data: B) { pub fn binary<B: Into<Binary>>(&mut self, data: B) {
self.write(Frame::message(data, OpCode::Binary, true).generate(false)); self.write(Frame::message(data, OpCode::Binary, true, false));
} }
/// Send ping frame /// Send ping frame
#[inline] #[inline]
pub fn ping(&mut self, message: &str) { pub fn ping(&mut self, message: &str) {
self.write(Frame::message(Vec::from(message), OpCode::Ping, true).generate(false)); self.write(Frame::message(Vec::from(message), OpCode::Ping, true, false));
} }
/// Send pong frame /// Send pong frame
#[inline] #[inline]
pub fn pong(&mut self, message: &str) { pub fn pong(&mut self, message: &str) {
self.write(Frame::message(Vec::from(message), OpCode::Pong, true).generate(false)); self.write(Frame::message(Vec::from(message), OpCode::Pong, true, false));
} }
/// Send close frame /// Send close frame
#[inline] #[inline]
pub fn close(&mut self, code: CloseCode, reason: &str) { pub fn close(&mut self, code: CloseCode, reason: &str) {
self.write(Frame::close(code, reason).generate(false)); self.write(Frame::close(code, reason, false));
} }
/// Returns drain future /// Returns drain future

View File

@ -9,14 +9,6 @@ use body::Binary;
use ws::proto::{OpCode, CloseCode}; use ws::proto::{OpCode, CloseCode};
use ws::mask::apply_mask; use ws::mask::apply_mask;
#[derive(Debug, PartialEq)]
pub(crate) enum FrameData {
Complete(Binary),
Split(Binary, Binary),
}
const MAX_LEN: usize = 122;
/// A struct representing a `WebSocket` frame. /// A struct representing a `WebSocket` frame.
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct Frame { pub(crate) struct Frame {
@ -35,20 +27,9 @@ impl Frame {
(self.finished, self.opcode, self.payload) (self.finished, self.opcode, self.payload)
} }
/// Create a new data frame.
#[inline]
pub fn message<B: Into<Binary>>(data: B, code: OpCode, finished: bool) -> Frame {
Frame {
finished: finished,
opcode: code,
payload: data.into(),
.. Frame::default()
}
}
/// Create a new Close control frame. /// Create a new Close control frame.
#[inline] #[inline]
pub fn close(code: CloseCode, reason: &str) -> Frame { pub fn close(code: CloseCode, reason: &str, genmask: bool) -> Binary {
let raw: [u8; 2] = unsafe { let raw: [u8; 2] = unsafe {
let u: u16 = code.into(); let u: u16 = code.into();
mem::transmute(u.to_be()) mem::transmute(u.to_be())
@ -63,10 +44,7 @@ impl Frame {
.cloned()) .cloned())
}; };
Frame { Frame::message(payload, OpCode::Close, true, genmask)
payload: payload.into(),
.. Frame::default()
}
} }
/// Parse the input stream into a frame. /// Parse the input stream into a frame.
@ -162,7 +140,7 @@ impl Frame {
} }
OpCode::Close if length > 125 => { OpCode::Close if length > 125 => {
debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
return Ok(Some(Frame::close(CloseCode::Protocol, "Received close frame with payload length exceeding 125."))) return Ok(Some(Frame::default()))
} }
_ => () _ => ()
} }
@ -183,96 +161,61 @@ impl Frame {
} }
/// Generate binary representation /// Generate binary representation
pub fn generate(self, genmask: bool) -> FrameData { pub fn message<B: Into<Binary>>(data: B, code: OpCode,
let mut one = 0u8; finished: bool, genmask: bool) -> Binary
let code: u8 = self.opcode.into(); {
if self.finished { let payload = data.into();
one |= 0x80; let one: u8 = if finished {
} 0x80 | Into::<u8>::into(code)
if self.rsv1 {
one |= 0x40;
}
if self.rsv2 {
one |= 0x20;
}
if self.rsv3 {
one |= 0x10;
}
one |= code;
let (two, mask_size) = if genmask {
(0x80, 4)
} else { } else {
(0, 0) code.into()
};
let payload_len = payload.len();
let (two, p_len) = if genmask {
(0x80, payload_len + 4)
} else {
(0, payload_len)
}; };
let payload_len = self.payload.len(); let mut buf = if payload_len < 126 {
let mut buf = if payload_len < MAX_LEN { let mut buf = BytesMut::with_capacity(p_len + 2);
if genmask { buf.put_slice(&[one, two | payload_len as u8]);
let len = payload_len + 6; buf
let mask: [u8; 4] = rand::random(); } else if payload_len <= 65_535 {
let mut buf = BytesMut::with_capacity(len); let mut buf = BytesMut::with_capacity(p_len + 4);
buf.put_slice(&[one, two | 126]);
{ {
let buf_mut = unsafe{buf.bytes_mut()}; let buf_mut = unsafe{buf.bytes_mut()};
buf_mut[0] = one; BigEndian::write_u16(&mut buf_mut[..2], payload_len as u16);
buf_mut[1] = two | payload_len as u8;
buf_mut[2..6].copy_from_slice(&mask);
buf_mut[6..payload_len+6].copy_from_slice(self.payload.as_ref());
apply_mask(&mut buf_mut[6..], &mask);
}
unsafe{buf.advance_mut(len)};
return FrameData::Complete(buf.into())
} else {
let len = payload_len + 2;
let mut buf = BytesMut::with_capacity(len);
{
let buf_mut = unsafe{buf.bytes_mut()};
buf_mut[0] = one;
buf_mut[1] = two | payload_len as u8;
buf_mut[2..payload_len+2].copy_from_slice(self.payload.as_ref());
}
unsafe{buf.advance_mut(len)};
return FrameData::Complete(buf.into())
}
} else if payload_len < 126 {
let mut buf = BytesMut::with_capacity(mask_size + 2);
{
let buf_mut = unsafe{buf.bytes_mut()};
buf_mut[0] = one;
buf_mut[1] = two | payload_len as u8;
} }
unsafe{buf.advance_mut(2)}; unsafe{buf.advance_mut(2)};
buf buf
} else if payload_len <= 65_535 {
let mut buf = BytesMut::with_capacity(mask_size + 4);
{
let buf_mut = unsafe{buf.bytes_mut()};
buf_mut[0] = one;
buf_mut[1] = two | 126;
BigEndian::write_u16(&mut buf_mut[2..4], payload_len as u16);
}
unsafe{buf.advance_mut(4)};
buf
} else { } else {
let mut buf = BytesMut::with_capacity(mask_size + 10); let mut buf = BytesMut::with_capacity(p_len + 8);
buf.put_slice(&[one, two | 127]);
{ {
let buf_mut = unsafe{buf.bytes_mut()}; let buf_mut = unsafe{buf.bytes_mut()};
buf_mut[0] = one; BigEndian::write_u64(&mut buf_mut[..8], payload_len as u64);
buf_mut[1] = two | 127;
BigEndian::write_u64(&mut buf_mut[2..10], payload_len as u64);
} }
unsafe{buf.advance_mut(10)}; unsafe{buf.advance_mut(8)};
buf buf
}; };
if genmask { if genmask {
let mut payload = Vec::from(self.payload.as_ref());
let mask: [u8; 4] = rand::random(); let mask: [u8; 4] = rand::random();
apply_mask(&mut payload, &mask); unsafe {
buf.extend_from_slice(&mask); {
FrameData::Split(buf.into(), payload.into()) let buf_mut = buf.bytes_mut();
buf_mut[..4].copy_from_slice(&mask);
buf_mut[4..payload_len+4].copy_from_slice(payload.as_ref());
apply_mask(&mut buf_mut[4..], &mask);
}
buf.advance_mut(payload_len + 4);
}
buf.into()
} else { } else {
FrameData::Split(buf.into(), self.payload) buf.put_slice(payload.as_ref());
buf.into()
} }
} }
} }
@ -392,31 +335,28 @@ mod tests {
#[test] #[test]
fn test_ping_frame() { fn test_ping_frame() {
let frame = Frame::message(Vec::from("data"), OpCode::Ping, true); let frame = Frame::message(Vec::from("data"), OpCode::Ping, true, false);
let res = frame.generate(false);
let mut v = vec![137u8, 4u8]; let mut v = vec![137u8, 4u8];
v.extend(b"data"); v.extend(b"data");
assert_eq!(res, FrameData::Complete(v.into())); assert_eq!(frame, v.into());
} }
#[test] #[test]
fn test_pong_frame() { fn test_pong_frame() {
let frame = Frame::message(Vec::from("data"), OpCode::Pong, true); let frame = Frame::message(Vec::from("data"), OpCode::Pong, true, false);
let res = frame.generate(false);
let mut v = vec![138u8, 4u8]; let mut v = vec![138u8, 4u8];
v.extend(b"data"); v.extend(b"data");
assert_eq!(res, FrameData::Complete(v.into())); assert_eq!(frame, v.into());
} }
#[test] #[test]
fn test_close_frame() { fn test_close_frame() {
let frame = Frame::close(CloseCode::Normal, "data"); let frame = Frame::close(CloseCode::Normal, "data", false);
let res = frame.generate(false);
let mut v = vec![136u8, 6u8, 3u8, 232u8]; let mut v = vec![136u8, 6u8, 3u8, 232u8];
v.extend(b"data"); v.extend(b"data");
assert_eq!(res, FrameData::Complete(v.into())); assert_eq!(frame, v.into());
} }
} }