// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) use crate::SqLiteDataStorageError; use mls_rs_core::psk::{ExternalPskId, PreSharedKey, PreSharedKeyStorage}; use rusqlite::{params, Connection, OptionalExtension}; use std::{ ops::Deref, sync::{Arc, Mutex}, }; #[derive(Debug, Clone)] /// SQLite storage for MLS pre-shared keys. pub struct SqLitePreSharedKeyStorage { connection: Arc>, } impl SqLitePreSharedKeyStorage { pub(crate) fn new(connection: Connection) -> SqLitePreSharedKeyStorage { SqLitePreSharedKeyStorage { connection: Arc::new(Mutex::new(connection)), } } /// Insert a pre-shared key into storage. pub fn insert(&self, psk_id: &[u8], psk: &PreSharedKey) -> Result<(), SqLiteDataStorageError> { let connection = self.connection.lock().unwrap(); // Upsert into the database connection .execute( "INSERT INTO psk (psk_id, data) VALUES (?,?) ON CONFLICT(psk_id) DO UPDATE SET data=excluded.data", params![psk_id, psk.deref()], ) .map(|_| ()) .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into())) } /// Get a pre-shared key from storage based on a unique id. pub fn get(&self, psk_id: &[u8]) -> Result, SqLiteDataStorageError> { let connection = self.connection.lock().unwrap(); connection .query_row( "SELECT data FROM psk WHERE psk_id = ?", params![psk_id], |row| Ok(PreSharedKey::new(row.get(0)?)), ) .optional() .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into())) } /// Delete a pre-shared key from storage based on a unique id. pub fn delete(&self, psk_id: &[u8]) -> Result<(), SqLiteDataStorageError> { let connection = self.connection.lock().unwrap(); connection .execute("DELETE FROM psk WHERE psk_id = ?", params![psk_id]) .map(|_| ()) .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into())) } } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(mls_build_async, maybe_async::must_be_async)] impl PreSharedKeyStorage for SqLitePreSharedKeyStorage { type Error = SqLiteDataStorageError; async fn get(&self, id: &ExternalPskId) -> Result, Self::Error> { self.get(id) .map_err(|e| SqLiteDataStorageError::DataConversionError(e.into())) } } #[cfg(test)] mod tests { use mls_rs_core::psk::PreSharedKey; use crate::{ SqLiteDataStorageEngine, {connection_strategy::MemoryStrategy, test_utils::gen_rand_bytes}, }; use super::SqLitePreSharedKeyStorage; fn test_psk() -> (Vec, PreSharedKey) { let psk_id = gen_rand_bytes(32); let stored_psk = PreSharedKey::new(gen_rand_bytes(64)); (psk_id, stored_psk) } fn test_storage() -> SqLitePreSharedKeyStorage { SqLiteDataStorageEngine::new(MemoryStrategy) .unwrap() .pre_shared_key_storage() .unwrap() } #[test] fn test_insert() { let (psk_id, psk) = test_psk(); let storage = test_storage(); storage.insert(&psk_id, &psk).unwrap(); let from_storage = storage.get(&psk_id).unwrap().unwrap(); assert_eq!(from_storage, psk); } #[test] fn test_insert_existing_overwrite() { let (psk_id, psk) = test_psk(); let (_, new_psk) = test_psk(); let storage = test_storage(); storage.insert(&psk_id, &psk).unwrap(); storage.insert(&psk_id, &new_psk).unwrap(); let from_storage = storage.get(&psk_id).unwrap().unwrap(); assert_eq!(from_storage, new_psk); } #[test] fn test_delete() { let (psk_id, psk) = test_psk(); let storage = test_storage(); storage.insert(&psk_id, &psk).unwrap(); storage.delete(&psk_id).unwrap(); assert!(storage.get(&psk_id).unwrap().is_none()); } }