Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: STUD-7: Metropolis-Hastings sampler implementation #533

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/

package space.kscience.kmath.samplers

import space.kscience.kmath.chains.BlockingDoubleChain
import space.kscience.kmath.distributions.Distribution1D
import space.kscience.kmath.distributions.NormalDistribution
import space.kscience.kmath.random.RandomGenerator
import space.kscience.kmath.structures.Float64Buffer
import kotlin.math.*

/**
* [Metropolis–Hastings algorithm](https://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm) for sampling
* target distribution [targetDist].
*
* The normal distribution is used as the proposal function.
*
* params:
* - targetDist: function close to the density of the sampled distribution;
* - initialState: initial value of the chain of sampled values;
* - proposalStd: standard deviation of the proposal function.
*/
public class MetropolisHastingsSampler(
public val targetDist: (arg : Double) -> Double,
public val initialState : Double = 0.0,
public val proposalStd : Double = 1.0,
) : BlockingDoubleSampler {
override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain {
var currentState = initialState
fun proposalDist(arg : Double) = NormalDistribution(arg, proposalStd)

override fun nextBufferBlocking(size: Int): Float64Buffer {
val acceptanceProb = generator.nextDoubleBuffer(size)

return Float64Buffer(size) {index ->
val newState = proposalDist(currentState).sample(generator).nextBufferBlocking(1).get(0)
val acceptanceRatio = min(1.0, targetDist(newState) / targetDist(currentState))

currentState = if (acceptanceProb[index] <= acceptanceRatio) newState else currentState
currentState
}
}

override suspend fun fork(): BlockingDoubleChain = sample(generator.fork())
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/

package space.kscience.kmath.samplers
import space.kscience.kmath.distributions.NormalDistribution
import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.random.DefaultGenerator
import space.kscience.kmath.stat.invoke
import space.kscience.kmath.stat.mean
import kotlin.math.exp
import kotlin.math.pow
import kotlin.test.Test
import kotlin.test.assertEquals

class TestMetropolisHastingsSampler {

@Test
fun samplingNormalTest() {
fun normalDist1(arg : Double) = NormalDistribution(0.5, 1.0).probability(arg)
var sampler = MetropolisHastingsSampler(::normalDist1, proposalStd = 1.0)
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)

assertEquals(0.5, Float64Field.mean(sampledValues), 1e-2)

fun normalDist2(arg : Double) = NormalDistribution(68.13, 1.0).probability(arg)
sampler = MetropolisHastingsSampler(::normalDist2, initialState = 63.0, proposalStd = 1.0)
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)

assertEquals(68.13, Float64Field.mean(sampledValues), 1e-2)
}

@Test
fun samplingExponentialTest() {
fun expDist(arg : Double, param : Double) : Double {
if (arg < 0.0) { return 0.0 }
return param * exp(-param * arg)
}

fun expDist1(arg : Double) = expDist(arg, 0.5)
var sampler = MetropolisHastingsSampler(::expDist1, initialState = 2.0, proposalStd = 1.0)
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)

assertEquals(2.0, Float64Field.mean(sampledValues), 1e-2)

fun expDist2(arg : Double) = expDist(arg, 2.0)
sampler = MetropolisHastingsSampler(::expDist2, initialState = 9.0, proposalStd = 1.0)
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)

assertEquals(0.5, Float64Field.mean(sampledValues), 1e-2)

}

@Test
fun samplingRayleighTest() {
fun rayleighDist(arg : Double, sigma : Double) : Double {
if (arg < 0.0) { return 0.0 }

val expArg = (arg / sigma).pow(2)
return arg * exp(-expArg / 2.0) / sigma.pow(2)
}

fun rayleighDist1(arg : Double) = rayleighDist(arg, 1.0)
var sampler = MetropolisHastingsSampler(::rayleighDist1, initialState = 2.0, proposalStd = 1.0)
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)

assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2)

fun rayleighDist2(arg : Double) = rayleighDist(arg, 2.0)
sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0)
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000)

assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2)
}
}
Loading