Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix/query #182

Merged
merged 8 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions big_scape/cli/cli_common_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,21 +273,6 @@ def common_cluster_query(fn):
"based on the generic 'mix' weights."
),
),
click.option(
"--legacy_classify",
is_flag=True,
help=(
"Does not use antiSMASH BGC classes to run analyses on "
"class-based bins, instead it uses BiG-SCAPE v1 predefined groups: "
"PKS1, PKSOther, NRPS, NRPS-PKS-hybrid, RiPP, Saccharide, Terpene, Others. "
"Will also use BiG-SCAPE v1 legacy_weights for distance calculations. "
"This feature is available for backwards compatibility with "
"antiSMASH versions up to v7. For higher antiSMASH versions, use "
"at your own risk, as BGC classes may have changed. All antiSMASH "
"classes that this legacy mode does not recognize will be grouped in "
"'others'."
),
),
click.option(
"--alignment_mode",
type=click.Choice(["global", "glocal", "local", "auto"]),
Expand Down Expand Up @@ -339,7 +324,7 @@ def common_cluster_query(fn):
"-db",
"--db_path",
type=click.Path(path_type=Path, dir_okay=False),
help="Path to sqlite db output file. (default: output_dir/data_sqlite.db).",
help="Path to sqlite db output file. (default: output_dir/output_dir.db).",
),
# TODO: implement cand_cluster here and LCS-ext
click.option(
Expand Down
2 changes: 1 addition & 1 deletion big_scape/cli/cli_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def validate_output_paths(ctx) -> None:
timestamp = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())

if "db_path" in ctx.obj and ctx.obj["db_path"] is None:
db_path = ctx.obj["output_dir"] / Path("data_sqlite.db")
db_path = ctx.obj["output_dir"] / Path(f"{ctx.obj['output_dir'].name}.db")
ctx.obj["db_path"] = db_path

if "log_path" in ctx.obj and ctx.obj["log_path"] is None:
Expand Down
15 changes: 15 additions & 0 deletions big_scape/cli/cluster_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@
)
# binning parameters
@click.option("--no_mix", is_flag=True, help=("Don't run the all-vs-all analysis."))
@click.option(
"--legacy_classify",
is_flag=True,
help=(
"Does not use antiSMASH BGC classes to run analyses on "
"class-based bins, instead it uses BiG-SCAPE v1 predefined groups: "
"PKS1, PKSOther, NRPS, NRPS-PKS-hybrid, RiPP, Saccharide, Terpene, Others. "
"Will also use BiG-SCAPE v1 legacy_weights for distance calculations. "
"This feature is available for backwards compatibility with "
"antiSMASH versions up to v7. For higher antiSMASH versions, use "
"at your own risk, as BGC classes may have changed. All antiSMASH "
"classes that this legacy mode does not recognize will be grouped in "
"'others'."
),
)
# networking parameters
@click.option(
"--include_singletons",
Expand Down
12 changes: 6 additions & 6 deletions big_scape/cli/query_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,13 @@
),
)
@click.option(
"--skip_propagation",
"--propagate",
is_flag=True,
help=(
"Only generate edges between the query and reference BGCs. If not set, "
"BiG-SCAPE will also propagate edge generation to reference BGCs. "
"Warning: if the database already contains all edges, this will not work, "
"and the output will still showcase all edges between nodes "
"in the query connected component."
"By default, BiG-SCAPE will only generate edges between the query and reference"
" BGCs. With the propagate flag, BiG-SCAPE will go through multiple cycles of "
"edge generation until no new reference BGCs are connected to the query "
"connected component."
),
)
@click.pass_context
Expand All @@ -74,6 +73,7 @@ def query(ctx, *args, **kwarg):
ctx.obj.update(ctx.params)
ctx.obj["no_mix"] = None
ctx.obj["hybrids_off"] = False
ctx.obj["legacy_classify"] = False
ctx.obj["mode"] = "Query"

