Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callback: step-parameter accepts list of steps #1542

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 61 additions & 42 deletions odl/solvers/util/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,9 @@ def __init__(self, results=None, function=None, step=1):
Deprecated, use composition instead. See examples.
Function to be called on all incoming results before storage.
Default: copy
step : int, optional
Number of iterates between storing iterates.
step : positive int, list, optional
Number of iterates between storing or
list of steps when to store iterates.

Examples
--------
Expand All @@ -227,12 +228,12 @@ def __init__(self, results=None, function=None, step=1):
'instead. '
'See Examples in the documentation.',
DeprecationWarning)
self.step = int(step)
_setupShouldEvaluateAtStep(self, step)
self.iter = 0

def __call__(self, result):
"""Append result to results list."""
if self.iter % self.step == 0:
if self.should_evaluate_at_step(self.iter):
if self.function:
self.results.append(self.function(result))
else:
Expand Down Expand Up @@ -279,8 +280,9 @@ def __init__(self, function, step=1):
----------
function : callable
Function to call on the current iterate.
step : int, optional
Number of iterates between applications of ``function``.
step : positive int, list, optional
Number of iterates between applying ``function``
or list of steps when to apply ``function``.

Examples
--------
Expand All @@ -306,12 +308,12 @@ def __init__(self, function, step=1):
"""
assert callable(function)
self.function = function
self.step = int(step)
_setupShouldEvaluateAtStep(self, step)
self.iter = 0

def __call__(self, result):
"""Apply function to result."""
if self.iter % self.step == 0:
if self.should_evaluate_at_step(self.iter):
self.function(result)
self.iter += 1

Expand Down Expand Up @@ -346,8 +348,9 @@ def __init__(self, fmt='iter = {}', step=1, **kwargs):
print(fmt.format(cur_iter_num))

where ``cur_iter_num`` is the current iteration number.
step : positive int, optional
Number of iterations between output.
step : positive int, list, optional
Number of iterations between output or
list of steps when to output.

Other Parameters
----------------
Expand Down Expand Up @@ -376,13 +379,13 @@ def __init__(self, fmt='iter = {}', step=1, **kwargs):
Current iter is 2.
"""
self.fmt = str(fmt)
self.step = int(step)
_setupShouldEvaluateAtStep(self, step)
self.iter = 0
self.kwargs = kwargs

def __call__(self, _):
"""Print the current iteration."""
if self.iter % self.step == 0:
if self.should_evaluate_at_step(self.iter):
print(self.fmt.format(self.iter), **self.kwargs)

