Skip to content

Commit

Permalink
Add GenericCreator for loading SSL certs in processes (#2578)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahopkins authored Oct 31, 2022
1 parent 3f4663b commit d70636b
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 18 deletions.
37 changes: 19 additions & 18 deletions sanic/worker/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,10 @@

from importlib import import_module
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Optional,
Type,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, cast

from sanic.http.tls.creators import CertCreator, MkcertCreator, TrustmeCreator
from sanic.http.tls.context import process_to_context
from sanic.http.tls.creators import MkcertCreator, TrustmeCreator


if TYPE_CHECKING:
Expand Down Expand Up @@ -106,21 +98,30 @@ def load(self) -> SanicApp:


class CertLoader:
_creator_class: Type[CertCreator]
_creators = {
"mkcert": MkcertCreator,
"trustme": TrustmeCreator,
}

def __init__(self, ssl_data: Dict[str, Union[str, os.PathLike]]):
creator_name = ssl_data.get("creator")
if creator_name not in ("mkcert", "trustme"):
self._ssl_data = ssl_data

creator_name = cast(str, ssl_data.get("creator"))

self._creator_class = self._creators.get(creator_name)
if not creator_name:
return

if not self._creator_class:
raise RuntimeError(f"Unknown certificate creator: {creator_name}")
elif creator_name == "mkcert":
self._creator_class = MkcertCreator
elif creator_name == "trustme":
self._creator_class = TrustmeCreator

self._key = ssl_data["key"]
self._cert = ssl_data["cert"]
self._localhost = cast(str, ssl_data["localhost"])

def load(self, app: SanicApp):
if not self._creator_class:
return process_to_context(self._ssl_data)

creator = self._creator_class(app, self._key, self._cert)
return creator.generate_cert(self._localhost)
27 changes: 27 additions & 0 deletions tests/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import subprocess

from contextlib import contextmanager
from multiprocessing import Event
from pathlib import Path
from unittest.mock import Mock, patch
from urllib.parse import urlparse
Expand Down Expand Up @@ -636,3 +637,29 @@ def test_sanic_ssl_context_create():

assert sanic_context is context
assert isinstance(sanic_context, SanicSSLContext)


def test_ssl_in_multiprocess_mode(app: Sanic, caplog):

ssl_dict = {"cert": localhost_cert, "key": localhost_key}
event = Event()

@app.main_process_start
async def main_start(app: Sanic):
app.shared_ctx.event = event

@app.after_server_start
async def shutdown(app):
app.shared_ctx.event.set()
app.stop()

assert not event.is_set()
with caplog.at_level(logging.INFO):
app.run(ssl=ssl_dict)
assert event.is_set()

assert (
"sanic.root",
logging.INFO,
"Goin' Fast @ https://127.0.0.1:8000",
) in caplog.record_tuples
4 changes: 4 additions & 0 deletions tests/worker/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def test_input_is_module():
@patch("sanic.worker.loader.TrustmeCreator")
@patch("sanic.worker.loader.MkcertCreator")
def test_cert_loader(MkcertCreator: Mock, TrustmeCreator: Mock, creator: str):
CertLoader._creators = {
"mkcert": MkcertCreator,
"trustme": TrustmeCreator,
}
MkcertCreator.return_value = MkcertCreator
TrustmeCreator.return_value = TrustmeCreator
data = {
Expand Down

0 comments on commit d70636b

Please sign in to comment.