root / trunk / twisted / conch / ssh / session.py

Revision 27060, 9.7 kB (checked in by therve, 3 days ago)

Merge session-signal-2687

Author: therve
Reviewer: jml
Fixes #2687

Manage process termination in Conch due to signals, sending exit-signal
message instead of exit-status.

  • Property svn:executable set to *
Line 
1 # -*- test-case-name: twisted.conch.test.test_session -*-
2 # Copyright (c) 2001-2009 Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5 """
6 This module contains the implementation of SSHSession, which (by default)
7 allows access to a shell and a python interpreter over SSH.
8
9 Maintainer: Paul Swartz
10 """
11
12 import struct
13 import signal
14 import sys
15 import os
16
17 from twisted.internet import protocol
18 from twisted.python import log
19 from twisted.conch.interfaces import ISession
20 from twisted.conch.ssh import common, channel
21
22 class SSHSession(channel.SSHChannel):
23
24     name = 'session'
25     def __init__(self, *args, **kw):
26         channel.SSHChannel.__init__(self, *args, **kw)
27         self.buf = ''
28         self.client = None
29         self.session = None
30
31     def request_subsystem(self, data):
32         subsystem, ignored= common.getNS(data)
33         log.msg('asking for subsystem "%s"' % subsystem)
34         client = self.avatar.lookupSubsystem(subsystem, data)
35         if client:
36             pp = SSHSessionProcessProtocol(self)
37             proto = wrapProcessProtocol(pp)
38             client.makeConnection(proto)
39             pp.makeConnection(wrapProtocol(client))
40             self.client = pp
41             return 1
42         else:
43             log.msg('failed to get subsystem')
44             return 0
45
46     def request_shell(self, data):
47         log.msg('getting shell')
48         if not self.session:
49             self.session = ISession(self.avatar)
50         try:
51             pp = SSHSessionProcessProtocol(self)
52             self.session.openShell(pp)
53         except:
54             log.deferr()
55             return 0
56         else:
57             self.client = pp
58             return 1
59
60     def request_exec(self, data):
61         if not self.session:
62             self.session = ISession(self.avatar)
63         f,data = common.getNS(data)
64         log.msg('executing command "%s"' % f)
65         try:
66             pp = SSHSessionProcessProtocol(self)
67             self.session.execCommand(pp, f)
68         except:
69             log.deferr()
70             return 0
71         else:
72             self.client = pp
73             return 1
74
75     def request_pty_req(self, data):
76         if not self.session:
77             self.session = ISession(self.avatar)
78         term, windowSize, modes = parseRequest_pty_req(data)
79         log.msg('pty request: %s %s' % (term, windowSize))
80         try:
81             self.session.getPty(term, windowSize, modes)
82         except:
83             log.err()
84             return 0
85         else:
86             return 1
87
88     def request_window_change(self, data):
89         if not self.session:
90             self.session = ISession(self.avatar)
91         winSize = parseRequest_window_change(data)
92         try:
93             self.session.windowChanged(winSize)
94         except:
95             log.msg('error changing window size')
96             log.err()
97             return 0
98         else:
99             return 1
100
101     def dataReceived(self, data):
102         if not self.client:
103             #self.conn.sendClose(self)
104             self.buf += data
105             return
106         self.client.transport.write(data)
107
108     def extReceived(self, dataType, data):
109         if dataType == connection.EXTENDED_DATA_STDERR:
110             if self.client and hasattr(self.client.transport, 'writeErr'):
111                 self.client.transport.writeErr(data)
112         else:
113             log.msg('weird extended data: %s'%dataType)
114
115     def eofReceived(self):
116         if self.session:
117             self.session.eofReceived()
118         elif self.client:
119             self.conn.sendClose(self)
120
121     def closed(self):
122         if self.session:
123             self.session.closed()
124         elif self.client:
125             self.client.transport.loseConnection()
126
127     #def closeReceived(self):
128     #    self.loseConnection() # don't know what to do with this
129
130     def loseConnection(self):
131         if self.client:
132             self.client.transport.loseConnection()
133         channel.SSHChannel.loseConnection(self)
134
135 class _ProtocolWrapper(protocol.ProcessProtocol):
136     """
137     This class wraps a L{Protocol} instance in a L{ProcessProtocol} instance.
138     """
139     def __init__(self, proto):
140         self.proto = proto
141
142     def connectionMade(self): self.proto.connectionMade()
143
144     def outReceived(self, data): self.proto.dataReceived(data)
145
146     def processEnded(self, reason): self.proto.connectionLost(reason)
147
148 class _DummyTransport:
149
150     def __init__(self, proto):
151         self.proto = proto
152
153     def dataReceived(self, data):
154         self.proto.transport.write(data)
155
156     def write(self, data):
157         self.proto.dataReceived(data)
158
159     def writeSequence(self, seq):
160         self.write(''.join(seq))
161
162     def loseConnection(self):
163         self.proto.connectionLost(protocol.connectionDone)
164
165 def wrapProcessProtocol(inst):
166     if isinstance(inst, protocol.Protocol):
167         return _ProtocolWrapper(inst)
168     else:
169         return inst
170
171 def wrapProtocol(proto):
172     return _DummyTransport(proto)
173
174
175
176 # SUPPORTED_SIGNALS is a list of signals that every session channel is supposed
177 # to accept.  See RFC 4254
178 SUPPORTED_SIGNALS = ["ABRT", "ALRM", "FPE", "HUP", "ILL", "INT", "KILL",
179                      "PIPE", "QUIT", "SEGV", "TERM", "USR1", "USR2"]
180
181
182
183 class SSHSessionProcessProtocol(protocol.ProcessProtocol):
184
185     # once initialized, a dictionary mapping signal values to strings
186     # that follow RFC 4254.
187     _signalValuesToNames = None
188
189     def __init__(self, session):
190         self.session = session
191
192     def connectionMade(self):
193         if self.session.buf:
194             self.transport.write(self.session.buf)
195             self.session.buf = None
196
197     def outReceived(self, data):
198         self.session.write(data)
199
200     def errReceived(self, err):
201         self.session.writeExtended(connection.EXTENDED_DATA_STDERR, err)
202
203     def inConnectionLost(self):
204         self.session.conn.sendEOF(self.session)
205
206     def connectionLost(self, reason = None):
207         self.session.loseConnection()
208
209
210     def _getSignalName(self, signum):
211         """
212         Get a signal name given a signal number.
213         """
214         if self._signalValuesToNames is None:
215             self._signalValuesToNames = {}
216             # make sure that the POSIX ones are the defaults
217             for signame in SUPPORTED_SIGNALS:
218                 signame = 'SIG' + signame
219                 sigvalue = getattr(signal, signame, None)
220                 if sigvalue is not None:
221                     self._signalValuesToNames[sigvalue] = signame
222             for k, v in signal.__dict__.items():
223                 # Check for platform specific signals, ignoring Python specific
224                 # SIG_DFL and SIG_IGN
225                 if k.startswith('SIG') and not k.startswith('SIG_'):
226                     if v not in self._signalValuesToNames:
227                         self._signalValuesToNames[v] = k + '@' + sys.platform
228         return self._signalValuesToNames[signum]
229
230
231     def processEnded(self, reason=None):
232         """
233         When we are told the process ended, try to notify the other side about
234         how the process ended using the exit-signal or exit-status requests.
235         Also, close the channel.
236         """
237         if reason is not None:
238             err = reason.value
239             if err.signal is not None:
240                 signame = self._getSignalName(err.signal)
241                 if (getattr(os, 'WCOREDUMP', None) is not None and
242                     os.WCOREDUMP(err.status)):
243                     log.msg('exitSignal: %s (core dumped)' % (signame,))
244                     coreDumped = 1
245                 else:
246                     log.msg('exitSignal: %s' % (signame,))
247                     coreDumped = 0
248                 self.session.conn.sendRequest(self.session, 'exit-signal',
249                         common.NS(signame[3:]) + chr(coreDumped) +
250                         common.NS('') + common.NS(''))
251             elif err.exitCode is not None:
252                 log.msg('exitCode: %r' % (err.exitCode,))
253                 self.session.conn.sendRequest(self.session, 'exit-status',
254                         struct.pack('>L', err.exitCode))
255         self.session.loseConnection()
256
257     # transport stuff (we are also a transport!)
258
259     def write(self, data):
260         self.session.write(data)
261
262     def writeSequence(self, seq):
263         self.session.write(''.join(seq))
264
265     def loseConnection(self):
266         self.session.loseConnection()
267
268 class SSHSessionClient(protocol.Protocol):
269
270     def dataReceived(self, data):
271         if self.transport:
272             self.transport.write(data)
273
274 # methods factored out to make live easier on server writers
275 def parseRequest_pty_req(data):
276     """Parse the data from a pty-req request into usable data.
277
278     @returns: a tuple of (terminal type, (rows, cols, xpixel, ypixel), modes)
279     """
280     term, rest = common.getNS(data)
281     cols, rows, xpixel, ypixel = struct.unpack('>4L', rest[: 16])
282     modes, ignored= common.getNS(rest[16:])
283     winSize = (rows, cols, xpixel, ypixel)
284     modes = [(ord(modes[i]), struct.unpack('>L', modes[i+1: i+5])[0]) for i in range(0, len(modes)-1, 5)]
285     return term, winSize, modes
286
287 def packRequest_pty_req(term, (rows, cols, xpixel, ypixel), modes):
288     """Pack a pty-req request so that it is suitable for sending.
289
290     NOTE: modes must be packed before being sent here.
291     """
292     termPacked = common.NS(term)
293     winSizePacked = struct.pack('>4L', cols, rows, xpixel, ypixel)
294     modesPacked = common.NS(modes) # depend on the client packing modes
295     return termPacked + winSizePacked + modesPacked
296
297 def parseRequest_window_change(data):
298     """Parse the data from a window-change request into usuable data.
299
300     @returns: a tuple of (rows, cols, xpixel, ypixel)
301     """
302     cols, rows, xpixel, ypixel = struct.unpack('>4L', data)
303     return rows, cols, xpixel, ypixel
304
305 def packRequest_window_change((rows, cols, xpixel, ypixel)):
306     """Pack a window-change request so that it is suitable for sending.
307     """
308     return struct.pack('>4L', cols, rows, xpixel, ypixel)
309
310 import connection
Note: See TracBrowser for help on using the browser.