From 6cd40df38769bd046529c62a56660791b7407b93 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 19 Mar 2018 17:27:03 -0700 Subject: [PATCH] Fix server websockets big payloads support --- .travis.yml | 3 - CHANGES.md | 2 + src/client/parser.rs | 4 +- src/pipeline.rs | 280 +++++++++++++++++++++-------------------- src/server/h1writer.rs | 6 +- src/ws/client.rs | 4 +- src/ws/mod.rs | 3 +- tests/test_ws.rs | 30 ++++- 8 files changed, 178 insertions(+), 154 deletions(-) diff --git a/.travis.yml b/.travis.yml index dfa93d40e..aa7f0c1e5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,9 +12,6 @@ matrix: - rust: stable - rust: beta - rust: nightly - allow_failures: - - rust: nightly - - rust: beta #rust: # - 1.21.0 diff --git a/CHANGES.md b/CHANGES.md index 7ff63f669..ab798f06c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -8,6 +8,8 @@ * Allow to set client websocket handshake timeout +* Fix server websockets big payloads support + ## 0.4.9 (2018-03-16) diff --git a/src/client/parser.rs b/src/client/parser.rs index 6ffcd76e4..8fe399009 100644 --- a/src/client/parser.rs +++ b/src/client/parser.rs @@ -145,9 +145,7 @@ impl HttpResponseParser { // convert headers let mut hdrs = HeaderMap::new(); for header in headers[..headers_len].iter() { - let n_start = header.name.as_ptr() as usize - bytes_ptr; - let n_end = n_start + header.name.len(); - if let Ok(name) = HeaderName::try_from(slice.slice(n_start, n_end)) { + if let Ok(name) = HeaderName::try_from(header.name) { let v_start = header.value.as_ptr() as usize - bytes_ptr; let v_end = v_start + header.value.len(); let value = unsafe { diff --git a/src/pipeline.rs b/src/pipeline.rs index e92e16f54..b5772e9a3 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -453,167 +453,171 @@ impl ProcessResponse { fn poll_io(mut self, io: &mut Writer, info: &mut PipelineInfo) -> Result, PipelineState> { - if self.drain.is_none() && self.running != RunningState::Paused { - // if task is paused, write buffer is probably full - 'outter: loop { - let result = match mem::replace(&mut self.iostate, IOState::Done) { - IOState::Response => { - let encoding = self.resp.content_encoding().unwrap_or(info.encoding); + loop { + if self.drain.is_none() && self.running != RunningState::Paused { + // if task is paused, write buffer is probably full + 'inner: loop { + let result = match mem::replace(&mut self.iostate, IOState::Done) { + IOState::Response => { + let encoding = self.resp.content_encoding().unwrap_or(info.encoding); - let result = match io.start(info.req_mut().get_inner(), - &mut self.resp, encoding) - { - Ok(res) => res, - Err(err) => { - info.error = Some(err.into()); - return Ok(FinishingMiddlewares::init(info, self.resp)) - } - }; - - if let Some(err) = self.resp.error() { - if self.resp.status().is_server_error() { - error!("Error occured during request handling: {}", err); - } else { - warn!("Error occured during request handling: {}", err); - } - if log_enabled!(Debug) { - debug!("{:?}", err); - } - } - - // always poll stream or actor for the first time - match self.resp.replace_body(Body::Empty) { - Body::Streaming(stream) => { - self.iostate = IOState::Payload(stream); - continue - }, - Body::Actor(ctx) => { - self.iostate = IOState::Actor(ctx); - continue - }, - _ => (), - } - - result - }, - IOState::Payload(mut body) => { - match body.poll() { - Ok(Async::Ready(None)) => { - if let Err(err) = io.write_eof() { + let result = match io.start(info.req_mut().get_inner(), + &mut self.resp, encoding) + { + Ok(res) => res, + Err(err) => { info.error = Some(err.into()); return Ok(FinishingMiddlewares::init(info, self.resp)) } - break + }; + + if let Some(err) = self.resp.error() { + if self.resp.status().is_server_error() { + error!("Error occured during request handling: {}", err); + } else { + warn!("Error occured during request handling: {}", err); + } + if log_enabled!(Debug) { + debug!("{:?}", err); + } + } + + // always poll stream or actor for the first time + match self.resp.replace_body(Body::Empty) { + Body::Streaming(stream) => { + self.iostate = IOState::Payload(stream); + continue 'inner + }, + Body::Actor(ctx) => { + self.iostate = IOState::Actor(ctx); + continue 'inner }, - Ok(Async::Ready(Some(chunk))) => { - self.iostate = IOState::Payload(body); - match io.write(chunk.into()) { - Err(err) => { + _ => (), + } + + result + }, + IOState::Payload(mut body) => { + match body.poll() { + Ok(Async::Ready(None)) => { + if let Err(err) = io.write_eof() { info.error = Some(err.into()); return Ok(FinishingMiddlewares::init(info, self.resp)) - }, - Ok(result) => result - } - } - Ok(Async::NotReady) => { - self.iostate = IOState::Payload(body); - break - }, - Err(err) => { - info.error = Some(err); - return Ok(FinishingMiddlewares::init(info, self.resp)) - } - } - }, - IOState::Actor(mut ctx) => { - if info.disconnected.take().is_some() { - ctx.disconnected(); - } - match ctx.poll() { - Ok(Async::Ready(Some(vec))) => { - if vec.is_empty() { - self.iostate = IOState::Actor(ctx); + } break - } - let mut res = None; - for frame in vec { - match frame { - Frame::Chunk(None) => { - info.context = Some(ctx); - if let Err(err) = io.write_eof() { - info.error = Some(err.into()); - return Ok( - FinishingMiddlewares::init(info, self.resp)) - } - break 'outter + }, + Ok(Async::Ready(Some(chunk))) => { + self.iostate = IOState::Payload(body); + match io.write(chunk.into()) { + Err(err) => { + info.error = Some(err.into()); + return Ok(FinishingMiddlewares::init(info, self.resp)) }, - Frame::Chunk(Some(chunk)) => { - match io.write(chunk) { - Err(err) => { + Ok(result) => result + } + } + Ok(Async::NotReady) => { + self.iostate = IOState::Payload(body); + break + }, + Err(err) => { + info.error = Some(err); + return Ok(FinishingMiddlewares::init(info, self.resp)) + } + } + }, + IOState::Actor(mut ctx) => { + if info.disconnected.take().is_some() { + ctx.disconnected(); + } + match ctx.poll() { + Ok(Async::Ready(Some(vec))) => { + if vec.is_empty() { + self.iostate = IOState::Actor(ctx); + break + } + let mut res = None; + for frame in vec { + match frame { + Frame::Chunk(None) => { + info.context = Some(ctx); + if let Err(err) = io.write_eof() { info.error = Some(err.into()); return Ok( FinishingMiddlewares::init(info, self.resp)) - }, - Ok(result) => res = Some(result), - } - }, - Frame::Drain(fut) => self.drain = Some(fut), + } + break 'inner + }, + Frame::Chunk(Some(chunk)) => { + match io.write(chunk) { + Err(err) => { + info.error = Some(err.into()); + return Ok( + FinishingMiddlewares::init(info, self.resp)) + }, + Ok(result) => res = Some(result), + } + }, + Frame::Drain(fut) => self.drain = Some(fut), + } } + self.iostate = IOState::Actor(ctx); + if self.drain.is_some() { + self.running.resume(); + break 'inner + } + res.unwrap() + }, + Ok(Async::Ready(None)) => { + break } - self.iostate = IOState::Actor(ctx); - if self.drain.is_some() { - self.running.resume(); - break 'outter + Ok(Async::NotReady) => { + self.iostate = IOState::Actor(ctx); + break + } + Err(err) => { + info.error = Some(err); + return Ok(FinishingMiddlewares::init(info, self.resp)) } - res.unwrap() - }, - Ok(Async::Ready(None)) => { - break - } - Ok(Async::NotReady) => { - self.iostate = IOState::Actor(ctx); - break - } - Err(err) => { - info.error = Some(err); - return Ok(FinishingMiddlewares::init(info, self.resp)) } } - } - IOState::Done => break, - }; + IOState::Done => break, + }; - match result { - WriterState::Pause => { - self.running.pause(); - break + match result { + WriterState::Pause => { + self.running.pause(); + break + } + WriterState::Done => { + self.running.resume() + }, } - WriterState::Done => { - self.running.resume() + } + } + + // flush io but only if we need to + if self.running == RunningState::Paused || self.drain.is_some() { + match io.poll_completed(false) { + Ok(Async::Ready(_)) => { + self.running.resume(); + + // resolve drain futures + if let Some(tx) = self.drain.take() { + let _ = tx.send(()); + } + // restart io processing + continue }, - } - } - } - - // flush io but only if we need to - if self.running == RunningState::Paused || self.drain.is_some() { - match io.poll_completed(false) { - Ok(Async::Ready(_)) => { - self.running.resume(); - - // resolve drain futures - if let Some(tx) = self.drain.take() { - let _ = tx.send(()); + Ok(Async::NotReady) => + return Err(PipelineState::Response(self)), + Err(err) => { + info.error = Some(err.into()); + return Ok(FinishingMiddlewares::init(info, self.resp)) } - // restart io processing - return self.poll_io(io, info); - }, - Ok(Async::NotReady) => return Err(PipelineState::Response(self)), - Err(err) => { - info.error = Some(err.into()); - return Ok(FinishingMiddlewares::init(info, self.resp)) } } + break } // response is completed diff --git a/src/server/h1writer.rs b/src/server/h1writer.rs index 531e3c8d5..c3eb5dc93 100644 --- a/src/server/h1writer.rs +++ b/src/server/h1writer.rs @@ -82,7 +82,9 @@ impl H1Writer { self.disconnected(); return Err(io::Error::new(io::ErrorKind::WriteZero, "")) }, - Ok(n) => written += n, + Ok(n) => { + written += n; + }, Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { return Ok(written) } @@ -229,7 +231,7 @@ impl Writer for H1Writer { if self.buffer.is_empty() { let pl: &[u8] = payload.as_ref(); let n = self.write_data(pl)?; - if pl.len() < n { + if n < pl.len() { self.buffer.extend_from_slice(&pl[n..]); return Ok(WriterState::Done); } diff --git a/src/ws/client.rs b/src/ws/client.rs index c5fdcf798..595930989 100644 --- a/src/ws/client.rs +++ b/src/ws/client.rs @@ -454,13 +454,13 @@ impl Stream for ClientReader { // read match Frame::parse(&mut inner.rx, false, max_size) { Ok(Async::Ready(Some(frame))) => { - let (finished, opcode, payload) = frame.unpack(); + let (_finished, opcode, payload) = frame.unpack(); match opcode { // continuation is not supported OpCode::Continue => { inner.closed = true; - return Err(ProtocolError::NoContinuation) + Err(ProtocolError::NoContinuation) }, OpCode::Bad => { inner.closed = true; diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 12fb4d709..7b41cf253 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -329,7 +329,8 @@ impl Stream for WsStream where S: Stream { match String::from_utf8(tmp) { Ok(s) => Ok(Async::Ready(Some(Message::Text(s)))), - Err(_) => { + Err(e) => { + println!("ENC: {:?}", e); self.closed = true; Err(ProtocolError::BadEncoding) } diff --git a/tests/test_ws.rs b/tests/test_ws.rs index 4d3ed4729..a4dd2c230 100644 --- a/tests/test_ws.rs +++ b/tests/test_ws.rs @@ -93,7 +93,8 @@ fn test_large_bin() { } struct Ws2 { - count: usize + count: usize, + bin: bool, } impl Actor for Ws2 { @@ -106,10 +107,14 @@ impl Actor for Ws2 { impl Ws2 { fn send(&mut self, ctx: &mut ws::WebsocketContext) { - ctx.text("0".repeat(65_536)); + if self.bin { + ctx.binary(Vec::from("0".repeat(65_536))); + } else { + ctx.text("0".repeat(65_536)); + } ctx.drain().and_then(|_, act, ctx| { act.count += 1; - if act.count != 100 { + if act.count != 10_000 { act.send(ctx); } actix::fut::ok(()) @@ -135,10 +140,25 @@ fn test_server_send_text() { let data = Some(ws::Message::Text("0".repeat(65_536))); let mut srv = test::TestServer::new( - |app| app.handler(|req| ws::start(req, Ws2{count:0}))); + |app| app.handler(|req| ws::start(req, Ws2{count:0, bin: false}))); let (mut reader, _writer) = srv.ws().unwrap(); - for _ in 0..100 { + for _ in 0..10_000 { + let (item, r) = srv.execute(reader.into_future()).unwrap(); + reader = r; + assert_eq!(item, data); + } +} + +#[test] +fn test_server_send_bin() { + let data = Some(ws::Message::Binary(Binary::from("0".repeat(65_536)))); + + let mut srv = test::TestServer::new( + |app| app.handler(|req| ws::start(req, Ws2{count:0, bin: true}))); + let (mut reader, _writer) = srv.ws().unwrap(); + + for _ in 0..10_000 { let (item, r) = srv.execute(reader.into_future()).unwrap(); reader = r; assert_eq!(item, data);