// Copyright (c) 2024 Mozilla Corporation and contributors. // SPDX-License-Identifier: (Apache-2.0 OR MIT) use std::{ collections::HashMap, path::Path, sync::{Arc, Mutex}, }; use mls_rs::{ client_builder::MlsConfig, crypto::{SignaturePublicKey, SignatureSecretKey}, error::IntoAnyError, identity::{Credential, SigningIdentity}, mls_rules::{CommitOptions, DefaultMlsRules}, storage_provider::KeyPackageData, CipherSuite, Client, ProtocolVersion, }; use crate::ClientConfig; use mls_rs_provider_sqlite::{ connection_strategy::{ CipheredConnectionStrategy, ConnectionStrategy, FileConnectionStrategy, SqlCipherConfig, SqlCipherKey, }, SqLiteDataStorageEngine, }; use crate::{Identity, PlatformError}; #[derive(serde::Serialize, serde::Deserialize, Clone)] pub struct GroupData { state_data: Vec, epoch_data: HashMap>, } #[derive(serde::Serialize, serde::Deserialize, Clone)] pub struct PlatformState { pub db_path: String, pub db_key: [u8; 32], } #[derive(serde::Serialize, serde::Deserialize, Clone, Default)] pub struct TemporaryState { pub groups: Arc, GroupData>>>, /// signing identity => key data pub sigkeys: HashMap, SignatureData>, pub key_packages: Arc, KeyPackageData>>>, } #[derive(serde::Serialize, serde::Deserialize, Clone)] pub struct SignatureData { #[serde(with = "hex::serde")] pub public_key: Vec, pub cs: u16, #[serde(with = "hex::serde")] pub secret_key: Vec, pub credential: Option, } impl PlatformState { pub fn new(db_path: &str, db_key: &[u8; 32]) -> Result { let state = Self { db_path: db_path.to_string(), db_key: *db_key, }; // This will create an empty database if it doesn't exist. state .get_sqlite_engine()? .application_data_storage() .map_err(|e| (PlatformError::StorageError(e.into_any_error())))?; Ok(state) } pub fn get_signing_identities(&self) -> Result, PlatformError> { todo!(); } pub fn client( &self, myself_identifier: &[u8], myself_credential: Option, version: ProtocolVersion, config: &ClientConfig, ) -> Result, PlatformError> { let crypto_provider = mls_rs_crypto_nss::NssCryptoProvider::default(); let engine = self .get_sqlite_engine()? .with_context(myself_identifier.to_vec()); let mut myself_sig_data = self .get_sig_data(myself_identifier)? .ok_or(PlatformError::UnavailableSecret)?; let myself_credential = if let Some(cred) = myself_credential { myself_sig_data.credential = Some(cred.clone()); self.store_sigdata(myself_identifier, &myself_sig_data)?; cred } else if let Some(cred) = myself_sig_data.credential { cred } else { return Err(PlatformError::UndefinedIdentity); }; let myself_signing_identity = SigningIdentity::new(myself_credential, myself_sig_data.public_key.into()); let mut builder = mls_rs::client_builder::ClientBuilder::new_sqlite(engine) .map_err(|e| PlatformError::StorageError(e.into_any_error()))? .crypto_provider(crypto_provider) .identity_provider(mls_rs::identity::basic::BasicIdentityProvider) .signing_identity( myself_signing_identity, myself_sig_data.secret_key.into(), myself_sig_data.cs.into(), ) .protocol_version(version); if let Some(key_package_lifetime_s) = config.key_package_lifetime_s { builder = builder.key_package_lifetime(key_package_lifetime_s); } let mls_rules = DefaultMlsRules::new().with_commit_options( CommitOptions::default().with_allow_external_commit(config.allow_external_commits), ); builder = builder.mls_rules(mls_rules); Ok(builder.build()) } pub fn client_default( &self, myself_identifier: &[u8], ) -> Result, PlatformError> { self.client( myself_identifier, None, ProtocolVersion::MLS_10, &Default::default(), ) } pub fn insert_sigkey( &self, myself_sigkey: &SignatureSecretKey, myself_pubkey: &SignaturePublicKey, cs: CipherSuite, identifier: &[u8], ) -> Result<(), PlatformError> { let signature_data = SignatureData { public_key: myself_pubkey.to_vec(), cs: *cs, secret_key: myself_sigkey.to_vec(), credential: None, }; self.store_sigdata(identifier, &signature_data)?; Ok(()) } fn store_sigdata(&self, identifier: &[u8], data: &SignatureData) -> Result<(), PlatformError> { let data = bincode::serialize(&data)?; let engine = self.get_sqlite_engine()?; let storage = engine .application_data_storage() .map_err(|e| PlatformError::StorageError(e.into_any_error()))?; storage .insert(&hex::encode(identifier), &data) .map_err(|e| PlatformError::StorageError(e.into_any_error()))?; Ok(()) } pub fn get_sig_data( &self, myself_identifier: &[u8], ) -> Result, PlatformError> { // TODO: Not clear if the option is needed here, the underlying function needs it. let key = myself_identifier.to_vec(); let engine = self.get_sqlite_engine()?; let storage = engine .application_data_storage() .map_err(|e| PlatformError::StorageError(e.into_any_error()))?; storage .get(&hex::encode(key)) .map_err(|e| PlatformError::StorageError(e.into_any_error()))? .map_or_else( || Ok(None), |data| bincode::deserialize(&data).map(Some).map_err(Into::into), ) } pub fn get_sqlite_engine( &self, ) -> Result, PlatformError> { let path = Path::new(&self.db_path); let file_conn = FileConnectionStrategy::new(path); let cipher_config = SqlCipherConfig::new(SqlCipherKey::RawKey(self.db_key)); let cipher_conn = CipheredConnectionStrategy::new(file_conn, cipher_config); SqLiteDataStorageEngine::new(cipher_conn) .map_err(|e| PlatformError::StorageError(e.into_any_error())) } pub fn delete_group(&self, gid: &[u8], identifier: &[u8]) -> Result<(), PlatformError> { let storage = self .get_sqlite_engine()? .with_context(identifier.to_vec()) .group_state_storage() .map_err(|_| PlatformError::InternalError)?; // Delete the group storage .delete_group(gid) .map_err(|_| PlatformError::InternalError) } pub fn delete(db_path: &str) -> Result<(), PlatformError> { let path = Path::new(db_path); if path.exists() { std::fs::remove_file(path)?; } Ok(()) } } // Dependencies for implementation of InMemoryState // use crate::GroupConfig; // use mls_rs::mls_rs_codec::{MlsDecode, MlsEncode}; // use mls_rs::{GroupStateStorage, KeyPackageStorage}; // use mls_rs_core::group::{EpochRecord, GroupState}; // use std::{collections::hash_map::Entry, convert::Infallible}; // impl InMemoryState { // pub fn new() -> Self { // Default::default() // } // pub fn to_bytes(&self) -> Result, PlatformError> { // bincode::serialize(self).map_err(Into::into) // } // pub fn from_bytes(bytes: &[u8]) -> Result { // bincode::deserialize(bytes).map_err(Into::into) // } // pub fn client( // &self, // myself: SigningIdentity, // group_config: Option, // ) -> Result, PlatformError> { // let crypto_provider = mls_rs_crypto_nss::NssCryptoProvider::default(); // let myself_sigkey = self.get_sigkey(&myself)?; // let mut builder = mls_rs::client_builder::ClientBuilder::new() // .key_package_repo(self.clone()) // .group_state_storage(self.clone()) // .crypto_provider(crypto_provider) // .identity_provider(mls_rs::identity::basic::BasicIdentityProvider) // .signing_identity( // myself, // myself_sigkey.secret_key.into(), // myself_sigkey.cs.into(), // ); // if let Some(config) = group_config { // builder = builder // .key_package_extensions(config.options) // .protocol_version(config.version); // } // Ok(builder.build()) // } // pub fn insert_sigkey( // &mut self, // myself_sigkey: &SignatureSecretKey, // myself_pubkey: &SignaturePublicKey, // cs: CipherSuite, // identifier: Identity, // ) -> Result<(), PlatformError> { // let signature_data = SignatureData { // public_key: myself_pubkey.to_vec(), // cs: *cs, // secret_key: myself_sigkey.to_vec(), // credential: None, // }; // self.sigkeys.insert(identifier, signature_data); // // TODO: We could return the value to indicate if the key // // existed (see the definition of insert). // Ok(()) // } // pub fn get_sigkey(&self, myself: &SigningIdentity) -> Result { // let key = myself.mls_encode_to_vec()?; // self.sigkeys // .get(&key) // .cloned() // .ok_or(PlatformError::UnavailableSecret) // } // } // impl GroupStateStorage for TemporaryState { // type Error = mls_rs::mls_rs_codec::Error; // fn max_epoch_id(&self, group_id: &[u8]) -> Result, Self::Error> { // let group_locked = self.groups.lock().unwrap(); // Ok(group_locked // .get(group_id) // .and_then(|group_data| group_data.epoch_data.keys().max().copied())) // } // fn state(&self, group_id: &[u8]) -> Result, Self::Error> // where // T: GroupState + MlsDecode, // { // self.groups // .lock() // .unwrap() // .get(group_id) // .map(|v| T::mls_decode(&mut v.state_data.as_slice())) // .transpose() // .map_err(Into::into) // } // fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result, Self::Error> // where // T: EpochRecord + MlsEncode + MlsDecode, // { // self.groups // .lock() // .unwrap() // .get(group_id) // .and_then(|group_data| group_data.epoch_data.get(&epoch_id)) // .map(|v| T::mls_decode(&mut &v[..])) // .transpose() // .map_err(Into::into) // } // fn write( // &mut self, // state: ST, // epoch_inserts: Vec, // epoch_updates: Vec, // ) -> Result<(), Self::Error> // where // ST: GroupState + MlsEncode + MlsDecode + Send + Sync, // ET: EpochRecord + MlsEncode + MlsDecode + Send + Sync, // { // let state_data = state.mls_encode_to_vec()?; // let mut states = self.groups.lock().unwrap(); // let group_data = match states.entry(state.id()) { // Entry::Occupied(entry) => { // let data = entry.into_mut(); // data.state_data = state_data; // data // } // Entry::Vacant(entry) => entry.insert(GroupData { // state_data, // epoch_data: Default::default(), // }), // }; // epoch_inserts.into_iter().try_for_each(|e| { // group_data.epoch_data.insert(e.id(), e.mls_encode_to_vec()?); // Ok::<_, Self::Error>(()) // })?; // epoch_updates.into_iter().try_for_each(|e| { // if let Some(data) = group_data.epoch_data.get_mut(&e.id()) { // *data = e.mls_encode_to_vec()?; // }; // Ok::<_, Self::Error>(()) // })?; // Ok(()) // } // } // impl KeyPackageStorage for TemporaryState { // type Error = Infallible; // fn insert(&mut self, id: Vec, pkg: KeyPackageData) -> Result<(), Self::Error> { // let mut states = self.key_packages.lock().unwrap(); // states.insert(id, pkg); // Ok(()) // } // fn get(&self, id: &[u8]) -> Result, Self::Error> { // Ok(self.key_packages.lock().unwrap().get(id).cloned()) // } // fn delete(&mut self, id: &[u8]) -> Result<(), Self::Error> { // let mut states = self.key_packages.lock().unwrap(); // states.remove(id); // Ok(()) // } // }