diff --git a/.gitignore b/.gitignore index 0a66cac50..ec2e10ce9 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ guide/build/ *.pid *.sock *~ +.DS_Store diff --git a/Cargo.toml b/Cargo.toml index fe1f02214..169763d88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,19 @@ [workspace] members = [ - "actix-cors", - "actix-identity", - "actix-protobuf", - "actix-protobuf/examples/prost-example", - "actix-redis", - "actix-session", - "actix-web-httpauth", + "actix-amqp", + "actix-amqp/codec", + "actix-cors", + "actix-identity", + "actix-mqtt", + "actix-mqtt/codec", + "actix-protobuf", + "actix-protobuf/examples/prost-example", + "actix-redis", + "actix-session", + "actix-web-httpauth", ] [patch.crates-io] actix-session = { path = "actix-session" } +mqtt-codec = { path = "actix-mqtt/codec" } +amqp-codec = { path = "actix-amqp/codec" } diff --git a/actix-amqp/CHANGES.md b/actix-amqp/CHANGES.md new file mode 100755 index 000000000..092e9f4e7 --- /dev/null +++ b/actix-amqp/CHANGES.md @@ -0,0 +1,25 @@ +# Changes + +## Unreleased - 2020-xx-xx + + +## 0.1.4 - 2020-03-05 +* Add server handshake timeout + + +## 0.1.3 - 2020-02-10 +* Allow to override sender link attach frame + + +## 0.1.2 - 2019-12-25 +* Allow to specify multi-pattern for topics + + +## 0.1.1 - 2019-12-18 +* Separate control frame entries for detach sender qand detach receiver +* Proper detach remote receiver +* Replace `async fn` with `impl Future` + + +## 0.1.0 - 2019-12-11 +* Initial release diff --git a/actix-amqp/Cargo.toml b/actix-amqp/Cargo.toml new file mode 100755 index 000000000..b4141a6b6 --- /dev/null +++ b/actix-amqp/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "actix-amqp" +version = "0.1.4" +authors = ["Nikolay Kim "] +description = "AMQP 1.0 Client/Server framework" +documentation = "https://docs.rs/actix-amqp" +repository = "https://github.com/actix/actix-extras.git" +categories = ["network-programming"] +keywords = ["AMQP", "IoT", "messaging"] +license = "MIT OR Apache-2.0" +edition = "2018" + +[dependencies] +amqp-codec = "0.1.0" +actix-codec = "0.2.0" +actix-service = "1.0.1" +actix-connect = "1.0.1" +actix-router = "0.2.2" +actix-utils = "1.0.4" +actix-rt = "1.0.0" +bytes = "0.5.3" +bytestring = "0.1.2" +derive_more = "0.99.2" +either = "1.5.3" +futures = "0.3.1" +fxhash = "0.2.1" +http = "0.2.0" +log = "0.4" +pin-project = "0.4.6" +uuid = { version = "0.8", features = ["v4"] } +slab = "0.4" + +[dev-dependencies] +env_logger = "0.6" +actix-testing = "1.0.0" diff --git a/actix-amqp/README.md b/actix-amqp/README.md new file mode 100755 index 000000000..37fc0c3a3 --- /dev/null +++ b/actix-amqp/README.md @@ -0,0 +1,3 @@ +# AMQP 1.0 Server Framework + +[![Build Status](https://travis-ci.org/actix/actix-amqp.svg?branch=master)](https://travis-ci.org/actix/actix-amqp) diff --git a/actix-amqp/codec/Cargo.toml b/actix-amqp/codec/Cargo.toml new file mode 100755 index 000000000..3d5d8c288 --- /dev/null +++ b/actix-amqp/codec/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "amqp-codec" +version = "0.1.0" +description = "AMQP 1.0 Protocol Codec" +authors = ["Nikolay Kim ", "Max Gortman ", "Mike Yagley "] +license = "MIT/Apache-2.0" +edition = "2018" + +[dependencies] +actix-codec = "0.2.0" +bytes = "0.5.2" +byteorder = "1.3.1" +bytestring = "0.1.1" +chrono = "0.4" +derive_more = "0.99.2" +fxhash = "0.2.1" +ordered-float = "1.0" +uuid = { version = "0.8", features = ["v4"] } + +[build-dependencies] +handlebars = { version = "0.27", optional = true } +serde = { version = "1.0", optional = true } +serde_derive = { version = "1.0", optional = true } +serde_json = { version = "1.0", optional = true } +lazy_static = { version = "1.0", optional = true } +regex = { version = "1.3", optional = true } + +[features] +default = [] + +from-spec = ["handlebars", "serde", "serde_derive", "serde_json", "lazy_static", "regex"] diff --git a/actix-amqp/codec/README.md b/actix-amqp/codec/README.md new file mode 100755 index 000000000..70d45cef4 --- /dev/null +++ b/actix-amqp/codec/README.md @@ -0,0 +1,3 @@ +# AMQP 1.0 Protocol Codec + +[![Build Status](https://travis-ci.org/fafhrd91/amqp-codec.svg?branch=master)](https://travis-ci.org/fafhrd91/amqp-codec) \ No newline at end of file diff --git a/actix-amqp/codec/build.rs b/actix-amqp/codec/build.rs new file mode 100755 index 000000000..3a83064c4 --- /dev/null +++ b/actix-amqp/codec/build.rs @@ -0,0 +1,102 @@ +#[cfg(feature = "from-spec")] +extern crate handlebars; +#[cfg(feature = "from-spec")] +#[macro_use] +extern crate lazy_static; +#[cfg(feature = "from-spec")] +extern crate serde; +#[cfg(feature = "from-spec")] +#[macro_use] +extern crate serde_derive; +#[cfg(feature = "from-spec")] +extern crate regex; +#[cfg(feature = "from-spec")] +extern crate serde_json; + +#[cfg(feature = "from-spec")] +mod codegen; + +fn main() { + #[cfg(feature = "from-spec")] + { + generate_from_spec(); + } +} + +#[cfg(feature = "from-spec")] +fn generate_from_spec() { + use handlebars::{Handlebars, Helper, RenderContext, RenderError}; + use std::env; + use std::fs::File; + use std::io::Write; + + let template = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/codegen/definitions.rs" + )); + let spec = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/codegen/specification.json" + )); + let out_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR is not defined") + + "/src/protocol/"; + + let definitions = codegen::parse(spec); + + let mut codegen = Handlebars::new(); + codegen.register_helper( + "snake", + Box::new( + |h: &Helper, _: &Handlebars, rc: &mut RenderContext| -> Result<(), RenderError> { + let value = h + .param(0) + .ok_or_else(|| RenderError::new("Param not found for helper \"snake\""))?; + let param = value.value().as_str().ok_or_else(|| { + RenderError::new("Non-string param given to helper \"snake\"") + })?; + rc.writer.write_all(codegen::snake_case(param).as_bytes())?; + Ok(()) + }, + ), + ); + + codegen + .register_template_string("definitions", template.to_string()) + .expect("Failed to register template."); + let mut data = std::collections::BTreeMap::new(); + + data.insert("defs", definitions); + let def_path = std::path::Path::new(&out_dir).join("definitions.rs"); + { + let mut f = File::create(def_path.clone()).expect("Failed to create target file."); + let rendered = codegen + .render("definitions", &data) + .expect("Failed to render template."); + writeln!(f, "{}", rendered).expect("Failed to write to file."); + } + + reformat_file(&def_path); + + fn reformat_file(path: &std::path::Path) { + use std::fs::OpenOptions; + use std::io::{Read, Seek}; + std::process::Command::new("rustfmt") + .arg(path.to_str().unwrap()) + .output() + .expect("failed to format definitions.rs"); + + let mut f = OpenOptions::new() + .read(true) + .write(true) + .open(path) + .expect("failed to open file"); + let mut data = String::new(); + f.read_to_string(&mut data).expect("failed to read file."); + let regex = regex::Regex::new("(?P[\r]?\n)[\\s]*\r?\n").unwrap(); + data = regex.replace_all(&data, "$a").into(); + f.seek(std::io::SeekFrom::Start(0)).unwrap(); + f.set_len(data.len() as u64).unwrap(); + f.write_all(data.as_bytes()) + .expect("Error writing reformatted file"); + } +} diff --git a/actix-amqp/codec/codegen/definitions.rs b/actix-amqp/codec/codegen/definitions.rs new file mode 100755 index 000000000..83c04ffca --- /dev/null +++ b/actix-amqp/codec/codegen/definitions.rs @@ -0,0 +1,460 @@ +#![allow(unused_assignments, unused_variables, unreachable_patterns)] + +use std::u8; +use derive_more::From; +use bytes::{BufMut, Bytes, BytesMut}; +use bytestring::ByteString; +use uuid::Uuid; + +use super::*; +use crate::errors::AmqpParseError; +use crate::codec::{self, decode_format_code, decode_list_header, Decode, DecodeFormatted, Encode}; + +#[derive(Clone, Debug, PartialEq, From)] +pub enum Frame { + Open(Open), + Begin(Begin), + Attach(Attach), + Flow(Flow), + Transfer(Transfer), + Disposition(Disposition), + Detach(Detach), + End(End), + Close(Close), + Empty, +} + +impl Frame { + pub fn name(&self) -> &'static str { + match self { + Frame::Open(_) => "Open", + Frame::Begin(_) => "Begin", + Frame::Attach(_) => "Attach", + Frame::Flow(_) => "Flow", + Frame::Transfer(_) => "Transfer", + Frame::Disposition(_) => "Disposition", + Frame::Detach(_) => "Detach", + Frame::End(_) => "End", + Frame::Close(_) => "Close", + Frame::Empty => "Empty", + } + } +} + +impl Decode for Frame { + fn decode(input: &[u8]) -> Result<(&[u8], Self), AmqpParseError> { + if input.is_empty() { + Ok((input, Frame::Empty)) + } else { + let (input, fmt) = decode_format_code(input)?; + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + match descriptor { + Descriptor::Ulong(16) => decode_open_inner(input).map(|(i, r)| (i, Frame::Open(r))), + Descriptor::Ulong(17) => { + decode_begin_inner(input).map(|(i, r)| (i, Frame::Begin(r))) + } + Descriptor::Ulong(18) => { + decode_attach_inner(input).map(|(i, r)| (i, Frame::Attach(r))) + } + Descriptor::Ulong(19) => decode_flow_inner(input).map(|(i, r)| (i, Frame::Flow(r))), + Descriptor::Ulong(20) => { + decode_transfer_inner(input).map(|(i, r)| (i, Frame::Transfer(r))) + } + Descriptor::Ulong(21) => { + decode_disposition_inner(input).map(|(i, r)| (i, Frame::Disposition(r))) + } + Descriptor::Ulong(22) => { + decode_detach_inner(input).map(|(i, r)| (i, Frame::Detach(r))) + } + Descriptor::Ulong(23) => decode_end_inner(input).map(|(i, r)| (i, Frame::End(r))), + Descriptor::Ulong(24) => { + decode_close_inner(input).map(|(i, r)| (i, Frame::Close(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:open:list" => { + decode_open_inner(input).map(|(i, r)| (i, Frame::Open(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:begin:list" => { + decode_begin_inner(input).map(|(i, r)| (i, Frame::Begin(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:attach:list" => { + decode_attach_inner(input).map(|(i, r)| (i, Frame::Attach(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:flow:list" => { + decode_flow_inner(input).map(|(i, r)| (i, Frame::Flow(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:transfer:list" => { + decode_transfer_inner(input).map(|(i, r)| (i, Frame::Transfer(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:disposition:list" => { + decode_disposition_inner(input).map(|(i, r)| (i, Frame::Disposition(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:detach:list" => { + decode_detach_inner(input).map(|(i, r)| (i, Frame::Detach(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:end:list" => { + decode_end_inner(input).map(|(i, r)| (i, Frame::End(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:close:list" => { + decode_close_inner(input).map(|(i, r)| (i, Frame::Close(r))) + } + _ => Err(AmqpParseError::InvalidDescriptor(descriptor)), + } + } + } +} + +impl Encode for Frame { + fn encoded_size(&self) -> usize { + match *self { + Frame::Open(ref v) => encoded_size_open_inner(v), + Frame::Begin(ref v) => encoded_size_begin_inner(v), + Frame::Attach(ref v) => encoded_size_attach_inner(v), + Frame::Flow(ref v) => encoded_size_flow_inner(v), + Frame::Transfer(ref v) => encoded_size_transfer_inner(v), + Frame::Disposition(ref v) => encoded_size_disposition_inner(v), + Frame::Detach(ref v) => encoded_size_detach_inner(v), + Frame::End(ref v) => encoded_size_end_inner(v), + Frame::Close(ref v) => encoded_size_close_inner(v), + Frame::Empty => 0, + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + Frame::Open(ref v) => encode_open_inner(v, buf), + Frame::Begin(ref v) => encode_begin_inner(v, buf), + Frame::Attach(ref v) => encode_attach_inner(v, buf), + Frame::Flow(ref v) => encode_flow_inner(v, buf), + Frame::Transfer(ref v) => encode_transfer_inner(v, buf), + Frame::Disposition(ref v) => encode_disposition_inner(v, buf), + Frame::Detach(ref v) => encode_detach_inner(v, buf), + Frame::End(ref v) => encode_end_inner(v, buf), + Frame::Close(ref v) => encode_close_inner(v, buf), + Frame::Empty => (), + } + } +} + +{{#each defs.provides as |provide|}} +{{#if provide.described}} +#[derive(Clone, Debug, PartialEq)] +pub enum {{provide.name}} { +{{#each provide.options as |option|}} + {{option.ty}}({{option.ty}}), +{{/each}} +} + +impl DecodeFormatted for {{provide.name}} { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + match descriptor { + {{#each provide.options as |option|}} + Descriptor::Ulong({{option.descriptor.code}}) => decode_{{snake option.ty}}_inner(input).map(|(i, r)| (i, {{provide.name}}::{{option.ty}}(r))), + {{/each}} + {{#each provide.options as |option|}} + Descriptor::Symbol(ref a) if a.as_str() == "{{option.descriptor.name}}" => decode_{{snake option.ty}}_inner(input).map(|(i, r)| (i, {{provide.name}}::{{option.ty}}(r))), + {{/each}} + _ => Err(AmqpParseError::InvalidDescriptor(descriptor)) + } + } +} + +impl Encode for {{provide.name}} { + fn encoded_size(&self) -> usize { + match *self { + {{#each provide.options as |option|}} + {{provide.name}}::{{option.ty}}(ref v) => encoded_size_{{snake option.ty}}_inner(v), + {{/each}} + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + {{#each provide.options as |option|}} + {{provide.name}}::{{option.ty}}(ref v) => encode_{{snake option.ty}}_inner(v, buf), + {{/each}} + } + } +} +{{/if}} +{{/each}} + +{{#each defs.aliases as |alias|}} +pub type {{alias.name}} = {{alias.source}}; +{{/each}} + +{{#each defs.enums as |enum|}} +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum {{enum.name}} { +{{#each enum.items as |item|}} + {{item.name}}, +{{/each}} +} +{{#if enum.is_symbol}} +impl {{enum.name}} { + pub fn try_from(v: &Symbol) -> Result { + match v.as_str() { + {{#each enum.items as |item|}} + "{{item.value}}" => Ok({{enum.name}}::{{item.name}}), + {{/each}} + _ => Err(AmqpParseError::UnknownEnumOption("{{enum.name}}")) + } + } +} +impl DecodeFormatted for {{enum.name}} { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, base) = Symbol::decode_with_format(input, fmt)?; + Ok((input, Self::try_from(&base)?)) + } +} +impl Encode for {{enum.name}} { + fn encoded_size(&self) -> usize { + match *self { + {{#each enum.items as |item|}} + {{enum.name}}::{{item.name}} => {{item.value_len}} + 2, + {{/each}} + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + {{#each enum.items as |item|}} + {{enum.name}}::{{item.name}} => StaticSymbol("{{item.value}}").encode(buf), + {{/each}} + } + } +} +{{else}} +impl {{enum.name}} { + pub fn try_from(v: {{enum.ty}}) -> Result { + match v { + {{#each enum.items as |item|}} + {{item.value}} => Ok({{enum.name}}::{{item.name}}), + {{/each}} + _ => Err(AmqpParseError::UnknownEnumOption("{{enum.name}}")) + } + } +} +impl DecodeFormatted for {{enum.name}} { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, base) = {{enum.ty}}::decode_with_format(input, fmt)?; + Ok((input, Self::try_from(base)?)) + } +} +impl Encode for {{enum.name}} { + fn encoded_size(&self) -> usize { + match *self { + {{#each enum.items as |item|}} + {{enum.name}}::{{item.name}} => { + let v : {{enum.ty}} = {{item.value}}; + v.encoded_size() + }, + {{/each}} + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + {{#each enum.items as |item|}} + {{enum.name}}::{{item.name}} => { + let v : {{enum.ty}} = {{item.value}}; + v.encode(buf); + }, + {{/each}} + } + } +} +{{/if}} +{{/each}} + +{{#each defs.described_restricted as |dr|}} +type {{dr.name}} = {{dr.ty}}; +fn decode_{{snake dr.name}}_inner(input: &[u8]) -> Result<(&[u8], {{dr.name}}), AmqpParseError> { + {{dr.name}}::decode(input) +} +fn encoded_size_{{snake dr.name}}_inner(dr: &{{dr.name}}) -> usize { + // descriptor size + actual size + 3 + dr.encoded_size() +} +fn encode_{{snake dr.name}}_inner(dr: &{{dr.name}}, buf: &mut BytesMut) { + Descriptor::Ulong({{dr.descriptor.code}}).encode(buf); + dr.encode(buf); +} +{{/each}} + +{{#each defs.lists as |list|}} +#[derive(Clone, Debug, PartialEq)] +pub struct {{list.name}} { + {{#each list.fields as |field|}} + {{#if field.optional}} + pub {{field.name}}: Option<{{{field.ty}}}>, + {{else}} + pub {{field.name}}: {{{field.ty}}}, + {{/if}} + {{/each}} + {{#if list.transfer}} + pub body: Option, + {{/if}} +} + +impl {{list.name}} { + {{#each list.fields as |field|}} + {{#if field.is_str}} + {{#if field.optional}} + pub fn {{field.name}}(&self) -> Option<&str> { + match self.{{field.name}} { + None => None, + Some(ref s) => Some(s.as_str()) + } + } + {{else}} + pub fn {{field.name}}(&self) -> &str { self.{{field.name}}.as_str() } + {{/if}} + {{else}} + {{#if field.is_ref}} + {{#if field.optional}} + pub fn {{field.name}}(&self) -> Option<&{{{field.ty}}}> { self.{{field.name}}.as_ref() } + {{else}} + pub fn {{field.name}}(&self) -> &{{{field.ty}}} { &self.{{field.name}} } + {{/if}} + {{else}} + {{#if field.optional}} + pub fn {{field.name}}(&self) -> Option<{{{field.ty}}}> { self.{{field.name}} } + {{else}} + pub fn {{field.name}}(&self) -> {{{field.ty}}} { self.{{field.name}} } + {{/if}} + {{/if}} + {{/if}} + {{/each}} + + {{#if list.transfer}} + pub fn body(&self) -> Option<&TransferBody> { + self.body.as_ref() + } + {{/if}} + + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 {{#each list.fields as |field|}} + 1{{/each}}; +} +#[allow(unused_mut)] +fn decode_{{snake list.name}}_inner(input: &[u8]) -> Result<(&[u8], {{list.name}}), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + {{#if list.fields}} + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + {{#each list.fields as |field|}} + {{#if field.optional}} + let {{field.name}}: Option<{{{field.ty}}}>; + if count > 0 { + let decoded = Option::<{{{field.ty}}}>::decode(input)?; + input = decoded.0; + {{field.name}} = decoded.1; + count -= 1; + } + else { + {{field.name}} = None; + } + {{else}} + let {{field.name}}: {{{field.ty}}}; + if count > 0 { + {{#if field.default}} + let (in1, decoded) = Option::<{{{field.ty}}}>::decode(input)?; + {{field.name}} = decoded.unwrap_or({{field.default}}); + {{else}} + let (in1, decoded) = {{{field.ty}}}::decode(input)?; + {{field.name}} = decoded; + {{/if}} + input = in1; + count -= 1; + } + else { + {{#if field.default}} + {{field.name}} = {{field.default}}; + {{else}} + return Err(AmqpParseError::RequiredFieldOmitted("{{field.name}}")) + {{/if}} + } + {{/if}} + {{/each}} + {{else}} + let mut remainder = &input[size..]; + {{/if}} + + {{#if list.transfer}} + let body = if remainder.is_empty() { + None + } else { + let b = Bytes::copy_from_slice(remainder); + remainder = &[]; + Some(b.into()) + }; + {{/if}} + + Ok((remainder, {{list.name}} { + {{#each list.fields as |field|}} + {{field.name}}, + {{/each}} + {{#if list.transfer}} + body + {{/if}} + })) +} + +fn encoded_size_{{snake list.name}}_inner(list: &{{list.name}}) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 {{#each list.fields as |field|}} + list.{{field.name}}.encoded_size(){{/each}}; + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { 12 } else { 6 }) + + content_size + + {{#if list.transfer}} + + list.body.as_ref().map(|b| b.len()).unwrap_or(0) + {{/if}} +} +fn encode_{{snake list.name}}_inner(list: &{{list.name}}, buf: &mut BytesMut) { + Descriptor::Ulong({{list.descriptor.code}}).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 {{#each list.fields as |field|}} + list.{{field.name}}.encoded_size(){{/each}}; + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32({{list.name}}::FIELD_COUNT as u32); + } + else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8({{list.name}}::FIELD_COUNT as u8); + } + {{#each list.fields as |field|}} + list.{{field.name}}.encode(buf); + {{/each}} + {{#if list.transfer}} + if let Some(ref body) = list.body { + body.encode(buf) + } + {{/if}} +} + +impl DecodeFormatted for {{list.name}} { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == {{list.descriptor.code}}, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"{{list.descriptor.name}}", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_{{snake list.name}}_inner(input) + } + } +} + +impl Encode for {{list.name}} { + fn encoded_size(&self) -> usize { encoded_size_{{snake list.name}}_inner(self) } + + fn encode(&self, buf: &mut BytesMut) { encode_{{snake list.name}}_inner(self, buf) } +} +{{/each}} diff --git a/actix-amqp/codec/codegen/mod.rs b/actix-amqp/codec/codegen/mod.rs new file mode 100755 index 000000000..78e8daa7d --- /dev/null +++ b/actix-amqp/codec/codegen/mod.rs @@ -0,0 +1,478 @@ +use serde::{Deserialize, Deserializer}; +use serde_json::from_str; +use std::collections::{HashMap, HashSet}; +use std::str::{FromStr, ParseBoolError}; +use std::sync::Mutex; + +lazy_static! { + static ref PRIMITIVE_TYPES: HashMap<&'static str, &'static str> = { + let mut m = HashMap::new(); + m.insert("*", "Variant"); + m.insert("binary", "Bytes"); + m.insert("string", "ByteString"); + m.insert("ubyte", "u8"); + m.insert("ushort", "u16"); + m.insert("uint", "u32"); + m.insert("ulong", "u64"); + m.insert("boolean", "bool"); + m + }; + static ref STRING_TYPES: HashSet<&'static str> = ["string", "symbol"].iter().cloned().collect(); + static ref REF_TYPES: Mutex> = Mutex::new( + [ + "Bytes", + "ByteString", + "Symbol", + "Fields", + "Map", + "MessageId", + "Address", + "NodeProperties", + "Outcome", + "DeliveryState", + "FilterSet", + "DeliveryTag", + "Symbols", + "IetfLanguageTags", + "ErrorCondition", + "DistributionMode" + ] + .iter() + .map(|s| s.to_string()) + .collect() + ); + static ref ENUM_TYPES: Mutex> = Mutex::new(HashSet::new()); +} + +pub fn parse(spec: &str) -> Definitions { + let types = from_str::>(spec).expect("Failed to parse AMQP spec."); + { + let mut ref_map = REF_TYPES.lock().unwrap(); + let mut enum_map = ENUM_TYPES.lock().unwrap(); + for t in types.iter() { + match *t { + _Type::Described(ref l) if l.source == "list" => { + ref_map.insert(camel_case(&*l.name)); + } + _Type::Choice(ref e) => { + enum_map.insert(camel_case(&*e.name)); + } + _ => {} + } + } + } + + Definitions::from(types) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +enum _Type { + Choice(_Enum), + Described(_Described), + Alias(_Alias), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct _Descriptor { + name: String, + code: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct _Enum { + name: String, + source: String, + provides: Option, + choice: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct _Described { + name: String, + class: String, + source: String, + provides: Option, + descriptor: _Descriptor, + #[serde(default)] + field: Vec<_Field>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct _Field { + name: String, + #[serde(rename = "type")] + ty: String, + #[serde(default)] + #[serde(deserialize_with = "string_as_bool")] + mandatory: bool, + default: Option, + #[serde(default)] + #[serde(deserialize_with = "string_as_bool")] + multiple: bool, + requires: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct _Alias { + name: String, + source: String, + provides: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Alias { + name: String, + source: String, + provides: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EnumItem { + name: String, + value: String, + #[serde(default)] + value_len: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Definitions { + aliases: Vec, + enums: Vec, + lists: Vec, + described_restricted: Vec, + provides: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ProvidesEnum { + name: String, + described: bool, + options: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ProvidesItem { + ty: String, + descriptor: Descriptor, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Enum { + name: String, + ty: String, + provides: Vec, + items: Vec, + is_symbol: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Described { + name: String, + ty: String, + provides: Vec, + descriptor: Descriptor, + fields: Vec, + transfer: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Descriptor { + name: String, + domain: u32, + code: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Field { + name: String, + ty: String, + is_str: bool, + is_ref: bool, + optional: bool, + default: String, + multiple: bool, +} + +impl Definitions { + fn from(types: Vec<_Type>) -> Definitions { + let mut aliases = vec![]; + let mut enums = vec![]; + let mut lists = vec![]; + let mut described_restricted = vec![]; + let mut provide_map: HashMap> = HashMap::new(); + for t in types.into_iter() { + match t { + _Type::Alias(ref a) if a.source != "map" => { + let al = Alias::from(a.clone()); + Definitions::register_provides(&mut provide_map, &al.name, None, &al.provides); + aliases.push(al); + } + _Type::Choice(ref e) => { + let en = Enum::from(e.clone()); + Definitions::register_provides(&mut provide_map, &en.name, None, &en.provides); + enums.push(en); + } + _Type::Described(ref d) if d.source == "list" && d.class != "restricted" => { + let ls = Described::list(d.clone()); + Definitions::register_provides( + &mut provide_map, + &ls.name, + Some(ls.descriptor.clone()), + &ls.provides, + ); + lists.push(ls); + } + _Type::Described(ref d) if d.class == "restricted" => { + let ls = Described::alias(d.clone()); + Definitions::register_provides( + &mut provide_map, + &ls.name, + Some(ls.descriptor.clone()), + &ls.provides, + ); + described_restricted.push(ls); + } + _ => {} + } + } + + let provides = provide_map + .into_iter() + .filter_map(|(k, v)| { + if v.len() == 1 { + None + } else { + Some(ProvidesEnum { + described: if k == "Frame" { + false + } else { + v.iter().any(|v| v.descriptor.code != 0) + }, + name: k, + options: v, + }) + } + }) + .collect(); + + Definitions { + aliases, + enums, + lists, + described_restricted, + provides, + } + } + + fn register_provides( + map: &mut HashMap>, + name: &str, + descriptor: Option, + provides: &Vec, + ) { + for p in provides.iter() { + map.entry(p.clone()) + .or_insert_with(|| vec![]) + .push(ProvidesItem { + ty: name.to_string(), + descriptor: descriptor.clone().unwrap_or_else(|| Descriptor { + name: String::new(), + domain: 0, + code: 0, + }), + }); + } + } +} + +impl Alias { + fn from(a: _Alias) -> Alias { + Alias { + name: camel_case(&*a.name), + source: get_type_name(&*a.source, None), + provides: parse_provides(a.provides), + } + } +} + +impl Enum { + fn from(e: _Enum) -> Enum { + let ty = get_type_name(&*e.source, None); + let is_symbol = ty == "Symbol"; + Enum { + name: camel_case(&*e.name), + ty: ty.clone(), + provides: parse_provides(e.provides), + is_symbol, + items: e + .choice + .into_iter() + .map(|c| EnumItem { + name: camel_case(&*c.name), + value_len: c.value.len(), + value: c.value, + }) + .collect(), + } + } +} + +impl Described { + fn list(d: _Described) -> Described { + let transfer = d.name == "transfer"; + Described { + name: camel_case(&d.name), + ty: String::new(), + provides: parse_provides(d.provides), + descriptor: Descriptor::from(d.descriptor), + fields: d.field.into_iter().map(|f| Field::from(f)).collect(), + transfer, + } + } + fn alias(d: _Described) -> Described { + Described { + name: camel_case(&d.name), + ty: get_type_name(&d.source, None), + provides: parse_provides(d.provides), + descriptor: Descriptor::from(d.descriptor), + fields: d.field.into_iter().map(|f| Field::from(f)).collect(), + transfer: false, + } + } +} + +impl Descriptor { + fn from(d: _Descriptor) -> Descriptor { + let code_parts: Vec = d + .code + .split(":") + .map(|p| { + assert!(p.starts_with("0x")); + u32::from_str_radix(&p[2..], 16).expect("malformed descriptor code") + }) + .collect(); + Descriptor { + name: d.name, + domain: code_parts[0], + code: code_parts[1], + } + } +} +impl Field { + fn from(field: _Field) -> Field { + let mut ty = get_type_name(&*field.ty, field.requires); + if field.multiple { + ty.push('s'); + } + let is_str = STRING_TYPES.contains(&*ty) && !field.multiple; + let is_ref = REF_TYPES.lock().unwrap().contains(&ty); + let default = Field::format_default(field.default, &ty); + Field { + name: snake_case(&*field.name), + ty: ty, + is_ref, + is_str, + optional: !field.mandatory && default.len() == 0, + multiple: field.multiple, + default, + } + } + + fn format_default(default: Option, ty: &str) -> String { + match default { + None => String::new(), + Some(def) => { + if ENUM_TYPES.lock().unwrap().contains(ty) { + format!("{}::{}", ty, camel_case(&*def)) + } else { + def + } + } + } + } +} + +fn get_type_name(ty: &str, req: Option) -> String { + match req { + Some(t) => camel_case(&*t), + None => match PRIMITIVE_TYPES.get(ty) { + Some(p) => p.to_string(), + None => camel_case(&*ty), + }, + } +} + +fn parse_provides(p: Option) -> Vec { + p.map(|v| { + v.split_terminator(",") + .filter_map(|s| { + let s = s.trim(); + if s == "" { + None + } else { + Some(camel_case(&s)) + } + }) + .collect() + }) + .unwrap_or(vec![]) +} + +fn string_as_bool<'de, T, D>(deserializer: D) -> Result +where + T: FromStr, + D: Deserializer<'de>, +{ + Ok(String::deserialize(deserializer)? + .parse::() + .expect("Error parsing bool from string")) +} + +pub fn camel_case(name: &str) -> String { + let mut new_word = true; + name.chars().fold("".to_string(), |mut result, ch| { + if ch == '-' || ch == '_' || ch == ' ' { + new_word = true; + result + } else { + result.push(if new_word { + ch.to_ascii_uppercase() + } else { + ch + }); + new_word = false; + result + } + }) +} + +pub fn snake_case(name: &str) -> String { + match name { + "type" => "type_".to_string(), + "return" => "return_".to_string(), + name => { + let mut new_word = false; + let mut last_was_upper = false; + name.chars().fold("".to_string(), |mut result, ch| { + if ch == '-' || ch == '_' || ch == ' ' { + new_word = true; + result + } else { + let uppercase = ch.is_uppercase(); + if new_word || (!last_was_upper && !result.is_empty() && uppercase) { + result.push('_'); + new_word = false; + } + last_was_upper = uppercase; + result.push(if uppercase { + ch.to_ascii_lowercase() + } else { + ch + }); + result + } + }) + } + } +} diff --git a/actix-amqp/codec/codegen/specification.json b/actix-amqp/codec/codegen/specification.json new file mode 100755 index 000000000..1c4e5e6b5 --- /dev/null +++ b/actix-amqp/codec/codegen/specification.json @@ -0,0 +1,1234 @@ +[ + { + "name": "role", + "class": "restricted", + "source": "boolean", + "choice": [ + { + "name": "sender", + "value": "false" + }, + { + "name": "receiver", + "value": "true" + } + ] + }, + { + "name": "sender-settle-mode", + "class": "restricted", + "source": "ubyte", + "choice": [ + { + "name": "unsettled", + "value": "0" + }, + { + "name": "settled", + "value": "1" + }, + { + "name": "mixed", + "value": "2" + } + ] + }, + { + "name": "receiver-settle-mode", + "class": "restricted", + "source": "ubyte", + "choice": [ + { + "name": "first", + "value": "0" + }, + { + "name": "second", + "value": "1" + } + ] + }, + { + "name": "handle", + "class": "restricted", + "source": "uint" + }, + { + "name": "seconds", + "class": "restricted", + "source": "uint" + }, + { + "name": "milliseconds", + "class": "restricted", + "source": "uint" + }, + { + "name": "delivery-tag", + "class": "restricted", + "source": "binary" + }, + { + "name": "sequence-no", + "class": "restricted", + "source": "uint" + }, + { + "name": "delivery-number", + "class": "restricted", + "source": "sequence-no" + }, + { + "name": "transfer-number", + "class": "restricted", + "source": "sequence-no" + }, + { + "name": "message-format", + "class": "restricted", + "source": "uint" + }, + { + "name": "ietf-language-tag", + "class": "restricted", + "source": "symbol" + }, + { + "name": "fields", + "class": "restricted", + "source": "map" + }, + { + "name": "error", + "class": "composite", + "source": "list", + "descriptor": { + "name": "amqp:error:list", + "code": "0x00000000:0x0000001d" + }, + "field": [ + { + "name": "condition", + "type": "symbol", + "requires": "error-condition", + "mandatory": "true" + }, + { + "name": "description", + "type": "string" + }, + { + "name": "info", + "type": "fields" + } + ] + }, + { + "name": "amqp-error", + "class": "restricted", + "source": "symbol", + "provides": "error-condition", + "choice": [ + { + "name": "internal-error", + "value": "amqp:internal-error" + }, + { + "name": "not-found", + "value": "amqp:not-found" + }, + { + "name": "unauthorized-access", + "value": "amqp:unauthorized-access" + }, + { + "name": "decode-error", + "value": "amqp:decode-error" + }, + { + "name": "resource-limit-exceeded", + "value": "amqp:resource-limit-exceeded" + }, + { + "name": "not-allowed", + "value": "amqp:not-allowed" + }, + { + "name": "invalid-field", + "value": "amqp:invalid-field" + }, + { + "name": "not-implemented", + "value": "amqp:not-implemented" + }, + { + "name": "resource-locked", + "value": "amqp:resource-locked" + }, + { + "name": "precondition-failed", + "value": "amqp:precondition-failed" + }, + { + "name": "resource-deleted", + "value": "amqp:resource-deleted" + }, + { + "name": "illegal-state", + "value": "amqp:illegal-state" + }, + { + "name": "frame-size-too-small", + "value": "amqp:frame-size-too-small" + } + ] + }, + { + "name": "connection-error", + "class": "restricted", + "source": "symbol", + "provides": "error-condition", + "choice": [ + { + "name": "connection-forced", + "value": "amqp:connection:forced" + }, + { + "name": "framing-error", + "value": "amqp:connection:framing-error" + }, + { + "name": "redirect", + "value": "amqp:connection:redirect" + } + ] + }, + { + "name": "session-error", + "class": "restricted", + "source": "symbol", + "provides": "error-condition", + "choice": [ + { + "name": "window-violation", + "value": "amqp:session:window-violation" + }, + { + "name": "errant-link", + "value": "amqp:session:errant-link" + }, + { + "name": "handle-in-use", + "value": "amqp:session:handle-in-use" + }, + { + "name": "unattached-handle", + "value": "amqp:session:unattached-handle" + } + ] + }, + { + "name": "link-error", + "class": "restricted", + "source": "symbol", + "provides": "error-condition", + "choice": [ + { + "name": "detach-forced", + "value": "amqp:link:detach-forced" + }, + { + "name": "transfer-limit-exceeded", + "value": "amqp:link:transfer-limit-exceeded" + }, + { + "name": "message-size-exceeded", + "value": "amqp:link:message-size-exceeded" + }, + { + "name": "redirect", + "value": "amqp:link:redirect" + }, + { + "name": "stolen", + "value": "amqp:link:stolen" + } + ] + }, + { + "name": "open", + "class": "composite", + "source": "list", + "provides": "frame", + "descriptor": { + "name": "amqp:open:list", + "code": "0x00000000:0x00000010" + }, + "field": [ + { + "name": "container-id", + "type": "string", + "mandatory": "true" + }, + { + "name": "hostname", + "type": "string" + }, + { + "name": "max-frame-size", + "type": "uint", + "default": "4294967295" + }, + { + "name": "channel-max", + "type": "ushort", + "default": "65535" + }, + { + "name": "idle-time-out", + "type": "milliseconds" + }, + { + "name": "outgoing-locales", + "type": "ietf-language-tag", + "multiple": "true" + }, + { + "name": "incoming-locales", + "type": "ietf-language-tag", + "multiple": "true" + }, + { + "name": "offered-capabilities", + "type": "symbol", + "multiple": "true" + }, + { + "name": "desired-capabilities", + "type": "symbol", + "multiple": "true" + }, + { + "name": "properties", + "type": "fields" + } + ] + }, + { + "name": "begin", + "class": "composite", + "source": "list", + "provides": "frame", + "descriptor": { + "name": "amqp:begin:list", + "code": "0x00000000:0x00000011" + }, + "field": [ + { + "name": "remote-channel", + "type": "ushort" + }, + { + "name": "next-outgoing-id", + "type": "transfer-number", + "mandatory": "true" + }, + { + "name": "incoming-window", + "type": "uint", + "mandatory": "true" + }, + { + "name": "outgoing-window", + "type": "uint", + "mandatory": "true" + }, + { + "name": "handle-max", + "type": "handle", + "default": "4294967295" + }, + { + "name": "offered-capabilities", + "type": "symbol", + "multiple": "true" + }, + { + "name": "desired-capabilities", + "type": "symbol", + "multiple": "true" + }, + { + "name": "properties", + "type": "fields" + } + ] + }, + { + "name": "attach", + "class": "composite", + "source": "list", + "provides": "frame", + "descriptor": { + "name": "amqp:attach:list", + "code": "0x00000000:0x00000012" + }, + "field": [ + { + "name": "name", + "type": "string", + "mandatory": "true" + }, + { + "name": "handle", + "type": "handle", + "mandatory": "true" + }, + { + "name": "role", + "type": "role", + "mandatory": "true" + }, + { + "name": "snd-settle-mode", + "type": "sender-settle-mode", + "default": "mixed" + }, + { + "name": "rcv-settle-mode", + "type": "receiver-settle-mode", + "default": "first" + }, + { + "name": "source", + "type": "*", + "requires": "source" + }, + { + "name": "target", + "type": "*", + "requires": "target" + }, + { + "name": "unsettled", + "type": "map" + }, + { + "name": "incomplete-unsettled", + "type": "boolean", + "default": "false" + }, + { + "name": "initial-delivery-count", + "type": "sequence-no" + }, + { + "name": "max-message-size", + "type": "ulong" + }, + { + "name": "offered-capabilities", + "type": "symbol", + "multiple": "true" + }, + { + "name": "desired-capabilities", + "type": "symbol", + "multiple": "true" + }, + { + "name": "properties", + "type": "fields" + } + ] + }, + { + "name": "flow", + "class": "composite", + "source": "list", + "provides": "frame", + "descriptor": { + "name": "amqp:flow:list", + "code": "0x00000000:0x00000013" + }, + "field": [ + { + "name": "next-incoming-id", + "type": "transfer-number" + }, + { + "name": "incoming-window", + "type": "uint", + "mandatory": "true" + }, + { + "name": "next-outgoing-id", + "type": "transfer-number", + "mandatory": "true" + }, + { + "name": "outgoing-window", + "type": "uint", + "mandatory": "true" + }, + { + "name": "handle", + "type": "handle" + }, + { + "name": "delivery-count", + "type": "sequence-no" + }, + { + "name": "link-credit", + "type": "uint" + }, + { + "name": "available", + "type": "uint" + }, + { + "name": "drain", + "type": "boolean", + "default": "false" + }, + { + "name": "echo", + "type": "boolean", + "default": "false" + }, + { + "name": "properties", + "type": "fields" + } + ] + }, + { + "name": "transfer", + "class": "composite", + "source": "list", + "provides": "frame", + "descriptor": { + "name": "amqp:transfer:list", + "code": "0x00000000:0x00000014" + }, + "field": [ + { + "name": "handle", + "type": "handle", + "mandatory": "true" + }, + { + "name": "delivery-id", + "type": "delivery-number" + }, + { + "name": "delivery-tag", + "type": "delivery-tag" + }, + { + "name": "message-format", + "type": "message-format" + }, + { + "name": "settled", + "type": "boolean" + }, + { + "name": "more", + "type": "boolean", + "default": "false" + }, + { + "name": "rcv-settle-mode", + "type": "receiver-settle-mode" + }, + { + "name": "state", + "type": "*", + "requires": "delivery-state" + }, + { + "name": "resume", + "type": "boolean", + "default": "false" + }, + { + "name": "aborted", + "type": "boolean", + "default": "false" + }, + { + "name": "batchable", + "type": "boolean", + "default": "false" + } + ] + }, + { + "name": "disposition", + "class": "composite", + "source": "list", + "provides": "frame", + "descriptor": { + "name": "amqp:disposition:list", + "code": "0x00000000:0x00000015" + }, + "field": [ + { + "name": "role", + "type": "role", + "mandatory": "true" + }, + { + "name": "first", + "type": "delivery-number", + "mandatory": "true" + }, + { + "name": "last", + "type": "delivery-number" + }, + { + "name": "settled", + "type": "boolean", + "default": "false" + }, + { + "name": "state", + "type": "*", + "requires": "delivery-state" + }, + { + "name": "batchable", + "type": "boolean", + "default": "false" + } + ] + }, + { + "name": "detach", + "class": "composite", + "source": "list", + "provides": "frame", + "descriptor": { + "name": "amqp:detach:list", + "code": "0x00000000:0x00000016" + }, + "field": [ + { + "name": "handle", + "type": "handle", + "mandatory": "true" + }, + { + "name": "closed", + "type": "boolean", + "default": "false" + }, + { + "name": "error", + "type": "error" + } + ] + }, + { + "name": "end", + "class": "composite", + "source": "list", + "provides": "frame", + "descriptor": { + "name": "amqp:end:list", + "code": "0x00000000:0x00000017" + }, + "field": [{ + "name": "error", + "type": "error" + }] + }, + { + "name": "close", + "class": "composite", + "source": "list", + "provides": "frame", + "descriptor": { + "name": "amqp:close:list", + "code": "0x00000000:0x00000018" + }, + "field": [{ + "name": "error", + "type": "error" + }] + }, + { + "name": "sasl-code", + "class": "restricted", + "source": "ubyte", + "choice": [ + { + "name": "ok", + "value": "0" + }, + { + "name": "auth", + "value": "1" + }, + { + "name": "sys", + "value": "2" + }, + { + "name": "sys-perm", + "value": "3" + }, + { + "name": "sys-temp", + "value": "4" + } + ] + }, + { + "name": "sasl-mechanisms", + "class": "composite", + "source": "list", + "provides": "sasl-frame-body", + "descriptor": { + "name": "amqp:sasl-mechanisms:list", + "code": "0x00000000:0x00000040" + }, + "field": [{ + "name": "sasl-server-mechanisms", + "type": "symbol", + "multiple": "true", + "mandatory": "true" + }] + }, + { + "name": "sasl-init", + "class": "composite", + "source": "list", + "provides": "sasl-frame-body", + "descriptor": { + "name": "amqp:sasl-init:list", + "code": "0x00000000:0x00000041" + }, + "field": [ + { + "name": "mechanism", + "type": "symbol", + "mandatory": "true" + }, + { + "name": "initial-response", + "type": "binary" + }, + { + "name": "hostname", + "type": "string" + } + ] + }, + { + "name": "sasl-challenge", + "class": "composite", + "source": "list", + "provides": "sasl-frame-body", + "descriptor": { + "name": "amqp:sasl-challenge:list", + "code": "0x00000000:0x00000042" + }, + "field": [{ + "name": "challenge", + "type": "binary", + "mandatory": "true" + }] + }, + { + "name": "sasl-response", + "class": "composite", + "source": "list", + "provides": "sasl-frame-body", + "descriptor": { + "name": "amqp:sasl-response:list", + "code": "0x00000000:0x00000043" + }, + "field": [{ + "name": "response", + "type": "binary", + "mandatory": "true" + }] + }, + { + "name": "sasl-outcome", + "class": "composite", + "source": "list", + "provides": "sasl-frame-body", + "descriptor": { + "name": "amqp:sasl-outcome:list", + "code": "0x00000000:0x00000044" + }, + "field": [ + { + "name": "code", + "type": "sasl-code", + "mandatory": "true" + }, + { + "name": "additional-data", + "type": "binary" + } + ] + }, + { + "name": "terminus-durability", + "class": "restricted", + "source": "uint", + "choice": [ + { + "name": "none", + "value": "0" + }, + { + "name": "configuration", + "value": "1" + }, + { + "name": "unsettled-state", + "value": "2" + } + ] + }, + { + "name": "terminus-expiry-policy", + "class": "restricted", + "source": "symbol", + "choice": [ + { + "name": "link-detach", + "value": "link-detach" + }, + { + "name": "session-end", + "value": "session-end" + }, + { + "name": "connection-close", + "value": "connection-close" + }, + { + "name": "never", + "value": "never" + } + ] + }, + { + "name": "node-properties", + "class": "restricted", + "source": "fields" + }, + { + "name": "filter-set", + "class": "restricted", + "source": "map" + }, + { + "name": "source", + "class": "composite", + "source": "list", + "provides": "source", + "descriptor": { + "name": "amqp:source:list", + "code": "0x00000000:0x00000028" + }, + "field": [ + { + "name": "address", + "type": "*", + "requires": "address" + }, + { + "name": "durable", + "type": "terminus-durability", + "default": "none" + }, + { + "name": "expiry-policy", + "type": "terminus-expiry-policy", + "default": "session-end" + }, + { + "name": "timeout", + "type": "seconds", + "default": "0" + }, + { + "name": "dynamic", + "type": "boolean", + "default": "false" + }, + { + "name": "dynamic-node-properties", + "type": "node-properties" + }, + { + "name": "distribution-mode", + "type": "symbol", + "requires": "distribution-mode" + }, + { + "name": "filter", + "type": "filter-set" + }, + { + "name": "default-outcome", + "type": "*", + "requires": "outcome" + }, + { + "name": "outcomes", + "type": "symbol", + "multiple": "true" + }, + { + "name": "capabilities", + "type": "symbol", + "multiple": "true" + } + ] + }, + { + "name": "target", + "class": "composite", + "source": "list", + "provides": "target", + "descriptor": { + "name": "amqp:target:list", + "code": "0x00000000:0x00000029" + }, + "field": [ + { + "name": "address", + "type": "*", + "requires": "address" + }, + { + "name": "durable", + "type": "terminus-durability", + "default": "none" + }, + { + "name": "expiry-policy", + "type": "terminus-expiry-policy", + "default": "session-end" + }, + { + "name": "timeout", + "type": "seconds", + "default": "0" + }, + { + "name": "dynamic", + "type": "boolean", + "default": "false" + }, + { + "name": "dynamic-node-properties", + "type": "node-properties" + }, + { + "name": "capabilities", + "type": "symbol", + "multiple": "true" + } + ] + }, + { + "name": "annotations", + "class": "restricted", + "source": "map" + }, + { + "name": "message-id-ulong", + "class": "restricted", + "source": "ulong", + "provides": "message-id" + }, + { + "name": "message-id-uuid", + "class": "restricted", + "source": "uuid", + "provides": "message-id" + }, + { + "name": "message-id-binary", + "class": "restricted", + "source": "binary", + "provides": "message-id" + }, + { + "name": "message-id-string", + "class": "restricted", + "source": "string", + "provides": "message-id" + }, + { + "name": "address", + "class": "restricted", + "source": "string", + "provides": "address" + }, + { + "name": "header", + "class": "composite", + "source": "list", + "provides": "section", + "descriptor": { + "name": "amqp:header:list", + "code": "0x00000000:0x00000070" + }, + "field": [ + { + "name": "durable", + "type": "boolean", + "default": "false" + }, + { + "name": "priority", + "type": "ubyte", + "default": "4" + }, + { + "name": "ttl", + "type": "milliseconds" + }, + { + "name": "first-acquirer", + "type": "boolean", + "default": "false" + }, + { + "name": "delivery-count", + "type": "uint", + "default": "0" + } + ] + }, + { + "name": "delivery-annotations", + "class": "restricted", + "source": "annotations", + "provides": "section", + "descriptor": { + "name": "amqp:delivery-annotations:map", + "code": "0x00000000:0x00000071" + } + }, + { + "name": "message-annotations", + "class": "restricted", + "source": "annotations", + "provides": "section", + "descriptor": { + "name": "amqp:message-annotations:map", + "code": "0x00000000:0x00000072" + } + }, + { + "name": "application-properties", + "class": "restricted", + "source": "string-variant-map", + "provides": "section", + "descriptor": { + "name": "amqp:application-properties:map", + "code": "0x00000000:0x00000074" + } + }, + { + "name": "data", + "class": "restricted", + "source": "binary", + "provides": "section", + "descriptor": { + "name": "amqp:data:binary", + "code": "0x00000000:0x00000075" + } + }, + { + "name": "amqp-sequence", + "class": "restricted", + "source": "list", + "provides": "section", + "descriptor": { + "name": "amqp:amqp-sequence:list", + "code": "0x00000000:0x00000076" + } + }, + { + "name": "amqp-value", + "class": "restricted", + "source": "*", + "provides": "section", + "descriptor": { + "name": "amqp:amqp-value:*", + "code": "0x00000000:0x00000077" + } + }, + { + "name": "footer", + "class": "restricted", + "source": "annotations", + "provides": "section", + "descriptor": { + "name": "amqp:footer:map", + "code": "0x00000000:0x00000078" + } + }, + { + "name": "properties", + "class": "composite", + "source": "list", + "provides": "section", + "descriptor": { + "name": "amqp:properties:list", + "code": "0x00000000:0x00000073" + }, + "field": [ + { + "name": "message-id", + "type": "*", + "requires": "message-id" + }, + { + "name": "user-id", + "type": "binary" + }, + { + "name": "to", + "type": "*", + "requires": "address" + }, + { + "name": "subject", + "type": "string" + }, + { + "name": "reply-to", + "type": "*", + "requires": "address" + }, + { + "name": "correlation-id", + "type": "*", + "requires": "message-id" + }, + { + "name": "content-type", + "type": "symbol" + }, + { + "name": "content-encoding", + "type": "symbol" + }, + { + "name": "absolute-expiry-time", + "type": "timestamp" + }, + { + "name": "creation-time", + "type": "timestamp" + }, + { + "name": "group-id", + "type": "string" + }, + { + "name": "group-sequence", + "type": "sequence-no" + }, + { + "name": "reply-to-group-id", + "type": "string" + } + ] + }, + { + "name": "received", + "class": "composite", + "source": "list", + "provides": "delivery-state", + "descriptor": { + "name": "amqp:received:list", + "code": "0x00000000:0x00000023" + }, + "field": [ + { + "name": "section-number", + "type": "uint", + "mandatory": "true" + }, + { + "name": "section-offset", + "type": "ulong", + "mandatory": "true" + } + ] + }, + { + "name": "accepted", + "class": "composite", + "source": "list", + "provides": "delivery-state, outcome", + "descriptor": { + "name": "amqp:accepted:list", + "code": "0x00000000:0x00000024" + } + }, + { + "name": "rejected", + "class": "composite", + "source": "list", + "provides": "delivery-state, outcome", + "descriptor": { + "name": "amqp:rejected:list", + "code": "0x00000000:0x00000025" + }, + "field": [{ + "name": "error", + "type": "error" + }] + }, + { + "name": "released", + "class": "composite", + "source": "list", + "provides": "delivery-state, outcome", + "descriptor": { + "name": "amqp:released:list", + "code": "0x00000000:0x00000026" + } + }, + { + "name": "modified", + "class": "composite", + "source": "list", + "provides": "delivery-state, outcome", + "descriptor": { + "name": "amqp:modified:list", + "code": "0x00000000:0x00000027" + }, + "field": [ + { + "name": "delivery-failed", + "type": "boolean" + }, + { + "name": "undeliverable-here", + "type": "boolean" + }, + { + "name": "message-annotations", + "type": "fields" + } + ] + } +] \ No newline at end of file diff --git a/actix-amqp/codec/src/codec/decode.rs b/actix-amqp/codec/src/codec/decode.rs new file mode 100755 index 000000000..f28faa334 --- /dev/null +++ b/actix-amqp/codec/src/codec/decode.rs @@ -0,0 +1,834 @@ +use std::collections::HashMap; +use std::convert::TryFrom; +use std::hash::{BuildHasher, Hash}; +use std::{char, str, u8}; + +use byteorder::{BigEndian, ByteOrder}; +use bytes::Bytes; +use bytestring::ByteString; +use chrono::{DateTime, TimeZone, Utc}; +use fxhash::FxHashMap; +use ordered_float::OrderedFloat; +use uuid::Uuid; + +use crate::codec::{self, ArrayDecode, Decode, DecodeFormatted}; +use crate::errors::AmqpParseError; +use crate::framing::{self, AmqpFrame, SaslFrame, HEADER_LEN}; +use crate::protocol::{self, CompoundHeader}; +use crate::types::{ + Descriptor, List, Multiple, Str, Symbol, Variant, VariantMap, VecStringMap, VecSymbolMap, +}; + +macro_rules! be_read { + ($input:ident, $fn:ident, $size:expr) => {{ + decode_check_len!($input, $size); + let x = BigEndian::$fn($input); + Ok((&$input[$size..], x)) + }}; +} + +fn read_u8(input: &[u8]) -> Result<(&[u8], u8), AmqpParseError> { + decode_check_len!(input, 1); + Ok((&input[1..], input[0])) +} + +fn read_i8(input: &[u8]) -> Result<(&[u8], i8), AmqpParseError> { + decode_check_len!(input, 1); + Ok((&input[1..], input[0] as i8)) +} + +fn read_bytes_u8(input: &[u8]) -> Result<(&[u8], &[u8]), AmqpParseError> { + let (input, len) = read_u8(input)?; + let len = len as usize; + decode_check_len!(input, len); + let (bytes, input) = input.split_at(len); + Ok((input, bytes)) +} + +fn read_bytes_u32(input: &[u8]) -> Result<(&[u8], &[u8]), AmqpParseError> { + let result: Result<(&[u8], u32), AmqpParseError> = be_read!(input, read_u32, 4); + let (input, len) = result?; + let len = len as usize; + decode_check_len!(input, len); + let (bytes, input) = input.split_at(len); + Ok((input, bytes)) +} + +#[macro_export] +macro_rules! validate_code { + ($fmt:ident, $code:expr) => { + if $fmt != $code { + return Err(AmqpParseError::InvalidFormatCode($fmt)); + } + }; +} + +impl DecodeFormatted for bool { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + match fmt { + codec::FORMATCODE_BOOLEAN => read_u8(input).map(|(i, o)| (i, o != 0)), + codec::FORMATCODE_BOOLEAN_TRUE => Ok((input, true)), + codec::FORMATCODE_BOOLEAN_FALSE => Ok((input, false)), + _ => Err(AmqpParseError::InvalidFormatCode(fmt)), + } + } +} + +impl DecodeFormatted for u8 { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_UBYTE); + read_u8(input) + } +} + +impl DecodeFormatted for u16 { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_USHORT); + be_read!(input, read_u16, 2) + } +} + +impl DecodeFormatted for u32 { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + match fmt { + codec::FORMATCODE_UINT => be_read!(input, read_u32, 4), + codec::FORMATCODE_SMALLUINT => read_u8(input).map(|(i, o)| (i, u32::from(o))), + codec::FORMATCODE_UINT_0 => Ok((input, 0)), + _ => Err(AmqpParseError::InvalidFormatCode(fmt)), + } + } +} + +impl DecodeFormatted for u64 { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + match fmt { + codec::FORMATCODE_ULONG => be_read!(input, read_u64, 8), + codec::FORMATCODE_SMALLULONG => read_u8(input).map(|(i, o)| (i, u64::from(o))), + codec::FORMATCODE_ULONG_0 => Ok((input, 0)), + _ => Err(AmqpParseError::InvalidFormatCode(fmt)), + } + } +} + +impl DecodeFormatted for i8 { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_BYTE); + read_i8(input) + } +} + +impl DecodeFormatted for i16 { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_SHORT); + be_read!(input, read_i16, 2) + } +} + +impl DecodeFormatted for i32 { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + match fmt { + codec::FORMATCODE_INT => be_read!(input, read_i32, 4), + codec::FORMATCODE_SMALLINT => read_i8(input).map(|(i, o)| (i, i32::from(o))), + _ => Err(AmqpParseError::InvalidFormatCode(fmt)), + } + } +} + +impl DecodeFormatted for i64 { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + match fmt { + codec::FORMATCODE_LONG => be_read!(input, read_i64, 8), + codec::FORMATCODE_SMALLLONG => read_i8(input).map(|(i, o)| (i, i64::from(o))), + _ => Err(AmqpParseError::InvalidFormatCode(fmt)), + } + } +} + +impl DecodeFormatted for f32 { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_FLOAT); + be_read!(input, read_f32, 4) + } +} + +impl DecodeFormatted for f64 { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DOUBLE); + be_read!(input, read_f64, 8) + } +} + +impl DecodeFormatted for char { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_CHAR); + let result: Result<(&[u8], u32), AmqpParseError> = be_read!(input, read_u32, 4); + let (i, o) = result?; + if let Some(c) = char::from_u32(o) { + Ok((i, c)) + } else { + Err(AmqpParseError::InvalidChar(o)) + } // todo: replace with CharTryFromError once try_from is stabilized + } +} + +impl DecodeFormatted for DateTime { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_TIMESTAMP); + be_read!(input, read_i64, 8).map(|(i, o)| (i, datetime_from_millis(o))) + } +} + +impl DecodeFormatted for Uuid { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_UUID); + decode_check_len!(input, 16); + let uuid = Uuid::from_slice(&input[..16])?; + Ok((&input[16..], uuid)) + } +} + +impl DecodeFormatted for Bytes { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + match fmt { + codec::FORMATCODE_BINARY8 => { + read_bytes_u8(input).map(|(i, o)| (i, Bytes::copy_from_slice(o))) + } + codec::FORMATCODE_BINARY32 => { + read_bytes_u32(input).map(|(i, o)| (i, Bytes::copy_from_slice(o))) + } + _ => Err(AmqpParseError::InvalidFormatCode(fmt)), + } + } +} + +impl DecodeFormatted for ByteString { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + match fmt { + codec::FORMATCODE_STRING8 => { + let (input, bytes) = read_bytes_u8(input)?; + Ok((input, ByteString::try_from(bytes)?)) + } + codec::FORMATCODE_STRING32 => { + let (input, bytes) = read_bytes_u32(input)?; + Ok((input, ByteString::try_from(bytes)?)) + } + _ => Err(AmqpParseError::InvalidFormatCode(fmt)), + } + } +} + +impl DecodeFormatted for Str { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + match fmt { + codec::FORMATCODE_STRING8 => { + let (input, bytes) = read_bytes_u8(input)?; + Ok((input, Str::from_str(str::from_utf8(bytes)?))) + } + codec::FORMATCODE_STRING32 => { + let (input, bytes) = read_bytes_u32(input)?; + Ok((input, Str::from_str(str::from_utf8(bytes)?))) + } + _ => Err(AmqpParseError::InvalidFormatCode(fmt)), + } + } +} + +impl DecodeFormatted for Symbol { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + match fmt { + codec::FORMATCODE_SYMBOL8 => { + let (input, bytes) = read_bytes_u8(input)?; + Ok((input, Symbol::from_slice(str::from_utf8(bytes)?))) + } + codec::FORMATCODE_SYMBOL32 => { + let (input, bytes) = read_bytes_u32(input)?; + Ok((input, Symbol::from_slice(str::from_utf8(bytes)?))) + } + _ => Err(AmqpParseError::InvalidFormatCode(fmt)), + } + } +} + +impl ArrayDecode for Symbol { + fn array_decode(input: &[u8]) -> Result<(&[u8], Self), AmqpParseError> { + let (input, bytes) = read_bytes_u32(input)?; + Ok((input, Symbol::from_slice(str::from_utf8(bytes)?))) + } +} + +impl DecodeFormatted + for HashMap +{ + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, header) = decode_map_header(input, fmt)?; + let mut map_input = &input[..header.size as usize]; + let count = header.count / 2; + let mut map: HashMap = + HashMap::with_capacity_and_hasher(count as usize, Default::default()); + for _ in 0..count { + let (input1, key) = K::decode(map_input)?; + let (input2, value) = V::decode(input1)?; + map_input = input2; + map.insert(key, value); // todo: ensure None returned? + } + // todo: validate map_input is empty + Ok((&input[header.size as usize..], map)) + } +} + +impl DecodeFormatted for Vec { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, header) = decode_array_header(input, fmt)?; + let item_fmt = input[0]; // todo: support descriptor + let mut input = &input[1..]; + let mut result: Vec = Vec::with_capacity(header.count as usize); + for _ in 0..header.count { + let (new_input, decoded) = T::decode_with_format(input, item_fmt)?; + result.push(decoded); + input = new_input; + } + Ok((input, result)) + } +} + +impl DecodeFormatted for VecSymbolMap { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, header) = decode_map_header(input, fmt)?; + let mut map_input = &input[..header.size as usize]; + let count = header.count / 2; + let mut map = Vec::with_capacity(count as usize); + for _ in 0..count { + let (input1, key) = Symbol::decode(map_input)?; + let (input2, value) = Variant::decode(input1)?; + map_input = input2; + map.push((key, value)); // todo: ensure None returned? + } + // todo: validate map_input is empty + Ok((&input[header.size as usize..], VecSymbolMap(map))) + } +} + +impl DecodeFormatted for VecStringMap { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, header) = decode_map_header(input, fmt)?; + let mut map_input = &input[..header.size as usize]; + let count = header.count / 2; + let mut map = Vec::with_capacity(count as usize); + for _ in 0..count { + let (input1, key) = Str::decode(map_input)?; + let (input2, value) = Variant::decode(input1)?; + map_input = input2; + map.push((key, value)); // todo: ensure None returned? + } + // todo: validate map_input is empty + Ok((&input[header.size as usize..], VecStringMap(map))) + } +} + +impl DecodeFormatted for Multiple { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + match fmt { + codec::FORMATCODE_ARRAY8 | codec::FORMATCODE_ARRAY32 => { + let (input, items) = Vec::::decode_with_format(input, fmt)?; + Ok((input, Multiple(items))) + } + _ => { + let (input, item) = T::decode_with_format(input, fmt)?; + Ok((input, Multiple(vec![item]))) + } + } + } +} + +impl DecodeFormatted for List { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (mut input, header) = decode_list_header(input, fmt)?; + let mut result: Vec = Vec::with_capacity(header.count as usize); + for _ in 0..header.count { + let (new_input, decoded) = Variant::decode(input)?; + result.push(decoded); + input = new_input; + } + Ok((input, List(result))) + } +} + +impl DecodeFormatted for Variant { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + match fmt { + codec::FORMATCODE_NULL => Ok((input, Variant::Null)), + codec::FORMATCODE_BOOLEAN => { + bool::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Boolean(o))) + } + codec::FORMATCODE_BOOLEAN_FALSE => Ok((input, Variant::Boolean(false))), + codec::FORMATCODE_BOOLEAN_TRUE => Ok((input, Variant::Boolean(true))), + codec::FORMATCODE_UINT_0 => Ok((input, Variant::Uint(0))), + codec::FORMATCODE_ULONG_0 => Ok((input, Variant::Ulong(0))), + codec::FORMATCODE_UBYTE => { + u8::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Ubyte(o))) + } + codec::FORMATCODE_USHORT => { + u16::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Ushort(o))) + } + codec::FORMATCODE_UINT => { + u32::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Uint(o))) + } + codec::FORMATCODE_ULONG => { + u64::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Ulong(o))) + } + codec::FORMATCODE_BYTE => { + i8::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Byte(o))) + } + codec::FORMATCODE_SHORT => { + i16::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Short(o))) + } + codec::FORMATCODE_INT => { + i32::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Int(o))) + } + codec::FORMATCODE_LONG => { + i64::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Long(o))) + } + codec::FORMATCODE_SMALLUINT => { + u32::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Uint(o))) + } + codec::FORMATCODE_SMALLULONG => { + u64::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Ulong(o))) + } + codec::FORMATCODE_SMALLINT => { + i32::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Int(o))) + } + codec::FORMATCODE_SMALLLONG => { + i64::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Long(o))) + } + codec::FORMATCODE_FLOAT => f32::decode_with_format(input, fmt) + .map(|(i, o)| (i, Variant::Float(OrderedFloat(o)))), + codec::FORMATCODE_DOUBLE => f64::decode_with_format(input, fmt) + .map(|(i, o)| (i, Variant::Double(OrderedFloat(o)))), + // codec::FORMATCODE_DECIMAL32 => x::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Decimal(o))), + // codec::FORMATCODE_DECIMAL64 => x::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Decimal(o))), + // codec::FORMATCODE_DECIMAL128 => x::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Decimal(o))), + codec::FORMATCODE_CHAR => { + char::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Char(o))) + } + codec::FORMATCODE_TIMESTAMP => DateTime::::decode_with_format(input, fmt) + .map(|(i, o)| (i, Variant::Timestamp(o))), + codec::FORMATCODE_UUID => { + Uuid::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Uuid(o))) + } + codec::FORMATCODE_BINARY8 => { + Bytes::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Binary(o))) + } + codec::FORMATCODE_BINARY32 => { + Bytes::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Binary(o))) + } + codec::FORMATCODE_STRING8 => ByteString::decode_with_format(input, fmt) + .map(|(i, o)| (i, Variant::String(o.into()))), + codec::FORMATCODE_STRING32 => ByteString::decode_with_format(input, fmt) + .map(|(i, o)| (i, Variant::String(o.into()))), + codec::FORMATCODE_SYMBOL8 => { + Symbol::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Symbol(o))) + } + codec::FORMATCODE_SYMBOL32 => { + Symbol::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Symbol(o))) + } + codec::FORMATCODE_LIST0 => Ok((input, Variant::List(List(vec![])))), + codec::FORMATCODE_LIST8 => { + List::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::List(o))) + } + codec::FORMATCODE_LIST32 => { + List::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::List(o))) + } + codec::FORMATCODE_MAP8 => FxHashMap::::decode_with_format(input, fmt) + .map(|(i, o)| (i, Variant::Map(VariantMap::new(o)))), + codec::FORMATCODE_MAP32 => { + FxHashMap::::decode_with_format(input, fmt) + .map(|(i, o)| (i, Variant::Map(VariantMap::new(o)))) + } + // codec::FORMATCODE_ARRAY8 => Vec::::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Array(o))), + // codec::FORMATCODE_ARRAY32 => Vec::::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Array(o))), + codec::FORMATCODE_DESCRIBED => { + let (input, descriptor) = Descriptor::decode(input)?; + let (input, value) = Variant::decode(input)?; + Ok((input, Variant::Described((descriptor, Box::new(value))))) + } + _ => Err(AmqpParseError::InvalidFormatCode(fmt)), + } + } +} + +impl DecodeFormatted for Option { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + match fmt { + codec::FORMATCODE_NULL => Ok((input, None)), + _ => T::decode_with_format(input, fmt).map(|(i, o)| (i, Some(o))), + } + } +} + +impl DecodeFormatted for Descriptor { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + match fmt { + codec::FORMATCODE_SMALLULONG => { + u64::decode_with_format(input, fmt).map(|(i, o)| (i, Descriptor::Ulong(o))) + } + codec::FORMATCODE_ULONG => { + u64::decode_with_format(input, fmt).map(|(i, o)| (i, Descriptor::Ulong(o))) + } + codec::FORMATCODE_SYMBOL8 => { + Symbol::decode_with_format(input, fmt).map(|(i, o)| (i, Descriptor::Symbol(o))) + } + codec::FORMATCODE_SYMBOL32 => { + Symbol::decode_with_format(input, fmt).map(|(i, o)| (i, Descriptor::Symbol(o))) + } + _ => Err(AmqpParseError::InvalidFormatCode(fmt)), + } + } +} + +impl Decode for AmqpFrame { + fn decode(input: &[u8]) -> Result<(&[u8], Self), AmqpParseError> { + let (input, channel_id) = decode_frame_header(input, framing::FRAME_TYPE_AMQP)?; + let (input, performative) = protocol::Frame::decode(input)?; + Ok((input, AmqpFrame::new(channel_id, performative))) + } +} + +impl Decode for SaslFrame { + fn decode(input: &[u8]) -> Result<(&[u8], Self), AmqpParseError> { + let (input, _) = decode_frame_header(input, framing::FRAME_TYPE_SASL)?; + let (input, frame) = protocol::SaslFrameBody::decode(input)?; + Ok((input, SaslFrame { body: frame })) + } +} + +fn decode_frame_header( + input: &[u8], + expected_frame_type: u8, +) -> Result<(&[u8], u16), AmqpParseError> { + decode_check_len!(input, 4); + let doff = input[0]; + let frame_type = input[1]; + if frame_type != expected_frame_type { + return Err(AmqpParseError::UnexpectedFrameType(frame_type)); + } + + let channel_id = BigEndian::read_u16(&input[2..]); + let doff = doff as usize * 4; + if doff < HEADER_LEN { + return Err(AmqpParseError::InvalidSize); + } + let ext_header_len = doff - HEADER_LEN; + decode_check_len!(input, ext_header_len + 4); + let input = &input[ext_header_len + 4..]; // skipping remaining two header bytes and ext header + Ok((input, channel_id)) +} + +fn decode_array_header(input: &[u8], fmt: u8) -> Result<(&[u8], CompoundHeader), AmqpParseError> { + match fmt { + codec::FORMATCODE_ARRAY8 => decode_compound8(input), + codec::FORMATCODE_ARRAY32 => decode_compound32(input), + _ => Err(AmqpParseError::InvalidFormatCode(fmt)), + } +} + +pub(crate) fn decode_list_header( + input: &[u8], + fmt: u8, +) -> Result<(&[u8], CompoundHeader), AmqpParseError> { + match fmt { + codec::FORMATCODE_LIST0 => Ok((input, CompoundHeader::empty())), + codec::FORMATCODE_LIST8 => decode_compound8(input), + codec::FORMATCODE_LIST32 => decode_compound32(input), + _ => Err(AmqpParseError::InvalidFormatCode(fmt)), + } +} + +pub(crate) fn decode_map_header( + input: &[u8], + fmt: u8, +) -> Result<(&[u8], CompoundHeader), AmqpParseError> { + match fmt { + codec::FORMATCODE_MAP8 => decode_compound8(input), + codec::FORMATCODE_MAP32 => decode_compound32(input), + _ => Err(AmqpParseError::InvalidFormatCode(fmt)), + } +} + +fn decode_compound8(input: &[u8]) -> Result<(&[u8], CompoundHeader), AmqpParseError> { + decode_check_len!(input, 2); + let size = input[0] - 1; // -1 for 1 byte count + let count = input[1]; + Ok(( + &input[2..], + CompoundHeader { + size: u32::from(size), + count: u32::from(count), + }, + )) +} + +fn decode_compound32(input: &[u8]) -> Result<(&[u8], CompoundHeader), AmqpParseError> { + decode_check_len!(input, 8); + let size = BigEndian::read_u32(input) - 4; // -4 for 4 byte count + let count = BigEndian::read_u32(&input[4..]); + Ok((&input[8..], CompoundHeader { size, count })) +} + +fn datetime_from_millis(millis: i64) -> DateTime { + let seconds = millis / 1000; + if seconds < 0 { + // In order to handle time before 1970 correctly, we need to subtract a second + // and use the nanoseconds field to add it back. This is a result of the nanoseconds + // parameter being u32 + let nanoseconds = ((1000 + (millis - (seconds * 1000))) * 1_000_000).abs() as u32; + Utc.timestamp(seconds - 1, nanoseconds) + } else { + let nanoseconds = ((millis - (seconds * 1000)) * 1_000_000).abs() as u32; + Utc.timestamp(seconds, nanoseconds) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::codec::Encode; + use bytes::{BufMut, BytesMut}; + + const LOREM: &str = include_str!("lorem.txt"); + + macro_rules! decode_tests { + ($($name:ident: $kind:ident, $test:expr, $expected:expr,)*) => { + $( + #[test] + fn $name() { + let b1 = &mut BytesMut::with_capacity(($test).encoded_size()); + ($test).encode(b1); + assert_eq!($expected, unwrap_value($kind::decode(b1))); + } + )* + } + } + + decode_tests! { + ubyte: u8, 255_u8, 255_u8, + ushort: u16, 350_u16, 350_u16, + + uint_zero: u32, 0_u32, 0_u32, + uint_small: u32, 128_u32, 128_u32, + uint_big: u32, 2147483647_u32, 2147483647_u32, + + ulong_zero: u64, 0_u64, 0_u64, + ulong_small: u64, 128_u64, 128_u64, + uulong_big: u64, 2147483649_u64, 2147483649_u64, + + byte: i8, -128_i8, -128_i8, + short: i16, -255_i16, -255_i16, + + int_zero: i32, 0_i32, 0_i32, + int_small: i32, -50000_i32, -50000_i32, + int_neg: i32, -128_i32, -128_i32, + + long_zero: i64, 0_i64, 0_i64, + long_big: i64, -2147483647_i64, -2147483647_i64, + long_small: i64, -128_i64, -128_i64, + + float: f32, 1.234_f32, 1.234_f32, + double: f64, 1.234_f64, 1.234_f64, + + test_char: char, '💯', '💯', + + uuid: Uuid, Uuid::from_slice(&[4, 54, 67, 12, 43, 2, 98, 76, 32, 50, 87, 5, 1, 33, 43, 87]).expect("parse error"), + Uuid::parse_str("0436430c2b02624c2032570501212b57").expect("parse error"), + + binary_short: Bytes, Bytes::from(&[4u8, 5u8][..]), Bytes::from(&[4u8, 5u8][..]), + binary_long: Bytes, Bytes::from(&[4u8; 500][..]), Bytes::from(&[4u8; 500][..]), + + string_short: ByteString, ByteString::from("Hello there"), ByteString::from("Hello there"), + string_long: ByteString, ByteString::from(LOREM), ByteString::from(LOREM), + + // symbol_short: Symbol, Symbol::from("Hello there"), Symbol::from("Hello there"), + // symbol_long: Symbol, Symbol::from(LOREM), Symbol::from(LOREM), + + variant_ubyte: Variant, Variant::Ubyte(255_u8), Variant::Ubyte(255_u8), + variant_ushort: Variant, Variant::Ushort(350_u16), Variant::Ushort(350_u16), + + variant_uint_zero: Variant, Variant::Uint(0_u32), Variant::Uint(0_u32), + variant_uint_small: Variant, Variant::Uint(128_u32), Variant::Uint(128_u32), + variant_uint_big: Variant, Variant::Uint(2147483647_u32), Variant::Uint(2147483647_u32), + + variant_ulong_zero: Variant, Variant::Ulong(0_u64), Variant::Ulong(0_u64), + variant_ulong_small: Variant, Variant::Ulong(128_u64), Variant::Ulong(128_u64), + variant_ulong_big: Variant, Variant::Ulong(2147483649_u64), Variant::Ulong(2147483649_u64), + + variant_byte: Variant, Variant::Byte(-128_i8), Variant::Byte(-128_i8), + variant_short: Variant, Variant::Short(-255_i16), Variant::Short(-255_i16), + + variant_int_zero: Variant, Variant::Int(0_i32), Variant::Int(0_i32), + variant_int_small: Variant, Variant::Int(-50000_i32), Variant::Int(-50000_i32), + variant_int_neg: Variant, Variant::Int(-128_i32), Variant::Int(-128_i32), + + variant_long_zero: Variant, Variant::Long(0_i64), Variant::Long(0_i64), + variant_long_big: Variant, Variant::Long(-2147483647_i64), Variant::Long(-2147483647_i64), + variant_long_small: Variant, Variant::Long(-128_i64), Variant::Long(-128_i64), + + variant_float: Variant, Variant::Float(OrderedFloat(1.234_f32)), Variant::Float(OrderedFloat(1.234_f32)), + variant_double: Variant, Variant::Double(OrderedFloat(1.234_f64)), Variant::Double(OrderedFloat(1.234_f64)), + + variant_char: Variant, Variant::Char('💯'), Variant::Char('💯'), + + variant_uuid: Variant, Variant::Uuid(Uuid::from_slice(&[4, 54, 67, 12, 43, 2, 98, 76, 32, 50, 87, 5, 1, 33, 43, 87]).expect("parse error")), + Variant::Uuid(Uuid::parse_str("0436430c2b02624c2032570501212b57").expect("parse error")), + + variant_binary_short: Variant, Variant::Binary(Bytes::from(&[4u8, 5u8][..])), Variant::Binary(Bytes::from(&[4u8, 5u8][..])), + variant_binary_long: Variant, Variant::Binary(Bytes::from(&[4u8; 500][..])), Variant::Binary(Bytes::from(&[4u8; 500][..])), + + variant_string_short: Variant, Variant::String(ByteString::from("Hello there").into()), Variant::String(ByteString::from("Hello there").into()), + variant_string_long: Variant, Variant::String(ByteString::from(LOREM).into()), Variant::String(ByteString::from(LOREM).into()), + + // variant_symbol_short: Variant, Variant::Symbol(Symbol::from("Hello there")), Variant::Symbol(Symbol::from("Hello there")), + // variant_symbol_long: Variant, Variant::Symbol(Symbol::from(LOREM)), Variant::Symbol(Symbol::from(LOREM)), + } + + fn unwrap_value(res: Result<(&[u8], T), AmqpParseError>) -> T { + let r = res.map(|(_i, o)| o); + assert!(r.is_ok()); + r.unwrap() + } + + #[test] + fn test_bool_true() { + let b1 = &mut BytesMut::with_capacity(0); + b1.put_u8(0x41); + assert_eq!(true, unwrap_value(bool::decode(b1))); + + let b2 = &mut BytesMut::with_capacity(0); + b2.put_u8(0x56); + b2.put_u8(0x01); + assert_eq!(true, unwrap_value(bool::decode(b2))); + } + + #[test] + fn test_bool_false() { + let b1 = &mut BytesMut::with_capacity(0); + b1.put_u8(0x42u8); + assert_eq!(false, unwrap_value(bool::decode(b1))); + + let b2 = &mut BytesMut::with_capacity(0); + b2.put_u8(0x56); + b2.put_u8(0x00); + assert_eq!(false, unwrap_value(bool::decode(b2))); + } + + /// UTC with a precision of milliseconds. For example, 1311704463521 + /// represents the moment 2011-07-26T18:21:03.521Z. + #[test] + fn test_timestamp() { + let b1 = &mut BytesMut::with_capacity(0); + let datetime = Utc.ymd(2011, 7, 26).and_hms_milli(18, 21, 3, 521); + datetime.encode(b1); + + let expected = Utc.ymd(2011, 7, 26).and_hms_milli(18, 21, 3, 521); + assert_eq!(expected, unwrap_value(DateTime::::decode(b1))); + } + + #[test] + fn test_timestamp_pre_unix() { + let b1 = &mut BytesMut::with_capacity(0); + let datetime = Utc.ymd(1968, 7, 26).and_hms_milli(18, 21, 3, 521); + datetime.encode(b1); + + let expected = Utc.ymd(1968, 7, 26).and_hms_milli(18, 21, 3, 521); + assert_eq!(expected, unwrap_value(DateTime::::decode(b1))); + } + + #[test] + fn variant_null() { + let mut b = BytesMut::with_capacity(0); + Variant::Null.encode(&mut b); + let t = unwrap_value(Variant::decode(&mut b)); + assert_eq!(Variant::Null, t); + } + + #[test] + fn variant_bool_true() { + let b1 = &mut BytesMut::with_capacity(0); + b1.put_u8(0x41); + assert_eq!(Variant::Boolean(true), unwrap_value(Variant::decode(b1))); + + let b2 = &mut BytesMut::with_capacity(0); + b2.put_u8(0x56); + b2.put_u8(0x01); + assert_eq!(Variant::Boolean(true), unwrap_value(Variant::decode(b2))); + } + + #[test] + fn variant_bool_false() { + let b1 = &mut BytesMut::with_capacity(0); + b1.put_u8(0x42u8); + assert_eq!(Variant::Boolean(false), unwrap_value(Variant::decode(b1))); + + let b2 = &mut BytesMut::with_capacity(0); + b2.put_u8(0x56); + b2.put_u8(0x00); + assert_eq!(Variant::Boolean(false), unwrap_value(Variant::decode(b2))); + } + + /// UTC with a precision of milliseconds. For example, 1311704463521 + /// represents the moment 2011-07-26T18:21:03.521Z. + #[test] + fn variant_timestamp() { + let b1 = &mut BytesMut::with_capacity(0); + let datetime = Utc.ymd(2011, 7, 26).and_hms_milli(18, 21, 3, 521); + Variant::Timestamp(datetime).encode(b1); + + let expected = Utc.ymd(2011, 7, 26).and_hms_milli(18, 21, 3, 521); + assert_eq!( + Variant::Timestamp(expected), + unwrap_value(Variant::decode(b1)) + ); + } + + #[test] + fn variant_timestamp_pre_unix() { + let b1 = &mut BytesMut::with_capacity(0); + let datetime = Utc.ymd(1968, 7, 26).and_hms_milli(18, 21, 3, 521); + Variant::Timestamp(datetime).encode(b1); + + let expected = Utc.ymd(1968, 7, 26).and_hms_milli(18, 21, 3, 521); + assert_eq!( + Variant::Timestamp(expected), + unwrap_value(Variant::decode(b1)) + ); + } + + #[test] + fn option_i8() { + let b1 = &mut BytesMut::with_capacity(0); + Some(42i8).encode(b1); + + assert_eq!(Some(42), unwrap_value(Option::::decode(b1))); + + let b2 = &mut BytesMut::with_capacity(0); + let o1: Option = None; + o1.encode(b2); + + assert_eq!(None, unwrap_value(Option::::decode(b2))); + } + + #[test] + fn option_string() { + let b1 = &mut BytesMut::with_capacity(0); + Some(ByteString::from("hello")).encode(b1); + + assert_eq!( + Some(ByteString::from("hello")), + unwrap_value(Option::::decode(b1)) + ); + + let b2 = &mut BytesMut::with_capacity(0); + let o1: Option = None; + o1.encode(b2); + + assert_eq!(None, unwrap_value(Option::::decode(b2))); + } +} diff --git a/actix-amqp/codec/src/codec/encode.rs b/actix-amqp/codec/src/codec/encode.rs new file mode 100755 index 000000000..d0acc58d8 --- /dev/null +++ b/actix-amqp/codec/src/codec/encode.rs @@ -0,0 +1,786 @@ +use std::collections::HashMap; +use std::hash::{BuildHasher, Hash}; +use std::{i8, u8}; + +use bytes::{BufMut, Bytes, BytesMut}; +use bytestring::ByteString; +use chrono::{DateTime, Utc}; +use uuid::Uuid; + +use crate::codec::{self, ArrayEncode, Encode}; +use crate::framing::{self, AmqpFrame, SaslFrame}; +use crate::types::{ + Descriptor, List, Multiple, StaticSymbol, Str, Symbol, Variant, VecStringMap, VecSymbolMap, +}; + +fn encode_null(buf: &mut BytesMut) { + buf.put_u8(codec::FORMATCODE_NULL); +} + +pub trait FixedEncode {} + +impl Encode for T { + fn encoded_size(&self) -> usize { + self.array_encoded_size() + 1 + } + fn encode(&self, buf: &mut BytesMut) { + buf.put_u8(T::ARRAY_FORMAT_CODE); + self.array_encode(buf); + } +} + +impl Encode for bool { + fn encoded_size(&self) -> usize { + 1 + } + fn encode(&self, buf: &mut BytesMut) { + buf.put_u8(if *self { + codec::FORMATCODE_BOOLEAN_TRUE + } else { + codec::FORMATCODE_BOOLEAN_FALSE + }); + } +} +impl ArrayEncode for bool { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_BOOLEAN; + fn array_encoded_size(&self) -> usize { + 1 + } + fn array_encode(&self, buf: &mut BytesMut) { + buf.put_u8(if *self { 1 } else { 0 }); + } +} + +impl FixedEncode for u8 {} +impl ArrayEncode for u8 { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_UBYTE; + fn array_encoded_size(&self) -> usize { + 1 + } + fn array_encode(&self, buf: &mut BytesMut) { + buf.put_u8(*self); + } +} + +impl FixedEncode for u16 {} +impl ArrayEncode for u16 { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_USHORT; + fn array_encoded_size(&self) -> usize { + 2 + } + fn array_encode(&self, buf: &mut BytesMut) { + buf.put_u16(*self); + } +} + +impl Encode for u32 { + fn encoded_size(&self) -> usize { + if *self == 0 { + 1 + } else if *self > u32::from(u8::MAX) { + 5 + } else { + 2 + } + } + fn encode(&self, buf: &mut BytesMut) { + if *self == 0 { + buf.put_u8(codec::FORMATCODE_UINT_0) + } else if *self > u32::from(u8::MAX) { + buf.put_u8(codec::FORMATCODE_UINT); + buf.put_u32(*self); + } else { + buf.put_u8(codec::FORMATCODE_SMALLUINT); + buf.put_u8(*self as u8); + } + } +} +impl ArrayEncode for u32 { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_UINT; + fn array_encoded_size(&self) -> usize { + 4 + } + fn array_encode(&self, buf: &mut BytesMut) { + buf.put_u32(*self); + } +} + +impl Encode for u64 { + fn encoded_size(&self) -> usize { + if *self == 0 { + 1 + } else if *self > u64::from(u8::MAX) { + 9 + } else { + 2 + } + } + + fn encode(&self, buf: &mut BytesMut) { + if *self == 0 { + buf.put_u8(codec::FORMATCODE_ULONG_0) + } else if *self > u64::from(u8::MAX) { + buf.put_u8(codec::FORMATCODE_ULONG); + buf.put_u64(*self); + } else { + buf.put_u8(codec::FORMATCODE_SMALLULONG); + buf.put_u8(*self as u8); + } + } +} + +impl ArrayEncode for u64 { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_ULONG; + fn array_encoded_size(&self) -> usize { + 8 + } + fn array_encode(&self, buf: &mut BytesMut) { + buf.put_u64(*self); + } +} + +impl FixedEncode for i8 {} + +impl ArrayEncode for i8 { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_BYTE; + fn array_encoded_size(&self) -> usize { + 1 + } + fn array_encode(&self, buf: &mut BytesMut) { + buf.put_i8(*self); + } +} + +impl FixedEncode for i16 {} + +impl ArrayEncode for i16 { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_SHORT; + fn array_encoded_size(&self) -> usize { + 2 + } + fn array_encode(&self, buf: &mut BytesMut) { + buf.put_i16(*self); + } +} + +impl Encode for i32 { + fn encoded_size(&self) -> usize { + if *self > i32::from(i8::MAX) || *self < i32::from(i8::MIN) { + 5 + } else { + 2 + } + } + + fn encode(&self, buf: &mut BytesMut) { + if *self > i32::from(i8::MAX) || *self < i32::from(i8::MIN) { + buf.put_u8(codec::FORMATCODE_INT); + buf.put_i32(*self); + } else { + buf.put_u8(codec::FORMATCODE_SMALLINT); + buf.put_i8(*self as i8); + } + } +} + +impl ArrayEncode for i32 { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_INT; + + fn array_encoded_size(&self) -> usize { + 4 + } + + fn array_encode(&self, buf: &mut BytesMut) { + buf.put_i32(*self); + } +} + +impl Encode for i64 { + fn encoded_size(&self) -> usize { + if *self > i64::from(i8::MAX) || *self < i64::from(i8::MIN) { + 9 + } else { + 2 + } + } + + fn encode(&self, buf: &mut BytesMut) { + if *self > i64::from(i8::MAX) || *self < i64::from(i8::MIN) { + buf.put_u8(codec::FORMATCODE_LONG); + buf.put_i64(*self); + } else { + buf.put_u8(codec::FORMATCODE_SMALLLONG); + buf.put_i8(*self as i8); + } + } +} + +impl ArrayEncode for i64 { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_LONG; + fn array_encoded_size(&self) -> usize { + 8 + } + fn array_encode(&self, buf: &mut BytesMut) { + buf.put_i64(*self); + } +} + +impl FixedEncode for f32 {} + +impl ArrayEncode for f32 { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_FLOAT; + + fn array_encoded_size(&self) -> usize { + 4 + } + + fn array_encode(&self, buf: &mut BytesMut) { + buf.put_f32(*self); + } +} + +impl FixedEncode for f64 {} + +impl ArrayEncode for f64 { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_DOUBLE; + fn array_encoded_size(&self) -> usize { + 8 + } + fn array_encode(&self, buf: &mut BytesMut) { + buf.put_f64(*self); + } +} + +impl FixedEncode for char {} + +impl ArrayEncode for char { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_CHAR; + fn array_encoded_size(&self) -> usize { + 4 + } + fn array_encode(&self, buf: &mut BytesMut) { + buf.put_u32(*self as u32); + } +} + +impl FixedEncode for DateTime {} + +impl ArrayEncode for DateTime { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_TIMESTAMP; + fn array_encoded_size(&self) -> usize { + 8 + } + fn array_encode(&self, buf: &mut BytesMut) { + let timestamp = self.timestamp() * 1000 + i64::from(self.timestamp_subsec_millis()); + buf.put_i64(timestamp); + } +} + +impl FixedEncode for Uuid {} + +impl ArrayEncode for Uuid { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_UUID; + fn array_encoded_size(&self) -> usize { + 16 + } + fn array_encode(&self, buf: &mut BytesMut) { + buf.extend_from_slice(self.as_bytes()); + } +} + +impl Encode for Bytes { + fn encoded_size(&self) -> usize { + let length = self.len(); + let size = if length > u8::MAX as usize { 5 } else { 2 }; + size + length + } + + fn encode(&self, buf: &mut BytesMut) { + let length = self.len(); + if length > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_BINARY32); + buf.put_u32(length as u32); + } else { + buf.put_u8(codec::FORMATCODE_BINARY8); + buf.put_u8(length as u8); + } + buf.put_slice(self); + } +} + +impl ArrayEncode for Bytes { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_BINARY32; + fn array_encoded_size(&self) -> usize { + 4 + self.len() + } + fn array_encode(&self, buf: &mut BytesMut) { + buf.put_u32(self.len() as u32); + buf.put_slice(&self); + } +} + +impl Encode for ByteString { + fn encoded_size(&self) -> usize { + let length = self.len(); + let size = if length > u8::MAX as usize { 5 } else { 2 }; + size + length + } + + fn encode(&self, buf: &mut BytesMut) { + let length = self.len(); + if length > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_STRING32); + buf.put_u32(length as u32); + } else { + buf.put_u8(codec::FORMATCODE_STRING8); + buf.put_u8(length as u8); + } + buf.put_slice(self.as_bytes()); + } +} +impl ArrayEncode for ByteString { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_STRING32; + fn array_encoded_size(&self) -> usize { + 4 + self.len() + } + fn array_encode(&self, buf: &mut BytesMut) { + buf.put_u32(self.len() as u32); + buf.put_slice(self.as_bytes()); + } +} + +impl Encode for str { + fn encoded_size(&self) -> usize { + let length = self.len(); + let size = if length > u8::MAX as usize { 5 } else { 2 }; + size + length + } + + fn encode(&self, buf: &mut BytesMut) { + let length = self.len(); + if length > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_STRING32); + buf.put_u32(length as u32); + } else { + buf.put_u8(codec::FORMATCODE_STRING8); + buf.put_u8(length as u8); + } + buf.put_slice(self.as_bytes()); + } +} + +impl ArrayEncode for str { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_STRING32; + fn array_encoded_size(&self) -> usize { + 4 + self.len() + } + fn array_encode(&self, buf: &mut BytesMut) { + buf.put_u32(self.len() as u32); + buf.put_slice(self.as_bytes()); + } +} + +impl Encode for Str { + fn encoded_size(&self) -> usize { + let length = self.len(); + let size = if length > u8::MAX as usize { 5 } else { 2 }; + size + length + } + + fn encode(&self, buf: &mut BytesMut) { + let length = self.as_str().len(); + if length > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_STRING32); + buf.put_u32(length as u32); + } else { + buf.put_u8(codec::FORMATCODE_STRING8); + buf.put_u8(length as u8); + } + buf.put_slice(self.as_bytes()); + } +} + +impl Encode for Symbol { + fn encoded_size(&self) -> usize { + let length = self.len(); + let size = if length > u8::MAX as usize { 5 } else { 2 }; + size + length + } + + fn encode(&self, buf: &mut BytesMut) { + let length = self.as_str().len(); + if length > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_SYMBOL32); + buf.put_u32(length as u32); + } else { + buf.put_u8(codec::FORMATCODE_SYMBOL8); + buf.put_u8(length as u8); + } + buf.put_slice(self.as_bytes()); + } +} + +impl ArrayEncode for Symbol { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_SYMBOL32; + fn array_encoded_size(&self) -> usize { + 4 + self.len() + } + fn array_encode(&self, buf: &mut BytesMut) { + buf.put_u32(self.len() as u32); + buf.put_slice(self.as_bytes()); + } +} + +impl Encode for StaticSymbol { + fn encoded_size(&self) -> usize { + let length = self.0.len(); + let size = if length > u8::MAX as usize { 5 } else { 2 }; + size + length + } + + fn encode(&self, buf: &mut BytesMut) { + let length = self.0.len(); + if length > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_SYMBOL32); + buf.put_u32(length as u32); + } else { + buf.put_u8(codec::FORMATCODE_SYMBOL8); + buf.put_u8(length as u8); + } + buf.put_slice(self.0.as_bytes()); + } +} + +fn map_encoded_size( + map: &HashMap, +) -> usize { + map.iter() + .fold(0, |r, (k, v)| r + k.encoded_size() + v.encoded_size()) +} +impl Encode for HashMap { + fn encoded_size(&self) -> usize { + let size = map_encoded_size(self); + // f:1 + s:4 + c:4 vs f:1 + s:1 + c:1 + let preamble = if size + 1 > u8::MAX as usize { 9 } else { 3 }; + preamble + size + } + + fn encode(&self, buf: &mut BytesMut) { + let count = self.len() * 2; // key-value pair accounts for two items in count + let size = map_encoded_size(self); + if size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_MAP32); + buf.put_u32((size + 4) as u32); // +4 for 4 byte count that follows + buf.put_u32(count as u32); + } else { + buf.put_u8(codec::FORMATCODE_MAP8); + buf.put_u8((size + 1) as u8); // +1 for 1 byte count that follows + buf.put_u8(count as u8); + } + + for (k, v) in self { + k.encode(buf); + v.encode(buf); + } + } +} + +impl ArrayEncode for HashMap { + const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_MAP32; + fn array_encoded_size(&self) -> usize { + 8 + map_encoded_size(self) + } + fn array_encode(&self, buf: &mut BytesMut) { + let count = self.len() * 2; + let size = map_encoded_size(self) + 4; + buf.put_u32(size as u32); + buf.put_u32(count as u32); + + for (k, v) in self { + k.encode(buf); + v.encode(buf); + } + } +} + +impl Encode for VecSymbolMap { + fn encoded_size(&self) -> usize { + let size = self + .0 + .iter() + .fold(0, |r, (k, v)| r + k.encoded_size() + v.encoded_size()); + + // f:1 + s:4 + c:4 vs f:1 + s:1 + c:1 + let preamble = if size + 1 > u8::MAX as usize { 9 } else { 3 }; + preamble + size + } + + fn encode(&self, buf: &mut BytesMut) { + let count = self.len() * 2; // key-value pair accounts for two items in count + let size = self + .0 + .iter() + .fold(0, |r, (k, v)| r + k.encoded_size() + v.encoded_size()); + + if size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_MAP32); + buf.put_u32((size + 4) as u32); // +4 for 4 byte count that follows + buf.put_u32(count as u32); + } else { + buf.put_u8(codec::FORMATCODE_MAP8); + buf.put_u8((size + 1) as u8); // +1 for 1 byte count that follows + buf.put_u8(count as u8); + } + + for (k, v) in self.iter() { + k.encode(buf); + v.encode(buf); + } + } +} + +impl Encode for VecStringMap { + fn encoded_size(&self) -> usize { + let size = self + .0 + .iter() + .fold(0, |r, (k, v)| r + k.encoded_size() + v.encoded_size()); + + // f:1 + s:4 + c:4 vs f:1 + s:1 + c:1 + let preamble = if size + 1 > u8::MAX as usize { 9 } else { 3 }; + preamble + size + } + + fn encode(&self, buf: &mut BytesMut) { + let count = self.len() * 2; // key-value pair accounts for two items in count + let size = self + .0 + .iter() + .fold(0, |r, (k, v)| r + k.encoded_size() + v.encoded_size()); + + if size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_MAP32); + buf.put_u32((size + 4) as u32); // +4 for 4 byte count that follows + buf.put_u32(count as u32); + } else { + buf.put_u8(codec::FORMATCODE_MAP8); + buf.put_u8((size + 1) as u8); // +1 for 1 byte count that follows + buf.put_u8(count as u8); + } + + for (k, v) in self.iter() { + k.encode(buf); + v.encode(buf); + } + } +} + +fn array_encoded_size(vec: &[T]) -> usize { + vec.iter().fold(0, |r, i| r + i.array_encoded_size()) +} + +impl Encode for Vec { + fn encoded_size(&self) -> usize { + let content_size = array_encoded_size(self); + // format_code + size + count + item constructor -- todo: support described ctor? + (if content_size + 1 > u8::MAX as usize { + 10 + } else { + 4 + }) // +1 for 1 byte count and 1 byte format code + + content_size + } + + fn encode(&self, buf: &mut BytesMut) { + let size = array_encoded_size(self); + if size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_ARRAY32); + buf.put_u32((size + 5) as u32); // +4 for 4 byte count and 1 byte item ctor that follow + buf.put_u32(self.len() as u32); + } else { + buf.put_u8(codec::FORMATCODE_ARRAY8); + buf.put_u8((size + 2) as u8); // +1 for 1 byte count and 1 byte item ctor that follow + buf.put_u8(self.len() as u8); + } + buf.put_u8(T::ARRAY_FORMAT_CODE); + for i in self { + i.array_encode(buf); + } + } +} + +impl Encode for Multiple { + fn encoded_size(&self) -> usize { + let count = self.len(); + if count == 1 { + // special case: single item is encoded without array encoding + self.0[0].encoded_size() + } else { + self.0.encoded_size() + } + } + + fn encode(&self, buf: &mut BytesMut) { + let count = self.0.len(); + if count == 1 { + // special case: single item is encoded without array encoding + self.0[0].encode(buf) + } else { + self.0.encode(buf) + } + } +} + +fn list_encoded_size(vec: &List) -> usize { + vec.iter().fold(0, |r, i| r + i.encoded_size()) +} + +impl Encode for List { + fn encoded_size(&self) -> usize { + let content_size = list_encoded_size(self); + // format_code + size + count + (if content_size + 1 > u8::MAX as usize { + 9 + } else { + 3 + }) + content_size + } + + fn encode(&self, buf: &mut BytesMut) { + let size = list_encoded_size(self); + if size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_ARRAY32); + buf.put_u32((size + 4) as u32); // +4 for 4 byte count that follow + buf.put_u32(self.len() as u32); + } else { + buf.put_u8(codec::FORMATCODE_ARRAY8); + buf.put_u8((size + 1) as u8); // +1 for 1 byte count that follow + buf.put_u8(self.len() as u8); + } + for i in self.iter() { + i.encode(buf); + } + } +} + +impl Encode for Variant { + fn encoded_size(&self) -> usize { + match *self { + Variant::Null => 1, + Variant::Boolean(b) => b.encoded_size(), + Variant::Ubyte(b) => b.encoded_size(), + Variant::Ushort(s) => s.encoded_size(), + Variant::Uint(i) => i.encoded_size(), + Variant::Ulong(l) => l.encoded_size(), + Variant::Byte(b) => b.encoded_size(), + Variant::Short(s) => s.encoded_size(), + Variant::Int(i) => i.encoded_size(), + Variant::Long(l) => l.encoded_size(), + Variant::Float(f) => f.encoded_size(), + Variant::Double(d) => d.encoded_size(), + Variant::Char(c) => c.encoded_size(), + Variant::Timestamp(ref t) => t.encoded_size(), + Variant::Uuid(ref u) => u.encoded_size(), + Variant::Binary(ref b) => b.encoded_size(), + Variant::String(ref s) => s.encoded_size(), + Variant::Symbol(ref s) => s.encoded_size(), + Variant::StaticSymbol(ref s) => s.encoded_size(), + Variant::List(ref l) => l.encoded_size(), + Variant::Map(ref m) => m.map.encoded_size(), + Variant::Described(ref dv) => dv.0.encoded_size() + dv.1.encoded_size(), + } + } + + /// Encodes `Variant` into provided `BytesMut` + fn encode(&self, buf: &mut BytesMut) { + match *self { + Variant::Null => encode_null(buf), + Variant::Boolean(b) => b.encode(buf), + Variant::Ubyte(b) => b.encode(buf), + Variant::Ushort(s) => s.encode(buf), + Variant::Uint(i) => i.encode(buf), + Variant::Ulong(l) => l.encode(buf), + Variant::Byte(b) => b.encode(buf), + Variant::Short(s) => s.encode(buf), + Variant::Int(i) => i.encode(buf), + Variant::Long(l) => l.encode(buf), + Variant::Float(f) => f.encode(buf), + Variant::Double(d) => d.encode(buf), + Variant::Char(c) => c.encode(buf), + Variant::Timestamp(ref t) => t.encode(buf), + Variant::Uuid(ref u) => u.encode(buf), + Variant::Binary(ref b) => b.encode(buf), + Variant::String(ref s) => s.encode(buf), + Variant::Symbol(ref s) => s.encode(buf), + Variant::StaticSymbol(ref s) => s.encode(buf), + Variant::List(ref l) => l.encode(buf), + Variant::Map(ref m) => m.map.encode(buf), + Variant::Described(ref dv) => { + dv.0.encode(buf); + dv.1.encode(buf); + } + } + } +} + +impl Encode for Option { + fn encoded_size(&self) -> usize { + self.as_ref().map_or(1, |v| v.encoded_size()) + } + + fn encode(&self, buf: &mut BytesMut) { + match *self { + Some(ref e) => e.encode(buf), + None => encode_null(buf), + } + } +} + +impl Encode for Descriptor { + fn encoded_size(&self) -> usize { + match *self { + Descriptor::Ulong(v) => 1 + v.encoded_size(), + Descriptor::Symbol(ref v) => 1 + v.encoded_size(), + } + } + + fn encode(&self, buf: &mut BytesMut) { + buf.put_u8(codec::FORMATCODE_DESCRIBED); + match *self { + Descriptor::Ulong(v) => v.encode(buf), + Descriptor::Symbol(ref v) => v.encode(buf), + } + } +} + +const WORD_LEN: usize = 4; + +impl Encode for AmqpFrame { + fn encoded_size(&self) -> usize { + framing::HEADER_LEN + self.performative().encoded_size() + } + + fn encode(&self, buf: &mut BytesMut) { + let doff: u8 = (framing::HEADER_LEN / WORD_LEN) as u8; + buf.put_u32(self.encoded_size() as u32); + buf.put_u8(doff); + buf.put_u8(framing::FRAME_TYPE_AMQP); + buf.put_u16(self.channel_id()); + self.performative().encode(buf); + } +} + +impl Encode for SaslFrame { + fn encoded_size(&self) -> usize { + framing::HEADER_LEN + self.body.encoded_size() + } + + fn encode(&self, buf: &mut BytesMut) { + let doff: u8 = (framing::HEADER_LEN / WORD_LEN) as u8; + buf.put_u32(self.encoded_size() as u32); + buf.put_u8(doff); + buf.put_u8(framing::FRAME_TYPE_SASL); + buf.put_u16(0); + self.body.encode(buf); + } +} diff --git a/actix-amqp/codec/src/codec/lorem.txt b/actix-amqp/codec/src/codec/lorem.txt new file mode 100755 index 000000000..fdc257a0a --- /dev/null +++ b/actix-amqp/codec/src/codec/lorem.txt @@ -0,0 +1 @@ +Lorem ipsum dolor sit amet, consectetur adipiscing elit. Donec accumsan iaculis ipsum sed convallis. Phasellus consectetur justo et odio maximus, vel vehicula sapien venenatis. Nunc ac viverra risus. Pellentesque elementum, mauris et viverra ultricies, lacus erat varius nulla, eget maximus nisl sed. \ No newline at end of file diff --git a/actix-amqp/codec/src/codec/mod.rs b/actix-amqp/codec/src/codec/mod.rs new file mode 100755 index 000000000..e2cbc73c5 --- /dev/null +++ b/actix-amqp/codec/src/codec/mod.rs @@ -0,0 +1,150 @@ +use bytes::BytesMut; +use std::marker::Sized; + +use crate::errors::AmqpParseError; + +macro_rules! decode_check_len { + ($buf:ident, $size:expr) => { + if $buf.len() < $size { + return Err(AmqpParseError::Incomplete(Some($size))); + } + }; +} + +#[macro_use] +mod decode; +mod encode; + +pub(crate) use self::decode::decode_list_header; + +pub trait Encode { + fn encoded_size(&self) -> usize; + + fn encode(&self, buf: &mut BytesMut); +} + +pub trait ArrayEncode { + const ARRAY_FORMAT_CODE: u8; + + fn array_encoded_size(&self) -> usize; + + fn array_encode(&self, buf: &mut BytesMut); +} + +pub trait Decode +where + Self: Sized, +{ + fn decode(input: &[u8]) -> Result<(&[u8], Self), AmqpParseError>; +} + +pub trait DecodeFormatted +where + Self: Sized, +{ + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError>; +} + +pub trait ArrayDecode: Sized { + fn array_decode(input: &[u8]) -> Result<(&[u8], Self), AmqpParseError>; +} + +impl Decode for T { + fn decode(input: &[u8]) -> Result<(&[u8], Self), AmqpParseError> { + let (input, fmt) = decode_format_code(input)?; + T::decode_with_format(input, fmt) + } +} + +pub fn decode_format_code(input: &[u8]) -> Result<(&[u8], u8), AmqpParseError> { + decode_check_len!(input, 1); + Ok((&input[1..], input[0])) +} + +pub const FORMATCODE_DESCRIBED: u8 = 0x00; +pub const FORMATCODE_NULL: u8 = 0x40; // fixed width --V +pub const FORMATCODE_BOOLEAN: u8 = 0x56; +pub const FORMATCODE_BOOLEAN_TRUE: u8 = 0x41; +pub const FORMATCODE_BOOLEAN_FALSE: u8 = 0x42; +pub const FORMATCODE_UINT_0: u8 = 0x43; +pub const FORMATCODE_ULONG_0: u8 = 0x44; +pub const FORMATCODE_UBYTE: u8 = 0x50; +pub const FORMATCODE_USHORT: u8 = 0x60; +pub const FORMATCODE_UINT: u8 = 0x70; +pub const FORMATCODE_ULONG: u8 = 0x80; +pub const FORMATCODE_BYTE: u8 = 0x51; +pub const FORMATCODE_SHORT: u8 = 0x61; +pub const FORMATCODE_INT: u8 = 0x71; +pub const FORMATCODE_LONG: u8 = 0x81; +pub const FORMATCODE_SMALLUINT: u8 = 0x52; +pub const FORMATCODE_SMALLULONG: u8 = 0x53; +pub const FORMATCODE_SMALLINT: u8 = 0x54; +pub const FORMATCODE_SMALLLONG: u8 = 0x55; +pub const FORMATCODE_FLOAT: u8 = 0x72; +pub const FORMATCODE_DOUBLE: u8 = 0x82; +// pub const FORMATCODE_DECIMAL32: u8 = 0x74; +// pub const FORMATCODE_DECIMAL64: u8 = 0x84; +// pub const FORMATCODE_DECIMAL128: u8 = 0x94; +pub const FORMATCODE_CHAR: u8 = 0x73; +pub const FORMATCODE_TIMESTAMP: u8 = 0x83; +pub const FORMATCODE_UUID: u8 = 0x98; +pub const FORMATCODE_BINARY8: u8 = 0xa0; // variable --V +pub const FORMATCODE_BINARY32: u8 = 0xb0; +pub const FORMATCODE_STRING8: u8 = 0xa1; +pub const FORMATCODE_STRING32: u8 = 0xb1; +pub const FORMATCODE_SYMBOL8: u8 = 0xa3; +pub const FORMATCODE_SYMBOL32: u8 = 0xb3; +pub const FORMATCODE_LIST0: u8 = 0x45; // compound --V +pub const FORMATCODE_LIST8: u8 = 0xc0; +pub const FORMATCODE_LIST32: u8 = 0xd0; +pub const FORMATCODE_MAP8: u8 = 0xc1; +pub const FORMATCODE_MAP32: u8 = 0xd1; +pub const FORMATCODE_ARRAY8: u8 = 0xe0; +pub const FORMATCODE_ARRAY32: u8 = 0xf0; + +#[cfg(test)] +mod tests { + use bytes::{Bytes, BytesMut}; + + use crate::codec::{Decode, Encode}; + use crate::errors::AmqpCodecError; + use crate::framing::{AmqpFrame, SaslFrame}; + use crate::protocol::SaslFrameBody; + + #[test] + fn test_sasl_mechanisms() -> Result<(), AmqpCodecError> { + let data = b"\x02\x01\0\0\0S@\xc02\x01\xe0/\x04\xb3\0\0\0\x07MSSBCBS\0\0\0\x05PLAIN\0\0\0\tANONYMOUS\0\0\0\x08EXTERNAL"; + + let (remainder, frame) = SaslFrame::decode(data.as_ref())?; + assert!(remainder.is_empty()); + match frame.body { + SaslFrameBody::SaslMechanisms(_) => (), + _ => panic!("error"), + } + + let mut buf = BytesMut::new(); + buf.reserve(frame.encoded_size()); + frame.encode(&mut buf); + buf.split_to(4); + assert_eq!(Bytes::from_static(data), buf.freeze()); + + Ok(()) + } + + #[test] + fn test_disposition() -> Result<(), AmqpCodecError> { + let data = b"\x02\0\0\0\0S\x15\xc0\x0c\x06AC@A\0S$\xc0\x01\0B"; + + let (remainder, frame) = AmqpFrame::decode(data.as_ref())?; + assert!(remainder.is_empty()); + assert_eq!(frame.performative().name(), "Disposition"); + + let mut buf = BytesMut::new(); + buf.reserve(frame.encoded_size()); + frame.encode(&mut buf); + buf.split_to(4); + assert_eq!(Bytes::from_static(data), buf.freeze()); + + Ok(()) + } +} diff --git a/actix-amqp/codec/src/errors.rs b/actix-amqp/codec/src/errors.rs new file mode 100755 index 000000000..0da95ca86 --- /dev/null +++ b/actix-amqp/codec/src/errors.rs @@ -0,0 +1,73 @@ +use uuid; + +use crate::protocol::ProtocolId; +use crate::types::Descriptor; + +#[derive(Debug, Display, From, Clone)] +pub enum AmqpParseError { + #[display(fmt = "Loaded item size is invalid")] + InvalidSize, + #[display(fmt = "More data required during frame parsing: '{:?}'", "_0")] + Incomplete(Option), + #[from(ignore)] + #[display(fmt = "Unexpected format code: '{}'", "_0")] + InvalidFormatCode(u8), + #[display(fmt = "Invalid value converting to char: {}", "_0")] + InvalidChar(u32), + #[display(fmt = "Unexpected descriptor: '{:?}'", "_0")] + InvalidDescriptor(Descriptor), + #[from(ignore)] + #[display(fmt = "Unexpected frame type: '{:?}'", "_0")] + UnexpectedFrameType(u8), + #[from(ignore)] + #[display(fmt = "Required field '{:?}' was omitted.", "_0")] + RequiredFieldOmitted(&'static str), + #[from(ignore)] + #[display(fmt = "Unknown {:?} option.", "_0")] + UnknownEnumOption(&'static str), + UuidParseError(uuid::Error), + Utf8Error(std::str::Utf8Error), +} + +#[derive(Debug, Display, From)] +pub enum AmqpCodecError { + ParseError(AmqpParseError), + #[display(fmt = "bytes left unparsed at the frame trail")] + UnparsedBytesLeft, + #[display(fmt = "max inbound frame size exceeded")] + MaxSizeExceeded, + #[display(fmt = "Io error: {:?}", _0)] + Io(Option), +} + +impl Clone for AmqpCodecError { + fn clone(&self) -> AmqpCodecError { + match self { + AmqpCodecError::ParseError(err) => AmqpCodecError::ParseError(err.clone()), + AmqpCodecError::UnparsedBytesLeft => AmqpCodecError::UnparsedBytesLeft, + AmqpCodecError::MaxSizeExceeded => AmqpCodecError::MaxSizeExceeded, + AmqpCodecError::Io(_) => AmqpCodecError::Io(None), + } + } +} + +impl From for AmqpCodecError { + fn from(err: std::io::Error) -> AmqpCodecError { + AmqpCodecError::Io(Some(err)) + } +} + +#[derive(Debug, Display, From)] +pub enum ProtocolIdError { + InvalidHeader, + Incompatible, + Unknown, + #[display(fmt = "Expected {:?} protocol id, seen {:?} instead.", exp, got)] + Unexpected { + exp: ProtocolId, + got: ProtocolId, + }, + Disconnected, + #[display(fmt = "io error: {:?}", "_0")] + Io(std::io::Error), +} diff --git a/actix-amqp/codec/src/framing.rs b/actix-amqp/codec/src/framing.rs new file mode 100755 index 000000000..681fa7ab3 --- /dev/null +++ b/actix-amqp/codec/src/framing.rs @@ -0,0 +1,80 @@ +use super::protocol; + +/// Length in bytes of the fixed frame header +pub const HEADER_LEN: usize = 8; + +/// AMQP Frame type marker (0) +pub const FRAME_TYPE_AMQP: u8 = 0x00; +pub const FRAME_TYPE_SASL: u8 = 0x01; + +/// Represents an AMQP Frame +#[derive(Clone, Debug, PartialEq)] +pub struct AmqpFrame { + channel_id: u16, + performative: protocol::Frame, +} + +impl AmqpFrame { + pub fn new(channel_id: u16, performative: protocol::Frame) -> AmqpFrame { + AmqpFrame { + channel_id, + performative, + } + } + + #[inline] + pub fn channel_id(&self) -> u16 { + self.channel_id + } + + #[inline] + pub fn performative(&self) -> &protocol::Frame { + &self.performative + } + + #[inline] + pub fn into_parts(self) -> (u16, protocol::Frame) { + (self.channel_id, self.performative) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct SaslFrame { + pub body: protocol::SaslFrameBody, +} + +impl SaslFrame { + pub fn new(body: protocol::SaslFrameBody) -> SaslFrame { + SaslFrame { body } + } +} + +impl From for SaslFrame { + fn from(item: protocol::SaslMechanisms) -> SaslFrame { + SaslFrame::new(protocol::SaslFrameBody::SaslMechanisms(item)) + } +} + +impl From for SaslFrame { + fn from(item: protocol::SaslInit) -> SaslFrame { + SaslFrame::new(protocol::SaslFrameBody::SaslInit(item)) + } +} + +impl From for SaslFrame { + fn from(item: protocol::SaslChallenge) -> SaslFrame { + SaslFrame::new(protocol::SaslFrameBody::SaslChallenge(item)) + } +} + +impl From for SaslFrame { + fn from(item: protocol::SaslResponse) -> SaslFrame { + SaslFrame::new(protocol::SaslFrameBody::SaslResponse(item)) + } +} + +impl From for SaslFrame { + fn from(item: protocol::SaslOutcome) -> SaslFrame { + SaslFrame::new(protocol::SaslFrameBody::SaslOutcome(item)) + } +} diff --git a/actix-amqp/codec/src/io.rs b/actix-amqp/codec/src/io.rs new file mode 100755 index 000000000..dc322de03 --- /dev/null +++ b/actix-amqp/codec/src/io.rs @@ -0,0 +1,160 @@ +use std::marker::PhantomData; + +use actix_codec::{Decoder, Encoder}; +use byteorder::{BigEndian, ByteOrder}; +use bytes::{BufMut, BytesMut}; + +use super::errors::{AmqpCodecError, ProtocolIdError}; +use super::framing::HEADER_LEN; +use crate::codec::{Decode, Encode}; +use crate::protocol::ProtocolId; + +const SIZE_LOW_WM: usize = 4096; +const SIZE_HIGH_WM: usize = 32768; + +#[derive(Debug)] +pub struct AmqpCodec { + state: DecodeState, + max_size: usize, + phantom: PhantomData, +} + +#[derive(Debug, Clone, Copy)] +enum DecodeState { + FrameHeader, + Frame(usize), +} + +impl Default for AmqpCodec { + fn default() -> Self { + Self::new() + } +} + +impl AmqpCodec { + pub fn new() -> AmqpCodec { + AmqpCodec { + state: DecodeState::FrameHeader, + max_size: 0, + phantom: PhantomData, + } + } + + /// Set max inbound frame size. + /// + /// If max size is set to `0`, size is unlimited. + /// By default max size is set to `0` + pub fn max_size(&mut self, size: usize) { + self.max_size = size; + } +} + +impl Decoder for AmqpCodec { + type Item = T; + type Error = AmqpCodecError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + loop { + match self.state { + DecodeState::FrameHeader => { + let len = src.len(); + if len < HEADER_LEN { + return Ok(None); + } + + // read frame size + let size = BigEndian::read_u32(src.as_ref()) as usize; + if self.max_size != 0 && size > self.max_size { + return Err(AmqpCodecError::MaxSizeExceeded); + } + self.state = DecodeState::Frame(size - 4); + src.split_to(4); + + if len < size { + // extend receiving buffer to fit the whole frame + if src.remaining_mut() < std::cmp::max(SIZE_LOW_WM, size + HEADER_LEN) { + src.reserve(SIZE_HIGH_WM); + } + return Ok(None); + } + } + DecodeState::Frame(size) => { + if src.len() < size { + return Ok(None); + } + + let frame_buf = src.split_to(size); + let (remainder, frame) = T::decode(frame_buf.as_ref())?; + if !remainder.is_empty() { + // todo: could it really happen? + return Err(AmqpCodecError::UnparsedBytesLeft); + } + self.state = DecodeState::FrameHeader; + return Ok(Some(frame)); + } + } + } + } +} + +impl Encoder for AmqpCodec { + type Item = T; + type Error = AmqpCodecError; + + fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { + let size = item.encoded_size(); + let need = std::cmp::max(SIZE_LOW_WM, size); + if dst.remaining_mut() < need { + dst.reserve(std::cmp::max(need, SIZE_HIGH_WM)); + } + + item.encode(dst); + Ok(()) + } +} + +const PROTOCOL_HEADER_LEN: usize = 8; +const PROTOCOL_HEADER_PREFIX: &[u8] = b"AMQP"; +const PROTOCOL_VERSION: &[u8] = &[1, 0, 0]; + +#[derive(Default, Debug)] +pub struct ProtocolIdCodec; + +impl Decoder for ProtocolIdCodec { + type Item = ProtocolId; + type Error = ProtocolIdError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + if src.len() < PROTOCOL_HEADER_LEN { + Ok(None) + } else { + let src = src.split_to(8); + if &src[0..4] != PROTOCOL_HEADER_PREFIX { + Err(ProtocolIdError::InvalidHeader) + } else if &src[5..8] != PROTOCOL_VERSION { + Err(ProtocolIdError::Incompatible) + } else { + let protocol_id = src[4]; + match protocol_id { + 0 => Ok(Some(ProtocolId::Amqp)), + 2 => Ok(Some(ProtocolId::AmqpTls)), + 3 => Ok(Some(ProtocolId::AmqpSasl)), + _ => Err(ProtocolIdError::Unknown), + } + } + } + } +} + +impl Encoder for ProtocolIdCodec { + type Item = ProtocolId; + type Error = ProtocolIdError; + + fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { + dst.reserve(PROTOCOL_HEADER_LEN); + dst.put_slice(PROTOCOL_HEADER_PREFIX); + dst.put_u8(item as u8); + dst.put_slice(PROTOCOL_VERSION); + Ok(()) + } +} diff --git a/actix-amqp/codec/src/lib.rs b/actix-amqp/codec/src/lib.rs new file mode 100755 index 000000000..e944f9188 --- /dev/null +++ b/actix-amqp/codec/src/lib.rs @@ -0,0 +1,17 @@ +#[macro_use] +extern crate derive_more; + +#[macro_use] +mod codec; +mod errors; +mod framing; +mod io; +mod message; +pub mod protocol; +pub mod types; + +pub use self::codec::{Decode, Encode}; +pub use self::errors::{AmqpCodecError, AmqpParseError, ProtocolIdError}; +pub use self::framing::{AmqpFrame, SaslFrame}; +pub use self::io::{AmqpCodec, ProtocolIdCodec}; +pub use self::message::{InMessage, MessageBody, OutMessage}; diff --git a/actix-amqp/codec/src/message/body.rs b/actix-amqp/codec/src/message/body.rs new file mode 100755 index 000000000..4b7cb53e4 --- /dev/null +++ b/actix-amqp/codec/src/message/body.rs @@ -0,0 +1,89 @@ +use bytes::{BufMut, Bytes, BytesMut}; + +use crate::codec::{Encode, FORMATCODE_BINARY32, FORMATCODE_BINARY8}; +use crate::protocol::TransferBody; +use crate::types::{Descriptor, List, Variant}; + +use super::SECTION_PREFIX_LENGTH; + +#[derive(Debug, Clone, Default, PartialEq)] +pub struct MessageBody { + pub data: Vec, + pub sequence: Vec, + pub messages: Vec, + pub value: Option, +} + +impl MessageBody { + pub fn data(&self) -> Option<&Bytes> { + if self.data.is_empty() { + None + } else { + Some(&self.data[0]) + } + } + + pub fn value(&self) -> Option<&Variant> { + self.value.as_ref() + } + + pub fn set_data(&mut self, data: Bytes) { + self.data.clear(); + self.data.push(data); + } +} + +impl Encode for MessageBody { + fn encoded_size(&self) -> usize { + let mut size = self + .data + .iter() + .fold(0, |a, d| a + d.encoded_size() + SECTION_PREFIX_LENGTH); + size += self + .sequence + .iter() + .fold(0, |a, seq| a + seq.encoded_size() + SECTION_PREFIX_LENGTH); + size += self.messages.iter().fold(0, |a, m| { + let length = m.encoded_size(); + let size = length + if length > std::u8::MAX as usize { 5 } else { 2 }; + a + size + SECTION_PREFIX_LENGTH + }); + + if let Some(ref val) = self.value { + size + val.encoded_size() + SECTION_PREFIX_LENGTH + } else { + size + } + } + + fn encode(&self, dst: &mut BytesMut) { + self.data.iter().for_each(|d| { + Descriptor::Ulong(117).encode(dst); + d.encode(dst); + }); + self.sequence.iter().for_each(|seq| { + Descriptor::Ulong(118).encode(dst); + seq.encode(dst) + }); + if let Some(ref val) = self.value { + Descriptor::Ulong(119).encode(dst); + val.encode(dst); + } + // encode Message as nested Bytes object + self.messages.iter().for_each(|m| { + Descriptor::Ulong(117).encode(dst); + + // Bytes prefix + let length = m.encoded_size(); + if length > std::u8::MAX as usize { + dst.put_u8(FORMATCODE_BINARY32); + dst.put_u32(length as u32); + } else { + dst.put_u8(FORMATCODE_BINARY8); + dst.put_u8(length as u8); + } + // encode nested Message + m.encode(dst); + }); + } +} diff --git a/actix-amqp/codec/src/message/inmessage.rs b/actix-amqp/codec/src/message/inmessage.rs new file mode 100755 index 000000000..8c5d262c7 --- /dev/null +++ b/actix-amqp/codec/src/message/inmessage.rs @@ -0,0 +1,406 @@ +use std::cell::Cell; + +use bytes::{BufMut, Bytes, BytesMut}; +use fxhash::FxHashMap; + +use crate::codec::{Decode, Encode, FORMATCODE_BINARY8}; +use crate::errors::AmqpParseError; +use crate::protocol::{ + Annotations, Header, MessageFormat, Properties, Section, StringVariantMap, TransferBody, +}; +use crate::types::{Descriptor, Str, Variant}; + +use super::body::MessageBody; +use super::outmessage::OutMessage; +use super::SECTION_PREFIX_LENGTH; + +#[derive(Debug, Clone, Default, PartialEq)] +pub struct InMessage { + pub message_format: Option, + pub(super) header: Option
, + pub(super) delivery_annotations: Option, + pub(super) message_annotations: Option, + pub(super) properties: Option, + pub(super) application_properties: Option, + pub(super) footer: Option, + pub(super) body: MessageBody, + pub(super) size: Cell, +} + +impl InMessage { + /// Create new message and set body + pub fn with_body(body: Bytes) -> InMessage { + let mut msg = InMessage::default(); + msg.body.data.push(body); + msg + } + + /// Create new message and set messages as body + pub fn with_messages(messages: Vec) -> InMessage { + let mut msg = InMessage::default(); + msg.body.messages = messages; + msg + } + + /// Header + pub fn header(&self) -> Option<&Header> { + self.header.as_ref() + } + + /// Set message header + pub fn set_header(mut self, header: Header) -> Self { + self.header = Some(header); + self.size.set(0); + self + } + + /// Message properties + pub fn properties(&self) -> Option<&Properties> { + self.properties.as_ref() + } + + /// Add property + pub fn set_properties(mut self, f: F) -> Self + where + F: Fn(&mut Properties), + { + if let Some(ref mut props) = self.properties { + f(props); + } else { + let mut props = Properties::default(); + f(&mut props); + self.properties = Some(props); + } + self.size.set(0); + self + } + + /// Get application property + pub fn app_property(&self, key: &str) -> Option<&Variant> { + if let Some(ref props) = self.application_properties { + props.get(key) + } else { + None + } + } + + /// Get application properties + pub fn app_properties(&self) -> Option<&StringVariantMap> { + self.application_properties.as_ref() + } + + /// Get message annotation + pub fn message_annotation(&self, key: &str) -> Option<&Variant> { + if let Some(ref props) = self.message_annotations { + props.get(key) + } else { + None + } + } + + /// Add application property + pub fn set_app_property, V: Into>(mut self, key: K, value: V) -> Self { + if let Some(ref mut props) = self.application_properties { + props.insert(key.into(), value.into()); + } else { + let mut props = FxHashMap::default(); + props.insert(key.into(), value.into()); + self.application_properties = Some(props); + } + self.size.set(0); + self + } + + /// Call closure with message reference + pub fn update(self, f: F) -> Self + where + F: Fn(Self) -> Self, + { + self.size.set(0); + f(self) + } + + /// Call closure if value is Some value + pub fn if_some(self, value: &Option, f: F) -> Self + where + F: Fn(Self, &T) -> Self, + { + if let Some(ref val) = value { + self.size.set(0); + f(self, val) + } else { + self + } + } + + /// Message body + pub fn body(&self) -> &MessageBody { + &self.body + } + + /// Message value + pub fn value(&self) -> Option<&Variant> { + self.body.value.as_ref() + } + + /// Set message body value + pub fn set_value>(mut self, v: V) -> Self { + self.body.value = Some(v.into()); + self + } + + /// Set message body + pub fn set_body(mut self, f: F) -> Self + where + F: Fn(&mut MessageBody), + { + f(&mut self.body); + self.size.set(0); + self + } + + /// Create new message and set `correlation_id` property + pub fn reply_message(&self) -> OutMessage { + let mut msg = OutMessage::default().if_some(&self.properties, |mut msg, data| { + msg.set_properties(|props| props.correlation_id = data.message_id.clone()); + msg + }); + msg.message_format = self.message_format; + msg + } +} + +impl Decode for InMessage { + fn decode(mut input: &[u8]) -> Result<(&[u8], InMessage), AmqpParseError> { + let mut message = InMessage::default(); + + loop { + let (buf, sec) = Section::decode(input)?; + match sec { + Section::Header(val) => { + message.header = Some(val); + } + Section::DeliveryAnnotations(val) => { + message.delivery_annotations = Some(val); + } + Section::MessageAnnotations(val) => { + message.message_annotations = Some(val); + } + Section::ApplicationProperties(val) => { + message.application_properties = Some(val); + } + Section::Footer(val) => { + message.footer = Some(val); + } + Section::Properties(val) => { + message.properties = Some(val); + } + + // body + Section::AmqpSequence(val) => { + message.body.sequence.push(val); + } + Section::AmqpValue(val) => { + message.body.value = Some(val); + } + Section::Data(val) => { + message.body.data.push(val); + } + } + if buf.is_empty() { + break; + } + input = buf; + } + Ok((input, message)) + } +} + +impl Encode for InMessage { + fn encoded_size(&self) -> usize { + let size = self.size.get(); + if size != 0 { + return size; + } + + // body size, always add empty body if needed + let body_size = self.body.encoded_size(); + let mut size = if body_size == 0 { + // empty bytes + SECTION_PREFIX_LENGTH + 2 + } else { + body_size + }; + + if let Some(ref h) = self.header { + size += h.encoded_size() + SECTION_PREFIX_LENGTH; + } + if let Some(ref da) = self.delivery_annotations { + size += da.encoded_size() + SECTION_PREFIX_LENGTH; + } + if let Some(ref ma) = self.message_annotations { + size += ma.encoded_size() + SECTION_PREFIX_LENGTH; + } + if let Some(ref p) = self.properties { + size += p.encoded_size(); + } + if let Some(ref ap) = self.application_properties { + size += ap.encoded_size() + SECTION_PREFIX_LENGTH; + } + if let Some(ref f) = self.footer { + size += f.encoded_size() + SECTION_PREFIX_LENGTH; + } + self.size.set(size); + size + } + + fn encode(&self, dst: &mut BytesMut) { + if let Some(ref h) = self.header { + h.encode(dst); + } + if let Some(ref da) = self.delivery_annotations { + Descriptor::Ulong(113).encode(dst); + da.encode(dst); + } + if let Some(ref ma) = self.message_annotations { + Descriptor::Ulong(114).encode(dst); + ma.encode(dst); + } + if let Some(ref p) = self.properties { + p.encode(dst); + } + if let Some(ref ap) = self.application_properties { + Descriptor::Ulong(116).encode(dst); + ap.encode(dst); + } + + // message body + if self.body.encoded_size() == 0 { + // special treatment for empty body + Descriptor::Ulong(117).encode(dst); + dst.put_u8(FORMATCODE_BINARY8); + dst.put_u8(0); + } else { + self.body.encode(dst); + } + + // message footer, always last item + if let Some(ref f) = self.footer { + Descriptor::Ulong(120).encode(dst); + f.encode(dst); + } + } +} + +#[cfg(test)] +mod tests { + use bytes::{Bytes, BytesMut}; + use bytestring::ByteString; + + use crate::codec::{Decode, Encode}; + use crate::errors::AmqpCodecError; + use crate::protocol::Header; + use crate::types::Variant; + + use super::InMessage; + + #[test] + fn test_properties() -> Result<(), AmqpCodecError> { + let msg = + InMessage::with_body(Bytes::from_static(b"Hello world")).set_properties(|props| { + props.message_id = Some(Bytes::from_static(b"msg1").into()); + props.content_type = Some("text".to_string().into()); + props.correlation_id = Some(Bytes::from_static(b"no1").into()); + props.content_encoding = Some("utf8+1".to_string().into()); + }); + + let mut buf = BytesMut::with_capacity(msg.encoded_size()); + msg.encode(&mut buf); + + let msg2 = InMessage::decode(&buf)?.1; + let props = msg2.properties.as_ref().unwrap(); + assert_eq!(props.message_id, Some(Bytes::from_static(b"msg1").into())); + assert_eq!( + props.correlation_id, + Some(Bytes::from_static(b"no1").into()) + ); + Ok(()) + } + + #[test] + fn test_app_properties() -> Result<(), AmqpCodecError> { + let msg = InMessage::default().set_app_property(ByteString::from("test"), 1); + + let mut buf = BytesMut::with_capacity(msg.encoded_size()); + msg.encode(&mut buf); + + let msg2 = InMessage::decode(&buf)?.1; + let props = msg2.application_properties.as_ref().unwrap(); + assert_eq!(*props.get("test").unwrap(), Variant::from(1)); + Ok(()) + } + + #[test] + fn test_header() -> Result<(), AmqpCodecError> { + let hdr = Header { + durable: false, + priority: 1, + ttl: None, + first_acquirer: false, + delivery_count: 1, + }; + + let msg = InMessage::default().set_header(hdr.clone()); + let mut buf = BytesMut::with_capacity(msg.encoded_size()); + msg.encode(&mut buf); + + let msg2 = InMessage::decode(&buf)?.1; + assert_eq!(msg2.header().unwrap(), &hdr); + Ok(()) + } + + #[test] + fn test_data() -> Result<(), AmqpCodecError> { + let data = Bytes::from_static(b"test data"); + + let msg = InMessage::default().set_body(|body| body.set_data(data.clone())); + let mut buf = BytesMut::with_capacity(msg.encoded_size()); + msg.encode(&mut buf); + + let msg2 = InMessage::decode(&buf)?.1; + assert_eq!(msg2.body.data().unwrap(), &data); + Ok(()) + } + + #[test] + fn test_data_empty() -> Result<(), AmqpCodecError> { + let msg = InMessage::default(); + let mut buf = BytesMut::with_capacity(msg.encoded_size()); + msg.encode(&mut buf); + + let msg2 = InMessage::decode(&buf)?.1; + assert_eq!(msg2.body.data().unwrap(), &Bytes::from_static(b"")); + Ok(()) + } + + #[test] + fn test_messages() -> Result<(), AmqpCodecError> { + let msg1 = InMessage::default().set_properties(|props| props.message_id = Some(1.into())); + let msg2 = InMessage::default().set_properties(|props| props.message_id = Some(2.into())); + + let msg = InMessage::default().set_body(|body| { + body.messages.push(msg1.clone().into()); + body.messages.push(msg2.clone().into()); + }); + let mut buf = BytesMut::with_capacity(msg.encoded_size()); + msg.encode(&mut buf); + + let msg3 = InMessage::decode(&buf)?.1; + let msg4 = InMessage::decode(&msg3.body.data().unwrap())?.1; + assert_eq!(msg1.properties, msg4.properties); + + let msg5 = InMessage::decode(&msg3.body.data[1])?.1; + assert_eq!(msg2.properties, msg5.properties); + Ok(()) + } +} diff --git a/actix-amqp/codec/src/message/mod.rs b/actix-amqp/codec/src/message/mod.rs new file mode 100755 index 000000000..a0046b0f2 --- /dev/null +++ b/actix-amqp/codec/src/message/mod.rs @@ -0,0 +1,9 @@ +mod body; +mod inmessage; +mod outmessage; + +pub use self::body::MessageBody; +pub use self::inmessage::InMessage; +pub use self::outmessage::OutMessage; + +pub(self) const SECTION_PREFIX_LENGTH: usize = 3; diff --git a/actix-amqp/codec/src/message/outmessage.rs b/actix-amqp/codec/src/message/outmessage.rs new file mode 100755 index 000000000..23bc81496 --- /dev/null +++ b/actix-amqp/codec/src/message/outmessage.rs @@ -0,0 +1,440 @@ +use std::cell::Cell; + +use bytes::{BufMut, Bytes, BytesMut}; + +use crate::codec::{Decode, Encode, FORMATCODE_BINARY8}; +use crate::errors::AmqpParseError; +use crate::protocol::{Annotations, Header, MessageFormat, Properties, Section, TransferBody}; +use crate::types::{Descriptor, Str, Symbol, Variant, VecStringMap, VecSymbolMap}; + +use super::body::MessageBody; +use super::inmessage::InMessage; +use super::SECTION_PREFIX_LENGTH; + +#[derive(Debug, Clone, Default, PartialEq)] +pub struct OutMessage { + pub message_format: Option, + header: Option
, + delivery_annotations: Option, + message_annotations: Option, + properties: Option, + application_properties: Option, + footer: Option, + body: MessageBody, + size: Cell, +} + +impl OutMessage { + /// Create new message and set body + pub fn with_body(body: Bytes) -> OutMessage { + let mut msg = OutMessage::default(); + msg.body.data.push(body); + msg.message_format = Some(0); + msg + } + + /// Create new message and set messages as body + pub fn with_messages(messages: Vec) -> OutMessage { + let mut msg = OutMessage::default(); + msg.body.messages = messages; + msg.message_format = Some(0); + msg + } + + /// Header + pub fn header(&self) -> Option<&Header> { + self.header.as_ref() + } + + /// Set message header + pub fn set_header(&mut self, header: Header) -> &mut Self { + self.header = Some(header); + self.size.set(0); + self + } + + /// Message properties + pub fn properties(&self) -> Option<&Properties> { + self.properties.as_ref() + } + + /// Mutable reference to properties + pub fn properties_mut(&mut self) -> &mut Properties { + if self.properties.is_none() { + self.properties = Some(Properties::default()); + } + + self.size.set(0); + self.properties.as_mut().unwrap() + } + + /// Add property + pub fn set_properties(&mut self, f: F) -> &mut Self + where + F: Fn(&mut Properties), + { + if let Some(ref mut props) = self.properties { + f(props); + } else { + let mut props = Properties::default(); + f(&mut props); + self.properties = Some(props); + } + self.size.set(0); + self + } + + /// Get application property + pub fn app_properties(&self) -> Option<&VecStringMap> { + self.application_properties.as_ref() + } + + /// Add application property + pub fn set_app_property(&mut self, key: K, value: V) -> &mut Self + where + K: Into, + V: Into, + { + if let Some(ref mut props) = self.application_properties { + props.push((key.into(), value.into())); + } else { + let mut props = VecStringMap::default(); + props.push((key.into(), value.into())); + self.application_properties = Some(props); + } + self.size.set(0); + self + } + + /// Add message annotation + pub fn add_message_annotation(&mut self, key: K, value: V) -> &mut Self + where + K: Into, + V: Into, + { + if let Some(ref mut props) = self.message_annotations { + props.push((key.into(), value.into())); + } else { + let mut props = VecSymbolMap::default(); + props.push((key.into(), value.into())); + self.message_annotations = Some(props); + } + self.size.set(0); + self + } + + /// Call closure with message reference + pub fn update(self, f: F) -> Self + where + F: Fn(Self) -> Self, + { + self.size.set(0); + f(self) + } + + /// Call closure if value is Some value + pub fn if_some(self, value: &Option, f: F) -> Self + where + F: Fn(Self, &T) -> Self, + { + if let Some(ref val) = value { + self.size.set(0); + f(self, val) + } else { + self + } + } + + /// Message body + pub fn body(&self) -> &MessageBody { + &self.body + } + + /// Message value + pub fn value(&self) -> Option<&Variant> { + self.body.value.as_ref() + } + + /// Set message body value + pub fn set_value>(&mut self, v: V) -> &mut Self { + self.body.value = Some(v.into()); + self + } + + /// Set message body + pub fn set_body(&mut self, f: F) -> &mut Self + where + F: FnOnce(&mut MessageBody), + { + f(&mut self.body); + self.size.set(0); + self + } + + /// Create new message and set `correlation_id` property + pub fn reply_message(&self) -> OutMessage { + OutMessage::default().if_some(&self.properties, |mut msg, data| { + msg.set_properties(|props| props.correlation_id = data.message_id.clone()); + msg + }) + } +} + +impl From for OutMessage { + fn from(from: InMessage) -> Self { + let mut msg = OutMessage { + message_format: from.message_format, + header: from.header, + properties: from.properties, + delivery_annotations: from.delivery_annotations, + message_annotations: from.message_annotations.map(|v| v.into()), + application_properties: from.application_properties.map(|v| v.into()), + footer: from.footer, + body: from.body, + size: Cell::new(0), + }; + + if let Some(ref mut props) = msg.properties { + props.correlation_id = props.message_id.clone(); + }; + + msg + } +} + +impl Decode for OutMessage { + fn decode(mut input: &[u8]) -> Result<(&[u8], OutMessage), AmqpParseError> { + let mut message = OutMessage::default(); + + loop { + let (buf, sec) = Section::decode(input)?; + match sec { + Section::Header(val) => { + message.header = Some(val); + } + Section::DeliveryAnnotations(val) => { + message.delivery_annotations = Some(val); + } + Section::MessageAnnotations(val) => { + message.message_annotations = Some(VecSymbolMap( + val.into_iter().map(|(k, v)| (k.0.into(), v)).collect(), + )); + } + Section::ApplicationProperties(val) => { + message.application_properties = Some(VecStringMap( + val.into_iter().map(|(k, v)| (k.into(), v)).collect(), + )); + } + Section::Footer(val) => { + message.footer = Some(val); + } + Section::Properties(val) => { + message.properties = Some(val); + } + + // body + Section::AmqpSequence(val) => { + message.body.sequence.push(val); + } + Section::AmqpValue(val) => { + message.body.value = Some(val); + } + Section::Data(val) => { + message.body.data.push(val); + } + } + if buf.is_empty() { + break; + } + input = buf; + } + Ok((input, message)) + } +} + +impl Encode for OutMessage { + fn encoded_size(&self) -> usize { + let size = self.size.get(); + if size != 0 { + return size; + } + + // body size, always add empty body if needed + let body_size = self.body.encoded_size(); + let mut size = if body_size == 0 { + // empty bytes + SECTION_PREFIX_LENGTH + 2 + } else { + body_size + }; + + if let Some(ref h) = self.header { + size += h.encoded_size() + SECTION_PREFIX_LENGTH; + } + if let Some(ref da) = self.delivery_annotations { + size += da.encoded_size() + SECTION_PREFIX_LENGTH; + } + if let Some(ref ma) = self.message_annotations { + size += ma.encoded_size() + SECTION_PREFIX_LENGTH; + } + if let Some(ref p) = self.properties { + size += p.encoded_size(); + } + if let Some(ref ap) = self.application_properties { + size += ap.encoded_size() + SECTION_PREFIX_LENGTH; + } + if let Some(ref f) = self.footer { + size += f.encoded_size() + SECTION_PREFIX_LENGTH; + } + self.size.set(size); + size + } + + fn encode(&self, dst: &mut BytesMut) { + if let Some(ref h) = self.header { + h.encode(dst); + } + if let Some(ref da) = self.delivery_annotations { + Descriptor::Ulong(113).encode(dst); + da.encode(dst); + } + if let Some(ref ma) = self.message_annotations { + Descriptor::Ulong(114).encode(dst); + ma.encode(dst); + } + if let Some(ref p) = self.properties { + p.encode(dst); + } + if let Some(ref ap) = self.application_properties { + Descriptor::Ulong(116).encode(dst); + ap.encode(dst); + } + + // message body + if self.body.encoded_size() == 0 { + // special treatment for empty body + Descriptor::Ulong(117).encode(dst); + dst.put_u8(FORMATCODE_BINARY8); + dst.put_u8(0); + } else { + self.body.encode(dst); + } + + // message footer, always last item + if let Some(ref f) = self.footer { + Descriptor::Ulong(120).encode(dst); + f.encode(dst); + } + } +} + +#[cfg(test)] +mod tests { + use bytes::{Bytes, BytesMut}; + use bytestring::ByteString; + + use crate::codec::{Decode, Encode}; + use crate::errors::AmqpCodecError; + use crate::protocol::Header; + use crate::types::Variant; + + use super::OutMessage; + + #[test] + fn test_properties() -> Result<(), AmqpCodecError> { + let mut msg = OutMessage::default(); + msg.set_properties(|props| props.message_id = Some(1.into())); + + let mut buf = BytesMut::with_capacity(msg.encoded_size()); + msg.encode(&mut buf); + + let msg2 = OutMessage::decode(&buf)?.1; + let props = msg2.properties.as_ref().unwrap(); + assert_eq!(props.message_id, Some(1.into())); + Ok(()) + } + + #[test] + fn test_app_properties() -> Result<(), AmqpCodecError> { + let mut msg = OutMessage::default(); + msg.set_app_property(ByteString::from("test"), 1); + + let mut buf = BytesMut::with_capacity(msg.encoded_size()); + msg.encode(&mut buf); + + let msg2 = OutMessage::decode(&buf)?.1; + let props = msg2.application_properties.as_ref().unwrap(); + assert_eq!(props[0].0.as_str(), "test"); + assert_eq!(props[0].1, Variant::from(1)); + Ok(()) + } + + #[test] + fn test_header() -> Result<(), AmqpCodecError> { + let hdr = Header { + durable: false, + priority: 1, + ttl: None, + first_acquirer: false, + delivery_count: 1, + }; + + let mut msg = OutMessage::default(); + msg.set_header(hdr.clone()); + let mut buf = BytesMut::with_capacity(msg.encoded_size()); + msg.encode(&mut buf); + + let msg2 = OutMessage::decode(&buf)?.1; + assert_eq!(msg2.header().unwrap(), &hdr); + Ok(()) + } + + #[test] + fn test_data() -> Result<(), AmqpCodecError> { + let data = Bytes::from_static(b"test data"); + + let mut msg = OutMessage::default(); + msg.set_body(|body| body.set_data(data.clone())); + let mut buf = BytesMut::with_capacity(msg.encoded_size()); + msg.encode(&mut buf); + + let msg2 = OutMessage::decode(&buf)?.1; + assert_eq!(msg2.body.data().unwrap(), &data); + Ok(()) + } + + #[test] + fn test_data_empty() -> Result<(), AmqpCodecError> { + let msg = OutMessage::default(); + let mut buf = BytesMut::with_capacity(msg.encoded_size()); + msg.encode(&mut buf); + + let msg2 = OutMessage::decode(&buf)?.1; + assert_eq!(msg2.body.data().unwrap(), &Bytes::from_static(b"")); + Ok(()) + } + + #[test] + fn test_messages() -> Result<(), AmqpCodecError> { + let mut msg1 = OutMessage::default(); + msg1.set_properties(|props| props.message_id = Some(1.into())); + let mut msg2 = OutMessage::default(); + msg2.set_properties(|props| props.message_id = Some(2.into())); + + let mut msg = OutMessage::default(); + msg.set_body(|body| { + body.messages.push(msg1.clone().into()); + body.messages.push(msg2.clone().into()); + }); + let mut buf = BytesMut::with_capacity(msg.encoded_size()); + msg.encode(&mut buf); + + let msg3 = OutMessage::decode(&buf)?.1; + let msg4 = OutMessage::decode(&msg3.body.data().unwrap())?.1; + assert_eq!(msg1.properties, msg4.properties); + + let msg5 = OutMessage::decode(&msg3.body.data[1])?.1; + assert_eq!(msg2.properties, msg5.properties); + Ok(()) + } +} diff --git a/actix-amqp/codec/src/protocol/definitions.rs b/actix-amqp/codec/src/protocol/definitions.rs new file mode 100755 index 000000000..905fd4b68 --- /dev/null +++ b/actix-amqp/codec/src/protocol/definitions.rs @@ -0,0 +1,4661 @@ +#![allow(unused_assignments, unused_variables, unreachable_patterns)] +use super::*; +use crate::codec::{self, decode_format_code, decode_list_header, Decode, DecodeFormatted, Encode}; +use crate::errors::AmqpParseError; +use bytes::{BufMut, Bytes, BytesMut}; +use bytestring::ByteString; +use derive_more::From; +use std::u8; +use uuid::Uuid; +#[derive(Clone, Debug, PartialEq, From)] +pub enum Frame { + Open(Open), + Begin(Begin), + Attach(Attach), + Flow(Flow), + Transfer(Transfer), + Disposition(Disposition), + Detach(Detach), + End(End), + Close(Close), + Empty, +} +impl Frame { + pub fn name(&self) -> &'static str { + match self { + Frame::Open(_) => "Open", + Frame::Begin(_) => "Begin", + Frame::Attach(_) => "Attach", + Frame::Flow(_) => "Flow", + Frame::Transfer(_) => "Transfer", + Frame::Disposition(_) => "Disposition", + Frame::Detach(_) => "Detach", + Frame::End(_) => "End", + Frame::Close(_) => "Close", + Frame::Empty => "Empty", + } + } +} +impl Decode for Frame { + fn decode(input: &[u8]) -> Result<(&[u8], Self), AmqpParseError> { + if input.is_empty() { + Ok((input, Frame::Empty)) + } else { + let (input, fmt) = decode_format_code(input)?; + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + match descriptor { + Descriptor::Ulong(16) => decode_open_inner(input).map(|(i, r)| (i, Frame::Open(r))), + Descriptor::Ulong(17) => { + decode_begin_inner(input).map(|(i, r)| (i, Frame::Begin(r))) + } + Descriptor::Ulong(18) => { + decode_attach_inner(input).map(|(i, r)| (i, Frame::Attach(r))) + } + Descriptor::Ulong(19) => decode_flow_inner(input).map(|(i, r)| (i, Frame::Flow(r))), + Descriptor::Ulong(20) => { + decode_transfer_inner(input).map(|(i, r)| (i, Frame::Transfer(r))) + } + Descriptor::Ulong(21) => { + decode_disposition_inner(input).map(|(i, r)| (i, Frame::Disposition(r))) + } + Descriptor::Ulong(22) => { + decode_detach_inner(input).map(|(i, r)| (i, Frame::Detach(r))) + } + Descriptor::Ulong(23) => decode_end_inner(input).map(|(i, r)| (i, Frame::End(r))), + Descriptor::Ulong(24) => { + decode_close_inner(input).map(|(i, r)| (i, Frame::Close(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:open:list" => { + decode_open_inner(input).map(|(i, r)| (i, Frame::Open(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:begin:list" => { + decode_begin_inner(input).map(|(i, r)| (i, Frame::Begin(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:attach:list" => { + decode_attach_inner(input).map(|(i, r)| (i, Frame::Attach(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:flow:list" => { + decode_flow_inner(input).map(|(i, r)| (i, Frame::Flow(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:transfer:list" => { + decode_transfer_inner(input).map(|(i, r)| (i, Frame::Transfer(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:disposition:list" => { + decode_disposition_inner(input).map(|(i, r)| (i, Frame::Disposition(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:detach:list" => { + decode_detach_inner(input).map(|(i, r)| (i, Frame::Detach(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:end:list" => { + decode_end_inner(input).map(|(i, r)| (i, Frame::End(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:close:list" => { + decode_close_inner(input).map(|(i, r)| (i, Frame::Close(r))) + } + _ => Err(AmqpParseError::InvalidDescriptor(descriptor)), + } + } + } +} +impl Encode for Frame { + fn encoded_size(&self) -> usize { + match *self { + Frame::Open(ref v) => encoded_size_open_inner(v), + Frame::Begin(ref v) => encoded_size_begin_inner(v), + Frame::Attach(ref v) => encoded_size_attach_inner(v), + Frame::Flow(ref v) => encoded_size_flow_inner(v), + Frame::Transfer(ref v) => encoded_size_transfer_inner(v), + Frame::Disposition(ref v) => encoded_size_disposition_inner(v), + Frame::Detach(ref v) => encoded_size_detach_inner(v), + Frame::End(ref v) => encoded_size_end_inner(v), + Frame::Close(ref v) => encoded_size_close_inner(v), + Frame::Empty => 0, + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + Frame::Open(ref v) => encode_open_inner(v, buf), + Frame::Begin(ref v) => encode_begin_inner(v, buf), + Frame::Attach(ref v) => encode_attach_inner(v, buf), + Frame::Flow(ref v) => encode_flow_inner(v, buf), + Frame::Transfer(ref v) => encode_transfer_inner(v, buf), + Frame::Disposition(ref v) => encode_disposition_inner(v, buf), + Frame::Detach(ref v) => encode_detach_inner(v, buf), + Frame::End(ref v) => encode_end_inner(v, buf), + Frame::Close(ref v) => encode_close_inner(v, buf), + Frame::Empty => (), + } + } +} +#[derive(Clone, Debug, PartialEq)] +pub enum DeliveryState { + Received(Received), + Accepted(Accepted), + Rejected(Rejected), + Released(Released), + Modified(Modified), +} +impl DecodeFormatted for DeliveryState { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + match descriptor { + Descriptor::Ulong(35) => { + decode_received_inner(input).map(|(i, r)| (i, DeliveryState::Received(r))) + } + Descriptor::Ulong(36) => { + decode_accepted_inner(input).map(|(i, r)| (i, DeliveryState::Accepted(r))) + } + Descriptor::Ulong(37) => { + decode_rejected_inner(input).map(|(i, r)| (i, DeliveryState::Rejected(r))) + } + Descriptor::Ulong(38) => { + decode_released_inner(input).map(|(i, r)| (i, DeliveryState::Released(r))) + } + Descriptor::Ulong(39) => { + decode_modified_inner(input).map(|(i, r)| (i, DeliveryState::Modified(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:received:list" => { + decode_received_inner(input).map(|(i, r)| (i, DeliveryState::Received(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:accepted:list" => { + decode_accepted_inner(input).map(|(i, r)| (i, DeliveryState::Accepted(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:rejected:list" => { + decode_rejected_inner(input).map(|(i, r)| (i, DeliveryState::Rejected(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:released:list" => { + decode_released_inner(input).map(|(i, r)| (i, DeliveryState::Released(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:modified:list" => { + decode_modified_inner(input).map(|(i, r)| (i, DeliveryState::Modified(r))) + } + _ => Err(AmqpParseError::InvalidDescriptor(descriptor)), + } + } +} +impl Encode for DeliveryState { + fn encoded_size(&self) -> usize { + match *self { + DeliveryState::Received(ref v) => encoded_size_received_inner(v), + DeliveryState::Accepted(ref v) => encoded_size_accepted_inner(v), + DeliveryState::Rejected(ref v) => encoded_size_rejected_inner(v), + DeliveryState::Released(ref v) => encoded_size_released_inner(v), + DeliveryState::Modified(ref v) => encoded_size_modified_inner(v), + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + DeliveryState::Received(ref v) => encode_received_inner(v, buf), + DeliveryState::Accepted(ref v) => encode_accepted_inner(v, buf), + DeliveryState::Rejected(ref v) => encode_rejected_inner(v, buf), + DeliveryState::Released(ref v) => encode_released_inner(v, buf), + DeliveryState::Modified(ref v) => encode_modified_inner(v, buf), + } + } +} +#[derive(Clone, Debug, PartialEq)] +pub enum SaslFrameBody { + SaslMechanisms(SaslMechanisms), + SaslInit(SaslInit), + SaslChallenge(SaslChallenge), + SaslResponse(SaslResponse), + SaslOutcome(SaslOutcome), +} +impl DecodeFormatted for SaslFrameBody { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + match descriptor { + Descriptor::Ulong(64) => decode_sasl_mechanisms_inner(input) + .map(|(i, r)| (i, SaslFrameBody::SaslMechanisms(r))), + Descriptor::Ulong(65) => { + decode_sasl_init_inner(input).map(|(i, r)| (i, SaslFrameBody::SaslInit(r))) + } + Descriptor::Ulong(66) => decode_sasl_challenge_inner(input) + .map(|(i, r)| (i, SaslFrameBody::SaslChallenge(r))), + Descriptor::Ulong(67) => { + decode_sasl_response_inner(input).map(|(i, r)| (i, SaslFrameBody::SaslResponse(r))) + } + Descriptor::Ulong(68) => { + decode_sasl_outcome_inner(input).map(|(i, r)| (i, SaslFrameBody::SaslOutcome(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:sasl-mechanisms:list" => { + decode_sasl_mechanisms_inner(input) + .map(|(i, r)| (i, SaslFrameBody::SaslMechanisms(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:sasl-init:list" => { + decode_sasl_init_inner(input).map(|(i, r)| (i, SaslFrameBody::SaslInit(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:sasl-challenge:list" => { + decode_sasl_challenge_inner(input) + .map(|(i, r)| (i, SaslFrameBody::SaslChallenge(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:sasl-response:list" => { + decode_sasl_response_inner(input).map(|(i, r)| (i, SaslFrameBody::SaslResponse(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:sasl-outcome:list" => { + decode_sasl_outcome_inner(input).map(|(i, r)| (i, SaslFrameBody::SaslOutcome(r))) + } + _ => Err(AmqpParseError::InvalidDescriptor(descriptor)), + } + } +} +impl Encode for SaslFrameBody { + fn encoded_size(&self) -> usize { + match *self { + SaslFrameBody::SaslMechanisms(ref v) => encoded_size_sasl_mechanisms_inner(v), + SaslFrameBody::SaslInit(ref v) => encoded_size_sasl_init_inner(v), + SaslFrameBody::SaslChallenge(ref v) => encoded_size_sasl_challenge_inner(v), + SaslFrameBody::SaslResponse(ref v) => encoded_size_sasl_response_inner(v), + SaslFrameBody::SaslOutcome(ref v) => encoded_size_sasl_outcome_inner(v), + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + SaslFrameBody::SaslMechanisms(ref v) => encode_sasl_mechanisms_inner(v, buf), + SaslFrameBody::SaslInit(ref v) => encode_sasl_init_inner(v, buf), + SaslFrameBody::SaslChallenge(ref v) => encode_sasl_challenge_inner(v, buf), + SaslFrameBody::SaslResponse(ref v) => encode_sasl_response_inner(v, buf), + SaslFrameBody::SaslOutcome(ref v) => encode_sasl_outcome_inner(v, buf), + } + } +} +#[derive(Clone, Debug, PartialEq)] +pub enum Section { + Header(Header), + DeliveryAnnotations(DeliveryAnnotations), + MessageAnnotations(MessageAnnotations), + ApplicationProperties(ApplicationProperties), + Data(Data), + AmqpSequence(AmqpSequence), + AmqpValue(AmqpValue), + Footer(Footer), + Properties(Properties), +} +impl DecodeFormatted for Section { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + match descriptor { + Descriptor::Ulong(112) => { + decode_header_inner(input).map(|(i, r)| (i, Section::Header(r))) + } + Descriptor::Ulong(113) => decode_delivery_annotations_inner(input) + .map(|(i, r)| (i, Section::DeliveryAnnotations(r))), + Descriptor::Ulong(114) => decode_message_annotations_inner(input) + .map(|(i, r)| (i, Section::MessageAnnotations(r))), + Descriptor::Ulong(116) => decode_application_properties_inner(input) + .map(|(i, r)| (i, Section::ApplicationProperties(r))), + Descriptor::Ulong(117) => decode_data_inner(input).map(|(i, r)| (i, Section::Data(r))), + Descriptor::Ulong(118) => { + decode_amqp_sequence_inner(input).map(|(i, r)| (i, Section::AmqpSequence(r))) + } + Descriptor::Ulong(119) => { + decode_amqp_value_inner(input).map(|(i, r)| (i, Section::AmqpValue(r))) + } + Descriptor::Ulong(120) => { + decode_footer_inner(input).map(|(i, r)| (i, Section::Footer(r))) + } + Descriptor::Ulong(115) => { + decode_properties_inner(input).map(|(i, r)| (i, Section::Properties(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:header:list" => { + decode_header_inner(input).map(|(i, r)| (i, Section::Header(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:delivery-annotations:map" => { + decode_delivery_annotations_inner(input) + .map(|(i, r)| (i, Section::DeliveryAnnotations(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:message-annotations:map" => { + decode_message_annotations_inner(input) + .map(|(i, r)| (i, Section::MessageAnnotations(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:application-properties:map" => { + decode_application_properties_inner(input) + .map(|(i, r)| (i, Section::ApplicationProperties(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:data:binary" => { + decode_data_inner(input).map(|(i, r)| (i, Section::Data(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:amqp-sequence:list" => { + decode_amqp_sequence_inner(input).map(|(i, r)| (i, Section::AmqpSequence(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:amqp-value:*" => { + decode_amqp_value_inner(input).map(|(i, r)| (i, Section::AmqpValue(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:footer:map" => { + decode_footer_inner(input).map(|(i, r)| (i, Section::Footer(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:properties:list" => { + decode_properties_inner(input).map(|(i, r)| (i, Section::Properties(r))) + } + _ => Err(AmqpParseError::InvalidDescriptor(descriptor)), + } + } +} +impl Encode for Section { + fn encoded_size(&self) -> usize { + match *self { + Section::Header(ref v) => encoded_size_header_inner(v), + Section::DeliveryAnnotations(ref v) => encoded_size_delivery_annotations_inner(v), + Section::MessageAnnotations(ref v) => encoded_size_message_annotations_inner(v), + Section::ApplicationProperties(ref v) => encoded_size_application_properties_inner(v), + Section::Data(ref v) => encoded_size_data_inner(v), + Section::AmqpSequence(ref v) => encoded_size_amqp_sequence_inner(v), + Section::AmqpValue(ref v) => encoded_size_amqp_value_inner(v), + Section::Footer(ref v) => encoded_size_footer_inner(v), + Section::Properties(ref v) => encoded_size_properties_inner(v), + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + Section::Header(ref v) => encode_header_inner(v, buf), + Section::DeliveryAnnotations(ref v) => encode_delivery_annotations_inner(v, buf), + Section::MessageAnnotations(ref v) => encode_message_annotations_inner(v, buf), + Section::ApplicationProperties(ref v) => encode_application_properties_inner(v, buf), + Section::Data(ref v) => encode_data_inner(v, buf), + Section::AmqpSequence(ref v) => encode_amqp_sequence_inner(v, buf), + Section::AmqpValue(ref v) => encode_amqp_value_inner(v, buf), + Section::Footer(ref v) => encode_footer_inner(v, buf), + Section::Properties(ref v) => encode_properties_inner(v, buf), + } + } +} +#[derive(Clone, Debug, PartialEq)] +pub enum Outcome { + Accepted(Accepted), + Rejected(Rejected), + Released(Released), + Modified(Modified), +} +impl DecodeFormatted for Outcome { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + match descriptor { + Descriptor::Ulong(36) => { + decode_accepted_inner(input).map(|(i, r)| (i, Outcome::Accepted(r))) + } + Descriptor::Ulong(37) => { + decode_rejected_inner(input).map(|(i, r)| (i, Outcome::Rejected(r))) + } + Descriptor::Ulong(38) => { + decode_released_inner(input).map(|(i, r)| (i, Outcome::Released(r))) + } + Descriptor::Ulong(39) => { + decode_modified_inner(input).map(|(i, r)| (i, Outcome::Modified(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:accepted:list" => { + decode_accepted_inner(input).map(|(i, r)| (i, Outcome::Accepted(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:rejected:list" => { + decode_rejected_inner(input).map(|(i, r)| (i, Outcome::Rejected(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:released:list" => { + decode_released_inner(input).map(|(i, r)| (i, Outcome::Released(r))) + } + Descriptor::Symbol(ref a) if a.as_str() == "amqp:modified:list" => { + decode_modified_inner(input).map(|(i, r)| (i, Outcome::Modified(r))) + } + _ => Err(AmqpParseError::InvalidDescriptor(descriptor)), + } + } +} +impl Encode for Outcome { + fn encoded_size(&self) -> usize { + match *self { + Outcome::Accepted(ref v) => encoded_size_accepted_inner(v), + Outcome::Rejected(ref v) => encoded_size_rejected_inner(v), + Outcome::Released(ref v) => encoded_size_released_inner(v), + Outcome::Modified(ref v) => encoded_size_modified_inner(v), + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + Outcome::Accepted(ref v) => encode_accepted_inner(v, buf), + Outcome::Rejected(ref v) => encode_rejected_inner(v, buf), + Outcome::Released(ref v) => encode_released_inner(v, buf), + Outcome::Modified(ref v) => encode_modified_inner(v, buf), + } + } +} +pub type Handle = u32; +pub type Seconds = u32; +pub type Milliseconds = u32; +pub type DeliveryTag = Bytes; +pub type SequenceNo = u32; +pub type DeliveryNumber = SequenceNo; +pub type TransferNumber = SequenceNo; +pub type MessageFormat = u32; +pub type IetfLanguageTag = Symbol; +pub type NodeProperties = Fields; +pub type MessageIdUlong = u64; +pub type MessageIdUuid = Uuid; +pub type MessageIdBinary = Bytes; +pub type MessageIdString = ByteString; +pub type Address = ByteString; +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum Role { + Sender, + Receiver, +} +impl Role { + pub fn try_from(v: bool) -> Result { + match v { + false => Ok(Role::Sender), + true => Ok(Role::Receiver), + _ => Err(AmqpParseError::UnknownEnumOption("Role")), + } + } +} +impl DecodeFormatted for Role { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, base) = bool::decode_with_format(input, fmt)?; + Ok((input, Self::try_from(base)?)) + } +} +impl Encode for Role { + fn encoded_size(&self) -> usize { + match *self { + Role::Sender => { + let v: bool = false; + v.encoded_size() + } + Role::Receiver => { + let v: bool = true; + v.encoded_size() + } + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + Role::Sender => { + let v: bool = false; + v.encode(buf); + } + Role::Receiver => { + let v: bool = true; + v.encode(buf); + } + } + } +} +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum SenderSettleMode { + Unsettled, + Settled, + Mixed, +} +impl SenderSettleMode { + pub fn try_from(v: u8) -> Result { + match v { + 0 => Ok(SenderSettleMode::Unsettled), + 1 => Ok(SenderSettleMode::Settled), + 2 => Ok(SenderSettleMode::Mixed), + _ => Err(AmqpParseError::UnknownEnumOption("SenderSettleMode")), + } + } +} +impl DecodeFormatted for SenderSettleMode { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, base) = u8::decode_with_format(input, fmt)?; + Ok((input, Self::try_from(base)?)) + } +} +impl Encode for SenderSettleMode { + fn encoded_size(&self) -> usize { + match *self { + SenderSettleMode::Unsettled => { + let v: u8 = 0; + v.encoded_size() + } + SenderSettleMode::Settled => { + let v: u8 = 1; + v.encoded_size() + } + SenderSettleMode::Mixed => { + let v: u8 = 2; + v.encoded_size() + } + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + SenderSettleMode::Unsettled => { + let v: u8 = 0; + v.encode(buf); + } + SenderSettleMode::Settled => { + let v: u8 = 1; + v.encode(buf); + } + SenderSettleMode::Mixed => { + let v: u8 = 2; + v.encode(buf); + } + } + } +} +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum ReceiverSettleMode { + First, + Second, +} +impl ReceiverSettleMode { + pub fn try_from(v: u8) -> Result { + match v { + 0 => Ok(ReceiverSettleMode::First), + 1 => Ok(ReceiverSettleMode::Second), + _ => Err(AmqpParseError::UnknownEnumOption("ReceiverSettleMode")), + } + } +} +impl DecodeFormatted for ReceiverSettleMode { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, base) = u8::decode_with_format(input, fmt)?; + Ok((input, Self::try_from(base)?)) + } +} +impl Encode for ReceiverSettleMode { + fn encoded_size(&self) -> usize { + match *self { + ReceiverSettleMode::First => { + let v: u8 = 0; + v.encoded_size() + } + ReceiverSettleMode::Second => { + let v: u8 = 1; + v.encoded_size() + } + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + ReceiverSettleMode::First => { + let v: u8 = 0; + v.encode(buf); + } + ReceiverSettleMode::Second => { + let v: u8 = 1; + v.encode(buf); + } + } + } +} +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum AmqpError { + InternalError, + NotFound, + UnauthorizedAccess, + DecodeError, + ResourceLimitExceeded, + NotAllowed, + InvalidField, + NotImplemented, + ResourceLocked, + PreconditionFailed, + ResourceDeleted, + IllegalState, + FrameSizeTooSmall, +} +impl AmqpError { + pub fn try_from(v: &Symbol) -> Result { + match v.as_str() { + "amqp:internal-error" => Ok(AmqpError::InternalError), + "amqp:not-found" => Ok(AmqpError::NotFound), + "amqp:unauthorized-access" => Ok(AmqpError::UnauthorizedAccess), + "amqp:decode-error" => Ok(AmqpError::DecodeError), + "amqp:resource-limit-exceeded" => Ok(AmqpError::ResourceLimitExceeded), + "amqp:not-allowed" => Ok(AmqpError::NotAllowed), + "amqp:invalid-field" => Ok(AmqpError::InvalidField), + "amqp:not-implemented" => Ok(AmqpError::NotImplemented), + "amqp:resource-locked" => Ok(AmqpError::ResourceLocked), + "amqp:precondition-failed" => Ok(AmqpError::PreconditionFailed), + "amqp:resource-deleted" => Ok(AmqpError::ResourceDeleted), + "amqp:illegal-state" => Ok(AmqpError::IllegalState), + "amqp:frame-size-too-small" => Ok(AmqpError::FrameSizeTooSmall), + _ => Err(AmqpParseError::UnknownEnumOption("AmqpError")), + } + } +} +impl DecodeFormatted for AmqpError { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, base) = Symbol::decode_with_format(input, fmt)?; + Ok((input, Self::try_from(&base)?)) + } +} +impl Encode for AmqpError { + fn encoded_size(&self) -> usize { + match *self { + AmqpError::InternalError => 19 + 2, + AmqpError::NotFound => 14 + 2, + AmqpError::UnauthorizedAccess => 24 + 2, + AmqpError::DecodeError => 17 + 2, + AmqpError::ResourceLimitExceeded => 28 + 2, + AmqpError::NotAllowed => 16 + 2, + AmqpError::InvalidField => 18 + 2, + AmqpError::NotImplemented => 20 + 2, + AmqpError::ResourceLocked => 20 + 2, + AmqpError::PreconditionFailed => 24 + 2, + AmqpError::ResourceDeleted => 21 + 2, + AmqpError::IllegalState => 18 + 2, + AmqpError::FrameSizeTooSmall => 25 + 2, + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + AmqpError::InternalError => StaticSymbol("amqp:internal-error").encode(buf), + AmqpError::NotFound => StaticSymbol("amqp:not-found").encode(buf), + AmqpError::UnauthorizedAccess => StaticSymbol("amqp:unauthorized-access").encode(buf), + AmqpError::DecodeError => StaticSymbol("amqp:decode-error").encode(buf), + AmqpError::ResourceLimitExceeded => { + StaticSymbol("amqp:resource-limit-exceeded").encode(buf) + } + AmqpError::NotAllowed => StaticSymbol("amqp:not-allowed").encode(buf), + AmqpError::InvalidField => StaticSymbol("amqp:invalid-field").encode(buf), + AmqpError::NotImplemented => StaticSymbol("amqp:not-implemented").encode(buf), + AmqpError::ResourceLocked => StaticSymbol("amqp:resource-locked").encode(buf), + AmqpError::PreconditionFailed => StaticSymbol("amqp:precondition-failed").encode(buf), + AmqpError::ResourceDeleted => StaticSymbol("amqp:resource-deleted").encode(buf), + AmqpError::IllegalState => StaticSymbol("amqp:illegal-state").encode(buf), + AmqpError::FrameSizeTooSmall => StaticSymbol("amqp:frame-size-too-small").encode(buf), + } + } +} +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum ConnectionError { + ConnectionForced, + FramingError, + Redirect, +} +impl ConnectionError { + pub fn try_from(v: &Symbol) -> Result { + match v.as_str() { + "amqp:connection:forced" => Ok(ConnectionError::ConnectionForced), + "amqp:connection:framing-error" => Ok(ConnectionError::FramingError), + "amqp:connection:redirect" => Ok(ConnectionError::Redirect), + _ => Err(AmqpParseError::UnknownEnumOption("ConnectionError")), + } + } +} +impl DecodeFormatted for ConnectionError { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, base) = Symbol::decode_with_format(input, fmt)?; + Ok((input, Self::try_from(&base)?)) + } +} +impl Encode for ConnectionError { + fn encoded_size(&self) -> usize { + match *self { + ConnectionError::ConnectionForced => 22 + 2, + ConnectionError::FramingError => 29 + 2, + ConnectionError::Redirect => 24 + 2, + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + ConnectionError::ConnectionForced => StaticSymbol("amqp:connection:forced").encode(buf), + ConnectionError::FramingError => { + StaticSymbol("amqp:connection:framing-error").encode(buf) + } + ConnectionError::Redirect => StaticSymbol("amqp:connection:redirect").encode(buf), + } + } +} +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum SessionError { + WindowViolation, + ErrantLink, + HandleInUse, + UnattachedHandle, +} +impl SessionError { + pub fn try_from(v: &Symbol) -> Result { + match v.as_str() { + "amqp:session:window-violation" => Ok(SessionError::WindowViolation), + "amqp:session:errant-link" => Ok(SessionError::ErrantLink), + "amqp:session:handle-in-use" => Ok(SessionError::HandleInUse), + "amqp:session:unattached-handle" => Ok(SessionError::UnattachedHandle), + _ => Err(AmqpParseError::UnknownEnumOption("SessionError")), + } + } +} +impl DecodeFormatted for SessionError { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, base) = Symbol::decode_with_format(input, fmt)?; + Ok((input, Self::try_from(&base)?)) + } +} +impl Encode for SessionError { + fn encoded_size(&self) -> usize { + match *self { + SessionError::WindowViolation => 29 + 2, + SessionError::ErrantLink => 24 + 2, + SessionError::HandleInUse => 26 + 2, + SessionError::UnattachedHandle => 30 + 2, + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + SessionError::WindowViolation => { + StaticSymbol("amqp:session:window-violation").encode(buf) + } + SessionError::ErrantLink => StaticSymbol("amqp:session:errant-link").encode(buf), + SessionError::HandleInUse => StaticSymbol("amqp:session:handle-in-use").encode(buf), + SessionError::UnattachedHandle => { + StaticSymbol("amqp:session:unattached-handle").encode(buf) + } + } + } +} +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum LinkError { + DetachForced, + TransferLimitExceeded, + MessageSizeExceeded, + Redirect, + Stolen, +} +impl LinkError { + pub fn try_from(v: &Symbol) -> Result { + match v.as_str() { + "amqp:link:detach-forced" => Ok(LinkError::DetachForced), + "amqp:link:transfer-limit-exceeded" => Ok(LinkError::TransferLimitExceeded), + "amqp:link:message-size-exceeded" => Ok(LinkError::MessageSizeExceeded), + "amqp:link:redirect" => Ok(LinkError::Redirect), + "amqp:link:stolen" => Ok(LinkError::Stolen), + _ => Err(AmqpParseError::UnknownEnumOption("LinkError")), + } + } +} +impl DecodeFormatted for LinkError { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, base) = Symbol::decode_with_format(input, fmt)?; + Ok((input, Self::try_from(&base)?)) + } +} +impl Encode for LinkError { + fn encoded_size(&self) -> usize { + match *self { + LinkError::DetachForced => 23 + 2, + LinkError::TransferLimitExceeded => 33 + 2, + LinkError::MessageSizeExceeded => 31 + 2, + LinkError::Redirect => 18 + 2, + LinkError::Stolen => 16 + 2, + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + LinkError::DetachForced => StaticSymbol("amqp:link:detach-forced").encode(buf), + LinkError::TransferLimitExceeded => { + StaticSymbol("amqp:link:transfer-limit-exceeded").encode(buf) + } + LinkError::MessageSizeExceeded => { + StaticSymbol("amqp:link:message-size-exceeded").encode(buf) + } + LinkError::Redirect => StaticSymbol("amqp:link:redirect").encode(buf), + LinkError::Stolen => StaticSymbol("amqp:link:stolen").encode(buf), + } + } +} +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum SaslCode { + Ok, + Auth, + Sys, + SysPerm, + SysTemp, +} +impl SaslCode { + pub fn try_from(v: u8) -> Result { + match v { + 0 => Ok(SaslCode::Ok), + 1 => Ok(SaslCode::Auth), + 2 => Ok(SaslCode::Sys), + 3 => Ok(SaslCode::SysPerm), + 4 => Ok(SaslCode::SysTemp), + _ => Err(AmqpParseError::UnknownEnumOption("SaslCode")), + } + } +} +impl DecodeFormatted for SaslCode { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, base) = u8::decode_with_format(input, fmt)?; + Ok((input, Self::try_from(base)?)) + } +} +impl Encode for SaslCode { + fn encoded_size(&self) -> usize { + match *self { + SaslCode::Ok => { + let v: u8 = 0; + v.encoded_size() + } + SaslCode::Auth => { + let v: u8 = 1; + v.encoded_size() + } + SaslCode::Sys => { + let v: u8 = 2; + v.encoded_size() + } + SaslCode::SysPerm => { + let v: u8 = 3; + v.encoded_size() + } + SaslCode::SysTemp => { + let v: u8 = 4; + v.encoded_size() + } + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + SaslCode::Ok => { + let v: u8 = 0; + v.encode(buf); + } + SaslCode::Auth => { + let v: u8 = 1; + v.encode(buf); + } + SaslCode::Sys => { + let v: u8 = 2; + v.encode(buf); + } + SaslCode::SysPerm => { + let v: u8 = 3; + v.encode(buf); + } + SaslCode::SysTemp => { + let v: u8 = 4; + v.encode(buf); + } + } + } +} +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum TerminusDurability { + None, + Configuration, + UnsettledState, +} +impl TerminusDurability { + pub fn try_from(v: u32) -> Result { + match v { + 0 => Ok(TerminusDurability::None), + 1 => Ok(TerminusDurability::Configuration), + 2 => Ok(TerminusDurability::UnsettledState), + _ => Err(AmqpParseError::UnknownEnumOption("TerminusDurability")), + } + } +} +impl DecodeFormatted for TerminusDurability { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, base) = u32::decode_with_format(input, fmt)?; + Ok((input, Self::try_from(base)?)) + } +} +impl Encode for TerminusDurability { + fn encoded_size(&self) -> usize { + match *self { + TerminusDurability::None => { + let v: u32 = 0; + v.encoded_size() + } + TerminusDurability::Configuration => { + let v: u32 = 1; + v.encoded_size() + } + TerminusDurability::UnsettledState => { + let v: u32 = 2; + v.encoded_size() + } + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + TerminusDurability::None => { + let v: u32 = 0; + v.encode(buf); + } + TerminusDurability::Configuration => { + let v: u32 = 1; + v.encode(buf); + } + TerminusDurability::UnsettledState => { + let v: u32 = 2; + v.encode(buf); + } + } + } +} +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum TerminusExpiryPolicy { + LinkDetach, + SessionEnd, + ConnectionClose, + Never, +} +impl TerminusExpiryPolicy { + pub fn try_from(v: &Symbol) -> Result { + match v.as_str() { + "link-detach" => Ok(TerminusExpiryPolicy::LinkDetach), + "session-end" => Ok(TerminusExpiryPolicy::SessionEnd), + "connection-close" => Ok(TerminusExpiryPolicy::ConnectionClose), + "never" => Ok(TerminusExpiryPolicy::Never), + _ => Err(AmqpParseError::UnknownEnumOption("TerminusExpiryPolicy")), + } + } +} +impl DecodeFormatted for TerminusExpiryPolicy { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, base) = Symbol::decode_with_format(input, fmt)?; + Ok((input, Self::try_from(&base)?)) + } +} +impl Encode for TerminusExpiryPolicy { + fn encoded_size(&self) -> usize { + match *self { + TerminusExpiryPolicy::LinkDetach => 11 + 2, + TerminusExpiryPolicy::SessionEnd => 11 + 2, + TerminusExpiryPolicy::ConnectionClose => 16 + 2, + TerminusExpiryPolicy::Never => 5 + 2, + } + } + fn encode(&self, buf: &mut BytesMut) { + match *self { + TerminusExpiryPolicy::LinkDetach => StaticSymbol("link-detach").encode(buf), + TerminusExpiryPolicy::SessionEnd => StaticSymbol("session-end").encode(buf), + TerminusExpiryPolicy::ConnectionClose => StaticSymbol("connection-close").encode(buf), + TerminusExpiryPolicy::Never => StaticSymbol("never").encode(buf), + } + } +} +type DeliveryAnnotations = Annotations; +fn decode_delivery_annotations_inner( + input: &[u8], +) -> Result<(&[u8], DeliveryAnnotations), AmqpParseError> { + DeliveryAnnotations::decode(input) +} +fn encoded_size_delivery_annotations_inner(dr: &DeliveryAnnotations) -> usize { + // descriptor size + actual size + 3 + dr.encoded_size() +} +fn encode_delivery_annotations_inner(dr: &DeliveryAnnotations, buf: &mut BytesMut) { + Descriptor::Ulong(113).encode(buf); + dr.encode(buf); +} +type MessageAnnotations = Annotations; +fn decode_message_annotations_inner( + input: &[u8], +) -> Result<(&[u8], MessageAnnotations), AmqpParseError> { + MessageAnnotations::decode(input) +} +fn encoded_size_message_annotations_inner(dr: &MessageAnnotations) -> usize { + // descriptor size + actual size + 3 + dr.encoded_size() +} +fn encode_message_annotations_inner(dr: &MessageAnnotations, buf: &mut BytesMut) { + Descriptor::Ulong(114).encode(buf); + dr.encode(buf); +} +type ApplicationProperties = StringVariantMap; +fn decode_application_properties_inner( + input: &[u8], +) -> Result<(&[u8], ApplicationProperties), AmqpParseError> { + ApplicationProperties::decode(input) +} +fn encoded_size_application_properties_inner(dr: &ApplicationProperties) -> usize { + // descriptor size + actual size + 3 + dr.encoded_size() +} +fn encode_application_properties_inner(dr: &ApplicationProperties, buf: &mut BytesMut) { + Descriptor::Ulong(116).encode(buf); + dr.encode(buf); +} +type Data = Bytes; +fn decode_data_inner(input: &[u8]) -> Result<(&[u8], Data), AmqpParseError> { + Data::decode(input) +} +fn encoded_size_data_inner(dr: &Data) -> usize { + // descriptor size + actual size + 3 + dr.encoded_size() +} +fn encode_data_inner(dr: &Data, buf: &mut BytesMut) { + Descriptor::Ulong(117).encode(buf); + dr.encode(buf); +} +type AmqpSequence = List; +fn decode_amqp_sequence_inner(input: &[u8]) -> Result<(&[u8], AmqpSequence), AmqpParseError> { + AmqpSequence::decode(input) +} +fn encoded_size_amqp_sequence_inner(dr: &AmqpSequence) -> usize { + // descriptor size + actual size + 3 + dr.encoded_size() +} +fn encode_amqp_sequence_inner(dr: &AmqpSequence, buf: &mut BytesMut) { + Descriptor::Ulong(118).encode(buf); + dr.encode(buf); +} +type AmqpValue = Variant; +fn decode_amqp_value_inner(input: &[u8]) -> Result<(&[u8], AmqpValue), AmqpParseError> { + AmqpValue::decode(input) +} +fn encoded_size_amqp_value_inner(dr: &AmqpValue) -> usize { + // descriptor size + actual size + 3 + dr.encoded_size() +} +fn encode_amqp_value_inner(dr: &AmqpValue, buf: &mut BytesMut) { + Descriptor::Ulong(119).encode(buf); + dr.encode(buf); +} +type Footer = Annotations; +fn decode_footer_inner(input: &[u8]) -> Result<(&[u8], Footer), AmqpParseError> { + Footer::decode(input) +} +fn encoded_size_footer_inner(dr: &Footer) -> usize { + // descriptor size + actual size + 3 + dr.encoded_size() +} +fn encode_footer_inner(dr: &Footer, buf: &mut BytesMut) { + Descriptor::Ulong(120).encode(buf); + dr.encode(buf); +} +#[derive(Clone, Debug, PartialEq)] +pub struct Error { + pub condition: ErrorCondition, + pub description: Option, + pub info: Option, +} +impl Error { + pub fn condition(&self) -> &ErrorCondition { + &self.condition + } + pub fn description(&self) -> Option<&ByteString> { + self.description.as_ref() + } + pub fn info(&self) -> Option<&Fields> { + self.info.as_ref() + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1 + 1 + 1; +} +#[allow(unused_mut)] +fn decode_error_inner(input: &[u8]) -> Result<(&[u8], Error), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let condition: ErrorCondition; + if count > 0 { + let (in1, decoded) = ErrorCondition::decode(input)?; + condition = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("condition")); + } + let description: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + description = decoded.1; + count -= 1; + } else { + description = None; + } + let info: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + info = decoded.1; + count -= 1; + } else { + info = None; + } + Ok(( + remainder, + Error { + condition, + description, + info, + }, + )) +} +fn encoded_size_error_inner(list: &Error) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + + list.condition.encoded_size() + + list.description.encoded_size() + + list.info.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_error_inner(list: &Error, buf: &mut BytesMut) { + Descriptor::Ulong(29).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + + list.condition.encoded_size() + + list.description.encoded_size() + + list.info.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Error::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Error::FIELD_COUNT as u8); + } + list.condition.encode(buf); + list.description.encode(buf); + list.info.encode(buf); +} +impl DecodeFormatted for Error { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 29, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:error:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_error_inner(input) + } + } +} +impl Encode for Error { + fn encoded_size(&self) -> usize { + encoded_size_error_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_error_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct Open { + pub container_id: ByteString, + pub hostname: Option, + pub max_frame_size: u32, + pub channel_max: u16, + pub idle_time_out: Option, + pub outgoing_locales: Option, + pub incoming_locales: Option, + pub offered_capabilities: Option, + pub desired_capabilities: Option, + pub properties: Option, +} +impl Open { + pub fn container_id(&self) -> &ByteString { + &self.container_id + } + pub fn hostname(&self) -> Option<&ByteString> { + self.hostname.as_ref() + } + pub fn max_frame_size(&self) -> u32 { + self.max_frame_size + } + pub fn channel_max(&self) -> u16 { + self.channel_max + } + pub fn idle_time_out(&self) -> Option { + self.idle_time_out + } + pub fn outgoing_locales(&self) -> Option<&IetfLanguageTags> { + self.outgoing_locales.as_ref() + } + pub fn incoming_locales(&self) -> Option<&IetfLanguageTags> { + self.incoming_locales.as_ref() + } + pub fn offered_capabilities(&self) -> Option<&Symbols> { + self.offered_capabilities.as_ref() + } + pub fn desired_capabilities(&self) -> Option<&Symbols> { + self.desired_capabilities.as_ref() + } + pub fn properties(&self) -> Option<&Fields> { + self.properties.as_ref() + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1; +} +#[allow(unused_mut)] +fn decode_open_inner(input: &[u8]) -> Result<(&[u8], Open), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let container_id: ByteString; + if count > 0 { + let (in1, decoded) = ByteString::decode(input)?; + container_id = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("container_id")); + } + let hostname: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + hostname = decoded.1; + count -= 1; + } else { + hostname = None; + } + let max_frame_size: u32; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + max_frame_size = decoded.unwrap_or(4294967295); + input = in1; + count -= 1; + } else { + max_frame_size = 4294967295; + } + let channel_max: u16; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + channel_max = decoded.unwrap_or(65535); + input = in1; + count -= 1; + } else { + channel_max = 65535; + } + let idle_time_out: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + idle_time_out = decoded.1; + count -= 1; + } else { + idle_time_out = None; + } + let outgoing_locales: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + outgoing_locales = decoded.1; + count -= 1; + } else { + outgoing_locales = None; + } + let incoming_locales: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + incoming_locales = decoded.1; + count -= 1; + } else { + incoming_locales = None; + } + let offered_capabilities: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + offered_capabilities = decoded.1; + count -= 1; + } else { + offered_capabilities = None; + } + let desired_capabilities: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + desired_capabilities = decoded.1; + count -= 1; + } else { + desired_capabilities = None; + } + let properties: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + properties = decoded.1; + count -= 1; + } else { + properties = None; + } + Ok(( + remainder, + Open { + container_id, + hostname, + max_frame_size, + channel_max, + idle_time_out, + outgoing_locales, + incoming_locales, + offered_capabilities, + desired_capabilities, + properties, + }, + )) +} +fn encoded_size_open_inner(list: &Open) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + + list.container_id.encoded_size() + + list.hostname.encoded_size() + + list.max_frame_size.encoded_size() + + list.channel_max.encoded_size() + + list.idle_time_out.encoded_size() + + list.outgoing_locales.encoded_size() + + list.incoming_locales.encoded_size() + + list.offered_capabilities.encoded_size() + + list.desired_capabilities.encoded_size() + + list.properties.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_open_inner(list: &Open, buf: &mut BytesMut) { + Descriptor::Ulong(16).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + + list.container_id.encoded_size() + + list.hostname.encoded_size() + + list.max_frame_size.encoded_size() + + list.channel_max.encoded_size() + + list.idle_time_out.encoded_size() + + list.outgoing_locales.encoded_size() + + list.incoming_locales.encoded_size() + + list.offered_capabilities.encoded_size() + + list.desired_capabilities.encoded_size() + + list.properties.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Open::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Open::FIELD_COUNT as u8); + } + list.container_id.encode(buf); + list.hostname.encode(buf); + list.max_frame_size.encode(buf); + list.channel_max.encode(buf); + list.idle_time_out.encode(buf); + list.outgoing_locales.encode(buf); + list.incoming_locales.encode(buf); + list.offered_capabilities.encode(buf); + list.desired_capabilities.encode(buf); + list.properties.encode(buf); +} +impl DecodeFormatted for Open { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 16, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:open:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_open_inner(input) + } + } +} +impl Encode for Open { + fn encoded_size(&self) -> usize { + encoded_size_open_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_open_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct Begin { + pub remote_channel: Option, + pub next_outgoing_id: TransferNumber, + pub incoming_window: u32, + pub outgoing_window: u32, + pub handle_max: Handle, + pub offered_capabilities: Option, + pub desired_capabilities: Option, + pub properties: Option, +} +impl Begin { + pub fn remote_channel(&self) -> Option { + self.remote_channel + } + pub fn next_outgoing_id(&self) -> TransferNumber { + self.next_outgoing_id + } + pub fn incoming_window(&self) -> u32 { + self.incoming_window + } + pub fn outgoing_window(&self) -> u32 { + self.outgoing_window + } + pub fn handle_max(&self) -> Handle { + self.handle_max + } + pub fn offered_capabilities(&self) -> Option<&Symbols> { + self.offered_capabilities.as_ref() + } + pub fn desired_capabilities(&self) -> Option<&Symbols> { + self.desired_capabilities.as_ref() + } + pub fn properties(&self) -> Option<&Fields> { + self.properties.as_ref() + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1; +} +#[allow(unused_mut)] +fn decode_begin_inner(input: &[u8]) -> Result<(&[u8], Begin), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let remote_channel: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + remote_channel = decoded.1; + count -= 1; + } else { + remote_channel = None; + } + let next_outgoing_id: TransferNumber; + if count > 0 { + let (in1, decoded) = TransferNumber::decode(input)?; + next_outgoing_id = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("next_outgoing_id")); + } + let incoming_window: u32; + if count > 0 { + let (in1, decoded) = u32::decode(input)?; + incoming_window = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("incoming_window")); + } + let outgoing_window: u32; + if count > 0 { + let (in1, decoded) = u32::decode(input)?; + outgoing_window = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("outgoing_window")); + } + let handle_max: Handle; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + handle_max = decoded.unwrap_or(4294967295); + input = in1; + count -= 1; + } else { + handle_max = 4294967295; + } + let offered_capabilities: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + offered_capabilities = decoded.1; + count -= 1; + } else { + offered_capabilities = None; + } + let desired_capabilities: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + desired_capabilities = decoded.1; + count -= 1; + } else { + desired_capabilities = None; + } + let properties: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + properties = decoded.1; + count -= 1; + } else { + properties = None; + } + Ok(( + remainder, + Begin { + remote_channel, + next_outgoing_id, + incoming_window, + outgoing_window, + handle_max, + offered_capabilities, + desired_capabilities, + properties, + }, + )) +} +fn encoded_size_begin_inner(list: &Begin) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + + list.remote_channel.encoded_size() + + list.next_outgoing_id.encoded_size() + + list.incoming_window.encoded_size() + + list.outgoing_window.encoded_size() + + list.handle_max.encoded_size() + + list.offered_capabilities.encoded_size() + + list.desired_capabilities.encoded_size() + + list.properties.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_begin_inner(list: &Begin, buf: &mut BytesMut) { + Descriptor::Ulong(17).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + + list.remote_channel.encoded_size() + + list.next_outgoing_id.encoded_size() + + list.incoming_window.encoded_size() + + list.outgoing_window.encoded_size() + + list.handle_max.encoded_size() + + list.offered_capabilities.encoded_size() + + list.desired_capabilities.encoded_size() + + list.properties.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Begin::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Begin::FIELD_COUNT as u8); + } + list.remote_channel.encode(buf); + list.next_outgoing_id.encode(buf); + list.incoming_window.encode(buf); + list.outgoing_window.encode(buf); + list.handle_max.encode(buf); + list.offered_capabilities.encode(buf); + list.desired_capabilities.encode(buf); + list.properties.encode(buf); +} +impl DecodeFormatted for Begin { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 17, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:begin:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_begin_inner(input) + } + } +} +impl Encode for Begin { + fn encoded_size(&self) -> usize { + encoded_size_begin_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_begin_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct Attach { + pub name: ByteString, + pub handle: Handle, + pub role: Role, + pub snd_settle_mode: SenderSettleMode, + pub rcv_settle_mode: ReceiverSettleMode, + pub source: Option, + pub target: Option, + pub unsettled: Option, + pub incomplete_unsettled: bool, + pub initial_delivery_count: Option, + pub max_message_size: Option, + pub offered_capabilities: Option, + pub desired_capabilities: Option, + pub properties: Option, +} +impl Attach { + pub fn name(&self) -> &ByteString { + &self.name + } + pub fn handle(&self) -> Handle { + self.handle + } + pub fn role(&self) -> Role { + self.role + } + pub fn snd_settle_mode(&self) -> SenderSettleMode { + self.snd_settle_mode + } + pub fn rcv_settle_mode(&self) -> ReceiverSettleMode { + self.rcv_settle_mode + } + pub fn source(&self) -> Option<&Source> { + self.source.as_ref() + } + pub fn target(&self) -> Option<&Target> { + self.target.as_ref() + } + pub fn unsettled(&self) -> Option<&Map> { + self.unsettled.as_ref() + } + pub fn incomplete_unsettled(&self) -> bool { + self.incomplete_unsettled + } + pub fn initial_delivery_count(&self) -> Option { + self.initial_delivery_count + } + pub fn max_message_size(&self) -> Option { + self.max_message_size + } + pub fn offered_capabilities(&self) -> Option<&Symbols> { + self.offered_capabilities.as_ref() + } + pub fn desired_capabilities(&self) -> Option<&Symbols> { + self.desired_capabilities.as_ref() + } + pub fn properties(&self) -> Option<&Fields> { + self.properties.as_ref() + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1; +} +#[allow(unused_mut)] +fn decode_attach_inner(input: &[u8]) -> Result<(&[u8], Attach), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let name: ByteString; + if count > 0 { + let (in1, decoded) = ByteString::decode(input)?; + name = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("name")); + } + let handle: Handle; + if count > 0 { + let (in1, decoded) = Handle::decode(input)?; + handle = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("handle")); + } + let role: Role; + if count > 0 { + let (in1, decoded) = Role::decode(input)?; + role = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("role")); + } + let snd_settle_mode: SenderSettleMode; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + snd_settle_mode = decoded.unwrap_or(SenderSettleMode::Mixed); + input = in1; + count -= 1; + } else { + snd_settle_mode = SenderSettleMode::Mixed; + } + let rcv_settle_mode: ReceiverSettleMode; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + rcv_settle_mode = decoded.unwrap_or(ReceiverSettleMode::First); + input = in1; + count -= 1; + } else { + rcv_settle_mode = ReceiverSettleMode::First; + } + let source: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + source = decoded.1; + count -= 1; + } else { + source = None; + } + let target: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + target = decoded.1; + count -= 1; + } else { + target = None; + } + let unsettled: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + unsettled = decoded.1; + count -= 1; + } else { + unsettled = None; + } + let incomplete_unsettled: bool; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + incomplete_unsettled = decoded.unwrap_or(false); + input = in1; + count -= 1; + } else { + incomplete_unsettled = false; + } + let initial_delivery_count: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + initial_delivery_count = decoded.1; + count -= 1; + } else { + initial_delivery_count = None; + } + let max_message_size: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + max_message_size = decoded.1; + count -= 1; + } else { + max_message_size = None; + } + let offered_capabilities: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + offered_capabilities = decoded.1; + count -= 1; + } else { + offered_capabilities = None; + } + let desired_capabilities: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + desired_capabilities = decoded.1; + count -= 1; + } else { + desired_capabilities = None; + } + let properties: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + properties = decoded.1; + count -= 1; + } else { + properties = None; + } + Ok(( + remainder, + Attach { + name, + handle, + role, + snd_settle_mode, + rcv_settle_mode, + source, + target, + unsettled, + incomplete_unsettled, + initial_delivery_count, + max_message_size, + offered_capabilities, + desired_capabilities, + properties, + }, + )) +} +fn encoded_size_attach_inner(list: &Attach) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + + list.name.encoded_size() + + list.handle.encoded_size() + + list.role.encoded_size() + + list.snd_settle_mode.encoded_size() + + list.rcv_settle_mode.encoded_size() + + list.source.encoded_size() + + list.target.encoded_size() + + list.unsettled.encoded_size() + + list.incomplete_unsettled.encoded_size() + + list.initial_delivery_count.encoded_size() + + list.max_message_size.encoded_size() + + list.offered_capabilities.encoded_size() + + list.desired_capabilities.encoded_size() + + list.properties.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_attach_inner(list: &Attach, buf: &mut BytesMut) { + Descriptor::Ulong(18).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + + list.name.encoded_size() + + list.handle.encoded_size() + + list.role.encoded_size() + + list.snd_settle_mode.encoded_size() + + list.rcv_settle_mode.encoded_size() + + list.source.encoded_size() + + list.target.encoded_size() + + list.unsettled.encoded_size() + + list.incomplete_unsettled.encoded_size() + + list.initial_delivery_count.encoded_size() + + list.max_message_size.encoded_size() + + list.offered_capabilities.encoded_size() + + list.desired_capabilities.encoded_size() + + list.properties.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Attach::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Attach::FIELD_COUNT as u8); + } + list.name.encode(buf); + list.handle.encode(buf); + list.role.encode(buf); + list.snd_settle_mode.encode(buf); + list.rcv_settle_mode.encode(buf); + list.source.encode(buf); + list.target.encode(buf); + list.unsettled.encode(buf); + list.incomplete_unsettled.encode(buf); + list.initial_delivery_count.encode(buf); + list.max_message_size.encode(buf); + list.offered_capabilities.encode(buf); + list.desired_capabilities.encode(buf); + list.properties.encode(buf); +} +impl DecodeFormatted for Attach { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 18, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:attach:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_attach_inner(input) + } + } +} +impl Encode for Attach { + fn encoded_size(&self) -> usize { + encoded_size_attach_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_attach_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct Flow { + pub next_incoming_id: Option, + pub incoming_window: u32, + pub next_outgoing_id: TransferNumber, + pub outgoing_window: u32, + pub handle: Option, + pub delivery_count: Option, + pub link_credit: Option, + pub available: Option, + pub drain: bool, + pub echo: bool, + pub properties: Option, +} +impl Flow { + pub fn next_incoming_id(&self) -> Option { + self.next_incoming_id + } + pub fn incoming_window(&self) -> u32 { + self.incoming_window + } + pub fn next_outgoing_id(&self) -> TransferNumber { + self.next_outgoing_id + } + pub fn outgoing_window(&self) -> u32 { + self.outgoing_window + } + pub fn handle(&self) -> Option { + self.handle + } + pub fn delivery_count(&self) -> Option { + self.delivery_count + } + pub fn link_credit(&self) -> Option { + self.link_credit + } + pub fn available(&self) -> Option { + self.available + } + pub fn drain(&self) -> bool { + self.drain + } + pub fn echo(&self) -> bool { + self.echo + } + pub fn properties(&self) -> Option<&Fields> { + self.properties.as_ref() + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1; +} +#[allow(unused_mut)] +fn decode_flow_inner(input: &[u8]) -> Result<(&[u8], Flow), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let next_incoming_id: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + next_incoming_id = decoded.1; + count -= 1; + } else { + next_incoming_id = None; + } + let incoming_window: u32; + if count > 0 { + let (in1, decoded) = u32::decode(input)?; + incoming_window = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("incoming_window")); + } + let next_outgoing_id: TransferNumber; + if count > 0 { + let (in1, decoded) = TransferNumber::decode(input)?; + next_outgoing_id = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("next_outgoing_id")); + } + let outgoing_window: u32; + if count > 0 { + let (in1, decoded) = u32::decode(input)?; + outgoing_window = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("outgoing_window")); + } + let handle: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + handle = decoded.1; + count -= 1; + } else { + handle = None; + } + let delivery_count: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + delivery_count = decoded.1; + count -= 1; + } else { + delivery_count = None; + } + let link_credit: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + link_credit = decoded.1; + count -= 1; + } else { + link_credit = None; + } + let available: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + available = decoded.1; + count -= 1; + } else { + available = None; + } + let drain: bool; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + drain = decoded.unwrap_or(false); + input = in1; + count -= 1; + } else { + drain = false; + } + let echo: bool; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + echo = decoded.unwrap_or(false); + input = in1; + count -= 1; + } else { + echo = false; + } + let properties: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + properties = decoded.1; + count -= 1; + } else { + properties = None; + } + Ok(( + remainder, + Flow { + next_incoming_id, + incoming_window, + next_outgoing_id, + outgoing_window, + handle, + delivery_count, + link_credit, + available, + drain, + echo, + properties, + }, + )) +} +fn encoded_size_flow_inner(list: &Flow) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + + list.next_incoming_id.encoded_size() + + list.incoming_window.encoded_size() + + list.next_outgoing_id.encoded_size() + + list.outgoing_window.encoded_size() + + list.handle.encoded_size() + + list.delivery_count.encoded_size() + + list.link_credit.encoded_size() + + list.available.encoded_size() + + list.drain.encoded_size() + + list.echo.encoded_size() + + list.properties.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_flow_inner(list: &Flow, buf: &mut BytesMut) { + Descriptor::Ulong(19).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + + list.next_incoming_id.encoded_size() + + list.incoming_window.encoded_size() + + list.next_outgoing_id.encoded_size() + + list.outgoing_window.encoded_size() + + list.handle.encoded_size() + + list.delivery_count.encoded_size() + + list.link_credit.encoded_size() + + list.available.encoded_size() + + list.drain.encoded_size() + + list.echo.encoded_size() + + list.properties.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Flow::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Flow::FIELD_COUNT as u8); + } + list.next_incoming_id.encode(buf); + list.incoming_window.encode(buf); + list.next_outgoing_id.encode(buf); + list.outgoing_window.encode(buf); + list.handle.encode(buf); + list.delivery_count.encode(buf); + list.link_credit.encode(buf); + list.available.encode(buf); + list.drain.encode(buf); + list.echo.encode(buf); + list.properties.encode(buf); +} +impl DecodeFormatted for Flow { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 19, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:flow:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_flow_inner(input) + } + } +} +impl Encode for Flow { + fn encoded_size(&self) -> usize { + encoded_size_flow_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_flow_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct Transfer { + pub handle: Handle, + pub delivery_id: Option, + pub delivery_tag: Option, + pub message_format: Option, + pub settled: Option, + pub more: bool, + pub rcv_settle_mode: Option, + pub state: Option, + pub resume: bool, + pub aborted: bool, + pub batchable: bool, + pub body: Option, +} +impl Transfer { + pub fn handle(&self) -> Handle { + self.handle + } + pub fn delivery_id(&self) -> Option { + self.delivery_id + } + pub fn delivery_tag(&self) -> Option<&DeliveryTag> { + self.delivery_tag.as_ref() + } + pub fn message_format(&self) -> Option { + self.message_format + } + pub fn settled(&self) -> Option { + self.settled + } + pub fn more(&self) -> bool { + self.more + } + pub fn rcv_settle_mode(&self) -> Option { + self.rcv_settle_mode + } + pub fn state(&self) -> Option<&DeliveryState> { + self.state.as_ref() + } + pub fn resume(&self) -> bool { + self.resume + } + pub fn aborted(&self) -> bool { + self.aborted + } + pub fn batchable(&self) -> bool { + self.batchable + } + pub fn body(&self) -> Option<&TransferBody> { + self.body.as_ref() + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1; +} +#[allow(unused_mut)] +fn decode_transfer_inner(input: &[u8]) -> Result<(&[u8], Transfer), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let handle: Handle; + if count > 0 { + let (in1, decoded) = Handle::decode(input)?; + handle = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("handle")); + } + let delivery_id: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + delivery_id = decoded.1; + count -= 1; + } else { + delivery_id = None; + } + let delivery_tag: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + delivery_tag = decoded.1; + count -= 1; + } else { + delivery_tag = None; + } + let message_format: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + message_format = decoded.1; + count -= 1; + } else { + message_format = None; + } + let settled: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + settled = decoded.1; + count -= 1; + } else { + settled = None; + } + let more: bool; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + more = decoded.unwrap_or(false); + input = in1; + count -= 1; + } else { + more = false; + } + let rcv_settle_mode: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + rcv_settle_mode = decoded.1; + count -= 1; + } else { + rcv_settle_mode = None; + } + let state: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + state = decoded.1; + count -= 1; + } else { + state = None; + } + let resume: bool; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + resume = decoded.unwrap_or(false); + input = in1; + count -= 1; + } else { + resume = false; + } + let aborted: bool; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + aborted = decoded.unwrap_or(false); + input = in1; + count -= 1; + } else { + aborted = false; + } + let batchable: bool; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + batchable = decoded.unwrap_or(false); + input = in1; + count -= 1; + } else { + batchable = false; + } + let body = if remainder.is_empty() { + None + } else { + let b = Bytes::copy_from_slice(remainder); + remainder = &[]; + Some(b.into()) + }; + Ok(( + remainder, + Transfer { + handle, + delivery_id, + delivery_tag, + message_format, + settled, + more, + rcv_settle_mode, + state, + resume, + aborted, + batchable, + body, + }, + )) +} +fn encoded_size_transfer_inner(list: &Transfer) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + + list.handle.encoded_size() + + list.delivery_id.encoded_size() + + list.delivery_tag.encoded_size() + + list.message_format.encoded_size() + + list.settled.encoded_size() + + list.more.encoded_size() + + list.rcv_settle_mode.encoded_size() + + list.state.encoded_size() + + list.resume.encoded_size() + + list.aborted.encoded_size() + + list.batchable.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size + + list.body.as_ref().map(|b| b.len()).unwrap_or(0) +} +fn encode_transfer_inner(list: &Transfer, buf: &mut BytesMut) { + Descriptor::Ulong(20).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + + list.handle.encoded_size() + + list.delivery_id.encoded_size() + + list.delivery_tag.encoded_size() + + list.message_format.encoded_size() + + list.settled.encoded_size() + + list.more.encoded_size() + + list.rcv_settle_mode.encoded_size() + + list.state.encoded_size() + + list.resume.encoded_size() + + list.aborted.encoded_size() + + list.batchable.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Transfer::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Transfer::FIELD_COUNT as u8); + } + list.handle.encode(buf); + list.delivery_id.encode(buf); + list.delivery_tag.encode(buf); + list.message_format.encode(buf); + list.settled.encode(buf); + list.more.encode(buf); + list.rcv_settle_mode.encode(buf); + list.state.encode(buf); + list.resume.encode(buf); + list.aborted.encode(buf); + list.batchable.encode(buf); + if let Some(ref body) = list.body { + body.encode(buf) + } +} +impl DecodeFormatted for Transfer { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 20, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:transfer:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_transfer_inner(input) + } + } +} +impl Encode for Transfer { + fn encoded_size(&self) -> usize { + encoded_size_transfer_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_transfer_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct Disposition { + pub role: Role, + pub first: DeliveryNumber, + pub last: Option, + pub settled: bool, + pub state: Option, + pub batchable: bool, +} +impl Disposition { + pub fn role(&self) -> Role { + self.role + } + pub fn first(&self) -> DeliveryNumber { + self.first + } + pub fn last(&self) -> Option { + self.last + } + pub fn settled(&self) -> bool { + self.settled + } + pub fn state(&self) -> Option<&DeliveryState> { + self.state.as_ref() + } + pub fn batchable(&self) -> bool { + self.batchable + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1 + 1 + 1 + 1 + 1 + 1; +} +#[allow(unused_mut)] +fn decode_disposition_inner(input: &[u8]) -> Result<(&[u8], Disposition), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let role: Role; + if count > 0 { + let (in1, decoded) = Role::decode(input)?; + role = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("role")); + } + let first: DeliveryNumber; + if count > 0 { + let (in1, decoded) = DeliveryNumber::decode(input)?; + first = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("first")); + } + let last: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + last = decoded.1; + count -= 1; + } else { + last = None; + } + let settled: bool; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + settled = decoded.unwrap_or(false); + input = in1; + count -= 1; + } else { + settled = false; + } + let state: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + state = decoded.1; + count -= 1; + } else { + state = None; + } + let batchable: bool; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + batchable = decoded.unwrap_or(false); + input = in1; + count -= 1; + } else { + batchable = false; + } + Ok(( + remainder, + Disposition { + role, + first, + last, + settled, + state, + batchable, + }, + )) +} +fn encoded_size_disposition_inner(list: &Disposition) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + + list.role.encoded_size() + + list.first.encoded_size() + + list.last.encoded_size() + + list.settled.encoded_size() + + list.state.encoded_size() + + list.batchable.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_disposition_inner(list: &Disposition, buf: &mut BytesMut) { + Descriptor::Ulong(21).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + + list.role.encoded_size() + + list.first.encoded_size() + + list.last.encoded_size() + + list.settled.encoded_size() + + list.state.encoded_size() + + list.batchable.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Disposition::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Disposition::FIELD_COUNT as u8); + } + list.role.encode(buf); + list.first.encode(buf); + list.last.encode(buf); + list.settled.encode(buf); + list.state.encode(buf); + list.batchable.encode(buf); +} +impl DecodeFormatted for Disposition { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 21, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:disposition:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_disposition_inner(input) + } + } +} +impl Encode for Disposition { + fn encoded_size(&self) -> usize { + encoded_size_disposition_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_disposition_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct Detach { + pub handle: Handle, + pub closed: bool, + pub error: Option, +} +impl Detach { + pub fn handle(&self) -> Handle { + self.handle + } + pub fn closed(&self) -> bool { + self.closed + } + pub fn error(&self) -> Option<&Error> { + self.error.as_ref() + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1 + 1 + 1; +} +#[allow(unused_mut)] +fn decode_detach_inner(input: &[u8]) -> Result<(&[u8], Detach), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let handle: Handle; + if count > 0 { + let (in1, decoded) = Handle::decode(input)?; + handle = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("handle")); + } + let closed: bool; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + closed = decoded.unwrap_or(false); + input = in1; + count -= 1; + } else { + closed = false; + } + let error: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + error = decoded.1; + count -= 1; + } else { + error = None; + } + Ok(( + remainder, + Detach { + handle, + closed, + error, + }, + )) +} +fn encoded_size_detach_inner(list: &Detach) -> usize { + #[allow(clippy::identity_op)] + let content_size = + 0 + list.handle.encoded_size() + list.closed.encoded_size() + list.error.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_detach_inner(list: &Detach, buf: &mut BytesMut) { + Descriptor::Ulong(22).encode(buf); + #[allow(clippy::identity_op)] + let content_size = + 0 + list.handle.encoded_size() + list.closed.encoded_size() + list.error.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Detach::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Detach::FIELD_COUNT as u8); + } + list.handle.encode(buf); + list.closed.encode(buf); + list.error.encode(buf); +} +impl DecodeFormatted for Detach { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 22, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:detach:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_detach_inner(input) + } + } +} +impl Encode for Detach { + fn encoded_size(&self) -> usize { + encoded_size_detach_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_detach_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct End { + pub error: Option, +} +impl End { + pub fn error(&self) -> Option<&Error> { + self.error.as_ref() + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1; +} +#[allow(unused_mut)] +fn decode_end_inner(input: &[u8]) -> Result<(&[u8], End), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let error: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + error = decoded.1; + count -= 1; + } else { + error = None; + } + Ok((remainder, End { error })) +} +fn encoded_size_end_inner(list: &End) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + list.error.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_end_inner(list: &End, buf: &mut BytesMut) { + Descriptor::Ulong(23).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + list.error.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(End::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(End::FIELD_COUNT as u8); + } + list.error.encode(buf); +} +impl DecodeFormatted for End { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 23, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:end:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_end_inner(input) + } + } +} +impl Encode for End { + fn encoded_size(&self) -> usize { + encoded_size_end_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_end_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct Close { + pub error: Option, +} +impl Close { + pub fn error(&self) -> Option<&Error> { + self.error.as_ref() + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1; +} +#[allow(unused_mut)] +fn decode_close_inner(input: &[u8]) -> Result<(&[u8], Close), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let error: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + error = decoded.1; + count -= 1; + } else { + error = None; + } + Ok((remainder, Close { error })) +} +fn encoded_size_close_inner(list: &Close) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + list.error.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_close_inner(list: &Close, buf: &mut BytesMut) { + Descriptor::Ulong(24).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + list.error.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Close::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Close::FIELD_COUNT as u8); + } + list.error.encode(buf); +} +impl DecodeFormatted for Close { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 24, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:close:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_close_inner(input) + } + } +} +impl Encode for Close { + fn encoded_size(&self) -> usize { + encoded_size_close_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_close_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct SaslMechanisms { + pub sasl_server_mechanisms: Symbols, +} +impl SaslMechanisms { + pub fn sasl_server_mechanisms(&self) -> &Symbols { + &self.sasl_server_mechanisms + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1; +} +#[allow(unused_mut)] +fn decode_sasl_mechanisms_inner(input: &[u8]) -> Result<(&[u8], SaslMechanisms), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let sasl_server_mechanisms: Symbols; + if count > 0 { + let (in1, decoded) = Symbols::decode(input)?; + sasl_server_mechanisms = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted( + "sasl_server_mechanisms", + )); + } + Ok(( + remainder, + SaslMechanisms { + sasl_server_mechanisms, + }, + )) +} +fn encoded_size_sasl_mechanisms_inner(list: &SaslMechanisms) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + list.sasl_server_mechanisms.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_sasl_mechanisms_inner(list: &SaslMechanisms, buf: &mut BytesMut) { + Descriptor::Ulong(64).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + list.sasl_server_mechanisms.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(SaslMechanisms::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(SaslMechanisms::FIELD_COUNT as u8); + } + list.sasl_server_mechanisms.encode(buf); +} +impl DecodeFormatted for SaslMechanisms { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 64, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:sasl-mechanisms:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_sasl_mechanisms_inner(input) + } + } +} +impl Encode for SaslMechanisms { + fn encoded_size(&self) -> usize { + encoded_size_sasl_mechanisms_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_sasl_mechanisms_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct SaslInit { + pub mechanism: Symbol, + pub initial_response: Option, + pub hostname: Option, +} +impl SaslInit { + pub fn mechanism(&self) -> &Symbol { + &self.mechanism + } + pub fn initial_response(&self) -> Option<&Bytes> { + self.initial_response.as_ref() + } + pub fn hostname(&self) -> Option<&ByteString> { + self.hostname.as_ref() + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1 + 1 + 1; +} +#[allow(unused_mut)] +fn decode_sasl_init_inner(input: &[u8]) -> Result<(&[u8], SaslInit), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let mechanism: Symbol; + if count > 0 { + let (in1, decoded) = Symbol::decode(input)?; + mechanism = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("mechanism")); + } + let initial_response: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + initial_response = decoded.1; + count -= 1; + } else { + initial_response = None; + } + let hostname: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + hostname = decoded.1; + count -= 1; + } else { + hostname = None; + } + Ok(( + remainder, + SaslInit { + mechanism, + initial_response, + hostname, + }, + )) +} +fn encoded_size_sasl_init_inner(list: &SaslInit) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + + list.mechanism.encoded_size() + + list.initial_response.encoded_size() + + list.hostname.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_sasl_init_inner(list: &SaslInit, buf: &mut BytesMut) { + Descriptor::Ulong(65).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + + list.mechanism.encoded_size() + + list.initial_response.encoded_size() + + list.hostname.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(SaslInit::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(SaslInit::FIELD_COUNT as u8); + } + list.mechanism.encode(buf); + list.initial_response.encode(buf); + list.hostname.encode(buf); +} +impl DecodeFormatted for SaslInit { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 65, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:sasl-init:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_sasl_init_inner(input) + } + } +} +impl Encode for SaslInit { + fn encoded_size(&self) -> usize { + encoded_size_sasl_init_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_sasl_init_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct SaslChallenge { + pub challenge: Bytes, +} +impl SaslChallenge { + pub fn challenge(&self) -> &Bytes { + &self.challenge + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1; +} +#[allow(unused_mut)] +fn decode_sasl_challenge_inner(input: &[u8]) -> Result<(&[u8], SaslChallenge), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let challenge: Bytes; + if count > 0 { + let (in1, decoded) = Bytes::decode(input)?; + challenge = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("challenge")); + } + Ok((remainder, SaslChallenge { challenge })) +} +fn encoded_size_sasl_challenge_inner(list: &SaslChallenge) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + list.challenge.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_sasl_challenge_inner(list: &SaslChallenge, buf: &mut BytesMut) { + Descriptor::Ulong(66).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + list.challenge.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(SaslChallenge::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(SaslChallenge::FIELD_COUNT as u8); + } + list.challenge.encode(buf); +} +impl DecodeFormatted for SaslChallenge { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 66, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:sasl-challenge:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_sasl_challenge_inner(input) + } + } +} +impl Encode for SaslChallenge { + fn encoded_size(&self) -> usize { + encoded_size_sasl_challenge_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_sasl_challenge_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct SaslResponse { + pub response: Bytes, +} +impl SaslResponse { + pub fn response(&self) -> &Bytes { + &self.response + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1; +} +#[allow(unused_mut)] +fn decode_sasl_response_inner(input: &[u8]) -> Result<(&[u8], SaslResponse), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let response: Bytes; + if count > 0 { + let (in1, decoded) = Bytes::decode(input)?; + response = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("response")); + } + Ok((remainder, SaslResponse { response })) +} +fn encoded_size_sasl_response_inner(list: &SaslResponse) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + list.response.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_sasl_response_inner(list: &SaslResponse, buf: &mut BytesMut) { + Descriptor::Ulong(67).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + list.response.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(SaslResponse::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(SaslResponse::FIELD_COUNT as u8); + } + list.response.encode(buf); +} +impl DecodeFormatted for SaslResponse { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 67, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:sasl-response:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_sasl_response_inner(input) + } + } +} +impl Encode for SaslResponse { + fn encoded_size(&self) -> usize { + encoded_size_sasl_response_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_sasl_response_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct SaslOutcome { + pub code: SaslCode, + pub additional_data: Option, +} +impl SaslOutcome { + pub fn code(&self) -> SaslCode { + self.code + } + pub fn additional_data(&self) -> Option<&Bytes> { + self.additional_data.as_ref() + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1 + 1; +} +#[allow(unused_mut)] +fn decode_sasl_outcome_inner(input: &[u8]) -> Result<(&[u8], SaslOutcome), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let code: SaslCode; + if count > 0 { + let (in1, decoded) = SaslCode::decode(input)?; + code = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("code")); + } + let additional_data: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + additional_data = decoded.1; + count -= 1; + } else { + additional_data = None; + } + Ok(( + remainder, + SaslOutcome { + code, + additional_data, + }, + )) +} +fn encoded_size_sasl_outcome_inner(list: &SaslOutcome) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + list.code.encoded_size() + list.additional_data.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_sasl_outcome_inner(list: &SaslOutcome, buf: &mut BytesMut) { + Descriptor::Ulong(68).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + list.code.encoded_size() + list.additional_data.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(SaslOutcome::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(SaslOutcome::FIELD_COUNT as u8); + } + list.code.encode(buf); + list.additional_data.encode(buf); +} +impl DecodeFormatted for SaslOutcome { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 68, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:sasl-outcome:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_sasl_outcome_inner(input) + } + } +} +impl Encode for SaslOutcome { + fn encoded_size(&self) -> usize { + encoded_size_sasl_outcome_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_sasl_outcome_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct Source { + pub address: Option
, + pub durable: TerminusDurability, + pub expiry_policy: TerminusExpiryPolicy, + pub timeout: Seconds, + pub dynamic: bool, + pub dynamic_node_properties: Option, + pub distribution_mode: Option, + pub filter: Option, + pub default_outcome: Option, + pub outcomes: Option, + pub capabilities: Option, +} +impl Source { + pub fn address(&self) -> Option<&Address> { + self.address.as_ref() + } + pub fn durable(&self) -> TerminusDurability { + self.durable + } + pub fn expiry_policy(&self) -> TerminusExpiryPolicy { + self.expiry_policy + } + pub fn timeout(&self) -> Seconds { + self.timeout + } + pub fn dynamic(&self) -> bool { + self.dynamic + } + pub fn dynamic_node_properties(&self) -> Option<&NodeProperties> { + self.dynamic_node_properties.as_ref() + } + pub fn distribution_mode(&self) -> Option<&DistributionMode> { + self.distribution_mode.as_ref() + } + pub fn filter(&self) -> Option<&FilterSet> { + self.filter.as_ref() + } + pub fn default_outcome(&self) -> Option<&Outcome> { + self.default_outcome.as_ref() + } + pub fn outcomes(&self) -> Option<&Symbols> { + self.outcomes.as_ref() + } + pub fn capabilities(&self) -> Option<&Symbols> { + self.capabilities.as_ref() + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1; +} +#[allow(unused_mut)] +fn decode_source_inner(input: &[u8]) -> Result<(&[u8], Source), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let address: Option
; + if count > 0 { + let decoded = Option::
::decode(input)?; + input = decoded.0; + address = decoded.1; + count -= 1; + } else { + address = None; + } + let durable: TerminusDurability; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + durable = decoded.unwrap_or(TerminusDurability::None); + input = in1; + count -= 1; + } else { + durable = TerminusDurability::None; + } + let expiry_policy: TerminusExpiryPolicy; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + expiry_policy = decoded.unwrap_or(TerminusExpiryPolicy::SessionEnd); + input = in1; + count -= 1; + } else { + expiry_policy = TerminusExpiryPolicy::SessionEnd; + } + let timeout: Seconds; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + timeout = decoded.unwrap_or(0); + input = in1; + count -= 1; + } else { + timeout = 0; + } + let dynamic: bool; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + dynamic = decoded.unwrap_or(false); + input = in1; + count -= 1; + } else { + dynamic = false; + } + let dynamic_node_properties: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + dynamic_node_properties = decoded.1; + count -= 1; + } else { + dynamic_node_properties = None; + } + let distribution_mode: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + distribution_mode = decoded.1; + count -= 1; + } else { + distribution_mode = None; + } + let filter: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + filter = decoded.1; + count -= 1; + } else { + filter = None; + } + let default_outcome: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + default_outcome = decoded.1; + count -= 1; + } else { + default_outcome = None; + } + let outcomes: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + outcomes = decoded.1; + count -= 1; + } else { + outcomes = None; + } + let capabilities: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + capabilities = decoded.1; + count -= 1; + } else { + capabilities = None; + } + Ok(( + remainder, + Source { + address, + durable, + expiry_policy, + timeout, + dynamic, + dynamic_node_properties, + distribution_mode, + filter, + default_outcome, + outcomes, + capabilities, + }, + )) +} +fn encoded_size_source_inner(list: &Source) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + + list.address.encoded_size() + + list.durable.encoded_size() + + list.expiry_policy.encoded_size() + + list.timeout.encoded_size() + + list.dynamic.encoded_size() + + list.dynamic_node_properties.encoded_size() + + list.distribution_mode.encoded_size() + + list.filter.encoded_size() + + list.default_outcome.encoded_size() + + list.outcomes.encoded_size() + + list.capabilities.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_source_inner(list: &Source, buf: &mut BytesMut) { + Descriptor::Ulong(40).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + + list.address.encoded_size() + + list.durable.encoded_size() + + list.expiry_policy.encoded_size() + + list.timeout.encoded_size() + + list.dynamic.encoded_size() + + list.dynamic_node_properties.encoded_size() + + list.distribution_mode.encoded_size() + + list.filter.encoded_size() + + list.default_outcome.encoded_size() + + list.outcomes.encoded_size() + + list.capabilities.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Source::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Source::FIELD_COUNT as u8); + } + list.address.encode(buf); + list.durable.encode(buf); + list.expiry_policy.encode(buf); + list.timeout.encode(buf); + list.dynamic.encode(buf); + list.dynamic_node_properties.encode(buf); + list.distribution_mode.encode(buf); + list.filter.encode(buf); + list.default_outcome.encode(buf); + list.outcomes.encode(buf); + list.capabilities.encode(buf); +} +impl DecodeFormatted for Source { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 40, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:source:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_source_inner(input) + } + } +} +impl Encode for Source { + fn encoded_size(&self) -> usize { + encoded_size_source_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_source_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct Target { + pub address: Option
, + pub durable: TerminusDurability, + pub expiry_policy: TerminusExpiryPolicy, + pub timeout: Seconds, + pub dynamic: bool, + pub dynamic_node_properties: Option, + pub capabilities: Option, +} +impl Target { + pub fn address(&self) -> Option<&Address> { + self.address.as_ref() + } + pub fn durable(&self) -> TerminusDurability { + self.durable + } + pub fn expiry_policy(&self) -> TerminusExpiryPolicy { + self.expiry_policy + } + pub fn timeout(&self) -> Seconds { + self.timeout + } + pub fn dynamic(&self) -> bool { + self.dynamic + } + pub fn dynamic_node_properties(&self) -> Option<&NodeProperties> { + self.dynamic_node_properties.as_ref() + } + pub fn capabilities(&self) -> Option<&Symbols> { + self.capabilities.as_ref() + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1 + 1 + 1 + 1 + 1 + 1 + 1; +} +#[allow(unused_mut)] +fn decode_target_inner(input: &[u8]) -> Result<(&[u8], Target), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let address: Option
; + if count > 0 { + let decoded = Option::
::decode(input)?; + input = decoded.0; + address = decoded.1; + count -= 1; + } else { + address = None; + } + let durable: TerminusDurability; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + durable = decoded.unwrap_or(TerminusDurability::None); + input = in1; + count -= 1; + } else { + durable = TerminusDurability::None; + } + let expiry_policy: TerminusExpiryPolicy; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + expiry_policy = decoded.unwrap_or(TerminusExpiryPolicy::SessionEnd); + input = in1; + count -= 1; + } else { + expiry_policy = TerminusExpiryPolicy::SessionEnd; + } + let timeout: Seconds; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + timeout = decoded.unwrap_or(0); + input = in1; + count -= 1; + } else { + timeout = 0; + } + let dynamic: bool; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + dynamic = decoded.unwrap_or(false); + input = in1; + count -= 1; + } else { + dynamic = false; + } + let dynamic_node_properties: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + dynamic_node_properties = decoded.1; + count -= 1; + } else { + dynamic_node_properties = None; + } + let capabilities: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + capabilities = decoded.1; + count -= 1; + } else { + capabilities = None; + } + Ok(( + remainder, + Target { + address, + durable, + expiry_policy, + timeout, + dynamic, + dynamic_node_properties, + capabilities, + }, + )) +} +fn encoded_size_target_inner(list: &Target) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + + list.address.encoded_size() + + list.durable.encoded_size() + + list.expiry_policy.encoded_size() + + list.timeout.encoded_size() + + list.dynamic.encoded_size() + + list.dynamic_node_properties.encoded_size() + + list.capabilities.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_target_inner(list: &Target, buf: &mut BytesMut) { + Descriptor::Ulong(41).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + + list.address.encoded_size() + + list.durable.encoded_size() + + list.expiry_policy.encoded_size() + + list.timeout.encoded_size() + + list.dynamic.encoded_size() + + list.dynamic_node_properties.encoded_size() + + list.capabilities.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Target::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Target::FIELD_COUNT as u8); + } + list.address.encode(buf); + list.durable.encode(buf); + list.expiry_policy.encode(buf); + list.timeout.encode(buf); + list.dynamic.encode(buf); + list.dynamic_node_properties.encode(buf); + list.capabilities.encode(buf); +} +impl DecodeFormatted for Target { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 41, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:target:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_target_inner(input) + } + } +} +impl Encode for Target { + fn encoded_size(&self) -> usize { + encoded_size_target_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_target_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct Header { + pub durable: bool, + pub priority: u8, + pub ttl: Option, + pub first_acquirer: bool, + pub delivery_count: u32, +} +impl Header { + pub fn durable(&self) -> bool { + self.durable + } + pub fn priority(&self) -> u8 { + self.priority + } + pub fn ttl(&self) -> Option { + self.ttl + } + pub fn first_acquirer(&self) -> bool { + self.first_acquirer + } + pub fn delivery_count(&self) -> u32 { + self.delivery_count + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1 + 1 + 1 + 1 + 1; +} +#[allow(unused_mut)] +fn decode_header_inner(input: &[u8]) -> Result<(&[u8], Header), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let durable: bool; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + durable = decoded.unwrap_or(false); + input = in1; + count -= 1; + } else { + durable = false; + } + let priority: u8; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + priority = decoded.unwrap_or(4); + input = in1; + count -= 1; + } else { + priority = 4; + } + let ttl: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + ttl = decoded.1; + count -= 1; + } else { + ttl = None; + } + let first_acquirer: bool; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + first_acquirer = decoded.unwrap_or(false); + input = in1; + count -= 1; + } else { + first_acquirer = false; + } + let delivery_count: u32; + if count > 0 { + let (in1, decoded) = Option::::decode(input)?; + delivery_count = decoded.unwrap_or(0); + input = in1; + count -= 1; + } else { + delivery_count = 0; + } + Ok(( + remainder, + Header { + durable, + priority, + ttl, + first_acquirer, + delivery_count, + }, + )) +} +fn encoded_size_header_inner(list: &Header) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + + list.durable.encoded_size() + + list.priority.encoded_size() + + list.ttl.encoded_size() + + list.first_acquirer.encoded_size() + + list.delivery_count.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_header_inner(list: &Header, buf: &mut BytesMut) { + Descriptor::Ulong(112).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + + list.durable.encoded_size() + + list.priority.encoded_size() + + list.ttl.encoded_size() + + list.first_acquirer.encoded_size() + + list.delivery_count.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Header::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Header::FIELD_COUNT as u8); + } + list.durable.encode(buf); + list.priority.encode(buf); + list.ttl.encode(buf); + list.first_acquirer.encode(buf); + list.delivery_count.encode(buf); +} +impl DecodeFormatted for Header { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 112, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:header:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_header_inner(input) + } + } +} +impl Encode for Header { + fn encoded_size(&self) -> usize { + encoded_size_header_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_header_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct Properties { + pub message_id: Option, + pub user_id: Option, + pub to: Option
, + pub subject: Option, + pub reply_to: Option
, + pub correlation_id: Option, + pub content_type: Option, + pub content_encoding: Option, + pub absolute_expiry_time: Option, + pub creation_time: Option, + pub group_id: Option, + pub group_sequence: Option, + pub reply_to_group_id: Option, +} +impl Properties { + pub fn message_id(&self) -> Option<&MessageId> { + self.message_id.as_ref() + } + pub fn user_id(&self) -> Option<&Bytes> { + self.user_id.as_ref() + } + pub fn to(&self) -> Option<&Address> { + self.to.as_ref() + } + pub fn subject(&self) -> Option<&ByteString> { + self.subject.as_ref() + } + pub fn reply_to(&self) -> Option<&Address> { + self.reply_to.as_ref() + } + pub fn correlation_id(&self) -> Option<&MessageId> { + self.correlation_id.as_ref() + } + pub fn content_type(&self) -> Option<&Symbol> { + self.content_type.as_ref() + } + pub fn content_encoding(&self) -> Option<&Symbol> { + self.content_encoding.as_ref() + } + pub fn absolute_expiry_time(&self) -> Option { + self.absolute_expiry_time + } + pub fn creation_time(&self) -> Option { + self.creation_time + } + pub fn group_id(&self) -> Option<&ByteString> { + self.group_id.as_ref() + } + pub fn group_sequence(&self) -> Option { + self.group_sequence + } + pub fn reply_to_group_id(&self) -> Option<&ByteString> { + self.reply_to_group_id.as_ref() + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1; +} +#[allow(unused_mut)] +fn decode_properties_inner(input: &[u8]) -> Result<(&[u8], Properties), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let message_id: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + message_id = decoded.1; + count -= 1; + } else { + message_id = None; + } + let user_id: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + user_id = decoded.1; + count -= 1; + } else { + user_id = None; + } + let to: Option
; + if count > 0 { + let decoded = Option::
::decode(input)?; + input = decoded.0; + to = decoded.1; + count -= 1; + } else { + to = None; + } + let subject: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + subject = decoded.1; + count -= 1; + } else { + subject = None; + } + let reply_to: Option
; + if count > 0 { + let decoded = Option::
::decode(input)?; + input = decoded.0; + reply_to = decoded.1; + count -= 1; + } else { + reply_to = None; + } + let correlation_id: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + correlation_id = decoded.1; + count -= 1; + } else { + correlation_id = None; + } + let content_type: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + content_type = decoded.1; + count -= 1; + } else { + content_type = None; + } + let content_encoding: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + content_encoding = decoded.1; + count -= 1; + } else { + content_encoding = None; + } + let absolute_expiry_time: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + absolute_expiry_time = decoded.1; + count -= 1; + } else { + absolute_expiry_time = None; + } + let creation_time: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + creation_time = decoded.1; + count -= 1; + } else { + creation_time = None; + } + let group_id: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + group_id = decoded.1; + count -= 1; + } else { + group_id = None; + } + let group_sequence: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + group_sequence = decoded.1; + count -= 1; + } else { + group_sequence = None; + } + let reply_to_group_id: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + reply_to_group_id = decoded.1; + count -= 1; + } else { + reply_to_group_id = None; + } + Ok(( + remainder, + Properties { + message_id, + user_id, + to, + subject, + reply_to, + correlation_id, + content_type, + content_encoding, + absolute_expiry_time, + creation_time, + group_id, + group_sequence, + reply_to_group_id, + }, + )) +} +fn encoded_size_properties_inner(list: &Properties) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + + list.message_id.encoded_size() + + list.user_id.encoded_size() + + list.to.encoded_size() + + list.subject.encoded_size() + + list.reply_to.encoded_size() + + list.correlation_id.encoded_size() + + list.content_type.encoded_size() + + list.content_encoding.encoded_size() + + list.absolute_expiry_time.encoded_size() + + list.creation_time.encoded_size() + + list.group_id.encoded_size() + + list.group_sequence.encoded_size() + + list.reply_to_group_id.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_properties_inner(list: &Properties, buf: &mut BytesMut) { + Descriptor::Ulong(115).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + + list.message_id.encoded_size() + + list.user_id.encoded_size() + + list.to.encoded_size() + + list.subject.encoded_size() + + list.reply_to.encoded_size() + + list.correlation_id.encoded_size() + + list.content_type.encoded_size() + + list.content_encoding.encoded_size() + + list.absolute_expiry_time.encoded_size() + + list.creation_time.encoded_size() + + list.group_id.encoded_size() + + list.group_sequence.encoded_size() + + list.reply_to_group_id.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Properties::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Properties::FIELD_COUNT as u8); + } + list.message_id.encode(buf); + list.user_id.encode(buf); + list.to.encode(buf); + list.subject.encode(buf); + list.reply_to.encode(buf); + list.correlation_id.encode(buf); + list.content_type.encode(buf); + list.content_encoding.encode(buf); + list.absolute_expiry_time.encode(buf); + list.creation_time.encode(buf); + list.group_id.encode(buf); + list.group_sequence.encode(buf); + list.reply_to_group_id.encode(buf); +} +impl DecodeFormatted for Properties { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 115, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:properties:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_properties_inner(input) + } + } +} +impl Encode for Properties { + fn encoded_size(&self) -> usize { + encoded_size_properties_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_properties_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct Received { + pub section_number: u32, + pub section_offset: u64, +} +impl Received { + pub fn section_number(&self) -> u32 { + self.section_number + } + pub fn section_offset(&self) -> u64 { + self.section_offset + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1 + 1; +} +#[allow(unused_mut)] +fn decode_received_inner(input: &[u8]) -> Result<(&[u8], Received), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let section_number: u32; + if count > 0 { + let (in1, decoded) = u32::decode(input)?; + section_number = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("section_number")); + } + let section_offset: u64; + if count > 0 { + let (in1, decoded) = u64::decode(input)?; + section_offset = decoded; + input = in1; + count -= 1; + } else { + return Err(AmqpParseError::RequiredFieldOmitted("section_offset")); + } + Ok(( + remainder, + Received { + section_number, + section_offset, + }, + )) +} +fn encoded_size_received_inner(list: &Received) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + list.section_number.encoded_size() + list.section_offset.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_received_inner(list: &Received, buf: &mut BytesMut) { + Descriptor::Ulong(35).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + list.section_number.encoded_size() + list.section_offset.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Received::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Received::FIELD_COUNT as u8); + } + list.section_number.encode(buf); + list.section_offset.encode(buf); +} +impl DecodeFormatted for Received { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 35, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:received:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_received_inner(input) + } + } +} +impl Encode for Received { + fn encoded_size(&self) -> usize { + encoded_size_received_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_received_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct Accepted {} +impl Accepted { + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0; +} +#[allow(unused_mut)] +fn decode_accepted_inner(input: &[u8]) -> Result<(&[u8], Accepted), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let mut remainder = &input[size..]; + Ok((remainder, Accepted {})) +} +fn encoded_size_accepted_inner(list: &Accepted) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0; + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_accepted_inner(list: &Accepted, buf: &mut BytesMut) { + Descriptor::Ulong(36).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0; + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Accepted::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Accepted::FIELD_COUNT as u8); + } +} +impl DecodeFormatted for Accepted { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 36, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:accepted:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_accepted_inner(input) + } + } +} +impl Encode for Accepted { + fn encoded_size(&self) -> usize { + encoded_size_accepted_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_accepted_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct Rejected { + pub error: Option, +} +impl Rejected { + pub fn error(&self) -> Option<&Error> { + self.error.as_ref() + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1; +} +#[allow(unused_mut)] +fn decode_rejected_inner(input: &[u8]) -> Result<(&[u8], Rejected), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let error: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + error = decoded.1; + count -= 1; + } else { + error = None; + } + Ok((remainder, Rejected { error })) +} +fn encoded_size_rejected_inner(list: &Rejected) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + list.error.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_rejected_inner(list: &Rejected, buf: &mut BytesMut) { + Descriptor::Ulong(37).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + list.error.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Rejected::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Rejected::FIELD_COUNT as u8); + } + list.error.encode(buf); +} +impl DecodeFormatted for Rejected { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 37, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:rejected:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_rejected_inner(input) + } + } +} +impl Encode for Rejected { + fn encoded_size(&self) -> usize { + encoded_size_rejected_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_rejected_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct Released {} +impl Released { + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0; +} +#[allow(unused_mut)] +fn decode_released_inner(input: &[u8]) -> Result<(&[u8], Released), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let mut remainder = &input[size..]; + Ok((remainder, Released {})) +} +fn encoded_size_released_inner(list: &Released) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0; + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_released_inner(list: &Released, buf: &mut BytesMut) { + Descriptor::Ulong(38).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0; + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Released::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Released::FIELD_COUNT as u8); + } +} +impl DecodeFormatted for Released { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 38, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:released:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_released_inner(input) + } + } +} +impl Encode for Released { + fn encoded_size(&self) -> usize { + encoded_size_released_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_released_inner(self, buf) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct Modified { + pub delivery_failed: Option, + pub undeliverable_here: Option, + pub message_annotations: Option, +} +impl Modified { + pub fn delivery_failed(&self) -> Option { + self.delivery_failed + } + pub fn undeliverable_here(&self) -> Option { + self.undeliverable_here + } + pub fn message_annotations(&self) -> Option<&Fields> { + self.message_annotations.as_ref() + } + #[allow(clippy::identity_op)] + const FIELD_COUNT: usize = 0 + 1 + 1 + 1; +} +#[allow(unused_mut)] +fn decode_modified_inner(input: &[u8]) -> Result<(&[u8], Modified), AmqpParseError> { + let (input, format) = decode_format_code(input)?; + let (input, header) = decode_list_header(input, format)?; + let size = header.size as usize; + decode_check_len!(input, size); + let (mut input, mut remainder) = input.split_at(size); + let mut count = header.count; + let delivery_failed: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + delivery_failed = decoded.1; + count -= 1; + } else { + delivery_failed = None; + } + let undeliverable_here: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + undeliverable_here = decoded.1; + count -= 1; + } else { + undeliverable_here = None; + } + let message_annotations: Option; + if count > 0 { + let decoded = Option::::decode(input)?; + input = decoded.0; + message_annotations = decoded.1; + count -= 1; + } else { + message_annotations = None; + } + Ok(( + remainder, + Modified { + delivery_failed, + undeliverable_here, + message_annotations, + }, + )) +} +fn encoded_size_modified_inner(list: &Modified) -> usize { + #[allow(clippy::identity_op)] + let content_size = 0 + + list.delivery_failed.encoded_size() + + list.undeliverable_here.encoded_size() + + list.message_annotations.encoded_size(); + // header: 0x00 0x53 format_code size count + (if content_size + 1 > u8::MAX as usize { + 12 + } else { + 6 + }) + content_size +} +fn encode_modified_inner(list: &Modified, buf: &mut BytesMut) { + Descriptor::Ulong(39).encode(buf); + #[allow(clippy::identity_op)] + let content_size = 0 + + list.delivery_failed.encoded_size() + + list.undeliverable_here.encoded_size() + + list.message_annotations.encoded_size(); + if content_size + 1 > u8::MAX as usize { + buf.put_u8(codec::FORMATCODE_LIST32); + buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count + buf.put_u32(Modified::FIELD_COUNT as u32); + } else { + buf.put_u8(codec::FORMATCODE_LIST8); + buf.put_u8((content_size + 1) as u8); + buf.put_u8(Modified::FIELD_COUNT as u8); + } + list.delivery_failed.encode(buf); + list.undeliverable_here.encode(buf); + list.message_annotations.encode(buf); +} +impl DecodeFormatted for Modified { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + validate_code!(fmt, codec::FORMATCODE_DESCRIBED); + let (input, descriptor) = Descriptor::decode(input)?; + let is_match = match descriptor { + Descriptor::Ulong(val) => val == 39, + Descriptor::Symbol(ref sym) => sym.as_bytes() == b"amqp:modified:list", + }; + if !is_match { + Err(AmqpParseError::InvalidDescriptor(descriptor)) + } else { + decode_modified_inner(input) + } + } +} +impl Encode for Modified { + fn encoded_size(&self) -> usize { + encoded_size_modified_inner(self) + } + fn encode(&self, buf: &mut BytesMut) { + encode_modified_inner(self, buf) + } +} diff --git a/actix-amqp/codec/src/protocol/mod.rs b/actix-amqp/codec/src/protocol/mod.rs new file mode 100755 index 000000000..66a38680c --- /dev/null +++ b/actix-amqp/codec/src/protocol/mod.rs @@ -0,0 +1,268 @@ +use std::fmt; + +use bytes::{BufMut, Bytes, BytesMut}; +use bytestring::ByteString; +use chrono::{DateTime, Utc}; +use derive_more::From; +use fxhash::FxHashMap; +use uuid::Uuid; + +use super::codec::{self, DecodeFormatted, Encode}; +use super::errors::AmqpParseError; +use super::message::{InMessage, OutMessage}; +use super::types::*; + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:?}", self) + } +} + +#[derive(Debug)] +pub(crate) struct CompoundHeader { + pub size: u32, + pub count: u32, +} + +impl CompoundHeader { + pub fn empty() -> CompoundHeader { + CompoundHeader { size: 0, count: 0 } + } +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +pub enum ProtocolId { + Amqp = 0, + AmqpTls = 2, + AmqpSasl = 3, +} + +pub type Map = FxHashMap; +pub type StringVariantMap = FxHashMap; +pub type Fields = FxHashMap; +pub type FilterSet = FxHashMap>; +pub type Timestamp = DateTime; +pub type Symbols = Multiple; +pub type IetfLanguageTags = Multiple; +pub type Annotations = FxHashMap; + +mod definitions; +pub use self::definitions::*; + +#[derive(Debug, Eq, PartialEq, Clone, From, Display)] +pub enum MessageId { + #[display(fmt = "{}", _0)] + Ulong(u64), + #[display(fmt = "{}", _0)] + Uuid(Uuid), + #[display(fmt = "{:?}", _0)] + Binary(Bytes), + #[display(fmt = "{}", _0)] + String(ByteString), +} + +impl From for MessageId { + fn from(id: usize) -> MessageId { + MessageId::Ulong(id as u64) + } +} + +impl From for MessageId { + fn from(id: i32) -> MessageId { + MessageId::Ulong(id as u64) + } +} + +impl DecodeFormatted for MessageId { + fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> { + match fmt { + codec::FORMATCODE_SMALLULONG | codec::FORMATCODE_ULONG | codec::FORMATCODE_ULONG_0 => { + u64::decode_with_format(input, fmt).map(|(i, o)| (i, MessageId::Ulong(o))) + } + codec::FORMATCODE_UUID => { + Uuid::decode_with_format(input, fmt).map(|(i, o)| (i, MessageId::Uuid(o))) + } + codec::FORMATCODE_BINARY8 | codec::FORMATCODE_BINARY32 => { + Bytes::decode_with_format(input, fmt).map(|(i, o)| (i, MessageId::Binary(o))) + } + codec::FORMATCODE_STRING8 | codec::FORMATCODE_STRING32 => { + ByteString::decode_with_format(input, fmt).map(|(i, o)| (i, MessageId::String(o))) + } + _ => Err(AmqpParseError::InvalidFormatCode(fmt)), + } + } +} + +impl Encode for MessageId { + fn encoded_size(&self) -> usize { + match *self { + MessageId::Ulong(v) => v.encoded_size(), + MessageId::Uuid(ref v) => v.encoded_size(), + MessageId::Binary(ref v) => v.encoded_size(), + MessageId::String(ref v) => v.encoded_size(), + } + } + + fn encode(&self, buf: &mut BytesMut) { + match *self { + MessageId::Ulong(v) => v.encode(buf), + MessageId::Uuid(ref v) => v.encode(buf), + MessageId::Binary(ref v) => v.encode(buf), + MessageId::String(ref v) => v.encode(buf), + } + } +} + +#[derive(Clone, Debug, PartialEq, From)] +pub enum ErrorCondition { + AmqpError(AmqpError), + ConnectionError(ConnectionError), + SessionError(SessionError), + LinkError(LinkError), + Custom(Symbol), +} + +impl DecodeFormatted for ErrorCondition { + #[inline] + fn decode_with_format(input: &[u8], format: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, result) = Symbol::decode_with_format(input, format)?; + if let Ok(r) = AmqpError::try_from(&result) { + return Ok((input, ErrorCondition::AmqpError(r))); + } + if let Ok(r) = ConnectionError::try_from(&result) { + return Ok((input, ErrorCondition::ConnectionError(r))); + } + if let Ok(r) = SessionError::try_from(&result) { + return Ok((input, ErrorCondition::SessionError(r))); + } + if let Ok(r) = LinkError::try_from(&result) { + return Ok((input, ErrorCondition::LinkError(r))); + } + Ok((input, ErrorCondition::Custom(result))) + } +} + +impl Encode for ErrorCondition { + fn encoded_size(&self) -> usize { + match *self { + ErrorCondition::AmqpError(ref v) => v.encoded_size(), + ErrorCondition::ConnectionError(ref v) => v.encoded_size(), + ErrorCondition::SessionError(ref v) => v.encoded_size(), + ErrorCondition::LinkError(ref v) => v.encoded_size(), + ErrorCondition::Custom(ref v) => v.encoded_size(), + } + } + + fn encode(&self, buf: &mut BytesMut) { + match *self { + ErrorCondition::AmqpError(ref v) => v.encode(buf), + ErrorCondition::ConnectionError(ref v) => v.encode(buf), + ErrorCondition::SessionError(ref v) => v.encode(buf), + ErrorCondition::LinkError(ref v) => v.encode(buf), + ErrorCondition::Custom(ref v) => v.encode(buf), + } + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum DistributionMode { + Move, + Copy, + Custom(Symbol), +} + +impl DecodeFormatted for DistributionMode { + fn decode_with_format(input: &[u8], format: u8) -> Result<(&[u8], Self), AmqpParseError> { + let (input, result) = Symbol::decode_with_format(input, format)?; + let result = match result.as_str() { + "move" => DistributionMode::Move, + "copy" => DistributionMode::Copy, + _ => DistributionMode::Custom(result), + }; + Ok((input, result)) + } +} + +impl Encode for DistributionMode { + fn encoded_size(&self) -> usize { + match *self { + DistributionMode::Move => 6, + DistributionMode::Copy => 6, + DistributionMode::Custom(ref v) => v.encoded_size(), + } + } + + fn encode(&self, buf: &mut BytesMut) { + match *self { + DistributionMode::Move => Symbol::from("move").encode(buf), + DistributionMode::Copy => Symbol::from("copy").encode(buf), + DistributionMode::Custom(ref v) => v.encode(buf), + } + } +} + +impl SaslInit { + pub fn prepare_response(authz_id: &str, authn_id: &str, password: &str) -> Bytes { + Bytes::from(format!("{}\x00{}\x00{}", authz_id, authn_id, password)) + } +} + +impl Default for Properties { + fn default() -> Properties { + Properties { + message_id: None, + user_id: None, + to: None, + subject: None, + reply_to: None, + correlation_id: None, + content_type: None, + content_encoding: None, + absolute_expiry_time: None, + creation_time: None, + group_id: None, + group_sequence: None, + reply_to_group_id: None, + } + } +} + +#[derive(Debug, Clone, From, PartialEq)] +pub enum TransferBody { + Data(Bytes), + MessageIn(InMessage), + MessageOut(OutMessage), +} + +impl TransferBody { + #[inline] + pub fn len(&self) -> usize { + self.encoded_size() + } + + #[inline] + pub fn message_format(&self) -> Option { + match self { + TransferBody::Data(_) => None, + TransferBody::MessageIn(ref data) => data.message_format, + TransferBody::MessageOut(ref data) => data.message_format, + } + } +} + +impl Encode for TransferBody { + fn encoded_size(&self) -> usize { + match self { + TransferBody::Data(ref data) => data.len(), + TransferBody::MessageIn(ref data) => data.encoded_size(), + TransferBody::MessageOut(ref data) => data.encoded_size(), + } + } + fn encode(&self, dst: &mut BytesMut) { + match *self { + TransferBody::Data(ref data) => dst.put_slice(&data), + TransferBody::MessageIn(ref data) => data.encode(dst), + TransferBody::MessageOut(ref data) => data.encode(dst), + } + } +} diff --git a/actix-amqp/codec/src/types/mod.rs b/actix-amqp/codec/src/types/mod.rs new file mode 100755 index 000000000..d42a8a1f8 --- /dev/null +++ b/actix-amqp/codec/src/types/mod.rs @@ -0,0 +1,194 @@ +use std::{borrow, fmt, hash, ops}; + +use bytestring::ByteString; + +mod symbol; +mod variant; + +pub use self::symbol::{StaticSymbol, Symbol}; +pub use self::variant::{Variant, VariantMap, VecStringMap, VecSymbolMap}; + +#[derive(Debug, PartialEq, Eq, Clone, Hash, Display)] +pub enum Descriptor { + Ulong(u64), + Symbol(Symbol), +} + +#[derive(Debug, PartialEq, Eq, Clone, Hash, From)] +pub struct Multiple(pub Vec); + +impl Multiple { + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn iter(&self) -> ::std::slice::Iter { + self.0.iter() + } +} + +impl Default for Multiple { + fn default() -> Multiple { + Multiple(Vec::new()) + } +} + +impl ops::Deref for Multiple { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl ops::DerefMut for Multiple { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub struct List(pub Vec); + +impl List { + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn iter(&self) -> ::std::slice::Iter { + self.0.iter() + } +} + +#[derive(Display, Clone, Eq, Ord, PartialOrd)] +pub enum Str { + String(String), + ByteStr(ByteString), + Static(&'static str), +} + +impl Str { + pub fn from_str(s: &str) -> Str { + Str::ByteStr(ByteString::from(s)) + } + + pub fn as_bytes(&self) -> &[u8] { + match self { + Str::String(s) => s.as_ref(), + Str::ByteStr(s) => s.as_ref(), + Str::Static(s) => s.as_bytes(), + } + } + + pub fn as_str(&self) -> &str { + match self { + Str::String(s) => s.as_str(), + Str::ByteStr(s) => s.as_ref(), + Str::Static(s) => s, + } + } + + pub fn to_bytes_str(&self) -> ByteString { + match self { + Str::String(s) => ByteString::from(s.as_str()), + Str::ByteStr(s) => s.clone(), + Str::Static(s) => ByteString::from_static(s), + } + } + + pub fn len(&self) -> usize { + match self { + Str::String(s) => s.len(), + Str::ByteStr(s) => s.len(), + Str::Static(s) => s.len(), + } + } +} + +impl From<&'static str> for Str { + fn from(s: &'static str) -> Str { + Str::Static(s) + } +} + +impl From for Str { + fn from(s: ByteString) -> Str { + Str::ByteStr(s) + } +} + +impl From for Str { + fn from(s: String) -> Str { + Str::String(s) + } +} + +impl hash::Hash for Str { + fn hash(&self, state: &mut H) { + match self { + Str::String(s) => (&*s).hash(state), + Str::ByteStr(s) => (&*s).hash(state), + Str::Static(s) => s.hash(state), + } + } +} + +impl borrow::Borrow for Str { + fn borrow(&self) -> &str { + self.as_str() + } +} + +impl PartialEq for Str { + fn eq(&self, other: &Str) -> bool { + match self { + Str::String(s) => match other { + Str::String(o) => s == o, + Str::ByteStr(o) => o == s.as_str(), + Str::Static(o) => s == *o, + }, + Str::ByteStr(s) => match other { + Str::String(o) => s == o.as_str(), + Str::ByteStr(o) => s == o, + Str::Static(o) => s == *o, + }, + Str::Static(s) => match other { + Str::String(o) => o == s, + Str::ByteStr(o) => o == *s, + Str::Static(o) => s == o, + }, + } + } +} + +impl PartialEq for Str { + fn eq(&self, other: &str) -> bool { + match self { + Str::String(ref s) => s == other, + Str::ByteStr(ref s) => { + // workaround for possible compiler bug + let t: &str = &*s; + t == other + } + Str::Static(s) => *s == other, + } + } +} + +impl fmt::Debug for Str { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Str::String(s) => write!(f, "ST:\"{}\"", s), + Str::ByteStr(s) => write!(f, "B:\"{}\"", &*s), + Str::Static(s) => write!(f, "S:\"{}\"", s), + } + } +} diff --git a/actix-amqp/codec/src/types/symbol.rs b/actix-amqp/codec/src/types/symbol.rs new file mode 100755 index 000000000..1b8081686 --- /dev/null +++ b/actix-amqp/codec/src/types/symbol.rs @@ -0,0 +1,75 @@ +use std::{borrow, str}; + +use bytestring::ByteString; + +use super::Str; + +#[derive(Debug, Clone, Eq, PartialEq, Hash, Display)] +pub struct Symbol(pub Str); + +impl Symbol { + pub fn from_slice(s: &str) -> Symbol { + Symbol(Str::ByteStr(ByteString::from(s))) + } + + pub fn as_bytes(&self) -> &[u8] { + self.0.as_bytes() + } + + pub fn as_str(&self) -> &str { + self.0.as_str() + } + + pub fn to_bytes_str(&self) -> ByteString { + self.0.to_bytes_str() + } + + pub fn len(&self) -> usize { + self.0.len() + } +} + +impl From<&'static str> for Symbol { + fn from(s: &'static str) -> Symbol { + Symbol(Str::Static(s)) + } +} + +impl From for Symbol { + fn from(s: Str) -> Symbol { + Symbol(s) + } +} + +impl From for Symbol { + fn from(s: std::string::String) -> Symbol { + Symbol(Str::from(s)) + } +} + +impl From for Symbol { + fn from(s: ByteString) -> Symbol { + Symbol(Str::ByteStr(s)) + } +} + +impl borrow::Borrow for Symbol { + fn borrow(&self) -> &str { + self.as_str() + } +} + +impl PartialEq for Symbol { + fn eq(&self, other: &str) -> bool { + self.0 == *other + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Display)] +pub struct StaticSymbol(pub &'static str); + +impl From<&'static str> for StaticSymbol { + fn from(s: &'static str) -> StaticSymbol { + StaticSymbol(s) + } +} diff --git a/actix-amqp/codec/src/types/variant.rs b/actix-amqp/codec/src/types/variant.rs new file mode 100755 index 000000000..7ef0be8f4 --- /dev/null +++ b/actix-amqp/codec/src/types/variant.rs @@ -0,0 +1,268 @@ +use std::hash::{Hash, Hasher}; + +use bytes::Bytes; +use bytestring::ByteString; +use chrono::{DateTime, Utc}; +use fxhash::FxHashMap; +use ordered_float::OrderedFloat; +use uuid::Uuid; + +use crate::protocol::Annotations; +use crate::types::{Descriptor, List, StaticSymbol, Str, Symbol}; + +/// Represents an AMQP type for use in polymorphic collections +#[derive(Debug, Eq, PartialEq, Hash, Clone, Display, From)] +pub enum Variant { + /// Indicates an empty value. + Null, + + /// Represents a true or false value. + Boolean(bool), + + /// Integer in the range 0 to 2^8 - 1 inclusive. + Ubyte(u8), + + /// Integer in the range 0 to 2^16 - 1 inclusive. + Ushort(u16), + + /// Integer in the range 0 to 2^32 - 1 inclusive. + Uint(u32), + + /// Integer in the range 0 to 2^64 - 1 inclusive. + Ulong(u64), + + /// Integer in the range 0 to 2^7 - 1 inclusive. + Byte(i8), + + /// Integer in the range 0 to 2^15 - 1 inclusive. + Short(i16), + + /// Integer in the range 0 to 2^32 - 1 inclusive. + Int(i32), + + /// Integer in the range 0 to 2^64 - 1 inclusive. + Long(i64), + + /// 32-bit floating point number (IEEE 754-2008 binary32). + Float(OrderedFloat), + + /// 64-bit floating point number (IEEE 754-2008 binary64). + Double(OrderedFloat), + + // Decimal32(d32), + // Decimal64(d64), + // Decimal128(d128), + /// A single Unicode character. + Char(char), + + /// An absolute point in time. + /// Represents an approximate point in time using the Unix time encoding of + /// UTC with a precision of milliseconds. For example, 1311704463521 + /// represents the moment 2011-07-26T18:21:03.521Z. + Timestamp(DateTime), + + /// A universally unique identifier as defined by RFC-4122 section 4.1.2 + Uuid(Uuid), + + /// A sequence of octets. + #[display(fmt = "Binary({:?})", _0)] + Binary(Bytes), + + /// A sequence of Unicode characters + String(Str), + + /// Symbolic values from a constrained domain. + Symbol(Symbol), + + /// Same as Symbol but for static refs + StaticSymbol(StaticSymbol), + + /// List + #[display(fmt = "List({:?})", _0)] + List(List), + + /// Map + Map(VariantMap), + + /// Described value + #[display(fmt = "Described{:?}", _0)] + Described((Descriptor, Box)), +} + +impl From for Variant { + fn from(s: ByteString) -> Self { + Str::from(s).into() + } +} + +impl From for Variant { + fn from(s: String) -> Self { + Str::from(ByteString::from(s)).into() + } +} + +impl From<&'static str> for Variant { + fn from(s: &'static str) -> Self { + Str::from(s).into() + } +} + +impl PartialEq for Variant { + fn eq(&self, other: &str) -> bool { + match self { + Variant::String(s) => s == other, + Variant::Symbol(s) => s == other, + _ => false, + } + } +} + +impl Variant { + pub fn as_str(&self) -> Option<&str> { + match self { + Variant::String(s) => Some(s.as_str()), + Variant::Symbol(s) => Some(s.as_str()), + _ => None, + } + } + + pub fn as_int(&self) -> Option { + match self { + Variant::Int(v) => Some(*v as i32), + _ => None, + } + } + + pub fn as_long(&self) -> Option { + match self { + Variant::Ubyte(v) => Some(*v as i64), + Variant::Ushort(v) => Some(*v as i64), + Variant::Uint(v) => Some(*v as i64), + Variant::Ulong(v) => Some(*v as i64), + Variant::Byte(v) => Some(*v as i64), + Variant::Short(v) => Some(*v as i64), + Variant::Int(v) => Some(*v as i64), + Variant::Long(v) => Some(*v as i64), + _ => None, + } + } + + pub fn to_bytes_str(&self) -> Option { + match self { + Variant::String(s) => Some(s.to_bytes_str()), + Variant::Symbol(s) => Some(s.to_bytes_str()), + _ => None, + } + } +} + +#[derive(PartialEq, Eq, Clone, Debug, Display)] +#[display(fmt = "{:?}", map)] +pub struct VariantMap { + pub map: FxHashMap, +} + +impl VariantMap { + pub fn new(map: FxHashMap) -> VariantMap { + VariantMap { map } + } +} + +impl Hash for VariantMap { + fn hash(&self, _state: &mut H) { + unimplemented!() + } +} + +#[derive(PartialEq, Clone, Debug, Display)] +#[display(fmt = "{:?}", _0)] +pub struct VecSymbolMap(pub Vec<(Symbol, Variant)>); + +impl Default for VecSymbolMap { + fn default() -> Self { + VecSymbolMap(Vec::with_capacity(8)) + } +} + +impl From for VecSymbolMap { + fn from(anns: Annotations) -> VecSymbolMap { + VecSymbolMap(anns.into_iter().collect()) + } +} + +impl std::ops::Deref for VecSymbolMap { + type Target = Vec<(Symbol, Variant)>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for VecSymbolMap { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +#[derive(PartialEq, Clone, Debug, Display)] +#[display(fmt = "{:?}", _0)] +pub struct VecStringMap(pub Vec<(Str, Variant)>); + +impl Default for VecStringMap { + fn default() -> Self { + VecStringMap(Vec::with_capacity(8)) + } +} + +impl From> for VecStringMap { + fn from(map: FxHashMap) -> VecStringMap { + VecStringMap(map.into_iter().collect()) + } +} + +impl std::ops::Deref for VecStringMap { + type Target = Vec<(Str, Variant)>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for VecStringMap { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn bytes_eq() { + let bytes1 = Variant::Binary(Bytes::from(&b"hello"[..])); + let bytes2 = Variant::Binary(Bytes::from(&b"hello"[..])); + let bytes3 = Variant::Binary(Bytes::from(&b"world"[..])); + + assert_eq!(bytes1, bytes2); + assert!(bytes1 != bytes3); + } + + #[test] + fn string_eq() { + let a = Variant::String(ByteString::from("hello").into()); + let b = Variant::String(ByteString::from("world!").into()); + + assert_eq!(Variant::String(ByteString::from("hello").into()), a); + assert!(a != b); + } + + #[test] + fn symbol_eq() { + let a = Variant::Symbol(Symbol::from("hello")); + let b = Variant::Symbol(Symbol::from("world!")); + + assert_eq!(Variant::Symbol(Symbol::from("hello")), a); + assert!(a != b); + } +} diff --git a/actix-amqp/src/cell.rs b/actix-amqp/src/cell.rs new file mode 100755 index 000000000..bb7780d5a --- /dev/null +++ b/actix-amqp/src/cell.rs @@ -0,0 +1,72 @@ +//! Custom cell impl +use std::cell::UnsafeCell; +use std::ops::Deref; +use std::rc::{Rc, Weak}; + +pub(crate) struct Cell { + inner: Rc>, +} + +pub(crate) struct WeakCell { + inner: Weak>, +} + +impl Clone for Cell { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +impl Deref for Cell { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.get_ref() + } +} + +impl std::fmt::Debug for Cell { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.inner.fmt(f) + } +} + +impl Cell { + pub fn new(inner: T) -> Self { + Self { + inner: Rc::new(UnsafeCell::new(inner)), + } + } + + pub fn downgrade(&self) -> WeakCell { + WeakCell { + inner: Rc::downgrade(&self.inner), + } + } + + pub fn get_ref(&self) -> &T { + unsafe { &*self.inner.as_ref().get() } + } + + pub fn get_mut(&self) -> &mut T { + unsafe { &mut *self.inner.as_ref().get() } + } +} + +impl std::fmt::Debug for WeakCell { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.inner.fmt(f) + } +} + +impl WeakCell { + pub fn upgrade(&self) -> Option> { + if let Some(inner) = self.inner.upgrade() { + Some(Cell { inner }) + } else { + None + } + } +} diff --git a/actix-amqp/src/client/connect.rs b/actix-amqp/src/client/connect.rs new file mode 100755 index 000000000..6a20097a9 --- /dev/null +++ b/actix-amqp/src/client/connect.rs @@ -0,0 +1,11 @@ +use actix_codec::Framed; + +use crate::Configuration; + +trait IntoFramed { + fn into_framed(self) -> Framed; +} + +pub struct Handshake { + _cfg: Configuration, +} diff --git a/actix-amqp/src/client/mod.rs b/actix-amqp/src/client/mod.rs new file mode 100755 index 000000000..1d0e73eb7 --- /dev/null +++ b/actix-amqp/src/client/mod.rs @@ -0,0 +1,5 @@ +mod connect; +mod protocol; + +pub use self::connect::Handshake; +pub use self::protocol::ProtocolNegotiation; diff --git a/actix-amqp/src/client/protocol.rs b/actix-amqp/src/client/protocol.rs new file mode 100755 index 000000000..8c668beb9 --- /dev/null +++ b/actix-amqp/src/client/protocol.rs @@ -0,0 +1,78 @@ +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_service::Service; +use futures::{Future, SinkExt, StreamExt}; + +use amqp_codec::protocol::ProtocolId; +use amqp_codec::{ProtocolIdCodec, ProtocolIdError}; + +pub struct ProtocolNegotiation { + proto: ProtocolId, + _r: PhantomData, +} + +impl Clone for ProtocolNegotiation { + fn clone(&self) -> Self { + ProtocolNegotiation { + proto: self.proto.clone(), + _r: PhantomData, + } + } +} + +impl ProtocolNegotiation { + pub fn new(proto: ProtocolId) -> Self { + ProtocolNegotiation { + proto, + _r: PhantomData, + } + } + + pub fn framed(stream: Io) -> Framed + where + Io: AsyncRead + AsyncWrite, + { + Framed::new(stream, ProtocolIdCodec) + } +} + +impl Default for ProtocolNegotiation { + fn default() -> Self { + Self::new(ProtocolId::Amqp) + } +} + +impl Service for ProtocolNegotiation +where + Io: AsyncRead + AsyncWrite + 'static, +{ + type Request = Framed; + type Response = Framed; + type Error = ProtocolIdError; + type Future = Pin>>>; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, mut framed: Framed) -> Self::Future { + let proto = self.proto; + + Box::pin(async move { + framed.send(proto).await?; + let protocol = framed.next().await.ok_or(ProtocolIdError::Disconnected)??; + + if proto == protocol { + Ok(framed) + } else { + Err(ProtocolIdError::Unexpected { + exp: proto, + got: protocol, + }) + } + }) + } +} diff --git a/actix-amqp/src/connection.rs b/actix-amqp/src/connection.rs new file mode 100755 index 000000000..ab9e6c28b --- /dev/null +++ b/actix-amqp/src/connection.rs @@ -0,0 +1,577 @@ +use std::collections::VecDeque; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_utils::oneshot; +use actix_utils::task::LocalWaker; +use actix_utils::time::LowResTimeService; +use futures::future::{err, Either}; +use futures::{future, Sink, Stream}; +use fxhash::FxHashMap; + +use amqp_codec::protocol::{Begin, Close, End, Error, Frame}; +use amqp_codec::{AmqpCodec, AmqpCodecError, AmqpFrame}; + +use crate::cell::{Cell, WeakCell}; +use crate::errors::AmqpTransportError; +use crate::hb::{Heartbeat, HeartbeatAction}; +use crate::session::{Session, SessionInner}; +use crate::Configuration; + +pub struct Connection { + inner: Cell, + framed: Framed>, + hb: Heartbeat, +} + +pub(crate) enum ChannelState { + Opening(Option>, WeakCell), + Established(Cell), + Closing(Option>>), +} + +impl ChannelState { + fn is_opening(&self) -> bool { + match self { + ChannelState::Opening(_, _) => true, + _ => false, + } + } +} + +pub(crate) struct ConnectionInner { + local: Configuration, + remote: Configuration, + write_queue: VecDeque, + write_task: LocalWaker, + sessions: slab::Slab, + sessions_map: FxHashMap, + error: Option, + state: State, +} + +#[derive(PartialEq)] +enum State { + Normal, + Closing, + RemoteClose, + Drop, +} + +impl Connection { + pub fn new( + framed: Framed>, + local: Configuration, + remote: Configuration, + time: Option, + ) -> Connection { + Connection { + framed, + hb: Heartbeat::new( + local.timeout().unwrap(), + remote.timeout(), + time.unwrap_or_else(|| LowResTimeService::with(Duration::from_secs(1))), + ), + inner: Cell::new(ConnectionInner::new(local, remote)), + } + } + + pub(crate) fn new_server( + framed: Framed>, + inner: Cell, + time: Option, + ) -> Connection { + let l_timeout = inner.get_ref().local.timeout().unwrap(); + let r_timeout = inner.get_ref().remote.timeout(); + Connection { + framed, + inner, + hb: Heartbeat::new( + l_timeout, + r_timeout, + time.unwrap_or_else(|| LowResTimeService::with(Duration::from_secs(1))), + ), + } + } + + /// Connection controller + pub fn controller(&self) -> ConnectionController { + ConnectionController(self.inner.clone()) + } + + /// Get remote configuration + pub fn remote_config(&self) -> &Configuration { + &self.inner.get_ref().remote + } + + /// Gracefully close connection + pub fn close(&mut self) -> impl Future> { + future::ok(()) + } + + // TODO: implement + /// Close connection with error + pub fn close_with_error( + &mut self, + _err: Error, + ) -> impl Future> { + future::ok(()) + } + + /// Opens the session + pub fn open_session(&mut self) -> impl Future> { + let cell = self.inner.downgrade(); + let inner = self.inner.clone(); + + async move { + let inner = inner.get_mut(); + + if let Some(ref e) = inner.error { + Err(e.clone()) + } else { + let (tx, rx) = oneshot::channel(); + + let entry = inner.sessions.vacant_entry(); + let token = entry.key(); + + if token >= inner.local.channel_max { + Err(AmqpTransportError::TooManyChannels) + } else { + entry.insert(ChannelState::Opening(Some(tx), cell)); + + let begin = Begin { + remote_channel: None, + next_outgoing_id: 1, + incoming_window: std::u32::MAX, + outgoing_window: std::u32::MAX, + handle_max: std::u32::MAX, + offered_capabilities: None, + desired_capabilities: None, + properties: None, + }; + inner.post_frame(AmqpFrame::new(token as u16, begin.into())); + + rx.await.map_err(|_| AmqpTransportError::Disconnected) + } + } + } + } + + /// Get session by id. This method panics if session does not exists or in opening/closing state. + pub(crate) fn get_session(&self, id: usize) -> Cell { + if let Some(channel) = self.inner.get_ref().sessions.get(id) { + if let ChannelState::Established(ref session) = channel { + return session.clone(); + } + } + panic!("Session not found: {}", id); + } + + pub(crate) fn register_remote_session(&mut self, channel_id: u16, begin: &Begin) { + trace!("remote session opened: {:?}", channel_id); + + let cell = self.inner.clone(); + let inner = self.inner.get_mut(); + let entry = inner.sessions.vacant_entry(); + let token = entry.key(); + + let session = Cell::new(SessionInner::new( + token, + false, + ConnectionController(cell), + token as u16, + begin.next_outgoing_id(), + begin.incoming_window(), + begin.outgoing_window(), + )); + entry.insert(ChannelState::Established(session)); + inner.sessions_map.insert(channel_id, token); + + let begin = Begin { + remote_channel: Some(channel_id), + next_outgoing_id: 1, + incoming_window: std::u32::MAX, + outgoing_window: begin.incoming_window(), + handle_max: std::u32::MAX, + offered_capabilities: None, + desired_capabilities: None, + properties: None, + }; + inner.post_frame(AmqpFrame::new(token as u16, begin.into())); + } + + pub(crate) fn send_frame(&mut self, frame: AmqpFrame) { + self.inner.get_mut().post_frame(frame) + } + + pub(crate) fn register_write_task(&self, cx: &mut Context) { + self.inner.write_task.register(cx.waker()); + } + + pub(crate) fn poll_outgoing(&mut self, cx: &mut Context) -> Poll> { + let inner = self.inner.get_mut(); + let mut update = false; + loop { + while !self.framed.is_write_buf_full() { + if let Some(frame) = inner.pop_next_frame() { + trace!("outgoing: {:#?}", frame); + update = true; + if let Err(e) = self.framed.write(frame) { + inner.set_error(e.clone().into()); + return Poll::Ready(Err(e)); + } + } else { + break; + } + } + + if !self.framed.is_write_buf_empty() { + match self.framed.flush(cx) { + Poll::Pending => break, + Poll::Ready(Err(e)) => { + trace!("error sending data: {}", e); + inner.set_error(e.clone().into()); + return Poll::Ready(Err(e)); + } + Poll::Ready(_) => (), + } + } else { + break; + } + } + self.hb.update_remote(update); + + if inner.state == State::Drop { + Poll::Ready(Ok(())) + } else if inner.state == State::RemoteClose + && inner.write_queue.is_empty() + && self.framed.is_write_buf_empty() + { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + + pub(crate) fn poll_incoming( + &mut self, + cx: &mut Context, + ) -> Poll>> { + let inner = self.inner.get_mut(); + + let mut update = false; + loop { + match Pin::new(&mut self.framed).poll_next(cx) { + Poll::Ready(Some(Ok(frame))) => { + trace!("incoming: {:#?}", frame); + + update = true; + + if let Frame::Empty = frame.performative() { + self.hb.update_local(update); + continue; + } + + // handle connection close + if let Frame::Close(ref close) = frame.performative() { + inner.set_error(AmqpTransportError::Closed(close.error.clone())); + + if inner.state == State::Closing { + inner.sessions.clear(); + return Poll::Ready(None); + } else { + let close = Close { error: None }; + inner.post_frame(AmqpFrame::new(0, close.into())); + inner.state = State::RemoteClose; + } + } + + if inner.error.is_some() { + error!("connection closed but new framed is received: {:?}", frame); + return Poll::Ready(None); + } + + // get local session id + let channel_id = + if let Some(token) = inner.sessions_map.get(&frame.channel_id()) { + *token + } else { + // we dont have channel info, only Begin frame is allowed on new channel + if let Frame::Begin(ref begin) = frame.performative() { + if begin.remote_channel().is_some() { + inner.complete_session_creation(frame.channel_id(), begin); + } else { + return Poll::Ready(Some(Ok(frame))); + } + } else { + warn!("Unexpected frame: {:#?}", frame); + } + continue; + }; + + // handle session frames + if let Some(channel) = inner.sessions.get_mut(channel_id) { + match channel { + ChannelState::Opening(_, _) => { + error!("Unexpected opening state: {}", channel_id); + } + ChannelState::Established(ref mut session) => { + match frame.performative() { + Frame::Attach(attach) => { + let cell = session.clone(); + if !session.get_mut().handle_attach(attach, cell) { + return Poll::Ready(Some(Ok(frame))); + } + } + Frame::Flow(_) | Frame::Detach(_) => { + return Poll::Ready(Some(Ok(frame))); + } + Frame::End(remote_end) => { + trace!("Remote session end: {}", frame.channel_id()); + let end = End { error: None }; + session.get_mut().set_error( + AmqpTransportError::SessionEnded( + remote_end.error.clone(), + ), + ); + let id = session.get_mut().id(); + inner.post_frame(AmqpFrame::new(id, end.into())); + inner.sessions.remove(channel_id); + inner.sessions_map.remove(&frame.channel_id()); + } + _ => session.get_mut().handle_frame(frame.into_parts().1), + } + } + ChannelState::Closing(ref mut tx) => match frame.performative() { + Frame::End(_) => { + if let Some(tx) = tx.take() { + let _ = tx.send(Ok(())); + } + inner.sessions.remove(channel_id); + inner.sessions_map.remove(&frame.channel_id()); + } + frm => trace!("Got frame after initiated session end: {:?}", frm), + }, + } + } else { + error!("Can not find channel: {}", channel_id); + continue; + } + } + Poll::Ready(None) => { + inner.set_error(AmqpTransportError::Disconnected); + return Poll::Ready(None); + } + Poll::Pending => { + self.hb.update_local(update); + break; + } + Poll::Ready(Some(Err(e))) => { + trace!("error reading: {:?}", e); + inner.set_error(e.clone().into()); + return Poll::Ready(Some(Err(e.into()))); + } + } + } + + Poll::Pending + } +} + +impl Drop for Connection { + fn drop(&mut self) { + self.inner + .get_mut() + .set_error(AmqpTransportError::Disconnected); + } +} + +impl Future for Connection { + type Output = Result<(), AmqpCodecError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + // connection heartbeat + match self.hb.poll(cx) { + Ok(act) => match act { + HeartbeatAction::None => (), + HeartbeatAction::Close => { + self.inner.get_mut().set_error(AmqpTransportError::Timeout); + return Poll::Ready(Ok(())); + } + HeartbeatAction::Heartbeat => { + self.inner + .get_mut() + .write_queue + .push_back(AmqpFrame::new(0, Frame::Empty)); + } + }, + Err(e) => { + self.inner.get_mut().set_error(e); + return Poll::Ready(Ok(())); + } + } + + loop { + match self.poll_incoming(cx) { + Poll::Ready(None) => return Poll::Ready(Ok(())), + Poll::Ready(Some(Ok(frame))) => { + if let Some(channel) = self.inner.sessions.get(frame.channel_id() as usize) { + if let ChannelState::Established(ref session) = channel { + session.get_mut().handle_frame(frame.into_parts().1); + continue; + } + } + warn!("Unexpected frame: {:?}", frame); + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), + Poll::Pending => break, + } + } + let _ = self.poll_outgoing(cx)?; + self.register_write_task(cx); + + match self.poll_incoming(cx) { + Poll::Ready(None) => return Poll::Ready(Ok(())), + Poll::Ready(Some(Ok(frame))) => { + if let Some(channel) = self.inner.sessions.get(frame.channel_id() as usize) { + if let ChannelState::Established(ref session) = channel { + session.get_mut().handle_frame(frame.into_parts().1); + return Poll::Pending; + } + } + warn!("Unexpected frame: {:?}", frame); + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), + Poll::Pending => (), + } + + Poll::Pending + } +} + +#[derive(Clone)] +pub struct ConnectionController(pub(crate) Cell); + +impl ConnectionController { + pub(crate) fn new(local: Configuration) -> ConnectionController { + ConnectionController(Cell::new(ConnectionInner { + local, + remote: Configuration::default(), + write_queue: VecDeque::new(), + write_task: LocalWaker::new(), + sessions: slab::Slab::with_capacity(8), + sessions_map: FxHashMap::default(), + error: None, + state: State::Normal, + })) + } + + pub(crate) fn set_remote(&mut self, remote: Configuration) { + self.0.get_mut().remote = remote; + } + + #[inline] + /// Get remote connection configuration + pub fn remote_config(&self) -> &Configuration { + &self.0.get_ref().remote + } + + #[inline] + /// Drop connection + pub fn drop_connection(&mut self) { + let inner = self.0.get_mut(); + inner.state = State::Drop; + inner.write_task.wake() + } + + pub(crate) fn post_frame(&mut self, frame: AmqpFrame) { + self.0.get_mut().post_frame(frame) + } + + pub(crate) fn drop_session_copy(&mut self, _id: usize) {} +} + +impl ConnectionInner { + pub(crate) fn new(local: Configuration, remote: Configuration) -> ConnectionInner { + ConnectionInner { + local, + remote, + write_queue: VecDeque::new(), + write_task: LocalWaker::new(), + sessions: slab::Slab::with_capacity(8), + sessions_map: FxHashMap::default(), + error: None, + state: State::Normal, + } + } + + fn set_error(&mut self, err: AmqpTransportError) { + for (_, channel) in self.sessions.iter_mut() { + match channel { + ChannelState::Opening(_, _) | ChannelState::Closing(_) => (), + ChannelState::Established(ref mut ses) => { + ses.get_mut().set_error(err.clone()); + } + } + } + self.sessions.clear(); + self.sessions_map.clear(); + + self.error = Some(err); + } + + fn pop_next_frame(&mut self) -> Option { + self.write_queue.pop_front() + } + + fn post_frame(&mut self, frame: AmqpFrame) { + // trace!("POST-FRAME: {:#?}", frame.performative()); + self.write_queue.push_back(frame); + self.write_task.wake(); + } + + fn complete_session_creation(&mut self, channel_id: u16, begin: &Begin) { + trace!( + "session opened: {:?} {:?}", + channel_id, + begin.remote_channel() + ); + + let id = begin.remote_channel().unwrap() as usize; + + if let Some(channel) = self.sessions.get_mut(id) { + if channel.is_opening() { + if let ChannelState::Opening(tx, cell) = channel { + let cell = cell.upgrade().unwrap(); + let session = Cell::new(SessionInner::new( + id, + true, + ConnectionController(cell), + channel_id, + begin.next_outgoing_id(), + begin.incoming_window(), + begin.outgoing_window(), + )); + self.sessions_map.insert(channel_id, id); + + if tx + .take() + .unwrap() + .send(Session::new(session.clone())) + .is_err() + { + // todo: send end session + } + *channel = ChannelState::Established(session) + } + } else { + // send error response + } + } else { + // todo: rogue begin right now - do nothing. in future might indicate incoming attach + } + } +} diff --git a/actix-amqp/src/errors.rs b/actix-amqp/src/errors.rs new file mode 100755 index 000000000..4d145bd62 --- /dev/null +++ b/actix-amqp/src/errors.rs @@ -0,0 +1,31 @@ +use amqp_codec::{protocol, AmqpCodecError, ProtocolIdError}; + +#[derive(Debug, Display, Clone)] +pub enum AmqpTransportError { + Codec(AmqpCodecError), + TooManyChannels, + Disconnected, + Timeout, + #[display(fmt = "Connection closed, error: {:?}", _0)] + Closed(Option), + #[display(fmt = "Session ended, error: {:?}", _0)] + SessionEnded(Option), + #[display(fmt = "Link detached, error: {:?}", _0)] + LinkDetached(Option), +} + +impl From for AmqpTransportError { + fn from(err: AmqpCodecError) -> Self { + AmqpTransportError::Codec(err) + } +} + +#[derive(Debug, Display, From)] +pub enum SaslConnectError { + Protocol(ProtocolIdError), + AmqpError(AmqpCodecError), + #[display(fmt = "Sasl error code: {:?}", _0)] + Sasl(protocol::SaslCode), + ExpectedOpenFrame, + Disconnected, +} diff --git a/actix-amqp/src/hb.rs b/actix-amqp/src/hb.rs new file mode 100755 index 000000000..487d19329 --- /dev/null +++ b/actix-amqp/src/hb.rs @@ -0,0 +1,93 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +use actix_rt::time::{delay_until, Delay, Instant}; +use actix_utils::time::LowResTimeService; + +use crate::errors::AmqpTransportError; + +pub(crate) enum HeartbeatAction { + None, + Heartbeat, + Close, +} + +pub(crate) struct Heartbeat { + expire_local: Instant, + expire_remote: Instant, + local: Duration, + remote: Option, + time: LowResTimeService, + delay: Delay, +} + +impl Heartbeat { + pub(crate) fn new(local: Duration, remote: Option, time: LowResTimeService) -> Self { + let now = Instant::from_std(time.now()); + let delay = if let Some(remote) = remote { + delay_until(now + std::cmp::min(local, remote)) + } else { + delay_until(now + local) + }; + + Heartbeat { + expire_local: now, + expire_remote: now, + local, + remote, + time, + delay, + } + } + + pub(crate) fn update_local(&mut self, update: bool) { + if update { + self.expire_local = Instant::from_std(self.time.now()); + } + } + + pub(crate) fn update_remote(&mut self, update: bool) { + if update && self.remote.is_some() { + self.expire_remote = Instant::from_std(self.time.now()); + } + } + + fn next_expire(&self) -> Instant { + if let Some(remote) = self.remote { + let t1 = self.expire_local + self.local; + let t2 = self.expire_remote + remote; + if t1 < t2 { + t1 + } else { + t2 + } + } else { + self.expire_local + self.local + } + } + + pub(crate) fn poll(&mut self, cx: &mut Context) -> Result { + match Pin::new(&mut self.delay).poll(cx) { + Poll::Ready(_) => { + let mut act = HeartbeatAction::None; + let dl = self.delay.deadline(); + if dl >= self.expire_local + self.local { + // close connection + return Ok(HeartbeatAction::Close); + } + if let Some(remote) = self.remote { + if dl >= self.expire_remote + remote { + // send heartbeat + act = HeartbeatAction::Heartbeat; + } + } + self.delay.reset(self.next_expire()); + let _ = Pin::new(&mut self.delay).poll(cx); + Ok(act) + } + Poll::Pending => Ok(HeartbeatAction::None), + } + } +} diff --git a/actix-amqp/src/lib.rs b/actix-amqp/src/lib.rs new file mode 100755 index 000000000..72b3ce19d --- /dev/null +++ b/actix-amqp/src/lib.rs @@ -0,0 +1,167 @@ +#![allow(unused_imports, dead_code)] + +#[macro_use] +extern crate derive_more; +#[macro_use] +extern crate log; + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +use actix_utils::oneshot; +use amqp_codec::protocol::{Disposition, Handle, Milliseconds, Open}; +use bytes::Bytes; +use bytestring::ByteString; +use uuid::Uuid; + +mod cell; +pub mod client; +mod connection; +mod errors; +mod hb; +mod rcvlink; +pub mod sasl; +pub mod server; +mod service; +mod session; +mod sndlink; + +pub use self::connection::{Connection, ConnectionController}; +pub use self::errors::AmqpTransportError; +pub use self::rcvlink::{ReceiverLink, ReceiverLinkBuilder}; +pub use self::session::Session; +pub use self::sndlink::{SenderLink, SenderLinkBuilder}; + +pub enum Delivery { + Resolved(Result), + Pending(oneshot::Receiver>), + Gone, +} + +type DeliveryPromise = oneshot::Sender>; + +impl Future for Delivery { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + if let Delivery::Pending(ref mut receiver) = *self { + return match Pin::new(receiver).poll(cx) { + Poll::Ready(Ok(r)) => Poll::Ready(r.map(|state| state)), + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => { + trace!("delivery oneshot is gone: {:?}", e); + Poll::Ready(Err(AmqpTransportError::Disconnected)) + } + }; + } + + let old_v = ::std::mem::replace(&mut *self, Delivery::Gone); + if let Delivery::Resolved(r) = old_v { + return match r { + Ok(state) => Poll::Ready(Ok(state)), + Err(e) => Poll::Ready(Err(e)), + }; + } + panic!("Polling Delivery after it was polled as ready is an error."); + } +} + +/// Amqp1 transport configuration. +#[derive(Debug, Clone)] +pub struct Configuration { + pub max_frame_size: u32, + pub channel_max: usize, + pub idle_time_out: Option, + pub hostname: Option, +} + +impl Default for Configuration { + fn default() -> Self { + Self::new() + } +} + +impl Configuration { + /// Create connection configuration. + pub fn new() -> Self { + Configuration { + max_frame_size: std::u16::MAX as u32, + channel_max: 1024, + idle_time_out: Some(120000), + hostname: None, + } + } + + /// The channel-max value is the highest channel number that + /// may be used on the Connection. This value plus one is the maximum + /// number of Sessions that can be simultaneously active on the Connection. + /// + /// By default channel max value is set to 1024 + pub fn channel_max(&mut self, num: u16) -> &mut Self { + self.channel_max = num as usize; + self + } + + /// Set max frame size for the connection. + /// + /// By default max size is set to 64kb + pub fn max_frame_size(&mut self, size: u32) -> &mut Self { + self.max_frame_size = size; + self + } + + /// Get max frame size for the connection. + pub fn get_max_frame_size(&self) -> usize { + self.max_frame_size as usize + } + + /// Set idle time-out for the connection in milliseconds + /// + /// By default idle time-out is set to 120000 milliseconds + pub fn idle_timeout(&mut self, timeout: u32) -> &mut Self { + self.idle_time_out = Some(timeout as Milliseconds); + self + } + + /// Set connection hostname + /// + /// Hostname is not set by default + pub fn hostname(&mut self, hostname: &str) -> &mut Self { + self.hostname = Some(ByteString::from(hostname)); + self + } + + /// Create `Open` performative for this configuration. + pub fn to_open(&self) -> Open { + Open { + container_id: ByteString::from(Uuid::new_v4().to_simple().to_string()), + hostname: self.hostname.clone(), + max_frame_size: self.max_frame_size, + channel_max: self.channel_max as u16, + idle_time_out: self.idle_time_out, + outgoing_locales: None, + incoming_locales: None, + offered_capabilities: None, + desired_capabilities: None, + properties: None, + } + } + + pub(crate) fn timeout(&self) -> Option { + self.idle_time_out + .map(|v| Duration::from_millis(((v as f32) * 0.8) as u64)) + } +} + +impl<'a> From<&'a Open> for Configuration { + fn from(open: &'a Open) -> Self { + Configuration { + max_frame_size: open.max_frame_size, + channel_max: open.channel_max as usize, + idle_time_out: open.idle_time_out, + hostname: open.hostname.clone(), + } + } +} diff --git a/actix-amqp/src/rcvlink.rs b/actix-amqp/src/rcvlink.rs new file mode 100755 index 000000000..8abff7928 --- /dev/null +++ b/actix-amqp/src/rcvlink.rs @@ -0,0 +1,263 @@ +use std::collections::VecDeque; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::u32; + +use actix_utils::oneshot; +use actix_utils::task::LocalWaker; +use amqp_codec::protocol::{ + Attach, DeliveryNumber, Disposition, Error, Handle, LinkError, ReceiverSettleMode, Role, + SenderSettleMode, Source, TerminusDurability, TerminusExpiryPolicy, Transfer, +}; +use bytes::Bytes; +use bytestring::ByteString; +use futures::Stream; + +use crate::cell::Cell; +use crate::errors::AmqpTransportError; +use crate::session::{Session, SessionInner}; +use crate::Configuration; + +#[derive(Clone, Debug)] +pub struct ReceiverLink { + pub(crate) inner: Cell, +} + +impl ReceiverLink { + pub(crate) fn new(inner: Cell) -> ReceiverLink { + ReceiverLink { inner } + } + + pub fn handle(&self) -> Handle { + self.inner.get_ref().handle as Handle + } + + pub fn credit(&self) -> u32 { + self.inner.get_ref().credit + } + + pub fn session(&self) -> &Session { + &self.inner.get_ref().session + } + + pub fn session_mut(&mut self) -> &mut Session { + &mut self.inner.get_mut().session + } + + pub fn frame(&self) -> &Attach { + &self.inner.get_ref().attach + } + + pub fn open(&mut self) { + let inner = self.inner.get_mut(); + inner + .session + .inner + .get_mut() + .confirm_receiver_link(inner.handle, &inner.attach); + } + + pub fn set_link_credit(&mut self, credit: u32) { + self.inner.get_mut().set_link_credit(credit); + } + + /// Send disposition frame + pub fn send_disposition(&mut self, disp: Disposition) { + self.inner + .get_mut() + .session + .inner + .get_mut() + .post_frame(disp.into()); + } + + /// Wait for disposition with specified number + pub fn wait_disposition( + &mut self, + id: DeliveryNumber, + ) -> impl Future> { + self.inner.get_mut().session.wait_disposition(id) + } + + pub fn close(&self) -> impl Future> { + self.inner.get_mut().close(None) + } + + pub fn close_with_error( + &self, + error: Error, + ) -> impl Future> { + self.inner.get_mut().close(Some(error)) + } + + #[inline] + /// Get remote connection configuration + pub fn remote_config(&self) -> &Configuration { + &self.inner.session.remote_config() + } +} + +impl Stream for ReceiverLink { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let inner = self.inner.get_mut(); + + if let Some(tr) = inner.queue.pop_front() { + Poll::Ready(Some(Ok(tr))) + } else { + if inner.closed { + Poll::Ready(None) + } else { + inner.reader_task.register(cx.waker()); + Poll::Pending + } + } + } +} + +#[derive(Debug)] +pub(crate) struct ReceiverLinkInner { + handle: Handle, + attach: Attach, + session: Session, + closed: bool, + reader_task: LocalWaker, + queue: VecDeque, + credit: u32, + delivery_count: u32, +} + +impl ReceiverLinkInner { + pub(crate) fn new( + session: Cell, + handle: Handle, + attach: Attach, + ) -> ReceiverLinkInner { + ReceiverLinkInner { + handle, + session: Session::new(session), + closed: false, + reader_task: LocalWaker::new(), + queue: VecDeque::with_capacity(4), + credit: 0, + delivery_count: attach.initial_delivery_count().unwrap_or(0), + attach, + } + } + + pub fn name(&self) -> &ByteString { + &self.attach.name + } + + pub fn close( + &mut self, + error: Option, + ) -> impl Future> { + let (tx, rx) = oneshot::channel(); + if self.closed { + let _ = tx.send(Ok(())); + } else { + self.session + .inner + .get_mut() + .detach_receiver_link(self.handle, true, error, tx); + } + async move { + match rx.await { + Ok(Ok(_)) => Ok(()), + Ok(Err(e)) => Err(e), + Err(_) => Err(AmqpTransportError::Disconnected), + } + } + } + + pub fn set_link_credit(&mut self, credit: u32) { + self.credit += credit; + self.session + .inner + .get_mut() + .rcv_link_flow(self.handle as u32, self.delivery_count, credit); + } + + pub fn handle_transfer(&mut self, transfer: Transfer) { + if self.credit == 0 { + // check link credit + let err = Error { + condition: LinkError::TransferLimitExceeded.into(), + description: None, + info: None, + }; + let _ = self.close(Some(err)); + } else { + self.credit -= 1; + self.delivery_count += 1; + self.queue.push_back(transfer); + if self.queue.len() == 1 { + self.reader_task.wake() + } + } + } +} + +pub struct ReceiverLinkBuilder { + frame: Attach, + session: Cell, +} + +impl ReceiverLinkBuilder { + pub(crate) fn new(name: ByteString, address: ByteString, session: Cell) -> Self { + let source = Source { + address: Some(address), + durable: TerminusDurability::None, + expiry_policy: TerminusExpiryPolicy::SessionEnd, + timeout: 0, + dynamic: false, + dynamic_node_properties: None, + distribution_mode: None, + filter: None, + default_outcome: None, + outcomes: None, + capabilities: None, + }; + let frame = Attach { + name, + handle: 0 as Handle, + role: Role::Receiver, + snd_settle_mode: SenderSettleMode::Mixed, + rcv_settle_mode: ReceiverSettleMode::First, + source: Some(source), + target: None, + unsettled: None, + incomplete_unsettled: false, + initial_delivery_count: None, + max_message_size: Some(65536 * 4), + offered_capabilities: None, + desired_capabilities: None, + properties: None, + }; + + ReceiverLinkBuilder { frame, session } + } + + pub fn max_message_size(mut self, size: u64) -> Self { + self.frame.max_message_size = Some(size); + self + } + + pub async fn open(self) -> Result { + let cell = self.session.clone(); + let res = self + .session + .get_mut() + .open_local_receiver_link(cell, self.frame) + .await; + + match res { + Ok(Ok(res)) => Ok(res), + Ok(Err(err)) => Err(err), + Err(_) => Err(AmqpTransportError::Disconnected), + } + } +} diff --git a/actix-amqp/src/sasl.rs b/actix-amqp/src/sasl.rs new file mode 100755 index 000000000..b7a7a743e --- /dev/null +++ b/actix-amqp/src/sasl.rs @@ -0,0 +1,190 @@ +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_connect::{Connect as TcpConnect, Connection as TcpConnection}; +use actix_service::{apply_fn, pipeline, IntoService, Service}; +use actix_utils::time::LowResTimeService; +use bytestring::ByteString; +use either::Either; +use futures::future::{ok, Future}; +use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt}; +use http::Uri; + +use amqp_codec::protocol::{Frame, ProtocolId, SaslCode, SaslFrameBody, SaslInit}; +use amqp_codec::types::Symbol; +use amqp_codec::{AmqpCodec, AmqpFrame, ProtocolIdCodec, SaslFrame}; + +use crate::connection::Connection; +use crate::service::ProtocolNegotiation; + +use super::Configuration; +pub use crate::errors::SaslConnectError; + +#[derive(Debug)] +/// Sasl connect request +pub struct SaslConnect { + pub uri: Uri, + pub config: Configuration, + pub auth: SaslAuth, + pub time: Option, +} + +#[derive(Debug)] +/// Sasl authentication parameters +pub struct SaslAuth { + pub authz_id: String, + pub authn_id: String, + pub password: String, +} + +/// Create service that connects to amqp server and authenticate itself via sasl. +/// This service uses supplied connector service. Service resolves to +/// a `Connection<_>` instance. +pub fn connect_service( + connector: T, +) -> impl Service< + Request = SaslConnect, + Response = Connection, + Error = either::Either, +> +where + T: Service, Response = TcpConnection>, + T::Error: 'static, + Io: AsyncRead + AsyncWrite + 'static, +{ + pipeline(|connect: SaslConnect| { + let SaslConnect { + uri, + config, + auth, + time, + } = connect; + ok::<_, either::Either>((uri, config, auth, time)) + }) + // connect to host + .and_then(apply_fn( + connector.map_err(|e| either::Right(e)), + |(uri, config, auth, time): (Uri, Configuration, _, _), srv| { + let fut = srv.call(uri.clone().into()); + async move { + fut.await.map(|stream| { + let (io, _) = stream.into_parts(); + (io, uri, config, auth, time) + }) + } + }, + )) + // sasl protocol negotiation + .and_then(apply_fn( + ProtocolNegotiation::new(ProtocolId::AmqpSasl) + .map_err(|e| Either::Left(SaslConnectError::from(e))), + |(io, uri, config, auth, time): (Io, _, _, _, _), srv| { + let framed = Framed::new(io, ProtocolIdCodec); + let fut = srv.call(framed); + async move { + fut.await + .map(move |framed| (framed, uri, config, auth, time)) + } + }, + )) + // sasl auth + .and_then(apply_fn( + sasl_connect.into_service().map_err(Either::Left), + |(framed, uri, config, auth, time): (_, Uri, _, _, _), srv| { + let fut = srv.call((framed, uri.clone(), auth)); + async move { fut.await.map(move |framed| (uri, config, framed, time)) } + }, + )) + // re-negotiate amqp protocol negotiation + .and_then(apply_fn( + ProtocolNegotiation::new(ProtocolId::Amqp) + .map_err(|e| Either::Left(SaslConnectError::from(e))), + |(uri, config, framed, time): (_, _, Framed, _), srv| { + let fut = srv.call(framed); + async move { fut.await.map(move |framed| (uri, config, framed, time)) } + }, + )) + // open connection + .and_then( + |(uri, mut config, framed, time): (Uri, Configuration, Framed, _)| { + async move { + let mut framed = framed.into_framed(AmqpCodec::::new()); + if let Some(hostname) = uri.host() { + config.hostname(hostname); + } + let open = config.to_open(); + trace!("Open connection: {:?}", open); + framed + .send(AmqpFrame::new(0, Frame::Open(open))) + .await + .map_err(|e| Either::Left(SaslConnectError::from(e))) + .map(move |_| (config, framed, time)) + } + }, + ) + // read open frame + .and_then( + move |(config, mut framed, time): (Configuration, Framed<_, AmqpCodec>, _)| { + async move { + let frame = framed + .next() + .await + .ok_or(Either::Left(SaslConnectError::Disconnected))? + .map_err(|e| Either::Left(SaslConnectError::from(e)))?; + + if let Frame::Open(open) = frame.performative() { + trace!("Open confirmed: {:?}", open); + Ok(Connection::new(framed, config, open.into(), time)) + } else { + Err(Either::Left(SaslConnectError::ExpectedOpenFrame)) + } + } + }, + ) +} + +async fn sasl_connect( + (framed, uri, auth): (Framed, Uri, SaslAuth), +) -> Result, SaslConnectError> { + let mut sasl_io = framed.into_framed(AmqpCodec::::new()); + + // processing sasl-mechanisms + let _ = sasl_io + .next() + .await + .ok_or(SaslConnectError::Disconnected)? + .map_err(SaslConnectError::from)?; + + let initial_response = + SaslInit::prepare_response(&auth.authz_id, &auth.authn_id, &auth.password); + + let hostname = uri.host().map(|host| ByteString::from(host)); + + let sasl_init = SaslInit { + hostname, + mechanism: Symbol::from("PLAIN"), + initial_response: Some(initial_response), + }; + + sasl_io + .send(sasl_init.into()) + .await + .map_err(SaslConnectError::from)?; + + // processing sasl-outcome + let sasl_frame = sasl_io + .next() + .await + .ok_or(SaslConnectError::Disconnected)? + .map_err(SaslConnectError::from)?; + + if let SaslFrame { + body: SaslFrameBody::SaslOutcome(outcome), + } = sasl_frame + { + if outcome.code() != SaslCode::Ok { + return Err(SaslConnectError::Sasl(outcome.code())); + } + } else { + return Err(SaslConnectError::Disconnected); + } + Ok(sasl_io.into_framed(ProtocolIdCodec)) +} diff --git a/actix-amqp/src/server/app.rs b/actix-amqp/src/server/app.rs new file mode 100755 index 000000000..bd7e9e8b9 --- /dev/null +++ b/actix-amqp/src/server/app.rs @@ -0,0 +1,263 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use actix_router::{IntoPattern, Router}; +use actix_service::{boxed, fn_factory_with_config, IntoServiceFactory, Service, ServiceFactory}; +use amqp_codec::protocol::{DeliveryNumber, DeliveryState, Disposition, Error, Rejected, Role}; +use futures::future::{err, ok, Either, Ready}; +use futures::{Stream, StreamExt}; + +use crate::cell::Cell; +use crate::rcvlink::ReceiverLink; + +use super::errors::LinkError; +use super::link::Link; +use super::message::{Message, Outcome}; +use super::State; + +type Handle = boxed::BoxServiceFactory, Message, Outcome, Error, Error>; + +pub struct App(Vec<(Vec, Handle)>); + +impl App { + pub fn new() -> App { + App(Vec::new()) + } + + pub fn service(mut self, address: T, service: F) -> Self + where + T: IntoPattern, + F: IntoServiceFactory, + U: ServiceFactory, Request = Message, Response = Outcome>, + U::Error: Into, + U::InitError: Into, + { + self.0.push(( + address.patterns(), + boxed::factory( + service + .into_factory() + .map_init_err(|e| e.into()) + .map_err(|e| e.into()), + ), + )); + + self + } + + pub fn finish( + self, + ) -> impl ServiceFactory< + Config = State, + Request = Link, + Response = (), + Error = Error, + InitError = Error, + > { + let mut router = Router::build(); + for (addr, hnd) in self.0 { + router.path(addr, hnd); + } + let router = Cell::new(router.finish()); + + fn_factory_with_config(move |_: State| { + ok(AppService { + router: router.clone(), + }) + }) + } +} + +struct AppService { + router: Cell>>, +} + +impl Service for AppService { + type Request = Link; + type Response = (); + type Error = Error; + type Future = Either>, AppServiceResponse>; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, mut link: Link) -> Self::Future { + let path = link + .frame() + .target + .as_ref() + .and_then(|target| target.address.as_ref().map(|addr| addr.clone())); + + if let Some(path) = path { + link.path_mut().set(path); + if let Some((hnd, _info)) = self.router.recognize(link.path_mut()) { + let fut = hnd.new_service(link.clone()); + Either::Right(AppServiceResponse { + link: link.link.clone(), + app_state: link.state.clone(), + state: AppServiceResponseState::NewService(fut), + // has_credit: true, + }) + } else { + Either::Left(err(LinkError::force_detach() + .description(format!( + "Target address is not supported: {}", + link.path().get_ref() + )) + .into())) + } + } else { + Either::Left(err(LinkError::force_detach() + .description("Target address is required") + .into())) + } + } +} + +struct AppServiceResponse { + link: ReceiverLink, + app_state: State, + state: AppServiceResponseState, + // has_credit: bool, +} + +enum AppServiceResponseState { + Service(boxed::BoxService, Outcome, Error>), + NewService( + Pin, Outcome, Error>, Error>>>>, + ), +} + +impl Future for AppServiceResponse { + type Output = Result<(), Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut this = self.as_mut(); + let mut link = this.link.clone(); + let app_state = this.app_state.clone(); + + loop { + match this.state { + AppServiceResponseState::Service(ref mut srv) => { + // check readiness + match srv.poll_ready(cx) { + Poll::Ready(Ok(_)) => (), + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => { + let _ = this.link.close_with_error( + LinkError::force_detach() + .description(format!("error: {}", e)) + .into(), + ); + return Poll::Ready(Ok(())); + } + } + + match Pin::new(&mut link).poll_next(cx) { + Poll::Ready(Some(Ok(transfer))) => { + // #2.7.5 delivery_id MUST be set. batching is not supported atm + if transfer.delivery_id.is_none() { + let _ = this.link.close_with_error( + LinkError::force_detach() + .description("delivery_id MUST be set") + .into(), + ); + return Poll::Ready(Ok(())); + } + if link.credit() == 0 { + // self.has_credit = self.link.credit() != 0; + link.set_link_credit(50); + } + + let delivery_id = transfer.delivery_id.unwrap(); + let msg = Message::new(app_state.clone(), transfer, link.clone()); + + let mut fut = srv.call(msg); + match Pin::new(&mut fut).poll(cx) { + Poll::Ready(Ok(outcome)) => settle( + &mut this.link, + delivery_id, + outcome.into_delivery_state(), + ), + Poll::Pending => { + actix_rt::spawn(HandleMessage { + fut, + delivery_id, + link: this.link.clone(), + }); + } + Poll::Ready(Err(e)) => settle( + &mut this.link, + delivery_id, + DeliveryState::Rejected(Rejected { error: Some(e) }), + ), + } + } + Poll::Ready(None) => return Poll::Ready(Ok(())), + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Err(_))) => { + let _ = this.link.close_with_error(LinkError::force_detach().into()); + return Poll::Ready(Ok(())); + } + } + } + AppServiceResponseState::NewService(ref mut fut) => match Pin::new(fut).poll(cx) { + Poll::Ready(Ok(srv)) => { + this.link.open(); + this.link.set_link_credit(50); + this.state = AppServiceResponseState::Service(srv); + continue; + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + }, + } + } + } +} + +struct HandleMessage { + link: ReceiverLink, + delivery_id: DeliveryNumber, + fut: Pin>>>, +} + +impl Future for HandleMessage { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut this = self.as_mut(); + + match Pin::new(&mut this.fut).poll(cx) { + Poll::Ready(Ok(outcome)) => { + let delivery_id = this.delivery_id; + settle(&mut this.link, delivery_id, outcome.into_delivery_state()); + Poll::Ready(()) + } + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => { + let delivery_id = this.delivery_id; + settle( + &mut this.link, + delivery_id, + DeliveryState::Rejected(Rejected { error: Some(e) }), + ); + Poll::Ready(()) + } + } + } +} + +fn settle(link: &mut ReceiverLink, id: DeliveryNumber, state: DeliveryState) { + let disposition = Disposition { + state: Some(state), + role: Role::Receiver, + first: id, + last: None, + settled: true, + batchable: false, + }; + link.send_disposition(disposition); +} diff --git a/actix-amqp/src/server/connect.rs b/actix-amqp/src/server/connect.rs new file mode 100755 index 000000000..37d1729d9 --- /dev/null +++ b/actix-amqp/src/server/connect.rs @@ -0,0 +1,120 @@ +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use amqp_codec::protocol::{Frame, Open}; +use amqp_codec::{AmqpCodec, AmqpFrame, ProtocolIdCodec}; +use futures::{Future, StreamExt}; + +use super::errors::ServerError; +use crate::connection::ConnectionController; + +/// Open new connection +pub struct Connect { + conn: Framed, + controller: ConnectionController, +} + +impl Connect { + pub(crate) fn new(conn: Framed, controller: ConnectionController) -> Self { + Self { conn, controller } + } + + /// Returns reference to io object + pub fn get_ref(&self) -> &Io { + self.conn.get_ref() + } + + /// Returns mutable reference to io object + pub fn get_mut(&mut self) -> &mut Io { + self.conn.get_mut() + } +} + +impl Connect { + /// Wait for connection open frame + pub async fn open(self) -> Result, ServerError<()>> { + let mut framed = self.conn.into_framed(AmqpCodec::::new()); + let mut controller = self.controller; + + let frame = framed + .next() + .await + .ok_or(ServerError::Disconnected)? + .map_err(ServerError::from)?; + + let frame = frame.into_parts().1; + match frame { + Frame::Open(frame) => { + trace!("Got open frame: {:?}", frame); + controller.set_remote((&frame).into()); + Ok(ConnectOpened { + frame, + framed, + controller, + }) + } + frame => Err(ServerError::Unexpected(frame)), + } + } +} + +/// Connection is opened +pub struct ConnectOpened { + frame: Open, + framed: Framed>, + controller: ConnectionController, +} + +impl ConnectOpened { + pub(crate) fn new( + frame: Open, + framed: Framed>, + controller: ConnectionController, + ) -> Self { + ConnectOpened { + frame, + framed, + controller, + } + } + + /// Get reference to remote `Open` frame + pub fn frame(&self) -> &Open { + &self.frame + } + + /// Returns reference to io object + pub fn get_ref(&self) -> &Io { + self.framed.get_ref() + } + + /// Returns mutable reference to io object + pub fn get_mut(&mut self) -> &mut Io { + self.framed.get_mut() + } + + /// Connection controller + pub fn connection(&self) -> &ConnectionController { + &self.controller + } + + /// Ack connect message and set state + pub fn ack(self, state: St) -> ConnectAck { + ConnectAck { + state, + framed: self.framed, + controller: self.controller, + } + } +} + +/// Ack connect message +pub struct ConnectAck { + state: St, + framed: Framed>, + controller: ConnectionController, +} + +impl ConnectAck { + pub(crate) fn into_inner(self) -> (St, Framed>, ConnectionController) { + (self.state, self.framed, self.controller) + } +} diff --git a/actix-amqp/src/server/control.rs b/actix-amqp/src/server/control.rs new file mode 100755 index 000000000..1b875a591 --- /dev/null +++ b/actix-amqp/src/server/control.rs @@ -0,0 +1,69 @@ +use actix_service::boxed::{BoxService, BoxServiceFactory}; +use amqp_codec::protocol; + +use crate::cell::Cell; +use crate::rcvlink::ReceiverLink; +use crate::session::Session; +use crate::sndlink::SenderLink; + +use super::errors::LinkError; +use super::State; + +pub(crate) type ControlFrameService = BoxService, (), LinkError>; +pub(crate) type ControlFrameNewService = + BoxServiceFactory<(), ControlFrame, (), LinkError, ()>; + +pub struct ControlFrame(pub(super) Cell>); + +pub(super) struct FrameInner { + pub(super) kind: ControlFrameKind, + pub(super) state: State, + pub(super) session: Session, +} + +#[derive(Debug)] +pub enum ControlFrameKind { + Attach(protocol::Attach), + Flow(protocol::Flow, SenderLink), + DetachSender(protocol::Detach, SenderLink), + DetachReceiver(protocol::Detach, ReceiverLink), +} + +impl ControlFrame { + pub(crate) fn new(state: State, session: Session, kind: ControlFrameKind) -> Self { + ControlFrame(Cell::new(FrameInner { + state, + session, + kind, + })) + } + + pub(crate) fn clone(&self) -> Self { + ControlFrame(self.0.clone()) + } + + #[inline] + pub fn state(&self) -> &St { + self.0.state.get_ref() + } + + #[inline] + pub fn state_mut(&mut self) -> &mut St { + self.0.get_mut().state.get_mut() + } + + #[inline] + pub fn session(&self) -> &Session { + &self.0.session + } + + #[inline] + pub fn session_mut(&mut self) -> &mut Session { + &mut self.0.get_mut().session + } + + #[inline] + pub fn frame(&self) -> &ControlFrameKind { + &self.0.kind + } +} diff --git a/actix-amqp/src/server/dispatcher.rs b/actix-amqp/src/server/dispatcher.rs new file mode 100755 index 000000000..100978722 --- /dev/null +++ b/actix-amqp/src/server/dispatcher.rs @@ -0,0 +1,287 @@ +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_service::Service; +use amqp_codec::protocol::{Error, Frame, Role}; +use amqp_codec::AmqpCodecError; +use slab::Slab; + +use crate::cell::Cell; +use crate::connection::{ChannelState, Connection}; +use crate::rcvlink::ReceiverLink; +use crate::session::Session; + +use super::control::{ControlFrame, ControlFrameKind, ControlFrameService}; +use super::errors::LinkError; +use super::{Link, State}; + +/// Amqp server connection dispatcher. +#[pin_project::pin_project] +pub struct Dispatcher +where + Io: AsyncRead + AsyncWrite, + Sr: Service, Response = ()>, +{ + conn: Connection, + state: State, + service: Sr, + control_srv: Option>, + control_frame: Option>, + #[pin] + control_fut: Option< as Service>::Future>, + receivers: Vec<(ReceiverLink, Sr::Future)>, + _channels: slab::Slab, +} + +enum IncomingResult { + Control, + Done, + Disconnect, +} + +impl Dispatcher +where + Io: AsyncRead + AsyncWrite, + Sr: Service, Response = ()>, + Sr::Error: fmt::Display + Into, +{ + pub(crate) fn new( + conn: Connection, + state: State, + service: Sr, + control_srv: Option>, + ) -> Self { + Dispatcher { + conn, + service, + state, + control_srv, + control_frame: None, + control_fut: None, + receivers: Vec::with_capacity(16), + _channels: Slab::with_capacity(16), + } + } + + fn handle_control_fut(&mut self, cx: &mut Context) -> bool { + // process control frame + if let Some(ref mut fut) = self.control_fut { + match Pin::new(fut).poll(cx) { + Poll::Ready(Ok(_)) => { + self.control_fut.take(); + let frame = self.control_frame.take().unwrap(); + self.handle_control_frame(&frame, None); + } + Poll::Pending => return false, + Poll::Ready(Err(e)) => { + let _ = self.control_fut.take(); + let frame = self.control_frame.take().unwrap(); + self.handle_control_frame(&frame, Some(e)); + } + } + } + true + } + + fn handle_control_frame(&self, frame: &ControlFrame, err: Option) { + if let Some(e) = err { + error!("Error in link handler: {}", e); + } else { + match frame.0.kind { + ControlFrameKind::Attach(ref frm) => { + let cell = frame.0.session.inner.clone(); + frame + .0 + .session + .inner + .get_mut() + .confirm_sender_link(cell, &frm); + } + ControlFrameKind::Flow(ref frm, ref link) => { + if let Some(err) = err { + let _ = link.close_with_error(err.into()); + } else { + frame.0.session.inner.get_mut().apply_flow(frm); + } + } + ControlFrameKind::DetachSender(ref frm, ref link) => { + if let Some(err) = err { + let _ = link.close_with_error(err.into()); + } else { + frame.0.session.inner.get_mut().handle_detach(&frm); + } + } + ControlFrameKind::DetachReceiver(ref frm, ref link) => { + if let Some(err) = err { + let _ = link.close_with_error(err.into()); + } else { + frame.0.session.inner.get_mut().handle_detach(frm); + } + } + } + } + } + + fn poll_incoming(&mut self, cx: &mut Context) -> Result { + loop { + // handle remote begin and attach + match self.conn.poll_incoming(cx) { + Poll::Ready(Some(Ok(frame))) => { + let (channel_id, frame) = frame.into_parts(); + let channel_id = channel_id as usize; + + match frame { + Frame::Begin(frm) => { + self.conn.register_remote_session(channel_id as u16, &frm); + } + Frame::Flow(frm) => { + // apply flow to specific link + let session = self.conn.get_session(channel_id); + if self.control_srv.is_some() { + if let Some(link) = + session.get_sender_link_by_handle(frm.handle.unwrap()) + { + self.control_frame = Some(ControlFrame::new( + self.state.clone(), + Session::new(session.clone()), + ControlFrameKind::Flow(frm, link.clone()), + )); + return Ok(IncomingResult::Control); + } + } + session.get_mut().apply_flow(&frm); + } + Frame::Attach(attach) => match attach.role { + Role::Receiver => { + // remotly opened sender link + let session = self.conn.get_session(channel_id); + let cell = session.clone(); + if self.control_srv.is_some() { + self.control_frame = Some(ControlFrame::new( + self.state.clone(), + Session::new(cell.clone()), + ControlFrameKind::Attach(attach), + )); + return Ok(IncomingResult::Control); + } else { + session.get_mut().confirm_sender_link(cell, &attach); + } + } + Role::Sender => { + // receiver link + let session = self.conn.get_session(channel_id); + let cell = session.clone(); + let link = session.get_mut().open_receiver_link(cell, attach); + let fut = self + .service + .call(Link::new(link.clone(), self.state.clone())); + self.receivers.push((link, fut)); + } + }, + Frame::Detach(frm) => { + let session = self.conn.get_session(channel_id); + let cell = session.clone(); + + if self.control_srv.is_some() { + if let Some(link) = session.get_sender_link_by_handle(frm.handle) { + self.control_frame = Some(ControlFrame::new( + self.state.clone(), + Session::new(cell.clone()), + ControlFrameKind::DetachSender(frm, link.clone()), + )); + + return Ok(IncomingResult::Control); + } else if let Some(link) = + session.get_receiver_link_by_handle(frm.handle) + { + self.control_frame = Some(ControlFrame::new( + self.state.clone(), + Session::new(cell.clone()), + ControlFrameKind::DetachReceiver(frm, link.clone()), + )); + + return Ok(IncomingResult::Control); + } + } + session.get_mut().handle_frame(Frame::Detach(frm)); + } + _ => { + trace!("Unexpected frame {:#?}", frame); + } + } + } + Poll::Pending => break, + Poll::Ready(None) => return Ok(IncomingResult::Disconnect), + Poll::Ready(Some(Err(e))) => return Err(e), + } + } + + Ok(IncomingResult::Done) + } +} + +impl Future for Dispatcher +where + Io: AsyncRead + AsyncWrite, + Sr: Service, Response = ()>, + Sr::Error: fmt::Display + Into, +{ + type Output = Result<(), AmqpCodecError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + // process control frame + if !self.handle_control_fut(cx) { + return Poll::Pending; + } + + // check control frames service + if self.control_frame.is_some() { + let srv = self.control_srv.as_mut().unwrap(); + match srv.poll_ready(cx) { + Poll::Ready(Ok(_)) => { + let frame = self.control_frame.as_ref().unwrap().clone(); + let srv = self.control_srv.as_mut().unwrap(); + self.control_fut = Some(srv.call(frame)); + if !self.handle_control_fut(cx) { + return Poll::Pending; + } + } + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => { + let frame = self.control_frame.take().unwrap(); + self.handle_control_frame(&frame, Some(e)); + } + } + } + + match self.poll_incoming(cx)? { + IncomingResult::Control => return self.poll(cx), + IncomingResult::Disconnect => return Poll::Ready(Ok(())), + IncomingResult::Done => (), + } + + // process service responses + let mut idx = 0; + while idx < self.receivers.len() { + match unsafe { Pin::new_unchecked(&mut self.receivers[idx].1) }.poll(cx) { + Poll::Ready(Ok(_detach)) => { + let (link, _) = self.receivers.swap_remove(idx); + let _ = link.close(); + } + Poll::Pending => idx += 1, + Poll::Ready(Err(e)) => { + let (link, _) = self.receivers.swap_remove(idx); + error!("Error in link handler: {}", e); + let _ = link.close_with_error(e.into()); + } + } + } + + let res = self.conn.poll_outgoing(cx); + self.conn.register_write_task(cx); + res + } +} diff --git a/actix-amqp/src/server/errors.rs b/actix-amqp/src/server/errors.rs new file mode 100755 index 000000000..9891eb0de --- /dev/null +++ b/actix-amqp/src/server/errors.rs @@ -0,0 +1,185 @@ +use std::io; + +use amqp_codec::{protocol, AmqpCodecError, ProtocolIdError, SaslFrame}; +use bytestring::ByteString; +use derive_more::Display; + +pub use amqp_codec::protocol::Error; + +/// Errors which can occur when attempting to handle amqp connection. +#[derive(Debug, Display)] +pub enum ServerError { + #[display(fmt = "Message handler service error")] + /// Message handler service error + Service(E), + /// Control service init error + ControlServiceInit, + #[display(fmt = "Amqp error: {}", _0)] + /// Amqp error + Amqp(AmqpError), + #[display(fmt = "Protocol negotiation error: {}", _0)] + /// Amqp protocol negotiation error + Handshake(ProtocolIdError), + /// Amqp handshake timeout + HandshakeTimeout, + /// Amqp codec error + #[display(fmt = "Amqp codec error: {:?}", _0)] + Protocol(AmqpCodecError), + #[display(fmt = "Protocol error: {}", _0)] + /// Amqp protocol error + ProtocolError(Error), + #[display(fmt = "Expected open frame, got: {:?}", _0)] + Unexpected(protocol::Frame), + #[display(fmt = "Unexpected sasl frame: {:?}", _0)] + UnexpectedSaslFrame(SaslFrame), + #[display(fmt = "Unexpected sasl frame body: {:?}", _0)] + UnexpectedSaslBodyFrame(protocol::SaslFrameBody), + /// Peer disconnect + Disconnected, + /// Unexpected io error + Io(io::Error), +} + +impl Into for ServerError { + fn into(self) -> protocol::Error { + protocol::Error { + condition: protocol::AmqpError::InternalError.into(), + description: Some(ByteString::from(format!("{}", self))), + info: None, + } + } +} + +impl From for ServerError { + fn from(err: AmqpError) -> Self { + ServerError::Amqp(err) + } +} + +impl From for ServerError { + fn from(err: AmqpCodecError) -> Self { + ServerError::Protocol(err) + } +} + +impl From for ServerError { + fn from(err: ProtocolIdError) -> Self { + ServerError::Handshake(err) + } +} + +impl From for ServerError { + fn from(err: SaslFrame) -> Self { + ServerError::UnexpectedSaslFrame(err) + } +} + +impl From for ServerError { + fn from(err: io::Error) -> Self { + ServerError::Io(err) + } +} + +#[derive(Debug, Display)] +#[display(fmt = "Amqp error: {:?} {:?} ({:?})", err, description, info)] +pub struct AmqpError { + err: protocol::AmqpError, + description: Option, + info: Option, +} + +impl AmqpError { + pub fn new(err: protocol::AmqpError) -> Self { + AmqpError { + err, + description: None, + info: None, + } + } + + pub fn internal_error() -> Self { + Self::new(protocol::AmqpError::InternalError) + } + + pub fn not_found() -> Self { + Self::new(protocol::AmqpError::NotFound) + } + + pub fn unauthorized_access() -> Self { + Self::new(protocol::AmqpError::UnauthorizedAccess) + } + + pub fn decode_error() -> Self { + Self::new(protocol::AmqpError::DecodeError) + } + + pub fn invalid_field() -> Self { + Self::new(protocol::AmqpError::InvalidField) + } + + pub fn not_allowed() -> Self { + Self::new(protocol::AmqpError::NotAllowed) + } + + pub fn not_implemented() -> Self { + Self::new(protocol::AmqpError::NotImplemented) + } + + pub fn description>(mut self, text: T) -> Self { + self.description = Some(ByteString::from(text.as_ref())); + self + } + + pub fn set_description(mut self, text: ByteString) -> Self { + self.description = Some(text); + self + } +} + +impl Into for AmqpError { + fn into(self) -> protocol::Error { + protocol::Error { + condition: self.err.into(), + description: self.description, + info: self.info, + } + } +} + +#[derive(Debug, Display)] +#[display(fmt = "Link error: {:?} {:?} ({:?})", err, description, info)] +pub struct LinkError { + err: protocol::LinkError, + description: Option, + info: Option, +} + +impl LinkError { + pub fn force_detach() -> Self { + LinkError { + err: protocol::LinkError::DetachForced, + description: None, + info: None, + } + } + + pub fn description>(mut self, text: T) -> Self { + self.description = Some(ByteString::from(text.as_ref())); + self + } + + pub fn set_description(mut self, text: ByteString) -> Self { + self.description = Some(text); + self + } +} + +impl Into for LinkError { + fn into(self) -> protocol::Error { + protocol::Error { + condition: self.err.into(), + description: self.description, + info: self.info, + } + } +} diff --git a/actix-amqp/src/server/handshake.rs b/actix-amqp/src/server/handshake.rs new file mode 100755 index 000000000..a21ce8eeb --- /dev/null +++ b/actix-amqp/src/server/handshake.rs @@ -0,0 +1,50 @@ +use actix_service::{IntoServiceFactory, ServiceFactory}; + +use super::connect::ConnectAck; + +pub fn handshake(srv: F) -> Handshake +where + F: IntoServiceFactory, + A: ServiceFactory>, +{ + Handshake::new(srv) +} + +pub struct Handshake { + a: A, + _t: std::marker::PhantomData<(Io, St)>, +} + +impl Handshake +where + A: ServiceFactory, +{ + pub fn new(srv: F) -> Handshake + where + F: IntoServiceFactory, + { + Handshake { + a: srv.into_factory(), + _t: std::marker::PhantomData, + } + } +} + +impl Handshake +where + A: ServiceFactory>, +{ + pub fn sasl(self, srv: F) -> actix_utils::either::Either + where + F: IntoServiceFactory, + B: ServiceFactory< + Config = (), + Response = A::Response, + Error = A::Error, + InitError = A::InitError, + >, + B::Error: Into, + { + actix_utils::either::Either::new(self.a, srv.into_factory()) + } +} diff --git a/actix-amqp/src/server/link.rs b/actix-amqp/src/server/link.rs new file mode 100755 index 000000000..3e6ba7d2e --- /dev/null +++ b/actix-amqp/src/server/link.rs @@ -0,0 +1,87 @@ +use std::fmt; + +use actix_router::Path; +use amqp_codec::protocol::Attach; +use bytestring::ByteString; + +use crate::cell::Cell; +use crate::rcvlink::ReceiverLink; +use crate::server::State; +use crate::session::Session; +use crate::{Configuration, Handle}; + +pub struct Link { + pub(crate) state: State, + pub(crate) link: ReceiverLink, + pub(crate) path: Path, +} + +impl Link { + pub(crate) fn new(link: ReceiverLink, state: State) -> Self { + Link { + state, + link, + path: Path::new(ByteString::from_static("")), + } + } + + pub fn path(&self) -> &Path { + &self.path + } + + pub fn path_mut(&mut self) -> &mut Path { + &mut self.path + } + + pub fn frame(&self) -> &Attach { + self.link.frame() + } + + pub fn state(&self) -> &S { + self.state.get_ref() + } + + pub fn state_mut(&mut self) -> &mut S { + self.state.get_mut() + } + + pub fn handle(&self) -> Handle { + self.link.handle() + } + + pub fn session(&self) -> &Session { + self.link.session() + } + + pub fn session_mut(&mut self) -> &mut Session { + self.link.session_mut() + } + + pub fn link_credit(mut self, credit: u32) { + self.link.set_link_credit(credit); + } + + #[inline] + /// Get remote connection configuration + pub fn remote_config(&self) -> &Configuration { + &self.link.remote_config() + } +} + +impl Clone for Link { + fn clone(&self) -> Self { + Self { + state: self.state.clone(), + link: self.link.clone(), + path: self.path.clone(), + } + } +} + +impl fmt::Debug for Link { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.debug_struct("Link") + .field("frame", self.link.frame()) + .finish() + } +} diff --git a/actix-amqp/src/server/message.rs b/actix-amqp/src/server/message.rs new file mode 100755 index 000000000..534f0fbaa --- /dev/null +++ b/actix-amqp/src/server/message.rs @@ -0,0 +1,96 @@ +use std::fmt; + +use amqp_codec::protocol::{Accepted, DeliveryState, Error, Rejected, Transfer, TransferBody}; +use amqp_codec::Decode; +use bytes::Bytes; + +use crate::rcvlink::ReceiverLink; +use crate::session::Session; + +use super::errors::AmqpError; +use super::State; + +pub struct Message { + state: State, + frame: Transfer, + link: ReceiverLink, +} + +#[derive(Debug)] +pub enum Outcome { + Accept, + Reject, + Error(Error), +} + +impl From for Outcome +where + T: Into, +{ + fn from(err: T) -> Self { + Outcome::Error(err.into()) + } +} + +impl Outcome { + pub(crate) fn into_delivery_state(self) -> DeliveryState { + match self { + Outcome::Accept => DeliveryState::Accepted(Accepted {}), + Outcome::Reject => DeliveryState::Rejected(Rejected { error: None }), + Outcome::Error(e) => DeliveryState::Rejected(Rejected { error: Some(e) }), + } + } +} + +impl Message { + pub(crate) fn new(state: State, frame: Transfer, link: ReceiverLink) -> Self { + Message { state, frame, link } + } + + pub fn state(&self) -> &S { + self.state.get_ref() + } + + pub fn state_mut(&mut self) -> &mut S { + self.state.get_mut() + } + + pub fn session(&self) -> &Session { + self.link.session() + } + + pub fn session_mut(&mut self) -> &mut Session { + self.link.session_mut() + } + + pub fn frame(&self) -> &Transfer { + &self.frame + } + + pub fn body(&self) -> Option<&Bytes> { + match self.frame.body { + Some(TransferBody::Data(ref b)) => Some(b), + _ => None, + } + } + + pub fn load_message(&self) -> Result { + if let Some(TransferBody::Data(ref b)) = self.frame.body { + if let Ok((_, msg)) = T::decode(b) { + Ok(msg) + } else { + Err(AmqpError::decode_error().description("Can not decode message")) + } + } else { + Err(AmqpError::invalid_field().description("Unknown body")) + } + } +} + +impl fmt::Debug for Message { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.debug_struct("Message") + .field("frame", &self.frame) + .finish() + } +} diff --git a/actix-amqp/src/server/mod.rs b/actix-amqp/src/server/mod.rs new file mode 100755 index 000000000..36e1158bc --- /dev/null +++ b/actix-amqp/src/server/mod.rs @@ -0,0 +1,55 @@ +mod app; +mod connect; +mod control; +mod dispatcher; +pub mod errors; +mod handshake; +mod link; +mod message; +pub mod sasl; +mod service; + +pub use self::app::App; +pub use self::connect::{Connect, ConnectAck, ConnectOpened}; +pub use self::control::{ControlFrame, ControlFrameKind}; +pub use self::handshake::{handshake, Handshake}; +pub use self::link::Link; +pub use self::message::{Message, Outcome}; +pub use self::sasl::Sasl; +pub use self::service::Server; + +use crate::cell::Cell; + +pub struct State(Cell); + +impl State { + pub(crate) fn new(st: St) -> Self { + State(Cell::new(st)) + } + + pub(crate) fn clone(&self) -> Self { + State(self.0.clone()) + } + + pub fn get_ref(&self) -> &St { + self.0.get_ref() + } + + pub fn get_mut(&mut self) -> &mut St { + self.0.get_mut() + } +} + +impl std::ops::Deref for State { + type Target = St; + + fn deref(&self) -> &Self::Target { + self.get_ref() + } +} + +impl std::ops::DerefMut for State { + fn deref_mut(&mut self) -> &mut Self::Target { + self.get_mut() + } +} diff --git a/actix-amqp/src/server/sasl.rs b/actix-amqp/src/server/sasl.rs new file mode 100755 index 000000000..edf4fbeca --- /dev/null +++ b/actix-amqp/src/server/sasl.rs @@ -0,0 +1,338 @@ +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_service::{Service, ServiceFactory}; +use amqp_codec::protocol::{ + self, ProtocolId, SaslChallenge, SaslCode, SaslFrameBody, SaslMechanisms, SaslOutcome, Symbols, +}; +use amqp_codec::{AmqpCodec, AmqpFrame, ProtocolIdCodec, ProtocolIdError, SaslFrame}; +use bytes::Bytes; +use bytestring::ByteString; +use futures::future::{err, ok, Either, Ready}; +use futures::{SinkExt, StreamExt}; + +use super::connect::{ConnectAck, ConnectOpened}; +use super::errors::{AmqpError, ServerError}; +use crate::connection::ConnectionController; + +pub struct Sasl { + framed: Framed, + mechanisms: Symbols, + controller: ConnectionController, +} + +impl fmt::Debug for Sasl { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.debug_struct("SaslAuth") + .field("mechanisms", &self.mechanisms) + .finish() + } +} + +impl Sasl { + pub(crate) fn new( + framed: Framed, + controller: ConnectionController, + ) -> Self { + Sasl { + framed, + controller, + mechanisms: Symbols::default(), + } + } +} + +impl Sasl +where + Io: AsyncRead + AsyncWrite, +{ + /// Returns reference to io object + pub fn get_ref(&self) -> &Io { + self.framed.get_ref() + } + + /// Returns mutable reference to io object + pub fn get_mut(&mut self) -> &mut Io { + self.framed.get_mut() + } + + /// Add supported sasl mechanism + pub fn mechanism>(mut self, symbol: U) -> Self { + self.mechanisms.push(ByteString::from(symbol.into()).into()); + self + } + + /// Initialize sasl auth procedure + pub async fn init(self) -> Result, ServerError<()>> { + let Sasl { + framed, + mechanisms, + controller, + .. + } = self; + + let mut framed = framed.into_framed(AmqpCodec::::new()); + let frame = SaslMechanisms { + sasl_server_mechanisms: mechanisms, + } + .into(); + + framed.send(frame).await.map_err(ServerError::from)?; + let frame = framed + .next() + .await + .ok_or(ServerError::Disconnected)? + .map_err(ServerError::from)?; + + match frame.body { + SaslFrameBody::SaslInit(frame) => Ok(Init { + frame, + framed, + controller, + }), + body => Err(ServerError::UnexpectedSaslBodyFrame(body)), + } + } +} + +/// Initialization stage of sasl negotiation +pub struct Init { + frame: protocol::SaslInit, + framed: Framed>, + controller: ConnectionController, +} + +impl fmt::Debug for Init { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.debug_struct("SaslInit") + .field("frame", &self.frame) + .finish() + } +} + +impl Init +where + Io: AsyncRead + AsyncWrite, +{ + /// Sasl mechanism + pub fn mechanism(&self) -> &str { + self.frame.mechanism.as_str() + } + + /// Sasl initial response + pub fn initial_response(&self) -> Option<&[u8]> { + self.frame.initial_response.as_ref().map(|b| b.as_ref()) + } + + /// Sasl initial response + pub fn hostname(&self) -> Option<&str> { + self.frame.hostname.as_ref().map(|b| b.as_ref()) + } + + /// Returns reference to io object + pub fn get_ref(&self) -> &Io { + self.framed.get_ref() + } + + /// Returns mutable reference to io object + pub fn get_mut(&mut self) -> &mut Io { + self.framed.get_mut() + } + + /// Initiate sasl challenge + pub async fn challenge(self) -> Result, ServerError<()>> { + self.challenge_with(Bytes::new()).await + } + + /// Initiate sasl challenge with challenge payload + pub async fn challenge_with(self, challenge: Bytes) -> Result, ServerError<()>> { + let mut framed = self.framed; + let controller = self.controller; + let frame = SaslChallenge { challenge }.into(); + + framed.send(frame).await.map_err(ServerError::from)?; + let frame = framed + .next() + .await + .ok_or(ServerError::Disconnected)? + .map_err(ServerError::from)?; + + match frame.body { + SaslFrameBody::SaslResponse(frame) => Ok(Response { + frame, + framed, + controller, + }), + body => Err(ServerError::UnexpectedSaslBodyFrame(body)), + } + } + + /// Sasl challenge outcome + pub async fn outcome(self, code: SaslCode) -> Result, ServerError<()>> { + let mut framed = self.framed; + let controller = self.controller; + + let frame = SaslOutcome { + code, + additional_data: None, + } + .into(); + framed.send(frame).await.map_err(ServerError::from)?; + + Ok(Success { framed, controller }) + } +} + +pub struct Response { + frame: protocol::SaslResponse, + framed: Framed>, + controller: ConnectionController, +} + +impl fmt::Debug for Response { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.debug_struct("SaslResponse") + .field("frame", &self.frame) + .finish() + } +} + +impl Response +where + Io: AsyncRead + AsyncWrite, +{ + /// Client response payload + pub fn response(&self) -> &[u8] { + &self.frame.response[..] + } + + /// Sasl challenge outcome + pub async fn outcome(self, code: SaslCode) -> Result, ServerError<()>> { + let mut framed = self.framed; + let controller = self.controller; + let frame = SaslOutcome { + code, + additional_data: None, + } + .into(); + + framed.send(frame).await.map_err(ServerError::from)?; + framed + .next() + .await + .ok_or(ServerError::Disconnected)? + .map_err(|res| ServerError::from(res))?; + + Ok(Success { framed, controller }) + } +} + +pub struct Success { + framed: Framed>, + controller: ConnectionController, +} + +impl Success +where + Io: AsyncRead + AsyncWrite, +{ + /// Returns reference to io object + pub fn get_ref(&self) -> &Io { + self.framed.get_ref() + } + + /// Returns mutable reference to io object + pub fn get_mut(&mut self) -> &mut Io { + self.framed.get_mut() + } + + /// Wait for connection open frame + pub async fn open(self) -> Result, ServerError<()>> { + let mut framed = self.framed.into_framed(ProtocolIdCodec); + let mut controller = self.controller; + + let protocol = framed + .next() + .await + .ok_or(ServerError::from(ProtocolIdError::Disconnected))? + .map_err(ServerError::from)?; + + match protocol { + ProtocolId::Amqp => { + // confirm protocol + framed + .send(ProtocolId::Amqp) + .await + .map_err(ServerError::from)?; + + // Wait for connection open frame + let mut framed = framed.into_framed(AmqpCodec::::new()); + let frame = framed + .next() + .await + .ok_or(ServerError::Disconnected)? + .map_err(ServerError::from)?; + + let frame = frame.into_parts().1; + match frame { + protocol::Frame::Open(frame) => { + trace!("Got open frame: {:?}", frame); + controller.set_remote((&frame).into()); + Ok(ConnectOpened::new(frame, framed, controller)) + } + frame => Err(ServerError::Unexpected(frame)), + } + } + proto => Err(ProtocolIdError::Unexpected { + exp: ProtocolId::Amqp, + got: proto, + } + .into()), + } + } +} + +/// Create service factory with disabled sasl support +pub fn no_sasl() -> NoSaslService { + NoSaslService::default() +} + +pub struct NoSaslService(std::marker::PhantomData<(Io, St, E)>); + +impl Default for NoSaslService { + fn default() -> Self { + NoSaslService(std::marker::PhantomData) + } +} + +impl ServiceFactory for NoSaslService { + type Config = (); + type Request = Sasl; + type Response = ConnectAck; + type Error = AmqpError; + type InitError = E; + type Service = NoSaslService; + type Future = Ready>; + + fn new_service(&self, _: ()) -> Self::Future { + ok(NoSaslService(std::marker::PhantomData)) + } +} + +impl Service for NoSaslService { + type Request = Sasl; + type Response = ConnectAck; + type Error = AmqpError; + type Future = Ready>; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: Self::Request) -> Self::Future { + err(AmqpError::not_implemented()) + } +} diff --git a/actix-amqp/src/server/service.rs b/actix-amqp/src/server/service.rs new file mode 100755 index 000000000..6b80cfadc --- /dev/null +++ b/actix-amqp/src/server/service.rs @@ -0,0 +1,360 @@ +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, time}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_service::{boxed, IntoServiceFactory, Service, ServiceFactory}; +use amqp_codec::protocol::{Error, ProtocolId}; +use amqp_codec::{AmqpCodecError, AmqpFrame, ProtocolIdCodec, ProtocolIdError}; +use futures::future::{err, poll_fn, Either}; +use futures::{FutureExt, SinkExt, StreamExt}; + +use crate::cell::Cell; +use crate::connection::{Connection, ConnectionController}; +use crate::Configuration; + +use super::connect::{Connect, ConnectAck}; +use super::control::{ControlFrame, ControlFrameNewService}; +use super::dispatcher::Dispatcher; +use super::errors::{LinkError, ServerError}; +use super::link::Link; +use super::sasl::Sasl; +use super::State; + +/// Amqp connection type +pub type AmqpConnect = either::Either, Sasl>; + +/// Server dispatcher factory +pub struct Server { + connect: Cn, + config: Configuration, + control: Option>, + disconnect: Option>)>>, + max_size: usize, + handshake_timeout: u64, + _t: PhantomData<(Io, St)>, +} + +pub(super) struct ServerInner { + connect: Cn, + publish: Pb, + config: Configuration, + control: Option>, + disconnect: Option>)>>, + max_size: usize, + handshake_timeout: u64, +} + +impl Server +where + St: 'static, + Io: AsyncRead + AsyncWrite + 'static, + Cn: ServiceFactory, Response = ConnectAck> + + 'static, +{ + /// Create server factory and provide connect service + pub fn new(connect: F) -> Self + where + F: IntoServiceFactory, + { + Self { + connect: connect.into_factory(), + config: Configuration::default(), + control: None, + disconnect: None, + max_size: 0, + handshake_timeout: 0, + _t: PhantomData, + } + } + + /// Provide connection configuration + pub fn config(mut self, config: Configuration) -> Self { + self.config = config; + self + } + + /// Set max inbound frame size. + /// + /// If max size is set to `0`, size is unlimited. + /// By default max size is set to `0` + pub fn max_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } + + /// Set handshake timeout in millis. + /// + /// By default handshake timeuot is disabled. + pub fn handshake_timeout(mut self, timeout: u64) -> Self { + self.handshake_timeout = timeout; + self + } + + /// Service to call with control frames + pub fn control(self, f: F) -> Self + where + F: IntoServiceFactory, + S: ServiceFactory, Response = (), InitError = ()> + + 'static, + S::Error: Into, + { + Server { + connect: self.connect, + config: self.config, + disconnect: self.disconnect, + control: Some(boxed::factory( + f.into_factory() + .map_err(|e| e.into()) + .map_init_err(|e| e.into()), + )), + max_size: self.max_size, + handshake_timeout: self.handshake_timeout, + _t: PhantomData, + } + } + + /// Callback to execute on disconnect + /// + /// Second parameter indicates error occured during disconnect. + pub fn disconnect(self, disconnect: F) -> Self + where + F: Fn(&mut St, Option<&ServerError>) -> Out + 'static, + Out: Future + 'static, + { + Server { + connect: self.connect, + config: self.config, + control: self.control, + disconnect: Some(Box::new(move |st, err| { + let fut = disconnect(st, err); + actix_rt::spawn(fut.map(|_| ())); + })), + max_size: self.max_size, + handshake_timeout: self.handshake_timeout, + _t: PhantomData, + } + } + + /// Set service to execute for incoming links and create service factory + pub fn finish( + self, + service: F, + ) -> impl ServiceFactory> + where + F: IntoServiceFactory, + Pb: ServiceFactory, Request = Link, Response = ()> + 'static, + Pb::Error: fmt::Display + Into, + Pb::InitError: fmt::Display + Into, + { + ServerImpl { + inner: Cell::new(ServerInner { + connect: self.connect, + config: self.config, + publish: service.into_factory(), + control: self.control, + disconnect: self.disconnect, + max_size: self.max_size, + handshake_timeout: self.handshake_timeout, + }), + _t: PhantomData, + } + } +} + +struct ServerImpl { + inner: Cell>, + _t: PhantomData<(Io,)>, +} + +impl ServiceFactory for ServerImpl +where + St: 'static, + Io: AsyncRead + AsyncWrite + 'static, + Cn: ServiceFactory, Response = ConnectAck> + + 'static, + Pb: ServiceFactory, Request = Link, Response = ()> + 'static, + Pb::Error: fmt::Display + Into, + Pb::InitError: fmt::Display + Into, +{ + type Config = (); + type Request = Io; + type Response = (); + type Error = ServerError; + type Service = ServerImplService; + type InitError = Cn::InitError; + type Future = Pin>>>; + + fn new_service(&self, _: ()) -> Self::Future { + let inner = self.inner.clone(); + + Box::pin(async move { + inner + .connect + .new_service(()) + .await + .map(move |connect| ServerImplService { + inner, + connect: Cell::new(connect), + _t: PhantomData, + }) + }) + } +} + +struct ServerImplService { + connect: Cell, + inner: Cell>, + _t: PhantomData<(Io,)>, +} + +impl Service for ServerImplService +where + St: 'static, + Io: AsyncRead + AsyncWrite + 'static, + Cn: ServiceFactory, Response = ConnectAck> + + 'static, + Pb: ServiceFactory, Request = Link, Response = ()> + 'static, + Pb::Error: fmt::Display + Into, + Pb::InitError: fmt::Display + Into, +{ + type Request = Io; + type Response = (); + type Error = ServerError; + type Future = Pin>>>; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.connect + .get_mut() + .poll_ready(cx) + .map(|res| res.map_err(|e| ServerError::Service(e))) + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + let timeout = self.inner.handshake_timeout; + if timeout == 0 { + Box::pin(handshake( + self.inner.max_size, + self.connect.clone(), + self.inner.clone(), + req, + )) + } else { + Box::pin( + actix_rt::time::timeout( + time::Duration::from_millis(timeout), + handshake( + self.inner.max_size, + self.connect.clone(), + self.inner.clone(), + req, + ), + ) + .map(|res| match res { + Ok(res) => res, + Err(_) => Err(ServerError::HandshakeTimeout), + }), + ) + } + } +} + +async fn handshake( + max_size: usize, + connect: Cell, + inner: Cell>, + io: Io, +) -> Result<(), ServerError> +where + St: 'static, + Io: AsyncRead + AsyncWrite + 'static, + Cn: ServiceFactory, Response = ConnectAck>, + Pb: ServiceFactory, Request = Link, Response = ()> + 'static, + Pb::Error: fmt::Display + Into, + Pb::InitError: fmt::Display + Into, +{ + let inner2 = inner.clone(); + let mut framed = Framed::new(io, ProtocolIdCodec); + + let protocol = framed + .next() + .await + .ok_or(ServerError::Disconnected)? + .map_err(ServerError::Handshake)?; + + let (st, srv, conn) = match protocol { + // start amqp processing + ProtocolId::Amqp | ProtocolId::AmqpSasl => { + framed.send(protocol).await.map_err(ServerError::from)?; + + let cfg = inner.get_ref().config.clone(); + let controller = ConnectionController::new(cfg.clone()); + + let ack = connect + .get_mut() + .call(if protocol == ProtocolId::Amqp { + either::Either::Left(Connect::new(framed, controller)) + } else { + either::Either::Right(Sasl::new(framed, controller)) + }) + .await + .map_err(|e| ServerError::Service(e))?; + + let (st, mut framed, controller) = ack.into_inner(); + let st = State::new(st); + framed.get_codec_mut().max_size(max_size); + + // confirm Open + let local = cfg.to_open(); + framed + .send(AmqpFrame::new(0, local.into())) + .await + .map_err(ServerError::from)?; + + let conn = Connection::new_server(framed, controller.0, None); + + // create publish service + let srv = inner.publish.new_service(st.clone()).await.map_err(|e| { + error!("Can not construct app service"); + ServerError::ProtocolError(e.into()) + })?; + + (st, srv, conn) + } + ProtocolId::AmqpTls => { + return Err(ServerError::from(ProtocolIdError::Unexpected { + exp: ProtocolId::Amqp, + got: ProtocolId::AmqpTls, + })) + } + }; + + let mut st2 = st.clone(); + + if let Some(ref control_srv) = inner2.control { + let control = control_srv + .new_service(()) + .await + .map_err(|_| ServerError::ControlServiceInit)?; + + let res = Dispatcher::new(conn, st, srv, Some(control)) + .await + .map_err(ServerError::from); + + if inner2.disconnect.is_some() { + (*inner2.get_mut().disconnect.as_mut().unwrap())(st2.get_mut(), res.as_ref().err()) + } + res + } else { + let res = Dispatcher::new(conn, st, srv, None) + .await + .map_err(ServerError::from); + + if inner2.disconnect.is_some() { + (*inner2.get_mut().disconnect.as_mut().unwrap())(st2.get_mut(), res.as_ref().err()) + } + res + } +} diff --git a/actix-amqp/src/service.rs b/actix-amqp/src/service.rs new file mode 100755 index 000000000..b4ed51a9c --- /dev/null +++ b/actix-amqp/src/service.rs @@ -0,0 +1,65 @@ +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_service::Service; +use futures::{Future, SinkExt, StreamExt}; + +use amqp_codec::protocol::ProtocolId; +use amqp_codec::{ProtocolIdCodec, ProtocolIdError}; + +pub struct ProtocolNegotiation { + proto: ProtocolId, + _r: PhantomData, +} + +impl Clone for ProtocolNegotiation { + fn clone(&self) -> Self { + ProtocolNegotiation { + proto: self.proto.clone(), + _r: PhantomData, + } + } +} + +impl ProtocolNegotiation { + pub fn new(proto: ProtocolId) -> Self { + ProtocolNegotiation { + proto, + _r: PhantomData, + } + } +} + +impl Service for ProtocolNegotiation +where + T: AsyncRead + AsyncWrite + 'static, +{ + type Request = Framed; + type Response = Framed; + type Error = ProtocolIdError; + type Future = Pin>>>; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, mut framed: Framed) -> Self::Future { + let proto = self.proto; + + Box::pin(async move { + framed.send(proto).await?; + + let protocol = framed.next().await.ok_or(ProtocolIdError::Disconnected)??; + if proto == protocol { + Ok(framed) + } else { + Err(ProtocolIdError::Unexpected { + exp: proto, + got: protocol, + }) + } + }) + } +} diff --git a/actix-amqp/src/session.rs b/actix-amqp/src/session.rs new file mode 100755 index 000000000..eb7a7fa2d --- /dev/null +++ b/actix-amqp/src/session.rs @@ -0,0 +1,911 @@ +use std::collections::VecDeque; +use std::future::Future; + +use actix_utils::oneshot; +use bytes::{BufMut, Bytes, BytesMut}; +use bytestring::ByteString; +use either::Either; +use futures::future::ok; +use fxhash::FxHashMap; +use slab::Slab; + +use amqp_codec::protocol::{ + Accepted, Attach, DeliveryNumber, DeliveryState, Detach, Disposition, Error, Flow, Frame, + Handle, ReceiverSettleMode, Role, SenderSettleMode, Transfer, TransferBody, TransferNumber, +}; +use amqp_codec::AmqpFrame; + +use crate::cell::Cell; +use crate::connection::ConnectionController; +use crate::errors::AmqpTransportError; +use crate::rcvlink::{ReceiverLink, ReceiverLinkBuilder, ReceiverLinkInner}; +use crate::sndlink::{SenderLink, SenderLinkBuilder, SenderLinkInner}; +use crate::{Configuration, DeliveryPromise}; + +const INITIAL_OUTGOING_ID: TransferNumber = 0; + +#[derive(Clone)] +pub struct Session { + pub(crate) inner: Cell, +} + +impl Drop for Session { + fn drop(&mut self) { + self.inner.get_mut().drop_session() + } +} + +impl std::fmt::Debug for Session { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + fmt.debug_struct("Session").finish() + } +} + +impl Session { + pub(crate) fn new(inner: Cell) -> Session { + Session { inner } + } + + #[inline] + /// Get remote connection configuration + pub fn remote_config(&self) -> &Configuration { + self.inner.connection.remote_config() + } + + pub fn close(&self) -> impl Future> { + ok(()) + } + + pub fn get_sender_link(&self, name: &str) -> Option<&SenderLink> { + let inner = self.inner.get_ref(); + + if let Some(id) = inner.links_by_name.get(name) { + if let Some(Either::Left(SenderLinkState::Established(ref link))) = inner.links.get(*id) + { + return Some(link); + } + } + None + } + + pub fn get_sender_link_by_handle(&self, hnd: Handle) -> Option<&SenderLink> { + self.inner.get_ref().get_sender_link_by_handle(hnd) + } + + pub fn get_receiver_link_by_handle(&self, hnd: Handle) -> Option<&ReceiverLink> { + self.inner.get_ref().get_receiver_link_by_handle(hnd) + } + + /// Open sender link + pub fn build_sender_link, U: Into>( + &mut self, + name: U, + address: T, + ) -> SenderLinkBuilder { + let name = ByteString::from(name.into()); + let address = ByteString::from(address.into()); + SenderLinkBuilder::new(name, address, self.inner.clone()) + } + + /// Open receiver link + pub fn build_receiver_link, U: Into>( + &mut self, + name: U, + address: T, + ) -> ReceiverLinkBuilder { + let name = ByteString::from(name.into()); + let address = ByteString::from(address.into()); + ReceiverLinkBuilder::new(name, address, self.inner.clone()) + } + + /// Detach receiver link + pub fn detach_receiver_link( + &mut self, + handle: Handle, + error: Option, + ) -> impl Future> { + let (tx, rx) = oneshot::channel(); + + self.inner + .get_mut() + .detach_receiver_link(handle, false, error, tx); + + async move { + match rx.await { + Ok(Ok(_)) => Ok(()), + Ok(Err(e)) => Err(e), + Err(_) => Err(AmqpTransportError::Disconnected), + } + } + } + + pub fn wait_disposition( + &mut self, + id: DeliveryNumber, + ) -> impl Future> { + self.inner.get_mut().wait_disposition(id) + } +} + +#[derive(Debug)] +enum SenderLinkState { + Opening(oneshot::Sender), + Established(SenderLink), + Closing(Option>>), +} + +#[derive(Debug)] +enum ReceiverLinkState { + Opening(Option>), + OpeningLocal( + Option<( + Cell, + oneshot::Sender>, + )>, + ), + Established(ReceiverLink), + Closing(Option>>), +} + +impl SenderLinkState { + fn is_opening(&self) -> bool { + match self { + SenderLinkState::Opening(_) => true, + _ => false, + } + } +} + +impl ReceiverLinkState { + fn is_opening(&self) -> bool { + match self { + ReceiverLinkState::OpeningLocal(_) => true, + _ => false, + } + } +} + +pub(crate) struct SessionInner { + id: usize, + connection: ConnectionController, + next_outgoing_id: TransferNumber, + local: bool, + + remote_channel_id: u16, + next_incoming_id: TransferNumber, + remote_outgoing_window: u32, + remote_incoming_window: u32, + + unsettled_deliveries: FxHashMap, + + links: Slab>, + links_by_name: FxHashMap, + remote_handles: FxHashMap, + pending_transfers: VecDeque, + disposition_subscribers: FxHashMap>, + error: Option, +} + +struct PendingTransfer { + link_handle: Handle, + idx: u32, + body: Option, + promise: DeliveryPromise, + tag: Option, + settled: Option, +} + +impl SessionInner { + pub fn new( + id: usize, + local: bool, + connection: ConnectionController, + remote_channel_id: u16, + next_incoming_id: DeliveryNumber, + remote_incoming_window: u32, + remote_outgoing_window: u32, + ) -> SessionInner { + SessionInner { + id, + local, + connection, + next_incoming_id, + remote_channel_id, + remote_incoming_window, + remote_outgoing_window, + next_outgoing_id: INITIAL_OUTGOING_ID, + unsettled_deliveries: FxHashMap::default(), + links: Slab::new(), + links_by_name: FxHashMap::default(), + remote_handles: FxHashMap::default(), + pending_transfers: VecDeque::new(), + disposition_subscribers: FxHashMap::default(), + error: None, + } + } + + /// Local channel id + pub fn id(&self) -> u16 { + self.id as u16 + } + + /// Set error. New operations will return error. + pub(crate) fn set_error(&mut self, err: AmqpTransportError) { + // drop pending transfers + for tr in self.pending_transfers.drain(..) { + let _ = tr.promise.send(Err(err.clone())); + } + + // drop links + self.links_by_name.clear(); + for (_, st) in self.links.iter_mut() { + match st { + Either::Left(SenderLinkState::Opening(_)) => (), + Either::Left(SenderLinkState::Established(ref mut link)) => { + link.inner.get_mut().set_error(err.clone()) + } + Either::Left(SenderLinkState::Closing(ref mut link)) => { + if let Some(tx) = link.take() { + let _ = tx.send(Err(err.clone())); + } + } + _ => (), + } + } + self.links.clear(); + + self.error = Some(err); + } + + fn drop_session(&mut self) { + self.connection.drop_session_copy(self.id); + } + + fn wait_disposition( + &mut self, + id: DeliveryNumber, + ) -> impl Future> { + let (tx, rx) = oneshot::channel(); + self.disposition_subscribers.insert(id, tx); + async move { rx.await.map_err(|_| AmqpTransportError::Disconnected) } + } + + /// Register remote sender link + pub(crate) fn confirm_sender_link(&mut self, cell: Cell, attach: &Attach) { + trace!("Remote sender link opened: {:?}", attach.name()); + let entry = self.links.vacant_entry(); + let token = entry.key(); + let delivery_count = attach.initial_delivery_count.unwrap_or(0); + + let mut name = None; + if let Some(ref source) = attach.source { + if let Some(ref addr) = source.address { + name = Some(addr.clone()); + self.links_by_name.insert(addr.clone(), token); + } + } + + self.remote_handles.insert(attach.handle(), token); + let link = Cell::new(SenderLinkInner::new( + token, + name.unwrap_or_else(|| ByteString::default()), + attach.handle(), + delivery_count, + cell, + )); + entry.insert(Either::Left(SenderLinkState::Established(SenderLink::new( + link, + )))); + + let attach = Attach { + name: attach.name.clone(), + handle: token as Handle, + role: Role::Sender, + snd_settle_mode: SenderSettleMode::Mixed, + rcv_settle_mode: ReceiverSettleMode::First, + source: attach.source.clone(), + target: attach.target.clone(), + unsettled: None, + incomplete_unsettled: false, + initial_delivery_count: Some(delivery_count), + max_message_size: Some(65536), + offered_capabilities: None, + desired_capabilities: None, + properties: None, + }; + self.post_frame(attach.into()); + } + + /// Register receiver link + pub(crate) fn open_receiver_link( + &mut self, + cell: Cell, + attach: Attach, + ) -> ReceiverLink { + let handle = attach.handle(); + let entry = self.links.vacant_entry(); + let token = entry.key(); + + let inner = Cell::new(ReceiverLinkInner::new(cell, token as u32, attach)); + entry.insert(Either::Right(ReceiverLinkState::Opening(Some( + inner.clone(), + )))); + self.remote_handles.insert(handle, token); + ReceiverLink::new(inner) + } + + pub(crate) fn open_local_receiver_link( + &mut self, + cell: Cell, + mut frame: Attach, + ) -> oneshot::Receiver> { + let (tx, rx) = oneshot::channel(); + + let entry = self.links.vacant_entry(); + let token = entry.key(); + + let inner = Cell::new(ReceiverLinkInner::new(cell, token as u32, frame.clone())); + entry.insert(Either::Right(ReceiverLinkState::OpeningLocal(Some(( + inner.clone(), + tx, + ))))); + + frame.handle = token as Handle; + + self.links_by_name.insert(frame.name.clone(), token); + self.post_frame(Frame::Attach(frame)); + rx + } + + pub(crate) fn confirm_receiver_link(&mut self, token: Handle, attach: &Attach) { + if let Some(Either::Right(link)) = self.links.get_mut(token as usize) { + match link { + ReceiverLinkState::Opening(l) => { + let attach = Attach { + name: attach.name.clone(), + handle: token as Handle, + role: Role::Receiver, + snd_settle_mode: SenderSettleMode::Mixed, + rcv_settle_mode: ReceiverSettleMode::First, + source: attach.source.clone(), + target: attach.target.clone(), + unsettled: None, + incomplete_unsettled: false, + initial_delivery_count: Some(0), + max_message_size: Some(65536), + offered_capabilities: None, + desired_capabilities: None, + properties: None, + }; + *link = ReceiverLinkState::Established(ReceiverLink::new(l.take().unwrap())); + self.post_frame(attach.into()); + } + _ => error!("Unexpected receiver link state"), + } + } + } + + /// Close receiver link + pub(crate) fn detach_receiver_link( + &mut self, + id: Handle, + closed: bool, + error: Option, + tx: oneshot::Sender>, + ) { + if let Some(Either::Right(link)) = self.links.get_mut(id as usize) { + match link { + ReceiverLinkState::Opening(inner) => { + let attach = Attach { + name: inner.as_ref().unwrap().get_ref().name().clone(), + handle: id as Handle, + role: Role::Sender, + snd_settle_mode: SenderSettleMode::Mixed, + rcv_settle_mode: ReceiverSettleMode::First, + source: None, + target: None, + unsettled: None, + incomplete_unsettled: false, + initial_delivery_count: None, + max_message_size: None, + offered_capabilities: None, + desired_capabilities: None, + properties: None, + }; + let detach = Detach { + handle: id, + closed, + error, + }; + *link = ReceiverLinkState::Closing(Some(tx)); + self.post_frame(attach.into()); + self.post_frame(detach.into()); + } + ReceiverLinkState::Established(_) => { + let detach = Detach { + handle: id, + closed, + error, + }; + *link = ReceiverLinkState::Closing(Some(tx)); + self.post_frame(detach.into()); + } + ReceiverLinkState::Closing(_) => { + let _ = tx.send(Ok(())); + error!("Unexpected receiver link state: closing - {}", id); + } + ReceiverLinkState::OpeningLocal(_inner) => unimplemented!(), + } + } else { + let _ = tx.send(Ok(())); + error!("Receiver link does not exist while detaching: {}", id); + } + } + + pub(crate) fn detach_sender_link( + &mut self, + id: usize, + closed: bool, + error: Option, + tx: oneshot::Sender>, + ) { + if let Some(Either::Left(link)) = self.links.get_mut(id) { + match link { + SenderLinkState::Opening(_) => { + let detach = Detach { + handle: id as u32, + closed, + error, + }; + *link = SenderLinkState::Closing(Some(tx)); + self.post_frame(detach.into()); + } + SenderLinkState::Established(_) => { + let detach = Detach { + handle: id as u32, + closed, + error, + }; + *link = SenderLinkState::Closing(Some(tx)); + self.post_frame(detach.into()); + } + SenderLinkState::Closing(_) => { + let _ = tx.send(Ok(())); + error!("Unexpected receiver link state: closing - {}", id); + } + } + } else { + let _ = tx.send(Ok(())); + error!("Receiver link does not exist while detaching: {}", id); + } + } + + pub(crate) fn get_sender_link_by_handle(&self, hnd: Handle) -> Option<&SenderLink> { + if let Some(id) = self.remote_handles.get(&hnd) { + if let Some(Either::Left(SenderLinkState::Established(ref link))) = self.links.get(*id) + { + return Some(link); + } + } + None + } + + pub(crate) fn get_receiver_link_by_handle(&self, hnd: Handle) -> Option<&ReceiverLink> { + if let Some(id) = self.remote_handles.get(&hnd) { + if let Some(Either::Right(ReceiverLinkState::Established(ref link))) = + self.links.get(*id) + { + return Some(link); + } + } + None + } + + pub fn handle_frame(&mut self, frame: Frame) { + if self.error.is_none() { + match frame { + Frame::Flow(flow) => self.apply_flow(&flow), + Frame::Disposition(disp) => { + if let Some(sender) = self.disposition_subscribers.remove(&disp.first) { + let _ = sender.send(disp); + } else { + self.settle_deliveries(disp); + } + } + Frame::Transfer(transfer) => { + let idx = if let Some(idx) = self.remote_handles.get(&transfer.handle()) { + *idx + } else { + error!("Transfer's link {:?} is unknown", transfer.handle()); + return; + }; + + if let Some(link) = self.links.get_mut(idx) { + match link { + Either::Left(_) => error!("Got trasfer from sender link"), + Either::Right(link) => match link { + ReceiverLinkState::Opening(_) => { + error!( + "Got transfer for opening link: {} -> {}", + transfer.handle(), + idx + ); + } + ReceiverLinkState::OpeningLocal(_) => { + error!( + "Got transfer for opening link: {} -> {}", + transfer.handle(), + idx + ); + } + ReceiverLinkState::Established(link) => { + // self.outgoing_window -= 1; + let _ = self.next_incoming_id.wrapping_add(1); + link.inner.get_mut().handle_transfer(transfer); + } + ReceiverLinkState::Closing(_) => (), + }, + } + } else { + error!( + "Remote link handle mapped to non-existing link: {} -> {}", + transfer.handle(), + idx + ); + } + } + Frame::Detach(detach) => { + self.handle_detach(&detach); + } + frame => error!("Unexpected frame: {:?}", frame), + } + } + } + + /// Handle `Attach` frame. return false if attach frame is remote and can not be handled + pub fn handle_attach(&mut self, attach: &Attach, cell: Cell) -> bool { + let name = attach.name(); + + if let Some(index) = self.links_by_name.get(name) { + match self.links.get_mut(*index) { + Some(Either::Left(item)) => { + if item.is_opening() { + trace!( + "sender link opened: {:?} {} -> {}", + name, + index, + attach.handle() + ); + + self.remote_handles.insert(attach.handle(), *index); + let delivery_count = attach.initial_delivery_count.unwrap_or(0); + let link = Cell::new(SenderLinkInner::new( + *index, + name.clone(), + attach.handle(), + delivery_count, + cell, + )); + let local_sender = std::mem::replace( + item, + SenderLinkState::Established(SenderLink::new(link.clone())), + ); + + if let SenderLinkState::Opening(tx) = local_sender { + let _ = tx.send(SenderLink::new(link)); + } + } + } + Some(Either::Right(item)) => { + if item.is_opening() { + trace!( + "receiver link opened: {:?} {} -> {}", + name, + index, + attach.handle() + ); + if let ReceiverLinkState::OpeningLocal(opt_item) = item { + let (link, tx) = opt_item.take().unwrap(); + self.remote_handles.insert(attach.handle(), *index); + + *item = ReceiverLinkState::Established(ReceiverLink::new(link.clone())); + let _ = tx.send(Ok(ReceiverLink::new(link))); + } + } + } + _ => { + // TODO: error in proto, have to close connection + } + } + true + } else { + // cannot handle remote attach + false + } + } + + /// Handle `Detach` frame. + pub fn handle_detach(&mut self, detach: &Detach) { + // get local link instance + let idx = if let Some(idx) = self.remote_handles.get(&detach.handle()) { + *idx + } else { + // should not happen, error + return; + }; + + let remove = if let Some(link) = self.links.get_mut(idx) { + match link { + Either::Left(link) => match link { + SenderLinkState::Opening(_) => true, + SenderLinkState::Established(link) => { + // detach from remote endpoint + let detach = Detach { + handle: link.inner.get_ref().id(), + closed: true, + error: detach.error.clone(), + }; + let err = AmqpTransportError::LinkDetached(detach.error.clone()); + + // remove name + self.links_by_name.remove(link.inner.name()); + + // drop pending transfers + let mut idx = 0; + let handle = link.inner.get_ref().remote_handle(); + while idx < self.pending_transfers.len() { + if self.pending_transfers[idx].link_handle == handle { + let tr = self.pending_transfers.remove(idx).unwrap(); + let _ = tr.promise.send(Err(err.clone())); + } else { + idx += 1; + } + } + + // detach snd link + link.inner.get_mut().detached(err); + self.connection + .post_frame(AmqpFrame::new(self.remote_channel_id, detach.into())); + true + } + SenderLinkState::Closing(_) => true, + }, + Either::Right(link) => match link { + ReceiverLinkState::Opening(_) => false, + ReceiverLinkState::OpeningLocal(_) => false, + ReceiverLinkState::Established(link) => { + // detach from remote endpoint + let detach = Detach { + handle: link.handle(), + closed: true, + error: None, + }; + + // detach rcv link + self.connection + .post_frame(AmqpFrame::new(self.remote_channel_id, detach.into())); + true + } + ReceiverLinkState::Closing(tx) => { + // detach confirmation + if let Some(tx) = tx.take() { + if let Some(err) = detach.error.clone() { + let _ = tx.send(Err(AmqpTransportError::LinkDetached(Some(err)))); + } else { + let _ = tx.send(Ok(())); + } + } + true + } + }, + } + } else { + false + }; + + if remove { + self.links.remove(idx); + self.remote_handles.remove(&detach.handle()); + } + } + + fn settle_deliveries(&mut self, disposition: Disposition) { + trace!("settle delivery: {:#?}", disposition); + + let from = disposition.first; + let to = disposition.last.unwrap_or(from); + + if from == to { + let _ = self + .unsettled_deliveries + .remove(&from) + .unwrap() + .send(Ok(disposition)); + } else { + for k in from..=to { + let _ = self + .unsettled_deliveries + .remove(&k) + .unwrap() + .send(Ok(disposition.clone())); + } + } + } + + pub(crate) fn apply_flow(&mut self, flow: &Flow) { + // # AMQP1.0 2.5.6 + self.next_incoming_id = flow.next_outgoing_id(); + self.remote_outgoing_window = flow.outgoing_window(); + + self.remote_incoming_window = flow + .next_incoming_id() + .unwrap_or(INITIAL_OUTGOING_ID) + .saturating_add(flow.incoming_window()) + .saturating_sub(self.next_outgoing_id); + + trace!( + "session received credit. window: {}, pending: {}", + self.remote_outgoing_window, + self.pending_transfers.len() + ); + + while let Some(t) = self.pending_transfers.pop_front() { + self.send_transfer(t.link_handle, t.idx, t.body, t.promise, t.tag, t.settled); + if self.remote_outgoing_window == 0 { + break; + } + } + + // apply link flow + if let Some(Either::Left(link)) = flow.handle().and_then(|h| self.links.get_mut(h as usize)) + { + match link { + SenderLinkState::Established(ref mut link) => { + link.inner.get_mut().apply_flow(&flow); + } + _ => warn!("Received flow frame"), + } + } + if flow.echo() { + self.send_flow(); + } + } + + fn send_flow(&mut self) { + let flow = Flow { + next_incoming_id: if self.local { + Some(self.next_incoming_id) + } else { + None + }, + incoming_window: std::u32::MAX, + next_outgoing_id: self.next_outgoing_id, + outgoing_window: self.remote_incoming_window, + handle: None, + delivery_count: None, + link_credit: None, + available: None, + drain: false, + echo: false, + properties: None, + }; + self.post_frame(flow.into()); + } + + pub(crate) fn rcv_link_flow(&mut self, handle: u32, delivery_count: u32, credit: u32) { + let flow = Flow { + next_incoming_id: if self.local { + Some(self.next_incoming_id) + } else { + None + }, + incoming_window: std::u32::MAX, + next_outgoing_id: self.next_outgoing_id, + outgoing_window: self.remote_incoming_window, + handle: Some(handle), + delivery_count: Some(delivery_count), + link_credit: Some(credit), + available: None, + drain: false, + echo: false, + properties: None, + }; + self.post_frame(flow.into()); + } + + pub fn post_frame(&mut self, frame: Frame) { + self.connection + .post_frame(AmqpFrame::new(self.remote_channel_id, frame)); + } + + pub(crate) fn open_sender_link(&mut self, mut frame: Attach) -> oneshot::Receiver { + let (tx, rx) = oneshot::channel(); + + let entry = self.links.vacant_entry(); + let token = entry.key(); + entry.insert(Either::Left(SenderLinkState::Opening(tx))); + + frame.handle = token as Handle; + + self.links_by_name.insert(frame.name.clone(), token); + self.post_frame(Frame::Attach(frame)); + rx + } + + pub fn send_transfer( + &mut self, + link_handle: Handle, + idx: u32, + body: Option, + promise: DeliveryPromise, + tag: Option, + settled: Option, + ) { + if self.remote_incoming_window == 0 { + self.pending_transfers.push_back(PendingTransfer { + link_handle, + idx, + body, + promise, + tag, + settled, + }); + return; + } + let frame = self.prepare_transfer(link_handle, body, promise, tag, settled); + self.post_frame(frame); + } + + pub fn prepare_transfer( + &mut self, + link_handle: Handle, + body: Option, + promise: DeliveryPromise, + delivery_tag: Option, + settled: Option, + ) -> Frame { + let delivery_id = self.next_outgoing_id; + + let tag = if let Some(tag) = delivery_tag { + tag + } else { + let mut buf = BytesMut::new(); + buf.put_u32(delivery_id); + buf.freeze() + }; + + self.next_outgoing_id += 1; + self.remote_incoming_window -= 1; + + let message_format = if let Some(ref body) = body { + body.message_format() + } else { + None + }; + + let settled2 = settled.clone().unwrap_or(false); + let state = if settled2 { + Some(DeliveryState::Accepted(Accepted {})) + } else { + None + }; + + let transfer = Transfer { + settled, + message_format, + handle: link_handle, + delivery_id: Some(delivery_id), + delivery_tag: Some(tag), + more: false, + rcv_settle_mode: None, + state, //: Some(DeliveryState::Accepted(Accepted {})), + resume: false, + aborted: false, + batchable: false, + body: body, + }; + self.unsettled_deliveries.insert(delivery_id, promise); + + Frame::Transfer(transfer) + } +} diff --git a/actix-amqp/src/sndlink.rs b/actix-amqp/src/sndlink.rs new file mode 100755 index 000000000..da7c37d82 --- /dev/null +++ b/actix-amqp/src/sndlink.rs @@ -0,0 +1,326 @@ +use std::collections::VecDeque; +use std::future::Future; + +use actix_utils::oneshot; +use amqp_codec::protocol::{ + Attach, DeliveryNumber, DeliveryState, Disposition, Error, Flow, ReceiverSettleMode, Role, + SenderSettleMode, SequenceNo, Target, TerminusDurability, TerminusExpiryPolicy, TransferBody, +}; +use bytes::Bytes; +use bytestring::ByteString; +use futures::future::{ok, Either}; + +use crate::cell::Cell; +use crate::errors::AmqpTransportError; +use crate::session::{Session, SessionInner}; +use crate::{Delivery, DeliveryPromise, Handle}; + +#[derive(Clone)] +pub struct SenderLink { + pub(crate) inner: Cell, +} + +impl std::fmt::Debug for SenderLink { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + fmt.debug_tuple("SenderLink") + .field(&std::ops::Deref::deref(&self.inner.get_ref().name)) + .finish() + } +} + +pub(crate) struct SenderLinkInner { + id: usize, + idx: u32, + name: ByteString, + session: Session, + remote_handle: Handle, + delivery_count: SequenceNo, + link_credit: u32, + pending_transfers: VecDeque, + error: Option, + closed: bool, +} + +struct PendingTransfer { + idx: u32, + tag: Option, + body: Option, + promise: DeliveryPromise, + settle: Option, +} + +impl SenderLink { + pub(crate) fn new(inner: Cell) -> SenderLink { + SenderLink { inner } + } + + pub fn id(&self) -> u32 { + self.inner.id as u32 + } + + pub fn name(&self) -> &ByteString { + &self.inner.name + } + + pub fn remote_handle(&self) -> Handle { + self.inner.remote_handle + } + + pub fn session(&self) -> &Session { + &self.inner.get_ref().session + } + + pub fn session_mut(&mut self) -> &mut Session { + &mut self.inner.get_mut().session + } + + pub fn send(&self, body: T) -> impl Future> + where + T: Into, + { + self.inner.get_mut().send(body, None) + } + + pub fn send_with_tag( + &self, + body: T, + tag: Bytes, + ) -> impl Future> + where + T: Into, + { + self.inner.get_mut().send(body, Some(tag)) + } + + pub fn settle_message(&self, id: DeliveryNumber, state: DeliveryState) { + self.inner.get_mut().settle_message(id, state) + } + + pub fn close(&self) -> impl Future> { + self.inner.get_mut().close(None) + } + + pub fn close_with_error( + &self, + error: Error, + ) -> impl Future> { + self.inner.get_mut().close(Some(error)) + } +} + +impl SenderLinkInner { + pub(crate) fn new( + id: usize, + name: ByteString, + handle: Handle, + delivery_count: SequenceNo, + session: Cell, + ) -> SenderLinkInner { + SenderLinkInner { + id, + name, + delivery_count, + idx: 0, + session: Session::new(session), + remote_handle: handle, + link_credit: 0, + pending_transfers: VecDeque::new(), + error: None, + closed: false, + } + } + + pub fn id(&self) -> u32 { + self.id as u32 + } + + pub fn remote_handle(&self) -> Handle { + self.remote_handle + } + + pub(crate) fn name(&self) -> &ByteString { + &self.name + } + + pub(crate) fn detached(&mut self, err: AmqpTransportError) { + // drop pending transfers + for tr in self.pending_transfers.drain(..) { + let _ = tr.promise.send(Err(err.clone())); + } + + self.error = Some(err); + } + + pub(crate) fn close( + &self, + error: Option, + ) -> impl Future> { + if self.closed { + Either::Left(ok(())) + } else { + let (tx, rx) = oneshot::channel(); + + self.session + .inner + .get_mut() + .detach_sender_link(self.id, true, error, tx); + + Either::Right(async move { + match rx.await { + Ok(Ok(_)) => Ok(()), + Ok(Err(e)) => Err(e), + Err(_) => Err(AmqpTransportError::Disconnected), + } + }) + } + } + + pub(crate) fn set_error(&mut self, err: AmqpTransportError) { + // drop pending transfers + for tr in self.pending_transfers.drain(..) { + let _ = tr.promise.send(Err(err.clone())); + } + + self.error = Some(err); + } + + pub fn apply_flow(&mut self, flow: &Flow) { + // #2.7.6 + if let Some(credit) = flow.link_credit() { + let delta = flow + .delivery_count + .unwrap_or(0) + .saturating_add(credit) + .saturating_sub(self.delivery_count); + + let session = self.session.inner.get_mut(); + + // credit became available => drain pending_transfers + self.link_credit += delta; + while self.link_credit > 0 { + if let Some(transfer) = self.pending_transfers.pop_front() { + self.link_credit -= 1; + let _ = self.delivery_count.saturating_add(1); + session.send_transfer( + self.remote_handle, + transfer.idx, + transfer.body, + transfer.promise, + transfer.tag, + transfer.settle, + ); + } else { + break; + } + } + } + + if flow.echo() { + // todo: send flow + } + } + + pub fn send>(&mut self, body: T, tag: Option) -> Delivery { + if let Some(ref err) = self.error { + Delivery::Resolved(Err(err.clone())) + } else { + let body = body.into(); + let (delivery_tx, delivery_rx) = oneshot::channel(); + if self.link_credit == 0 { + self.pending_transfers.push_back(PendingTransfer { + tag, + settle: Some(false), + body: Some(body), + idx: self.idx, + promise: delivery_tx, + }); + } else { + let session = self.session.inner.get_mut(); + self.link_credit -= 1; + let _ = self.delivery_count.saturating_add(1); + session.send_transfer( + self.remote_handle, + self.idx, + Some(body), + delivery_tx, + tag, + Some(false), + ); + } + let _ = self.idx.saturating_add(1); + Delivery::Pending(delivery_rx) + } + } + + pub fn settle_message(&mut self, id: DeliveryNumber, state: DeliveryState) { + let _ = self.delivery_count.saturating_add(1); + + let disp = Disposition { + role: Role::Sender, + first: id, + last: None, + settled: true, + state: Some(state), + batchable: false, + }; + self.session.inner.get_mut().post_frame(disp.into()); + } +} + +pub struct SenderLinkBuilder { + frame: Attach, + session: Cell, +} + +impl SenderLinkBuilder { + pub(crate) fn new(name: ByteString, address: ByteString, session: Cell) -> Self { + let target = Target { + address: Some(address), + durable: TerminusDurability::None, + expiry_policy: TerminusExpiryPolicy::SessionEnd, + timeout: 0, + dynamic: false, + dynamic_node_properties: None, + capabilities: None, + }; + let frame = Attach { + name, + handle: 0 as Handle, + role: Role::Sender, + snd_settle_mode: SenderSettleMode::Mixed, + rcv_settle_mode: ReceiverSettleMode::First, + source: None, + target: Some(target), + unsettled: None, + incomplete_unsettled: false, + initial_delivery_count: None, + max_message_size: Some(65536 * 4), + offered_capabilities: None, + desired_capabilities: None, + properties: None, + }; + + SenderLinkBuilder { frame, session } + } + + pub fn max_message_size(mut self, size: u64) -> Self { + self.frame.max_message_size = Some(size); + self + } + + pub fn with_frame(mut self, f: F) -> Self + where + F: FnOnce(&mut Attach), + { + f(&mut self.frame); + self + } + + pub async fn open(self) -> Result { + self.session + .get_mut() + .open_sender_link(self.frame) + .await + .map_err(|_e| AmqpTransportError::Disconnected) + } +} diff --git a/actix-amqp/tests/test_server.rs b/actix-amqp/tests/test_server.rs new file mode 100755 index 000000000..772b2bd43 --- /dev/null +++ b/actix-amqp/tests/test_server.rs @@ -0,0 +1,131 @@ +use std::convert::TryFrom; + +use actix_amqp::server::{self, errors}; +use actix_amqp::{sasl, Configuration}; +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_connect::{default_connector, TcpConnector}; +use actix_service::{fn_factory_with_config, pipeline_factory, Service}; +use actix_testing::TestServer; +use futures::future::{err, Ready}; +use futures::Future; +use http::Uri; + +fn server( + link: server::Link<()>, +) -> impl Future< + Output = Result< + Box< + dyn Service< + Request = server::Message<()>, + Response = server::Outcome, + Error = errors::AmqpError, + Future = Ready, server::Outcome>>, + > + 'static, + >, + errors::LinkError, + >, +> { + println!("OPEN LINK: {:?}", link); + err(errors::LinkError::force_detach().description("unimplemented")) +} + +#[actix_rt::test] +async fn test_simple() -> std::io::Result<()> { + std::env::set_var( + "RUST_LOG", + "actix_codec=info,actix_server=trace,actix_connector=trace,amqp_transport=trace", + ); + env_logger::init(); + + let srv = TestServer::with(|| { + server::Server::new( + server::Handshake::new(|conn: server::Connect<_>| async move { + let conn = conn.open().await.unwrap(); + Ok::<_, errors::AmqpError>(conn.ack(())) + }) + .sasl(server::sasl::no_sasl()), + ) + .finish( + server::App::<()>::new() + .service("test", fn_factory_with_config(server)) + .finish(), + ) + }); + + let uri = Uri::try_from(format!("amqp://{}:{}", srv.host(), srv.port())).unwrap(); + let mut sasl_srv = sasl::connect_service(default_connector()); + let req = sasl::SaslConnect { + uri, + config: Configuration::default(), + time: None, + auth: sasl::SaslAuth { + authz_id: "".to_string(), + authn_id: "user1".to_string(), + password: "password1".to_string(), + }, + }; + let res = sasl_srv.call(req).await; + println!("E: {:?}", res.err()); + + Ok(()) +} + +async fn sasl_auth( + auth: server::Sasl, +) -> Result, server::errors::ServerError<()>> { + let init = auth + .mechanism("PLAIN") + .mechanism("ANONYMOUS") + .mechanism("MSSBCBS") + .mechanism("AMQPCBS") + .init() + .await?; + + if init.mechanism() == "PLAIN" { + if let Some(resp) = init.initial_response() { + if resp == b"\0user1\0password1" { + let succ = init.outcome(amqp_codec::protocol::SaslCode::Ok).await?; + return Ok(succ.open().await?.ack(())); + } + } + } + + let succ = init.outcome(amqp_codec::protocol::SaslCode::Auth).await?; + Ok(succ.open().await?.ack(())) +} + +#[actix_rt::test] +async fn test_sasl() -> std::io::Result<()> { + let srv = TestServer::with(|| { + server::Server::new( + server::Handshake::new(|conn: server::Connect<_>| async move { + let conn = conn.open().await.unwrap(); + Ok::<_, errors::Error>(conn.ack(())) + }) + .sasl(pipeline_factory(sasl_auth).map_err(|e| e.into())), + ) + .finish( + server::App::<()>::new() + .service("test", fn_factory_with_config(server)) + .finish(), + ) + }); + + let uri = Uri::try_from(format!("amqp://{}:{}", srv.host(), srv.port())).unwrap(); + let mut sasl_srv = sasl::connect_service(TcpConnector::new()); + + let req = sasl::SaslConnect { + uri, + config: Configuration::default(), + time: None, + auth: sasl::SaslAuth { + authz_id: "".to_string(), + authn_id: "user1".to_string(), + password: "password1".to_string(), + }, + }; + let res = sasl_srv.call(req).await; + println!("E: {:?}", res.err()); + + Ok(()) +} diff --git a/actix-mqtt/CHANGES.md b/actix-mqtt/CHANGES.md new file mode 100755 index 000000000..3cd8415c4 --- /dev/null +++ b/actix-mqtt/CHANGES.md @@ -0,0 +1,24 @@ +# Changes + +## Unreleased - 2020-xx-xx + + +## 0.2.3 - 2020-03-10 +* Add server handshake timeout + + +## 0.2.2 - 2020-02-04 +* Fix server keep-alive impl + + +## 0.2.1 - 2019-12-25 +* Allow to specify multi-pattern for topics + + +## 0.2.0 - 2019-12-11 +* Migrate to `std::future` +* Support publish with QoS 1 + + +## 0.1.0 - 2019-09-25 +* Initial release diff --git a/actix-mqtt/Cargo.toml b/actix-mqtt/Cargo.toml new file mode 100755 index 000000000..6e11e3f34 --- /dev/null +++ b/actix-mqtt/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "actix-mqtt" +version = "0.2.3" +authors = ["Nikolay Kim "] +description = "MQTT v3.1.1 Client/Server framework" +documentation = "https://docs.rs/actix-mqtt" +repository = "https://github.com/actix/actix-extras.git" +categories = ["network-programming"] +keywords = ["MQTT", "IoT", "messaging"] +license = "MIT OR Apache-2.0" +edition = "2018" + +[dependencies] +mqtt-codec = "0.3.0" +actix-codec = "0.2.0" +actix-service = "1.0.1" +actix-utils = "1.0.4" +actix-router = "0.2.2" +actix-ioframe = "0.4.1" +actix-rt = "1.0.0" + +derive_more = "0.99.2" +bytes = "0.5.3" +either = "1.5.2" +futures = "0.3.1" +pin-project = "0.4.6" +log = "0.4" +bytestring = "0.1.2" +serde = "1.0" +serde_json = "1.0" +uuid = { version = "0.8", features = ["v4"] } + +[dev-dependencies] +env_logger = "0.6" +actix-connect = "1.0.1" +actix-server = "1.0.0" +actix-testing = "1.0.0" +actix-rt = "1.0.0" diff --git a/actix-mqtt/LICENSE-APACHE b/actix-mqtt/LICENSE-APACHE new file mode 100755 index 000000000..b3dcc4422 --- /dev/null +++ b/actix-mqtt/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2019-NOW Nikolay Kim + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/actix-mqtt/LICENSE-MIT b/actix-mqtt/LICENSE-MIT new file mode 100755 index 000000000..ecc97592e --- /dev/null +++ b/actix-mqtt/LICENSE-MIT @@ -0,0 +1,25 @@ +Copyright (c) 2019 Nikolay Kim + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/actix-mqtt/README.md b/actix-mqtt/README.md new file mode 100755 index 000000000..e4e3a537e --- /dev/null +++ b/actix-mqtt/README.md @@ -0,0 +1 @@ +# MQTT 3.1.1 Client/Server framework [![Build Status](https://travis-ci.org/actix/actix-mqtt.svg?branch=master)](https://travis-ci.org/actix/actix-mqtt) [![codecov](https://codecov.io/gh/actix/actix-mqtt/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-mqtt) [![crates.io](https://meritbadge.herokuapp.com/actix-mqtt)](https://crates.io/crates/actix-mqtt) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) diff --git a/actix-mqtt/codec/CHANGES.md b/actix-mqtt/codec/CHANGES.md new file mode 100755 index 000000000..78c0593d7 --- /dev/null +++ b/actix-mqtt/codec/CHANGES.md @@ -0,0 +1,17 @@ +# Changes + +## Unreleased - 2020-xx-xx + + +## 0.3.0 - 2019-12-11 +* Use `bytestring` instead of string +* Upgrade actix-codec to 0.2.0 + + +## 0.2.0 - 2019-09-09 +* Remove `Packet::Empty` +* Add `max frame size` config + + +## 0.1.0 - 2019-06-18 +* Initial release diff --git a/actix-mqtt/codec/Cargo.toml b/actix-mqtt/codec/Cargo.toml new file mode 100755 index 000000000..5be257028 --- /dev/null +++ b/actix-mqtt/codec/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "mqtt-codec" +version = "0.3.0" +authors = [ + "Max Gortman ", + "Nikolay Kim ", + "Flier Lu ", +] +description = "MQTT v3.1.1 Codec" +documentation = "https://docs.rs/mqtt-codec" +repository = "https://github.com/actix/actix-mqtt.git" +readme = "README.md" +keywords = ["MQTT", "IoT", "messaging"] +license = "MIT/Apache-2.0" +edition = "2018" + +[dependencies] +bitflags = "1.0" +bytes = "0.5.2" +bytestring = "0.1.0" +actix-codec = "0.2.0" diff --git a/actix-mqtt/codec/LICENSE-APACHE b/actix-mqtt/codec/LICENSE-APACHE new file mode 100755 index 000000000..b3dcc4422 --- /dev/null +++ b/actix-mqtt/codec/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2019-NOW Nikolay Kim + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/actix-mqtt/codec/LICENSE-MIT b/actix-mqtt/codec/LICENSE-MIT new file mode 100755 index 000000000..ecc97592e --- /dev/null +++ b/actix-mqtt/codec/LICENSE-MIT @@ -0,0 +1,25 @@ +Copyright (c) 2019 Nikolay Kim + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/actix-mqtt/codec/README.md b/actix-mqtt/codec/README.md new file mode 100755 index 000000000..7bf1dd08c --- /dev/null +++ b/actix-mqtt/codec/README.md @@ -0,0 +1,3 @@ +# MQTT v3.1 Codec + +MQTT v3.1 Codec implementation diff --git a/actix-mqtt/codec/src/codec/decode.rs b/actix-mqtt/codec/src/codec/decode.rs new file mode 100755 index 000000000..f037dcefc --- /dev/null +++ b/actix-mqtt/codec/src/codec/decode.rs @@ -0,0 +1,616 @@ +use std::convert::TryFrom; +use std::num::{NonZeroU16, NonZeroU32}; + +use bytes::{Buf, Bytes}; +use bytestring::ByteString; + +use crate::error::ParseError; +use crate::packet::*; +use crate::proto::*; + +use super::{ConnectAckFlags, ConnectFlags, FixedHeader, WILL_QOS_SHIFT}; + +pub(crate) fn read_packet(mut src: Bytes, header: FixedHeader) -> Result { + match header.packet_type { + CONNECT => decode_connect_packet(&mut src), + CONNACK => decode_connect_ack_packet(&mut src), + PUBLISH => decode_publish_packet(&mut src, header), + PUBACK => Ok(Packet::PublishAck { + packet_id: NonZeroU16::parse(&mut src)?, + }), + PUBREC => Ok(Packet::PublishReceived { + packet_id: NonZeroU16::parse(&mut src)?, + }), + PUBREL => Ok(Packet::PublishRelease { + packet_id: NonZeroU16::parse(&mut src)?, + }), + PUBCOMP => Ok(Packet::PublishComplete { + packet_id: NonZeroU16::parse(&mut src)?, + }), + SUBSCRIBE => decode_subscribe_packet(&mut src), + SUBACK => decode_subscribe_ack_packet(&mut src), + UNSUBSCRIBE => decode_unsubscribe_packet(&mut src), + UNSUBACK => Ok(Packet::UnsubscribeAck { + packet_id: NonZeroU16::parse(&mut src)?, + }), + PINGREQ => Ok(Packet::PingRequest), + PINGRESP => Ok(Packet::PingResponse), + DISCONNECT => Ok(Packet::Disconnect), + _ => Err(ParseError::UnsupportedPacketType), + } +} + +macro_rules! check_flag { + ($flags:expr, $flag:expr) => { + ($flags & $flag.bits()) == $flag.bits() + }; +} + +macro_rules! ensure { + ($cond:expr, $e:expr) => { + if !($cond) { + return Err($e); + } + }; + ($cond:expr, $fmt:expr, $($arg:tt)+) => { + if !($cond) { + return Err($fmt, $($arg)+); + } + }; +} + +pub fn decode_variable_length(src: &[u8]) -> Result, ParseError> { + if let Some((len, consumed, more)) = src + .iter() + .enumerate() + .scan((0, true), |state, (idx, x)| { + if !state.1 || idx > 3 { + return None; + } + state.0 += ((x & 0x7F) as usize) << (idx * 7); + state.1 = x & 0x80 != 0; + Some((state.0, idx + 1, state.1)) + }) + .last() + { + ensure!(!more || consumed < 4, ParseError::InvalidLength); + return Ok(Some((len, consumed))); + } + + Ok(None) +} + +fn decode_connect_packet(src: &mut Bytes) -> Result { + ensure!(src.remaining() >= 10, ParseError::InvalidLength); + let len = src.get_u16(); + + ensure!( + len == 4 && &src.bytes()[0..4] == b"MQTT", + ParseError::InvalidProtocol + ); + src.advance(4); + + let level = src.get_u8(); + ensure!( + level == DEFAULT_MQTT_LEVEL, + ParseError::UnsupportedProtocolLevel + ); + + let flags = src.get_u8(); + ensure!((flags & 0x01) == 0, ParseError::ConnectReservedFlagSet); + + let keep_alive = src.get_u16(); + let client_id = decode_utf8_str(src)?; + + ensure!( + !client_id.is_empty() || check_flag!(flags, ConnectFlags::CLEAN_SESSION), + ParseError::InvalidClientId + ); + + let topic = if check_flag!(flags, ConnectFlags::WILL) { + Some(decode_utf8_str(src)?) + } else { + None + }; + let message = if check_flag!(flags, ConnectFlags::WILL) { + Some(decode_length_bytes(src)?) + } else { + None + }; + let username = if check_flag!(flags, ConnectFlags::USERNAME) { + Some(decode_utf8_str(src)?) + } else { + None + }; + let password = if check_flag!(flags, ConnectFlags::PASSWORD) { + Some(decode_length_bytes(src)?) + } else { + None + }; + let last_will = if let Some(topic) = topic { + Some(LastWill { + qos: QoS::from((flags & ConnectFlags::WILL_QOS.bits()) >> WILL_QOS_SHIFT), + retain: check_flag!(flags, ConnectFlags::WILL_RETAIN), + topic, + message: message.unwrap(), + }) + } else { + None + }; + + Ok(Packet::Connect(Connect { + protocol: Protocol::MQTT(level), + clean_session: check_flag!(flags, ConnectFlags::CLEAN_SESSION), + keep_alive, + client_id, + last_will, + username, + password, + })) +} + +fn decode_connect_ack_packet(src: &mut Bytes) -> Result { + ensure!(src.remaining() >= 2, ParseError::InvalidLength); + let flags = src.get_u8(); + ensure!( + (flags & 0b1111_1110) == 0, + ParseError::ConnAckReservedFlagSet + ); + + let return_code = src.get_u8(); + Ok(Packet::ConnectAck { + session_present: check_flag!(flags, ConnectAckFlags::SESSION_PRESENT), + return_code: ConnectCode::from(return_code), + }) +} + +fn decode_publish_packet(src: &mut Bytes, header: FixedHeader) -> Result { + let topic = decode_utf8_str(src)?; + let qos = QoS::from((header.packet_flags & 0b0110) >> 1); + let packet_id = if qos == QoS::AtMostOnce { + None + } else { + Some(NonZeroU16::parse(src)?) + }; + + let len = src.remaining(); + let payload = src.split_to(len); + + Ok(Packet::Publish(Publish { + dup: (header.packet_flags & 0b1000) == 0b1000, + qos, + retain: (header.packet_flags & 0b0001) == 0b0001, + topic, + packet_id, + payload, + })) +} + +fn decode_subscribe_packet(src: &mut Bytes) -> Result { + let packet_id = NonZeroU16::parse(src)?; + let mut topic_filters = Vec::new(); + while src.remaining() > 0 { + let topic = decode_utf8_str(src)?; + ensure!(src.remaining() >= 1, ParseError::InvalidLength); + let qos = QoS::from(src.get_u8() & 0x03); + topic_filters.push((topic, qos)); + } + + Ok(Packet::Subscribe { + packet_id, + topic_filters, + }) +} + +fn decode_subscribe_ack_packet(src: &mut Bytes) -> Result { + let packet_id = NonZeroU16::parse(src)?; + let status = src + .bytes() + .iter() + .map(|code| { + if *code == 0x80 { + SubscribeReturnCode::Failure + } else { + SubscribeReturnCode::Success(QoS::from(code & 0x03)) + } + }) + .collect(); + Ok(Packet::SubscribeAck { packet_id, status }) +} + +fn decode_unsubscribe_packet(src: &mut Bytes) -> Result { + let packet_id = NonZeroU16::parse(src)?; + let mut topic_filters = Vec::new(); + while src.remaining() > 0 { + topic_filters.push(decode_utf8_str(src)?); + } + Ok(Packet::Unsubscribe { + packet_id, + topic_filters, + }) +} + +fn decode_length_bytes(src: &mut Bytes) -> Result { + let len: u16 = NonZeroU16::parse(src)?.into(); + ensure!(src.remaining() >= len as usize, ParseError::InvalidLength); + Ok(src.split_to(len as usize)) +} + +fn decode_utf8_str(src: &mut Bytes) -> Result { + Ok(ByteString::try_from(decode_length_bytes(src)?)?) +} + +pub(crate) trait ByteBuf: Buf { + fn inner_mut(&mut self) -> &mut Bytes; +} + +impl ByteBuf for Bytes { + fn inner_mut(&mut self) -> &mut Bytes { + self + } +} + +impl ByteBuf for bytes::buf::ext::Take<&mut Bytes> { + fn inner_mut(&mut self) -> &mut Bytes { + self.get_mut() + } +} + +pub(crate) trait Parse: Sized { + fn parse(src: &mut B) -> Result; +} + +impl Parse for bool { + fn parse(src: &mut B) -> Result { + ensure!(src.has_remaining(), ParseError::InvalidLength); // expected more data within the field + let v = src.get_u8(); + ensure!(v <= 0x1, ParseError::MalformedPacket); // value is invalid + Ok(v == 0x1) + } +} + +impl Parse for u16 { + fn parse(src: &mut B) -> Result { + ensure!(src.remaining() >= 2, ParseError::InvalidLength); + Ok(src.get_u16()) + } +} + +impl Parse for u32 { + fn parse(src: &mut B) -> Result { + ensure!(src.remaining() >= 4, ParseError::InvalidLength); // expected more data within the field + let val = src.get_u32(); + Ok(val) + } +} + +impl Parse for NonZeroU32 { + fn parse(src: &mut B) -> Result { + let val = NonZeroU32::new(u32::parse(src)?).ok_or(ParseError::MalformedPacket)?; + Ok(val) + } +} + +impl Parse for NonZeroU16 { + fn parse(src: &mut B) -> Result { + Ok(NonZeroU16::new(u16::parse(src)?).ok_or(ParseError::MalformedPacket)?) + } +} + +impl Parse for Bytes { + fn parse(src: &mut B) -> Result { + let len = u16::parse(src)? as usize; + ensure!(src.remaining() >= len, ParseError::InvalidLength); + Ok(src.inner_mut().split_to(len)) + } +} + +pub(crate) type ByteStr = ByteString; + +impl Parse for ByteStr { + fn parse(src: &mut B) -> Result { + let bytes = Bytes::parse(src)?; + Ok(ByteString::try_from(bytes)?) + } +} + +impl Parse for (ByteStr, ByteStr) { + fn parse(src: &mut B) -> Result { + let key = ByteStr::parse(src)?; + let val = ByteStr::parse(src)?; + Ok((key, val)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! assert_decode_packet ( + ($bytes:expr, $res:expr) => {{ + let fixed = $bytes.as_ref()[0]; + let (_len, consumned) = decode_variable_length(&$bytes[1..]).unwrap().unwrap(); + let hdr = FixedHeader { + packet_type: fixed >> 4, + packet_flags: fixed & 0xF, + remaining_length: $bytes.len() - consumned - 1, + }; + let cur = Bytes::from_static(&$bytes[consumned + 1..]); + assert_eq!(read_packet(cur, hdr), Ok($res)); + }}; + ); + + fn packet_id(v: u16) -> NonZeroU16 { + NonZeroU16::new(v).unwrap() + } + + #[test] + fn test_decode_variable_length() { + macro_rules! assert_variable_length ( + ($bytes:expr, $res:expr) => {{ + assert_eq!(decode_variable_length($bytes), Ok(Some($res))); + }}; + + ($bytes:expr, $res:expr, $rest:expr) => {{ + assert_eq!(decode_variable_length($bytes), Ok(Some($res))); + }}; + ); + + assert_variable_length!(b"\x7f\x7f", (127, 1), b"\x7f"); + + //assert_eq!(decode_variable_length(b"\xff\xff\xff"), Ok(None)); + assert_eq!( + decode_variable_length(b"\xff\xff\xff\xff\xff\xff"), + Err(ParseError::InvalidLength) + ); + + assert_variable_length!(b"\x00", (0, 1)); + assert_variable_length!(b"\x7f", (127, 1)); + assert_variable_length!(b"\x80\x01", (128, 2)); + assert_variable_length!(b"\xff\x7f", (16383, 2)); + assert_variable_length!(b"\x80\x80\x01", (16384, 3)); + assert_variable_length!(b"\xff\xff\x7f", (2097151, 3)); + assert_variable_length!(b"\x80\x80\x80\x01", (2097152, 4)); + assert_variable_length!(b"\xff\xff\xff\x7f", (268435455, 4)); + } + + // #[test] + // fn test_decode_header() { + // assert_eq!( + // decode_header(b"\x20\x7f"), + // Done( + // &b""[..], + // FixedHeader { + // packet_type: CONNACK, + // packet_flags: 0, + // remaining_length: 127, + // } + // ) + // ); + + // assert_eq!( + // decode_header(b"\x3C\x82\x7f"), + // Done( + // &b""[..], + // FixedHeader { + // packet_type: PUBLISH, + // packet_flags: 0x0C, + // remaining_length: 16258, + // } + // ) + // ); + + // assert_eq!(decode_header(b"\x20"), Incomplete(Needed::Unknown)); + // } + + #[test] + fn test_decode_connect_packets() { + assert_eq!( + decode_connect_packet(&mut Bytes::from_static( + b"\x00\x04MQTT\x04\xC0\x00\x3C\x00\x0512345\x00\x04user\x00\x04pass" + )), + Ok(Packet::Connect(Connect { + protocol: Protocol::MQTT(4), + clean_session: false, + keep_alive: 60, + client_id: ByteString::try_from(Bytes::from_static(b"12345")).unwrap(), + last_will: None, + username: Some(ByteString::try_from(Bytes::from_static(b"user")).unwrap()), + password: Some(Bytes::from(&b"pass"[..])), + })) + ); + + assert_eq!( + decode_connect_packet(&mut Bytes::from_static( + b"\x00\x04MQTT\x04\x14\x00\x3C\x00\x0512345\x00\x05topic\x00\x07message" + )), + Ok(Packet::Connect(Connect { + protocol: Protocol::MQTT(4), + clean_session: false, + keep_alive: 60, + client_id: ByteString::try_from(Bytes::from_static(b"12345")).unwrap(), + last_will: Some(LastWill { + qos: QoS::ExactlyOnce, + retain: false, + topic: ByteString::try_from(Bytes::from_static(b"topic")).unwrap(), + message: Bytes::from(&b"message"[..]), + }), + username: None, + password: None, + })) + ); + + assert_eq!( + decode_connect_packet(&mut Bytes::from_static(b"\x00\x02MQ00000000000000000000")), + Err(ParseError::InvalidProtocol), + ); + assert_eq!( + decode_connect_packet(&mut Bytes::from_static(b"\x00\x10MQ00000000000000000000")), + Err(ParseError::InvalidProtocol), + ); + assert_eq!( + decode_connect_packet(&mut Bytes::from_static(b"\x00\x04MQAA00000000000000000000")), + Err(ParseError::InvalidProtocol), + ); + assert_eq!( + decode_connect_packet(&mut Bytes::from_static( + b"\x00\x04MQTT\x0300000000000000000000" + )), + Err(ParseError::UnsupportedProtocolLevel), + ); + assert_eq!( + decode_connect_packet(&mut Bytes::from_static( + b"\x00\x04MQTT\x04\xff00000000000000000000" + )), + Err(ParseError::ConnectReservedFlagSet) + ); + + assert_eq!( + decode_connect_ack_packet(&mut Bytes::from_static(b"\x01\x04")), + Ok(Packet::ConnectAck { + session_present: true, + return_code: ConnectCode::BadUserNameOrPassword + }) + ); + + assert_eq!( + decode_connect_ack_packet(&mut Bytes::from_static(b"\x03\x04")), + Err(ParseError::ConnAckReservedFlagSet) + ); + + assert_decode_packet!( + b"\x20\x02\x01\x04", + Packet::ConnectAck { + session_present: true, + return_code: ConnectCode::BadUserNameOrPassword, + } + ); + + assert_decode_packet!(b"\xe0\x00", Packet::Disconnect); + } + + #[test] + fn test_decode_publish_packets() { + //assert_eq!( + // decode_publish_packet(b"\x00\x05topic\x12\x34"), + // Done(&b""[..], ("topic".to_owned(), 0x1234)) + //); + + assert_decode_packet!( + b"\x3d\x0D\x00\x05topic\x43\x21data", + Packet::Publish(Publish { + dup: true, + retain: true, + qos: QoS::ExactlyOnce, + topic: ByteString::try_from(Bytes::from_static(b"topic")).unwrap(), + packet_id: Some(packet_id(0x4321)), + payload: Bytes::from_static(b"data"), + }) + ); + assert_decode_packet!( + b"\x30\x0b\x00\x05topicdata", + Packet::Publish(Publish { + dup: false, + retain: false, + qos: QoS::AtMostOnce, + topic: ByteString::try_from(Bytes::from_static(b"topic")).unwrap(), + packet_id: None, + payload: Bytes::from_static(b"data"), + }) + ); + + assert_decode_packet!( + b"\x40\x02\x43\x21", + Packet::PublishAck { + packet_id: packet_id(0x4321), + } + ); + assert_decode_packet!( + b"\x50\x02\x43\x21", + Packet::PublishReceived { + packet_id: packet_id(0x4321), + } + ); + assert_decode_packet!( + b"\x60\x02\x43\x21", + Packet::PublishRelease { + packet_id: packet_id(0x4321), + } + ); + assert_decode_packet!( + b"\x70\x02\x43\x21", + Packet::PublishComplete { + packet_id: packet_id(0x4321), + } + ); + } + + #[test] + fn test_decode_subscribe_packets() { + let p = Packet::Subscribe { + packet_id: packet_id(0x1234), + topic_filters: vec![ + ( + ByteString::try_from(Bytes::from_static(b"test")).unwrap(), + QoS::AtLeastOnce, + ), + ( + ByteString::try_from(Bytes::from_static(b"filter")).unwrap(), + QoS::ExactlyOnce, + ), + ], + }; + + assert_eq!( + decode_subscribe_packet(&mut Bytes::from_static( + b"\x12\x34\x00\x04test\x01\x00\x06filter\x02" + )), + Ok(p.clone()) + ); + assert_decode_packet!(b"\x82\x12\x12\x34\x00\x04test\x01\x00\x06filter\x02", p); + + let p = Packet::SubscribeAck { + packet_id: packet_id(0x1234), + status: vec![ + SubscribeReturnCode::Success(QoS::AtLeastOnce), + SubscribeReturnCode::Failure, + SubscribeReturnCode::Success(QoS::ExactlyOnce), + ], + }; + + assert_eq!( + decode_subscribe_ack_packet(&mut Bytes::from_static(b"\x12\x34\x01\x80\x02")), + Ok(p.clone()) + ); + assert_decode_packet!(b"\x90\x05\x12\x34\x01\x80\x02", p); + + let p = Packet::Unsubscribe { + packet_id: packet_id(0x1234), + topic_filters: vec![ + ByteString::try_from(Bytes::from_static(b"test")).unwrap(), + ByteString::try_from(Bytes::from_static(b"filter")).unwrap(), + ], + }; + + assert_eq!( + decode_unsubscribe_packet(&mut Bytes::from_static( + b"\x12\x34\x00\x04test\x00\x06filter" + )), + Ok(p.clone()) + ); + assert_decode_packet!(b"\xa2\x10\x12\x34\x00\x04test\x00\x06filter", p); + + assert_decode_packet!( + b"\xb0\x02\x43\x21", + Packet::UnsubscribeAck { + packet_id: packet_id(0x4321), + } + ); + } + + #[test] + fn test_decode_ping_packets() { + assert_decode_packet!(b"\xc0\x00", Packet::PingRequest); + assert_decode_packet!(b"\xd0\x00", Packet::PingResponse); + } +} diff --git a/actix-mqtt/codec/src/codec/encode.rs b/actix-mqtt/codec/src/codec/encode.rs new file mode 100755 index 000000000..3cafa308b --- /dev/null +++ b/actix-mqtt/codec/src/codec/encode.rs @@ -0,0 +1,446 @@ +use bytes::{BufMut, BytesMut}; + +use crate::packet::*; +use crate::proto::*; + +use super::{ConnectFlags, WILL_QOS_SHIFT}; + +pub fn write_packet(packet: &Packet, dst: &mut BytesMut, content_size: usize) { + write_fixed_header(packet, dst, content_size); + write_content(packet, dst); +} + +pub fn get_encoded_size(packet: &Packet) -> usize { + match *packet { + Packet::Connect ( ref connect ) => { + match *connect { + Connect {ref last_will, ref client_id, ref username, ref password, ..} => + { + // Protocol Name + Protocol Level + Connect Flags + Keep Alive + let mut n = 2 + 4 + 1 + 1 + 2; + + // Client Id + n += 2 + client_id.len(); + + // Will Topic + Will Message + if let Some(LastWill { ref topic, ref message, .. }) = *last_will { + n += 2 + topic.len() + 2 + message.len(); + } + + if let Some(ref s) = *username { + n += 2 + s.len(); + } + + if let Some(ref s) = *password { + n += 2 + s.len(); + } + + n + } + } + } + + Packet::Publish( Publish{ qos, ref topic, ref payload, .. }) => { + // Topic + Packet Id + Payload + if qos == QoS::AtLeastOnce || qos == QoS::ExactlyOnce { + 4 + topic.len() + payload.len() + } else { + 2 + topic.len() + payload.len() + } + } + + Packet::ConnectAck { .. } | // Flags + Return Code + Packet::PublishAck { .. } | // Packet Id + Packet::PublishReceived { .. } | // Packet Id + Packet::PublishRelease { .. } | // Packet Id + Packet::PublishComplete { .. } | // Packet Id + Packet::UnsubscribeAck { .. } => 2, // Packet Id + + Packet::Subscribe { ref topic_filters, .. } => { + 2 + topic_filters.iter().fold(0, |acc, &(ref filter, _)| acc + 2 + filter.len() + 1) + } + + Packet::SubscribeAck { ref status, .. } => 2 + status.len(), + + Packet::Unsubscribe { ref topic_filters, .. } => { + 2 + topic_filters.iter().fold(0, |acc, filter| acc + 2 + filter.len()) + } + + Packet::PingRequest | Packet::PingResponse | Packet::Disconnect => 0, + } +} + +#[inline] +fn write_fixed_header(packet: &Packet, dst: &mut BytesMut, content_size: usize) { + dst.put_u8((packet.packet_type() << 4) | packet.packet_flags()); + write_variable_length(content_size, dst); +} + +fn write_content(packet: &Packet, dst: &mut BytesMut) { + match *packet { + Packet::Connect(ref connect) => match *connect { + Connect { + protocol, + clean_session, + keep_alive, + ref last_will, + ref client_id, + ref username, + ref password, + } => { + write_slice(protocol.name().as_bytes(), dst); + + let mut flags = ConnectFlags::empty(); + + if username.is_some() { + flags |= ConnectFlags::USERNAME; + } + if password.is_some() { + flags |= ConnectFlags::PASSWORD; + } + + if let Some(LastWill { qos, retain, .. }) = *last_will { + flags |= ConnectFlags::WILL; + + if retain { + flags |= ConnectFlags::WILL_RETAIN; + } + + let b: u8 = qos as u8; + + flags |= ConnectFlags::from_bits_truncate(b << WILL_QOS_SHIFT); + } + + if clean_session { + flags |= ConnectFlags::CLEAN_SESSION; + } + + dst.put_slice(&[protocol.level(), flags.bits()]); + + dst.put_u16(keep_alive); + + write_slice(client_id.as_bytes(), dst); + + if let Some(LastWill { + ref topic, + ref message, + .. + }) = *last_will + { + write_slice(topic.as_bytes(), dst); + write_slice(&message, dst); + } + + if let Some(ref s) = *username { + write_slice(s.as_bytes(), dst); + } + + if let Some(ref s) = *password { + write_slice(s, dst); + } + } + }, + + Packet::ConnectAck { + session_present, + return_code, + } => { + dst.put_slice(&[if session_present { 0x01 } else { 0x00 }, return_code as u8]); + } + + Packet::Publish(Publish { + qos, + ref topic, + packet_id, + ref payload, + .. + }) => { + write_slice(topic.as_bytes(), dst); + + if qos == QoS::AtLeastOnce || qos == QoS::ExactlyOnce { + dst.put_u16(packet_id.unwrap().into()); + } + + dst.put(payload.as_ref()); + } + + Packet::PublishAck { packet_id } + | Packet::PublishReceived { packet_id } + | Packet::PublishRelease { packet_id } + | Packet::PublishComplete { packet_id } + | Packet::UnsubscribeAck { packet_id } => { + dst.put_u16(packet_id.into()); + } + + Packet::Subscribe { + packet_id, + ref topic_filters, + } => { + dst.put_u16(packet_id.into()); + + for &(ref filter, qos) in topic_filters { + write_slice(filter.as_ref(), dst); + dst.put_slice(&[qos as u8]); + } + } + + Packet::SubscribeAck { + packet_id, + ref status, + } => { + dst.put_u16(packet_id.into()); + + let buf: Vec = status + .iter() + .map(|s| { + if let SubscribeReturnCode::Success(qos) = *s { + qos as u8 + } else { + 0x80 + } + }) + .collect(); + + dst.put_slice(&buf); + } + + Packet::Unsubscribe { + packet_id, + ref topic_filters, + } => { + dst.put_u16(packet_id.into()); + + for filter in topic_filters { + write_slice(filter.as_ref(), dst); + } + } + + Packet::PingRequest | Packet::PingResponse | Packet::Disconnect => {} + } +} + +#[inline] +fn write_slice(r: &[u8], dst: &mut BytesMut) { + dst.put_u16(r.len() as u16); + dst.put_slice(r); +} + +#[inline] +fn write_variable_length(size: usize, dst: &mut BytesMut) { + // todo: verify at higher level + // if size > MAX_VARIABLE_LENGTH { + // Err(Error::new(ErrorKind::Other, "out of range")) + if size <= 127 { + dst.put_u8(size as u8); + } else if size <= 16383 { + // 127 + 127 << 7 + dst.put_slice(&[((size % 128) | 0x80) as u8, (size >> 7) as u8]); + } else if size <= 2_097_151 { + // 127 + 127 << 7 + 127 << 14 + dst.put_slice(&[ + ((size % 128) | 0x80) as u8, + (((size >> 7) % 128) | 0x80) as u8, + (size >> 14) as u8, + ]); + } else { + dst.put_slice(&[ + ((size % 128) | 0x80) as u8, + (((size >> 7) % 128) | 0x80) as u8, + (((size >> 14) % 128) | 0x80) as u8, + (size >> 21) as u8, + ]); + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use bytestring::ByteString; + use std::num::NonZeroU16; + + use super::*; + + fn packet_id(v: u16) -> NonZeroU16 { + NonZeroU16::new(v).unwrap() + } + + #[test] + fn test_encode_variable_length() { + let mut v = BytesMut::new(); + + write_variable_length(123, &mut v); + assert_eq!(v, [123].as_ref()); + + v.clear(); + + write_variable_length(129, &mut v); + assert_eq!(v, b"\x81\x01".as_ref()); + + v.clear(); + + write_variable_length(16383, &mut v); + assert_eq!(v, b"\xff\x7f".as_ref()); + + v.clear(); + + write_variable_length(2097151, &mut v); + assert_eq!(v, b"\xff\xff\x7f".as_ref()); + + v.clear(); + + write_variable_length(268435455, &mut v); + assert_eq!(v, b"\xff\xff\xff\x7f".as_ref()); + + // assert!(v.write_variable_length(MAX_VARIABLE_LENGTH + 1).is_err()) + } + + #[test] + fn test_encode_fixed_header() { + let mut v = BytesMut::new(); + let p = Packet::PingRequest; + + assert_eq!(get_encoded_size(&p), 0); + write_fixed_header(&p, &mut v, 0); + assert_eq!(v, b"\xc0\x00".as_ref()); + + v.clear(); + + let p = Packet::Publish(Publish { + dup: true, + retain: true, + qos: QoS::ExactlyOnce, + topic: ByteString::from_static("topic"), + packet_id: Some(packet_id(0x4321)), + payload: (0..255).collect::>().into(), + }); + + assert_eq!(get_encoded_size(&p), 264); + write_fixed_header(&p, &mut v, 264); + assert_eq!(v, b"\x3d\x88\x02".as_ref()); + } + + macro_rules! assert_packet { + ($p:expr, $data:expr) => { + let mut v = BytesMut::with_capacity(1024); + write_packet(&$p, &mut v, get_encoded_size($p)); + assert_eq!(v.len(), $data.len()); + assert_eq!(v, &$data[..]); + // assert_eq!(read_packet($data.cursor()).unwrap(), (&b""[..], $p)); + }; + } + + #[test] + fn test_encode_connect_packets() { + assert_packet!( + &Packet::Connect(Connect { + protocol: Protocol::MQTT(4), + clean_session: false, + keep_alive: 60, + client_id: ByteString::from_static("12345"), + last_will: None, + username: Some(ByteString::from_static("user")), + password: Some(Bytes::from_static(b"pass")), + }), + &b"\x10\x1D\x00\x04MQTT\x04\xC0\x00\x3C\x00\ +\x0512345\x00\x04user\x00\x04pass"[..] + ); + + assert_packet!( + &Packet::Connect(Connect { + protocol: Protocol::MQTT(4), + clean_session: false, + keep_alive: 60, + client_id: ByteString::from_static("12345"), + last_will: Some(LastWill { + qos: QoS::ExactlyOnce, + retain: false, + topic: ByteString::from_static("topic"), + message: Bytes::from_static(b"message"), + }), + username: None, + password: None, + }), + &b"\x10\x21\x00\x04MQTT\x04\x14\x00\x3C\x00\ +\x0512345\x00\x05topic\x00\x07message"[..] + ); + + assert_packet!(&Packet::Disconnect, b"\xe0\x00"); + } + + #[test] + fn test_encode_publish_packets() { + assert_packet!( + &Packet::Publish(Publish { + dup: true, + retain: true, + qos: QoS::ExactlyOnce, + topic: ByteString::from_static("topic"), + packet_id: Some(packet_id(0x4321)), + payload: Bytes::from_static(b"data"), + }), + b"\x3d\x0D\x00\x05topic\x43\x21data" + ); + + assert_packet!( + &Packet::Publish(Publish { + dup: false, + retain: false, + qos: QoS::AtMostOnce, + topic: ByteString::from_static("topic"), + packet_id: None, + payload: Bytes::from_static(b"data"), + }), + b"\x30\x0b\x00\x05topicdata" + ); + } + + #[test] + fn test_encode_subscribe_packets() { + assert_packet!( + &Packet::Subscribe { + packet_id: packet_id(0x1234), + topic_filters: vec![ + (ByteString::from_static("test"), QoS::AtLeastOnce), + (ByteString::from_static("filter"), QoS::ExactlyOnce) + ], + }, + b"\x82\x12\x12\x34\x00\x04test\x01\x00\x06filter\x02" + ); + + assert_packet!( + &Packet::SubscribeAck { + packet_id: packet_id(0x1234), + status: vec![ + SubscribeReturnCode::Success(QoS::AtLeastOnce), + SubscribeReturnCode::Failure, + SubscribeReturnCode::Success(QoS::ExactlyOnce) + ], + }, + b"\x90\x05\x12\x34\x01\x80\x02" + ); + + assert_packet!( + &Packet::Unsubscribe { + packet_id: packet_id(0x1234), + topic_filters: vec![ + ByteString::from_static("test"), + ByteString::from_static("filter"), + ], + }, + b"\xa2\x10\x12\x34\x00\x04test\x00\x06filter" + ); + + assert_packet!( + &Packet::UnsubscribeAck { + packet_id: packet_id(0x4321) + }, + b"\xb0\x02\x43\x21" + ); + } + + #[test] + fn test_encode_ping_packets() { + assert_packet!(&Packet::PingRequest, b"\xc0\x00"); + assert_packet!(&Packet::PingResponse, b"\xd0\x00"); + } +} diff --git a/actix-mqtt/codec/src/codec/mod.rs b/actix-mqtt/codec/src/codec/mod.rs new file mode 100755 index 000000000..602276a74 --- /dev/null +++ b/actix-mqtt/codec/src/codec/mod.rs @@ -0,0 +1,161 @@ +use actix_codec::{Decoder, Encoder}; +use bytes::{buf::Buf, BytesMut}; + +use crate::error::ParseError; +use crate::proto::QoS; +use crate::{Packet, Publish}; + +mod decode; +mod encode; + +use self::decode::*; +use self::encode::*; + +bitflags! { + pub struct ConnectFlags: u8 { + const USERNAME = 0b1000_0000; + const PASSWORD = 0b0100_0000; + const WILL_RETAIN = 0b0010_0000; + const WILL_QOS = 0b0001_1000; + const WILL = 0b0000_0100; + const CLEAN_SESSION = 0b0000_0010; + } +} + +pub const WILL_QOS_SHIFT: u8 = 3; + +bitflags! { + pub struct ConnectAckFlags: u8 { + const SESSION_PRESENT = 0b0000_0001; + } +} + +#[derive(Debug)] +pub struct Codec { + state: DecodeState, + max_size: usize, +} + +#[derive(Debug, Clone, Copy)] +enum DecodeState { + FrameHeader, + Frame(FixedHeader), +} + +impl Codec { + /// Create `Codec` instance + pub fn new() -> Self { + Codec { + state: DecodeState::FrameHeader, + max_size: 0, + } + } + + /// Set max inbound frame size. + /// + /// If max size is set to `0`, size is unlimited. + /// By default max size is set to `0` + pub fn max_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } +} + +impl Default for Codec { + fn default() -> Self { + Self::new() + } +} + +impl Decoder for Codec { + type Item = Packet; + type Error = ParseError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, ParseError> { + loop { + match self.state { + DecodeState::FrameHeader => { + if src.len() < 2 { + return Ok(None); + } + let fixed = src.as_ref()[0]; + match decode_variable_length(&src.as_ref()[1..])? { + Some((remaining_length, consumed)) => { + // check max message size + if self.max_size != 0 && self.max_size < remaining_length { + return Err(ParseError::MaxSizeExceeded); + } + src.advance(consumed + 1); + self.state = DecodeState::Frame(FixedHeader { + packet_type: fixed >> 4, + packet_flags: fixed & 0xF, + remaining_length, + }); + // todo: validate remaining_length against max frame size config + if src.len() < remaining_length { + // todo: subtract? + src.reserve(remaining_length); // extend receiving buffer to fit the whole frame -- todo: too eager? + return Ok(None); + } + } + None => { + return Ok(None); + } + } + } + DecodeState::Frame(fixed) => { + if src.len() < fixed.remaining_length { + return Ok(None); + } + let packet_buf = src.split_to(fixed.remaining_length); + let packet = read_packet(packet_buf.freeze(), fixed)?; + self.state = DecodeState::FrameHeader; + src.reserve(2); + return Ok(Some(packet)); + } + } + } + } +} + +impl Encoder for Codec { + type Item = Packet; + type Error = ParseError; + + fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), ParseError> { + if let Packet::Publish(Publish { qos, packet_id, .. }) = item { + if (qos == QoS::AtLeastOnce || qos == QoS::ExactlyOnce) && packet_id.is_none() { + return Err(ParseError::PacketIdRequired); + } + } + let content_size = get_encoded_size(&item); + dst.reserve(content_size + 5); + write_packet(&item, dst, content_size); + Ok(()) + } +} + +#[derive(Debug, PartialEq, Clone, Copy)] +pub(crate) struct FixedHeader { + /// MQTT Control Packet type + pub packet_type: u8, + /// Flags specific to each MQTT Control Packet type + pub packet_flags: u8, + /// the number of bytes remaining within the current packet, + /// including data in the variable header and the payload. + pub remaining_length: usize, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_max_size() { + let mut codec = Codec::new().max_size(5); + + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"\0\x09"); + assert_eq!(codec.decode(&mut buf), Err(ParseError::MaxSizeExceeded)); + } +} diff --git a/actix-mqtt/codec/src/error.rs b/actix-mqtt/codec/src/error.rs new file mode 100755 index 000000000..9a6cf1670 --- /dev/null +++ b/actix-mqtt/codec/src/error.rs @@ -0,0 +1,57 @@ +use std::{io, str}; + +#[derive(Debug)] +pub enum ParseError { + InvalidProtocol, + InvalidLength, + MalformedPacket, + UnsupportedProtocolLevel, + ConnectReservedFlagSet, + ConnAckReservedFlagSet, + InvalidClientId, + UnsupportedPacketType, + PacketIdRequired, + MaxSizeExceeded, + IoError(io::Error), + Utf8Error(str::Utf8Error), +} + +impl PartialEq for ParseError { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (ParseError::InvalidProtocol, ParseError::InvalidProtocol) => true, + (ParseError::InvalidLength, ParseError::InvalidLength) => true, + (ParseError::UnsupportedProtocolLevel, ParseError::UnsupportedProtocolLevel) => { + true + } + (ParseError::ConnectReservedFlagSet, ParseError::ConnectReservedFlagSet) => true, + (ParseError::ConnAckReservedFlagSet, ParseError::ConnAckReservedFlagSet) => true, + (ParseError::InvalidClientId, ParseError::InvalidClientId) => true, + (ParseError::UnsupportedPacketType, ParseError::UnsupportedPacketType) => true, + (ParseError::PacketIdRequired, ParseError::PacketIdRequired) => true, + (ParseError::MaxSizeExceeded, ParseError::MaxSizeExceeded) => true, + (ParseError::MalformedPacket, ParseError::MalformedPacket) => true, + (ParseError::IoError(_), _) => false, + (ParseError::Utf8Error(_), _) => false, + _ => false, + } + } +} + +impl From for ParseError { + fn from(err: io::Error) -> Self { + ParseError::IoError(err) + } +} + +impl From for ParseError { + fn from(err: str::Utf8Error) -> Self { + ParseError::Utf8Error(err) + } +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum TopicError { + InvalidTopic, + InvalidLevel, +} diff --git a/actix-mqtt/codec/src/lib.rs b/actix-mqtt/codec/src/lib.rs new file mode 100755 index 000000000..376b8887f --- /dev/null +++ b/actix-mqtt/codec/src/lib.rs @@ -0,0 +1,22 @@ +#[macro_use] +extern crate bitflags; + +extern crate bytestring; + +mod error; +#[macro_use] +mod topic; +#[macro_use] +mod proto; +mod codec; +mod packet; + +pub use self::codec::Codec; +pub use self::error::{ParseError, TopicError}; +pub use self::packet::{Connect, ConnectCode, LastWill, Packet, Publish, SubscribeReturnCode}; +pub use self::proto::{Protocol, QoS}; +pub use self::topic::{Level, Topic}; + +// http://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xhtml +pub const TCP_PORT: u16 = 1883; +pub const SSL_PORT: u16 = 8883; diff --git a/actix-mqtt/codec/src/packet.rs b/actix-mqtt/codec/src/packet.rs new file mode 100755 index 000000000..f11e81aeb --- /dev/null +++ b/actix-mqtt/codec/src/packet.rs @@ -0,0 +1,251 @@ +use bytes::Bytes; +use bytestring::ByteString; +use std::num::NonZeroU16; + +use crate::proto::{Protocol, QoS}; + +#[repr(u8)] +#[derive(Debug, Eq, PartialEq, Copy, Clone)] +/// Connect Return Code +pub enum ConnectCode { + /// Connection accepted + ConnectionAccepted = 0, + /// Connection Refused, unacceptable protocol version + UnacceptableProtocolVersion = 1, + /// Connection Refused, identifier rejected + IdentifierRejected = 2, + /// Connection Refused, Server unavailable + ServiceUnavailable = 3, + /// Connection Refused, bad user name or password + BadUserNameOrPassword = 4, + /// Connection Refused, not authorized + NotAuthorized = 5, + /// Reserved + Reserved = 6, +} + +const_enum!(ConnectCode: u8); + +impl ConnectCode { + pub fn reason(self) -> &'static str { + match self { + ConnectCode::ConnectionAccepted => "Connection Accepted", + ConnectCode::UnacceptableProtocolVersion => { + "Connection Refused, unacceptable protocol version" + } + ConnectCode::IdentifierRejected => "Connection Refused, identifier rejected", + ConnectCode::ServiceUnavailable => "Connection Refused, Server unavailable", + ConnectCode::BadUserNameOrPassword => { + "Connection Refused, bad user name or password" + } + ConnectCode::NotAuthorized => "Connection Refused, not authorized", + _ => "Connection Refused", + } + } +} + +#[derive(Debug, PartialEq, Clone)] +/// Connection Will +pub struct LastWill { + /// the QoS level to be used when publishing the Will Message. + pub qos: QoS, + /// the Will Message is to be Retained when it is published. + pub retain: bool, + /// the Will Topic + pub topic: ByteString, + /// defines the Application Message that is to be published to the Will Topic + pub message: Bytes, +} + +#[derive(Debug, PartialEq, Clone)] +/// Connect packet content +pub struct Connect { + /// mqtt protocol version + pub protocol: Protocol, + /// the handling of the Session state. + pub clean_session: bool, + /// a time interval measured in seconds. + pub keep_alive: u16, + /// Will Message be stored on the Server and associated with the Network Connection. + pub last_will: Option, + /// identifies the Client to the Server. + pub client_id: ByteString, + /// username can be used by the Server for authentication and authorization. + pub username: Option, + /// password can be used by the Server for authentication and authorization. + pub password: Option, +} + +#[derive(Debug, PartialEq, Clone)] +/// Publish message +pub struct Publish { + /// this might be re-delivery of an earlier attempt to send the Packet. + pub dup: bool, + pub retain: bool, + /// the level of assurance for delivery of an Application Message. + pub qos: QoS, + /// the information channel to which payload data is published. + pub topic: ByteString, + /// only present in PUBLISH Packets where the QoS level is 1 or 2. + pub packet_id: Option, + /// the Application Message that is being published. + pub payload: Bytes, +} + +#[derive(Debug, PartialEq, Copy, Clone)] +/// Subscribe Return Code +pub enum SubscribeReturnCode { + Success(QoS), + Failure, +} + +#[derive(Debug, PartialEq, Clone)] +/// MQTT Control Packets +pub enum Packet { + /// Client request to connect to Server + Connect(Connect), + + /// Connect acknowledgment + ConnectAck { + /// enables a Client to establish whether the Client and Server have a consistent view + /// about whether there is already stored Session state. + session_present: bool, + return_code: ConnectCode, + }, + + /// Publish message + Publish(Publish), + + /// Publish acknowledgment + PublishAck { + /// Packet Identifier + packet_id: NonZeroU16, + }, + /// Publish received (assured delivery part 1) + PublishReceived { + /// Packet Identifier + packet_id: NonZeroU16, + }, + /// Publish release (assured delivery part 2) + PublishRelease { + /// Packet Identifier + packet_id: NonZeroU16, + }, + /// Publish complete (assured delivery part 3) + PublishComplete { + /// Packet Identifier + packet_id: NonZeroU16, + }, + + /// Client subscribe request + Subscribe { + /// Packet Identifier + packet_id: NonZeroU16, + /// the list of Topic Filters and QoS to which the Client wants to subscribe. + topic_filters: Vec<(ByteString, QoS)>, + }, + /// Subscribe acknowledgment + SubscribeAck { + packet_id: NonZeroU16, + /// corresponds to a Topic Filter in the SUBSCRIBE Packet being acknowledged. + status: Vec, + }, + + /// Unsubscribe request + Unsubscribe { + /// Packet Identifier + packet_id: NonZeroU16, + /// the list of Topic Filters that the Client wishes to unsubscribe from. + topic_filters: Vec, + }, + /// Unsubscribe acknowledgment + UnsubscribeAck { + /// Packet Identifier + packet_id: NonZeroU16, + }, + + /// PING request + PingRequest, + /// PING response + PingResponse, + + /// Client is disconnecting + Disconnect, +} + +impl Packet { + #[inline] + /// MQTT Control Packet type + pub fn packet_type(&self) -> u8 { + match *self { + Packet::Connect { .. } => CONNECT, + Packet::ConnectAck { .. } => CONNACK, + Packet::Publish { .. } => PUBLISH, + Packet::PublishAck { .. } => PUBACK, + Packet::PublishReceived { .. } => PUBREC, + Packet::PublishRelease { .. } => PUBREL, + Packet::PublishComplete { .. } => PUBCOMP, + Packet::Subscribe { .. } => SUBSCRIBE, + Packet::SubscribeAck { .. } => SUBACK, + Packet::Unsubscribe { .. } => UNSUBSCRIBE, + Packet::UnsubscribeAck { .. } => UNSUBACK, + Packet::PingRequest => PINGREQ, + Packet::PingResponse => PINGRESP, + Packet::Disconnect => DISCONNECT, + } + } + + /// Flags specific to each MQTT Control Packet type + pub fn packet_flags(&self) -> u8 { + match *self { + Packet::Publish(Publish { + dup, qos, retain, .. + }) => { + let mut b = qos as u8; + + b <<= 1; + + if dup { + b |= 0b1000; + } + + if retain { + b |= 0b0001; + } + + b + } + Packet::PublishRelease { .. } + | Packet::Subscribe { .. } + | Packet::Unsubscribe { .. } => 0b0010, + _ => 0, + } + } +} + +impl From for Packet { + fn from(val: Connect) -> Packet { + Packet::Connect(val) + } +} + +impl From for Packet { + fn from(val: Publish) -> Packet { + Packet::Publish(val) + } +} + +pub const CONNECT: u8 = 1; +pub const CONNACK: u8 = 2; +pub const PUBLISH: u8 = 3; +pub const PUBACK: u8 = 4; +pub const PUBREC: u8 = 5; +pub const PUBREL: u8 = 6; +pub const PUBCOMP: u8 = 7; +pub const SUBSCRIBE: u8 = 8; +pub const SUBACK: u8 = 9; +pub const UNSUBSCRIBE: u8 = 10; +pub const UNSUBACK: u8 = 11; +pub const PINGREQ: u8 = 12; +pub const PINGRESP: u8 = 13; +pub const DISCONNECT: u8 = 14; diff --git a/actix-mqtt/codec/src/proto.rs b/actix-mqtt/codec/src/proto.rs new file mode 100755 index 000000000..2d0bba755 --- /dev/null +++ b/actix-mqtt/codec/src/proto.rs @@ -0,0 +1,64 @@ +#[macro_export] +macro_rules! const_enum { + ($name:ty : $repr:ty) => { + impl ::std::convert::From<$repr> for $name { + fn from(u: $repr) -> Self { + unsafe { ::std::mem::transmute(u) } + } + } + }; +} + +pub const DEFAULT_MQTT_LEVEL: u8 = 4; + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Protocol { + MQTT(u8), +} + +impl Protocol { + pub fn name(self) -> &'static str { + match self { + Protocol::MQTT(_) => "MQTT", + } + } + + pub fn level(self) -> u8 { + match self { + Protocol::MQTT(level) => level, + } + } +} + +impl Default for Protocol { + fn default() -> Self { + Protocol::MQTT(DEFAULT_MQTT_LEVEL) + } +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +/// Quality of Service levels +pub enum QoS { + /// At most once delivery + /// + /// The message is delivered according to the capabilities of the underlying network. + /// No response is sent by the receiver and no retry is performed by the sender. + /// The message arrives at the receiver either once or not at all. + AtMostOnce = 0, + /// At least once delivery + /// + /// This quality of service ensures that the message arrives at the receiver at least once. + /// A QoS 1 PUBLISH Packet has a Packet Identifier in its variable header + /// and is acknowledged by a PUBACK Packet. + AtLeastOnce = 1, + /// Exactly once delivery + /// + /// This is the highest quality of service, + /// for use when neither loss nor duplication of messages are acceptable. + /// There is an increased overhead associated with this quality of service. + ExactlyOnce = 2, +} + +const_enum!(QoS: u8); diff --git a/actix-mqtt/codec/src/topic.rs b/actix-mqtt/codec/src/topic.rs new file mode 100755 index 000000000..31b4b0359 --- /dev/null +++ b/actix-mqtt/codec/src/topic.rs @@ -0,0 +1,520 @@ +use std::fmt::{self, Write}; +use std::{io, ops, str::FromStr}; + +use crate::error::TopicError; + +#[inline] +fn is_metadata>(s: T) -> bool { + s.as_ref().chars().nth(0) == Some('$') +} + +#[derive(Debug, Eq, PartialEq, Clone, Hash)] +pub enum Level { + Normal(String), + Metadata(String), // $SYS + Blank, + SingleWildcard, // Single level wildcard + + MultiWildcard, // Multi-level wildcard # +} + +impl Level { + pub fn parse>(s: T) -> Result { + Level::from_str(s.as_ref()) + } + + pub fn normal>(s: T) -> Level { + if s.as_ref().contains(|c| c == '+' || c == '#') { + panic!("invalid normal level `{}` contains +|#", s.as_ref()); + } + + if s.as_ref().chars().nth(0) == Some('$') { + panic!("invalid normal level `{}` starts with $", s.as_ref()) + } + + Level::Normal(String::from(s.as_ref())) + } + + pub fn metadata>(s: T) -> Level { + if s.as_ref().contains(|c| c == '+' || c == '#') { + panic!("invalid metadata level `{}` contains +|#", s.as_ref()); + } + + if s.as_ref().chars().nth(0) != Some('$') { + panic!("invalid metadata level `{}` not starts with $", s.as_ref()) + } + + Level::Metadata(String::from(s.as_ref())) + } + + #[inline] + pub fn value(&self) -> Option<&str> { + match *self { + Level::Normal(ref s) | Level::Metadata(ref s) => Some(s), + _ => None, + } + } + + #[inline] + pub fn is_normal(&self) -> bool { + if let Level::Normal(_) = *self { + true + } else { + false + } + } + + #[inline] + pub fn is_metadata(&self) -> bool { + if let Level::Metadata(_) = *self { + true + } else { + false + } + } + + #[inline] + pub fn is_valid(&self) -> bool { + match *self { + Level::Normal(ref s) => { + s.chars().nth(0) != Some('$') && !s.contains(|c| c == '+' || c == '#') + } + Level::Metadata(ref s) => { + s.chars().nth(0) == Some('$') && !s.contains(|c| c == '+' || c == '#') + } + _ => true, + } + } +} + +#[derive(Debug, Eq, Clone)] +pub struct Topic(Vec); + +impl Topic { + #[inline] + pub fn levels(&self) -> &Vec { + &self.0 + } + + #[inline] + pub fn is_valid(&self) -> bool { + self.0 + .iter() + .position(|level| !level.is_valid()) + .or_else(|| { + self.0 + .iter() + .enumerate() + .position(|(pos, level)| match *level { + Level::MultiWildcard => pos != self.0.len() - 1, + Level::Metadata(_) => pos != 0, + _ => false, + }) + }) + .is_none() + } +} + +macro_rules! match_topic { + ($topic:expr, $levels:expr) => {{ + let mut lhs = $topic.0.iter(); + + for rhs in $levels { + match lhs.next() { + Some(&Level::SingleWildcard) => { + if !rhs.match_level(&Level::SingleWildcard) { + break; + } + } + Some(&Level::MultiWildcard) => { + return rhs.match_level(&Level::MultiWildcard); + } + Some(level) if rhs.match_level(level) => continue, + _ => return false, + } + } + + match lhs.next() { + Some(&Level::MultiWildcard) => true, + Some(_) => false, + None => true, + } + }}; +} + +impl PartialEq for Topic { + fn eq(&self, other: &Topic) -> bool { + match_topic!(self, &other.0) + } +} + +impl> PartialEq for Topic { + fn eq(&self, other: &T) -> bool { + match_topic!(self, other.as_ref().split('/')) + } +} + +impl<'a> From<&'a [Level]> for Topic { + fn from(s: &[Level]) -> Self { + let mut v = vec![]; + + v.extend_from_slice(s); + + Topic(v) + } +} + +impl From> for Topic { + fn from(v: Vec) -> Self { + Topic(v) + } +} + +impl Into> for Topic { + fn into(self) -> Vec { + self.0 + } +} + +impl ops::Deref for Topic { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl ops::DerefMut for Topic { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +#[macro_export] +macro_rules! topic { + ($s:expr) => { + $s.parse::().unwrap() + }; +} + +pub trait MatchLevel { + fn match_level(&self, level: &Level) -> bool; +} + +impl MatchLevel for Level { + fn match_level(&self, level: &Level) -> bool { + match *level { + Level::Normal(ref lhs) => { + if let Level::Normal(ref rhs) = *self { + lhs == rhs + } else { + false + } + } + Level::Metadata(ref lhs) => { + if let Level::Metadata(ref rhs) = *self { + lhs == rhs + } else { + false + } + } + Level::Blank => true, + Level::SingleWildcard | Level::MultiWildcard => !self.is_metadata(), + } + } +} + +impl> MatchLevel for T { + fn match_level(&self, level: &Level) -> bool { + match *level { + Level::Normal(ref lhs) => !is_metadata(self) && lhs == self.as_ref(), + Level::Metadata(ref lhs) => is_metadata(self) && lhs == self.as_ref(), + Level::Blank => self.as_ref().is_empty(), + Level::SingleWildcard | Level::MultiWildcard => !is_metadata(self), + } + } +} + +impl FromStr for Level { + type Err = TopicError; + + #[inline] + fn from_str(s: &str) -> Result { + match s { + "+" => Ok(Level::SingleWildcard), + "#" => Ok(Level::MultiWildcard), + "" => Ok(Level::Blank), + _ => { + if s.contains(|c| c == '+' || c == '#') { + Err(TopicError::InvalidLevel) + } else if is_metadata(s) { + Ok(Level::Metadata(String::from(s))) + } else { + Ok(Level::Normal(String::from(s))) + } + } + } + } +} + +impl FromStr for Topic { + type Err = TopicError; + + #[inline] + fn from_str(s: &str) -> Result { + s.split('/') + .map(|level| Level::from_str(level)) + .collect::, TopicError>>() + .map(Topic) + .and_then(|topic| { + if topic.is_valid() { + Ok(topic) + } else { + Err(TopicError::InvalidTopic) + } + }) + } +} + +impl fmt::Display for Level { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Level::Normal(ref s) | Level::Metadata(ref s) => f.write_str(s.as_str()), + Level::Blank => Ok(()), + Level::SingleWildcard => f.write_char('+'), + Level::MultiWildcard => f.write_char('#'), + } + } +} + +impl fmt::Display for Topic { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut first = true; + + for level in &self.0 { + if first { + first = false; + } else { + f.write_char('/')?; + } + + level.fmt(f)?; + } + + Ok(()) + } +} + +pub trait WriteTopicExt: io::Write { + fn write_level(&mut self, level: &Level) -> io::Result { + match *level { + Level::Normal(ref s) | Level::Metadata(ref s) => self.write(s.as_str().as_bytes()), + Level::Blank => Ok(0), + Level::SingleWildcard => self.write(b"+"), + Level::MultiWildcard => self.write(b"#"), + } + } + + fn write_topic(&mut self, topic: &Topic) -> io::Result { + let mut n = 0; + let mut first = true; + + for level in topic.levels() { + if first { + first = false; + } else { + n += self.write(b"/")?; + } + + n += self.write_level(level)?; + } + + Ok(n) + } +} + +impl WriteTopicExt for W {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_level() { + assert!(Level::normal("sport").is_normal()); + assert!(Level::metadata("$SYS").is_metadata()); + + assert_eq!(Level::normal("sport").value(), Some("sport")); + assert_eq!(Level::metadata("$SYS").value(), Some("$SYS")); + + assert_eq!(Level::normal("sport"), "sport".parse().unwrap()); + assert_eq!(Level::metadata("$SYS"), "$SYS".parse().unwrap()); + + assert!(Level::Normal(String::from("sport")).is_valid()); + assert!(Level::Metadata(String::from("$SYS")).is_valid()); + + assert!(!Level::Normal(String::from("$sport")).is_valid()); + assert!(!Level::Metadata(String::from("SYS")).is_valid()); + + assert!(!Level::Normal(String::from("sport#")).is_valid()); + assert!(!Level::Metadata(String::from("SYS+")).is_valid()); + } + + #[test] + fn test_valid_topic() { + assert!(Topic(vec![ + Level::normal("sport"), + Level::normal("tennis"), + Level::normal("player1") + ]) + .is_valid()); + + assert!(Topic(vec![ + Level::normal("sport"), + Level::normal("tennis"), + Level::MultiWildcard + ]) + .is_valid()); + assert!(Topic(vec![ + Level::metadata("$SYS"), + Level::normal("tennis"), + Level::MultiWildcard + ]) + .is_valid()); + + assert!(Topic(vec![ + Level::normal("sport"), + Level::SingleWildcard, + Level::normal("player1") + ]) + .is_valid()); + + assert!(!Topic(vec![ + Level::normal("sport"), + Level::MultiWildcard, + Level::normal("player1") + ]) + .is_valid()); + assert!(!Topic(vec![ + Level::normal("sport"), + Level::metadata("$SYS"), + Level::normal("player1") + ]) + .is_valid()); + } + + #[test] + fn test_parse_topic() { + assert_eq!( + topic!("sport/tennis/player1"), + Topic::from(vec![ + Level::normal("sport"), + Level::normal("tennis"), + Level::normal("player1") + ]) + ); + + assert_eq!(topic!(""), Topic(vec![Level::Blank])); + assert_eq!( + topic!("/finance"), + Topic::from(vec![Level::Blank, Level::normal("finance")]) + ); + + assert_eq!(topic!("$SYS"), Topic::from(vec![Level::metadata("$SYS")])); + assert!("sport/$SYS".parse::().is_err()); + } + + #[test] + fn test_multi_wildcard_topic() { + assert_eq!( + topic!("sport/tennis/#"), + Topic::from(vec![ + Level::normal("sport"), + Level::normal("tennis"), + Level::MultiWildcard + ]) + ); + + assert_eq!(topic!("#"), Topic::from(vec![Level::MultiWildcard])); + assert!("sport/tennis#".parse::().is_err()); + assert!("sport/tennis/#/ranking".parse::().is_err()); + } + + #[test] + fn test_single_wildcard_topic() { + assert_eq!(topic!("+"), Topic::from(vec![Level::SingleWildcard])); + + assert_eq!( + topic!("+/tennis/#"), + Topic::from(vec![ + Level::SingleWildcard, + Level::normal("tennis"), + Level::MultiWildcard + ]) + ); + + assert_eq!( + topic!("sport/+/player1"), + Topic::from(vec![ + Level::normal("sport"), + Level::SingleWildcard, + Level::normal("player1") + ]) + ); + + assert!("sport+".parse::().is_err()); + } + + #[test] + fn test_write_topic() { + let mut v = vec![]; + let t = vec![ + Level::SingleWildcard, + Level::normal("tennis"), + Level::MultiWildcard, + ] + .into(); + + assert_eq!(v.write_topic(&t).unwrap(), 10); + assert_eq!(v, b"+/tennis/#"); + + assert_eq!(format!("{}", t), "+/tennis/#"); + assert_eq!(t.to_string(), "+/tennis/#"); + } + + #[test] + fn test_match_topic() { + assert!("test".match_level(&Level::normal("test"))); + assert!("$SYS".match_level(&Level::metadata("$SYS"))); + + let t: Topic = "sport/tennis/player1/#".parse().unwrap(); + + assert_eq!(t, "sport/tennis/player1"); + assert_eq!(t, "sport/tennis/player1/ranking"); + assert_eq!(t, "sport/tennis/player1/score/wimbledon"); + + assert_eq!(Topic::from_str("sport/#").unwrap(), "sport"); + + let t: Topic = "sport/tennis/+".parse().unwrap(); + + assert_eq!(t, "sport/tennis/player1"); + assert_eq!(t, "sport/tennis/player2"); + assert!(t != "sport/tennis/player1/ranking"); + + let t: Topic = "sport/+".parse().unwrap(); + + assert!(t != "sport"); + assert_eq!(t, "sport/"); + + assert_eq!(Topic::from_str("+/+").unwrap(), "/finance"); + assert_eq!(Topic::from_str("/+").unwrap(), "/finance",); + assert!(Topic::from_str("+").unwrap() != "/finance",); + + assert!(Topic::from_str("#").unwrap() != "$SYS"); + assert!(Topic::from_str("+/monitor/Clients").unwrap() != "$SYS/monitor/Clients"); + assert_eq!(Topic::from_str(&"$SYS/#").unwrap(), "$SYS/"); + assert_eq!( + Topic::from_str("$SYS/monitor/+").unwrap(), + "$SYS/monitor/Clients", + ); + } +} diff --git a/actix-mqtt/examples/basic.rs b/actix-mqtt/examples/basic.rs new file mode 100755 index 000000000..02b33869e --- /dev/null +++ b/actix-mqtt/examples/basic.rs @@ -0,0 +1,35 @@ +use actix_mqtt::{Connect, ConnectAck, MqttServer, Publish}; + +#[derive(Clone)] +struct Session; + +async fn connect(connect: Connect) -> Result, ()> { + log::info!("new connection: {:?}", connect); + Ok(connect.ack(Session, false)) +} + +async fn publish(publish: Publish) -> Result<(), ()> { + log::info!( + "incoming publish: {:?} -> {:?}", + publish.id(), + publish.topic() + ); + Ok(()) +} + +#[actix_rt::main] +async fn main() -> std::io::Result<()> { + std::env::set_var( + "RUST_LOG", + "actix_server=trace,actix_mqtt=trace,basic=trace", + ); + env_logger::init(); + + actix_server::Server::build() + .bind("mqtt", "127.0.0.1:1883", || { + MqttServer::new(connect).finish(publish) + })? + .workers(1) + .run() + .await +} diff --git a/actix-mqtt/src/cell.rs b/actix-mqtt/src/cell.rs new file mode 100755 index 000000000..9ed09c0d8 --- /dev/null +++ b/actix-mqtt/src/cell.rs @@ -0,0 +1,64 @@ +//! Custom cell impl +use std::cell::UnsafeCell; +use std::ops::Deref; +use std::rc::Rc; +use std::task::{Context, Poll}; + +use actix_service::Service; + +pub(crate) struct Cell { + inner: Rc>, +} + +impl Clone for Cell { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +impl Deref for Cell { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.get_ref() + } +} + +impl std::fmt::Debug for Cell { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.inner.fmt(f) + } +} + +impl Cell { + pub fn new(inner: T) -> Self { + Self { + inner: Rc::new(UnsafeCell::new(inner)), + } + } + + pub fn get_ref(&self) -> &T { + unsafe { &*self.inner.as_ref().get() } + } + + pub fn get_mut(&mut self) -> &mut T { + unsafe { &mut *self.inner.as_ref().get() } + } +} + +impl Service for Cell { + type Request = T::Request; + type Response = T::Response; + type Error = T::Error; + type Future = T::Future; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.get_mut().poll_ready(cx) + } + + fn call(&mut self, req: T::Request) -> T::Future { + self.get_mut().call(req) + } +} diff --git a/actix-mqtt/src/client.rs b/actix-mqtt/src/client.rs new file mode 100755 index 000000000..6714a79ce --- /dev/null +++ b/actix-mqtt/src/client.rs @@ -0,0 +1,406 @@ +use std::marker::PhantomData; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; +use std::time::Duration; + +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_ioframe as ioframe; +use actix_service::{boxed, IntoService, IntoServiceFactory, Service, ServiceFactory}; +use bytes::Bytes; +use bytestring::ByteString; +use futures::future::{FutureExt, LocalBoxFuture}; +use futures::{Sink, SinkExt, Stream, StreamExt}; +use mqtt_codec as mqtt; + +use crate::cell::Cell; +use crate::default::{SubsNotImplemented, UnsubsNotImplemented}; +use crate::dispatcher::{dispatcher, MqttState}; +use crate::error::MqttError; +use crate::publish::Publish; +use crate::sink::MqttSink; +use crate::subs::{Subscribe, SubscribeResult, Unsubscribe}; + +/// Mqtt client +#[derive(Clone)] +pub struct Client { + client_id: ByteString, + clean_session: bool, + protocol: mqtt::Protocol, + keep_alive: u16, + last_will: Option, + username: Option, + password: Option, + inflight: usize, + _t: PhantomData<(Io, St)>, +} + +impl Client +where + St: 'static, +{ + /// Create new client and provide client id + pub fn new(client_id: ByteString) -> Self { + Client { + client_id, + clean_session: true, + protocol: mqtt::Protocol::default(), + keep_alive: 30, + last_will: None, + username: None, + password: None, + inflight: 15, + _t: PhantomData, + } + } + + /// Mqtt protocol version + pub fn protocol(mut self, val: mqtt::Protocol) -> Self { + self.protocol = val; + self + } + + /// The handling of the Session state. + pub fn clean_session(mut self, val: bool) -> Self { + self.clean_session = val; + self + } + + /// A time interval measured in seconds. + /// + /// keep-alive is set to 30 seconds by default. + pub fn keep_alive(mut self, val: u16) -> Self { + self.keep_alive = val; + self + } + + /// Will Message be stored on the Server and associated with the Network Connection. + /// + /// by default last will value is not set + pub fn last_will(mut self, val: mqtt::LastWill) -> Self { + self.last_will = Some(val); + self + } + + /// Username can be used by the Server for authentication and authorization. + pub fn username(mut self, val: ByteString) -> Self { + self.username = Some(val); + self + } + + /// Password can be used by the Server for authentication and authorization. + pub fn password(mut self, val: Bytes) -> Self { + self.password = Some(val); + self + } + + /// Number of in-flight concurrent messages. + /// + /// in-flight is set to 15 messages + pub fn inflight(mut self, val: usize) -> Self { + self.inflight = val; + self + } + + /// Set state service + /// + /// State service verifies connect ack packet and construct connection state. + pub fn state(self, state: F) -> ServiceBuilder + where + F: IntoService, + Io: AsyncRead + AsyncWrite, + C: Service, Response = ConnectAckResult>, + C::Error: 'static, + { + ServiceBuilder { + state: Cell::new(state.into_service()), + packet: mqtt::Connect { + client_id: self.client_id, + clean_session: self.clean_session, + protocol: self.protocol, + keep_alive: self.keep_alive, + last_will: self.last_will, + username: self.username, + password: self.password, + }, + subscribe: Rc::new(boxed::factory(SubsNotImplemented::default())), + unsubscribe: Rc::new(boxed::factory(UnsubsNotImplemented::default())), + disconnect: None, + keep_alive: self.keep_alive.into(), + inflight: self.inflight, + _t: PhantomData, + } + } +} + +pub struct ServiceBuilder { + state: Cell, + packet: mqtt::Connect, + subscribe: Rc< + boxed::BoxServiceFactory< + St, + Subscribe, + SubscribeResult, + MqttError, + MqttError, + >, + >, + unsubscribe: Rc< + boxed::BoxServiceFactory< + St, + Unsubscribe, + (), + MqttError, + MqttError, + >, + >, + disconnect: Option>>>, + keep_alive: u64, + inflight: usize, + + _t: PhantomData<(Io, St, C)>, +} + +impl ServiceBuilder +where + St: Clone + 'static, + Io: AsyncRead + AsyncWrite + 'static, + C: Service, Response = ConnectAckResult> + 'static, + C::Error: 'static, +{ + /// Service to execute on disconnect + pub fn disconnect(mut self, srv: UF) -> Self + where + UF: IntoService, + U: Service + 'static, + { + self.disconnect = Some(Cell::new(boxed::service( + srv.into_service().map_err(MqttError::Service), + ))); + self + } + + pub fn finish( + self, + service: F, + ) -> impl Service> + where + F: IntoServiceFactory, + T: ServiceFactory< + Config = St, + Request = Publish, + Response = (), + Error = C::Error, + InitError = C::Error, + > + 'static, + { + ioframe::Builder::new() + .service(ConnectService { + connect: self.state, + packet: self.packet, + keep_alive: self.keep_alive, + inflight: self.inflight, + _t: PhantomData, + }) + .finish(dispatcher( + service + .into_factory() + .map_err(MqttError::Service) + .map_init_err(MqttError::Service), + self.subscribe, + self.unsubscribe, + )) + .map_err(|e| match e { + ioframe::ServiceError::Service(e) => e, + ioframe::ServiceError::Encoder(e) => MqttError::Protocol(e), + ioframe::ServiceError::Decoder(e) => MqttError::Protocol(e), + }) + } +} + +struct ConnectService { + connect: Cell, + packet: mqtt::Connect, + keep_alive: u64, + inflight: usize, + _t: PhantomData<(Io, St)>, +} + +impl Service for ConnectService +where + St: 'static, + Io: AsyncRead + AsyncWrite + 'static, + C: Service, Response = ConnectAckResult> + 'static, + C::Error: 'static, +{ + type Request = ioframe::Connect; + type Response = ioframe::ConnectResult, mqtt::Codec>; + type Error = MqttError; + type Future = LocalBoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.connect + .get_mut() + .poll_ready(cx) + .map_err(MqttError::Service) + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + let mut srv = self.connect.clone(); + let packet = self.packet.clone(); + let keep_alive = Duration::from_secs(self.keep_alive as u64); + let inflight = self.inflight; + + // send Connect packet + async move { + let mut framed = req.codec(mqtt::Codec::new()); + framed + .send(mqtt::Packet::Connect(packet)) + .await + .map_err(MqttError::Protocol)?; + + let packet = framed + .next() + .await + .ok_or(MqttError::Disconnected) + .and_then(|res| res.map_err(MqttError::Protocol))?; + + match packet { + mqtt::Packet::ConnectAck { + session_present, + return_code, + } => { + let sink = MqttSink::new(framed.sink().clone()); + let ack = ConnectAck { + sink, + session_present, + return_code, + keep_alive, + inflight, + io: framed, + }; + Ok(srv + .get_mut() + .call(ack) + .await + .map_err(MqttError::Service) + .map(|ack| ack.io.state(ack.state))?) + } + p => Err(MqttError::Unexpected(p, "Expected CONNECT-ACK packet")), + } + } + .boxed_local() + } +} + +pub struct ConnectAck { + io: ioframe::ConnectResult, + sink: MqttSink, + session_present: bool, + return_code: mqtt::ConnectCode, + keep_alive: Duration, + inflight: usize, +} + +impl ConnectAck { + #[inline] + /// Indicates whether there is already stored Session state + pub fn session_present(&self) -> bool { + self.session_present + } + + #[inline] + /// Connect return code + pub fn return_code(&self) -> mqtt::ConnectCode { + self.return_code + } + + #[inline] + /// Mqtt client sink object + pub fn sink(&self) -> &MqttSink { + &self.sink + } + + #[inline] + /// Set connection state and create result object + pub fn state(self, state: St) -> ConnectAckResult { + ConnectAckResult { + io: self.io, + state: MqttState::new(state, self.sink, self.keep_alive, self.inflight), + } + } +} + +impl Stream for ConnectAck +where + Io: AsyncRead + AsyncWrite + Unpin, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.io).poll_next(cx) + } +} + +impl Sink for ConnectAck +where + Io: AsyncRead + AsyncWrite + Unpin, +{ + type Error = mqtt::ParseError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.io).poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: mqtt::Packet) -> Result<(), Self::Error> { + Pin::new(&mut self.io).start_send(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.io).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.io).poll_close(cx) + } +} + +#[pin_project::pin_project] +pub struct ConnectAckResult { + state: MqttState, + io: ioframe::ConnectResult, +} + +impl Stream for ConnectAckResult +where + Io: AsyncRead + AsyncWrite + Unpin, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.io).poll_next(cx) + } +} + +impl Sink for ConnectAckResult +where + Io: AsyncRead + AsyncWrite + Unpin, +{ + type Error = mqtt::ParseError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.io).poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: mqtt::Packet) -> Result<(), Self::Error> { + Pin::new(&mut self.io).start_send(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.io).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.io).poll_close(cx) + } +} diff --git a/actix-mqtt/src/connect.rs b/actix-mqtt/src/connect.rs new file mode 100755 index 000000000..ce64311f1 --- /dev/null +++ b/actix-mqtt/src/connect.rs @@ -0,0 +1,150 @@ +use std::fmt; +use std::ops::Deref; +use std::time::Duration; + +use actix_ioframe as ioframe; +use mqtt_codec as mqtt; + +use crate::sink::MqttSink; + +/// Connect message +pub struct Connect { + connect: mqtt::Connect, + sink: MqttSink, + keep_alive: Duration, + inflight: usize, + io: ioframe::ConnectResult, +} + +impl Connect { + pub(crate) fn new( + connect: mqtt::Connect, + io: ioframe::ConnectResult, + sink: MqttSink, + inflight: usize, + ) -> Self { + Self { + keep_alive: Duration::from_secs(connect.keep_alive as u64), + connect, + io, + sink, + inflight, + } + } + + /// Returns reference to io object + pub fn get_ref(&self) -> &Io { + self.io.get_ref() + } + + /// Returns mutable reference to io object + pub fn get_mut(&mut self) -> &mut Io { + self.io.get_mut() + } + + /// Returns mqtt server sink + pub fn sink(&self) -> &MqttSink { + &self.sink + } + + /// Ack connect message and set state + pub fn ack(self, st: St, session_present: bool) -> ConnectAck { + ConnectAck::new(self.io, st, session_present, self.keep_alive, self.inflight) + } + + /// Create connect ack object with `identifier rejected` return code + pub fn identifier_rejected(self) -> ConnectAck { + ConnectAck { + io: self.io, + session: None, + session_present: false, + return_code: mqtt::ConnectCode::IdentifierRejected, + keep_alive: Duration::from_secs(5), + inflight: 15, + } + } + + /// Create connect ack object with `bad user name or password` return code + pub fn bad_username_or_pwd(self) -> ConnectAck { + ConnectAck { + io: self.io, + session: None, + session_present: false, + return_code: mqtt::ConnectCode::BadUserNameOrPassword, + keep_alive: Duration::from_secs(5), + inflight: 15, + } + } + + /// Create connect ack object with `not authorized` return code + pub fn not_authorized(self) -> ConnectAck { + ConnectAck { + io: self.io, + session: None, + session_present: false, + return_code: mqtt::ConnectCode::NotAuthorized, + keep_alive: Duration::from_secs(5), + inflight: 15, + } + } +} + +impl Deref for Connect { + type Target = mqtt::Connect; + + fn deref(&self) -> &Self::Target { + &self.connect + } +} + +impl fmt::Debug for Connect { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.connect.fmt(f) + } +} + +/// Ack connect message +pub struct ConnectAck { + pub(crate) io: ioframe::ConnectResult, + pub(crate) session: Option, + pub(crate) session_present: bool, + pub(crate) return_code: mqtt::ConnectCode, + pub(crate) keep_alive: Duration, + pub(crate) inflight: usize, +} + +impl ConnectAck { + /// Create connect ack, `session_present` indicates that previous session is presents + pub(crate) fn new( + io: ioframe::ConnectResult, + session: St, + session_present: bool, + keep_alive: Duration, + inflight: usize, + ) -> Self { + Self { + io, + session_present, + keep_alive, + inflight, + session: Some(session), + return_code: mqtt::ConnectCode::ConnectionAccepted, + } + } + + /// Set idle time-out for the connection in milliseconds + /// + /// By default idle time-out is set to 300000 milliseconds + pub fn idle_timeout(mut self, timeout: Duration) -> Self { + self.keep_alive = timeout; + self + } + + /// Set in-flight count. Total number of `in-flight` packets + /// + /// By default in-flight count is set to 15 + pub fn in_flight(mut self, in_flight: usize) -> Self { + self.inflight = in_flight; + self + } +} diff --git a/actix-mqtt/src/default.rs b/actix-mqtt/src/default.rs new file mode 100755 index 000000000..d6bfdb46f --- /dev/null +++ b/actix-mqtt/src/default.rs @@ -0,0 +1,125 @@ +use std::marker::PhantomData; +use std::task::{Context, Poll}; + +use actix_service::{Service, ServiceFactory}; +use futures::future::{ok, Ready}; + +use crate::publish::Publish; +use crate::subs::{Subscribe, SubscribeResult, Unsubscribe}; + +/// Not implemented publish service +pub struct NotImplemented(PhantomData<(S, E)>); + +impl Default for NotImplemented { + fn default() -> Self { + NotImplemented(PhantomData) + } +} + +impl ServiceFactory for NotImplemented { + type Config = S; + type Request = Publish; + type Response = (); + type Error = E; + type InitError = E; + type Service = NotImplemented; + type Future = Ready>; + + fn new_service(&self, _: S) -> Self::Future { + ok(NotImplemented(PhantomData)) + } +} + +impl Service for NotImplemented { + type Request = Publish; + type Response = (); + type Error = E; + type Future = Ready>; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: Publish) -> Self::Future { + log::warn!("MQTT Publish is not implemented"); + ok(()) + } +} + +/// Not implemented subscribe service +pub struct SubsNotImplemented(PhantomData<(S, E)>); + +impl Default for SubsNotImplemented { + fn default() -> Self { + SubsNotImplemented(PhantomData) + } +} + +impl ServiceFactory for SubsNotImplemented { + type Config = S; + type Request = Subscribe; + type Response = SubscribeResult; + type Error = E; + type InitError = E; + type Service = SubsNotImplemented; + type Future = Ready>; + + fn new_service(&self, _: S) -> Self::Future { + ok(SubsNotImplemented(PhantomData)) + } +} + +impl Service for SubsNotImplemented { + type Request = Subscribe; + type Response = SubscribeResult; + type Error = E; + type Future = Ready>; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, subs: Subscribe) -> Self::Future { + log::warn!("MQTT Subscribe is not implemented"); + ok(subs.into_result()) + } +} + +/// Not implemented unsubscribe service +pub struct UnsubsNotImplemented(PhantomData<(S, E)>); + +impl Default for UnsubsNotImplemented { + fn default() -> Self { + UnsubsNotImplemented(PhantomData) + } +} + +impl ServiceFactory for UnsubsNotImplemented { + type Config = S; + type Request = Unsubscribe; + type Response = (); + type Error = E; + type InitError = E; + type Service = UnsubsNotImplemented; + type Future = Ready>; + + fn new_service(&self, _: S) -> Self::Future { + ok(UnsubsNotImplemented(PhantomData)) + } +} + +impl Service for UnsubsNotImplemented { + type Request = Unsubscribe; + type Response = (); + type Error = E; + type Future = Ready>; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: Unsubscribe) -> Self::Future { + log::warn!("MQTT Unsubscribe is not implemented"); + ok(()) + } +} diff --git a/actix-mqtt/src/dispatcher.rs b/actix-mqtt/src/dispatcher.rs new file mode 100755 index 000000000..1500fa236 --- /dev/null +++ b/actix-mqtt/src/dispatcher.rs @@ -0,0 +1,286 @@ +use std::future::Future; +use std::marker::PhantomData; +use std::num::NonZeroU16; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; +use std::time::Duration; + +use actix_ioframe as ioframe; +use actix_service::{boxed, fn_factory_with_config, pipeline, Service, ServiceFactory}; +use actix_utils::inflight::InFlightService; +use actix_utils::keepalive::KeepAliveService; +use actix_utils::order::{InOrder, InOrderError}; +use actix_utils::time::LowResTimeService; +use futures::future::{join3, ok, Either, FutureExt, LocalBoxFuture, Ready}; +use futures::ready; +use mqtt_codec as mqtt; + +use crate::cell::Cell; +use crate::error::MqttError; +use crate::publish::Publish; +use crate::sink::MqttSink; +use crate::subs::{Subscribe, SubscribeResult, Unsubscribe}; + +pub(crate) struct MqttState { + inner: Cell>, +} + +struct MqttStateInner { + pub(crate) st: St, + pub(crate) sink: MqttSink, + pub(self) timeout: Duration, + pub(self) in_flight: usize, +} + +impl Clone for MqttState { + fn clone(&self) -> Self { + MqttState { + inner: self.inner.clone(), + } + } +} + +impl MqttState { + pub(crate) fn new(st: St, sink: MqttSink, timeout: Duration, in_flight: usize) -> Self { + MqttState { + inner: Cell::new(MqttStateInner { + st, + sink, + timeout, + in_flight, + }), + } + } + + pub(crate) fn sink(&self) -> &MqttSink { + &self.inner.sink + } + + pub(crate) fn session(&self) -> &St { + &self.inner.get_ref().st + } + + pub(crate) fn session_mut(&mut self) -> &mut St { + &mut self.inner.get_mut().st + } +} + +// dispatcher factory +pub(crate) fn dispatcher( + publish: T, + subscribe: Rc< + boxed::BoxServiceFactory< + St, + Subscribe, + SubscribeResult, + MqttError, + MqttError, + >, + >, + unsubscribe: Rc< + boxed::BoxServiceFactory, (), MqttError, MqttError>, + >, +) -> impl ServiceFactory< + Config = MqttState, + Request = ioframe::Item, mqtt::Codec>, + Response = Option, + Error = MqttError, + InitError = MqttError, +> +where + E: 'static, + St: Clone + 'static, + T: ServiceFactory< + Config = St, + Request = Publish, + Response = (), + Error = MqttError, + InitError = MqttError, + > + 'static, +{ + let time = LowResTimeService::with(Duration::from_secs(1)); + + fn_factory_with_config(move |cfg: MqttState| { + let time = time.clone(); + let state = cfg.session().clone(); + let timeout = cfg.inner.timeout; + let inflight = cfg.inner.in_flight; + + // create services + let fut = join3( + publish.new_service(state.clone()), + subscribe.new_service(state.clone()), + unsubscribe.new_service(state.clone()), + ); + + async move { + let (publish, subscribe, unsubscribe) = fut.await; + + // mqtt dispatcher + Ok(Dispatcher::new( + // keep-alive connection + pipeline(KeepAliveService::new(timeout, time, || { + MqttError::KeepAliveTimeout + })) + .and_then( + // limit number of in-flight messages + InFlightService::new( + inflight, + // mqtt spec requires ack ordering, so enforce response ordering + InOrder::service(publish?).map_err(|e| match e { + InOrderError::Service(e) => e, + InOrderError::Disconnected => MqttError::Disconnected, + }), + ), + ), + subscribe?, + unsubscribe?, + )) + } + }) +} + +/// PUBLIS/SUBSCRIBER/UNSUBSCRIBER packets dispatcher +pub(crate) struct Dispatcher { + publish: T, + subscribe: boxed::BoxService, SubscribeResult, T::Error>, + unsubscribe: boxed::BoxService, (), T::Error>, +} + +impl Dispatcher +where + T: Service, Response = ()>, +{ + pub(crate) fn new( + publish: T, + subscribe: boxed::BoxService, SubscribeResult, T::Error>, + unsubscribe: boxed::BoxService, (), T::Error>, + ) -> Self { + Self { + publish, + subscribe, + unsubscribe, + } + } +} + +impl Service for Dispatcher +where + T: Service, Response = ()>, + T::Error: 'static, +{ + type Request = ioframe::Item, mqtt::Codec>; + type Response = Option; + type Error = T::Error; + type Future = Either< + Either< + Ready>, + LocalBoxFuture<'static, Result>, + >, + PublishResponse, + >; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + let res1 = self.publish.poll_ready(cx)?; + let res2 = self.subscribe.poll_ready(cx)?; + let res3 = self.unsubscribe.poll_ready(cx)?; + + if res1.is_pending() || res2.is_pending() || res3.is_pending() { + Poll::Pending + } else { + Poll::Ready(Ok(())) + } + } + + fn call(&mut self, req: ioframe::Item, mqtt::Codec>) -> Self::Future { + let (mut state, _, packet) = req.into_parts(); + + log::trace!("Dispatch packet: {:#?}", packet); + match packet { + mqtt::Packet::PingRequest => { + Either::Left(Either::Left(ok(Some(mqtt::Packet::PingResponse)))) + } + mqtt::Packet::Disconnect => Either::Left(Either::Left(ok(None))), + mqtt::Packet::Publish(publish) => { + let packet_id = publish.packet_id; + Either::Right(PublishResponse { + packet_id, + fut: self.publish.call(Publish::new(state, publish)), + _t: PhantomData, + }) + } + mqtt::Packet::PublishAck { packet_id } => { + state.inner.get_mut().sink.complete_publish_qos1(packet_id); + Either::Left(Either::Left(ok(None))) + } + mqtt::Packet::Subscribe { + packet_id, + topic_filters, + } => Either::Left(Either::Right( + SubscribeResponse { + packet_id, + fut: self.subscribe.call(Subscribe::new(state, topic_filters)), + } + .boxed_local(), + )), + mqtt::Packet::Unsubscribe { + packet_id, + topic_filters, + } => Either::Left(Either::Right( + self.unsubscribe + .call(Unsubscribe::new(state, topic_filters)) + .map(move |_| Ok(Some(mqtt::Packet::UnsubscribeAck { packet_id }))) + .boxed_local(), + )), + _ => Either::Left(Either::Left(ok(None))), + } + } +} + +/// Publish service response future +#[pin_project::pin_project] +pub(crate) struct PublishResponse { + #[pin] + fut: T, + packet_id: Option, + _t: PhantomData, +} + +impl Future for PublishResponse +where + T: Future>, +{ + type Output = Result, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); + + ready!(this.fut.poll(cx))?; + if let Some(packet_id) = this.packet_id { + Poll::Ready(Ok(Some(mqtt::Packet::PublishAck { + packet_id: *packet_id, + }))) + } else { + Poll::Ready(Ok(None)) + } + } +} + +/// Subscribe service response future +pub(crate) struct SubscribeResponse { + fut: LocalBoxFuture<'static, Result>, + packet_id: NonZeroU16, +} + +impl Future for SubscribeResponse { + type Output = Result, E>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let res = ready!(Pin::new(&mut self.fut).poll(cx))?; + Poll::Ready(Ok(Some(mqtt::Packet::SubscribeAck { + status: res.codes, + packet_id: self.packet_id, + }))) + } +} diff --git a/actix-mqtt/src/error.rs b/actix-mqtt/src/error.rs new file mode 100755 index 000000000..19537447a --- /dev/null +++ b/actix-mqtt/src/error.rs @@ -0,0 +1,34 @@ +use std::io; + +/// Errors which can occur when attempting to handle mqtt connection. +#[derive(Debug)] +pub enum MqttError { + /// Message handler service error + Service(E), + /// Mqtt parse error + Protocol(mqtt_codec::ParseError), + /// Unexpected packet + Unexpected(mqtt_codec::Packet, &'static str), + /// "SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier [MQTT-2.3.1-1]." + PacketIdRequired, + /// Keep alive timeout + KeepAliveTimeout, + /// Handshake timeout + HandshakeTimeout, + /// Peer disconnect + Disconnected, + /// Unexpected io error + Io(io::Error), +} + +impl From for MqttError { + fn from(err: mqtt_codec::ParseError) -> Self { + MqttError::Protocol(err) + } +} + +impl From for MqttError { + fn from(err: io::Error) -> Self { + MqttError::Io(err) + } +} diff --git a/actix-mqtt/src/lib.rs b/actix-mqtt/src/lib.rs new file mode 100755 index 000000000..36b6e9c3e --- /dev/null +++ b/actix-mqtt/src/lib.rs @@ -0,0 +1,23 @@ +#![allow(clippy::type_complexity, clippy::new_ret_no_self)] +//! MQTT v3.1 Server framework + +mod cell; +pub mod client; +mod connect; +mod default; +mod dispatcher; +mod error; +mod publish; +mod router; +mod server; +mod sink; +mod subs; + +pub use self::client::Client; +pub use self::connect::{Connect, ConnectAck}; +pub use self::error::MqttError; +pub use self::publish::Publish; +pub use self::router::Router; +pub use self::server::MqttServer; +pub use self::sink::MqttSink; +pub use self::subs::{Subscribe, SubscribeIter, SubscribeResult, Subscription, Unsubscribe}; diff --git a/actix-mqtt/src/publish.rs b/actix-mqtt/src/publish.rs new file mode 100755 index 000000000..73463e046 --- /dev/null +++ b/actix-mqtt/src/publish.rs @@ -0,0 +1,137 @@ +use std::convert::TryFrom; +use std::num::NonZeroU16; + +use actix_router::Path; +use bytes::Bytes; +use bytestring::ByteString; +use mqtt_codec as mqtt; +use serde::de::DeserializeOwned; +use serde_json::Error as JsonError; + +use crate::dispatcher::MqttState; +use crate::sink::MqttSink; + +/// Publish message +pub struct Publish { + publish: mqtt::Publish, + sink: MqttSink, + state: MqttState, + topic: Path, + query: Option, +} + +impl Publish { + pub(crate) fn new(state: MqttState, publish: mqtt::Publish) -> Self { + let (topic, query) = if let Some(pos) = publish.topic.find('?') { + ( + ByteString::try_from(publish.topic.get_ref().slice(0..pos)).unwrap(), + Some( + ByteString::try_from( + publish.topic.get_ref().slice(pos + 1..publish.topic.len()), + ) + .unwrap(), + ), + ) + } else { + (publish.topic.clone(), None) + }; + let topic = Path::new(topic); + let sink = state.sink().clone(); + Self { + sink, + publish, + state, + topic, + query, + } + } + + #[inline] + /// this might be re-delivery of an earlier attempt to send the Packet. + pub fn dup(&self) -> bool { + self.publish.dup + } + + #[inline] + pub fn retain(&self) -> bool { + self.publish.retain + } + + #[inline] + /// the level of assurance for delivery of an Application Message. + pub fn qos(&self) -> mqtt::QoS { + self.publish.qos + } + + #[inline] + /// the information channel to which payload data is published. + pub fn publish_topic(&self) -> &str { + &self.publish.topic + } + + #[inline] + /// returns reference to a connection session + pub fn session(&self) -> &S { + self.state.session() + } + + #[inline] + /// returns mutable reference to a connection session + pub fn session_mut(&mut self) -> &mut S { + self.state.session_mut() + } + + #[inline] + /// only present in PUBLISH Packets where the QoS level is 1 or 2. + pub fn id(&self) -> Option { + self.publish.packet_id + } + + #[inline] + pub fn topic(&self) -> &Path { + &self.topic + } + + #[inline] + pub fn topic_mut(&mut self) -> &mut Path { + &mut self.topic + } + + #[inline] + pub fn query(&self) -> &str { + self.query.as_ref().map(|s| s.as_ref()).unwrap_or("") + } + + #[inline] + pub fn packet(&self) -> &mqtt::Publish { + &self.publish + } + + #[inline] + /// the Application Message that is being published. + pub fn payload(&self) -> &Bytes { + &self.publish.payload + } + + /// Extract Bytes from packet payload + pub fn take_payload(&self) -> Bytes { + self.publish.payload.clone() + } + + #[inline] + /// Mqtt client sink object + pub fn sink(&self) -> &MqttSink { + &self.sink + } + + /// Loads and parse `application/json` encoded body. + pub fn json(&mut self) -> Result { + serde_json::from_slice(&self.publish.payload) + } +} + +impl std::fmt::Debug for Publish { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.publish.fmt(f) + } +} diff --git a/actix-mqtt/src/router.rs b/actix-mqtt/src/router.rs new file mode 100755 index 000000000..6ce129b7f --- /dev/null +++ b/actix-mqtt/src/router.rs @@ -0,0 +1,206 @@ +use std::future::Future; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; + +use actix_router::{IntoPattern, RouterBuilder}; +use actix_service::boxed::{self, BoxService, BoxServiceFactory}; +use actix_service::{fn_service, IntoServiceFactory, Service, ServiceFactory}; +use futures::future::{join_all, ok, JoinAll, LocalBoxFuture}; + +use crate::publish::Publish; + +type Handler = BoxServiceFactory, (), E, E>; +type HandlerService = BoxService, (), E>; + +/// Router - structure that follows the builder pattern +/// for building publish packet router instances for mqtt server. +pub struct Router { + router: RouterBuilder, + handlers: Vec>, + default: Handler, +} + +impl Router +where + S: Clone + 'static, + E: 'static, +{ + /// Create mqtt application. + /// + /// **Note** Default service acks all publish packets + pub fn new() -> Self { + Router { + router: actix_router::Router::build(), + handlers: Vec::new(), + default: boxed::factory( + fn_service(|p: Publish| { + log::warn!("Unknown topic {:?}", p.publish_topic()); + ok::<_, E>(()) + }) + .map_init_err(|_| panic!()), + ), + } + } + + /// Configure mqtt resource for a specific topic. + pub fn resource(mut self, address: T, service: F) -> Self + where + T: IntoPattern, + F: IntoServiceFactory, + U: ServiceFactory, Response = (), Error = E>, + E: From, + { + self.router.path(address, self.handlers.len()); + self.handlers + .push(boxed::factory(service.into_factory().map_init_err(E::from))); + self + } + + /// Default service to be used if no matching resource could be found. + pub fn default_resource(mut self, service: F) -> Self + where + F: IntoServiceFactory, + U: ServiceFactory< + Config = S, + Request = Publish, + Response = (), + Error = E, + InitError = E, + >, + { + self.default = boxed::factory(service.into_factory()); + self + } +} + +impl IntoServiceFactory> for Router +where + S: Clone + 'static, + E: 'static, +{ + fn into_factory(self) -> RouterFactory { + RouterFactory { + router: Rc::new(self.router.finish()), + handlers: self.handlers, + default: self.default, + } + } +} + +pub struct RouterFactory { + router: Rc>, + handlers: Vec>, + default: Handler, +} + +impl ServiceFactory for RouterFactory +where + S: Clone + 'static, + E: 'static, +{ + type Config = S; + type Request = Publish; + type Response = (); + type Error = E; + type InitError = E; + type Service = RouterService; + type Future = RouterFactoryFut; + + fn new_service(&self, session: S) -> Self::Future { + let fut: Vec<_> = self + .handlers + .iter() + .map(|h| h.new_service(session.clone())) + .collect(); + + RouterFactoryFut { + router: self.router.clone(), + handlers: join_all(fut), + default: Some(either::Either::Left(self.default.new_service(session))), + } + } +} + +pub struct RouterFactoryFut { + router: Rc>, + handlers: JoinAll, E>>>, + default: Option< + either::Either< + LocalBoxFuture<'static, Result, E>>, + HandlerService, + >, + >, +} + +impl Future for RouterFactoryFut { + type Output = Result, E>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let res = match self.default.as_mut().unwrap() { + either::Either::Left(ref mut fut) => { + let default = match futures::ready!(Pin::new(fut).poll(cx)) { + Ok(default) => default, + Err(e) => return Poll::Ready(Err(e)), + }; + self.default = Some(either::Either::Right(default)); + return self.poll(cx); + } + either::Either::Right(_) => futures::ready!(Pin::new(&mut self.handlers).poll(cx)), + }; + + let mut handlers = Vec::new(); + for handler in res { + match handler { + Ok(h) => handlers.push(h), + Err(e) => return Poll::Ready(Err(e)), + } + } + + Poll::Ready(Ok(RouterService { + handlers, + router: self.router.clone(), + default: self.default.take().unwrap().right().unwrap(), + })) + } +} + +pub struct RouterService { + router: Rc>, + handlers: Vec, (), E>>, + default: BoxService, (), E>, +} + +impl Service for RouterService +where + S: 'static, + E: 'static, +{ + type Request = Publish; + type Response = (); + type Error = E; + type Future = LocalBoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + let mut not_ready = false; + for hnd in &mut self.handlers { + if let Poll::Pending = hnd.poll_ready(cx)? { + not_ready = true; + } + } + + if not_ready { + Poll::Pending + } else { + Poll::Ready(Ok(())) + } + } + + fn call(&mut self, mut req: Publish) -> Self::Future { + if let Some((idx, _info)) = self.router.recognize(req.topic_mut()) { + self.handlers[*idx].call(req) + } else { + self.default.call(req) + } + } +} diff --git a/actix-mqtt/src/server.rs b/actix-mqtt/src/server.rs new file mode 100755 index 000000000..e9a849343 --- /dev/null +++ b/actix-mqtt/src/server.rs @@ -0,0 +1,331 @@ +use std::future::Future; +use std::marker::PhantomData; +use std::rc::Rc; +use std::time::Duration; + +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_ioframe as ioframe; +use actix_service::{apply, apply_fn, boxed, fn_factory, pipeline_factory, unit_config}; +use actix_service::{IntoServiceFactory, Service, ServiceFactory}; +use actix_utils::timeout::{Timeout, TimeoutError}; +use futures::{FutureExt, SinkExt, StreamExt}; +use mqtt_codec as mqtt; + +use crate::cell::Cell; +use crate::connect::{Connect, ConnectAck}; +use crate::default::{SubsNotImplemented, UnsubsNotImplemented}; +use crate::dispatcher::{dispatcher, MqttState}; +use crate::error::MqttError; +use crate::publish::Publish; +use crate::sink::MqttSink; +use crate::subs::{Subscribe, SubscribeResult, Unsubscribe}; + +/// Mqtt Server +pub struct MqttServer { + connect: C, + subscribe: boxed::BoxServiceFactory< + St, + Subscribe, + SubscribeResult, + MqttError, + MqttError, + >, + unsubscribe: boxed::BoxServiceFactory< + St, + Unsubscribe, + (), + MqttError, + MqttError, + >, + disconnect: U, + max_size: usize, + inflight: usize, + handshake_timeout: u64, + _t: PhantomData<(Io, St)>, +} + +fn default_disconnect(_: St, _: bool) {} + +impl MqttServer +where + St: 'static, + C: ServiceFactory, Response = ConnectAck> + + 'static, +{ + /// Create server factory and provide connect service + pub fn new(connect: F) -> MqttServer + where + F: IntoServiceFactory, + { + MqttServer { + connect: connect.into_factory(), + subscribe: boxed::factory( + pipeline_factory(SubsNotImplemented::default()) + .map_err(MqttError::Service) + .map_init_err(MqttError::Service), + ), + unsubscribe: boxed::factory( + pipeline_factory(UnsubsNotImplemented::default()) + .map_err(MqttError::Service) + .map_init_err(MqttError::Service), + ), + max_size: 0, + inflight: 15, + disconnect: default_disconnect, + handshake_timeout: 0, + _t: PhantomData, + } + } +} + +impl MqttServer +where + St: Clone + 'static, + U: Fn(St, bool) + 'static, + C: ServiceFactory, Response = ConnectAck> + + 'static, +{ + /// Set handshake timeout in millis. + /// + /// Handshake includes `connect` packet and response `connect-ack`. + /// By default handshake timeuot is disabled. + pub fn handshake_timeout(mut self, timeout: u64) -> Self { + self.handshake_timeout = timeout; + self + } + + /// Set max inbound frame size. + /// + /// If max size is set to `0`, size is unlimited. + /// By default max size is set to `0` + pub fn max_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } + + /// Number of in-flight concurrent messages. + /// + /// in-flight is set to 15 messages + pub fn inflight(mut self, val: usize) -> Self { + self.inflight = val; + self + } + + /// Service to execute for subscribe packet + pub fn subscribe(mut self, subscribe: F) -> Self + where + F: IntoServiceFactory, + Srv: ServiceFactory, Response = SubscribeResult> + + 'static, + C::Error: From + From, + { + self.subscribe = boxed::factory( + subscribe + .into_factory() + .map_err(|e| MqttError::Service(e.into())) + .map_init_err(|e| MqttError::Service(e.into())), + ); + self + } + + /// Service to execute for unsubscribe packet + pub fn unsubscribe(mut self, unsubscribe: F) -> Self + where + F: IntoServiceFactory, + Srv: ServiceFactory, Response = ()> + 'static, + C::Error: From + From, + { + self.unsubscribe = boxed::factory( + unsubscribe + .into_factory() + .map_err(|e| MqttError::Service(e.into())) + .map_init_err(|e| MqttError::Service(e.into())), + ); + self + } + + /// Callback to execute on disconnect + /// + /// Second parameter indicates error occured during disconnect. + pub fn disconnect(self, disconnect: F) -> MqttServer + where + F: Fn(St, bool) -> Out, + Out: Future + 'static, + { + MqttServer { + connect: self.connect, + subscribe: self.subscribe, + unsubscribe: self.unsubscribe, + max_size: self.max_size, + inflight: self.inflight, + handshake_timeout: self.handshake_timeout, + disconnect: move |st: St, err| { + let fut = disconnect(st, err); + actix_rt::spawn(fut.map(|_| ())); + }, + _t: PhantomData, + } + } + + /// Set service to execute for publish packet and create service factory + pub fn finish( + self, + publish: F, + ) -> impl ServiceFactory> + where + Io: AsyncRead + AsyncWrite + 'static, + F: IntoServiceFactory

, + P: ServiceFactory, Response = ()> + 'static, + C::Error: From + From, + { + let connect = self.connect; + let max_size = self.max_size; + let handshake_timeout = self.handshake_timeout; + let disconnect = self.disconnect; + let publish = boxed::factory( + publish + .into_factory() + .map_err(|e| MqttError::Service(e.into())) + .map_init_err(|e| MqttError::Service(e.into())), + ); + + unit_config( + ioframe::Builder::new() + .factory(connect_service_factory( + connect, + max_size, + self.inflight, + handshake_timeout, + )) + .disconnect(move |cfg, err| disconnect(cfg.session().clone(), err)) + .finish(dispatcher( + publish, + Rc::new(self.subscribe), + Rc::new(self.unsubscribe), + )) + .map_err(|e| match e { + ioframe::ServiceError::Service(e) => e, + ioframe::ServiceError::Encoder(e) => MqttError::Protocol(e), + ioframe::ServiceError::Decoder(e) => MqttError::Protocol(e), + }), + ) + } +} + +fn connect_service_factory( + factory: C, + max_size: usize, + inflight: usize, + handshake_timeout: u64, +) -> impl ServiceFactory< + Config = (), + Request = ioframe::Connect, + Response = ioframe::ConnectResult, mqtt::Codec>, + Error = MqttError, +> +where + Io: AsyncRead + AsyncWrite, + C: ServiceFactory, Response = ConnectAck>, +{ + apply( + Timeout::new(Duration::from_millis(handshake_timeout)), + fn_factory(move || { + let fut = factory.new_service(()); + + async move { + let service = Cell::new(fut.await?); + + Ok::<_, C::InitError>(apply_fn( + service.map_err(MqttError::Service), + move |conn: ioframe::Connect, service| { + let mut srv = service.clone(); + let mut framed = conn.codec(mqtt::Codec::new().max_size(max_size)); + + async move { + // read first packet + let packet = framed + .next() + .await + .ok_or(MqttError::Disconnected) + .and_then(|res| res.map_err(|e| MqttError::Protocol(e)))?; + + match packet { + mqtt::Packet::Connect(connect) => { + let sink = MqttSink::new(framed.sink().clone()); + + // authenticate mqtt connection + let mut ack = srv + .call(Connect::new( + connect, + framed, + sink.clone(), + inflight, + )) + .await?; + + match ack.session { + Some(session) => { + log::trace!( + "Sending: {:#?}", + mqtt::Packet::ConnectAck { + session_present: ack.session_present, + return_code: + mqtt::ConnectCode::ConnectionAccepted, + } + ); + ack.io + .send(mqtt::Packet::ConnectAck { + session_present: ack.session_present, + return_code: + mqtt::ConnectCode::ConnectionAccepted, + }) + .await?; + + Ok(ack.io.state(MqttState::new( + session, + sink, + ack.keep_alive, + ack.inflight, + ))) + } + None => { + log::trace!( + "Sending: {:#?}", + mqtt::Packet::ConnectAck { + session_present: false, + return_code: ack.return_code, + } + ); + + ack.io + .send(mqtt::Packet::ConnectAck { + session_present: false, + return_code: ack.return_code, + }) + .await?; + Err(MqttError::Disconnected) + } + } + } + packet => { + log::info!( + "MQTT-3.1.0-1: Expected CONNECT packet, received {}", + packet.packet_type() + ); + Err(MqttError::Unexpected( + packet, + "MQTT-3.1.0-1: Expected CONNECT packet", + )) + } + } + } + }, + )) + } + }), + ) + .map_err(|e| match e { + TimeoutError::Service(e) => e, + TimeoutError::Timeout => MqttError::HandshakeTimeout, + }) +} diff --git a/actix-mqtt/src/sink.rs b/actix-mqtt/src/sink.rs new file mode 100755 index 000000000..5afd23c28 --- /dev/null +++ b/actix-mqtt/src/sink.rs @@ -0,0 +1,107 @@ +use std::collections::VecDeque; +use std::fmt; +use std::num::NonZeroU16; + +use actix_ioframe::Sink; +use actix_utils::oneshot; +use bytes::Bytes; +use bytestring::ByteString; +use futures::future::{Future, TryFutureExt}; +use mqtt_codec as mqtt; + +use crate::cell::Cell; + +#[derive(Clone)] +pub struct MqttSink { + sink: Sink, + pub(crate) inner: Cell, +} + +#[derive(Default)] +pub(crate) struct MqttSinkInner { + pub(crate) idx: u16, + pub(crate) queue: VecDeque<(u16, oneshot::Sender<()>)>, +} + +impl MqttSink { + pub(crate) fn new(sink: Sink) -> Self { + MqttSink { + sink, + inner: Cell::new(MqttSinkInner::default()), + } + } + + /// Close mqtt connection + pub fn close(&self) { + self.sink.close(); + } + + /// Send publish packet with qos set to 0 + pub fn publish_qos0(&self, topic: ByteString, payload: Bytes, dup: bool) { + log::trace!("Publish (QoS0) to {:?}", topic); + let publish = mqtt::Publish { + topic, + payload, + dup, + retain: false, + qos: mqtt::QoS::AtMostOnce, + packet_id: None, + }; + self.sink.send(mqtt::Packet::Publish(publish)); + } + + /// Send publish packet + pub fn publish_qos1( + &mut self, + topic: ByteString, + payload: Bytes, + dup: bool, + ) -> impl Future> { + let (tx, rx) = oneshot::channel(); + + let inner = self.inner.get_mut(); + inner.idx += 1; + if inner.idx == 0 { + inner.idx = 1 + } + inner.queue.push_back((inner.idx, tx)); + + let publish = mqtt::Packet::Publish(mqtt::Publish { + topic, + payload, + dup, + retain: false, + qos: mqtt::QoS::AtLeastOnce, + packet_id: NonZeroU16::new(inner.idx), + }); + log::trace!("Publish (QoS1) to {:#?}", publish); + + self.sink.send(publish); + rx.map_err(|_| ()) + } + + pub(crate) fn complete_publish_qos1(&mut self, packet_id: NonZeroU16) { + if let Some((idx, tx)) = self.inner.get_mut().queue.pop_front() { + if idx != packet_id.get() { + log::trace!( + "MQTT protocol error, packet_id order does not match, expected {}, got: {}", + idx, + packet_id + ); + self.close(); + } else { + log::trace!("Ack publish packet with id: {}", packet_id); + let _ = tx.send(()); + } + } else { + log::trace!("Unexpected PublishAck packet"); + self.close(); + } + } +} + +impl fmt::Debug for MqttSink { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.debug_struct("MqttSink").finish() + } +} diff --git a/actix-mqtt/src/subs.rs b/actix-mqtt/src/subs.rs new file mode 100755 index 000000000..884b4a200 --- /dev/null +++ b/actix-mqtt/src/subs.rs @@ -0,0 +1,191 @@ +use std::marker::PhantomData; + +use bytestring::ByteString; +use mqtt_codec as mqtt; + +use crate::dispatcher::MqttState; +use crate::sink::MqttSink; + +/// Subscribe message +pub struct Subscribe { + topics: Vec<(ByteString, mqtt::QoS)>, + codes: Vec, + state: MqttState, +} + +/// Result of a subscribe message +pub struct SubscribeResult { + pub(crate) codes: Vec, +} + +impl Subscribe { + pub(crate) fn new(state: MqttState, topics: Vec<(ByteString, mqtt::QoS)>) -> Self { + let mut codes = Vec::with_capacity(topics.len()); + (0..topics.len()).for_each(|_| codes.push(mqtt::SubscribeReturnCode::Failure)); + + Self { + topics, + state, + codes, + } + } + + #[inline] + /// returns reference to a connection session + pub fn session(&self) -> &S { + self.state.session() + } + + #[inline] + /// returns mutable reference to a connection session + pub fn session_mut(&mut self) -> &mut S { + self.state.session_mut() + } + + #[inline] + /// Mqtt client sink object + pub fn sink(&self) -> MqttSink { + self.state.sink().clone() + } + + #[inline] + /// returns iterator over subscription topics + pub fn iter_mut(&mut self) -> SubscribeIter { + SubscribeIter { + subs: self as *const _ as *mut _, + entry: 0, + lt: PhantomData, + } + } + + #[inline] + /// convert subscription to a result + pub fn into_result(self) -> SubscribeResult { + SubscribeResult { codes: self.codes } + } +} + +impl<'a, S> IntoIterator for &'a mut Subscribe { + type Item = Subscription<'a, S>; + type IntoIter = SubscribeIter<'a, S>; + + fn into_iter(self) -> SubscribeIter<'a, S> { + self.iter_mut() + } +} + +/// Iterator over subscription topics +pub struct SubscribeIter<'a, S> { + subs: *mut Subscribe, + entry: usize, + lt: PhantomData<&'a mut Subscribe>, +} + +impl<'a, S> SubscribeIter<'a, S> { + fn next_unsafe(&mut self) -> Option> { + let subs = unsafe { &mut *self.subs }; + + if self.entry < subs.topics.len() { + let s = Subscription { + topic: &subs.topics[self.entry].0, + qos: subs.topics[self.entry].1, + state: subs.state.clone(), + code: &mut subs.codes[self.entry], + }; + self.entry += 1; + Some(s) + } else { + None + } + } +} + +impl<'a, S> Iterator for SubscribeIter<'a, S> { + type Item = Subscription<'a, S>; + + #[inline] + fn next(&mut self) -> Option> { + self.next_unsafe() + } +} + +/// Subscription topic +pub struct Subscription<'a, S> { + topic: &'a ByteString, + state: MqttState, + qos: mqtt::QoS, + code: &'a mut mqtt::SubscribeReturnCode, +} + +impl<'a, S> Subscription<'a, S> { + #[inline] + /// reference to a connection session + pub fn session(&self) -> &S { + self.state.session() + } + + #[inline] + /// mutable reference to a connection session + pub fn session_mut(&mut self) -> &mut S { + self.state.session_mut() + } + + #[inline] + /// subscription topic + pub fn topic(&self) -> &'a ByteString { + &self.topic + } + + #[inline] + /// the level of assurance for delivery of an Application Message. + pub fn qos(&self) -> mqtt::QoS { + self.qos + } + + #[inline] + /// fail to subscribe to the topic + pub fn fail(&mut self) { + *self.code = mqtt::SubscribeReturnCode::Failure + } + + #[inline] + /// subscribe to a topic with specific qos + pub fn subscribe(&mut self, qos: mqtt::QoS) { + *self.code = mqtt::SubscribeReturnCode::Success(qos) + } +} + +/// Unsubscribe message +pub struct Unsubscribe { + state: MqttState, + topics: Vec, +} + +impl Unsubscribe { + pub(crate) fn new(state: MqttState, topics: Vec) -> Self { + Self { topics, state } + } + + #[inline] + /// reference to a connection session + pub fn session(&self) -> &S { + self.state.session() + } + + #[inline] + /// mutable reference to a connection session + pub fn session_mut(&mut self) -> &mut S { + self.state.session_mut() + } + + #[inline] + /// Mqtt client sink object + pub fn sink(&self) -> MqttSink { + self.state.sink().clone() + } + + /// returns iterator over unsubscribe topics + pub fn iter(&self) -> impl Iterator { + self.topics.iter() + } +} diff --git a/actix-mqtt/tests/test_server.rs b/actix-mqtt/tests/test_server.rs new file mode 100755 index 000000000..29abf621c --- /dev/null +++ b/actix-mqtt/tests/test_server.rs @@ -0,0 +1,52 @@ +use actix_service::Service; +use actix_testing::TestServer; +use bytes::Bytes; +use bytestring::ByteString; +use futures::future::ok; + +use actix_mqtt::{client, Connect, ConnectAck, MqttServer, Publish}; + +#[derive(Clone)] +struct Session; + +async fn connect(packet: Connect) -> Result, ()> { + println!("CONNECT: {:?}", packet); + Ok(packet.ack(Session, false)) +} + +#[actix_rt::test] +async fn test_simple() -> std::io::Result<()> { + std::env::set_var( + "RUST_LOG", + "actix_codec=info,actix_server=trace,actix_connector=trace", + ); + env_logger::init(); + + let srv = TestServer::with(|| MqttServer::new(connect).finish(|_t| ok(()))); + + #[derive(Clone)] + struct ClientSession; + + let mut client = client::Client::new(ByteString::from_static("user")) + .state(|ack: client::ConnectAck<_>| async move { + ack.sink() + .publish_qos0(ByteString::from_static("#"), Bytes::new(), false); + ack.sink().close(); + Ok(ack.state(ClientSession)) + }) + .finish(|_t: Publish<_>| { + async { + // t.sink().close(); + Ok(()) + } + }); + + let conn = actix_connect::default_connector() + .call(actix_connect::Connect::with(String::new(), srv.addr())) + .await + .unwrap(); + + client.call(conn.into_parts().0).await.unwrap(); + + Ok(()) +}