Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timeout for DB access and sanitize #57

Merged
merged 1 commit into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 49 additions & 49 deletions src/o2tuner/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
"""
import sys
from time import time
from os import getcwd, chdir, remove
from os import getcwd, chdir
from os.path import join
from inspect import signature
from copy import deepcopy
import pickle
# import needed to ensure sanity of storage
from sqlalchemy.exc import NoSuchModuleError

import optuna

Expand Down Expand Up @@ -38,51 +40,67 @@ def make_trial_directory(trial):
return cwd


def get_storage_identifier(storage_path):
def adjust_storage_url(storage_url, workdir="./"):
"""
From the full storage path, try to extract the identifier for what to be used
"""
# check if there is a known identifier
storage_prefixes = ["sqlite:///", "mysql:///"]

for prefix in storage_prefixes:
if storage_path.find(prefix) == 0:
return prefix
if storage_url.find(prefix) == 0:
if prefix == "sqlite:///":
# in case of SQLite, the file could be requested to be stored under an absolute or relative path
path = storage_url[len(prefix):]
if path[0] != "/":
# if not an absolute path, put the working directory in between
storage_url = prefix + join(workdir, path)
return storage_url

return None
LOG.warning("Unknown storage identifier in URL %s, might fail", storage_url)

return storage_url


def get_default_storage_url(study_name="o2tuner_study"):
"""
Construct a default storage path
"""
return f"sqlite:///{study_name}.db"


def adjust_storage_path(storage_path, workdir="./"):
def create_storage(storage, workdir="./"):
"""
Make sure the path is either absolute path or relative to the specified workdir.
Take care of storage identifier. Right now check for MySQL and SQLite.
"""

if not storage_path:
if not storage:
# Empty path, cannot know how to deal with it, return None
return None

# check if there is a known identifier
check_prefix = get_storage_identifier(storage_path)

if not check_prefix:
# either no or unknown identifier
return storage_path
# default arguments which we will use
# for now, use a high timeout so we don't fail if another process is currently using the storage backend
engine_kwargs = {"connect_args": {"timeout": 100}}

path = storage_path[len(check_prefix):]
if path[0] == "/":
# Absolute path, just return
return storage_path
if isinstance(storage, str):
# simply treat this as the storage url
url = storage
else:
# first pop the url...
url = storage.pop("url", get_default_storage_url())
if storage:
# ...then check, if there is more in the dictionary; if so, use it
engine_kwargs = storage

# re-assemble, put the working directory in between
return check_prefix + join(workdir, path)


def get_default_storage(study_name):
"""
Construct a default storage path
"""
return f"sqlite:///{study_name}.db"
# check if there is a known identifier
url = adjust_storage_url(url, workdir)
try:
storage = optuna.storages.RDBStorage(url=url, engine_kwargs=engine_kwargs)
return storage
except (ImportError, NoSuchModuleError) as import_error:
LOG.error(import_error)
return None


def load_or_create_study_from_storage(study_name, storage, sampler=None, create_if_not_exists=True):
Expand All @@ -91,17 +109,13 @@ def load_or_create_study_from_storage(study_name, storage, sampler=None, create_
"""
try:
study = optuna.load_study(study_name=study_name, storage=storage, sampler=sampler)
LOG.debug("Loading existing study %s from storage %s", study_name, storage)
LOG.debug("Loading existing study %s from storage %s", study_name, storage.url)
return study
except KeyError:
if create_if_not_exists:
study = optuna.create_study(study_name=study_name, storage=storage, sampler=sampler)
LOG.debug("Creating new study %s at storage %s", study_name, storage)
LOG.debug("Creating new study %s at storage %s", study_name, storage.url)
return study
except ImportError as exc:
# Probably cannot import MySQL or SQLite stuff
LOG.warning("Probably cannot import what is needed for database access. Will try to attempt a serial run.")
LOG.warning(exc)

return None

Expand Down Expand Up @@ -140,7 +154,7 @@ def load_or_create_study(study_name=None, storage=None, sampler=None, workdir=".
file in the given directory with <study_name>.pkl. If found, tru to load.
If also this does not exist, create a new in-memory study.
"""
storage = adjust_storage_path(storage, workdir)
storage = create_storage(storage, workdir)
if study_name and storage:
# Although optuna would come up with a unique name when study_name is None,
# we force a name to be given by the user for those cases
Expand Down Expand Up @@ -174,25 +188,11 @@ def pickle_study(study, workdir="./"):
return file_name


def can_do_storage(storage):
def can_do_storage(storage_url):
"""
Basically a dry run to try and create a study for given storage
"""
identifier = get_storage_identifier(storage)
if not identifier:
LOG.error("Storage %s has unknown identifier, cannot create study.", storage)
return False
filepath = "/tmp/o2tuner_dry_run.db"
if exists_file(filepath):
remove(filepath)
storage = f"{identifier}{filepath}"
can_do, _ = load_or_create_study("o2tuner_dry_study", storage)
if exists_file(filepath):
# E.g. in case of SQLite, remove it
remove(filepath)
if not can_do:
LOG.error("Tested storage via %s, cannot create study at storage %s.", identifier, storage)
return can_do
return create_storage(storage_url) is not None


class OptunaHandler:
Expand Down
6 changes: 3 additions & 3 deletions src/o2tuner/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import functools

from o2tuner.io import make_dir, parse_yaml
from o2tuner.backends import OptunaHandler, can_do_storage, get_default_storage
from o2tuner.backends import OptunaHandler, can_do_storage, get_default_storage_url
from o2tuner.sampler import construct_sampler
from o2tuner.inspector import O2TunerInspector
from o2tuner.exception import O2TunerStopOptimisation
Expand Down Expand Up @@ -81,13 +81,13 @@ def prepare_optimisation(optuna_config, work_dir="o2tuner_optimise"):

if not storage and not in_memory:
# make a default storage, optimisation via storage should be the way to go
storage = get_default_storage(study_name)
storage = get_default_storage_url(study_name)

if not in_memory and not can_do_storage(storage):
# no worries - at this point - if optimisation via storage is not possible
optuna_storage_config["storage"] = None
if jobs > 1:
# however, if more than 1 one requested, abort the preparation here
# however, if more than 1 job requested, abort the preparation here
LOG.error("Requested %d jobs but problem to set up storage %s", jobs, storage)
return None, None, None
else:
Expand Down
Loading