Ticket #764: Twistedssl.diff

File Twistedssl.diff, 9.2 KB (added by jknight, 12 years ago)
  • twisted/internet/abstract.py

     
    129129                return self._postLoseConnection()
    130130            elif self._writeDisconnecting:
    131131                # I was previously asked to to half-close the connection.
    132                 self._closeWriteConnection()
     132                result = self._closeWriteConnection()
    133133                self._writeDisconnected = True
    134134                return result
    135135        return result
  • twisted/internet/tcp.py

     
    7474class _TLSMixin:
    7575    writeBlockedOnRead = 0
    7676    readBlockedOnWrite = 0
    77     sslShutdown = 0
    78 
     77    _userWantRead = _userWantWrite = True
     78   
    7979    def getPeerCertificate(self):
    8080        return self.socket.get_peer_certificate()
    8181
    8282    def doRead(self):
    8383        if self.writeBlockedOnRead:
    8484            self.writeBlockedOnRead = 0
    85             self.startWriting()
     85            self._resetReadWrite()
    8686        try:
    8787            return Connection.doRead(self)
    8888        except SSL.ZeroReturnError:
    89             # close SSL layer, since other side has done so, if we haven't
    90             if not self.sslShutdown:
    91                 try:
    92                     self.socket.shutdown()
    93                     self.sslShutdown = 1
    94                 except SSL.Error:
    95                     pass
     89            #print "ZeroReturnError", self
    9690            return main.CONNECTION_DONE
    9791        except SSL.WantReadError:
    9892            return
    9993        except SSL.WantWriteError:
    10094            self.readBlockedOnWrite = 1
    101             self.startWriting()
     95            Connection.startWriting(self)
     96            Connection.stopReading(self)
    10297            return
     98        except SSL.SysCallError, (retval, desc):
     99            if ((retval == -1 and desc == 'Unexpected EOF')
     100                or retval > 0):
     101                return main.CONNECTION_LOST
     102            log.err()
     103            return main.CONNECTION_LOST
    103104        except SSL.Error:
    104105            log.err()
    105106            return main.CONNECTION_LOST
    106107
    107     def loseConnection(self):
    108         Connection.loseConnection(self)
    109         if self.connected:
    110             self.startReading()
    111 
    112     def halfCloseConnection(self, read=False, write=False):
    113         raise RuntimeError, "TLS connections currently do not support half-closing"
    114    
    115108    def doWrite(self):
    116         if self.writeBlockedOnRead:
    117             self.stopWriting()
    118             return
     109        #print "doWrite", self
     110        # Retry disconnecting
     111        if self.disconnecting:
     112            return self._postLoseConnection()
     113        if self._writeDisconnected:
     114            return self._closeWriteConnection()
     115       
    119116        if self.readBlockedOnWrite:
    120117            self.readBlockedOnWrite = 0
    121             # XXX - This is touching internal guts bad bad bad
    122             if not self.dataBuffer and not self._tempDataBuffer:
    123                 self.stopWriting()
    124             return self.doRead()
     118            self._resetReadWrite()
    125119        return Connection.doWrite(self)
    126120
    127121    def writeSomeData(self, data):
     
    131125            return 0
    132126        except SSL.WantReadError:
    133127            self.writeBlockedOnRead = 1
     128            Connection.stopWriting(self)
     129            Connection.startReading(self)
    134130            return 0
     131        except SSL.ZeroReturnError:
     132            return main.CONNECTION_LOST
    135133        except SSL.SysCallError, e:
    136134            if e[0] == -1 and data == "":
    137135                # errors when writing empty strings are expected
     
    156154    def _postLoseConnection(self):
    157155        """Gets called after loseConnection(), after buffered data is sent.
    158156
    159         We close the SSL transport layer, and if the other side hasn't
    160         closed it yet we start reading, waiting for a ZeroReturnError
    161         which will indicate the SSL shutdown has completed.
     157        We try to send an SSL shutdown alert, but if it doesn't work, retry
     158        when the socket is writable.
    162159        """
     160        #print "_postLoseConnection", self
     161        self.socket.set_shutdown(SSL.RECEIVED_SHUTDOWN)
     162        return self._sendCloseAlert()
     163
     164    _first=False
     165    def _sendCloseAlert(self):
     166        # Okay, *THIS* is a bit complicated.
     167       
     168        # Basically, the issue is, OpenSSL seems to not actually return
     169        # errors from SSL_shutdown. Therefore, the only way to
     170        # determine if the close notification has been sent is by
     171        # SSL_shutdown returning "done". However, it will not claim it's
     172        # done until it's both sent *and* received a shutdown notification.
     173
     174        # I don't actually want to wait for a received shutdown
     175        # notification, though, so, I have to set RECEIVED_SHUTDOWN
     176        # before calling shutdown. Then, it'll return True once it's
     177        # *SENT* the shutdown.
     178
     179        # However, RECEIVED_SHUTDOWN can't be left set, because then
     180        # reads will fail, breaking half close.
     181
     182        # Also, since shutdown doesn't report errors, an empty write call is
     183        # done first, to try to detect if the connection has gone away.
     184        # (*NOT* an SSL_write call, because that fails once you've called
     185        # shutdown)
     186       
     187        #print "_sendCloseAlert"
     188        #import pdb; pdb.set_trace()
    163189        try:
     190            os.write(self.socket.fileno(), '')
     191        except OSError, se:
     192            if se.args[0] in (EINTR, EWOULDBLOCK, ENOBUFS):
     193                return 0
     194            # Write error, socket gone
     195            return main.CONNECTION_LOST
     196       
     197        try:
     198            laststate = self.socket.get_shutdown()
     199            self.socket.set_shutdown(laststate | SSL.RECEIVED_SHUTDOWN)
    164200            done = self.socket.shutdown()
    165             self.sslShutdown = 1
     201            if not (laststate & SSL.RECEIVED_SHUTDOWN):
     202                self.socket.set_shutdown(SSL.SENT_SHUTDOWN)
     203            #print "SSL_SHUTDOWN:", done
    166204        except SSL.Error:
    167205            log.err()
    168206            return main.CONNECTION_LOST
     207
    169208        if done:
     209            self.stopWriting()
    170210            return main.CONNECTION_DONE
    171211        else:
    172             # we wait for other side to close SSL connection -
    173             # this will be signaled by SSL.ZeroReturnError when reading
    174             # from the socket
    175             self.stopWriting()
    176             self.startReading()
     212            #print "writeBlockedOnRead:", self.writeBlockedOnRead
     213            self.startWriting()
     214            #import default
     215            #print default.writes
     216            return None
    177217
    178             # don't close socket just yet
     218    def _closeWriteConnection(self):
     219        #print "_closeWriteConnection", self
     220        result = self._sendCloseAlert()
     221       
     222        if result is main.CONNECTION_DONE:
     223            self.socket.sock_shutdown(1)
     224            p = interfaces.IHalfCloseableProtocol(self.protocol, None)
     225            if p:
     226                p.writeConnectionLost()
    179227            return None
     228       
     229        return result
    180230
     231    def _closeReadConnection(self):
     232        # Keeps further reads from being received.
     233        self.socket.set_shutdown(SSL.RECEIVED_SHUTDOWN)
     234        self.socket.sock_shutdown(0)
     235        p = interfaces.IHalfCloseableProtocol(self.protocol, None)
     236        if p:
     237            p.readConnectionLost()
     238
     239    def startReading(self):
     240        self._userWantRead = True
     241        if not self.readBlockedOnWrite:
     242            return Connection.startReading(self)
     243
     244    def stopReading(self):
     245        self._userWantRead = False
     246        if not self.writeBlockedOnRead:
     247            return Connection.stopReading(self)
     248
     249    def startWriting(self):
     250        self._userWantWrite = True
     251        if not self.writeBlockedOnRead:
     252            return Connection.startWriting(self)
     253
     254    def stopWriting(self):
     255        self._userWantWrite = False
     256        if not self.readBlockedOnWrite:
     257            #print "stopWriting"
     258            return Connection.stopWriting(self)
     259
     260    def _resetReadWrite(self):
     261        # After changing readBlockedOnWrite or writeBlockedOnRead,
     262        # call this to reset the state to what the user requested.
     263        if self._userWantWrite:
     264            self.startWriting()
     265        else:
     266            self.stopWriting()
     267       
     268        if self._userWantRead:
     269            self.startReading()
     270        else:
     271            self.stopReading()
     272   
    181273class Connection(abstract.FileDescriptor):
    182274    """I am the superclass of all socket-based FileDescriptors.
    183275
     
    248340                return
    249341            else:
    250342                return main.CONNECTION_LOST
    251         except SSL.SysCallError, (retval, desc):
    252             # Yes, SSL might be None, but self.socket.recv() can *only*
    253             # raise socket.error, if anything else is raised, it must be an
    254             # SSL socket, and so SSL can't be None. (That's my story, I'm
    255             # stickin' to it)
    256             if retval == -1 and desc == 'Unexpected EOF':
    257                 return main.CONNECTION_DONE
    258             raise
    259343        if not data:
    260344            return main.CONNECTION_DONE
    261345        return self.protocol.dataReceived(data)