-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #533 from Vasilev-Ilya/STUD-7_metropolis_hastings_…
…sampler Draft: STUD-7: Metropolis-Hastings sampler implementation
- Loading branch information
Showing
2 changed files
with
126 additions
and
0 deletions.
There are no files selected for viewing
50 changes: 50 additions & 0 deletions
50
kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/* | ||
* Copyright 2018-2024 KMath contributors. | ||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. | ||
*/ | ||
|
||
package space.kscience.kmath.samplers | ||
|
||
import space.kscience.kmath.chains.BlockingDoubleChain | ||
import space.kscience.kmath.distributions.Distribution1D | ||
import space.kscience.kmath.distributions.NormalDistribution | ||
import space.kscience.kmath.random.RandomGenerator | ||
import space.kscience.kmath.structures.Float64Buffer | ||
import kotlin.math.* | ||
|
||
/** | ||
* [Metropolis–Hastings algorithm](https://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm) for sampling | ||
* target distribution [targetDist]. | ||
* | ||
* The normal distribution is used as the proposal function. | ||
* | ||
* params: | ||
* - targetDist: function close to the density of the sampled distribution; | ||
* - initialState: initial value of the chain of sampled values; | ||
* - proposalStd: standard deviation of the proposal function. | ||
*/ | ||
public class MetropolisHastingsSampler( | ||
public val targetDist: (arg : Double) -> Double, | ||
public val initialState : Double = 0.0, | ||
public val proposalStd : Double = 1.0, | ||
) : BlockingDoubleSampler { | ||
override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain { | ||
var currentState = initialState | ||
fun proposalDist(arg : Double) = NormalDistribution(arg, proposalStd) | ||
|
||
override fun nextBufferBlocking(size: Int): Float64Buffer { | ||
val acceptanceProb = generator.nextDoubleBuffer(size) | ||
|
||
return Float64Buffer(size) {index -> | ||
val newState = proposalDist(currentState).sample(generator).nextBufferBlocking(1).get(0) | ||
val acceptanceRatio = min(1.0, targetDist(newState) / targetDist(currentState)) | ||
|
||
currentState = if (acceptanceProb[index] <= acceptanceRatio) newState else currentState | ||
currentState | ||
} | ||
} | ||
|
||
override suspend fun fork(): BlockingDoubleChain = sample(generator.fork()) | ||
} | ||
|
||
} |
76 changes: 76 additions & 0 deletions
76
...stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsSampler.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
/* | ||
* Copyright 2018-2024 KMath contributors. | ||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. | ||
*/ | ||
|
||
package space.kscience.kmath.samplers | ||
import space.kscience.kmath.distributions.NormalDistribution | ||
import space.kscience.kmath.operations.Float64Field | ||
import space.kscience.kmath.random.DefaultGenerator | ||
import space.kscience.kmath.stat.invoke | ||
import space.kscience.kmath.stat.mean | ||
import kotlin.math.exp | ||
import kotlin.math.pow | ||
import kotlin.test.Test | ||
import kotlin.test.assertEquals | ||
|
||
class TestMetropolisHastingsSampler { | ||
|
||
@Test | ||
fun samplingNormalTest() { | ||
fun normalDist1(arg : Double) = NormalDistribution(0.5, 1.0).probability(arg) | ||
var sampler = MetropolisHastingsSampler(::normalDist1, proposalStd = 1.0) | ||
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000) | ||
|
||
assertEquals(0.5, Float64Field.mean(sampledValues), 1e-2) | ||
|
||
fun normalDist2(arg : Double) = NormalDistribution(68.13, 1.0).probability(arg) | ||
sampler = MetropolisHastingsSampler(::normalDist2, initialState = 63.0, proposalStd = 1.0) | ||
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000) | ||
|
||
assertEquals(68.13, Float64Field.mean(sampledValues), 1e-2) | ||
} | ||
|
||
@Test | ||
fun samplingExponentialTest() { | ||
fun expDist(arg : Double, param : Double) : Double { | ||
if (arg < 0.0) { return 0.0 } | ||
return param * exp(-param * arg) | ||
} | ||
|
||
fun expDist1(arg : Double) = expDist(arg, 0.5) | ||
var sampler = MetropolisHastingsSampler(::expDist1, initialState = 2.0, proposalStd = 1.0) | ||
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000) | ||
|
||
assertEquals(2.0, Float64Field.mean(sampledValues), 1e-2) | ||
|
||
fun expDist2(arg : Double) = expDist(arg, 2.0) | ||
sampler = MetropolisHastingsSampler(::expDist2, initialState = 9.0, proposalStd = 1.0) | ||
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000) | ||
|
||
assertEquals(0.5, Float64Field.mean(sampledValues), 1e-2) | ||
|
||
} | ||
|
||
@Test | ||
fun samplingRayleighTest() { | ||
fun rayleighDist(arg : Double, sigma : Double) : Double { | ||
if (arg < 0.0) { return 0.0 } | ||
|
||
val expArg = (arg / sigma).pow(2) | ||
return arg * exp(-expArg / 2.0) / sigma.pow(2) | ||
} | ||
|
||
fun rayleighDist1(arg : Double) = rayleighDist(arg, 1.0) | ||
var sampler = MetropolisHastingsSampler(::rayleighDist1, initialState = 2.0, proposalStd = 1.0) | ||
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000) | ||
|
||
assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2) | ||
|
||
fun rayleighDist2(arg : Double) = rayleighDist(arg, 2.0) | ||
sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0) | ||
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000) | ||
|
||
assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2) | ||
} | ||
} |