// Licensed under the Apache License, Version 2.0 or the MIT license // , at your // option. This file may not be copied, modified, or distributed // except according to those terms. use enumset::{EnumSet, EnumSetType, enum_set}; use neqo_common::{Header, header::HeadersExt as _}; use crate::{Error, MessageType, Res}; #[derive(EnumSetType, Debug)] enum PseudoHeaderState { Status, Method, Scheme, Authority, Path, Protocol, Regular, } impl PseudoHeaderState { fn is_pseudo(self) -> bool { self != Self::Regular } } impl TryFrom<(MessageType, &str)> for PseudoHeaderState { type Error = Error; fn try_from(v: (MessageType, &str)) -> Res { match v { (MessageType::Response, ":status") => Ok(Self::Status), (MessageType::Request, ":method") => Ok(Self::Method), (MessageType::Request, ":scheme") => Ok(Self::Scheme), (MessageType::Request, ":authority") => Ok(Self::Authority), (MessageType::Request, ":path") => Ok(Self::Path), (MessageType::Request, ":protocol") => Ok(Self::Protocol), (_, _) => Err(Error::InvalidHeader), } } } /// Check whether the response is informational(1xx). /// /// # Errors /// /// Returns an error if response headers do not contain /// a status header or if the value of the header is 101 or cannot be parsed. pub fn is_interim(headers: &[Header]) -> Res { if let Some(h) = headers.iter().take(1).find_header(":status") { let status_code = std::str::from_utf8(h.value()) .map_err(|_| Error::InvalidHeader)? .parse::() .map_err(|_| Error::InvalidHeader)?; if status_code == 101 { // https://datatracker.ietf.org/doc/html/draft-ietf-quic-http#section-4.3 Err(Error::InvalidHeader) } else { Ok((100..200).contains(&status_code)) } } else { Err(Error::InvalidHeader) } } fn track_pseudo( name: &str, result_state: &mut EnumSet, message_type: MessageType, ) -> Res { let new_state = if name.starts_with(':') { if result_state.contains(PseudoHeaderState::Regular) { return Err(Error::InvalidHeader); } PseudoHeaderState::try_from((message_type, name))? } else { PseudoHeaderState::Regular }; let pseudo = new_state.is_pseudo(); if *result_state & new_state == EnumSet::empty() || !pseudo { *result_state |= new_state; Ok(pseudo) } else { Err(Error::InvalidHeader) } } /// Checks if request/response headers are well formed, i.e. contain /// allowed pseudo headers and in a right order, etc. /// /// # Errors /// /// Returns an error if headers are not well formed. pub fn headers_valid(headers: &[Header], message_type: MessageType) -> Res<()> { let mut method_value: Option<&[u8]> = None; let mut protocol_value: Option<&[u8]> = None; let mut scheme_value: Option<&[u8]> = None; let mut pseudo_state = EnumSet::new(); for header in headers { let is_pseudo = track_pseudo(header.name(), &mut pseudo_state, message_type)?; let mut bytes = header.name().bytes(); if is_pseudo { if header.name() == ":method" { method_value = Some(header.value()); } else if header.name() == ":protocol" { protocol_value = Some(header.value()); } else if header.name() == ":scheme" { scheme_value = Some(header.value()); } _ = bytes.next(); } if bytes.any(|b| matches!(b, 0 | 0x10 | 0x13 | 0x3a | 0x41..=0x5a)) { return Err(Error::InvalidHeader); // illegal characters. } } // Clear the regular header bit, since we only check pseudo headers below. pseudo_state.remove(PseudoHeaderState::Regular); let pseudo_header_mask = match message_type { MessageType::Response => enum_set!(PseudoHeaderState::Status), MessageType::Request => { if method_value == Some(b"CONNECT".as_ref()) { let connect_mask = PseudoHeaderState::Method | PseudoHeaderState::Authority; if let Some(protocol) = protocol_value { // For a webtransport CONNECT, the :scheme field must be set to https. if protocol == b"webtransport" && scheme_value != Some(b"https".as_ref()) { return Err(Error::InvalidHeader); } // The CONNECT request for with :protocol included must have the scheme, // authority, and path set. connect_mask | PseudoHeaderState::Scheme | PseudoHeaderState::Path } else { connect_mask } } else { PseudoHeaderState::Method | PseudoHeaderState::Scheme | PseudoHeaderState::Path } } }; if (MessageType::Request == message_type) && pseudo_state.contains(PseudoHeaderState::Protocol) && method_value != Some(b"CONNECT".as_ref()) { return Err(Error::InvalidHeader); } if pseudo_state & pseudo_header_mask != pseudo_header_mask { return Err(Error::InvalidHeader); } Ok(()) } /// Checks if trailers are well formed, i.e. pseudo headers are not /// allowed in trailers. /// /// # Errors /// /// Returns an error if trailers are not well formed. pub fn trailers_valid(headers: &[Header]) -> Res<()> { for header in headers { if header.name().starts_with(':') { return Err(Error::InvalidHeader); } } Ok(()) } #[cfg(test)] #[cfg_attr(coverage_nightly, coverage(off))] mod tests { use neqo_common::Header; use super::{headers_valid, is_interim}; use crate::MessageType; fn create_connect_headers() -> Vec
{ vec![ Header::new(":method", "CONNECT"), Header::new(":protocol", "webtransport"), Header::new(":scheme", "https"), Header::new(":authority", "something.com"), Header::new(":path", "/here"), ] } fn create_connect_headers_without_field(field: &str) -> Vec
{ create_connect_headers() .into_iter() .filter(|header| header.name() != field) .collect() } #[test] fn connect_with_missing_header() { for field in &[":scheme", ":path", ":authority"] { assert!( headers_valid( &create_connect_headers_without_field(field), MessageType::Request ) .is_err() ); } } #[test] fn invalid_scheme_webtransport_connect() { let mut headers = create_connect_headers(); headers[2] = Header::new(":scheme", "http"); assert!(headers_valid(&headers, MessageType::Request).is_err()); } #[test] fn valid_webtransport_connect() { assert!(headers_valid(&create_connect_headers(), MessageType::Request).is_ok()); } #[test] fn invalid_webtransport_connect_with_status() { assert!( headers_valid( [ create_connect_headers(), vec![Header::new(":status", "200")] ] .concat() .as_slice(), MessageType::Request ) .is_err() ); } #[test] fn is_interim_invalid_utf8() { // Create a header with invalid UTF-8 bytes in the status value let invalid_utf8_bytes = vec![0xFF, 0xFE, 0xFD]; let header = Header::new(":status", invalid_utf8_bytes.as_slice()); let headers = vec![header]; assert!(is_interim(&headers).is_err()); } #[test] fn is_interim_not_a_number() { let headers = vec![Header::new(":status", "not-a-number")]; assert!(is_interim(&headers).is_err()); } #[test] fn protocol_requires_connect_method() { // :protocol is only valid with CONNECT method. let mut headers = create_connect_headers(); headers[0] = Header::new(":method", "GET"); assert!(headers_valid(&headers, MessageType::Request).is_err()); } #[test] fn classic_connect_valid() { // Classic CONNECT only requires :method and :authority. let headers = vec![ Header::new(":method", "CONNECT"), Header::new(":authority", "proxy.example.com:443"), ]; assert!(headers_valid(&headers, MessageType::Request).is_ok()); } #[test] fn response_requires_status() { let headers = vec![Header::new(":status", "200")]; assert!(headers_valid(&headers, MessageType::Response).is_ok()); } #[test] fn response_missing_status() { let headers: Vec
= vec![]; assert!(headers_valid(&headers, MessageType::Response).is_err()); } #[test] fn regular_request_valid() { let headers = vec![ Header::new(":method", "GET"), Header::new(":scheme", "https"), Header::new(":path", "/index.html"), ]; assert!(headers_valid(&headers, MessageType::Request).is_ok()); } #[test] fn regular_request_missing_method() { let headers = vec![ Header::new(":scheme", "https"), Header::new(":path", "/index.html"), ]; assert!(headers_valid(&headers, MessageType::Request).is_err()); } #[test] fn regular_request_missing_scheme() { let headers = vec![ Header::new(":method", "GET"), Header::new(":path", "/index.html"), ]; assert!(headers_valid(&headers, MessageType::Request).is_err()); } #[test] fn regular_request_missing_path() { let headers = vec![ Header::new(":method", "GET"), Header::new(":scheme", "https"), ]; assert!(headers_valid(&headers, MessageType::Request).is_err()); } }