// Licensed under the Apache License, Version 2.0 or the MIT license // , at your // option. This file may not be copied, modified, or distributed // except according to those terms. #![expect( clippy::unwrap_used, reason = "Let's assume the use of `unwrap` was checked when the use of `unsafe` was reviewed." )] use std::{ cmp::min, fmt::{self, Display, Formatter}, mem, ops::Deref, os::raw::{c_uint, c_void}, pin::Pin, ptr::{null, null_mut}, }; use neqo_common::{hex, hex_with_len, qtrace}; use crate::{ constants::{ContentType, Epoch}, err::{Error, PR_SetError, Res, nspr}, null_safe_slice, prio, ssl, }; // Alias common types. type PrFd = *mut prio::PRFileDesc; type PrStatus = prio::PRStatus::Type; const PR_SUCCESS: PrStatus = prio::PRStatus::PR_SUCCESS; const PR_FAILURE: PrStatus = prio::PRStatus::PR_FAILURE; /// Convert a pinned, boxed object into a void pointer. pub fn as_c_void(pin: &mut Pin>) -> *mut c_void { (std::ptr::from_mut::(Pin::into_inner(pin.as_mut()))).cast() } /// A slice of the output. #[derive(Default)] pub struct Record { pub epoch: Epoch, pub ct: ContentType, pub data: Vec, } impl Record { #[must_use] pub fn new(epoch: Epoch, ct: ContentType, data: &[u8]) -> Self { Self { epoch, ct, data: data.to_vec(), } } // Shoves this record into the socket, returns true if blocked. pub(crate) fn write(self, fd: *mut ssl::PRFileDesc) -> Res<()> { unsafe { ssl::SSL_RecordLayerData( fd, self.epoch, ssl::SSLContentType::Type::from(self.ct), self.data.as_ptr(), c_uint::try_from(self.data.len())?, ) } } } impl fmt::Debug for Record { fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!( f, "Record {:?}:{:?} {}", self.epoch, self.ct, hex_with_len(&self.data[..]) ) } } #[derive(Debug, Default)] pub struct RecordList { records: Vec, } impl RecordList { fn append(&mut self, epoch: Epoch, ct: ContentType, data: &[u8]) { self.records.push(Record::new(epoch, ct, data)); } unsafe extern "C" fn ingest( _fd: *mut ssl::PRFileDesc, epoch: ssl::PRUint16, ct: ssl::SSLContentType::Type, data: *const ssl::PRUint8, len: c_uint, arg: *mut c_void, ) -> ssl::SECStatus { let Ok(epoch) = Epoch::try_from(epoch) else { return ssl::SECFailure; }; let Ok(ct) = ContentType::try_from(ct) else { return ssl::SECFailure; }; let Some(records) = (unsafe { arg.cast::().as_mut() }) else { return ssl::SECFailure; }; let slice = unsafe { null_safe_slice(data, len) }; records.append(epoch, ct, slice); ssl::SECSuccess } /// Create a new record list. pub(crate) fn setup(fd: *mut ssl::PRFileDesc) -> Res>> { let mut records = Box::pin(Self::default()); unsafe { ssl::SSL_RecordLayerWriteCallback(fd, Some(Self::ingest), as_c_void(&mut records)) }?; Ok(records) } } impl Deref for RecordList { type Target = Vec; fn deref(&self) -> &Vec { &self.records } } pub struct RecordListIter(std::vec::IntoIter); impl Iterator for RecordListIter { type Item = Record; fn next(&mut self) -> Option { self.0.next() } } impl IntoIterator for RecordList { type Item = Record; type IntoIter = RecordListIter; fn into_iter(self) -> Self::IntoIter { RecordListIter(self.records.into_iter()) } } pub struct AgentIoInputContext<'a> { input: &'a mut AgentIoInput, } impl Drop for AgentIoInputContext<'_> { fn drop(&mut self) { self.input.reset(); } } #[derive(Debug, Default)] struct AgentIoInput { // input is data that is read by TLS. input: *const u8, // input_available is how much data is left for reading. available: usize, } impl AgentIoInput { fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> { assert!(self.input.is_null()); self.input = input.as_ptr(); self.available = input.len(); qtrace!("AgentIoInput wrap {:p}", self.input); AgentIoInputContext { input: self } } // Take the data provided as input and provide it to the TLS stack. fn read_input(&mut self, buf: *mut u8, count: usize) -> Res { let amount = min(self.available, count); if amount == 0 { unsafe { PR_SetError(nspr::PR_WOULD_BLOCK_ERROR, 0); } return Err(Error::NoDataAvailable); } #[expect( clippy::disallowed_methods, reason = "We just checked if this was empty." )] let src = unsafe { std::slice::from_raw_parts(self.input, amount) }; qtrace!("[{self}] read {}", hex(src)); let dst = unsafe { std::slice::from_raw_parts_mut(buf, amount) }; dst.copy_from_slice(src); self.input = self.input.wrapping_add(amount); self.available -= amount; Ok(amount) } fn reset(&mut self) { qtrace!("[{self}] reset"); self.input = null(); self.available = 0; } } impl Display for AgentIoInput { fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(f, "AgentIoInput {:p}", self.input) } } #[derive(Debug, Default)] pub struct AgentIo { // input collects the input we might provide to TLS. input: AgentIoInput, // output contains data that is written by TLS. output: Vec, } impl AgentIo { unsafe fn borrow(fd: &mut PrFd) -> &mut Self { unsafe { (**fd).secret.cast::().as_mut().unwrap() } } pub fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> { assert_eq!(self.output.len(), 0); self.input.wrap(input) } // Stage output from TLS into the output buffer. fn save_output(&mut self, buf: *const u8, count: usize) { let slice = unsafe { null_safe_slice(buf, count) }; qtrace!("[{self}] save output {}", hex(slice)); self.output.extend_from_slice(slice); } pub fn take_output(&mut self) -> Vec { qtrace!("[{self}] take output"); mem::take(&mut self.output) } } impl Display for AgentIo { fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(f, "AgentIo") } } unsafe extern "C" fn agent_close(fd: PrFd) -> PrStatus { unsafe { (*fd).secret = null_mut(); if let Some(dtor) = (*fd).dtor { dtor(fd); } } PR_SUCCESS } unsafe extern "C" fn agent_read(mut fd: PrFd, buf: *mut c_void, amount: prio::PRInt32) -> PrStatus { let io = unsafe { AgentIo::borrow(&mut fd) }; let Ok(a) = usize::try_from(amount) else { return PR_FAILURE; }; match io.input.read_input(buf.cast(), a) { Ok(_) => PR_SUCCESS, Err(_) => PR_FAILURE, } } unsafe extern "C" fn agent_recv( mut fd: PrFd, buf: *mut c_void, amount: prio::PRInt32, flags: prio::PRIntn, _timeout: prio::PRIntervalTime, ) -> prio::PRInt32 { let io = unsafe { AgentIo::borrow(&mut fd) }; if flags != 0 { return PR_FAILURE; } let Ok(a) = usize::try_from(amount) else { return PR_FAILURE; }; io.input.read_input(buf.cast(), a).map_or(PR_FAILURE, |v| { prio::PRInt32::try_from(v).unwrap_or(PR_FAILURE) }) } unsafe extern "C" fn agent_write( mut fd: PrFd, buf: *const c_void, amount: prio::PRInt32, ) -> PrStatus { let io = unsafe { AgentIo::borrow(&mut fd) }; usize::try_from(amount).map_or(PR_FAILURE, |a| { io.save_output(buf.cast(), a); amount }) } unsafe extern "C" fn agent_send( mut fd: PrFd, buf: *const c_void, amount: prio::PRInt32, flags: prio::PRIntn, _timeout: prio::PRIntervalTime, ) -> prio::PRInt32 { let io = unsafe { AgentIo::borrow(&mut fd) }; if flags != 0 { return PR_FAILURE; } usize::try_from(amount).map_or(PR_FAILURE, |a| { io.save_output(buf.cast(), a); amount }) } unsafe extern "C" fn agent_available(mut fd: PrFd) -> prio::PRInt32 { let io = unsafe { AgentIo::borrow(&mut fd) }; io.input.available.try_into().unwrap_or(PR_FAILURE) } unsafe extern "C" fn agent_available64(mut fd: PrFd) -> prio::PRInt64 { let io = unsafe { AgentIo::borrow(&mut fd) }; io.input .available .try_into() .unwrap_or_else(|_| PR_FAILURE.into()) } #[expect( clippy::cast_possible_truncation, reason = "Cast is safe because prio::PR_AF_INET is 2." )] const unsafe extern "C" fn agent_getname(_fd: PrFd, addr: *mut prio::PRNetAddr) -> PrStatus { let Some(a) = (unsafe { addr.as_mut() }) else { return PR_FAILURE; }; a.inet.family = prio::PR_AF_INET as prio::PRUint16; a.inet.port = 0; a.inet.ip = 0; PR_SUCCESS } const unsafe extern "C" fn agent_getsockopt( _fd: PrFd, opt: *mut prio::PRSocketOptionData, ) -> PrStatus { let Some(o) = (unsafe { opt.as_mut() }) else { return PR_FAILURE; }; if o.option == prio::PRSockOption::PR_SockOpt_Nonblocking { o.value.non_blocking = 1; return PR_SUCCESS; } PR_FAILURE } pub const METHODS: &prio::PRIOMethods = &prio::PRIOMethods { file_type: prio::PRDescType::PR_DESC_LAYERED, close: Some(agent_close), read: Some(agent_read), write: Some(agent_write), available: Some(agent_available), available64: Some(agent_available64), fsync: None, seek: None, seek64: None, fileInfo: None, fileInfo64: None, writev: None, connect: None, accept: None, bind: None, listen: None, shutdown: None, recv: Some(agent_recv), send: Some(agent_send), recvfrom: None, sendto: None, poll: None, acceptread: None, transmitfile: None, getsockname: Some(agent_getname), getpeername: Some(agent_getname), reserved_fn_6: None, reserved_fn_5: None, getsocketoption: Some(agent_getsockopt), setsocketoption: None, sendfile: None, connectcontinue: None, reserved_fn_3: None, reserved_fn_2: None, reserved_fn_1: None, reserved_fn_0: None, }; #[cfg(test)] #[cfg_attr(coverage_nightly, coverage(off))] mod tests { use std::ptr::addr_of_mut; use super::*; #[test] fn ingest_errors() { let mut records = RecordList::default(); let data = [0u8]; unsafe { assert_eq!( RecordList::ingest( null_mut(), 999, 0x17, data.as_ptr(), 1, addr_of_mut!(records).cast() ), ssl::SECFailure ); assert_eq!( RecordList::ingest(null_mut(), 0, 0x17, data.as_ptr(), 1, null_mut()), ssl::SECFailure ); // Test invalid content type (value outside u8 range) assert_eq!( RecordList::ingest( null_mut(), 0, 256, data.as_ptr(), 1, addr_of_mut!(records).cast() ), ssl::SECFailure ); } } #[test] fn formatting() { let record = Record::new(Epoch::ApplicationData, 0x17, &[1, 2, 3]); let dbg = format!("{record:?}"); assert_eq!(&dbg[..6], "Record"); let input = AgentIoInput::default(); let disp = format!("{input}"); assert_eq!(&disp[..12], "AgentIoInput"); let io = AgentIo::default(); assert_eq!(format!("{io}"), "AgentIo"); } }