# workflow validations
Expand Down
50 changes: 48 additions & 2 deletions big_scape/comparison/binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def num_pairs(self) -> int:
# find a collection of gbks with more than one subrecord
member_table = (
select(func.count(record_table.c.gbk_id).label("rec_count"))
.where(record_table.c.record_type == self.record_type.value)
.where(record_table.c.id.in_(self.record_ids))
.group_by(record_table.c.gbk_id)
.having(func.count() > 1)
.subquery()
Expand Down Expand Up @@ -406,8 +406,9 @@ def __init__(
label: str,
edge_param_id: int,
weights: str,
record_type: Optional[RECORD_TYPE],
):
super().__init__(label, edge_param_id, weights)
super().__init__(label, edge_param_id, weights, record_type)
self.reference_records: set[BGCRecord] = set()
self.done_records: set[BGCRecord] = set()
self.working_query_records: set[BGCRecord] = set()
Expand All @@ -426,6 +427,9 @@ def generate_pairs(
if record_a == record_b:
continue

if record_a.parent_gbk == record_b.parent_gbk:
continue

if legacy_sorting:
sorted_a, sorted_b = sorted((record_a, record_b), key=sort_name_key)
if sorted_a._db_id is None or sorted_b._db_id is None:
Expand Down Expand Up @@ -455,6 +459,48 @@ def num_pairs(self) -> int:

num_pairs = num_query_records * num_ref_records

# delete pairs originating from the same parent gbk
if self.record_type is not None and self.record_type != RECORD_TYPE.REGION:
query_ids = [record._db_id for record in self.working_query_records]
ref_ids = [record._db_id for record in self.working_ref_records]

if DB.metadata is None:
raise RuntimeError("DB.metadata is None")

rec_table = DB.metadata.tables["bgc_record"]

# contruct two tables that hold the gbk id and the number of subrecords
# present in the set of query and ref records respectively
query_gbk = (
select(
rec_table.c.gbk_id,
func.count(rec_table.c.gbk_id).label("query_count"),
)
.where(rec_table.c.id.in_(query_ids))
.group_by(rec_table.c.gbk_id)
.subquery()
)

ref_gbk = (
select(
rec_table.c.gbk_id,
func.count(rec_table.c.gbk_id).label("ref_count"),
)
.where(rec_table.c.id.in_(ref_ids))
.group_by(rec_table.c.gbk_id)
.subquery()
)

# now we can join the two tables and obtain the number of links between
# records from the same gbks by multiplying their counts
same_gbk_query = select(
func.sum(query_gbk.c.query_count * ref_gbk.c.ref_count)
).join(ref_gbk, query_gbk.c.gbk_id == ref_gbk.c.gbk_id)

same_gbks = DB.execute(same_gbk_query).scalar_one()
if same_gbks:
num_pairs -= same_gbks

return num_pairs

def add_records(self, record_list: list[BGCRecord]) -> None:
Expand Down
83 changes: 49 additions & 34 deletions big_scape/distances/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def calculate_distances_query(
max_cutoff = max(run["gcf_cutoffs"])
edge_param_id = bs_comparison.get_edge_param_id(run, weights)

query_bin = bs_comparison.QueryRecordPairGenerator("Query", edge_param_id, weights)
query_bin = bs_comparison.QueryRecordPairGenerator(
"Query", edge_param_id, weights, run["record_type"]
)
query_bin.add_records(query_records)

missing_query_bin = bs_comparison.QueryMissingRecordPairGenerator(query_bin)
Expand All @@ -54,9 +56,13 @@ def calculate_distances_query(

query_connected_component = next(
bs_network.get_connected_components(
max_cutoff, edge_param_id, query_bin, run["run_id"]
)
max_cutoff, edge_param_id, query_bin, run["run_id"], query_record
),
None,
)
if query_connected_component is None:
# no nodes are connected even with the highest cutoffs in the run
return query_bin

query_nodes = bs_network.get_nodes_from_cc(query_connected_component, query_records)

Expand Down Expand Up @@ -159,47 +165,56 @@ def calculate_distances(run: dict, bin: bs_comparison.RecordPairGenerator):
# fetches the current number of singleton ref <-> connected ref pairs from the database
num_pairs = bin.num_pairs()

# if there are no more singleton ref <-> connected ref pairs, then break and exit
if num_pairs == 0:
break

logging.info("Calculating distances for %d pairs", num_pairs)
if num_pairs > 0:
adraismawur marked this conversation as resolved.
Show resolved Hide resolved
logging.info("Calculating distances for %d pairs", num_pairs)

save_batch = []
num_edges = 0
save_batch = []
num_edges = 0

with tqdm.tqdm(total=num_pairs, unit="edge", desc="Calculating distances") as t:
with tqdm.tqdm(
total=num_pairs, unit="edge", desc="Calculating distances"
) as t:

def callback(edges):
nonlocal num_edges
nonlocal save_batch
batch_size = run["cores"] * 100000
for edge in edges:
num_edges += 1
t.update(1)
save_batch.append(edge)
if len(save_batch) > batch_size:
bs_comparison.save_edges_to_db(save_batch, commit=True)
save_batch = []
def callback(edges):
nonlocal num_edges
nonlocal save_batch
batch_size = run["cores"] * 100000
for edge in edges:
num_edges += 1
t.update(1)
save_batch.append(edge)
if len(save_batch) > batch_size:
bs_comparison.save_edges_to_db(save_batch, commit=True)
save_batch = []

bs_comparison.generate_edges(
bin,
run["alignment_mode"],
run["extend_strategy"],
run["cores"],
run["cores"] * 2,
callback,
)
bs_comparison.generate_edges(
bin,
run["alignment_mode"],
run["extend_strategy"],
run["cores"],
run["cores"] * 2,
callback,
)

bs_comparison.save_edges_to_db(save_batch)
bs_comparison.save_edges_to_db(save_batch)

bs_data.DB.commit()
bs_data.DB.commit()

logging.info("Generated %d edges", num_edges)
logging.info("Generated %d edges", num_edges)

if run["skip_propagation"]:
if not run["propagate"]:
# in this case we only want one iteration, the Query -> Ref edges
break

if isinstance(bin, bs_comparison.MissingRecordPairGenerator):
# in this case we only need edges within one cc, no cycles needed
break

if isinstance(bin, bs_comparison.QueryMissingRecordPairGenerator):
# use the num_pairs from the parent bin because in a partial database,
# all distances for the first cycle(s) might already be present.
# we still only want to stop when no other connected nodes are discovered.
if bin.bin.num_pairs() == 0:
break

bin.cycle_records(max(run["gcf_cutoffs"]))
22 changes: 13 additions & 9 deletions big_scape/file_input/load_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,21 +409,16 @@ def get_all_bgc_records_query(
"""Get all BGC records from the working list of GBKs

Args:
gbks (list[GBK]): list of GBK objects
run (dict): run parameters
gbks (list[GBK]): list of GBK objects

Returns:
list[bs_gbk.BGCRecord]: list of BGC records
list[bs_gbk.BGCRecord], bs_gbk.BGCRecord: list of BGC records, query BGC record
"""
all_bgc_records: list[bs_gbk.BGCRecord] = []
for gbk in gbks:
if gbk.region is not None:
gbk_records = bs_gbk.bgc_record.get_sub_records(
gbk.region, run["record_type"]
)
if gbk.source_type == bs_enums.SOURCE_TYPE.QUERY:
query_record_type = run["record_type"]

query_record_type = run["record_type"]
query_record_number = run["query_record_number"]

Expand All @@ -435,15 +430,24 @@ def get_all_bgc_records_query(
query_record = query_sub_records[0]

else:
query_record = [
matching_query_records = [
record
for record in query_sub_records
if record.number == query_record_number
][0]
]
if len(matching_query_records) == 0:
raise RuntimeError(
f"Could not find {query_record_type.value} number {query_record_number} in query GBK. "
"Depending on config settings, overlapping records will be merged and take on the lower number."
)
query_record = matching_query_records[0]

all_bgc_records.append(query_record)

else:
gbk_records = bs_gbk.bgc_record.get_sub_records(
gbk.region, run["record_type"]
)
all_bgc_records.extend(gbk_records)

return all_bgc_records, query_record
46 changes: 25 additions & 21 deletions big_scape/network/families.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,35 +370,39 @@ def run_family_assignments_query(
# get_connected_components returns a list of connected components, but we only
# want the first one, so we use next()

try:
query_connected_component = next(
bs_network.get_connected_components(
cutoff, query_bin.edge_param_id, query_bin, run["run_id"]
)
)

cc_cutoff[cutoff] = query_connected_component

logging.debug(
"Found connected component with %d edges",
len(query_connected_component),
)

regions_families = generate_families(
query_connected_component, query_bin.label, cutoff, run["run_id"]
)

# save families to database
save_to_db(regions_families)
query_connected_component = next(
bs_network.get_connected_components(
cutoff,
query_bin.edge_param_id,
query_bin,
run["run_id"],
query_record,
),
None,
)

except StopIteration:
if query_connected_component is None:
logging.warning(
"No connected components found for %s bin at cutoff %s",
query_bin.label,
cutoff,
)
continue

cc_cutoff[cutoff] = query_connected_component

logging.debug(
"Found connected component with %d edges",
len(query_connected_component),
)

regions_families = generate_families(
query_connected_component, query_bin.label, cutoff, run["run_id"]
)

# save families to database
save_to_db(regions_families)

DB.commit()

# no connected components found
Expand Down
Loading
Loading