Skip to content

Commit

Permalink
Merge pull request #196 from philippkraft/classParameterSet
Browse files Browse the repository at this point in the history
Class parameter set, the second try
  • Loading branch information
philippkraft authored Nov 12, 2018
2 parents 91fecef + 0e0e423 commit e085278
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 35 deletions.
3 changes: 1 addition & 2 deletions spotpy/algorithms/_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ def __init__(self, spot_setup, dbname=None, dbformat=None, dbinit=True,
self.parnames = param_info['name']

# Create a type to hold the parameter values using a namedtuple
self.partype = parameter.get_namedtuple_from_paramnames(
self.setup, self.parnames)
self.partype = parameter.ParameterSet(param_info)

# use alt_objfun if alt_objfun is defined in objectivefunctions,
# else self.setup.objectivefunction
Expand Down
2 changes: 1 addition & 1 deletion spotpy/gui/mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def run(self, _=None):
sim = self.setup.simulation(parset)
objf = as_scalar(self.setup.objectivefunction(sim, self.setup.evaluation()))
label = ('{:0.4g}=M('.format(objf)
+ ', '.join('{f}={v:0.4g}'.format(f=f, v=v) for f, v in zip(parset._fields, parset))
+ ', '.join('{f}={v:0.4g}'.format(f=f, v=v) for f, v in zip(parset.name, parset))
+ ')')
self.lines.extend(self.ax.plot(sim, '-', label=label))
self.ax.legend()
Expand Down
174 changes: 145 additions & 29 deletions spotpy/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

if sys.version_info[0] >= 3:
unicode = str

from collections import namedtuple
from itertools import cycle


Expand Down Expand Up @@ -517,6 +515,146 @@ def __init__(self, *args, **kwargs):
"""
super(Triangular, self).__init__(rnd.triangular, 'Triangular', *args, **kwargs)


class ParameterSet(object):
"""
A Pickable parameter set to use named parameters in a setup
Is not created by a user directly, but in algorithm.
Older versions used a namedtuple, which is not pickable.
An instance of ParameterSet is sent to the users setup.simulate method.
Usage:
>>> ps = ParameterSet(...)
Update values by arguments or keyword arguments
>>> ps(0, 1, 2)
>>> ps(a=1, c=2)
Assess parameter values of this parameter set
>>> ps[0] == ps['a'] == ps.a
A parameter set is a sequence:
>>> list(ps)
Assess the parameter set properties as arrays
>>> [ps.maxbound, ps.minbound, ps.optguess, ps.step, ps.random]
"""
def __init__(self, param_info):
"""
Creates a set of parameters from a parameter info array.
To create the parameter set from a setup use either:
>>> setup = ...
>>> ps = ParameterSet(get_parameters_array(setup))
or you can just use a function for this:
>>> ps = create_set(setup)
:param param_info: A record array containing the properties of the parameters
of this set.
"""
self.__lookup = dict(("p" + x if x.isdigit() else x, i) for i, x in enumerate(param_info['name']))
self.__info = param_info

def __call__(self, *values, **kwargs):
"""
Populates the values ('random') of the parameter set with new data
:param values: Contains the new values or omitted.
If given, the number of values needs to match the number
of parameters
:param kwargs: Can be used to set only single parameter values
:return:
"""
if values:
if len(self.__info) != len(values):
raise ValueError('Given values do are not the same length as the parameter set')
self.__info['random'][:] = values
for k in kwargs:
try:
self.__info['random'][self.__lookup[k]] = kwargs[k]
except KeyError:
raise TypeError('{} is not a parameter of this set'.format(k))
return self

def __len__(self):
return len(self.__info['random'])

def __iter__(self):
return iter(self.__info['random'])

def __getitem__(self, item):
"""
Provides item access
>>> ps[0] == ps['a']
:raises: KeyError, IndexError and TypeError
"""
if type(item) is str:
item = self.__lookup[item]
return self.__info['random'][item]

def __setitem__(self, key, value):
"""
Provides setting of item
>>> ps[0] = 1
>>> ps['a'] = 2
"""
if key in self.__lookup:
key = self.__lookup[key]
self.__info['random'][key] = value

def __getattr__(self, item):
"""
Provides the attribute access like
>>> print(ps.a)
"""
if item.startswith('_'):
raise AttributeError('{} is not a member of this parameter set'.format(item))
elif item in self.__lookup:
return self.__info['random'][self.__lookup[item]]
elif item in self.__info.dtype.names:
return self.__info[item]
else:
raise AttributeError('{} is not a member of this parameter set'.format(item))

def __setattr__(self, key, value):
"""
Provides setting of attributes
>>> ps.a = 2
"""
# Allow normal usage
if key.startswith('_') or key not in self.__lookup:
return object.__setattr__(self, key, value)
else:
self.__info['random'][self.__lookup[key]] = value

def __str__(self):
return 'parameters({})'.format(
', '.join('{}={:g}'.format(k, self.__info['random'][i])
for i, k in enumerate(self.__info['name'])
)
)

def __repr__(self):
return 'spotpy.parameter.ParameterSet()'

def __dir__(self):
"""
Helps to show the field names in an interactive environment like IPython.
See: http://ipython.readthedocs.io/en/stable/config/integrating.html
:return: List of method names and fields
"""
attrs = [attr for attr in vars(type(self)) if not attr.startswith('_')]
return attrs + list(self.__info['name']) + list(self.__info.dtype.names)


def get_classes():
keys = []
current_module = sys.modules[__name__]
Expand All @@ -525,6 +663,7 @@ def get_classes():
keys.append(key)
return keys


def generate(parameters):
"""
This function generates a parameter set from a list of parameter objects. The parameter set
Expand Down Expand Up @@ -561,7 +700,6 @@ def get_parameters_array(setup, unaccepted_parameter_types=()):
# function
param_arrays = []
# Get parameters defined with the setup class
#setup_parameters = checked_parameter_types(, unaccepted_parameter_types)
setup_parameters = get_parameters_from_setup(setup)
check_parameter_types(setup_parameters, unaccepted_parameter_types)
param_arrays.append(
Expand All @@ -587,8 +725,7 @@ def find_constant_parameters(parameter_array):
return (parameter_array['maxbound'] - parameter_array['minbound'] == 0.0)



def create_set(setup, valuetype='optguess', **kwargs):
def create_set(setup, valuetype='random', **kwargs):
"""
Returns a named tuple holding parameter values, to be used with the simulation method of a setup
Expand All @@ -611,34 +748,13 @@ def create_set(setup, valuetype='optguess', **kwargs):
params = get_parameters_array(setup)

# Create the namedtuple from the parameter names
partype = get_namedtuple_from_paramnames(setup, params['name'])

# Use the generated values from the distribution
pardict = dict(zip(params['name'], params[valuetype]))

# Overwrite parameters with keyword arguments
pardict.update(kwargs)
partype = ParameterSet(params)

# Return the namedtuple with fitting names
return partype(**pardict)


def get_namedtuple_from_paramnames(owner, parnames):
"""
Returns the namedtuple classname for parameter names
:param owner: Owner of the parameters, usually the spotpy setup
:param parnames: Sequence of parameter names
:return: Class
"""

# Get name of owner class
typename = type(owner).__name__
parnames = ["p" + x if x.isdigit() else x for x in list(parnames)]
return namedtuple('Par_' + typename, # Type name created from the setup name
parnames) # get parameter names
return partype(*params[valuetype], **kwargs)


def get_constant_indices(setup, unaccepted_parameter_types=(Constant)):
def get_constant_indices(setup, unaccepted_parameter_types=(Constant,)):
"""
Returns a list of the class defined parameters, and
overwrites the names of the parameters.
Expand Down
76 changes: 73 additions & 3 deletions spotpy/unittests/test_setup_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,86 @@ class SpotSetupParameterList(SpotSetupBase):
def __init__(self):
self.parameters = [parameter.Uniform(name, -1, 1) for name in 'abcd']


class SpotSetupMixedParameterList(SpotSetupBase):
"""
A Test case with two parameters as class parameters (a,b)
and 2 given from the parameter function
"""
a = parameter.Uniform(0, 1)
b = parameter.Uniform(1, 2)

def parameters(self):
return parameter.generate([parameter.Uniform(name, -1, 1) for name in 'cd'])



class TestParameterSet(unittest.TestCase):
def setUp(self):
model = SpotSetupParameterFunction()
param_info = model.parameters()
self.ps = parameter.ParameterSet(param_info)

def test_create(self):
self.assertEqual(type(self.ps), parameter.ParameterSet)

def test_assign(self):
values = [1] * len(self.ps)
self.ps(*values)
self.assertEquals(list(self.ps), values)
# Test if wrong number of parameters raises
with self.assertRaises(ValueError):
self.ps(*values[:-1])

def test_iter(self):
values = [1] * len(self.ps)
self.ps(*values)
ps_values = list(self.ps)
self.assertEquals(values, ps_values)

def test_getitem(self):
values = [1] * len(self.ps)
self.ps(*values)
self.assertEquals(self.ps['a'], 1.0)
self.assertEquals(self.ps[0], 1.0)

def test_getattr(self):
values = [1] * len(self.ps)
self.ps(*values)

with self.assertRaises(AttributeError):
_ = self.ps.__x

self.assertEquals(self.ps.a, 1.0)
self.assertEquals(list(self.ps.random), list(self.ps), 'Access to random variable does not equal list of names')

with self.assertRaises(AttributeError):
_ = self.ps.x

def test_setattr(self):
self.ps.a = 2
self.assertEquals(self.ps[0], 2)

def test_dir(self):
values = [1] * len(self.ps)
self.ps(*values)

attrs = dir(self.ps)
for param in self.ps.name:
self.assertIn(param, attrs, 'Attribute {} not found in {}'.format(param, self.ps))
for prop in ['maxbound', 'minbound', 'name', 'optguess', 'random', 'step']:
self.assertIn(prop, attrs, 'Property {} not found in {}'.format(prop, self.ps))

def test_str(self):
values = [1] * len(self.ps)
self.ps(*values)
self.assertEquals(str(self.ps), 'parameters(a=1, b=1, c=1, d=1)')

def test_repr(self):
values = [1] * len(self.ps)
self.ps(*values)
self.assertEquals(repr(self.ps), 'spotpy.parameter.ParameterSet()')


class TestSetupVariants(unittest.TestCase):
def setUp(self):
# Get all Setups from this module
Expand All @@ -102,8 +172,8 @@ def test_exists(self):
self.assertGreater(len(self.objects), 0)

def parameter_count_test(self, o):
params = parameter.create_set(o)
param_names = ','.join(pn for pn in params._fields)
params = parameter.create_set(o, valuetype='optguess')
param_names = ','.join(pn for pn in params.name)
self.assertEqual(len(params), 4, '{} should have 4 parameters, but found only {} ({})'
.format(o, len(params), param_names))
self.assertEqual(param_names, 'a,b,c,d', '{} Parameter names should be "a,b,c,d" but got "{}"'
Expand Down

0 comments on commit e085278

Please sign in to comment.