self.iter += 1
Expand Down Expand Up @@ -421,8 +424,9 @@ def __init__(self, fmt='Time elapsed = {:<5.03f} s', step=1,
print(fmt.format(runtime))

where ``runtime`` is the runtime since the last iterate.
step : positive int, optional
Number of iterations between prints.
step : positive int, list, optional
Number of iterations between prints or
list of iterations when to print.
cumulative : boolean, optional
Print the time since the initialization instead of the last call.

Expand All @@ -432,15 +436,15 @@ def __init__(self, fmt='Time elapsed = {:<5.03f} s', step=1,
Key word arguments passed to the print function.
"""
self.fmt = str(fmt)
self.step = int(step)
_setupShouldEvaluateAtStep(self, step)
self.iter = 0
self.cumulative = cumulative
self.start_time = time.time()
self.kwargs = kwargs

def __call__(self, _):
"""Print time elapsed from the previous iteration."""
if self.iter % self.step == 0:
if self.should_evaluate_at_step(self.iter):
current_time = time.time()

print(self.fmt.format(current_time - self.start_time),
Expand Down Expand Up @@ -484,8 +488,9 @@ def __init__(self, func=None, fmt='{!r}', step=1, **kwargs):
print(fmt.format(x))

where ``x`` is the input to the callback.
step : positive int, optional
Number of iterations between prints.
step : positive int, list, optional
Number of iterations between prints or
list of iterations when to print.

Other Parameters
----------------
Expand Down Expand Up @@ -522,13 +527,13 @@ def __init__(self, func=None, fmt='{!r}', step=1, **kwargs):
raise TypeError('`func` must be `callable` or `None`')

self.fmt = str(fmt)
self.step = int(step)
_setupShouldEvaluateAtStep(self, step)
self.iter = 0
self.kwargs = kwargs

def __call__(self, result):
"""Print the current value."""
if self.iter % self.step == 0:
if self.should_evaluate_at_step(self.iter):
if self.func is not None:
result = self.func(result)

Expand Down Expand Up @@ -588,8 +593,9 @@ def __init__(self, title=None, step=1, saveto=None, **kwargs):
where ``cur_iter_num`` is the current iteration number.
For the default ``None``, the title format ``'Iterate {}'``
is used.
step : positive int, optional
Number of iterations between plots.
step : positive int, list, optional
Number of iterations between plots or
list of iterations when to plot.
saveto : str or callable, optional
Format string for the name of the file(s) where
iterates are saved.
Expand Down Expand Up @@ -643,7 +649,7 @@ def __init__(self, title=None, step=1, saveto=None, **kwargs):
self.saveto = saveto
self.saveto_formatter = getattr(self.saveto, 'format', self.saveto)

self.step = step
_setupShouldEvaluateAtStep(self, step)
self.fig = kwargs.pop('fig', None)
self.iter = 0
self.space_of_last_x = None
Expand All @@ -656,7 +662,7 @@ def __call__(self, x):
update_in_place = (self.space_of_last_x == x_space)
self.space_of_last_x = x_space

if self.iter % self.step == 0:
if self.should_evaluate_at_step(self.iter):
title = self.title_formatter(self.iter)

if self.saveto is None:
Expand Down Expand Up @@ -707,8 +713,9 @@ def __init__(self, saveto, step=1, impl='pickle', **kwargs):
filename = saveto.format(cur_iter_num)

where ``cur_iter_num`` is the current iteration number.
step : positive int, optional
Number of iterations between saves.
step : positive int, list, optional
Number of iterations between saves or
list of iterations when to save.
impl : {'pickle', 'numpy', 'numpy_txt'}, optional
The format to store the iterates in. Numpy formats are only usable
if the data can be converted to an array via `numpy.asarray`.
Expand Down Expand Up @@ -741,14 +748,14 @@ def __init__(self, saveto, step=1, impl='pickle', **kwargs):
except AttributeError:
self.saveto_formatter = self.saveto

self.step = int(step)
_setupShouldEvaluateAtStep(self, step)
self.impl = str(impl).lower()
self.kwargs = kwargs
self.iter = 0

def __call__(self, x):
"""Save the current iterate."""
if self.iter % self.step == 0:
if self.should_evaluate_at_step(self.iter):
file_path = self.saveto_formatter(self.iter)
folder_path = os.path.dirname(os.path.realpath(file_path))

Expand Down Expand Up @@ -823,7 +830,7 @@ class CallbackShowConvergence(Callback):
"""Displays a convergence plot."""

def __init__(self, functional, title='convergence', logx=False, logy=False,
**kwargs):
step=1, **kwargs):
"""Initialize a new instance.

Parameters
Expand All @@ -848,6 +855,7 @@ def __init__(self, functional, title='convergence', logx=False, logy=False,
self.logx = logx
self.logy = logy
self.kwargs = kwargs
_setupShouldEvaluateAtStep(self, step)
self.iter = 0

import matplotlib.pyplot as plt
Expand All @@ -863,11 +871,12 @@ def __init__(self, functional, title='convergence', logx=False, logy=False,

def __call__(self, x):
"""Implement ``self(x)``."""
if self.logx:
it = self.iter + 1
else:
it = self.iter
self.ax.scatter(it, self.functional(x), **self.kwargs)
if self.should_evaluate_at_step(self.iter):
if self.logx:
it = self.iter + 1
else:
it = self.iter
self.ax.scatter(it, self.functional(x), **self.kwargs)
self.iter += 1

def reset(self):
Expand Down Expand Up @@ -897,8 +906,9 @@ def __init__(self, step=1, fmt_cpu='CPU usage (% each core): {}',

Parameters
----------
step : positive int, optional
Number of iterations between output.
step : positive int, list, optional
Number of iterations between output or
list of iterations when to output.
fmt_cpu : string, optional
Formating that should be applied. The CPU usage is printed as ::

Expand Down Expand Up @@ -944,7 +954,7 @@ def __init__(self, step=1, fmt_cpu='CPU usage (% each core): {}',
... fmt_mem='RAM {}',
... fmt_swap='')
"""
self.step = int(step)
_setupShouldEvaluateAtStep(self, step)
self.fmt_cpu = str(fmt_cpu)
self.fmt_mem = str(fmt_mem)
self.fmt_swap = str(fmt_swap)
Expand All @@ -955,7 +965,7 @@ def __call__(self, _):

import psutil

if self.iter % self.step == 0:
if self.should_evaluate_at_step(self.iter):
if self.fmt_cpu:
print(self.fmt_cpu.format(psutil.cpu_percent(percpu=True)),
**self.kwargs)
Expand Down Expand Up @@ -996,22 +1006,23 @@ def __init__(self, niter, step=1, **kwargs):
----------
niter : positive int, optional
Total number of iterations.
step : positive int, optional
Number of iterations between output.
step : positive int, list, optional
Number of iterations between output or
list of iterations when to output.

Other Parameters
----------------
kwargs :
Further parameters passed to ``tqdm.tqdm``.
"""
self.niter = int(niter)
self.step = int(step)
_setupShouldEvaluateAtStep(self, step)
self.kwargs = kwargs
self.reset()

def __call__(self, _):
"""Update the progressbar."""
if self.iter % self.step == 0:
if self.should_evaluate_at_step(self.iter):
self.pbar.update(self.step)

self.iter += 1
Expand All @@ -1034,6 +1045,14 @@ def __repr__(self):
return '{}({})'.format(self.__class__.__name__,
inner_str)

def _setupShouldEvaluateAtStep(self, step):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use a different function name style?

try:
self.step = frozenset(int(i) for i in step)
self.should_evaluate_at_step = lambda i: i in self.step
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attributes should only be set in __init__ if possible, not in "free" functions like this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to refactor the whole repetition and thought about returning self, so that the user does see that self is modified inside of this function. Otherwise there is a lot of code repetition and the inits are quite bloated if the user wants to take a look at them.
I also thought about a metaclass that will automatically setup the the init and call, but I guess that is a bit of an overkill

except TypeError:
self.step = int(step)
self.should_evaluate_at_step = lambda i: i % self.step == 0


if __name__ == '__main__':
from odl.util.testutils import run_doctests
Expand Down