/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ #include #include #include #include #include #include "RefCounted.h" #include "content_decryption_module.h" // Define the include guard of WMFClearKeyCDM.h before including // WMFClearKeySession.cpp so the preprocessor skips the real header when // WMFClearKeySession.cpp includes it from its own directory. This allows the // session source to compile without pulling in the full CDM stack. #define DOM_MEDIA_PLATFORM_WMF_CLEARKEY_WMFCLEARKEYCDM_H namespace mozilla { class WMFClearKeySession; class SessionManagerWrapper : public RefCounted { public: HRESULT GenerateRequest(cdm::InitDataType, const BYTE*, DWORD, cdm::SessionType, WMFClearKeySession*, std::string&) { return E_NOTIMPL; } HRESULT UpdateSession(const std::string&, const BYTE*, DWORD) { return E_NOTIMPL; } HRESULT CloseSession(const std::string&) { return E_NOTIMPL; } HRESULT RemoveSession(const std::string&) { return E_NOTIMPL; } void Shutdown() {} bool IsShutdown() { return false; } private: ~SessionManagerWrapper() override = default; }; } // namespace mozilla #include "../WMFClearKeySession.cpp" using namespace Microsoft::WRL; using namespace mozilla; namespace { class WMFClearKeySessionTest : public ::testing::Test { protected: void SetUp() override { ASSERT_HRESULT_SUCCEEDED( MakeAndInitialize(&mCallbacks)); // Assign from a raw pointer so RefPtr::operator=(T*) runs and correctly // AddRefs the object (avoids a UAF with this RefPtr impl, as this is not // Mozilla's standard RefPtr). SessionManagerWrapper* mgr = new SessionManagerWrapper(); mMgr = mgr; ASSERT_HRESULT_SUCCEEDED(MakeAndInitialize( &mSession, MF_MEDIAKEYSESSION_TYPE_TEMPORARY, mCallbacks.Get(), mgr)); mSession->SetSessionIdForTesting("test-session"); } class MockSessionCallbacks : public RuntimeClass, IMFContentDecryptionModuleSessionCallbacks> { public: STDMETHODIMP KeyMessage(MF_MEDIAKEYSESSION_MESSAGETYPE, const BYTE*, DWORD, LPCWSTR) override { return S_OK; } STDMETHODIMP KeyStatusChanged() override { return S_OK; } }; static void ExpectSingleKeyStatusEquals(WMFClearKeySession* session, const uint8_t* keyId, size_t keyIdSize) { cdm::KeyInformation keyInfo; keyInfo.key_id = keyId; keyInfo.key_id_size = static_cast(keyIdSize); keyInfo.status = cdm::KeyStatus::kUsable; keyInfo.system_code = 0; session->OnKeyStatusChanged(&keyInfo, 1); MFMediaKeyStatus* statuses = nullptr; UINT count = 0; ASSERT_HRESULT_SUCCEEDED(session->GetKeyStatuses(&statuses, &count)); ASSERT_NE(statuses, nullptr); ASSERT_EQ(count, 1u); EXPECT_EQ(statuses[0].cbKeyId, keyIdSize); ASSERT_NE(statuses[0].pbKeyId, nullptr); EXPECT_EQ(memcmp(statuses[0].pbKeyId, keyId, keyIdSize), 0); CoTaskMemFree(statuses[0].pbKeyId); CoTaskMemFree(statuses); } ComPtr mCallbacks; RefPtr mMgr; ComPtr mSession; }; } // namespace TEST_F(WMFClearKeySessionTest, GetKeyStatusesSixteenByteKeyId) { const uint8_t kKeyId[16] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f}; ExpectSingleKeyStatusEquals(mSession.Get(), kKeyId, sizeof(kKeyId)); } TEST_F(WMFClearKeySessionTest, GetKeyStatusesEightByteKeyId) { const uint8_t kKeyId[8] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}; ExpectSingleKeyStatusEquals(mSession.Get(), kKeyId, sizeof(kKeyId)); } TEST_F(WMFClearKeySessionTest, GetKeyStatusesThirtyTwoByteKeyId) { const uint8_t kKeyId[32] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f}; ExpectSingleKeyStatusEquals(mSession.Get(), kKeyId, sizeof(kKeyId)); }