diff --git a/src/lib.rs b/src/lib.rs index 09c27218a..c091139cd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,9 +15,9 @@ //! [`Logger`]: https://docs.rs/actix-web/3.0.2/actix_web/middleware/struct.Logger.html //! [`log`]: https://docs.rs/log //! [`tracing`]: https://docs.rs/tracing -use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform}; -use actix_web::Error; -use futures::future::{ok, Ready}; +use actix_web::dev::{Payload, Service, ServiceRequest, ServiceResponse, Transform}; +use actix_web::{Error, FromRequest, HttpMessage, HttpRequest}; +use futures::future::{ok, ready, Ready}; use futures::task::{Context, Poll}; use std::future::Future; use std::pin::Pin; @@ -112,6 +112,52 @@ pub struct TracingLoggerMiddleware { service: S, } +/// A unique identifier for each incomming request. This ID is added to the logger span, even if +/// the `RequestId` is never extracted. +/// +/// Extracting a `RequestId` when the `TracingLogger` middleware is not registered, will result in +/// a internal server error. +/// +/// # Usage +/// ```rust +/// use actix_web::get; +/// use tracing_actix_web::RequestId; +/// use uuid::Uuid; +/// +/// #[get("/")] +/// async fn index(request_id: RequestId) -> String { +/// format!("{}", request_id) +/// } +/// +/// #[get("/2")] +/// async fn index2(request_id: RequestId) -> String { +/// let uuid: Uuid = request_id.into(); +/// format!("{}", uuid) +/// } +/// ``` +#[derive(Clone, Copy)] +pub struct RequestId(Uuid); + +impl std::ops::Deref for RequestId { + type Target = Uuid; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::convert::Into for RequestId { + fn into(self) -> Uuid { + self.0 + } +} + +impl std::fmt::Display for RequestId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + impl Service for TracingLoggerMiddleware where S: Service, Error = Error>, @@ -133,14 +179,16 @@ where .get("User-Agent") .map(|h| h.to_str().unwrap_or("")) .unwrap_or(""); + let request_id = RequestId(Uuid::new_v4()); let span = tracing::info_span!( "Request", request_path = %req.path(), user_agent = %user_agent, client_ip_address = %req.connection_info().realip_remote_addr().unwrap_or(""), - request_id = %Uuid::new_v4(), + request_id = %request_id.0, status_code = tracing::field::Empty, ); + req.extensions_mut().insert(request_id); let fut = self.service.call(req); Box::pin( async move { @@ -156,3 +204,13 @@ where ) } } + +impl FromRequest for RequestId { + type Error = (); + type Future = Ready>; + type Config = (); + + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + ready(req.extensions().get::().copied().ok_or(())) + } +}