diff --git a/pyeed/core/alignment.py b/pyeed/core/alignment.py index 60891af9..5358416e 100644 --- a/pyeed/core/alignment.py +++ b/pyeed/core/alignment.py @@ -1,8 +1,11 @@ import sdRDM +from rich.status import Status, Console from tqdm import tqdm from itertools import combinations from typing import List, Optional, Union, Tuple, TYPE_CHECKING from pydantic import Field, validator +from IPython.display import clear_output + from sdRDM.base.listplus import ListPlus from sdRDM.base.utils import forge_signature, IDGenerator from Bio.Align import Alignment as BioAlignment @@ -367,8 +370,16 @@ def from_sequences( input_sequences=sequences, ) - if aligner is not None: - return alignment.align(aligner, **kwargs) + with Status("Running ClustalOmega...", console=Console(force_terminal=False)): + + if aligner is not None: + + result = alignment.align(aligner, **kwargs) + + clear_output() + if result: + print("✅ Alignment completed") + return result return alignment diff --git a/pyeed/core/proteininfo.py b/pyeed/core/proteininfo.py index 513f4d27..bf5bc198 100644 --- a/pyeed/core/proteininfo.py +++ b/pyeed/core/proteininfo.py @@ -1,23 +1,26 @@ -import re import os +import asyncio from typing import List, Optional +from IPython.display import clear_output +from rich.status import Status, Console +from concurrent.futures import ThreadPoolExecutor import warnings from pydantic import Field from sdRDM.base.listplus import ListPlus from sdRDM.base.utils import forge_signature, IDGenerator -from Bio.Blast import NCBIWWW, NCBIXML - - -from .dnainfo import DNAInfo -from .proteinregion import ProteinRegion -from .abstractsequence import AbstractSequence -from .site import Site -from .citation import Citation -from .span import Span -from .proteinregiontype import ProteinRegionType -from .substrate import Substrate -from .dnaregion import DNARegion -from .proteinsitetype import ProteinSiteType +from Bio.Blast import NCBIXML + + +from pyeed.core.dnainfo import DNAInfo +from pyeed.core.proteinregion import ProteinRegion +from pyeed.core.abstractsequence import AbstractSequence +from pyeed.core.site import Site +from pyeed.core.citation import Citation +from pyeed.core.span import Span +from pyeed.core.proteinregiontype import ProteinRegionType +from pyeed.core.substrate import Substrate +from pyeed.core.dnaregion import DNARegion +from pyeed.core.proteinsitetype import ProteinSiteType from pyeed.container.abstract_container import Blastp @@ -164,10 +167,13 @@ def add_to_substrates( @classmethod def get_id(cls, protein_id: str) -> "ProteinInfo": - from pyeed.fetch import NCBIProteinFetcher + from pyeed.fetch.proteinfetcher import ProteinFetcher + import nest_asyncio + + nest_asyncio.apply() """ - This method creates a 'ProteinInfo' object from a given NCBI ID. + This method creates a 'ProteinInfo' object from a given protein accession ID. Args: protein_id (str): ID of the protein in NCBI or UniProt database. @@ -180,89 +186,146 @@ def get_id(cls, protein_id: str) -> "ProteinInfo": warnings.warn("For getting multiple sequences by ID use `get_ids` instead.") return cls.get_ids(protein_id) - return NCBIProteinFetcher(protein_id).fetch(cls)[0] + sequences = asyncio.run(ProteinFetcher(ids=[protein_id]).fetch(quiet=True))[0] + clear_output() + return sequences @classmethod - def get_ids( - cls, accession_ids: List[str], email: str = None, api_key: str = None - ) -> List["ProteinInfo"]: - from pyeed.fetch import NCBIProteinFetcher + def get_ids(cls, accession_ids: List[str]) -> List["ProteinInfo"]: + from pyeed.fetch.proteinfetcher import ProteinFetcher + import nest_asyncio + + nest_asyncio.apply() - proteins = NCBIProteinFetcher(accession_ids, email, api_key).fetch(cls) + return asyncio.run( + ProteinFetcher(ids=accession_ids).fetch(force_terminal=False) + ) + + @classmethod + def from_sequence( + cls, + sequence: str, + exact_match: bool = True, + database: str = "nr", + matrix: str = "BLOSUM62", + ): + """ + Creates a 'ProteinInfo' object from a given protein sequence by + performing a BLAST search on NCBI server. + + Args: + sequence (str): The protein sequence to search for. + exact_match (bool, optional): If True, only exact matches will be considered. + If False, approximate matches will also be included. Defaults to True. + database (str, optional): The database to search against. Must be one of + the supported databases: 'nr', 'swissprot', 'pdb', 'refseq_protein'. + Defaults to 'nr'. - return proteins + Returns: + ProteinInfo: A 'ProteinInfo' object representing the protein sequence + found in the database. + + Raises: + AssertionError: If the specified database is not supported. + """ + + import nest_asyncio + from pyeed.fetch.blast import Blast, NCBIDataBase, BlastProgram + from pyeed.fetch.proteinfetcher import ProteinFetcher + + nest_asyncio.apply() + + assert ( + database in NCBIDataBase + ), f"Database needs to be one of {NCBIDataBase.__members__.keys()}" + + identity = 1 if exact_match else 0 + + blaster = Blast( + query=sequence, + n_hits=1, + identity=identity, + matrix=matrix, + ) - def ncbi_blastp( + with Status("Running BLAST", console=Console(force_terminal=False)) as status: + result = asyncio.run( + blaster.async_run( + NCBIDataBase.NR.value, + BlastProgram.BLASTP.value, + ) + ) + clear_output() + + accession = blaster.extract_accession(result) + + status.update("Fetching protein data") + + if accession: + return asyncio.run( + ProteinFetcher(ids=accession).fetch(force_terminal=False) + )[0] + + return + + def ncbi_blast( self, n_hits: int, e_value: float = 10.0, - api_key: str = None, database: str = "nr", + matrix: str = "BLOSUM62", + identity: float = 0.0, **kwargs, ) -> List["ProteinInfo"]: - """Run protein blast for a `ProteinInfo`. - Additional keyword arguments can be pass according to the blast [specifications](https://biopython.org/docs/1.75/api/Bio.Blast.NCBIWWW.html). + """ + Runs a BLAST search using the NCBI BLAST service to find similar protein sequences. Args: - n_hits (int): Number of hits to return. - e_value (float, optional): E-value threshold. Defaults to 10.0. - api_key (str, optional): NCBI API key for sequence retrieval. Defaults to None. - database (str, optional): Database to search. Defaults to "nr" (Non Redundant). - + n_hits (int): The number of hits to retrieve. + e_value (float, optional): The maximum E-value threshold for reporting hits. Defaults to 10.0. + database (str, optional): The database to search against. Defaults to "nr". + matrix (str, optional): The substitution matrix to use. Defaults to "BLOSUM62". + identity (float, optional): The minimum sequence identity threshold for reporting hits. Defaults to 0.0. + **kwargs: Additional keyword arguments. Returns: - List[ProteinInfo]: List of 'ProteinInfo' objects that are the result of the blast search. - """ - from pyeed.fetch import NCBIProteinFetcher - - print("🏃🏼‍♀️ Running PBLAST") - print(f"╭── protein name: {self.name}") - print(f"├── accession: {self.source_id}") - print(f"├── organism: {self.organism.name}") - print(f"├── e-value: {e_value}") - print(f"╰── max hits: {n_hits}") - - result_handle = NCBIWWW.qblast( - "blastp", - database, - self.sequence, - hitlist_size=n_hits, - expect=e_value, - **kwargs, - ) - blast_record = NCBIXML.read(result_handle) - - accessions = self._get_accessions(blast_record) - uniprot_accessions = self._filter_uniprot_accessions(accessions) - ncbi_accessions = list(set(accessions) - set(uniprot_accessions)) + List[ProteinInfo]: A list of ProteinInfo objects representing the similar protein sequences found. - print(f"🔍 Found {len(ncbi_accessions)} NCBI accessions") - print(f"🔍 Found {len(uniprot_accessions)} UniProt accessions") + Raises: + AssertionError: If the specified database is not supported. - protein_infos = NCBIProteinFetcher( - foreign_id=ncbi_accessions, api_key=api_key - ).fetch(ProteinInfo) - protein_infos.insert(0, self) + Example: + protein_info = ProteinInfo() + similar_proteins = protein_info.ncbi_blast(n_hits=10, e_value=0.001, database="swissprot") + """ - if uniprot_accessions: - from pyeed.fetch.uniprotmapper import UniprotFetcher + from pyeed.fetch.proteinfetcher import ProteinFetcher + from pyeed.fetch.blast import Blast, NCBIDataBase, BlastProgram + import nest_asyncio - uniprot_proteins = UniprotFetcher(foreign_id=uniprot_accessions).fetch() - protein_infos.extend(uniprot_proteins) + nest_asyncio.apply() - print("🎉 Done\n") - return protein_infos + assert database in NCBIDataBase - def _filter_uniprot_accessions(self, accessions: List[str]) -> List[str]: - uniprot_pattern = re.compile( - r"[OPQ][0-9][A-Z0-9]{3}[0-9]|[A-NR-Z][0-9]([A-Z][A-Z0-9]{2}[0-9]){1,2}" + program = BlastProgram.BLASTP.value + executor = ThreadPoolExecutor(max_workers=1) + blaster = Blast( + query=self.sequence, + n_hits=n_hits, + evalue=e_value, + matrix=matrix, + identity=identity, ) - return [ - uniprot_pattern.match(acc)[0] - for acc in accessions - if uniprot_pattern.match(acc) - ] + with Status( + "Running BLAST", console=Console(force_terminal=False, force_jupyter=True) + ): + result = asyncio.run(blaster.async_run(database, program, executor)) + clear_output() + + accessions = blaster.extract_accession(result) + + return asyncio.run(ProteinFetcher(ids=accessions).fetch(force_terminal=False)) def blastp( self, @@ -320,3 +383,10 @@ def from_ncbi(self): def from_accessions(self): raise DeprecationWarning("This method is deprecated. Use `get_ids` instead.") + + +if __name__ == "__main__": + seq_string = "MSDRNIRVEPVVGRAVEEQDVEIVERKGLGHPDSLCDGIAEHVSQALARAYIDRVGKVLHYNTDETQLVAGTAAPAFGGGEVVDPIYLLITGRATKEYEGTKIPAETIALRAAREYINETLPFLEFGTDVVVDVKLGEGSGDLQEVFGEDGKQVPMSNDTSFGVGHAPLTETERIVLEAERALNGDYSDDNPAVGQDIKVMGKREGDDIDVTVAVAMVDRYVDDLDGYEAAVAGVREFVADLATDYTDRNVSVHVNTADDYDEGAIYLTTTGTSAEQGDDGSVGRGNRSNGLITPNRSMSMEATSGKNPVNHIGKIYNLLSTEIARTVVDEVDGIREIRIRLLSQIGQPIDKPHVADANLVTEDGIEIADIEDEVEAIIDAELENVTSITERVIDGELTTF" + + seq = ProteinInfo.from_sequence(seq_string) + print(seq) diff --git a/pyeed/fetch/blast.py b/pyeed/fetch/blast.py new file mode 100644 index 00000000..11fcec32 --- /dev/null +++ b/pyeed/fetch/blast.py @@ -0,0 +1,164 @@ +import io +import asyncio +import logging +from typing import List +from enum import Enum, EnumMeta +from pydantic import BaseModel, Field +from Bio.Blast import NCBIWWW, NCBIXML +from Bio.Blast.Record import Blast as BlastRecord +from concurrent.futures import ThreadPoolExecutor + + +LOGGER = logging.getLogger(__name__) + + +class MetaEnum(EnumMeta): + def __contains__(cls, item): + try: + cls(item) + except ValueError: + return False + return True + + +class BaseEnum(Enum, metaclass=MetaEnum): + pass + + +class BlastProgram(BaseEnum): + BLASTP = "blastp" + BLASTN = "blastn" + BLASTX = "blastx" + TBLASTN = "tblastn" + TBLASTX = "tblastx" + + +class NCBIDataBase(BaseEnum): + NR = "nr" + UNIPROTKB = "swissprot" + PDB = "pdb" + REFSEQ = "refseq_protein" + + +class SubstitutionMatrix(BaseEnum): + BLOSUM45 = "BLOSUM45" + BLOSUM62 = "BLOSUM62" + BLOSUM80 = "BLOSUM80" + PAM30 = "PAM30" + PAM70 = "PAM70" + + +class Blast(BaseModel): + + query: str = Field( + description="The query sequence", + default=None, + ) + + n_hits: int = Field( + description="Maximum number of hits to return", + default=100, + ) + + evalue: float = Field( + description="Expectation value (E) to safe hits", + default=10, + ) + + matrix: str = Field( + description="Substitution matrix", + default=SubstitutionMatrix.BLOSUM62.value, + ) + + identity: float = Field( + description="Minimum identity to accept hit", + default=0.0, + ge=0.0, + le=1, + ) + + def run(self, program: str, ncbi_db: str) -> io.StringIO: + + assert ( + program in BlastProgram + ), f"Invalid program: {program}, valid programs: {BlastProgram}" + assert ( + ncbi_db in NCBIDataBase + ), f"Invalid database: {ncbi_db}, valid databases: {NCBIDataBase}" + + return NCBIWWW.qblast( + program, + ncbi_db, + self.query, + expect=self.evalue, + matrix_name=self.matrix, + hitlist_size=self.n_hits, + ) + + async def async_run( + self, + ncbi_db: str, + program: str = BlastProgram.BLASTP.value, + foreign_executor: ThreadPoolExecutor = None, + ) -> io.StringIO: + + assert program in BlastProgram + assert ncbi_db in NCBIDataBase + + if not foreign_executor: + executor = ThreadPoolExecutor() + else: + executor = foreign_executor + + loop = asyncio.get_running_loop() + + with executor as pool: + result = await loop.run_in_executor(pool, self.run, program, ncbi_db) + + if not foreign_executor: + executor.shutdown() + + return result + + def read(self, result: io.StringIO) -> BlastRecord: + return NCBIXML.read(result) + + def extract_accession(self, record: io.StringIO) -> List[str]: + + record = NCBIXML.read(record) + + hits = [] + for hit in record.alignments: + if hit.hsps[0].identities / hit.hsps[0].align_length > self.identity: + hits.append(hit.accession) + + return hits + + +if __name__ == "__main__": + from rich.status import Status + + async def main(): + seq = "MRNINVQLNPLSDIEKLQVELVERKGLGHPDYIADAVAEEASRKLSLYYLKKYGVILHHNLDKTLVVGGQATPRFKGGDVIQPIYIVVAGRATTEVKTESGIEQIPVGTIIIESVKEWIRNNFRYLDAEKHLIVDYKIGKGSTDLVGIFEAGKRVPLSNDTSFGVGFAPFTKLEKLVYETERHLNSKQFKAKLPEVGEDIKVMGLRRGNEVDLTIAMATISELIEDVNHYINVKEQAKNKILDLASKIAPDYDVRIYVNTGDKIDKNILYLTVTGTSAEHGDDGMTGRGNRGVGLITPMRPMSLEATAGKNPVNHVGKLYNVLANLIANKIAQEVKDVKFSQVQVLGQIGRPIDDPLIANVDVITYDGKLNDETKNEISGIVDEMLSSFNKLTELILEGKATLF" + blast = Blast(query=seq, n_hits=10) + executor = ThreadPoolExecutor(max_workers=4) + + blast_task = asyncio.create_task( + blast.async_run( + NCBIDataBase.UNIPROTKB.value, BlastProgram.BLASTP.value, executor + ) + ) + + with Status("Running BLAST"): + result = await blast_task + + parsed_result = blast.parse(result) + print(parsed_result) + + executor.shutdown() + + return parsed_result + + results = asyncio.run(main()) + + print(results) diff --git a/pyeed/fetch/ncbiproteinmapper.py b/pyeed/fetch/ncbiproteinmapper.py index abcea76e..81c37c34 100644 --- a/pyeed/fetch/ncbiproteinmapper.py +++ b/pyeed/fetch/ncbiproteinmapper.py @@ -22,18 +22,22 @@ class NCBIProteinMapper: def __init__(self): pass - def _to_seq_records(self, data: dict) -> List[SeqRecord]: + def _to_seq_records(self, responses: List[str]) -> List[SeqRecord]: """ Converts the fetched data to a list of `Bio.SeqRecord.SeqRecord` objects. """ - return SeqIO.parse(io.StringIO(data[0]), "gb") + records = [] + for response in responses: + records.extend(SeqIO.parse(io.StringIO(response), "gb")) - def map(self, seq_records: List[str]) -> List[ProteinInfo]: + return records + + def map(self, responses: List[str]) -> List[ProteinInfo]: """ Maps the fetched data to an instance of the `ProteinInfo` class. """ - seq_records = self._to_seq_records(seq_records) + seq_records = self._to_seq_records(responses) protein_infos = [] for record in seq_records: diff --git a/pyeed/fetch/proteinfetcher.py b/pyeed/fetch/proteinfetcher.py index 561e6770..c270dde8 100644 --- a/pyeed/fetch/proteinfetcher.py +++ b/pyeed/fetch/proteinfetcher.py @@ -23,7 +23,7 @@ def __init__(self, ids: List[str]): # self.ncbi_key = ncbi_key #TODO: Add NCBI key to NCBI requester nest_asyncio.apply() - async def fetch(self, force_terminal: bool = False): + async def fetch(self, **console_kwargs): """ Fetches protein data from various databases based on the provided IDs. @@ -40,7 +40,9 @@ async def fetch(self, force_terminal: bool = False): """ db_entries = SortIDs.sort(self.ids) - with Progress(console=Console(force_terminal=force_terminal)) as progress: + with Progress( + console=Console(**console_kwargs), + ) as progress: requesters: List[AsyncRequester] = [] for db_name, db_ids in db_entries.items(): @@ -112,9 +114,10 @@ async def fetch(self, force_terminal: bool = False): ) # map data to objects - ncbi_response, uniprot_response = self.identify_data_source(responses) + ncbi_responses, uniprot_response = self.identify_data_source(responses) + + ncbi_entries = NCBIProteinMapper().map(ncbi_responses) - ncbi_entries = NCBIProteinMapper().map(ncbi_response) uniprot_entries = [ UniprotMapper().map(*resp) for resp in uniprot_response.values() ] @@ -136,12 +139,14 @@ async def fetch(self, force_terminal: bool = False): task_id=task_id, progress=progress, batch_size=1, - rate_limit=10, + rate_limit=50, n_concurrent=20, ) taxonomies = await tax_requester.make_request() + progress.update(task_id, completed=len(unique_tax_ids)) + # map taxonomy data to objects organisms = [TaxonomyMapper().map(entry) for entry in taxonomies] @@ -166,23 +171,24 @@ def identify_data_source(self, responses: List[List[str]]): """ uniprot = {} + ncbi = None for response in responses: if response[0].startswith("LOCUS"): ncbi = response - print("NCBI detected") elif response[0].startswith("{"): - print("INTERPRO detected") uniprot[DBPattern.INTERPRO.name] = [ json.loads(entry) for entry in response ] elif response[0].startswith("["): - print("UNIPROT detected") uniprot[DBPattern.UNIPROT.name] = [ json.loads(entry)[0] for entry in response ] else: LOGGER.warning(f"Response could not be mapped to mapper: {response[0]}") + if not ncbi: + ncbi = [] + if uniprot: uniprot_dict = self.sort_uniprot_by_id(uniprot) else: diff --git a/pyeed/fetch/requester.py b/pyeed/fetch/requester.py index 8ea29027..c41d2c49 100644 --- a/pyeed/fetch/requester.py +++ b/pyeed/fetch/requester.py @@ -51,7 +51,7 @@ def _create_progress(self): self.progress = Progress(disable=True) self.task_id = self.progress.add_task("Requesting data...", total=len(self.ids)) - async def send_request(self, args: RequestArgs): + async def send_request(self, args: RequestArgs) -> str: """ Sends an asynchronous HTTP GET request to the specified URL using the provided AsyncClient. @@ -63,7 +63,6 @@ async def send_request(self, args: RequestArgs): Returns: str: The response text from the request. """ - client = args.client url = args.url