root / trunk / twisted / internet / ssl.py

Revision 26525, 7.3 kB (checked in by exarkun, 3 months ago)

Merge early-ssl-context-error-reporting-3700

Author: exarkun
Reviewer: glyph
Fixes: #3700

Reinstate the early failure mode for invalid parameters passed to
DefaultOpenSSLContextFactory and add tests for this behavior so
it is preserved in the future.

Line 
1 # -*- test-case-name: twisted.test.test_ssl -*-
2 # Copyright (c) 2001-2009 Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5
6 """
7 SSL transport. Requires PyOpenSSL (http://pyopenssl.sf.net).
8
9 SSL connections require a ContextFactory so they can create SSL contexts.
10 End users should only use the ContextFactory classes directly - for SSL
11 connections use the reactor.connectSSL/listenSSL and so on, as documented
12 in IReactorSSL.
13
14 All server context factories should inherit from ContextFactory, and all
15 client context factories should inherit from ClientContextFactory. At the
16 moment this is not enforced, but in the future it might be.
17
18 Future Plans:
19     - split module so reactor-specific classes are in a separate module
20     - support for switching TCP into SSL
21     - more options
22
23 Maintainer: Itamar Shtull-Trauring
24 """
25
26 # If something goes wrong, most notably an OpenSSL import failure,
27 # sys.modules['twisted.internet.ssl'] will be bound to a partially
28 # initialized module object.  This is wacko, but we will take advantage
29 # of it to publish whether or not SSL is available.
30 # See the end of this module for the other half of this solution.
31
32 # The correct idiom to import this module is thus:
33
34 # try:
35 #    from twisted.internet import ssl
36 # except ImportError:
37 #    # happens the first time the interpreter tries to import it
38 #    ssl = None
39 # if ssl and not ssl.supported:
40 #    # happens second and later times
41 #    ssl = None
42
43 supported = False
44
45 # System imports
46 from OpenSSL import SSL
47 from zope.interface import implements, implementsOnly, implementedBy
48
49 # Twisted imports
50 from twisted.internet import tcp, interfaces, base, address
51
52
53 class ContextFactory:
54     """A factory for SSL context objects, for server SSL connections."""
55
56     isClient = 0
57
58     def getContext(self):
59         """Return a SSL.Context object. override in subclasses."""
60         raise NotImplementedError
61
62
63 class DefaultOpenSSLContextFactory(ContextFactory):
64     """
65     L{DefaultOpenSSLContextFactory} is a factory for server-side SSL context
66     objects.  These objects define certain parameters related to SSL
67     handshakes and the subsequent connection.
68
69     @ivar _contextFactory: A callable which will be used to create new
70         context objects.  This is typically L{SSL.Context}.
71     """
72     _context = None
73
74     def __init__(self, privateKeyFileName, certificateFileName,
75                  sslmethod=SSL.SSLv23_METHOD, _contextFactory=SSL.Context):
76         """
77         @param privateKeyFileName: Name of a file containing a private key
78         @param certificateFileName: Name of a file containing a certificate
79         @param sslmethod: The SSL method to use
80         """
81         self.privateKeyFileName = privateKeyFileName
82         self.certificateFileName = certificateFileName
83         self.sslmethod = sslmethod
84         self._contextFactory = _contextFactory
85
86         # Create a context object right now.  This is to force validation of
87         # the given parameters so that errors are detected earlier rather
88         # than later.
89         self.cacheContext()
90
91
92     def cacheContext(self):
93         if self._context is None:
94             ctx = self._contextFactory(self.sslmethod)
95             # Disallow SSLv2!  It's insecure!  SSLv3 has been around since
96             # 1996.  It's time to move on.
97             ctx.set_options(SSL.OP_NO_SSLv2)
98             ctx.use_certificate_file(self.certificateFileName)
99             ctx.use_privatekey_file(self.privateKeyFileName)
100             self._context = ctx
101
102
103     def __getstate__(self):
104         d = self.__dict__.copy()
105         del d['_context']
106         return d
107
108
109     def __setstate__(self, state):
110         self.__dict__ = state
111
112
113     def getContext(self):
114         """
115         Return an SSL context.
116         """
117         return self._context
118
119
120 class ClientContextFactory:
121     """A context factory for SSL clients."""
122
123     isClient = 1
124
125     # SSLv23_METHOD allows SSLv2, SSLv3, and TLSv1.  We disable SSLv2 below,
126     # though.
127     method = SSL.SSLv23_METHOD
128
129     _contextFactory = SSL.Context
130
131     def getContext(self):
132         ctx = self._contextFactory(self.method)
133         # See comment in DefaultOpenSSLContextFactory about SSLv2.
134         ctx.set_options(SSL.OP_NO_SSLv2)
135         return ctx
136
137
138
139 class Client(tcp.Client):
140     """I am an SSL client."""
141
142     implementsOnly(interfaces.ISSLTransport,
143                    *[i for i in implementedBy(tcp.Client) if i != interfaces.ITLSTransport])
144    
145     def __init__(self, host, port, bindAddress, ctxFactory, connector, reactor=None):
146         # tcp.Client.__init__ depends on self.ctxFactory being set
147         self.ctxFactory = ctxFactory
148         tcp.Client.__init__(self, host, port, bindAddress, connector, reactor)
149
150     def getHost(self):
151         """Returns the address from which I am connecting."""
152         h, p = self.socket.getsockname()
153         return address.IPv4Address('TCP', h, p, 'SSL')
154
155     def getPeer(self):
156         """Returns the address that I am connected."""
157         return address.IPv4Address('TCP', self.addr[0], self.addr[1], 'SSL')
158
159     def _connectDone(self):
160         self.startTLS(self.ctxFactory)
161         self.startWriting()
162         tcp.Client._connectDone(self)
163
164
165 class Server(tcp.Server):
166     """I am an SSL server.
167     """
168
169     implements(interfaces.ISSLTransport)
170    
171     def getHost(self):
172         """Return server's address."""
173         h, p = self.socket.getsockname()
174         return address.IPv4Address('TCP', h, p, 'SSL')
175
176     def getPeer(self):
177         """Return address of peer."""
178         h, p = self.client
179         return address.IPv4Address('TCP', h, p, 'SSL')
180
181
182 class Port(tcp.Port):
183     """I am an SSL port."""
184     _socketShutdownMethod = 'sock_shutdown'
185    
186     transport = Server
187
188     def __init__(self, port, factory, ctxFactory, backlog=50, interface='', reactor=None):
189         tcp.Port.__init__(self, port, factory, backlog, interface, reactor)
190         self.ctxFactory = ctxFactory
191
192     def createInternetSocket(self):
193         """(internal) create an SSL socket
194         """
195         sock = tcp.Port.createInternetSocket(self)
196         return SSL.Connection(self.ctxFactory.getContext(), sock)
197
198     def _preMakeConnection(self, transport):
199         # *Don't* call startTLS here
200         # The transport already has the SSL.Connection object from above
201         transport._startTLS()
202         return tcp.Port._preMakeConnection(self, transport)
203
204
205 class Connector(base.BaseConnector):
206     def __init__(self, host, port, factory, contextFactory, timeout, bindAddress, reactor=None):
207         self.host = host
208         self.port = port
209         self.bindAddress = bindAddress
210         self.contextFactory = contextFactory
211         base.BaseConnector.__init__(self, factory, timeout, reactor)
212
213     def _makeTransport(self):
214         return Client(self.host, self.port, self.bindAddress, self.contextFactory, self, self.reactor)
215
216     def getDestination(self):
217         return address.IPv4Address('TCP', self.host, self.port, 'SSL')
218
219 from twisted.internet._sslverify import DistinguishedName, DN, Certificate
220 from twisted.internet._sslverify import CertificateRequest, PrivateCertificate
221 from twisted.internet._sslverify import KeyPair
222 from twisted.internet._sslverify import OpenSSLCertificateOptions as CertificateOptions
223
224 __all__ = [
225     "ContextFactory", "DefaultOpenSSLContextFactory", "ClientContextFactory",
226
227     'DistinguishedName', 'DN',
228     'Certificate', 'CertificateRequest', 'PrivateCertificate',
229     'KeyPair',
230     'CertificateOptions',
231     ]
232
233 supported = True
Note: See TracBrowser for help on using the browser.