=== modified file 'twisted/internet/abstract.py' --- twisted/internet/abstract.py 2008-11-19 17:01:07 +0000 +++ twisted/internet/abstract.py 2009-04-22 14:55:47 +0000 @@ -250,6 +250,12 @@ self._writeDisconnecting = True self.startWriting() + def abortConnection(self): + """Aborts the connection -- default implementation calls + self.connectionLost(failure.Failure(main.CONNECTION_LOST)). + """ + self.connectionLost(failure.Failure(main.CONNECTION_LOST)) + def stopReading(self): """Stop waiting for read availability. === modified file 'twisted/internet/interfaces.py' --- twisted/internet/interfaces.py 2009-01-22 13:01:46 +0000 +++ twisted/internet/interfaces.py 2009-04-22 15:01:30 +0000 @@ -1385,6 +1385,14 @@ producer. """ + def abortConnection(): + """ + Close the connection abruptly. + + Discards any buffered data, unregisters any producer, and, if + possible, notifies the other end of the unclean closure. + """ + def getTcpNoDelay(): """ Return if C{TCP_NODELAY} is enabled. === modified file 'twisted/internet/tcp.py' --- twisted/internet/tcp.py 2009-02-26 17:26:02 +0000 +++ twisted/internet/tcp.py 2009-04-22 14:55:47 +0000 @@ -17,6 +17,7 @@ import socket import sys import operator +import struct from zope.interface import implements, classImplements @@ -83,15 +84,23 @@ class _SocketCloser: _socketShutdownMethod = 'shutdown' - def _closeSocket(self): - # socket.close() doesn't *really* close if there's another reference - # to it in the TCP/IP stack, e.g. if it was was inherited by a - # subprocess. And we really do want to close the connection. So we - # use shutdown() instead, and then close() in order to release the - # filedescriptor. + def _closeSocket(self, orderly): + # The call to shutdown() before close() isn't really necessary, because + # we set FD_CLOEXEC now, which will ensure this is the only process + # holding the FD, thus ensuring close() really will shutdown the TCP + # socket. However, do it anyways, just to be safe. But when doing + # a non-orderly shutdown, it needs to be done via close only. skt = self.socket try: - getattr(skt, self._socketShutdownMethod)(2) + if orderly: + getattr(skt, self._socketShutdownMethod)(2) + else: + # Set SO_LINGER to 1,0 which, by convention, causes a + # connection reset to be sent when close is called, + # instead of the standard FIN shutdown sequence. + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, + struct.pack("ii", 1, 0)) + except socket.error: pass try: @@ -511,7 +520,7 @@ """See abstract.FileDescriptor.connectionLost(). """ abstract.FileDescriptor.connectionLost(self, reason) - self._closeSocket() + self._closeSocket(reason.value.__class__ == error.ConnectionDone) protocol = self.protocol del self.protocol del self.socket @@ -587,7 +596,7 @@ del self.connector try: - self._closeSocket() + self._closeSocket(True) except AttributeError: pass else: @@ -978,7 +987,7 @@ base.BasePort.connectionLost(self, reason) self.connected = False - self._closeSocket() + self._closeSocket(True) del self.socket del self.fileno === modified file 'twisted/test/test_ssl.py' --- twisted/test/test_ssl.py 2009-03-24 17:52:07 +0000 +++ twisted/test/test_ssl.py 2009-04-22 14:55:47 +0000 @@ -6,11 +6,13 @@ """ from twisted.trial import unittest -from twisted.internet import protocol, reactor, interfaces, defer +from twisted.internet import protocol, reactor, interfaces, defer, error from twisted.protocols import basic from twisted.python import util from twisted.python.runtime import platform -from twisted.test.test_tcp import WriteDataTestCase, ProperlyCloseFilesMixin +from twisted.test.test_tcp import ( + WriteDataTestCase, ProperlyCloseFilesMixin, AbortServerFactory, + AbortingClient) import os, errno @@ -642,6 +644,34 @@ +class AbortConnectionTest(unittest.TestCase, ContextGeneratingMixin): + ## Hrmuf. This test is pretty much a copy of the same named one in twisted.test.test_tcp. + def test_abort(self): + org = "twisted.test.test_ssl" + self.setupServerAndClient( + (org, org + ", client"), {}, + (org, org + ", server"), {}) + + serverDoneDeferred = defer.Deferred() + clientDoneDeferred = defer.Deferred() + + server = AbortServerFactory(serverDoneDeferred) + serverport = reactor.listenSSL(0, server, self.serverCtxFactory, interface="127.0.0.1") + self.addCleanup(serverport.stopListening) + + c = protocol.ClientCreator(reactor, AbortingClient, clientDoneDeferred) + d = c.connectSSL(serverport.getHost().host, serverport.getHost().port, self.clientCtxFactory) + + def checkConnectionLost(result): + result.trap(error.ConnectionLost) + serverDoneDeferred.addBoth(checkConnectionLost) + clientDoneDeferred.addBoth(checkConnectionLost) + d.addCallback(lambda x: clientDoneDeferred) + d.addCallback(lambda x: serverDoneDeferred) + return d + + + if interfaces.IReactorSSL(reactor, None) is None: for tCase in [StolenTCPTestCase, TLSTestCase, SpammyTLSTestCase, BufferingTestCase, ConnectionLostTestCase, === modified file 'twisted/test/test_tcp.py' --- twisted/test/test_tcp.py 2008-09-12 21:32:29 +0000 +++ twisted/test/test_tcp.py 2009-04-22 14:55:47 +0000 @@ -1923,6 +1923,61 @@ +class AbortServerProtocol(protocol.Protocol): + def dataReceived(self, data): + raise Exception("Unexpectedly received data.") + + def connectionLost(self, reason): + self.factory.done.callback(reason) + + + +class AbortServerFactory(protocol.ServerFactory): + protocol = AbortServerProtocol + + def __init__(self, done): + self.done = done + + + +class AbortingClient(protocol.Protocol): + def __init__(self, done): + self.done = done + + def connectionMade(self): + reactor.callLater(0, self.writeAndAbort) + + def writeAndAbort(self): + self.transport.write("X"*10000) + self.transport.abortConnection() + + def connectionLost(self, reason): + self.done.callback(reason) + + + +class AbortConnectionTest(unittest.TestCase): + def test_abort(self): + serverDoneDeferred = defer.Deferred() + clientDoneDeferred = defer.Deferred() + + server = AbortServerFactory(serverDoneDeferred) + serverport = reactor.listenTCP(0, server, interface="127.0.0.1") + self.addCleanup(serverport.stopListening) + + c = protocol.ClientCreator(reactor, AbortingClient, clientDoneDeferred) + d = c.connectTCP(serverport.getHost().host, serverport.getHost().port) + + def checkConnectionLost(result): + result.trap(error.ConnectionLost) + serverDoneDeferred.addBoth(checkConnectionLost) + clientDoneDeferred.addBoth(checkConnectionLost) + d.addCallback(lambda x: clientDoneDeferred) + d.addCallback(lambda x: serverDoneDeferred) + return d + + + try: import resource except ImportError: