Skip to content

Commit

Permalink
Merge pull request #197 from medema-group/hotfix/empty-ref-bins
Browse files Browse the repository at this point in the history
Hotfix/empty ref bins
  • Loading branch information
nlouwen authored Oct 23, 2024
2 parents 4f9dd28 + 4201df4 commit 1210a0a
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 264 deletions.
55 changes: 17 additions & 38 deletions big_scape/comparison/binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,13 @@ def add_records(self, record_list: list[BGCRecord]):
if None in self.record_ids:
raise ValueError("Record in bin has no db id!")

def cull_singletons(self, cutoff: float, ref_only: bool = False):
"""Culls singletons for given cutoff, i.e. records which have either no edges
in the database, or all edges have a distance above/equal to the cutoff
def cull_singletons(self, cutoff: float, run_id: int):
"""Culls singletons for given cutoff, i.e. records that do not occur in any
connected components
Args:
cutoff (float): distance cuttoff
ref_only (False): if true only reference singletons are culled
run_id (int): id of the current run
Raises:
RuntimeError: DB.metadata is None
Expand All @@ -182,45 +182,24 @@ def cull_singletons(self, cutoff: float, ref_only: bool = False):
if not DB.metadata:
raise RuntimeError("DB.metadata is None")

distance_table = DB.metadata.tables["distance"]
cc_table = DB.metadata.tables["connected_component"]

# get all distances/edges in the table for the records in this bin and
# with distances below the cutoff
# get all record ids that occur in a connected component
select_statement = (
select(distance_table.c.record_a_id, distance_table.c.record_b_id)
.where(distance_table.c.record_a_id.in_(self.record_ids))
.where(distance_table.c.record_b_id.in_(self.record_ids))
.where(distance_table.c.distance < cutoff)
.where(distance_table.c.edge_param_id == self.edge_param_id)
select(cc_table.c.record_id)
.where(cc_table.c.cutoff == cutoff)
.where(cc_table.c.bin_label == self.label)
.where(cc_table.c.run_id == run_id)
)

edges = DB.execute(select_statement).fetchall()

# get all record_ids in the edges
edge_record_ids: set[int] = set()
for edge in edges:
edge_record_ids.update(edge)

if ref_only:
singleton_record_ids = self.record_ids - edge_record_ids
self.source_records = [
record
for record in self.source_records
if (record._db_id in edge_record_ids)
or (
record._db_id in singleton_record_ids
and record.parent_gbk.source_type != SOURCE_TYPE.REFERENCE
)
]
self.record_ids = {record._db_id for record in self.source_records}
connected_records = set(DB.execute(select_statement).scalars())

else:
self.record_ids = edge_record_ids
self.source_records = [
record
for record in self.source_records
if record._db_id in edge_record_ids
]
self.record_ids = connected_records
self.source_records = [
record
for record in self.source_records
if record._db_id in connected_records
]

