Ticket #6024: tls_notify_v3.patch

File tls_notify_v3.patch, 20.6 KB (added by Andy Lutomirski, 7 years ago)

Version 3, against trunk

  • twisted/internet/interfaces.py

    diff --git twisted/internet/interfaces.py twisted/internet/interfaces.py
    index 4021e57..907914a 100644
    class ISSLTransport(ITCPTransport): 
    21522152        Return an object with the peer's certificate info.
    21532153        """
    21542154
     2155    def whenHandshakeDone():
     2156        """
     2157        Returns a Deferred that will complete when the initial handshake
     2158        is done and will errback if the handshake fails.  (Connection
     2159        loss during the handshake is considered to be a handshake failure.)
     2160
     2161        If the handshake is already complete, then the returned Deferred
     2162        will already be complete.
     2163        """
     2164
    21552165
    21562166class IProcessTransport(ITransport):
    21572167    """
  • twisted/protocols/test/test_tls.py

    diff --git twisted/protocols/test/test_tls.py twisted/protocols/test/test_tls.py
    index 49e3a79..c5651d6 100644
    from twisted.test.test_tcp import ConnectionLostNotifyingProtocol 
    4242from twisted.test.proto_helpers import StringTransport
    4343
    4444
    45 class HandshakeCallbackContextFactory:
     45class TestContextFactory:
    4646    """
    47     L{HandshakeCallbackContextFactory} is a factory for SSL contexts which
    48     allows applications to get notification when the SSL handshake completes.
    49 
    50     @ivar _finished: A L{Deferred} which will be called back when the handshake
    51         is done.
     47    L{TestContextFactory} is a trivial factory for SSL contexts.
    5248    """
    53     # pyOpenSSL needs to expose this.
    54     # https://bugs.launchpad.net/pyopenssl/+bug/372832
    55     SSL_CB_HANDSHAKE_DONE = 0x20
    56 
    57     def __init__(self):
    58         self._finished = Deferred()
    59 
    60 
    61     def factoryAndDeferred(cls):
    62         """
    63         Create a new L{HandshakeCallbackContextFactory} and return a two-tuple
    64         of it and a L{Deferred} which will fire when a connection created with
    65         it completes a TLS handshake.
    66         """
    67         contextFactory = cls()
    68         return contextFactory, contextFactory._finished
    69     factoryAndDeferred = classmethod(factoryAndDeferred)
    70 
    71 
    72     def _info(self, connection, where, ret):
    73         """
    74         This is the "info callback" on the context.  It will be called
    75         periodically by pyOpenSSL with information about the state of a
    76         connection.  When it indicates the handshake is complete, it will fire
    77         C{self._finished}.
    78         """
    79         if where & self.SSL_CB_HANDSHAKE_DONE:
    80             self._finished.callback(None)
    81 
    82 
    8349    def getContext(self):
    8450        """
    85         Create and return an SSL context configured to use L{self._info} as the
    86         info callback.
     51        Create and return an SSL context.
    8752        """
    88         context = Context(TLSv1_METHOD)
    89         context.set_info_callback(self._info)
    90         return context
     53        return Context(TLSv1_METHOD)
    9154
    9255
    9356
    class TLSMemoryBIOTests(TestCase): 
    290253        clientFactory = ClientFactory()
    291254        clientFactory.protocol = Protocol
    292255
    293         clientContextFactory, handshakeDeferred = (
    294             HandshakeCallbackContextFactory.factoryAndDeferred())
     256        clientContextFactory = TestContextFactory()
    295257        wrapperFactory = TLSMemoryBIOFactory(
    296258            clientContextFactory, True, clientFactory)
    297259        sslClientProtocol = wrapperFactory.buildProtocol(None)
     260        handshakeDeferred = sslClientProtocol.whenHandshakeDone()
    298261
    299262        serverFactory = ServerFactory()
    300263        serverFactory.protocol = Protocol
    class TLSMemoryBIOTests(TestCase): 
    333296            lambda: ConnectionLostNotifyingProtocol(
    334297                clientConnectionLost))
    335298
    336         clientContextFactory = HandshakeCallbackContextFactory()
     299        clientContextFactory = TestContextFactory()
    337300        wrapperFactory = TLSMemoryBIOFactory(
    338301            clientContextFactory, True, clientFactory)
    339302        sslClientProtocol = wrapperFactory.buildProtocol(None)
    class TLSMemoryBIOTests(TestCase): 
    371334                connectionDeferred])
    372335
    373336
     337    def test_notifyAfterSuccessfulHandshake(self):
     338        """
     339        Calling L{TLSMemoryBIOProtocol.whenHandshakeDone} after a
     340        successful handshake should work.
     341        """
     342        tlsClient, tlsServer, handshakeDeferred, _ = self.handshakeProtocols()
     343
     344        result = Deferred()
     345
     346        def check(_):
     347            d = tlsClient.whenHandshakeDone()
     348            d.addCallback(result.callback)
     349            d.addErrback(result.errback)
     350
     351        handshakeDeferred.addCallback(check)
     352        return result
     353
     354
     355    def test_notifyAfterFailedHandshake(self):
     356        """
     357        Calling L{TLSMemoryBIOProtocol.whenHandshakeDone} after a
     358        failed handshake should work.
     359        """
     360        clientConnectionLost = Deferred()
     361        clientFactory = ClientFactory()
     362        clientFactory.protocol = Protocol
     363
     364        clientContextFactory = TestContextFactory()
     365        wrapperFactory = TLSMemoryBIOFactory(
     366            clientContextFactory, True, clientFactory)
     367        sslClientProtocol = wrapperFactory.buildProtocol(None)
     368
     369        serverConnectionLost = Deferred()
     370        serverFactory = ServerFactory()
     371        serverFactory.protocol = Protocol
     372
     373        # This context factory rejects any clients which do not present a
     374        # certificate.
     375        certificateData = FilePath(certPath).getContent()
     376        certificate = PrivateCertificate.loadPEM(certificateData)
     377        serverContextFactory = certificate.options(certificate)
     378        wrapperFactory = TLSMemoryBIOFactory(
     379            serverContextFactory, False, serverFactory)
     380        sslServerProtocol = wrapperFactory.buildProtocol(None)
     381
     382        connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
     383
     384        result = Deferred()
     385
     386        def fail(_):
     387            result.errback(False)
     388
     389        def check(reason):
     390            d = sslClientProtocol.whenHandshakeDone()
     391            if not d.called:
     392                result.errback(Exception('notification should be called'))
     393                return
     394            d.addCallback(fail)
     395            d.addErrback(lambda _: result.callback(None))
     396
     397        sslClientProtocol.whenHandshakeDone().addCallbacks(fail, check)
     398
     399        return gatherResults([connectionDeferred, result])
     400
     401
     402    def test_handshakeAfterConnectionLost(self):
     403        """
     404        Make sure that the correct handshake paths get run after a connection
     405        is lost.
     406        """
     407        clientConnectionLost = Deferred()
     408        clientFactory = ClientFactory()
     409        clientFactory.protocol = Protocol
     410
     411        clientContextFactory = TestContextFactory()
     412        wrapperFactory = TLSMemoryBIOFactory(
     413            clientContextFactory, True, clientFactory)
     414        sslClientProtocol = wrapperFactory.buildProtocol(None)
     415
     416        serverConnectionLost = Deferred()
     417        serverFactory = ServerFactory()
     418        serverFactory.protocol = Protocol
     419
     420        # This context factory rejects any clients which do not present a
     421        # certificate.
     422        certificateData = FilePath(certPath).getContent()
     423        certificate = PrivateCertificate.loadPEM(certificateData)
     424        serverContextFactory = certificate.options(certificate)
     425        wrapperFactory = TLSMemoryBIOFactory(
     426            serverContextFactory, False, serverFactory)
     427        sslServerProtocol = wrapperFactory.buildProtocol(None)
     428
     429        connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
     430        result = Deferred()
     431
     432        def checkSide(side):
     433            return self.assertFailure(side.whenHandshakeDone(), Error)
     434
     435        return gatherResults([connectionDeferred, checkSide(sslClientProtocol),
     436                              checkSide(sslServerProtocol)])
     437
     438
    374439    def test_getPeerCertificate(self):
    375440        """
    376441        L{TLSMemoryBIOProtocol.getPeerCertificate} returns the
    class TLSMemoryBIOTests(TestCase): 
    381446        clientFactory = ClientFactory()
    382447        clientFactory.protocol = Protocol
    383448
    384         clientContextFactory, handshakeDeferred = (
    385             HandshakeCallbackContextFactory.factoryAndDeferred())
     449        clientContextFactory = TestContextFactory()
    386450        wrapperFactory = TLSMemoryBIOFactory(
    387451            clientContextFactory, True, clientFactory)
    388452        sslClientProtocol = wrapperFactory.buildProtocol(None)
     453        handshakeDeferred = sslClientProtocol.whenHandshakeDone()
    389454
    390455        serverFactory = ServerFactory()
    391456        serverFactory.protocol = Protocol
    class TLSMemoryBIOTests(TestCase): 
    421486        clientFactory = ClientFactory()
    422487        clientFactory.protocol = lambda: clientProtocol
    423488
    424         clientContextFactory, handshakeDeferred = (
    425             HandshakeCallbackContextFactory.factoryAndDeferred())
     489        clientContextFactory = TestContextFactory()
    426490        wrapperFactory = TLSMemoryBIOFactory(
    427491            clientContextFactory, True, clientFactory)
    428492        sslClientProtocol = wrapperFactory.buildProtocol(None)
     493        handshakeDeferred = sslClientProtocol.whenHandshakeDone()
    429494
    430495        serverProtocol = AccumulatingProtocol(len(bytes))
    431496        serverFactory = ServerFactory()
    class TLSMemoryBIOTests(TestCase): 
    463528        clientFactory = ClientFactory()
    464529        clientFactory.protocol = sendingProtocol
    465530
    466         clientContextFactory, handshakeDeferred = (
    467             HandshakeCallbackContextFactory.factoryAndDeferred())
     531        clientContextFactory = TestContextFactory()
    468532        wrapperFactory = TLSMemoryBIOFactory(
    469533            clientContextFactory, True, clientFactory)
    470534        sslClientProtocol = wrapperFactory.buildProtocol(None)
     535        handshakeDeferred = sslClientProtocol.whenHandshakeDone()
    471536
    472537        serverProtocol = AccumulatingProtocol(len(bytes))
    473538        serverFactory = ServerFactory()
    class TLSMemoryBIOTests(TestCase): 
    565630        clientFactory = ClientFactory()
    566631        clientFactory.protocol = SimpleSendingProtocol
    567632
    568         clientContextFactory = HandshakeCallbackContextFactory()
     633        clientContextFactory = TestContextFactory()
    569634        wrapperFactory = TLSMemoryBIOFactory(
    570635            clientContextFactory, True, clientFactory)
    571636        sslClientProtocol = wrapperFactory.buildProtocol(None)
    class TLSMemoryBIOTests(TestCase): 
    604669        clientFactory = ClientFactory()
    605670        clientFactory.protocol = SimpleSendingProtocol
    606671
    607         clientContextFactory = HandshakeCallbackContextFactory()
     672        clientContextFactory = TestContextFactory()
    608673        wrapperFactory = TLSMemoryBIOFactory(
    609674            clientContextFactory, True, clientFactory)
    610675        sslClientProtocol = wrapperFactory.buildProtocol(None)
    class TLSMemoryBIOTests(TestCase): 
    639704            lambda: ConnectionLostNotifyingProtocol(
    640705                clientConnectionLost))
    641706
    642         clientContextFactory = HandshakeCallbackContextFactory()
     707        clientContextFactory = TestContextFactory()
    643708        wrapperFactory = TLSMemoryBIOFactory(
    644709            clientContextFactory, True, clientFactory)
    645710        sslClientProtocol = wrapperFactory.buildProtocol(None)
    class TLSMemoryBIOTests(TestCase): 
    679744        clientProtocol = NotifyingProtocol(clientConnectionLost)
    680745        clientFactory.protocol = lambda: clientProtocol
    681746
    682         clientContextFactory, handshakeDeferred = (
    683             HandshakeCallbackContextFactory.factoryAndDeferred())
     747        clientContextFactory = TestContextFactory()
    684748        wrapperFactory = TLSMemoryBIOFactory(
    685749            clientContextFactory, True, clientFactory)
    686750        sslClientProtocol = wrapperFactory.buildProtocol(None)
     751        handshakeDeferred = sslClientProtocol.whenHandshakeDone()
    687752
    688753        serverConnectionLost = Deferred()
    689754        serverProtocol = NotifyingProtocol(serverConnectionLost)
  • twisted/protocols/tls.py

    diff --git twisted/protocols/tls.py twisted/protocols/tls.py
    index f139c6a..57a38b5 100644
    to run TLS over unusual transports, such as UNIX sockets and stdio. 
    3737
    3838from __future__ import division, absolute_import
    3939
    40 from OpenSSL.SSL import Error, ZeroReturnError, WantReadError
     40from OpenSSL.SSL import Error, ZeroReturnError, WantReadError, WantWriteError
    4141from OpenSSL.SSL import TLSv1_METHOD, Context, Connection
    4242
    4343try:
    from twisted.python import log 
    5555from twisted.python._reflectpy3 import safe_str
    5656from twisted.internet.interfaces import ISystemHandle, ISSLTransport
    5757from twisted.internet.interfaces import IPushProducer, ILoggingContext
     58from twisted.internet import defer
    5859from twisted.internet.main import CONNECTION_LOST
    5960from twisted.internet.protocol import Protocol
    6061from twisted.internet.task import cooperate
    class TLSMemoryBIOProtocol(ProtocolWrapper): 
    244245        on, and which has no interest in a new transport.  See #3821.
    245246
    246247    @ivar _handshakeDone: A flag indicating whether or not the handshake is
    247         known to have completed successfully (C{True}) or not (C{False}).  This
    248         is used to control error reporting behavior.  If the handshake has not
    249         completed, the underlying L{OpenSSL.SSL.Error} will be passed to the
    250         application's C{connectionLost} method.  If it has completed, any
    251         unexpected L{OpenSSL.SSL.Error} will be turned into a
    252         L{ConnectionLost}.  This is weird; however, it is simply an attempt at
    253         a faithful re-implementation of the behavior provided by
    254         L{twisted.internet.ssl}.
     248        complete (C{True}) or not (C{False}).
     249
     250    @ivar _handshakeError: If the handshake failed, then this will store
     251        the reason (a L{twisted.python.failure.Failure} object).
     252        Otherwise it will be C{None}.
     253
     254    @ivar _handshakeDeferreds: If the handshake is not done, then this
     255        is a list of L{twisted.internet.defer.Deferred} instances to
     256        be completed when the handshake finishes.  Once the handshake
     257        is done, this is C{None}.
    255258
    256259    @ivar _reason: If an unexpected L{OpenSSL.SSL.Error} occurs which causes
    257260        the connection to be lost, it is saved here.  If appropriate, this may
    class TLSMemoryBIOProtocol(ProtocolWrapper): 
    265268
    266269    _reason = None
    267270    _handshakeDone = False
     271    _handshakeError = None
    268272    _lostTLSConnection = False
    269273    _writeBlockedOnRead = False
    270274    _producer = None
    class TLSMemoryBIOProtocol(ProtocolWrapper): 
    272276    def __init__(self, factory, wrappedProtocol, _connectWrapped=True):
    273277        ProtocolWrapper.__init__(self, factory, wrappedProtocol)
    274278        self._connectWrapped = _connectWrapped
     279        self._handshakeDeferreds = []
    275280
    276281
    277282    def getHandle(self):
    class TLSMemoryBIOProtocol(ProtocolWrapper): 
    316321        # Now that we ourselves have a transport (initialized by the
    317322        # ProtocolWrapper.makeConnection call above), kick off the TLS
    318323        # handshake.
    319         try:
    320             self._tlsConnection.do_handshake()
    321         except WantReadError:
    322             # This is the expected case - there's no data in the connection's
    323             # input buffer yet, so it won't be able to complete the whole
    324             # handshake now.  If this is the speak-first side of the
    325             # connection, then some bytes will be in the send buffer now; flush
    326             # them.
    327             self._flushSendBIO()
     324        self._tryHandshake()
     325
     326
     327    def whenHandshakeDone(self):
     328        d = defer.Deferred()
     329        if self._handshakeDone:
     330            if self._handshakeError is None:
     331                d.callback(None)
     332            else:
     333                d.errback(self._handshakeError)
     334        else:
     335            self._handshakeDeferreds.append(d)
     336        return d
     337
     338
     339    def _tryHandshake(self):
     340        """
     341        Attempts to handshake.  OpenSSL wants us to keep trying to
     342        handshake until either it works or fails (as opposed to needing
     343        to do I/O).
     344        """
     345        while True:
     346            try:
     347                self._tlsConnection.do_handshake()
     348            except WantReadError:
     349                self._flushSendBIO()  # do_handshake may have queued up a send
     350                return
     351            except WantWriteError:
     352                self._flushSendBIO()
     353                # And try again immediately
     354            except Error as e:
     355                self._tlsShutdownFinished(Failure())
     356                return
     357            else:
     358                self._handshakeSucceeded()
     359                return
     360
     361
     362    def _handshakeSucceeded(self):
     363        """
     364        Mark the handshake done and notify everyone.  It's okay to call
     365        this more than once.
     366        """
     367        if not self._handshakeDone:
     368            self._handshakeDone = True
     369            deferreds = self._handshakeDeferreds
     370            self._handshakeDeferreds = None
     371            for d in deferreds:
     372                d.callback(None)
    328373
    329374
    330375    def _flushSendBIO(self):
    class TLSMemoryBIOProtocol(ProtocolWrapper): 
    349394        the protocol, as well as handling of the various exceptions which
    350395        can come from trying to get such bytes.
    351396        """
     397        # SSL_read can transparently complete a handshake, but we can't
     398        # rely on it: if the handshake is done but there's no application
     399        # data, then SSL_read won't tell us.
     400        if not self._handshakeDone:
     401            self._tryHandshake()
     402        if not self._handshakeDone:
     403            return  # Save some effort: SSL_read can't possibly work
     404
    352405        # Keep trying this until an error indicates we should stop or we
    353406        # close the connection.  Looping is necessary to make sure we
    354407        # process all of the data which was put into the receive BIO, as
    class TLSMemoryBIOProtocol(ProtocolWrapper): 
    383436                self._flushSendBIO()
    384437                self._tlsShutdownFinished(failure)
    385438            else:
    386                 # If we got application bytes, the handshake must be done by
    387                 # now.  Keep track of this to control error reporting later.
    388                 self._handshakeDone = True
    389439                ProtocolWrapper.dataReceived(self, bytes)
    390440
    391441        # The received bytes might have generated a response which needs to be
    392         # sent now.  For example, the handshake involves several round-trip
    393         # exchanges without ever producing application-bytes.
     442        # sent now.  This is most likely to occur during renegotiation.
    394443        self._flushSendBIO()
    395444
    396445
    class TLSMemoryBIOProtocol(ProtocolWrapper): 
    438487        Called when TLS connection has gone away; tell underlying transport to
    439488        disconnect.
    440489        """
     490        if not self._handshakeDone:
     491            # This is a handshake failure (either an explicit failure from
     492            # OpenSSL or an implicit failure due to a dropped transport
     493            # connection).
     494            #
     495            # Note: Some testcases evilly call _tlsShutdownFinished(None)
     496            # before the handshake finishes.  This can't happen in real life
     497            # (none of the call sites allow it), so it's okay that we'll
     498            # crash if there's actually anyone waiting for notification
     499            # of the handshake result.
     500            self._handshakeDone = True
     501            self._handshakeError = reason
     502
     503            deferreds = self._handshakeDeferreds
     504            self._handshakeDeferreds = None
     505            for d in deferreds:
     506                d.errback(reason)
     507
    441508        self._reason = reason
    442509        self._lostTLSConnection = True
    443510        # Using loseConnection causes the application protocol's
    class TLSMemoryBIOProtocol(ProtocolWrapper): 
    457524        """
    458525        if not self._lostTLSConnection:
    459526            # Tell the TLS connection that it's not going to get any more data
    460             # and give it a chance to finish reading.
     527            # and give it a chance to finish handshaking and/or reading.
    461528            self._tlsConnection.bio_shutdown()
    462529            self._flushReceiveBIO()
    463530            self._lostTLSConnection = True
    class TLSMemoryBIOProtocol(ProtocolWrapper): 
    532599                self._tlsShutdownFinished(Failure())
    533600                break
    534601            else:
    535                 # If we sent some bytes, the handshake must be done.  Keep
    536                 # track of this to control error reporting behavior.
    537                 self._handshakeDone = True
     602                # SSL_write can transparently complete a handshake.  If we
     603                # get here, then we're done handshaking.
     604                self._handshakeSucceeded()
    538605                self._flushSendBIO()
    539606                alreadySent += sent
    540607
  • new file twisted/topfiles/6204.feature

    diff --git twisted/topfiles/6204.feature twisted/topfiles/6204.feature
    new file mode 100644
    index 0000000..46cfabb
    - +  
     1twisted.internet.interfaces.ISSLTransport now has a whenHandshakeDone method to request notification when the handshake succeeds or fails.