import argparse import struct import sys from twisted.python import log from twisted.internet import defer, protocol, reactor, ssl from twisted.internet.endpoints import ( TCP4ClientEndpoint, TCP4ServerEndpoint, connectProtocol) class BufferedProtocol(protocol.Protocol, object): def __init__(self): self._buffer_enabled = True self._buffer = b'' self._deferred = None self._wait_length = 0 super(BufferedProtocol, self).__init__() def _read_buffer(self, length): assert len(self._buffer) >= length if length == 0: # return the entire buffer data, self._buffer = self._buffer, b'' return data else: data, self._buffer = self._buffer[:length], self._buffer[length:] return data def stop_buffering(self): assert self._deferred is None self._buffer_enabled = False if len(self._buffer): self.rawDataReceived(self._buffer) self._buffer = b'' def waitfor(self, length): assert isinstance(length, int) assert self._deferred is None if len(self._buffer) >= length: return defer.succeed(self._read_buffer(length)) else: self._wait_length = length self._deferred = defer.Deferred() return self._deferred def connectionLost(self, reason): log.msg('{} connection lost: {}'.format(self.__class__, reason)) if self._deferred: self._deferred.errback(reason) def dataReceived(self, data): log.msg('{} received: {!r}'.format(self.__class__, data)) if not self._buffer_enabled: self.rawDataReceived(data) else: self._buffer += data if self._deferred and len(self._buffer) >= self._wait_length: data = self._read_buffer(self._wait_length) d = self._deferred self._deferred = None self._wait_length = 0 d.callback(data) def rawDataReceived(self, data): raise NotImplementedError() class MySQLForwardBaseProtocol(BufferedProtocol): SSL_FLAG = 0x00000800 def __init__(self, peer=None): self.peer = peer super(MySQLForwardBaseProtocol, self).__init__() @staticmethod def parse_header(packet): assert len(packet) >= 4 header = packet[:4] (length_s, seq_id) = struct.unpack('<3sB', header) (length,) = struct.unpack('= 4 (seq,) = struct.unpack(' server packet needs to have its seq_id incremented; # server -> client packets need to have them decremented client_handshake = self.modify_seq(client_handshake, 1) self.peer.transport.write(client_handshake) # send packets back and forth until the client resets # the sequence numbers yield self.forward_until_seq_reset() # start forwarding traffic with no further manipulation self.stop_buffering() self.peer.stop_buffering() except Exception: self.transport.loseConnection() log.err() @defer.inlineCallbacks def forward_until_seq_reset(self): # the first packet we forward is server-to-client increment = -1 source, dest = self.peer, self while True: # read packet next_packet = yield source.read_packet() _length, seq = self.parse_header(next_packet) # increment or decrement sequence number if it was not reset if seq != 0: next_packet = self.modify_seq(next_packet, increment) dest.transport.write(next_packet) # if sequence number was reset, we're done if seq == 0: break # reverse direction! increment = -increment source, dest = dest, source class MySQLForwardClientProtocol(MySQLForwardBaseProtocol): pass class MySQLForwardServerFactory(protocol.Factory): protocol = MySQLForwardServerProtocol def __init__(self, dest_host, dest_port): self.dest = TCP4ClientEndpoint(reactor, dest_host, dest_port) def main(): p = argparse.ArgumentParser() p.add_argument('-p', '--listen-port', type=int, default=3306) p.add_argument('-i', '--listen-interface', default='127.0.0.1') p.add_argument('dest') args = p.parse_args() if ':' in args.dest: dest_host, dest_port = args.dest.split(':') dest_port = int(dest_port) else: dest_host = args.dest dest_port = 3306 log.startLogging(sys.stdout) log.msg('listen: {}:{}; connect: {}:{}'.format( args.listen_interface, args.listen_port, dest_host, dest_port)) endpoint = TCP4ServerEndpoint( reactor, args.listen_port, interface=args.listen_interface) endpoint.listen(MySQLForwardServerFactory(dest_host, dest_port)) reactor.run() if __name__ == '__main__': main()