diff --git a/src/fs.rs b/src/fs.rs index 56534bbf4..0f82818c5 100644 --- a/src/fs.rs +++ b/src/fs.rs @@ -36,6 +36,7 @@ pub struct NamedFile { md: Metadata, modified: Option, cpu_pool: Option, + only_get: bool, } impl NamedFile { @@ -54,7 +55,14 @@ impl NamedFile { let path = path.as_ref().to_path_buf(); let modified = md.modified().ok(); let cpu_pool = None; - Ok(NamedFile{path, file, md, modified, cpu_pool}) + Ok(NamedFile{path, file, md, modified, cpu_pool, only_get: false}) + } + + /// Allow only GET and HEAD methods + #[inline] + pub fn only_get(mut self) -> Self { + self.only_get = true; + self } /// Returns reference to the underlying `File` object. @@ -168,7 +176,7 @@ impl Responder for NamedFile { type Error = io::Error; fn respond_to(self, req: HttpRequest) -> Result { - if *req.method() != Method::GET && *req.method() != Method::HEAD { + if self.only_get && *req.method() != Method::GET && *req.method() != Method::HEAD { return Ok(HttpMethodNotAllowed.build() .header(header::http::CONTENT_TYPE, "text/plain") .header(header::http::ALLOW, "GET, HEAD") @@ -215,7 +223,9 @@ impl Responder for NamedFile { return Ok(resp.status(StatusCode::NOT_MODIFIED).finish().unwrap()) } - if *req.method() == Method::GET { + if *req.method() == Method::HEAD { + Ok(resp.finish().unwrap()) + } else { let reader = ChunkedReadFile { size: self.md.len(), offset: 0, @@ -224,8 +234,6 @@ impl Responder for NamedFile { fut: None, }; Ok(resp.streaming(reader).unwrap()) - } else { - Ok(resp.finish().unwrap()) } } } @@ -487,7 +495,8 @@ impl Handler for StaticFiles { } } else { Ok(FilesystemElement::File( - NamedFile::open(path)?.set_cpu_pool(self.cpu_pool.clone()))) + NamedFile::open(path)? + .set_cpu_pool(self.cpu_pool.clone()).only_get())) } } } @@ -517,10 +526,18 @@ mod tests { let req = TestRequest::default().method(Method::POST).finish(); let file = NamedFile::open("Cargo.toml").unwrap(); - let resp = file.respond_to(req).unwrap(); + let resp = file.only_get().respond_to(req).unwrap(); assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); } + #[test] + fn test_named_file_any_method() { + let req = TestRequest::default().method(Method::POST).finish(); + let file = NamedFile::open("Cargo.toml").unwrap(); + let resp = file.respond_to(req).unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + #[test] fn test_static_files() { let mut st = StaticFiles::new(".", true);