diff --git a/twisted/internet/abstract.py b/twisted/internet/abstract.py index 6321239..ad6c8b9 100644 --- a/twisted/internet/abstract.py +++ b/twisted/internet/abstract.py @@ -32,6 +32,7 @@ class FileDescriptor(log.Logger, styles.Ephemeral, object): disconnecting = 0 _writeDisconnecting = False _writeDisconnected = False + _postLoseConnectionReason = failure.Failure(main.CONNECTION_DONE) dataBuffer = "" offset = 0 @@ -154,7 +155,7 @@ class FileDescriptor(log.Logger, styles.Ephemeral, object): Whatever this returns is then returned by doWrite. """ # default implementation, telling reactor we're finished - return main.CONNECTION_DONE + return self._postLoseConnectionReason def _closeWriteConnection(self): # override in subclasses @@ -244,12 +245,24 @@ class FileDescriptor(log.Logger, styles.Ephemeral, object): else: self.stopReading() self.startWriting() + self._postLoseConnectionReason = _connDone self.disconnecting = 1 def loseWriteConnection(self): self._writeDisconnecting = True self.startWriting() + def abortConnection(self): + """ + Unregister the producer, empty the send buffer, and abruptly abort the + connection. + """ + self.unregisterProducer() + self.dataBuffer = "" + self._tempDataBuffer = [] + self.offset = 0 + self.loseConnection(failure.Failure(main.CONNECTION_LOST)) + def stopReading(self): """Stop waiting for read availability. diff --git a/twisted/internet/interfaces.py b/twisted/internet/interfaces.py index cb3438c..34712c8 100644 --- a/twisted/internet/interfaces.py +++ b/twisted/internet/interfaces.py @@ -1389,6 +1389,14 @@ class ITCPTransport(ITransport): producer. """ + def abortConnection(): + """ + Close the connection abruptly. + + Discards any buffered data, unregisters any producer, and, if + possible, notifies the other end of the unclean closure. + """ + def getTcpNoDelay(): """ Return if C{TCP_NODELAY} is enabled. diff --git a/twisted/internet/tcp.py b/twisted/internet/tcp.py index 41ed68a..d9fa9de 100644 --- a/twisted/internet/tcp.py +++ b/twisted/internet/tcp.py @@ -17,6 +17,7 @@ import types import socket import sys import operator +import struct from zope.interface import implements, classImplements @@ -83,15 +84,23 @@ from twisted.internet import abstract, main, interfaces, error class _SocketCloser: _socketShutdownMethod = 'shutdown' - def _closeSocket(self): - # socket.close() doesn't *really* close if there's another reference - # to it in the TCP/IP stack, e.g. if it was was inherited by a - # subprocess. And we really do want to close the connection. So we - # use shutdown() instead, and then close() in order to release the - # filedescriptor. + def _closeSocket(self, orderly): + # The call to shutdown() before close() isn't really necessary, because + # we set FD_CLOEXEC now, which will ensure this is the only process + # holding the FD, thus ensuring close() really will shutdown the TCP + # socket. However, do it anyways, just to be safe. But when doing + # a non-orderly shutdown, it needs to be done via close only. skt = self.socket try: - getattr(skt, self._socketShutdownMethod)(2) + if orderly: + getattr(skt, self._socketShutdownMethod)(2) + else: + # Set SO_LINGER to 1,0 which, by convention, causes a + # connection reset to be sent when close is called, + # instead of the standard FIN shutdown sequence. + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, + struct.pack("ii", 1, 0)) + except socket.error: pass try: @@ -218,7 +227,7 @@ class _TLSMixin: # already received a TLS close alert from the peer. Why do # this??? self.socket.set_shutdown(SSL.RECEIVED_SHUTDOWN) - return self._sendCloseAlert() + return self._sendCloseAlert()[0] def _sendCloseAlert(self): @@ -246,9 +255,9 @@ class _TLSMixin: os.write(self.socket.fileno(), '') except OSError, se: if se.args[0] in (EINTR, EWOULDBLOCK, ENOBUFS): - return 0 + return 0, False # Write error, socket gone - return main.CONNECTION_LOST + return main.CONNECTION_LOST, False try: if hasattr(self.socket, 'set_shutdown'): @@ -263,12 +272,11 @@ class _TLSMixin: self.socket.shutdown() done = True except SSL.Error, e: - return e + return e, False if done: self.stopWriting() - # Note that this is tested for by identity below. - return main.CONNECTION_DONE + return self._postLoseConnectionReason, True else: # For some reason, the close alert wasn't sent. Start writing # again so that we'll get another chance to send it. @@ -284,12 +292,12 @@ class _TLSMixin: # Linux, it doesn't implement select in terms of poll and then map # POLLHUP to select's in fd_set). self.startReading() - return None + return None, False def _closeWriteConnection(self): - result = self._sendCloseAlert() + result, shouldClose = self._sendCloseAlert() - if result is main.CONNECTION_DONE: + if shouldClose: return Connection._closeWriteConnection(self) return result @@ -511,7 +519,7 @@ class Connection(abstract.FileDescriptor, _SocketCloser): """See abstract.FileDescriptor.connectionLost(). """ abstract.FileDescriptor.connectionLost(self, reason) - self._closeSocket() + self._closeSocket(reason.value.__class__ == error.ConnectionDone) protocol = self.protocol del self.protocol del self.socket @@ -587,7 +595,7 @@ class BaseClient(Connection): del self.connector try: - self._closeSocket() + self._closeSocket(True) except AttributeError: pass else: @@ -978,7 +986,7 @@ class Port(base.BasePort, _SocketCloser): base.BasePort.connectionLost(self, reason) self.connected = False - self._closeSocket() + self._closeSocket(True) del self.socket del self.fileno diff --git a/twisted/test/test_ssl.py b/twisted/test/test_ssl.py index 7df6ca9..443ab09 100644 --- a/twisted/test/test_ssl.py +++ b/twisted/test/test_ssl.py @@ -6,12 +6,14 @@ Tests for twisted SSL support. """ from twisted.trial import unittest -from twisted.internet import protocol, reactor, interfaces, defer +from twisted.internet import protocol, reactor, interfaces, defer, error from twisted.protocols import basic from twisted.python import util from twisted.python.reflect import getClass, fullyQualifiedName from twisted.python.runtime import platform -from twisted.test.test_tcp import WriteDataTestCase, ProperlyCloseFilesMixin +from twisted.test.test_tcp import ( + WriteDataTestCase, ProperlyCloseFilesMixin, AbortServerFactory, + AbortingClient) import os, errno @@ -653,6 +655,34 @@ class ClientContextFactoryTests(unittest.TestCase): +class AbortConnectionTest(unittest.TestCase, ContextGeneratingMixin): + ## Hrmuf. This test is pretty much a copy of the same named one in twisted.test.test_tcp. + def test_abort(self): + org = "twisted.test.test_ssl" + self.setupServerAndClient( + (org, org + ", client"), {}, + (org, org + ", server"), {}) + + serverDoneDeferred = defer.Deferred() + clientDoneDeferred = defer.Deferred() + + server = AbortServerFactory(serverDoneDeferred) + serverport = reactor.listenSSL(0, server, self.serverCtxFactory, interface="127.0.0.1") + self.addCleanup(serverport.stopListening) + + c = protocol.ClientCreator(reactor, AbortingClient, clientDoneDeferred) + d = c.connectSSL(serverport.getHost().host, serverport.getHost().port, self.clientCtxFactory) + + def checkConnectionLost(result): + result.trap(error.ConnectionLost) + serverDoneDeferred.addBoth(checkConnectionLost) + clientDoneDeferred.addBoth(checkConnectionLost) + d.addCallback(lambda x: clientDoneDeferred) + d.addCallback(lambda x: serverDoneDeferred) + return d + + + if interfaces.IReactorSSL(reactor, None) is None: for tCase in [StolenTCPTestCase, TLSTestCase, SpammyTLSTestCase, BufferingTestCase, ConnectionLostTestCase, diff --git a/twisted/test/test_tcp.py b/twisted/test/test_tcp.py index 1b7b463..397d6b7 100644 --- a/twisted/test/test_tcp.py +++ b/twisted/test/test_tcp.py @@ -1899,6 +1899,61 @@ class CallBackOrderTestCase(unittest.TestCase): +class AbortServerProtocol(protocol.Protocol): + def dataReceived(self, data): + raise Exception("Unexpectedly received data.") + + def connectionLost(self, reason): + self.factory.done.callback(reason) + + + +class AbortServerFactory(protocol.ServerFactory): + protocol = AbortServerProtocol + + def __init__(self, done): + self.done = done + + + +class AbortingClient(protocol.Protocol): + def __init__(self, done): + self.done = done + + def connectionMade(self): + reactor.callLater(0, self.writeAndAbort) + + def writeAndAbort(self): + self.transport.write("X"*2000000) + self.transport.abortConnection() + + def connectionLost(self, reason): + self.done.callback(reason) + + + +class AbortConnectionTest(unittest.TestCase): + def test_abort(self): + serverDoneDeferred = defer.Deferred() + clientDoneDeferred = defer.Deferred() + + server = AbortServerFactory(serverDoneDeferred) + serverport = reactor.listenTCP(0, server, interface="127.0.0.1") + self.addCleanup(serverport.stopListening) + + c = protocol.ClientCreator(reactor, AbortingClient, clientDoneDeferred) + d = c.connectTCP(serverport.getHost().host, serverport.getHost().port) + + def checkConnectionLost(result): + result.trap(error.ConnectionLost) + serverDoneDeferred.addBoth(checkConnectionLost) + clientDoneDeferred.addBoth(checkConnectionLost) + d.addCallback(lambda x: clientDoneDeferred) + d.addCallback(lambda x: serverDoneDeferred) + return d + + + try: import resource except ImportError: