Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Compiler Errors #133

Merged
merged 5 commits into from
Jun 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/albatross/Common
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,8 @@
#include "src/utils/map_utils.hpp"
#include "src/cereal/eigen.hpp"

#include "src/details/traits.hpp"
#include "src/details/has_any_macros.hpp"
#include "src/details/error_handling.hpp"

#endif
9 changes: 0 additions & 9 deletions include/albatross/src/cereal/traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,6 @@

namespace albatross {

/*
* This little trick was borrowed from cereal, you can think of it as
* a function that will always return false ... but that doesn't
* get resolved until template instantiation, which when combined
* with a static assert let's you include a static assert that
* only triggers with a particular template parameter is used.
*/
template <class T> struct delay_static_assert : std::false_type {};

/*
* The following helper functions let you inspect a type and cereal Archive
* and determine if the type has a valid serialization method for that Archive
Expand Down
23 changes: 17 additions & 6 deletions include/albatross/src/core/model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ using Insights = std::map<std::string, std::string>;

template <typename ModelType> class ModelBase : public ParameterHandlingMixin {

template <typename X, typename Y, typename Z> friend class Prediction;
friend class JointPredictor;
friend class MarginalPredictor;
friend class MeanPredictor;

template <typename T, typename FeatureType> friend class fit_model_type;

Expand All @@ -46,15 +48,21 @@ template <typename ModelType> class ModelBase : public ParameterHandlingMixin {
!has_valid_fit<ModelType, FeatureType>::value,
int>::type = 0>
void _fit(const std::vector<FeatureType> &features,
const MarginalDistribution &targets) const = delete; // Invalid fit
const MarginalDistribution &targets) const
ALBATROSS_FAIL(FeatureType,
"The ModelType *almost* has a _fit_impl method for "
"FeatureType, but it appears to be invalid");

template <typename FeatureType,
typename std::enable_if<
!has_possible_fit<ModelType, FeatureType>::value &&
!has_valid_fit<ModelType, FeatureType>::value,
int>::type = 0>
void _fit(const std::vector<FeatureType> &features,
const MarginalDistribution &targets) const = delete;
const MarginalDistribution &targets) const
ALBATROSS_FAIL(
FeatureType,
"The ModelType is missing a _fit_impl method for FeatureType.");

template <
typename PredictFeatureType, typename FitType, typename PredictType,
Expand All @@ -73,9 +81,12 @@ template <typename ModelType> class ModelBase : public ParameterHandlingMixin {
typename std::enable_if<!has_valid_predict<ModelType, PredictFeatureType,
FitType, PredictType>::value,
int>::type = 0>
PredictType predict_(
const std::vector<PredictFeatureType> &features, const FitType &fit,
PredictTypeIdentity<PredictType> &&) const = delete; // No valid predict.
PredictType predict_(const std::vector<PredictFeatureType> &features,
const FitType &fit,
PredictTypeIdentity<PredictType> &&) const
ALBATROSS_FAIL(PredictFeatureType,
"The ModelType is missing a _predict_impl method for "
"PredictFeatureType, FitType, PredictType.");

public:
/*
Expand Down
198 changes: 122 additions & 76 deletions include/albatross/src/core/prediction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,129 +19,175 @@ namespace albatross {
// which behave different conditional on the type of predictions desired.
template <typename T> struct PredictTypeIdentity { typedef T type; };

template <typename ModelType, typename FeatureType, typename FitType>
class Prediction {

/*
* MeanPredictor is responsible for determining if a valid form of
* predicting exists for a given set of model, feature, fit. The
* primary goal of the class is to consolidate all the logic required
* to decide if different predict types are available. For example,
* by inspecting this class for a _mean method you can determine if
* any valid mean prediction method exists.
*/
class MeanPredictor {
public:
Prediction(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features)
: model_(model), fit_(fit), features_(features) {}

Prediction(const ModelType &model, const FitType &fit,
std::vector<FeatureType> &&features)
: model_(model), fit_(fit), features_(std::move(features)) {}

// Mean
template <typename DummyType = FeatureType,
template <typename ModelType, typename FeatureType, typename FitType,
typename std::enable_if<
has_valid_predict_mean<ModelType, DummyType, FitType>::value,
has_valid_predict_mean<ModelType, FeatureType, FitType>::value,
int>::type = 0>
Eigen::VectorXd mean() const {
static_assert(std::is_same<DummyType, FeatureType>::value,
"never do prediction.mean<T>()");
return model_.predict_(features_, fit_,
PredictTypeIdentity<Eigen::VectorXd>());
Eigen::VectorXd _mean(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features) const {
return model.predict_(features, fit,
PredictTypeIdentity<Eigen::VectorXd>());
}

template <
typename DummyType = FeatureType,
typename ModelType, typename FeatureType, typename FitType,
typename std::enable_if<
!has_valid_predict_mean<ModelType, DummyType, FitType>::value &&
has_valid_predict_marginal<ModelType, DummyType, FitType>::value,
!has_valid_predict_mean<ModelType, FeatureType, FitType>::value &&
has_valid_predict_marginal<ModelType, FeatureType,
FitType>::value,
int>::type = 0>
Eigen::VectorXd mean() const {
static_assert(std::is_same<DummyType, FeatureType>::value,
"never do prediction.mean<T>()");
return model_
.predict_(features_, fit_, PredictTypeIdentity<MarginalDistribution>())
Eigen::VectorXd _mean(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features) const {
return model
.predict_(features, fit, PredictTypeIdentity<MarginalDistribution>())
.mean;
}

template <
typename DummyType = FeatureType,
typename ModelType, typename FeatureType, typename FitType,
typename std::enable_if<
!has_valid_predict_mean<ModelType, DummyType, FitType>::value &&
!has_valid_predict_marginal<ModelType, DummyType,
!has_valid_predict_mean<ModelType, FeatureType, FitType>::value &&
!has_valid_predict_marginal<ModelType, FeatureType,
FitType>::value &&
has_valid_predict_joint<ModelType, DummyType, FitType>::value,
has_valid_predict_joint<ModelType, FeatureType, FitType>::value,
int>::type = 0>
Eigen::VectorXd mean() const {
static_assert(std::is_same<DummyType, FeatureType>::value,
"never do prediction.mean<T>()");
return model_
.predict_(features_, fit_, PredictTypeIdentity<JointDistribution>())
Eigen::VectorXd _mean(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features) const {
return model
.predict_(features, fit, PredictTypeIdentity<JointDistribution>())
.mean;
}
};

// Marginal
template <typename DummyType = FeatureType,
class MarginalPredictor {
public:
template <typename ModelType, typename FeatureType, typename FitType,
typename std::enable_if<has_valid_predict_marginal<
ModelType, DummyType, FitType>::value,
ModelType, FeatureType, FitType>::value,
int>::type = 0>
MarginalDistribution marginal() const {
static_assert(std::is_same<DummyType, FeatureType>::value,
"never do prediction.marginal<T>()");
return model_.predict_(features_, fit_,
PredictTypeIdentity<MarginalDistribution>());
MarginalDistribution
_marginal(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features) const {
return model.predict_(features, fit,
PredictTypeIdentity<MarginalDistribution>());
}

template <
typename DummyType = FeatureType,
typename ModelType, typename FeatureType, typename FitType,
typename std::enable_if<
!has_valid_predict_marginal<ModelType, DummyType, FitType>::value &&
has_valid_predict_joint<ModelType, DummyType, FitType>::value,
!has_valid_predict_marginal<ModelType, FeatureType, FitType>::value &&
has_valid_predict_joint<ModelType, FeatureType, FitType>::value,
int>::type = 0>
MarginalDistribution marginal() const {
static_assert(std::is_same<DummyType, FeatureType>::value,
"never do prediction.marginal<T>()");
const auto joint_pred = model_.predict_(
features_, fit_, PredictTypeIdentity<JointDistribution>());
MarginalDistribution
_marginal(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features) const {
const auto joint_pred =
model.predict_(features, fit, PredictTypeIdentity<JointDistribution>());
if (joint_pred.has_covariance()) {
Eigen::VectorXd diag = joint_pred.covariance.diagonal();
return MarginalDistribution(joint_pred.mean, diag.asDiagonal());
} else {
return MarginalDistribution(joint_pred.mean);
}
}
};

// Joint
template <typename DummyType = FeatureType,
class JointPredictor {
public:
template <typename ModelType, typename FeatureType, typename FitType,
typename std::enable_if<
has_valid_predict_joint<ModelType, DummyType, FitType>::value,
has_valid_predict_joint<ModelType, FeatureType, FitType>::value,
int>::type = 0>
JointDistribution joint() const {
JointDistribution _joint(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features) const {
return model.predict_(features, fit,
PredictTypeIdentity<JointDistribution>());
}
};

template <typename ModelType, typename FeatureType, typename FitType>
class Prediction {

public:
Prediction(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features)
: model_(model), fit_(fit), features_(features) {}

Prediction(const ModelType &model, const FitType &fit,
std::vector<FeatureType> &&features)
: model_(model), fit_(fit), features_(std::move(features)) {}

// Mean
template <typename DummyType = FeatureType,
typename std::enable_if<can_predict_mean<MeanPredictor, ModelType,
DummyType, FitType>::value,
int>::type = 0>
Eigen::VectorXd mean() const {
static_assert(std::is_same<DummyType, FeatureType>::value,
"never do prediction.joint<T>()");
return model_.predict_(features_, fit_,
PredictTypeIdentity<JointDistribution>());
"never do prediction.mean<T>()");
return MeanPredictor()._mean(model_, fit_, features_);
}

// CATCH FAILURE MODES
template <
typename DummyType = FeatureType,
typename std::enable_if<
!has_valid_predict_mean<ModelType, DummyType, FitType>::value &&
!has_valid_predict_marginal<ModelType, DummyType,
FitType>::value &&
!has_valid_predict_joint<ModelType, DummyType, FitType>::value,
int>::type = 0>
Eigen::VectorXd mean() const = delete; // No valid predict method found.
typename std::enable_if<!can_predict_mean<MeanPredictor, ModelType,
DummyType, FitType>::value,
int>::type = 0>
void mean() const
ALBATROSS_FAIL(DummyType, "No valid predict method in ModelType for the "
"mean with FitType and FeatureType.");

// Marginal
template <
typename DummyType = FeatureType,
typename std::enable_if<
!has_valid_predict_marginal<ModelType, DummyType, FitType>::value &&
!has_valid_predict_joint<ModelType, DummyType, FitType>::value,
int>::type = 0>
Eigen::VectorXd
marginal() const = delete; // No valid predict marginal method found.
typename std::enable_if<can_predict_marginal<MarginalPredictor, ModelType,
DummyType, FitType>::value,
int>::type = 0>
MarginalDistribution marginal() const {
static_assert(std::is_same<DummyType, FeatureType>::value,
"never do prediction.mean<T>()");
return MarginalPredictor()._marginal(model_, fit_, features_);
}

template <typename DummyType = FeatureType,
typename std::enable_if<
!has_valid_predict_joint<ModelType, DummyType, FitType>::value,
!can_predict_marginal<MarginalPredictor, ModelType, DummyType,
FitType>::value,
int>::type = 0>
Eigen::VectorXd
joint() const = delete; // No valid predict joint method found.
void marginal() const
ALBATROSS_FAIL(DummyType, "No valid predict method in ModelType for the "
"marginal with FitType and FeatureType.");

// Joint
template <
typename DummyType = FeatureType,
typename std::enable_if<can_predict_joint<JointPredictor, ModelType,
DummyType, FitType>::value,
int>::type = 0>
JointDistribution joint() const {
static_assert(std::is_same<DummyType, FeatureType>::value,
"never do prediction.mean<T>()");
return JointPredictor()._joint(model_, fit_, features_);
}

template <
typename DummyType = FeatureType,
typename std::enable_if<!can_predict_joint<JointPredictor, ModelType,
DummyType, FitType>::value,
int>::type = 0>
void joint() const
ALBATROSS_FAIL(DummyType, "No valid predict method in ModelType for the "
"joint with FitType and FeatureType.");

template <typename PredictType>
PredictType get(PredictTypeIdentity<PredictType> =
Expand Down
Loading