Skip to content

Commit

Permalink
#183: add type hints for all code
Browse files Browse the repository at this point in the history
  • Loading branch information
kmyk committed Nov 3, 2018
1 parent 22b4e5a commit 783f461
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 52 deletions.
11 changes: 6 additions & 5 deletions onlinejudge/implementation/command/code_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import os.path
import subprocess
import contextlib
from typing import *
if TYPE_CHECKING:
import argparse


def get_char_class(c):
assert isinstance(c, int)
def get_char_class(c: int) -> str:
assert 0 <= c < 256
if chr(c) in string.ascii_letters + string.digits:
return 'alnum'
Expand All @@ -21,8 +23,7 @@ def get_char_class(c):
else:
return 'binary'

def get_statistics(s):
assert isinstance(s, bytes)
def get_statistics(s: bytes) -> Dict[str, int]:
stat = {
'binary': 0,
'alnum': 0,
Expand All @@ -33,7 +34,7 @@ def get_statistics(s):
stat[get_char_class(c)] += 1
return stat

def code_statistics(args):
def code_statistics(args: 'argparse.Namespace') -> None:
with open(args.file, 'rb') as fh:
code = fh.read()
stat = get_statistics(code)
Expand Down
8 changes: 6 additions & 2 deletions onlinejudge/implementation/command/generate_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import onlinejudge.implementation.logging as log
import onlinejudge.implementation.command.utils as cutils
import time
from typing import *
if TYPE_CHECKING:
import argparse

def generate_output(args):
def generate_output(args: 'argparse.Namespace') -> None:
if not args.test:
args.test = cutils.glob_with_format(args.directory, args.format) # by default
if args.ignore_backup:
Expand All @@ -28,7 +31,8 @@ def generate_output(args):
log.info('skipped.')
continue
log.emit(log.bold(answer.decode().rstrip()))
path = cutils.path_from_format(args.directory, args.format, name=cutils.match_with_format(args.directory, args.format, it['in']).groupdict()['name'], ext='out')
name = cutils.match_with_format(args.directory, args.format, it['in']).groupdict()['name'] # type: ignore
path = cutils.path_from_format(args.directory, args.format, name=name, ext='out')
with open(path, 'w') as fh:
fh.buffer.write(answer)
log.success('saved to: %s', path)
34 changes: 17 additions & 17 deletions onlinejudge/implementation/command/generate_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@
import colorama
import collections
import sys
from typing import *
if TYPE_CHECKING:
import argparse

def tokenize(pre): # => [ [ dict ] ]
def tokenize(pre: str) -> Generator[List[Dict[str, str]], None, None]:
for y, line in enumerate(pre.splitlines()):
# remove mathjax tokens
line = line.replace('$', '').replace('\\(', '').replace('\\)', '')
line = line.replace('\\ ', ' ').replace('\\quad', ' ')
# tokenize each line
tokens = []
tokens: List[Dict[str, str]] = []
for x, s in enumerate(line.split()):
if s in [ '..', '...', '\\dots', '…', '⋯' ]:
tokens += [ { 'kind': 'dots', 'dir': ['hr', 'vr'][x == 0] } ]
Expand All @@ -36,13 +39,13 @@ def tokenize(pre): # => [ [ dict ] ]
tokens += [ { 'kind': 'fixed', 'name': s } ]
yield tokens

def simplify_expr(s):
def simplify_expr(s: str) -> str:
transformations = sympy_parser.standard_transformations + ( sympy_parser.implicit_multiplication_application ,)
local_dict = { 'N': sympy.Symbol('N') }
return str(sympy_parser.parse_expr(s, local_dict=local_dict, transformations=transformations))

def parse(tokens):
env = collections.defaultdict(dict)
def parse(tokens: List[List[Dict[str, Any]]]) -> Generator[Dict[str, Any], None, None]:
env: Dict[str, Any] = collections.defaultdict(dict)
for y, line in enumerate(tokens):
for x, item in enumerate(line):
if item['kind'] == 'indexed':
Expand All @@ -55,7 +58,7 @@ def parse(tokens):
f['r'] = item['index']
for name in env:
env[name]['n'] = simplify_expr('{}-{}+1'.format(env[name]['r'], env[name]['l']))
used = set()
used: Set[Any] = set()
for y, line in enumerate(tokens):
for x, item in enumerate(line):
if item['kind'] == 'fixed':
Expand All @@ -74,7 +77,7 @@ def parse(tokens):
yield { 'kind': 'loop', 'length': n, 'body': [ { 'kind': 'read-indexed', 'targets': [ { 'name': name, 'index': 0 } ] } ] }
used.add(name)
elif item['dir'] == 'vr':
names = []
names: List[str] = []
for item in tokens[y-1]:
if item['kind'] != 'indexed':
raise NotImplementedError
Expand All @@ -85,25 +88,22 @@ def parse(tokens):
used.add(name)
if not names:
continue
acc = []
n = env[names[0]]['n']
body = []
body: List[Dict[str, Any]] = []
for name in names:
assert env[name]['n'] == n
yield { 'kind': 'decl-vector', 'targets': [ { 'name': name, 'length': n } ] }
body += [ { 'kind': 'read-indexed', 'targets': [ { 'name': name, 'index': 0 } ] } ]
yield { 'kind': 'loop', 'length': n, 'body': body }
decls = []
reads = []
else:
assert False
else:
assert False

def get_names(targets):
def get_names(targets: List[Dict[str, str]]) -> List[str]:
return list(map(lambda target: target['name'], targets))

def postprocess(it):
def postprocess(it: Any) -> Any:
def go(it):
i = 0
while i < len(it):
Expand All @@ -125,14 +125,14 @@ def go(it):
it = go(it)
return it

def paren_if(n, lr):
def paren_if(n: str, lr: Iterable[str]) -> str:
l, r = lr
if n:
return l + n + r
else:
return n

def export(it, repeat_macro=None, use_scanf=False):
def export(it, repeat_macro: Optional[str] = None, use_scanf: bool = False) -> str:
def go(it, nest):
if it['kind'] == 'decl':
if it['names']:
Expand Down Expand Up @@ -173,7 +173,7 @@ def go(it, nest):
s += go(line, 0)
return s

def generate_scanner(args):
def generate_scanner(args: 'argparse.Namespace') -> None:
if not args.silent:
log.warning('This feature is ' + log.red('experimental') + '.')
if args.silent:
Expand All @@ -183,7 +183,7 @@ def generate_scanner(args):
if problem is None:
sys.exit(1)
with utils.with_cookiejar(utils.new_default_session(), path=args.cookie) as sess:
it = problem.get_input_format(session=sess)
it: Any = problem.get_input_format(session=sess)
if not it:
log.error('input format not found')
sys.exit(1)
Expand Down
5 changes: 4 additions & 1 deletion onlinejudge/implementation/command/get_standings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import onlinejudge.implementation.logging as log
import json
import sys
from typing import *
if TYPE_CHECKING:
import argparse


def get_standings(args):
def get_standings(args: 'argparse.Namespace') -> None:
# parse url
problem = onlinejudge.dispatch.problem_from_url(args.url)
if problem is None:
Expand Down
9 changes: 6 additions & 3 deletions onlinejudge/implementation/command/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import onlinejudge.implementation.logging as log
import sys
import getpass
from typing import *
if TYPE_CHECKING:
import argparse

def login(args):
def login(args: 'argparse.Namespace') -> None:
# get service
service = onlinejudge.dispatch.service_from_url(args.url)
if service is None:
Expand All @@ -26,11 +29,11 @@ def login(args):
sys.exit(1)

# login
def get_credentials():
def get_credentials() -> Tuple[str, str]:
if args.username is None:
args.username = input('Username: ')
if args.password is None:
args.password = getpass.getpass()
return args.username, args.password
with utils.with_cookiejar(utils.new_default_session(), path=args.cookie) as sess:
service.login(get_credentials, session=sess, **kwargs)
service.login(get_credentials, session=sess, **kwargs) # type: ignore
7 changes: 5 additions & 2 deletions onlinejudge/implementation/command/split_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
import sys
import subprocess
import time
from typing import *
if TYPE_CHECKING:
import argparse

def non_block_read(fh):
def non_block_read(fh: IO[Any]) -> str:
# workaround
import fcntl
import os
Expand All @@ -20,7 +23,7 @@ def non_block_read(fh):

split_input_auto_footer = ('__AUTO_FOOTER__', ) # this shouldn't be a string, so a tuple

def split_input(args):
def split_input(args: 'argparse.Namespace') -> None:
with open(args.input) as fh:
inf = fh.read()
if args.footer == split_input_auto_footer:
Expand Down
21 changes: 12 additions & 9 deletions onlinejudge/implementation/command/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
import subprocess
import os
import re
from typing import *
if TYPE_CHECKING:
import argparse

default_url_opener = [ 'sensible-browser', 'xdg-open', 'open' ]

def submit(args):
def submit(args: 'argparse.Namespace') -> None:
# parse url
problem = onlinejudge.dispatch.problem_from_url(args.url)
if problem is None:
Expand Down Expand Up @@ -39,6 +42,7 @@ def submit(args):
with utils.with_cookiejar(utils.new_default_session(), path=args.cookie) as sess:
# guess or select language ids
langs = problem.get_language_dict(session=sess)
matched_lang_ids: Optional[List[str]] = None
if args.guess:
kwargs = {
'language_dict': langs,
Expand All @@ -56,7 +60,7 @@ def submit(args):
elif args.language in langs:
matched_lang_ids = [ args.language ]
else:
matched_lang_ids = select_ids_of_matched_languages(args.language.split(), langs.keys(), language_dict=langs)
matched_lang_ids = select_ids_of_matched_languages(args.language.split(), list(langs.keys()), language_dict=langs)

# report selected language ids
if matched_lang_ids is not None and len(matched_lang_ids) == 1:
Expand Down Expand Up @@ -95,7 +99,7 @@ def submit(args):
kwargs['kind'] = 'full'
else:
kwargs['kind'] = 'example'
submission = problem.submit(code, language=args.language, session=sess, **kwargs)
submission = problem.submit(code, language=args.language, session=sess, **kwargs) # type: ignore

# show result
if submission is None:
Expand All @@ -115,8 +119,7 @@ def submit(args):
log.info('open the submission page with: %s', browser)
subprocess.check_call([ browser, submission.get_url() ], stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr)

def select_ids_of_matched_languages(words, lang_ids, language_dict, split=False, remove=False):
assert isinstance(words, list)
def select_ids_of_matched_languages(words: List[str], lang_ids: List[str], language_dict, split: bool = False, remove: bool = False) -> List[str]:
result = []
for lang_id in lang_ids:
desc = language_dict[lang_id]['description'].lower()
Expand All @@ -129,12 +132,12 @@ def select_ids_of_matched_languages(words, lang_ids, language_dict, split=False,
result.append(lang_id)
return result

def guess_lang_ids_of_file(filename, code, language_dict, cxx_latest=False, cxx_compiler='all', python_version='all', python_interpreter='all'):
def guess_lang_ids_of_file(filename: str, code: bytes, language_dict, cxx_latest: bool = False, cxx_compiler: str = 'all', python_version: str = 'all', python_interpreter: str = 'all') -> List[str]:
assert cxx_compiler.lower() in ( 'gcc', 'clang', 'all' )
assert python_version.lower() in ( '2', '3', 'auto', 'all' )
assert python_interpreter.lower() in ( 'cpython', 'pypy', 'all' )

select = lambda word, lang_ids, **kwargs: select_ids_of_matched_languages([ word ], lang_ids, **kwargs, language_dict=language_dict)
select = lambda word, lang_ids, **kwargs: select_ids_of_matched_languages([ word ], lang_ids, language_dict=language_dict, **kwargs)
_, ext = os.path.splitext(filename)
lang_ids = language_dict.keys()

Expand Down Expand Up @@ -225,7 +228,7 @@ def guess_lang_ids_of_file(filename, code, language_dict, cxx_latest=False, cxx_

else:
log.debug('language guessing: othres')
table = [
table: List[Dict[str, Any]] = [
{ 'names': [ 'awk' ], 'exts': [ 'awk' ] },
{ 'names': [ 'bash' ], 'exts': [ 'sh' ] },
{ 'names': [ 'brainfuck' ], 'exts': [ 'bf' ] },
Expand Down Expand Up @@ -266,7 +269,7 @@ def guess_lang_ids_of_file(filename, code, language_dict, cxx_latest=False, cxx_
return list(set(lang_ids))


def format_code(code, dos2unix=False, rstrip=False):
def format_code(code: bytes, dos2unix: bool = False, rstrip: bool = False) -> bytes:
if dos2unix:
log.info('dos2unix...')
code = code.replace(b'\r\n', b'\n')
Expand Down
24 changes: 14 additions & 10 deletions onlinejudge/implementation/command/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
import collections
import time
import math
from typing import *
if TYPE_CHECKING:
import argparse

def compare_as_floats(xs, ys, error):
def compare_as_floats(xs_: str, ys_: str, error: float) -> bool:
def f(x):
try:
y = float(x)
Expand All @@ -22,8 +25,8 @@ def f(x):
return y
except ValueError:
return x
xs = list(map(f, xs.split()))
ys = list(map(f, ys.split()))
xs = list(map(f, xs_.split()))
ys = list(map(f, ys_.split()))
if len(xs) != len(ys):
return False
for x, y in zip(xs, ys):
Expand All @@ -35,7 +38,7 @@ def f(x):
return False
return True

def test(args):
def test(args: 'argparse.Namespace') -> None:
# prepare
if not args.test:
args.test = cutils.glob_with_format(args.directory, args.format) # by default
Expand All @@ -53,7 +56,8 @@ def match(a, b):
return True
return False
rstrip_targets = ' \t\r\n\f\v\0' # ruby's one, follow AnarchyGolf
slowest, slowest_name = -1, ''
slowest: Union[int, float] = -1
slowest_name = ''
ac_count = 0

for name, it in sorted(tests.items()):
Expand All @@ -71,9 +75,9 @@ def print_input():
# run the binary
with open(it['in']) as inf:
begin = time.perf_counter()
answer, proc = utils.exec_command(args.command, shell=True, stdin=inf, timeout=args.tle)
answer_byte, proc = utils.exec_command(args.command, shell=True, stdin=inf, timeout=args.tle)
end = time.perf_counter()
answer = answer.decode()
answer = answer_byte.decode()
if slowest < end - begin:
slowest = end - begin
slowest_name = name
Expand Down Expand Up @@ -105,9 +109,9 @@ def print_input():
log.emit('expected:\n%s', log.bold(correct))
result = 'WA'
elif args.mode == 'line':
answer = answer .splitlines()
correct = correct.splitlines()
for i, (x, y) in enumerate(zip(answer + [ None ] * len(correct), correct + [ None ] * len(answer))):
answer_words = answer .splitlines()
correct_words = correct.splitlines()
for i, (x, y) in enumerate(zip(answer_words + [ None ] * len(correct_words), correct_words + [ None ] * len(answer_words))): # type: ignore
if x is None and y is None:
break
elif x is None:
Expand Down
Loading

0 comments on commit 783f461

Please sign in to comment.