//-------------------------------------------------------------------------------------------------
//
// Copyright (c) Microsoft Corporation. All rights reserved.
//
//-------------------------------------------------------------------------------------------------
//
// TODO: Use OPENSSL_cleanse(buffer, sizeof(buffer)) to clear sensitive data from memory.
#ifdef _MSC_VER
#pragma warning(disable : 4996) // suppress MSVC deprecation warning for std::getenv
#endif
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include "AttestationUtil.h"
#include "Logger.h"
#include "Constants.h"
#ifdef AZURE_LOCAL
#include
#endif
#ifdef _WIN32
#include
#endif
using namespace attest;
using json = nlohmann::json;
bool Util::isTraceOn = false;
int Util::traceLevel = 1;
/// \copydoc Util::base64_to_binary()
std::vector Util::base64_to_binary(const std::string &base64_data)
{
using namespace boost::archive::iterators;
using It = transform_width, 8, 6>;
return boost::algorithm::trim_right_copy_if(std::vector(It(std::begin(base64_data)), It(std::end(base64_data))), [](char c)
{ return c == '\0'; });
}
/// \copydoc Util::binary_to_base64()
std::string Util::binary_to_base64(const std::vector &binary_data)
{
using namespace boost::archive::iterators;
using It = base64_from_binary::const_iterator, 6, 8>>;
auto tmp = std::string(It(std::begin(binary_data)), It(std::end(binary_data)));
return tmp.append((3 - binary_data.size() % 3) % 3, '=');
}
/// \copydoc Util::binary_to_hex()
std::string Util::binary_to_hex(const std::vector &binary_data)
{
std::stringstream ss;
ss << std::hex << std::setfill('0');
for (auto c : binary_data)
{
ss << std::setw(2) << static_cast(c);
}
return ss.str();
}
/// \copydoc Util::hex_to_binary()
std::vector Util::hex_to_binary(const std::string &hex_data)
{
std::vector result;
for (size_t i = 0; i < hex_data.length(); i += 2)
{
std::string byteString = hex_data.substr(i, 2);
BYTE byte = (BYTE)strtol(byteString.c_str(), NULL, 16);
result.push_back(byte);
}
return result;
}
/// \copydoc Util::binary_to_base64url()
std::string Util::binary_to_base64url(const std::vector &binary_data)
{
using namespace boost::archive::iterators;
using It = base64_from_binary::const_iterator, 6, 8>>;
auto tmp = std::string(It(std::begin(binary_data)), It(std::end(binary_data)));
// For encoding to base64url, replace "+" with "-" and "/" with "_"
boost::replace_all(tmp, "+", "-");
boost::replace_all(tmp, "/", "_");
// We do not need to add padding characters while url encoding.
return tmp;
}
/// \copydoc Util::base64url_to_binary()
std::vector Util::base64url_to_binary(const std::string &base64_data)
{
std::string stringData = base64_data;
// While decoding base64 url, replace - with + and _ with + and
// use stanard base64 decode. we dont need to add padding characters. underlying library handles it.
boost::replace_all(stringData, "-", "+");
boost::replace_all(stringData, "_", "/");
return base64_to_binary(stringData);
}
/// \copydoc Util::base64_decode()
std::string Util::base64_decode(const std::string &data)
{
using namespace boost::archive::iterators;
using It = transform_width, 8, 6>;
return boost::algorithm::trim_right_copy_if(std::string(It(std::begin(data)), It(std::end(data))), [](char c)
{ return c == '\0'; });
}
/// \copydoc Util::url_encode()
std::string Util::url_encode(const std::string &data)
{
std::string encoded_str{data};
CURL *curl = curl_easy_init();
if (!curl)
{
TRACE_ERROR_EXIT("curl_easy_init() failed")
}
char *output = curl_easy_escape(curl, data.c_str(), data.length());
if (output)
{
encoded_str = data;
curl_free(output);
}
curl_easy_cleanup(curl);
return encoded_str;
}
///
/// Callback for curl perform operation.
///
size_t Util::CurlWriteCallback(char *data, size_t size, size_t nmemb, std::string *buffer)
{
size_t result = 0;
if (buffer != NULL)
{
buffer->append(data, size * nmemb);
result = size * nmemb;
}
return result;
}
/// Retrieve IMDS token retrieval URL for a resource url.
/// eg, "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fvault.azure.net"};
static inline std::string GetImdsTokenUrl(std::string url)
{
std::ostringstream oss;
oss << Constants::IMDS_TOKEN_URL;
oss << "?api-version=" << Constants::IMDS_API_VERSION;
oss << "&resource=" << Util::url_encode(url);
// Managed id is optional if there is only 1 client id registered for the VM.
auto client_id = std::getenv("IMDS_CLIENT_ID");
if (client_id != nullptr && strlen(client_id) > 0)
{
oss << "&client_id=" << client_id;
}
else
{
auto object_id = std::getenv("IMDS_OBJECT_ID");
if (object_id != nullptr && strlen(object_id) > 0)
{
oss << "&object_id=" << object_id;
}
else
{
// If client id is not provided, msi_res_id (ARM resource id) could be provided.
auto msi_res_id = std::getenv("IMDS_MSI_RES_ID");
if (msi_res_id != nullptr && strlen(msi_res_id) > 0)
{
oss << "&msi_res_id=" << Util::url_encode(msi_res_id);
}
}
}
TRACE_OUT("IMDS token URL: %s", oss.str().c_str());
return oss.str();
}
// Define a utility method to determine the resource URL based on KEKUrl
std::string getResourceUrl(const std::string &KEKUrl, bool isIMDS = true)
{
// Constants for suffixes and corresponding resource URLs
const std::string AKV_URL_SUFFIX = Constants::AKV_URL_SUFFIX;
const std::string MHSM_URL_SUFFIX = Constants::MHSM_URL_SUFFIX;
const std::string AKV_RESOURCE_URL = Constants::AKV_RESOURCE_URL;
const std::string MHSM_RESOURCE_URL = Constants::MHSM_RESOURCE_URL;
// Check if AKV suffix is present in KEKUrl
if (KEKUrl.find(AKV_URL_SUFFIX) != std::string::npos)
{
TRACE_OUT("AKV resource suffix found in KEKUrl");
return isIMDS ? AKV_RESOURCE_URL : AKV_RESOURCE_URL + "/.default";
}
// If AKV suffix is not found, check if MHSM suffix is present
else if (KEKUrl.find(MHSM_URL_SUFFIX) != std::string::npos)
{
TRACE_OUT("MHSM resource suffix found in KEKUrl");
return isIMDS ? MHSM_RESOURCE_URL : MHSM_RESOURCE_URL + "/.default";
}
// If neither AKV nor MHSM suffix is found, throw an error
else
{
THROW_SKR_ERROR(EXIT_USAGE, std::string("Invalid resource suffix found in KEKUrl: " + KEKUrl))
}
}
/// \copydoc Util::GetIMDSToken()
std::string Util::GetIMDSToken(const std::string &KEKUrl)
{
TRACE_OUT("Entering Util::GetIMDSToken()");
CURL *curl = curl_easy_init();
if (!curl)
{
TRACE_ERROR_EXIT("curl_easy_init() failed")
}
// AKV and mHSM has different audience need to be passed to IMDS.
std::string resourceUrl = getResourceUrl(KEKUrl);
CURLcode curlRet = curl_easy_setopt(curl, CURLOPT_URL, GetImdsTokenUrl(resourceUrl).c_str());
if (curlRet != CURLE_OK)
{
TRACE_ERROR_EXIT("curl_easy_setopt() failed")
}
// ByPassing proxy for IMDS.
// ref: https://learn.microsoft.com/en-us/azure/virtual-machines/instance-metadata-service?tabs=windows
curlRet = curl_easy_setopt(curl, CURLOPT_PROXY, "");
if (curlRet != CURLE_OK)
{
std::ostringstream oss;
oss << "curl_easy_setopt() failed: " << curl_easy_strerror(curlRet);
TRACE_ERROR_EXIT(oss.str().c_str())
}
struct curl_slist *headers = NULL;
headers = curl_slist_append(headers, "Metadata: true");
curlRet = curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
if (curlRet != CURLE_OK)
{
TRACE_ERROR_EXIT("curl_easy_setopt() failed\n")
}
curlRet = curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
if (curlRet != CURLE_OK)
{
TRACE_ERROR_EXIT("curl_easy_setopt() failed")
}
std::string responseStr;
curlRet = curl_easy_setopt(curl, CURLOPT_WRITEDATA, &responseStr);
if (curlRet != CURLE_OK)
{
std::ostringstream oss;
oss << "curl_easy_setopt() failed: " << curl_easy_strerror(curlRet);
TRACE_ERROR_EXIT(oss.str().c_str())
}
curlRet = curl_easy_perform(curl);
if (curlRet != CURLE_OK)
{
std::string msg = std::string("IMDS curl_easy_perform() failed: ") + curl_easy_strerror(curlRet);
curl_slist_free_all(headers);
curl_easy_cleanup(curl);
throw skr_error(EXIT_NETWORK_FAIL, msg);
}
// Check HTTP status code
long http_code = 0;
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code);
curl_slist_free_all(headers);
curl_easy_cleanup(curl);
TRACE_OUT("IMDS HTTP status: %ld, Response: %s", http_code, Util::reduct_log(responseStr).c_str());
if (http_code != 200)
{
// Try to extract the error_description from the IMDS error JSON
std::string detail = responseStr;
try
{
json errJson = json::parse(responseStr);
if (errJson.contains("error_description"))
detail = errJson["error_description"].get();
else if (errJson.contains("error"))
detail = errJson["error"].get();
}
catch (...) {} // use raw response if not JSON
std::ostringstream oss;
oss << "IMDS token request failed: HTTP " << http_code << ": " << detail;
throw skr_error(EXIT_AUTH_FAIL, oss.str());
}
json json_object = json::parse(responseStr.c_str());
if (!json_object.contains("access_token"))
{
throw skr_error(EXIT_AUTH_FAIL,
"IMDS response missing 'access_token' field: " + responseStr);
}
std::string access_token = json_object["access_token"].get();
TRACE_OUT("Access Token: %s", Util::reduct_log(access_token).c_str());
TRACE_OUT("Exiting Util::GetIMDSToken()");
return access_token;
}
/// \copydoc Util::GetAADToken()
std::string Util::GetAADToken(const std::string &KEKUrl)
{
TRACE_OUT("Entering Util::GetAADToken()");
auto clientId = std::getenv("AKV_SKR_CLIENT_ID");
auto clientSecret = std::getenv("AKV_SKR_CLIENT_SECRET");
auto tenantId = std::getenv("AKV_SKR_TENANT_ID");
if (!clientId || !clientSecret || !tenantId)
{
throw skr_error(EXIT_AUTH_FAIL,
"AAD service principal env vars not set. "
"Need AKV_SKR_CLIENT_ID, AKV_SKR_CLIENT_SECRET, AKV_SKR_TENANT_ID");
}
std::string resourceUrl = getResourceUrl(KEKUrl, false);
std::string tokenUrl = "https://login.microsoftonline.com/" + std::string(tenantId) + "/oauth2/v2.0/token";
std::string postData = "client_id=" + std::string(clientId) + "&client_secret=" + std::string(clientSecret) + "&grant_type=client_credentials&scope= " + resourceUrl;
CURL *curl = curl_easy_init();
if (!curl)
{
throw skr_error(EXIT_NETWORK_FAIL, "AAD: curl_easy_init() failed");
}
curl_easy_setopt(curl, CURLOPT_URL, tokenUrl.c_str());
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, postData.c_str());
curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, postData.length());
curl_slist *headers = nullptr;
headers = curl_slist_append(headers, "Content-Type: application/x-www-form-urlencoded");
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
std::string response;
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response);
CURLcode result = curl_easy_perform(curl);
if (result != CURLE_OK)
{
std::string msg = std::string("AAD curl_easy_perform() failed: ") + curl_easy_strerror(result);
curl_slist_free_all(headers);
curl_easy_cleanup(curl);
throw skr_error(EXIT_NETWORK_FAIL, msg);
}
long http_code = 0;
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code);
curl_slist_free_all(headers);
curl_easy_cleanup(curl);
TRACE_OUT("AAD HTTP status: %ld, Response: %s", http_code, Util::reduct_log(response).c_str());
json jsonResponse = json::parse(response);
if (jsonResponse.contains("access_token"))
{
std::string token = jsonResponse["access_token"].get();
TRACE_OUT("Response: %s", Util::reduct_log(token).c_str());
TRACE_OUT("Exiting Util::GetAADToken()");
return token;
}
// No access_token — extract the AAD error details
std::string detail = response;
if (jsonResponse.contains("error_description"))
detail = jsonResponse["error_description"].get();
else if (jsonResponse.contains("error"))
detail = jsonResponse["error"].get();
std::ostringstream oss;
oss << "AAD token request failed: HTTP " << http_code << ": " << detail;
throw skr_error(EXIT_AUTH_FAIL, oss.str());
}
/// \copydoc Util::GetMAAToken()
// TODO: attestation server URL can be constructed from VM region if necessary.
std::string Util::GetMAAToken(const std::string &attestation_url, const std::string &nonce)
{
TRACE_OUT("Entering Util::GetMAAToken()");
std::string attest_server_url;
attest_server_url.assign(attestation_url);
if (attest_server_url.empty())
{
// use the default attestation url
attest_server_url.assign(Constants::DEFAULT_ATTESTATION_URL);
}
std::string nonce_token;
nonce_token.assign(nonce);
if (nonce_token.empty())
{
// use some random nonce
nonce_token.assign(Constants::NONCE);
}
AttestationClient *attestation_client = nullptr;
AttestationLogger *log_handle = new Logger(Util::get_trace());
// Initialize attestation client
if (!Initialize(log_handle, &attestation_client))
{
Uninitialize();
throw skr_error(EXIT_ATTEST_FAIL, "Failed to create attestation client object");
}
// parameters for the Attest call
attest::ClientParameters params = {};
params.attestation_endpoint_url = (PBYTE)attest_server_url.c_str();
std::string client_payload_str = "{\"nonce\": \"" + nonce_token + "\"}"; // nonce is optional
params.client_payload = (PBYTE)client_payload_str.c_str();
params.version = CLIENT_PARAMS_VERSION;
PBYTE jwt = nullptr;
attest::AttestationResult result;
bool is_cvm = false;
std::string jwt_str;
if ((result = attestation_client->Attest(params, &jwt)).code_ != attest::AttestationResult::ErrorCode::SUCCESS)
{
std::string errDesc = result.description_.empty() ? "(no description)" : result.description_;
fprintf(stderr, "MAA attestation failed: error code %d, description: %s\n",
static_cast(result.code_), errDesc.c_str());
Uninitialize();
throw skr_error(EXIT_ATTEST_FAIL, "MAA attestation failed: " + errDesc);
}
// Attestation succeeded
jwt_str = std::string(reinterpret_cast(jwt));
std::vector tokens;
boost::split(tokens, jwt_str, [](char c)
{ return c == '.'; });
if (tokens.size() < 3)
{
attestation_client->Free(jwt);
Uninitialize();
throw skr_error(EXIT_ATTEST_FAIL, "MAA returned invalid JWT token (fewer than 3 parts)");
}
json attestation_claims = json::parse(base64_decode(tokens[1]));
try
{
std::string attestation_type = attestation_claims["x-ms-isolation-tee"]["x-ms-attestation-type"].get();
std::string compliance_status = attestation_claims["x-ms-isolation-tee"]["x-ms-compliance-status"].get();
if ((boost::iequals(attestation_type, "sevsnpvm") ||
boost::iequals(attestation_type, "tdxvm")) &&
boost::iequals(compliance_status, "azure-compliant-cvm"))
{
is_cvm = true;
}
}
catch (...)
{
TRACE_OUT("TEE isolation claims not found in token (non-CVM or different token schema)");
}
attestation_client->Free(jwt);
Uninitialize();
TRACE_OUT("MAA attestation succeeded, is_cvm=%d, token length=%zu", is_cvm, jwt_str.length());
TRACE_OUT("Exiting Util::GetMAAToken()");
return jwt_str;
}
/// \copydoc Util::SplitString()
std::vector Util::SplitString(const std::string &str, char delim)
{
TRACE_OUT("Entering Util::SplitString()");
std::vector result;
std::stringstream ss(str);
std::string item;
while (std::getline(ss, item, delim))
{
result.push_back(item);
}
TRACE_OUT("Exiting Util::SplitString()");
return result;
}
/// handle openssl errors
static void handle_openssl_errors(void)
{
TRACE_OUT("Entering handle_openssl_errors()");
std::ostringstream oss;
oss << "OpenSSL error: ";
ERR_print_errors_fp(stderr);
unsigned long error;
while ((error = ERR_get_error()))
{
char error_str[120]{};
ERR_error_string_n(error, error_str, sizeof(error_str));
oss << error_str << "; ";
}
TRACE_OUT("Exiting handle_openssl_errors()");
throw skr_error(EXIT_CRYPTO_FAIL, oss.str());
}
/// Decrypt ciphertext using the key
static int decrypt_aes_key_unwrap(PBYTE key, PBYTE ciphertext, int ciphertext_len, PBYTE plaintext)
{
TRACE_OUT("Entering decrypt_aes_key_unwrap()");
EVP_CIPHER_CTX *ctx;
int len;
int plaintext_len;
/* Create and initialise the context */
if (!(ctx = EVP_CIPHER_CTX_new()))
handle_openssl_errors();
EVP_CIPHER_CTX_set_flags(ctx, EVP_CIPHER_CTX_FLAG_WRAP_ALLOW);
/* Initialise the decryption operation. */
if (1 != EVP_DecryptInit_ex(ctx, EVP_aes_256_wrap_pad(), NULL, NULL, NULL))
handle_openssl_errors();
if (1 != EVP_DecryptInit_ex(ctx, NULL, NULL, key, NULL))
handle_openssl_errors();
// Set padding to PKCS#8
/*if (1 != EVP_CIPHER_CTX_set_padding(ctx, 1)) {
handle_openssl_errors();
}*/
if (1 != EVP_DecryptUpdate(ctx, plaintext, &len, ciphertext, ciphertext_len))
handle_openssl_errors();
plaintext_len = len;
if (1 != EVP_DecryptFinal_ex(ctx, plaintext + len, &len))
handle_openssl_errors();
plaintext_len += len;
EVP_CIPHER_CTX_free(ctx);
TRACE_OUT("Exiting decrypt_aes_key_unwrap()");
return plaintext_len;
}
// Construct URL for secure key release.
// Format: https://.vault.azure.net/keys///release?api-version=7.3
std::string Util::GetKeyVaultSKRurl(const std::string &KEKUrl)
{
TRACE_OUT("Entering Util::GetKeyVaultSKRurl()");
std::ostringstream requestUri;
requestUri << KEKUrl;
requestUri << "/"
<< "release";
requestUri << "?"
<< "api-version";
requestUri << "="
<< "7.3";
TRACE_OUT("Request URI: %s", requestUri.str().c_str());
TRACE_OUT("Exiting Util::GetKeyVaultSKRurl()");
return requestUri.str();
}
#ifdef _WIN32
///
/// Windows-specific implementation of the SKR HTTP POST using WinHTTP.
/// WinHTTP is the preferred HTTP client on Windows — it supports system proxy
/// settings, Kerberos/NTLM auth, and does not require shipping a CA bundle.
///
static std::string GetKeyVaultResponseWinHttp(const std::string &requestUri,
const std::string &access_token,
const std::string &requestBodyStr)
{
TRACE_OUT("Entering GetKeyVaultResponseWinHttp()");
HINTERNET hSession = NULL;
HINTERNET hConnect = NULL;
HINTERNET hRequest = NULL;
std::string responseStr;
try
{
// Parse the URL to extract host and path components
URL_COMPONENTS urlComp = {0};
urlComp.dwStructSize = sizeof(urlComp);
WCHAR szHostName[256] = {0};
WCHAR szUrlPath[1024] = {0};
urlComp.lpszHostName = szHostName;
urlComp.dwHostNameLength = sizeof(szHostName) / sizeof(WCHAR);
urlComp.lpszUrlPath = szUrlPath;
urlComp.dwUrlPathLength = sizeof(szUrlPath) / sizeof(WCHAR);
// Convert URI to wide string
std::wstring wRequestUri(requestUri.begin(), requestUri.end());
TRACE_OUT("Cracking URL: %s", requestUri.c_str());
if (!WinHttpCrackUrl(wRequestUri.c_str(), (DWORD)wRequestUri.length(), 0, &urlComp))
{
std::ostringstream oss;
oss << "WinHttpCrackUrl failed with error: " << GetLastError();
TRACE_ERROR_EXIT(oss.str().c_str())
}
// Initialize WinHTTP session
hSession = WinHttpOpen(L"AzureDiskEncryption/1.0",
WINHTTP_ACCESS_TYPE_DEFAULT_PROXY,
WINHTTP_NO_PROXY_NAME,
WINHTTP_NO_PROXY_BYPASS, 0);
if (!hSession)
{
std::ostringstream oss;
oss << "WinHttpOpen failed with error: " << GetLastError();
TRACE_ERROR_EXIT(oss.str().c_str())
}
// Connect to the server
hConnect = WinHttpConnect(hSession, szHostName, urlComp.nPort, 0);
if (!hConnect)
{
std::ostringstream oss;
oss << "WinHttpConnect failed with error: " << GetLastError();
TRACE_ERROR_EXIT(oss.str().c_str())
}
// Create an HTTP POST request
DWORD dwFlags = (urlComp.nScheme == INTERNET_SCHEME_HTTPS) ? WINHTTP_FLAG_SECURE : 0;
hRequest = WinHttpOpenRequest(hConnect, L"POST", szUrlPath, NULL,
WINHTTP_NO_REFERER, WINHTTP_DEFAULT_ACCEPT_TYPES, dwFlags);
if (!hRequest)
{
std::ostringstream oss;
oss << "WinHttpOpenRequest failed with error: " << GetLastError();
TRACE_ERROR_EXIT(oss.str().c_str())
}
// Add headers individually to avoid issues with very long bearer tokens
std::string authHeader = "Authorization: Bearer " + access_token;
std::wstring wAuthHeader(authHeader.begin(), authHeader.end());
if (!WinHttpAddRequestHeaders(hRequest, wAuthHeader.c_str(), (DWORD)wAuthHeader.length(), WINHTTP_ADDREQ_FLAG_ADD))
{
std::ostringstream oss;
oss << "WinHttpAddRequestHeaders (Authorization) failed: " << GetLastError();
TRACE_ERROR_EXIT(oss.str().c_str())
}
if (!WinHttpAddRequestHeaders(hRequest, L"Content-Type: application/json", -1, WINHTTP_ADDREQ_FLAG_ADD))
{
std::ostringstream oss;
oss << "WinHttpAddRequestHeaders (Content-Type) failed: " << GetLastError();
TRACE_ERROR_EXIT(oss.str().c_str())
}
if (!WinHttpAddRequestHeaders(hRequest, L"Accept: application/json", -1, WINHTTP_ADDREQ_FLAG_ADD))
{
std::ostringstream oss;
oss << "WinHttpAddRequestHeaders (Accept) failed: " << GetLastError();
TRACE_ERROR_EXIT(oss.str().c_str())
}
if (!WinHttpAddRequestHeaders(hRequest, L"User-Agent: AzureDiskEncryption", -1, WINHTTP_ADDREQ_FLAG_ADD))
{
std::ostringstream oss;
oss << "WinHttpAddRequestHeaders (User-Agent) failed: " << GetLastError();
TRACE_ERROR_EXIT(oss.str().c_str())
}
TRACE_OUT("HTTP headers added, sending request with body length: %d", (int)requestBodyStr.length());
// Send the request
BOOL bResults = WinHttpSendRequest(hRequest,
WINHTTP_NO_ADDITIONAL_HEADERS, 0,
(LPVOID)requestBodyStr.c_str(), (DWORD)requestBodyStr.length(),
(DWORD)requestBodyStr.length(), 0);
if (!bResults)
{
std::ostringstream oss;
oss << "WinHttpSendRequest failed with error: " << GetLastError();
TRACE_ERROR_EXIT(oss.str().c_str())
}
// Wait for the response
bResults = WinHttpReceiveResponse(hRequest, NULL);
if (!bResults)
{
std::ostringstream oss;
oss << "WinHttpReceiveResponse failed with error: " << GetLastError();
TRACE_ERROR_EXIT(oss.str().c_str())
}
// Check the HTTP status code
DWORD dwStatusCode = 0;
DWORD dwSize = sizeof(dwStatusCode);
WinHttpQueryHeaders(hRequest, WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER,
WINHTTP_HEADER_NAME_BY_INDEX, &dwStatusCode, &dwSize, WINHTTP_NO_HEADER_INDEX);
// Read the response body
DWORD dwDownloaded = 0;
do
{
dwSize = 0;
if (!WinHttpQueryDataAvailable(hRequest, &dwSize))
{
std::ostringstream oss;
oss << "WinHttpQueryDataAvailable failed with error: " << GetLastError();
TRACE_ERROR_EXIT(oss.str().c_str())
}
if (dwSize > 0)
{
char szBuffer[8192] = {0};
DWORD dwToRead = (dwSize < sizeof(szBuffer) - 1) ? dwSize : (DWORD)(sizeof(szBuffer) - 1);
if (!WinHttpReadData(hRequest, szBuffer, dwToRead, &dwDownloaded))
{
std::ostringstream oss;
oss << "WinHttpReadData failed with error: " << GetLastError();
TRACE_ERROR_EXIT(oss.str().c_str())
}
szBuffer[dwDownloaded] = '\0';
responseStr.append(szBuffer, dwDownloaded);
}
} while (dwSize > 0);
TRACE_OUT("HTTP status=%d, response: %s", dwStatusCode, Util::reduct_log(responseStr).c_str());
if (dwStatusCode != 200)
{
std::ostringstream oss;
oss << "SKR HTTP request failed: HTTP " << dwStatusCode << ": " << responseStr;
throw skr_error(EXIT_SKR_FAIL, oss.str());
}
}
catch (...)
{
// Cleanup on exception and re-throw
if (hRequest) WinHttpCloseHandle(hRequest);
if (hConnect) WinHttpCloseHandle(hConnect);
if (hSession) WinHttpCloseHandle(hSession);
throw;
}
// Cleanup
if (hRequest) WinHttpCloseHandle(hRequest);
if (hConnect) WinHttpCloseHandle(hConnect);
if (hSession) WinHttpCloseHandle(hSession);
TRACE_OUT("Exiting GetKeyVaultResponseWinHttp()");
return responseStr;
}
#endif // _WIN32
#ifndef _WIN32
///
/// Linux (and non-Windows) implementation of the SKR HTTP POST using libcurl.
///
static std::string GetKeyVaultResponseCurl(const std::string &requestUri,
const std::string &access_token,
const std::string &requestBodyStr)
{
TRACE_OUT("Entering GetKeyVaultResponseCurl()");
CURL *curl = curl_easy_init();
if (!curl)
{
TRACE_ERROR_EXIT("curl_easy_init() failed")
}
CURLcode curlRet = curl_easy_setopt(curl, CURLOPT_URL, requestUri.c_str());
if (curlRet != CURLE_OK)
{
TRACE_ERROR_EXIT("curl_easy_setopt() failed for URL")
}
curlRet = curl_easy_setopt(curl, CURLOPT_POST, 1L);
if (curlRet != CURLE_OK)
{
TRACE_ERROR_EXIT("curl_easy_setopt() failed for POST")
}
curlRet = curl_easy_setopt(curl, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_1_1);
if (curlRet != CURLE_OK)
{
TRACE_ERROR_EXIT("curl_easy_setopt() failed for HTTP_VERSION")
}
struct curl_slist *headers = NULL;
std::ostringstream bearerToken;
bearerToken << "Authorization: Bearer " << access_token;
headers = curl_slist_append(headers, bearerToken.str().c_str());
TRACE_OUT("Bearer token: %s", Util::reduct_log(bearerToken.str()).c_str());
headers = curl_slist_append(headers, "Content-Type: application/json");
headers = curl_slist_append(headers, "Accept: application/json");
headers = curl_slist_append(headers, "User-Agent: AzureDiskEncryption");
curlRet = curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
if (curlRet != CURLE_OK)
{
TRACE_ERROR_EXIT("curl_easy_setopt() failed")
}
curlRet = curl_easy_setopt(curl, CURLOPT_POSTFIELDS, requestBodyStr.c_str());
if (curlRet != CURLE_OK)
{
TRACE_ERROR_EXIT("curl_easy_setopt() failed for CURLOPT_POSTFIELDS")
}
curlRet = curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, (long)requestBodyStr.size());
if (curlRet != CURLE_OK)
{
TRACE_ERROR_EXIT("curl_easy_setopt() failed for CURLOPT_POSTFIELDSIZE")
}
char errbuf[CURL_ERROR_SIZE] = {0};
curlRet = curl_easy_setopt(curl, CURLOPT_ERRORBUFFER, errbuf);
if (curlRet != CURLE_OK)
{
size_t len = strlen(errbuf);
std::cerr << "libcurl: " << curlRet << std::endl;
if (len)
std::cerr << errbuf << ((errbuf[len - 1] != '\n') ? "\n" : "");
std::cerr << curl_easy_strerror(curlRet) << std::endl;
TRACE_ERROR_EXIT("curl_easy_setopt() failed for CURLOPT_ERRORBUFFER")
}
curlRet = curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, Util::CurlWriteCallback);
if (curlRet != CURLE_OK)
{
TRACE_ERROR_EXIT("curl_easy_setopt() failed")
}
std::string responseStr;
curlRet = curl_easy_setopt(curl, CURLOPT_WRITEDATA, &responseStr);
if (curlRet != CURLE_OK)
{
std::ostringstream oss;
oss << "curl_easy_setopt() failed: " << curl_easy_strerror(curlRet);
TRACE_ERROR_EXIT(oss.str().c_str())
}
// Perform the request
curlRet = curl_easy_perform(curl);
if (curlRet != CURLE_OK)
{
std::string msg = std::string("SKR curl_easy_perform() failed: ") + curl_easy_strerror(curlRet);
if (strlen(errbuf))
msg += std::string(" (") + errbuf + ")";
curl_slist_free_all(headers);
curl_easy_cleanup(curl);
throw skr_error(EXIT_NETWORK_FAIL, msg);
}
// Check HTTP status — AKV/MHSM errors (403, 404, etc.) come back as valid HTTP responses
long http_code = 0;
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code);
// Cleanup curl
curl_slist_free_all(headers);
curl_easy_cleanup(curl);
TRACE_OUT("SKR HTTP status: %ld, response: %s", http_code, Util::reduct_log(responseStr).c_str());
if (http_code != 200)
{
std::ostringstream oss;
oss << "SKR HTTP request failed: HTTP " << http_code << ": " << responseStr;
throw skr_error(EXIT_SKR_FAIL, oss.str());
}
TRACE_OUT("Exiting GetKeyVaultResponseCurl()");
return responseStr;
}
#endif // !_WIN32
std::string Util::GetKeyVaultResponse(const std::string &requestUri,
const std::string &access_token,
const std::string &attestation_token,
const std::string &nonce)
{
TRACE_OUT("Entering Util::GetKeyVaultResponse()");
std::string nonce_token;
nonce_token.assign(nonce);
if (nonce_token.empty())
{
// use some random nonce
nonce_token.assign(Constants::NONCE);
}
// Build the JSON request body (shared across both implementations)
// RSA_AES_KEY_WRAP_256 uses RSA-OAEP with SHA-256 for the transfer key,
// upgrading from CKM_RSA_AES_KEY_WRAP which used SHA-1.
// See: https://learn.microsoft.com/en-us/rest/api/keyvault/keys/release/release
std::ostringstream requestBody;
requestBody << "{";
requestBody << "\"nonce\": \"" + nonce_token + "\",";
requestBody << "\"target\": \"" << attestation_token << "\",";
requestBody << "\"enc\": \"RSA_AES_KEY_WRAP_256\"";
requestBody << "}";
std::string requestBodyStr(requestBody.str());
TRACE_OUT("SKR wrapping algorithm: RSA_AES_KEY_WRAP_256 (RSA-OAEP-SHA256)");
TRACE_OUT("SKR request URI: %s", requestUri.c_str());
TRACE_OUT("SKR request body length: %zu, target (attestation_token) length: %zu",
requestBodyStr.length(), attestation_token.length());
// Log first 200 chars of body for diagnostics (token is large)
TRACE_OUT("SKR request body prefix: %.200s", requestBodyStr.c_str());
#ifdef _WIN32
std::string result = GetKeyVaultResponseWinHttp(requestUri, access_token, requestBodyStr);
#else
std::string result = GetKeyVaultResponseCurl(requestUri, access_token, requestBodyStr);
#endif
TRACE_OUT("SKR response: %s", Util::reduct_log(result).c_str());
TRACE_OUT("Exiting Util::GetKeyVaultResponse()");
return result;
}
bool Util::doSKR(const std::string &attestation_url,
const std::string &nonce,
std::string KEKUrl,
EVP_PKEY **pkey,
const Util::AkvCredentialSource &akv_credential_source)
{
TRACE_OUT("Entering Util::doSKR()");
try
{
std::string attest_token(Util::GetMAAToken(attestation_url, nonce));
TRACE_OUT("MAA Token: %s", Util::reduct_log(attest_token).c_str());
if (attest_token.empty())
{
throw skr_error(EXIT_ATTEST_FAIL,
"MAA attestation returned an empty token. Cannot proceed with key release.");
}
#ifndef AZURE_LOCAL
// Get Akv access token either using IMDS or Service Principal
std::string access_token;
if (akv_credential_source == Util::AkvCredentialSource::EnvServicePrincipal)
{
access_token = std::move(Util::GetAADToken(KEKUrl));
}
else
{
access_token = std::move(Util::GetIMDSToken(KEKUrl));
}
TRACE_OUT("AkvMsiAccessToken: %s", Util::reduct_log(access_token).c_str());
std::string requestUri = Util::GetKeyVaultSKRurl(KEKUrl);
std::string responseStr = Util::GetKeyVaultResponse(requestUri, access_token, attest_token, nonce);
#else
// On Azure Local, the Evidence SDK handles AKV authentication and key release
// via the host using the cluster identity.
std::string nonce_token = nonce.empty() ? Constants::NONCE : nonce;
// RAII wrapper: ensures hw_evidence_free is called on every exit path
// (success, exception, early return). hw_evidence_free is documented
// to no-op on null, so the initial null state is safe.
std::unique_ptr wrapped_key(
nullptr, &hw_evidence_free);
// release_akv_key writes the allocated buffer pointer into a uint8_t**
// out-parameter, so we need a short-lived raw pointer; we transfer
// ownership to the unique_ptr immediately after the call.
uint8_t* wrapped_key_raw = nullptr;
uint32_t wrapped_key_size = 0;
std::string encryption_algorithm = "RSA_AES_KEY_WRAP_256";
hw_evidence_result skr_result = release_akv_key(
reinterpret_cast(const_cast(KEKUrl.c_str())),
static_cast(KEKUrl.size()),
reinterpret_cast(const_cast(attest_token.c_str())),
static_cast(attest_token.size()),
reinterpret_cast(const_cast(nonce_token.c_str())),
static_cast(nonce_token.size()),
reinterpret_cast(const_cast(encryption_algorithm.c_str())),
static_cast(encryption_algorithm.size()),
&wrapped_key_raw,
&wrapped_key_size);
if (wrapped_key_raw == nullptr || wrapped_key_size == 0)
{
throw skr_error(EXIT_SKR_FAIL, "release_akv_key() did not return a wrapped key");
}
wrapped_key.reset(wrapped_key_raw);
if (skr_result != HW_EVIDENCE_OK)
{
throw skr_error(EXIT_SKR_FAIL, "release_akv_key() failed on Azure Local");
}
std::string responseStr(reinterpret_cast(wrapped_key.get()), wrapped_key_size);
TRACE_OUT("Azure Local SKR response: %s", Util::reduct_log(responseStr).c_str());
#endif
// Parse the response:
json skrJson = json::parse(responseStr.c_str());
std::string skrToken = skrJson["value"];
TRACE_OUT("SKR token: %s", Util::reduct_log(skrToken).c_str());
std::vector tokenParts = Util::SplitString(skrToken, '.');
if (tokenParts.size() != 3)
{
throw skr_error(EXIT_SKR_FAIL, "Invalid SKR token (expected 3 dot-separated parts)");
}
std::vector tokenPayload(Util::base64url_to_binary(tokenParts[1]));
std::string tokenPayloadStr(tokenPayload.begin(), tokenPayload.end());
TRACE_OUT("SKR token payload: %s", Util::reduct_log(tokenPayloadStr).c_str());
json skrPayloadJson = json::parse(tokenPayloadStr.c_str());
std::vector key_hsm = Util::base64url_to_binary(skrPayloadJson["response"]["key"]["key"]["key_hsm"]);
TRACE_OUT("SKR key_hsm: %s", Util::reduct_log(Util::binary_to_base64url(key_hsm)).c_str());
json cipherTextJson = json::parse(key_hsm);
std::vector cipherText = Util::base64url_to_binary(cipherTextJson["ciphertext"]);
TRACE_OUT("Encrypted bytes length: %ld", cipherText.size());
TRACE_OUT("Encrypted bytes: %s", Util::reduct_log(Util::binary_to_base64url(cipherText)).c_str());
AttestationClient *attestation_client = nullptr;
AttestationLogger *log_handle = new Logger(Util::get_trace());
// Initialize attestation client
if (!Initialize(log_handle, &attestation_client))
{
Uninitialize();
// Note: do NOT delete log_handle — Initialize() wraps it in a
// shared_ptr that takes ownership (see AttestationClient.cpp).
throw skr_error(EXIT_SKR_FAIL, "Failed to create attestation client object for TPM decrypt");
}
attest::AttestationResult result;
int RSASize = 2048;
int ModulusSize = RSASize / 8;
uint8_t *decryptedAESBytes = nullptr;
uint32_t decryptedBytesSize = 0;
TRACE_OUT("TPM decrypt: RSA-OAEP with SHA-256 (matching RSA_AES_KEY_WRAP_256)");
result = attestation_client->Decrypt(attest::EncryptionType::NONE,
cipherText.data(),
ModulusSize,
NULL,
0,
&decryptedAESBytes,
&decryptedBytesSize,
attest::RsaScheme::RsaOaep, // RSA-OAEP wrapping
attest::RsaHashAlg::RsaSha256 // SHA-256 to match RSA_AES_KEY_WRAP_256
);
if (result.code_ != attest::AttestationResult::ErrorCode::SUCCESS)
{
std::ostringstream oss;
oss << "Failed to decrypt AES key: error code " << static_cast(result.code_)
<< ", TPM error code=" << result.tpm_error_code_
<< ", Desc=" << result.description_;
fprintf(stderr, "%s\n", oss.str().c_str());
Uninitialize();
throw skr_error(EXIT_CRYPTO_FAIL, oss.str());
}
{
std::vector decryptedAESBytesVec(decryptedAESBytes, decryptedAESBytes + decryptedBytesSize);
TRACE_OUT("Decrypted Transfer key: %s", Util::reduct_log(Util::binary_to_base64url(decryptedAESBytesVec)).c_str());
}
// The remaining bytes are the encrypted CMK bytes with the decrypted AES key.
// use openssl AES to decrypt the CMK bytes.
BYTE private_key[8192];
int private_key_len = 0;
private_key_len = decrypt_aes_key_unwrap(decryptedAESBytes,
cipherText.data() + ModulusSize,
(int)(cipherText.size() - ModulusSize),
private_key);
// Securely zero and free the decrypted AES transfer key
OPENSSL_cleanse(decryptedAESBytes, decryptedBytesSize);
free(decryptedAESBytes);
decryptedAESBytes = nullptr;
if (private_key_len == 0)
{
OPENSSL_cleanse(private_key, sizeof(private_key));
Uninitialize();
throw skr_error(EXIT_CRYPTO_FAIL, "Failed to decrypt the CMK (AES key unwrap returned 0 bytes)");
}
TRACE_OUT("CMK private key has length=%d", private_key_len);
// PKCS#8: parse the decrypted private key material
BIO *bio_key = BIO_new_mem_buf(private_key, private_key_len);
if (!bio_key)
{
OPENSSL_cleanse(private_key, sizeof(private_key));
Uninitialize();
throw skr_error(EXIT_CRYPTO_FAIL, "Error creating memory BIO for private key");
}
*pkey = d2i_PrivateKey_bio(bio_key, NULL);
BIO_free(bio_key);
if (!*pkey)
{
// Collect OpenSSL error details
std::ostringstream oss;
oss << "Failed to parse PKCS#8 private key: ";
unsigned long oerr;
while ((oerr = ERR_get_error()))
{
char buf[120];
ERR_error_string_n(oerr, buf, sizeof(buf));
oss << buf << "; ";
}
OPENSSL_cleanse(private_key, sizeof(private_key));
Uninitialize();
throw skr_error(EXIT_CRYPTO_FAIL, oss.str());
}
TRACE_OUT("Parsed private key: type=%d", EVP_PKEY_base_id(*pkey));
// Securely zero the private key material on the stack
OPENSSL_cleanse(private_key, sizeof(private_key));
// Cleanup attestation client resources
Uninitialize();
// Note: do NOT delete log_handle — Initialize() wraps it in a
// shared_ptr that takes ownership (see AttestationClient.cpp).
TRACE_OUT("Exiting Util::doSKR()");
return true;
}
catch (skr_error &)
{
throw; // let structured errors propagate to main()
}
catch (std::exception &e)
{
// Wrap unexpected exceptions with EXIT_SKR_FAIL
throw skr_error(EXIT_SKR_FAIL, std::string("doSKR failed: ") + e.what());
}
return false;
}
// A helper function to handle errors
void handleErrors()
{
std::ostringstream oss;
oss << "OpenSSL error: ";
unsigned long err;
while ((err = ERR_get_error()))
{
char buf[120];
ERR_error_string_n(err, buf, sizeof(buf));
oss << buf << "; ";
}
ERR_print_errors_fp(stderr);
throw skr_error(EXIT_CRYPTO_FAIL, oss.str());
}
// A function that encrypts a message with a public key using EVP_PKEY_encrypt
/// @brief Map a hash algorithm name to an OpenSSL EVP_MD.
/// Supported names (case-insensitive): sha1, sha256, sha384, sha512.
static const EVP_MD *get_evp_md_by_name(const std::string &name)
{
std::string lower = name;
std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
// Strip optional dashes (e.g. "sha-256" -> "sha256")
lower.erase(std::remove(lower.begin(), lower.end(), '-'), lower.end());
if (lower == "sha1") return EVP_sha1();
if (lower == "sha256") return EVP_sha256();
if (lower == "sha384") return EVP_sha384();
if (lower == "sha512") return EVP_sha512();
throw std::runtime_error("Unsupported hash algorithm: " + name +
". Supported: sha1, sha256, sha384, sha512");
}
int rsa_encrypt(EVP_PKEY *pkey, const PBYTE msg, size_t msglen, PBYTE *enc, size_t *enclen)
{
TRACE_OUT("Entering rsa_encrypt()");
int ret = -1;
EVP_PKEY_CTX *ctx = NULL;
size_t outlen;
// Create the context for the encryption operation
ctx = EVP_PKEY_CTX_new(pkey, NULL);
if (!ctx)
handleErrors();
// Initialize the encryption operation
if (EVP_PKEY_encrypt_init(ctx) <= 0)
handleErrors();
// Set the RSA padding mode to PKCS #1 OAEP
if (EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_OAEP_PADDING) <= 0)
handleErrors();
// Set RSA signature scheme to SHA256
if (EVP_PKEY_CTX_set_rsa_oaep_md(ctx, EVP_sha256()) <= 0)
handleErrors();
// Determine the buffer length for the encrypted data
if (EVP_PKEY_encrypt(ctx, NULL, &outlen, msg, msglen) <= 0)
handleErrors();
// Allocate memory for the encrypted data
*enc = (PBYTE)OPENSSL_malloc(outlen);
if (!*enc)
handleErrors();
// Perform the encryption operation
if (EVP_PKEY_encrypt(ctx, *enc, &outlen, msg, msglen) <= 0)
handleErrors();
// Set the encrypted data length
*enclen = outlen;
// Clean up and return success
ret = 0;
EVP_PKEY_CTX_free(ctx);
TRACE_OUT("Exiting rsa_encrypt()");
return ret;
}
/// @brief RSA-OAEP decrypt with caller-specified hash algorithms.
/// @param oaep_md OAEP hash (e.g. EVP_sha256()). Must not be NULL.
/// @param mgf1_md MGF1 hash. If NULL, defaults to oaep_md.
int rsa_decrypt(EVP_PKEY *pkey, const PBYTE msg, size_t msglen, PBYTE *dec, size_t *declen,
const EVP_MD *oaep_md, const EVP_MD *mgf1_md)
{
TRACE_OUT("Entering rsa_decrypt()");
TRACE_OUT(" OAEP hash: %s, MGF1 hash: %s",
EVP_MD_get0_name(oaep_md),
mgf1_md ? EVP_MD_get0_name(mgf1_md) : EVP_MD_get0_name(oaep_md));
if (mgf1_md == nullptr)
mgf1_md = oaep_md;
int ret = -1;
EVP_PKEY_CTX *ctx = NULL;
size_t outlen;
// Create the context for the decryption operation
ctx = EVP_PKEY_CTX_new(pkey, NULL);
if (!ctx)
handleErrors();
// Initialize the decryption operation
if (EVP_PKEY_decrypt_init(ctx) <= 0)
handleErrors();
// Set the RSA padding mode to PKCS #1 OAEP
if (EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_OAEP_PADDING) <= 0)
handleErrors();
// Set OAEP hash algorithm
if (EVP_PKEY_CTX_set_rsa_oaep_md(ctx, oaep_md) <= 0)
handleErrors();
// Set MGF1 hash algorithm (explicit to avoid platform-dependent defaults)
if (EVP_PKEY_CTX_set_rsa_mgf1_md(ctx, mgf1_md) <= 0)
handleErrors();
// Determine the buffer length for the encrypted data
if (EVP_PKEY_decrypt(ctx, NULL, &outlen, msg, msglen) <= 0)
handleErrors();
// Allocate memory for the encrypted data
*dec = (PBYTE)OPENSSL_malloc(outlen);
if (!*dec)
handleErrors();
// Perform the encryption operation
if (EVP_PKEY_decrypt(ctx, *dec, &outlen, msg, msglen) <= 0)
handleErrors();
// Set the encrypted data length
*declen = outlen;
// Clean up and return success
ret = 0;
EVP_PKEY_CTX_free(ctx);
TRACE_OUT("Exiting rsa_decrypt()");
return ret;
}
/// @brief RSA-OAEP decrypt (exception-safe variant for batch operations).
/// Throws std::runtime_error on failure instead of calling abort().
static void rsa_decrypt_safe(EVP_PKEY *pkey, const PBYTE msg, size_t msglen,
PBYTE *dec, size_t *declen,
const EVP_MD *oaep_md, const EVP_MD *mgf1_md)
{
if (mgf1_md == nullptr)
mgf1_md = oaep_md;
EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new(pkey, NULL);
if (!ctx)
throw std::runtime_error("EVP_PKEY_CTX_new failed");
// Use a lambda to ensure ctx cleanup on any exit path
auto cleanup = [&]() { EVP_PKEY_CTX_free(ctx); };
try
{
if (EVP_PKEY_decrypt_init(ctx) <= 0)
throw std::runtime_error("EVP_PKEY_decrypt_init failed");
if (EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_OAEP_PADDING) <= 0)
throw std::runtime_error("set_rsa_padding failed");
if (EVP_PKEY_CTX_set_rsa_oaep_md(ctx, oaep_md) <= 0)
throw std::runtime_error("set_rsa_oaep_md failed");
if (EVP_PKEY_CTX_set_rsa_mgf1_md(ctx, mgf1_md) <= 0)
throw std::runtime_error("set_rsa_mgf1_md failed");
size_t outlen = 0;
if (EVP_PKEY_decrypt(ctx, NULL, &outlen, msg, msglen) <= 0)
throw std::runtime_error("EVP_PKEY_decrypt (size query) failed");
*dec = (PBYTE)OPENSSL_malloc(outlen);
if (!*dec)
throw std::runtime_error("OPENSSL_malloc failed");
if (EVP_PKEY_decrypt(ctx, *dec, &outlen, msg, msglen) <= 0)
{
OPENSSL_free(*dec);
*dec = nullptr;
throw std::runtime_error("EVP_PKEY_decrypt failed");
}
*declen = outlen;
}
catch (...)
{
cleanup();
throw;
}
cleanup();
}
std::string Util::WrapKey(const std::string &attestation_url,
const std::string &nonce,
const std::string &sym_key,
const std::string &key_enc_key_url,
const Util::AkvCredentialSource &akv_credential_source)
{
TRACE_OUT("Entering Util::WrapKey()");
EVP_PKEY *pkey = nullptr;
if (!Util::doSKR(attestation_url, nonce, key_enc_key_url, &pkey, akv_credential_source))
{
throw skr_error(EXIT_SKR_FAIL, "WrapKey: Failed to release the private key");
}
int pkeyBaseId = EVP_PKEY_base_id(pkey);
TRACE_OUT("Key release completed successfully. EVP_PKEY_base_id=%d", pkeyBaseId);
// Check if the key is of type RSA. If not, exit because EC keys do not support wrapKey/unwrapKey
if (pkeyBaseId != EVP_PKEY_RSA /* PKCS1 */ &&
pkeyBaseId != EVP_PKEY_RSA2 /* X500 */)
{
EVP_PKEY_free(pkey);
throw skr_error(EXIT_CRYPTO_FAIL, "The key is not of type RSA. Only RSA keys are supported for wrapKey/unwrapKey");
}
int rsaSize = EVP_PKEY_get_size(pkey);
TRACE_OUT("Wrapping: %s", Util::reduct_log(sym_key).c_str());
size_t encrypted_length = 0;
PBYTE encryptedKey;
if (rsa_encrypt(pkey, (const PBYTE)sym_key.c_str(), sym_key.size(), &encryptedKey, &encrypted_length) == -1)
{
EVP_PKEY_free(pkey);
handle_openssl_errors(); // throws skr_error(EXIT_CRYPTO_FAIL)
}
TRACE_OUT("Wrapping the symmetric key succeeded: encrypted_length=%ld", encrypted_length);
std::vector encryptedKeyVector(encryptedKey, encryptedKey + encrypted_length);
std::string cipherText = Util::binary_to_base64(encryptedKeyVector);
TRACE_OUT("Wrapped symmetric key in base64: %s", Util::reduct_log(cipherText).c_str());
// Cleanup
OPENSSL_free(encryptedKey);
EVP_PKEY_free(pkey);
TRACE_OUT("Exiting Util::WrapKey()");
return cipherText;
}
std::string Util::UnwrapKey(const std::string &attestation_url,
const std::string &nonce,
const std::string &wrapped_key_base64,
const std::string &key_enc_key_url,
const Util::AkvCredentialSource &akv_credential_source,
const std::string &oaep_hash_alg,
const std::string &mgf1_hash_alg)
{
TRACE_OUT("Entering Util::UnwrapKey()");
TRACE_OUT(" OAEP hash: %s, MGF1 hash: %s",
oaep_hash_alg.c_str(),
mgf1_hash_alg.empty() ? oaep_hash_alg.c_str() : mgf1_hash_alg.c_str());
EVP_PKEY *pkey = nullptr;
if (!Util::doSKR(attestation_url, nonce, key_enc_key_url, &pkey, akv_credential_source))
{
throw skr_error(EXIT_SKR_FAIL, "UnwrapKey: Failed to release the private key");
}
int pkeyBaseId = EVP_PKEY_base_id(pkey);
TRACE_OUT("Key release completed successfully. EVP_PKEY_base_id=%d", pkeyBaseId);
// Check if the key is of type RSA. If not, exit because EC keys do not support wrapKey/unwrapKey
if (pkeyBaseId != EVP_PKEY_RSA /* PKCS1 */ &&
pkeyBaseId != EVP_PKEY_RSA2 /* X500 */)
{
EVP_PKEY_free(pkey);
throw skr_error(EXIT_CRYPTO_FAIL, "The key is not of type RSA. Only RSA keys are supported for wrapKey/unwrapKey");
}
int rsaSize = EVP_PKEY_get_size(pkey);
TRACE_OUT("Unwrapping: %s", wrapped_key_base64.c_str());
std::vector wrapped_key = Util::base64_to_binary(wrapped_key_base64);
TRACE_OUT("RSA key size=%d bytes, wrapped_key decoded size=%zu bytes", rsaSize, wrapped_key.size());
if ((int)wrapped_key.size() != rsaSize)
{
TRACE_OUT("WARNING: wrapped_key size (%zu) != RSA key size (%d). Possible base64 or key-size mismatch.",
wrapped_key.size(), rsaSize);
}
// Resolve hash algorithms
const EVP_MD *oaep_md = get_evp_md_by_name(oaep_hash_alg);
const EVP_MD *mgf1_md = mgf1_hash_alg.empty() ? nullptr : get_evp_md_by_name(mgf1_hash_alg);
size_t decrypted_length = 0;
PBYTE decryptedKey;
if (rsa_decrypt(pkey, wrapped_key.data(), wrapped_key.size(), &decryptedKey, &decrypted_length,
oaep_md, mgf1_md) == -1)
{
EVP_PKEY_free(pkey);
handle_openssl_errors(); // throws skr_error(EXIT_CRYPTO_FAIL)
}
TRACE_OUT("Unwrapping the symmetric key succeeded: decrypted_length=%lud", decrypted_length);
std::vector decryptedKeyVector(decryptedKey, decryptedKey + decrypted_length);
std::string plainText = Util::binary_to_base64(decryptedKeyVector);
TRACE_OUT("Unwrapped symmetric key in base64: %s", Util::reduct_log(plainText).c_str());
TRACE_OUT("Exiting Util::UnwrapKey()");
// Cleanup
OPENSSL_free(decryptedKey);
EVP_PKEY_free(pkey);
return Util::base64_decode(plainText);
}
std::string Util::UnwrapKeyBatch(const std::string &attestation_url,
const std::string &nonce,
const std::string &batch_json,
const std::string &key_enc_key_url,
const Util::AkvCredentialSource &akv_credential_source,
const std::string &oaep_hash_alg,
const std::string &mgf1_hash_alg)
{
TRACE_OUT("Entering Util::UnwrapKeyBatch()");
TRACE_OUT(" OAEP hash: %s, MGF1 hash: %s",
oaep_hash_alg.c_str(),
mgf1_hash_alg.empty() ? oaep_hash_alg.c_str() : mgf1_hash_alg.c_str());
// --- Parse the input JSON ---
json inputJson;
try
{
inputJson = json::parse(batch_json);
}
catch (const json::parse_error &e)
{
throw skr_error(EXIT_USAGE, std::string("UnwrapKeyBatch: Invalid JSON input: ") + e.what());
}
if (!inputJson.contains("keys") || !inputJson["keys"].is_array())
{
throw skr_error(EXIT_USAGE, "UnwrapKeyBatch: JSON must contain a \"keys\" array");
}
const auto &keysArray = inputJson["keys"];
TRACE_OUT("Batch contains %zu keys to unwrap", keysArray.size());
// --- Single SKR call for the entire batch ---
EVP_PKEY *pkey = nullptr;
if (!Util::doSKR(attestation_url, nonce, key_enc_key_url, &pkey, akv_credential_source))
{
throw skr_error(EXIT_SKR_FAIL, "UnwrapKeyBatch: Failed to release the private key");
}
int pkeyBaseId = EVP_PKEY_base_id(pkey);
TRACE_OUT("Key release completed successfully. EVP_PKEY_base_id=%d", pkeyBaseId);
if (pkeyBaseId != EVP_PKEY_RSA && pkeyBaseId != EVP_PKEY_RSA2)
{
EVP_PKEY_free(pkey);
throw skr_error(EXIT_CRYPTO_FAIL, "UnwrapKeyBatch: Released key is not RSA. Only RSA keys are supported for unwrap.");
}
// Resolve hash algorithms once for the whole batch
const EVP_MD *oaep_md = get_evp_md_by_name(oaep_hash_alg);
const EVP_MD *mgf1_md = mgf1_hash_alg.empty() ? nullptr : get_evp_md_by_name(mgf1_hash_alg);
// --- Iterate and unwrap each key ---
json resultsArray = json::array();
int successCount = 0;
int errorCount = 0;
for (size_t i = 0; i < keysArray.size(); ++i)
{
const auto &entry = keysArray[i];
json resultEntry;
// Extract the id (optional — default to index)
std::string id = entry.value("id", std::to_string(i));
resultEntry["id"] = id;
try
{
// "wrapped" field is required
if (!entry.contains("wrapped") || !entry["wrapped"].is_string())
{
throw std::runtime_error("missing or invalid \"wrapped\" field");
}
std::string wrapped_key_base64 = entry["wrapped"].get();
TRACE_OUT("Unwrapping key [%s] (%zu/%zu)", id.c_str(), i + 1, keysArray.size());
std::vector wrapped_key = Util::base64_to_binary(wrapped_key_base64);
size_t decrypted_length = 0;
PBYTE decryptedKey = nullptr;
rsa_decrypt_safe(pkey, wrapped_key.data(), wrapped_key.size(),
&decryptedKey, &decrypted_length, oaep_md, mgf1_md);
std::vector decryptedKeyVector(decryptedKey, decryptedKey + decrypted_length);
std::string plainTextB64 = Util::binary_to_base64(decryptedKeyVector);
OPENSSL_free(decryptedKey);
resultEntry["unwrapped"] = Util::base64_decode(plainTextB64);
++successCount;
}
catch (const std::exception &e)
{
resultEntry["error"] = e.what();
++errorCount;
TRACE_OUT("Key [%s] failed: %s", id.c_str(), e.what());
}
resultsArray.push_back(resultEntry);
}
EVP_PKEY_free(pkey);
TRACE_OUT("Batch unwrap complete: %d succeeded, %d failed out of %zu",
successCount, errorCount, keysArray.size());
TRACE_OUT("Exiting Util::UnwrapKeyBatch()");
json outputJson;
outputJson["results"] = resultsArray;
return outputJson.dump(2); // pretty-print with 2-space indent
}
bool Util::ReleaseKey(const std::string &attestation_url,
const std::string &nonce,
const std::string &key_enc_key_url,
const Util::AkvCredentialSource &akv_credential_source)
{
TRACE_OUT("Entering Util::ReleaseKey()");
EVP_PKEY *pkey = nullptr;
if (!Util::doSKR(attestation_url, nonce, key_enc_key_url, &pkey, akv_credential_source))
{
throw skr_error(EXIT_SKR_FAIL, "Failed to release the private key");
}
TRACE_OUT("Key release completed successfully.");
// Check if the key is of type RSA. If not, exit because EC keys do not support wrapKey/unwrapKey
bool releaseOk = false;
switch (EVP_PKEY_base_id(pkey))
{
case EVP_PKEY_RSA:
case EVP_PKEY_RSA2:
std::cerr << "The released key is of type RSA. It can be used for wrapKey/unwrapKey operations." << std::endl;
releaseOk = true;
break;
case EVP_PKEY_EC:
std::cerr << "The released key is of type EC. It can be used for sign/verify operations." << std::endl;
releaseOk = true;
break;
default:
std::cerr << "The released key is of type " << EVP_PKEY_base_id(pkey) << ". Not sure what operations are supported." << std::endl;
break;
}
EVP_PKEY_free(pkey);
return releaseOk;
}