Skip to content

Commit

Permalink
use a temp table for gbk ids in record loading
Browse files Browse the repository at this point in the history
  • Loading branch information
adraismawur committed Dec 20, 2024
1 parent 441b668 commit 15f1c94
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 28 deletions.
14 changes: 10 additions & 4 deletions big_scape/genbank/candidate_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# from dependencies
from Bio.SeqFeature import SeqFeature
from sqlalchemy import Table, select

# from other modules
from big_scape.data import DB
Expand Down Expand Up @@ -169,7 +170,7 @@ def __repr__(self) -> str:
return f"{self.parent_gbk} Candidate cluster {self.number} {self.nt_start}-{self.nt_stop} "

@staticmethod
def load_all(region_dict: dict[int, Region]):
def load_all(region_dict: dict[int, Region], temp_gbk_id_table: Table = None):
"""Load all CandidateCluster objects from the database
This function populates the CandidateCluster lists in the Regions provided in
Expand Down Expand Up @@ -198,10 +199,15 @@ def load_all(region_dict: dict[int, Region]):
record_table.c.product,
)
.where(record_table.c.record_type == "cand_cluster")
.where(record_table.c.parent_id.in_(region_dict.keys()))
.compile()
)

if temp_gbk_id_table is not None:
candidate_cluster_select_query = candidate_cluster_select_query.where(
record_table.c.gbk_id.in_(select(temp_gbk_id_table.c.gbk_id))
)

candidate_cluster_select_query = candidate_cluster_select_query.compile()

cursor_result = DB.execute(candidate_cluster_select_query)

candidate_cluster_dict = {}
Expand Down Expand Up @@ -230,4 +236,4 @@ def load_all(region_dict: dict[int, Region]):
# add to dictionary
candidate_cluster_dict[result.id] = new_candidate_cluster

ProtoCluster.load_all(candidate_cluster_dict)
ProtoCluster.load_all(candidate_cluster_dict, temp_gbk_id_table)
12 changes: 9 additions & 3 deletions big_scape/genbank/cds.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from Bio.SeqFeature import SeqFeature
from Bio.Seq import Seq
from Bio import BiopythonWarning
from sqlalchemy import Table, select

# from other modules
from big_scape.errors import InvalidGBKError
Expand Down Expand Up @@ -320,7 +321,7 @@ def len_nt_overlap(cds_a: CDS, cds_b: CDS) -> int:
return max(0, right - left)

@staticmethod
def load_all(gbk_dict: dict[int, GBK]) -> None:
def load_all(gbk_dict: dict[int, GBK], temp_gbk_id_table: Table = None) -> None:
"""Load all Region objects from the database
This function populates the region objects in the GBKs provided in the input
Expand Down Expand Up @@ -349,10 +350,15 @@ def load_all(gbk_dict: dict[int, GBK]) -> None:
cds_table.c.aa_seq,
)
.order_by(cds_table.c.orf_num)
.where(cds_table.c.gbk_id.in_(gbk_dict.keys()))
.compile()
)

if temp_gbk_id_table is not None:
region_select_query = region_select_query.where(
cds_table.c.gbk_id.in_(select(temp_gbk_id_table.c.gbk_id))
)

region_select_query = region_select_query.compile()

cursor_result = DB.execute(region_select_query)

for result in cursor_result.all():
Expand Down
79 changes: 68 additions & 11 deletions big_scape/genbank/gbk.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,61 @@ def batch_hash(gbks: list[GBK], n: int):
return temp_table


def create_temp_gbk_id_table(gbks: list[GBK]) -> Table:
"""Create a temporary table with ids of given gbks
Args:
gbks (list[GBK]): the gbks to include in the connected component
Returns:
Table: the temporary table
"""

# generate a short random string
temp_table_name = "temp_" + "".join(random.choices(string.ascii_lowercase, k=10))

temp_table = Table(
temp_table_name,
DB.metadata,
Column(
"gbk_id",
Integer,
ForeignKey(DB.metadata.tables["gbk"].c.id),
primary_key=True,
nullable=False,
),
prefixes=["TEMPORARY"],
)

DB.metadata.create_all(DB.engine)

if DB.engine is None:
raise RuntimeError("DB engine is None")

cursor = DB.engine.raw_connection().driver_connection.cursor()

insert_query = f"""
INSERT INTO {temp_table_name} (gbk_id) VALUES (?);
"""

def batch_hash(gbks: list[GBK], n: int):
l = len(gbks)
for ndx in range(0, l, n):
yield [gbk._db_id for gbk in gbks[ndx : min(ndx + n, l)]]

for hash_batch in batch_hash(gbks, 1000):
cursor.executemany(insert_query, [(x,) for x in hash_batch]) # type: ignore

cursor.close()

DB.commit()

if DB.metadata is None:
raise ValueError("DB metadata is None")

return temp_table


