Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Returning success count from the .populate() call #1050

Merged
merged 23 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Changed - Migrate docs from `https://docs.datajoint.org/python` to `https://datajoint.com/docs/core/datajoint-python`
- Fixed - Updated set_password to work on MySQL 8 - PR [#1106](https://github.com/datajoint/datajoint-python/pull/1106)
- Added - Missing tests for set_password - PR [#1106](https://github.com/datajoint/datajoint-python/pull/1106)
- Changed - Returning success count after the .populate() call - PR [#1050](https://github.com/datajoint/datajoint-python/pull/1050)

### 0.14.1 -- Jun 02, 2023
- Fixed - Fix altering a part table that uses the "master" keyword - PR [#991](https://github.com/datajoint/datajoint-python/pull/991)
Expand Down
179 changes: 100 additions & 79 deletions datajoint/autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ def populate(
to be passed down to each ``make()`` call. Computation arguments should be
specified within the pipeline e.g. using a `dj.Lookup` table.
:type make_kwargs: dict, optional
:return: a dict with two keys
"success_count": the count of successful ``make()`` calls in this ``populate()`` call
"error_list": the error list that is filled if `suppress_errors` is True
"""
if self.connection.in_transaction:
raise DataJointError("Populate cannot be called during a transaction.")
Expand Down Expand Up @@ -222,49 +225,62 @@ def handler(signum, frame):

keys = keys[:max_calls]
nkeys = len(keys)
if not nkeys:
return

processes = min(_ for _ in (processes, nkeys, mp.cpu_count()) if _)

error_list = []
populate_kwargs = dict(
suppress_errors=suppress_errors,
return_exception_objects=return_exception_objects,
make_kwargs=make_kwargs,
)
success_list = []

if processes == 1:
for key in (
tqdm(keys, desc=self.__class__.__name__) if display_progress else keys
):
error = self._populate1(key, jobs, **populate_kwargs)
if error is not None:
error_list.append(error)
else:
# spawn multiple processes
self.connection.close() # disconnect parent process from MySQL server
del self.connection._conn.ctx # SSLContext is not pickleable
with mp.Pool(
processes, _initialize_populate, (self, jobs, populate_kwargs)
) as pool, (
tqdm(desc="Processes: ", total=nkeys)
if display_progress
else contextlib.nullcontext()
) as progress_bar:
for error in pool.imap(_call_populate1, keys, chunksize=1):
if error is not None:
error_list.append(error)
if display_progress:
progress_bar.update()
self.connection.connect() # reconnect parent process to MySQL server
if nkeys:
processes = min(_ for _ in (processes, nkeys, mp.cpu_count()) if _)

populate_kwargs = dict(
suppress_errors=suppress_errors,
return_exception_objects=return_exception_objects,
make_kwargs=make_kwargs,
)

if processes == 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is processes == 0 handled? Perhaps

if not processes:
       return {
            "success_count": 0,
            "error_list": [],
        }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is handled in the else block

for key in (
tqdm(keys, desc=self.__class__.__name__)
if display_progress
else keys
):
status = self._populate1(key, jobs, **populate_kwargs)
if status is True:
success_list.append(1)
elif isinstance(status, tuple):
error_list.append(status)
else:
assert status is False
else:
# spawn multiple processes
self.connection.close() # disconnect parent process from MySQL server
del self.connection._conn.ctx # SSLContext is not pickleable
with mp.Pool(
processes, _initialize_populate, (self, jobs, populate_kwargs)
) as pool, (
tqdm(desc="Processes: ", total=nkeys)
if display_progress
else contextlib.nullcontext()
) as progress_bar:
for status in pool.imap(_call_populate1, keys, chunksize=1):
if status is True:
success_list.append(1)
elif isinstance(status, tuple):
error_list.append(status)
else:
assert status is False
if display_progress:
progress_bar.update()
self.connection.connect() # reconnect parent process to MySQL server

# restore original signal handler:
if reserve_jobs:
signal.signal(signal.SIGTERM, old_handler)

if suppress_errors:
return error_list
return {
"success_count": sum(success_list),
"error_list": error_list,
}

def _populate1(
self, key, jobs, suppress_errors, return_exception_objects, make_kwargs=None
Expand All @@ -275,55 +291,60 @@ def _populate1(
:param key: dict specifying job to populate
:param suppress_errors: bool if errors should be suppressed and returned
:param return_exception_objects: if True, errors must be returned as objects
:return: (key, error) when suppress_errors=True, otherwise None
:return: (key, error) when suppress_errors=True,
True if successfully invoke one `make()` call, otherwise False
"""
make = self._make_tuples if hasattr(self, "_make_tuples") else self.make

if jobs is None or jobs.reserve(self.target.table_name, self._job_key(key)):
self.connection.start_transaction()
if key in self.target: # already populated
if jobs is not None and not jobs.reserve(
self.target.table_name, self._job_key(key)
):
return False

self.connection.start_transaction()
if key in self.target: # already populated
self.connection.cancel_transaction()
if jobs is not None:
jobs.complete(self.target.table_name, self._job_key(key))
return False

logger.debug(f"Making {key} -> {self.target.full_table_name}")
self.__class__._allow_insert = True
try:
make(dict(key), **(make_kwargs or {}))
except (KeyboardInterrupt, SystemExit, Exception) as error:
try:
self.connection.cancel_transaction()
if jobs is not None:
jobs.complete(self.target.table_name, self._job_key(key))
except LostConnectionError:
pass
error_message = "{exception}{msg}".format(
exception=error.__class__.__name__,
msg=": " + str(error) if str(error) else "",
)
logger.debug(
f"Error making {key} -> {self.target.full_table_name} - {error_message}"
)
if jobs is not None:
# show error name and error message (if any)
jobs.error(
self.target.table_name,
self._job_key(key),
error_message=error_message,
error_stack=traceback.format_exc(),
)
if not suppress_errors or isinstance(error, SystemExit):
raise
else:
logger.debug(f"Making {key} -> {self.target.full_table_name}")
self.__class__._allow_insert = True
try:
make(dict(key), **(make_kwargs or {}))
except (KeyboardInterrupt, SystemExit, Exception) as error:
try:
self.connection.cancel_transaction()
except LostConnectionError:
pass
error_message = "{exception}{msg}".format(
exception=error.__class__.__name__,
msg=": " + str(error) if str(error) else "",
)
logger.debug(
f"Error making {key} -> {self.target.full_table_name} - {error_message}"
)
if jobs is not None:
# show error name and error message (if any)
jobs.error(
self.target.table_name,
self._job_key(key),
error_message=error_message,
error_stack=traceback.format_exc(),
)
if not suppress_errors or isinstance(error, SystemExit):
raise
else:
logger.error(error)
return key, error if return_exception_objects else error_message
else:
self.connection.commit_transaction()
logger.debug(
f"Success making {key} -> {self.target.full_table_name}"
)
if jobs is not None:
jobs.complete(self.target.table_name, self._job_key(key))
finally:
self.__class__._allow_insert = False
logger.error(error)
return key, error if return_exception_objects else error_message
else:
self.connection.commit_transaction()
logger.debug(f"Success making {key} -> {self.target.full_table_name}")
if jobs is not None:
jobs.complete(self.target.table_name, self._job_key(key))
return True
finally:
self.__class__._allow_insert = False

def progress(self, *restrictions, display=False):
"""
Expand Down
17 changes: 17 additions & 0 deletions tests_old/test_autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,23 @@ def test_populate(self):
assert_true(self.ephys)
assert_true(self.channel)

def test_populate_with_success_count(self):
# test simple populate
assert_true(self.subject, "root tables are empty")
assert_false(self.experiment, "table already filled?")
ret = self.experiment.populate()
success_count = ret["success_count"]
assert_equal(len(self.experiment.key_source & self.experiment), success_count)

# test restricted populate
assert_false(self.trial, "table already filled?")
restriction = self.subject.proj(animal="subject_id").fetch("KEY")[0]
d = self.trial.connection.dependencies
d.load()
ret = self.trial.populate(restriction, suppress_errors=True)
success_count = ret["success_count"]
assert_equal(len(self.trial.key_source & self.trial), success_count)

def test_populate_exclude_error_and_ignore_jobs(self):
# test simple populate
assert_true(self.subject, "root tables are empty")
Expand Down