Skip to content

Commit

Permalink
better thorney xml requires branch rate model - slight change at root
Browse files Browse the repository at this point in the history
  • Loading branch information
JT committed Sep 1, 2021
1 parent a6fba82 commit d54d67b
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 192 deletions.
3 changes: 1 addition & 2 deletions src/dr/app/beast/development_parsers.properties
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ dr.evomodelxml.bigfasttree.BigFastTreeModelParser
dr.evomodelxml.treelikelihood.thorneytreelikelihood.ConstrainedBranchLengthProviderParser
dr.evomodelxml.treelikelihood.thorneytreelikelihood.ConstrainedTreeModelParser
dr.evomodelxml.treelikelihood.thorneytreelikelihood.ConstraintsTreeLikelihoodParser
dr.evomodelxml.treelikelihood.thorneytreelikelihood.StrictClockBranchLengthLikelihoodParser
dr.evomodelxml.treelikelihood.thorneytreelikelihood.BranchLengthLikelihoodParser
dr.evomodelxml.treelikelihood.thorneytreelikelihood.PoissonBranchLengthLikelihoodParser
dr.evomodelxml.treelikelihood.thorneytreelikelihood.ThorneyTreeLikelihoodParser
dr.evomodelxml.treelikelihood.thorneytreelikelihood.UniformSubtreePruneRegraftParser
dr.inferencexml.operators.RepeatOperatorParser
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
package dr.evomodel.treelikelihood.thorneytreelikelihood;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.inference.model.*;
import org.apache.commons.math.special.Gamma;
import org.apache.commons.math.util.FastMath;

