// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) use mls_rs_crypto_traits::{DhType, KdfType, KemResult, KemType, SamplingMethod}; use mls_rs_core::{ crypto::{HpkePublicKey, HpkeSecretKey}, error::{AnyError, IntoAnyError}, }; use zeroize::Zeroizing; use crate::kdf::HpkeKdf; use alloc::vec::Vec; #[derive(Debug)] #[cfg_attr(feature = "std", derive(thiserror::Error))] pub enum DhKemError { #[cfg_attr(feature = "std", error(transparent))] KdfError(AnyError), #[cfg_attr(feature = "std", error(transparent))] DhError(AnyError), /// NIST key derivation from bytes failure. This is statistically unlikely #[cfg_attr( feature = "std", error("Failed to derive nist keypair from raw bytes after 255 attempts") )] KeyDerivationError, } impl IntoAnyError for DhKemError { #[cfg(feature = "std")] fn into_dyn_error(self) -> Result, Self> { Ok(self.into()) } } #[derive(Clone, Debug, Eq, PartialEq)] pub struct DhKem { dh: DH, kdf: HpkeKdf, kem_id: u16, n_secret: usize, #[cfg(feature = "test_utils")] test_key_data: Vec, } impl DhKem { pub fn new(dh: DH, kdf: KDF, kem_id: u16, n_secret: usize) -> Self { let suite_id = [b"KEM", &kem_id.to_be_bytes() as &[u8]].concat(); let kdf = HpkeKdf::new(suite_id, kdf); Self { dh, kdf, kem_id, n_secret, #[cfg(feature = "test_utils")] test_key_data: alloc::vec![], } } } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))] #[cfg_attr( all(not(target_arch = "wasm32"), mls_build_async), maybe_async::must_be_async )] impl KemType for DhKem { type Error = DhKemError; fn kem_id(&self) -> u16 { self.kem_id } async fn generate_deterministic( &self, seed: &[u8], ) -> Result<(HpkeSecretKey, HpkePublicKey), DhKemError> { match self.dh.bitmask_for_rejection_sampling() { SamplingMethod::HpkeWithBitmask(bitmask) => { self.derive_with_rejection_sampling(seed, bitmask).await } SamplingMethod::HpkeWithoutBitmask => { self.derive_without_rejection_sampling(seed).await } SamplingMethod::Raw => self.derive_raw(seed.to_vec()).await, } } async fn generate(&self) -> Result<(HpkeSecretKey, HpkePublicKey), Self::Error> { #[cfg(feature = "test_utils")] if !self.test_key_data.is_empty() { let dkp_prk = self .kdf .labeled_extract(&[], b"dkp_prk", &self.test_key_data) .await .map_err(|e| DhKemError::KdfError(e.into_any_error()))?; return self.generate_deterministic(&dkp_prk).await; } self.dh .generate() .await .map_err(|e| DhKemError::DhError(e.into_any_error())) } async fn encap(&self, remote_pk: &HpkePublicKey) -> Result { let (ephemeral_sk, ephemeral_pk) = self.generate().await?; let ecdh_ss = self .dh .dh(&ephemeral_sk, remote_pk) .await .map(Zeroizing::new) .map_err(|e| DhKemError::DhError(e.into_any_error()))?; let kem_context = [ephemeral_pk.as_ref(), remote_pk.as_ref()].concat(); let shared_secret = self .kdf .labeled_extract_then_expand(&ecdh_ss, &kem_context, self.n_secret) .await .map_err(|e| DhKemError::KdfError(e.into_any_error()))?; Ok(KemResult::new(shared_secret, ephemeral_pk.into())) } async fn decap( &self, enc: &[u8], secret_key: &HpkeSecretKey, public_key: &HpkePublicKey, ) -> Result, Self::Error> { let remote_pk = enc.to_vec().into(); let ecdh_ss = self .dh .dh(secret_key, &remote_pk) .await .map(Zeroizing::new) .map_err(|e| DhKemError::DhError(e.into_any_error()))?; let kem_context = [enc, public_key].concat(); self.kdf .labeled_extract_then_expand(&ecdh_ss, &kem_context, self.n_secret) .await .map_err(|e| DhKemError::KdfError(e.into_any_error())) } fn public_key_validate(&self, key: &HpkePublicKey) -> Result<(), Self::Error> { self.dh .public_key_validate(key) .map_err(|e| DhKemError::DhError(e.into_any_error())) } fn seed_length_for_derive(&self) -> usize { self.n_secret } fn public_key_size(&self) -> usize { self.dh.public_key_size() } fn secret_key_size(&self) -> usize { self.dh.secret_key_size() } } impl DhKem { #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn derive_with_rejection_sampling( &self, dkp_prk: &[u8], bitmask: u8, ) -> Result<(HpkeSecretKey, HpkePublicKey), DhKemError> { // The RFC specifies we get 255 chances to generate bytes that will be within range of the order for the curve for i in 0u8..255 { let mut secret_key = self .kdf .labeled_expand(dkp_prk, b"candidate", &[i], self.dh.secret_key_size()) .await .map_err(|e| DhKemError::KdfError(e.into_any_error()))?; secret_key[0] &= bitmask; let secret_key = secret_key.into(); // Compute the public key and if it succeeds, return the key pair if let Ok(pair) = self .dh .to_public(&secret_key) .await .map(|pk| (secret_key, pk)) { return Ok(pair); } } // If we never generate bytes that work, throw an error Err(DhKemError::KeyDerivationError) } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn derive_without_rejection_sampling( &self, dkp_prk: &[u8], ) -> Result<(HpkeSecretKey, HpkePublicKey), DhKemError> { let sk = self .kdf .labeled_expand(dkp_prk, b"sk", &[], self.dh.secret_key_size()) .await .map_err(|e| DhKemError::KdfError(e.into_any_error()))?; self.derive_raw(sk).await } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn derive_raw( &self, seed: Vec, ) -> Result<(HpkeSecretKey, HpkePublicKey), DhKemError> { let sk = seed.into(); let pk = self .dh .to_public(&sk) .await .map_err(|e| DhKemError::DhError(e.into_any_error()))?; Ok((sk, pk)) } #[cfg(feature = "test_utils")] pub fn set_test_data(&mut self, test_data: Vec) { self.test_key_data = test_data } }