class GBK:
"""
Class to describe a given GBK file
Expand Down Expand Up @@ -357,9 +412,11 @@ def load_many(input_gbks: list[GBK]) -> list[GBK]:
# load GBK regions. This will also populate all record levels below region
# e.g. candidate cluster, protocore if they exist

Region.load_all(gbk_dict)
temp_gbk_id_table = create_temp_gbk_id_table(input_gbks)

CDS.load_all(gbk_dict)
Region.load_all(gbk_dict, temp_gbk_id_table)

CDS.load_all(gbk_dict, temp_gbk_id_table)

return list(gbk_dict.values())

Expand Down Expand Up @@ -695,15 +752,15 @@ def collapse_hybrids_in_cand_clusters(
for number in cand_cluster.proto_clusters.keys()
]
merged_protocluster = MergedProtoCluster.merge(protoclusters)
merged_tmp_proto_clusters[merged_protocluster.number] = (
merged_protocluster
)
merged_tmp_proto_clusters[
merged_protocluster.number
] = merged_protocluster

# update the protocluster old:new ids for the merged protoclusters of this cand_cluster
for proto_cluster_num in cand_cluster.proto_clusters.keys():
merged_protocluster_ids[proto_cluster_num] = (
merged_protocluster.number
)
merged_protocluster_ids[
proto_cluster_num
] = merged_protocluster.number

# now we build a new version of the tmp_proto_clusters dict that contains the merged protoclusters
# as well as protoclusters which did not need merging, with updated unique IDs/numbers
Expand All @@ -717,9 +774,9 @@ def collapse_hybrids_in_cand_clusters(
# this protocluster has been merged, so we need to add it to
# the dict with its new protocluster number
new_proto_cluster_num = merged_protocluster_ids[proto_cluster_num]
updated_tmp_proto_clusters[new_proto_cluster_num] = (
merged_tmp_proto_clusters[new_proto_cluster_num]
)
updated_tmp_proto_clusters[
new_proto_cluster_num
] = merged_tmp_proto_clusters[new_proto_cluster_num]
updated_proto_cluster_dict[new_proto_cluster_num] = None
else:
# protoclusters which have not been merged are added to the dict as is
Expand Down
15 changes: 12 additions & 3 deletions big_scape/genbank/proto_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# from dependencies
from Bio.SeqFeature import SeqFeature
from sqlalchemy import Table, select

# from other modules
from big_scape.data import DB
Expand Down Expand Up @@ -179,7 +180,10 @@ def __repr__(self) -> str:
return f"{self.parent_gbk} ProtoCluster {self.number} {self.nt_start}-{self.nt_stop} "

@staticmethod
def load_all(candidate_cluster_dict: dict[int, CandidateCluster]):
def load_all(
candidate_cluster_dict: dict[int, CandidateCluster],
temp_gbk_id_table: Table = None,
):
"""Load all ProtoCluster objects from the database
This function populates the CandidateCluster objects in the GBKs provided in the
Expand Down Expand Up @@ -210,10 +214,15 @@ def load_all(candidate_cluster_dict: dict[int, CandidateCluster]):
record_table.c.merged,
)
.where(record_table.c.record_type == "protocluster")
.where(record_table.c.parent_id.in_(candidate_cluster_dict.keys()))
.compile()
)

if temp_gbk_id_table is not None:
protocluster_select_query = protocluster_select_query.where(
record_table.c.gbk_id.in_(select(temp_gbk_id_table.c.gbk_id))
)

protocluster_select_query = protocluster_select_query.compile()

cursor_result = DB.execute(protocluster_select_query)

protocluster_dict = {}
Expand Down
14 changes: 11 additions & 3 deletions big_scape/genbank/proto_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# from dependencies
from Bio.SeqFeature import SeqFeature
from sqlalchemy import Table, select

# from other modules
from big_scape.data import DB
Expand Down Expand Up @@ -110,7 +111,9 @@ def __repr__(self) -> str:
)

@staticmethod
def load_all(protocluster_dict: dict[int, ProtoCluster]):
def load_all(
protocluster_dict: dict[int, ProtoCluster], temp_gbk_id_table: Table = None
):
"""Load all ProtoCore objects from the database
This function populates the region objects in the GBKs provided in the input
Expand Down Expand Up @@ -141,10 +144,15 @@ def load_all(protocluster_dict: dict[int, ProtoCluster]):
record_table.c.merged,
)
.where(record_table.c.record_type == "proto_core")
.where(record_table.c.parent_id.in_(protocluster_dict.keys()))
.compile()
)

if temp_gbk_id_table is not None:
region_select_query = region_select_query.where(
record_table.c.gbk_id.in_(select(temp_gbk_id_table.c.gbk_id))
)

region_select_query = region_select_query.compile()

cursor_result = DB.execute(region_select_query)

for result in cursor_result.all():
Expand Down
14 changes: 10 additions & 4 deletions big_scape/genbank/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# from dependencies
from Bio.SeqFeature import SeqFeature
from Bio.SeqRecord import SeqRecord
from sqlalchemy import Table, select

# from other modules
from big_scape.data import DB
Expand Down Expand Up @@ -262,7 +263,7 @@ def __repr__(self):
return f"{self.parent_gbk} Region {self.number} {self.nt_start}-{self.nt_stop} "

@staticmethod
def load_all(gbk_dict: dict[int, GBK]) -> None:
def load_all(gbk_dict: dict[int, GBK], temp_gbk_id_table: Table = None) -> None:
"""Load all Region objects from the database
This function populates the region objects in the GBKs provided in the input
Expand Down Expand Up @@ -292,10 +293,15 @@ def load_all(gbk_dict: dict[int, GBK]) -> None:
record_table.c.product,
)
.where(record_table.c.record_type == "region")
.where(record_table.c.gbk_id.in_(gbk_dict.keys()))
.compile()
)

if temp_gbk_id_table is not None:
region_select_query = region_select_query.where(
record_table.c.gbk_id.in_(select(temp_gbk_id_table.c.gbk_id))
)

region_select_query = region_select_query.compile()

cursor_result = DB.execute(region_select_query)

region_dict = {}
Expand All @@ -320,4 +326,4 @@ def load_all(gbk_dict: dict[int, GBK]) -> None:
# add to dictionary
region_dict[result.id] = new_region

CandidateCluster.load_all(region_dict)
CandidateCluster.load_all(region_dict, temp_gbk_id_table)

0 comments on commit 15f1c94

Please sign in to comment.