From a59601d633a96a4dbb17c961fa957d5f3d7d1bdb Mon Sep 17 00:00:00 2001 From: Matthew Andres Moreno Date: Thu, 21 Mar 2024 18:33:04 -0400 Subject: [PATCH] Impl gen3sis reconstruction generation Note workaround for https://github.com/project-gen3sis/R-package/issues/70 --- ...age=1+what=make_reconstructions_gen3sis.sh | 299 ++++++++++++------ requirements.in | 2 +- requirements.txt | 2 +- 3 files changed, 204 insertions(+), 99 deletions(-) diff --git a/pipeline-gen3sis/stage=1+what=make_reconstructions_gen3sis.sh b/pipeline-gen3sis/stage=1+what=make_reconstructions_gen3sis.sh index 91cbd1382..807ddaafe 100755 --- a/pipeline-gen3sis/stage=1+what=make_reconstructions_gen3sis.sh +++ b/pipeline-gen3sis/stage=1+what=make_reconstructions_gen3sis.sh @@ -59,12 +59,46 @@ sleep "${SLEEP_DURATION}" done PYSCRIPT=$(cat << HEREDOC +import multiprocessing.pool as mpp + +# https://stackoverflow.com/a/65854996 +def istarmap(self, func, iterable, chunksize=1): + """starmap-version of imap""" + self._check_running() + if chunksize < 1: + raise ValueError( + "Chunksize must be 1+, not {0:n}".format( + chunksize)) + + task_batches = mpp.Pool._get_tasks(func, iterable, chunksize) + result = mpp.IMapIterator(self) + self._taskqueue.put( + ( + self._guarded_task_generation(result._job, + mpp.starmapstar, + task_batches), + result._set_length + )) + return (item for chunk in result for item in chunk) + +mpp.Pool.istarmap = istarmap + + +import functools import gc import glob +import itertools as it import logging import multiprocessing import os - +import pathlib +import random + +import alifedata_phyloinformatics_convert as apc +from hstrat import _auxiliary_lib as hstrat_aux +from hstrat import hstrat +from keyname import keyname as kn +import numpy as np import pandas as pd from retry import retry from tqdm import tqdm @@ -89,102 +123,173 @@ logging.info(f"""first globbed phylogeny path is { globbed_phylogeny_paths[0] }""") -# read_csv_with_retry = retry( -# tries=10, -# delay=1, -# max_delay=10, -# backoff=2, -# jitter=(0, 4), -# logger=logging, -# )( -# # wrap is workaround for retry compatibility -# lambda *args, **kwargs: pd.read_csv(*args, **kwargs) -# ) - -# collated_audit_df = pd.concat( -# ( -# read_csv_with_retry(audit_path) -# for audit_path in tqdm( -# globbed_audit_paths, -# desc="audit_paths", -# mininterval=10, -# ) -# ), -# ignore_index=True, -# join="outer", -# ) -# logging.info( -# "collated audit dataframe constructed " -# f"with {len(collated_audit_df)} rows" -# ) -# collated_audit_path = ( -# "${STAGE_PATH}/latest/a=collated-reconstruction-audits+ext=.csv" -# ) - -# retry( -# tries=10, -# delay=1, -# max_delay=10, -# backoff=2, -# jitter=(0, 4), -# logger=logging, -# )( -# # wrap is workaround for retry compatibility -# lambda path, index: collated_audit_df.to_csv(path, index=index) -# )(collated_audit_path, index=False) - -# logging.info(f"collated audit written to {collated_audit_path}") - -# del collated_audit_df -# gc.collect() -# logging.info(f"collated_audit_df cleared from memory") - -# collated_provlog_path = collated_audit_path + ".provlog.yaml" - -# # adapted from https://stackoverflow.com/a/74214157 -# def read_file_bytes(path: str, size: int = -1) -> bytes: -# fd = os.open(path, os.O_RDONLY) -# try: -# if size == -1: -# size = os.fstat(fd).st_size -# return os.read(fd, size) -# finally: -# os.close(fd) - -# @retry( -# tries=10, -# delay=1, -# max_delay=10, -# backoff=2, -# jitter=(0, 4), -# logger=logging, -# ) -# def do_collate_provlogs(): -# with multiprocessing.Pool(processes=None) as pool: -# contents = [*pool.imap( -# read_file_bytes, -# ( -# f"{audit_path}.provlog.yaml" -# for audit_path in tqdm( -# globbed_audit_paths, -# desc="provlog_files", -# mininterval=10, -# ) -# ), -# )] -# logging.info("contents read in from provlogs") - -# with open(collated_provlog_path, "wb") as f_out: -# f_out.writelines( -# tqdm( -# contents, -# desc="provlog_contents", -# mininterval=10, -# ), -# ) -# do_collate_provlogs() - -# logging.info(f"collated provlog written to {collated_provlog_path}") +open_retry = retry( + tries=10, delay=1, max_delay=10, backoff=2, jitter=(0, 4), logger=logging, +)(open) + +def reconstruct_one( + template_path: str, recency_proportional_resolution: int +) -> None: + template_df = retry( + tries=10, delay=1, max_delay=10, backoff=2, jitter=(0, 4), logger=logging, + )( + apc.RosettaTree.from_nexus, + )( + pathlib.Path(template_path), + ).as_alife + template_df = hstrat_aux.alifestd_to_working_format(template_df) + assert hstrat_aux.alifestd_validate(template_df) + assert hstrat_aux.alifestd_has_contiguous_ids(template_df) + assert hstrat_aux.alifestd_is_topologically_sorted(template_df) + + # Gen3sis sometimes produces negative branch lengths... set them to 0 + # https://github.com/project-gen3sis/R-package/issues/70 + template_df = hstrat_aux.alifestd_coerce_chronological_consistency( + template_df, mutate=True + ) + collapsed_df = hstrat_aux.alifestd_collapse_unifurcations( + template_df, + mutate=True, + ) + collapsed_df = hstrat_aux.alifestd_to_working_format(collapsed_df) + + attrs = kn.unpack(kn.rejoin(template_path.replace( + "/", "+", + ))) + hstrat_aux.seed_random( random.Random( + f"{ attrs['seed'] } " + f"{ recency_proportional_resolution } " + ).randrange(2**32) ) + + seed_column = hstrat.HereditaryStratigraphicColumn( + hstrat.recency_proportional_resolution_algo.Policy( + int(recency_proportional_resolution) + ), + stratum_differentia_bit_width=8, + ) + extant_population = hstrat.descend_template_phylogeny_alifestd( + collapsed_df, + seed_column, + ) + + reconstruction_postprocesses = ("naive",) + tree_ensemble = hstrat.build_tree_trie_ensemble( + extant_population, + trie_postprocessors=[ + # naive + hstrat.CompoundTriePostprocessor( + postprocessors=[ + hstrat.AssignOriginTimeNaiveTriePostprocessor(), + hstrat.AssignDestructionTimeYoungestPlusOneTriePostprocessor(), + ], + ), + ], + ) + logging.info(f"tree_ensemble has size {len(tree_ensemble)}") + + reconstruction_dfs = [*map( + functools.partial( + hstrat_aux.alifestd_assign_root_ancestor_token, + root_ancestor_token="None", + ), + tree_ensemble, + )] + logging.info(f"reconstruction_dfs has size {len(reconstruction_dfs)}") + + # check data validity + for postprocess, reconstruction_df in zip( + reconstruction_postprocesses, reconstruction_dfs + ): + assert hstrat_aux.alifestd_validate(reconstruction_df), postprocess + + reconstruction_filenames = [*map( + lambda postprocess: kn.pack({ + **kn.unpack(kn.rejoin( + template_path.replace("/", "+"), + )), + **{ + "a" : "reconstructed-tree", + "trie-postprocess" : postprocess, + "subsampling-fraction" : 1, + "ext" : ".csv.gz", + }, + }), + reconstruction_postprocesses, + )] + logging.info(f"""reconstruction_filenames has size { + len(reconstruction_filenames) + }""") + + def setup_reconstruction_paths(): + return [ + kn.chop( + f"${BATCH_PATH}/" + f"""epoch={ + 0 + }+resolution={ + recency_proportional_resolution + }+subsampling_fraction={ + 1 + }+seed={ + attrs['seed'] + }+treatment={ + kn.unpack(kn.rejoin( + template_path.replace("/", "+"), + ))["treatment"] + }/""" + f"{reconstruction_filename}", + mkdir=True, + logger=logging, + ) + for reconstruction_filename in reconstruction_filenames + ] + reconstruction_paths = retry( + tries=10, delay=1, max_delay=10, backoff=2, jitter=(0, 4), logger=logging, + )(setup_reconstruction_paths)() + logging.info(f"""reconstruction_paths has size { + len(reconstruction_paths) + }""") + + for reconstruction_path, reconstruction_df in zip( + reconstruction_paths, reconstruction_dfs + ): + retry( + tries=10, delay=1, max_delay=10, backoff=2, jitter=(0, 4), logger=logging, + )(reconstruction_df.to_csv)(reconstruction_path, index=False) + logging.info(f"wrote reconstructed tree to {reconstruction_path}") + + provlog_path = f"{reconstruction_path}.provlog.yaml" + @retry( + tries=10, delay=1, max_delay=10, backoff=2, jitter=(0, 4), logger=logging, + ) + def do_save(): + with open_retry(provlog_path, "a+") as provlog_file: + provlog_file.write( + f"""- + a: {provlog_path} + batch: ${BATCH} + date: $(date --iso-8601=seconds) + hostname: $(hostname) + revision: ${REVISION} + runmode: ${RUNMODE} + user: $(whoami) + uuid: $(uuidgen) + slurm_job_id: ${SLURM_JOB_ID-none} + stage: 1 + stage 1 batch path: $(readlink -f "${BATCH_PATH}") + template_path: {template_path} + """, + ) + do_save() + +cpu_count = multiprocessing.cpu_count() +logging.info(f"cpu_count {cpu_count}") +with multiprocessing.Pool(processes=cpu_count) as pool: + args = [*it.product(globbed_phylogeny_paths, [3, 10, 33, 100])] + [*tqdm( + pool.istarmap(reconstruct_one, args), + total=len(args), + )] logging.info("PYSCRIPT complete") diff --git a/requirements.in b/requirements.in index 0ed8c7e5a..b7a599c08 100644 --- a/requirements.in +++ b/requirements.in @@ -1,6 +1,6 @@ alifedata_phyloinformatics_convert==0.16.2 ALifeStdDev==0.2.4 -hstrat==1.11.5 +hstrat==1.11.7 iterpop==0.4.1 j2cli==0.3.10 jinja2==3.0.0 diff --git a/requirements.txt b/requirements.txt index 35aade576..05de2489f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -50,7 +50,7 @@ gitdb==4.0.10 # via gitpython gitpython==3.1.31 # via lyncs-setuptools -hstrat==1.11.5 +hstrat==1.11.7 # via -r requirements.in idna==3.6 # via yarl