use bitflags::bitflags;
use bytes::{Bytes, BytesMut};
use bytestring::ByteString;
use tokio_util::codec;
use tracing::error;

#[cfg(feature = "compress-ws-deflate")]
use super::deflate::{
    DeflateCompressionContext, DeflateContext, DeflateDecompressionContext, RSV_BIT_DEFLATE_FLAG,
};
use super::{
    frame::Parser,
    proto::{CloseReason, OpCode, RsvBits},
    ProtocolError,
};

/// A WebSocket message.
#[derive(Debug, PartialEq, Eq)]
pub enum Message {
    /// Text message.
    Text(ByteString),

    /// Binary message.
    Binary(Bytes),

    /// Continuation.
    Continuation(Item),

    /// Ping message.
    Ping(Bytes),

    /// Pong message.
    Pong(Bytes),

    /// Close message with optional reason.
    Close(Option<CloseReason>),

    /// No-op. Useful for low-level services.
    Nop,
}

/// A WebSocket frame.
#[derive(Debug, PartialEq, Eq)]
pub enum Frame {
    /// Text frame. Note that the codec does not validate UTF-8 encoding.
    Text(Bytes),

    /// Binary frame.
    Binary(Bytes),

    /// Continuation.
    Continuation(Item),

    /// Ping message.
    Ping(Bytes),

    /// Pong message.
    Pong(Bytes),

    /// Close message with optional reason.
    Close(Option<CloseReason>),
}

/// A WebSocket continuation item.
#[derive(Debug, PartialEq, Eq)]
pub enum Item {
    FirstText(Bytes),
    FirstBinary(Bytes),
    Continue(Bytes),
    Last(Bytes),
}

bitflags! {
    #[derive(Debug, Clone, Copy)]
    struct Flags: u8 {
        const SERVER         = 0b0000_0001;
        const CONTINUATION   = 0b0000_0010;
        const W_CONTINUATION = 0b0000_0100;
    }
}

/// WebSocket message encoder.
#[derive(Debug, Clone)]
pub struct Encoder {
    flags: Flags,

    #[cfg(feature = "compress-ws-deflate")]
    deflate_compress: Option<DeflateCompressionContext>,
}

impl Encoder {
    /// Create new WebSocket frames encoder.
    pub const fn new() -> Encoder {
        Encoder {
            flags: Flags::SERVER,

            #[cfg(feature = "compress-ws-deflate")]
            deflate_compress: None,
        }
    }

    /// Create new WebSocket frames encoder with `permessage-deflate` extension support.
    #[cfg(feature = "compress-ws-deflate")]
    pub fn new_deflate(compress: DeflateCompressionContext) -> Encoder {
        Encoder {
            flags: Flags::SERVER,

            deflate_compress: Some(compress),
        }
    }

    fn set_client_mode(mut self) -> Self {
        self.flags = Flags::empty();
        self
    }

    #[cfg(feature = "compress-ws-deflate")]
    fn set_client_mode_deflate(
        mut self,
        remote_no_context_takeover: bool,
        remote_max_window_bits: u8,
    ) -> Self {
        self.deflate_compress = self
            .deflate_compress
            .map(|c| c.reset_with(remote_no_context_takeover, remote_max_window_bits));
        self
    }

    #[cfg(feature = "compress-ws-deflate")]
    fn process_payload(
        &mut self,
        fin: bool,
        bytes: Bytes,
    ) -> Result<(Bytes, RsvBits), ProtocolError> {
        if let Some(compress) = &mut self.deflate_compress {
            Ok((compress.compress(fin, bytes)?, RSV_BIT_DEFLATE_FLAG))
        } else {
            Ok((bytes, RsvBits::empty()))
        }
    }

    #[cfg(not(feature = "compress-ws-deflate"))]
    fn process_payload(
        &mut self,
        _fin: bool,
        bytes: Bytes,
    ) -> Result<(Bytes, RsvBits), ProtocolError> {
        Ok((bytes, RsvBits::empty()))
    }
}

impl Default for Encoder {
    fn default() -> Self {
        Self::new()
    }
}

impl codec::Encoder<Message> for Encoder {
    type Error = ProtocolError;

    fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
        match item {
            Message::Text(txt) => {
                let (bytes, rsv_bits) = self.process_payload(true, txt.into_bytes())?;

                Parser::write_message(
                    dst,
                    bytes,
                    OpCode::Text,
                    rsv_bits,
                    true,
                    !self.flags.contains(Flags::SERVER),
                )
            }
            Message::Binary(bin) => {
                let (bin, rsv_bits) = self.process_payload(true, bin)?;

                Parser::write_message(
                    dst,
                    bin,
                    OpCode::Binary,
                    rsv_bits,
                    true,
                    !self.flags.contains(Flags::SERVER),
                )
            }
            Message::Ping(txt) => Parser::write_message(
                dst,
                txt,
                OpCode::Ping,
                RsvBits::empty(),
                true,
                !self.flags.contains(Flags::SERVER),
            ),
            Message::Pong(txt) => Parser::write_message(
                dst,
                txt,
                OpCode::Pong,
                RsvBits::empty(),
                true,
                !self.flags.contains(Flags::SERVER),
            ),
            Message::Close(reason) => Parser::write_close(
                dst,
                reason,
                RsvBits::empty(),
                !self.flags.contains(Flags::SERVER),
            ),
            Message::Continuation(cont) => match cont {
                Item::FirstText(data) => {
                    if self.flags.contains(Flags::W_CONTINUATION) {
                        return Err(ProtocolError::ContinuationStarted);
                    } else {
                        let (data, rsv_bits) = self.process_payload(false, data)?;

                        self.flags.insert(Flags::W_CONTINUATION);
                        Parser::write_message(
                            dst,
                            data,
                            OpCode::Text,
                            rsv_bits,
                            false,
                            !self.flags.contains(Flags::SERVER),
                        )
                    }
                }
                Item::FirstBinary(data) => {
                    if self.flags.contains(Flags::W_CONTINUATION) {
                        return Err(ProtocolError::ContinuationStarted);
                    } else {
                        let (data, rsv_bits) = self.process_payload(false, data)?;

                        self.flags.insert(Flags::W_CONTINUATION);
                        Parser::write_message(
                            dst,
                            data,
                            OpCode::Binary,
                            rsv_bits,
                            false,
                            !self.flags.contains(Flags::SERVER),
                        )
                    }
                }
                Item::Continue(data) => {
                    if self.flags.contains(Flags::W_CONTINUATION) {
                        let (data, rsv_bits) = self.process_payload(false, data)?;

                        Parser::write_message(
                            dst,
                            data,
                            OpCode::Continue,
                            rsv_bits,
                            false,
                            !self.flags.contains(Flags::SERVER),
                        )
                    } else {
                        return Err(ProtocolError::ContinuationNotStarted);
                    }
                }
                Item::Last(data) => {
                    if self.flags.contains(Flags::W_CONTINUATION) {
                        self.flags.remove(Flags::W_CONTINUATION);

                        let (data, rsv_bits) = self.process_payload(true, data)?;

                        Parser::write_message(
                            dst,
                            data,
                            OpCode::Continue,
                            rsv_bits,
                            true,
                            !self.flags.contains(Flags::SERVER),
                        )
                    } else {
                        return Err(ProtocolError::ContinuationNotStarted);
                    }
                }
            },
            Message::Nop => {}
        }
        Ok(())
    }
}

/// WebSocket message decoder.
#[derive(Debug, Clone)]
pub struct Decoder {
    flags: Flags,
    max_size: usize,

    #[cfg(feature = "compress-ws-deflate")]
    deflate_decompress: Option<DeflateDecompressionContext>,
}

impl Decoder {
    /// Create new WebSocket frames decoder.
    pub const fn new() -> Decoder {
        Decoder {
            flags: Flags::SERVER,
            max_size: 65_536,

            #[cfg(feature = "compress-ws-deflate")]
            deflate_decompress: None,
        }
    }

    /// Create new WebSocket frames decoder with `permessage-deflate` extension support.
    #[cfg(feature = "compress-ws-deflate")]
    pub fn new_deflate(decompress: DeflateDecompressionContext) -> Decoder {
        Decoder {
            flags: Flags::SERVER,
            max_size: 65_536,

            deflate_decompress: Some(decompress),
        }
    }

    fn set_client_mode(mut self) -> Self {
        self.flags = Flags::empty();
        self
    }

    #[cfg(feature = "compress-ws-deflate")]
    fn set_client_mode_deflate(
        mut self,
        local_no_context_takeover: bool,
        local_max_window_bits: u8,
    ) -> Self {
        if let Some(decompress) = &mut self.deflate_decompress {
            decompress.reset_with(local_no_context_takeover, local_max_window_bits);
        }

        self
    }

    fn set_max_size(mut self, size: usize) -> Self {
        self.max_size = size;
        self
    }

    #[cfg(feature = "compress-ws-deflate")]
    fn process_payload(
        &mut self,
        fin: bool,
        opcode: OpCode,
        rsv_bits: RsvBits,
        bytes: Option<Bytes>,
    ) -> Result<Option<Bytes>, ProtocolError> {
        if let Some(bytes) = bytes {
            if let Some(decompress) = &mut self.deflate_decompress {
                Ok(Some(decompress.decompress(fin, opcode, rsv_bits, bytes)?))
            } else {
                Ok(Some(bytes))
            }
        } else {
            Ok(None)
        }
    }

    #[cfg(not(feature = "compress-ws-deflate"))]
    fn process_payload(
        &mut self,
        _fin: bool,
        _opcode: OpCode,
        _rsv_bits: RsvBits,
        bytes: Option<Bytes>,
    ) -> Result<Option<Bytes>, ProtocolError> {
        Ok(bytes)
    }
}

impl Default for Decoder {
    fn default() -> Self {
        Self::new()
    }
}

