Skip to content

Commit

Permalink
Merge pull request #30 from Snawoot/lint
Browse files Browse the repository at this point in the history
Lint
  • Loading branch information
Snawoot authored May 27, 2019
2 parents dbd9b68 + c8c5fc3 commit cd92e4f
Show file tree
Hide file tree
Showing 9 changed files with 696 additions and 112 deletions.
574 changes: 574 additions & 0 deletions .pylintrc

Large diffs are not rendered by default.

11 changes: 2 additions & 9 deletions postfix_mta_sts_resolver/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python3

import sys
import argparse
import asyncio

Expand All @@ -10,11 +9,6 @@
def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-v", "--verbosity",
help="logging verbosity",
type=utils.LogLevel.__getitem__,
choices=list(utils.LogLevel),
default=utils.LogLevel.warn)
parser.add_argument("domain",
help="domain to fetch MTA-STS policy from")
parser.add_argument("known_version",
Expand All @@ -27,11 +21,10 @@ def parse_args():

def main():
args = parse_args()
mainLogger = utils.setup_logger('MAIN', args.verbosity)

loop = asyncio.get_event_loop()
R = STSResolver(loop=loop)
result = loop.run_until_complete(R.resolve(args.domain, args.known_version))
resolver = STSResolver(loop=loop)
result = loop.run_until_complete(resolver.resolve(args.domain, args.known_version))
print(result)


Expand Down
25 changes: 12 additions & 13 deletions postfix_mta_sts_resolver/daemon.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python3

import sys
import os
import argparse
import asyncio
Expand Down Expand Up @@ -36,19 +35,19 @@ def parse_args():
return parser.parse_args()


def exit_handler(exit_event, signum, frame):
def exit_handler(exit_event, signum, frame): # pylint: disable=unused-argument
logger = logging.getLogger('MAIN')
if exit_event.is_set():
logger.warning("Got second exit signal! Terminating hard.")
os._exit(1)
os._exit(1) # pylint: disable=protected-access
else:
logger.warning("Got first exit signal! Terminating gracefully.")
exit_event.set()


async def heartbeat():
""" Hacky coroutine which keeps event loop spinning with some interval
even if no events are coming. This is required to handle Futures and
""" Hacky coroutine which keeps event loop spinning with some interval
even if no events are coming. This is required to handle Futures and
Events state change when no events are occuring."""
while True:
await asyncio.sleep(.5)
Expand Down Expand Up @@ -79,25 +78,25 @@ async def amain(cfg, loop):
def main():
# Parse command line arguments and setup basic logging
args = parse_args()
mainLogger = utils.setup_logger('MAIN', args.verbosity, args.logfile )
logger = utils.setup_logger('MAIN', args.verbosity, args.logfile)
utils.setup_logger('STS', args.verbosity, args.logfile)
mainLogger.info("MTA-STS daemon starting...")
logger.info("MTA-STS daemon starting...")

# Read config and populate with defaults
cfg = utils.load_config(args.config)

# Construct event loop
mainLogger.info("Starting eventloop...")
logger.info("Starting eventloop...")
if not args.disable_uvloop:
if utils.enable_uvloop():
mainLogger.info("uvloop enabled.")
logger.info("uvloop enabled.")
else:
mainLogger.info("uvloop is not available. "
"Falling back to built-in event loop.")
logger.info("uvloop is not available. "
"Falling back to built-in event loop.")
evloop = asyncio.get_event_loop()
mainLogger.info("Eventloop started.")
logger.info("Eventloop started.")


evloop.run_until_complete(amain(cfg, evloop))
evloop.close()
mainLogger.info("Server finished its work.")
logger.info("Server finished its work.")
13 changes: 8 additions & 5 deletions postfix_mta_sts_resolver/redis_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import logging
import uuid

import aioredis
Expand All @@ -8,7 +7,7 @@


def pack_entry(entry):
ts, pol_id, pol_body = entry
ts, pol_id, pol_body = entry # pylint: disable=invalid-name,unused-variable
obj = (pol_id, pol_body)
# add unique seed to entry in order to avoid set collisions
# and use ZSET two-index table
Expand All @@ -29,30 +28,34 @@ def __init__(self, **opts):
self._opts['timeout'] = self._opts.get('timeout',
defaults.REDIS_TIMEOUT)
self._opts['encoding'] = None
self._pool = None

async def setup(self):
self._pool = await aioredis.create_redis_pool(**self._opts)

async def get(self, key):
assert self._pool is not None
key = key.encode('utf-8')
res = await self._pool.zrevrange(key, 0, 0, "WITHSCORES")
if not res:
return None
packed, ts = res[0]
packed, ts = res[0] # pylint: disable=invalid-name
entry = unpack_entry(packed)
return CacheEntry(ts=ts, pol_id=entry.pol_id, pol_body=entry.pol_body)

async def set(self, key, value):
assert self._pool is not None
packed = pack_entry(value)
ts = value.ts
ts = value.ts # pylint: disable=invalid-name
key = key.encode('utf-8')

# Write
pipe = self._pool.pipeline()
pipe.zadd(key, ts, packed)
pipe.zremrangebyrank(key, 0, -2)
results = await pipe.execute()
await pipe.execute()

async def teardown(self):
assert self._pool is not None
self._pool.close()
await self._pool.wait_closed()
42 changes: 21 additions & 21 deletions postfix_mta_sts_resolver/resolver.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import asyncio
import aiodns
import aiohttp
import enum
from io import BytesIO

import aiodns
import aiohttp

from . import defaults
from .utils import *
from .constants import *
from .utils import parse_mta_sts_record, parse_mta_sts_policy, is_plaintext, filter_text
from .constants import HARD_RESP_LIMIT, CHUNK


class BadSTSPolicy(Exception):
Expand All @@ -20,23 +20,26 @@ class STSFetchResult(enum.Enum):
NOT_CHANGED = 3


class STSResolver(object):
_HEADERS = {"User-Agent": defaults.USER_AGENT}


# pylint: disable=too-few-public-methods
class STSResolver:
def __init__(self, *, timeout=defaults.TIMEOUT, loop):
self._loop = loop
self._timeout = timeout
self._resolver = aiodns.DNSResolver(timeout=timeout, loop=loop)
self._http_timeout = aiohttp.ClientTimeout(total=timeout)
self._proxy_info = aiohttp.helpers.proxies_from_env().get('https',
None)
self._headers = {}

self._proxy_info = aiohttp.helpers.proxies_from_env().get('https', None)

if self._proxy_info is None:
self._proxy = None
self._proxy_auth = None
else:
self._proxy = self._proxy_info.proxy
self._proxy_auth = self._proxy_info.proxy_auth

# pylint: disable=too-many-locals,too-many-branches,too-many-return-statements
async def resolve(self, domain, last_known_id=None):
if domain.startswith('.'):
return STSFetchResult.NONE, None
Expand All @@ -49,15 +52,15 @@ async def resolve(self, domain, last_known_id=None):
# Try to fetch it
try:
txt_records = await self._resolver.query(sts_txt_domain, 'TXT')
except aiodns.error.DNSError as e:
if e.args[0] == aiodns.error.ARES_ETIMEOUT:
except aiodns.error.DNSError as error:
if error.args[0] == aiodns.error.ARES_ETIMEOUT: # pylint: disable=no-else-return
# It's hard to decide what to do in case of timeout
# Probably it's better to threat this as fetch error
# so caller probably shall report such cases.
return STSFetchResult.FETCH_ERROR, None
elif e.args[0] == aiodns.error.ARES_ENOTFOUND:
elif error.args[0] == aiodns.error.ARES_ENOTFOUND:
return STSFetchResult.NONE, None
elif e.args[0] == aiodns.error.ARES_ENODATA:
elif error.args[0] == aiodns.error.ARES_ENODATA:
return STSFetchResult.NONE, None
else:
return STSFetchResult.NONE, None
Expand Down Expand Up @@ -88,17 +91,14 @@ async def resolve(self, domain, last_known_id=None):
domain +
'/.well-known/mta-sts.txt')

# Construct headers for MTA-STS policy fetch
self._headers["User-Agent"] = defaults.USER_AGENT

# Fetch actual policy
try:
async with aiohttp.ClientSession(loop=self._loop,
timeout=self._http_timeout) \
as session:
async with session.get(sts_policy_url,
allow_redirects=False,
proxy=self._proxy, headers=self._headers,
proxy=self._proxy, headers=_HEADERS,
proxy_auth=self._proxy_auth) as resp:
if resp.status != 200:
raise BadSTSPolicy()
Expand All @@ -118,7 +118,7 @@ async def resolve(self, domain, last_known_id=None):
charset = (resp.charset if resp.charset is not None
else 'ascii')
policy_text = policy_file.getvalue().decode(charset)
except:
except Exception:
return STSFetchResult.FETCH_ERROR, None

# Parse policy
Expand All @@ -131,10 +131,10 @@ async def resolve(self, domain, last_known_id=None):
try:
max_age = int(pol.get('max_age', '-1'))
pol['max_age'] = max_age
except:
except ValueError:
return STSFetchResult.FETCH_ERROR, None

if not (0 <= max_age <= 31557600):
if not 0 <= max_age <= 31557600:
return STSFetchResult.FETCH_ERROR, None

if 'mode' not in pol:
Expand Down
Loading

0 comments on commit cd92e4f

Please sign in to comment.