root / trunk / twisted / conch / ssh / connection.py

Revision 24441, 24.1 kB (checked in by thijs, 1 year ago)

Merge maintainer-email-2438: Get rid of references to maintainer email addresses from code.

Author: thijs
Reviewer: exarkun
Fixes: #2438

Line 
1 # -*- test-case-name: twisted.conch.test.test_connection -*-
2 # Copyright (c) 2001-2007 Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5 #
6
7 """
8 This module contains the implementation of the ssh-connection service, which
9 allows access to the shell and port-forwarding.
10
11 Maintainer: Paul Swartz
12 """
13
14 import struct
15
16 from twisted.conch.ssh import service, common
17 from twisted.conch import error
18 from twisted.internet import defer
19 from twisted.python import log
20
21 class SSHConnection(service.SSHService):
22     """
23     An implementation of the 'ssh-connection' service.  It is used to
24     multiplex multiple channels over the single SSH connection.
25
26     @ivar localChannelID: the next number to use as a local channel ID.
27     @type localChannelID: C{int}
28     @ivar channels: a C{dict} mapping a local channel ID to C{SSHChannel}
29         subclasses.
30     @type channels: C{dict}
31     @ivar localToRemoteChannel: a C{dict} mapping a local channel ID to a
32         remote channel ID.
33     @type localToRemoteChannel: C{dict}
34     @ivar channelsToRemoteChannel: a C{dict} mapping a C{SSHChannel} subclass
35         to remote channel ID.
36     @type channelsToRemoteChannel: C{dict}
37     @ivar deferreds: a C{dict} mapping a local channel ID to a C{list} of
38         C{Deferreds} for outstanding channel requests.  Also, the 'global'
39         key stores the C{list} of pending global request C{Deferred}s.
40     """
41     name = 'ssh-connection'
42
43     def __init__(self):
44         self.localChannelID = 0 # this is the current # to use for channel ID
45         self.localToRemoteChannel = {} # local channel ID -> remote channel ID
46         self.channels = {} # local channel ID -> subclass of SSHChannel
47         self.channelsToRemoteChannel = {} # subclass of SSHChannel ->
48                                           # remote channel ID
49         self.deferreds = {} # local channel -> list of deferreds for pending
50                             # requests or 'global' -> list of deferreds for
51                             # global requests
52         self.transport = None # gets set later
53
54     def serviceStarted(self):
55         if hasattr(self.transport, 'avatar'):
56             self.transport.avatar.conn = self
57
58     def serviceStopped(self):
59         map(self.channelClosed, self.channels.values())
60
61     # packet methods
62     def ssh_GLOBAL_REQUEST(self, packet):
63         """
64         The other side has made a global request.  Payload::
65             string  request type
66             bool    want reply
67             <request specific data>
68
69         This dispatches to self.gotGlobalRequest.
70         """
71         requestType, rest = common.getNS(packet)
72         wantReply, rest = ord(rest[0]), rest[1:]
73         ret = self.gotGlobalRequest(requestType, rest)
74         if wantReply:
75             reply = MSG_REQUEST_FAILURE
76             data = ''
77             if ret:
78                 reply = MSG_REQUEST_SUCCESS
79                 if isinstance(ret, (tuple, list)):
80                     data = ret[1]
81             self.transport.sendPacket(reply, data)
82
83     def ssh_REQUEST_SUCCESS(self, packet):
84         """
85         Our global request succeeded.  Get the appropriate Deferred and call
86         it back with the packet we received.
87         """
88         log.msg('RS')
89         self.deferreds['global'].pop(0).callback(packet)
90
91     def ssh_REQUEST_FAILURE(self, packet):
92         """
93         Our global request failed.  Get the appropriate Deferred and errback
94         it with the packet we received.
95         """
96         log.msg('RF')
97         self.deferreds['global'].pop(0).errback(
98             error.ConchError('global request failed', packet))
99
100     def ssh_CHANNEL_OPEN(self, packet):
101         """
102         The other side wants to get a channel.  Payload::
103             string  channel name
104             uint32  remote channel number
105             uint32  remote window size
106             uint32  remote maximum packet size
107             <channel specific data>
108
109         We get a channel from self.getChannel(), give it a local channel number
110         and notify the other side.  Then notify the channel by calling its
111         channelOpen method.
112         """
113         channelType, rest = common.getNS(packet)
114         senderChannel, windowSize, maxPacket = struct.unpack('>3L', rest[:12])
115         packet = rest[12:]
116         try:
117             channel = self.getChannel(channelType, windowSize, maxPacket,
118                             packet)
119             localChannel = self.localChannelID
120             self.localChannelID += 1
121             channel.id = localChannel
122             self.channels[localChannel] = channel
123             self.channelsToRemoteChannel[channel] = senderChannel
124             self.localToRemoteChannel[localChannel] = senderChannel
125             self.transport.sendPacket(MSG_CHANNEL_OPEN_CONFIRMATION,
126                 struct.pack('>4L', senderChannel, localChannel,
127                     channel.localWindowSize,
128                     channel.localMaxPacket)+channel.specificData)
129             log.callWithLogger(channel, channel.channelOpen, packet)
130         except Exception, e:
131             log.msg('channel open failed')
132             log.err(e)
133             if isinstance(e, error.ConchError):
134                 textualInfo, reason = e.args
135             else:
136                 reason = OPEN_CONNECT_FAILED
137                 textualInfo = "unknown failure"
138             self.transport.sendPacket(MSG_CHANNEL_OPEN_FAILURE,
139                                 struct.pack('>2L', senderChannel, reason) +
140                                common.NS(textualInfo) + common.NS(''))
141
142     def ssh_CHANNEL_OPEN_CONFIRMATION(self, packet):
143         """
144         The other side accepted our MSG_CHANNEL_OPEN request.  Payload::
145             uint32  local channel number
146             uint32  remote channel number
147             uint32  remote window size
148             uint32  remote maximum packet size
149             <channel specific data>
150
151         Find the channel using the local channel number and notify its
152         channelOpen method.
153         """
154         (localChannel, remoteChannel, windowSize,
155                 maxPacket) = struct.unpack('>4L', packet[: 16])
156         specificData = packet[16:]
157         channel = self.channels[localChannel]
158         channel.conn = self
159         self.localToRemoteChannel[localChannel] = remoteChannel
160         self.channelsToRemoteChannel[channel] = remoteChannel
161         channel.remoteWindowLeft = windowSize
162         channel.remoteMaxPacket = maxPacket
163         log.callWithLogger(channel, channel.channelOpen, specificData)
164
165     def ssh_CHANNEL_OPEN_FAILURE(self, packet):
166         """
167         The other side did not accept our MSG_CHANNEL_OPEN request.  Payload::
168             uint32  local channel number
169             uint32  reason code
170             string  reason description
171
172         Find the channel using the local channel number and notify it by
173         calling its openFailed() method.
174         """
175         localChannel, reasonCode = struct.unpack('>2L', packet[:8])
176         reasonDesc = common.getNS(packet[8:])[0]
177         channel = self.channels[localChannel]
178         del self.channels[localChannel]
179         channel.conn = self
180         reason = error.ConchError(reasonDesc, reasonCode)
181         log.callWithLogger(channel, channel.openFailed, reason)
182
183     def ssh_CHANNEL_WINDOW_ADJUST(self, packet):
184         """
185         The other side is adding bytes to its window.  Payload::
186             uint32  local channel number
187             uint32  bytes to add
188
189         Call the channel's addWindowBytes() method to add new bytes to the
190         remote window.
191         """
192         localChannel, bytesToAdd = struct.unpack('>2L', packet[:8])
193         channel = self.channels[localChannel]
194         log.callWithLogger(channel, channel.addWindowBytes, bytesToAdd)
195
196     def ssh_CHANNEL_DATA(self, packet):
197         """
198         The other side is sending us data.  Payload::
199             uint32 local channel number
200             string data
201
202         Check to make sure the other side hasn't sent too much data (more
203         than what's in the window, or more than the maximum packet size).  If
204         they have, close the channel.  Otherwise, decrease the available
205         window and pass the data to the channel's dataReceived().
206         """
207         localChannel, dataLength = struct.unpack('>2L', packet[:8])
208         channel = self.channels[localChannel]
209         # XXX should this move to dataReceived to put client in charge?
210         if (dataLength > channel.localWindowLeft or
211            dataLength > channel.localMaxPacket): # more data than we want
212             log.callWithLogger(channel, log.msg, 'too much data')
213             self.sendClose(channel)
214             return
215             #packet = packet[:channel.localWindowLeft+4]
216         data = common.getNS(packet[4:])[0]
217         channel.localWindowLeft -= dataLength
218         if channel.localWindowLeft < channel.localWindowSize / 2:
219             self.adjustWindow(channel, channel.localWindowSize - \
220                                        channel.localWindowLeft)
221             #log.msg('local window left: %s/%s' % (channel.localWindowLeft,
222             #                                    channel.localWindowSize))
223         log.callWithLogger(channel, channel.dataReceived, data)
224
225     def ssh_CHANNEL_EXTENDED_DATA(self, packet):
226         """
227         The other side is sending us exteneded data.  Payload::
228             uint32  local channel number
229             uint32  type code
230             string  data
231
232         Check to make sure the other side hasn't sent too much data (more
233         than what's in the window, or or than the maximum packet size).  If
234         they have, close the channel.  Otherwise, decrease the available
235         window and pass the data and type code to the channel's
236         extReceived().
237         """
238         localChannel, typeCode, dataLength = struct.unpack('>3L', packet[:12])
239         channel = self.channels[localChannel]
240         if (dataLength > channel.localWindowLeft or
241                 dataLength > channel.localMaxPacket):
242             log.callWithLogger(channel, log.msg, 'too much extdata')
243             self.sendClose(channel)
244             return
245         data = common.getNS(packet[8:])[0]
246         channel.localWindowLeft -= dataLength
247         if channel.localWindowLeft < channel.localWindowSize / 2:
248             self.adjustWindow(channel, channel.localWindowSize -
249                                        channel.localWindowLeft)
250         log.callWithLogger(channel, channel.extReceived, typeCode, data)
251
252     def ssh_CHANNEL_EOF(self, packet):
253         """
254         The other side is not sending any more data.  Payload::
255             uint32  local channel number
256
257         Notify the channel by calling its eofReceived() method.
258         """
259         localChannel = struct.unpack('>L', packet[:4])[0]
260         channel = self.channels[localChannel]
261         log.callWithLogger(channel, channel.eofReceived)
262
263     def ssh_CHANNEL_CLOSE(self, packet):
264         """
265         The other side is closing its end; it does not want to receive any
266         more data.  Payload::
267             uint32  local channel number
268
269         Notify the channnel by calling its closeReceived() method.  If
270         the channel has also sent a close message, call self.channelClosed().
271         """
272         localChannel = struct.unpack('>L', packet[:4])[0]
273         channel = self.channels[localChannel]
274         log.callWithLogger(channel, channel.closeReceived)
275         channel.remoteClosed = True
276         if channel.localClosed and channel.remoteClosed:
277             self.channelClosed(channel)
278
279     def ssh_CHANNEL_REQUEST(self, packet):
280         """
281         The other side is sending a request to a channel.  Payload::
282             uint32  local channel number
283             string  request name
284             bool    want reply
285             <request specific data>
286
287         Pass the message to the channel's requestReceived method.  If the
288         other side wants a reply, add callbacks which will send the
289         reply.
290         """
291         localChannel = struct.unpack('>L', packet[: 4])[0]
292         requestType, rest = common.getNS(packet[4:])
293         wantReply = ord(rest[0])
294         channel = self.channels[localChannel]
295         d = defer.maybeDeferred(log.callWithLogger, channel,
296                 channel.requestReceived, requestType, rest[1:])
297         if wantReply:
298             d.addCallback(self._cbChannelRequest, localChannel)
299             d.addErrback(self._ebChannelRequest, localChannel)
300             return d
301
302     def _cbChannelRequest(self, result, localChannel):
303         """
304         Called back if the other side wanted a reply to a channel request.  If
305         the result is true, send a MSG_CHANNEL_SUCCESS.  Otherwise, raise
306         a C{error.ConchError}
307
308         @param result: the value returned from the channel's requestReceived()
309             method.  If it's False, the request failed.
310         @type result: C{bool}
311         @param localChannel: the local channel ID of the channel to which the
312             request was made.
313         @type localChannel: C{int}
314         @raises ConchError: if the result is False.
315         """
316         if not result:
317             raise error.ConchError('failed request')
318         self.transport.sendPacket(MSG_CHANNEL_SUCCESS, struct.pack('>L',
319                                 self.localToRemoteChannel[localChannel]))
320
321     def _ebChannelRequest(self, result, localChannel):
322         """
323         Called if the other wisde wanted a reply to the channel requeset and
324         the channel request failed.
325
326         @param result: a Failure, but it's not used.
327         @param localChannel: the local channel ID of the channel to which the
328             request was made.
329         @type localChannel: C{int}
330         """
331         self.transport.sendPacket(MSG_CHANNEL_FAILURE, struct.pack('>L',
332                                 self.localToRemoteChannel[localChannel]))
333
334     def ssh_CHANNEL_SUCCESS(self, packet):
335         """
336         Our channel request to the other other side succeeded.  Payload::
337             uint32  local channel number
338
339         Get the C{Deferred} out of self.deferreds and call it back.
340         """
341         localChannel = struct.unpack('>L', packet[:4])[0]
342         if self.deferreds.get(localChannel):
343             d = self.deferreds[localChannel].pop(0)
344             log.callWithLogger(self.channels[localChannel],
345                                d.callback, '')
346
347     def ssh_CHANNEL_FAILURE(self, packet):
348         """
349         Our channel request to the other side failed.  Payload::
350             uint32  local channel number
351
352         Get the C{Deferred} out of self.deferreds and errback it with a
353         C{error.ConchError}.
354         """
355         localChannel = struct.unpack('>L', packet[:4])[0]
356         if self.deferreds.get(localChannel):
357             d = self.deferreds[localChannel].pop(0)
358             log.callWithLogger(self.channels[localChannel],
359                                d.errback,
360                                error.ConchError('channel request failed'))
361
362     # methods for users of the connection to call
363
364     def sendGlobalRequest(self, request, data, wantReply=0):
365         """
366         Send a global request for this connection.  Current this is only used
367         for remote->local TCP forwarding.
368
369         @type request:      C{str}
370         @type data:         C{str}
371         @type wantReply:    C{bool}
372         @rtype              C{Deferred}/C{None}
373         """
374         self.transport.sendPacket(MSG_GLOBAL_REQUEST,
375                                   common.NS(request)
376                                   + (wantReply and '\xff' or '\x00')
377                                   + data)
378         if wantReply:
379             d = defer.Deferred()
380             self.deferreds.setdefault('global', []).append(d)
381             return d
382
383     def openChannel(self, channel, extra=''):
384         """
385         Open a new channel on this connection.
386
387         @type channel:  subclass of C{SSHChannel}
388         @type extra:    C{str}
389         """
390         log.msg('opening channel %s with %s %s'%(self.localChannelID,
391                 channel.localWindowSize, channel.localMaxPacket))
392         self.transport.sendPacket(MSG_CHANNEL_OPEN, common.NS(channel.name)
393                     + struct.pack('>3L', self.localChannelID,
394                     channel.localWindowSize, channel.localMaxPacket)
395                     + extra)
396         channel.id = self.localChannelID
397         self.channels[self.localChannelID] = channel
398         self.localChannelID += 1
399
400     def sendRequest(self, channel, requestType, data, wantReply=0):
401         """
402         Send a request to a channel.
403
404         @type channel:      subclass of C{SSHChannel}
405         @type requestType:  C{str}
406         @type data:         C{str}
407         @type wantReply:    C{bool}
408         @rtype              C{Deferred}/C{None}
409         """
410         if channel.localClosed:
411             return
412         log.msg('sending request %s' % requestType)
413         self.transport.sendPacket(MSG_CHANNEL_REQUEST, struct.pack('>L',
414                                     self.channelsToRemoteChannel[channel])
415                                   + common.NS(requestType)+chr(wantReply)
416                                   + data)
417         if wantReply:
418             d = defer.Deferred()
419             self.deferreds.setdefault(channel.id, []).append(d)
420             return d
421
422     def adjustWindow(self, channel, bytesToAdd):
423         """
424         Tell the other side that we will receive more data.  This should not
425         normally need to be called as it is managed automatically.
426
427         @type channel:      subclass of L{SSHChannel}
428         @type bytesToAdd:   C{int}
429         """
430         if channel.localClosed:
431             return # we're already closed
432         self.transport.sendPacket(MSG_CHANNEL_WINDOW_ADJUST, struct.pack('>2L',
433                                     self.channelsToRemoteChannel[channel],
434                                     bytesToAdd))
435         log.msg('adding %i to %i in channel %i' % (bytesToAdd,
436             channel.localWindowLeft, channel.id))
437         channel.localWindowLeft += bytesToAdd
438
439     def sendData(self, channel, data):
440         """
441         Send data to a channel.  This should not normally be used: instead use
442         channel.write(data) as it manages the window automatically.
443
444         @type channel:  subclass of L{SSHChannel}
445         @type data:     C{str}
446         """
447         if channel.localClosed:
448             return # we're already closed
449         self.transport.sendPacket(MSG_CHANNEL_DATA, struct.pack('>L',
450                                     self.channelsToRemoteChannel[channel]) +
451                                    common.NS(data))
452
453     def sendExtendedData(self, channel, dataType, data):
454         """
455         Send extended data to a channel.  This should not normally be used:
456         instead use channel.writeExtendedData(data, dataType) as it manages
457         the window automatically.
458
459         @type channel:  subclass of L{SSHChannel}
460         @type dataType: C{int}
461         @type data:     C{str}
462         """
463         if channel.localClosed:
464             return # we're already closed
465         self.transport.sendPacket(MSG_CHANNEL_EXTENDED_DATA, struct.pack('>2L',
466                             self.channelsToRemoteChannel[channel],dataType) \
467                             + common.NS(data))
468
469     def sendEOF(self, channel):
470         """
471         Send an EOF (End of File) for a channel.
472
473         @type channel:  subclass of L{SSHChannel}
474         """
475         if channel.localClosed:
476             return # we're already closed
477         log.msg('sending eof')
478         self.transport.sendPacket(MSG_CHANNEL_EOF, struct.pack('>L',
479                                     self.channelsToRemoteChannel[channel]))
480
481     def sendClose(self, channel):
482         """
483         Close a channel.
484
485         @type channel:  subclass of L{SSHChannel}
486         """
487         if channel.localClosed:
488             return # we're already closed
489         log.msg('sending close %i' % channel.id)
490         self.transport.sendPacket(MSG_CHANNEL_CLOSE, struct.pack('>L',
491                 self.channelsToRemoteChannel[channel]))
492         channel.localClosed = True
493         if channel.localClosed and channel.remoteClosed:
494             self.channelClosed(channel)
495
496     # methods to override
497     def getChannel(self, channelType, windowSize, maxPacket, data):
498         """
499         The other side requested a channel of some sort.
500         channelType is the type of channel being requested,
501         windowSize is the initial size of the remote window,
502         maxPacket is the largest packet we should send,
503         data is any other packet data (often nothing).
504
505         We return a subclass of L{SSHChannel}.
506
507         By default, this dispatches to a method 'channel_channelType' with any
508         non-alphanumerics in the channelType replace with _'s.  If it cannot
509         find a suitable method, it returns an OPEN_UNKNOWN_CHANNEL_TYPE error.
510         The method is called with arguments of windowSize, maxPacket, data.
511
512         @type channelType:  C{str}
513         @type windowSize:   C{int}
514         @type maxPacket:    C{int}
515         @type data:         C{str}
516         @rtype:             subclass of L{SSHChannel}/C{tuple}
517         """
518         log.msg('got channel %s request' % channelType)
519         if hasattr(self.transport, "avatar"): # this is a server!
520             chan = self.transport.avatar.lookupChannel(channelType,
521                                                        windowSize,
522                                                        maxPacket,
523                                                        data)
524         else:
525             channelType = channelType.translate(TRANSLATE_TABLE)
526             f = getattr(self, 'channel_%s' % channelType, None)
527             if f is not None:
528                 chan = f(windowSize, maxPacket, data)
529             else:
530                 chan = None
531         if chan is None:
532             raise error.ConchError('unknown channel',
533                     OPEN_UNKNOWN_CHANNEL_TYPE)
534         else:
535             chan.conn = self
536             return chan
537
538     def gotGlobalRequest(self, requestType, data):
539         """
540         We got a global request.  pretty much, this is just used by the client
541         to request that we forward a port from the server to the client.
542         Returns either:
543             - 1: request accepted
544             - 1, <data>: request accepted with request specific data
545             - 0: request denied
546
547         By default, this dispatches to a method 'global_requestType' with
548         -'s in requestType replaced with _'s.  The found method is passed data.
549         If this method cannot be found, this method returns 0.  Otherwise, it
550         returns the return value of that method.
551
552         @type requestType:  C{str}
553         @type data:         C{str}
554         @rtype:             C{int}/C{tuple}
555         """
556         log.msg('got global %s request' % requestType)
557         if hasattr(self.transport, 'avatar'): # this is a server!
558             return self.transport.avatar.gotGlobalRequest(requestType, data)
559
560         requestType = requestType.replace('-','_')
561         f = getattr(self, 'global_%s' % requestType, None)
562         if not f:
563             return 0
564         return f(data)
565
566     def channelClosed(self, channel):
567         """
568         Called when a channel is closed.
569         It clears the local state related to the channel, and calls
570         channel.closed().
571         MAKE SURE YOU CALL THIS METHOD, even if you subclass L{SSHConnection}.
572         If you don't, things will break mysteriously.
573         """
574         if channel in self.channelsToRemoteChannel: # actually open
575             channel.localClosed = channel.remoteClosed = True
576             del self.localToRemoteChannel[channel.id]
577             del self.channels[channel.id]
578             del self.channelsToRemoteChannel[channel]
579             self.deferreds[channel.id] = []
580             log.callWithLogger(channel, channel.closed)
581
582 MSG_GLOBAL_REQUEST = 80
583 MSG_REQUEST_SUCCESS = 81
584 MSG_REQUEST_FAILURE = 82
585 MSG_CHANNEL_OPEN = 90
586 MSG_CHANNEL_OPEN_CONFIRMATION = 91
587 MSG_CHANNEL_OPEN_FAILURE = 92
588 MSG_CHANNEL_WINDOW_ADJUST = 93
589 MSG_CHANNEL_DATA = 94
590 MSG_CHANNEL_EXTENDED_DATA = 95
591 MSG_CHANNEL_EOF = 96
592 MSG_CHANNEL_CLOSE = 97
593 MSG_CHANNEL_REQUEST = 98
594 MSG_CHANNEL_SUCCESS = 99
595 MSG_CHANNEL_FAILURE = 100
596
597 OPEN_ADMINISTRATIVELY_PROHIBITED = 1
598 OPEN_CONNECT_FAILED = 2
599 OPEN_UNKNOWN_CHANNEL_TYPE = 3
600 OPEN_RESOURCE_SHORTAGE = 4
601
602 EXTENDED_DATA_STDERR = 1
603
604 messages = {}
605 for name, value in locals().copy().items():
606     if name[:4] == 'MSG_':
607         messages[value] = name # doesn't handle doubles
608
609 import string
610 alphanums = string.letters + string.digits
611 TRANSLATE_TABLE = ''.join([chr(i) in alphanums and chr(i) or '_'
612     for i in range(256)])
613 SSHConnection.protocolMessages = messages
Note: See TracBrowser for help on using the browser.