Skip to content

Commit

Permalink
#183: add type hints for all classes for services
Browse files Browse the repository at this point in the history
  • Loading branch information
kmyk committed Nov 3, 2018
1 parent 51fb05b commit 1424236
Show file tree
Hide file tree
Showing 13 changed files with 225 additions and 168 deletions.
2 changes: 1 addition & 1 deletion onlinejudge/anarchygolf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AnarchyGolfProblem(onlinejudge.problem.Problem):
def __init__(self, problem_id: str):
self.problem_id = problem_id

def download(self, session: Optional[requests.Session] = None):
def download(self, session: Optional[requests.Session] = None) -> List[onlinejudge.problem.TestCase]:
session = session or utils.new_default_session()
# get
resp = utils.request('GET', self.get_url(), session=session)
Expand Down
30 changes: 17 additions & 13 deletions onlinejudge/aoj.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
import onlinejudge.service
import onlinejudge.problem
from onlinejudge.problem import LabeledString, TestCase
import onlinejudge.dispatch
import onlinejudge.implementation.utils as utils
import onlinejudge.implementation.logging as log
Expand All @@ -13,6 +14,7 @@
import zipfile
import collections
import itertools
from typing import *


@utils.singleton
Expand All @@ -25,24 +27,25 @@ def get_name(self):
return 'aoj'

@classmethod
def from_url(cls, s):
def from_url(cls, s: str) -> Optional['AOJService']:
# example: http://judge.u-aizu.ac.jp/onlinejudge/
result = urllib.parse.urlparse(s)
if result.scheme in ('', 'http', 'https') \
and result.netloc == 'judge.u-aizu.ac.jp':
return cls()
return None


class AOJProblem(onlinejudge.problem.Problem):
def __init__(self, problem_id):
self.problem_id = problem_id

def download(self, session=None, is_system=False):
def download(self, session: Optional[requests.Session] = None, is_system: bool = False) -> List[TestCase]:
if is_system:
return self.download_system(session=session)
else:
return self.download_samples(session=session)
def download_samples(self, session=None):
def download_samples(self, session: Optional[requests.Session] = None) -> List[TestCase]:
session = session or utils.new_default_session()
# get
resp = utils.request('GET', self.get_url(), session=session)
Expand All @@ -65,34 +68,34 @@ def download_samples(self, session=None):
name = hn.string
samples.add(s, name)
return samples.get()
def download_system(self, session=None):
def download_system(self, session: Optional[requests.Session] = None) -> List[TestCase]:
session = session or utils.new_default_session()
get_url = lambda case, type: 'http://analytic.u-aizu.ac.jp:8080/aoj/testcase.jsp?id={}&case={}&type={}'.format(self.problem_id, case, type)
testcases = []
testcases: List[TestCase] = []
for case in itertools.count(1):
# input
# get
resp = utils.request('GET', get_url(case, 'in'), session=session, raise_for_status=False)
if resp.status_code != 200:
break
in_txt = resp.text
if case == 2 and testcases[0]['input']['data'] == in_txt:
if case == 2 and testcases[0].input.data == in_txt:
break # if the querystring case=??? is ignored
# output
# get
resp = utils.request('GET', get_url(case, 'out'), session=session)
out_txt = resp.text
testcases += [ {
'input': { 'data': in_txt, 'name': 'in%d.txt' % case },
'output': { 'data': out_txt, 'name': 'out%d.txt' % case },
} ]
testcases += [ TestCase(
LabeledString('in%d.txt' % case, in_txt),
LabeledString('out%d.txt' % case, out_txt),
) ]
return testcases

def get_url(self):
def get_url(self) -> str:
return 'http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id={}'.format(self.problem_id)

@classmethod
def from_url(cls, s):
def from_url(cls, s: str) -> Optional['AOJProblem']:
# example: http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=1169
# example: http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=DSL_1_A&lang=jp
result = urllib.parse.urlparse(s)
Expand All @@ -104,8 +107,9 @@ def from_url(cls, s):
and len(querystring['id']) == 1:
n, = querystring['id']
return cls(n)
return None

def get_service(self):
def get_service(self) -> AOJService:
return AOJService()


Expand Down
84 changes: 50 additions & 34 deletions onlinejudge/atcoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
import onlinejudge.service
import onlinejudge.problem
from onlinejudge.problem import SubmissionError
import onlinejudge.submission
import onlinejudge.dispatch
import onlinejudge.implementation.utils as utils
Expand All @@ -12,12 +13,13 @@
import urllib.parse
import posixpath
import json
from typing import *


@utils.singleton
class AtCoderService(onlinejudge.service.Service):

