Skip to content

Commit

Permalink
Update Test script for rosenbrock setup #213
Browse files Browse the repository at this point in the history
  • Loading branch information
thouska committed May 17, 2019
1 parent ac76d02 commit ea0be90
Showing 1 changed file with 14 additions and 17 deletions.
31 changes: 14 additions & 17 deletions spotpy/unittests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,93 +26,90 @@ def setUp(self):
# How many digits to match in case of floating point answers
self.tolerance = 7
#Create samplers for every algorithm:
self.spot_setup = spot_setup()
self.rep = 987
self.timeout = 10 #Given in Seconds

self.parallel = os.environ.get('SPOTPY_PARALLEL', 'seq')
self.dbformat = "ram"

def test_mc(self):
sampler=spotpy.algorithms.mc(self.spot_setup,parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler=spotpy.algorithms.mc(spot_setup(),parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler.sample(self.rep)
results = sampler.getdata()
self.assertEqual(len(results), self.rep)

def test_lhs(self):
sampler=spotpy.algorithms.lhs(self.spot_setup,parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler=spotpy.algorithms.lhs(spot_setup(),parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler.sample(self.rep)
results = sampler.getdata()
self.assertEqual(len(results), self.rep)

def test_mle(self):
sampler=spotpy.algorithms.mle(self.spot_setup,parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler=spotpy.algorithms.mle(spot_setup(),parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler.sample(self.rep)
results = sampler.getdata()
self.assertEqual(len(results), self.rep)

def test_mcmc(self):
sampler=spotpy.algorithms.mcmc(self.spot_setup,parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler=spotpy.algorithms.mcmc(spot_setup(used_algorithm='mcmc'),parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler.sample(self.rep)
results = sampler.getdata()
self.assertEqual(len(results), self.rep)

def test_demcz(self):
sampler=spotpy.algorithms.demcz(self.spot_setup,parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler=spotpy.algorithms.demcz(spot_setup(used_algorithm='demcz'),parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler.sample(self.rep, convergenceCriteria=0)
results = sampler.getdata()
self.assertEqual(len(results), self.rep)

def test_dream(self):
sampler=spotpy.algorithms.dream(self.spot_setup,parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler=spotpy.algorithms.dream(spot_setup(used_algorithm='dream'),parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler.sample(self.rep)
results = sampler.getdata()
self.assertEqual(len(results), self.rep)

def test_sceua(self):
sampler=spotpy.algorithms.sceua(self.spot_setup,parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler=spotpy.algorithms.sceua(spot_setup(used_algorithm='sceua'),parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler.sample(self.rep)
results = sampler.getdata()
self.assertLessEqual(len(results), self.rep) #Sceua save per definition not all sampled runs

def test_abc(self):
sampler=spotpy.algorithms.abc(self.spot_setup,parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler=spotpy.algorithms.abc(spot_setup(used_algorithm='abc'),parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler.sample(self.rep)
results = sampler.getdata()
self.assertEqual(len(results), self.rep)

def test_fscabc(self):
sampler=spotpy.algorithms.fscabc(self.spot_setup,parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler=spotpy.algorithms.fscabc(spot_setup(used_algorithm='fscabc'),parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler.sample(self.rep)
results = sampler.getdata()
self.assertEqual(len(results), self.rep)

def test_rope(self):
sampler=spotpy.algorithms.rope(self.spot_setup,parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler=spotpy.algorithms.rope(spot_setup(),parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler.sample(self.rep)
results = sampler.getdata()
self.assertEqual(len(results), self.rep)

def test_sa(self):
sampler=spotpy.algorithms.sa(self.spot_setup,parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler=spotpy.algorithms.sa(spot_setup(),parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler.sample(self.rep)
results = sampler.getdata()
self.assertEqual(len(results), self.rep)

def test_list(self):
#generate a List sampler input
print(self.spot_setup.simulation)
sampler=spotpy.algorithms.mc(self.spot_setup,parallel=self.parallel, dbname='Rosen', dbformat='csv', sim_timeout=self.timeout)
sampler=spotpy.algorithms.mc(spot_setup(),parallel=self.parallel, dbname='Rosen', dbformat='csv', sim_timeout=self.timeout)
sampler.sample(self.rep)

print(self.spot_setup.simulation)
sampler=spotpy.algorithms.list_sampler(self.spot_setup,parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler=spotpy.algorithms.list_sampler(spot_setup(),parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler.sample(self.rep)
results = sampler.getdata()
self.assertEqual(len(results), self.rep)

def test_fast(self):
sampler=spotpy.algorithms.fast(self.spot_setup,parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler=spotpy.algorithms.fast(spot_setup(),parallel=self.parallel, dbname='Rosen', dbformat=self.dbformat, sim_timeout=self.timeout)
sampler.sample(self.rep, M=5)
results = sampler.getdata()
self.assertEqual(len(results), self.rep) #Si values should be returned
Expand Down

0 comments on commit ea0be90

Please sign in to comment.