Skip to content
This repository has been archived by the owner on Jan 13, 2021. It is now read-only.

Commit

Permalink
Merge branch 'development' of github.com:Lukasa/hyper into development
Browse files Browse the repository at this point in the history
  • Loading branch information
Lukasa committed Jun 13, 2016
2 parents 9766ad0 + 806aff7 commit a4f185e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 8 deletions.
7 changes: 7 additions & 0 deletions hyper/common/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,10 @@ def __init__(self, negotiated, sock):
super(HTTPUpgrade, self).__init__()
self.negotiated = negotiated
self.sock = sock


class MissingCertFile(Exception):
"""
The certificate file could not be found.
"""
pass
27 changes: 19 additions & 8 deletions hyper/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Contains the TLS/SSL logic for use in hyper.
"""
import os.path as path

from .common.exceptions import MissingCertFile
from .compat import ignore_missing, ssl


Expand All @@ -29,14 +29,17 @@ def wrap_socket(sock, server_hostname, ssl_context=None, force_proto=None):
A vastly simplified SSL wrapping function. We'll probably extend this to
do more things later.
"""
global _context

# create the singleton SSLContext we use
if _context is None: # pragma: no cover
_context = init_context()
global _context

# if an SSLContext is provided then use it instead of default context
_ssl_context = ssl_context or _context
if ssl_context:
# if an SSLContext is provided then use it instead of default context
_ssl_context = ssl_context
else:
# create the singleton SSLContext we use
if _context is None: # pragma: no cover
_context = init_context()
_ssl_context = _context

# the spec requires SNI support
ssl_sock = _ssl_context.wrap_socket(sock, server_hostname=server_hostname)
Expand Down Expand Up @@ -94,9 +97,17 @@ def init_context(cert_path=None, cert=None, cert_password=None):
encrypted and no password is needed.
:returns: An ``SSLContext`` correctly set up for HTTP/2.
"""
cafile = cert_path or cert_loc
if not cafile or not path.exists(cafile):
err_msg = ("No certificate found at " + str(cafile) + ". Either " +
"ensure the default cert.pem file is included in the " +
"distribution or provide a custom certificate when " +
"creating the connection.")
raise MissingCertFile(err_msg)

context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
context.set_default_verify_paths()
context.load_verify_locations(cafile=cert_path or cert_loc)
context.load_verify_locations(cafile=cafile)
context.verify_mode = ssl.CERT_REQUIRED
context.check_hostname = True

Expand Down
15 changes: 15 additions & 0 deletions test/test_SSLContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
CLIENT_CERT_FILE = os.path.join(TEST_CERTS_DIR, 'client.crt')
CLIENT_KEY_FILE = os.path.join(TEST_CERTS_DIR, 'client.key')
CLIENT_PEM_FILE = os.path.join(TEST_CERTS_DIR, 'nopassword.pem')
MISSING_PEM_FILE = os.path.join(TEST_CERTS_DIR, 'missing.pem')


class TestSSLContext(object):
Expand Down Expand Up @@ -60,3 +61,17 @@ def test_client_certificates(self):
cert=(CLIENT_CERT_FILE, CLIENT_KEY_FILE),
cert_password=b'abc123')
hyper.tls.init_context(cert=CLIENT_PEM_FILE)

def test_missing_certs(self):
succeeded = False
threw_expected_exception = False
try:
hyper.tls.init_context(MISSING_PEM_FILE)
succeeded = True
except hyper.common.exceptions.MissingCertFile:
threw_expected_exception = True
except:
pass

assert not succeeded
assert threw_expected_exception

0 comments on commit a4f185e

Please sign in to comment.