1
0
mirror of https://github.com/actix/actix-extras.git synced 2025-01-22 06:45:56 +01:00

add MQTT and AMQP packages

This commit is contained in:
Rob Ede 2020-09-28 02:59:57 +01:00
parent e3da3094f0
commit 7de64899b4
No known key found for this signature in database
GPG Key ID: C2A3B36E841A91E6
85 changed files with 20566 additions and 7 deletions

1
.gitignore vendored
View File

@ -11,3 +11,4 @@ guide/build/
*.pid
*.sock
*~
.DS_Store

View File

@ -1,13 +1,19 @@
[workspace]
members = [
"actix-cors",
"actix-identity",
"actix-protobuf",
"actix-protobuf/examples/prost-example",
"actix-redis",
"actix-session",
"actix-web-httpauth",
"actix-amqp",
"actix-amqp/codec",
"actix-cors",
"actix-identity",
"actix-mqtt",
"actix-mqtt/codec",
"actix-protobuf",
"actix-protobuf/examples/prost-example",
"actix-redis",
"actix-session",
"actix-web-httpauth",
]
[patch.crates-io]
actix-session = { path = "actix-session" }
mqtt-codec = { path = "actix-mqtt/codec" }
amqp-codec = { path = "actix-amqp/codec" }

25
actix-amqp/CHANGES.md Executable file
View File

@ -0,0 +1,25 @@
# Changes
## Unreleased - 2020-xx-xx
## 0.1.4 - 2020-03-05
* Add server handshake timeout
## 0.1.3 - 2020-02-10
* Allow to override sender link attach frame
## 0.1.2 - 2019-12-25
* Allow to specify multi-pattern for topics
## 0.1.1 - 2019-12-18
* Separate control frame entries for detach sender qand detach receiver
* Proper detach remote receiver
* Replace `async fn` with `impl Future`
## 0.1.0 - 2019-12-11
* Initial release

35
actix-amqp/Cargo.toml Executable file
View File

@ -0,0 +1,35 @@
[package]
name = "actix-amqp"
version = "0.1.4"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "AMQP 1.0 Client/Server framework"
documentation = "https://docs.rs/actix-amqp"
repository = "https://github.com/actix/actix-extras.git"
categories = ["network-programming"]
keywords = ["AMQP", "IoT", "messaging"]
license = "MIT OR Apache-2.0"
edition = "2018"
[dependencies]
amqp-codec = "0.1.0"
actix-codec = "0.2.0"
actix-service = "1.0.1"
actix-connect = "1.0.1"
actix-router = "0.2.2"
actix-utils = "1.0.4"
actix-rt = "1.0.0"
bytes = "0.5.3"
bytestring = "0.1.2"
derive_more = "0.99.2"
either = "1.5.3"
futures = "0.3.1"
fxhash = "0.2.1"
http = "0.2.0"
log = "0.4"
pin-project = "0.4.6"
uuid = { version = "0.8", features = ["v4"] }
slab = "0.4"
[dev-dependencies]
env_logger = "0.6"
actix-testing = "1.0.0"

3
actix-amqp/README.md Executable file
View File

@ -0,0 +1,3 @@
# AMQP 1.0 Server Framework
[![Build Status](https://travis-ci.org/actix/actix-amqp.svg?branch=master)](https://travis-ci.org/actix/actix-amqp)

31
actix-amqp/codec/Cargo.toml Executable file
View File

@ -0,0 +1,31 @@
[package]
name = "amqp-codec"
version = "0.1.0"
description = "AMQP 1.0 Protocol Codec"
authors = ["Nikolay Kim <fafhrd91@gmail.com>", "Max Gortman <mgortman@microsoft.com>", "Mike Yagley <myagley@gmail.com>"]
license = "MIT/Apache-2.0"
edition = "2018"
[dependencies]
actix-codec = "0.2.0"
bytes = "0.5.2"
byteorder = "1.3.1"
bytestring = "0.1.1"
chrono = "0.4"
derive_more = "0.99.2"
fxhash = "0.2.1"
ordered-float = "1.0"
uuid = { version = "0.8", features = ["v4"] }
[build-dependencies]
handlebars = { version = "0.27", optional = true }
serde = { version = "1.0", optional = true }
serde_derive = { version = "1.0", optional = true }
serde_json = { version = "1.0", optional = true }
lazy_static = { version = "1.0", optional = true }
regex = { version = "1.3", optional = true }
[features]
default = []
from-spec = ["handlebars", "serde", "serde_derive", "serde_json", "lazy_static", "regex"]

3
actix-amqp/codec/README.md Executable file
View File

@ -0,0 +1,3 @@
# AMQP 1.0 Protocol Codec
[![Build Status](https://travis-ci.org/fafhrd91/amqp-codec.svg?branch=master)](https://travis-ci.org/fafhrd91/amqp-codec)

102
actix-amqp/codec/build.rs Executable file
View File

@ -0,0 +1,102 @@
#[cfg(feature = "from-spec")]
extern crate handlebars;
#[cfg(feature = "from-spec")]
#[macro_use]
extern crate lazy_static;
#[cfg(feature = "from-spec")]
extern crate serde;
#[cfg(feature = "from-spec")]
#[macro_use]
extern crate serde_derive;
#[cfg(feature = "from-spec")]
extern crate regex;
#[cfg(feature = "from-spec")]
extern crate serde_json;
#[cfg(feature = "from-spec")]
mod codegen;
fn main() {
#[cfg(feature = "from-spec")]
{
generate_from_spec();
}
}
#[cfg(feature = "from-spec")]
fn generate_from_spec() {
use handlebars::{Handlebars, Helper, RenderContext, RenderError};
use std::env;
use std::fs::File;
use std::io::Write;
let template = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/codegen/definitions.rs"
));
let spec = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/codegen/specification.json"
));
let out_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR is not defined")
+ "/src/protocol/";
let definitions = codegen::parse(spec);
let mut codegen = Handlebars::new();
codegen.register_helper(
"snake",
Box::new(
|h: &Helper, _: &Handlebars, rc: &mut RenderContext| -> Result<(), RenderError> {
let value = h
.param(0)
.ok_or_else(|| RenderError::new("Param not found for helper \"snake\""))?;
let param = value.value().as_str().ok_or_else(|| {
RenderError::new("Non-string param given to helper \"snake\"")
})?;
rc.writer.write_all(codegen::snake_case(param).as_bytes())?;
Ok(())
},
),
);
codegen
.register_template_string("definitions", template.to_string())
.expect("Failed to register template.");
let mut data = std::collections::BTreeMap::new();
data.insert("defs", definitions);
let def_path = std::path::Path::new(&out_dir).join("definitions.rs");
{
let mut f = File::create(def_path.clone()).expect("Failed to create target file.");
let rendered = codegen
.render("definitions", &data)
.expect("Failed to render template.");
writeln!(f, "{}", rendered).expect("Failed to write to file.");
}
reformat_file(&def_path);
fn reformat_file(path: &std::path::Path) {
use std::fs::OpenOptions;
use std::io::{Read, Seek};
std::process::Command::new("rustfmt")
.arg(path.to_str().unwrap())
.output()
.expect("failed to format definitions.rs");
let mut f = OpenOptions::new()
.read(true)
.write(true)
.open(path)
.expect("failed to open file");
let mut data = String::new();
f.read_to_string(&mut data).expect("failed to read file.");
let regex = regex::Regex::new("(?P<a>[\r]?\n)[\\s]*\r?\n").unwrap();
data = regex.replace_all(&data, "$a").into();
f.seek(std::io::SeekFrom::Start(0)).unwrap();
f.set_len(data.len() as u64).unwrap();
f.write_all(data.as_bytes())
.expect("Error writing reformatted file");
}
}

View File

@ -0,0 +1,460 @@
#![allow(unused_assignments, unused_variables, unreachable_patterns)]
use std::u8;
use derive_more::From;
use bytes::{BufMut, Bytes, BytesMut};
use bytestring::ByteString;
use uuid::Uuid;
use super::*;
use crate::errors::AmqpParseError;
use crate::codec::{self, decode_format_code, decode_list_header, Decode, DecodeFormatted, Encode};
#[derive(Clone, Debug, PartialEq, From)]
pub enum Frame {
Open(Open),
Begin(Begin),
Attach(Attach),
Flow(Flow),
Transfer(Transfer),
Disposition(Disposition),
Detach(Detach),
End(End),
Close(Close),
Empty,
}
impl Frame {
pub fn name(&self) -> &'static str {
match self {
Frame::Open(_) => "Open",
Frame::Begin(_) => "Begin",
Frame::Attach(_) => "Attach",
Frame::Flow(_) => "Flow",
Frame::Transfer(_) => "Transfer",
Frame::Disposition(_) => "Disposition",
Frame::Detach(_) => "Detach",
Frame::End(_) => "End",
Frame::Close(_) => "Close",
Frame::Empty => "Empty",
}
}
}
impl Decode for Frame {
fn decode(input: &[u8]) -> Result<(&[u8], Self), AmqpParseError> {
if input.is_empty() {
Ok((input, Frame::Empty))
} else {
let (input, fmt) = decode_format_code(input)?;
validate_code!(fmt, codec::FORMATCODE_DESCRIBED);
let (input, descriptor) = Descriptor::decode(input)?;
match descriptor {
Descriptor::Ulong(16) => decode_open_inner(input).map(|(i, r)| (i, Frame::Open(r))),
Descriptor::Ulong(17) => {
decode_begin_inner(input).map(|(i, r)| (i, Frame::Begin(r)))
}
Descriptor::Ulong(18) => {
decode_attach_inner(input).map(|(i, r)| (i, Frame::Attach(r)))
}
Descriptor::Ulong(19) => decode_flow_inner(input).map(|(i, r)| (i, Frame::Flow(r))),
Descriptor::Ulong(20) => {
decode_transfer_inner(input).map(|(i, r)| (i, Frame::Transfer(r)))
}
Descriptor::Ulong(21) => {
decode_disposition_inner(input).map(|(i, r)| (i, Frame::Disposition(r)))
}
Descriptor::Ulong(22) => {
decode_detach_inner(input).map(|(i, r)| (i, Frame::Detach(r)))
}
Descriptor::Ulong(23) => decode_end_inner(input).map(|(i, r)| (i, Frame::End(r))),
Descriptor::Ulong(24) => {
decode_close_inner(input).map(|(i, r)| (i, Frame::Close(r)))
}
Descriptor::Symbol(ref a) if a.as_str() == "amqp:open:list" => {
decode_open_inner(input).map(|(i, r)| (i, Frame::Open(r)))
}
Descriptor::Symbol(ref a) if a.as_str() == "amqp:begin:list" => {
decode_begin_inner(input).map(|(i, r)| (i, Frame::Begin(r)))
}
Descriptor::Symbol(ref a) if a.as_str() == "amqp:attach:list" => {
decode_attach_inner(input).map(|(i, r)| (i, Frame::Attach(r)))
}
Descriptor::Symbol(ref a) if a.as_str() == "amqp:flow:list" => {
decode_flow_inner(input).map(|(i, r)| (i, Frame::Flow(r)))
}
Descriptor::Symbol(ref a) if a.as_str() == "amqp:transfer:list" => {
decode_transfer_inner(input).map(|(i, r)| (i, Frame::Transfer(r)))
}
Descriptor::Symbol(ref a) if a.as_str() == "amqp:disposition:list" => {
decode_disposition_inner(input).map(|(i, r)| (i, Frame::Disposition(r)))
}
Descriptor::Symbol(ref a) if a.as_str() == "amqp:detach:list" => {
decode_detach_inner(input).map(|(i, r)| (i, Frame::Detach(r)))
}
Descriptor::Symbol(ref a) if a.as_str() == "amqp:end:list" => {
decode_end_inner(input).map(|(i, r)| (i, Frame::End(r)))
}
Descriptor::Symbol(ref a) if a.as_str() == "amqp:close:list" => {
decode_close_inner(input).map(|(i, r)| (i, Frame::Close(r)))
}
_ => Err(AmqpParseError::InvalidDescriptor(descriptor)),
}
}
}
}
impl Encode for Frame {
fn encoded_size(&self) -> usize {
match *self {
Frame::Open(ref v) => encoded_size_open_inner(v),
Frame::Begin(ref v) => encoded_size_begin_inner(v),
Frame::Attach(ref v) => encoded_size_attach_inner(v),
Frame::Flow(ref v) => encoded_size_flow_inner(v),
Frame::Transfer(ref v) => encoded_size_transfer_inner(v),
Frame::Disposition(ref v) => encoded_size_disposition_inner(v),
Frame::Detach(ref v) => encoded_size_detach_inner(v),
Frame::End(ref v) => encoded_size_end_inner(v),
Frame::Close(ref v) => encoded_size_close_inner(v),
Frame::Empty => 0,
}
}
fn encode(&self, buf: &mut BytesMut) {
match *self {
Frame::Open(ref v) => encode_open_inner(v, buf),
Frame::Begin(ref v) => encode_begin_inner(v, buf),
Frame::Attach(ref v) => encode_attach_inner(v, buf),
Frame::Flow(ref v) => encode_flow_inner(v, buf),
Frame::Transfer(ref v) => encode_transfer_inner(v, buf),
Frame::Disposition(ref v) => encode_disposition_inner(v, buf),
Frame::Detach(ref v) => encode_detach_inner(v, buf),
Frame::End(ref v) => encode_end_inner(v, buf),
Frame::Close(ref v) => encode_close_inner(v, buf),
Frame::Empty => (),
}
}
}
{{#each defs.provides as |provide|}}
{{#if provide.described}}
#[derive(Clone, Debug, PartialEq)]
pub enum {{provide.name}} {
{{#each provide.options as |option|}}
{{option.ty}}({{option.ty}}),
{{/each}}
}
impl DecodeFormatted for {{provide.name}} {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
validate_code!(fmt, codec::FORMATCODE_DESCRIBED);
let (input, descriptor) = Descriptor::decode(input)?;
match descriptor {
{{#each provide.options as |option|}}
Descriptor::Ulong({{option.descriptor.code}}) => decode_{{snake option.ty}}_inner(input).map(|(i, r)| (i, {{provide.name}}::{{option.ty}}(r))),
{{/each}}
{{#each provide.options as |option|}}
Descriptor::Symbol(ref a) if a.as_str() == "{{option.descriptor.name}}" => decode_{{snake option.ty}}_inner(input).map(|(i, r)| (i, {{provide.name}}::{{option.ty}}(r))),
{{/each}}
_ => Err(AmqpParseError::InvalidDescriptor(descriptor))
}
}
}
impl Encode for {{provide.name}} {
fn encoded_size(&self) -> usize {
match *self {
{{#each provide.options as |option|}}
{{provide.name}}::{{option.ty}}(ref v) => encoded_size_{{snake option.ty}}_inner(v),
{{/each}}
}
}
fn encode(&self, buf: &mut BytesMut) {
match *self {
{{#each provide.options as |option|}}
{{provide.name}}::{{option.ty}}(ref v) => encode_{{snake option.ty}}_inner(v, buf),
{{/each}}
}
}
}
{{/if}}
{{/each}}
{{#each defs.aliases as |alias|}}
pub type {{alias.name}} = {{alias.source}};
{{/each}}
{{#each defs.enums as |enum|}}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum {{enum.name}} {
{{#each enum.items as |item|}}
{{item.name}},
{{/each}}
}
{{#if enum.is_symbol}}
impl {{enum.name}} {
pub fn try_from(v: &Symbol) -> Result<Self, AmqpParseError> {
match v.as_str() {
{{#each enum.items as |item|}}
"{{item.value}}" => Ok({{enum.name}}::{{item.name}}),
{{/each}}
_ => Err(AmqpParseError::UnknownEnumOption("{{enum.name}}"))
}
}
}
impl DecodeFormatted for {{enum.name}} {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
let (input, base) = Symbol::decode_with_format(input, fmt)?;
Ok((input, Self::try_from(&base)?))
}
}
impl Encode for {{enum.name}} {
fn encoded_size(&self) -> usize {
match *self {
{{#each enum.items as |item|}}
{{enum.name}}::{{item.name}} => {{item.value_len}} + 2,
{{/each}}
}
}
fn encode(&self, buf: &mut BytesMut) {
match *self {
{{#each enum.items as |item|}}
{{enum.name}}::{{item.name}} => StaticSymbol("{{item.value}}").encode(buf),
{{/each}}
}
}
}
{{else}}
impl {{enum.name}} {
pub fn try_from(v: {{enum.ty}}) -> Result<Self, AmqpParseError> {
match v {
{{#each enum.items as |item|}}
{{item.value}} => Ok({{enum.name}}::{{item.name}}),
{{/each}}
_ => Err(AmqpParseError::UnknownEnumOption("{{enum.name}}"))
}
}
}
impl DecodeFormatted for {{enum.name}} {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
let (input, base) = {{enum.ty}}::decode_with_format(input, fmt)?;
Ok((input, Self::try_from(base)?))
}
}
impl Encode for {{enum.name}} {
fn encoded_size(&self) -> usize {
match *self {
{{#each enum.items as |item|}}
{{enum.name}}::{{item.name}} => {
let v : {{enum.ty}} = {{item.value}};
v.encoded_size()
},
{{/each}}
}
}
fn encode(&self, buf: &mut BytesMut) {
match *self {
{{#each enum.items as |item|}}
{{enum.name}}::{{item.name}} => {
let v : {{enum.ty}} = {{item.value}};
v.encode(buf);
},
{{/each}}
}
}
}
{{/if}}
{{/each}}
{{#each defs.described_restricted as |dr|}}
type {{dr.name}} = {{dr.ty}};
fn decode_{{snake dr.name}}_inner(input: &[u8]) -> Result<(&[u8], {{dr.name}}), AmqpParseError> {
{{dr.name}}::decode(input)
}
fn encoded_size_{{snake dr.name}}_inner(dr: &{{dr.name}}) -> usize {
// descriptor size + actual size
3 + dr.encoded_size()
}
fn encode_{{snake dr.name}}_inner(dr: &{{dr.name}}, buf: &mut BytesMut) {
Descriptor::Ulong({{dr.descriptor.code}}).encode(buf);
dr.encode(buf);
}
{{/each}}
{{#each defs.lists as |list|}}
#[derive(Clone, Debug, PartialEq)]
pub struct {{list.name}} {
{{#each list.fields as |field|}}
{{#if field.optional}}
pub {{field.name}}: Option<{{{field.ty}}}>,
{{else}}
pub {{field.name}}: {{{field.ty}}},
{{/if}}
{{/each}}
{{#if list.transfer}}
pub body: Option<TransferBody>,
{{/if}}
}
impl {{list.name}} {
{{#each list.fields as |field|}}
{{#if field.is_str}}
{{#if field.optional}}
pub fn {{field.name}}(&self) -> Option<&str> {
match self.{{field.name}} {
None => None,
Some(ref s) => Some(s.as_str())
}
}
{{else}}
pub fn {{field.name}}(&self) -> &str { self.{{field.name}}.as_str() }
{{/if}}
{{else}}
{{#if field.is_ref}}
{{#if field.optional}}
pub fn {{field.name}}(&self) -> Option<&{{{field.ty}}}> { self.{{field.name}}.as_ref() }
{{else}}
pub fn {{field.name}}(&self) -> &{{{field.ty}}} { &self.{{field.name}} }
{{/if}}
{{else}}
{{#if field.optional}}
pub fn {{field.name}}(&self) -> Option<{{{field.ty}}}> { self.{{field.name}} }
{{else}}
pub fn {{field.name}}(&self) -> {{{field.ty}}} { self.{{field.name}} }
{{/if}}
{{/if}}
{{/if}}
{{/each}}
{{#if list.transfer}}
pub fn body(&self) -> Option<&TransferBody> {
self.body.as_ref()
}
{{/if}}
#[allow(clippy::identity_op)]
const FIELD_COUNT: usize = 0 {{#each list.fields as |field|}} + 1{{/each}};
}
#[allow(unused_mut)]
fn decode_{{snake list.name}}_inner(input: &[u8]) -> Result<(&[u8], {{list.name}}), AmqpParseError> {
let (input, format) = decode_format_code(input)?;
let (input, header) = decode_list_header(input, format)?;
let size = header.size as usize;
decode_check_len!(input, size);
{{#if list.fields}}
let (mut input, mut remainder) = input.split_at(size);
let mut count = header.count;
{{#each list.fields as |field|}}
{{#if field.optional}}
let {{field.name}}: Option<{{{field.ty}}}>;
if count > 0 {
let decoded = Option::<{{{field.ty}}}>::decode(input)?;
input = decoded.0;
{{field.name}} = decoded.1;
count -= 1;
}
else {
{{field.name}} = None;
}
{{else}}
let {{field.name}}: {{{field.ty}}};
if count > 0 {
{{#if field.default}}
let (in1, decoded) = Option::<{{{field.ty}}}>::decode(input)?;
{{field.name}} = decoded.unwrap_or({{field.default}});
{{else}}
let (in1, decoded) = {{{field.ty}}}::decode(input)?;
{{field.name}} = decoded;
{{/if}}
input = in1;
count -= 1;
}
else {
{{#if field.default}}
{{field.name}} = {{field.default}};
{{else}}
return Err(AmqpParseError::RequiredFieldOmitted("{{field.name}}"))
{{/if}}
}
{{/if}}
{{/each}}
{{else}}
let mut remainder = &input[size..];
{{/if}}
{{#if list.transfer}}
let body = if remainder.is_empty() {
None
} else {
let b = Bytes::copy_from_slice(remainder);
remainder = &[];
Some(b.into())
};
{{/if}}
Ok((remainder, {{list.name}} {
{{#each list.fields as |field|}}
{{field.name}},
{{/each}}
{{#if list.transfer}}
body
{{/if}}
}))
}
fn encoded_size_{{snake list.name}}_inner(list: &{{list.name}}) -> usize {
#[allow(clippy::identity_op)]
let content_size = 0 {{#each list.fields as |field|}} + list.{{field.name}}.encoded_size(){{/each}};
// header: 0x00 0x53 <descriptor code> format_code size count
(if content_size + 1 > u8::MAX as usize { 12 } else { 6 })
+ content_size
{{#if list.transfer}}
+ list.body.as_ref().map(|b| b.len()).unwrap_or(0)
{{/if}}
}
fn encode_{{snake list.name}}_inner(list: &{{list.name}}, buf: &mut BytesMut) {
Descriptor::Ulong({{list.descriptor.code}}).encode(buf);
#[allow(clippy::identity_op)]
let content_size = 0 {{#each list.fields as |field|}} + list.{{field.name}}.encoded_size(){{/each}};
if content_size + 1 > u8::MAX as usize {
buf.put_u8(codec::FORMATCODE_LIST32);
buf.put_u32((content_size + 4) as u32); // +4 for 4 byte count
buf.put_u32({{list.name}}::FIELD_COUNT as u32);
}
else {
buf.put_u8(codec::FORMATCODE_LIST8);
buf.put_u8((content_size + 1) as u8);
buf.put_u8({{list.name}}::FIELD_COUNT as u8);
}
{{#each list.fields as |field|}}
list.{{field.name}}.encode(buf);
{{/each}}
{{#if list.transfer}}
if let Some(ref body) = list.body {
body.encode(buf)
}
{{/if}}
}
impl DecodeFormatted for {{list.name}} {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
validate_code!(fmt, codec::FORMATCODE_DESCRIBED);
let (input, descriptor) = Descriptor::decode(input)?;
let is_match = match descriptor {
Descriptor::Ulong(val) => val == {{list.descriptor.code}},
Descriptor::Symbol(ref sym) => sym.as_bytes() == b"{{list.descriptor.name}}",
};
if !is_match {
Err(AmqpParseError::InvalidDescriptor(descriptor))
} else {
decode_{{snake list.name}}_inner(input)
}
}
}
impl Encode for {{list.name}} {
fn encoded_size(&self) -> usize { encoded_size_{{snake list.name}}_inner(self) }
fn encode(&self, buf: &mut BytesMut) { encode_{{snake list.name}}_inner(self, buf) }
}
{{/each}}

478
actix-amqp/codec/codegen/mod.rs Executable file
View File

@ -0,0 +1,478 @@
use serde::{Deserialize, Deserializer};
use serde_json::from_str;
use std::collections::{HashMap, HashSet};
use std::str::{FromStr, ParseBoolError};
use std::sync::Mutex;
lazy_static! {
static ref PRIMITIVE_TYPES: HashMap<&'static str, &'static str> = {
let mut m = HashMap::new();
m.insert("*", "Variant");
m.insert("binary", "Bytes");
m.insert("string", "ByteString");
m.insert("ubyte", "u8");
m.insert("ushort", "u16");
m.insert("uint", "u32");
m.insert("ulong", "u64");
m.insert("boolean", "bool");
m
};
static ref STRING_TYPES: HashSet<&'static str> = ["string", "symbol"].iter().cloned().collect();
static ref REF_TYPES: Mutex<HashSet<String>> = Mutex::new(
[
"Bytes",
"ByteString",
"Symbol",
"Fields",
"Map",
"MessageId",
"Address",
"NodeProperties",
"Outcome",
"DeliveryState",
"FilterSet",
"DeliveryTag",
"Symbols",
"IetfLanguageTags",
"ErrorCondition",
"DistributionMode"
]
.iter()
.map(|s| s.to_string())
.collect()
);
static ref ENUM_TYPES: Mutex<HashSet<String>> = Mutex::new(HashSet::new());
}
pub fn parse(spec: &str) -> Definitions {
let types = from_str::<Vec<_Type>>(spec).expect("Failed to parse AMQP spec.");
{
let mut ref_map = REF_TYPES.lock().unwrap();
let mut enum_map = ENUM_TYPES.lock().unwrap();
for t in types.iter() {
match *t {
_Type::Described(ref l) if l.source == "list" => {
ref_map.insert(camel_case(&*l.name));
}
_Type::Choice(ref e) => {
enum_map.insert(camel_case(&*e.name));
}
_ => {}
}
}
}
Definitions::from(types)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
enum _Type {
Choice(_Enum),
Described(_Described),
Alias(_Alias),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct _Descriptor {
name: String,
code: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct _Enum {
name: String,
source: String,
provides: Option<String>,
choice: Vec<EnumItem>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct _Described {
name: String,
class: String,
source: String,
provides: Option<String>,
descriptor: _Descriptor,
#[serde(default)]
field: Vec<_Field>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct _Field {
name: String,
#[serde(rename = "type")]
ty: String,
#[serde(default)]
#[serde(deserialize_with = "string_as_bool")]
mandatory: bool,
default: Option<String>,
#[serde(default)]
#[serde(deserialize_with = "string_as_bool")]
multiple: bool,
requires: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct _Alias {
name: String,
source: String,
provides: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Alias {
name: String,
source: String,
provides: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnumItem {
name: String,
value: String,
#[serde(default)]
value_len: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Definitions {
aliases: Vec<Alias>,
enums: Vec<Enum>,
lists: Vec<Described>,
described_restricted: Vec<Described>,
provides: Vec<ProvidesEnum>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ProvidesEnum {
name: String,
described: bool,
options: Vec<ProvidesItem>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ProvidesItem {
ty: String,
descriptor: Descriptor,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Enum {
name: String,
ty: String,
provides: Vec<String>,
items: Vec<EnumItem>,
is_symbol: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Described {
name: String,
ty: String,
provides: Vec<String>,
descriptor: Descriptor,
fields: Vec<Field>,
transfer: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Descriptor {
name: String,
domain: u32,
code: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Field {
name: String,
ty: String,
is_str: bool,
is_ref: bool,
optional: bool,
default: String,
multiple: bool,
}
impl Definitions {
fn from(types: Vec<_Type>) -> Definitions {
let mut aliases = vec![];
let mut enums = vec![];
let mut lists = vec![];
let mut described_restricted = vec![];
let mut provide_map: HashMap<String, Vec<ProvidesItem>> = HashMap::new();
for t in types.into_iter() {
match t {
_Type::Alias(ref a) if a.source != "map" => {
let al = Alias::from(a.clone());
Definitions::register_provides(&mut provide_map, &al.name, None, &al.provides);
aliases.push(al);
}
_Type::Choice(ref e) => {
let en = Enum::from(e.clone());
Definitions::register_provides(&mut provide_map, &en.name, None, &en.provides);
enums.push(en);
}
_Type::Described(ref d) if d.source == "list" && d.class != "restricted" => {
let ls = Described::list(d.clone());
Definitions::register_provides(
&mut provide_map,
&ls.name,
Some(ls.descriptor.clone()),
&ls.provides,
);
lists.push(ls);
}
_Type::Described(ref d) if d.class == "restricted" => {
let ls = Described::alias(d.clone());
Definitions::register_provides(
&mut provide_map,
&ls.name,
Some(ls.descriptor.clone()),
&ls.provides,
);
described_restricted.push(ls);
}
_ => {}
}
}
let provides = provide_map
.into_iter()
.filter_map(|(k, v)| {
if v.len() == 1 {
None
} else {
Some(ProvidesEnum {
described: if k == "Frame" {
false
} else {
v.iter().any(|v| v.descriptor.code != 0)
},
name: k,
options: v,
})
}
})
.collect();
Definitions {
aliases,
enums,
lists,
described_restricted,
provides,
}
}
fn register_provides(
map: &mut HashMap<String, Vec<ProvidesItem>>,
name: &str,
descriptor: Option<Descriptor>,
provides: &Vec<String>,
) {
for p in provides.iter() {
map.entry(p.clone())
.or_insert_with(|| vec![])
.push(ProvidesItem {
ty: name.to_string(),
descriptor: descriptor.clone().unwrap_or_else(|| Descriptor {
name: String::new(),
domain: 0,
code: 0,
}),
});
}
}
}
impl Alias {
fn from(a: _Alias) -> Alias {
Alias {
name: camel_case(&*a.name),
source: get_type_name(&*a.source, None),
provides: parse_provides(a.provides),
}
}
}
impl Enum {
fn from(e: _Enum) -> Enum {
let ty = get_type_name(&*e.source, None);
let is_symbol = ty == "Symbol";
Enum {
name: camel_case(&*e.name),
ty: ty.clone(),
provides: parse_provides(e.provides),
is_symbol,
items: e
.choice
.into_iter()
.map(|c| EnumItem {
name: camel_case(&*c.name),
value_len: c.value.len(),
value: c.value,
})
.collect(),
}
}
}
impl Described {
fn list(d: _Described) -> Described {
let transfer = d.name == "transfer";
Described {
name: camel_case(&d.name),
ty: String::new(),
provides: parse_provides(d.provides),
descriptor: Descriptor::from(d.descriptor),
fields: d.field.into_iter().map(|f| Field::from(f)).collect(),
transfer,
}
}
fn alias(d: _Described) -> Described {
Described {
name: camel_case(&d.name),
ty: get_type_name(&d.source, None),
provides: parse_provides(d.provides),
descriptor: Descriptor::from(d.descriptor),
fields: d.field.into_iter().map(|f| Field::from(f)).collect(),
transfer: false,
}
}
}
impl Descriptor {
fn from(d: _Descriptor) -> Descriptor {
let code_parts: Vec<u32> = d
.code
.split(":")
.map(|p| {
assert!(p.starts_with("0x"));
u32::from_str_radix(&p[2..], 16).expect("malformed descriptor code")
})
.collect();
Descriptor {
name: d.name,
domain: code_parts[0],
code: code_parts[1],
}
}
}
impl Field {
fn from(field: _Field) -> Field {
let mut ty = get_type_name(&*field.ty, field.requires);
if field.multiple {
ty.push('s');
}
let is_str = STRING_TYPES.contains(&*ty) && !field.multiple;
let is_ref = REF_TYPES.lock().unwrap().contains(&ty);
let default = Field::format_default(field.default, &ty);
Field {
name: snake_case(&*field.name),
ty: ty,
is_ref,
is_str,
optional: !field.mandatory && default.len() == 0,
multiple: field.multiple,
default,
}
}
fn format_default(default: Option<String>, ty: &str) -> String {
match default {
None => String::new(),
Some(def) => {
if ENUM_TYPES.lock().unwrap().contains(ty) {
format!("{}::{}", ty, camel_case(&*def))
} else {
def
}
}
}
}
}
fn get_type_name(ty: &str, req: Option<String>) -> String {
match req {
Some(t) => camel_case(&*t),
None => match PRIMITIVE_TYPES.get(ty) {
Some(p) => p.to_string(),
None => camel_case(&*ty),
},
}
}
fn parse_provides(p: Option<String>) -> Vec<String> {
p.map(|v| {
v.split_terminator(",")
.filter_map(|s| {
let s = s.trim();
if s == "" {
None
} else {
Some(camel_case(&s))
}
})
.collect()
})
.unwrap_or(vec![])
}
fn string_as_bool<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
T: FromStr<Err = ParseBoolError>,
D: Deserializer<'de>,
{
Ok(String::deserialize(deserializer)?
.parse::<T>()
.expect("Error parsing bool from string"))
}
pub fn camel_case(name: &str) -> String {
let mut new_word = true;
name.chars().fold("".to_string(), |mut result, ch| {
if ch == '-' || ch == '_' || ch == ' ' {
new_word = true;
result
} else {
result.push(if new_word {
ch.to_ascii_uppercase()
} else {
ch
});
new_word = false;
result
}
})
}
pub fn snake_case(name: &str) -> String {
match name {
"type" => "type_".to_string(),
"return" => "return_".to_string(),
name => {
let mut new_word = false;
let mut last_was_upper = false;
name.chars().fold("".to_string(), |mut result, ch| {
if ch == '-' || ch == '_' || ch == ' ' {
new_word = true;
result
} else {
let uppercase = ch.is_uppercase();
if new_word || (!last_was_upper && !result.is_empty() && uppercase) {
result.push('_');
new_word = false;
}
last_was_upper = uppercase;
result.push(if uppercase {
ch.to_ascii_lowercase()
} else {
ch
});
result
}
})
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,834 @@
use std::collections::HashMap;
use std::convert::TryFrom;
use std::hash::{BuildHasher, Hash};
use std::{char, str, u8};
use byteorder::{BigEndian, ByteOrder};
use bytes::Bytes;
use bytestring::ByteString;
use chrono::{DateTime, TimeZone, Utc};
use fxhash::FxHashMap;
use ordered_float::OrderedFloat;
use uuid::Uuid;
use crate::codec::{self, ArrayDecode, Decode, DecodeFormatted};
use crate::errors::AmqpParseError;
use crate::framing::{self, AmqpFrame, SaslFrame, HEADER_LEN};
use crate::protocol::{self, CompoundHeader};
use crate::types::{
Descriptor, List, Multiple, Str, Symbol, Variant, VariantMap, VecStringMap, VecSymbolMap,
};
macro_rules! be_read {
($input:ident, $fn:ident, $size:expr) => {{
decode_check_len!($input, $size);
let x = BigEndian::$fn($input);
Ok((&$input[$size..], x))
}};
}
fn read_u8(input: &[u8]) -> Result<(&[u8], u8), AmqpParseError> {
decode_check_len!(input, 1);
Ok((&input[1..], input[0]))
}
fn read_i8(input: &[u8]) -> Result<(&[u8], i8), AmqpParseError> {
decode_check_len!(input, 1);
Ok((&input[1..], input[0] as i8))
}
fn read_bytes_u8(input: &[u8]) -> Result<(&[u8], &[u8]), AmqpParseError> {
let (input, len) = read_u8(input)?;
let len = len as usize;
decode_check_len!(input, len);
let (bytes, input) = input.split_at(len);
Ok((input, bytes))
}
fn read_bytes_u32(input: &[u8]) -> Result<(&[u8], &[u8]), AmqpParseError> {
let result: Result<(&[u8], u32), AmqpParseError> = be_read!(input, read_u32, 4);
let (input, len) = result?;
let len = len as usize;
decode_check_len!(input, len);
let (bytes, input) = input.split_at(len);
Ok((input, bytes))
}
#[macro_export]
macro_rules! validate_code {
($fmt:ident, $code:expr) => {
if $fmt != $code {
return Err(AmqpParseError::InvalidFormatCode($fmt));
}
};
}
impl DecodeFormatted for bool {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
match fmt {
codec::FORMATCODE_BOOLEAN => read_u8(input).map(|(i, o)| (i, o != 0)),
codec::FORMATCODE_BOOLEAN_TRUE => Ok((input, true)),
codec::FORMATCODE_BOOLEAN_FALSE => Ok((input, false)),
_ => Err(AmqpParseError::InvalidFormatCode(fmt)),
}
}
}
impl DecodeFormatted for u8 {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
validate_code!(fmt, codec::FORMATCODE_UBYTE);
read_u8(input)
}
}
impl DecodeFormatted for u16 {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
validate_code!(fmt, codec::FORMATCODE_USHORT);
be_read!(input, read_u16, 2)
}
}
impl DecodeFormatted for u32 {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
match fmt {
codec::FORMATCODE_UINT => be_read!(input, read_u32, 4),
codec::FORMATCODE_SMALLUINT => read_u8(input).map(|(i, o)| (i, u32::from(o))),
codec::FORMATCODE_UINT_0 => Ok((input, 0)),
_ => Err(AmqpParseError::InvalidFormatCode(fmt)),
}
}
}
impl DecodeFormatted for u64 {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
match fmt {
codec::FORMATCODE_ULONG => be_read!(input, read_u64, 8),
codec::FORMATCODE_SMALLULONG => read_u8(input).map(|(i, o)| (i, u64::from(o))),
codec::FORMATCODE_ULONG_0 => Ok((input, 0)),
_ => Err(AmqpParseError::InvalidFormatCode(fmt)),
}
}
}
impl DecodeFormatted for i8 {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
validate_code!(fmt, codec::FORMATCODE_BYTE);
read_i8(input)
}
}
impl DecodeFormatted for i16 {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
validate_code!(fmt, codec::FORMATCODE_SHORT);
be_read!(input, read_i16, 2)
}
}
impl DecodeFormatted for i32 {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
match fmt {
codec::FORMATCODE_INT => be_read!(input, read_i32, 4),
codec::FORMATCODE_SMALLINT => read_i8(input).map(|(i, o)| (i, i32::from(o))),
_ => Err(AmqpParseError::InvalidFormatCode(fmt)),
}
}
}
impl DecodeFormatted for i64 {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
match fmt {
codec::FORMATCODE_LONG => be_read!(input, read_i64, 8),
codec::FORMATCODE_SMALLLONG => read_i8(input).map(|(i, o)| (i, i64::from(o))),
_ => Err(AmqpParseError::InvalidFormatCode(fmt)),
}
}
}
impl DecodeFormatted for f32 {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
validate_code!(fmt, codec::FORMATCODE_FLOAT);
be_read!(input, read_f32, 4)
}
}
impl DecodeFormatted for f64 {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
validate_code!(fmt, codec::FORMATCODE_DOUBLE);
be_read!(input, read_f64, 8)
}
}
impl DecodeFormatted for char {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
validate_code!(fmt, codec::FORMATCODE_CHAR);
let result: Result<(&[u8], u32), AmqpParseError> = be_read!(input, read_u32, 4);
let (i, o) = result?;
if let Some(c) = char::from_u32(o) {
Ok((i, c))
} else {
Err(AmqpParseError::InvalidChar(o))
} // todo: replace with CharTryFromError once try_from is stabilized
}
}
impl DecodeFormatted for DateTime<Utc> {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
validate_code!(fmt, codec::FORMATCODE_TIMESTAMP);
be_read!(input, read_i64, 8).map(|(i, o)| (i, datetime_from_millis(o)))
}
}
impl DecodeFormatted for Uuid {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
validate_code!(fmt, codec::FORMATCODE_UUID);
decode_check_len!(input, 16);
let uuid = Uuid::from_slice(&input[..16])?;
Ok((&input[16..], uuid))
}
}
impl DecodeFormatted for Bytes {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
match fmt {
codec::FORMATCODE_BINARY8 => {
read_bytes_u8(input).map(|(i, o)| (i, Bytes::copy_from_slice(o)))
}
codec::FORMATCODE_BINARY32 => {
read_bytes_u32(input).map(|(i, o)| (i, Bytes::copy_from_slice(o)))
}
_ => Err(AmqpParseError::InvalidFormatCode(fmt)),
}
}
}
impl DecodeFormatted for ByteString {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
match fmt {
codec::FORMATCODE_STRING8 => {
let (input, bytes) = read_bytes_u8(input)?;
Ok((input, ByteString::try_from(bytes)?))
}
codec::FORMATCODE_STRING32 => {
let (input, bytes) = read_bytes_u32(input)?;
Ok((input, ByteString::try_from(bytes)?))
}
_ => Err(AmqpParseError::InvalidFormatCode(fmt)),
}
}
}
impl DecodeFormatted for Str {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
match fmt {
codec::FORMATCODE_STRING8 => {
let (input, bytes) = read_bytes_u8(input)?;
Ok((input, Str::from_str(str::from_utf8(bytes)?)))
}
codec::FORMATCODE_STRING32 => {
let (input, bytes) = read_bytes_u32(input)?;
Ok((input, Str::from_str(str::from_utf8(bytes)?)))
}
_ => Err(AmqpParseError::InvalidFormatCode(fmt)),
}
}
}
impl DecodeFormatted for Symbol {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
match fmt {
codec::FORMATCODE_SYMBOL8 => {
let (input, bytes) = read_bytes_u8(input)?;
Ok((input, Symbol::from_slice(str::from_utf8(bytes)?)))
}
codec::FORMATCODE_SYMBOL32 => {
let (input, bytes) = read_bytes_u32(input)?;
Ok((input, Symbol::from_slice(str::from_utf8(bytes)?)))
}
_ => Err(AmqpParseError::InvalidFormatCode(fmt)),
}
}
}
impl ArrayDecode for Symbol {
fn array_decode(input: &[u8]) -> Result<(&[u8], Self), AmqpParseError> {
let (input, bytes) = read_bytes_u32(input)?;
Ok((input, Symbol::from_slice(str::from_utf8(bytes)?)))
}
}
impl<K: Decode + Eq + Hash, V: Decode, S: BuildHasher + Default> DecodeFormatted
for HashMap<K, V, S>
{
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
let (input, header) = decode_map_header(input, fmt)?;
let mut map_input = &input[..header.size as usize];
let count = header.count / 2;
let mut map: HashMap<K, V, S> =
HashMap::with_capacity_and_hasher(count as usize, Default::default());
for _ in 0..count {
let (input1, key) = K::decode(map_input)?;
let (input2, value) = V::decode(input1)?;
map_input = input2;
map.insert(key, value); // todo: ensure None returned?
}
// todo: validate map_input is empty
Ok((&input[header.size as usize..], map))
}
}
impl<T: DecodeFormatted> DecodeFormatted for Vec<T> {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
let (input, header) = decode_array_header(input, fmt)?;
let item_fmt = input[0]; // todo: support descriptor
let mut input = &input[1..];
let mut result: Vec<T> = Vec::with_capacity(header.count as usize);
for _ in 0..header.count {
let (new_input, decoded) = T::decode_with_format(input, item_fmt)?;
result.push(decoded);
input = new_input;
}
Ok((input, result))
}
}
impl DecodeFormatted for VecSymbolMap {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
let (input, header) = decode_map_header(input, fmt)?;
let mut map_input = &input[..header.size as usize];
let count = header.count / 2;
let mut map = Vec::with_capacity(count as usize);
for _ in 0..count {
let (input1, key) = Symbol::decode(map_input)?;
let (input2, value) = Variant::decode(input1)?;
map_input = input2;
map.push((key, value)); // todo: ensure None returned?
}
// todo: validate map_input is empty
Ok((&input[header.size as usize..], VecSymbolMap(map)))
}
}
impl DecodeFormatted for VecStringMap {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
let (input, header) = decode_map_header(input, fmt)?;
let mut map_input = &input[..header.size as usize];
let count = header.count / 2;
let mut map = Vec::with_capacity(count as usize);
for _ in 0..count {
let (input1, key) = Str::decode(map_input)?;
let (input2, value) = Variant::decode(input1)?;
map_input = input2;
map.push((key, value)); // todo: ensure None returned?
}
// todo: validate map_input is empty
Ok((&input[header.size as usize..], VecStringMap(map)))
}
}
impl<T: ArrayDecode + DecodeFormatted> DecodeFormatted for Multiple<T> {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
match fmt {
codec::FORMATCODE_ARRAY8 | codec::FORMATCODE_ARRAY32 => {
let (input, items) = Vec::<T>::decode_with_format(input, fmt)?;
Ok((input, Multiple(items)))
}
_ => {
let (input, item) = T::decode_with_format(input, fmt)?;
Ok((input, Multiple(vec![item])))
}
}
}
}
impl DecodeFormatted for List {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
let (mut input, header) = decode_list_header(input, fmt)?;
let mut result: Vec<Variant> = Vec::with_capacity(header.count as usize);
for _ in 0..header.count {
let (new_input, decoded) = Variant::decode(input)?;
result.push(decoded);
input = new_input;
}
Ok((input, List(result)))
}
}
impl DecodeFormatted for Variant {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
match fmt {
codec::FORMATCODE_NULL => Ok((input, Variant::Null)),
codec::FORMATCODE_BOOLEAN => {
bool::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Boolean(o)))
}
codec::FORMATCODE_BOOLEAN_FALSE => Ok((input, Variant::Boolean(false))),
codec::FORMATCODE_BOOLEAN_TRUE => Ok((input, Variant::Boolean(true))),
codec::FORMATCODE_UINT_0 => Ok((input, Variant::Uint(0))),
codec::FORMATCODE_ULONG_0 => Ok((input, Variant::Ulong(0))),
codec::FORMATCODE_UBYTE => {
u8::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Ubyte(o)))
}
codec::FORMATCODE_USHORT => {
u16::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Ushort(o)))
}
codec::FORMATCODE_UINT => {
u32::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Uint(o)))
}
codec::FORMATCODE_ULONG => {
u64::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Ulong(o)))
}
codec::FORMATCODE_BYTE => {
i8::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Byte(o)))
}
codec::FORMATCODE_SHORT => {
i16::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Short(o)))
}
codec::FORMATCODE_INT => {
i32::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Int(o)))
}
codec::FORMATCODE_LONG => {
i64::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Long(o)))
}
codec::FORMATCODE_SMALLUINT => {
u32::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Uint(o)))
}
codec::FORMATCODE_SMALLULONG => {
u64::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Ulong(o)))
}
codec::FORMATCODE_SMALLINT => {
i32::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Int(o)))
}
codec::FORMATCODE_SMALLLONG => {
i64::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Long(o)))
}
codec::FORMATCODE_FLOAT => f32::decode_with_format(input, fmt)
.map(|(i, o)| (i, Variant::Float(OrderedFloat(o)))),
codec::FORMATCODE_DOUBLE => f64::decode_with_format(input, fmt)
.map(|(i, o)| (i, Variant::Double(OrderedFloat(o)))),
// codec::FORMATCODE_DECIMAL32 => x::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Decimal(o))),
// codec::FORMATCODE_DECIMAL64 => x::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Decimal(o))),
// codec::FORMATCODE_DECIMAL128 => x::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Decimal(o))),
codec::FORMATCODE_CHAR => {
char::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Char(o)))
}
codec::FORMATCODE_TIMESTAMP => DateTime::<Utc>::decode_with_format(input, fmt)
.map(|(i, o)| (i, Variant::Timestamp(o))),
codec::FORMATCODE_UUID => {
Uuid::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Uuid(o)))
}
codec::FORMATCODE_BINARY8 => {
Bytes::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Binary(o)))
}
codec::FORMATCODE_BINARY32 => {
Bytes::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Binary(o)))
}
codec::FORMATCODE_STRING8 => ByteString::decode_with_format(input, fmt)
.map(|(i, o)| (i, Variant::String(o.into()))),
codec::FORMATCODE_STRING32 => ByteString::decode_with_format(input, fmt)
.map(|(i, o)| (i, Variant::String(o.into()))),
codec::FORMATCODE_SYMBOL8 => {
Symbol::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Symbol(o)))
}
codec::FORMATCODE_SYMBOL32 => {
Symbol::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Symbol(o)))
}
codec::FORMATCODE_LIST0 => Ok((input, Variant::List(List(vec![])))),
codec::FORMATCODE_LIST8 => {
List::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::List(o)))
}
codec::FORMATCODE_LIST32 => {
List::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::List(o)))
}
codec::FORMATCODE_MAP8 => FxHashMap::<Variant, Variant>::decode_with_format(input, fmt)
.map(|(i, o)| (i, Variant::Map(VariantMap::new(o)))),
codec::FORMATCODE_MAP32 => {
FxHashMap::<Variant, Variant>::decode_with_format(input, fmt)
.map(|(i, o)| (i, Variant::Map(VariantMap::new(o))))
}
// codec::FORMATCODE_ARRAY8 => Vec::<Variant>::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Array(o))),
// codec::FORMATCODE_ARRAY32 => Vec::<Variant>::decode_with_format(input, fmt).map(|(i, o)| (i, Variant::Array(o))),
codec::FORMATCODE_DESCRIBED => {
let (input, descriptor) = Descriptor::decode(input)?;
let (input, value) = Variant::decode(input)?;
Ok((input, Variant::Described((descriptor, Box::new(value)))))
}
_ => Err(AmqpParseError::InvalidFormatCode(fmt)),
}
}
}
impl<T: DecodeFormatted> DecodeFormatted for Option<T> {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
match fmt {
codec::FORMATCODE_NULL => Ok((input, None)),
_ => T::decode_with_format(input, fmt).map(|(i, o)| (i, Some(o))),
}
}
}
impl DecodeFormatted for Descriptor {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
match fmt {
codec::FORMATCODE_SMALLULONG => {
u64::decode_with_format(input, fmt).map(|(i, o)| (i, Descriptor::Ulong(o)))
}
codec::FORMATCODE_ULONG => {
u64::decode_with_format(input, fmt).map(|(i, o)| (i, Descriptor::Ulong(o)))
}
codec::FORMATCODE_SYMBOL8 => {
Symbol::decode_with_format(input, fmt).map(|(i, o)| (i, Descriptor::Symbol(o)))
}
codec::FORMATCODE_SYMBOL32 => {
Symbol::decode_with_format(input, fmt).map(|(i, o)| (i, Descriptor::Symbol(o)))
}
_ => Err(AmqpParseError::InvalidFormatCode(fmt)),
}
}
}
impl Decode for AmqpFrame {
fn decode(input: &[u8]) -> Result<(&[u8], Self), AmqpParseError> {
let (input, channel_id) = decode_frame_header(input, framing::FRAME_TYPE_AMQP)?;
let (input, performative) = protocol::Frame::decode(input)?;
Ok((input, AmqpFrame::new(channel_id, performative)))
}
}
impl Decode for SaslFrame {
fn decode(input: &[u8]) -> Result<(&[u8], Self), AmqpParseError> {
let (input, _) = decode_frame_header(input, framing::FRAME_TYPE_SASL)?;
let (input, frame) = protocol::SaslFrameBody::decode(input)?;
Ok((input, SaslFrame { body: frame }))
}
}
fn decode_frame_header(
input: &[u8],
expected_frame_type: u8,
) -> Result<(&[u8], u16), AmqpParseError> {
decode_check_len!(input, 4);
let doff = input[0];
let frame_type = input[1];
if frame_type != expected_frame_type {
return Err(AmqpParseError::UnexpectedFrameType(frame_type));
}
let channel_id = BigEndian::read_u16(&input[2..]);
let doff = doff as usize * 4;
if doff < HEADER_LEN {
return Err(AmqpParseError::InvalidSize);
}
let ext_header_len = doff - HEADER_LEN;
decode_check_len!(input, ext_header_len + 4);
let input = &input[ext_header_len + 4..]; // skipping remaining two header bytes and ext header
Ok((input, channel_id))
}
fn decode_array_header(input: &[u8], fmt: u8) -> Result<(&[u8], CompoundHeader), AmqpParseError> {
match fmt {
codec::FORMATCODE_ARRAY8 => decode_compound8(input),
codec::FORMATCODE_ARRAY32 => decode_compound32(input),
_ => Err(AmqpParseError::InvalidFormatCode(fmt)),
}
}
pub(crate) fn decode_list_header(
input: &[u8],
fmt: u8,
) -> Result<(&[u8], CompoundHeader), AmqpParseError> {
match fmt {
codec::FORMATCODE_LIST0 => Ok((input, CompoundHeader::empty())),
codec::FORMATCODE_LIST8 => decode_compound8(input),
codec::FORMATCODE_LIST32 => decode_compound32(input),
_ => Err(AmqpParseError::InvalidFormatCode(fmt)),
}
}
pub(crate) fn decode_map_header(
input: &[u8],
fmt: u8,
) -> Result<(&[u8], CompoundHeader), AmqpParseError> {
match fmt {
codec::FORMATCODE_MAP8 => decode_compound8(input),
codec::FORMATCODE_MAP32 => decode_compound32(input),
_ => Err(AmqpParseError::InvalidFormatCode(fmt)),
}
}
fn decode_compound8(input: &[u8]) -> Result<(&[u8], CompoundHeader), AmqpParseError> {
decode_check_len!(input, 2);
let size = input[0] - 1; // -1 for 1 byte count
let count = input[1];
Ok((
&input[2..],
CompoundHeader {
size: u32::from(size),
count: u32::from(count),
},
))
}
fn decode_compound32(input: &[u8]) -> Result<(&[u8], CompoundHeader), AmqpParseError> {
decode_check_len!(input, 8);
let size = BigEndian::read_u32(input) - 4; // -4 for 4 byte count
let count = BigEndian::read_u32(&input[4..]);
Ok((&input[8..], CompoundHeader { size, count }))
}
fn datetime_from_millis(millis: i64) -> DateTime<Utc> {
let seconds = millis / 1000;
if seconds < 0 {
// In order to handle time before 1970 correctly, we need to subtract a second
// and use the nanoseconds field to add it back. This is a result of the nanoseconds
// parameter being u32
let nanoseconds = ((1000 + (millis - (seconds * 1000))) * 1_000_000).abs() as u32;
Utc.timestamp(seconds - 1, nanoseconds)
} else {
let nanoseconds = ((millis - (seconds * 1000)) * 1_000_000).abs() as u32;
Utc.timestamp(seconds, nanoseconds)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::Encode;
use bytes::{BufMut, BytesMut};
const LOREM: &str = include_str!("lorem.txt");
macro_rules! decode_tests {
($($name:ident: $kind:ident, $test:expr, $expected:expr,)*) => {
$(
#[test]
fn $name() {
let b1 = &mut BytesMut::with_capacity(($test).encoded_size());
($test).encode(b1);
assert_eq!($expected, unwrap_value($kind::decode(b1)));
}
)*
}
}
decode_tests! {
ubyte: u8, 255_u8, 255_u8,
ushort: u16, 350_u16, 350_u16,
uint_zero: u32, 0_u32, 0_u32,
uint_small: u32, 128_u32, 128_u32,
uint_big: u32, 2147483647_u32, 2147483647_u32,
ulong_zero: u64, 0_u64, 0_u64,
ulong_small: u64, 128_u64, 128_u64,
uulong_big: u64, 2147483649_u64, 2147483649_u64,
byte: i8, -128_i8, -128_i8,
short: i16, -255_i16, -255_i16,
int_zero: i32, 0_i32, 0_i32,
int_small: i32, -50000_i32, -50000_i32,
int_neg: i32, -128_i32, -128_i32,
long_zero: i64, 0_i64, 0_i64,
long_big: i64, -2147483647_i64, -2147483647_i64,
long_small: i64, -128_i64, -128_i64,
float: f32, 1.234_f32, 1.234_f32,
double: f64, 1.234_f64, 1.234_f64,
test_char: char, '💯', '💯',
uuid: Uuid, Uuid::from_slice(&[4, 54, 67, 12, 43, 2, 98, 76, 32, 50, 87, 5, 1, 33, 43, 87]).expect("parse error"),
Uuid::parse_str("0436430c2b02624c2032570501212b57").expect("parse error"),
binary_short: Bytes, Bytes::from(&[4u8, 5u8][..]), Bytes::from(&[4u8, 5u8][..]),
binary_long: Bytes, Bytes::from(&[4u8; 500][..]), Bytes::from(&[4u8; 500][..]),
string_short: ByteString, ByteString::from("Hello there"), ByteString::from("Hello there"),
string_long: ByteString, ByteString::from(LOREM), ByteString::from(LOREM),
// symbol_short: Symbol, Symbol::from("Hello there"), Symbol::from("Hello there"),
// symbol_long: Symbol, Symbol::from(LOREM), Symbol::from(LOREM),
variant_ubyte: Variant, Variant::Ubyte(255_u8), Variant::Ubyte(255_u8),
variant_ushort: Variant, Variant::Ushort(350_u16), Variant::Ushort(350_u16),
variant_uint_zero: Variant, Variant::Uint(0_u32), Variant::Uint(0_u32),
variant_uint_small: Variant, Variant::Uint(128_u32), Variant::Uint(128_u32),
variant_uint_big: Variant, Variant::Uint(2147483647_u32), Variant::Uint(2147483647_u32),
variant_ulong_zero: Variant, Variant::Ulong(0_u64), Variant::Ulong(0_u64),
variant_ulong_small: Variant, Variant::Ulong(128_u64), Variant::Ulong(128_u64),
variant_ulong_big: Variant, Variant::Ulong(2147483649_u64), Variant::Ulong(2147483649_u64),
variant_byte: Variant, Variant::Byte(-128_i8), Variant::Byte(-128_i8),
variant_short: Variant, Variant::Short(-255_i16), Variant::Short(-255_i16),
variant_int_zero: Variant, Variant::Int(0_i32), Variant::Int(0_i32),
variant_int_small: Variant, Variant::Int(-50000_i32), Variant::Int(-50000_i32),
variant_int_neg: Variant, Variant::Int(-128_i32), Variant::Int(-128_i32),
variant_long_zero: Variant, Variant::Long(0_i64), Variant::Long(0_i64),
variant_long_big: Variant, Variant::Long(-2147483647_i64), Variant::Long(-2147483647_i64),
variant_long_small: Variant, Variant::Long(-128_i64), Variant::Long(-128_i64),
variant_float: Variant, Variant::Float(OrderedFloat(1.234_f32)), Variant::Float(OrderedFloat(1.234_f32)),
variant_double: Variant, Variant::Double(OrderedFloat(1.234_f64)), Variant::Double(OrderedFloat(1.234_f64)),
variant_char: Variant, Variant::Char('💯'), Variant::Char('💯'),
variant_uuid: Variant, Variant::Uuid(Uuid::from_slice(&[4, 54, 67, 12, 43, 2, 98, 76, 32, 50, 87, 5, 1, 33, 43, 87]).expect("parse error")),
Variant::Uuid(Uuid::parse_str("0436430c2b02624c2032570501212b57").expect("parse error")),
variant_binary_short: Variant, Variant::Binary(Bytes::from(&[4u8, 5u8][..])), Variant::Binary(Bytes::from(&[4u8, 5u8][..])),
variant_binary_long: Variant, Variant::Binary(Bytes::from(&[4u8; 500][..])), Variant::Binary(Bytes::from(&[4u8; 500][..])),
variant_string_short: Variant, Variant::String(ByteString::from("Hello there").into()), Variant::String(ByteString::from("Hello there").into()),
variant_string_long: Variant, Variant::String(ByteString::from(LOREM).into()), Variant::String(ByteString::from(LOREM).into()),
// variant_symbol_short: Variant, Variant::Symbol(Symbol::from("Hello there")), Variant::Symbol(Symbol::from("Hello there")),
// variant_symbol_long: Variant, Variant::Symbol(Symbol::from(LOREM)), Variant::Symbol(Symbol::from(LOREM)),
}
fn unwrap_value<T>(res: Result<(&[u8], T), AmqpParseError>) -> T {
let r = res.map(|(_i, o)| o);
assert!(r.is_ok());
r.unwrap()
}
#[test]
fn test_bool_true() {
let b1 = &mut BytesMut::with_capacity(0);
b1.put_u8(0x41);
assert_eq!(true, unwrap_value(bool::decode(b1)));
let b2 = &mut BytesMut::with_capacity(0);
b2.put_u8(0x56);
b2.put_u8(0x01);
assert_eq!(true, unwrap_value(bool::decode(b2)));
}
#[test]
fn test_bool_false() {
let b1 = &mut BytesMut::with_capacity(0);
b1.put_u8(0x42u8);
assert_eq!(false, unwrap_value(bool::decode(b1)));
let b2 = &mut BytesMut::with_capacity(0);
b2.put_u8(0x56);
b2.put_u8(0x00);
assert_eq!(false, unwrap_value(bool::decode(b2)));
}
/// UTC with a precision of milliseconds. For example, 1311704463521
/// represents the moment 2011-07-26T18:21:03.521Z.
#[test]
fn test_timestamp() {
let b1 = &mut BytesMut::with_capacity(0);
let datetime = Utc.ymd(2011, 7, 26).and_hms_milli(18, 21, 3, 521);
datetime.encode(b1);
let expected = Utc.ymd(2011, 7, 26).and_hms_milli(18, 21, 3, 521);
assert_eq!(expected, unwrap_value(DateTime::<Utc>::decode(b1)));
}
#[test]
fn test_timestamp_pre_unix() {
let b1 = &mut BytesMut::with_capacity(0);
let datetime = Utc.ymd(1968, 7, 26).and_hms_milli(18, 21, 3, 521);
datetime.encode(b1);
let expected = Utc.ymd(1968, 7, 26).and_hms_milli(18, 21, 3, 521);
assert_eq!(expected, unwrap_value(DateTime::<Utc>::decode(b1)));
}
#[test]
fn variant_null() {
let mut b = BytesMut::with_capacity(0);
Variant::Null.encode(&mut b);
let t = unwrap_value(Variant::decode(&mut b));
assert_eq!(Variant::Null, t);
}
#[test]
fn variant_bool_true() {
let b1 = &mut BytesMut::with_capacity(0);
b1.put_u8(0x41);
assert_eq!(Variant::Boolean(true), unwrap_value(Variant::decode(b1)));
let b2 = &mut BytesMut::with_capacity(0);
b2.put_u8(0x56);
b2.put_u8(0x01);
assert_eq!(Variant::Boolean(true), unwrap_value(Variant::decode(b2)));
}
#[test]
fn variant_bool_false() {
let b1 = &mut BytesMut::with_capacity(0);
b1.put_u8(0x42u8);
assert_eq!(Variant::Boolean(false), unwrap_value(Variant::decode(b1)));
let b2 = &mut BytesMut::with_capacity(0);
b2.put_u8(0x56);
b2.put_u8(0x00);
assert_eq!(Variant::Boolean(false), unwrap_value(Variant::decode(b2)));
}
/// UTC with a precision of milliseconds. For example, 1311704463521
/// represents the moment 2011-07-26T18:21:03.521Z.
#[test]
fn variant_timestamp() {
let b1 = &mut BytesMut::with_capacity(0);
let datetime = Utc.ymd(2011, 7, 26).and_hms_milli(18, 21, 3, 521);
Variant::Timestamp(datetime).encode(b1);
let expected = Utc.ymd(2011, 7, 26).and_hms_milli(18, 21, 3, 521);
assert_eq!(
Variant::Timestamp(expected),
unwrap_value(Variant::decode(b1))
);
}
#[test]
fn variant_timestamp_pre_unix() {
let b1 = &mut BytesMut::with_capacity(0);
let datetime = Utc.ymd(1968, 7, 26).and_hms_milli(18, 21, 3, 521);
Variant::Timestamp(datetime).encode(b1);
let expected = Utc.ymd(1968, 7, 26).and_hms_milli(18, 21, 3, 521);
assert_eq!(
Variant::Timestamp(expected),
unwrap_value(Variant::decode(b1))
);
}
#[test]
fn option_i8() {
let b1 = &mut BytesMut::with_capacity(0);
Some(42i8).encode(b1);
assert_eq!(Some(42), unwrap_value(Option::<i8>::decode(b1)));
let b2 = &mut BytesMut::with_capacity(0);
let o1: Option<i8> = None;
o1.encode(b2);
assert_eq!(None, unwrap_value(Option::<i8>::decode(b2)));
}
#[test]
fn option_string() {
let b1 = &mut BytesMut::with_capacity(0);
Some(ByteString::from("hello")).encode(b1);
assert_eq!(
Some(ByteString::from("hello")),
unwrap_value(Option::<ByteString>::decode(b1))
);
let b2 = &mut BytesMut::with_capacity(0);
let o1: Option<ByteString> = None;
o1.encode(b2);
assert_eq!(None, unwrap_value(Option::<ByteString>::decode(b2)));
}
}

View File

@ -0,0 +1,786 @@
use std::collections::HashMap;
use std::hash::{BuildHasher, Hash};
use std::{i8, u8};
use bytes::{BufMut, Bytes, BytesMut};
use bytestring::ByteString;
use chrono::{DateTime, Utc};
use uuid::Uuid;
use crate::codec::{self, ArrayEncode, Encode};
use crate::framing::{self, AmqpFrame, SaslFrame};
use crate::types::{
Descriptor, List, Multiple, StaticSymbol, Str, Symbol, Variant, VecStringMap, VecSymbolMap,
};
fn encode_null(buf: &mut BytesMut) {
buf.put_u8(codec::FORMATCODE_NULL);
}
pub trait FixedEncode {}
impl<T: FixedEncode + ArrayEncode> Encode for T {
fn encoded_size(&self) -> usize {
self.array_encoded_size() + 1
}
fn encode(&self, buf: &mut BytesMut) {
buf.put_u8(T::ARRAY_FORMAT_CODE);
self.array_encode(buf);
}
}
impl Encode for bool {
fn encoded_size(&self) -> usize {
1
}
fn encode(&self, buf: &mut BytesMut) {
buf.put_u8(if *self {
codec::FORMATCODE_BOOLEAN_TRUE
} else {
codec::FORMATCODE_BOOLEAN_FALSE
});
}
}
impl ArrayEncode for bool {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_BOOLEAN;
fn array_encoded_size(&self) -> usize {
1
}
fn array_encode(&self, buf: &mut BytesMut) {
buf.put_u8(if *self { 1 } else { 0 });
}
}
impl FixedEncode for u8 {}
impl ArrayEncode for u8 {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_UBYTE;
fn array_encoded_size(&self) -> usize {
1
}
fn array_encode(&self, buf: &mut BytesMut) {
buf.put_u8(*self);
}
}
impl FixedEncode for u16 {}
impl ArrayEncode for u16 {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_USHORT;
fn array_encoded_size(&self) -> usize {
2
}
fn array_encode(&self, buf: &mut BytesMut) {
buf.put_u16(*self);
}
}
impl Encode for u32 {
fn encoded_size(&self) -> usize {
if *self == 0 {
1
} else if *self > u32::from(u8::MAX) {
5
} else {
2
}
}
fn encode(&self, buf: &mut BytesMut) {
if *self == 0 {
buf.put_u8(codec::FORMATCODE_UINT_0)
} else if *self > u32::from(u8::MAX) {
buf.put_u8(codec::FORMATCODE_UINT);
buf.put_u32(*self);
} else {
buf.put_u8(codec::FORMATCODE_SMALLUINT);
buf.put_u8(*self as u8);
}
}
}
impl ArrayEncode for u32 {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_UINT;
fn array_encoded_size(&self) -> usize {
4
}
fn array_encode(&self, buf: &mut BytesMut) {
buf.put_u32(*self);
}
}
impl Encode for u64 {
fn encoded_size(&self) -> usize {
if *self == 0 {
1
} else if *self > u64::from(u8::MAX) {
9
} else {
2
}
}
fn encode(&self, buf: &mut BytesMut) {
if *self == 0 {
buf.put_u8(codec::FORMATCODE_ULONG_0)
} else if *self > u64::from(u8::MAX) {
buf.put_u8(codec::FORMATCODE_ULONG);
buf.put_u64(*self);
} else {
buf.put_u8(codec::FORMATCODE_SMALLULONG);
buf.put_u8(*self as u8);
}
}
}
impl ArrayEncode for u64 {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_ULONG;
fn array_encoded_size(&self) -> usize {
8
}
fn array_encode(&self, buf: &mut BytesMut) {
buf.put_u64(*self);
}
}
impl FixedEncode for i8 {}
impl ArrayEncode for i8 {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_BYTE;
fn array_encoded_size(&self) -> usize {
1
}
fn array_encode(&self, buf: &mut BytesMut) {
buf.put_i8(*self);
}
}
impl FixedEncode for i16 {}
impl ArrayEncode for i16 {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_SHORT;
fn array_encoded_size(&self) -> usize {
2
}
fn array_encode(&self, buf: &mut BytesMut) {
buf.put_i16(*self);
}
}
impl Encode for i32 {
fn encoded_size(&self) -> usize {
if *self > i32::from(i8::MAX) || *self < i32::from(i8::MIN) {
5
} else {
2
}
}
fn encode(&self, buf: &mut BytesMut) {
if *self > i32::from(i8::MAX) || *self < i32::from(i8::MIN) {
buf.put_u8(codec::FORMATCODE_INT);
buf.put_i32(*self);
} else {
buf.put_u8(codec::FORMATCODE_SMALLINT);
buf.put_i8(*self as i8);
}
}
}
impl ArrayEncode for i32 {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_INT;
fn array_encoded_size(&self) -> usize {
4
}
fn array_encode(&self, buf: &mut BytesMut) {
buf.put_i32(*self);
}
}
impl Encode for i64 {
fn encoded_size(&self) -> usize {
if *self > i64::from(i8::MAX) || *self < i64::from(i8::MIN) {
9
} else {
2
}
}
fn encode(&self, buf: &mut BytesMut) {
if *self > i64::from(i8::MAX) || *self < i64::from(i8::MIN) {
buf.put_u8(codec::FORMATCODE_LONG);
buf.put_i64(*self);
} else {
buf.put_u8(codec::FORMATCODE_SMALLLONG);
buf.put_i8(*self as i8);
}
}
}
impl ArrayEncode for i64 {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_LONG;
fn array_encoded_size(&self) -> usize {
8
}
fn array_encode(&self, buf: &mut BytesMut) {
buf.put_i64(*self);
}
}
impl FixedEncode for f32 {}
impl ArrayEncode for f32 {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_FLOAT;
fn array_encoded_size(&self) -> usize {
4
}
fn array_encode(&self, buf: &mut BytesMut) {
buf.put_f32(*self);
}
}
impl FixedEncode for f64 {}
impl ArrayEncode for f64 {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_DOUBLE;
fn array_encoded_size(&self) -> usize {
8
}
fn array_encode(&self, buf: &mut BytesMut) {
buf.put_f64(*self);
}
}
impl FixedEncode for char {}
impl ArrayEncode for char {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_CHAR;
fn array_encoded_size(&self) -> usize {
4
}
fn array_encode(&self, buf: &mut BytesMut) {
buf.put_u32(*self as u32);
}
}
impl FixedEncode for DateTime<Utc> {}
impl ArrayEncode for DateTime<Utc> {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_TIMESTAMP;
fn array_encoded_size(&self) -> usize {
8
}
fn array_encode(&self, buf: &mut BytesMut) {
let timestamp = self.timestamp() * 1000 + i64::from(self.timestamp_subsec_millis());
buf.put_i64(timestamp);
}
}
impl FixedEncode for Uuid {}
impl ArrayEncode for Uuid {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_UUID;
fn array_encoded_size(&self) -> usize {
16
}
fn array_encode(&self, buf: &mut BytesMut) {
buf.extend_from_slice(self.as_bytes());
}
}
impl Encode for Bytes {
fn encoded_size(&self) -> usize {
let length = self.len();
let size = if length > u8::MAX as usize { 5 } else { 2 };
size + length
}
fn encode(&self, buf: &mut BytesMut) {
let length = self.len();
if length > u8::MAX as usize {
buf.put_u8(codec::FORMATCODE_BINARY32);
buf.put_u32(length as u32);
} else {
buf.put_u8(codec::FORMATCODE_BINARY8);
buf.put_u8(length as u8);
}
buf.put_slice(self);
}
}
impl ArrayEncode for Bytes {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_BINARY32;
fn array_encoded_size(&self) -> usize {
4 + self.len()
}
fn array_encode(&self, buf: &mut BytesMut) {
buf.put_u32(self.len() as u32);
buf.put_slice(&self);
}
}
impl Encode for ByteString {
fn encoded_size(&self) -> usize {
let length = self.len();
let size = if length > u8::MAX as usize { 5 } else { 2 };
size + length
}
fn encode(&self, buf: &mut BytesMut) {
let length = self.len();
if length > u8::MAX as usize {
buf.put_u8(codec::FORMATCODE_STRING32);
buf.put_u32(length as u32);
} else {
buf.put_u8(codec::FORMATCODE_STRING8);
buf.put_u8(length as u8);
}
buf.put_slice(self.as_bytes());
}
}
impl ArrayEncode for ByteString {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_STRING32;
fn array_encoded_size(&self) -> usize {
4 + self.len()
}
fn array_encode(&self, buf: &mut BytesMut) {
buf.put_u32(self.len() as u32);
buf.put_slice(self.as_bytes());
}
}
impl Encode for str {
fn encoded_size(&self) -> usize {
let length = self.len();
let size = if length > u8::MAX as usize { 5 } else { 2 };
size + length
}
fn encode(&self, buf: &mut BytesMut) {
let length = self.len();
if length > u8::MAX as usize {
buf.put_u8(codec::FORMATCODE_STRING32);
buf.put_u32(length as u32);
} else {
buf.put_u8(codec::FORMATCODE_STRING8);
buf.put_u8(length as u8);
}
buf.put_slice(self.as_bytes());
}
}
impl ArrayEncode for str {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_STRING32;
fn array_encoded_size(&self) -> usize {
4 + self.len()
}
fn array_encode(&self, buf: &mut BytesMut) {
buf.put_u32(self.len() as u32);
buf.put_slice(self.as_bytes());
}
}
impl Encode for Str {
fn encoded_size(&self) -> usize {
let length = self.len();
let size = if length > u8::MAX as usize { 5 } else { 2 };
size + length
}
fn encode(&self, buf: &mut BytesMut) {
let length = self.as_str().len();
if length > u8::MAX as usize {
buf.put_u8(codec::FORMATCODE_STRING32);
buf.put_u32(length as u32);
} else {
buf.put_u8(codec::FORMATCODE_STRING8);
buf.put_u8(length as u8);
}
buf.put_slice(self.as_bytes());
}
}
impl Encode for Symbol {
fn encoded_size(&self) -> usize {
let length = self.len();
let size = if length > u8::MAX as usize { 5 } else { 2 };
size + length
}
fn encode(&self, buf: &mut BytesMut) {
let length = self.as_str().len();
if length > u8::MAX as usize {
buf.put_u8(codec::FORMATCODE_SYMBOL32);
buf.put_u32(length as u32);
} else {
buf.put_u8(codec::FORMATCODE_SYMBOL8);
buf.put_u8(length as u8);
}
buf.put_slice(self.as_bytes());
}
}
impl ArrayEncode for Symbol {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_SYMBOL32;
fn array_encoded_size(&self) -> usize {
4 + self.len()
}
fn array_encode(&self, buf: &mut BytesMut) {
buf.put_u32(self.len() as u32);
buf.put_slice(self.as_bytes());
}
}
impl Encode for StaticSymbol {
fn encoded_size(&self) -> usize {
let length = self.0.len();
let size = if length > u8::MAX as usize { 5 } else { 2 };
size + length
}
fn encode(&self, buf: &mut BytesMut) {
let length = self.0.len();
if length > u8::MAX as usize {
buf.put_u8(codec::FORMATCODE_SYMBOL32);
buf.put_u32(length as u32);
} else {
buf.put_u8(codec::FORMATCODE_SYMBOL8);
buf.put_u8(length as u8);
}
buf.put_slice(self.0.as_bytes());
}
}
fn map_encoded_size<K: Hash + Eq + Encode, V: Encode, S: BuildHasher>(
map: &HashMap<K, V, S>,
) -> usize {
map.iter()
.fold(0, |r, (k, v)| r + k.encoded_size() + v.encoded_size())
}
impl<K: Eq + Hash + Encode, V: Encode, S: BuildHasher> Encode for HashMap<K, V, S> {
fn encoded_size(&self) -> usize {
let size = map_encoded_size(self);
// f:1 + s:4 + c:4 vs f:1 + s:1 + c:1
let preamble = if size + 1 > u8::MAX as usize { 9 } else { 3 };
preamble + size
}
fn encode(&self, buf: &mut BytesMut) {
let count = self.len() * 2; // key-value pair accounts for two items in count
let size = map_encoded_size(self);
if size + 1 > u8::MAX as usize {
buf.put_u8(codec::FORMATCODE_MAP32);
buf.put_u32((size + 4) as u32); // +4 for 4 byte count that follows
buf.put_u32(count as u32);
} else {
buf.put_u8(codec::FORMATCODE_MAP8);
buf.put_u8((size + 1) as u8); // +1 for 1 byte count that follows
buf.put_u8(count as u8);
}
for (k, v) in self {
k.encode(buf);
v.encode(buf);
}
}
}
impl<K: Eq + Hash + Encode, V: Encode> ArrayEncode for HashMap<K, V> {
const ARRAY_FORMAT_CODE: u8 = codec::FORMATCODE_MAP32;
fn array_encoded_size(&self) -> usize {
8 + map_encoded_size(self)
}
fn array_encode(&self, buf: &mut BytesMut) {
let count = self.len() * 2;
let size = map_encoded_size(self) + 4;
buf.put_u32(size as u32);
buf.put_u32(count as u32);
for (k, v) in self {
k.encode(buf);
v.encode(buf);
}
}
}
impl Encode for VecSymbolMap {
fn encoded_size(&self) -> usize {
let size = self
.0
.iter()
.fold(0, |r, (k, v)| r + k.encoded_size() + v.encoded_size());
// f:1 + s:4 + c:4 vs f:1 + s:1 + c:1
let preamble = if size + 1 > u8::MAX as usize { 9 } else { 3 };
preamble + size
}
fn encode(&self, buf: &mut BytesMut) {
let count = self.len() * 2; // key-value pair accounts for two items in count
let size = self
.0
.iter()
.fold(0, |r, (k, v)| r + k.encoded_size() + v.encoded_size());
if size + 1 > u8::MAX as usize {
buf.put_u8(codec::FORMATCODE_MAP32);
buf.put_u32((size + 4) as u32); // +4 for 4 byte count that follows
buf.put_u32(count as u32);
} else {
buf.put_u8(codec::FORMATCODE_MAP8);
buf.put_u8((size + 1) as u8); // +1 for 1 byte count that follows
buf.put_u8(count as u8);
}
for (k, v) in self.iter() {
k.encode(buf);
v.encode(buf);
}
}
}
impl Encode for VecStringMap {
fn encoded_size(&self) -> usize {
let size = self
.0
.iter()
.fold(0, |r, (k, v)| r + k.encoded_size() + v.encoded_size());
// f:1 + s:4 + c:4 vs f:1 + s:1 + c:1
let preamble = if size + 1 > u8::MAX as usize { 9 } else { 3 };
preamble + size
}
fn encode(&self, buf: &mut BytesMut) {
let count = self.len() * 2; // key-value pair accounts for two items in count
let size = self
.0
.iter()
.fold(0, |r, (k, v)| r + k.encoded_size() + v.encoded_size());
if size + 1 > u8::MAX as usize {
buf.put_u8(codec::FORMATCODE_MAP32);
buf.put_u32((size + 4) as u32); // +4 for 4 byte count that follows
buf.put_u32(count as u32);
} else {
buf.put_u8(codec::FORMATCODE_MAP8);
buf.put_u8((size + 1) as u8); // +1 for 1 byte count that follows
buf.put_u8(count as u8);
}
for (k, v) in self.iter() {
k.encode(buf);
v.encode(buf);
}
}
}
fn array_encoded_size<T: ArrayEncode>(vec: &[T]) -> usize {
vec.iter().fold(0, |r, i| r + i.array_encoded_size())
}
impl<T: ArrayEncode> Encode for Vec<T> {
fn encoded_size(&self) -> usize {
let content_size = array_encoded_size(self);
// format_code + size + count + item constructor -- todo: support described ctor?
(if content_size + 1 > u8::MAX as usize {
10
} else {
4
}) // +1 for 1 byte count and 1 byte format code
+ content_size
}
fn encode(&self, buf: &mut BytesMut) {
let size = array_encoded_size(self);
if size + 1 > u8::MAX as usize {
buf.put_u8(codec::FORMATCODE_ARRAY32);
buf.put_u32((size + 5) as u32); // +4 for 4 byte count and 1 byte item ctor that follow
buf.put_u32(self.len() as u32);
} else {
buf.put_u8(codec::FORMATCODE_ARRAY8);
buf.put_u8((size + 2) as u8); // +1 for 1 byte count and 1 byte item ctor that follow
buf.put_u8(self.len() as u8);
}
buf.put_u8(T::ARRAY_FORMAT_CODE);
for i in self {
i.array_encode(buf);
}
}
}
impl<T: Encode + ArrayEncode> Encode for Multiple<T> {
fn encoded_size(&self) -> usize {
let count = self.len();
if count == 1 {
// special case: single item is encoded without array encoding
self.0[0].encoded_size()
} else {
self.0.encoded_size()
}
}
fn encode(&self, buf: &mut BytesMut) {
let count = self.0.len();
if count == 1 {
// special case: single item is encoded without array encoding
self.0[0].encode(buf)
} else {
self.0.encode(buf)
}
}
}
fn list_encoded_size(vec: &List) -> usize {
vec.iter().fold(0, |r, i| r + i.encoded_size())
}
impl Encode for List {
fn encoded_size(&self) -> usize {
let content_size = list_encoded_size(self);
// format_code + size + count
(if content_size + 1 > u8::MAX as usize {
9
} else {
3
}) + content_size
}
fn encode(&self, buf: &mut BytesMut) {
let size = list_encoded_size(self);
if size + 1 > u8::MAX as usize {
buf.put_u8(codec::FORMATCODE_ARRAY32);
buf.put_u32((size + 4) as u32); // +4 for 4 byte count that follow
buf.put_u32(self.len() as u32);
} else {
buf.put_u8(codec::FORMATCODE_ARRAY8);
buf.put_u8((size + 1) as u8); // +1 for 1 byte count that follow
buf.put_u8(self.len() as u8);
}
for i in self.iter() {
i.encode(buf);
}
}
}
impl Encode for Variant {
fn encoded_size(&self) -> usize {
match *self {
Variant::Null => 1,
Variant::Boolean(b) => b.encoded_size(),
Variant::Ubyte(b) => b.encoded_size(),
Variant::Ushort(s) => s.encoded_size(),
Variant::Uint(i) => i.encoded_size(),
Variant::Ulong(l) => l.encoded_size(),
Variant::Byte(b) => b.encoded_size(),
Variant::Short(s) => s.encoded_size(),
Variant::Int(i) => i.encoded_size(),
Variant::Long(l) => l.encoded_size(),
Variant::Float(f) => f.encoded_size(),
Variant::Double(d) => d.encoded_size(),
Variant::Char(c) => c.encoded_size(),
Variant::Timestamp(ref t) => t.encoded_size(),
Variant::Uuid(ref u) => u.encoded_size(),
Variant::Binary(ref b) => b.encoded_size(),
Variant::String(ref s) => s.encoded_size(),
Variant::Symbol(ref s) => s.encoded_size(),
Variant::StaticSymbol(ref s) => s.encoded_size(),
Variant::List(ref l) => l.encoded_size(),
Variant::Map(ref m) => m.map.encoded_size(),
Variant::Described(ref dv) => dv.0.encoded_size() + dv.1.encoded_size(),
}
}
/// Encodes `Variant` into provided `BytesMut`
fn encode(&self, buf: &mut BytesMut) {
match *self {
Variant::Null => encode_null(buf),
Variant::Boolean(b) => b.encode(buf),
Variant::Ubyte(b) => b.encode(buf),
Variant::Ushort(s) => s.encode(buf),
Variant::Uint(i) => i.encode(buf),
Variant::Ulong(l) => l.encode(buf),
Variant::Byte(b) => b.encode(buf),
Variant::Short(s) => s.encode(buf),
Variant::Int(i) => i.encode(buf),
Variant::Long(l) => l.encode(buf),
Variant::Float(f) => f.encode(buf),
Variant::Double(d) => d.encode(buf),
Variant::Char(c) => c.encode(buf),
Variant::Timestamp(ref t) => t.encode(buf),
Variant::Uuid(ref u) => u.encode(buf),
Variant::Binary(ref b) => b.encode(buf),
Variant::String(ref s) => s.encode(buf),
Variant::Symbol(ref s) => s.encode(buf),
Variant::StaticSymbol(ref s) => s.encode(buf),
Variant::List(ref l) => l.encode(buf),
Variant::Map(ref m) => m.map.encode(buf),
Variant::Described(ref dv) => {
dv.0.encode(buf);
dv.1.encode(buf);
}
}
}
}
impl<T: Encode> Encode for Option<T> {
fn encoded_size(&self) -> usize {
self.as_ref().map_or(1, |v| v.encoded_size())
}
fn encode(&self, buf: &mut BytesMut) {
match *self {
Some(ref e) => e.encode(buf),
None => encode_null(buf),
}
}
}
impl Encode for Descriptor {
fn encoded_size(&self) -> usize {
match *self {
Descriptor::Ulong(v) => 1 + v.encoded_size(),
Descriptor::Symbol(ref v) => 1 + v.encoded_size(),
}
}
fn encode(&self, buf: &mut BytesMut) {
buf.put_u8(codec::FORMATCODE_DESCRIBED);
match *self {
Descriptor::Ulong(v) => v.encode(buf),
Descriptor::Symbol(ref v) => v.encode(buf),
}
}
}
const WORD_LEN: usize = 4;
impl Encode for AmqpFrame {
fn encoded_size(&self) -> usize {
framing::HEADER_LEN + self.performative().encoded_size()
}
fn encode(&self, buf: &mut BytesMut) {
let doff: u8 = (framing::HEADER_LEN / WORD_LEN) as u8;
buf.put_u32(self.encoded_size() as u32);
buf.put_u8(doff);
buf.put_u8(framing::FRAME_TYPE_AMQP);
buf.put_u16(self.channel_id());
self.performative().encode(buf);
}
}
impl Encode for SaslFrame {
fn encoded_size(&self) -> usize {
framing::HEADER_LEN + self.body.encoded_size()
}
fn encode(&self, buf: &mut BytesMut) {
let doff: u8 = (framing::HEADER_LEN / WORD_LEN) as u8;
buf.put_u32(self.encoded_size() as u32);
buf.put_u8(doff);
buf.put_u8(framing::FRAME_TYPE_SASL);
buf.put_u16(0);
self.body.encode(buf);
}
}

View File

@ -0,0 +1 @@
Lorem ipsum dolor sit amet, consectetur adipiscing elit. Donec accumsan iaculis ipsum sed convallis. Phasellus consectetur justo et odio maximus, vel vehicula sapien venenatis. Nunc ac viverra risus. Pellentesque elementum, mauris et viverra ultricies, lacus erat varius nulla, eget maximus nisl sed.

150
actix-amqp/codec/src/codec/mod.rs Executable file
View File

@ -0,0 +1,150 @@
use bytes::BytesMut;
use std::marker::Sized;
use crate::errors::AmqpParseError;
macro_rules! decode_check_len {
($buf:ident, $size:expr) => {
if $buf.len() < $size {
return Err(AmqpParseError::Incomplete(Some($size)));
}
};
}
#[macro_use]
mod decode;
mod encode;
pub(crate) use self::decode::decode_list_header;
pub trait Encode {
fn encoded_size(&self) -> usize;
fn encode(&self, buf: &mut BytesMut);
}
pub trait ArrayEncode {
const ARRAY_FORMAT_CODE: u8;
fn array_encoded_size(&self) -> usize;
fn array_encode(&self, buf: &mut BytesMut);
}
pub trait Decode
where
Self: Sized,
{
fn decode(input: &[u8]) -> Result<(&[u8], Self), AmqpParseError>;
}
pub trait DecodeFormatted
where
Self: Sized,
{
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError>;
}
pub trait ArrayDecode: Sized {
fn array_decode(input: &[u8]) -> Result<(&[u8], Self), AmqpParseError>;
}
impl<T: DecodeFormatted> Decode for T {
fn decode(input: &[u8]) -> Result<(&[u8], Self), AmqpParseError> {
let (input, fmt) = decode_format_code(input)?;
T::decode_with_format(input, fmt)
}
}
pub fn decode_format_code(input: &[u8]) -> Result<(&[u8], u8), AmqpParseError> {
decode_check_len!(input, 1);
Ok((&input[1..], input[0]))
}
pub const FORMATCODE_DESCRIBED: u8 = 0x00;
pub const FORMATCODE_NULL: u8 = 0x40; // fixed width --V
pub const FORMATCODE_BOOLEAN: u8 = 0x56;
pub const FORMATCODE_BOOLEAN_TRUE: u8 = 0x41;
pub const FORMATCODE_BOOLEAN_FALSE: u8 = 0x42;
pub const FORMATCODE_UINT_0: u8 = 0x43;
pub const FORMATCODE_ULONG_0: u8 = 0x44;
pub const FORMATCODE_UBYTE: u8 = 0x50;
pub const FORMATCODE_USHORT: u8 = 0x60;
pub const FORMATCODE_UINT: u8 = 0x70;
pub const FORMATCODE_ULONG: u8 = 0x80;
pub const FORMATCODE_BYTE: u8 = 0x51;
pub const FORMATCODE_SHORT: u8 = 0x61;
pub const FORMATCODE_INT: u8 = 0x71;
pub const FORMATCODE_LONG: u8 = 0x81;
pub const FORMATCODE_SMALLUINT: u8 = 0x52;
pub const FORMATCODE_SMALLULONG: u8 = 0x53;
pub const FORMATCODE_SMALLINT: u8 = 0x54;
pub const FORMATCODE_SMALLLONG: u8 = 0x55;
pub const FORMATCODE_FLOAT: u8 = 0x72;
pub const FORMATCODE_DOUBLE: u8 = 0x82;
// pub const FORMATCODE_DECIMAL32: u8 = 0x74;
// pub const FORMATCODE_DECIMAL64: u8 = 0x84;
// pub const FORMATCODE_DECIMAL128: u8 = 0x94;
pub const FORMATCODE_CHAR: u8 = 0x73;
pub const FORMATCODE_TIMESTAMP: u8 = 0x83;
pub const FORMATCODE_UUID: u8 = 0x98;
pub const FORMATCODE_BINARY8: u8 = 0xa0; // variable --V
pub const FORMATCODE_BINARY32: u8 = 0xb0;
pub const FORMATCODE_STRING8: u8 = 0xa1;
pub const FORMATCODE_STRING32: u8 = 0xb1;
pub const FORMATCODE_SYMBOL8: u8 = 0xa3;
pub const FORMATCODE_SYMBOL32: u8 = 0xb3;
pub const FORMATCODE_LIST0: u8 = 0x45; // compound --V
pub const FORMATCODE_LIST8: u8 = 0xc0;
pub const FORMATCODE_LIST32: u8 = 0xd0;
pub const FORMATCODE_MAP8: u8 = 0xc1;
pub const FORMATCODE_MAP32: u8 = 0xd1;
pub const FORMATCODE_ARRAY8: u8 = 0xe0;
pub const FORMATCODE_ARRAY32: u8 = 0xf0;
#[cfg(test)]
mod tests {
use bytes::{Bytes, BytesMut};
use crate::codec::{Decode, Encode};
use crate::errors::AmqpCodecError;
use crate::framing::{AmqpFrame, SaslFrame};
use crate::protocol::SaslFrameBody;
#[test]
fn test_sasl_mechanisms() -> Result<(), AmqpCodecError> {
let data = b"\x02\x01\0\0\0S@\xc02\x01\xe0/\x04\xb3\0\0\0\x07MSSBCBS\0\0\0\x05PLAIN\0\0\0\tANONYMOUS\0\0\0\x08EXTERNAL";
let (remainder, frame) = SaslFrame::decode(data.as_ref())?;
assert!(remainder.is_empty());
match frame.body {
SaslFrameBody::SaslMechanisms(_) => (),
_ => panic!("error"),
}
let mut buf = BytesMut::new();
buf.reserve(frame.encoded_size());
frame.encode(&mut buf);
buf.split_to(4);
assert_eq!(Bytes::from_static(data), buf.freeze());
Ok(())
}
#[test]
fn test_disposition() -> Result<(), AmqpCodecError> {
let data = b"\x02\0\0\0\0S\x15\xc0\x0c\x06AC@A\0S$\xc0\x01\0B";
let (remainder, frame) = AmqpFrame::decode(data.as_ref())?;
assert!(remainder.is_empty());
assert_eq!(frame.performative().name(), "Disposition");
let mut buf = BytesMut::new();
buf.reserve(frame.encoded_size());
frame.encode(&mut buf);
buf.split_to(4);
assert_eq!(Bytes::from_static(data), buf.freeze());
Ok(())
}
}

73
actix-amqp/codec/src/errors.rs Executable file
View File

@ -0,0 +1,73 @@
use uuid;
use crate::protocol::ProtocolId;
use crate::types::Descriptor;
#[derive(Debug, Display, From, Clone)]
pub enum AmqpParseError {
#[display(fmt = "Loaded item size is invalid")]
InvalidSize,
#[display(fmt = "More data required during frame parsing: '{:?}'", "_0")]
Incomplete(Option<usize>),
#[from(ignore)]
#[display(fmt = "Unexpected format code: '{}'", "_0")]
InvalidFormatCode(u8),
#[display(fmt = "Invalid value converting to char: {}", "_0")]
InvalidChar(u32),
#[display(fmt = "Unexpected descriptor: '{:?}'", "_0")]
InvalidDescriptor(Descriptor),
#[from(ignore)]
#[display(fmt = "Unexpected frame type: '{:?}'", "_0")]
UnexpectedFrameType(u8),
#[from(ignore)]
#[display(fmt = "Required field '{:?}' was omitted.", "_0")]
RequiredFieldOmitted(&'static str),
#[from(ignore)]
#[display(fmt = "Unknown {:?} option.", "_0")]
UnknownEnumOption(&'static str),
UuidParseError(uuid::Error),
Utf8Error(std::str::Utf8Error),
}
#[derive(Debug, Display, From)]
pub enum AmqpCodecError {
ParseError(AmqpParseError),
#[display(fmt = "bytes left unparsed at the frame trail")]
UnparsedBytesLeft,
#[display(fmt = "max inbound frame size exceeded")]
MaxSizeExceeded,
#[display(fmt = "Io error: {:?}", _0)]
Io(Option<std::io::Error>),
}
impl Clone for AmqpCodecError {
fn clone(&self) -> AmqpCodecError {
match self {
AmqpCodecError::ParseError(err) => AmqpCodecError::ParseError(err.clone()),
AmqpCodecError::UnparsedBytesLeft => AmqpCodecError::UnparsedBytesLeft,
AmqpCodecError::MaxSizeExceeded => AmqpCodecError::MaxSizeExceeded,
AmqpCodecError::Io(_) => AmqpCodecError::Io(None),
}
}
}
impl From<std::io::Error> for AmqpCodecError {
fn from(err: std::io::Error) -> AmqpCodecError {
AmqpCodecError::Io(Some(err))
}
}
#[derive(Debug, Display, From)]
pub enum ProtocolIdError {
InvalidHeader,
Incompatible,
Unknown,
#[display(fmt = "Expected {:?} protocol id, seen {:?} instead.", exp, got)]
Unexpected {
exp: ProtocolId,
got: ProtocolId,
},
Disconnected,
#[display(fmt = "io error: {:?}", "_0")]
Io(std::io::Error),
}

80
actix-amqp/codec/src/framing.rs Executable file
View File

@ -0,0 +1,80 @@
use super::protocol;
/// Length in bytes of the fixed frame header
pub const HEADER_LEN: usize = 8;
/// AMQP Frame type marker (0)
pub const FRAME_TYPE_AMQP: u8 = 0x00;
pub const FRAME_TYPE_SASL: u8 = 0x01;
/// Represents an AMQP Frame
#[derive(Clone, Debug, PartialEq)]
pub struct AmqpFrame {
channel_id: u16,
performative: protocol::Frame,
}
impl AmqpFrame {
pub fn new(channel_id: u16, performative: protocol::Frame) -> AmqpFrame {
AmqpFrame {
channel_id,
performative,
}
}
#[inline]
pub fn channel_id(&self) -> u16 {
self.channel_id
}
#[inline]
pub fn performative(&self) -> &protocol::Frame {
&self.performative
}
#[inline]
pub fn into_parts(self) -> (u16, protocol::Frame) {
(self.channel_id, self.performative)
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct SaslFrame {
pub body: protocol::SaslFrameBody,
}
impl SaslFrame {
pub fn new(body: protocol::SaslFrameBody) -> SaslFrame {
SaslFrame { body }
}
}
impl From<protocol::SaslMechanisms> for SaslFrame {
fn from(item: protocol::SaslMechanisms) -> SaslFrame {
SaslFrame::new(protocol::SaslFrameBody::SaslMechanisms(item))
}
}
impl From<protocol::SaslInit> for SaslFrame {
fn from(item: protocol::SaslInit) -> SaslFrame {
SaslFrame::new(protocol::SaslFrameBody::SaslInit(item))
}
}
impl From<protocol::SaslChallenge> for SaslFrame {
fn from(item: protocol::SaslChallenge) -> SaslFrame {
SaslFrame::new(protocol::SaslFrameBody::SaslChallenge(item))
}
}
impl From<protocol::SaslResponse> for SaslFrame {
fn from(item: protocol::SaslResponse) -> SaslFrame {
SaslFrame::new(protocol::SaslFrameBody::SaslResponse(item))
}
}
impl From<protocol::SaslOutcome> for SaslFrame {
fn from(item: protocol::SaslOutcome) -> SaslFrame {
SaslFrame::new(protocol::SaslFrameBody::SaslOutcome(item))
}
}

160
actix-amqp/codec/src/io.rs Executable file
View File

@ -0,0 +1,160 @@
use std::marker::PhantomData;
use actix_codec::{Decoder, Encoder};
use byteorder::{BigEndian, ByteOrder};
use bytes::{BufMut, BytesMut};
use super::errors::{AmqpCodecError, ProtocolIdError};
use super::framing::HEADER_LEN;
use crate::codec::{Decode, Encode};
use crate::protocol::ProtocolId;
const SIZE_LOW_WM: usize = 4096;
const SIZE_HIGH_WM: usize = 32768;
#[derive(Debug)]
pub struct AmqpCodec<T: Decode + Encode> {
state: DecodeState,
max_size: usize,
phantom: PhantomData<T>,
}
#[derive(Debug, Clone, Copy)]
enum DecodeState {
FrameHeader,
Frame(usize),
}
impl<T: Decode + Encode> Default for AmqpCodec<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Decode + Encode> AmqpCodec<T> {
pub fn new() -> AmqpCodec<T> {
AmqpCodec {
state: DecodeState::FrameHeader,
max_size: 0,
phantom: PhantomData,
}
}
/// Set max inbound frame size.
///
/// If max size is set to `0`, size is unlimited.
/// By default max size is set to `0`
pub fn max_size(&mut self, size: usize) {
self.max_size = size;
}
}
impl<T: Decode + Encode> Decoder for AmqpCodec<T> {
type Item = T;
type Error = AmqpCodecError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
loop {
match self.state {
DecodeState::FrameHeader => {
let len = src.len();
if len < HEADER_LEN {
return Ok(None);
}
// read frame size
let size = BigEndian::read_u32(src.as_ref()) as usize;
if self.max_size != 0 && size > self.max_size {
return Err(AmqpCodecError::MaxSizeExceeded);
}
self.state = DecodeState::Frame(size - 4);
src.split_to(4);
if len < size {
// extend receiving buffer to fit the whole frame
if src.remaining_mut() < std::cmp::max(SIZE_LOW_WM, size + HEADER_LEN) {
src.reserve(SIZE_HIGH_WM);
}
return Ok(None);
}
}
DecodeState::Frame(size) => {
if src.len() < size {
return Ok(None);
}
let frame_buf = src.split_to(size);
let (remainder, frame) = T::decode(frame_buf.as_ref())?;
if !remainder.is_empty() {
// todo: could it really happen?
return Err(AmqpCodecError::UnparsedBytesLeft);
}
self.state = DecodeState::FrameHeader;
return Ok(Some(frame));
}
}
}
}
}
impl<T: Decode + Encode + ::std::fmt::Debug> Encoder for AmqpCodec<T> {
type Item = T;
type Error = AmqpCodecError;
fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
let size = item.encoded_size();
let need = std::cmp::max(SIZE_LOW_WM, size);
if dst.remaining_mut() < need {
dst.reserve(std::cmp::max(need, SIZE_HIGH_WM));
}
item.encode(dst);
Ok(())
}
}
const PROTOCOL_HEADER_LEN: usize = 8;
const PROTOCOL_HEADER_PREFIX: &[u8] = b"AMQP";
const PROTOCOL_VERSION: &[u8] = &[1, 0, 0];
#[derive(Default, Debug)]
pub struct ProtocolIdCodec;
impl Decoder for ProtocolIdCodec {
type Item = ProtocolId;
type Error = ProtocolIdError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < PROTOCOL_HEADER_LEN {
Ok(None)
} else {
let src = src.split_to(8);
if &src[0..4] != PROTOCOL_HEADER_PREFIX {
Err(ProtocolIdError::InvalidHeader)
} else if &src[5..8] != PROTOCOL_VERSION {
Err(ProtocolIdError::Incompatible)
} else {
let protocol_id = src[4];
match protocol_id {
0 => Ok(Some(ProtocolId::Amqp)),
2 => Ok(Some(ProtocolId::AmqpTls)),
3 => Ok(Some(ProtocolId::AmqpSasl)),
_ => Err(ProtocolIdError::Unknown),
}
}
}
}
}
impl Encoder for ProtocolIdCodec {
type Item = ProtocolId;
type Error = ProtocolIdError;
fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
dst.reserve(PROTOCOL_HEADER_LEN);
dst.put_slice(PROTOCOL_HEADER_PREFIX);
dst.put_u8(item as u8);
dst.put_slice(PROTOCOL_VERSION);
Ok(())
}
}

17
actix-amqp/codec/src/lib.rs Executable file
View File

@ -0,0 +1,17 @@
#[macro_use]
extern crate derive_more;
#[macro_use]
mod codec;
mod errors;
mod framing;
mod io;
mod message;
pub mod protocol;
pub mod types;
pub use self::codec::{Decode, Encode};
pub use self::errors::{AmqpCodecError, AmqpParseError, ProtocolIdError};
pub use self::framing::{AmqpFrame, SaslFrame};
pub use self::io::{AmqpCodec, ProtocolIdCodec};
pub use self::message::{InMessage, MessageBody, OutMessage};

View File

@ -0,0 +1,89 @@
use bytes::{BufMut, Bytes, BytesMut};
use crate::codec::{Encode, FORMATCODE_BINARY32, FORMATCODE_BINARY8};
use crate::protocol::TransferBody;
use crate::types::{Descriptor, List, Variant};
use super::SECTION_PREFIX_LENGTH;
#[derive(Debug, Clone, Default, PartialEq)]
pub struct MessageBody {
pub data: Vec<Bytes>,
pub sequence: Vec<List>,
pub messages: Vec<TransferBody>,
pub value: Option<Variant>,
}
impl MessageBody {
pub fn data(&self) -> Option<&Bytes> {
if self.data.is_empty() {
None
} else {
Some(&self.data[0])
}
}
pub fn value(&self) -> Option<&Variant> {
self.value.as_ref()
}
pub fn set_data(&mut self, data: Bytes) {
self.data.clear();
self.data.push(data);
}
}
impl Encode for MessageBody {
fn encoded_size(&self) -> usize {
let mut size = self
.data
.iter()
.fold(0, |a, d| a + d.encoded_size() + SECTION_PREFIX_LENGTH);
size += self
.sequence
.iter()
.fold(0, |a, seq| a + seq.encoded_size() + SECTION_PREFIX_LENGTH);
size += self.messages.iter().fold(0, |a, m| {
let length = m.encoded_size();
let size = length + if length > std::u8::MAX as usize { 5 } else { 2 };
a + size + SECTION_PREFIX_LENGTH
});
if let Some(ref val) = self.value {
size + val.encoded_size() + SECTION_PREFIX_LENGTH
} else {
size
}
}
fn encode(&self, dst: &mut BytesMut) {
self.data.iter().for_each(|d| {
Descriptor::Ulong(117).encode(dst);
d.encode(dst);
});
self.sequence.iter().for_each(|seq| {
Descriptor::Ulong(118).encode(dst);
seq.encode(dst)
});
if let Some(ref val) = self.value {
Descriptor::Ulong(119).encode(dst);
val.encode(dst);
}
// encode Message as nested Bytes object
self.messages.iter().for_each(|m| {
Descriptor::Ulong(117).encode(dst);
// Bytes prefix
let length = m.encoded_size();
if length > std::u8::MAX as usize {
dst.put_u8(FORMATCODE_BINARY32);
dst.put_u32(length as u32);
} else {
dst.put_u8(FORMATCODE_BINARY8);
dst.put_u8(length as u8);
}
// encode nested Message
m.encode(dst);
});
}
}

View File

@ -0,0 +1,406 @@
use std::cell::Cell;
use bytes::{BufMut, Bytes, BytesMut};
use fxhash::FxHashMap;
use crate::codec::{Decode, Encode, FORMATCODE_BINARY8};
use crate::errors::AmqpParseError;
use crate::protocol::{
Annotations, Header, MessageFormat, Properties, Section, StringVariantMap, TransferBody,
};
use crate::types::{Descriptor, Str, Variant};
use super::body::MessageBody;
use super::outmessage::OutMessage;
use super::SECTION_PREFIX_LENGTH;
#[derive(Debug, Clone, Default, PartialEq)]
pub struct InMessage {
pub message_format: Option<MessageFormat>,
pub(super) header: Option<Header>,
pub(super) delivery_annotations: Option<Annotations>,
pub(super) message_annotations: Option<Annotations>,
pub(super) properties: Option<Properties>,
pub(super) application_properties: Option<StringVariantMap>,
pub(super) footer: Option<Annotations>,
pub(super) body: MessageBody,
pub(super) size: Cell<usize>,
}
impl InMessage {
/// Create new message and set body
pub fn with_body(body: Bytes) -> InMessage {
let mut msg = InMessage::default();
msg.body.data.push(body);
msg
}
/// Create new message and set messages as body
pub fn with_messages(messages: Vec<TransferBody>) -> InMessage {
let mut msg = InMessage::default();
msg.body.messages = messages;
msg
}
/// Header
pub fn header(&self) -> Option<&Header> {
self.header.as_ref()
}
/// Set message header
pub fn set_header(mut self, header: Header) -> Self {
self.header = Some(header);
self.size.set(0);
self
}
/// Message properties
pub fn properties(&self) -> Option<&Properties> {
self.properties.as_ref()
}
/// Add property
pub fn set_properties<F>(mut self, f: F) -> Self
where
F: Fn(&mut Properties),
{
if let Some(ref mut props) = self.properties {
f(props);
} else {
let mut props = Properties::default();
f(&mut props);
self.properties = Some(props);
}
self.size.set(0);
self
}
/// Get application property
pub fn app_property(&self, key: &str) -> Option<&Variant> {
if let Some(ref props) = self.application_properties {
props.get(key)
} else {
None
}
}
/// Get application properties
pub fn app_properties(&self) -> Option<&StringVariantMap> {
self.application_properties.as_ref()
}
/// Get message annotation
pub fn message_annotation(&self, key: &str) -> Option<&Variant> {
if let Some(ref props) = self.message_annotations {
props.get(key)
} else {
None
}
}
/// Add application property
pub fn set_app_property<K: Into<Str>, V: Into<Variant>>(mut self, key: K, value: V) -> Self {
if let Some(ref mut props) = self.application_properties {
props.insert(key.into(), value.into());
} else {
let mut props = FxHashMap::default();
props.insert(key.into(), value.into());
self.application_properties = Some(props);
}
self.size.set(0);
self
}
/// Call closure with message reference
pub fn update<F>(self, f: F) -> Self
where
F: Fn(Self) -> Self,
{
self.size.set(0);
f(self)
}
/// Call closure if value is Some value
pub fn if_some<T, F>(self, value: &Option<T>, f: F) -> Self
where
F: Fn(Self, &T) -> Self,
{
if let Some(ref val) = value {
self.size.set(0);
f(self, val)
} else {
self
}
}
/// Message body
pub fn body(&self) -> &MessageBody {
&self.body
}
/// Message value
pub fn value(&self) -> Option<&Variant> {
self.body.value.as_ref()
}
/// Set message body value
pub fn set_value<V: Into<Variant>>(mut self, v: V) -> Self {
self.body.value = Some(v.into());
self
}
/// Set message body
pub fn set_body<F>(mut self, f: F) -> Self
where
F: Fn(&mut MessageBody),
{
f(&mut self.body);
self.size.set(0);
self
}
/// Create new message and set `correlation_id` property
pub fn reply_message(&self) -> OutMessage {
let mut msg = OutMessage::default().if_some(&self.properties, |mut msg, data| {
msg.set_properties(|props| props.correlation_id = data.message_id.clone());
msg
});
msg.message_format = self.message_format;
msg
}
}
impl Decode for InMessage {
fn decode(mut input: &[u8]) -> Result<(&[u8], InMessage), AmqpParseError> {
let mut message = InMessage::default();
loop {
let (buf, sec) = Section::decode(input)?;
match sec {
Section::Header(val) => {
message.header = Some(val);
}
Section::DeliveryAnnotations(val) => {
message.delivery_annotations = Some(val);
}
Section::MessageAnnotations(val) => {
message.message_annotations = Some(val);
}
Section::ApplicationProperties(val) => {
message.application_properties = Some(val);
}
Section::Footer(val) => {
message.footer = Some(val);
}
Section::Properties(val) => {
message.properties = Some(val);
}
// body
Section::AmqpSequence(val) => {
message.body.sequence.push(val);
}
Section::AmqpValue(val) => {
message.body.value = Some(val);
}
Section::Data(val) => {
message.body.data.push(val);
}
}
if buf.is_empty() {
break;
}
input = buf;
}
Ok((input, message))
}
}
impl Encode for InMessage {
fn encoded_size(&self) -> usize {
let size = self.size.get();
if size != 0 {
return size;
}
// body size, always add empty body if needed
let body_size = self.body.encoded_size();
let mut size = if body_size == 0 {
// empty bytes
SECTION_PREFIX_LENGTH + 2
} else {
body_size
};
if let Some(ref h) = self.header {
size += h.encoded_size() + SECTION_PREFIX_LENGTH;
}
if let Some(ref da) = self.delivery_annotations {
size += da.encoded_size() + SECTION_PREFIX_LENGTH;
}
if let Some(ref ma) = self.message_annotations {
size += ma.encoded_size() + SECTION_PREFIX_LENGTH;
}
if let Some(ref p) = self.properties {
size += p.encoded_size();
}
if let Some(ref ap) = self.application_properties {
size += ap.encoded_size() + SECTION_PREFIX_LENGTH;
}
if let Some(ref f) = self.footer {
size += f.encoded_size() + SECTION_PREFIX_LENGTH;
}
self.size.set(size);
size
}
fn encode(&self, dst: &mut BytesMut) {
if let Some(ref h) = self.header {
h.encode(dst);
}
if let Some(ref da) = self.delivery_annotations {
Descriptor::Ulong(113).encode(dst);
da.encode(dst);
}
if let Some(ref ma) = self.message_annotations {
Descriptor::Ulong(114).encode(dst);
ma.encode(dst);
}
if let Some(ref p) = self.properties {
p.encode(dst);
}
if let Some(ref ap) = self.application_properties {
Descriptor::Ulong(116).encode(dst);
ap.encode(dst);
}
// message body
if self.body.encoded_size() == 0 {
// special treatment for empty body
Descriptor::Ulong(117).encode(dst);
dst.put_u8(FORMATCODE_BINARY8);
dst.put_u8(0);
} else {
self.body.encode(dst);
}
// message footer, always last item
if let Some(ref f) = self.footer {
Descriptor::Ulong(120).encode(dst);
f.encode(dst);
}
}
}
#[cfg(test)]
mod tests {
use bytes::{Bytes, BytesMut};
use bytestring::ByteString;
use crate::codec::{Decode, Encode};
use crate::errors::AmqpCodecError;
use crate::protocol::Header;
use crate::types::Variant;
use super::InMessage;
#[test]
fn test_properties() -> Result<(), AmqpCodecError> {
let msg =
InMessage::with_body(Bytes::from_static(b"Hello world")).set_properties(|props| {
props.message_id = Some(Bytes::from_static(b"msg1").into());
props.content_type = Some("text".to_string().into());
props.correlation_id = Some(Bytes::from_static(b"no1").into());
props.content_encoding = Some("utf8+1".to_string().into());
});
let mut buf = BytesMut::with_capacity(msg.encoded_size());
msg.encode(&mut buf);
let msg2 = InMessage::decode(&buf)?.1;
let props = msg2.properties.as_ref().unwrap();
assert_eq!(props.message_id, Some(Bytes::from_static(b"msg1").into()));
assert_eq!(
props.correlation_id,
Some(Bytes::from_static(b"no1").into())
);
Ok(())
}
#[test]
fn test_app_properties() -> Result<(), AmqpCodecError> {
let msg = InMessage::default().set_app_property(ByteString::from("test"), 1);
let mut buf = BytesMut::with_capacity(msg.encoded_size());
msg.encode(&mut buf);
let msg2 = InMessage::decode(&buf)?.1;
let props = msg2.application_properties.as_ref().unwrap();
assert_eq!(*props.get("test").unwrap(), Variant::from(1));
Ok(())
}
#[test]
fn test_header() -> Result<(), AmqpCodecError> {
let hdr = Header {
durable: false,
priority: 1,
ttl: None,
first_acquirer: false,
delivery_count: 1,
};
let msg = InMessage::default().set_header(hdr.clone());
let mut buf = BytesMut::with_capacity(msg.encoded_size());
msg.encode(&mut buf);
let msg2 = InMessage::decode(&buf)?.1;
assert_eq!(msg2.header().unwrap(), &hdr);
Ok(())
}
#[test]
fn test_data() -> Result<(), AmqpCodecError> {
let data = Bytes::from_static(b"test data");
let msg = InMessage::default().set_body(|body| body.set_data(data.clone()));
let mut buf = BytesMut::with_capacity(msg.encoded_size());
msg.encode(&mut buf);
let msg2 = InMessage::decode(&buf)?.1;
assert_eq!(msg2.body.data().unwrap(), &data);
Ok(())
}
#[test]
fn test_data_empty() -> Result<(), AmqpCodecError> {
let msg = InMessage::default();
let mut buf = BytesMut::with_capacity(msg.encoded_size());
msg.encode(&mut buf);
let msg2 = InMessage::decode(&buf)?.1;
assert_eq!(msg2.body.data().unwrap(), &Bytes::from_static(b""));
Ok(())
}
#[test]
fn test_messages() -> Result<(), AmqpCodecError> {
let msg1 = InMessage::default().set_properties(|props| props.message_id = Some(1.into()));
let msg2 = InMessage::default().set_properties(|props| props.message_id = Some(2.into()));
let msg = InMessage::default().set_body(|body| {
body.messages.push(msg1.clone().into());
body.messages.push(msg2.clone().into());
});
let mut buf = BytesMut::with_capacity(msg.encoded_size());
msg.encode(&mut buf);
let msg3 = InMessage::decode(&buf)?.1;
let msg4 = InMessage::decode(&msg3.body.data().unwrap())?.1;
assert_eq!(msg1.properties, msg4.properties);
let msg5 = InMessage::decode(&msg3.body.data[1])?.1;
assert_eq!(msg2.properties, msg5.properties);
Ok(())
}
}

View File

@ -0,0 +1,9 @@
mod body;
mod inmessage;
mod outmessage;
pub use self::body::MessageBody;
pub use self::inmessage::InMessage;
pub use self::outmessage::OutMessage;
pub(self) const SECTION_PREFIX_LENGTH: usize = 3;

View File

@ -0,0 +1,440 @@
use std::cell::Cell;
use bytes::{BufMut, Bytes, BytesMut};
use crate::codec::{Decode, Encode, FORMATCODE_BINARY8};
use crate::errors::AmqpParseError;
use crate::protocol::{Annotations, Header, MessageFormat, Properties, Section, TransferBody};
use crate::types::{Descriptor, Str, Symbol, Variant, VecStringMap, VecSymbolMap};
use super::body::MessageBody;
use super::inmessage::InMessage;
use super::SECTION_PREFIX_LENGTH;
#[derive(Debug, Clone, Default, PartialEq)]
pub struct OutMessage {
pub message_format: Option<MessageFormat>,
header: Option<Header>,
delivery_annotations: Option<Annotations>,
message_annotations: Option<VecSymbolMap>,
properties: Option<Properties>,
application_properties: Option<VecStringMap>,
footer: Option<Annotations>,
body: MessageBody,
size: Cell<usize>,
}
impl OutMessage {
/// Create new message and set body
pub fn with_body(body: Bytes) -> OutMessage {
let mut msg = OutMessage::default();
msg.body.data.push(body);
msg.message_format = Some(0);
msg
}
/// Create new message and set messages as body
pub fn with_messages(messages: Vec<TransferBody>) -> OutMessage {
let mut msg = OutMessage::default();
msg.body.messages = messages;
msg.message_format = Some(0);
msg
}
/// Header
pub fn header(&self) -> Option<&Header> {
self.header.as_ref()
}
/// Set message header
pub fn set_header(&mut self, header: Header) -> &mut Self {
self.header = Some(header);
self.size.set(0);
self
}
/// Message properties
pub fn properties(&self) -> Option<&Properties> {
self.properties.as_ref()
}
/// Mutable reference to properties
pub fn properties_mut(&mut self) -> &mut Properties {
if self.properties.is_none() {
self.properties = Some(Properties::default());
}
self.size.set(0);
self.properties.as_mut().unwrap()
}
/// Add property
pub fn set_properties<F>(&mut self, f: F) -> &mut Self
where
F: Fn(&mut Properties),
{
if let Some(ref mut props) = self.properties {
f(props);
} else {
let mut props = Properties::default();
f(&mut props);
self.properties = Some(props);
}
self.size.set(0);
self
}
/// Get application property
pub fn app_properties(&self) -> Option<&VecStringMap> {
self.application_properties.as_ref()
}
/// Add application property
pub fn set_app_property<K, V>(&mut self, key: K, value: V) -> &mut Self
where
K: Into<Str>,
V: Into<Variant>,
{
if let Some(ref mut props) = self.application_properties {
props.push((key.into(), value.into()));
} else {
let mut props = VecStringMap::default();
props.push((key.into(), value.into()));
self.application_properties = Some(props);
}
self.size.set(0);
self
}
/// Add message annotation
pub fn add_message_annotation<K, V>(&mut self, key: K, value: V) -> &mut Self
where
K: Into<Symbol>,
V: Into<Variant>,
{
if let Some(ref mut props) = self.message_annotations {
props.push((key.into(), value.into()));
} else {
let mut props = VecSymbolMap::default();
props.push((key.into(), value.into()));
self.message_annotations = Some(props);
}
self.size.set(0);
self
}
/// Call closure with message reference
pub fn update<F>(self, f: F) -> Self
where
F: Fn(Self) -> Self,
{
self.size.set(0);
f(self)
}
/// Call closure if value is Some value
pub fn if_some<T, F>(self, value: &Option<T>, f: F) -> Self
where
F: Fn(Self, &T) -> Self,
{
if let Some(ref val) = value {
self.size.set(0);
f(self, val)
} else {
self
}
}
/// Message body
pub fn body(&self) -> &MessageBody {
&self.body
}
/// Message value
pub fn value(&self) -> Option<&Variant> {
self.body.value.as_ref()
}
/// Set message body value
pub fn set_value<V: Into<Variant>>(&mut self, v: V) -> &mut Self {
self.body.value = Some(v.into());
self
}
/// Set message body
pub fn set_body<F>(&mut self, f: F) -> &mut Self
where
F: FnOnce(&mut MessageBody),
{
f(&mut self.body);
self.size.set(0);
self
}
/// Create new message and set `correlation_id` property
pub fn reply_message(&self) -> OutMessage {
OutMessage::default().if_some(&self.properties, |mut msg, data| {
msg.set_properties(|props| props.correlation_id = data.message_id.clone());
msg
})
}
}
impl From<InMessage> for OutMessage {
fn from(from: InMessage) -> Self {
let mut msg = OutMessage {
message_format: from.message_format,
header: from.header,
properties: from.properties,
delivery_annotations: from.delivery_annotations,
message_annotations: from.message_annotations.map(|v| v.into()),
application_properties: from.application_properties.map(|v| v.into()),
footer: from.footer,
body: from.body,
size: Cell::new(0),
};
if let Some(ref mut props) = msg.properties {
props.correlation_id = props.message_id.clone();
};
msg
}
}
impl Decode for OutMessage {
fn decode(mut input: &[u8]) -> Result<(&[u8], OutMessage), AmqpParseError> {
let mut message = OutMessage::default();
loop {
let (buf, sec) = Section::decode(input)?;
match sec {
Section::Header(val) => {
message.header = Some(val);
}
Section::DeliveryAnnotations(val) => {
message.delivery_annotations = Some(val);
}
Section::MessageAnnotations(val) => {
message.message_annotations = Some(VecSymbolMap(
val.into_iter().map(|(k, v)| (k.0.into(), v)).collect(),
));
}
Section::ApplicationProperties(val) => {
message.application_properties = Some(VecStringMap(
val.into_iter().map(|(k, v)| (k.into(), v)).collect(),
));
}
Section::Footer(val) => {
message.footer = Some(val);
}
Section::Properties(val) => {
message.properties = Some(val);
}
// body
Section::AmqpSequence(val) => {
message.body.sequence.push(val);
}
Section::AmqpValue(val) => {
message.body.value = Some(val);
}
Section::Data(val) => {
message.body.data.push(val);
}
}
if buf.is_empty() {
break;
}
input = buf;
}
Ok((input, message))
}
}
impl Encode for OutMessage {
fn encoded_size(&self) -> usize {
let size = self.size.get();
if size != 0 {
return size;
}
// body size, always add empty body if needed
let body_size = self.body.encoded_size();
let mut size = if body_size == 0 {
// empty bytes
SECTION_PREFIX_LENGTH + 2
} else {
body_size
};
if let Some(ref h) = self.header {
size += h.encoded_size() + SECTION_PREFIX_LENGTH;
}
if let Some(ref da) = self.delivery_annotations {
size += da.encoded_size() + SECTION_PREFIX_LENGTH;
}
if let Some(ref ma) = self.message_annotations {
size += ma.encoded_size() + SECTION_PREFIX_LENGTH;
}
if let Some(ref p) = self.properties {
size += p.encoded_size();
}
if let Some(ref ap) = self.application_properties {
size += ap.encoded_size() + SECTION_PREFIX_LENGTH;
}
if let Some(ref f) = self.footer {
size += f.encoded_size() + SECTION_PREFIX_LENGTH;
}
self.size.set(size);
size
}
fn encode(&self, dst: &mut BytesMut) {
if let Some(ref h) = self.header {
h.encode(dst);
}
if let Some(ref da) = self.delivery_annotations {
Descriptor::Ulong(113).encode(dst);
da.encode(dst);
}
if let Some(ref ma) = self.message_annotations {
Descriptor::Ulong(114).encode(dst);
ma.encode(dst);
}
if let Some(ref p) = self.properties {
p.encode(dst);
}
if let Some(ref ap) = self.application_properties {
Descriptor::Ulong(116).encode(dst);
ap.encode(dst);
}
// message body
if self.body.encoded_size() == 0 {
// special treatment for empty body
Descriptor::Ulong(117).encode(dst);
dst.put_u8(FORMATCODE_BINARY8);
dst.put_u8(0);
} else {
self.body.encode(dst);
}
// message footer, always last item
if let Some(ref f) = self.footer {
Descriptor::Ulong(120).encode(dst);
f.encode(dst);
}
}
}
#[cfg(test)]
mod tests {
use bytes::{Bytes, BytesMut};
use bytestring::ByteString;
use crate::codec::{Decode, Encode};
use crate::errors::AmqpCodecError;
use crate::protocol::Header;
use crate::types::Variant;
use super::OutMessage;
#[test]
fn test_properties() -> Result<(), AmqpCodecError> {
let mut msg = OutMessage::default();
msg.set_properties(|props| props.message_id = Some(1.into()));
let mut buf = BytesMut::with_capacity(msg.encoded_size());
msg.encode(&mut buf);
let msg2 = OutMessage::decode(&buf)?.1;
let props = msg2.properties.as_ref().unwrap();
assert_eq!(props.message_id, Some(1.into()));
Ok(())
}
#[test]
fn test_app_properties() -> Result<(), AmqpCodecError> {
let mut msg = OutMessage::default();
msg.set_app_property(ByteString::from("test"), 1);
let mut buf = BytesMut::with_capacity(msg.encoded_size());
msg.encode(&mut buf);
let msg2 = OutMessage::decode(&buf)?.1;
let props = msg2.application_properties.as_ref().unwrap();
assert_eq!(props[0].0.as_str(), "test");
assert_eq!(props[0].1, Variant::from(1));
Ok(())
}
#[test]
fn test_header() -> Result<(), AmqpCodecError> {
let hdr = Header {
durable: false,
priority: 1,
ttl: None,
first_acquirer: false,
delivery_count: 1,
};
let mut msg = OutMessage::default();
msg.set_header(hdr.clone());
let mut buf = BytesMut::with_capacity(msg.encoded_size());
msg.encode(&mut buf);
let msg2 = OutMessage::decode(&buf)?.1;
assert_eq!(msg2.header().unwrap(), &hdr);
Ok(())
}
#[test]
fn test_data() -> Result<(), AmqpCodecError> {
let data = Bytes::from_static(b"test data");
let mut msg = OutMessage::default();
msg.set_body(|body| body.set_data(data.clone()));
let mut buf = BytesMut::with_capacity(msg.encoded_size());
msg.encode(&mut buf);
let msg2 = OutMessage::decode(&buf)?.1;
assert_eq!(msg2.body.data().unwrap(), &data);
Ok(())
}
#[test]
fn test_data_empty() -> Result<(), AmqpCodecError> {
let msg = OutMessage::default();
let mut buf = BytesMut::with_capacity(msg.encoded_size());
msg.encode(&mut buf);
let msg2 = OutMessage::decode(&buf)?.1;
assert_eq!(msg2.body.data().unwrap(), &Bytes::from_static(b""));
Ok(())
}
#[test]
fn test_messages() -> Result<(), AmqpCodecError> {
let mut msg1 = OutMessage::default();
msg1.set_properties(|props| props.message_id = Some(1.into()));
let mut msg2 = OutMessage::default();
msg2.set_properties(|props| props.message_id = Some(2.into()));
let mut msg = OutMessage::default();
msg.set_body(|body| {
body.messages.push(msg1.clone().into());
body.messages.push(msg2.clone().into());
});
let mut buf = BytesMut::with_capacity(msg.encoded_size());
msg.encode(&mut buf);
let msg3 = OutMessage::decode(&buf)?.1;
let msg4 = OutMessage::decode(&msg3.body.data().unwrap())?.1;
assert_eq!(msg1.properties, msg4.properties);
let msg5 = OutMessage::decode(&msg3.body.data[1])?.1;
assert_eq!(msg2.properties, msg5.properties);
Ok(())
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,268 @@
use std::fmt;
use bytes::{BufMut, Bytes, BytesMut};
use bytestring::ByteString;
use chrono::{DateTime, Utc};
use derive_more::From;
use fxhash::FxHashMap;
use uuid::Uuid;
use super::codec::{self, DecodeFormatted, Encode};
use super::errors::AmqpParseError;
use super::message::{InMessage, OutMessage};
use super::types::*;
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self)
}
}
#[derive(Debug)]
pub(crate) struct CompoundHeader {
pub size: u32,
pub count: u32,
}
impl CompoundHeader {
pub fn empty() -> CompoundHeader {
CompoundHeader { size: 0, count: 0 }
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum ProtocolId {
Amqp = 0,
AmqpTls = 2,
AmqpSasl = 3,
}
pub type Map = FxHashMap<Variant, Variant>;
pub type StringVariantMap = FxHashMap<Str, Variant>;
pub type Fields = FxHashMap<Symbol, Variant>;
pub type FilterSet = FxHashMap<Symbol, Option<ByteString>>;
pub type Timestamp = DateTime<Utc>;
pub type Symbols = Multiple<Symbol>;
pub type IetfLanguageTags = Multiple<IetfLanguageTag>;
pub type Annotations = FxHashMap<Symbol, Variant>;
mod definitions;
pub use self::definitions::*;
#[derive(Debug, Eq, PartialEq, Clone, From, Display)]
pub enum MessageId {
#[display(fmt = "{}", _0)]
Ulong(u64),
#[display(fmt = "{}", _0)]
Uuid(Uuid),
#[display(fmt = "{:?}", _0)]
Binary(Bytes),
#[display(fmt = "{}", _0)]
String(ByteString),
}
impl From<usize> for MessageId {
fn from(id: usize) -> MessageId {
MessageId::Ulong(id as u64)
}
}
impl From<i32> for MessageId {
fn from(id: i32) -> MessageId {
MessageId::Ulong(id as u64)
}
}
impl DecodeFormatted for MessageId {
fn decode_with_format(input: &[u8], fmt: u8) -> Result<(&[u8], Self), AmqpParseError> {
match fmt {
codec::FORMATCODE_SMALLULONG | codec::FORMATCODE_ULONG | codec::FORMATCODE_ULONG_0 => {
u64::decode_with_format(input, fmt).map(|(i, o)| (i, MessageId::Ulong(o)))
}
codec::FORMATCODE_UUID => {
Uuid::decode_with_format(input, fmt).map(|(i, o)| (i, MessageId::Uuid(o)))
}
codec::FORMATCODE_BINARY8 | codec::FORMATCODE_BINARY32 => {
Bytes::decode_with_format(input, fmt).map(|(i, o)| (i, MessageId::Binary(o)))
}
codec::FORMATCODE_STRING8 | codec::FORMATCODE_STRING32 => {
ByteString::decode_with_format(input, fmt).map(|(i, o)| (i, MessageId::String(o)))
}
_ => Err(AmqpParseError::InvalidFormatCode(fmt)),
}
}
}
impl Encode for MessageId {
fn encoded_size(&self) -> usize {
match *self {
MessageId::Ulong(v) => v.encoded_size(),
MessageId::Uuid(ref v) => v.encoded_size(),
MessageId::Binary(ref v) => v.encoded_size(),
MessageId::String(ref v) => v.encoded_size(),
}
}
fn encode(&self, buf: &mut BytesMut) {
match *self {
MessageId::Ulong(v) => v.encode(buf),
MessageId::Uuid(ref v) => v.encode(buf),
MessageId::Binary(ref v) => v.encode(buf),
MessageId::String(ref v) => v.encode(buf),
}
}
}
#[derive(Clone, Debug, PartialEq, From)]
pub enum ErrorCondition {
AmqpError(AmqpError),
ConnectionError(ConnectionError),
SessionError(SessionError),
LinkError(LinkError),
Custom(Symbol),
}
impl DecodeFormatted for ErrorCondition {
#[inline]
fn decode_with_format(input: &[u8], format: u8) -> Result<(&[u8], Self), AmqpParseError> {
let (input, result) = Symbol::decode_with_format(input, format)?;
if let Ok(r) = AmqpError::try_from(&result) {
return Ok((input, ErrorCondition::AmqpError(r)));
}
if let Ok(r) = ConnectionError::try_from(&result) {
return Ok((input, ErrorCondition::ConnectionError(r)));
}
if let Ok(r) = SessionError::try_from(&result) {
return Ok((input, ErrorCondition::SessionError(r)));
}
if let Ok(r) = LinkError::try_from(&result) {
return Ok((input, ErrorCondition::LinkError(r)));
}
Ok((input, ErrorCondition::Custom(result)))
}
}
impl Encode for ErrorCondition {
fn encoded_size(&self) -> usize {
match *self {
ErrorCondition::AmqpError(ref v) => v.encoded_size(),
ErrorCondition::ConnectionError(ref v) => v.encoded_size(),
ErrorCondition::SessionError(ref v) => v.encoded_size(),
ErrorCondition::LinkError(ref v) => v.encoded_size(),
ErrorCondition::Custom(ref v) => v.encoded_size(),
}
}
fn encode(&self, buf: &mut BytesMut) {
match *self {
ErrorCondition::AmqpError(ref v) => v.encode(buf),
ErrorCondition::ConnectionError(ref v) => v.encode(buf),
ErrorCondition::SessionError(ref v) => v.encode(buf),
ErrorCondition::LinkError(ref v) => v.encode(buf),
ErrorCondition::Custom(ref v) => v.encode(buf),
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum DistributionMode {
Move,
Copy,
Custom(Symbol),
}
impl DecodeFormatted for DistributionMode {
fn decode_with_format(input: &[u8], format: u8) -> Result<(&[u8], Self), AmqpParseError> {
let (input, result) = Symbol::decode_with_format(input, format)?;
let result = match result.as_str() {
"move" => DistributionMode::Move,
"copy" => DistributionMode::Copy,
_ => DistributionMode::Custom(result),
};
Ok((input, result))
}
}
impl Encode for DistributionMode {
fn encoded_size(&self) -> usize {
match *self {
DistributionMode::Move => 6,
DistributionMode::Copy => 6,
DistributionMode::Custom(ref v) => v.encoded_size(),
}
}
fn encode(&self, buf: &mut BytesMut) {
match *self {
DistributionMode::Move => Symbol::from("move").encode(buf),
DistributionMode::Copy => Symbol::from("copy").encode(buf),
DistributionMode::Custom(ref v) => v.encode(buf),
}
}
}
impl SaslInit {
pub fn prepare_response(authz_id: &str, authn_id: &str, password: &str) -> Bytes {
Bytes::from(format!("{}\x00{}\x00{}", authz_id, authn_id, password))
}
}
impl Default for Properties {
fn default() -> Properties {
Properties {
message_id: None,
user_id: None,
to: None,
subject: None,
reply_to: None,
correlation_id: None,
content_type: None,
content_encoding: None,
absolute_expiry_time: None,
creation_time: None,
group_id: None,
group_sequence: None,
reply_to_group_id: None,
}
}
}
#[derive(Debug, Clone, From, PartialEq)]
pub enum TransferBody {
Data(Bytes),
MessageIn(InMessage),
MessageOut(OutMessage),
}
impl TransferBody {
#[inline]
pub fn len(&self) -> usize {
self.encoded_size()
}
#[inline]
pub fn message_format(&self) -> Option<MessageFormat> {
match self {
TransferBody::Data(_) => None,
TransferBody::MessageIn(ref data) => data.message_format,
TransferBody::MessageOut(ref data) => data.message_format,
}
}
}
impl Encode for TransferBody {
fn encoded_size(&self) -> usize {
match self {
TransferBody::Data(ref data) => data.len(),
TransferBody::MessageIn(ref data) => data.encoded_size(),
TransferBody::MessageOut(ref data) => data.encoded_size(),
}
}
fn encode(&self, dst: &mut BytesMut) {
match *self {
TransferBody::Data(ref data) => dst.put_slice(&data),
TransferBody::MessageIn(ref data) => data.encode(dst),
TransferBody::MessageOut(ref data) => data.encode(dst),
}
}
}

194
actix-amqp/codec/src/types/mod.rs Executable file
View File

@ -0,0 +1,194 @@
use std::{borrow, fmt, hash, ops};
use bytestring::ByteString;
mod symbol;
mod variant;
pub use self::symbol::{StaticSymbol, Symbol};
pub use self::variant::{Variant, VariantMap, VecStringMap, VecSymbolMap};
#[derive(Debug, PartialEq, Eq, Clone, Hash, Display)]
pub enum Descriptor {
Ulong(u64),
Symbol(Symbol),
}
#[derive(Debug, PartialEq, Eq, Clone, Hash, From)]
pub struct Multiple<T>(pub Vec<T>);
impl<T> Multiple<T> {
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn iter(&self) -> ::std::slice::Iter<T> {
self.0.iter()
}
}
impl<T> Default for Multiple<T> {
fn default() -> Multiple<T> {
Multiple(Vec::new())
}
}
impl<T> ops::Deref for Multiple<T> {
type Target = Vec<T>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> ops::DerefMut for Multiple<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct List(pub Vec<Variant>);
impl List {
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn iter(&self) -> ::std::slice::Iter<Variant> {
self.0.iter()
}
}
#[derive(Display, Clone, Eq, Ord, PartialOrd)]
pub enum Str {
String(String),
ByteStr(ByteString),
Static(&'static str),
}
impl Str {
pub fn from_str(s: &str) -> Str {
Str::ByteStr(ByteString::from(s))
}
pub fn as_bytes(&self) -> &[u8] {
match self {
Str::String(s) => s.as_ref(),
Str::ByteStr(s) => s.as_ref(),
Str::Static(s) => s.as_bytes(),
}
}
pub fn as_str(&self) -> &str {
match self {
Str::String(s) => s.as_str(),
Str::ByteStr(s) => s.as_ref(),
Str::Static(s) => s,
}
}
pub fn to_bytes_str(&self) -> ByteString {
match self {
Str::String(s) => ByteString::from(s.as_str()),
Str::ByteStr(s) => s.clone(),
Str::Static(s) => ByteString::from_static(s),
}
}
pub fn len(&self) -> usize {
match self {
Str::String(s) => s.len(),
Str::ByteStr(s) => s.len(),
Str::Static(s) => s.len(),
}
}
}
impl From<&'static str> for Str {
fn from(s: &'static str) -> Str {
Str::Static(s)
}
}
impl From<ByteString> for Str {
fn from(s: ByteString) -> Str {
Str::ByteStr(s)
}
}
impl From<String> for Str {
fn from(s: String) -> Str {
Str::String(s)
}
}
impl hash::Hash for Str {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
match self {
Str::String(s) => (&*s).hash(state),
Str::ByteStr(s) => (&*s).hash(state),
Str::Static(s) => s.hash(state),
}
}
}
impl borrow::Borrow<str> for Str {
fn borrow(&self) -> &str {
self.as_str()
}
}
impl PartialEq<Str> for Str {
fn eq(&self, other: &Str) -> bool {
match self {
Str::String(s) => match other {
Str::String(o) => s == o,
Str::ByteStr(o) => o == s.as_str(),
Str::Static(o) => s == *o,
},
Str::ByteStr(s) => match other {
Str::String(o) => s == o.as_str(),
Str::ByteStr(o) => s == o,
Str::Static(o) => s == *o,
},
Str::Static(s) => match other {
Str::String(o) => o == s,
Str::ByteStr(o) => o == *s,
Str::Static(o) => s == o,
},
}
}
}
impl PartialEq<str> for Str {
fn eq(&self, other: &str) -> bool {
match self {
Str::String(ref s) => s == other,
Str::ByteStr(ref s) => {
// workaround for possible compiler bug
let t: &str = &*s;
t == other
}
Str::Static(s) => *s == other,
}
}
}
impl fmt::Debug for Str {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Str::String(s) => write!(f, "ST:\"{}\"", s),
Str::ByteStr(s) => write!(f, "B:\"{}\"", &*s),
Str::Static(s) => write!(f, "S:\"{}\"", s),
}
}
}

View File

@ -0,0 +1,75 @@
use std::{borrow, str};
use bytestring::ByteString;
use super::Str;
#[derive(Debug, Clone, Eq, PartialEq, Hash, Display)]
pub struct Symbol(pub Str);
impl Symbol {
pub fn from_slice(s: &str) -> Symbol {
Symbol(Str::ByteStr(ByteString::from(s)))
}
pub fn as_bytes(&self) -> &[u8] {
self.0.as_bytes()
}
pub fn as_str(&self) -> &str {
self.0.as_str()
}
pub fn to_bytes_str(&self) -> ByteString {
self.0.to_bytes_str()
}
pub fn len(&self) -> usize {
self.0.len()
}
}
impl From<&'static str> for Symbol {
fn from(s: &'static str) -> Symbol {
Symbol(Str::Static(s))
}
}
impl From<Str> for Symbol {
fn from(s: Str) -> Symbol {
Symbol(s)
}
}
impl From<std::string::String> for Symbol {
fn from(s: std::string::String) -> Symbol {
Symbol(Str::from(s))
}
}
impl From<ByteString> for Symbol {
fn from(s: ByteString) -> Symbol {
Symbol(Str::ByteStr(s))
}
}
impl borrow::Borrow<str> for Symbol {
fn borrow(&self) -> &str {
self.as_str()
}
}
impl PartialEq<str> for Symbol {
fn eq(&self, other: &str) -> bool {
self.0 == *other
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Display)]
pub struct StaticSymbol(pub &'static str);
impl From<&'static str> for StaticSymbol {
fn from(s: &'static str) -> StaticSymbol {
StaticSymbol(s)
}
}

View File

@ -0,0 +1,268 @@
use std::hash::{Hash, Hasher};
use bytes::Bytes;
use bytestring::ByteString;
use chrono::{DateTime, Utc};
use fxhash::FxHashMap;
use ordered_float::OrderedFloat;
use uuid::Uuid;
use crate::protocol::Annotations;
use crate::types::{Descriptor, List, StaticSymbol, Str, Symbol};
/// Represents an AMQP type for use in polymorphic collections
#[derive(Debug, Eq, PartialEq, Hash, Clone, Display, From)]
pub enum Variant {
/// Indicates an empty value.
Null,
/// Represents a true or false value.
Boolean(bool),
/// Integer in the range 0 to 2^8 - 1 inclusive.
Ubyte(u8),
/// Integer in the range 0 to 2^16 - 1 inclusive.
Ushort(u16),
/// Integer in the range 0 to 2^32 - 1 inclusive.
Uint(u32),
/// Integer in the range 0 to 2^64 - 1 inclusive.
Ulong(u64),
/// Integer in the range 0 to 2^7 - 1 inclusive.
Byte(i8),
/// Integer in the range 0 to 2^15 - 1 inclusive.
Short(i16),
/// Integer in the range 0 to 2^32 - 1 inclusive.
Int(i32),
/// Integer in the range 0 to 2^64 - 1 inclusive.
Long(i64),
/// 32-bit floating point number (IEEE 754-2008 binary32).
Float(OrderedFloat<f32>),
/// 64-bit floating point number (IEEE 754-2008 binary64).
Double(OrderedFloat<f64>),
// Decimal32(d32),
// Decimal64(d64),
// Decimal128(d128),
/// A single Unicode character.
Char(char),
/// An absolute point in time.
/// Represents an approximate point in time using the Unix time encoding of
/// UTC with a precision of milliseconds. For example, 1311704463521
/// represents the moment 2011-07-26T18:21:03.521Z.
Timestamp(DateTime<Utc>),
/// A universally unique identifier as defined by RFC-4122 section 4.1.2
Uuid(Uuid),
/// A sequence of octets.
#[display(fmt = "Binary({:?})", _0)]
Binary(Bytes),
/// A sequence of Unicode characters
String(Str),
/// Symbolic values from a constrained domain.
Symbol(Symbol),
/// Same as Symbol but for static refs
StaticSymbol(StaticSymbol),
/// List
#[display(fmt = "List({:?})", _0)]
List(List),
/// Map
Map(VariantMap),
/// Described value
#[display(fmt = "Described{:?}", _0)]
Described((Descriptor, Box<Variant>)),
}
impl From<ByteString> for Variant {
fn from(s: ByteString) -> Self {
Str::from(s).into()
}
}
impl From<String> for Variant {
fn from(s: String) -> Self {
Str::from(ByteString::from(s)).into()
}
}
impl From<&'static str> for Variant {
fn from(s: &'static str) -> Self {
Str::from(s).into()
}
}
impl PartialEq<str> for Variant {
fn eq(&self, other: &str) -> bool {
match self {
Variant::String(s) => s == other,
Variant::Symbol(s) => s == other,
_ => false,
}
}
}
impl Variant {
pub fn as_str(&self) -> Option<&str> {
match self {
Variant::String(s) => Some(s.as_str()),
Variant::Symbol(s) => Some(s.as_str()),
_ => None,
}
}
pub fn as_int(&self) -> Option<i32> {
match self {
Variant::Int(v) => Some(*v as i32),
_ => None,
}
}
pub fn as_long(&self) -> Option<i64> {
match self {
Variant::Ubyte(v) => Some(*v as i64),
Variant::Ushort(v) => Some(*v as i64),
Variant::Uint(v) => Some(*v as i64),
Variant::Ulong(v) => Some(*v as i64),
Variant::Byte(v) => Some(*v as i64),
Variant::Short(v) => Some(*v as i64),
Variant::Int(v) => Some(*v as i64),
Variant::Long(v) => Some(*v as i64),
_ => None,
}
}
pub fn to_bytes_str(&self) -> Option<ByteString> {
match self {
Variant::String(s) => Some(s.to_bytes_str()),
Variant::Symbol(s) => Some(s.to_bytes_str()),
_ => None,
}
}
}
#[derive(PartialEq, Eq, Clone, Debug, Display)]
#[display(fmt = "{:?}", map)]
pub struct VariantMap {
pub map: FxHashMap<Variant, Variant>,
}
impl VariantMap {
pub fn new(map: FxHashMap<Variant, Variant>) -> VariantMap {
VariantMap { map }
}
}
impl Hash for VariantMap {
fn hash<H: Hasher>(&self, _state: &mut H) {
unimplemented!()
}
}
#[derive(PartialEq, Clone, Debug, Display)]
#[display(fmt = "{:?}", _0)]
pub struct VecSymbolMap(pub Vec<(Symbol, Variant)>);
impl Default for VecSymbolMap {
fn default() -> Self {
VecSymbolMap(Vec::with_capacity(8))
}
}
impl From<Annotations> for VecSymbolMap {
fn from(anns: Annotations) -> VecSymbolMap {
VecSymbolMap(anns.into_iter().collect())
}
}
impl std::ops::Deref for VecSymbolMap {
type Target = Vec<(Symbol, Variant)>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::DerefMut for VecSymbolMap {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[derive(PartialEq, Clone, Debug, Display)]
#[display(fmt = "{:?}", _0)]
pub struct VecStringMap(pub Vec<(Str, Variant)>);
impl Default for VecStringMap {
fn default() -> Self {
VecStringMap(Vec::with_capacity(8))
}
}
impl From<FxHashMap<Str, Variant>> for VecStringMap {
fn from(map: FxHashMap<Str, Variant>) -> VecStringMap {
VecStringMap(map.into_iter().collect())
}
}
impl std::ops::Deref for VecStringMap {
type Target = Vec<(Str, Variant)>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::DerefMut for VecStringMap {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bytes_eq() {
let bytes1 = Variant::Binary(Bytes::from(&b"hello"[..]));
let bytes2 = Variant::Binary(Bytes::from(&b"hello"[..]));
let bytes3 = Variant::Binary(Bytes::from(&b"world"[..]));
assert_eq!(bytes1, bytes2);
assert!(bytes1 != bytes3);
}
#[test]
fn string_eq() {
let a = Variant::String(ByteString::from("hello").into());
let b = Variant::String(ByteString::from("world!").into());
assert_eq!(Variant::String(ByteString::from("hello").into()), a);
assert!(a != b);
}
#[test]
fn symbol_eq() {
let a = Variant::Symbol(Symbol::from("hello"));
let b = Variant::Symbol(Symbol::from("world!"));
assert_eq!(Variant::Symbol(Symbol::from("hello")), a);
assert!(a != b);
}
}

72
actix-amqp/src/cell.rs Executable file
View File

@ -0,0 +1,72 @@
//! Custom cell impl
use std::cell::UnsafeCell;
use std::ops::Deref;
use std::rc::{Rc, Weak};
pub(crate) struct Cell<T> {
inner: Rc<UnsafeCell<T>>,
}
pub(crate) struct WeakCell<T> {
inner: Weak<UnsafeCell<T>>,
}
impl<T> Clone for Cell<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<T> Deref for Cell<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.get_ref()
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for Cell<T> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
self.inner.fmt(f)
}
}
impl<T> Cell<T> {
pub fn new(inner: T) -> Self {
Self {
inner: Rc::new(UnsafeCell::new(inner)),
}
}
pub fn downgrade(&self) -> WeakCell<T> {
WeakCell {
inner: Rc::downgrade(&self.inner),
}
}
pub fn get_ref(&self) -> &T {
unsafe { &*self.inner.as_ref().get() }
}
pub fn get_mut(&self) -> &mut T {
unsafe { &mut *self.inner.as_ref().get() }
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for WeakCell<T> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
self.inner.fmt(f)
}
}
impl<T> WeakCell<T> {
pub fn upgrade(&self) -> Option<Cell<T>> {
if let Some(inner) = self.inner.upgrade() {
Some(Cell { inner })
} else {
None
}
}
}

View File

@ -0,0 +1,11 @@
use actix_codec::Framed;
use crate::Configuration;
trait IntoFramed<T, U: Default> {
fn into_framed(self) -> Framed<T, U>;
}
pub struct Handshake {
_cfg: Configuration,
}

5
actix-amqp/src/client/mod.rs Executable file
View File

@ -0,0 +1,5 @@
mod connect;
mod protocol;
pub use self::connect::Handshake;
pub use self::protocol::ProtocolNegotiation;

View File

@ -0,0 +1,78 @@
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_service::Service;
use futures::{Future, SinkExt, StreamExt};
use amqp_codec::protocol::ProtocolId;
use amqp_codec::{ProtocolIdCodec, ProtocolIdError};
pub struct ProtocolNegotiation<Io> {
proto: ProtocolId,
_r: PhantomData<Io>,
}
impl<Io> Clone for ProtocolNegotiation<Io> {
fn clone(&self) -> Self {
ProtocolNegotiation {
proto: self.proto.clone(),
_r: PhantomData,
}
}
}
impl<Io> ProtocolNegotiation<Io> {
pub fn new(proto: ProtocolId) -> Self {
ProtocolNegotiation {
proto,
_r: PhantomData,
}
}
pub fn framed(stream: Io) -> Framed<Io, ProtocolIdCodec>
where
Io: AsyncRead + AsyncWrite,
{
Framed::new(stream, ProtocolIdCodec)
}
}
impl<Io> Default for ProtocolNegotiation<Io> {
fn default() -> Self {
Self::new(ProtocolId::Amqp)
}
}
impl<Io> Service for ProtocolNegotiation<Io>
where
Io: AsyncRead + AsyncWrite + 'static,
{
type Request = Framed<Io, ProtocolIdCodec>;
type Response = Framed<Io, ProtocolIdCodec>;
type Error = ProtocolIdError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, mut framed: Framed<Io, ProtocolIdCodec>) -> Self::Future {
let proto = self.proto;
Box::pin(async move {
framed.send(proto).await?;
let protocol = framed.next().await.ok_or(ProtocolIdError::Disconnected)??;
if proto == protocol {
Ok(framed)
} else {
Err(ProtocolIdError::Unexpected {
exp: proto,
got: protocol,
})
}
})
}
}

577
actix-amqp/src/connection.rs Executable file
View File

@ -0,0 +1,577 @@
use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_utils::oneshot;
use actix_utils::task::LocalWaker;
use actix_utils::time::LowResTimeService;
use futures::future::{err, Either};
use futures::{future, Sink, Stream};
use fxhash::FxHashMap;
use amqp_codec::protocol::{Begin, Close, End, Error, Frame};
use amqp_codec::{AmqpCodec, AmqpCodecError, AmqpFrame};
use crate::cell::{Cell, WeakCell};
use crate::errors::AmqpTransportError;
use crate::hb::{Heartbeat, HeartbeatAction};
use crate::session::{Session, SessionInner};
use crate::Configuration;
pub struct Connection<T: AsyncRead + AsyncWrite> {
inner: Cell<ConnectionInner>,
framed: Framed<T, AmqpCodec<AmqpFrame>>,
hb: Heartbeat,
}
pub(crate) enum ChannelState {
Opening(Option<oneshot::Sender<Session>>, WeakCell<ConnectionInner>),
Established(Cell<SessionInner>),
Closing(Option<oneshot::Sender<Result<(), AmqpTransportError>>>),
}
impl ChannelState {
fn is_opening(&self) -> bool {
match self {
ChannelState::Opening(_, _) => true,
_ => false,
}
}
}
pub(crate) struct ConnectionInner {
local: Configuration,
remote: Configuration,
write_queue: VecDeque<AmqpFrame>,
write_task: LocalWaker,
sessions: slab::Slab<ChannelState>,
sessions_map: FxHashMap<u16, usize>,
error: Option<AmqpTransportError>,
state: State,
}
#[derive(PartialEq)]
enum State {
Normal,
Closing,
RemoteClose,
Drop,
}
impl<T: AsyncRead + AsyncWrite> Connection<T> {
pub fn new(
framed: Framed<T, AmqpCodec<AmqpFrame>>,
local: Configuration,
remote: Configuration,
time: Option<LowResTimeService>,
) -> Connection<T> {
Connection {
framed,
hb: Heartbeat::new(
local.timeout().unwrap(),
remote.timeout(),
time.unwrap_or_else(|| LowResTimeService::with(Duration::from_secs(1))),
),
inner: Cell::new(ConnectionInner::new(local, remote)),
}
}
pub(crate) fn new_server(
framed: Framed<T, AmqpCodec<AmqpFrame>>,
inner: Cell<ConnectionInner>,
time: Option<LowResTimeService>,
) -> Connection<T> {
let l_timeout = inner.get_ref().local.timeout().unwrap();
let r_timeout = inner.get_ref().remote.timeout();
Connection {
framed,
inner,
hb: Heartbeat::new(
l_timeout,
r_timeout,
time.unwrap_or_else(|| LowResTimeService::with(Duration::from_secs(1))),
),
}
}
/// Connection controller
pub fn controller(&self) -> ConnectionController {
ConnectionController(self.inner.clone())
}
/// Get remote configuration
pub fn remote_config(&self) -> &Configuration {
&self.inner.get_ref().remote
}
/// Gracefully close connection
pub fn close(&mut self) -> impl Future<Output = Result<(), AmqpTransportError>> {
future::ok(())
}
// TODO: implement
/// Close connection with error
pub fn close_with_error(
&mut self,
_err: Error,
) -> impl Future<Output = Result<(), AmqpTransportError>> {
future::ok(())
}
/// Opens the session
pub fn open_session(&mut self) -> impl Future<Output = Result<Session, AmqpTransportError>> {
let cell = self.inner.downgrade();
let inner = self.inner.clone();
async move {
let inner = inner.get_mut();
if let Some(ref e) = inner.error {
Err(e.clone())
} else {
let (tx, rx) = oneshot::channel();
let entry = inner.sessions.vacant_entry();
let token = entry.key();
if token >= inner.local.channel_max {
Err(AmqpTransportError::TooManyChannels)
} else {
entry.insert(ChannelState::Opening(Some(tx), cell));
let begin = Begin {
remote_channel: None,
next_outgoing_id: 1,
incoming_window: std::u32::MAX,
outgoing_window: std::u32::MAX,
handle_max: std::u32::MAX,
offered_capabilities: None,
desired_capabilities: None,
properties: None,
};
inner.post_frame(AmqpFrame::new(token as u16, begin.into()));
rx.await.map_err(|_| AmqpTransportError::Disconnected)
}
}
}
}
/// Get session by id. This method panics if session does not exists or in opening/closing state.
pub(crate) fn get_session(&self, id: usize) -> Cell<SessionInner> {
if let Some(channel) = self.inner.get_ref().sessions.get(id) {
if let ChannelState::Established(ref session) = channel {
return session.clone();
}
}
panic!("Session not found: {}", id);
}
pub(crate) fn register_remote_session(&mut self, channel_id: u16, begin: &Begin) {
trace!("remote session opened: {:?}", channel_id);
let cell = self.inner.clone();
let inner = self.inner.get_mut();
let entry = inner.sessions.vacant_entry();
let token = entry.key();
let session = Cell::new(SessionInner::new(
token,
false,
ConnectionController(cell),
token as u16,
begin.next_outgoing_id(),
begin.incoming_window(),
begin.outgoing_window(),
));
entry.insert(ChannelState::Established(session));
inner.sessions_map.insert(channel_id, token);
let begin = Begin {
remote_channel: Some(channel_id),
next_outgoing_id: 1,
incoming_window: std::u32::MAX,
outgoing_window: begin.incoming_window(),
handle_max: std::u32::MAX,
offered_capabilities: None,
desired_capabilities: None,
properties: None,
};
inner.post_frame(AmqpFrame::new(token as u16, begin.into()));
}
pub(crate) fn send_frame(&mut self, frame: AmqpFrame) {
self.inner.get_mut().post_frame(frame)
}
pub(crate) fn register_write_task(&self, cx: &mut Context) {
self.inner.write_task.register(cx.waker());
}
pub(crate) fn poll_outgoing(&mut self, cx: &mut Context) -> Poll<Result<(), AmqpCodecError>> {
let inner = self.inner.get_mut();
let mut update = false;
loop {
while !self.framed.is_write_buf_full() {
if let Some(frame) = inner.pop_next_frame() {
trace!("outgoing: {:#?}", frame);
update = true;
if let Err(e) = self.framed.write(frame) {
inner.set_error(e.clone().into());
return Poll::Ready(Err(e));
}
} else {
break;
}
}
if !self.framed.is_write_buf_empty() {
match self.framed.flush(cx) {
Poll::Pending => break,
Poll::Ready(Err(e)) => {
trace!("error sending data: {}", e);
inner.set_error(e.clone().into());
return Poll::Ready(Err(e));
}
Poll::Ready(_) => (),
}
} else {
break;
}
}
self.hb.update_remote(update);
if inner.state == State::Drop {
Poll::Ready(Ok(()))
} else if inner.state == State::RemoteClose
&& inner.write_queue.is_empty()
&& self.framed.is_write_buf_empty()
{
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
pub(crate) fn poll_incoming(
&mut self,
cx: &mut Context,
) -> Poll<Option<Result<AmqpFrame, AmqpCodecError>>> {
let inner = self.inner.get_mut();
let mut update = false;
loop {
match Pin::new(&mut self.framed).poll_next(cx) {
Poll::Ready(Some(Ok(frame))) => {
trace!("incoming: {:#?}", frame);
update = true;
if let Frame::Empty = frame.performative() {
self.hb.update_local(update);
continue;
}
// handle connection close
if let Frame::Close(ref close) = frame.performative() {
inner.set_error(AmqpTransportError::Closed(close.error.clone()));
if inner.state == State::Closing {
inner.sessions.clear();
return Poll::Ready(None);
} else {
let close = Close { error: None };
inner.post_frame(AmqpFrame::new(0, close.into()));
inner.state = State::RemoteClose;
}
}
if inner.error.is_some() {
error!("connection closed but new framed is received: {:?}", frame);
return Poll::Ready(None);
}
// get local session id
let channel_id =
if let Some(token) = inner.sessions_map.get(&frame.channel_id()) {
*token
} else {
// we dont have channel info, only Begin frame is allowed on new channel
if let Frame::Begin(ref begin) = frame.performative() {
if begin.remote_channel().is_some() {
inner.complete_session_creation(frame.channel_id(), begin);
} else {
return Poll::Ready(Some(Ok(frame)));
}
} else {
warn!("Unexpected frame: {:#?}", frame);
}
continue;
};
// handle session frames
if let Some(channel) = inner.sessions.get_mut(channel_id) {
match channel {
ChannelState::Opening(_, _) => {
error!("Unexpected opening state: {}", channel_id);
}
ChannelState::Established(ref mut session) => {
match frame.performative() {
Frame::Attach(attach) => {
let cell = session.clone();
if !session.get_mut().handle_attach(attach, cell) {
return Poll::Ready(Some(Ok(frame)));
}
}
Frame::Flow(_) | Frame::Detach(_) => {
return Poll::Ready(Some(Ok(frame)));
}
Frame::End(remote_end) => {
trace!("Remote session end: {}", frame.channel_id());
let end = End { error: None };
session.get_mut().set_error(
AmqpTransportError::SessionEnded(
remote_end.error.clone(),
),
);
let id = session.get_mut().id();
inner.post_frame(AmqpFrame::new(id, end.into()));
inner.sessions.remove(channel_id);
inner.sessions_map.remove(&frame.channel_id());
}
_ => session.get_mut().handle_frame(frame.into_parts().1),
}
}
ChannelState::Closing(ref mut tx) => match frame.performative() {
Frame::End(_) => {
if let Some(tx) = tx.take() {
let _ = tx.send(Ok(()));
}
inner.sessions.remove(channel_id);
inner.sessions_map.remove(&frame.channel_id());
}
frm => trace!("Got frame after initiated session end: {:?}", frm),
},
}
} else {
error!("Can not find channel: {}", channel_id);
continue;
}
}
Poll::Ready(None) => {
inner.set_error(AmqpTransportError::Disconnected);
return Poll::Ready(None);
}
Poll::Pending => {
self.hb.update_local(update);
break;
}
Poll::Ready(Some(Err(e))) => {
trace!("error reading: {:?}", e);
inner.set_error(e.clone().into());
return Poll::Ready(Some(Err(e.into())));
}
}
}
Poll::Pending
}
}
impl<T: AsyncRead + AsyncWrite> Drop for Connection<T> {
fn drop(&mut self) {
self.inner
.get_mut()
.set_error(AmqpTransportError::Disconnected);
}
}
impl<T: AsyncRead + AsyncWrite> Future for Connection<T> {
type Output = Result<(), AmqpCodecError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
// connection heartbeat
match self.hb.poll(cx) {
Ok(act) => match act {
HeartbeatAction::None => (),
HeartbeatAction::Close => {
self.inner.get_mut().set_error(AmqpTransportError::Timeout);
return Poll::Ready(Ok(()));
}
HeartbeatAction::Heartbeat => {
self.inner
.get_mut()
.write_queue
.push_back(AmqpFrame::new(0, Frame::Empty));
}
},
Err(e) => {
self.inner.get_mut().set_error(e);
return Poll::Ready(Ok(()));
}
}
loop {
match self.poll_incoming(cx) {
Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Ready(Some(Ok(frame))) => {
if let Some(channel) = self.inner.sessions.get(frame.channel_id() as usize) {
if let ChannelState::Established(ref session) = channel {
session.get_mut().handle_frame(frame.into_parts().1);
continue;
}
}
warn!("Unexpected frame: {:?}", frame);
}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
Poll::Pending => break,
}
}
let _ = self.poll_outgoing(cx)?;
self.register_write_task(cx);
match self.poll_incoming(cx) {
Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Ready(Some(Ok(frame))) => {
if let Some(channel) = self.inner.sessions.get(frame.channel_id() as usize) {
if let ChannelState::Established(ref session) = channel {
session.get_mut().handle_frame(frame.into_parts().1);
return Poll::Pending;
}
}
warn!("Unexpected frame: {:?}", frame);
}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
Poll::Pending => (),
}
Poll::Pending
}
}
#[derive(Clone)]
pub struct ConnectionController(pub(crate) Cell<ConnectionInner>);
impl ConnectionController {
pub(crate) fn new(local: Configuration) -> ConnectionController {
ConnectionController(Cell::new(ConnectionInner {
local,
remote: Configuration::default(),
write_queue: VecDeque::new(),
write_task: LocalWaker::new(),
sessions: slab::Slab::with_capacity(8),
sessions_map: FxHashMap::default(),
error: None,
state: State::Normal,
}))
}
pub(crate) fn set_remote(&mut self, remote: Configuration) {
self.0.get_mut().remote = remote;
}
#[inline]
/// Get remote connection configuration
pub fn remote_config(&self) -> &Configuration {
&self.0.get_ref().remote
}
#[inline]
/// Drop connection
pub fn drop_connection(&mut self) {
let inner = self.0.get_mut();
inner.state = State::Drop;
inner.write_task.wake()
}
pub(crate) fn post_frame(&mut self, frame: AmqpFrame) {
self.0.get_mut().post_frame(frame)
}
pub(crate) fn drop_session_copy(&mut self, _id: usize) {}
}
impl ConnectionInner {
pub(crate) fn new(local: Configuration, remote: Configuration) -> ConnectionInner {
ConnectionInner {
local,
remote,
write_queue: VecDeque::new(),
write_task: LocalWaker::new(),
sessions: slab::Slab::with_capacity(8),
sessions_map: FxHashMap::default(),
error: None,
state: State::Normal,
}
}
fn set_error(&mut self, err: AmqpTransportError) {
for (_, channel) in self.sessions.iter_mut() {
match channel {
ChannelState::Opening(_, _) | ChannelState::Closing(_) => (),
ChannelState::Established(ref mut ses) => {
ses.get_mut().set_error(err.clone());
}
}
}
self.sessions.clear();
self.sessions_map.clear();
self.error = Some(err);
}
fn pop_next_frame(&mut self) -> Option<AmqpFrame> {
self.write_queue.pop_front()
}
fn post_frame(&mut self, frame: AmqpFrame) {
// trace!("POST-FRAME: {:#?}", frame.performative());
self.write_queue.push_back(frame);
self.write_task.wake();
}
fn complete_session_creation(&mut self, channel_id: u16, begin: &Begin) {
trace!(
"session opened: {:?} {:?}",
channel_id,
begin.remote_channel()
);
let id = begin.remote_channel().unwrap() as usize;
if let Some(channel) = self.sessions.get_mut(id) {
if channel.is_opening() {
if let ChannelState::Opening(tx, cell) = channel {
let cell = cell.upgrade().unwrap();
let session = Cell::new(SessionInner::new(
id,
true,
ConnectionController(cell),
channel_id,
begin.next_outgoing_id(),
begin.incoming_window(),
begin.outgoing_window(),
));
self.sessions_map.insert(channel_id, id);
if tx
.take()
.unwrap()
.send(Session::new(session.clone()))
.is_err()
{
// todo: send end session
}
*channel = ChannelState::Established(session)
}
} else {
// send error response
}
} else {
// todo: rogue begin right now - do nothing. in future might indicate incoming attach
}
}
}

31
actix-amqp/src/errors.rs Executable file
View File

@ -0,0 +1,31 @@
use amqp_codec::{protocol, AmqpCodecError, ProtocolIdError};
#[derive(Debug, Display, Clone)]
pub enum AmqpTransportError {
Codec(AmqpCodecError),
TooManyChannels,
Disconnected,
Timeout,
#[display(fmt = "Connection closed, error: {:?}", _0)]
Closed(Option<protocol::Error>),
#[display(fmt = "Session ended, error: {:?}", _0)]
SessionEnded(Option<protocol::Error>),
#[display(fmt = "Link detached, error: {:?}", _0)]
LinkDetached(Option<protocol::Error>),
}
impl From<AmqpCodecError> for AmqpTransportError {
fn from(err: AmqpCodecError) -> Self {
AmqpTransportError::Codec(err)
}
}
#[derive(Debug, Display, From)]
pub enum SaslConnectError {
Protocol(ProtocolIdError),
AmqpError(AmqpCodecError),
#[display(fmt = "Sasl error code: {:?}", _0)]
Sasl(protocol::SaslCode),
ExpectedOpenFrame,
Disconnected,
}

93
actix-amqp/src/hb.rs Executable file
View File

@ -0,0 +1,93 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use actix_rt::time::{delay_until, Delay, Instant};
use actix_utils::time::LowResTimeService;
use crate::errors::AmqpTransportError;
pub(crate) enum HeartbeatAction {
None,
Heartbeat,
Close,
}
pub(crate) struct Heartbeat {
expire_local: Instant,
expire_remote: Instant,
local: Duration,
remote: Option<Duration>,
time: LowResTimeService,
delay: Delay,
}
impl Heartbeat {
pub(crate) fn new(local: Duration, remote: Option<Duration>, time: LowResTimeService) -> Self {
let now = Instant::from_std(time.now());
let delay = if let Some(remote) = remote {
delay_until(now + std::cmp::min(local, remote))
} else {
delay_until(now + local)
};
Heartbeat {
expire_local: now,
expire_remote: now,
local,
remote,
time,
delay,
}
}
pub(crate) fn update_local(&mut self, update: bool) {
if update {
self.expire_local = Instant::from_std(self.time.now());
}
}
pub(crate) fn update_remote(&mut self, update: bool) {
if update && self.remote.is_some() {
self.expire_remote = Instant::from_std(self.time.now());
}
}
fn next_expire(&self) -> Instant {
if let Some(remote) = self.remote {
let t1 = self.expire_local + self.local;
let t2 = self.expire_remote + remote;
if t1 < t2 {
t1
} else {
t2
}
} else {
self.expire_local + self.local
}
}
pub(crate) fn poll(&mut self, cx: &mut Context) -> Result<HeartbeatAction, AmqpTransportError> {
match Pin::new(&mut self.delay).poll(cx) {
Poll::Ready(_) => {
let mut act = HeartbeatAction::None;
let dl = self.delay.deadline();
if dl >= self.expire_local + self.local {
// close connection
return Ok(HeartbeatAction::Close);
}
if let Some(remote) = self.remote {
if dl >= self.expire_remote + remote {
// send heartbeat
act = HeartbeatAction::Heartbeat;
}
}
self.delay.reset(self.next_expire());
let _ = Pin::new(&mut self.delay).poll(cx);
Ok(act)
}
Poll::Pending => Ok(HeartbeatAction::None),
}
}
}

167
actix-amqp/src/lib.rs Executable file
View File

@ -0,0 +1,167 @@
#![allow(unused_imports, dead_code)]
#[macro_use]
extern crate derive_more;
#[macro_use]
extern crate log;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use actix_utils::oneshot;
use amqp_codec::protocol::{Disposition, Handle, Milliseconds, Open};
use bytes::Bytes;
use bytestring::ByteString;
use uuid::Uuid;
mod cell;
pub mod client;
mod connection;
mod errors;
mod hb;
mod rcvlink;
pub mod sasl;
pub mod server;
mod service;
mod session;
mod sndlink;
pub use self::connection::{Connection, ConnectionController};
pub use self::errors::AmqpTransportError;
pub use self::rcvlink::{ReceiverLink, ReceiverLinkBuilder};
pub use self::session::Session;
pub use self::sndlink::{SenderLink, SenderLinkBuilder};
pub enum Delivery {
Resolved(Result<Disposition, AmqpTransportError>),
Pending(oneshot::Receiver<Result<Disposition, AmqpTransportError>>),
Gone,
}
type DeliveryPromise = oneshot::Sender<Result<Disposition, AmqpTransportError>>;
impl Future for Delivery {
type Output = Result<Disposition, AmqpTransportError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let Delivery::Pending(ref mut receiver) = *self {
return match Pin::new(receiver).poll(cx) {
Poll::Ready(Ok(r)) => Poll::Ready(r.map(|state| state)),
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
trace!("delivery oneshot is gone: {:?}", e);
Poll::Ready(Err(AmqpTransportError::Disconnected))
}
};
}
let old_v = ::std::mem::replace(&mut *self, Delivery::Gone);
if let Delivery::Resolved(r) = old_v {
return match r {
Ok(state) => Poll::Ready(Ok(state)),
Err(e) => Poll::Ready(Err(e)),
};
}
panic!("Polling Delivery after it was polled as ready is an error.");
}
}
/// Amqp1 transport configuration.
#[derive(Debug, Clone)]
pub struct Configuration {
pub max_frame_size: u32,
pub channel_max: usize,
pub idle_time_out: Option<Milliseconds>,
pub hostname: Option<ByteString>,
}
impl Default for Configuration {
fn default() -> Self {
Self::new()
}
}
impl Configuration {
/// Create connection configuration.
pub fn new() -> Self {
Configuration {
max_frame_size: std::u16::MAX as u32,
channel_max: 1024,
idle_time_out: Some(120000),
hostname: None,
}
}
/// The channel-max value is the highest channel number that
/// may be used on the Connection. This value plus one is the maximum
/// number of Sessions that can be simultaneously active on the Connection.
///
/// By default channel max value is set to 1024
pub fn channel_max(&mut self, num: u16) -> &mut Self {
self.channel_max = num as usize;
self
}
/// Set max frame size for the connection.
///
/// By default max size is set to 64kb
pub fn max_frame_size(&mut self, size: u32) -> &mut Self {
self.max_frame_size = size;
self
}
/// Get max frame size for the connection.
pub fn get_max_frame_size(&self) -> usize {
self.max_frame_size as usize
}
/// Set idle time-out for the connection in milliseconds
///
/// By default idle time-out is set to 120000 milliseconds
pub fn idle_timeout(&mut self, timeout: u32) -> &mut Self {
self.idle_time_out = Some(timeout as Milliseconds);
self
}
/// Set connection hostname
///
/// Hostname is not set by default
pub fn hostname(&mut self, hostname: &str) -> &mut Self {
self.hostname = Some(ByteString::from(hostname));
self
}
/// Create `Open` performative for this configuration.
pub fn to_open(&self) -> Open {
Open {
container_id: ByteString::from(Uuid::new_v4().to_simple().to_string()),
hostname: self.hostname.clone(),
max_frame_size: self.max_frame_size,
channel_max: self.channel_max as u16,
idle_time_out: self.idle_time_out,
outgoing_locales: None,
incoming_locales: None,
offered_capabilities: None,
desired_capabilities: None,
properties: None,
}
}
pub(crate) fn timeout(&self) -> Option<Duration> {
self.idle_time_out
.map(|v| Duration::from_millis(((v as f32) * 0.8) as u64))
}
}
impl<'a> From<&'a Open> for Configuration {
fn from(open: &'a Open) -> Self {
Configuration {
max_frame_size: open.max_frame_size,
channel_max: open.channel_max as usize,
idle_time_out: open.idle_time_out,
hostname: open.hostname.clone(),
}
}
}

263
actix-amqp/src/rcvlink.rs Executable file
View File

@ -0,0 +1,263 @@
use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::u32;
use actix_utils::oneshot;
use actix_utils::task::LocalWaker;
use amqp_codec::protocol::{
Attach, DeliveryNumber, Disposition, Error, Handle, LinkError, ReceiverSettleMode, Role,
SenderSettleMode, Source, TerminusDurability, TerminusExpiryPolicy, Transfer,
};
use bytes::Bytes;
use bytestring::ByteString;
use futures::Stream;
use crate::cell::Cell;
use crate::errors::AmqpTransportError;
use crate::session::{Session, SessionInner};
use crate::Configuration;
#[derive(Clone, Debug)]
pub struct ReceiverLink {
pub(crate) inner: Cell<ReceiverLinkInner>,
}
impl ReceiverLink {
pub(crate) fn new(inner: Cell<ReceiverLinkInner>) -> ReceiverLink {
ReceiverLink { inner }
}
pub fn handle(&self) -> Handle {
self.inner.get_ref().handle as Handle
}
pub fn credit(&self) -> u32 {
self.inner.get_ref().credit
}
pub fn session(&self) -> &Session {
&self.inner.get_ref().session
}
pub fn session_mut(&mut self) -> &mut Session {
&mut self.inner.get_mut().session
}
pub fn frame(&self) -> &Attach {
&self.inner.get_ref().attach
}
pub fn open(&mut self) {
let inner = self.inner.get_mut();
inner
.session
.inner
.get_mut()
.confirm_receiver_link(inner.handle, &inner.attach);
}
pub fn set_link_credit(&mut self, credit: u32) {
self.inner.get_mut().set_link_credit(credit);
}
/// Send disposition frame
pub fn send_disposition(&mut self, disp: Disposition) {
self.inner
.get_mut()
.session
.inner
.get_mut()
.post_frame(disp.into());
}
/// Wait for disposition with specified number
pub fn wait_disposition(
&mut self,
id: DeliveryNumber,
) -> impl Future<Output = Result<Disposition, AmqpTransportError>> {
self.inner.get_mut().session.wait_disposition(id)
}
pub fn close(&self) -> impl Future<Output = Result<(), AmqpTransportError>> {
self.inner.get_mut().close(None)
}
pub fn close_with_error(
&self,
error: Error,
) -> impl Future<Output = Result<(), AmqpTransportError>> {
self.inner.get_mut().close(Some(error))
}
#[inline]
/// Get remote connection configuration
pub fn remote_config(&self) -> &Configuration {
&self.inner.session.remote_config()
}
}
impl Stream for ReceiverLink {
type Item = Result<Transfer, AmqpTransportError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let inner = self.inner.get_mut();
if let Some(tr) = inner.queue.pop_front() {
Poll::Ready(Some(Ok(tr)))
} else {
if inner.closed {
Poll::Ready(None)
} else {
inner.reader_task.register(cx.waker());
Poll::Pending
}
}
}
}
#[derive(Debug)]
pub(crate) struct ReceiverLinkInner {
handle: Handle,
attach: Attach,
session: Session,
closed: bool,
reader_task: LocalWaker,
queue: VecDeque<Transfer>,
credit: u32,
delivery_count: u32,
}
impl ReceiverLinkInner {
pub(crate) fn new(
session: Cell<SessionInner>,
handle: Handle,
attach: Attach,
) -> ReceiverLinkInner {
ReceiverLinkInner {
handle,
session: Session::new(session),
closed: false,
reader_task: LocalWaker::new(),
queue: VecDeque::with_capacity(4),
credit: 0,
delivery_count: attach.initial_delivery_count().unwrap_or(0),
attach,
}
}
pub fn name(&self) -> &ByteString {
&self.attach.name
}
pub fn close(
&mut self,
error: Option<Error>,
) -> impl Future<Output = Result<(), AmqpTransportError>> {
let (tx, rx) = oneshot::channel();
if self.closed {
let _ = tx.send(Ok(()));
} else {
self.session
.inner
.get_mut()
.detach_receiver_link(self.handle, true, error, tx);
}
async move {
match rx.await {
Ok(Ok(_)) => Ok(()),
Ok(Err(e)) => Err(e),
Err(_) => Err(AmqpTransportError::Disconnected),
}
}
}
pub fn set_link_credit(&mut self, credit: u32) {
self.credit += credit;
self.session
.inner
.get_mut()
.rcv_link_flow(self.handle as u32, self.delivery_count, credit);
}
pub fn handle_transfer(&mut self, transfer: Transfer) {
if self.credit == 0 {
// check link credit
let err = Error {
condition: LinkError::TransferLimitExceeded.into(),
description: None,
info: None,
};
let _ = self.close(Some(err));
} else {
self.credit -= 1;
self.delivery_count += 1;
self.queue.push_back(transfer);
if self.queue.len() == 1 {
self.reader_task.wake()
}
}
}
}
pub struct ReceiverLinkBuilder {
frame: Attach,
session: Cell<SessionInner>,
}
impl ReceiverLinkBuilder {
pub(crate) fn new(name: ByteString, address: ByteString, session: Cell<SessionInner>) -> Self {
let source = Source {
address: Some(address),
durable: TerminusDurability::None,
expiry_policy: TerminusExpiryPolicy::SessionEnd,
timeout: 0,
dynamic: false,
dynamic_node_properties: None,
distribution_mode: None,
filter: None,
default_outcome: None,
outcomes: None,
capabilities: None,
};
let frame = Attach {
name,
handle: 0 as Handle,
role: Role::Receiver,
snd_settle_mode: SenderSettleMode::Mixed,
rcv_settle_mode: ReceiverSettleMode::First,
source: Some(source),
target: None,
unsettled: None,
incomplete_unsettled: false,
initial_delivery_count: None,
max_message_size: Some(65536 * 4),
offered_capabilities: None,
desired_capabilities: None,
properties: None,
};
ReceiverLinkBuilder { frame, session }
}
pub fn max_message_size(mut self, size: u64) -> Self {
self.frame.max_message_size = Some(size);
self
}
pub async fn open(self) -> Result<ReceiverLink, AmqpTransportError> {
let cell = self.session.clone();
let res = self
.session
.get_mut()
.open_local_receiver_link(cell, self.frame)
.await;
match res {
Ok(Ok(res)) => Ok(res),
Ok(Err(err)) => Err(err),
Err(_) => Err(AmqpTransportError::Disconnected),
}
}
}

190
actix-amqp/src/sasl.rs Executable file
View File

@ -0,0 +1,190 @@
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_connect::{Connect as TcpConnect, Connection as TcpConnection};
use actix_service::{apply_fn, pipeline, IntoService, Service};
use actix_utils::time::LowResTimeService;
use bytestring::ByteString;
use either::Either;
use futures::future::{ok, Future};
use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
use http::Uri;
use amqp_codec::protocol::{Frame, ProtocolId, SaslCode, SaslFrameBody, SaslInit};
use amqp_codec::types::Symbol;
use amqp_codec::{AmqpCodec, AmqpFrame, ProtocolIdCodec, SaslFrame};
use crate::connection::Connection;
use crate::service::ProtocolNegotiation;
use super::Configuration;
pub use crate::errors::SaslConnectError;
#[derive(Debug)]
/// Sasl connect request
pub struct SaslConnect {
pub uri: Uri,
pub config: Configuration,
pub auth: SaslAuth,
pub time: Option<LowResTimeService>,
}
#[derive(Debug)]
/// Sasl authentication parameters
pub struct SaslAuth {
pub authz_id: String,
pub authn_id: String,
pub password: String,
}
/// Create service that connects to amqp server and authenticate itself via sasl.
/// This service uses supplied connector service. Service resolves to
/// a `Connection<_>` instance.
pub fn connect_service<T, Io>(
connector: T,
) -> impl Service<
Request = SaslConnect,
Response = Connection<Io>,
Error = either::Either<SaslConnectError, T::Error>,
>
where
T: Service<Request = TcpConnect<Uri>, Response = TcpConnection<Uri, Io>>,
T::Error: 'static,
Io: AsyncRead + AsyncWrite + 'static,
{
pipeline(|connect: SaslConnect| {
let SaslConnect {
uri,
config,
auth,
time,
} = connect;
ok::<_, either::Either<SaslConnectError, T::Error>>((uri, config, auth, time))
})
// connect to host
.and_then(apply_fn(
connector.map_err(|e| either::Right(e)),
|(uri, config, auth, time): (Uri, Configuration, _, _), srv| {
let fut = srv.call(uri.clone().into());
async move {
fut.await.map(|stream| {
let (io, _) = stream.into_parts();
(io, uri, config, auth, time)
})
}
},
))
// sasl protocol negotiation
.and_then(apply_fn(
ProtocolNegotiation::new(ProtocolId::AmqpSasl)
.map_err(|e| Either::Left(SaslConnectError::from(e))),
|(io, uri, config, auth, time): (Io, _, _, _, _), srv| {
let framed = Framed::new(io, ProtocolIdCodec);
let fut = srv.call(framed);
async move {
fut.await
.map(move |framed| (framed, uri, config, auth, time))
}
},
))
// sasl auth
.and_then(apply_fn(
sasl_connect.into_service().map_err(Either::Left),
|(framed, uri, config, auth, time): (_, Uri, _, _, _), srv| {
let fut = srv.call((framed, uri.clone(), auth));
async move { fut.await.map(move |framed| (uri, config, framed, time)) }
},
))
// re-negotiate amqp protocol negotiation
.and_then(apply_fn(
ProtocolNegotiation::new(ProtocolId::Amqp)
.map_err(|e| Either::Left(SaslConnectError::from(e))),
|(uri, config, framed, time): (_, _, Framed<Io, ProtocolIdCodec>, _), srv| {
let fut = srv.call(framed);
async move { fut.await.map(move |framed| (uri, config, framed, time)) }
},
))
// open connection
.and_then(
|(uri, mut config, framed, time): (Uri, Configuration, Framed<Io, ProtocolIdCodec>, _)| {
async move {
let mut framed = framed.into_framed(AmqpCodec::<AmqpFrame>::new());
if let Some(hostname) = uri.host() {
config.hostname(hostname);
}
let open = config.to_open();
trace!("Open connection: {:?}", open);
framed
.send(AmqpFrame::new(0, Frame::Open(open)))
.await
.map_err(|e| Either::Left(SaslConnectError::from(e)))
.map(move |_| (config, framed, time))
}
},
)
// read open frame
.and_then(
move |(config, mut framed, time): (Configuration, Framed<_, AmqpCodec<AmqpFrame>>, _)| {
async move {
let frame = framed
.next()
.await
.ok_or(Either::Left(SaslConnectError::Disconnected))?
.map_err(|e| Either::Left(SaslConnectError::from(e)))?;
if let Frame::Open(open) = frame.performative() {
trace!("Open confirmed: {:?}", open);
Ok(Connection::new(framed, config, open.into(), time))
} else {
Err(Either::Left(SaslConnectError::ExpectedOpenFrame))
}
}
},
)
}
async fn sasl_connect<Io: AsyncRead + AsyncWrite>(
(framed, uri, auth): (Framed<Io, ProtocolIdCodec>, Uri, SaslAuth),
) -> Result<Framed<Io, ProtocolIdCodec>, SaslConnectError> {
let mut sasl_io = framed.into_framed(AmqpCodec::<SaslFrame>::new());
// processing sasl-mechanisms
let _ = sasl_io
.next()
.await
.ok_or(SaslConnectError::Disconnected)?
.map_err(SaslConnectError::from)?;
let initial_response =
SaslInit::prepare_response(&auth.authz_id, &auth.authn_id, &auth.password);
let hostname = uri.host().map(|host| ByteString::from(host));
let sasl_init = SaslInit {
hostname,
mechanism: Symbol::from("PLAIN"),
initial_response: Some(initial_response),
};
sasl_io
.send(sasl_init.into())
.await
.map_err(SaslConnectError::from)?;
// processing sasl-outcome
let sasl_frame = sasl_io
.next()
.await
.ok_or(SaslConnectError::Disconnected)?
.map_err(SaslConnectError::from)?;
if let SaslFrame {
body: SaslFrameBody::SaslOutcome(outcome),
} = sasl_frame
{
if outcome.code() != SaslCode::Ok {
return Err(SaslConnectError::Sasl(outcome.code()));
}
} else {
return Err(SaslConnectError::Disconnected);
}
Ok(sasl_io.into_framed(ProtocolIdCodec))
}

263
actix-amqp/src/server/app.rs Executable file
View File

@ -0,0 +1,263 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_router::{IntoPattern, Router};
use actix_service::{boxed, fn_factory_with_config, IntoServiceFactory, Service, ServiceFactory};
use amqp_codec::protocol::{DeliveryNumber, DeliveryState, Disposition, Error, Rejected, Role};
use futures::future::{err, ok, Either, Ready};
use futures::{Stream, StreamExt};
use crate::cell::Cell;
use crate::rcvlink::ReceiverLink;
use super::errors::LinkError;
use super::link::Link;
use super::message::{Message, Outcome};
use super::State;
type Handle<S> = boxed::BoxServiceFactory<Link<S>, Message<S>, Outcome, Error, Error>;
pub struct App<S = ()>(Vec<(Vec<String>, Handle<S>)>);
impl<S: 'static> App<S> {
pub fn new() -> App<S> {
App(Vec::new())
}
pub fn service<T, F, U: 'static>(mut self, address: T, service: F) -> Self
where
T: IntoPattern,
F: IntoServiceFactory<U>,
U: ServiceFactory<Config = Link<S>, Request = Message<S>, Response = Outcome>,
U::Error: Into<Error>,
U::InitError: Into<Error>,
{
self.0.push((
address.patterns(),
boxed::factory(
service
.into_factory()
.map_init_err(|e| e.into())
.map_err(|e| e.into()),
),
));
self
}
pub fn finish(
self,
) -> impl ServiceFactory<
Config = State<S>,
Request = Link<S>,
Response = (),
Error = Error,
InitError = Error,
> {
let mut router = Router::build();
for (addr, hnd) in self.0 {
router.path(addr, hnd);
}
let router = Cell::new(router.finish());
fn_factory_with_config(move |_: State<S>| {
ok(AppService {
router: router.clone(),
})
})
}
}
struct AppService<S> {
router: Cell<Router<Handle<S>>>,
}
impl<S: 'static> Service for AppService<S> {
type Request = Link<S>;
type Response = ();
type Error = Error;
type Future = Either<Ready<Result<(), Error>>, AppServiceResponse<S>>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, mut link: Link<S>) -> Self::Future {
let path = link
.frame()
.target
.as_ref()
.and_then(|target| target.address.as_ref().map(|addr| addr.clone()));
if let Some(path) = path {
link.path_mut().set(path);
if let Some((hnd, _info)) = self.router.recognize(link.path_mut()) {
let fut = hnd.new_service(link.clone());
Either::Right(AppServiceResponse {
link: link.link.clone(),
app_state: link.state.clone(),
state: AppServiceResponseState::NewService(fut),
// has_credit: true,
})
} else {
Either::Left(err(LinkError::force_detach()
.description(format!(
"Target address is not supported: {}",
link.path().get_ref()
))
.into()))
}
} else {
Either::Left(err(LinkError::force_detach()
.description("Target address is required")
.into()))
}
}
}
struct AppServiceResponse<S> {
link: ReceiverLink,
app_state: State<S>,
state: AppServiceResponseState<S>,
// has_credit: bool,
}
enum AppServiceResponseState<S> {
Service(boxed::BoxService<Message<S>, Outcome, Error>),
NewService(
Pin<Box<dyn Future<Output = Result<boxed::BoxService<Message<S>, Outcome, Error>, Error>>>>,
),
}
impl<S> Future for AppServiceResponse<S> {
type Output = Result<(), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut this = self.as_mut();
let mut link = this.link.clone();
let app_state = this.app_state.clone();
loop {
match this.state {
AppServiceResponseState::Service(ref mut srv) => {
// check readiness
match srv.poll_ready(cx) {
Poll::Ready(Ok(_)) => (),
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => {
let _ = this.link.close_with_error(
LinkError::force_detach()
.description(format!("error: {}", e))
.into(),
);
return Poll::Ready(Ok(()));
}
}
match Pin::new(&mut link).poll_next(cx) {
Poll::Ready(Some(Ok(transfer))) => {
// #2.7.5 delivery_id MUST be set. batching is not supported atm
if transfer.delivery_id.is_none() {
let _ = this.link.close_with_error(
LinkError::force_detach()
.description("delivery_id MUST be set")
.into(),
);
return Poll::Ready(Ok(()));
}
if link.credit() == 0 {
// self.has_credit = self.link.credit() != 0;
link.set_link_credit(50);
}
let delivery_id = transfer.delivery_id.unwrap();
let msg = Message::new(app_state.clone(), transfer, link.clone());
let mut fut = srv.call(msg);
match Pin::new(&mut fut).poll(cx) {
Poll::Ready(Ok(outcome)) => settle(
&mut this.link,
delivery_id,
outcome.into_delivery_state(),
),
Poll::Pending => {
actix_rt::spawn(HandleMessage {
fut,
delivery_id,
link: this.link.clone(),
});
}
Poll::Ready(Err(e)) => settle(
&mut this.link,
delivery_id,
DeliveryState::Rejected(Rejected { error: Some(e) }),
),
}
}
Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Err(_))) => {
let _ = this.link.close_with_error(LinkError::force_detach().into());
return Poll::Ready(Ok(()));
}
}
}
AppServiceResponseState::NewService(ref mut fut) => match Pin::new(fut).poll(cx) {
Poll::Ready(Ok(srv)) => {
this.link.open();
this.link.set_link_credit(50);
this.state = AppServiceResponseState::Service(srv);
continue;
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
},
}
}
}
}
struct HandleMessage {
link: ReceiverLink,
delivery_id: DeliveryNumber,
fut: Pin<Box<dyn Future<Output = Result<Outcome, Error>>>>,
}
impl Future for HandleMessage {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut this = self.as_mut();
match Pin::new(&mut this.fut).poll(cx) {
Poll::Ready(Ok(outcome)) => {
let delivery_id = this.delivery_id;
settle(&mut this.link, delivery_id, outcome.into_delivery_state());
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
let delivery_id = this.delivery_id;
settle(
&mut this.link,
delivery_id,
DeliveryState::Rejected(Rejected { error: Some(e) }),
);
Poll::Ready(())
}
}
}
}
fn settle(link: &mut ReceiverLink, id: DeliveryNumber, state: DeliveryState) {
let disposition = Disposition {
state: Some(state),
role: Role::Receiver,
first: id,
last: None,
settled: true,
batchable: false,
};
link.send_disposition(disposition);
}

120
actix-amqp/src/server/connect.rs Executable file
View File

@ -0,0 +1,120 @@
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use amqp_codec::protocol::{Frame, Open};
use amqp_codec::{AmqpCodec, AmqpFrame, ProtocolIdCodec};
use futures::{Future, StreamExt};
use super::errors::ServerError;
use crate::connection::ConnectionController;
/// Open new connection
pub struct Connect<Io> {
conn: Framed<Io, ProtocolIdCodec>,
controller: ConnectionController,
}
impl<Io> Connect<Io> {
pub(crate) fn new(conn: Framed<Io, ProtocolIdCodec>, controller: ConnectionController) -> Self {
Self { conn, controller }
}
/// Returns reference to io object
pub fn get_ref(&self) -> &Io {
self.conn.get_ref()
}
/// Returns mutable reference to io object
pub fn get_mut(&mut self) -> &mut Io {
self.conn.get_mut()
}
}
impl<Io: AsyncRead + AsyncWrite> Connect<Io> {
/// Wait for connection open frame
pub async fn open(self) -> Result<ConnectOpened<Io>, ServerError<()>> {
let mut framed = self.conn.into_framed(AmqpCodec::<AmqpFrame>::new());
let mut controller = self.controller;
let frame = framed
.next()
.await
.ok_or(ServerError::Disconnected)?
.map_err(ServerError::from)?;
let frame = frame.into_parts().1;
match frame {
Frame::Open(frame) => {
trace!("Got open frame: {:?}", frame);
controller.set_remote((&frame).into());
Ok(ConnectOpened {
frame,
framed,
controller,
})
}
frame => Err(ServerError::Unexpected(frame)),
}
}
}
/// Connection is opened
pub struct ConnectOpened<Io> {
frame: Open,
framed: Framed<Io, AmqpCodec<AmqpFrame>>,
controller: ConnectionController,
}
impl<Io> ConnectOpened<Io> {
pub(crate) fn new(
frame: Open,
framed: Framed<Io, AmqpCodec<AmqpFrame>>,
controller: ConnectionController,
) -> Self {
ConnectOpened {
frame,
framed,
controller,
}
}
/// Get reference to remote `Open` frame
pub fn frame(&self) -> &Open {
&self.frame
}
/// Returns reference to io object
pub fn get_ref(&self) -> &Io {
self.framed.get_ref()
}
/// Returns mutable reference to io object
pub fn get_mut(&mut self) -> &mut Io {
self.framed.get_mut()
}
/// Connection controller
pub fn connection(&self) -> &ConnectionController {
&self.controller
}
/// Ack connect message and set state
pub fn ack<St>(self, state: St) -> ConnectAck<Io, St> {
ConnectAck {
state,
framed: self.framed,
controller: self.controller,
}
}
}
/// Ack connect message
pub struct ConnectAck<Io, St> {
state: St,
framed: Framed<Io, AmqpCodec<AmqpFrame>>,
controller: ConnectionController,
}
impl<Io, St> ConnectAck<Io, St> {
pub(crate) fn into_inner(self) -> (St, Framed<Io, AmqpCodec<AmqpFrame>>, ConnectionController) {
(self.state, self.framed, self.controller)
}
}

View File

@ -0,0 +1,69 @@
use actix_service::boxed::{BoxService, BoxServiceFactory};
use amqp_codec::protocol;
use crate::cell::Cell;
use crate::rcvlink::ReceiverLink;
use crate::session::Session;
use crate::sndlink::SenderLink;
use super::errors::LinkError;
use super::State;
pub(crate) type ControlFrameService<St> = BoxService<ControlFrame<St>, (), LinkError>;
pub(crate) type ControlFrameNewService<St> =
BoxServiceFactory<(), ControlFrame<St>, (), LinkError, ()>;
pub struct ControlFrame<St>(pub(super) Cell<FrameInner<St>>);
pub(super) struct FrameInner<St> {
pub(super) kind: ControlFrameKind,
pub(super) state: State<St>,
pub(super) session: Session,
}
#[derive(Debug)]
pub enum ControlFrameKind {
Attach(protocol::Attach),
Flow(protocol::Flow, SenderLink),
DetachSender(protocol::Detach, SenderLink),
DetachReceiver(protocol::Detach, ReceiverLink),
}
impl<St> ControlFrame<St> {
pub(crate) fn new(state: State<St>, session: Session, kind: ControlFrameKind) -> Self {
ControlFrame(Cell::new(FrameInner {
state,
session,
kind,
}))
}
pub(crate) fn clone(&self) -> Self {
ControlFrame(self.0.clone())
}
#[inline]
pub fn state(&self) -> &St {
self.0.state.get_ref()
}
#[inline]
pub fn state_mut(&mut self) -> &mut St {
self.0.get_mut().state.get_mut()
}
#[inline]
pub fn session(&self) -> &Session {
&self.0.session
}
#[inline]
pub fn session_mut(&mut self) -> &mut Session {
&mut self.0.get_mut().session
}
#[inline]
pub fn frame(&self) -> &ControlFrameKind {
&self.0.kind
}
}

View File

@ -0,0 +1,287 @@
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite};
use actix_service::Service;
use amqp_codec::protocol::{Error, Frame, Role};
use amqp_codec::AmqpCodecError;
use slab::Slab;
use crate::cell::Cell;
use crate::connection::{ChannelState, Connection};
use crate::rcvlink::ReceiverLink;
use crate::session::Session;
use super::control::{ControlFrame, ControlFrameKind, ControlFrameService};
use super::errors::LinkError;
use super::{Link, State};
/// Amqp server connection dispatcher.
#[pin_project::pin_project]
pub struct Dispatcher<Io, St, Sr>
where
Io: AsyncRead + AsyncWrite,
Sr: Service<Request = Link<St>, Response = ()>,
{
conn: Connection<Io>,
state: State<St>,
service: Sr,
control_srv: Option<ControlFrameService<St>>,
control_frame: Option<ControlFrame<St>>,
#[pin]
control_fut: Option<<ControlFrameService<St> as Service>::Future>,
receivers: Vec<(ReceiverLink, Sr::Future)>,
_channels: slab::Slab<ChannelState>,
}
enum IncomingResult {
Control,
Done,
Disconnect,
}
impl<Io, St, Sr> Dispatcher<Io, St, Sr>
where
Io: AsyncRead + AsyncWrite,
Sr: Service<Request = Link<St>, Response = ()>,
Sr::Error: fmt::Display + Into<Error>,
{
pub(crate) fn new(
conn: Connection<Io>,
state: State<St>,
service: Sr,
control_srv: Option<ControlFrameService<St>>,
) -> Self {
Dispatcher {
conn,
service,
state,
control_srv,
control_frame: None,
control_fut: None,
receivers: Vec::with_capacity(16),
_channels: Slab::with_capacity(16),
}
}
fn handle_control_fut(&mut self, cx: &mut Context) -> bool {
// process control frame
if let Some(ref mut fut) = self.control_fut {
match Pin::new(fut).poll(cx) {
Poll::Ready(Ok(_)) => {
self.control_fut.take();
let frame = self.control_frame.take().unwrap();
self.handle_control_frame(&frame, None);
}
Poll::Pending => return false,
Poll::Ready(Err(e)) => {
let _ = self.control_fut.take();
let frame = self.control_frame.take().unwrap();
self.handle_control_frame(&frame, Some(e));
}
}
}
true
}
fn handle_control_frame(&self, frame: &ControlFrame<St>, err: Option<LinkError>) {
if let Some(e) = err {
error!("Error in link handler: {}", e);
} else {
match frame.0.kind {
ControlFrameKind::Attach(ref frm) => {
let cell = frame.0.session.inner.clone();
frame
.0
.session
.inner
.get_mut()
.confirm_sender_link(cell, &frm);
}
ControlFrameKind::Flow(ref frm, ref link) => {
if let Some(err) = err {
let _ = link.close_with_error(err.into());
} else {
frame.0.session.inner.get_mut().apply_flow(frm);
}
}
ControlFrameKind::DetachSender(ref frm, ref link) => {
if let Some(err) = err {
let _ = link.close_with_error(err.into());
} else {
frame.0.session.inner.get_mut().handle_detach(&frm);
}
}
ControlFrameKind::DetachReceiver(ref frm, ref link) => {
if let Some(err) = err {
let _ = link.close_with_error(err.into());
} else {
frame.0.session.inner.get_mut().handle_detach(frm);
}
}
}
}
}
fn poll_incoming(&mut self, cx: &mut Context) -> Result<IncomingResult, AmqpCodecError> {
loop {
// handle remote begin and attach
match self.conn.poll_incoming(cx) {
Poll::Ready(Some(Ok(frame))) => {
let (channel_id, frame) = frame.into_parts();
let channel_id = channel_id as usize;
match frame {
Frame::Begin(frm) => {
self.conn.register_remote_session(channel_id as u16, &frm);
}
Frame::Flow(frm) => {
// apply flow to specific link
let session = self.conn.get_session(channel_id);
if self.control_srv.is_some() {
if let Some(link) =
session.get_sender_link_by_handle(frm.handle.unwrap())
{
self.control_frame = Some(ControlFrame::new(
self.state.clone(),
Session::new(session.clone()),
ControlFrameKind::Flow(frm, link.clone()),
));
return Ok(IncomingResult::Control);
}
}
session.get_mut().apply_flow(&frm);
}
Frame::Attach(attach) => match attach.role {
Role::Receiver => {
// remotly opened sender link
let session = self.conn.get_session(channel_id);
let cell = session.clone();
if self.control_srv.is_some() {
self.control_frame = Some(ControlFrame::new(
self.state.clone(),
Session::new(cell.clone()),
ControlFrameKind::Attach(attach),
));
return Ok(IncomingResult::Control);
} else {
session.get_mut().confirm_sender_link(cell, &attach);
}
}
Role::Sender => {
// receiver link
let session = self.conn.get_session(channel_id);
let cell = session.clone();
let link = session.get_mut().open_receiver_link(cell, attach);
let fut = self
.service
.call(Link::new(link.clone(), self.state.clone()));
self.receivers.push((link, fut));
}
},
Frame::Detach(frm) => {
let session = self.conn.get_session(channel_id);
let cell = session.clone();
if self.control_srv.is_some() {
if let Some(link) = session.get_sender_link_by_handle(frm.handle) {
self.control_frame = Some(ControlFrame::new(
self.state.clone(),
Session::new(cell.clone()),
ControlFrameKind::DetachSender(frm, link.clone()),
));
return Ok(IncomingResult::Control);
} else if let Some(link) =
session.get_receiver_link_by_handle(frm.handle)
{
self.control_frame = Some(ControlFrame::new(
self.state.clone(),
Session::new(cell.clone()),
ControlFrameKind::DetachReceiver(frm, link.clone()),
));
return Ok(IncomingResult::Control);
}
}
session.get_mut().handle_frame(Frame::Detach(frm));
}
_ => {
trace!("Unexpected frame {:#?}", frame);
}
}
}
Poll::Pending => break,
Poll::Ready(None) => return Ok(IncomingResult::Disconnect),
Poll::Ready(Some(Err(e))) => return Err(e),
}
}
Ok(IncomingResult::Done)
}
}
impl<Io, St, Sr> Future for Dispatcher<Io, St, Sr>
where
Io: AsyncRead + AsyncWrite,
Sr: Service<Request = Link<St>, Response = ()>,
Sr::Error: fmt::Display + Into<Error>,
{
type Output = Result<(), AmqpCodecError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
// process control frame
if !self.handle_control_fut(cx) {
return Poll::Pending;
}
// check control frames service
if self.control_frame.is_some() {
let srv = self.control_srv.as_mut().unwrap();
match srv.poll_ready(cx) {
Poll::Ready(Ok(_)) => {
let frame = self.control_frame.as_ref().unwrap().clone();
let srv = self.control_srv.as_mut().unwrap();
self.control_fut = Some(srv.call(frame));
if !self.handle_control_fut(cx) {
return Poll::Pending;
}
}
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => {
let frame = self.control_frame.take().unwrap();
self.handle_control_frame(&frame, Some(e));
}
}
}
match self.poll_incoming(cx)? {
IncomingResult::Control => return self.poll(cx),
IncomingResult::Disconnect => return Poll::Ready(Ok(())),
IncomingResult::Done => (),
}
// process service responses
let mut idx = 0;
while idx < self.receivers.len() {
match unsafe { Pin::new_unchecked(&mut self.receivers[idx].1) }.poll(cx) {
Poll::Ready(Ok(_detach)) => {
let (link, _) = self.receivers.swap_remove(idx);
let _ = link.close();
}
Poll::Pending => idx += 1,
Poll::Ready(Err(e)) => {
let (link, _) = self.receivers.swap_remove(idx);
error!("Error in link handler: {}", e);
let _ = link.close_with_error(e.into());
}
}
}
let res = self.conn.poll_outgoing(cx);
self.conn.register_write_task(cx);
res
}
}

185
actix-amqp/src/server/errors.rs Executable file
View File

@ -0,0 +1,185 @@
use std::io;
use amqp_codec::{protocol, AmqpCodecError, ProtocolIdError, SaslFrame};
use bytestring::ByteString;
use derive_more::Display;
pub use amqp_codec::protocol::Error;
/// Errors which can occur when attempting to handle amqp connection.
#[derive(Debug, Display)]
pub enum ServerError<E> {
#[display(fmt = "Message handler service error")]
/// Message handler service error
Service(E),
/// Control service init error
ControlServiceInit,
#[display(fmt = "Amqp error: {}", _0)]
/// Amqp error
Amqp(AmqpError),
#[display(fmt = "Protocol negotiation error: {}", _0)]
/// Amqp protocol negotiation error
Handshake(ProtocolIdError),
/// Amqp handshake timeout
HandshakeTimeout,
/// Amqp codec error
#[display(fmt = "Amqp codec error: {:?}", _0)]
Protocol(AmqpCodecError),
#[display(fmt = "Protocol error: {}", _0)]
/// Amqp protocol error
ProtocolError(Error),
#[display(fmt = "Expected open frame, got: {:?}", _0)]
Unexpected(protocol::Frame),
#[display(fmt = "Unexpected sasl frame: {:?}", _0)]
UnexpectedSaslFrame(SaslFrame),
#[display(fmt = "Unexpected sasl frame body: {:?}", _0)]
UnexpectedSaslBodyFrame(protocol::SaslFrameBody),
/// Peer disconnect
Disconnected,
/// Unexpected io error
Io(io::Error),
}
impl<E> Into<protocol::Error> for ServerError<E> {
fn into(self) -> protocol::Error {
protocol::Error {
condition: protocol::AmqpError::InternalError.into(),
description: Some(ByteString::from(format!("{}", self))),
info: None,
}
}
}
impl<E> From<AmqpError> for ServerError<E> {
fn from(err: AmqpError) -> Self {
ServerError::Amqp(err)
}
}
impl<E> From<AmqpCodecError> for ServerError<E> {
fn from(err: AmqpCodecError) -> Self {
ServerError::Protocol(err)
}
}
impl<E> From<ProtocolIdError> for ServerError<E> {
fn from(err: ProtocolIdError) -> Self {
ServerError::Handshake(err)
}
}
impl<E> From<SaslFrame> for ServerError<E> {
fn from(err: SaslFrame) -> Self {
ServerError::UnexpectedSaslFrame(err)
}
}
impl<E> From<io::Error> for ServerError<E> {
fn from(err: io::Error) -> Self {
ServerError::Io(err)
}
}
#[derive(Debug, Display)]
#[display(fmt = "Amqp error: {:?} {:?} ({:?})", err, description, info)]
pub struct AmqpError {
err: protocol::AmqpError,
description: Option<ByteString>,
info: Option<protocol::Fields>,
}
impl AmqpError {
pub fn new(err: protocol::AmqpError) -> Self {
AmqpError {
err,
description: None,
info: None,
}
}
pub fn internal_error() -> Self {
Self::new(protocol::AmqpError::InternalError)
}
pub fn not_found() -> Self {
Self::new(protocol::AmqpError::NotFound)
}
pub fn unauthorized_access() -> Self {
Self::new(protocol::AmqpError::UnauthorizedAccess)
}
pub fn decode_error() -> Self {
Self::new(protocol::AmqpError::DecodeError)
}
pub fn invalid_field() -> Self {
Self::new(protocol::AmqpError::InvalidField)
}
pub fn not_allowed() -> Self {
Self::new(protocol::AmqpError::NotAllowed)
}
pub fn not_implemented() -> Self {
Self::new(protocol::AmqpError::NotImplemented)
}
pub fn description<T: AsRef<str>>(mut self, text: T) -> Self {
self.description = Some(ByteString::from(text.as_ref()));
self
}
pub fn set_description(mut self, text: ByteString) -> Self {
self.description = Some(text);
self
}
}
impl Into<protocol::Error> for AmqpError {
fn into(self) -> protocol::Error {
protocol::Error {
condition: self.err.into(),
description: self.description,
info: self.info,
}
}
}
#[derive(Debug, Display)]
#[display(fmt = "Link error: {:?} {:?} ({:?})", err, description, info)]
pub struct LinkError {
err: protocol::LinkError,
description: Option<ByteString>,
info: Option<protocol::Fields>,
}
impl LinkError {
pub fn force_detach() -> Self {
LinkError {
err: protocol::LinkError::DetachForced,
description: None,
info: None,
}
}
pub fn description<T: AsRef<str>>(mut self, text: T) -> Self {
self.description = Some(ByteString::from(text.as_ref()));
self
}
pub fn set_description(mut self, text: ByteString) -> Self {
self.description = Some(text);
self
}
}
impl Into<protocol::Error> for LinkError {
fn into(self) -> protocol::Error {
protocol::Error {
condition: self.err.into(),
description: self.description,
info: self.info,
}
}
}

View File

@ -0,0 +1,50 @@
use actix_service::{IntoServiceFactory, ServiceFactory};
use super::connect::ConnectAck;
pub fn handshake<Io, St, A, F>(srv: F) -> Handshake<Io, St, A>
where
F: IntoServiceFactory<A>,
A: ServiceFactory<Config = (), Response = ConnectAck<Io, St>>,
{
Handshake::new(srv)
}
pub struct Handshake<Io, St, A> {
a: A,
_t: std::marker::PhantomData<(Io, St)>,
}
impl<Io, St, A> Handshake<Io, St, A>
where
A: ServiceFactory<Config = ()>,
{
pub fn new<F>(srv: F) -> Handshake<Io, St, A>
where
F: IntoServiceFactory<A>,
{
Handshake {
a: srv.into_factory(),
_t: std::marker::PhantomData,
}
}
}
impl<Io, St, A> Handshake<Io, St, A>
where
A: ServiceFactory<Config = (), Response = ConnectAck<Io, St>>,
{
pub fn sasl<F, B>(self, srv: F) -> actix_utils::either::Either<A, B>
where
F: IntoServiceFactory<B>,
B: ServiceFactory<
Config = (),
Response = A::Response,
Error = A::Error,
InitError = A::InitError,
>,
B::Error: Into<amqp_codec::protocol::Error>,
{
actix_utils::either::Either::new(self.a, srv.into_factory())
}
}

87
actix-amqp/src/server/link.rs Executable file
View File

@ -0,0 +1,87 @@
use std::fmt;
use actix_router::Path;
use amqp_codec::protocol::Attach;
use bytestring::ByteString;
use crate::cell::Cell;
use crate::rcvlink::ReceiverLink;
use crate::server::State;
use crate::session::Session;
use crate::{Configuration, Handle};
pub struct Link<S> {
pub(crate) state: State<S>,
pub(crate) link: ReceiverLink,
pub(crate) path: Path<ByteString>,
}
impl<S> Link<S> {
pub(crate) fn new(link: ReceiverLink, state: State<S>) -> Self {
Link {
state,
link,
path: Path::new(ByteString::from_static("")),
}
}
pub fn path(&self) -> &Path<ByteString> {
&self.path
}
pub fn path_mut(&mut self) -> &mut Path<ByteString> {
&mut self.path
}
pub fn frame(&self) -> &Attach {
self.link.frame()
}
pub fn state(&self) -> &S {
self.state.get_ref()
}
pub fn state_mut(&mut self) -> &mut S {
self.state.get_mut()
}
pub fn handle(&self) -> Handle {
self.link.handle()
}
pub fn session(&self) -> &Session {
self.link.session()
}
pub fn session_mut(&mut self) -> &mut Session {
self.link.session_mut()
}
pub fn link_credit(mut self, credit: u32) {
self.link.set_link_credit(credit);
}
#[inline]
/// Get remote connection configuration
pub fn remote_config(&self) -> &Configuration {
&self.link.remote_config()
}
}
impl<S> Clone for Link<S> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
link: self.link.clone(),
path: self.path.clone(),
}
}
}
impl<S> fmt::Debug for Link<S> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("Link<S>")
.field("frame", self.link.frame())
.finish()
}
}

View File

@ -0,0 +1,96 @@
use std::fmt;
use amqp_codec::protocol::{Accepted, DeliveryState, Error, Rejected, Transfer, TransferBody};
use amqp_codec::Decode;
use bytes::Bytes;
use crate::rcvlink::ReceiverLink;
use crate::session::Session;
use super::errors::AmqpError;
use super::State;
pub struct Message<S> {
state: State<S>,
frame: Transfer,
link: ReceiverLink,
}
#[derive(Debug)]
pub enum Outcome {
Accept,
Reject,
Error(Error),
}
impl<T> From<T> for Outcome
where
T: Into<Error>,
{
fn from(err: T) -> Self {
Outcome::Error(err.into())
}
}
impl Outcome {
pub(crate) fn into_delivery_state(self) -> DeliveryState {
match self {
Outcome::Accept => DeliveryState::Accepted(Accepted {}),
Outcome::Reject => DeliveryState::Rejected(Rejected { error: None }),
Outcome::Error(e) => DeliveryState::Rejected(Rejected { error: Some(e) }),
}
}
}
impl<S> Message<S> {
pub(crate) fn new(state: State<S>, frame: Transfer, link: ReceiverLink) -> Self {
Message { state, frame, link }
}
pub fn state(&self) -> &S {
self.state.get_ref()
}
pub fn state_mut(&mut self) -> &mut S {
self.state.get_mut()
}
pub fn session(&self) -> &Session {
self.link.session()
}
pub fn session_mut(&mut self) -> &mut Session {
self.link.session_mut()
}
pub fn frame(&self) -> &Transfer {
&self.frame
}
pub fn body(&self) -> Option<&Bytes> {
match self.frame.body {
Some(TransferBody::Data(ref b)) => Some(b),
_ => None,
}
}
pub fn load_message<T: Decode>(&self) -> Result<T, AmqpError> {
if let Some(TransferBody::Data(ref b)) = self.frame.body {
if let Ok((_, msg)) = T::decode(b) {
Ok(msg)
} else {
Err(AmqpError::decode_error().description("Can not decode message"))
}
} else {
Err(AmqpError::invalid_field().description("Unknown body"))
}
}
}
impl<S> fmt::Debug for Message<S> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("Message<S>")
.field("frame", &self.frame)
.finish()
}
}

55
actix-amqp/src/server/mod.rs Executable file
View File

@ -0,0 +1,55 @@
mod app;
mod connect;
mod control;
mod dispatcher;
pub mod errors;
mod handshake;
mod link;
mod message;
pub mod sasl;
mod service;
pub use self::app::App;
pub use self::connect::{Connect, ConnectAck, ConnectOpened};
pub use self::control::{ControlFrame, ControlFrameKind};
pub use self::handshake::{handshake, Handshake};
pub use self::link::Link;
pub use self::message::{Message, Outcome};
pub use self::sasl::Sasl;
pub use self::service::Server;
use crate::cell::Cell;
pub struct State<St>(Cell<St>);
impl<St> State<St> {
pub(crate) fn new(st: St) -> Self {
State(Cell::new(st))
}
pub(crate) fn clone(&self) -> Self {
State(self.0.clone())
}
pub fn get_ref(&self) -> &St {
self.0.get_ref()
}
pub fn get_mut(&mut self) -> &mut St {
self.0.get_mut()
}
}
impl<St> std::ops::Deref for State<St> {
type Target = St;
fn deref(&self) -> &Self::Target {
self.get_ref()
}
}
impl<St> std::ops::DerefMut for State<St> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.get_mut()
}
}

338
actix-amqp/src/server/sasl.rs Executable file
View File

@ -0,0 +1,338 @@
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_service::{Service, ServiceFactory};
use amqp_codec::protocol::{
self, ProtocolId, SaslChallenge, SaslCode, SaslFrameBody, SaslMechanisms, SaslOutcome, Symbols,
};
use amqp_codec::{AmqpCodec, AmqpFrame, ProtocolIdCodec, ProtocolIdError, SaslFrame};
use bytes::Bytes;
use bytestring::ByteString;
use futures::future::{err, ok, Either, Ready};
use futures::{SinkExt, StreamExt};
use super::connect::{ConnectAck, ConnectOpened};
use super::errors::{AmqpError, ServerError};
use crate::connection::ConnectionController;
pub struct Sasl<Io> {
framed: Framed<Io, ProtocolIdCodec>,
mechanisms: Symbols,
controller: ConnectionController,
}
impl<Io> fmt::Debug for Sasl<Io> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("SaslAuth")
.field("mechanisms", &self.mechanisms)
.finish()
}
}
impl<Io> Sasl<Io> {
pub(crate) fn new(
framed: Framed<Io, ProtocolIdCodec>,
controller: ConnectionController,
) -> Self {
Sasl {
framed,
controller,
mechanisms: Symbols::default(),
}
}
}
impl<Io> Sasl<Io>
where
Io: AsyncRead + AsyncWrite,
{
/// Returns reference to io object
pub fn get_ref(&self) -> &Io {
self.framed.get_ref()
}
/// Returns mutable reference to io object
pub fn get_mut(&mut self) -> &mut Io {
self.framed.get_mut()
}
/// Add supported sasl mechanism
pub fn mechanism<U: Into<String>>(mut self, symbol: U) -> Self {
self.mechanisms.push(ByteString::from(symbol.into()).into());
self
}
/// Initialize sasl auth procedure
pub async fn init(self) -> Result<Init<Io>, ServerError<()>> {
let Sasl {
framed,
mechanisms,
controller,
..
} = self;
let mut framed = framed.into_framed(AmqpCodec::<SaslFrame>::new());
let frame = SaslMechanisms {
sasl_server_mechanisms: mechanisms,
}
.into();
framed.send(frame).await.map_err(ServerError::from)?;
let frame = framed
.next()
.await
.ok_or(ServerError::Disconnected)?
.map_err(ServerError::from)?;
match frame.body {
SaslFrameBody::SaslInit(frame) => Ok(Init {
frame,
framed,
controller,
}),
body => Err(ServerError::UnexpectedSaslBodyFrame(body)),
}
}
}
/// Initialization stage of sasl negotiation
pub struct Init<Io> {
frame: protocol::SaslInit,
framed: Framed<Io, AmqpCodec<SaslFrame>>,
controller: ConnectionController,
}
impl<Io> fmt::Debug for Init<Io> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("SaslInit")
.field("frame", &self.frame)
.finish()
}
}
impl<Io> Init<Io>
where
Io: AsyncRead + AsyncWrite,
{
/// Sasl mechanism
pub fn mechanism(&self) -> &str {
self.frame.mechanism.as_str()
}
/// Sasl initial response
pub fn initial_response(&self) -> Option<&[u8]> {
self.frame.initial_response.as_ref().map(|b| b.as_ref())
}
/// Sasl initial response
pub fn hostname(&self) -> Option<&str> {
self.frame.hostname.as_ref().map(|b| b.as_ref())
}
/// Returns reference to io object
pub fn get_ref(&self) -> &Io {
self.framed.get_ref()
}
/// Returns mutable reference to io object
pub fn get_mut(&mut self) -> &mut Io {
self.framed.get_mut()
}
/// Initiate sasl challenge
pub async fn challenge(self) -> Result<Response<Io>, ServerError<()>> {
self.challenge_with(Bytes::new()).await
}
/// Initiate sasl challenge with challenge payload
pub async fn challenge_with(self, challenge: Bytes) -> Result<Response<Io>, ServerError<()>> {
let mut framed = self.framed;
let controller = self.controller;
let frame = SaslChallenge { challenge }.into();
framed.send(frame).await.map_err(ServerError::from)?;
let frame = framed
.next()
.await
.ok_or(ServerError::Disconnected)?
.map_err(ServerError::from)?;
match frame.body {
SaslFrameBody::SaslResponse(frame) => Ok(Response {
frame,
framed,
controller,
}),
body => Err(ServerError::UnexpectedSaslBodyFrame(body)),
}
}
/// Sasl challenge outcome
pub async fn outcome(self, code: SaslCode) -> Result<Success<Io>, ServerError<()>> {
let mut framed = self.framed;
let controller = self.controller;
let frame = SaslOutcome {
code,
additional_data: None,
}
.into();
framed.send(frame).await.map_err(ServerError::from)?;
Ok(Success { framed, controller })
}
}
pub struct Response<Io> {
frame: protocol::SaslResponse,
framed: Framed<Io, AmqpCodec<SaslFrame>>,
controller: ConnectionController,
}
impl<Io> fmt::Debug for Response<Io> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("SaslResponse")
.field("frame", &self.frame)
.finish()
}
}
impl<Io> Response<Io>
where
Io: AsyncRead + AsyncWrite,
{
/// Client response payload
pub fn response(&self) -> &[u8] {
&self.frame.response[..]
}
/// Sasl challenge outcome
pub async fn outcome(self, code: SaslCode) -> Result<Success<Io>, ServerError<()>> {
let mut framed = self.framed;
let controller = self.controller;
let frame = SaslOutcome {
code,
additional_data: None,
}
.into();
framed.send(frame).await.map_err(ServerError::from)?;
framed
.next()
.await
.ok_or(ServerError::Disconnected)?
.map_err(|res| ServerError::from(res))?;
Ok(Success { framed, controller })
}
}
pub struct Success<Io> {
framed: Framed<Io, AmqpCodec<SaslFrame>>,
controller: ConnectionController,
}
impl<Io> Success<Io>
where
Io: AsyncRead + AsyncWrite,
{
/// Returns reference to io object
pub fn get_ref(&self) -> &Io {
self.framed.get_ref()
}
/// Returns mutable reference to io object
pub fn get_mut(&mut self) -> &mut Io {
self.framed.get_mut()
}
/// Wait for connection open frame
pub async fn open(self) -> Result<ConnectOpened<Io>, ServerError<()>> {
let mut framed = self.framed.into_framed(ProtocolIdCodec);
let mut controller = self.controller;
let protocol = framed
.next()
.await
.ok_or(ServerError::from(ProtocolIdError::Disconnected))?
.map_err(ServerError::from)?;
match protocol {
ProtocolId::Amqp => {
// confirm protocol
framed
.send(ProtocolId::Amqp)
.await
.map_err(ServerError::from)?;
// Wait for connection open frame
let mut framed = framed.into_framed(AmqpCodec::<AmqpFrame>::new());
let frame = framed
.next()
.await
.ok_or(ServerError::Disconnected)?
.map_err(ServerError::from)?;
let frame = frame.into_parts().1;
match frame {
protocol::Frame::Open(frame) => {
trace!("Got open frame: {:?}", frame);
controller.set_remote((&frame).into());
Ok(ConnectOpened::new(frame, framed, controller))
}
frame => Err(ServerError::Unexpected(frame)),
}
}
proto => Err(ProtocolIdError::Unexpected {
exp: ProtocolId::Amqp,
got: proto,
}
.into()),
}
}
}
/// Create service factory with disabled sasl support
pub fn no_sasl<Io, St, E>() -> NoSaslService<Io, St, E> {
NoSaslService::default()
}
pub struct NoSaslService<Io, St, E>(std::marker::PhantomData<(Io, St, E)>);
impl<Io, St, E> Default for NoSaslService<Io, St, E> {
fn default() -> Self {
NoSaslService(std::marker::PhantomData)
}
}
impl<Io, St, E> ServiceFactory for NoSaslService<Io, St, E> {
type Config = ();
type Request = Sasl<Io>;
type Response = ConnectAck<Io, St>;
type Error = AmqpError;
type InitError = E;
type Service = NoSaslService<Io, St, E>;
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: ()) -> Self::Future {
ok(NoSaslService(std::marker::PhantomData))
}
}
impl<Io, St, E> Service for NoSaslService<Io, St, E> {
type Request = Sasl<Io>;
type Response = ConnectAck<Io, St>;
type Error = AmqpError;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: Self::Request) -> Self::Future {
err(AmqpError::not_implemented())
}
}

360
actix-amqp/src/server/service.rs Executable file
View File

@ -0,0 +1,360 @@
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, time};
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_service::{boxed, IntoServiceFactory, Service, ServiceFactory};
use amqp_codec::protocol::{Error, ProtocolId};
use amqp_codec::{AmqpCodecError, AmqpFrame, ProtocolIdCodec, ProtocolIdError};
use futures::future::{err, poll_fn, Either};
use futures::{FutureExt, SinkExt, StreamExt};
use crate::cell::Cell;
use crate::connection::{Connection, ConnectionController};
use crate::Configuration;
use super::connect::{Connect, ConnectAck};
use super::control::{ControlFrame, ControlFrameNewService};
use super::dispatcher::Dispatcher;
use super::errors::{LinkError, ServerError};
use super::link::Link;
use super::sasl::Sasl;
use super::State;
/// Amqp connection type
pub type AmqpConnect<Io> = either::Either<Connect<Io>, Sasl<Io>>;
/// Server dispatcher factory
pub struct Server<Io, St, Cn: ServiceFactory> {
connect: Cn,
config: Configuration,
control: Option<ControlFrameNewService<St>>,
disconnect: Option<Box<dyn Fn(&mut St, Option<&ServerError<Cn::Error>>)>>,
max_size: usize,
handshake_timeout: u64,
_t: PhantomData<(Io, St)>,
}
pub(super) struct ServerInner<St, Cn: ServiceFactory, Pb> {
connect: Cn,
publish: Pb,
config: Configuration,
control: Option<ControlFrameNewService<St>>,
disconnect: Option<Box<dyn Fn(&mut St, Option<&ServerError<Cn::Error>>)>>,
max_size: usize,
handshake_timeout: u64,
}
impl<Io, St, Cn> Server<Io, St, Cn>
where
St: 'static,
Io: AsyncRead + AsyncWrite + 'static,
Cn: ServiceFactory<Config = (), Request = AmqpConnect<Io>, Response = ConnectAck<Io, St>>
+ 'static,
{
/// Create server factory and provide connect service
pub fn new<F>(connect: F) -> Self
where
F: IntoServiceFactory<Cn>,
{
Self {
connect: connect.into_factory(),
config: Configuration::default(),
control: None,
disconnect: None,
max_size: 0,
handshake_timeout: 0,
_t: PhantomData,
}
}
/// Provide connection configuration
pub fn config(mut self, config: Configuration) -> Self {
self.config = config;
self
}
/// Set max inbound frame size.
///
/// If max size is set to `0`, size is unlimited.
/// By default max size is set to `0`
pub fn max_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
/// Set handshake timeout in millis.
///
/// By default handshake timeuot is disabled.
pub fn handshake_timeout(mut self, timeout: u64) -> Self {
self.handshake_timeout = timeout;
self
}
/// Service to call with control frames
pub fn control<F, S>(self, f: F) -> Self
where
F: IntoServiceFactory<S>,
S: ServiceFactory<Config = (), Request = ControlFrame<St>, Response = (), InitError = ()>
+ 'static,
S::Error: Into<LinkError>,
{
Server {
connect: self.connect,
config: self.config,
disconnect: self.disconnect,
control: Some(boxed::factory(
f.into_factory()
.map_err(|e| e.into())
.map_init_err(|e| e.into()),
)),
max_size: self.max_size,
handshake_timeout: self.handshake_timeout,
_t: PhantomData,
}
}
/// Callback to execute on disconnect
///
/// Second parameter indicates error occured during disconnect.
pub fn disconnect<F, Out>(self, disconnect: F) -> Self
where
F: Fn(&mut St, Option<&ServerError<Cn::Error>>) -> Out + 'static,
Out: Future + 'static,
{
Server {
connect: self.connect,
config: self.config,
control: self.control,
disconnect: Some(Box::new(move |st, err| {
let fut = disconnect(st, err);
actix_rt::spawn(fut.map(|_| ()));
})),
max_size: self.max_size,
handshake_timeout: self.handshake_timeout,
_t: PhantomData,
}
}
/// Set service to execute for incoming links and create service factory
pub fn finish<F, Pb>(
self,
service: F,
) -> impl ServiceFactory<Config = (), Request = Io, Response = (), Error = ServerError<Cn::Error>>
where
F: IntoServiceFactory<Pb>,
Pb: ServiceFactory<Config = State<St>, Request = Link<St>, Response = ()> + 'static,
Pb::Error: fmt::Display + Into<Error>,
Pb::InitError: fmt::Display + Into<Error>,
{
ServerImpl {
inner: Cell::new(ServerInner {
connect: self.connect,
config: self.config,
publish: service.into_factory(),
control: self.control,
disconnect: self.disconnect,
max_size: self.max_size,
handshake_timeout: self.handshake_timeout,
}),
_t: PhantomData,
}
}
}
struct ServerImpl<Io, St, Cn: ServiceFactory, Pb> {
inner: Cell<ServerInner<St, Cn, Pb>>,
_t: PhantomData<(Io,)>,
}
impl<Io, St, Cn, Pb> ServiceFactory for ServerImpl<Io, St, Cn, Pb>
where
St: 'static,
Io: AsyncRead + AsyncWrite + 'static,
Cn: ServiceFactory<Config = (), Request = AmqpConnect<Io>, Response = ConnectAck<Io, St>>
+ 'static,
Pb: ServiceFactory<Config = State<St>, Request = Link<St>, Response = ()> + 'static,
Pb::Error: fmt::Display + Into<Error>,
Pb::InitError: fmt::Display + Into<Error>,
{
type Config = ();
type Request = Io;
type Response = ();
type Error = ServerError<Cn::Error>;
type Service = ServerImplService<Io, St, Cn, Pb>;
type InitError = Cn::InitError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Service, Cn::InitError>>>>;
fn new_service(&self, _: ()) -> Self::Future {
let inner = self.inner.clone();
Box::pin(async move {
inner
.connect
.new_service(())
.await
.map(move |connect| ServerImplService {
inner,
connect: Cell::new(connect),
_t: PhantomData,
})
})
}
}
struct ServerImplService<Io, St, Cn: ServiceFactory, Pb> {
connect: Cell<Cn::Service>,
inner: Cell<ServerInner<St, Cn, Pb>>,
_t: PhantomData<(Io,)>,
}
impl<Io, St, Cn, Pb> Service for ServerImplService<Io, St, Cn, Pb>
where
St: 'static,
Io: AsyncRead + AsyncWrite + 'static,
Cn: ServiceFactory<Config = (), Request = AmqpConnect<Io>, Response = ConnectAck<Io, St>>
+ 'static,
Pb: ServiceFactory<Config = State<St>, Request = Link<St>, Response = ()> + 'static,
Pb::Error: fmt::Display + Into<Error>,
Pb::InitError: fmt::Display + Into<Error>,
{
type Request = Io;
type Response = ();
type Error = ServerError<Cn::Error>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.connect
.get_mut()
.poll_ready(cx)
.map(|res| res.map_err(|e| ServerError::Service(e)))
}
fn call(&mut self, req: Self::Request) -> Self::Future {
let timeout = self.inner.handshake_timeout;
if timeout == 0 {
Box::pin(handshake(
self.inner.max_size,
self.connect.clone(),
self.inner.clone(),
req,
))
} else {
Box::pin(
actix_rt::time::timeout(
time::Duration::from_millis(timeout),
handshake(
self.inner.max_size,
self.connect.clone(),
self.inner.clone(),
req,
),
)
.map(|res| match res {
Ok(res) => res,
Err(_) => Err(ServerError::HandshakeTimeout),
}),
)
}
}
}
async fn handshake<Io, St, Cn: ServiceFactory, Pb>(
max_size: usize,
connect: Cell<Cn::Service>,
inner: Cell<ServerInner<St, Cn, Pb>>,
io: Io,
) -> Result<(), ServerError<Cn::Error>>
where
St: 'static,
Io: AsyncRead + AsyncWrite + 'static,
Cn: ServiceFactory<Config = (), Request = AmqpConnect<Io>, Response = ConnectAck<Io, St>>,
Pb: ServiceFactory<Config = State<St>, Request = Link<St>, Response = ()> + 'static,
Pb::Error: fmt::Display + Into<Error>,
Pb::InitError: fmt::Display + Into<Error>,
{
let inner2 = inner.clone();
let mut framed = Framed::new(io, ProtocolIdCodec);
let protocol = framed
.next()
.await
.ok_or(ServerError::Disconnected)?
.map_err(ServerError::Handshake)?;
let (st, srv, conn) = match protocol {
// start amqp processing
ProtocolId::Amqp | ProtocolId::AmqpSasl => {
framed.send(protocol).await.map_err(ServerError::from)?;
let cfg = inner.get_ref().config.clone();
let controller = ConnectionController::new(cfg.clone());
let ack = connect
.get_mut()
.call(if protocol == ProtocolId::Amqp {
either::Either::Left(Connect::new(framed, controller))
} else {
either::Either::Right(Sasl::new(framed, controller))
})
.await
.map_err(|e| ServerError::Service(e))?;
let (st, mut framed, controller) = ack.into_inner();
let st = State::new(st);
framed.get_codec_mut().max_size(max_size);
// confirm Open
let local = cfg.to_open();
framed
.send(AmqpFrame::new(0, local.into()))
.await
.map_err(ServerError::from)?;
let conn = Connection::new_server(framed, controller.0, None);
// create publish service
let srv = inner.publish.new_service(st.clone()).await.map_err(|e| {
error!("Can not construct app service");
ServerError::ProtocolError(e.into())
})?;
(st, srv, conn)
}
ProtocolId::AmqpTls => {
return Err(ServerError::from(ProtocolIdError::Unexpected {
exp: ProtocolId::Amqp,
got: ProtocolId::AmqpTls,
}))
}
};
let mut st2 = st.clone();
if let Some(ref control_srv) = inner2.control {
let control = control_srv
.new_service(())
.await
.map_err(|_| ServerError::ControlServiceInit)?;
let res = Dispatcher::new(conn, st, srv, Some(control))
.await
.map_err(ServerError::from);
if inner2.disconnect.is_some() {
(*inner2.get_mut().disconnect.as_mut().unwrap())(st2.get_mut(), res.as_ref().err())
}
res
} else {
let res = Dispatcher::new(conn, st, srv, None)
.await
.map_err(ServerError::from);
if inner2.disconnect.is_some() {
(*inner2.get_mut().disconnect.as_mut().unwrap())(st2.get_mut(), res.as_ref().err())
}
res
}
}

65
actix-amqp/src/service.rs Executable file
View File

@ -0,0 +1,65 @@
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_service::Service;
use futures::{Future, SinkExt, StreamExt};
use amqp_codec::protocol::ProtocolId;
use amqp_codec::{ProtocolIdCodec, ProtocolIdError};
pub struct ProtocolNegotiation<T> {
proto: ProtocolId,
_r: PhantomData<T>,
}
impl<T> Clone for ProtocolNegotiation<T> {
fn clone(&self) -> Self {
ProtocolNegotiation {
proto: self.proto.clone(),
_r: PhantomData,
}
}
}
impl<T> ProtocolNegotiation<T> {
pub fn new(proto: ProtocolId) -> Self {
ProtocolNegotiation {
proto,
_r: PhantomData,
}
}
}
impl<T> Service for ProtocolNegotiation<T>
where
T: AsyncRead + AsyncWrite + 'static,
{
type Request = Framed<T, ProtocolIdCodec>;
type Response = Framed<T, ProtocolIdCodec>;
type Error = ProtocolIdError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, mut framed: Framed<T, ProtocolIdCodec>) -> Self::Future {
let proto = self.proto;
Box::pin(async move {
framed.send(proto).await?;
let protocol = framed.next().await.ok_or(ProtocolIdError::Disconnected)??;
if proto == protocol {
Ok(framed)
} else {
Err(ProtocolIdError::Unexpected {
exp: proto,
got: protocol,
})
}
})
}
}

911
actix-amqp/src/session.rs Executable file
View File

@ -0,0 +1,911 @@
use std::collections::VecDeque;
use std::future::Future;
use actix_utils::oneshot;
use bytes::{BufMut, Bytes, BytesMut};
use bytestring::ByteString;
use either::Either;
use futures::future::ok;
use fxhash::FxHashMap;
use slab::Slab;
use amqp_codec::protocol::{
Accepted, Attach, DeliveryNumber, DeliveryState, Detach, Disposition, Error, Flow, Frame,
Handle, ReceiverSettleMode, Role, SenderSettleMode, Transfer, TransferBody, TransferNumber,
};
use amqp_codec::AmqpFrame;
use crate::cell::Cell;
use crate::connection::ConnectionController;
use crate::errors::AmqpTransportError;
use crate::rcvlink::{ReceiverLink, ReceiverLinkBuilder, ReceiverLinkInner};
use crate::sndlink::{SenderLink, SenderLinkBuilder, SenderLinkInner};
use crate::{Configuration, DeliveryPromise};
const INITIAL_OUTGOING_ID: TransferNumber = 0;
#[derive(Clone)]
pub struct Session {
pub(crate) inner: Cell<SessionInner>,
}
impl Drop for Session {
fn drop(&mut self) {
self.inner.get_mut().drop_session()
}
}
impl std::fmt::Debug for Session {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("Session").finish()
}
}
impl Session {
pub(crate) fn new(inner: Cell<SessionInner>) -> Session {
Session { inner }
}
#[inline]
/// Get remote connection configuration
pub fn remote_config(&self) -> &Configuration {
self.inner.connection.remote_config()
}
pub fn close(&self) -> impl Future<Output = Result<(), AmqpTransportError>> {
ok(())
}
pub fn get_sender_link(&self, name: &str) -> Option<&SenderLink> {
let inner = self.inner.get_ref();
if let Some(id) = inner.links_by_name.get(name) {
if let Some(Either::Left(SenderLinkState::Established(ref link))) = inner.links.get(*id)
{
return Some(link);
}
}
None
}
pub fn get_sender_link_by_handle(&self, hnd: Handle) -> Option<&SenderLink> {
self.inner.get_ref().get_sender_link_by_handle(hnd)
}
pub fn get_receiver_link_by_handle(&self, hnd: Handle) -> Option<&ReceiverLink> {
self.inner.get_ref().get_receiver_link_by_handle(hnd)
}
/// Open sender link
pub fn build_sender_link<T: Into<String>, U: Into<String>>(
&mut self,
name: U,
address: T,
) -> SenderLinkBuilder {
let name = ByteString::from(name.into());
let address = ByteString::from(address.into());
SenderLinkBuilder::new(name, address, self.inner.clone())
}
/// Open receiver link
pub fn build_receiver_link<T: Into<String>, U: Into<String>>(
&mut self,
name: U,
address: T,
) -> ReceiverLinkBuilder {
let name = ByteString::from(name.into());
let address = ByteString::from(address.into());
ReceiverLinkBuilder::new(name, address, self.inner.clone())
}
/// Detach receiver link
pub fn detach_receiver_link(
&mut self,
handle: Handle,
error: Option<Error>,
) -> impl Future<Output = Result<(), AmqpTransportError>> {
let (tx, rx) = oneshot::channel();
self.inner
.get_mut()
.detach_receiver_link(handle, false, error, tx);
async move {
match rx.await {
Ok(Ok(_)) => Ok(()),
Ok(Err(e)) => Err(e),
Err(_) => Err(AmqpTransportError::Disconnected),
}
}
}
pub fn wait_disposition(
&mut self,
id: DeliveryNumber,
) -> impl Future<Output = Result<Disposition, AmqpTransportError>> {
self.inner.get_mut().wait_disposition(id)
}
}
#[derive(Debug)]
enum SenderLinkState {
Opening(oneshot::Sender<SenderLink>),
Established(SenderLink),
Closing(Option<oneshot::Sender<Result<(), AmqpTransportError>>>),
}
#[derive(Debug)]
enum ReceiverLinkState {
Opening(Option<Cell<ReceiverLinkInner>>),
OpeningLocal(
Option<(
Cell<ReceiverLinkInner>,
oneshot::Sender<Result<ReceiverLink, AmqpTransportError>>,
)>,
),
Established(ReceiverLink),
Closing(Option<oneshot::Sender<Result<(), AmqpTransportError>>>),
}
impl SenderLinkState {
fn is_opening(&self) -> bool {
match self {
SenderLinkState::Opening(_) => true,
_ => false,
}
}
}
impl ReceiverLinkState {
fn is_opening(&self) -> bool {
match self {
ReceiverLinkState::OpeningLocal(_) => true,
_ => false,
}
}
}
pub(crate) struct SessionInner {
id: usize,
connection: ConnectionController,
next_outgoing_id: TransferNumber,
local: bool,
remote_channel_id: u16,
next_incoming_id: TransferNumber,
remote_outgoing_window: u32,
remote_incoming_window: u32,
unsettled_deliveries: FxHashMap<DeliveryNumber, DeliveryPromise>,
links: Slab<Either<SenderLinkState, ReceiverLinkState>>,
links_by_name: FxHashMap<ByteString, usize>,
remote_handles: FxHashMap<Handle, usize>,
pending_transfers: VecDeque<PendingTransfer>,
disposition_subscribers: FxHashMap<DeliveryNumber, oneshot::Sender<Disposition>>,
error: Option<AmqpTransportError>,
}
struct PendingTransfer {
link_handle: Handle,
idx: u32,
body: Option<TransferBody>,
promise: DeliveryPromise,
tag: Option<Bytes>,
settled: Option<bool>,
}
impl SessionInner {
pub fn new(
id: usize,
local: bool,
connection: ConnectionController,
remote_channel_id: u16,
next_incoming_id: DeliveryNumber,
remote_incoming_window: u32,
remote_outgoing_window: u32,
) -> SessionInner {
SessionInner {
id,
local,
connection,
next_incoming_id,
remote_channel_id,
remote_incoming_window,
remote_outgoing_window,
next_outgoing_id: INITIAL_OUTGOING_ID,
unsettled_deliveries: FxHashMap::default(),
links: Slab::new(),
links_by_name: FxHashMap::default(),
remote_handles: FxHashMap::default(),
pending_transfers: VecDeque::new(),
disposition_subscribers: FxHashMap::default(),
error: None,
}
}
/// Local channel id
pub fn id(&self) -> u16 {
self.id as u16
}
/// Set error. New operations will return error.
pub(crate) fn set_error(&mut self, err: AmqpTransportError) {
// drop pending transfers
for tr in self.pending_transfers.drain(..) {
let _ = tr.promise.send(Err(err.clone()));
}
// drop links
self.links_by_name.clear();
for (_, st) in self.links.iter_mut() {
match st {
Either::Left(SenderLinkState::Opening(_)) => (),
Either::Left(SenderLinkState::Established(ref mut link)) => {
link.inner.get_mut().set_error(err.clone())
}
Either::Left(SenderLinkState::Closing(ref mut link)) => {
if let Some(tx) = link.take() {
let _ = tx.send(Err(err.clone()));
}
}
_ => (),
}
}
self.links.clear();
self.error = Some(err);
}
fn drop_session(&mut self) {
self.connection.drop_session_copy(self.id);
}
fn wait_disposition(
&mut self,
id: DeliveryNumber,
) -> impl Future<Output = Result<Disposition, AmqpTransportError>> {
let (tx, rx) = oneshot::channel();
self.disposition_subscribers.insert(id, tx);
async move { rx.await.map_err(|_| AmqpTransportError::Disconnected) }
}
/// Register remote sender link
pub(crate) fn confirm_sender_link(&mut self, cell: Cell<SessionInner>, attach: &Attach) {
trace!("Remote sender link opened: {:?}", attach.name());
let entry = self.links.vacant_entry();
let token = entry.key();
let delivery_count = attach.initial_delivery_count.unwrap_or(0);
let mut name = None;
if let Some(ref source) = attach.source {
if let Some(ref addr) = source.address {
name = Some(addr.clone());
self.links_by_name.insert(addr.clone(), token);
}
}
self.remote_handles.insert(attach.handle(), token);
let link = Cell::new(SenderLinkInner::new(
token,
name.unwrap_or_else(|| ByteString::default()),
attach.handle(),
delivery_count,
cell,
));
entry.insert(Either::Left(SenderLinkState::Established(SenderLink::new(
link,
))));
let attach = Attach {
name: attach.name.clone(),
handle: token as Handle,
role: Role::Sender,
snd_settle_mode: SenderSettleMode::Mixed,
rcv_settle_mode: ReceiverSettleMode::First,
source: attach.source.clone(),
target: attach.target.clone(),
unsettled: None,
incomplete_unsettled: false,
initial_delivery_count: Some(delivery_count),
max_message_size: Some(65536),
offered_capabilities: None,
desired_capabilities: None,
properties: None,
};
self.post_frame(attach.into());
}
/// Register receiver link
pub(crate) fn open_receiver_link(
&mut self,
cell: Cell<SessionInner>,
attach: Attach,
) -> ReceiverLink {
let handle = attach.handle();
let entry = self.links.vacant_entry();
let token = entry.key();
let inner = Cell::new(ReceiverLinkInner::new(cell, token as u32, attach));
entry.insert(Either::Right(ReceiverLinkState::Opening(Some(
inner.clone(),
))));
self.remote_handles.insert(handle, token);
ReceiverLink::new(inner)
}
pub(crate) fn open_local_receiver_link(
&mut self,
cell: Cell<SessionInner>,
mut frame: Attach,
) -> oneshot::Receiver<Result<ReceiverLink, AmqpTransportError>> {
let (tx, rx) = oneshot::channel();
let entry = self.links.vacant_entry();
let token = entry.key();
let inner = Cell::new(ReceiverLinkInner::new(cell, token as u32, frame.clone()));
entry.insert(Either::Right(ReceiverLinkState::OpeningLocal(Some((
inner.clone(),
tx,
)))));
frame.handle = token as Handle;
self.links_by_name.insert(frame.name.clone(), token);
self.post_frame(Frame::Attach(frame));
rx
}
pub(crate) fn confirm_receiver_link(&mut self, token: Handle, attach: &Attach) {
if let Some(Either::Right(link)) = self.links.get_mut(token as usize) {
match link {
ReceiverLinkState::Opening(l) => {
let attach = Attach {
name: attach.name.clone(),
handle: token as Handle,
role: Role::Receiver,
snd_settle_mode: SenderSettleMode::Mixed,
rcv_settle_mode: ReceiverSettleMode::First,
source: attach.source.clone(),
target: attach.target.clone(),
unsettled: None,
incomplete_unsettled: false,
initial_delivery_count: Some(0),
max_message_size: Some(65536),
offered_capabilities: None,
desired_capabilities: None,
properties: None,
};
*link = ReceiverLinkState::Established(ReceiverLink::new(l.take().unwrap()));
self.post_frame(attach.into());
}
_ => error!("Unexpected receiver link state"),
}
}
}
/// Close receiver link
pub(crate) fn detach_receiver_link(
&mut self,
id: Handle,
closed: bool,
error: Option<Error>,
tx: oneshot::Sender<Result<(), AmqpTransportError>>,
) {
if let Some(Either::Right(link)) = self.links.get_mut(id as usize) {
match link {
ReceiverLinkState::Opening(inner) => {
let attach = Attach {
name: inner.as_ref().unwrap().get_ref().name().clone(),
handle: id as Handle,
role: Role::Sender,
snd_settle_mode: SenderSettleMode::Mixed,
rcv_settle_mode: ReceiverSettleMode::First,
source: None,
target: None,
unsettled: None,
incomplete_unsettled: false,
initial_delivery_count: None,
max_message_size: None,
offered_capabilities: None,
desired_capabilities: None,
properties: None,
};
let detach = Detach {
handle: id,
closed,
error,
};
*link = ReceiverLinkState::Closing(Some(tx));
self.post_frame(attach.into());
self.post_frame(detach.into());
}
ReceiverLinkState::Established(_) => {
let detach = Detach {
handle: id,
closed,
error,
};
*link = ReceiverLinkState::Closing(Some(tx));
self.post_frame(detach.into());
}
ReceiverLinkState::Closing(_) => {
let _ = tx.send(Ok(()));
error!("Unexpected receiver link state: closing - {}", id);
}
ReceiverLinkState::OpeningLocal(_inner) => unimplemented!(),
}
} else {
let _ = tx.send(Ok(()));
error!("Receiver link does not exist while detaching: {}", id);
}
}
pub(crate) fn detach_sender_link(
&mut self,
id: usize,
closed: bool,
error: Option<Error>,
tx: oneshot::Sender<Result<(), AmqpTransportError>>,
) {
if let Some(Either::Left(link)) = self.links.get_mut(id) {
match link {
SenderLinkState::Opening(_) => {
let detach = Detach {
handle: id as u32,
closed,
error,
};
*link = SenderLinkState::Closing(Some(tx));
self.post_frame(detach.into());
}
SenderLinkState::Established(_) => {
let detach = Detach {
handle: id as u32,
closed,
error,
};
*link = SenderLinkState::Closing(Some(tx));
self.post_frame(detach.into());
}
SenderLinkState::Closing(_) => {
let _ = tx.send(Ok(()));
error!("Unexpected receiver link state: closing - {}", id);
}
}
} else {
let _ = tx.send(Ok(()));
error!("Receiver link does not exist while detaching: {}", id);
}
}
pub(crate) fn get_sender_link_by_handle(&self, hnd: Handle) -> Option<&SenderLink> {
if let Some(id) = self.remote_handles.get(&hnd) {
if let Some(Either::Left(SenderLinkState::Established(ref link))) = self.links.get(*id)
{
return Some(link);
}
}
None
}
pub(crate) fn get_receiver_link_by_handle(&self, hnd: Handle) -> Option<&ReceiverLink> {
if let Some(id) = self.remote_handles.get(&hnd) {
if let Some(Either::Right(ReceiverLinkState::Established(ref link))) =
self.links.get(*id)
{
return Some(link);
}
}
None
}
pub fn handle_frame(&mut self, frame: Frame) {
if self.error.is_none() {
match frame {
Frame::Flow(flow) => self.apply_flow(&flow),
Frame::Disposition(disp) => {
if let Some(sender) = self.disposition_subscribers.remove(&disp.first) {
let _ = sender.send(disp);
} else {
self.settle_deliveries(disp);
}
}
Frame::Transfer(transfer) => {
let idx = if let Some(idx) = self.remote_handles.get(&transfer.handle()) {
*idx
} else {
error!("Transfer's link {:?} is unknown", transfer.handle());
return;
};
if let Some(link) = self.links.get_mut(idx) {
match link {
Either::Left(_) => error!("Got trasfer from sender link"),
Either::Right(link) => match link {
ReceiverLinkState::Opening(_) => {
error!(
"Got transfer for opening link: {} -> {}",
transfer.handle(),
idx
);
}
ReceiverLinkState::OpeningLocal(_) => {
error!(
"Got transfer for opening link: {} -> {}",
transfer.handle(),
idx
);
}
ReceiverLinkState::Established(link) => {
// self.outgoing_window -= 1;
let _ = self.next_incoming_id.wrapping_add(1);
link.inner.get_mut().handle_transfer(transfer);
}
ReceiverLinkState::Closing(_) => (),
},
}
} else {
error!(
"Remote link handle mapped to non-existing link: {} -> {}",
transfer.handle(),
idx
);
}
}
Frame::Detach(detach) => {
self.handle_detach(&detach);
}
frame => error!("Unexpected frame: {:?}", frame),
}
}
}
/// Handle `Attach` frame. return false if attach frame is remote and can not be handled
pub fn handle_attach(&mut self, attach: &Attach, cell: Cell<SessionInner>) -> bool {
let name = attach.name();
if let Some(index) = self.links_by_name.get(name) {
match self.links.get_mut(*index) {
Some(Either::Left(item)) => {
if item.is_opening() {
trace!(
"sender link opened: {:?} {} -> {}",
name,
index,
attach.handle()
);
self.remote_handles.insert(attach.handle(), *index);
let delivery_count = attach.initial_delivery_count.unwrap_or(0);
let link = Cell::new(SenderLinkInner::new(
*index,
name.clone(),
attach.handle(),
delivery_count,
cell,
));
let local_sender = std::mem::replace(
item,
SenderLinkState::Established(SenderLink::new(link.clone())),
);
if let SenderLinkState::Opening(tx) = local_sender {
let _ = tx.send(SenderLink::new(link));
}
}
}
Some(Either::Right(item)) => {
if item.is_opening() {
trace!(
"receiver link opened: {:?} {} -> {}",
name,
index,
attach.handle()
);
if let ReceiverLinkState::OpeningLocal(opt_item) = item {
let (link, tx) = opt_item.take().unwrap();
self.remote_handles.insert(attach.handle(), *index);
*item = ReceiverLinkState::Established(ReceiverLink::new(link.clone()));
let _ = tx.send(Ok(ReceiverLink::new(link)));
}
}
}
_ => {
// TODO: error in proto, have to close connection
}
}
true
} else {
// cannot handle remote attach
false
}
}
/// Handle `Detach` frame.
pub fn handle_detach(&mut self, detach: &Detach) {
// get local link instance
let idx = if let Some(idx) = self.remote_handles.get(&detach.handle()) {
*idx
} else {
// should not happen, error
return;
};
let remove = if let Some(link) = self.links.get_mut(idx) {
match link {
Either::Left(link) => match link {
SenderLinkState::Opening(_) => true,
SenderLinkState::Established(link) => {
// detach from remote endpoint
let detach = Detach {
handle: link.inner.get_ref().id(),
closed: true,
error: detach.error.clone(),
};
let err = AmqpTransportError::LinkDetached(detach.error.clone());
// remove name
self.links_by_name.remove(link.inner.name());
// drop pending transfers
let mut idx = 0;
let handle = link.inner.get_ref().remote_handle();
while idx < self.pending_transfers.len() {
if self.pending_transfers[idx].link_handle == handle {
let tr = self.pending_transfers.remove(idx).unwrap();
let _ = tr.promise.send(Err(err.clone()));
} else {
idx += 1;
}
}
// detach snd link
link.inner.get_mut().detached(err);
self.connection
.post_frame(AmqpFrame::new(self.remote_channel_id, detach.into()));
true
}
SenderLinkState::Closing(_) => true,
},
Either::Right(link) => match link {
ReceiverLinkState::Opening(_) => false,
ReceiverLinkState::OpeningLocal(_) => false,
ReceiverLinkState::Established(link) => {
// detach from remote endpoint
let detach = Detach {
handle: link.handle(),
closed: true,
error: None,
};
// detach rcv link
self.connection
.post_frame(AmqpFrame::new(self.remote_channel_id, detach.into()));
true
}
ReceiverLinkState::Closing(tx) => {
// detach confirmation
if let Some(tx) = tx.take() {
if let Some(err) = detach.error.clone() {
let _ = tx.send(Err(AmqpTransportError::LinkDetached(Some(err))));
} else {
let _ = tx.send(Ok(()));
}
}
true
}
},
}
} else {
false
};
if remove {
self.links.remove(idx);
self.remote_handles.remove(&detach.handle());
}
}
fn settle_deliveries(&mut self, disposition: Disposition) {
trace!("settle delivery: {:#?}", disposition);
let from = disposition.first;
let to = disposition.last.unwrap_or(from);
if from == to {
let _ = self
.unsettled_deliveries
.remove(&from)
.unwrap()
.send(Ok(disposition));
} else {
for k in from..=to {
let _ = self
.unsettled_deliveries
.remove(&k)
.unwrap()
.send(Ok(disposition.clone()));
}
}
}
pub(crate) fn apply_flow(&mut self, flow: &Flow) {
// # AMQP1.0 2.5.6
self.next_incoming_id = flow.next_outgoing_id();
self.remote_outgoing_window = flow.outgoing_window();
self.remote_incoming_window = flow
.next_incoming_id()
.unwrap_or(INITIAL_OUTGOING_ID)
.saturating_add(flow.incoming_window())
.saturating_sub(self.next_outgoing_id);
trace!(
"session received credit. window: {}, pending: {}",
self.remote_outgoing_window,
self.pending_transfers.len()
);
while let Some(t) = self.pending_transfers.pop_front() {
self.send_transfer(t.link_handle, t.idx, t.body, t.promise, t.tag, t.settled);
if self.remote_outgoing_window == 0 {
break;
}
}
// apply link flow
if let Some(Either::Left(link)) = flow.handle().and_then(|h| self.links.get_mut(h as usize))
{
match link {
SenderLinkState::Established(ref mut link) => {
link.inner.get_mut().apply_flow(&flow);
}
_ => warn!("Received flow frame"),
}
}
if flow.echo() {
self.send_flow();
}
}
fn send_flow(&mut self) {
let flow = Flow {
next_incoming_id: if self.local {
Some(self.next_incoming_id)
} else {
None
},
incoming_window: std::u32::MAX,
next_outgoing_id: self.next_outgoing_id,
outgoing_window: self.remote_incoming_window,
handle: None,
delivery_count: None,
link_credit: None,
available: None,
drain: false,
echo: false,
properties: None,
};
self.post_frame(flow.into());
}
pub(crate) fn rcv_link_flow(&mut self, handle: u32, delivery_count: u32, credit: u32) {
let flow = Flow {
next_incoming_id: if self.local {
Some(self.next_incoming_id)
} else {
None
},
incoming_window: std::u32::MAX,
next_outgoing_id: self.next_outgoing_id,
outgoing_window: self.remote_incoming_window,
handle: Some(handle),
delivery_count: Some(delivery_count),
link_credit: Some(credit),
available: None,
drain: false,
echo: false,
properties: None,
};
self.post_frame(flow.into());
}
pub fn post_frame(&mut self, frame: Frame) {
self.connection
.post_frame(AmqpFrame::new(self.remote_channel_id, frame));
}
pub(crate) fn open_sender_link(&mut self, mut frame: Attach) -> oneshot::Receiver<SenderLink> {
let (tx, rx) = oneshot::channel();
let entry = self.links.vacant_entry();
let token = entry.key();
entry.insert(Either::Left(SenderLinkState::Opening(tx)));
frame.handle = token as Handle;
self.links_by_name.insert(frame.name.clone(), token);
self.post_frame(Frame::Attach(frame));
rx
}
pub fn send_transfer(
&mut self,
link_handle: Handle,
idx: u32,
body: Option<TransferBody>,
promise: DeliveryPromise,
tag: Option<Bytes>,
settled: Option<bool>,
) {
if self.remote_incoming_window == 0 {
self.pending_transfers.push_back(PendingTransfer {
link_handle,
idx,
body,
promise,
tag,
settled,
});
return;
}
let frame = self.prepare_transfer(link_handle, body, promise, tag, settled);
self.post_frame(frame);
}
pub fn prepare_transfer(
&mut self,
link_handle: Handle,
body: Option<TransferBody>,
promise: DeliveryPromise,
delivery_tag: Option<Bytes>,
settled: Option<bool>,
) -> Frame {
let delivery_id = self.next_outgoing_id;
let tag = if let Some(tag) = delivery_tag {
tag
} else {
let mut buf = BytesMut::new();
buf.put_u32(delivery_id);
buf.freeze()
};
self.next_outgoing_id += 1;
self.remote_incoming_window -= 1;
let message_format = if let Some(ref body) = body {
body.message_format()
} else {
None
};
let settled2 = settled.clone().unwrap_or(false);
let state = if settled2 {
Some(DeliveryState::Accepted(Accepted {}))
} else {
None
};
let transfer = Transfer {
settled,
message_format,
handle: link_handle,
delivery_id: Some(delivery_id),
delivery_tag: Some(tag),
more: false,
rcv_settle_mode: None,
state, //: Some(DeliveryState::Accepted(Accepted {})),
resume: false,
aborted: false,
batchable: false,
body: body,
};
self.unsettled_deliveries.insert(delivery_id, promise);
Frame::Transfer(transfer)
}
}

326
actix-amqp/src/sndlink.rs Executable file
View File

@ -0,0 +1,326 @@
use std::collections::VecDeque;
use std::future::Future;
use actix_utils::oneshot;
use amqp_codec::protocol::{
Attach, DeliveryNumber, DeliveryState, Disposition, Error, Flow, ReceiverSettleMode, Role,
SenderSettleMode, SequenceNo, Target, TerminusDurability, TerminusExpiryPolicy, TransferBody,
};
use bytes::Bytes;
use bytestring::ByteString;
use futures::future::{ok, Either};
use crate::cell::Cell;
use crate::errors::AmqpTransportError;
use crate::session::{Session, SessionInner};
use crate::{Delivery, DeliveryPromise, Handle};
#[derive(Clone)]
pub struct SenderLink {
pub(crate) inner: Cell<SenderLinkInner>,
}
impl std::fmt::Debug for SenderLink {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_tuple("SenderLink")
.field(&std::ops::Deref::deref(&self.inner.get_ref().name))
.finish()
}
}
pub(crate) struct SenderLinkInner {
id: usize,
idx: u32,
name: ByteString,
session: Session,
remote_handle: Handle,
delivery_count: SequenceNo,
link_credit: u32,
pending_transfers: VecDeque<PendingTransfer>,
error: Option<AmqpTransportError>,
closed: bool,
}
struct PendingTransfer {
idx: u32,
tag: Option<Bytes>,
body: Option<TransferBody>,
promise: DeliveryPromise,
settle: Option<bool>,
}
impl SenderLink {
pub(crate) fn new(inner: Cell<SenderLinkInner>) -> SenderLink {
SenderLink { inner }
}
pub fn id(&self) -> u32 {
self.inner.id as u32
}
pub fn name(&self) -> &ByteString {
&self.inner.name
}
pub fn remote_handle(&self) -> Handle {
self.inner.remote_handle
}
pub fn session(&self) -> &Session {
&self.inner.get_ref().session
}
pub fn session_mut(&mut self) -> &mut Session {
&mut self.inner.get_mut().session
}
pub fn send<T>(&self, body: T) -> impl Future<Output = Result<Disposition, AmqpTransportError>>
where
T: Into<TransferBody>,
{
self.inner.get_mut().send(body, None)
}
pub fn send_with_tag<T>(
&self,
body: T,
tag: Bytes,
) -> impl Future<Output = Result<Disposition, AmqpTransportError>>
where
T: Into<TransferBody>,
{
self.inner.get_mut().send(body, Some(tag))
}
pub fn settle_message(&self, id: DeliveryNumber, state: DeliveryState) {
self.inner.get_mut().settle_message(id, state)
}
pub fn close(&self) -> impl Future<Output = Result<(), AmqpTransportError>> {
self.inner.get_mut().close(None)
}
pub fn close_with_error(
&self,
error: Error,
) -> impl Future<Output = Result<(), AmqpTransportError>> {
self.inner.get_mut().close(Some(error))
}
}
impl SenderLinkInner {
pub(crate) fn new(
id: usize,
name: ByteString,
handle: Handle,
delivery_count: SequenceNo,
session: Cell<SessionInner>,
) -> SenderLinkInner {
SenderLinkInner {
id,
name,
delivery_count,
idx: 0,
session: Session::new(session),
remote_handle: handle,
link_credit: 0,
pending_transfers: VecDeque::new(),
error: None,
closed: false,
}
}
pub fn id(&self) -> u32 {
self.id as u32
}
pub fn remote_handle(&self) -> Handle {
self.remote_handle
}
pub(crate) fn name(&self) -> &ByteString {
&self.name
}
pub(crate) fn detached(&mut self, err: AmqpTransportError) {
// drop pending transfers
for tr in self.pending_transfers.drain(..) {
let _ = tr.promise.send(Err(err.clone()));
}
self.error = Some(err);
}
pub(crate) fn close(
&self,
error: Option<Error>,
) -> impl Future<Output = Result<(), AmqpTransportError>> {
if self.closed {
Either::Left(ok(()))
} else {
let (tx, rx) = oneshot::channel();
self.session
.inner
.get_mut()
.detach_sender_link(self.id, true, error, tx);
Either::Right(async move {
match rx.await {
Ok(Ok(_)) => Ok(()),
Ok(Err(e)) => Err(e),
Err(_) => Err(AmqpTransportError::Disconnected),
}
})
}
}
pub(crate) fn set_error(&mut self, err: AmqpTransportError) {
// drop pending transfers
for tr in self.pending_transfers.drain(..) {
let _ = tr.promise.send(Err(err.clone()));
}
self.error = Some(err);
}
pub fn apply_flow(&mut self, flow: &Flow) {
// #2.7.6
if let Some(credit) = flow.link_credit() {
let delta = flow
.delivery_count
.unwrap_or(0)
.saturating_add(credit)
.saturating_sub(self.delivery_count);
let session = self.session.inner.get_mut();
// credit became available => drain pending_transfers
self.link_credit += delta;
while self.link_credit > 0 {
if let Some(transfer) = self.pending_transfers.pop_front() {
self.link_credit -= 1;
let _ = self.delivery_count.saturating_add(1);
session.send_transfer(
self.remote_handle,
transfer.idx,
transfer.body,
transfer.promise,
transfer.tag,
transfer.settle,
);
} else {
break;
}
}
}
if flow.echo() {
// todo: send flow
}
}
pub fn send<T: Into<TransferBody>>(&mut self, body: T, tag: Option<Bytes>) -> Delivery {
if let Some(ref err) = self.error {
Delivery::Resolved(Err(err.clone()))
} else {
let body = body.into();
let (delivery_tx, delivery_rx) = oneshot::channel();
if self.link_credit == 0 {
self.pending_transfers.push_back(PendingTransfer {
tag,
settle: Some(false),
body: Some(body),
idx: self.idx,
promise: delivery_tx,
});
} else {
let session = self.session.inner.get_mut();
self.link_credit -= 1;
let _ = self.delivery_count.saturating_add(1);
session.send_transfer(
self.remote_handle,
self.idx,
Some(body),
delivery_tx,
tag,
Some(false),
);
}
let _ = self.idx.saturating_add(1);
Delivery::Pending(delivery_rx)
}
}
pub fn settle_message(&mut self, id: DeliveryNumber, state: DeliveryState) {
let _ = self.delivery_count.saturating_add(1);
let disp = Disposition {
role: Role::Sender,
first: id,
last: None,
settled: true,
state: Some(state),
batchable: false,
};
self.session.inner.get_mut().post_frame(disp.into());
}
}
pub struct SenderLinkBuilder {
frame: Attach,
session: Cell<SessionInner>,
}
impl SenderLinkBuilder {
pub(crate) fn new(name: ByteString, address: ByteString, session: Cell<SessionInner>) -> Self {
let target = Target {
address: Some(address),
durable: TerminusDurability::None,
expiry_policy: TerminusExpiryPolicy::SessionEnd,
timeout: 0,
dynamic: false,
dynamic_node_properties: None,
capabilities: None,
};
let frame = Attach {
name,
handle: 0 as Handle,
role: Role::Sender,
snd_settle_mode: SenderSettleMode::Mixed,
rcv_settle_mode: ReceiverSettleMode::First,
source: None,
target: Some(target),
unsettled: None,
incomplete_unsettled: false,
initial_delivery_count: None,
max_message_size: Some(65536 * 4),
offered_capabilities: None,
desired_capabilities: None,
properties: None,
};
SenderLinkBuilder { frame, session }
}
pub fn max_message_size(mut self, size: u64) -> Self {
self.frame.max_message_size = Some(size);
self
}
pub fn with_frame<F>(mut self, f: F) -> Self
where
F: FnOnce(&mut Attach),
{
f(&mut self.frame);
self
}
pub async fn open(self) -> Result<SenderLink, AmqpTransportError> {
self.session
.get_mut()
.open_sender_link(self.frame)
.await
.map_err(|_e| AmqpTransportError::Disconnected)
}
}

131
actix-amqp/tests/test_server.rs Executable file
View File

@ -0,0 +1,131 @@
use std::convert::TryFrom;
use actix_amqp::server::{self, errors};
use actix_amqp::{sasl, Configuration};
use actix_codec::{AsyncRead, AsyncWrite};
use actix_connect::{default_connector, TcpConnector};
use actix_service::{fn_factory_with_config, pipeline_factory, Service};
use actix_testing::TestServer;
use futures::future::{err, Ready};
use futures::Future;
use http::Uri;
fn server(
link: server::Link<()>,
) -> impl Future<
Output = Result<
Box<
dyn Service<
Request = server::Message<()>,
Response = server::Outcome,
Error = errors::AmqpError,
Future = Ready<Result<server::Message<()>, server::Outcome>>,
> + 'static,
>,
errors::LinkError,
>,
> {
println!("OPEN LINK: {:?}", link);
err(errors::LinkError::force_detach().description("unimplemented"))
}
#[actix_rt::test]
async fn test_simple() -> std::io::Result<()> {
std::env::set_var(
"RUST_LOG",
"actix_codec=info,actix_server=trace,actix_connector=trace,amqp_transport=trace",
);
env_logger::init();
let srv = TestServer::with(|| {
server::Server::new(
server::Handshake::new(|conn: server::Connect<_>| async move {
let conn = conn.open().await.unwrap();
Ok::<_, errors::AmqpError>(conn.ack(()))
})
.sasl(server::sasl::no_sasl()),
)
.finish(
server::App::<()>::new()
.service("test", fn_factory_with_config(server))
.finish(),
)
});
let uri = Uri::try_from(format!("amqp://{}:{}", srv.host(), srv.port())).unwrap();
let mut sasl_srv = sasl::connect_service(default_connector());
let req = sasl::SaslConnect {
uri,
config: Configuration::default(),
time: None,
auth: sasl::SaslAuth {
authz_id: "".to_string(),
authn_id: "user1".to_string(),
password: "password1".to_string(),
},
};
let res = sasl_srv.call(req).await;
println!("E: {:?}", res.err());
Ok(())
}
async fn sasl_auth<Io: AsyncRead + AsyncWrite>(
auth: server::Sasl<Io>,
) -> Result<server::ConnectAck<Io, ()>, server::errors::ServerError<()>> {
let init = auth
.mechanism("PLAIN")
.mechanism("ANONYMOUS")
.mechanism("MSSBCBS")
.mechanism("AMQPCBS")
.init()
.await?;
if init.mechanism() == "PLAIN" {
if let Some(resp) = init.initial_response() {
if resp == b"\0user1\0password1" {
let succ = init.outcome(amqp_codec::protocol::SaslCode::Ok).await?;
return Ok(succ.open().await?.ack(()));
}
}
}
let succ = init.outcome(amqp_codec::protocol::SaslCode::Auth).await?;
Ok(succ.open().await?.ack(()))
}
#[actix_rt::test]
async fn test_sasl() -> std::io::Result<()> {
let srv = TestServer::with(|| {
server::Server::new(
server::Handshake::new(|conn: server::Connect<_>| async move {
let conn = conn.open().await.unwrap();
Ok::<_, errors::Error>(conn.ack(()))
})
.sasl(pipeline_factory(sasl_auth).map_err(|e| e.into())),
)
.finish(
server::App::<()>::new()
.service("test", fn_factory_with_config(server))
.finish(),
)
});
let uri = Uri::try_from(format!("amqp://{}:{}", srv.host(), srv.port())).unwrap();
let mut sasl_srv = sasl::connect_service(TcpConnector::new());
let req = sasl::SaslConnect {
uri,
config: Configuration::default(),
time: None,
auth: sasl::SaslAuth {
authz_id: "".to_string(),
authn_id: "user1".to_string(),
password: "password1".to_string(),
},
};
let res = sasl_srv.call(req).await;
println!("E: {:?}", res.err());
Ok(())
}

24
actix-mqtt/CHANGES.md Executable file
View File

@ -0,0 +1,24 @@
# Changes
## Unreleased - 2020-xx-xx
## 0.2.3 - 2020-03-10
* Add server handshake timeout
## 0.2.2 - 2020-02-04
* Fix server keep-alive impl
## 0.2.1 - 2019-12-25
* Allow to specify multi-pattern for topics
## 0.2.0 - 2019-12-11
* Migrate to `std::future`
* Support publish with QoS 1
## 0.1.0 - 2019-09-25
* Initial release

38
actix-mqtt/Cargo.toml Executable file
View File

@ -0,0 +1,38 @@
[package]
name = "actix-mqtt"
version = "0.2.3"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "MQTT v3.1.1 Client/Server framework"
documentation = "https://docs.rs/actix-mqtt"
repository = "https://github.com/actix/actix-extras.git"
categories = ["network-programming"]
keywords = ["MQTT", "IoT", "messaging"]
license = "MIT OR Apache-2.0"
edition = "2018"
[dependencies]
mqtt-codec = "0.3.0"
actix-codec = "0.2.0"
actix-service = "1.0.1"
actix-utils = "1.0.4"
actix-router = "0.2.2"
actix-ioframe = "0.4.1"
actix-rt = "1.0.0"
derive_more = "0.99.2"
bytes = "0.5.3"
either = "1.5.2"
futures = "0.3.1"
pin-project = "0.4.6"
log = "0.4"
bytestring = "0.1.2"
serde = "1.0"
serde_json = "1.0"
uuid = { version = "0.8", features = ["v4"] }
[dev-dependencies]
env_logger = "0.6"
actix-connect = "1.0.1"
actix-server = "1.0.0"
actix-testing = "1.0.0"
actix-rt = "1.0.0"

201
actix-mqtt/LICENSE-APACHE Executable file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2019-NOW Nikolay Kim
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

25
actix-mqtt/LICENSE-MIT Executable file
View File

@ -0,0 +1,25 @@
Copyright (c) 2019 Nikolay Kim
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

1
actix-mqtt/README.md Executable file
View File

@ -0,0 +1 @@
# MQTT 3.1.1 Client/Server framework [![Build Status](https://travis-ci.org/actix/actix-mqtt.svg?branch=master)](https://travis-ci.org/actix/actix-mqtt) [![codecov](https://codecov.io/gh/actix/actix-mqtt/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-mqtt) [![crates.io](https://meritbadge.herokuapp.com/actix-mqtt)](https://crates.io/crates/actix-mqtt) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)

17
actix-mqtt/codec/CHANGES.md Executable file
View File

@ -0,0 +1,17 @@
# Changes
## Unreleased - 2020-xx-xx
## 0.3.0 - 2019-12-11
* Use `bytestring` instead of string
* Upgrade actix-codec to 0.2.0
## 0.2.0 - 2019-09-09
* Remove `Packet::Empty`
* Add `max frame size` config
## 0.1.0 - 2019-06-18
* Initial release

21
actix-mqtt/codec/Cargo.toml Executable file
View File

@ -0,0 +1,21 @@
[package]
name = "mqtt-codec"
version = "0.3.0"
authors = [
"Max Gortman <mgortman@microsoft.com>",
"Nikolay Kim <fafhrd91@gmail.com>",
"Flier Lu <flier.lu@gmail.com>",
]
description = "MQTT v3.1.1 Codec"
documentation = "https://docs.rs/mqtt-codec"
repository = "https://github.com/actix/actix-mqtt.git"
readme = "README.md"
keywords = ["MQTT", "IoT", "messaging"]
license = "MIT/Apache-2.0"
edition = "2018"
[dependencies]
bitflags = "1.0"
bytes = "0.5.2"
bytestring = "0.1.0"
actix-codec = "0.2.0"

201
actix-mqtt/codec/LICENSE-APACHE Executable file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2019-NOW Nikolay Kim
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

25
actix-mqtt/codec/LICENSE-MIT Executable file
View File

@ -0,0 +1,25 @@
Copyright (c) 2019 Nikolay Kim
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

3
actix-mqtt/codec/README.md Executable file
View File

@ -0,0 +1,3 @@
# MQTT v3.1 Codec
MQTT v3.1 Codec implementation

View File

@ -0,0 +1,616 @@
use std::convert::TryFrom;
use std::num::{NonZeroU16, NonZeroU32};
use bytes::{Buf, Bytes};
use bytestring::ByteString;
use crate::error::ParseError;
use crate::packet::*;
use crate::proto::*;
use super::{ConnectAckFlags, ConnectFlags, FixedHeader, WILL_QOS_SHIFT};
pub(crate) fn read_packet(mut src: Bytes, header: FixedHeader) -> Result<Packet, ParseError> {
match header.packet_type {
CONNECT => decode_connect_packet(&mut src),
CONNACK => decode_connect_ack_packet(&mut src),
PUBLISH => decode_publish_packet(&mut src, header),
PUBACK => Ok(Packet::PublishAck {
packet_id: NonZeroU16::parse(&mut src)?,
}),
PUBREC => Ok(Packet::PublishReceived {
packet_id: NonZeroU16::parse(&mut src)?,
}),
PUBREL => Ok(Packet::PublishRelease {
packet_id: NonZeroU16::parse(&mut src)?,
}),
PUBCOMP => Ok(Packet::PublishComplete {
packet_id: NonZeroU16::parse(&mut src)?,
}),
SUBSCRIBE => decode_subscribe_packet(&mut src),
SUBACK => decode_subscribe_ack_packet(&mut src),
UNSUBSCRIBE => decode_unsubscribe_packet(&mut src),
UNSUBACK => Ok(Packet::UnsubscribeAck {
packet_id: NonZeroU16::parse(&mut src)?,
}),
PINGREQ => Ok(Packet::PingRequest),
PINGRESP => Ok(Packet::PingResponse),
DISCONNECT => Ok(Packet::Disconnect),
_ => Err(ParseError::UnsupportedPacketType),
}
}
macro_rules! check_flag {
($flags:expr, $flag:expr) => {
($flags & $flag.bits()) == $flag.bits()
};
}
macro_rules! ensure {
($cond:expr, $e:expr) => {
if !($cond) {
return Err($e);
}
};
($cond:expr, $fmt:expr, $($arg:tt)+) => {
if !($cond) {
return Err($fmt, $($arg)+);
}
};
}
pub fn decode_variable_length(src: &[u8]) -> Result<Option<(usize, usize)>, ParseError> {
if let Some((len, consumed, more)) = src
.iter()
.enumerate()
.scan((0, true), |state, (idx, x)| {
if !state.1 || idx > 3 {
return None;
}
state.0 += ((x & 0x7F) as usize) << (idx * 7);
state.1 = x & 0x80 != 0;
Some((state.0, idx + 1, state.1))
})
.last()
{
ensure!(!more || consumed < 4, ParseError::InvalidLength);
return Ok(Some((len, consumed)));
}
Ok(None)
}
fn decode_connect_packet(src: &mut Bytes) -> Result<Packet, ParseError> {
ensure!(src.remaining() >= 10, ParseError::InvalidLength);
let len = src.get_u16();
ensure!(
len == 4 && &src.bytes()[0..4] == b"MQTT",
ParseError::InvalidProtocol
);
src.advance(4);
let level = src.get_u8();
ensure!(
level == DEFAULT_MQTT_LEVEL,
ParseError::UnsupportedProtocolLevel
);
let flags = src.get_u8();
ensure!((flags & 0x01) == 0, ParseError::ConnectReservedFlagSet);
let keep_alive = src.get_u16();
let client_id = decode_utf8_str(src)?;
ensure!(
!client_id.is_empty() || check_flag!(flags, ConnectFlags::CLEAN_SESSION),
ParseError::InvalidClientId
);
let topic = if check_flag!(flags, ConnectFlags::WILL) {
Some(decode_utf8_str(src)?)
} else {
None
};
let message = if check_flag!(flags, ConnectFlags::WILL) {
Some(decode_length_bytes(src)?)
} else {
None
};
let username = if check_flag!(flags, ConnectFlags::USERNAME) {
Some(decode_utf8_str(src)?)
} else {
None
};
let password = if check_flag!(flags, ConnectFlags::PASSWORD) {
Some(decode_length_bytes(src)?)
} else {
None
};
let last_will = if let Some(topic) = topic {
Some(LastWill {
qos: QoS::from((flags & ConnectFlags::WILL_QOS.bits()) >> WILL_QOS_SHIFT),
retain: check_flag!(flags, ConnectFlags::WILL_RETAIN),
topic,
message: message.unwrap(),
})
} else {
None
};
Ok(Packet::Connect(Connect {
protocol: Protocol::MQTT(level),
clean_session: check_flag!(flags, ConnectFlags::CLEAN_SESSION),
keep_alive,
client_id,
last_will,
username,
password,
}))
}
fn decode_connect_ack_packet(src: &mut Bytes) -> Result<Packet, ParseError> {
ensure!(src.remaining() >= 2, ParseError::InvalidLength);
let flags = src.get_u8();
ensure!(
(flags & 0b1111_1110) == 0,
ParseError::ConnAckReservedFlagSet
);
let return_code = src.get_u8();
Ok(Packet::ConnectAck {
session_present: check_flag!(flags, ConnectAckFlags::SESSION_PRESENT),
return_code: ConnectCode::from(return_code),
})
}
fn decode_publish_packet(src: &mut Bytes, header: FixedHeader) -> Result<Packet, ParseError> {
let topic = decode_utf8_str(src)?;
let qos = QoS::from((header.packet_flags & 0b0110) >> 1);
let packet_id = if qos == QoS::AtMostOnce {
None
} else {
Some(NonZeroU16::parse(src)?)
};
let len = src.remaining();
let payload = src.split_to(len);
Ok(Packet::Publish(Publish {
dup: (header.packet_flags & 0b1000) == 0b1000,
qos,
retain: (header.packet_flags & 0b0001) == 0b0001,
topic,
packet_id,
payload,
}))
}
fn decode_subscribe_packet(src: &mut Bytes) -> Result<Packet, ParseError> {
let packet_id = NonZeroU16::parse(src)?;
let mut topic_filters = Vec::new();
while src.remaining() > 0 {
let topic = decode_utf8_str(src)?;
ensure!(src.remaining() >= 1, ParseError::InvalidLength);
let qos = QoS::from(src.get_u8() & 0x03);
topic_filters.push((topic, qos));
}
Ok(Packet::Subscribe {
packet_id,
topic_filters,
})
}
fn decode_subscribe_ack_packet(src: &mut Bytes) -> Result<Packet, ParseError> {
let packet_id = NonZeroU16::parse(src)?;
let status = src
.bytes()
.iter()
.map(|code| {
if *code == 0x80 {
SubscribeReturnCode::Failure
} else {
SubscribeReturnCode::Success(QoS::from(code & 0x03))
}
})
.collect();
Ok(Packet::SubscribeAck { packet_id, status })
}
fn decode_unsubscribe_packet(src: &mut Bytes) -> Result<Packet, ParseError> {
let packet_id = NonZeroU16::parse(src)?;
let mut topic_filters = Vec::new();
while src.remaining() > 0 {
topic_filters.push(decode_utf8_str(src)?);
}
Ok(Packet::Unsubscribe {
packet_id,
topic_filters,
})
}
fn decode_length_bytes(src: &mut Bytes) -> Result<Bytes, ParseError> {
let len: u16 = NonZeroU16::parse(src)?.into();
ensure!(src.remaining() >= len as usize, ParseError::InvalidLength);
Ok(src.split_to(len as usize))
}
fn decode_utf8_str(src: &mut Bytes) -> Result<ByteString, ParseError> {
Ok(ByteString::try_from(decode_length_bytes(src)?)?)
}
pub(crate) trait ByteBuf: Buf {
fn inner_mut(&mut self) -> &mut Bytes;
}
impl ByteBuf for Bytes {
fn inner_mut(&mut self) -> &mut Bytes {
self
}
}
impl ByteBuf for bytes::buf::ext::Take<&mut Bytes> {
fn inner_mut(&mut self) -> &mut Bytes {
self.get_mut()
}
}
pub(crate) trait Parse: Sized {
fn parse<B: ByteBuf>(src: &mut B) -> Result<Self, ParseError>;
}
impl Parse for bool {
fn parse<B: ByteBuf>(src: &mut B) -> Result<Self, ParseError> {
ensure!(src.has_remaining(), ParseError::InvalidLength); // expected more data within the field
let v = src.get_u8();
ensure!(v <= 0x1, ParseError::MalformedPacket); // value is invalid
Ok(v == 0x1)
}
}
impl Parse for u16 {
fn parse<B: ByteBuf>(src: &mut B) -> Result<Self, ParseError> {
ensure!(src.remaining() >= 2, ParseError::InvalidLength);
Ok(src.get_u16())
}
}
impl Parse for u32 {
fn parse<B: ByteBuf>(src: &mut B) -> Result<Self, ParseError> {
ensure!(src.remaining() >= 4, ParseError::InvalidLength); // expected more data within the field
let val = src.get_u32();
Ok(val)
}
}
impl Parse for NonZeroU32 {
fn parse<B: ByteBuf>(src: &mut B) -> Result<Self, ParseError> {
let val = NonZeroU32::new(u32::parse(src)?).ok_or(ParseError::MalformedPacket)?;
Ok(val)
}
}
impl Parse for NonZeroU16 {
fn parse<B: ByteBuf>(src: &mut B) -> Result<Self, ParseError> {
Ok(NonZeroU16::new(u16::parse(src)?).ok_or(ParseError::MalformedPacket)?)
}
}
impl Parse for Bytes {
fn parse<B: ByteBuf>(src: &mut B) -> Result<Self, ParseError> {
let len = u16::parse(src)? as usize;
ensure!(src.remaining() >= len, ParseError::InvalidLength);
Ok(src.inner_mut().split_to(len))
}
}
pub(crate) type ByteStr = ByteString;
impl Parse for ByteStr {
fn parse<B: ByteBuf>(src: &mut B) -> Result<Self, ParseError> {
let bytes = Bytes::parse(src)?;
Ok(ByteString::try_from(bytes)?)
}
}
impl Parse for (ByteStr, ByteStr) {
fn parse<B: ByteBuf>(src: &mut B) -> Result<Self, ParseError> {
let key = ByteStr::parse(src)?;
let val = ByteStr::parse(src)?;
Ok((key, val))
}
}
#[cfg(test)]
mod tests {
use super::*;
macro_rules! assert_decode_packet (
($bytes:expr, $res:expr) => {{
let fixed = $bytes.as_ref()[0];
let (_len, consumned) = decode_variable_length(&$bytes[1..]).unwrap().unwrap();
let hdr = FixedHeader {
packet_type: fixed >> 4,
packet_flags: fixed & 0xF,
remaining_length: $bytes.len() - consumned - 1,
};
let cur = Bytes::from_static(&$bytes[consumned + 1..]);
assert_eq!(read_packet(cur, hdr), Ok($res));
}};
);
fn packet_id(v: u16) -> NonZeroU16 {
NonZeroU16::new(v).unwrap()
}
#[test]
fn test_decode_variable_length() {
macro_rules! assert_variable_length (
($bytes:expr, $res:expr) => {{
assert_eq!(decode_variable_length($bytes), Ok(Some($res)));
}};
($bytes:expr, $res:expr, $rest:expr) => {{
assert_eq!(decode_variable_length($bytes), Ok(Some($res)));
}};
);
assert_variable_length!(b"\x7f\x7f", (127, 1), b"\x7f");
//assert_eq!(decode_variable_length(b"\xff\xff\xff"), Ok(None));
assert_eq!(
decode_variable_length(b"\xff\xff\xff\xff\xff\xff"),
Err(ParseError::InvalidLength)
);
assert_variable_length!(b"\x00", (0, 1));
assert_variable_length!(b"\x7f", (127, 1));
assert_variable_length!(b"\x80\x01", (128, 2));
assert_variable_length!(b"\xff\x7f", (16383, 2));
assert_variable_length!(b"\x80\x80\x01", (16384, 3));
assert_variable_length!(b"\xff\xff\x7f", (2097151, 3));
assert_variable_length!(b"\x80\x80\x80\x01", (2097152, 4));
assert_variable_length!(b"\xff\xff\xff\x7f", (268435455, 4));
}
// #[test]
// fn test_decode_header() {
// assert_eq!(
// decode_header(b"\x20\x7f"),
// Done(
// &b""[..],
// FixedHeader {
// packet_type: CONNACK,
// packet_flags: 0,
// remaining_length: 127,
// }
// )
// );
// assert_eq!(
// decode_header(b"\x3C\x82\x7f"),
// Done(
// &b""[..],
// FixedHeader {
// packet_type: PUBLISH,
// packet_flags: 0x0C,
// remaining_length: 16258,
// }
// )
// );
// assert_eq!(decode_header(b"\x20"), Incomplete(Needed::Unknown));
// }
#[test]
fn test_decode_connect_packets() {
assert_eq!(
decode_connect_packet(&mut Bytes::from_static(
b"\x00\x04MQTT\x04\xC0\x00\x3C\x00\x0512345\x00\x04user\x00\x04pass"
)),
Ok(Packet::Connect(Connect {
protocol: Protocol::MQTT(4),
clean_session: false,
keep_alive: 60,
client_id: ByteString::try_from(Bytes::from_static(b"12345")).unwrap(),
last_will: None,
username: Some(ByteString::try_from(Bytes::from_static(b"user")).unwrap()),
password: Some(Bytes::from(&b"pass"[..])),
}))
);
assert_eq!(
decode_connect_packet(&mut Bytes::from_static(
b"\x00\x04MQTT\x04\x14\x00\x3C\x00\x0512345\x00\x05topic\x00\x07message"
)),
Ok(Packet::Connect(Connect {
protocol: Protocol::MQTT(4),
clean_session: false,
keep_alive: 60,
client_id: ByteString::try_from(Bytes::from_static(b"12345")).unwrap(),
last_will: Some(LastWill {
qos: QoS::ExactlyOnce,
retain: false,
topic: ByteString::try_from(Bytes::from_static(b"topic")).unwrap(),
message: Bytes::from(&b"message"[..]),
}),
username: None,
password: None,
}))
);
assert_eq!(
decode_connect_packet(&mut Bytes::from_static(b"\x00\x02MQ00000000000000000000")),
Err(ParseError::InvalidProtocol),
);
assert_eq!(
decode_connect_packet(&mut Bytes::from_static(b"\x00\x10MQ00000000000000000000")),
Err(ParseError::InvalidProtocol),
);
assert_eq!(
decode_connect_packet(&mut Bytes::from_static(b"\x00\x04MQAA00000000000000000000")),
Err(ParseError::InvalidProtocol),
);
assert_eq!(
decode_connect_packet(&mut Bytes::from_static(
b"\x00\x04MQTT\x0300000000000000000000"
)),
Err(ParseError::UnsupportedProtocolLevel),
);
assert_eq!(
decode_connect_packet(&mut Bytes::from_static(
b"\x00\x04MQTT\x04\xff00000000000000000000"
)),
Err(ParseError::ConnectReservedFlagSet)
);
assert_eq!(
decode_connect_ack_packet(&mut Bytes::from_static(b"\x01\x04")),
Ok(Packet::ConnectAck {
session_present: true,
return_code: ConnectCode::BadUserNameOrPassword
})
);
assert_eq!(
decode_connect_ack_packet(&mut Bytes::from_static(b"\x03\x04")),
Err(ParseError::ConnAckReservedFlagSet)
);
assert_decode_packet!(
b"\x20\x02\x01\x04",
Packet::ConnectAck {
session_present: true,
return_code: ConnectCode::BadUserNameOrPassword,
}
);
assert_decode_packet!(b"\xe0\x00", Packet::Disconnect);
}
#[test]
fn test_decode_publish_packets() {
//assert_eq!(
// decode_publish_packet(b"\x00\x05topic\x12\x34"),
// Done(&b""[..], ("topic".to_owned(), 0x1234))
//);
assert_decode_packet!(
b"\x3d\x0D\x00\x05topic\x43\x21data",
Packet::Publish(Publish {
dup: true,
retain: true,
qos: QoS::ExactlyOnce,
topic: ByteString::try_from(Bytes::from_static(b"topic")).unwrap(),
packet_id: Some(packet_id(0x4321)),
payload: Bytes::from_static(b"data"),
})
);
assert_decode_packet!(
b"\x30\x0b\x00\x05topicdata",
Packet::Publish(Publish {
dup: false,
retain: false,
qos: QoS::AtMostOnce,
topic: ByteString::try_from(Bytes::from_static(b"topic")).unwrap(),
packet_id: None,
payload: Bytes::from_static(b"data"),
})
);
assert_decode_packet!(
b"\x40\x02\x43\x21",
Packet::PublishAck {
packet_id: packet_id(0x4321),
}
);
assert_decode_packet!(
b"\x50\x02\x43\x21",
Packet::PublishReceived {
packet_id: packet_id(0x4321),
}
);
assert_decode_packet!(
b"\x60\x02\x43\x21",
Packet::PublishRelease {
packet_id: packet_id(0x4321),
}
);
assert_decode_packet!(
b"\x70\x02\x43\x21",
Packet::PublishComplete {
packet_id: packet_id(0x4321),
}
);
}
#[test]
fn test_decode_subscribe_packets() {
let p = Packet::Subscribe {
packet_id: packet_id(0x1234),
topic_filters: vec![
(
ByteString::try_from(Bytes::from_static(b"test")).unwrap(),
QoS::AtLeastOnce,
),
(
ByteString::try_from(Bytes::from_static(b"filter")).unwrap(),
QoS::ExactlyOnce,
),
],
};
assert_eq!(
decode_subscribe_packet(&mut Bytes::from_static(
b"\x12\x34\x00\x04test\x01\x00\x06filter\x02"
)),
Ok(p.clone())
);
assert_decode_packet!(b"\x82\x12\x12\x34\x00\x04test\x01\x00\x06filter\x02", p);
let p = Packet::SubscribeAck {
packet_id: packet_id(0x1234),
status: vec![
SubscribeReturnCode::Success(QoS::AtLeastOnce),
SubscribeReturnCode::Failure,
SubscribeReturnCode::Success(QoS::ExactlyOnce),
],
};
assert_eq!(
decode_subscribe_ack_packet(&mut Bytes::from_static(b"\x12\x34\x01\x80\x02")),
Ok(p.clone())
);
assert_decode_packet!(b"\x90\x05\x12\x34\x01\x80\x02", p);
let p = Packet::Unsubscribe {
packet_id: packet_id(0x1234),
topic_filters: vec![
ByteString::try_from(Bytes::from_static(b"test")).unwrap(),
ByteString::try_from(Bytes::from_static(b"filter")).unwrap(),
],
};
assert_eq!(
decode_unsubscribe_packet(&mut Bytes::from_static(
b"\x12\x34\x00\x04test\x00\x06filter"
)),
Ok(p.clone())
);
assert_decode_packet!(b"\xa2\x10\x12\x34\x00\x04test\x00\x06filter", p);
assert_decode_packet!(
b"\xb0\x02\x43\x21",
Packet::UnsubscribeAck {
packet_id: packet_id(0x4321),
}
);
}
#[test]
fn test_decode_ping_packets() {
assert_decode_packet!(b"\xc0\x00", Packet::PingRequest);
assert_decode_packet!(b"\xd0\x00", Packet::PingResponse);
}
}

View File

@ -0,0 +1,446 @@
use bytes::{BufMut, BytesMut};
use crate::packet::*;
use crate::proto::*;
use super::{ConnectFlags, WILL_QOS_SHIFT};
pub fn write_packet(packet: &Packet, dst: &mut BytesMut, content_size: usize) {
write_fixed_header(packet, dst, content_size);
write_content(packet, dst);
}
pub fn get_encoded_size(packet: &Packet) -> usize {
match *packet {
Packet::Connect ( ref connect ) => {
match *connect {
Connect {ref last_will, ref client_id, ref username, ref password, ..} =>
{
// Protocol Name + Protocol Level + Connect Flags + Keep Alive
let mut n = 2 + 4 + 1 + 1 + 2;
// Client Id
n += 2 + client_id.len();
// Will Topic + Will Message
if let Some(LastWill { ref topic, ref message, .. }) = *last_will {
n += 2 + topic.len() + 2 + message.len();
}
if let Some(ref s) = *username {
n += 2 + s.len();
}
if let Some(ref s) = *password {
n += 2 + s.len();
}
n
}
}
}
Packet::Publish( Publish{ qos, ref topic, ref payload, .. }) => {
// Topic + Packet Id + Payload
if qos == QoS::AtLeastOnce || qos == QoS::ExactlyOnce {
4 + topic.len() + payload.len()
} else {
2 + topic.len() + payload.len()
}
}
Packet::ConnectAck { .. } | // Flags + Return Code
Packet::PublishAck { .. } | // Packet Id
Packet::PublishReceived { .. } | // Packet Id
Packet::PublishRelease { .. } | // Packet Id
Packet::PublishComplete { .. } | // Packet Id
Packet::UnsubscribeAck { .. } => 2, // Packet Id
Packet::Subscribe { ref topic_filters, .. } => {
2 + topic_filters.iter().fold(0, |acc, &(ref filter, _)| acc + 2 + filter.len() + 1)
}
Packet::SubscribeAck { ref status, .. } => 2 + status.len(),
Packet::Unsubscribe { ref topic_filters, .. } => {
2 + topic_filters.iter().fold(0, |acc, filter| acc + 2 + filter.len())
}
Packet::PingRequest | Packet::PingResponse | Packet::Disconnect => 0,
}
}
#[inline]
fn write_fixed_header(packet: &Packet, dst: &mut BytesMut, content_size: usize) {
dst.put_u8((packet.packet_type() << 4) | packet.packet_flags());
write_variable_length(content_size, dst);
}
fn write_content(packet: &Packet, dst: &mut BytesMut) {
match *packet {
Packet::Connect(ref connect) => match *connect {
Connect {
protocol,
clean_session,
keep_alive,
ref last_will,
ref client_id,
ref username,
ref password,
} => {
write_slice(protocol.name().as_bytes(), dst);
let mut flags = ConnectFlags::empty();
if username.is_some() {
flags |= ConnectFlags::USERNAME;
}
if password.is_some() {
flags |= ConnectFlags::PASSWORD;
}
if let Some(LastWill { qos, retain, .. }) = *last_will {
flags |= ConnectFlags::WILL;
if retain {
flags |= ConnectFlags::WILL_RETAIN;
}
let b: u8 = qos as u8;
flags |= ConnectFlags::from_bits_truncate(b << WILL_QOS_SHIFT);
}
if clean_session {
flags |= ConnectFlags::CLEAN_SESSION;
}
dst.put_slice(&[protocol.level(), flags.bits()]);
dst.put_u16(keep_alive);
write_slice(client_id.as_bytes(), dst);
if let Some(LastWill {
ref topic,
ref message,
..
}) = *last_will
{
write_slice(topic.as_bytes(), dst);
write_slice(&message, dst);
}
if let Some(ref s) = *username {
write_slice(s.as_bytes(), dst);
}
if let Some(ref s) = *password {
write_slice(s, dst);
}
}
},
Packet::ConnectAck {
session_present,
return_code,
} => {
dst.put_slice(&[if session_present { 0x01 } else { 0x00 }, return_code as u8]);
}
Packet::Publish(Publish {
qos,
ref topic,
packet_id,
ref payload,
..
}) => {
write_slice(topic.as_bytes(), dst);
if qos == QoS::AtLeastOnce || qos == QoS::ExactlyOnce {
dst.put_u16(packet_id.unwrap().into());
}
dst.put(payload.as_ref());
}
Packet::PublishAck { packet_id }
| Packet::PublishReceived { packet_id }
| Packet::PublishRelease { packet_id }
| Packet::PublishComplete { packet_id }
| Packet::UnsubscribeAck { packet_id } => {
dst.put_u16(packet_id.into());
}
Packet::Subscribe {
packet_id,
ref topic_filters,
} => {
dst.put_u16(packet_id.into());
for &(ref filter, qos) in topic_filters {
write_slice(filter.as_ref(), dst);
dst.put_slice(&[qos as u8]);
}
}
Packet::SubscribeAck {
packet_id,
ref status,
} => {
dst.put_u16(packet_id.into());
let buf: Vec<u8> = status
.iter()
.map(|s| {
if let SubscribeReturnCode::Success(qos) = *s {
qos as u8
} else {
0x80
}
})
.collect();
dst.put_slice(&buf);
}
Packet::Unsubscribe {
packet_id,
ref topic_filters,
} => {
dst.put_u16(packet_id.into());
for filter in topic_filters {
write_slice(filter.as_ref(), dst);
}
}
Packet::PingRequest | Packet::PingResponse | Packet::Disconnect => {}
}
}
#[inline]
fn write_slice(r: &[u8], dst: &mut BytesMut) {
dst.put_u16(r.len() as u16);
dst.put_slice(r);
}
#[inline]
fn write_variable_length(size: usize, dst: &mut BytesMut) {
// todo: verify at higher level
// if size > MAX_VARIABLE_LENGTH {
// Err(Error::new(ErrorKind::Other, "out of range"))
if size <= 127 {
dst.put_u8(size as u8);
} else if size <= 16383 {
// 127 + 127 << 7
dst.put_slice(&[((size % 128) | 0x80) as u8, (size >> 7) as u8]);
} else if size <= 2_097_151 {
// 127 + 127 << 7 + 127 << 14
dst.put_slice(&[
((size % 128) | 0x80) as u8,
(((size >> 7) % 128) | 0x80) as u8,
(size >> 14) as u8,
]);
} else {
dst.put_slice(&[
((size % 128) | 0x80) as u8,
(((size >> 7) % 128) | 0x80) as u8,
(((size >> 14) % 128) | 0x80) as u8,
(size >> 21) as u8,
]);
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use bytestring::ByteString;
use std::num::NonZeroU16;
use super::*;
fn packet_id(v: u16) -> NonZeroU16 {
NonZeroU16::new(v).unwrap()
}
#[test]
fn test_encode_variable_length() {
let mut v = BytesMut::new();
write_variable_length(123, &mut v);
assert_eq!(v, [123].as_ref());
v.clear();
write_variable_length(129, &mut v);
assert_eq!(v, b"\x81\x01".as_ref());
v.clear();
write_variable_length(16383, &mut v);
assert_eq!(v, b"\xff\x7f".as_ref());
v.clear();
write_variable_length(2097151, &mut v);
assert_eq!(v, b"\xff\xff\x7f".as_ref());
v.clear();
write_variable_length(268435455, &mut v);
assert_eq!(v, b"\xff\xff\xff\x7f".as_ref());
// assert!(v.write_variable_length(MAX_VARIABLE_LENGTH + 1).is_err())
}
#[test]
fn test_encode_fixed_header() {
let mut v = BytesMut::new();
let p = Packet::PingRequest;
assert_eq!(get_encoded_size(&p), 0);
write_fixed_header(&p, &mut v, 0);
assert_eq!(v, b"\xc0\x00".as_ref());
v.clear();
let p = Packet::Publish(Publish {
dup: true,
retain: true,
qos: QoS::ExactlyOnce,
topic: ByteString::from_static("topic"),
packet_id: Some(packet_id(0x4321)),
payload: (0..255).collect::<Vec<u8>>().into(),
});
assert_eq!(get_encoded_size(&p), 264);
write_fixed_header(&p, &mut v, 264);
assert_eq!(v, b"\x3d\x88\x02".as_ref());
}
macro_rules! assert_packet {
($p:expr, $data:expr) => {
let mut v = BytesMut::with_capacity(1024);
write_packet(&$p, &mut v, get_encoded_size($p));
assert_eq!(v.len(), $data.len());
assert_eq!(v, &$data[..]);
// assert_eq!(read_packet($data.cursor()).unwrap(), (&b""[..], $p));
};
}
#[test]
fn test_encode_connect_packets() {
assert_packet!(
&Packet::Connect(Connect {
protocol: Protocol::MQTT(4),
clean_session: false,
keep_alive: 60,
client_id: ByteString::from_static("12345"),
last_will: None,
username: Some(ByteString::from_static("user")),
password: Some(Bytes::from_static(b"pass")),
}),
&b"\x10\x1D\x00\x04MQTT\x04\xC0\x00\x3C\x00\
\x0512345\x00\x04user\x00\x04pass"[..]
);
assert_packet!(
&Packet::Connect(Connect {
protocol: Protocol::MQTT(4),
clean_session: false,
keep_alive: 60,
client_id: ByteString::from_static("12345"),
last_will: Some(LastWill {
qos: QoS::ExactlyOnce,
retain: false,
topic: ByteString::from_static("topic"),
message: Bytes::from_static(b"message"),
}),
username: None,
password: None,
}),
&b"\x10\x21\x00\x04MQTT\x04\x14\x00\x3C\x00\
\x0512345\x00\x05topic\x00\x07message"[..]
);
assert_packet!(&Packet::Disconnect, b"\xe0\x00");
}
#[test]
fn test_encode_publish_packets() {
assert_packet!(
&Packet::Publish(Publish {
dup: true,
retain: true,
qos: QoS::ExactlyOnce,
topic: ByteString::from_static("topic"),
packet_id: Some(packet_id(0x4321)),
payload: Bytes::from_static(b"data"),
}),
b"\x3d\x0D\x00\x05topic\x43\x21data"
);
assert_packet!(
&Packet::Publish(Publish {
dup: false,
retain: false,
qos: QoS::AtMostOnce,
topic: ByteString::from_static("topic"),
packet_id: None,
payload: Bytes::from_static(b"data"),
}),
b"\x30\x0b\x00\x05topicdata"
);
}
#[test]
fn test_encode_subscribe_packets() {
assert_packet!(
&Packet::Subscribe {
packet_id: packet_id(0x1234),
topic_filters: vec![
(ByteString::from_static("test"), QoS::AtLeastOnce),
(ByteString::from_static("filter"), QoS::ExactlyOnce)
],
},
b"\x82\x12\x12\x34\x00\x04test\x01\x00\x06filter\x02"
);
assert_packet!(
&Packet::SubscribeAck {
packet_id: packet_id(0x1234),
status: vec![
SubscribeReturnCode::Success(QoS::AtLeastOnce),
SubscribeReturnCode::Failure,
SubscribeReturnCode::Success(QoS::ExactlyOnce)
],
},
b"\x90\x05\x12\x34\x01\x80\x02"
);
assert_packet!(
&Packet::Unsubscribe {
packet_id: packet_id(0x1234),
topic_filters: vec![
ByteString::from_static("test"),
ByteString::from_static("filter"),
],
},
b"\xa2\x10\x12\x34\x00\x04test\x00\x06filter"
);
assert_packet!(
&Packet::UnsubscribeAck {
packet_id: packet_id(0x4321)
},
b"\xb0\x02\x43\x21"
);
}
#[test]
fn test_encode_ping_packets() {
assert_packet!(&Packet::PingRequest, b"\xc0\x00");
assert_packet!(&Packet::PingResponse, b"\xd0\x00");
}
}

161
actix-mqtt/codec/src/codec/mod.rs Executable file
View File

@ -0,0 +1,161 @@
use actix_codec::{Decoder, Encoder};
use bytes::{buf::Buf, BytesMut};
use crate::error::ParseError;
use crate::proto::QoS;
use crate::{Packet, Publish};
mod decode;
mod encode;
use self::decode::*;
use self::encode::*;
bitflags! {
pub struct ConnectFlags: u8 {
const USERNAME = 0b1000_0000;
const PASSWORD = 0b0100_0000;
const WILL_RETAIN = 0b0010_0000;
const WILL_QOS = 0b0001_1000;
const WILL = 0b0000_0100;
const CLEAN_SESSION = 0b0000_0010;
}
}
pub const WILL_QOS_SHIFT: u8 = 3;
bitflags! {
pub struct ConnectAckFlags: u8 {
const SESSION_PRESENT = 0b0000_0001;
}
}
#[derive(Debug)]
pub struct Codec {
state: DecodeState,
max_size: usize,
}
#[derive(Debug, Clone, Copy)]
enum DecodeState {
FrameHeader,
Frame(FixedHeader),
}
impl Codec {
/// Create `Codec` instance
pub fn new() -> Self {
Codec {
state: DecodeState::FrameHeader,
max_size: 0,
}
}
/// Set max inbound frame size.
///
/// If max size is set to `0`, size is unlimited.
/// By default max size is set to `0`
pub fn max_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
}
impl Default for Codec {
fn default() -> Self {
Self::new()
}
}
impl Decoder for Codec {
type Item = Packet;
type Error = ParseError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, ParseError> {
loop {
match self.state {
DecodeState::FrameHeader => {
if src.len() < 2 {
return Ok(None);
}
let fixed = src.as_ref()[0];
match decode_variable_length(&src.as_ref()[1..])? {
Some((remaining_length, consumed)) => {
// check max message size
if self.max_size != 0 && self.max_size < remaining_length {
return Err(ParseError::MaxSizeExceeded);
}
src.advance(consumed + 1);
self.state = DecodeState::Frame(FixedHeader {
packet_type: fixed >> 4,
packet_flags: fixed & 0xF,
remaining_length,
});
// todo: validate remaining_length against max frame size config
if src.len() < remaining_length {
// todo: subtract?
src.reserve(remaining_length); // extend receiving buffer to fit the whole frame -- todo: too eager?
return Ok(None);
}
}
None => {
return Ok(None);
}
}
}
DecodeState::Frame(fixed) => {
if src.len() < fixed.remaining_length {
return Ok(None);
}
let packet_buf = src.split_to(fixed.remaining_length);
let packet = read_packet(packet_buf.freeze(), fixed)?;
self.state = DecodeState::FrameHeader;
src.reserve(2);
return Ok(Some(packet));
}
}
}
}
}
impl Encoder for Codec {
type Item = Packet;
type Error = ParseError;
fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), ParseError> {
if let Packet::Publish(Publish { qos, packet_id, .. }) = item {
if (qos == QoS::AtLeastOnce || qos == QoS::ExactlyOnce) && packet_id.is_none() {
return Err(ParseError::PacketIdRequired);
}
}
let content_size = get_encoded_size(&item);
dst.reserve(content_size + 5);
write_packet(&item, dst, content_size);
Ok(())
}
}
#[derive(Debug, PartialEq, Clone, Copy)]
pub(crate) struct FixedHeader {
/// MQTT Control Packet type
pub packet_type: u8,
/// Flags specific to each MQTT Control Packet type
pub packet_flags: u8,
/// the number of bytes remaining within the current packet,
/// including data in the variable header and the payload.
pub remaining_length: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_max_size() {
let mut codec = Codec::new().max_size(5);
let mut buf = BytesMut::new();
buf.extend_from_slice(b"\0\x09");
assert_eq!(codec.decode(&mut buf), Err(ParseError::MaxSizeExceeded));
}
}

57
actix-mqtt/codec/src/error.rs Executable file
View File

@ -0,0 +1,57 @@
use std::{io, str};
#[derive(Debug)]
pub enum ParseError {
InvalidProtocol,
InvalidLength,
MalformedPacket,
UnsupportedProtocolLevel,
ConnectReservedFlagSet,
ConnAckReservedFlagSet,
InvalidClientId,
UnsupportedPacketType,
PacketIdRequired,
MaxSizeExceeded,
IoError(io::Error),
Utf8Error(str::Utf8Error),
}
impl PartialEq for ParseError {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(ParseError::InvalidProtocol, ParseError::InvalidProtocol) => true,
(ParseError::InvalidLength, ParseError::InvalidLength) => true,
(ParseError::UnsupportedProtocolLevel, ParseError::UnsupportedProtocolLevel) => {
true
}
(ParseError::ConnectReservedFlagSet, ParseError::ConnectReservedFlagSet) => true,
(ParseError::ConnAckReservedFlagSet, ParseError::ConnAckReservedFlagSet) => true,
(ParseError::InvalidClientId, ParseError::InvalidClientId) => true,
(ParseError::UnsupportedPacketType, ParseError::UnsupportedPacketType) => true,
(ParseError::PacketIdRequired, ParseError::PacketIdRequired) => true,
(ParseError::MaxSizeExceeded, ParseError::MaxSizeExceeded) => true,
(ParseError::MalformedPacket, ParseError::MalformedPacket) => true,
(ParseError::IoError(_), _) => false,
(ParseError::Utf8Error(_), _) => false,
_ => false,
}
}
}
impl From<io::Error> for ParseError {
fn from(err: io::Error) -> Self {
ParseError::IoError(err)
}
}
impl From<str::Utf8Error> for ParseError {
fn from(err: str::Utf8Error) -> Self {
ParseError::Utf8Error(err)
}
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum TopicError {
InvalidTopic,
InvalidLevel,
}

22
actix-mqtt/codec/src/lib.rs Executable file
View File

@ -0,0 +1,22 @@
#[macro_use]
extern crate bitflags;
extern crate bytestring;
mod error;
#[macro_use]
mod topic;
#[macro_use]
mod proto;
mod codec;
mod packet;
pub use self::codec::Codec;
pub use self::error::{ParseError, TopicError};
pub use self::packet::{Connect, ConnectCode, LastWill, Packet, Publish, SubscribeReturnCode};
pub use self::proto::{Protocol, QoS};
pub use self::topic::{Level, Topic};
// http://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xhtml
pub const TCP_PORT: u16 = 1883;
pub const SSL_PORT: u16 = 8883;

251
actix-mqtt/codec/src/packet.rs Executable file
View File

@ -0,0 +1,251 @@
use bytes::Bytes;
use bytestring::ByteString;
use std::num::NonZeroU16;
use crate::proto::{Protocol, QoS};
#[repr(u8)]
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
/// Connect Return Code
pub enum ConnectCode {
/// Connection accepted
ConnectionAccepted = 0,
/// Connection Refused, unacceptable protocol version
UnacceptableProtocolVersion = 1,
/// Connection Refused, identifier rejected
IdentifierRejected = 2,
/// Connection Refused, Server unavailable
ServiceUnavailable = 3,
/// Connection Refused, bad user name or password
BadUserNameOrPassword = 4,
/// Connection Refused, not authorized
NotAuthorized = 5,
/// Reserved
Reserved = 6,
}
const_enum!(ConnectCode: u8);
impl ConnectCode {
pub fn reason(self) -> &'static str {
match self {
ConnectCode::ConnectionAccepted => "Connection Accepted",
ConnectCode::UnacceptableProtocolVersion => {
"Connection Refused, unacceptable protocol version"
}
ConnectCode::IdentifierRejected => "Connection Refused, identifier rejected",
ConnectCode::ServiceUnavailable => "Connection Refused, Server unavailable",
ConnectCode::BadUserNameOrPassword => {
"Connection Refused, bad user name or password"
}
ConnectCode::NotAuthorized => "Connection Refused, not authorized",
_ => "Connection Refused",
}
}
}
#[derive(Debug, PartialEq, Clone)]
/// Connection Will
pub struct LastWill {
/// the QoS level to be used when publishing the Will Message.
pub qos: QoS,
/// the Will Message is to be Retained when it is published.
pub retain: bool,
/// the Will Topic
pub topic: ByteString,
/// defines the Application Message that is to be published to the Will Topic
pub message: Bytes,
}
#[derive(Debug, PartialEq, Clone)]
/// Connect packet content
pub struct Connect {
/// mqtt protocol version
pub protocol: Protocol,
/// the handling of the Session state.
pub clean_session: bool,
/// a time interval measured in seconds.
pub keep_alive: u16,
/// Will Message be stored on the Server and associated with the Network Connection.
pub last_will: Option<LastWill>,
/// identifies the Client to the Server.
pub client_id: ByteString,
/// username can be used by the Server for authentication and authorization.
pub username: Option<ByteString>,
/// password can be used by the Server for authentication and authorization.
pub password: Option<Bytes>,
}
#[derive(Debug, PartialEq, Clone)]
/// Publish message
pub struct Publish {
/// this might be re-delivery of an earlier attempt to send the Packet.
pub dup: bool,
pub retain: bool,
/// the level of assurance for delivery of an Application Message.
pub qos: QoS,
/// the information channel to which payload data is published.
pub topic: ByteString,
/// only present in PUBLISH Packets where the QoS level is 1 or 2.
pub packet_id: Option<NonZeroU16>,
/// the Application Message that is being published.
pub payload: Bytes,
}
#[derive(Debug, PartialEq, Copy, Clone)]
/// Subscribe Return Code
pub enum SubscribeReturnCode {
Success(QoS),
Failure,
}
#[derive(Debug, PartialEq, Clone)]
/// MQTT Control Packets
pub enum Packet {
/// Client request to connect to Server
Connect(Connect),
/// Connect acknowledgment
ConnectAck {
/// enables a Client to establish whether the Client and Server have a consistent view
/// about whether there is already stored Session state.
session_present: bool,
return_code: ConnectCode,
},
/// Publish message
Publish(Publish),
/// Publish acknowledgment
PublishAck {
/// Packet Identifier
packet_id: NonZeroU16,
},
/// Publish received (assured delivery part 1)
PublishReceived {
/// Packet Identifier
packet_id: NonZeroU16,
},
/// Publish release (assured delivery part 2)
PublishRelease {
/// Packet Identifier
packet_id: NonZeroU16,
},
/// Publish complete (assured delivery part 3)
PublishComplete {
/// Packet Identifier
packet_id: NonZeroU16,
},
/// Client subscribe request
Subscribe {
/// Packet Identifier
packet_id: NonZeroU16,
/// the list of Topic Filters and QoS to which the Client wants to subscribe.
topic_filters: Vec<(ByteString, QoS)>,
},
/// Subscribe acknowledgment
SubscribeAck {
packet_id: NonZeroU16,
/// corresponds to a Topic Filter in the SUBSCRIBE Packet being acknowledged.
status: Vec<SubscribeReturnCode>,
},
/// Unsubscribe request
Unsubscribe {
/// Packet Identifier
packet_id: NonZeroU16,
/// the list of Topic Filters that the Client wishes to unsubscribe from.
topic_filters: Vec<ByteString>,
},
/// Unsubscribe acknowledgment
UnsubscribeAck {
/// Packet Identifier
packet_id: NonZeroU16,
},
/// PING request
PingRequest,
/// PING response
PingResponse,
/// Client is disconnecting
Disconnect,
}
impl Packet {
#[inline]
/// MQTT Control Packet type
pub fn packet_type(&self) -> u8 {
match *self {
Packet::Connect { .. } => CONNECT,
Packet::ConnectAck { .. } => CONNACK,
Packet::Publish { .. } => PUBLISH,
Packet::PublishAck { .. } => PUBACK,
Packet::PublishReceived { .. } => PUBREC,
Packet::PublishRelease { .. } => PUBREL,
Packet::PublishComplete { .. } => PUBCOMP,
Packet::Subscribe { .. } => SUBSCRIBE,
Packet::SubscribeAck { .. } => SUBACK,
Packet::Unsubscribe { .. } => UNSUBSCRIBE,
Packet::UnsubscribeAck { .. } => UNSUBACK,
Packet::PingRequest => PINGREQ,
Packet::PingResponse => PINGRESP,
Packet::Disconnect => DISCONNECT,
}
}
/// Flags specific to each MQTT Control Packet type
pub fn packet_flags(&self) -> u8 {
match *self {
Packet::Publish(Publish {
dup, qos, retain, ..
}) => {
let mut b = qos as u8;
b <<= 1;
if dup {
b |= 0b1000;
}
if retain {
b |= 0b0001;
}
b
}
Packet::PublishRelease { .. }
| Packet::Subscribe { .. }
| Packet::Unsubscribe { .. } => 0b0010,
_ => 0,
}
}
}
impl From<Connect> for Packet {
fn from(val: Connect) -> Packet {
Packet::Connect(val)
}
}
impl From<Publish> for Packet {
fn from(val: Publish) -> Packet {
Packet::Publish(val)
}
}
pub const CONNECT: u8 = 1;
pub const CONNACK: u8 = 2;
pub const PUBLISH: u8 = 3;
pub const PUBACK: u8 = 4;
pub const PUBREC: u8 = 5;
pub const PUBREL: u8 = 6;
pub const PUBCOMP: u8 = 7;
pub const SUBSCRIBE: u8 = 8;
pub const SUBACK: u8 = 9;
pub const UNSUBSCRIBE: u8 = 10;
pub const UNSUBACK: u8 = 11;
pub const PINGREQ: u8 = 12;
pub const PINGRESP: u8 = 13;
pub const DISCONNECT: u8 = 14;

64
actix-mqtt/codec/src/proto.rs Executable file
View File

@ -0,0 +1,64 @@
#[macro_export]
macro_rules! const_enum {
($name:ty : $repr:ty) => {
impl ::std::convert::From<$repr> for $name {
fn from(u: $repr) -> Self {
unsafe { ::std::mem::transmute(u) }
}
}
};
}
pub const DEFAULT_MQTT_LEVEL: u8 = 4;
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Protocol {
MQTT(u8),
}
impl Protocol {
pub fn name(self) -> &'static str {
match self {
Protocol::MQTT(_) => "MQTT",
}
}
pub fn level(self) -> u8 {
match self {
Protocol::MQTT(level) => level,
}
}
}
impl Default for Protocol {
fn default() -> Self {
Protocol::MQTT(DEFAULT_MQTT_LEVEL)
}
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
/// Quality of Service levels
pub enum QoS {
/// At most once delivery
///
/// The message is delivered according to the capabilities of the underlying network.
/// No response is sent by the receiver and no retry is performed by the sender.
/// The message arrives at the receiver either once or not at all.
AtMostOnce = 0,
/// At least once delivery
///
/// This quality of service ensures that the message arrives at the receiver at least once.
/// A QoS 1 PUBLISH Packet has a Packet Identifier in its variable header
/// and is acknowledged by a PUBACK Packet.
AtLeastOnce = 1,
/// Exactly once delivery
///
/// This is the highest quality of service,
/// for use when neither loss nor duplication of messages are acceptable.
/// There is an increased overhead associated with this quality of service.
ExactlyOnce = 2,
}
const_enum!(QoS: u8);

520
actix-mqtt/codec/src/topic.rs Executable file
View File

@ -0,0 +1,520 @@
use std::fmt::{self, Write};
use std::{io, ops, str::FromStr};
use crate::error::TopicError;
#[inline]
fn is_metadata<T: AsRef<str>>(s: T) -> bool {
s.as_ref().chars().nth(0) == Some('$')
}
#[derive(Debug, Eq, PartialEq, Clone, Hash)]
pub enum Level {
Normal(String),
Metadata(String), // $SYS
Blank,
SingleWildcard, // Single level wildcard +
MultiWildcard, // Multi-level wildcard #
}
impl Level {
pub fn parse<T: AsRef<str>>(s: T) -> Result<Level, TopicError> {
Level::from_str(s.as_ref())
}
pub fn normal<T: AsRef<str>>(s: T) -> Level {
if s.as_ref().contains(|c| c == '+' || c == '#') {
panic!("invalid normal level `{}` contains +|#", s.as_ref());
}
if s.as_ref().chars().nth(0) == Some('$') {
panic!("invalid normal level `{}` starts with $", s.as_ref())
}
Level::Normal(String::from(s.as_ref()))
}
pub fn metadata<T: AsRef<str>>(s: T) -> Level {
if s.as_ref().contains(|c| c == '+' || c == '#') {
panic!("invalid metadata level `{}` contains +|#", s.as_ref());
}
if s.as_ref().chars().nth(0) != Some('$') {
panic!("invalid metadata level `{}` not starts with $", s.as_ref())
}
Level::Metadata(String::from(s.as_ref()))
}
#[inline]
pub fn value(&self) -> Option<&str> {
match *self {
Level::Normal(ref s) | Level::Metadata(ref s) => Some(s),
_ => None,
}
}
#[inline]
pub fn is_normal(&self) -> bool {
if let Level::Normal(_) = *self {
true
} else {
false
}
}
#[inline]
pub fn is_metadata(&self) -> bool {
if let Level::Metadata(_) = *self {
true
} else {
false
}
}
#[inline]
pub fn is_valid(&self) -> bool {
match *self {
Level::Normal(ref s) => {
s.chars().nth(0) != Some('$') && !s.contains(|c| c == '+' || c == '#')
}
Level::Metadata(ref s) => {
s.chars().nth(0) == Some('$') && !s.contains(|c| c == '+' || c == '#')
}
_ => true,
}
}
}
#[derive(Debug, Eq, Clone)]
pub struct Topic(Vec<Level>);
impl Topic {
#[inline]
pub fn levels(&self) -> &Vec<Level> {
&self.0
}
#[inline]
pub fn is_valid(&self) -> bool {
self.0
.iter()
.position(|level| !level.is_valid())
.or_else(|| {
self.0
.iter()
.enumerate()
.position(|(pos, level)| match *level {
Level::MultiWildcard => pos != self.0.len() - 1,
Level::Metadata(_) => pos != 0,
_ => false,
})
})
.is_none()
}
}
macro_rules! match_topic {
($topic:expr, $levels:expr) => {{
let mut lhs = $topic.0.iter();
for rhs in $levels {
match lhs.next() {
Some(&Level::SingleWildcard) => {
if !rhs.match_level(&Level::SingleWildcard) {
break;
}
}
Some(&Level::MultiWildcard) => {
return rhs.match_level(&Level::MultiWildcard);
}
Some(level) if rhs.match_level(level) => continue,
_ => return false,
}
}
match lhs.next() {
Some(&Level::MultiWildcard) => true,
Some(_) => false,
None => true,
}
}};
}
impl PartialEq for Topic {
fn eq(&self, other: &Topic) -> bool {
match_topic!(self, &other.0)
}
}
impl<T: AsRef<str>> PartialEq<T> for Topic {
fn eq(&self, other: &T) -> bool {
match_topic!(self, other.as_ref().split('/'))
}
}
impl<'a> From<&'a [Level]> for Topic {
fn from(s: &[Level]) -> Self {
let mut v = vec![];
v.extend_from_slice(s);
Topic(v)
}
}
impl From<Vec<Level>> for Topic {
fn from(v: Vec<Level>) -> Self {
Topic(v)
}
}
impl Into<Vec<Level>> for Topic {
fn into(self) -> Vec<Level> {
self.0
}
}
impl ops::Deref for Topic {
type Target = Vec<Level>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl ops::DerefMut for Topic {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[macro_export]
macro_rules! topic {
($s:expr) => {
$s.parse::<Topic>().unwrap()
};
}
pub trait MatchLevel {
fn match_level(&self, level: &Level) -> bool;
}
impl MatchLevel for Level {
fn match_level(&self, level: &Level) -> bool {
match *level {
Level::Normal(ref lhs) => {
if let Level::Normal(ref rhs) = *self {
lhs == rhs
} else {
false
}
}
Level::Metadata(ref lhs) => {
if let Level::Metadata(ref rhs) = *self {
lhs == rhs
} else {
false
}
}
Level::Blank => true,
Level::SingleWildcard | Level::MultiWildcard => !self.is_metadata(),
}
}
}
impl<T: AsRef<str>> MatchLevel for T {
fn match_level(&self, level: &Level) -> bool {
match *level {
Level::Normal(ref lhs) => !is_metadata(self) && lhs == self.as_ref(),
Level::Metadata(ref lhs) => is_metadata(self) && lhs == self.as_ref(),
Level::Blank => self.as_ref().is_empty(),
Level::SingleWildcard | Level::MultiWildcard => !is_metadata(self),
}
}
}
impl FromStr for Level {
type Err = TopicError;
#[inline]
fn from_str(s: &str) -> Result<Self, TopicError> {
match s {
"+" => Ok(Level::SingleWildcard),
"#" => Ok(Level::MultiWildcard),
"" => Ok(Level::Blank),
_ => {
if s.contains(|c| c == '+' || c == '#') {
Err(TopicError::InvalidLevel)
} else if is_metadata(s) {
Ok(Level::Metadata(String::from(s)))
} else {
Ok(Level::Normal(String::from(s)))
}
}
}
}
}
impl FromStr for Topic {
type Err = TopicError;
#[inline]
fn from_str(s: &str) -> Result<Self, TopicError> {
s.split('/')
.map(|level| Level::from_str(level))
.collect::<Result<Vec<_>, TopicError>>()
.map(Topic)
.and_then(|topic| {
if topic.is_valid() {
Ok(topic)
} else {
Err(TopicError::InvalidTopic)
}
})
}
}
impl fmt::Display for Level {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Level::Normal(ref s) | Level::Metadata(ref s) => f.write_str(s.as_str()),
Level::Blank => Ok(()),
Level::SingleWildcard => f.write_char('+'),
Level::MultiWildcard => f.write_char('#'),
}
}
}
impl fmt::Display for Topic {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut first = true;
for level in &self.0 {
if first {
first = false;
} else {
f.write_char('/')?;
}
level.fmt(f)?;
}
Ok(())
}
}
pub trait WriteTopicExt: io::Write {
fn write_level(&mut self, level: &Level) -> io::Result<usize> {
match *level {
Level::Normal(ref s) | Level::Metadata(ref s) => self.write(s.as_str().as_bytes()),
Level::Blank => Ok(0),
Level::SingleWildcard => self.write(b"+"),
Level::MultiWildcard => self.write(b"#"),
}
}
fn write_topic(&mut self, topic: &Topic) -> io::Result<usize> {
let mut n = 0;
let mut first = true;
for level in topic.levels() {
if first {
first = false;
} else {
n += self.write(b"/")?;
}
n += self.write_level(level)?;
}
Ok(n)
}
}
impl<W: io::Write + ?Sized> WriteTopicExt for W {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_level() {
assert!(Level::normal("sport").is_normal());
assert!(Level::metadata("$SYS").is_metadata());
assert_eq!(Level::normal("sport").value(), Some("sport"));
assert_eq!(Level::metadata("$SYS").value(), Some("$SYS"));
assert_eq!(Level::normal("sport"), "sport".parse().unwrap());
assert_eq!(Level::metadata("$SYS"), "$SYS".parse().unwrap());
assert!(Level::Normal(String::from("sport")).is_valid());
assert!(Level::Metadata(String::from("$SYS")).is_valid());
assert!(!Level::Normal(String::from("$sport")).is_valid());
assert!(!Level::Metadata(String::from("SYS")).is_valid());
assert!(!Level::Normal(String::from("sport#")).is_valid());
assert!(!Level::Metadata(String::from("SYS+")).is_valid());
}
#[test]
fn test_valid_topic() {
assert!(Topic(vec![
Level::normal("sport"),
Level::normal("tennis"),
Level::normal("player1")
])
.is_valid());
assert!(Topic(vec![
Level::normal("sport"),
Level::normal("tennis"),
Level::MultiWildcard
])
.is_valid());
assert!(Topic(vec![
Level::metadata("$SYS"),
Level::normal("tennis"),
Level::MultiWildcard
])
.is_valid());
assert!(Topic(vec![
Level::normal("sport"),
Level::SingleWildcard,
Level::normal("player1")
])
.is_valid());
assert!(!Topic(vec![
Level::normal("sport"),
Level::MultiWildcard,
Level::normal("player1")
])
.is_valid());
assert!(!Topic(vec![
Level::normal("sport"),
Level::metadata("$SYS"),
Level::normal("player1")
])
.is_valid());
}
#[test]
fn test_parse_topic() {
assert_eq!(
topic!("sport/tennis/player1"),
Topic::from(vec![
Level::normal("sport"),
Level::normal("tennis"),
Level::normal("player1")
])
);
assert_eq!(topic!(""), Topic(vec![Level::Blank]));
assert_eq!(
topic!("/finance"),
Topic::from(vec![Level::Blank, Level::normal("finance")])
);
assert_eq!(topic!("$SYS"), Topic::from(vec![Level::metadata("$SYS")]));
assert!("sport/$SYS".parse::<Topic>().is_err());
}
#[test]
fn test_multi_wildcard_topic() {
assert_eq!(
topic!("sport/tennis/#"),
Topic::from(vec![
Level::normal("sport"),
Level::normal("tennis"),
Level::MultiWildcard
])
);
assert_eq!(topic!("#"), Topic::from(vec![Level::MultiWildcard]));
assert!("sport/tennis#".parse::<Topic>().is_err());
assert!("sport/tennis/#/ranking".parse::<Topic>().is_err());
}
#[test]
fn test_single_wildcard_topic() {
assert_eq!(topic!("+"), Topic::from(vec![Level::SingleWildcard]));
assert_eq!(
topic!("+/tennis/#"),
Topic::from(vec![
Level::SingleWildcard,
Level::normal("tennis"),
Level::MultiWildcard
])
);
assert_eq!(
topic!("sport/+/player1"),
Topic::from(vec![
Level::normal("sport"),
Level::SingleWildcard,
Level::normal("player1")
])
);
assert!("sport+".parse::<Topic>().is_err());
}
#[test]
fn test_write_topic() {
let mut v = vec![];
let t = vec![
Level::SingleWildcard,
Level::normal("tennis"),
Level::MultiWildcard,
]
.into();
assert_eq!(v.write_topic(&t).unwrap(), 10);
assert_eq!(v, b"+/tennis/#");
assert_eq!(format!("{}", t), "+/tennis/#");
assert_eq!(t.to_string(), "+/tennis/#");
}
#[test]
fn test_match_topic() {
assert!("test".match_level(&Level::normal("test")));
assert!("$SYS".match_level(&Level::metadata("$SYS")));
let t: Topic = "sport/tennis/player1/#".parse().unwrap();
assert_eq!(t, "sport/tennis/player1");
assert_eq!(t, "sport/tennis/player1/ranking");
assert_eq!(t, "sport/tennis/player1/score/wimbledon");
assert_eq!(Topic::from_str("sport/#").unwrap(), "sport");
let t: Topic = "sport/tennis/+".parse().unwrap();
assert_eq!(t, "sport/tennis/player1");
assert_eq!(t, "sport/tennis/player2");
assert!(t != "sport/tennis/player1/ranking");
let t: Topic = "sport/+".parse().unwrap();
assert!(t != "sport");
assert_eq!(t, "sport/");
assert_eq!(Topic::from_str("+/+").unwrap(), "/finance");
assert_eq!(Topic::from_str("/+").unwrap(), "/finance",);
assert!(Topic::from_str("+").unwrap() != "/finance",);
assert!(Topic::from_str("#").unwrap() != "$SYS");
assert!(Topic::from_str("+/monitor/Clients").unwrap() != "$SYS/monitor/Clients");
assert_eq!(Topic::from_str(&"$SYS/#").unwrap(), "$SYS/");
assert_eq!(
Topic::from_str("$SYS/monitor/+").unwrap(),
"$SYS/monitor/Clients",
);
}
}

35
actix-mqtt/examples/basic.rs Executable file
View File

@ -0,0 +1,35 @@
use actix_mqtt::{Connect, ConnectAck, MqttServer, Publish};
#[derive(Clone)]
struct Session;
async fn connect<Io>(connect: Connect<Io>) -> Result<ConnectAck<Io, Session>, ()> {
log::info!("new connection: {:?}", connect);
Ok(connect.ack(Session, false))
}
async fn publish(publish: Publish<Session>) -> Result<(), ()> {
log::info!(
"incoming publish: {:?} -> {:?}",
publish.id(),
publish.topic()
);
Ok(())
}
#[actix_rt::main]
async fn main() -> std::io::Result<()> {
std::env::set_var(
"RUST_LOG",
"actix_server=trace,actix_mqtt=trace,basic=trace",
);
env_logger::init();
actix_server::Server::build()
.bind("mqtt", "127.0.0.1:1883", || {
MqttServer::new(connect).finish(publish)
})?
.workers(1)
.run()
.await
}

64
actix-mqtt/src/cell.rs Executable file
View File

@ -0,0 +1,64 @@
//! Custom cell impl
use std::cell::UnsafeCell;
use std::ops::Deref;
use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::Service;
pub(crate) struct Cell<T> {
inner: Rc<UnsafeCell<T>>,
}
impl<T> Clone for Cell<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<T> Deref for Cell<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.get_ref()
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for Cell<T> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
self.inner.fmt(f)
}
}
impl<T> Cell<T> {
pub fn new(inner: T) -> Self {
Self {
inner: Rc::new(UnsafeCell::new(inner)),
}
}
pub fn get_ref(&self) -> &T {
unsafe { &*self.inner.as_ref().get() }
}
pub fn get_mut(&mut self) -> &mut T {
unsafe { &mut *self.inner.as_ref().get() }
}
}
impl<T: Service> Service for Cell<T> {
type Request = T::Request;
type Response = T::Response;
type Error = T::Error;
type Future = T::Future;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), T::Error>> {
self.get_mut().poll_ready(cx)
}
fn call(&mut self, req: T::Request) -> T::Future {
self.get_mut().call(req)
}
}

406
actix-mqtt/src/client.rs Executable file
View File

@ -0,0 +1,406 @@
use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use std::time::Duration;
use actix_codec::{AsyncRead, AsyncWrite};
use actix_ioframe as ioframe;
use actix_service::{boxed, IntoService, IntoServiceFactory, Service, ServiceFactory};
use bytes::Bytes;
use bytestring::ByteString;
use futures::future::{FutureExt, LocalBoxFuture};
use futures::{Sink, SinkExt, Stream, StreamExt};
use mqtt_codec as mqtt;
use crate::cell::Cell;
use crate::default::{SubsNotImplemented, UnsubsNotImplemented};
use crate::dispatcher::{dispatcher, MqttState};
use crate::error::MqttError;
use crate::publish::Publish;
use crate::sink::MqttSink;
use crate::subs::{Subscribe, SubscribeResult, Unsubscribe};
/// Mqtt client
#[derive(Clone)]
pub struct Client<Io, St> {
client_id: ByteString,
clean_session: bool,
protocol: mqtt::Protocol,
keep_alive: u16,
last_will: Option<mqtt::LastWill>,
username: Option<ByteString>,
password: Option<Bytes>,
inflight: usize,
_t: PhantomData<(Io, St)>,
}
impl<Io, St> Client<Io, St>
where
St: 'static,
{
/// Create new client and provide client id
pub fn new(client_id: ByteString) -> Self {
Client {
client_id,
clean_session: true,
protocol: mqtt::Protocol::default(),
keep_alive: 30,
last_will: None,
username: None,
password: None,
inflight: 15,
_t: PhantomData,
}
}
/// Mqtt protocol version
pub fn protocol(mut self, val: mqtt::Protocol) -> Self {
self.protocol = val;
self
}
/// The handling of the Session state.
pub fn clean_session(mut self, val: bool) -> Self {
self.clean_session = val;
self
}
/// A time interval measured in seconds.
///
/// keep-alive is set to 30 seconds by default.
pub fn keep_alive(mut self, val: u16) -> Self {
self.keep_alive = val;
self
}
/// Will Message be stored on the Server and associated with the Network Connection.
///
/// by default last will value is not set
pub fn last_will(mut self, val: mqtt::LastWill) -> Self {
self.last_will = Some(val);
self
}
/// Username can be used by the Server for authentication and authorization.
pub fn username(mut self, val: ByteString) -> Self {
self.username = Some(val);
self
}
/// Password can be used by the Server for authentication and authorization.
pub fn password(mut self, val: Bytes) -> Self {
self.password = Some(val);
self
}
/// Number of in-flight concurrent messages.
///
/// in-flight is set to 15 messages
pub fn inflight(mut self, val: usize) -> Self {
self.inflight = val;
self
}
/// Set state service
///
/// State service verifies connect ack packet and construct connection state.
pub fn state<C, F>(self, state: F) -> ServiceBuilder<Io, St, C>
where
F: IntoService<C>,
Io: AsyncRead + AsyncWrite,
C: Service<Request = ConnectAck<Io>, Response = ConnectAckResult<Io, St>>,
C::Error: 'static,
{
ServiceBuilder {
state: Cell::new(state.into_service()),
packet: mqtt::Connect {
client_id: self.client_id,
clean_session: self.clean_session,
protocol: self.protocol,
keep_alive: self.keep_alive,
last_will: self.last_will,
username: self.username,
password: self.password,
},
subscribe: Rc::new(boxed::factory(SubsNotImplemented::default())),
unsubscribe: Rc::new(boxed::factory(UnsubsNotImplemented::default())),
disconnect: None,
keep_alive: self.keep_alive.into(),
inflight: self.inflight,
_t: PhantomData,
}
}
}
pub struct ServiceBuilder<Io, St, C: Service> {
state: Cell<C>,
packet: mqtt::Connect,
subscribe: Rc<
boxed::BoxServiceFactory<
St,
Subscribe<St>,
SubscribeResult,
MqttError<C::Error>,
MqttError<C::Error>,
>,
>,
unsubscribe: Rc<
boxed::BoxServiceFactory<
St,
Unsubscribe<St>,
(),
MqttError<C::Error>,
MqttError<C::Error>,
>,
>,
disconnect: Option<Cell<boxed::BoxService<St, (), MqttError<C::Error>>>>,
keep_alive: u64,
inflight: usize,
_t: PhantomData<(Io, St, C)>,
}
impl<Io, St, C> ServiceBuilder<Io, St, C>
where
St: Clone + 'static,
Io: AsyncRead + AsyncWrite + 'static,
C: Service<Request = ConnectAck<Io>, Response = ConnectAckResult<Io, St>> + 'static,
C::Error: 'static,
{
/// Service to execute on disconnect
pub fn disconnect<UF, U>(mut self, srv: UF) -> Self
where
UF: IntoService<U>,
U: Service<Request = St, Response = (), Error = C::Error> + 'static,
{
self.disconnect = Some(Cell::new(boxed::service(
srv.into_service().map_err(MqttError::Service),
)));
self
}
pub fn finish<F, T>(
self,
service: F,
) -> impl Service<Request = Io, Response = (), Error = MqttError<C::Error>>
where
F: IntoServiceFactory<T>,
T: ServiceFactory<
Config = St,
Request = Publish<St>,
Response = (),
Error = C::Error,
InitError = C::Error,
> + 'static,
{
ioframe::Builder::new()
.service(ConnectService {
connect: self.state,
packet: self.packet,
keep_alive: self.keep_alive,
inflight: self.inflight,
_t: PhantomData,
})
.finish(dispatcher(
service
.into_factory()
.map_err(MqttError::Service)
.map_init_err(MqttError::Service),
self.subscribe,
self.unsubscribe,
))
.map_err(|e| match e {
ioframe::ServiceError::Service(e) => e,
ioframe::ServiceError::Encoder(e) => MqttError::Protocol(e),
ioframe::ServiceError::Decoder(e) => MqttError::Protocol(e),
})
}
}
struct ConnectService<Io, St, C> {
connect: Cell<C>,
packet: mqtt::Connect,
keep_alive: u64,
inflight: usize,
_t: PhantomData<(Io, St)>,
}
impl<Io, St, C> Service for ConnectService<Io, St, C>
where
St: 'static,
Io: AsyncRead + AsyncWrite + 'static,
C: Service<Request = ConnectAck<Io>, Response = ConnectAckResult<Io, St>> + 'static,
C::Error: 'static,
{
type Request = ioframe::Connect<Io, mqtt::Codec>;
type Response = ioframe::ConnectResult<Io, MqttState<St>, mqtt::Codec>;
type Error = MqttError<C::Error>;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.connect
.get_mut()
.poll_ready(cx)
.map_err(MqttError::Service)
}
fn call(&mut self, req: Self::Request) -> Self::Future {
let mut srv = self.connect.clone();
let packet = self.packet.clone();
let keep_alive = Duration::from_secs(self.keep_alive as u64);
let inflight = self.inflight;
// send Connect packet
async move {
let mut framed = req.codec(mqtt::Codec::new());
framed
.send(mqtt::Packet::Connect(packet))
.await
.map_err(MqttError::Protocol)?;
let packet = framed
.next()
.await
.ok_or(MqttError::Disconnected)
.and_then(|res| res.map_err(MqttError::Protocol))?;
match packet {
mqtt::Packet::ConnectAck {
session_present,
return_code,
} => {
let sink = MqttSink::new(framed.sink().clone());
let ack = ConnectAck {
sink,
session_present,
return_code,
keep_alive,
inflight,
io: framed,
};
Ok(srv
.get_mut()
.call(ack)
.await
.map_err(MqttError::Service)
.map(|ack| ack.io.state(ack.state))?)
}
p => Err(MqttError::Unexpected(p, "Expected CONNECT-ACK packet")),
}
}
.boxed_local()
}
}
pub struct ConnectAck<Io> {
io: ioframe::ConnectResult<Io, (), mqtt::Codec>,
sink: MqttSink,
session_present: bool,
return_code: mqtt::ConnectCode,
keep_alive: Duration,
inflight: usize,
}
impl<Io> ConnectAck<Io> {
#[inline]
/// Indicates whether there is already stored Session state
pub fn session_present(&self) -> bool {
self.session_present
}
#[inline]
/// Connect return code
pub fn return_code(&self) -> mqtt::ConnectCode {
self.return_code
}
#[inline]
/// Mqtt client sink object
pub fn sink(&self) -> &MqttSink {
&self.sink
}
#[inline]
/// Set connection state and create result object
pub fn state<St>(self, state: St) -> ConnectAckResult<Io, St> {
ConnectAckResult {
io: self.io,
state: MqttState::new(state, self.sink, self.keep_alive, self.inflight),
}
}
}
impl<Io> Stream for ConnectAck<Io>
where
Io: AsyncRead + AsyncWrite + Unpin,
{
type Item = Result<mqtt::Packet, mqtt::ParseError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.io).poll_next(cx)
}
}
impl<Io> Sink<mqtt::Packet> for ConnectAck<Io>
where
Io: AsyncRead + AsyncWrite + Unpin,
{
type Error = mqtt::ParseError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.io).poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: mqtt::Packet) -> Result<(), Self::Error> {
Pin::new(&mut self.io).start_send(item)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.io).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.io).poll_close(cx)
}
}
#[pin_project::pin_project]
pub struct ConnectAckResult<Io, St> {
state: MqttState<St>,
io: ioframe::ConnectResult<Io, (), mqtt::Codec>,
}
impl<Io, St> Stream for ConnectAckResult<Io, St>
where
Io: AsyncRead + AsyncWrite + Unpin,
{
type Item = Result<mqtt::Packet, mqtt::ParseError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.io).poll_next(cx)
}
}
impl<Io, St> Sink<mqtt::Packet> for ConnectAckResult<Io, St>
where
Io: AsyncRead + AsyncWrite + Unpin,
{
type Error = mqtt::ParseError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.io).poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: mqtt::Packet) -> Result<(), Self::Error> {
Pin::new(&mut self.io).start_send(item)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.io).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.io).poll_close(cx)
}
}

150
actix-mqtt/src/connect.rs Executable file
View File

@ -0,0 +1,150 @@
use std::fmt;
use std::ops::Deref;
use std::time::Duration;
use actix_ioframe as ioframe;
use mqtt_codec as mqtt;
use crate::sink::MqttSink;
/// Connect message
pub struct Connect<Io> {
connect: mqtt::Connect,
sink: MqttSink,
keep_alive: Duration,
inflight: usize,
io: ioframe::ConnectResult<Io, (), mqtt::Codec>,
}
impl<Io> Connect<Io> {
pub(crate) fn new(
connect: mqtt::Connect,
io: ioframe::ConnectResult<Io, (), mqtt::Codec>,
sink: MqttSink,
inflight: usize,
) -> Self {
Self {
keep_alive: Duration::from_secs(connect.keep_alive as u64),
connect,
io,
sink,
inflight,
}
}
/// Returns reference to io object
pub fn get_ref(&self) -> &Io {
self.io.get_ref()
}
/// Returns mutable reference to io object
pub fn get_mut(&mut self) -> &mut Io {
self.io.get_mut()
}
/// Returns mqtt server sink
pub fn sink(&self) -> &MqttSink {
&self.sink
}
/// Ack connect message and set state
pub fn ack<St>(self, st: St, session_present: bool) -> ConnectAck<Io, St> {
ConnectAck::new(self.io, st, session_present, self.keep_alive, self.inflight)
}
/// Create connect ack object with `identifier rejected` return code
pub fn identifier_rejected<St>(self) -> ConnectAck<Io, St> {
ConnectAck {
io: self.io,
session: None,
session_present: false,
return_code: mqtt::ConnectCode::IdentifierRejected,
keep_alive: Duration::from_secs(5),
inflight: 15,
}
}
/// Create connect ack object with `bad user name or password` return code
pub fn bad_username_or_pwd<St>(self) -> ConnectAck<Io, St> {
ConnectAck {
io: self.io,
session: None,
session_present: false,
return_code: mqtt::ConnectCode::BadUserNameOrPassword,
keep_alive: Duration::from_secs(5),
inflight: 15,
}
}
/// Create connect ack object with `not authorized` return code
pub fn not_authorized<St>(self) -> ConnectAck<Io, St> {
ConnectAck {
io: self.io,
session: None,
session_present: false,
return_code: mqtt::ConnectCode::NotAuthorized,
keep_alive: Duration::from_secs(5),
inflight: 15,
}
}
}
impl<Io> Deref for Connect<Io> {
type Target = mqtt::Connect;
fn deref(&self) -> &Self::Target {
&self.connect
}
}
impl<T> fmt::Debug for Connect<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.connect.fmt(f)
}
}
/// Ack connect message
pub struct ConnectAck<Io, St> {
pub(crate) io: ioframe::ConnectResult<Io, (), mqtt::Codec>,
pub(crate) session: Option<St>,
pub(crate) session_present: bool,
pub(crate) return_code: mqtt::ConnectCode,
pub(crate) keep_alive: Duration,
pub(crate) inflight: usize,
}
impl<Io, St> ConnectAck<Io, St> {
/// Create connect ack, `session_present` indicates that previous session is presents
pub(crate) fn new(
io: ioframe::ConnectResult<Io, (), mqtt::Codec>,
session: St,
session_present: bool,
keep_alive: Duration,
inflight: usize,
) -> Self {
Self {
io,
session_present,
keep_alive,
inflight,
session: Some(session),
return_code: mqtt::ConnectCode::ConnectionAccepted,
}
}
/// Set idle time-out for the connection in milliseconds
///
/// By default idle time-out is set to 300000 milliseconds
pub fn idle_timeout(mut self, timeout: Duration) -> Self {
self.keep_alive = timeout;
self
}
/// Set in-flight count. Total number of `in-flight` packets
///
/// By default in-flight count is set to 15
pub fn in_flight(mut self, in_flight: usize) -> Self {
self.inflight = in_flight;
self
}
}

125
actix-mqtt/src/default.rs Executable file
View File

@ -0,0 +1,125 @@
use std::marker::PhantomData;
use std::task::{Context, Poll};
use actix_service::{Service, ServiceFactory};
use futures::future::{ok, Ready};
use crate::publish::Publish;
use crate::subs::{Subscribe, SubscribeResult, Unsubscribe};
/// Not implemented publish service
pub struct NotImplemented<S, E>(PhantomData<(S, E)>);
impl<S, E> Default for NotImplemented<S, E> {
fn default() -> Self {
NotImplemented(PhantomData)
}
}
impl<S, E> ServiceFactory for NotImplemented<S, E> {
type Config = S;
type Request = Publish<S>;
type Response = ();
type Error = E;
type InitError = E;
type Service = NotImplemented<S, E>;
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: S) -> Self::Future {
ok(NotImplemented(PhantomData))
}
}
impl<S, E> Service for NotImplemented<S, E> {
type Request = Publish<S>;
type Response = ();
type Error = E;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: Publish<S>) -> Self::Future {
log::warn!("MQTT Publish is not implemented");
ok(())
}
}
/// Not implemented subscribe service
pub struct SubsNotImplemented<S, E>(PhantomData<(S, E)>);
impl<S, E> Default for SubsNotImplemented<S, E> {
fn default() -> Self {
SubsNotImplemented(PhantomData)
}
}
impl<S, E> ServiceFactory for SubsNotImplemented<S, E> {
type Config = S;
type Request = Subscribe<S>;
type Response = SubscribeResult;
type Error = E;
type InitError = E;
type Service = SubsNotImplemented<S, E>;
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: S) -> Self::Future {
ok(SubsNotImplemented(PhantomData))
}
}
impl<S, E> Service for SubsNotImplemented<S, E> {
type Request = Subscribe<S>;
type Response = SubscribeResult;
type Error = E;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, subs: Subscribe<S>) -> Self::Future {
log::warn!("MQTT Subscribe is not implemented");
ok(subs.into_result())
}
}
/// Not implemented unsubscribe service
pub struct UnsubsNotImplemented<S, E>(PhantomData<(S, E)>);
impl<S, E> Default for UnsubsNotImplemented<S, E> {
fn default() -> Self {
UnsubsNotImplemented(PhantomData)
}
}
impl<S, E> ServiceFactory for UnsubsNotImplemented<S, E> {
type Config = S;
type Request = Unsubscribe<S>;
type Response = ();
type Error = E;
type InitError = E;
type Service = UnsubsNotImplemented<S, E>;
type Future = Ready<Result<Self::Service, Self::InitError>>;
fn new_service(&self, _: S) -> Self::Future {
ok(UnsubsNotImplemented(PhantomData))
}
}
impl<S, E> Service for UnsubsNotImplemented<S, E> {
type Request = Unsubscribe<S>;
type Response = ();
type Error = E;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: Unsubscribe<S>) -> Self::Future {
log::warn!("MQTT Unsubscribe is not implemented");
ok(())
}
}

286
actix-mqtt/src/dispatcher.rs Executable file
View File

@ -0,0 +1,286 @@
use std::future::Future;
use std::marker::PhantomData;
use std::num::NonZeroU16;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use std::time::Duration;
use actix_ioframe as ioframe;
use actix_service::{boxed, fn_factory_with_config, pipeline, Service, ServiceFactory};
use actix_utils::inflight::InFlightService;
use actix_utils::keepalive::KeepAliveService;
use actix_utils::order::{InOrder, InOrderError};
use actix_utils::time::LowResTimeService;
use futures::future::{join3, ok, Either, FutureExt, LocalBoxFuture, Ready};
use futures::ready;
use mqtt_codec as mqtt;
use crate::cell::Cell;
use crate::error::MqttError;
use crate::publish::Publish;
use crate::sink::MqttSink;
use crate::subs::{Subscribe, SubscribeResult, Unsubscribe};
pub(crate) struct MqttState<St> {
inner: Cell<MqttStateInner<St>>,
}
struct MqttStateInner<St> {
pub(crate) st: St,
pub(crate) sink: MqttSink,
pub(self) timeout: Duration,
pub(self) in_flight: usize,
}
impl<St> Clone for MqttState<St> {
fn clone(&self) -> Self {
MqttState {
inner: self.inner.clone(),
}
}
}
impl<St> MqttState<St> {
pub(crate) fn new(st: St, sink: MqttSink, timeout: Duration, in_flight: usize) -> Self {
MqttState {
inner: Cell::new(MqttStateInner {
st,
sink,
timeout,
in_flight,
}),
}
}
pub(crate) fn sink(&self) -> &MqttSink {
&self.inner.sink
}
pub(crate) fn session(&self) -> &St {
&self.inner.get_ref().st
}
pub(crate) fn session_mut(&mut self) -> &mut St {
&mut self.inner.get_mut().st
}
}
// dispatcher factory
pub(crate) fn dispatcher<St, T, E>(
publish: T,
subscribe: Rc<
boxed::BoxServiceFactory<
St,
Subscribe<St>,
SubscribeResult,
MqttError<E>,
MqttError<E>,
>,
>,
unsubscribe: Rc<
boxed::BoxServiceFactory<St, Unsubscribe<St>, (), MqttError<E>, MqttError<E>>,
>,
) -> impl ServiceFactory<
Config = MqttState<St>,
Request = ioframe::Item<MqttState<St>, mqtt::Codec>,
Response = Option<mqtt::Packet>,
Error = MqttError<E>,
InitError = MqttError<E>,
>
where
E: 'static,
St: Clone + 'static,
T: ServiceFactory<
Config = St,
Request = Publish<St>,
Response = (),
Error = MqttError<E>,
InitError = MqttError<E>,
> + 'static,
{
let time = LowResTimeService::with(Duration::from_secs(1));
fn_factory_with_config(move |cfg: MqttState<St>| {
let time = time.clone();
let state = cfg.session().clone();
let timeout = cfg.inner.timeout;
let inflight = cfg.inner.in_flight;
// create services
let fut = join3(
publish.new_service(state.clone()),
subscribe.new_service(state.clone()),
unsubscribe.new_service(state.clone()),
);
async move {
let (publish, subscribe, unsubscribe) = fut.await;
// mqtt dispatcher
Ok(Dispatcher::new(
// keep-alive connection
pipeline(KeepAliveService::new(timeout, time, || {
MqttError::KeepAliveTimeout
}))
.and_then(
// limit number of in-flight messages
InFlightService::new(
inflight,
// mqtt spec requires ack ordering, so enforce response ordering
InOrder::service(publish?).map_err(|e| match e {
InOrderError::Service(e) => e,
InOrderError::Disconnected => MqttError::Disconnected,
}),
),
),
subscribe?,
unsubscribe?,
))
}
})
}
/// PUBLIS/SUBSCRIBER/UNSUBSCRIBER packets dispatcher
pub(crate) struct Dispatcher<St, T: Service> {
publish: T,
subscribe: boxed::BoxService<Subscribe<St>, SubscribeResult, T::Error>,
unsubscribe: boxed::BoxService<Unsubscribe<St>, (), T::Error>,
}
impl<St, T> Dispatcher<St, T>
where
T: Service<Request = Publish<St>, Response = ()>,
{
pub(crate) fn new(
publish: T,
subscribe: boxed::BoxService<Subscribe<St>, SubscribeResult, T::Error>,
unsubscribe: boxed::BoxService<Unsubscribe<St>, (), T::Error>,
) -> Self {
Self {
publish,
subscribe,
unsubscribe,
}
}
}
impl<St, T> Service for Dispatcher<St, T>
where
T: Service<Request = Publish<St>, Response = ()>,
T::Error: 'static,
{
type Request = ioframe::Item<MqttState<St>, mqtt::Codec>;
type Response = Option<mqtt::Packet>;
type Error = T::Error;
type Future = Either<
Either<
Ready<Result<Self::Response, T::Error>>,
LocalBoxFuture<'static, Result<Self::Response, T::Error>>,
>,
PublishResponse<T::Future, T::Error>,
>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let res1 = self.publish.poll_ready(cx)?;
let res2 = self.subscribe.poll_ready(cx)?;
let res3 = self.unsubscribe.poll_ready(cx)?;
if res1.is_pending() || res2.is_pending() || res3.is_pending() {
Poll::Pending
} else {
Poll::Ready(Ok(()))
}
}
fn call(&mut self, req: ioframe::Item<MqttState<St>, mqtt::Codec>) -> Self::Future {
let (mut state, _, packet) = req.into_parts();
log::trace!("Dispatch packet: {:#?}", packet);
match packet {
mqtt::Packet::PingRequest => {
Either::Left(Either::Left(ok(Some(mqtt::Packet::PingResponse))))
}
mqtt::Packet::Disconnect => Either::Left(Either::Left(ok(None))),
mqtt::Packet::Publish(publish) => {
let packet_id = publish.packet_id;
Either::Right(PublishResponse {
packet_id,
fut: self.publish.call(Publish::new(state, publish)),
_t: PhantomData,
})
}
mqtt::Packet::PublishAck { packet_id } => {
state.inner.get_mut().sink.complete_publish_qos1(packet_id);
Either::Left(Either::Left(ok(None)))
}
mqtt::Packet::Subscribe {
packet_id,
topic_filters,
} => Either::Left(Either::Right(
SubscribeResponse {
packet_id,
fut: self.subscribe.call(Subscribe::new(state, topic_filters)),
}
.boxed_local(),
)),
mqtt::Packet::Unsubscribe {
packet_id,
topic_filters,
} => Either::Left(Either::Right(
self.unsubscribe
.call(Unsubscribe::new(state, topic_filters))
.map(move |_| Ok(Some(mqtt::Packet::UnsubscribeAck { packet_id })))
.boxed_local(),
)),
_ => Either::Left(Either::Left(ok(None))),
}
}
}
/// Publish service response future
#[pin_project::pin_project]
pub(crate) struct PublishResponse<T, E> {
#[pin]
fut: T,
packet_id: Option<NonZeroU16>,
_t: PhantomData<E>,
}
impl<T, E> Future for PublishResponse<T, E>
where
T: Future<Output = Result<(), E>>,
{
type Output = Result<Option<mqtt::Packet>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
ready!(this.fut.poll(cx))?;
if let Some(packet_id) = this.packet_id {
Poll::Ready(Ok(Some(mqtt::Packet::PublishAck {
packet_id: *packet_id,
})))
} else {
Poll::Ready(Ok(None))
}
}
}
/// Subscribe service response future
pub(crate) struct SubscribeResponse<E> {
fut: LocalBoxFuture<'static, Result<SubscribeResult, E>>,
packet_id: NonZeroU16,
}
impl<E> Future for SubscribeResponse<E> {
type Output = Result<Option<mqtt::Packet>, E>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let res = ready!(Pin::new(&mut self.fut).poll(cx))?;
Poll::Ready(Ok(Some(mqtt::Packet::SubscribeAck {
status: res.codes,
packet_id: self.packet_id,
})))
}
}

34
actix-mqtt/src/error.rs Executable file
View File

@ -0,0 +1,34 @@
use std::io;
/// Errors which can occur when attempting to handle mqtt connection.
#[derive(Debug)]
pub enum MqttError<E> {
/// Message handler service error
Service(E),
/// Mqtt parse error
Protocol(mqtt_codec::ParseError),
/// Unexpected packet
Unexpected(mqtt_codec::Packet, &'static str),
/// "SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier [MQTT-2.3.1-1]."
PacketIdRequired,
/// Keep alive timeout
KeepAliveTimeout,
/// Handshake timeout
HandshakeTimeout,
/// Peer disconnect
Disconnected,
/// Unexpected io error
Io(io::Error),
}
impl<E> From<mqtt_codec::ParseError> for MqttError<E> {
fn from(err: mqtt_codec::ParseError) -> Self {
MqttError::Protocol(err)
}
}
impl<E> From<io::Error> for MqttError<E> {
fn from(err: io::Error) -> Self {
MqttError::Io(err)
}
}

23
actix-mqtt/src/lib.rs Executable file
View File

@ -0,0 +1,23 @@
#![allow(clippy::type_complexity, clippy::new_ret_no_self)]
//! MQTT v3.1 Server framework
mod cell;
pub mod client;
mod connect;
mod default;
mod dispatcher;
mod error;
mod publish;
mod router;
mod server;
mod sink;
mod subs;
pub use self::client::Client;
pub use self::connect::{Connect, ConnectAck};
pub use self::error::MqttError;
pub use self::publish::Publish;
pub use self::router::Router;
pub use self::server::MqttServer;
pub use self::sink::MqttSink;
pub use self::subs::{Subscribe, SubscribeIter, SubscribeResult, Subscription, Unsubscribe};

137
actix-mqtt/src/publish.rs Executable file
View File

@ -0,0 +1,137 @@
use std::convert::TryFrom;
use std::num::NonZeroU16;
use actix_router::Path;
use bytes::Bytes;
use bytestring::ByteString;
use mqtt_codec as mqtt;
use serde::de::DeserializeOwned;
use serde_json::Error as JsonError;
use crate::dispatcher::MqttState;
use crate::sink::MqttSink;
/// Publish message
pub struct Publish<S> {
publish: mqtt::Publish,
sink: MqttSink,
state: MqttState<S>,
topic: Path<ByteString>,
query: Option<ByteString>,
}
impl<S> Publish<S> {
pub(crate) fn new(state: MqttState<S>, publish: mqtt::Publish) -> Self {
let (topic, query) = if let Some(pos) = publish.topic.find('?') {
(
ByteString::try_from(publish.topic.get_ref().slice(0..pos)).unwrap(),
Some(
ByteString::try_from(
publish.topic.get_ref().slice(pos + 1..publish.topic.len()),
)
.unwrap(),
),
)
} else {
(publish.topic.clone(), None)
};
let topic = Path::new(topic);
let sink = state.sink().clone();
Self {
sink,
publish,
state,
topic,
query,
}
}
#[inline]
/// this might be re-delivery of an earlier attempt to send the Packet.
pub fn dup(&self) -> bool {
self.publish.dup
}
#[inline]
pub fn retain(&self) -> bool {
self.publish.retain
}
#[inline]
/// the level of assurance for delivery of an Application Message.
pub fn qos(&self) -> mqtt::QoS {
self.publish.qos
}
#[inline]
/// the information channel to which payload data is published.
pub fn publish_topic(&self) -> &str {
&self.publish.topic
}
#[inline]
/// returns reference to a connection session
pub fn session(&self) -> &S {
self.state.session()
}
#[inline]
/// returns mutable reference to a connection session
pub fn session_mut(&mut self) -> &mut S {
self.state.session_mut()
}
#[inline]
/// only present in PUBLISH Packets where the QoS level is 1 or 2.
pub fn id(&self) -> Option<NonZeroU16> {
self.publish.packet_id
}
#[inline]
pub fn topic(&self) -> &Path<ByteString> {
&self.topic
}
#[inline]
pub fn topic_mut(&mut self) -> &mut Path<ByteString> {
&mut self.topic
}
#[inline]
pub fn query(&self) -> &str {
self.query.as_ref().map(|s| s.as_ref()).unwrap_or("")
}
#[inline]
pub fn packet(&self) -> &mqtt::Publish {
&self.publish
}
#[inline]
/// the Application Message that is being published.
pub fn payload(&self) -> &Bytes {
&self.publish.payload
}
/// Extract Bytes from packet payload
pub fn take_payload(&self) -> Bytes {
self.publish.payload.clone()
}
#[inline]
/// Mqtt client sink object
pub fn sink(&self) -> &MqttSink {
&self.sink
}
/// Loads and parse `application/json` encoded body.
pub fn json<T: DeserializeOwned>(&mut self) -> Result<T, JsonError> {
serde_json::from_slice(&self.publish.payload)
}
}
impl<S> std::fmt::Debug for Publish<S> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
self.publish.fmt(f)
}
}

206
actix-mqtt/src/router.rs Executable file
View File

@ -0,0 +1,206 @@
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use actix_router::{IntoPattern, RouterBuilder};
use actix_service::boxed::{self, BoxService, BoxServiceFactory};
use actix_service::{fn_service, IntoServiceFactory, Service, ServiceFactory};
use futures::future::{join_all, ok, JoinAll, LocalBoxFuture};
use crate::publish::Publish;
type Handler<S, E> = BoxServiceFactory<S, Publish<S>, (), E, E>;
type HandlerService<S, E> = BoxService<Publish<S>, (), E>;
/// Router - structure that follows the builder pattern
/// for building publish packet router instances for mqtt server.
pub struct Router<S, E> {
router: RouterBuilder<usize>,
handlers: Vec<Handler<S, E>>,
default: Handler<S, E>,
}
impl<S, E> Router<S, E>
where
S: Clone + 'static,
E: 'static,
{
/// Create mqtt application.
///
/// **Note** Default service acks all publish packets
pub fn new() -> Self {
Router {
router: actix_router::Router::build(),
handlers: Vec::new(),
default: boxed::factory(
fn_service(|p: Publish<S>| {
log::warn!("Unknown topic {:?}", p.publish_topic());
ok::<_, E>(())
})
.map_init_err(|_| panic!()),
),
}
}
/// Configure mqtt resource for a specific topic.
pub fn resource<T, F, U: 'static>(mut self, address: T, service: F) -> Self
where
T: IntoPattern,
F: IntoServiceFactory<U>,
U: ServiceFactory<Config = S, Request = Publish<S>, Response = (), Error = E>,
E: From<U::InitError>,
{
self.router.path(address, self.handlers.len());
self.handlers
.push(boxed::factory(service.into_factory().map_init_err(E::from)));
self
}
/// Default service to be used if no matching resource could be found.
pub fn default_resource<F, U: 'static>(mut self, service: F) -> Self
where
F: IntoServiceFactory<U>,
U: ServiceFactory<
Config = S,
Request = Publish<S>,
Response = (),
Error = E,
InitError = E,
>,
{
self.default = boxed::factory(service.into_factory());
self
}
}
impl<S, E> IntoServiceFactory<RouterFactory<S, E>> for Router<S, E>
where
S: Clone + 'static,
E: 'static,
{
fn into_factory(self) -> RouterFactory<S, E> {
RouterFactory {
router: Rc::new(self.router.finish()),
handlers: self.handlers,
default: self.default,
}
}
}
pub struct RouterFactory<S, E> {
router: Rc<actix_router::Router<usize>>,
handlers: Vec<Handler<S, E>>,
default: Handler<S, E>,
}
impl<S, E> ServiceFactory for RouterFactory<S, E>
where
S: Clone + 'static,
E: 'static,
{
type Config = S;
type Request = Publish<S>;
type Response = ();
type Error = E;
type InitError = E;
type Service = RouterService<S, E>;
type Future = RouterFactoryFut<S, E>;
fn new_service(&self, session: S) -> Self::Future {
let fut: Vec<_> = self
.handlers
.iter()
.map(|h| h.new_service(session.clone()))
.collect();
RouterFactoryFut {
router: self.router.clone(),
handlers: join_all(fut),
default: Some(either::Either::Left(self.default.new_service(session))),
}
}
}
pub struct RouterFactoryFut<S, E> {
router: Rc<actix_router::Router<usize>>,
handlers: JoinAll<LocalBoxFuture<'static, Result<HandlerService<S, E>, E>>>,
default: Option<
either::Either<
LocalBoxFuture<'static, Result<HandlerService<S, E>, E>>,
HandlerService<S, E>,
>,
>,
}
impl<S, E> Future for RouterFactoryFut<S, E> {
type Output = Result<RouterService<S, E>, E>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let res = match self.default.as_mut().unwrap() {
either::Either::Left(ref mut fut) => {
let default = match futures::ready!(Pin::new(fut).poll(cx)) {
Ok(default) => default,
Err(e) => return Poll::Ready(Err(e)),
};
self.default = Some(either::Either::Right(default));
return self.poll(cx);
}
either::Either::Right(_) => futures::ready!(Pin::new(&mut self.handlers).poll(cx)),
};
let mut handlers = Vec::new();
for handler in res {
match handler {
Ok(h) => handlers.push(h),
Err(e) => return Poll::Ready(Err(e)),
}
}
Poll::Ready(Ok(RouterService {
handlers,
router: self.router.clone(),
default: self.default.take().unwrap().right().unwrap(),
}))
}
}
pub struct RouterService<S, E> {
router: Rc<actix_router::Router<usize>>,
handlers: Vec<BoxService<Publish<S>, (), E>>,
default: BoxService<Publish<S>, (), E>,
}
impl<S, E> Service for RouterService<S, E>
where
S: 'static,
E: 'static,
{
type Request = Publish<S>;
type Response = ();
type Error = E;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let mut not_ready = false;
for hnd in &mut self.handlers {
if let Poll::Pending = hnd.poll_ready(cx)? {
not_ready = true;
}
}
if not_ready {
Poll::Pending
} else {
Poll::Ready(Ok(()))
}
}
fn call(&mut self, mut req: Publish<S>) -> Self::Future {
if let Some((idx, _info)) = self.router.recognize(req.topic_mut()) {
self.handlers[*idx].call(req)
} else {
self.default.call(req)
}
}
}

331
actix-mqtt/src/server.rs Executable file
View File

@ -0,0 +1,331 @@
use std::future::Future;
use std::marker::PhantomData;
use std::rc::Rc;
use std::time::Duration;
use actix_codec::{AsyncRead, AsyncWrite};
use actix_ioframe as ioframe;
use actix_service::{apply, apply_fn, boxed, fn_factory, pipeline_factory, unit_config};
use actix_service::{IntoServiceFactory, Service, ServiceFactory};
use actix_utils::timeout::{Timeout, TimeoutError};
use futures::{FutureExt, SinkExt, StreamExt};
use mqtt_codec as mqtt;
use crate::cell::Cell;
use crate::connect::{Connect, ConnectAck};
use crate::default::{SubsNotImplemented, UnsubsNotImplemented};
use crate::dispatcher::{dispatcher, MqttState};
use crate::error::MqttError;
use crate::publish::Publish;
use crate::sink::MqttSink;
use crate::subs::{Subscribe, SubscribeResult, Unsubscribe};
/// Mqtt Server
pub struct MqttServer<Io, St, C: ServiceFactory, U> {
connect: C,
subscribe: boxed::BoxServiceFactory<
St,
Subscribe<St>,
SubscribeResult,
MqttError<C::Error>,
MqttError<C::Error>,
>,
unsubscribe: boxed::BoxServiceFactory<
St,
Unsubscribe<St>,
(),
MqttError<C::Error>,
MqttError<C::Error>,
>,
disconnect: U,
max_size: usize,
inflight: usize,
handshake_timeout: u64,
_t: PhantomData<(Io, St)>,
}
fn default_disconnect<St>(_: St, _: bool) {}
impl<Io, St, C> MqttServer<Io, St, C, ()>
where
St: 'static,
C: ServiceFactory<Config = (), Request = Connect<Io>, Response = ConnectAck<Io, St>>
+ 'static,
{
/// Create server factory and provide connect service
pub fn new<F>(connect: F) -> MqttServer<Io, St, C, impl Fn(St, bool)>
where
F: IntoServiceFactory<C>,
{
MqttServer {
connect: connect.into_factory(),
subscribe: boxed::factory(
pipeline_factory(SubsNotImplemented::default())
.map_err(MqttError::Service)
.map_init_err(MqttError::Service),
),
unsubscribe: boxed::factory(
pipeline_factory(UnsubsNotImplemented::default())
.map_err(MqttError::Service)
.map_init_err(MqttError::Service),
),
max_size: 0,
inflight: 15,
disconnect: default_disconnect,
handshake_timeout: 0,
_t: PhantomData,
}
}
}
impl<Io, St, C, U> MqttServer<Io, St, C, U>
where
St: Clone + 'static,
U: Fn(St, bool) + 'static,
C: ServiceFactory<Config = (), Request = Connect<Io>, Response = ConnectAck<Io, St>>
+ 'static,
{
/// Set handshake timeout in millis.
///
/// Handshake includes `connect` packet and response `connect-ack`.
/// By default handshake timeuot is disabled.
pub fn handshake_timeout(mut self, timeout: u64) -> Self {
self.handshake_timeout = timeout;
self
}
/// Set max inbound frame size.
///
/// If max size is set to `0`, size is unlimited.
/// By default max size is set to `0`
pub fn max_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
/// Number of in-flight concurrent messages.
///
/// in-flight is set to 15 messages
pub fn inflight(mut self, val: usize) -> Self {
self.inflight = val;
self
}
/// Service to execute for subscribe packet
pub fn subscribe<F, Srv>(mut self, subscribe: F) -> Self
where
F: IntoServiceFactory<Srv>,
Srv: ServiceFactory<Config = St, Request = Subscribe<St>, Response = SubscribeResult>
+ 'static,
C::Error: From<Srv::Error> + From<Srv::InitError>,
{
self.subscribe = boxed::factory(
subscribe
.into_factory()
.map_err(|e| MqttError::Service(e.into()))
.map_init_err(|e| MqttError::Service(e.into())),
);
self
}
/// Service to execute for unsubscribe packet
pub fn unsubscribe<F, Srv>(mut self, unsubscribe: F) -> Self
where
F: IntoServiceFactory<Srv>,
Srv: ServiceFactory<Config = St, Request = Unsubscribe<St>, Response = ()> + 'static,
C::Error: From<Srv::Error> + From<Srv::InitError>,
{
self.unsubscribe = boxed::factory(
unsubscribe
.into_factory()
.map_err(|e| MqttError::Service(e.into()))
.map_init_err(|e| MqttError::Service(e.into())),
);
self
}
/// Callback to execute on disconnect
///
/// Second parameter indicates error occured during disconnect.
pub fn disconnect<F, Out>(self, disconnect: F) -> MqttServer<Io, St, C, impl Fn(St, bool)>
where
F: Fn(St, bool) -> Out,
Out: Future + 'static,
{
MqttServer {
connect: self.connect,
subscribe: self.subscribe,
unsubscribe: self.unsubscribe,
max_size: self.max_size,
inflight: self.inflight,
handshake_timeout: self.handshake_timeout,
disconnect: move |st: St, err| {
let fut = disconnect(st, err);
actix_rt::spawn(fut.map(|_| ()));
},
_t: PhantomData,
}
}
/// Set service to execute for publish packet and create service factory
pub fn finish<F, P>(
self,
publish: F,
) -> impl ServiceFactory<Config = (), Request = Io, Response = (), Error = MqttError<C::Error>>
where
Io: AsyncRead + AsyncWrite + 'static,
F: IntoServiceFactory<P>,
P: ServiceFactory<Config = St, Request = Publish<St>, Response = ()> + 'static,
C::Error: From<P::Error> + From<P::InitError>,
{
let connect = self.connect;
let max_size = self.max_size;
let handshake_timeout = self.handshake_timeout;
let disconnect = self.disconnect;
let publish = boxed::factory(
publish
.into_factory()
.map_err(|e| MqttError::Service(e.into()))
.map_init_err(|e| MqttError::Service(e.into())),
);
unit_config(
ioframe::Builder::new()
.factory(connect_service_factory(
connect,
max_size,
self.inflight,
handshake_timeout,
))
.disconnect(move |cfg, err| disconnect(cfg.session().clone(), err))
.finish(dispatcher(
publish,
Rc::new(self.subscribe),
Rc::new(self.unsubscribe),
))
.map_err(|e| match e {
ioframe::ServiceError::Service(e) => e,
ioframe::ServiceError::Encoder(e) => MqttError::Protocol(e),
ioframe::ServiceError::Decoder(e) => MqttError::Protocol(e),
}),
)
}
}
fn connect_service_factory<Io, St, C>(
factory: C,
max_size: usize,
inflight: usize,
handshake_timeout: u64,
) -> impl ServiceFactory<
Config = (),
Request = ioframe::Connect<Io, mqtt::Codec>,
Response = ioframe::ConnectResult<Io, MqttState<St>, mqtt::Codec>,
Error = MqttError<C::Error>,
>
where
Io: AsyncRead + AsyncWrite,
C: ServiceFactory<Config = (), Request = Connect<Io>, Response = ConnectAck<Io, St>>,
{
apply(
Timeout::new(Duration::from_millis(handshake_timeout)),
fn_factory(move || {
let fut = factory.new_service(());
async move {
let service = Cell::new(fut.await?);
Ok::<_, C::InitError>(apply_fn(
service.map_err(MqttError::Service),
move |conn: ioframe::Connect<Io, mqtt::Codec>, service| {
let mut srv = service.clone();
let mut framed = conn.codec(mqtt::Codec::new().max_size(max_size));
async move {
// read first packet
let packet = framed
.next()
.await
.ok_or(MqttError::Disconnected)
.and_then(|res| res.map_err(|e| MqttError::Protocol(e)))?;
match packet {
mqtt::Packet::Connect(connect) => {
let sink = MqttSink::new(framed.sink().clone());
// authenticate mqtt connection
let mut ack = srv
.call(Connect::new(
connect,
framed,
sink.clone(),
inflight,
))
.await?;
match ack.session {
Some(session) => {
log::trace!(
"Sending: {:#?}",
mqtt::Packet::ConnectAck {
session_present: ack.session_present,
return_code:
mqtt::ConnectCode::ConnectionAccepted,
}
);
ack.io
.send(mqtt::Packet::ConnectAck {
session_present: ack.session_present,
return_code:
mqtt::ConnectCode::ConnectionAccepted,
})
.await?;
Ok(ack.io.state(MqttState::new(
session,
sink,
ack.keep_alive,
ack.inflight,
)))
}
None => {
log::trace!(
"Sending: {:#?}",
mqtt::Packet::ConnectAck {
session_present: false,
return_code: ack.return_code,
}
);
ack.io
.send(mqtt::Packet::ConnectAck {
session_present: false,
return_code: ack.return_code,
})
.await?;
Err(MqttError::Disconnected)
}
}
}
packet => {
log::info!(
"MQTT-3.1.0-1: Expected CONNECT packet, received {}",
packet.packet_type()
);
Err(MqttError::Unexpected(
packet,
"MQTT-3.1.0-1: Expected CONNECT packet",
))
}
}
}
},
))
}
}),
)
.map_err(|e| match e {
TimeoutError::Service(e) => e,
TimeoutError::Timeout => MqttError::HandshakeTimeout,
})
}

107
actix-mqtt/src/sink.rs Executable file
View File

@ -0,0 +1,107 @@
use std::collections::VecDeque;
use std::fmt;
use std::num::NonZeroU16;
use actix_ioframe::Sink;
use actix_utils::oneshot;
use bytes::Bytes;
use bytestring::ByteString;
use futures::future::{Future, TryFutureExt};
use mqtt_codec as mqtt;
use crate::cell::Cell;
#[derive(Clone)]
pub struct MqttSink {
sink: Sink<mqtt::Packet>,
pub(crate) inner: Cell<MqttSinkInner>,
}
#[derive(Default)]
pub(crate) struct MqttSinkInner {
pub(crate) idx: u16,
pub(crate) queue: VecDeque<(u16, oneshot::Sender<()>)>,
}
impl MqttSink {
pub(crate) fn new(sink: Sink<mqtt::Packet>) -> Self {
MqttSink {
sink,
inner: Cell::new(MqttSinkInner::default()),
}
}
/// Close mqtt connection
pub fn close(&self) {
self.sink.close();
}
/// Send publish packet with qos set to 0
pub fn publish_qos0(&self, topic: ByteString, payload: Bytes, dup: bool) {
log::trace!("Publish (QoS0) to {:?}", topic);
let publish = mqtt::Publish {
topic,
payload,
dup,
retain: false,
qos: mqtt::QoS::AtMostOnce,
packet_id: None,
};
self.sink.send(mqtt::Packet::Publish(publish));
}
/// Send publish packet
pub fn publish_qos1(
&mut self,
topic: ByteString,
payload: Bytes,
dup: bool,
) -> impl Future<Output = Result<(), ()>> {
let (tx, rx) = oneshot::channel();
let inner = self.inner.get_mut();
inner.idx += 1;
if inner.idx == 0 {
inner.idx = 1
}
inner.queue.push_back((inner.idx, tx));
let publish = mqtt::Packet::Publish(mqtt::Publish {
topic,
payload,
dup,
retain: false,
qos: mqtt::QoS::AtLeastOnce,
packet_id: NonZeroU16::new(inner.idx),
});
log::trace!("Publish (QoS1) to {:#?}", publish);
self.sink.send(publish);
rx.map_err(|_| ())
}
pub(crate) fn complete_publish_qos1(&mut self, packet_id: NonZeroU16) {
if let Some((idx, tx)) = self.inner.get_mut().queue.pop_front() {
if idx != packet_id.get() {
log::trace!(
"MQTT protocol error, packet_id order does not match, expected {}, got: {}",
idx,
packet_id
);
self.close();
} else {
log::trace!("Ack publish packet with id: {}", packet_id);
let _ = tx.send(());
}
} else {
log::trace!("Unexpected PublishAck packet");
self.close();
}
}
}
impl fmt::Debug for MqttSink {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("MqttSink").finish()
}
}

191
actix-mqtt/src/subs.rs Executable file
View File

@ -0,0 +1,191 @@
use std::marker::PhantomData;
use bytestring::ByteString;
use mqtt_codec as mqtt;
use crate::dispatcher::MqttState;
use crate::sink::MqttSink;
/// Subscribe message
pub struct Subscribe<S> {
topics: Vec<(ByteString, mqtt::QoS)>,
codes: Vec<mqtt::SubscribeReturnCode>,
state: MqttState<S>,
}
/// Result of a subscribe message
pub struct SubscribeResult {
pub(crate) codes: Vec<mqtt::SubscribeReturnCode>,
}
impl<S> Subscribe<S> {
pub(crate) fn new(state: MqttState<S>, topics: Vec<(ByteString, mqtt::QoS)>) -> Self {
let mut codes = Vec::with_capacity(topics.len());
(0..topics.len()).for_each(|_| codes.push(mqtt::SubscribeReturnCode::Failure));
Self {
topics,
state,
codes,
}
}
#[inline]
/// returns reference to a connection session
pub fn session(&self) -> &S {
self.state.session()
}
#[inline]
/// returns mutable reference to a connection session
pub fn session_mut(&mut self) -> &mut S {
self.state.session_mut()
}
#[inline]
/// Mqtt client sink object
pub fn sink(&self) -> MqttSink {
self.state.sink().clone()
}
#[inline]
/// returns iterator over subscription topics
pub fn iter_mut(&mut self) -> SubscribeIter<S> {
SubscribeIter {
subs: self as *const _ as *mut _,
entry: 0,
lt: PhantomData,
}
}
#[inline]
/// convert subscription to a result
pub fn into_result(self) -> SubscribeResult {
SubscribeResult { codes: self.codes }
}
}
impl<'a, S> IntoIterator for &'a mut Subscribe<S> {
type Item = Subscription<'a, S>;
type IntoIter = SubscribeIter<'a, S>;
fn into_iter(self) -> SubscribeIter<'a, S> {
self.iter_mut()
}
}
/// Iterator over subscription topics
pub struct SubscribeIter<'a, S> {
subs: *mut Subscribe<S>,
entry: usize,
lt: PhantomData<&'a mut Subscribe<S>>,
}
impl<'a, S> SubscribeIter<'a, S> {
fn next_unsafe(&mut self) -> Option<Subscription<'a, S>> {
let subs = unsafe { &mut *self.subs };
if self.entry < subs.topics.len() {
let s = Subscription {
topic: &subs.topics[self.entry].0,
qos: subs.topics[self.entry].1,
state: subs.state.clone(),
code: &mut subs.codes[self.entry],
};
self.entry += 1;
Some(s)
} else {
None
}
}
}
impl<'a, S> Iterator for SubscribeIter<'a, S> {
type Item = Subscription<'a, S>;
#[inline]
fn next(&mut self) -> Option<Subscription<'a, S>> {
self.next_unsafe()
}
}
/// Subscription topic
pub struct Subscription<'a, S> {
topic: &'a ByteString,
state: MqttState<S>,
qos: mqtt::QoS,
code: &'a mut mqtt::SubscribeReturnCode,
}
impl<'a, S> Subscription<'a, S> {
#[inline]
/// reference to a connection session
pub fn session(&self) -> &S {
self.state.session()
}
#[inline]
/// mutable reference to a connection session
pub fn session_mut(&mut self) -> &mut S {
self.state.session_mut()
}
#[inline]
/// subscription topic
pub fn topic(&self) -> &'a ByteString {
&self.topic
}
#[inline]
/// the level of assurance for delivery of an Application Message.
pub fn qos(&self) -> mqtt::QoS {
self.qos
}
#[inline]
/// fail to subscribe to the topic
pub fn fail(&mut self) {
*self.code = mqtt::SubscribeReturnCode::Failure
}
#[inline]
/// subscribe to a topic with specific qos
pub fn subscribe(&mut self, qos: mqtt::QoS) {
*self.code = mqtt::SubscribeReturnCode::Success(qos)
}
}
/// Unsubscribe message
pub struct Unsubscribe<S> {
state: MqttState<S>,
topics: Vec<ByteString>,
}
impl<S> Unsubscribe<S> {
pub(crate) fn new(state: MqttState<S>, topics: Vec<ByteString>) -> Self {
Self { topics, state }
}
#[inline]
/// reference to a connection session
pub fn session(&self) -> &S {
self.state.session()
}
#[inline]
/// mutable reference to a connection session
pub fn session_mut(&mut self) -> &mut S {
self.state.session_mut()
}
#[inline]
/// Mqtt client sink object
pub fn sink(&self) -> MqttSink {
self.state.sink().clone()
}
/// returns iterator over unsubscribe topics
pub fn iter(&self) -> impl Iterator<Item = &ByteString> {
self.topics.iter()
}
}

52
actix-mqtt/tests/test_server.rs Executable file
View File

@ -0,0 +1,52 @@
use actix_service::Service;
use actix_testing::TestServer;
use bytes::Bytes;
use bytestring::ByteString;
use futures::future::ok;
use actix_mqtt::{client, Connect, ConnectAck, MqttServer, Publish};
#[derive(Clone)]
struct Session;
async fn connect<Io>(packet: Connect<Io>) -> Result<ConnectAck<Io, Session>, ()> {
println!("CONNECT: {:?}", packet);
Ok(packet.ack(Session, false))
}
#[actix_rt::test]
async fn test_simple() -> std::io::Result<()> {
std::env::set_var(
"RUST_LOG",
"actix_codec=info,actix_server=trace,actix_connector=trace",
);
env_logger::init();
let srv = TestServer::with(|| MqttServer::new(connect).finish(|_t| ok(())));
#[derive(Clone)]
struct ClientSession;
let mut client = client::Client::new(ByteString::from_static("user"))
.state(|ack: client::ConnectAck<_>| async move {
ack.sink()
.publish_qos0(ByteString::from_static("#"), Bytes::new(), false);
ack.sink().close();
Ok(ack.state(ClientSession))
})
.finish(|_t: Publish<_>| {
async {
// t.sink().close();
Ok(())
}
});
let conn = actix_connect::default_connector()
.call(actix_connect::Connect::with(String::new(), srv.addr()))
.await
.unwrap();
client.call(conn.into_parts().0).await.unwrap();
Ok(())
}