/* * TLS Handshake IO * (C) 2012,2014,2015 Jack Lloyd * * Botan is released under the Simplified BSD License (see license.txt) */ #include #include #include #include #include #include #include namespace Botan::TLS { namespace { inline size_t load_be24(const uint8_t q[3]) { return make_uint32(0, q[0], q[1], q[2]); } void store_be24(uint8_t out[3], size_t val) { out[0] = get_byte<1>(static_cast(val)); out[1] = get_byte<2>(static_cast(val)); out[2] = get_byte<3>(static_cast(val)); } uint64_t steady_clock_ms() { return std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()) .count(); } } // namespace Protocol_Version Stream_Handshake_IO::initial_record_version() const { return Protocol_Version::TLS_V12; } void Stream_Handshake_IO::add_record(const uint8_t record[], size_t record_len, Record_Type record_type, uint64_t /*sequence_number*/) { if(record_type == Record_Type::Handshake) { m_queue.insert(m_queue.end(), record, record + record_len); } else if(record_type == Record_Type::ChangeCipherSpec) { if(record_len != 1 || record[0] != 1) { throw Decoding_Error("Invalid ChangeCipherSpec"); } // Pretend it's a regular handshake message of zero length const uint8_t ccs_hs[] = {static_cast(Handshake_Type::HandshakeCCS), 0, 0, 0}; m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs)); } else { throw Decoding_Error("Unknown message type " + std::to_string(static_cast(record_type)) + " in handshake processing"); } } std::pair> Stream_Handshake_IO::get_next_record(bool /*expecting_ccs*/) { if(m_queue.size() >= 4) { const size_t length = 4 + make_uint32(0, m_queue[1], m_queue[2], m_queue[3]); if(m_queue.size() >= length) { Handshake_Type type = static_cast(m_queue[0]); if(type == Handshake_Type::None) { throw Decoding_Error("Invalid handshake message type"); } std::vector contents(m_queue.begin() + 4, m_queue.begin() + length); m_queue.erase(m_queue.begin(), m_queue.begin() + length); return std::make_pair(type, contents); } } return std::make_pair(Handshake_Type::None, std::vector()); } std::vector Stream_Handshake_IO::format(const std::vector& msg, Handshake_Type type) const { std::vector send_buf(4 + msg.size()); const size_t buf_size = msg.size(); send_buf[0] = static_cast(type); store_be24(&send_buf[1], buf_size); if(!msg.empty()) { copy_mem(&send_buf[4], msg.data(), msg.size()); } return send_buf; } std::vector Stream_Handshake_IO::send_under_epoch(const Handshake_Message& /*msg*/, uint16_t /*epoch*/) { throw Invalid_State("Not possible to send under arbitrary epoch with stream based TLS"); } std::vector Stream_Handshake_IO::send(const Handshake_Message& msg) { const std::vector msg_bits = msg.serialize(); if(msg.type() == Handshake_Type::HandshakeCCS) { m_send_hs(Record_Type::ChangeCipherSpec, msg_bits); return std::vector(); // not included in handshake hashes } auto buf = format(msg_bits, msg.wire_type()); m_send_hs(Record_Type::Handshake, buf); return buf; } Protocol_Version Datagram_Handshake_IO::initial_record_version() const { return Protocol_Version::DTLS_V12; } void Datagram_Handshake_IO::retransmit_last_flight() { const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2); retransmit_flight(flight_idx); } void Datagram_Handshake_IO::retransmit_flight(size_t flight_idx) { const std::vector& flight = m_flights.at(flight_idx); BOTAN_ASSERT(!flight.empty(), "Nonempty flight to retransmit"); uint16_t epoch = m_flight_data[flight[0]].epoch; for(auto msg_seq : flight) { auto& msg = m_flight_data[msg_seq]; if(msg.epoch != epoch) { // Epoch gap: insert the CCS std::vector ccs(1, 1); m_send_hs(epoch, Record_Type::ChangeCipherSpec, ccs); } send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits); epoch = msg.epoch; } } bool Datagram_Handshake_IO::have_more_data() const { return false; } bool Datagram_Handshake_IO::timeout_check() { if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty())) { /* If we haven't written anything yet obviously no timeout. Also no timeout possible if we are mid-flight, */ return false; } const uint64_t ms_since_write = steady_clock_ms() - m_last_write; if(ms_since_write < m_next_timeout) { return false; } retransmit_last_flight(); m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout); return true; } void Datagram_Handshake_IO::add_record(const uint8_t record[], size_t record_len, Record_Type record_type, uint64_t record_sequence) { const uint16_t epoch = static_cast(record_sequence >> 48); if(record_type == Record_Type::ChangeCipherSpec) { if(record_len != 1 || record[0] != 1) { throw Decoding_Error("Invalid ChangeCipherSpec"); } // TODO: check this is otherwise empty m_ccs_epochs.insert(epoch); return; } const size_t DTLS_HANDSHAKE_HEADER_LEN = 12; while(record_len) { if(record_len < DTLS_HANDSHAKE_HEADER_LEN) { return; // completely bogus? at least degenerate/weird } const Handshake_Type msg_type = static_cast(record[0]); const size_t msg_len = load_be24(&record[1]); const uint16_t message_seq = load_be(&record[4], 0); const size_t fragment_offset = load_be24(&record[6]); const size_t fragment_length = load_be24(&record[9]); const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length; if(record_len < total_size) { throw Decoding_Error("Bad lengths in DTLS header"); } if(message_seq >= m_in_message_seq) { m_messages[message_seq].add_fragment( &record[DTLS_HANDSHAKE_HEADER_LEN], fragment_length, fragment_offset, epoch, msg_type, msg_len); } else { // TODO: detect retransmitted flight } record += total_size; record_len -= total_size; } } std::pair> Datagram_Handshake_IO::get_next_record(bool expecting_ccs) { // Expecting a message means the last flight is concluded if(!m_flights.rbegin()->empty()) { m_flights.push_back(std::vector()); } if(expecting_ccs) { if(!m_messages.empty()) { const uint16_t current_epoch = m_messages.begin()->second.epoch(); if(m_ccs_epochs.contains(current_epoch)) { return std::make_pair(Handshake_Type::HandshakeCCS, std::vector()); } } return std::make_pair(Handshake_Type::None, std::vector()); } auto i = m_messages.find(m_in_message_seq); if(i == m_messages.end() || !i->second.complete()) { return std::make_pair(Handshake_Type::None, std::vector()); } m_in_message_seq += 1; return i->second.message(); } void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(const uint8_t fragment[], size_t fragment_length, size_t fragment_offset, uint16_t epoch, Handshake_Type msg_type, size_t msg_length) { if(complete()) { return; // already have entire message, ignore this } if(m_msg_type == Handshake_Type::None) { m_epoch = epoch; m_msg_type = msg_type; m_msg_length = msg_length; } if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch) { throw Decoding_Error("Inconsistent values in fragmented DTLS handshake header"); } if(fragment_offset > m_msg_length) { throw Decoding_Error("Fragment offset past end of message"); } if(fragment_offset + fragment_length > m_msg_length) { throw Decoding_Error("Fragment overlaps past end of message"); } if(fragment_offset == 0 && fragment_length == m_msg_length) { m_fragments.clear(); m_message.assign(fragment, fragment + fragment_length); } else { /* * FIXME. This is a pretty lame way to do defragmentation, huge * overhead with a tree node per byte. * * Also should confirm that all overlaps have no changes, * otherwise we expose ourselves to the classic fingerprinting * and IDS evasion attacks on IP fragmentation. */ for(size_t i = 0; i != fragment_length; ++i) { m_fragments[fragment_offset + i] = fragment[i]; } if(m_fragments.size() == m_msg_length) { m_message.resize(m_msg_length); for(size_t i = 0; i != m_msg_length; ++i) { m_message[i] = m_fragments[i]; } m_fragments.clear(); } } } bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const { return (m_msg_type != Handshake_Type::None && m_message.size() == m_msg_length); } std::pair> Datagram_Handshake_IO::Handshake_Reassembly::message() const { if(!complete()) { throw Internal_Error("Datagram_Handshake_IO - message not complete"); } return std::make_pair(m_msg_type, m_message); } std::vector Datagram_Handshake_IO::format_fragment(const uint8_t fragment[], size_t frag_len, uint16_t frag_offset, uint16_t msg_len, Handshake_Type type, uint16_t msg_sequence) const { std::vector send_buf(12 + frag_len); send_buf[0] = static_cast(type); store_be24(&send_buf[1], msg_len); store_be(msg_sequence, &send_buf[4]); store_be24(&send_buf[6], frag_offset); store_be24(&send_buf[9], frag_len); if(frag_len > 0) { copy_mem(&send_buf[12], fragment, frag_len); } return send_buf; } std::vector Datagram_Handshake_IO::format_w_seq(const std::vector& msg, Handshake_Type type, uint16_t msg_sequence) const { return format_fragment(msg.data(), msg.size(), 0, static_cast(msg.size()), type, msg_sequence); } std::vector Datagram_Handshake_IO::format(const std::vector& msg, Handshake_Type type) const { return format_w_seq(msg, type, m_in_message_seq - 1); } std::vector Datagram_Handshake_IO::send(const Handshake_Message& msg) { return this->send_under_epoch(msg, m_seqs.current_write_epoch()); } std::vector Datagram_Handshake_IO::send_under_epoch(const Handshake_Message& msg, uint16_t epoch) { const std::vector msg_bits = msg.serialize(); const Handshake_Type msg_type = msg.type(); if(msg_type == Handshake_Type::HandshakeCCS) { m_send_hs(epoch, Record_Type::ChangeCipherSpec, msg_bits); return std::vector(); // not included in handshake hashes } else if(msg_type == Handshake_Type::HelloVerifyRequest) { // This message is not included in the handshake hashes send_message(m_out_message_seq, epoch, msg_type, msg_bits); m_out_message_seq += 1; return std::vector(); } // Note: not saving CCS, instead we know it was there due to change in epoch m_flights.rbegin()->push_back(m_out_message_seq); m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits); m_out_message_seq += 1; m_last_write = steady_clock_ms(); m_next_timeout = m_initial_timeout; return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits); } std::vector Datagram_Handshake_IO::send_message(uint16_t msg_seq, uint16_t epoch, Handshake_Type msg_type, const std::vector& msg_bits) { const size_t DTLS_HANDSHAKE_HEADER_LEN = 12; auto no_fragment = format_w_seq(msg_bits, msg_type, msg_seq); if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu) { m_send_hs(epoch, Record_Type::Handshake, no_fragment); } else { size_t frag_offset = 0; /** * Largest possible overhead is for SHA-384 CBC ciphers, with 16 byte IV, * 16+ for padding and 48 bytes for MAC. 128 is probably a strict * over-estimate here. When CBC ciphers are removed this can be reduced * since AEAD modes have no padding, at most 16 byte mac, and smaller * per-record nonce. */ const size_t ciphersuite_overhead = (epoch > 0) ? 128 : 0; const size_t header_overhead = DTLS_HEADER_SIZE + DTLS_HANDSHAKE_HEADER_LEN; if(m_mtu <= (header_overhead + ciphersuite_overhead)) { throw Invalid_Argument("DTLS MTU is too small to send headers"); } const size_t max_rec_size = m_mtu - (header_overhead + ciphersuite_overhead); while(frag_offset != msg_bits.size()) { const size_t frag_len = std::min(msg_bits.size() - frag_offset, max_rec_size); const std::vector frag = format_fragment(&msg_bits[frag_offset], frag_len, static_cast(frag_offset), static_cast(msg_bits.size()), msg_type, msg_seq); m_send_hs(epoch, Record_Type::Handshake, frag); frag_offset += frag_len; } } return no_fragment; } } // namespace Botan::TLS