public class StrictClockBranchLengthLikelihoodDelegate extends AbstractModel implements ThorneyBranchLengthLikelihoodDelegate {
private final Parameter rate;
public class PoissonBranchLengthLikelihoodDelegate extends AbstractModel implements ThorneyBranchLengthLikelihoodDelegate {
private final BranchRateModel branchRateModel;
private final double scale;

public StrictClockBranchLengthLikelihoodDelegate(String name, Parameter rate, double scale){
public PoissonBranchLengthLikelihoodDelegate(String name, BranchRateModel branchRateModel, double scale){
super(name);
this.rate = rate;
addVariable(rate);
this.branchRateModel = branchRateModel;
addModel(branchRateModel);
this.scale = scale;
}

@Override
public double getLogLikelihood(double mutations, double time) {
return SaddlePointExpansion.logPoissonProbability(time*rate.getValue(0)*scale, (int) Math.round(mutations));
public double getLogLikelihood(double mutations, double time, Tree tree , NodeRef node) {
double rate = this.branchRateModel.getBranchRate(tree, node);
return SaddlePointExpansion.logPoissonProbability(time*rate*scale, (int) Math.round(mutations));
}

@Override
protected void handleModelChangedEvent(Model model, Object object, int index) {

fireModelChanged(this,index);
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package dr.evomodel.treelikelihood.thorneytreelikelihood;

import dr.inference.model.Parameter;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;


public interface ThorneyBranchLengthLikelihoodDelegate {
double getLogLikelihood(double observed,double expected);
double getLogLikelihood(double observed, double expected, Tree tree, NodeRef node);

}
Original file line number Diff line number Diff line change
Expand Up @@ -51,28 +51,18 @@

public class ThorneyTreeLikelihood extends AbstractModelLikelihood implements Reportable {

public ThorneyTreeLikelihood(String name, TreeModel treeModel, BranchLengthProvider branchLengthProvider, ThorneyBranchLengthLikelihoodDelegate thorneyBranchLengthLikelihoodDelegate, BranchRateModel branchRateModel) {
public ThorneyTreeLikelihood(String name, TreeModel treeModel, BranchLengthProvider branchLengthProvider, ThorneyBranchLengthLikelihoodDelegate thorneyBranchLengthLikelihoodDelegate) {

super(name);

this.treeModel = treeModel;
addModel(treeModel);
this.thorneyBranchLengthLikelihoodDelegate = thorneyBranchLengthLikelihoodDelegate;
this.branchLengthProvider = branchLengthProvider;
this.branchRateModel = branchRateModel;
assert thorneyBranchLengthLikelihoodDelegate instanceof Model;
addModel((Model) thorneyBranchLengthLikelihoodDelegate);

if(this.thorneyBranchLengthLikelihoodDelegate instanceof Model & this.branchRateModel!=null){
throw new IllegalArgumentException("Can't use a branch rate likelihood with a rate parameter in combination with a branch rate model. Please remove either the parameter or the model");
}
if(this.branchRateModel!=null){
this.likelihoodDelegate=new branchRateModelLikelihood();
addModel(this.branchRateModel);
}else{
this.likelihoodDelegate=new additiveRateLikelihood();
assert thorneyBranchLengthLikelihoodDelegate instanceof Model;
addModel((Model) thorneyBranchLengthLikelihoodDelegate);

}

updateNode = new boolean[treeModel.getNodeCount()];
Arrays.fill(updateNode, true);
Expand All @@ -85,9 +75,7 @@ public ThorneyTreeLikelihood(String name, TreeModel treeModel, BranchLengthProvi
cachedRootChild1 = treeModel.getChild(treeModel.getRoot(), 0).getNumber();
cachedRootChild2 = treeModel.getChild(treeModel.getRoot(), 1).getNumber();
}
public ThorneyTreeLikelihood(String name, TreeModel treeModel, BranchLengthProvider branchLengthProvider, ThorneyBranchLengthLikelihoodDelegate thorneyBranchLengthLikelihoodDelegate) {
this(name, treeModel, branchLengthProvider, thorneyBranchLengthLikelihoodDelegate, null);
}



/**
Expand Down Expand Up @@ -201,13 +189,9 @@ protected void handleModelChangedEvent(Model model, Object object, int index) {
} else if (model == thorneyBranchLengthLikelihoodDelegate) {
if (index == -1) {
updateAllNodes();
}
} else if (model == branchRateModel) {
if (index == -1) {
updateAllNodes();
} else {
updateNode(treeModel.getNode(index));
}
} else {
updateNode(treeModel.getNode(index));
}
}
else{
throw new RuntimeException("Unknown componentChangedEvent");
Expand Down Expand Up @@ -250,41 +234,36 @@ protected void acceptState() {
// **************************************************************
// Likelihood IMPLEMENTATION
// **************************************************************
// private double calculateLogLikelihood(NodeRef node, int root, int rootChild1, int rootChild2) {
// int nodeIndex = node.getNumber();
//
// if (updateNode[nodeIndex]) {
// double logL;
// if(nodeIndex==root || nodeIndex==rootChild2){
// logL=0;
// }else{
// double time = treeModel.getBranchLength(node);
// double mutations = branchLengthProvider.getBranchLength(treeModel, node);
//
// if (nodeIndex == rootChild1) {
// // sum the branches on both sides of the root
// NodeRef node2 = treeModel.getNode(rootChild2);
// time += treeModel.getBranchLength(node2);
// mutations += branchLengthProvider.getBranchLength(treeModel, node2);
// }
//// gamma.setScale(1.0);
//// branchLogL[i] = gamma.logPdf(x);
// if(this.branchRateModel!=null) {
// logL = thorneyBranchLengthLikelihoodDelegate.getLogLikelihood(mutations, time); //SaddlePointExpansion.logBinomialProbability((int)x, sequenceLength, expected, 1.0D - expected);
// }else{
// logL = thorneyBranchLengthLikelihoodDelegate.getLogLikelihood(mutations, time); //SaddlePointExpansion.logBinomialProbability((int)x, sequenceLength, expected, 1.0D - expected);
//
// }
// }
//
// for (int i = 0; i < treeModel.getChildCount(node); i++) {
// logL += calculateLogLikelihood(treeModel.getChild(node, i),root,rootChild1,rootChild2);
// }
// branchLogL[nodeIndex] = logL;
// updateNode[nodeIndex] = false;
// }
// return branchLogL[nodeIndex];
// }
public double calculateLogLikelihood(NodeRef node, int root, int rootChild1, int rootChild2) {
int nodeIndex = node.getNumber();

if (updateNode[nodeIndex]) {
double logL;
if(nodeIndex==root){
logL=0;
}else{
double time = treeModel.getBranchLength(node);
double mutations = branchLengthProvider.getBranchLength(treeModel, node);

// if (nodeIndex == rootChild1) {
// // sum the branches on both sides of the root
// NodeRef node2 = treeModel.getNode(rootChild2);
// time += treeModel.getBranchLength(node2);
// mutations += branchLengthProvider.getBranchLength(treeModel, node2);
// }
// gamma.setScale(1.0);
// branchLogL[i] = gamma.logPdf(x);
logL = thorneyBranchLengthLikelihoodDelegate.getLogLikelihood(mutations, time,treeModel,node); //SaddlePointExpansion.logBinomialProbability((int)x, sequenceLength, expected, 1.0D - expected);

}
for (int i = 0; i < treeModel.getChildCount(node); i++) {
logL += this.calculateLogLikelihood(treeModel.getChild(node, i),root,rootChild1,rootChild2);
}
branchLogL[nodeIndex] = logL;
updateNode[nodeIndex] = false;
}
return branchLogL[nodeIndex];
}
private double calculateLogLikelihoodLinear() {

// makeDirty();
Expand All @@ -301,25 +280,25 @@ private double calculateLogLikelihoodLinear() {

double logL = 0.0;
for (int i = 0; i < treeModel.getNodeCount(); i++) {
if (updateNode[i] && i != root && i != rootChild2) {
if (updateNode[i] && i != root) {
NodeRef node = treeModel.getNode(i);
// skip the root and the second child of the root (this is added to the first child)

double time = treeModel.getBranchLength(node);
double mutations = branchLengthProvider.getBranchLength(treeModel, node);

if (i == rootChild1) {
// sum the branches on both sides of the root
NodeRef node2 = treeModel.getNode(rootChild2);
time += treeModel.getBranchLength(node2);
mutations += branchLengthProvider.getBranchLength(treeModel, node2);
}
// double mean = expected * sequenceLength;

// gamma.setScale(1.0);
// branchLogL[i] = gamma.logPdf(x);
// if (i == rootChild1) {
// // sum the branches on both sides of the root
// NodeRef node2 = treeModel.getNode(rootChild2);
// time += treeModel.getBranchLength(node2);
// mutations += branchLengthProvider.getBranchLength(treeModel, node2);
// }
//// double mean = expected * sequenceLength;
//
//// gamma.setScale(1.0);
//// branchLogL[i] = gamma.logPdf(x);

branchLogL[i] = thorneyBranchLengthLikelihoodDelegate.getLogLikelihood(mutations,time);
branchLogL[i] = thorneyBranchLengthLikelihoodDelegate.getLogLikelihood(mutations,time,treeModel,node);
//SaddlePointExpansion.logPoissonProbability(mean,(int)x); //SaddlePointExpansion.logBinomialProbability((int)x, sequenceLength, expected, 1.0D - expected);
}
updateNode[i] = false;
Expand All @@ -337,7 +316,7 @@ private double calculateLogLikelihood() {
updateNode[rootChild1] = updateNode[rootChild1] || updateNode[rootChild2] ;
assert updateNode[root];

return likelihoodDelegate.calculateLogLikelihood(treeModel.getRoot(), root, rootChild1, rootChild2);
return calculateLogLikelihood(treeModel.getRoot(), root, rootChild1, rootChild2);

}
public final Model getModel() {
Expand Down Expand Up @@ -375,57 +354,23 @@ public double calculateLogLikelihood(NodeRef node, int root, int rootChild1, int

if (updateNode[nodeIndex]) {
double logL;
if(nodeIndex==root || nodeIndex==rootChild2){
if(nodeIndex==root){
logL=0;
}else{
double time = treeModel.getBranchLength(node);
double mutations = branchLengthProvider.getBranchLength(treeModel, node);

if (nodeIndex == rootChild1) {
// sum the branches on both sides of the root
NodeRef node2 = treeModel.getNode(rootChild2);
time += treeModel.getBranchLength(node2);
mutations += branchLengthProvider.getBranchLength(treeModel, node2);
}
// if (nodeIndex == rootChild1) {
// // sum the branches on both sides of the root
// NodeRef node2 = treeModel.getNode(rootChild2);
// time += treeModel.getBranchLength(node2);
// mutations += branchLengthProvider.getBranchLength(treeModel, node2);
// }
// gamma.setScale(1.0);
// branchLogL[i] = gamma.logPdf(x);
logL = thorneyBranchLengthLikelihoodDelegate.getLogLikelihood(mutations, time); //SaddlePointExpansion.logBinomialProbability((int)x, sequenceLength, expected, 1.0D - expected);

}

for (int i = 0; i < treeModel.getChildCount(node); i++) {
logL += this.calculateLogLikelihood(treeModel.getChild(node, i),root,rootChild1,rootChild2);
}
branchLogL[nodeIndex] = logL;
updateNode[nodeIndex] = false;
}
return branchLogL[nodeIndex];
}
}
private class branchRateModelLikelihood implements LikelihoodDelegate{

public double calculateLogLikelihood(NodeRef node, int root, int rootChild1, int rootChild2) {
int nodeIndex = node.getNumber();

if (updateNode[nodeIndex]) {
double logL;
if(nodeIndex==root || nodeIndex==rootChild2){
logL=0;
}else{
double expectation = treeModel.getBranchLength(node) * branchRateModel.getBranchRate(treeModel, node);
double mutations = branchLengthProvider.getBranchLength(treeModel, node);
logL = thorneyBranchLengthLikelihoodDelegate.getLogLikelihood(mutations, time,treeModel,node); //SaddlePointExpansion.logBinomialProbability((int)x, sequenceLength, expected, 1.0D - expected);

if (nodeIndex == rootChild1) {
// sum the branches on both sides of the root
NodeRef node2 = treeModel.getNode(rootChild2);
expectation += treeModel.getBranchLength(node2)*branchRateModel.getBranchRate(treeModel, node2);
mutations += branchLengthProvider.getBranchLength(treeModel, node2);
}
// gamma.setScale(1.0);
// branchLogL[i] = gamma.logPdf(x);
logL = thorneyBranchLengthLikelihoodDelegate.getLogLikelihood(mutations, expectation); //SaddlePointExpansion.logBinomialProbability((int)x, sequenceLength, expected, 1.0D - expected);
}

for (int i = 0; i < treeModel.getChildCount(node); i++) {
logL += this.calculateLogLikelihood(treeModel.getChild(node, i),root,rootChild1,rootChild2);
}
Expand All @@ -447,7 +392,7 @@ public double calculateLogLikelihood(NodeRef node, int root, int rootChild1, int

private final ThorneyBranchLengthLikelihoodDelegate thorneyBranchLengthLikelihoodDelegate;
private final BranchLengthProvider branchLengthProvider;
private final BranchRateModel branchRateModel;


//private final double[][] distanceMatrix;

Expand Down

This file was deleted.

Loading

0 comments on commit d54d67b

Please sign in to comment.