Skip to content

Commit

Permalink
#318: make a shared default session
Browse files Browse the repository at this point in the history
  • Loading branch information
kmyk committed Mar 1, 2019
1 parent 9b6356b commit f1a0184
Show file tree
Hide file tree
Showing 15 changed files with 59 additions and 48 deletions.
2 changes: 1 addition & 1 deletion onlinejudge/_implementation/command/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def download(args: 'argparse.Namespace') -> None:
args.format = '%b.%e'

# get samples from the server
with utils.with_cookiejar(utils.new_default_session(), path=args.cookie) as sess:
with utils.with_cookiejar(utils.new_default_session_with_our_user_agent(), path=args.cookie) as sess:
if args.system:
try:
samples = problem.download_system_cases(session=sess) # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion onlinejudge/_implementation/command/generate_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def generate_scanner(args: 'argparse.Namespace') -> None:
problem = onlinejudge.dispatch.problem_from_url(args.url)
if problem is None:
sys.exit(1)
with utils.with_cookiejar(utils.new_default_session(), path=args.cookie) as sess:
with utils.with_cookiejar(utils.new_default_session_with_our_user_agent(), path=args.cookie) as sess:
it = problem.get_input_format(session=sess) # type: Any
if not it:
log.error('input format not found')
Expand Down
2 changes: 1 addition & 1 deletion onlinejudge/_implementation/command/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def login(args: 'argparse.Namespace') -> None:
log.failure('login for %s: invalid option: --method %s', service.get_name(), args.method)
sys.exit(1)

with utils.with_cookiejar(utils.new_default_session(), path=args.cookie) as sess:
with utils.with_cookiejar(utils.new_default_session_with_our_user_agent(), path=args.cookie) as sess:

if args.check:
if service.is_logged_in(session=sess):
Expand Down
2 changes: 1 addition & 1 deletion onlinejudge/_implementation/command/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def submit(args: 'argparse.Namespace') -> None:
log.info('code (%d byte):', len(code))
log.emit(utils.snip_large_file_content(s.rstrip(), limit=30, head=10, tail=10, bold=True))

with utils.with_cookiejar(utils.new_default_session(), path=args.cookie) as sess:
with utils.with_cookiejar(utils.new_default_session_with_our_user_agent(), path=args.cookie) as sess:
# guess or select language ids
langs = {language.id: {'description': language.name} for language in problem.get_available_languages(session=sess)} # type: Dict[LanguageId, Dict[str, str]]
matched_lang_ids = None # type: Optional[List[str]]
Expand Down
13 changes: 12 additions & 1 deletion onlinejudge/_implementation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,22 @@ def next_sibling_tag(tag: bs4.Tag) -> bs4.Tag:
return tag


def new_default_session() -> requests.Session: # without setting cookiejar
def new_session_with_our_user_agent() -> requests.Session:
session = requests.Session()
session.headers['User-Agent'] += ' (+{})'.format(version.__url__)
return session

_default_session = None # Optional[requests.Session]

def get_default_session() -> requests.Session:
"""
:note: cookie is not saved to disk
"""
global _default_session
if _default_session is None:
_default_session = new_session_with_our_user_agent()
return _default_session


default_cookie_path = data_dir / 'cookie.jar'

Expand Down
2 changes: 1 addition & 1 deletion onlinejudge/service/anarchygolf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, problem_id: str):
self.problem_id = problem_id

def download_sample_cases(self, session: Optional[requests.Session] = None) -> List[onlinejudge.type.TestCase]:
session = session or utils.new_default_session()
session = session or utils.get_default_session()
# get
resp = utils.request('GET', self.get_url(), session=session)
# parse
Expand Down
4 changes: 2 additions & 2 deletions onlinejudge/service/aoj.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, problem_id):
self.problem_id = problem_id

def download_sample_cases(self, session: Optional[requests.Session] = None) -> List[TestCase]:
session = session or utils.new_default_session()
session = session or utils.get_default_session()
# get samples via the official API
# reference: http://developers.u-aizu.ac.jp/api?key=judgedat%2Ftestcases%2Fsamples%2F%7BproblemId%7D_GET
url = 'https://judgedat.u-aizu.ac.jp/testcases/samples/{}'.format(self.problem_id)
Expand All @@ -68,7 +68,7 @@ def download_sample_cases(self, session: Optional[requests.Session] = None) -> L
return samples

