Skip to content

Commit

Permalink
Merge pull request #59 from galme/proactive_policy_fetching
Browse files Browse the repository at this point in the history
Proactive policy fetching
  • Loading branch information
Snawoot authored Mar 22, 2020
2 parents 3456620 + 6ba848d commit 36354fc
Show file tree
Hide file tree
Showing 20 changed files with 588 additions and 49 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ postfix-mta-sts-resolver

Daemon which provides TLS client policy for Postfix via socketmap, according to domain MTA-STS policy. Current support of RFC8461 is limited - daemon lacks some minor features:

* Proactive policy fetch
* Fetch error reporting
* Fetch ratelimit (but actual fetch rate partially restricted with `cache_grace` config option).

Expand Down
13 changes: 13 additions & 0 deletions man/mta-sts-daemon.yml.5.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ The file is in YAML syntax with the following elements:
** Options for _redis_ type:
*** All parameters are passed to `aioredis.create_redis_pool` [0]. Check there for a parameter reference.

*proactive_policy_fetching*::

* *enabled*: (_bool_) enable proactive policy fetching in the background. Default: false
* *interval*: (_int_) if proactive policy fetching is enabled, it is scheduled every this many seconds.
It is unaffected by `cache_grace` and vice versa. Default: 86400
* *concurrency_limit*: (_int_) the maximum number of concurrent domain updates. Default: 100
* *grace_ratio*: (_float_) proactive fetch for a particular domain is skipped if its cached policy age is less than `interval/grace_ratio`. Default: 2.0

*default_zone*::

* *strict_testing*: (_bool_) enforce policy for testing domains
Expand All @@ -65,6 +73,11 @@ domains operate under "testing" mode.
port: 8461
reuse_port: true
shutdown_timeout: 20
proactive_policy_fetching:
enabled: true
interval: 86400
concurrency_limit: 100
grace_ratio: 2
cache:
type: internal
options:
Expand Down
21 changes: 21 additions & 0 deletions postfix_mta_sts_resolver/base_cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import collections

from abc import ABC, abstractmethod
Expand All @@ -19,6 +20,26 @@ async def get(self, key):
async def set(self, key, value):
""" Abstract method """

async def safe_set(self, domain, entry, logger):
try:
await self.set(domain, entry)
except asyncio.CancelledError: # pragma: no cover pylint: disable=try-except-raise
raise
except Exception as exc: # pragma: no cover
logger.exception("Cache set failed: %s", str(exc))

@abstractmethod
async def scan(self, token, amount_hint):
""" Abstract method """

@abstractmethod
async def get_proactive_fetch_ts(self):
""" Abstract method """

@abstractmethod
async def set_proactive_fetch_ts(self, timestamp):
""" Abstract method """

@abstractmethod
async def teardown(self):
""" Abstract method """
2 changes: 2 additions & 0 deletions postfix_mta_sts_resolver/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
CHUNK = 4096
QUEUE_LIMIT = 128
REQUEST_LIMIT = 1024
DOMAIN_QUEUE_LIMIT = 1000
MIN_PROACTIVE_FETCH_INTERVAL = 1
26 changes: 24 additions & 2 deletions postfix_mta_sts_resolver/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from .asdnotify import AsyncSystemdNotifier
from . import utils
from . import defaults
from .proactive_fetcher import STSProactiveFetcher
from .responder import STSSocketmapResponder
from .utils import create_cache


def parse_args():
Expand Down Expand Up @@ -61,12 +63,28 @@ async def heartbeat():

async def amain(cfg, loop): # pragma: no cover
logger = logging.getLogger("MAIN")
# Construct request handler instance
responder = STSSocketmapResponder(cfg, loop)

proactive_fetch_enabled = cfg['proactive_policy_fetching']['enabled']

# Create policy cache
cache = create_cache(cfg["cache"]["type"],
cfg["cache"]["options"])
await cache.setup()

# Construct request handler
responder = STSSocketmapResponder(cfg, loop, cache)
await responder.start()
logger.info("Server started.")

