Index: twisted/internet/_sslverify.py =================================================================== --- twisted/internet/_sslverify.py (revision 44494) +++ twisted/internet/_sslverify.py (working copy) @@ -1281,7 +1281,8 @@ extraCertChain=None, acceptableCiphers=None, dhParameters=None, - trustRoot=None): + trustRoot=None, + nextProtos=None): """ Create an OpenSSL context SSL connection context factory. @@ -1372,6 +1373,14 @@ @type trustRoot: L{IOpenSSLTrustRoot} + @param nextProtos: The protocols this peer is willing to speak after + the TLS negotation has completed, advertised over both ALPN and + NPN. If this argument is specified, and no overlap can be found + with the other peer, the connection will fail to be established. + Protocols earlier in the list are preferred over those later in + the list. + @type nextProtos: C{list} of C{bytestring}s + @raise ValueError: when C{privateKey} or C{certificate} are set without setting the respective other. @raise ValueError: when C{verify} is L{True} but C{caCerts} doesn't @@ -1467,7 +1476,9 @@ trustRoot = IOpenSSLTrustRoot(trustRoot) self.trustRoot = trustRoot + self.nextProtos = nextProtos + def __getstate__(self): d = self.__dict__.copy() try: @@ -1536,6 +1547,38 @@ except BaseException: pass # ECDHE support is best effort only. + if self.nextProtos: + # Set both NPN and ALPN. Configure both the server and client + # side logic. + def _npnSelectCallback(conn, protocols): + overlap = set(protocols) & set(self.nextProtos) + + for p in self.nextProtos: + if p in overlap: + return p + else: + return b'' + + def _npnAdvertiseCallback(conn): + return self.nextProtos + + # We can reuse the NPN logic in the ALPN case. + _alpnSelectCallback = _npnSelectCallback + + # Now, attach all the things. + # If NPN is not supported this will raise a NotImplementedError, + # which is ideal. + ctx.set_npn_advertise_callback(_npnAdvertiseCallback) + ctx.set_npn_select_callback(_npnSelectCallback) + + # Anything that does not support ALPN will also not support NPN, + # so we can just catch the NotImplementedError here. + try: + ctx.set_alpn_select_callback(_alpnSelectCallback) + ctx.set_alpn_protos(self.nextProtos) + except NotImplementedError: + pass + return ctx Index: twisted/protocols/tls.py =================================================================== --- twisted/protocols/tls.py (revision 44494) +++ twisted/protocols/tls.py (working copy) @@ -586,6 +586,26 @@ return self._tlsConnection.get_peer_certificate() + def getNextProtocol(self): + """ + Returns the protocol selected to be spoken using ALPN/NPN. This + function prefers the result from ALPN. If no protocol was chosen, + returns C{None}. + + @return: the selected protocol, or None if none is chosen. + @rtype: C{str} + """ + proto = self._tlsConnection.get_alpn_proto_negotiated() + if proto: + return proto + + proto = self._tlsConnection.get_next_proto_negotiated() + if proto: + return proto + + return None + + def registerProducer(self, producer, streaming): # If we've already disconnected, nothing to do here: if self._lostTLSConnection: Index: twisted/test/test_sslverify.py =================================================================== --- twisted/test/test_sslverify.py (revision 44494) +++ twisted/test/test_sslverify.py (working copy) @@ -15,11 +15,13 @@ skipSSL = None skipSNI = None +skipNPN = None try: import OpenSSL except ImportError: skipSSL = "OpenSSL is required for SSL tests." skipSNI = skipSSL + skipNPN = skipSSL else: from OpenSSL import SSL from OpenSSL.crypto import PKey, X509 @@ -27,6 +29,14 @@ if getattr(SSL.Context, "set_tlsext_servername_callback", None) is None: skipSNI = "PyOpenSSL 0.13 or greater required for SNI support." + try: + c = SSL.Context(SSL.SSLv23_METHOD) + c.set_npn_advertise_callback(lambda c: None) + except AttributeError: + skipNPN = "PyOpenSSL 0.15 or greater is required for NPN support" + except NotImplementedError: + skipNPN = "OpenSSL 1.0.1 or greater required for NPN support" + from twisted.test.test_twisted import SetAsideModule from twisted.test.iosim import connectedServerAndClient @@ -289,7 +299,18 @@ self.factory.onLost.errback(reason) +class NegotiatedProtocol(protocol.Protocol): + def __init__(self): + self.deferred = defer.Deferred() + def connectionMade(self): + self.transport.write(b'x') + + def dataReceived(self, data): + self.deferred.callback(self.transport.getNextProtocol()) + + + class FakeContext(object): """ Introspectable fake of an C{OpenSSL.SSL.Context}. @@ -1719,6 +1740,109 @@ +class NPNAndALPNTest(unittest.TestCase): + if skipSSL: + skip = skipSSL + elif skipNPN: + skip = skipNPN + + serverPort = clientConn = None + + sKey = None + sCert = None + cKey = None + cCert = None + + def setUp(self): + """ + Create class variables of client and server certificates. + """ + self.sKey, self.sCert = makeCertificate( + O=b"Server Test Certificate", + CN=b"server") + self.cKey, self.cCert = makeCertificate( + O=b"Client Test Certificate", + CN=b"client") + self.caCert1 = makeCertificate( + O=b"CA Test Certificate 1", + CN=b"ca1")[1] + self.caCert2 = makeCertificate( + O=b"CA Test Certificate", + CN=b"ca2")[1] + + + def tearDown(self): + if self.serverPort is not None: + self.serverPort.stopListening() + if self.clientConn is not None: + self.clientConn.disconnect() + + + def loopback(self, serverCertOpts, clientCertOpts): + self.proto = NegotiatedProtocol() + + serverFactory = protocol.ServerFactory() + serverFactory.protocol = lambda: self.proto + + clientFactory = protocol.ClientFactory() + clientFactory.protocol = NegotiatedProtocol + + self.serverPort = reactor.listenSSL(0, serverFactory, serverCertOpts) + self.clientConn = reactor.connectSSL('127.0.0.1', + self.serverPort.getHost().port, clientFactory, clientCertOpts) + + + def test_NPNAndALPNSuccess(self): + protocols = [b'h2', b'http/1.1'] + self.loopback( + sslverify.OpenSSLCertificateOptions( + privateKey=self.sKey, + certificate=self.sCert, + verify=True, + requireCertificate=True, + caCerts=[self.cCert], + nextProtos=protocols, + ), + sslverify.OpenSSLCertificateOptions( + privateKey=self.cKey, + certificate=self.cCert, + verify=True, + requireCertificate=True, + caCerts=[self.sCert], + nextProtos=protocols, + ), + ) + + return self.proto.deferred.addCallback( + self.assertEqual, b'h2') + + + def test_NPNAndALPNFailure(self): + protocols = [b'h2', b'http/1.1'] + self.loopback( + sslverify.OpenSSLCertificateOptions( + privateKey=self.sKey, + certificate=self.sCert, + verify=True, + requireCertificate=True, + caCerts=[self.cCert], + nextProtos=protocols, + ), + sslverify.OpenSSLCertificateOptions( + privateKey=self.cKey, + certificate=self.cCert, + verify=True, + requireCertificate=True, + caCerts=[self.sCert], + nextProtos=[], + ), + ) + + return self.proto.deferred.addCallback( + self.assertEqual, None) + + + class _NotSSLTransport: def getHandle(self): return self Index: twisted/topfiles/7860.feature =================================================================== --- twisted/topfiles/7860.feature (revision 0) +++ twisted/topfiles/7860.feature (working copy) @@ -0,0 +1 @@ +twisted.internet.ssl.CertificateOptions now takes a nextProtos parameter that enables negotiation of the next protocol to speak, after the TLS handshake has completed. This field advertises protocols over both NPN and ALPN. \ No newline at end of file