def download_system_cases(self, session: Optional[requests.Session] = None) -> List[TestCase]:
session = session or utils.new_default_session()
session = session or utils.get_default_session()

# get header
# reference: http://developers.u-aizu.ac.jp/api?key=judgedat%2Ftestcases%2F%7BproblemId%7D%2Fheader_GET
Expand Down
22 changes: 11 additions & 11 deletions onlinejudge/service/atcoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def login(self, get_credentials: onlinejudge.type.CredentialsProvider, session:
:raises LoginError:
"""

session = session or utils.new_default_session()
session = session or utils.get_default_session()
if self.is_logged_in(session=session):
return

Expand Down Expand Up @@ -90,7 +90,7 @@ def login(self, get_credentials: onlinejudge.type.CredentialsProvider, session:
raise LoginError

def is_logged_in(self, session: Optional[requests.Session] = None) -> bool:
session = session or utils.new_default_session()
session = session or utils.get_default_session()
url = 'https://atcoder.jp/contests/practice/submit'
resp = _request('GET', url, session=session, allow_redirects=False)
return resp.status_code == 200
Expand Down Expand Up @@ -127,7 +127,7 @@ def iterate_contests(self, lang: str = 'ja', session: Optional[requests.Session]
"""

assert lang in ('ja', 'en')
session = session or utils.new_default_session()
session = session or utils.get_default_session()
last_page = None
for page in itertools.count(1): # 1-based
if last_page is not None and page > last_page:
Expand Down Expand Up @@ -234,7 +234,7 @@ def _parse_start_time(self, url: str) -> datetime.datetime:
return datetime.datetime.strptime(query['iso'][0], '%Y%m%dT%H%M').replace(tzinfo=utils.tzinfo_jst)

def _load_details(self, session: Optional[requests.Session] = None, lang: Optional[str] = None):
session = session or utils.new_default_session()
session = session or utils.get_default_session()
resp = _request('GET', self.get_url(type='beta', lang=lang), session=session)
soup = bs4.BeautifulSoup(resp.content.decode(resp.encoding), utils.html_parser)

Expand Down Expand Up @@ -286,7 +286,7 @@ def get_name(self, lang: Optional[str] = None, session: Optional[requests.Sessio

def list_problems(self, session: Optional[requests.Session] = None) -> List['AtCoderProblem']:
# get
session = session or utils.new_default_session()
session = session or utils.get_default_session()
url = 'https://atcoder.jp/contests/{}/tasks'.format(self.contest_id)
resp = _request('GET', url, session=session)

Expand Down Expand Up @@ -329,7 +329,7 @@ def _from_table_row(cls, tr: bs4.Tag) -> 'AtCoderProblem':
return self

def download_sample_cases(self, session: Optional[requests.Session] = None) -> List[onlinejudge.type.TestCase]:
session = session or utils.new_default_session()
session = session or utils.get_default_session()
# get
resp = _request('GET', self.get_url(type='beta'), session=session)
if _list_alert(resp):
Expand Down Expand Up @@ -430,7 +430,7 @@ def from_url(cls, s: str) -> Optional['AtCoderProblem']:
return None

def get_input_format(self, session: Optional[requests.Session] = None) -> str:
session = session or utils.new_default_session()
session = session or utils.get_default_session()
# get
resp = _request('GET', self.get_url(type='beta'), session=session)
if _list_alert(resp):
Expand All @@ -456,7 +456,7 @@ def get_available_languages(self, session: Optional[requests.Session] = None) ->
"""
:raises NotLoggedInError:
"""
session = session or utils.new_default_session()
session = session or utils.get_default_session()

# get
resp = _request('GET', self.get_url(type='beta'), session=session)
Expand All @@ -482,7 +482,7 @@ def submit_code(self, code: bytes, language_id: LanguageId, filename: Optional[s
"""

assert language_id in [language.id for language in self.get_available_languages(session=session)]
session = session or utils.new_default_session()
session = session or utils.get_default_session()

# get
url = 'https://atcoder.jp/contests/{}/submit'.format(self.contest_id)
Expand Down Expand Up @@ -517,7 +517,7 @@ def submit_code(self, code: bytes, language_id: LanguageId, filename: Optional[s
raise SubmissionError('it may be a rate limit')

def _load_details(self, session: Optional[requests.Session] = None) -> None:
session = session or utils.new_default_session()
session = session or utils.get_default_session()

# get
resp = _request('GET', self.get_url(type='beta', lang='ja'), session=session)
Expand Down Expand Up @@ -627,7 +627,7 @@ def download_code(self, session: Optional[requests.Session] = None) -> bytes:
return self.get_source_code(session=session)

def _load_details(self, session: Optional[requests.Session] = None) -> None:
session = session or utils.new_default_session()
session = session or utils.get_default_session()
resp = _request('GET', self.get_url(type='beta', lang='en'), session=session)
soup = bs4.BeautifulSoup(resp.content.decode(resp.encoding), utils.html_parser)

Expand Down
10 changes: 5 additions & 5 deletions onlinejudge/service/codeforces.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def login(self, get_credentials: onlinejudge.type.CredentialsProvider, session:
"""
:raises LoginError:
"""
session = session or utils.new_default_session()
session = session or utils.get_default_session()
url = 'https://codeforces.com/enter'
# get
resp = utils.request('GET', url, session=session)
Expand All @@ -53,7 +53,7 @@ def login(self, get_credentials: onlinejudge.type.CredentialsProvider, session:
raise LoginError('Invalid handle or password.')

def is_logged_in(self, session: Optional[requests.Session] = None) -> bool:
session = session or utils.new_default_session()
session = session or utils.get_default_session()
url = 'https://codeforces.com/enter'
resp = utils.request('GET', url, session=session, allow_redirects=False)
return resp.status_code == 302
Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(self, contest_id: int, index: str, kind: Optional[str] = None):
self.kind = kind # It seems 'gym' is specialized, 'contest' and 'problemset' are the same thing

def download_sample_cases(self, session: Optional[requests.Session] = None) -> List[onlinejudge.type.TestCase]:
session = session or utils.new_default_session()
session = session or utils.get_default_session()
# get
resp = utils.request('GET', self.get_url(), session=session)
# parse
Expand All @@ -127,7 +127,7 @@ def get_available_languages(self, session: Optional[requests.Session] = None) ->
:raises NotLoggedInError:
"""

session = session or utils.new_default_session()
session = session or utils.get_default_session()
# get
resp = utils.request('GET', self.get_url(), session=session)
# parse
Expand All @@ -146,7 +146,7 @@ def submit_code(self, code: bytes, language_id: LanguageId, filename: Optional[s
:raises SubmissionError:
"""

session = session or utils.new_default_session()
session = session or utils.get_default_session()
# get
resp = utils.request('GET', self.get_url(), session=session)
# parse
Expand Down
2 changes: 1 addition & 1 deletion onlinejudge/service/csacademy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, contest_name: str, task_name: str):
self.task_name = task_name

def download_sample_cases(self, session: Optional[requests.Session] = None) -> List[TestCase]:
session = session or utils.new_default_session()
session = session or utils.get_default_session()
base_url = self.get_url()

# get csrftoken
Expand Down
14 changes: 7 additions & 7 deletions onlinejudge/service/hackerrank.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def login(self, get_credentials: onlinejudge.type.CredentialsProvider, session:
:raises LoginError:
"""

session = session or utils.new_default_session()
session = session or utils.get_default_session()
url = 'https://www.hackerrank.com/auth/login'
# get
resp = utils.request('GET', url, session=session)
Expand Down Expand Up @@ -61,7 +61,7 @@ def login(self, get_credentials: onlinejudge.type.CredentialsProvider, session:
raise LoginError('You failed to sign in. Wrong user ID or password.')

def is_logged_in(self, session: Optional[requests.Session] = None) -> bool:
session = session or utils.new_default_session()
session = session or utils.get_default_session()
url = 'https://www.hackerrank.com/auth/login'
resp = utils.request('GET', url, session=session)
return '/auth' not in resp.url
Expand Down Expand Up @@ -98,7 +98,7 @@ def download_sample_cases(self, session: Optional[requests.Session] = None) -> L
raise NotImplementedError

def download_system_cases(self, session: Optional[requests.Session] = None) -> List[TestCase]:
session = session or utils.new_default_session()
session = session or utils.get_default_session()
# example: https://www.hackerrank.com/rest/contests/hourrank-1/challenges/beautiful-array/download_testcases
url = 'https://www.hackerrank.com/rest/contests/{}/challenges/{}/download_testcases'.format(self.contest_slug, self.challenge_slug)
resp = utils.request('GET', url, session=session, raise_for_status=False)
Expand Down Expand Up @@ -136,7 +136,7 @@ def _get_model(self, session: Optional[requests.Session] = None) -> Dict[str, An
:raises SubmissionError:
"""

session = session or utils.new_default_session()
session = session or utils.get_default_session()
# get
url = 'https://www.hackerrank.com/rest/contests/{}/challenges/{}'.format(self.contest_slug, self.challenge_slug)
resp = utils.request('GET', url, session=session)
Expand All @@ -149,7 +149,7 @@ def _get_model(self, session: Optional[requests.Session] = None) -> Dict[str, An
return it['model']

def _get_lang_display_mapping(self, session: Optional[requests.Session] = None) -> Dict[str, str]:
session = session or utils.new_default_session()
session = session or utils.get_default_session()
# get
url = 'https://hrcdn.net/hackerrank/assets/codeshell/dist/codeshell-cdffcdf1564c6416e1a2eb207a4521ce.js' # at "Mon Feb 4 14:51:27 JST 2019"
resp = utils.request('GET', url, session=session)
Expand All @@ -168,7 +168,7 @@ def _get_lang_display_mapping(self, session: Optional[requests.Session] = None)
return lang_display_mapping

def get_available_languages(self, session: Optional[requests.Session] = None) -> List[Language]:
session = session or utils.new_default_session()
session = session or utils.get_default_session()
info = self._get_model(session=session)
lang_display_mapping = self._get_lang_display_mapping()
result = [] # type: List[Language]
Expand All @@ -186,7 +186,7 @@ def submit_code(self, code: bytes, language_id: LanguageId, filename: Optional[s
:raises SubmissionError:
"""

session = session or utils.new_default_session()
session = session or utils.get_default_session()
if not self.get_service().is_logged_in(session=session):
raise NotLoggedInError
# get
Expand Down
2 changes: 1 addition & 1 deletion onlinejudge/service/kattis.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, problem_id: str, contest_id: Optional[str] = None, domain: st
self.problem_id = problem_id

def download_sample_cases(self, session: Optional[requests.Session] = None) -> List[onlinejudge.type.TestCase]:
session = session or utils.new_default_session()
session = session or utils.get_default_session()
# get
url = self.get_url(contests=False) + '/file/statement/samples.zip'
resp = utils.request('GET', url, session=session, raise_for_status=False)
Expand Down
2 changes: 1 addition & 1 deletion onlinejudge/service/poj.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, problem_id: int):
self.problem_id = problem_id

def download_sample_cases(self, session: Optional[requests.Session] = None) -> List[TestCase]:
session = session or utils.new_default_session()
session = session or utils.get_default_session()
# get
resp = utils.request('GET', self.get_url(), session=session)
# parse
Expand Down
8 changes: 4 additions & 4 deletions onlinejudge/service/topcoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def login(self, get_credentials: onlinejudge.type.CredentialsProvider, session:
"""
:raises LoginError:
"""
session = session or utils.new_default_session()
session = session or utils.get_default_session()

# NOTE: you can see this login page with https://community.topcoder.com/longcontest/?module=Submit
url = 'https://community.topcoder.com/longcontest/'
Expand Down Expand Up @@ -108,7 +108,7 @@ def from_url(cls, url: str) -> Optional['TopcoderLongContestProblem']:
return None

def get_available_languages(self, session: Optional[requests.Session] = None) -> List[Language]:
session = session or utils.new_default_session()
session = session or utils.get_default_session()

return [
Language(LanguageId('1'), 'Java 8'),
Expand All @@ -126,7 +126,7 @@ def submit_code(self, code: bytes, language_id: LanguageId, filename: Optional[s
"""

assert kind in ['example', 'full']
session = session or utils.new_default_session()
session = session or utils.get_default_session()

# TODO: implement self.is_logged_in()
# if not self.is_logged_in(session=session):
Expand Down Expand Up @@ -194,7 +194,7 @@ def get_standings(self, session: Optional[requests.Session] = None) -> Tuple[Lis
.. deprecated:: 6.0.0
This method may be deleted in future.
"""
session = session or utils.new_default_session()
session = session or utils.get_default_session()

header = None # type: Optional[List[str]]
rows = [] # type: List[Dict[str, str]]
Expand Down
Loading

0 comments on commit f1a0184

Please sign in to comment.