-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds on-the-fly enabling of function argument casting
- Loading branch information
1 parent
2513172
commit 24565c8
Showing
5 changed files
with
214 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import functools | ||
import inspect | ||
import sys | ||
import types | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
|
||
CAST_TYPES = [list, pd.Series] | ||
|
||
|
||
def type_casting(func): | ||
"""Type casting | ||
This decorator casts input arguments of types [list, pandas.Series] to numpy.ndarray | ||
so the algorithms accept these input arguments. | ||
As a bonus, the decorator casts the return value of the algorithm to the type of the first | ||
array-like input argument. | ||
Parameters | ||
---------- | ||
module_or_func : [types.ModuleType, types.FunctionType] | ||
Module or function that has to be type casted. Can be None. | ||
Returns | ||
------- | ||
function | ||
Decorated function. | ||
""" | ||
@functools.wraps(func) | ||
def func_wrapper(*args, **kwargs): | ||
output_type = None | ||
new_args = [] | ||
|
||
for arg in args: | ||
input_type = type(arg) | ||
|
||
if input_type in CAST_TYPES: | ||
new_args.append(np.asarray(arg)) | ||
|
||
if output_type is None: | ||
# Type of first array-like argument is used for output casting | ||
output_type = input_type | ||
else: | ||
new_args.append(arg) | ||
|
||
new_kwargs = dict() | ||
for key, value in kwargs.items(): | ||
input_type = type(value) | ||
|
||
if input_type in CAST_TYPES: | ||
new_kwargs[key] = np.asarray(value) | ||
|
||
if output_type is None: | ||
# Type of first array-like argument is used for output casting | ||
output_type = input_type | ||
else: | ||
new_kwargs[key] = value | ||
|
||
output = func(*new_args, **new_kwargs) | ||
if output_type is not None and isinstance(output, np.ndarray): | ||
return output_type(output) | ||
else: | ||
return output | ||
|
||
return func_wrapper | ||
|
||
|
||
def enable_type_casting(module_or_func=None): | ||
"""Enable type casting | ||
This method enables casting of input arguments to numpy.ndarray so the algorithms accept | ||
array-like input arguments of types list and pandas.Series. | ||
As a bonus, the return value of the algorithm is casted to the type of the first array-like input argument. | ||
Parameters | ||
---------- | ||
module_or_func : [types.ModuleType, types.FunctionType] | ||
Module or function that has to be type casted. Can be None. | ||
Returns | ||
------- | ||
function | ||
Decorated function. | ||
""" | ||
if module_or_func is None: | ||
# Because sys.modules changes during this operation we cannot loop over sys.modules directly | ||
key_values = [(key, value) for key, value in sys.modules.items()] | ||
for key, value in key_values: | ||
# @TODO this if statement might not cover all cases (or too much cases) | ||
if key.startswith('sweat.algorithms') and key != 'sweat.algorithms.utils': | ||
enable_type_casting(module_or_func=value) | ||
|
||
elif isinstance(module_or_func, types.ModuleType): | ||
for name, obj in [(name, obj) for name, obj in inspect.getmembers(module_or_func)]: | ||
if inspect.isfunction(obj) and inspect.getmodule(obj).__package__ == module_or_func.__package__: | ||
func = getattr(module_or_func, name) | ||
setattr(module_or_func, name, type_casting(func)) | ||
|
||
elif isinstance(module_or_func, types.FunctionType): | ||
return type_casting(module_or_func) | ||
|
||
else: | ||
raise ValueError('enable_type_casting takes arguments of types [ModuleType, FunctionType]') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import importlib | ||
import sys | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
from sweat.algorithms import utils | ||
from sweat.algorithms.metrics import power | ||
|
||
|
||
@pytest.fixture() | ||
def reload_power_module(): | ||
yield | ||
key_values = [(key, value) for key, value in sys.modules.items()] | ||
for key, value in key_values: | ||
if key.startswith('sweat.algorithms'): | ||
importlib.reload(value) | ||
|
||
|
||
def test_enable_type_casting_module(reload_power_module): | ||
pwr = [1, 2, 3] | ||
wap = [1, 2, 3] | ||
weight = 80 | ||
threshold_power = 80 | ||
|
||
assert isinstance(power.wpk(np.asarray(pwr), weight), np.ndarray) | ||
assert isinstance(power.relative_intensity(np.asarray(wap), threshold_power), np.ndarray) | ||
|
||
with pytest.raises(TypeError): | ||
power.wpk(pwr, weight) | ||
|
||
with pytest.raises(TypeError): | ||
power.relative_intensity(wap, threshold_power) | ||
|
||
doc_string = power.wpk.__doc__ | ||
utils.enable_type_casting(power) | ||
|
||
assert isinstance(power.wpk(pwr, weight), list) | ||
assert isinstance(power.wpk(pd.Series(pwr), weight), pd.Series) | ||
assert isinstance(power.wpk(np.asarray(pwr), weight), np.ndarray) | ||
|
||
assert isinstance(power.relative_intensity(wap, threshold_power), list) | ||
assert isinstance(power.relative_intensity(pd.Series(wap), threshold_power), pd.Series) | ||
assert isinstance(power.relative_intensity(np.asarray(wap), threshold_power), np.ndarray) | ||
|
||
assert power.wpk.__doc__ == doc_string | ||
|
||
def test_enable_type_casting_func(reload_power_module): | ||
pwr = [1, 2, 3] | ||
wap = [1, 2, 3] | ||
weight = 80 | ||
threshold_power = 80 | ||
|
||
assert isinstance(power.wpk(np.asarray(pwr), weight), np.ndarray) | ||
assert isinstance(power.relative_intensity(np.asarray(wap), threshold_power), np.ndarray) | ||
|
||
with pytest.raises(TypeError): | ||
power.wpk(pwr, weight) | ||
|
||
with pytest.raises(TypeError): | ||
power.relative_intensity(wap, threshold_power) | ||
|
||
doc_string = power.wpk.__doc__ | ||
wpk = utils.enable_type_casting(power.wpk) | ||
|
||
assert isinstance(wpk(pwr, weight), list) | ||
assert isinstance(wpk(pd.Series(pwr), weight), pd.Series) | ||
assert isinstance(wpk(np.asarray(pwr), weight), np.ndarray) | ||
|
||
with pytest.raises(TypeError): | ||
power.relative_intensity(wap, threshold_power) | ||
|
||
assert power.wpk.__doc__ == doc_string | ||
|
||
|
||
def test_enable_type_casting_all(reload_power_module): | ||
pwr = [1, 2, 3] | ||
wap = [1, 2, 3] | ||
weight = 80 | ||
threshold_power = 80 | ||
|
||
assert isinstance(power.wpk(np.asarray(pwr), weight), np.ndarray) | ||
assert isinstance(power.relative_intensity(np.asarray(wap), threshold_power), np.ndarray) | ||
|
||
with pytest.raises(TypeError): | ||
power.wpk(pwr, weight) | ||
|
||
with pytest.raises(TypeError): | ||
power.relative_intensity(wap, threshold_power) | ||
|
||
doc_string = power.wpk.__doc__ | ||
utils.enable_type_casting() | ||
|
||
assert isinstance(power.wpk(pwr, weight), list) | ||
assert isinstance(power.wpk(pd.Series(pwr), weight), pd.Series) | ||
assert isinstance(power.wpk(np.asarray(pwr), weight), np.ndarray) | ||
|
||
assert isinstance(power.relative_intensity(wap, threshold_power), list) | ||
assert isinstance(power.relative_intensity(pd.Series(wap), threshold_power), pd.Series) | ||
assert isinstance(power.relative_intensity(np.asarray(wap), threshold_power), np.ndarray) | ||
|
||
assert power.wpk.__doc__ == doc_string | ||
|
||
|
||
def test_enable_type_casting_error(): | ||
with pytest.raises(ValueError): | ||
utils.enable_type_casting('covfefe') |