root/trunk/twisted/conch/ssh/connection.py

Revision 30752, 24.8 KB (checked in by exarkun, 15 months ago)

Rewrite the copyright headers to exclude date information.

Author: exarkun
Reviewer: glyph
Fixes: #4857

To avoid the need to perpetually update copyright dates in each file in Twisted,
remove the dates from most files and just leave them in the LICENSE file.

As a side effect, some files also have had a trailing newline added where it was
missing before.

Line 
1# -*- test-case-name: twisted.conch.test.test_connection -*-
2# Copyright (c) Twisted Matrix Laboratories.
3# See LICENSE for details.
4
5"""
6This module contains the implementation of the ssh-connection service, which
7allows access to the shell and port-forwarding.
8
9Maintainer: Paul Swartz
10"""
11
12import struct
13
14from twisted.conch.ssh import service, common
15from twisted.conch import error
16from twisted.internet import defer
17from twisted.python import log
18
19class SSHConnection(service.SSHService):
20    """
21    An implementation of the 'ssh-connection' service.  It is used to
22    multiplex multiple channels over the single SSH connection.
23
24    @ivar localChannelID: the next number to use as a local channel ID.
25    @type localChannelID: C{int}
26    @ivar channels: a C{dict} mapping a local channel ID to C{SSHChannel}
27        subclasses.
28    @type channels: C{dict}
29    @ivar localToRemoteChannel: a C{dict} mapping a local channel ID to a
30        remote channel ID.
31    @type localToRemoteChannel: C{dict}
32    @ivar channelsToRemoteChannel: a C{dict} mapping a C{SSHChannel} subclass
33        to remote channel ID.
34    @type channelsToRemoteChannel: C{dict}
35    @ivar deferreds: a C{dict} mapping a local channel ID to a C{list} of
36        C{Deferreds} for outstanding channel requests.  Also, the 'global'
37        key stores the C{list} of pending global request C{Deferred}s.
38    """
39    name = 'ssh-connection'
40
41    def __init__(self):
42        self.localChannelID = 0 # this is the current # to use for channel ID
43        self.localToRemoteChannel = {} # local channel ID -> remote channel ID
44        self.channels = {} # local channel ID -> subclass of SSHChannel
45        self.channelsToRemoteChannel = {} # subclass of SSHChannel ->
46                                          # remote channel ID
47        self.deferreds = {"global": []} # local channel -> list of deferreds
48                            # for pending requests or 'global' -> list of
49                            # deferreds for global requests
50        self.transport = None # gets set later
51
52
53    def serviceStarted(self):
54        if hasattr(self.transport, 'avatar'):
55            self.transport.avatar.conn = self
56
57
58    def serviceStopped(self):
59        """
60        Called when the connection is stopped.
61        """
62        map(self.channelClosed, self.channels.values())
63        self._cleanupGlobalDeferreds()
64
65
66    def _cleanupGlobalDeferreds(self):
67        """
68        All pending requests that have returned a deferred must be errbacked
69        when this service is stopped, otherwise they might be left uncalled and
70        uncallable.
71        """
72        for d in self.deferreds["global"]:
73            d.errback(error.ConchError("Connection stopped."))
74        del self.deferreds["global"][:]
75
76
77    # packet methods
78    def ssh_GLOBAL_REQUEST(self, packet):
79        """
80        The other side has made a global request.  Payload::
81            string  request type
82            bool    want reply
83            <request specific data>
84
85        This dispatches to self.gotGlobalRequest.
86        """
87        requestType, rest = common.getNS(packet)
88        wantReply, rest = ord(rest[0]), rest[1:]
89        ret = self.gotGlobalRequest(requestType, rest)
90        if wantReply:
91            reply = MSG_REQUEST_FAILURE
92            data = ''
93            if ret:
94                reply = MSG_REQUEST_SUCCESS
95                if isinstance(ret, (tuple, list)):
96                    data = ret[1]
97            self.transport.sendPacket(reply, data)
98
99    def ssh_REQUEST_SUCCESS(self, packet):
100        """
101        Our global request succeeded.  Get the appropriate Deferred and call
102        it back with the packet we received.
103        """
104        log.msg('RS')
105        self.deferreds['global'].pop(0).callback(packet)
106
107    def ssh_REQUEST_FAILURE(self, packet):
108        """
109        Our global request failed.  Get the appropriate Deferred and errback
110        it with the packet we received.
111        """
112        log.msg('RF')
113        self.deferreds['global'].pop(0).errback(
114            error.ConchError('global request failed', packet))
115
116    def ssh_CHANNEL_OPEN(self, packet):
117        """
118        The other side wants to get a channel.  Payload::
119            string  channel name
120            uint32  remote channel number
121            uint32  remote window size
122            uint32  remote maximum packet size
123            <channel specific data>
124
125        We get a channel from self.getChannel(), give it a local channel number
126        and notify the other side.  Then notify the channel by calling its
127        channelOpen method.
128        """
129        channelType, rest = common.getNS(packet)
130        senderChannel, windowSize, maxPacket = struct.unpack('>3L', rest[:12])
131        packet = rest[12:]
132        try:
133            channel = self.getChannel(channelType, windowSize, maxPacket,
134                            packet)
135            localChannel = self.localChannelID
136            self.localChannelID += 1
137            channel.id = localChannel
138            self.channels[localChannel] = channel
139            self.channelsToRemoteChannel[channel] = senderChannel
140            self.localToRemoteChannel[localChannel] = senderChannel
141            self.transport.sendPacket(MSG_CHANNEL_OPEN_CONFIRMATION,
142                struct.pack('>4L', senderChannel, localChannel,
143                    channel.localWindowSize,
144                    channel.localMaxPacket)+channel.specificData)
145            log.callWithLogger(channel, channel.channelOpen, packet)
146        except Exception, e:
147            log.msg('channel open failed')
148            log.err(e)
149            if isinstance(e, error.ConchError):
150                textualInfo, reason = e.args
151                if isinstance(textualInfo, (int, long)):
152                    # See #3657 and #3071
153                    textualInfo, reason = reason, textualInfo
154            else:
155                reason = OPEN_CONNECT_FAILED
156                textualInfo = "unknown failure"
157            self.transport.sendPacket(
158                MSG_CHANNEL_OPEN_FAILURE,
159                struct.pack('>2L', senderChannel, reason) +
160                common.NS(textualInfo) + common.NS(''))
161
162    def ssh_CHANNEL_OPEN_CONFIRMATION(self, packet):
163        """
164        The other side accepted our MSG_CHANNEL_OPEN request.  Payload::
165            uint32  local channel number
166            uint32  remote channel number
167            uint32  remote window size
168            uint32  remote maximum packet size
169            <channel specific data>
170
171        Find the channel using the local channel number and notify its
172        channelOpen method.
173        """
174        (localChannel, remoteChannel, windowSize,
175                maxPacket) = struct.unpack('>4L', packet[: 16])
176        specificData = packet[16:]
177        channel = self.channels[localChannel]
178        channel.conn = self
179        self.localToRemoteChannel[localChannel] = remoteChannel
180        self.channelsToRemoteChannel[channel] = remoteChannel
181        channel.remoteWindowLeft = windowSize
182        channel.remoteMaxPacket = maxPacket
183        log.callWithLogger(channel, channel.channelOpen, specificData)
184
185    def ssh_CHANNEL_OPEN_FAILURE(self, packet):
186        """
187        The other side did not accept our MSG_CHANNEL_OPEN request.  Payload::
188            uint32  local channel number
189            uint32  reason code
190            string  reason description
191
192        Find the channel using the local channel number and notify it by
193        calling its openFailed() method.
194        """
195        localChannel, reasonCode = struct.unpack('>2L', packet[:8])
196        reasonDesc = common.getNS(packet[8:])[0]
197        channel = self.channels[localChannel]
198        del self.channels[localChannel]
199        channel.conn = self
200        reason = error.ConchError(reasonDesc, reasonCode)
201        log.callWithLogger(channel, channel.openFailed, reason)
202
203    def ssh_CHANNEL_WINDOW_ADJUST(self, packet):
204        """
205        The other side is adding bytes to its window.  Payload::
206            uint32  local channel number
207            uint32  bytes to add
208
209        Call the channel's addWindowBytes() method to add new bytes to the
210        remote window.
211        """
212        localChannel, bytesToAdd = struct.unpack('>2L', packet[:8])
213        channel = self.channels[localChannel]
214        log.callWithLogger(channel, channel.addWindowBytes, bytesToAdd)
215
216    def ssh_CHANNEL_DATA(self, packet):
217        """
218        The other side is sending us data.  Payload::
219            uint32 local channel number
220            string data
221
222        Check to make sure the other side hasn't sent too much data (more
223        than what's in the window, or more than the maximum packet size).  If
224        they have, close the channel.  Otherwise, decrease the available
225        window and pass the data to the channel's dataReceived().
226        """
227        localChannel, dataLength = struct.unpack('>2L', packet[:8])
228        channel = self.channels[localChannel]
229        # XXX should this move to dataReceived to put client in charge?
230        if (dataLength > channel.localWindowLeft or
231           dataLength > channel.localMaxPacket): # more data than we want
232            log.callWithLogger(channel, log.msg, 'too much data')
233            self.sendClose(channel)
234            return
235            #packet = packet[:channel.localWindowLeft+4]
236        data = common.getNS(packet[4:])[0]
237        channel.localWindowLeft -= dataLength
238        if channel.localWindowLeft < channel.localWindowSize / 2:
239            self.adjustWindow(channel, channel.localWindowSize - \
240                                       channel.localWindowLeft)
241            #log.msg('local window left: %s/%s' % (channel.localWindowLeft,
242            #                                    channel.localWindowSize))
243        log.callWithLogger(channel, channel.dataReceived, data)
244
245    def ssh_CHANNEL_EXTENDED_DATA(self, packet):
246        """
247        The other side is sending us exteneded data.  Payload::
248            uint32  local channel number
249            uint32  type code
250            string  data
251
252        Check to make sure the other side hasn't sent too much data (more
253        than what's in the window, or or than the maximum packet size).  If
254        they have, close the channel.  Otherwise, decrease the available
255        window and pass the data and type code to the channel's
256        extReceived().
257        """
258        localChannel, typeCode, dataLength = struct.unpack('>3L', packet[:12])
259        channel = self.channels[localChannel]
260        if (dataLength > channel.localWindowLeft or
261                dataLength > channel.localMaxPacket):
262            log.callWithLogger(channel, log.msg, 'too much extdata')
263            self.sendClose(channel)
264            return
265        data = common.getNS(packet[8:])[0]
266        channel.localWindowLeft -= dataLength
267        if channel.localWindowLeft < channel.localWindowSize / 2:
268            self.adjustWindow(channel, channel.localWindowSize -
269                                       channel.localWindowLeft)
270        log.callWithLogger(channel, channel.extReceived, typeCode, data)
271
272    def ssh_CHANNEL_EOF(self, packet):
273        """
274        The other side is not sending any more data.  Payload::
275            uint32  local channel number
276
277        Notify the channel by calling its eofReceived() method.
278        """
279        localChannel = struct.unpack('>L', packet[:4])[0]
280        channel = self.channels[localChannel]
281        log.callWithLogger(channel, channel.eofReceived)
282
283    def ssh_CHANNEL_CLOSE(self, packet):
284        """
285        The other side is closing its end; it does not want to receive any
286        more data.  Payload::
287            uint32  local channel number
288
289        Notify the channnel by calling its closeReceived() method.  If
290        the channel has also sent a close message, call self.channelClosed().
291        """
292        localChannel = struct.unpack('>L', packet[:4])[0]
293        channel = self.channels[localChannel]
294        log.callWithLogger(channel, channel.closeReceived)
295        channel.remoteClosed = True
296        if channel.localClosed and channel.remoteClosed:
297            self.channelClosed(channel)
298
299    def ssh_CHANNEL_REQUEST(self, packet):
300        """
301        The other side is sending a request to a channel.  Payload::
302            uint32  local channel number
303            string  request name
304            bool    want reply
305            <request specific data>
306
307        Pass the message to the channel's requestReceived method.  If the
308        other side wants a reply, add callbacks which will send the
309        reply.
310        """
311        localChannel = struct.unpack('>L', packet[: 4])[0]
312        requestType, rest = common.getNS(packet[4:])
313        wantReply = ord(rest[0])
314        channel = self.channels[localChannel]
315        d = defer.maybeDeferred(log.callWithLogger, channel,
316                channel.requestReceived, requestType, rest[1:])
317        if wantReply:
318            d.addCallback(self._cbChannelRequest, localChannel)
319            d.addErrback(self._ebChannelRequest, localChannel)
320            return d
321
322    def _cbChannelRequest(self, result, localChannel):
323        """
324        Called back if the other side wanted a reply to a channel request.  If
325        the result is true, send a MSG_CHANNEL_SUCCESS.  Otherwise, raise
326        a C{error.ConchError}
327
328        @param result: the value returned from the channel's requestReceived()
329            method.  If it's False, the request failed.
330        @type result: C{bool}
331        @param localChannel: the local channel ID of the channel to which the
332            request was made.
333        @type localChannel: C{int}
334        @raises ConchError: if the result is False.
335        """
336        if not result:
337            raise error.ConchError('failed request')
338        self.transport.sendPacket(MSG_CHANNEL_SUCCESS, struct.pack('>L',
339                                self.localToRemoteChannel[localChannel]))
340
341    def _ebChannelRequest(self, result, localChannel):
342        """
343        Called if the other wisde wanted a reply to the channel requeset and
344        the channel request failed.
345
346        @param result: a Failure, but it's not used.
347        @param localChannel: the local channel ID of the channel to which the
348            request was made.
349        @type localChannel: C{int}
350        """
351        self.transport.sendPacket(MSG_CHANNEL_FAILURE, struct.pack('>L',
352                                self.localToRemoteChannel[localChannel]))
353
354    def ssh_CHANNEL_SUCCESS(self, packet):
355        """
356        Our channel request to the other other side succeeded.  Payload::
357            uint32  local channel number
358
359        Get the C{Deferred} out of self.deferreds and call it back.
360        """
361        localChannel = struct.unpack('>L', packet[:4])[0]
362        if self.deferreds.get(localChannel):
363            d = self.deferreds[localChannel].pop(0)
364            log.callWithLogger(self.channels[localChannel],
365                               d.callback, '')
366
367    def ssh_CHANNEL_FAILURE(self, packet):
368        """
369        Our channel request to the other side failed.  Payload::
370            uint32  local channel number
371
372        Get the C{Deferred} out of self.deferreds and errback it with a
373        C{error.ConchError}.
374        """
375        localChannel = struct.unpack('>L', packet[:4])[0]
376        if self.deferreds.get(localChannel):
377            d = self.deferreds[localChannel].pop(0)
378            log.callWithLogger(self.channels[localChannel],
379                               d.errback,
380                               error.ConchError('channel request failed'))
381
382    # methods for users of the connection to call
383
384    def sendGlobalRequest(self, request, data, wantReply=0):
385        """
386        Send a global request for this connection.  Current this is only used
387        for remote->local TCP forwarding.
388
389        @type request:      C{str}
390        @type data:         C{str}
391        @type wantReply:    C{bool}
392        @rtype              C{Deferred}/C{None}
393        """
394        self.transport.sendPacket(MSG_GLOBAL_REQUEST,
395                                  common.NS(request)
396                                  + (wantReply and '\xff' or '\x00')
397                                  + data)
398        if wantReply:
399            d = defer.Deferred()
400            self.deferreds['global'].append(d)
401            return d
402
403    def openChannel(self, channel, extra=''):
404        """
405        Open a new channel on this connection.
406
407        @type channel:  subclass of C{SSHChannel}
408        @type extra:    C{str}
409        """
410        log.msg('opening channel %s with %s %s'%(self.localChannelID,
411                channel.localWindowSize, channel.localMaxPacket))
412        self.transport.sendPacket(MSG_CHANNEL_OPEN, common.NS(channel.name)
413                    + struct.pack('>3L', self.localChannelID,
414                    channel.localWindowSize, channel.localMaxPacket)
415                    + extra)
416        channel.id = self.localChannelID
417        self.channels[self.localChannelID] = channel
418        self.localChannelID += 1
419
420    def sendRequest(self, channel, requestType, data, wantReply=0):
421        """
422        Send a request to a channel.
423
424        @type channel:      subclass of C{SSHChannel}
425        @type requestType:  C{str}
426        @type data:         C{str}
427        @type wantReply:    C{bool}
428        @rtype              C{Deferred}/C{None}
429        """
430        if channel.localClosed:
431            return
432        log.msg('sending request %s' % requestType)
433        self.transport.sendPacket(MSG_CHANNEL_REQUEST, struct.pack('>L',
434                                    self.channelsToRemoteChannel[channel])
435                                  + common.NS(requestType)+chr(wantReply)
436                                  + data)
437        if wantReply:
438            d = defer.Deferred()
439            self.deferreds.setdefault(channel.id, []).append(d)
440            return d
441
442    def adjustWindow(self, channel, bytesToAdd):
443        """
444        Tell the other side that we will receive more data.  This should not
445        normally need to be called as it is managed automatically.
446
447        @type channel:      subclass of L{SSHChannel}
448        @type bytesToAdd:   C{int}
449        """
450        if channel.localClosed:
451            return # we're already closed
452        self.transport.sendPacket(MSG_CHANNEL_WINDOW_ADJUST, struct.pack('>2L',
453                                    self.channelsToRemoteChannel[channel],
454                                    bytesToAdd))
455        log.msg('adding %i to %i in channel %i' % (bytesToAdd,
456            channel.localWindowLeft, channel.id))
457        channel.localWindowLeft += bytesToAdd
458
459    def sendData(self, channel, data):
460        """
461        Send data to a channel.  This should not normally be used: instead use
462        channel.write(data) as it manages the window automatically.
463
464        @type channel:  subclass of L{SSHChannel}
465        @type data:     C{str}
466        """
467        if channel.localClosed:
468            return # we're already closed
469        self.transport.sendPacket(MSG_CHANNEL_DATA, struct.pack('>L',
470                                    self.channelsToRemoteChannel[channel]) +
471                                   common.NS(data))
472
473    def sendExtendedData(self, channel, dataType, data):
474        """
475        Send extended data to a channel.  This should not normally be used:
476        instead use channel.writeExtendedData(data, dataType) as it manages
477        the window automatically.
478
479        @type channel:  subclass of L{SSHChannel}
480        @type dataType: C{int}
481        @type data:     C{str}
482        """
483        if channel.localClosed:
484            return # we're already closed
485        self.transport.sendPacket(MSG_CHANNEL_EXTENDED_DATA, struct.pack('>2L',
486                            self.channelsToRemoteChannel[channel],dataType) \
487                            + common.NS(data))
488
489    def sendEOF(self, channel):
490        """
491        Send an EOF (End of File) for a channel.
492
493        @type channel:  subclass of L{SSHChannel}
494        """
495        if channel.localClosed:
496            return # we're already closed
497        log.msg('sending eof')
498        self.transport.sendPacket(MSG_CHANNEL_EOF, struct.pack('>L',
499                                    self.channelsToRemoteChannel[channel]))
500
501    def sendClose(self, channel):
502        """
503        Close a channel.
504
505        @type channel:  subclass of L{SSHChannel}
506        """
507        if channel.localClosed:
508            return # we're already closed
509        log.msg('sending close %i' % channel.id)
510        self.transport.sendPacket(MSG_CHANNEL_CLOSE, struct.pack('>L',
511                self.channelsToRemoteChannel[channel]))
512        channel.localClosed = True
513        if channel.localClosed and channel.remoteClosed:
514            self.channelClosed(channel)
515
516    # methods to override
517    def getChannel(self, channelType, windowSize, maxPacket, data):
518        """
519        The other side requested a channel of some sort.
520        channelType is the type of channel being requested,
521        windowSize is the initial size of the remote window,
522        maxPacket is the largest packet we should send,
523        data is any other packet data (often nothing).
524
525        We return a subclass of L{SSHChannel}.
526
527        By default, this dispatches to a method 'channel_channelType' with any
528        non-alphanumerics in the channelType replace with _'s.  If it cannot
529        find a suitable method, it returns an OPEN_UNKNOWN_CHANNEL_TYPE error.
530        The method is called with arguments of windowSize, maxPacket, data.
531
532        @type channelType:  C{str}
533        @type windowSize:   C{int}
534        @type maxPacket:    C{int}
535        @type data:         C{str}
536        @rtype:             subclass of L{SSHChannel}/C{tuple}
537        """
538        log.msg('got channel %s request' % channelType)
539        if hasattr(self.transport, "avatar"): # this is a server!
540            chan = self.transport.avatar.lookupChannel(channelType,
541                                                       windowSize,
542                                                       maxPacket,
543                                                       data)
544        else:
545            channelType = channelType.translate(TRANSLATE_TABLE)
546            f = getattr(self, 'channel_%s' % channelType, None)
547            if f is not None:
548                chan = f(windowSize, maxPacket, data)
549            else:
550                chan = None
551        if chan is None:
552            raise error.ConchError('unknown channel',
553                    OPEN_UNKNOWN_CHANNEL_TYPE)
554        else:
555            chan.conn = self
556            return chan
557
558    def gotGlobalRequest(self, requestType, data):
559        """
560        We got a global request.  pretty much, this is just used by the client
561        to request that we forward a port from the server to the client.
562        Returns either:
563            - 1: request accepted
564            - 1, <data>: request accepted with request specific data
565            - 0: request denied
566
567        By default, this dispatches to a method 'global_requestType' with
568        -'s in requestType replaced with _'s.  The found method is passed data.
569        If this method cannot be found, this method returns 0.  Otherwise, it
570        returns the return value of that method.
571
572        @type requestType:  C{str}
573        @type data:         C{str}
574        @rtype:             C{int}/C{tuple}
575        """
576        log.msg('got global %s request' % requestType)
577        if hasattr(self.transport, 'avatar'): # this is a server!
578            return self.transport.avatar.gotGlobalRequest(requestType, data)
579
580        requestType = requestType.replace('-','_')
581        f = getattr(self, 'global_%s' % requestType, None)
582        if not f:
583            return 0
584        return f(data)
585
586    def channelClosed(self, channel):
587        """
588        Called when a channel is closed.
589        It clears the local state related to the channel, and calls
590        channel.closed().
591        MAKE SURE YOU CALL THIS METHOD, even if you subclass L{SSHConnection}.
592        If you don't, things will break mysteriously.
593
594        @type channel: L{SSHChannel}
595        """
596        if channel in self.channelsToRemoteChannel: # actually open
597            channel.localClosed = channel.remoteClosed = True
598            del self.localToRemoteChannel[channel.id]
599            del self.channels[channel.id]
600            del self.channelsToRemoteChannel[channel]
601            for d in self.deferreds.setdefault(channel.id, []):
602                d.errback(error.ConchError("Channel closed."))
603            del self.deferreds[channel.id][:]
604            log.callWithLogger(channel, channel.closed)
605
606MSG_GLOBAL_REQUEST = 80
607MSG_REQUEST_SUCCESS = 81
608MSG_REQUEST_FAILURE = 82
609MSG_CHANNEL_OPEN = 90
610MSG_CHANNEL_OPEN_CONFIRMATION = 91
611MSG_CHANNEL_OPEN_FAILURE = 92
612MSG_CHANNEL_WINDOW_ADJUST = 93
613MSG_CHANNEL_DATA = 94
614MSG_CHANNEL_EXTENDED_DATA = 95
615MSG_CHANNEL_EOF = 96
616MSG_CHANNEL_CLOSE = 97
617MSG_CHANNEL_REQUEST = 98
618MSG_CHANNEL_SUCCESS = 99
619MSG_CHANNEL_FAILURE = 100
620
621OPEN_ADMINISTRATIVELY_PROHIBITED = 1
622OPEN_CONNECT_FAILED = 2
623OPEN_UNKNOWN_CHANNEL_TYPE = 3
624OPEN_RESOURCE_SHORTAGE = 4
625
626EXTENDED_DATA_STDERR = 1
627
628messages = {}
629for name, value in locals().copy().items():
630    if name[:4] == 'MSG_':
631        messages[value] = name # doesn't handle doubles
632
633import string
634alphanums = string.letters + string.digits
635TRANSLATE_TABLE = ''.join([chr(i) in alphanums and chr(i) or '_'
636    for i in range(256)])
637SSHConnection.protocolMessages = messages
Note: See TracBrowser for help on using the browser.