def get_query_source_record_ids(self) -> list[int]:
"""Return a list of record ids of all QUERY source type records in this bin
Expand Down
2 changes: 2 additions & 0 deletions big_scape/diagnostics/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# from python
import logging
import sys


def init_logger(run) -> None: # pragma: no cover
Expand Down Expand Up @@ -36,3 +37,4 @@ def init_logger_file(run) -> None: # pragma: no cover
file_handler = logging.FileHandler(run["log_path"])
file_handler.setFormatter(log_formatter)
root_logger.addHandler(file_handler)
logging.info(" ".join(sys.argv))
44 changes: 18 additions & 26 deletions big_scape/run_bigscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,16 +281,12 @@ def signal_handler(sig, frame):
mix_bin = bs_comparison.generate_mix_bin(all_bgc_records, run)

for cutoff in run["gcf_cutoffs"]:
# cull singletons ref only
mix_bin.cull_singletons(cutoff, ref_only=True)
if not run["include_singletons"]:
# cull all records, ref and query
mix_bin.cull_singletons(cutoff, ref_only=False)
if len(mix_bin.record_ids) == 0:
logging.info(
f"Network {mix_bin.label} with cutoff {cutoff} is empty after culling singletons"
)
continue
mix_bin.cull_singletons(cutoff, run["run_id"])
if len(mix_bin.record_ids) == 0:
logging.info(
f"Network '{mix_bin.label}' with cutoff {cutoff} is empty after culling singletons"
)
continue
legacy_prepare_bin_output(run, cutoff, mix_bin)
legacy_generate_bin_output(run, cutoff, mix_bin)

Expand All @@ -301,14 +297,12 @@ def signal_handler(sig, frame):

for bin in legacy_class_bins:
for cutoff in run["gcf_cutoffs"]:
bin.cull_singletons(cutoff, ref_only=True)
if not run["include_singletons"]:
bin.cull_singletons(cutoff)
if len(bin.record_ids) == 0:
logging.info(
f"Network '{bin.label}' with cutoff {cutoff} is empty after culling singletons"
)
continue
bin.cull_singletons(cutoff, run["run_id"])
if len(bin.record_ids) == 0:
logging.info(
f"Network '{bin.label}' with cutoff {cutoff} is empty after culling singletons"
)
continue
legacy_prepare_bin_output(run, cutoff, bin)
legacy_generate_bin_output(run, cutoff, bin)

Expand All @@ -319,14 +313,12 @@ def signal_handler(sig, frame):

for bin in as_class_bins:
for cutoff in run["gcf_cutoffs"]:
bin.cull_singletons(cutoff, ref_only=True)
if not run["include_singletons"]:
bin.cull_singletons(cutoff)
if len(bin.record_ids) == 0:
logging.info(
f"Network '{bin.label}' with cutoff {cutoff} is empty after culling singletons"
)
continue
bin.cull_singletons(cutoff, run["run_id"])
if len(bin.record_ids) == 0:
logging.info(
f"Network '{bin.label}' with cutoff {cutoff} is empty after culling singletons"
)
continue
legacy_prepare_bin_output(run, cutoff, bin)
legacy_generate_bin_output(run, cutoff, bin)

Expand Down
210 changes: 10 additions & 200 deletions test/comparison/test_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
RecordPairGenerator,
ConnectedComponentPairGenerator,
QueryRecordPairGenerator,
save_edge_to_db,
get_record_category,
get_legacy_weights_from_category,
as_class_bin_generator,
Expand Down Expand Up @@ -399,11 +398,11 @@ def test_cull_singletons_cutoff(self):
"""Tests whether singletons are correctly culled"""

bs_data.DB.create_in_mem()
query_gbk = create_mock_gbk(0, bs_enums.SOURCE_TYPE.QUERY)
query_gbk = create_mock_gbk(1, bs_enums.SOURCE_TYPE.QUERY)
# query -> test_path_0.gbk, rec_id 1
query_gbk.save_all()
ref_gbks = [
create_mock_gbk(i, bs_enums.SOURCE_TYPE.REFERENCE) for i in range(1, 5)
create_mock_gbk(i, bs_enums.SOURCE_TYPE.REFERENCE) for i in range(2, 6)
]
# ref[0] -> test_path_1.gbk, rec_id 2
# ref[1] -> test_path_2.gbk, rec_id 3
Expand All @@ -413,214 +412,25 @@ def test_cull_singletons_cutoff(self):
source_records.append(ref_gbk.region)
ref_gbk.save_all()

new_bin = RecordPairGenerator("Test", weights="mix", edge_param_id=1)
new_bin = RecordPairGenerator("mix", edge_param_id=1)
new_bin.add_records(source_records)

# making query <-> ref_1 edge with distance 0.0

save_edge_to_db(
(
query_gbk.region._db_id,
ref_gbks[0].region._db_id,
0.0,
1.0,
1.0,
1.0,
1,
bs_comparison.ComparableRegion(
0,
0,
0,
0,
0,
0,
0,
0,
False,
),
)
# making connected components for records 1, 2, 3
bs_data.DB.execute_raw_query(
"INSERT INTO connected_component VALUES "
"(1, 1, 0.5, 'mix', 1), "
"(1, 2, 0.5, 'mix', 1), "
"(3, 3, 0.5, 'mix', 1);"
)

save_edge_to_db(
(
query_gbk.region._db_id,
ref_gbks[1].region._db_id,
1.0,
0.0,
0.0,
0.0,
1,
bs_comparison.ComparableRegion(
0,
0,
0,
0,
0,
0,
0,
0,
False,
),
)
)

save_edge_to_db(
(
ref_gbks[0].region._db_id,
ref_gbks[1].region._db_id,
0.0,
1.0,
1.0,
1.0,
1,
bs_comparison.ComparableRegion(
0,
0,
0,
0,
0,
0,
0,
0,
False,
),
)
)

# edges above cutoff:
# query <-> ref_1 | rec_id 1 <-> rec_id 2
# ref_1 <-> ref_2 | rec_id 2 <-> rec_id 3

new_bin.cull_singletons(0.5)
new_bin.cull_singletons(0.5, 1)

expected_records = [source_records[0], source_records[1], source_records[2]]
# expected_record_ids = set([1, 2, 3])

actual_records = new_bin.source_records
# actual_record_ids = new_bin.record_ids

self.assertEqual(expected_records, actual_records)

def test_cull_singletons_ref_only(self):
"""Tests whether singletons are correctly culled"""

bs_data.DB.create_in_mem()

query_gbks = [
create_mock_gbk(i, bs_enums.SOURCE_TYPE.QUERY) for i in range(0, 3)
]

ref_gbks = [
create_mock_gbk(i, bs_enums.SOURCE_TYPE.REFERENCE) for i in range(0, 3)
]

all_gbks = query_gbks + ref_gbks
# ref[0] -> test_path_1.gbk, rec_id 2
# ref[1] -> test_path_2.gbk, rec_id 3

source_records = []
for gbk in all_gbks:
source_records.append(gbk.region)
gbk.save_all()

new_bin = RecordPairGenerator("Test", weights="mix", edge_param_id=1)
new_bin.add_records(source_records)

# making query_1 <-> ref_1 edge with distance 0.0

save_edge_to_db(
(
query_gbks[0].region._db_id,
ref_gbks[0].region._db_id,
0.0,
1.0,
1.0,
1.0,
1,
bs_comparison.ComparableRegion(
0,
0,
0,
0,
0,
0,
0,
0,
False,
),
)
)

save_edge_to_db(
(
query_gbks[1].region._db_id,
ref_gbks[1].region._db_id,
1.0,
0.0,
0.0,
0.0,
1,
bs_comparison.ComparableRegion(
0,
0,
0,
0,
0,
0,
0,
0,
False,
),
)
)

save_edge_to_db(
(
ref_gbks[0].region._db_id,
ref_gbks[1].region._db_id,
0.0,
1.0,
1.0,
1.0,
1,
bs_comparison.ComparableRegion(
0,
0,
0,
0,
0,
0,
0,
0,
False,
),
)
)

# edges above cutoff:
# query_1 <-> ref_1 | rec_id 1 <-> rec_id 2
# ref_1 <-> ref_2 | rec_id 2 <-> rec_id 3
# query_2, query_3, ref_3 are singletons

pre_cull_records = len(new_bin.source_records)

new_bin.cull_singletons(0.5, ref_only=True)

actual_records_post_ref_cull = len(new_bin.source_records)

new_bin.cull_singletons(0.5, ref_only=False)

actual_records_post_full_cull = len(new_bin.source_records)

seen_data = [
pre_cull_records,
actual_records_post_ref_cull,
actual_records_post_full_cull,
]
expected_data = [6, 5, 3]

self.assertEqual(seen_data, expected_data)


class TestMixComparison(TestCase):
def clean_db(self):
Expand Down

0 comments on commit 1210a0a

Please sign in to comment.