Skip to content

Commit

Permalink
Removed bounds_enforcing_decorator_factory from repo, resolving #756
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Hearin committed Nov 27, 2017
1 parent 39fc1bb commit 41ebd71
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 87 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

- Added `rp_pi_tpcf_jackknife` function.

- Removed obsolete bounds_enforcing_decorator_factory function - see https://github.com/astropy/halotools/issues/756


0.5 (2017-05-31)
----------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def _decorate_baseline_method(self):
"and the baseline model must have a method named ``%s``")
raise HalotoolsError(msg % self._method_name_to_decorate)

@model_helpers.bounds_enforcing_decorator_factory(0, 1)
def percentile_splitting_function(self, prim_haloprop):
"""
Method returns the fraction of halos that are ``type-2``
Expand Down Expand Up @@ -266,16 +265,14 @@ def percentile_splitting_function(self, prim_haloprop):
self._split_abscissa, self._split_ordinates, k=3)
result = spline_function(prim_haloprop)

result = np.where(result < 0, 0., result)
result = np.where(result > 1, 1., result)
return result

@model_helpers.bounds_enforcing_decorator_factory(-1, 1)
def assembias_strength(self, prim_haloprop):
"""
Method returns the strength of assembly bias as a function of the primary halo property.
The `bounds_enforcing_decorator_factory` guarantees that the assembly bias
strength is enforced to be between -1 and 1.
Parameters
----------
prim_haloprop : array_like
Expand All @@ -296,6 +293,8 @@ def assembias_strength(self, prim_haloprop):
else:
result = spline_function(prim_haloprop)

result = np.where(result < -1, -1., result)
result = np.where(result > 1, 1., result)
return result

def _get_assembias_param_dict_key(self, ipar):
Expand Down
70 changes: 2 additions & 68 deletions halotools/empirical_models/model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,14 @@
import numpy as np
from scipy.interpolate import InterpolatedUnivariateSpline as spline
from scipy.special import gammaincc, gamma, expi
from warnings import warn

from ..utils.array_utils import custom_len
from ..custom_exceptions import HalotoolsError


__all__ = ('solve_for_polynomial_coefficients', 'polynomial_from_table',
'enforce_periodicity_of_box', 'custom_spline', 'create_composite_dtype',
'bind_default_kwarg_mixin_safe',
'custom_incomplete_gamma', 'bounds_enforcing_decorator_factory')
'enforce_periodicity_of_box', 'custom_spline', 'create_composite_dtype',
'bind_default_kwarg_mixin_safe', 'custom_incomplete_gamma')

__author__ = ['Andrew Hearin', 'Surhud More', 'Johannes Ulf Lange']

Expand Down Expand Up @@ -445,67 +443,3 @@ def custom_incomplete_gamma(a, x):
else:
return gammaincc(a, x) * gamma(a)
custom_incomplete_gamma.__author__ = ['Surhud More', 'Johannes Ulf Lange']


def bounds_enforcing_decorator_factory(lower_bound, upper_bound, warning=True):
r"""
Function returns a decorator that can be used to clip the values
of an original function to produce a modified function whose
values are replaced by the input ``lower_bound`` and ``upper_bound`` whenever
the original function returns out of range values.
Parameters
-----------
lower_bound : float or int
Lower bound defining the output decorator
upper_bound : float or int
Upper bound defining the output decorator
warning : bool, optional
If True, decorator will raise a warning for cases where the values of the
undecorated function fall outside the boundaries. Default is True.
Returns
--------
decorator : object
Python decorator used to apply to any function for which you wish to
enforce that that the returned values of the original function are modified
to be bounded by ``lower_bound`` and ``upper_bound``.
Examples
--------
>>> def original_function(x): return x + 4
>>> lower_bound, upper_bound = 0, 5
>>> decorator = bounds_enforcing_decorator_factory(lower_bound, upper_bound)
>>> modified_function = decorator(original_function)
>>> assert original_function(3) == 7
>>> assert modified_function(3) == upper_bound
>>> assert original_function(-10) == -6
>>> assert modified_function(-10) == lower_bound
>>> assert original_function(0) == modified_function(0) == 4
"""

def decorator(input_func):

def output_func(*args, **kwargs):

unbounded_result = np.array(input_func(*args, **kwargs))
lower_bounded_result = np.where(unbounded_result < lower_bound, lower_bound, unbounded_result)
bounded_result = np.where(lower_bounded_result > upper_bound, upper_bound, lower_bounded_result)

if warning is True:
raise_warning = np.any(unbounded_result != bounded_result)
if raise_warning is True:
func_name = input_func.__name__
msg = ("The " + func_name + " function \nreturned at least one value that was "
"outside the range (%.2f, %.2f)\n. The bounds_enforcing_decorator_factory "
"manually set all such values equal to \nthe appropriate boundary condition.\n")
warn(msg)

return bounded_result

return output_func

return decorator
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from .. import model_defaults, model_helpers
from ..smhm_models import Behroozi10SmHm
from ..assembias_models import HeavisideAssembias
from ..model_helpers import bounds_enforcing_decorator_factory

from ...utils.array_utils import custom_len
from ... import sim_manager
Expand Down Expand Up @@ -139,7 +138,6 @@ def _initialize_param_dict(self,
for key, value in zip(self._ordinates_keys, quiescent_fraction_ordinates):
self.param_dict[key] = value

@bounds_enforcing_decorator_factory(0, 1, warning=False)
def mean_quiescent_fraction(self, **kwargs):
"""
"""
Expand All @@ -160,6 +158,9 @@ def mean_quiescent_fraction(self, **kwargs):

fraction = spline_function(np.log10(prim_haloprop))

fraction = np.where(fraction < 0, 0., fraction)
fraction = np.where(fraction > 1, 1., fraction)

return fraction

def mc_sfr_designation(self, seed=None, **kwargs):
Expand Down
13 changes: 1 addition & 12 deletions halotools/empirical_models/tests/test_model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from astropy.utils.misc import NumpyRNGContext

from ..model_helpers import custom_spline, create_composite_dtype
from ..model_helpers import bounds_enforcing_decorator_factory, enforce_periodicity_of_box
from ..model_helpers import enforce_periodicity_of_box
from ..model_helpers import call_func_table, bind_default_kwarg_mixin_safe

from ...custom_exceptions import HalotoolsError
Expand Down Expand Up @@ -77,17 +77,6 @@ def __init__(self, d):
assert substr in err.value.args[0]


def test_bounds_enforcing_decorator_factory():
"""
"""
def f(x):
return x
decorator = bounds_enforcing_decorator_factory(0, 1, warning=True)
decorated_f = decorator(f)
result = decorated_f(-1)
assert result == 0


def test_enforce_periodicity_of_box():
""" Verify that enforce_periodicity_of_box results in all points located
inside [0, Lbox]
Expand Down

0 comments on commit 41ebd71

Please sign in to comment.