def login(self, get_credentials, session=None):
def login(self, get_credentials: onlinejudge.service.CredentialsProvider, session: Optional[requests.Session] = None) -> bool:
session = session or utils.new_default_session()
url = 'https://practice.contest.atcoder.jp/login'
# get
Expand All @@ -34,31 +36,32 @@ def login(self, get_credentials, session=None):
AtCoderService._report_messages(msgs)
return 'login' not in resp.url # AtCoder redirects to the top page if success

def get_url(self):
def get_url(self) -> str:
return 'https://atcoder.jp/'

def get_name(self):
def get_name(self) -> str:
return 'atcoder'

@classmethod
def from_url(cls, s):
def from_url(cls, s: str) -> Optional['AtCoderService']:
# example: https://atcoder.jp/
# example: http://agc012.contest.atcoder.jp/
result = urllib.parse.urlparse(s)
if result.scheme in ('', 'http', 'https') \
and (result.netloc in ( 'atcoder.jp', 'beta.atcoder.jp' ) or result.netloc.endswith('.contest.atcoder.jp')):
return cls()
return None

@classmethod
def _get_messages_from_cookie(cls, cookies):
msgtags = []
def _get_messages_from_cookie(cls, cookies) -> List[str]:
msgtags: List[str] = []
for cookie in cookies:
log.debug('cookie: %s', str(cookie))
if cookie.name.startswith('__message_'):
msg = json.loads(urllib.parse.unquote_plus(cookie.value))
msgtags += [ msg['c'] ]
log.debug('message: %s: %s', cookie.name, str(msg))
msgs = []
msgs: List[str] = []
for msgtag in msgtags:
soup = bs4.BeautifulSoup(msgtag, utils.html_parser)
msg = None
Expand All @@ -73,7 +76,7 @@ def _get_messages_from_cookie(cls, cookies):
return msgs

@classmethod
def _report_messages(cls, msgs, unexpected=False):
def _report_messages(cls, msgs: List[str], unexpected: bool = False) -> bool:
for msg in msgs:
log.status('message: %s', msg)
if msgs and unexpected:
Expand All @@ -82,12 +85,12 @@ def _report_messages(cls, msgs, unexpected=False):


class AtCoderProblem(onlinejudge.problem.Problem):
def __init__(self, contest_id, problem_id):
def __init__(self, contest_id: str, problem_id: str):
self.contest_id = contest_id
self.problem_id = problem_id
self._task_id = None
self._task_id: Optional[int] = None

def download(self, session=None):
def download(self, session: Optional[requests.Session] = None) -> List[onlinejudge.problem.TestCase]:
session = session or utils.new_default_session()
# get
resp = utils.request('GET', self.get_url(), session=session)
Expand Down Expand Up @@ -135,14 +138,14 @@ def _find_sample_tags(self, soup):
result += [( pre, prv )]
return result

def get_url(self):
def get_url(self) -> str:
return 'http://{}.contest.atcoder.jp/tasks/{}'.format(self.contest_id, self.problem_id)

def get_service(self):
def get_service(self) -> AtCoderService:
return AtCoderService()

@classmethod
def from_url(cls, s):
def from_url(cls, s: str) -> Optional['AtCoderProblem']:
# example: http://agc012.contest.atcoder.jp/tasks/agc012_d
result = urllib.parse.urlparse(s)
dirname, basename = posixpath.split(utils.normpath(result.path))
Expand All @@ -165,7 +168,9 @@ def from_url(cls, s):
problem_id = m.group(2)
return cls(contest_id, problem_id)

def get_input_format(self, session=None):
return None

def get_input_format(self, session: Optional[requests.Session] = None) -> str:
session = session or utils.new_default_session()
# get
resp = utils.request('GET', self.get_url(), session=session)
Expand All @@ -186,8 +191,9 @@ def get_input_format(self, session=None):
for it in tag:
s += it.string or it # AtCoder uses <var>...</var> for math symbols
return s
return ''

def get_language_dict(self, session=None):
def get_language_dict(self, session: Optional[requests.Session] = None) -> Dict[str, Any]:
session = session or utils.new_default_session()
# get
url = 'http://{}.contest.atcoder.jp/submit'.format(self.contest_id)
Expand All @@ -208,26 +214,26 @@ def get_language_dict(self, session=None):
language_dict[option.attrs['value']] = { 'description': option.string }
return language_dict

