// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) use crate::{util::credential_to_chain, CertificateChain, X509IdentityError}; use alloc::vec; use alloc::vec::Vec; use mls_rs_core::{ crypto::SignaturePublicKey, error::IntoAnyError, extension::ExtensionList, identity::{CredentialType, IdentityProvider, MemberValidationContext, SigningIdentity}, time::MlsTime, }; #[cfg(not(feature = "std"))] use alloc::boxed::Box; #[cfg(all(test, feature = "std"))] use mockall::automock; #[cfg_attr(all(test, feature = "std"), automock(type Error = crate::test_utils::TestError;))] /// X.509 certificate unique identity trait. pub trait X509IdentityExtractor { type Error: IntoAnyError; /// Produce a unique identity value to represent the entity controlling a /// certificate credential within an MLS group. fn identity(&self, certificate_chain: &CertificateChain) -> Result, Self::Error>; /// Determine if `successor` is controlled by the same entity as /// `predecessor`. fn valid_successor( &self, predecessor: &CertificateChain, successor: &CertificateChain, ) -> Result; } #[cfg_attr(all(test, feature = "std"), automock(type Error = crate::test_utils::TestError;))] /// X.509 certificate validation trait. pub trait X509CredentialValidator { type Error: IntoAnyError; /// Validate a certificate chain. /// /// If `timestamp` is set to `None` then expiration checks should be skipped. fn validate_chain( &self, chain: &CertificateChain, timestamp: Option, ) -> Result; } #[derive(Clone, Debug)] #[non_exhaustive] /// A customizable generic X.509 certificate identity provider. /// /// This provider forwards its individual [`IdentityProvider`] /// behavior to its generic sub-components. /// /// Only X509 credentials are supported by this provider. pub struct X509IdentityProvider { pub identity_extractor: IE, pub validator: V, } impl X509IdentityProvider where IE: X509IdentityExtractor, V: X509CredentialValidator, { /// Create a new identity provider. pub fn new(identity_extractor: IE, validator: V) -> Self { Self { identity_extractor, validator, } } /// Determine if a certificate is valid based on the behavior of the /// underlying validator provided. fn validate( &self, signing_identity: &SigningIdentity, timestamp: Option, ) -> Result<(), X509IdentityError> { let chain = credential_to_chain(&signing_identity.credential)?; let leaf_public_key = self .validator .validate_chain(&chain, timestamp) .map_err(|e| X509IdentityError::X509ValidationError(e.into_any_error()))?; if leaf_public_key != signing_identity.signature_key { return Err(X509IdentityError::SignatureKeyMismatch); } Ok(()) } } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(mls_build_async, maybe_async::must_be_async)] impl IdentityProvider for X509IdentityProvider where IE: X509IdentityExtractor + Send + Sync, V: X509CredentialValidator + Send + Sync, { type Error = X509IdentityError; /// Determine if a certificate is valid based on the behavior of the /// underlying validator provided. async fn validate_member( &self, signing_identity: &SigningIdentity, timestamp: Option, _context: MemberValidationContext<'_>, ) -> Result<(), X509IdentityError> { self.validate(signing_identity, timestamp) } /// Produce a unique identity value to represent the entity controlling a /// certificate credential within an MLS group. async fn identity( &self, signing_id: &SigningIdentity, _extensions: &ExtensionList, ) -> Result, X509IdentityError> { self.identity_extractor .identity(&credential_to_chain(&signing_id.credential)?) .map_err(|e| X509IdentityError::IdentityExtractorError(e.into_any_error())) } /// Determine if `successor` is controlled by the same entity as /// `predecessor` based on the behavior of the underlying identity /// extractor provided. async fn valid_successor( &self, predecessor: &SigningIdentity, successor: &SigningIdentity, _extensions: &ExtensionList, ) -> Result { self.identity_extractor .valid_successor( &credential_to_chain(&predecessor.credential)?, &credential_to_chain(&successor.credential)?, ) .map_err(|e| X509IdentityError::IdentityExtractorError(e.into_any_error())) } async fn validate_external_sender( &self, signing_identity: &SigningIdentity, timestamp: Option, _extensions: Option<&ExtensionList>, ) -> Result<(), Self::Error> { self.validate(signing_identity, timestamp) } /// Supported credential types. /// /// Only [`CredentialType::X509`] is supported. fn supported_types(&self) -> Vec { vec![CredentialType::X509] } } #[cfg(all(test, feature = "std"))] mod tests { use super::*; use mls_rs_core::{crypto::SignaturePublicKey, identity::CredentialType, time::MlsTime}; use crate::{ test_utils::{ test_certificate_chain, test_signing_identity, test_signing_identity_with_chain, TestError, }, MockX509CredentialValidator, MockX509IdentityExtractor, X509IdentityError, X509IdentityProvider, }; use alloc::vec; use assert_matches::assert_matches; #[cfg(target_arch = "wasm32")] use wasm_bindgen_test::wasm_bindgen_test as test; fn test_setup( mut mock_setup: F, ) -> X509IdentityProvider where F: FnMut(&mut MockX509IdentityExtractor, &mut MockX509CredentialValidator), { let mut identity_extractor = MockX509IdentityExtractor::new(); let mut validator = MockX509CredentialValidator::new(); mock_setup(&mut identity_extractor, &mut validator); X509IdentityProvider::new(identity_extractor, validator) } #[test] fn test_supported_types() { let test_provider = test_setup(|_, _| ()); assert_eq!( test_provider.supported_types(), vec![CredentialType::new(2)] ) } #[test] fn test_successful_validation() { let chain = test_certificate_chain(); let test_signing_identity = test_signing_identity_with_chain(chain.clone()); let test_timestamp = MlsTime::now(); let test_provider = test_setup(|_, validator| { let validation_result = test_signing_identity.signature_key.clone(); validator .expect_validate_chain() .once() .with( mockall::predicate::eq(chain.clone()), mockall::predicate::eq(Some(test_timestamp)), ) .return_once_st(|_, _| Ok(validation_result)); }); test_provider .validate(&test_signing_identity, Some(test_timestamp)) .unwrap(); } #[test] fn test_signing_identity_key_mismatch() { let test_signing_identity = test_signing_identity(); let test_provider = test_setup(|_, validator| { let validation_result = SignaturePublicKey::from(vec![42u8; 32]); validator .expect_validate_chain() .return_once_st(|_, _| Ok(validation_result)); }); assert_matches!( test_provider.validate(&test_signing_identity, None), Err(X509IdentityError::SignatureKeyMismatch) ); } #[test] fn test_failing_validation() { let test_provider = test_setup(|_, validator| { validator .expect_validate_chain() .return_once_st(|_, _| Err(TestError)); }); assert_matches!( test_provider.validate(&test_signing_identity(), None), Err(X509IdentityError::X509ValidationError(_)) ) } }