/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ extern crate rusqlite; use parking_lot::Mutex; use std::collections::HashMap; #[derive(Debug, thiserror::Error)] pub enum OhttpError { #[error("Failed to fetch encryption key")] KeyFetchFailed, #[error("OHTTP key config is malformed")] MalformedKeyConfig, #[error("Unsupported OHTTP encryption algorithm")] UnsupportedKeyConfig, #[error("OhttpSession is in invalid state")] InvalidSession, #[error("Network errors communicating with Relay / Gateway")] RelayFailed, #[error("Cannot encode message as BHTTP/OHTTP")] CannotEncodeMessage, #[error("Cannot decode OHTTP/BHTTP message")] MalformedMessage, #[error("Duplicate HTTP response headers")] DuplicateHeaders, } #[derive(Default)] enum ExchangeState { #[default] Invalid, Request(ohttp::ClientRequest), Response(ohttp::ClientResponse), } pub struct OhttpSession { state: Mutex, } pub struct OhttpResponse { status_code: u16, headers: HashMap, payload: Vec, } /// Transform the headers from a BHTTP message into a HashMap for use from Swift /// later. If there are duplicate errors, we currently raise an error. fn headers_to_map(message: &bhttp::Message) -> Result, OhttpError> { let mut headers = HashMap::new(); for field in message.header().iter() { if headers .insert( std::str::from_utf8(field.name()) .map_err(|_| OhttpError::MalformedMessage)? .into(), std::str::from_utf8(field.value()) .map_err(|_| OhttpError::MalformedMessage)? .into(), ) .is_some() { return Err(OhttpError::DuplicateHeaders); } } Ok(headers) } impl OhttpSession { /// Create a new encryption session for use with specific key configuration pub fn new(config: &[u8]) -> Result { ohttp::init(); let request = ohttp::ClientRequest::from_encoded_config(config).map_err(|e| match e { ohttp::Error::Unsupported => OhttpError::UnsupportedKeyConfig, _ => OhttpError::MalformedKeyConfig, })?; let state = Mutex::new(ExchangeState::Request(request)); Ok(OhttpSession { state }) } /// Encode an HTTP request in Binary HTTP format and then encrypt it into an /// Oblivious HTTP request message. pub fn encapsulate( &self, method: &str, scheme: &str, server: &str, endpoint: &str, mut headers: HashMap, payload: &[u8], ) -> Result, OhttpError> { let mut message = bhttp::Message::request(method.into(), scheme.into(), server.into(), endpoint.into()); for (k, v) in headers.drain() { message.put_header(k, v); } message.write_content(payload); let mut encoded = vec![]; message .write_bhttp(bhttp::Mode::KnownLength, &mut encoded) .map_err(|_| OhttpError::CannotEncodeMessage)?; let mut state = self.state.lock(); let request = match std::mem::take(&mut *state) { ExchangeState::Request(request) => request, _ => return Err(OhttpError::InvalidSession), }; let (capsule, response) = request .encapsulate(&encoded) .map_err(|_| OhttpError::CannotEncodeMessage)?; *state = ExchangeState::Response(response); Ok(capsule) } /// Decode an OHTTP response returned in response to a request encoded on /// this session. pub fn decapsulate(&self, encoded: &[u8]) -> Result { let mut state = self.state.lock(); let decoder = match std::mem::take(&mut *state) { ExchangeState::Response(response) => response, _ => return Err(OhttpError::InvalidSession), }; let binary = decoder .decapsulate(encoded) .map_err(|_| OhttpError::MalformedMessage)?; let mut cursor = std::io::Cursor::new(binary); let message = bhttp::Message::read_bhttp(&mut cursor).map_err(|_| OhttpError::MalformedMessage)?; let headers = headers_to_map(&message)?; Ok(OhttpResponse { status_code: match message.control() { bhttp::ControlData::Response(sc) => (*sc).into(), _ => return Err(OhttpError::InvalidSession), }, headers, payload: message.content().into(), }) } } pub struct OhttpTestServer { server: Mutex, state: Mutex>, config: Vec, } pub struct TestServerRequest { method: String, scheme: String, server: String, endpoint: String, headers: HashMap, payload: Vec, } impl OhttpTestServer { /// Create a simple OHTTP server to decrypt and respond to OHTTP messages in /// testing. The key is randomly generated. fn new() -> Self { ohttp::init(); let key = ohttp::KeyConfig::new( 0x01, ohttp::hpke::Kem::X25519Sha256, vec![ohttp::SymmetricSuite::new( ohttp::hpke::Kdf::HkdfSha256, ohttp::hpke::Aead::Aes128Gcm, )], ) .unwrap(); let config = key.encode().unwrap(); let server = ohttp::Server::new(key).unwrap(); OhttpTestServer { server: Mutex::new(server), state: Mutex::new(Option::None), config, } } /// Return a copy of the key config for clients to use. fn get_config(&self) -> Vec { self.config.clone() } /// Decode an OHTTP request message and return the cleartext contents. This /// also updates the internal server state so that a response message can be /// generated. fn receive(&self, message: &[u8]) -> Result { let (encoded, response) = self .server .lock() .decapsulate(message) .map_err(|_| OhttpError::MalformedMessage)?; let mut cursor = std::io::Cursor::new(encoded); let message = bhttp::Message::read_bhttp(&mut cursor).map_err(|_| OhttpError::MalformedMessage)?; *self.state.lock() = Some(response); let headers = headers_to_map(&message)?; match message.control() { bhttp::ControlData::Request { method, scheme, authority, path, } => Ok(TestServerRequest { method: String::from_utf8_lossy(method).into(), scheme: String::from_utf8_lossy(scheme).into(), server: String::from_utf8_lossy(authority).into(), endpoint: String::from_utf8_lossy(path).into(), headers, payload: message.content().into(), }), _ => Err(OhttpError::MalformedMessage), } } /// Encode an OHTTP response keyed to the last message received. fn respond(&self, response: OhttpResponse) -> Result, OhttpError> { let state = self.state.lock().take().unwrap(); let mut message = bhttp::Message::response(bhttp::StatusCode::try_from(response.status_code).unwrap()); message.write_content(&response.payload); for (k, v) in response.headers { message.put_header(k, v); } let mut encoded = vec![]; message .write_bhttp(bhttp::Mode::KnownLength, &mut encoded) .map_err(|_| OhttpError::CannotEncodeMessage)?; state .encapsulate(&encoded) .map_err(|_| OhttpError::CannotEncodeMessage) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_smoke() { let server = OhttpTestServer::new(); let config = server.get_config(); let body: Vec = vec![0x00, 0x01, 0x02]; let header = HashMap::from([ ("Content-Type".into(), "application/octet-stream".into()), ("X-Header".into(), "value".into()), ]); let session = OhttpSession::new(&config).unwrap(); let mut message = session .encapsulate("GET", "https", "example.com", "/api", header.clone(), &body) .unwrap(); let request = server.receive(&message).unwrap(); assert_eq!(request.method, "GET"); assert_eq!(request.scheme, "https"); assert_eq!(request.server, "example.com"); assert_eq!(request.endpoint, "/api"); assert_eq!(request.headers, header); message = server .respond(OhttpResponse { status_code: 200, headers: header.clone(), payload: body.clone(), }) .unwrap(); let response = session.decapsulate(&message).unwrap(); assert_eq!(response.status_code, 200); assert_eq!(response.headers, header); assert_eq!(response.payload, body); } } uniffi::include_scaffolding!("as_ohttp_client");