# Conditionally construct proactive policy fetcher
proactive_fetcher = None
if proactive_fetch_enabled:
proactive_fetcher = STSProactiveFetcher(cfg, loop, cache)
await proactive_fetcher.start()
logger.info("Proactive policy fetcher started.")
else:
logger.info("Proactive policy fetching is disabled.")

exit_event = asyncio.Event()
beat = asyncio.ensure_future(heartbeat())
sig_handler = partial(exit_handler, exit_event)
Expand All @@ -79,6 +97,9 @@ async def amain(cfg, loop): # pragma: no cover
await notifier.notify(b"STOPPING=1")
beat.cancel()
await responder.stop()
if proactive_fetch_enabled:
await proactive_fetcher.stop()
await cache.teardown()


def main(): # pragma: no cover
Expand All @@ -87,6 +108,7 @@ def main(): # pragma: no cover
with utils.AsyncLoggingHandler(args.logfile) as log_handler:
logger = utils.setup_logger('MAIN', args.verbosity, log_handler)
utils.setup_logger('STS', args.verbosity, log_handler)
utils.setup_logger('PF', args.verbosity, log_handler)
logger.info("MTA-STS daemon starting...")

# Read config and populate with defaults
Expand Down
4 changes: 4 additions & 0 deletions postfix_mta_sts_resolver/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,8 @@
SQLITE_TIMEOUT = 5
REDIS_TIMEOUT = 5
CACHE_GRACE = 60
PROACTIVE_FETCH_ENABLED = False
PROACTIVE_FETCH_INTERVAL = 86400
PROACTIVE_FETCH_CONCURRENCY_LIMIT = 100
PROACTIVE_FETCH_GRACE_RATIO = 2.0
USER_AGENT = "postfix-mta-sts-resolver"
24 changes: 24 additions & 0 deletions postfix_mta_sts_resolver/internal_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
from itertools import islice

from .base_cache import BaseCache

Expand All @@ -7,6 +8,7 @@ class InternalLRUCache(BaseCache):
def __init__(self, cache_size=10000):
self._cache_size = cache_size
self._cache = collections.OrderedDict()
self._proactive_fetch_ts = 0

async def setup(self):
pass
Expand All @@ -29,3 +31,25 @@ async def set(self, key, value):
if len(self._cache) >= self._cache_size:
self._cache.popitem(last=False)
self._cache[key] = value

async def scan(self, token, amount_hint):
if token is None:
token = 0

total = len(self._cache)
left = total - token
if left > 0:
amount = min(left, amount_hint)
new_token = token + amount if token + amount < total else None
# Take "amount" of oldest
result = list(islice(self._cache.items(), amount))
for key, _ in result: # for LRU consistency
await self.get(key)
return new_token, result
return None, []

async def get_proactive_fetch_ts(self):
return self._proactive_fetch_ts

async def set_proactive_fetch_ts(self, timestamp):
self._proactive_fetch_ts = timestamp
109 changes: 109 additions & 0 deletions postfix_mta_sts_resolver/proactive_fetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import asyncio
import logging
import time

from postfix_mta_sts_resolver import constants
from postfix_mta_sts_resolver.base_cache import CacheEntry
from postfix_mta_sts_resolver.resolver import STSResolver, STSFetchResult


# pylint: disable=too-many-instance-attributes
class STSProactiveFetcher:
def __init__(self, cfg, loop, cache):
self._shutdown_timeout = cfg['shutdown_timeout']
self._pf_interval = cfg['proactive_policy_fetching']['interval']
self._pf_concurrency_limit = cfg['proactive_policy_fetching']['concurrency_limit']
self._pf_grace_ratio = cfg['proactive_policy_fetching']['grace_ratio']
self._logger = logging.getLogger("PF")
self._loop = loop
self._cache = cache
self._periodic_fetch_task = None
self._resolver = STSResolver(loop=loop,
timeout=cfg["default_zone"]["timeout"])

