use std::{
    cell::Cell,
    task::{Context, Poll},
};

use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_http::{
    body::BodySize,
    h1,
    ws::{self, CloseCode, Frame, Item, Message},
    Error, HttpService, Request, Response,
};
use actix_http_test::test_server;
use actix_service::{fn_factory, Service};
use bytes::Bytes;
use futures_core::future::LocalBoxFuture;
use futures_util::{SinkExt as _, StreamExt as _};

#[derive(Clone)]
struct WsService(Cell<bool>);

impl WsService {
    fn new() -> Self {
        WsService(Cell::new(false))
    }

    fn set_polled(&self) {
        self.0.set(true);
    }

    fn was_polled(&self) -> bool {
        self.0.get()
    }
}

impl<T> Service<(Request, Framed<T, h1::Codec>)> for WsService
where
    T: AsyncRead + AsyncWrite + Unpin + 'static,
{
    type Response = ();
    type Error = Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.set_polled();
        Poll::Ready(Ok(()))
    }

    fn call(&self, (req, mut framed): (Request, Framed<T, h1::Codec>)) -> Self::Future {
        assert!(self.was_polled());

        Box::pin(async move {
            let res = ws::handshake(req.head())?.message_body(())?;

            framed.send((res, BodySize::None).into()).await?;

            let framed = framed.replace_codec(ws::Codec::new());

            ws::Dispatcher::with(framed, service).await?;

            Ok(())
        })
    }
}

async fn service(msg: Frame) -> Result<Message, Error> {
    let msg = match msg {
        Frame::Ping(msg) => Message::Pong(msg),
        Frame::Text(text) => {
            Message::Text(String::from_utf8_lossy(&text).into_owned().into())
        }
        Frame::Binary(bin) => Message::Binary(bin),
        Frame::Continuation(item) => Message::Continuation(item),
        Frame::Close(reason) => Message::Close(reason),
        _ => return Err(Error::from(ws::ProtocolError::BadOpCode)),
    };

    Ok(msg)
}

#[actix_rt::test]
async fn test_simple() {
    let mut srv = test_server(|| {
        HttpService::build()
            .upgrade(fn_factory(|| async { Ok::<_, ()>(WsService::new()) }))
            .finish(|_| async { Ok::<_, ()>(Response::not_found()) })
            .tcp()
    })
    .await;

    // client service
    let mut framed = srv.ws().await.unwrap();
    framed.send(Message::Text("text".into())).await.unwrap();

    let item = framed.next().await.unwrap().unwrap();
    assert_eq!(item, Frame::Text(Bytes::from_static(b"text")));

    framed.send(Message::Binary("text".into())).await.unwrap();

    let item = framed.next().await.unwrap().unwrap();
    assert_eq!(item, Frame::Binary(Bytes::from_static(&b"text"[..])));

    framed.send(Message::Ping("text".into())).await.unwrap();
    let item = framed.next().await.unwrap().unwrap();
    assert_eq!(item, Frame::Pong("text".to_string().into()));

    framed
        .send(Message::Continuation(Item::FirstText("text".into())))
        .await
        .unwrap();
    let item = framed.next().await.unwrap().unwrap();
    assert_eq!(
        item,
        Frame::Continuation(Item::FirstText(Bytes::from_static(b"text")))
    );

    assert!(framed
        .send(Message::Continuation(Item::FirstText("text".into())))
        .await
        .is_err());
    assert!(framed
        .send(Message::Continuation(Item::FirstBinary("text".into())))
        .await
        .is_err());

    framed
        .send(Message::Continuation(Item::Continue("text".into())))
        .await
        .unwrap();
    let item = framed.next().await.unwrap().unwrap();
    assert_eq!(
        item,
        Frame::Continuation(Item::Continue(Bytes::from_static(b"text")))
    );

    framed
        .send(Message::Continuation(Item::Last("text".into())))
        .await
        .unwrap();
    let item = framed.next().await.unwrap().unwrap();
    assert_eq!(
        item,
        Frame::Continuation(Item::Last(Bytes::from_static(b"text")))
    );

    assert!(framed
        .send(Message::Continuation(Item::Continue("text".into())))
        .await
        .is_err());

    assert!(framed
        .send(Message::Continuation(Item::Last("text".into())))
        .await
        .is_err());

    framed
        .send(Message::Close(Some(CloseCode::Normal.into())))
        .await
        .unwrap();

    let item = framed.next().await.unwrap().unwrap();
    assert_eq!(item, Frame::Close(Some(CloseCode::Normal.into())));
}