Skip to content

Commit

Permalink
Add common orders of Matérn covariance function (#402)
Browse files Browse the repository at this point in the history
* Add matern covariance functions

* Add matern covariance function tests

Co-authored-by: kleeman <[email protected]>
  • Loading branch information
peddie and akleeman authored Jan 24, 2023
1 parent d43ff45 commit 9dc845b
Show file tree
Hide file tree
Showing 3 changed files with 402 additions and 0 deletions.
89 changes: 89 additions & 0 deletions include/albatross/src/covariance_functions/radial.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,5 +143,94 @@ class Exponential : public CovarianceFunction<Exponential<DistanceMetricType>> {
DistanceMetricType distance_metric_;
};

inline double matern_32_covariance(double distance, double length_scale,
double sigma = 1.) {
if (length_scale <= 0.) {
return 0.;
}
assert(distance >= 0.);
const double sqrt_3_d = std::sqrt(3.) * distance / length_scale;
return sigma * sigma * (1 + sqrt_3_d) * exp(-sqrt_3_d);
}

template <class DistanceMetricType>
class Matern32 : public CovarianceFunction<Matern32<DistanceMetricType>> {
public:
// The Matern nu = 3/2 radial function is not positive definite
// when the distance is an angular (or great circle) distance.
static_assert(!std::is_base_of<AngularDistance, DistanceMetricType>::value,
"Matern32 covariance with AngularDistance is not PSD.");

ALBATROSS_DECLARE_PARAMS(matern_32_length_scale, sigma_matern_32);

Matern32(double length_scale_ = default_length_scale,
double sigma_matern_32_ = default_radial_sigma)
: distance_metric_() {
matern_32_length_scale = {length_scale_, PositivePrior()};
sigma_matern_32 = {sigma_matern_32_, NonNegativePrior()};
};

std::string name() const {
return "matern_32[" + this->distance_metric_.get_name() + "]";
}

template <typename X,
typename std::enable_if<
has_call_operator<DistanceMetricType, X &, X &>::value,
int>::type = 0>
double _call_impl(const X &x, const X &y) const {
double distance = this->distance_metric_(x, y);
return matern_32_covariance(distance, matern_32_length_scale.value,
sigma_matern_32.value);
}

DistanceMetricType distance_metric_;
};

inline double matern_52_covariance(double distance, double length_scale,
double sigma = 1.) {
if (length_scale <= 0.) {
return 0.;
}
assert(distance >= 0.);
const double sqrt_5_d = std::sqrt(5.) * distance / length_scale;
return sigma * sigma * (1 + sqrt_5_d + sqrt_5_d * sqrt_5_d / 3.) *
exp(-sqrt_5_d);
}

template <class DistanceMetricType>
class Matern52 : public CovarianceFunction<Matern52<DistanceMetricType>> {
public:
// The Matern nu = 5/2 radial function is not positive definite
// when the distance is an angular (or great circle) distance.
static_assert(!std::is_base_of<AngularDistance, DistanceMetricType>::value,
"Matern52 covariance with AngularDistance is not PSD.");

ALBATROSS_DECLARE_PARAMS(matern_52_length_scale, sigma_matern_52);

Matern52(double length_scale_ = default_length_scale,
double sigma_matern_52_ = default_radial_sigma)
: distance_metric_() {
matern_52_length_scale = {length_scale_, PositivePrior()};
sigma_matern_52 = {sigma_matern_52_, NonNegativePrior()};
};

std::string name() const {
return "matern_52[" + this->distance_metric_.get_name() + "]";
}

template <typename X,
typename std::enable_if<
has_call_operator<DistanceMetricType, X &, X &>::value,
int>::type = 0>
double _call_impl(const X &x, const X &y) const {
double distance = this->distance_metric_(x, y);
return matern_52_covariance(distance, matern_52_length_scale.value,
sigma_matern_52.value);
}

DistanceMetricType distance_metric_;
};

} // namespace albatross
#endif
31 changes: 31 additions & 0 deletions python/gpytorch_covariance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import gpytorch.kernels as k
import torch
from gpytorch.functions import MaternCovariance


def tens(x):
return torch.tensor([x], dtype=torch.float64)


def matern(points, lengthscale, order=2.5):
m = k.MaternKernel(order, lengthscale=lengthscale)
mc = MaternCovariance()
xs = tens(points).t()
return mc.apply(xs, xs, tens(lengthscale), tens(order),
lambda x1, x2: m.covar_dist(x1, x2))


LENGTHSCALE=22.2

POINTS = [-100, -10, -5, -2, -1, -1e-2, -1e-5, 0, 1e-5, 1e-2, 1, 2, 5, 10, 100]

if __name__ == "__main__":
torch.set_printoptions(precision=16)
print(f"Length scale: {tens(LENGTHSCALE)}")
print(f"Evaluation points ({len(POINTS)}):")
print(POINTS)
print("Matern 5/2:")
print(matern(POINTS, LENGTHSCALE, order=2.5))
print("Matern 3/2:")
print(matern(POINTS, LENGTHSCALE, order=1.5))

Loading

0 comments on commit 9dc845b

Please sign in to comment.