def submit(self, code, language, session=None):
def submit(self, code: str, language: str, session: Optional[requests.Session] = None) -> 'AtCoderSubmission':
assert language in self.get_language_dict(session=session)
session = session or utils.new_default_session()
# get
url = 'http://{}.contest.atcoder.jp/submit'.format(self.contest_id) # TODO: use beta.atcoder.jp
resp = utils.request('GET', url, session=session)
msgs = AtCoderService._get_messages_from_cookie(resp.cookies)
if AtCoderService._report_messages(msgs, unexpected=True):
return None
raise SubmissionError
# check whether logged in
path = utils.normpath(urllib.parse.urlparse(resp.url).path)
if path.startswith('/login'):
log.error('not logged in')
return None
raise SubmissionError
# parse
soup = bs4.BeautifulSoup(resp.content.decode(resp.encoding), utils.html_parser)
form = soup.find('form', action=re.compile(r'^/submit\?task_id='))
if not form:
log.error('form not found')
return None
raise SubmissionError
log.debug('form: %s', str(form))
# post
task_id = self._get_task_id(session=session)
Expand All @@ -246,38 +252,43 @@ def submit(self, code, language, session=None):
log.success('success: result: %s', resp.url)
# NOTE: ignore the returned legacy URL and use beta.atcoder.jp's one
url = 'https://beta.atcoder.jp/contests/{}/submissions/me'.format(self.contest_id)
return onlinejudge.submission.CompatibilitySubmission(url)
submission = AtCoderSubmission.from_url(url, problem_id=self.problem_id)
if not submission:
raise SubmissionError
return submission
else:
log.failure('failure')
return None
raise SubmissionError

def _get_task_id(self, session=None):
def _get_task_id(self, session: Optional[requests.Session] = None) -> int:
if self._task_id is None:
session = session or utils.new_default_session()
# get
resp = utils.request('GET', self.get_url(), session=session)
msgs = AtCoderService._get_messages_from_cookie(resp.cookies)
if AtCoderService._report_messages(msgs, unexpected=True):
return {}
raise SubmissionError
# parse
soup = bs4.BeautifulSoup(resp.content.decode(resp.encoding), utils.html_parser)
submit = soup.find('a', href=re.compile(r'^/submit\?task_id='))
if not submit:
log.error('link to submit not found')
return False
raise SubmissionError
m = re.match(r'^/submit\?task_id=([0-9]+)$', submit.attrs['href'])
assert m
self._task_id = int(m.group(1))
return self._task_id

class AtCoderSubmission(onlinejudge.submission.Submission):
def __init__(self, contest_id, submission_id, problem_id=None):
def __init__(self, contest_id: str, submission_id: int, problem_id: Optional[str] = None):
self.contest_id = contest_id
self.submission_id = submission_id
self.problem_id = problem_id

@classmethod
def from_url(cls, s, problem_id=None):
def from_url(cls, s: str, problem_id: Optional[str] = None) -> Optional['AtCoderSubmission']:
submission_id: Optional[int] = None

# example: http://agc001.contest.atcoder.jp/submissions/1246803
result = urllib.parse.urlparse(s)
dirname, basename = posixpath.split(utils.normpath(result.path))
Expand All @@ -290,6 +301,7 @@ def from_url(cls, s, problem_id=None):
try:
submission_id = int(basename)
except ValueError:
pass
submission_id = None
if submission_id is not None:
return cls(contest_id, submission_id, problem_id=problem_id)
Expand All @@ -307,23 +319,26 @@ def from_url(cls, s, problem_id=None):
if submission_id is not None:
return cls(contest_id, submission_id, problem_id=problem_id)

def get_url(self):
return None

def get_url(self) -> str:
return 'http://{}.contest.atcoder.jp/submissions/{}'.format(self.contest_id, self.submission_id)

def get_problem(self):
if self.problem_id is not None:
return AtCoderProblem(self.contest_id, self.problem_id)
def get_problem(self) -> AtCoderProblem:
if self.problem_id is None:
raise ValueError
return AtCoderProblem(self.contest_id, self.problem_id)

def get_service(self):
def get_service(self) -> AtCoderService:
return AtCoderService()

def download(self, session=None):
def download(self, session: Optional[requests.Session] = None) -> str:
session = session or utils.new_default_session()
# get
resp = utils.request('GET', self.get_url(), session=session)
msgs = AtCoderService._get_messages_from_cookie(resp.cookies)
if AtCoderService._report_messages(msgs, unexpected=True):
return []
raise RuntimeError
# parse
soup = bs4.BeautifulSoup(resp.content.decode(resp.encoding), utils.html_parser)
code = None
Expand All @@ -335,6 +350,7 @@ def download(self, session=None):
code = pre.string
if code is None:
log.error('source code not found')
raise RuntimeError
return code

onlinejudge.dispatch.services += [ AtCoderService ]
Expand Down
Loading

0 comments on commit 1424236

Please sign in to comment.