async def process_domain(self, domain_queue):
async def update(cached):
status, policy = await self._resolver.resolve(domain, cached.pol_id)
if status is STSFetchResult.VALID:
pol_id, pol_body = policy
updated = CacheEntry(ts, pol_id, pol_body)
await self._cache.safe_set(domain, updated, self._logger)
elif status is STSFetchResult.NOT_CHANGED:
updated = CacheEntry(ts, cached.pol_id, cached.pol_body)
await self._cache.safe_set(domain, updated, self._logger)
else:
self._logger.warning("Domain %s does not have a valid policy.", domain)

while True: # Run until cancelled
cache_item = await domain_queue.get()
ts = time.time() # pylint: disable=invalid-name
try:
domain, cached = cache_item
if ts - cached.ts < self._pf_interval / self._pf_grace_ratio:
self._logger.debug("Domain %s skipped (cache recent enough).", domain)
else:
await update(cached)
except asyncio.CancelledError: # pragma: no cover pylint: disable=try-except-raise
raise
except Exception as exc: # pragma: no cover
self._logger.exception("Unhandled exception: %s", exc)
finally:
domain_queue.task_done()

async def iterate_domains(self):
self._logger.info("Proactive policy fetching "
"for all domains in cache started...")

# Create domain processor tasks
domain_processors = []
domain_queue = asyncio.Queue(maxsize=constants.DOMAIN_QUEUE_LIMIT)
for _ in range(self._pf_concurrency_limit):
domain_processor = self._loop.create_task(self.process_domain(domain_queue))
domain_processors.append(domain_processor)

# Produce work for domain processors
try:
token = None
while True:
token, cache_items = await self._cache.scan(token, constants.DOMAIN_QUEUE_LIMIT)
self._logger.debug("Enqueued %d domains for processing.", len(cache_items))
for cache_item in cache_items:
await domain_queue.put(cache_item)
if token is None:
break

# Wait for queue to clear
await domain_queue.join()
# Clean up the domain processors
finally:
for domain_processor in domain_processors:
domain_processor.cancel()
await asyncio.gather(*domain_processors, return_exceptions=True)

# Update the proactive fetch timestamp
await self._cache.set_proactive_fetch_ts(time.time())

self._logger.info("Proactive policy fetching "
"for all domains in cache finished.")

async def fetch_periodically(self):
while True: # Run until cancelled
next_fetch_ts = await self._cache.get_proactive_fetch_ts() + self._pf_interval
sleep_duration = max(constants.MIN_PROACTIVE_FETCH_INTERVAL,
next_fetch_ts - time.time() + 1)

self._logger.debug("Sleeping for %ds until next fetch.", sleep_duration)
await asyncio.sleep(sleep_duration)
await self.iterate_domains()

async def start(self):
self._periodic_fetch_task = self._loop.create_task(self.fetch_periodically())

async def stop(self):
self._periodic_fetch_task.cancel()

try:
self._logger.warning("Awaiting periodic fetching to finish...")
await self._periodic_fetch_task
except asyncio.CancelledError: # pragma: no cover
pass
26 changes: 26 additions & 0 deletions postfix_mta_sts_resolver/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,32 @@ async def set(self, key, value):
pipe.zremrangebyrank(key, 0, -2)
await pipe.execute()

async def scan(self, token, amount_hint):
assert self._pool is not None
if token is None:
token = b'0'

new_token, keys = await self._pool.scan(cursor=token, count=amount_hint)
if not new_token:
new_token = None

result = []
for key in keys:
key = key.decode('utf-8')
if key != '_metadata':
result.append((key, await self.get(key)))
return new_token, result

async def get_proactive_fetch_ts(self):
assert self._pool is not None
val = await self._pool.hget('_metadata', 'proactive_fetch_ts')
return 0 if not val else float(val.decode('utf-8'))

async def set_proactive_fetch_ts(self, timestamp):
assert self._pool is not None
val = str(timestamp).encode('utf-8')
await self._pool.hset('_metadata', 'proactive_fetch_ts', val)

async def teardown(self):
assert self._pool is not None
self._pool.close()
Expand Down
Loading

0 comments on commit 36354fc

Please sign in to comment.