Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Class parameter set, the second try #196

Merged
merged 3 commits into from
Nov 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions spotpy/algorithms/_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ def __init__(self, spot_setup, dbname=None, dbformat=None, dbinit=True,
self.parnames = param_info['name']

# Create a type to hold the parameter values using a namedtuple
self.partype = parameter.get_namedtuple_from_paramnames(
self.setup, self.parnames)
self.partype = parameter.ParameterSet(param_info)

# use alt_objfun if alt_objfun is defined in objectivefunctions,
# else self.setup.objectivefunction
Expand Down
2 changes: 1 addition & 1 deletion spotpy/gui/mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def run(self, _=None):
sim = self.setup.simulation(parset)
objf = as_scalar(self.setup.objectivefunction(sim, self.setup.evaluation()))
label = ('{:0.4g}=M('.format(objf)
+ ', '.join('{f}={v:0.4g}'.format(f=f, v=v) for f, v in zip(parset._fields, parset))
+ ', '.join('{f}={v:0.4g}'.format(f=f, v=v) for f, v in zip(parset.name, parset))
+ ')')
self.lines.extend(self.ax.plot(sim, '-', label=label))
self.ax.legend()
Expand Down
174 changes: 145 additions & 29 deletions spotpy/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

if sys.version_info[0] >= 3:
unicode = str

from collections import namedtuple
from itertools import cycle


Expand Down Expand Up @@ -517,6 +515,146 @@ def __init__(self, *args, **kwargs):
"""
super(Triangular, self).__init__(rnd.triangular, 'Triangular', *args, **kwargs)


class ParameterSet(object):
"""
A Pickable parameter set to use named parameters in a setup
Is not created by a user directly, but in algorithm.
Older versions used a namedtuple, which is not pickable.
An instance of ParameterSet is sent to the users setup.simulate method.
Usage:
>>> ps = ParameterSet(...)
Update values by arguments or keyword arguments
>>> ps(0, 1, 2)
>>> ps(a=1, c=2)
Assess parameter values of this parameter set
>>> ps[0] == ps['a'] == ps.a
A parameter set is a sequence:
>>> list(ps)
Assess the parameter set properties as arrays
>>> [ps.maxbound, ps.minbound, ps.optguess, ps.step, ps.random]
"""
def __init__(self, param_info):
"""
Creates a set of parameters from a parameter info array.
To create the parameter set from a setup use either:
>>> setup = ...
>>> ps = ParameterSet(get_parameters_array(setup))
or you can just use a function for this:
>>> ps = create_set(setup)
:param param_info: A record array containing the properties of the parameters
of this set.
"""
self.__lookup = dict(("p" + x if x.isdigit() else x, i) for i, x in enumerate(param_info['name']))
self.__info = param_info

def __call__(self, *values, **kwargs):
"""
Populates the values ('random') of the parameter set with new data
:param values: Contains the new values or omitted.
If given, the number of values needs to match the number
of parameters
:param kwargs: Can be used to set only single parameter values
:return:
"""
if values:
if len(self.__info) != len(values):
raise ValueError('Given values do are not the same length as the parameter set')
self.__info['random'][:] = values
for k in kwargs:
try:
self.__info['random'][self.__lookup[k]] = kwargs[k]
except KeyError:
raise TypeError('{} is not a parameter of this set'.format(k))
return self

def __len__(self):
return len(self.__info['random'])

def __iter__(self):
return iter(self.__info['random'])

def __getitem__(self, item):
"""
Provides item access
>>> ps[0] == ps['a']
:raises: KeyError, IndexError and TypeError
"""
if type(item) is str:
item = self.__lookup[item]
return self.__info['random'][item]

def __setitem__(self, key, value):
"""
Provides setting of item
>>> ps[0] = 1
>>> ps['a'] = 2
"""
if key in self.__lookup:
key = self.__lookup[key]
self.__info['random'][key] = value

def __getattr__(self, item):
"""
Provides the attribute access like
>>> print(ps.a)
"""
if item.startswith('_'):
raise AttributeError('{} is not a member of this parameter set'.format(item))
elif item in self.__lookup:
return self.__info['random'][self.__lookup[item]]
elif item in self.__info.dtype.names:
return self.__info[item]
else:
raise AttributeError('{} is not a member of this parameter set'.format(item))

def __setattr__(self, key, value):
"""
Provides setting of attributes
>>> ps.a = 2
"""
# Allow normal usage
if key.startswith('_') or key not in self.__lookup:
return object.__setattr__(self, key, value)
else:
self.__info['random'][self.__lookup[key]] = value

def __str__(self):
return 'parameters({})'.format(
', '.join('{}={:g}'.format(k, self.__info['random'][i])
for i, k in enumerate(self.__info['name'])
)
)

def __repr__(self):
return 'spotpy.parameter.ParameterSet()'

def __dir__(self):
"""
Helps to show the field names in an interactive environment like IPython.
See: http://ipython.readthedocs.io/en/stable/config/integrating.html
:return: List of method names and fields
"""
attrs = [attr for attr in vars(type(self)) if not attr.startswith('_')]
return attrs + list(self.__info['name']) + list(self.__info.dtype.names)


def get_classes():
keys = []
current_module = sys.modules[__name__]
Expand All @@ -525,6 +663,7 @@ def get_classes():
keys.append(key)
return keys


def generate(parameters):
"""
This function generates a parameter set from a list of parameter objects. The parameter set
Expand Down Expand Up @@ -561,7 +700,6 @@ def get_parameters_array(setup, unaccepted_parameter_types=()):
# function
param_arrays = []
# Get parameters defined with the setup class
#setup_parameters = checked_parameter_types(, unaccepted_parameter_types)
setup_parameters = get_parameters_from_setup(setup)
check_parameter_types(setup_parameters, unaccepted_parameter_types)
param_arrays.append(
Expand All @@ -587,8 +725,7 @@ def find_constant_parameters(parameter_array):
return (parameter_array['maxbound'] - parameter_array['minbound'] == 0.0)



def create_set(setup, valuetype='optguess', **kwargs):
def create_set(setup, valuetype='random', **kwargs):
"""
Returns a named tuple holding parameter values, to be used with the simulation method of a setup
Expand All @@ -611,34 +748,13 @@ def create_set(setup, valuetype='optguess', **kwargs):
params = get_parameters_array(setup)

# Create the namedtuple from the parameter names
partype = get_namedtuple_from_paramnames(setup, params['name'])

# Use the generated values from the distribution
pardict = dict(zip(params['name'], params[valuetype]))

# Overwrite parameters with keyword arguments
pardict.update(kwargs)
partype = ParameterSet(params)

# Return the namedtuple with fitting names
return partype(**pardict)


def get_namedtuple_from_paramnames(owner, parnames):
"""
Returns the namedtuple classname for parameter names
:param owner: Owner of the parameters, usually the spotpy setup
:param parnames: Sequence of parameter names
:return: Class
"""

# Get name of owner class
typename = type(owner).__name__
parnames = ["p" + x if x.isdigit() else x for x in list(parnames)]
return namedtuple('Par_' + typename, # Type name created from the setup name
parnames) # get parameter names
return partype(*params[valuetype], **kwargs)


def get_constant_indices(setup, unaccepted_parameter_types=(Constant)):
def get_constant_indices(setup, unaccepted_parameter_types=(Constant,)):
"""
Returns a list of the class defined parameters, and
overwrites the names of the parameters.
Expand Down
76 changes: 73 additions & 3 deletions spotpy/unittests/test_setup_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,86 @@ class SpotSetupParameterList(SpotSetupBase):
def __init__(self):
self.parameters = [parameter.Uniform(name, -1, 1) for name in 'abcd']


class SpotSetupMixedParameterList(SpotSetupBase):
"""
A Test case with two parameters as class parameters (a,b)
and 2 given from the parameter function
"""
a = parameter.Uniform(0, 1)
b = parameter.Uniform(1, 2)

def parameters(self):
return parameter.generate([parameter.Uniform(name, -1, 1) for name in 'cd'])



class TestParameterSet(unittest.TestCase):
def setUp(self):
model = SpotSetupParameterFunction()
param_info = model.parameters()
self.ps = parameter.ParameterSet(param_info)

def test_create(self):
self.assertEqual(type(self.ps), parameter.ParameterSet)

def test_assign(self):
values = [1] * len(self.ps)
self.ps(*values)
self.assertEquals(list(self.ps), values)
# Test if wrong number of parameters raises
with self.assertRaises(ValueError):
self.ps(*values[:-1])

def test_iter(self):
values = [1] * len(self.ps)
self.ps(*values)
ps_values = list(self.ps)
self.assertEquals(values, ps_values)

def test_getitem(self):
values = [1] * len(self.ps)
self.ps(*values)
self.assertEquals(self.ps['a'], 1.0)
self.assertEquals(self.ps[0], 1.0)

def test_getattr(self):
values = [1] * len(self.ps)
self.ps(*values)

with self.assertRaises(AttributeError):
_ = self.ps.__x

self.assertEquals(self.ps.a, 1.0)
self.assertEquals(list(self.ps.random), list(self.ps), 'Access to random variable does not equal list of names')

with self.assertRaises(AttributeError):
_ = self.ps.x

def test_setattr(self):
self.ps.a = 2
self.assertEquals(self.ps[0], 2)

def test_dir(self):
values = [1] * len(self.ps)
self.ps(*values)

attrs = dir(self.ps)
for param in self.ps.name:
self.assertIn(param, attrs, 'Attribute {} not found in {}'.format(param, self.ps))
for prop in ['maxbound', 'minbound', 'name', 'optguess', 'random', 'step']:
self.assertIn(prop, attrs, 'Property {} not found in {}'.format(prop, self.ps))

def test_str(self):
values = [1] * len(self.ps)
self.ps(*values)
self.assertEquals(str(self.ps), 'parameters(a=1, b=1, c=1, d=1)')

def test_repr(self):
values = [1] * len(self.ps)
self.ps(*values)
self.assertEquals(repr(self.ps), 'spotpy.parameter.ParameterSet()')


class TestSetupVariants(unittest.TestCase):
def setUp(self):
# Get all Setups from this module
Expand All @@ -102,8 +172,8 @@ def test_exists(self):
self.assertGreater(len(self.objects), 0)

def parameter_count_test(self, o):
params = parameter.create_set(o)
param_names = ','.join(pn for pn in params._fields)
params = parameter.create_set(o, valuetype='optguess')
param_names = ','.join(pn for pn in params.name)
self.assertEqual(len(params), 4, '{} should have 4 parameters, but found only {} ({})'
.format(o, len(params), param_names))
self.assertEqual(param_names, 'a,b,c,d', '{} Parameter names should be "a,b,c,d" but got "{}"'
Expand Down