diff --git a/neurom/features/bifurcation.py b/neurom/features/bifurcation.py index 6c628d15..423a3ef2 100644 --- a/neurom/features/bifurcation.py +++ b/neurom/features/bifurcation.py @@ -32,6 +32,7 @@ from neurom import morphmath from neurom.exceptions import NeuroMError from neurom.core.dataformat import COLS +from neurom.core.morphology import Section from neurom.features.section import section_mean_radius @@ -84,7 +85,7 @@ def remote_bifurcation_angle(bif_point): bif_point.children[1].points[-1]) -def bifurcation_partition(bif_point): +def bifurcation_partition(bif_point, iterator_type=Section.ipreorder): """Calculate the partition at a bifurcation point. We first ensure that the input point has only two children. @@ -94,12 +95,12 @@ def bifurcation_partition(bif_point): """ _raise_if_not_bifurcation(bif_point) - n = float(sum(1 for _ in bif_point.children[0].ipreorder())) - m = float(sum(1 for _ in bif_point.children[1].ipreorder())) + n, m = partition_pair(bif_point, iterator_type=iterator_type) + return max(n, m) / min(n, m) -def partition_asymmetry(bif_point, uylings=False): +def partition_asymmetry(bif_point, uylings=False, iterator_type=Section.ipreorder): """Calculate the partition asymmetry at a bifurcation point. By default partition asymmetry is defined as in https://www.ncbi.nlm.nih.gov/pubmed/18568015. @@ -113,8 +114,7 @@ def partition_asymmetry(bif_point, uylings=False): """ _raise_if_not_bifurcation(bif_point) - n = float(sum(1 for _ in bif_point.children[0].ipreorder())) - m = float(sum(1 for _ in bif_point.children[1].ipreorder())) + n, m = partition_pair(bif_point, iterator_type=iterator_type) if n == m == 1: # By definition the asymmetry A(1, 1) is zero @@ -125,16 +125,17 @@ def partition_asymmetry(bif_point, uylings=False): return abs(n - m) / abs(n + m - c) -def partition_pair(bif_point): +def partition_pair(bif_point, iterator_type=Section.ipreorder): """Calculate the partition pairs at a bifurcation point. The number of nodes in each child tree is counted. The partition pairs is the number of bifurcations in the two child subtrees at each branch point. """ - n = float(sum(1 for _ in bif_point.children[0].ipreorder())) - m = float(sum(1 for _ in bif_point.children[1].ipreorder())) - return (n, m) + return ( + float(len(list(iterator_type(bif_point.children[0])))), + float(len(list(iterator_type(bif_point.children[1])))), + ) def sibling_ratio(bif_point, method='first'):