diff --git a/src/header/value.rs b/src/header/value.rs index 02622c2e..7819e15d 100644 --- a/src/header/value.rs +++ b/src/header/value.rs @@ -23,11 +23,18 @@ pub struct HeaderValue { is_sensitive: bool, } +#[derive(Debug)] +enum InvalidHeaderValueKind { + InvalidByte, + LeadingWhitespace, + TrailingWhitespace +} + /// A possible error when converting a `HeaderValue` from a string or byte /// slice. #[derive(Debug)] pub struct InvalidHeaderValue { - _priv: (), + kind: InvalidHeaderValueKind } /// A possible error when converting a `HeaderValue` from a string or byte @@ -71,7 +78,16 @@ impl HeaderValue { panic!("invalid header value"); } } - + if let Some(&b) = bytes.first() { + if is_whitespace(b) { + panic!("invalid header value"); + } + } + if let Some(&b) = bytes.last() { + if is_whitespace(b) { + panic!("invalid header value"); + } + } HeaderValue { inner: Bytes::from_static(bytes), is_sensitive: false, @@ -183,7 +199,21 @@ impl HeaderValue { for &b in src.as_ref() { if !is_valid(b) { return Err(InvalidHeaderValue { - _priv: (), + kind: InvalidHeaderValueKind::InvalidByte + }); + } + } + if let Some(&b) = src.as_ref().first() { + if is_whitespace(b) { + return Err(InvalidHeaderValue { + kind: InvalidHeaderValueKind::LeadingWhitespace + }); + } + } + if let Some(&b) = src.as_ref().last() { + if is_whitespace(b) { + return Err(InvalidHeaderValue { + kind: InvalidHeaderValueKind::TrailingWhitespace }); } } @@ -565,6 +595,10 @@ fn is_visible_ascii(b: u8) -> bool { b >= 32 && b < 127 || b == b'\t' } +fn is_whitespace(b: u8) -> bool { + b == b' ' || b == b'\t' +} + #[inline] fn is_valid(b: u8) -> bool { b >= 32 && b != 127 || b == b'\t' @@ -578,7 +612,11 @@ impl fmt::Display for InvalidHeaderValue { impl Error for InvalidHeaderValue { fn description(&self) -> &str { - "failed to parse header value" + match self.kind { + InvalidHeaderValueKind::InvalidByte => "failed to parse header value (invalid character)", + InvalidHeaderValueKind::LeadingWhitespace => "failed to parse header value (leading whitespace)", + InvalidHeaderValueKind::TrailingWhitespace => "failed to parse header value (trailing whitespace)" + } } } @@ -763,7 +801,7 @@ impl<'a> PartialOrd for &'a str { #[test] fn test_try_from() { - HeaderValue::try_from(vec![127]).unwrap_err(); + assert_eq!(HeaderValue::try_from(vec![127]).unwrap_err().description(), "failed to parse header value (invalid character)"); } #[test] @@ -784,3 +822,39 @@ fn test_debug() { sensitive.set_sensitive(true); assert_eq!("Sensitive", format!("{:?}", sensitive)); } + +#[test] +fn test_leading_whitespace() { + assert_eq!(HeaderValue::from_str(" A").unwrap_err().description(), "failed to parse header value (leading whitespace)"); + assert_eq!(HeaderValue::from_str("\tA").unwrap_err().description(), "failed to parse header value (leading whitespace)"); +} + +#[test] +#[should_panic(expected = "invalid header value")] +fn test_leading_whitespace_static() { + HeaderValue::from_static(" A"); +} + +#[test] +#[should_panic(expected = "invalid header value")] +fn test_leading_tab_static() { + HeaderValue::from_static("\tA"); +} + +#[test] +fn test_trailing_whitespace() { + assert_eq!(HeaderValue::from_str("A ").unwrap_err().description(), "failed to parse header value (trailing whitespace)"); + assert_eq!(HeaderValue::from_str("A\t").unwrap_err().description(), "failed to parse header value (trailing whitespace)"); +} + +#[test] +#[should_panic(expected = "invalid header value")] +fn test_trailing_whitespace_static() { + HeaderValue::from_static("A "); +} + +#[test] +#[should_panic(expected = "invalid header value")] +fn test_trailing_tab_static() { + HeaderValue::from_static("A\t"); +}