impl codec::Decoder for Decoder {
    type Item = Frame;
    type Error = ProtocolError;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size) {
            Ok(Some((finished, opcode, rsv_bits, payload))) => {
                let payload = self.process_payload(
                    finished,
                    opcode,
                    rsv_bits,
                    payload.map(BytesMut::freeze),
                )?;

                // continuation is not supported
                if !finished {
                    return match opcode {
                        OpCode::Continue => {
                            if self.flags.contains(Flags::CONTINUATION) {
                                Ok(Some(Frame::Continuation(Item::Continue(
                                    payload.unwrap_or_else(Bytes::new),
                                ))))
                            } else {
                                Err(ProtocolError::ContinuationNotStarted)
                            }
                        }
                        OpCode::Binary => {
                            if !self.flags.contains(Flags::CONTINUATION) {
                                self.flags.insert(Flags::CONTINUATION);
                                Ok(Some(Frame::Continuation(Item::FirstBinary(
                                    payload.unwrap_or_else(Bytes::new),
                                ))))
                            } else {
                                Err(ProtocolError::ContinuationStarted)
                            }
                        }
                        OpCode::Text => {
                            if !self.flags.contains(Flags::CONTINUATION) {
                                self.flags.insert(Flags::CONTINUATION);
                                Ok(Some(Frame::Continuation(Item::FirstText(
                                    payload.unwrap_or_else(Bytes::new),
                                ))))
                            } else {
                                Err(ProtocolError::ContinuationStarted)
                            }
                        }
                        _ => {
                            error!("Unfinished fragment {:?}", opcode);
                            Err(ProtocolError::ContinuationFragment(opcode))
                        }
                    };
                }

                match opcode {
                    OpCode::Continue => {
                        if self.flags.contains(Flags::CONTINUATION) {
                            self.flags.remove(Flags::CONTINUATION);
                            Ok(Some(Frame::Continuation(Item::Last(
                                payload.unwrap_or_else(Bytes::new),
                            ))))
                        } else {
                            Err(ProtocolError::ContinuationNotStarted)
                        }
                    }
                    OpCode::Bad => Err(ProtocolError::BadOpCode),
                    OpCode::Close => {
                        if let Some(ref pl) = payload {
                            let close_reason = Parser::parse_close_payload(pl);
                            Ok(Some(Frame::Close(close_reason)))
                        } else {
                            Ok(Some(Frame::Close(None)))
                        }
                    }
                    OpCode::Ping => Ok(Some(Frame::Ping(payload.unwrap_or_else(Bytes::new)))),
                    OpCode::Pong => Ok(Some(Frame::Pong(payload.unwrap_or_else(Bytes::new)))),
                    OpCode::Binary => Ok(Some(Frame::Binary(payload.unwrap_or_else(Bytes::new)))),
                    OpCode::Text => Ok(Some(Frame::Text(payload.unwrap_or_else(Bytes::new)))),
                }
            }
            Ok(None) => Ok(None),
            Err(err) => Err(err),
        }
    }
}

/// WebSocket protocol codec.
#[derive(Debug, Default, Clone)]
pub struct Codec {
    encoder: Encoder,
    decoder: Decoder,
}

impl Codec {
    /// Create new WebSocket frames codec.
    pub fn new() -> Codec {
        Codec {
            encoder: Encoder::new(),
            decoder: Decoder::new(),
        }
    }

    /// Create new WebSocket frames codec with DEFLATE compression.
    #[cfg(feature = "compress-ws-deflate")]
    pub fn new_deflate(context: DeflateContext) -> Codec {
        let DeflateContext {
            compress,
            decompress,
        } = context;

        Codec {
            encoder: Encoder::new_deflate(compress),
            decoder: Decoder::new_deflate(decompress),
        }
    }

    /// Set max frame size.
    ///
    /// By default max size is set to 64KiB.
    #[must_use = "This returns the a new Codec, without modifying the original."]
    pub fn max_size(self, size: usize) -> Self {
        let Self { encoder, decoder } = self;

        Codec {
            encoder,
            decoder: decoder.set_max_size(size),
        }
    }

    /// Set decoder to client mode.
    ///
    /// By default decoder works in server mode.
    #[must_use = "This returns the a new Codec, without modifying the original."]
    pub fn client_mode(self) -> Self {
        let Self {
            mut encoder,
            mut decoder,
        } = self;

        encoder = encoder.set_client_mode();
        decoder = decoder.set_client_mode();
        #[cfg(feature = "compress-ws-deflate")]
        {
            if let Some(decoder) = &decoder.deflate_decompress {
                encoder = encoder.set_client_mode_deflate(
                    decoder.local_no_context_takeover,
                    decoder.local_max_window_bits,
                );
            }
            if let Some(encoder) = &encoder.deflate_compress {
                decoder = decoder.set_client_mode_deflate(
                    encoder.remote_no_context_takeover,
                    encoder.remote_max_window_bits,
                );
            }
        }

        Self { encoder, decoder }
    }
}

impl codec::Decoder for Codec {
    type Item = Frame;
    type Error = ProtocolError;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        self.decoder.decode(src)
    }
}

impl codec::Encoder<Message> for Codec {
    type Error = ProtocolError;

    fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
        self.encoder.encode(item, dst)
    }
}