root/trunk/twisted/conch/ssh/session.py

Revision 31146, 10.6 KB (checked in by exarkun, 14 months ago)

Merge session-transport-addresses-2453

Author: salgado
Reviewer: exarkun
Fixes: #2453

Add getHost and getPeer methods to SSHSessionProcessProtocol, used as the transport
for subsystem protocols.

  • Property svn:executable set to *
Line 
1# -*- test-case-name: twisted.conch.test.test_session -*-
2# Copyright (c) Twisted Matrix Laboratories.
3# See LICENSE for details.
4
5"""
6This module contains the implementation of SSHSession, which (by default)
7allows access to a shell and a python interpreter over SSH.
8
9Maintainer: Paul Swartz
10"""
11
12import struct
13import signal
14import sys
15import os
16from zope.interface import implements
17
18from twisted.internet import interfaces, protocol
19from twisted.python import log
20from twisted.conch.interfaces import ISession
21from twisted.conch.ssh import common, channel
22
23class SSHSession(channel.SSHChannel):
24
25    name = 'session'
26    def __init__(self, *args, **kw):
27        channel.SSHChannel.__init__(self, *args, **kw)
28        self.buf = ''
29        self.client = None
30        self.session = None
31
32    def request_subsystem(self, data):
33        subsystem, ignored= common.getNS(data)
34        log.msg('asking for subsystem "%s"' % subsystem)
35        client = self.avatar.lookupSubsystem(subsystem, data)
36        if client:
37            pp = SSHSessionProcessProtocol(self)
38            proto = wrapProcessProtocol(pp)
39            client.makeConnection(proto)
40            pp.makeConnection(wrapProtocol(client))
41            self.client = pp
42            return 1
43        else:
44            log.msg('failed to get subsystem')
45            return 0
46
47    def request_shell(self, data):
48        log.msg('getting shell')
49        if not self.session:
50            self.session = ISession(self.avatar)
51        try:
52            pp = SSHSessionProcessProtocol(self)
53            self.session.openShell(pp)
54        except:
55            log.deferr()
56            return 0
57        else:
58            self.client = pp
59            return 1
60
61    def request_exec(self, data):
62        if not self.session:
63            self.session = ISession(self.avatar)
64        f,data = common.getNS(data)
65        log.msg('executing command "%s"' % f)
66        try:
67            pp = SSHSessionProcessProtocol(self)
68            self.session.execCommand(pp, f)
69        except:
70            log.deferr()
71            return 0
72        else:
73            self.client = pp
74            return 1
75
76    def request_pty_req(self, data):
77        if not self.session:
78            self.session = ISession(self.avatar)
79        term, windowSize, modes = parseRequest_pty_req(data)
80        log.msg('pty request: %s %s' % (term, windowSize))
81        try:
82            self.session.getPty(term, windowSize, modes)
83        except:
84            log.err()
85            return 0
86        else:
87            return 1
88
89    def request_window_change(self, data):
90        if not self.session:
91            self.session = ISession(self.avatar)
92        winSize = parseRequest_window_change(data)
93        try:
94            self.session.windowChanged(winSize)
95        except:
96            log.msg('error changing window size')
97            log.err()
98            return 0
99        else:
100            return 1
101
102    def dataReceived(self, data):
103        if not self.client:
104            #self.conn.sendClose(self)
105            self.buf += data
106            return
107        self.client.transport.write(data)
108
109    def extReceived(self, dataType, data):
110        if dataType == connection.EXTENDED_DATA_STDERR:
111            if self.client and hasattr(self.client.transport, 'writeErr'):
112                self.client.transport.writeErr(data)
113        else:
114            log.msg('weird extended data: %s'%dataType)
115
116    def eofReceived(self):
117        if self.session:
118            self.session.eofReceived()
119        elif self.client:
120            self.conn.sendClose(self)
121
122    def closed(self):
123        if self.session:
124            self.session.closed()
125        elif self.client:
126            self.client.transport.loseConnection()
127
128    #def closeReceived(self):
129    #    self.loseConnection() # don't know what to do with this
130
131    def loseConnection(self):
132        if self.client:
133            self.client.transport.loseConnection()
134        channel.SSHChannel.loseConnection(self)
135
136class _ProtocolWrapper(protocol.ProcessProtocol):
137    """
138    This class wraps a L{Protocol} instance in a L{ProcessProtocol} instance.
139    """
140    def __init__(self, proto):
141        self.proto = proto
142
143    def connectionMade(self): self.proto.connectionMade()
144
145    def outReceived(self, data): self.proto.dataReceived(data)
146
147    def processEnded(self, reason): self.proto.connectionLost(reason)
148
149class _DummyTransport:
150
151    def __init__(self, proto):
152        self.proto = proto
153
154    def dataReceived(self, data):
155        self.proto.transport.write(data)
156
157    def write(self, data):
158        self.proto.dataReceived(data)
159
160    def writeSequence(self, seq):
161        self.write(''.join(seq))
162
163    def loseConnection(self):
164        self.proto.connectionLost(protocol.connectionDone)
165
166def wrapProcessProtocol(inst):
167    if isinstance(inst, protocol.Protocol):
168        return _ProtocolWrapper(inst)
169    else:
170        return inst
171
172def wrapProtocol(proto):
173    return _DummyTransport(proto)
174
175
176
177# SUPPORTED_SIGNALS is a list of signals that every session channel is supposed
178# to accept.  See RFC 4254
179SUPPORTED_SIGNALS = ["ABRT", "ALRM", "FPE", "HUP", "ILL", "INT", "KILL",
180                     "PIPE", "QUIT", "SEGV", "TERM", "USR1", "USR2"]
181
182
183
184class SSHSessionProcessProtocol(protocol.ProcessProtocol):
185    """I am both an L{IProcessProtocol} and an L{ITransport}.
186
187    I am a transport to the remote endpoint and a process protocol to the
188    local subsystem.
189    """
190
191    implements(interfaces.ITransport)
192
193    # once initialized, a dictionary mapping signal values to strings
194    # that follow RFC 4254.
195    _signalValuesToNames = None
196
197    def __init__(self, session):
198        self.session = session
199        self.lostOutOrErrFlag = False
200
201    def connectionMade(self):
202        if self.session.buf:
203            self.transport.write(self.session.buf)
204            self.session.buf = None
205
206    def outReceived(self, data):
207        self.session.write(data)
208
209    def errReceived(self, err):
210        self.session.writeExtended(connection.EXTENDED_DATA_STDERR, err)
211
212    def outConnectionLost(self):
213        """
214        EOF should only be sent when both STDOUT and STDERR have been closed.
215        """
216        if self.lostOutOrErrFlag:
217            self.session.conn.sendEOF(self.session)
218        else:
219            self.lostOutOrErrFlag = True
220
221    def errConnectionLost(self):
222        """
223        See outConnectionLost().
224        """
225        self.outConnectionLost()
226
227    def connectionLost(self, reason = None):
228        self.session.loseConnection()
229
230
231    def _getSignalName(self, signum):
232        """
233        Get a signal name given a signal number.
234        """
235        if self._signalValuesToNames is None:
236            self._signalValuesToNames = {}
237            # make sure that the POSIX ones are the defaults
238            for signame in SUPPORTED_SIGNALS:
239                signame = 'SIG' + signame
240                sigvalue = getattr(signal, signame, None)
241                if sigvalue is not None:
242                    self._signalValuesToNames[sigvalue] = signame
243            for k, v in signal.__dict__.items():
244                # Check for platform specific signals, ignoring Python specific
245                # SIG_DFL and SIG_IGN
246                if k.startswith('SIG') and not k.startswith('SIG_'):
247                    if v not in self._signalValuesToNames:
248                        self._signalValuesToNames[v] = k + '@' + sys.platform
249        return self._signalValuesToNames[signum]
250
251
252    def processEnded(self, reason=None):
253        """
254        When we are told the process ended, try to notify the other side about
255        how the process ended using the exit-signal or exit-status requests.
256        Also, close the channel.
257        """
258        if reason is not None:
259            err = reason.value
260            if err.signal is not None:
261                signame = self._getSignalName(err.signal)
262                if (getattr(os, 'WCOREDUMP', None) is not None and
263                    os.WCOREDUMP(err.status)):
264                    log.msg('exitSignal: %s (core dumped)' % (signame,))
265                    coreDumped = 1
266                else:
267                    log.msg('exitSignal: %s' % (signame,))
268                    coreDumped = 0
269                self.session.conn.sendRequest(self.session, 'exit-signal',
270                        common.NS(signame[3:]) + chr(coreDumped) +
271                        common.NS('') + common.NS(''))
272            elif err.exitCode is not None:
273                log.msg('exitCode: %r' % (err.exitCode,))
274                self.session.conn.sendRequest(self.session, 'exit-status',
275                        struct.pack('>L', err.exitCode))
276        self.session.loseConnection()
277
278
279    def getHost(self):
280        """
281        Return the host from my session's transport.
282        """
283        return self.session.conn.transport.getHost()
284
285
286    def getPeer(self):
287        """
288        Return the peer from my session's transport.
289        """
290        return self.session.conn.transport.getPeer()
291
292
293    def write(self, data):
294        self.session.write(data)
295
296
297    def writeSequence(self, seq):
298        self.session.write(''.join(seq))
299
300
301    def loseConnection(self):
302        self.session.loseConnection()
303
304
305
306class SSHSessionClient(protocol.Protocol):
307
308    def dataReceived(self, data):
309        if self.transport:
310            self.transport.write(data)
311
312# methods factored out to make live easier on server writers
313def parseRequest_pty_req(data):
314    """Parse the data from a pty-req request into usable data.
315
316    @returns: a tuple of (terminal type, (rows, cols, xpixel, ypixel), modes)
317    """
318    term, rest = common.getNS(data)
319    cols, rows, xpixel, ypixel = struct.unpack('>4L', rest[: 16])
320    modes, ignored= common.getNS(rest[16:])
321    winSize = (rows, cols, xpixel, ypixel)
322    modes = [(ord(modes[i]), struct.unpack('>L', modes[i+1: i+5])[0]) for i in range(0, len(modes)-1, 5)]
323    return term, winSize, modes
324
325def packRequest_pty_req(term, (rows, cols, xpixel, ypixel), modes):
326    """Pack a pty-req request so that it is suitable for sending.
327
328    NOTE: modes must be packed before being sent here.
329    """
330    termPacked = common.NS(term)
331    winSizePacked = struct.pack('>4L', cols, rows, xpixel, ypixel)
332    modesPacked = common.NS(modes) # depend on the client packing modes
333    return termPacked + winSizePacked + modesPacked
334
335def parseRequest_window_change(data):
336    """Parse the data from a window-change request into usuable data.
337
338    @returns: a tuple of (rows, cols, xpixel, ypixel)
339    """
340    cols, rows, xpixel, ypixel = struct.unpack('>4L', data)
341    return rows, cols, xpixel, ypixel
342
343def packRequest_window_change((rows, cols, xpixel, ypixel)):
344    """Pack a window-change request so that it is suitable for sending.
345    """
346    return struct.pack('>4L', cols, rows, xpixel, ypixel)
347
348import connection
Note: See TracBrowser for help on using the browser.