mirror of
https://github.com/fafhrd91/actix-web
synced 2025-01-18 05:41:50 +01:00
add support for specifying protocols on websocket handshake (#835)
* add support for specifying protocols on websocket handshake * separated the handshake function with and without protocols changed protocols type from Vec<&str> to [&str]
This commit is contained in:
parent
e53e9c8ba3
commit
03ca408e94
@ -60,15 +60,43 @@ where
|
|||||||
Ok((addr, res.streaming(out_stream)))
|
Ok((addr, res.streaming(out_stream)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Do websocket handshake and start ws actor.
|
||||||
|
///
|
||||||
|
/// `protocols` is a sequence of known protocols.
|
||||||
|
pub fn start_with_protocols<A, T>(
|
||||||
|
actor: A,
|
||||||
|
protocols: &[&str],
|
||||||
|
req: &HttpRequest,
|
||||||
|
stream: T,
|
||||||
|
) -> Result<HttpResponse, Error>
|
||||||
|
where
|
||||||
|
A: Actor<Context = WebsocketContext<A>> + StreamHandler<Message, ProtocolError>,
|
||||||
|
T: Stream<Item = Bytes, Error = PayloadError> + 'static,
|
||||||
|
{
|
||||||
|
let mut res = handshake_with_protocols(req, protocols)?;
|
||||||
|
Ok(res.streaming(WebsocketContext::create(actor, stream)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prepare `WebSocket` handshake response.
|
||||||
|
///
|
||||||
|
/// This function returns handshake `HttpResponse`, ready to send to peer.
|
||||||
|
/// It does not perform any IO.
|
||||||
|
pub fn handshake(req: &HttpRequest) -> Result<HttpResponseBuilder, HandshakeError> {
|
||||||
|
handshake_with_protocols(req, &[])
|
||||||
|
}
|
||||||
|
|
||||||
/// Prepare `WebSocket` handshake response.
|
/// Prepare `WebSocket` handshake response.
|
||||||
///
|
///
|
||||||
/// This function returns handshake `HttpResponse`, ready to send to peer.
|
/// This function returns handshake `HttpResponse`, ready to send to peer.
|
||||||
/// It does not perform any IO.
|
/// It does not perform any IO.
|
||||||
///
|
///
|
||||||
// /// `protocols` is a sequence of known protocols. On successful handshake,
|
/// `protocols` is a sequence of known protocols. On successful handshake,
|
||||||
// /// the returned response headers contain the first protocol in this list
|
/// the returned response headers contain the first protocol in this list
|
||||||
// /// which the server also knows.
|
/// which the server also knows.
|
||||||
pub fn handshake(req: &HttpRequest) -> Result<HttpResponseBuilder, HandshakeError> {
|
pub fn handshake_with_protocols(
|
||||||
|
req: &HttpRequest,
|
||||||
|
protocols: &[&str],
|
||||||
|
) -> Result<HttpResponseBuilder, HandshakeError> {
|
||||||
// WebSocket accepts only GET
|
// WebSocket accepts only GET
|
||||||
if *req.method() != Method::GET {
|
if *req.method() != Method::GET {
|
||||||
return Err(HandshakeError::GetMethodRequired);
|
return Err(HandshakeError::GetMethodRequired);
|
||||||
@ -117,11 +145,28 @@ pub fn handshake(req: &HttpRequest) -> Result<HttpResponseBuilder, HandshakeErro
|
|||||||
hash_key(key.as_ref())
|
hash_key(key.as_ref())
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS)
|
// check requested protocols
|
||||||
|
let protocol =
|
||||||
|
req.headers()
|
||||||
|
.get(&header::SEC_WEBSOCKET_PROTOCOL)
|
||||||
|
.and_then(|req_protocols| {
|
||||||
|
let req_protocols = req_protocols.to_str().ok()?;
|
||||||
|
req_protocols
|
||||||
|
.split(", ")
|
||||||
|
.find(|req_p| protocols.iter().any(|p| p == req_p))
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut response = HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS)
|
||||||
.upgrade("websocket")
|
.upgrade("websocket")
|
||||||
.header(header::TRANSFER_ENCODING, "chunked")
|
.header(header::TRANSFER_ENCODING, "chunked")
|
||||||
.header(header::SEC_WEBSOCKET_ACCEPT, key.as_str())
|
.header(header::SEC_WEBSOCKET_ACCEPT, key.as_str())
|
||||||
.take())
|
.take();
|
||||||
|
|
||||||
|
if let Some(protocol) = protocol {
|
||||||
|
response.header(&header::SEC_WEBSOCKET_PROTOCOL, protocol);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execution context for `WebSockets` actors
|
/// Execution context for `WebSockets` actors
|
||||||
@ -609,5 +654,87 @@ mod tests {
|
|||||||
StatusCode::SWITCHING_PROTOCOLS,
|
StatusCode::SWITCHING_PROTOCOLS,
|
||||||
handshake(&req).unwrap().finish().status()
|
handshake(&req).unwrap().finish().status()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let req = TestRequest::default()
|
||||||
|
.header(
|
||||||
|
header::UPGRADE,
|
||||||
|
header::HeaderValue::from_static("websocket"),
|
||||||
|
)
|
||||||
|
.header(
|
||||||
|
header::CONNECTION,
|
||||||
|
header::HeaderValue::from_static("upgrade"),
|
||||||
|
)
|
||||||
|
.header(
|
||||||
|
header::SEC_WEBSOCKET_VERSION,
|
||||||
|
header::HeaderValue::from_static("13"),
|
||||||
|
)
|
||||||
|
.header(
|
||||||
|
header::SEC_WEBSOCKET_KEY,
|
||||||
|
header::HeaderValue::from_static("13"),
|
||||||
|
)
|
||||||
|
.header(
|
||||||
|
header::SEC_WEBSOCKET_PROTOCOL,
|
||||||
|
header::HeaderValue::from_static("graphql"),
|
||||||
|
)
|
||||||
|
.to_http_request();
|
||||||
|
|
||||||
|
let protocols = ["graphql"];
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
StatusCode::SWITCHING_PROTOCOLS,
|
||||||
|
handshake_with_protocols(&req, &protocols)
|
||||||
|
.unwrap()
|
||||||
|
.finish()
|
||||||
|
.status()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Some(&header::HeaderValue::from_static("graphql")),
|
||||||
|
handshake_with_protocols(&req, &protocols)
|
||||||
|
.unwrap()
|
||||||
|
.finish()
|
||||||
|
.headers()
|
||||||
|
.get(&header::SEC_WEBSOCKET_PROTOCOL)
|
||||||
|
);
|
||||||
|
|
||||||
|
let req = TestRequest::default()
|
||||||
|
.header(
|
||||||
|
header::UPGRADE,
|
||||||
|
header::HeaderValue::from_static("websocket"),
|
||||||
|
)
|
||||||
|
.header(
|
||||||
|
header::CONNECTION,
|
||||||
|
header::HeaderValue::from_static("upgrade"),
|
||||||
|
)
|
||||||
|
.header(
|
||||||
|
header::SEC_WEBSOCKET_VERSION,
|
||||||
|
header::HeaderValue::from_static("13"),
|
||||||
|
)
|
||||||
|
.header(
|
||||||
|
header::SEC_WEBSOCKET_KEY,
|
||||||
|
header::HeaderValue::from_static("13"),
|
||||||
|
)
|
||||||
|
.header(
|
||||||
|
header::SEC_WEBSOCKET_PROTOCOL,
|
||||||
|
header::HeaderValue::from_static("p1, p2, p3"),
|
||||||
|
)
|
||||||
|
.to_http_request();
|
||||||
|
|
||||||
|
let protocols = vec!["p3", "p2"];
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
StatusCode::SWITCHING_PROTOCOLS,
|
||||||
|
handshake_with_protocols(&req, &protocols)
|
||||||
|
.unwrap()
|
||||||
|
.finish()
|
||||||
|
.status()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Some(&header::HeaderValue::from_static("p2")),
|
||||||
|
handshake_with_protocols(&req, &protocols)
|
||||||
|
.unwrap()
|
||||||
|
.finish()
|
||||||
|
.headers()
|
||||||
|
.get(&header::SEC_WEBSOCKET_PROTOCOL)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user