Skip to content

Commit

Permalink
Merge branch 'hmc-clock' of github.com:beast-dev/beast-mcmc into hmc-…
Browse files Browse the repository at this point in the history
…clock
  • Loading branch information
msuchard committed Jun 16, 2022
2 parents f56d120 + 5c7f22e commit 859f33f
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 18 deletions.
9 changes: 7 additions & 2 deletions src/dr/app/checkpoint/BeastCheckpointer.java
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,9 @@ protected boolean writeStateToFile(File file, long state, double lnL, MarkovChai
//check up front if there are any TreeParameterModel objects
for (Model model : Model.CONNECTED_MODEL_SET) {
if (model instanceof TreeParameterModel) {
//System.out.println("\nDetected TreeParameterModel: " + ((TreeParameterModel) model).toString());
if (DEBUG) {
System.out.println("\nSave TreeParameterModel: " + model.getClass().getSimpleName());
}
traitModels.add((TreeParameterModel) model);
}
}
Expand Down Expand Up @@ -505,7 +507,7 @@ protected long readStateFromFile(File file, MarkovChain markovChain, double[] ln

// load the tree models last as we get the node heights from the tree (not the parameters which
// which may not be associated with the right node
Set<String> expectedTreeModelNames = new HashSet<String>();
Set<String> expectedTreeModelNames = new LinkedHashSet<>();

//store list of TreeModels for debugging purposes
ArrayList<TreeModel> treeModelList = new ArrayList<TreeModel>();
Expand All @@ -529,6 +531,9 @@ protected long readStateFromFile(File file, MarkovChain markovChain, double[] ln

//first add all TreeParameterModels to a list
if (model instanceof TreeParameterModel) {
if (DEBUG) {
System.out.println("\nLoad TreeParameterModel: " + model.getClass().getSimpleName());
}
traitModels.add((TreeParameterModel)model);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,58 @@

import dr.inference.model.Parameter;

import static java.lang.Math.exp;
import static java.lang.Math.log;

public class CriticalBirthDeathSerialSamplingModel extends NewBirthDeathSerialSamplingModel {

public CriticalBirthDeathSerialSamplingModel(String modelName,
Parameter birthDeathRate,
Parameter serialSamplingRate,
Parameter treatmentProbability,
Parameter samplingFractionAtPresent,
Parameter originTime,
boolean condition,
Type units) {

super(modelName,
birthDeathRate, birthDeathRate, serialSamplingRate,
treatmentProbability, samplingFractionAtPresent, originTime,
condition, units);
private int n_events;

public CriticalBirthDeathSerialSamplingModel(
String modelName,
Parameter birthRate,
Parameter deathRate,
Parameter serialSamplingRate,
Parameter treatmentProbability,
Parameter samplingFractionAtPresent,
Parameter originTime,
boolean condition,
Type units) {

super(modelName, birthRate, deathRate, serialSamplingRate, treatmentProbability, samplingFractionAtPresent, originTime, condition, units);
n_events = 0;
}


@Override
public double processInterval(int model, double tYoung, double tOld, int nLineages) {
// TODO Do something different
return super.processInterval(model, tYoung, tOld, nLineages);
}
}

@Override
public double processOrigin(int model, double rootAge) {
double lambda = lambda();
double rho = rho();
double mu = mu();
double v = exp(-(lambda - mu) * rootAge);
double p_n = log(lambda*rho + (lambda*(1-rho) - mu)* v) - log(1- v);
return -2*logq(rootAge) + (n_events-1)*p_n;
}

@Override
public double processCoalescence(int model, double tOld) {
n_events += 1;
return 0;
}

@Override
public double processSampling(int model, double tOld) {
return 0;
}

@Override
public double logConditioningProbability() {
return -log(n_events);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
package dr.evomodelxml.speciation;

import dr.evolution.util.Units;
import dr.evomodel.speciation.CriticalBirthDeathSerialSamplingModel;
import dr.evomodel.speciation.NewBirthDeathSerialSamplingModel;
import dr.evoxml.util.XMLUnits;
import dr.inference.model.Parameter;
Expand Down Expand Up @@ -76,6 +77,10 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {

Logger.getLogger("dr.evomodel").info(citeThisModel);

if (psi.getParameterValue(0) < Double.MIN_VALUE){
return new CriticalBirthDeathSerialSamplingModel(modelName, lambda, mu, psi, r, rho, origin, condition, units);
}

return new NewBirthDeathSerialSamplingModel(modelName, lambda, mu, psi, r, rho, origin, condition, units);
}

Expand Down
4 changes: 2 additions & 2 deletions src/dr/inference/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ public int getListenerCount() {


// set to store all created models
final static Set<Model> FULL_MODEL_SET = new HashSet<Model>();
final static Set<Model> CONNECTED_MODEL_SET = new HashSet<Model>();
final static Set<Model> FULL_MODEL_SET = new LinkedHashSet<Model>();
final static Set<Model> CONNECTED_MODEL_SET = new LinkedHashSet<Model>();

}

0 comments on commit 859f33f

Please sign in to comment.