diff --git a/.gitignore b/.gitignore index b38e78e18..71ec16e6a 100644 --- a/.gitignore +++ b/.gitignore @@ -15,5 +15,5 @@ src/svtk/svtk/utils/utils.pyc src/svtk/svtk/vcfcluster.pyc /inputs/build/ /inputs/values/google_cloud.*.json -/carrot/.runs.json -/carrot/.configs.json +/tests/carrot/.runs.json +/tests/carrot/.configs.json diff --git a/carrot/ExpansionHunterDenovo/casecontrol/eval.wdl b/tests/carrot/ExpansionHunterDenovo/casecontrol/eval.wdl similarity index 100% rename from carrot/ExpansionHunterDenovo/casecontrol/eval.wdl rename to tests/carrot/ExpansionHunterDenovo/casecontrol/eval.wdl diff --git a/carrot/ExpansionHunterDenovo/casecontrol/eval_input_defaults.json b/tests/carrot/ExpansionHunterDenovo/casecontrol/eval_input_defaults.json similarity index 100% rename from carrot/ExpansionHunterDenovo/casecontrol/eval_input_defaults.json rename to tests/carrot/ExpansionHunterDenovo/casecontrol/eval_input_defaults.json diff --git a/carrot/ExpansionHunterDenovo/casecontrol/simulated_data/eval_input.json b/tests/carrot/ExpansionHunterDenovo/casecontrol/simulated_data/eval_input.json similarity index 100% rename from carrot/ExpansionHunterDenovo/casecontrol/simulated_data/eval_input.json rename to tests/carrot/ExpansionHunterDenovo/casecontrol/simulated_data/eval_input.json diff --git a/carrot/ExpansionHunterDenovo/casecontrol/simulated_data/test_input.json b/tests/carrot/ExpansionHunterDenovo/casecontrol/simulated_data/test_input.json similarity index 100% rename from carrot/ExpansionHunterDenovo/casecontrol/simulated_data/test_input.json rename to tests/carrot/ExpansionHunterDenovo/casecontrol/simulated_data/test_input.json diff --git a/carrot/ExpansionHunterDenovo/casecontrol/test_input_defaults.json b/tests/carrot/ExpansionHunterDenovo/casecontrol/test_input_defaults.json similarity index 100% rename from carrot/ExpansionHunterDenovo/casecontrol/test_input_defaults.json rename to tests/carrot/ExpansionHunterDenovo/casecontrol/test_input_defaults.json diff --git a/carrot/README.md b/tests/carrot/README.md similarity index 100% rename from carrot/README.md rename to tests/carrot/README.md diff --git a/carrot/carrot_helper.py b/tests/carrot/carrot_helper.py similarity index 99% rename from carrot/carrot_helper.py rename to tests/carrot/carrot_helper.py index e9474b292..ff46a3128 100644 --- a/carrot/carrot_helper.py +++ b/tests/carrot/carrot_helper.py @@ -19,7 +19,7 @@ # The directories where the WDLs # and their tests are located. -WDLS_DIR_RELATIVE = "../wdl" +WDLS_DIR_RELATIVE = "../../wdl" WDLS_DIR = "wdl" WDLS_TEST_DIR = "wdl_test" diff --git a/tests/utilities/README.md b/tests/utilities/README.md new file mode 100644 index 000000000..8461833e7 --- /dev/null +++ b/tests/utilities/README.md @@ -0,0 +1,20 @@ + +# Installation + +1. Create a virtual environment: + + ```shell + virtualenv .venv + ``` + +2. Activate the virtual environment: + + ```shell + source .venv/bin/activate + ``` + +3. Install requirements: + + ```shell + pip install -r requirements.txt + ``` \ No newline at end of file diff --git a/tests/utilities/default_downsampling_regions.bed b/tests/utilities/default_downsampling_regions.bed new file mode 100644 index 000000000..0a62a2a4c --- /dev/null +++ b/tests/utilities/default_downsampling_regions.bed @@ -0,0 +1,5 @@ +chr1 42000000 44000000 +chr4 156467305 156545919 +chr6 32000000 33000000 +chr16 21000000 23000000 +chrX 123409545 123667508 diff --git a/tests/utilities/generate_test_data.py b/tests/utilities/generate_test_data.py new file mode 100644 index 000000000..225745931 --- /dev/null +++ b/tests/utilities/generate_test_data.py @@ -0,0 +1,284 @@ +import argparse +import transformers +import json +import logging +import os + +from pathlib import Path +from dataclasses import dataclass +from google.cloud import storage +from typing import Callable, List, Type, Union + + +logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.INFO) + + +@dataclass +class Region: + chr: str + start: int + end: int + + +@dataclass +class Handler: + transformer: Union[transformers.BaseTransformer, Type[transformers.BaseTransformer]] + callback: Callable[[str, ...], dict] + + +SUBJECT_WORKFLOW_INPUTS = { + "GatherSampleEvidence": { + "bam_or_cram_file": Handler( + transformers.CramDownsampler, + lambda cram, index: {"bam_or_cram_file": cram, "bam_or_cram_index": index} + ), + "preprocessed_intervals": Handler( + transformers.BedToIntervalListConverter, + lambda x: {"preprocessed_intervals": x, "melt_metrics_intervals": x} + ), + "sd_locs_vcf": Handler( + transformers.VcfDownsampler, + lambda x: {"sd_locs_vcf": x} + ), + "primary_contigs_list": Handler( + transformers.PrimaryContigsDownsampler, + lambda x: {"primary_contigs_list": x} + ), + "primary_contigs_fai": Handler( + transformers.PrimaryContigsDownsampler, + lambda x: {"primary_contigs_fai": x} + ), + "wham_include_list_bed_file": Handler( + transformers.BedDownsampler, + lambda x: {"wham_include_list_bed_file": x} + ) + } +} + + +WORKFLOW_INPUTS_TO_DROP = { + "GatherSampleEvidence": [ + "melt_docker", + "melt_metrics_intervals", + "melt_standard_vcf_header" + ] +} + + +def parse_target_regions(input_filename): + regions = [] + with open(input_filename, "r") as f: + for line in f: + cols = line.strip().split("\t") + if len(cols) != 3: + raise ValueError( + f"Invalid line in {input_filename}. Expected a line with three columns, " + f"chr, start, and stop positions, found: {repr(line.strip())}") + regions.append(Region(str(cols[0]), int(cols[1]), int(cols[2]))) + return regions + + +def localize_file(input_filename, output_filename): + if os.path.isfile(output_filename): + logging.info(f"File {input_filename} exists locally, skipping localization.") + return + if input_filename.startswith("gs://"): + logging.info(f"Localizing from GCP; blob: {input_filename} ...") + download_blob_from_gs(input_filename, output_filename) + logging.info(f"Finished localizing blob {input_filename}") + else: + raise NotImplementedError() + + +def initialize_transformers( + working_dir: str, + reference_fasta: str, + reference_index: str, + sequence_dict_filename: str, + picard_path: str): + for _, inputs in SUBJECT_WORKFLOW_INPUTS.items(): + for _, handler in inputs.items(): + handler.transformer = handler.transformer( + working_dir=working_dir, + callback=handler.callback, + reference_fasta=reference_fasta, + reference_index=reference_index, + sequence_dict_filename=sequence_dict_filename, + picard_path=picard_path + ) + + +def update_workflow_json( + working_dir: str, input_filename: str, output_filename: str, output_filename_prefix: str, + regions: List[Region], bucket_name: str = None, blob_name: str = None): + with open(input_filename, "r") as f: + workflow_inputs = json.load(f) + + updated_workflow_inputs = {} + + for k, v in workflow_inputs.items(): + # Example of the following split: + # k="a.b.c" --> workflow_name="a.b", input_var="c" + workflow_name, input_var = k.rsplit(".", maxsplit=1) + + try: + handler = SUBJECT_WORKFLOW_INPUTS[workflow_name][input_var] + except KeyError: + if workflow_name in WORKFLOW_INPUTS_TO_DROP and input_var in WORKFLOW_INPUTS_TO_DROP[workflow_name]: + logging.info(f"Dropping {k}.") + else: + updated_workflow_inputs[k] = v + logging.info(f"Leaving {k} unchanged.") + continue + + logging.info(f"Processing input {k}.") + workflow_input_local_filename = Path(working_dir).joinpath(Path(v).name) + localize_file(v, workflow_input_local_filename) + updated_files = handler.transformer.transform( + input_filename=workflow_input_local_filename, + output_prefix=output_filename_prefix, + regions=regions + ) + + for varname, filename in updated_files.items(): + input_key = f"{workflow_name}.{varname}" + updated_workflow_inputs[input_key] = filename + if bucket_name is not None and blob_name is not None: + logging.info(f"Uploading downsampled file {filename} to bucket {bucket_name}.") + blob = upload_to_gs_blob(filename, bucket_name, blob_name) + logging.info(f"Finished uploading {filename}.") + updated_workflow_inputs[input_key] = blob + logging.info(f"Creating output JSON {output_filename}.") + with open(output_filename, "w") as f: + json.dump(updated_workflow_inputs, f, indent=4) + logging.info(f"Finished creating output JSON {output_filename}.") + + +def upload_to_gs_blob(filename, bucket_name, blob_name): + path = Path(filename) + blob_name = blob_name + "/" + path.stem + "".join(path.suffix) + blob = storage.Client().bucket(bucket_name).blob(blob_name) + blob.upload_from_filename(filename) + return "gs://" + bucket_name + "/" + blob_name + + +def download_blob_from_gs(gs_link, local_filename): + bucket_name, blob_name = gs_link.split("/", maxsplit=3)[2:] + blob = storage.Client().bucket(bucket_name).blob(blob_name) + blob.download_to_filename(local_filename) + + +def main(): + parser = argparse.ArgumentParser( + description="This is a utility script to downsample the inputs of a workflow " + "in order to run the workflow faster or prepare data for unit testing." + "In addition to other inputs, the script takes a JSON file containing the inputs to a workflow, " + "downsamples the inputs according to the defined rules (see `SUBJECT_WORKFLOW_INPUTS`), " + "pushes the downsampled files to a given cloud storage, and creates a new JSON " + "file with the updated downsampled inputs." + "This script needs samtools version 1.19.2 or newer installed and added to PATH.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument( + "-w", "--working-dir", + default=os.getcwd(), + help="Sets the working directory where the downsampled files will be stored before pushed to a cloud storage." + ) + + parser.add_argument( + "input_workflow_json", + help="Sets a JSON filename containing the inputs to a workflow." + ) + + parser.add_argument( + "--output-workflow-json", + help="Sets a JSON filename containing the updated input arguments from the input JSON." + "The default value is a JSON file created in the working directory with the same " + "name as the input JSON with an added prefix." + ) + + parser.add_argument( + "picard_path", + help="Sets the absolute path to `picard.jar`." + "You may download picard.jar from `https://github.com/broadinstitute/picard/releases`." + ) + + parser.add_argument( + "--output-filename-prefix", + default="downsampled_", + help="Sets a prefix to be added to all the output files generated." + ) + + parser.add_argument( + "--bucket-name", + help="Sets the cloud bucket name where the downsampled files will be pushed. " + "The script skips uploading to cloud if a value for this argument is not provided." + ) + + parser.add_argument( + "--blob-name", + help="Sets the cloud blob name where the downsampled files will be pushed." + "The script skips uploading to cloud if a value for this argument is not provided." + ) + + this_script_folder = os.path.dirname(os.path.abspath(__file__)) + gatk_sv_path = os.path.dirname(os.path.dirname(this_script_folder)) + + parser.add_argument( + "--target-regions", + default=os.path.join(gatk_sv_path, "tests", "utilities", "default_downsampling_regions.bed"), + help="Sets a BED filename containing target regions for downsampling, " + "such that the downsampled files contains data from the input files overlapping these regions." + ) + + parser.add_argument( + "--reference-fasta", + default="gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.fasta", + help="Set the path to reference fasta file. " + ) + + parser.add_argument( + "--reference-fasta-index", + default="gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.fasta.fai", + help="Set the path to index file of the reference fasta file. " + ) + + parser.add_argument( + "--reference-dict", + default="gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.dict", + help="Set the path to the reference dictionary file." + ) + + args = parser.parse_args() + + regions = parse_target_regions(args.target_regions) + logging.info(f"Found {len(regions)} target regions for downsampling.") + + output_workflow_json = args.output_workflow_json + if not output_workflow_json: + output_workflow_json = os.path.join( + args.working_dir, + f"{args.output_filename_prefix}{Path(args.input_workflow_json).name}" + ) + + Path(args.working_dir).mkdir(parents=True, exist_ok=True) + + sequence_dict_local_filename = Path(args.working_dir).joinpath(Path(args.reference_dict).name) + localize_file(args.reference_dict, sequence_dict_local_filename) + initialize_transformers(args.working_dir, args.reference_fasta, args.reference_fasta_index, + sequence_dict_local_filename, args.picard_path) + + update_workflow_json( + working_dir=args.working_dir, + input_filename=args.input_workflow_json, + output_filename=output_workflow_json, + output_filename_prefix=args.output_filename_prefix, + regions=regions, + bucket_name=args.bucket_name, + blob_name=args.blob_name) + + logging.info("All process finished successfully.") + + +if __name__ == '__main__': + main() diff --git a/tests/utilities/requirements.txt b/tests/utilities/requirements.txt new file mode 100644 index 000000000..6e16d8d3b --- /dev/null +++ b/tests/utilities/requirements.txt @@ -0,0 +1,2 @@ +google-cloud-storage +pysam diff --git a/tests/utilities/transformers.py b/tests/utilities/transformers.py new file mode 100644 index 000000000..9b52c8b3a --- /dev/null +++ b/tests/utilities/transformers.py @@ -0,0 +1,248 @@ +import logging +import os +import pysam +import subprocess +import uuid + +from typing import Callable, List +from dataclasses import dataclass +from pathlib import Path +from tqdm import tqdm + + +logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.INFO) + + +@dataclass +class Region: + chr: str + start: int + end: int + + @staticmethod + def to_file(working_dir, regions): + filename = str(uuid.uuid4()) + filename = os.path.join(working_dir, filename + ".bed") + with open(filename, "w") as regions_file: + for r in regions: + regions_file.write("\t".join([str(r.chr), str(r.start), str(r.end)]) + "\n") + return filename + + +class BaseTransformer: + def __init__(self, working_dir: str, callback: Callable[[str, ...], dict]): + # Convert the string to an ABS path and make sure it exists. + self.working_dir = Path(working_dir).resolve(strict=True) + self.callback = callback + + @staticmethod + def get_supported_file_types() -> List[str]: + """ + The file types should include the '.' prefix to match with Path().suffix output + (e.g., it should return '.cram' instead of 'cram'). + """ + raise NotImplementedError() + + def get_output_filename(self, input_filename, output_prefix): + return str(self.working_dir.joinpath(f"{output_prefix}{Path(input_filename).name}")) + + def transform(self, input_filename: str, output_prefix: str, regions: List[Region], **kwargs): + raise NotImplementedError() + + +class BedToIntervalListConverter(BaseTransformer): + def __init__(self, working_dir, callback: Callable[[str], dict], sequence_dict_filename: str, picard_path: str, **kwargs): + super().__init__(working_dir, callback) + self.sequence_dict_filename = sequence_dict_filename + self.picard_path = picard_path + + @staticmethod + def get_supported_file_types() -> List[str]: + return [".interval_list"] + + def transform(self, input_filename: str, output_prefix: str, regions: List[Region], **kwargs) -> dict: + output_filename = self.get_output_filename(input_filename, output_prefix) + regions_filename = Region.to_file(self.working_dir, regions) + + subprocess.run( + ["java", "-jar", self.picard_path, "BedToIntervalList", + "-I", regions_filename, "-O", output_filename, "-SD", self.sequence_dict_filename], + check=True) + + os.remove(regions_filename) + return self.callback(output_filename) + + +class CramDownsampler(BaseTransformer): + def __init__(self, working_dir, callback: Callable[[str, str], dict], reference_fasta: str, reference_index: str, **kwargs): + super().__init__(working_dir, callback) + self.reference_fasta = reference_fasta + self.reference_index = reference_index + + @staticmethod + def get_supported_file_types() -> List[str]: + return [".cram"] + + def transform(self, input_filename: str, output_prefix: str, regions: List[Region], **kwargs) -> dict: + include_discordant_reads = kwargs.get("include_discordant_reads", False) + output_filename_unsorted = self.get_output_filename(input_filename, f"unsorted_{output_prefix}") + with pysam.AlignmentFile(input_filename, "rc") as cram: + header = cram.header + references = cram.references + + reads_for_second_pass = {} + with pysam.AlignmentFile(output_filename_unsorted, "wc", header=header, reference_names=references) as output_cram_file: + for region in regions: + generator = self.read_pair_generator(input_filename, region=f"{region.chr}:{region.start}-{region.end}") + try: + while True: + r1, r2 = next(generator) + output_cram_file.write(r1) + output_cram_file.write(r2) + except StopIteration as e: + reads_for_second_pass = reads_for_second_pass | e.value + + if include_discordant_reads: + for r1, r2 in self.get_discordant_reads(input_filename, reads_for_second_pass): + output_cram_file.write(r1) + output_cram_file.write(r2) + + output_filename = self.get_output_filename(input_filename, output_prefix) + pysam.sort("-o", output_filename, output_filename_unsorted) + os.remove(output_filename_unsorted) + index_filename = f"{output_filename}.crai" + pysam.index(output_filename, index_filename) + return self.callback(output_filename, index_filename) + + @staticmethod + def read_pair_generator(filename: str, region: str = None): + """ + Generate read pairs in a CRAM file. + If the `region` string is provided, it only generates the reads overlapping this region. + """ + read_dict = {} + with pysam.AlignmentFile(filename, "rc") as cram: + for read in cram.fetch(region=region): + q_name = read.query_name + if not read.is_proper_pair or read.is_secondary or read.is_supplementary: + continue + if q_name not in read_dict: + read_dict[q_name] = [None, None] + read_dict[q_name][0 if read.is_read1 else 1] = read + else: + if read.is_read1: + yield read, read_dict[q_name][1] + else: + yield read_dict[q_name][0], read + del read_dict[q_name] + + return read_dict + + @staticmethod + def get_discordant_reads(filename: str, discordant_reads: dict): + logging.info(f"Linearly scanning {filename} for discordant pairs of {len(discordant_reads)} reads.") + with pysam.AlignmentFile(filename, "rc") as cram: + for read in tqdm(cram.fetch(), desc="Iterating on reads", unit=" read", dynamic_ncols=True, + bar_format="{desc}: {n:,} [{elapsed}] {rate_fmt}"): + query_name = read.query_name + pair = discordant_reads.get(query_name) + if pair is not None: + if (read.is_read1 and pair[0] is not None) or (not read.is_read1 and pair[1] is not None): + # It is the same read as the one seen before. + continue + if read.is_read1: + yield read, pair[1] + else: + yield pair[0], read + del discordant_reads[query_name] + if len(discordant_reads) > 0: + logging.warning(f"Did not find discordant pairs for {len(read)} reads.") + logging.info("Finished linear scanning.") + + +class VcfDownsampler(BaseTransformer): + def __init__(self, working_dir, callback: Callable[[str], dict], **kwargs): + super().__init__(working_dir, callback) + + @staticmethod + def get_supported_file_types() -> List[str]: + return [".vcf"] + + def transform(self, input_filename: str, output_prefix: str, regions: List[Region], **kwargs) -> dict: + output_filename = self.get_output_filename(input_filename, output_prefix) + with \ + pysam.VariantFile(input_filename) as input_file, \ + pysam.VariantFile(output_filename, "w", header=input_file.header) as output_file: + for record in input_file: + for region in regions: + if record.contig == region.chr and region.start <= record.pos <= region.end: + output_file.write(record) + break + return self.callback(output_filename) + + +class IntervalListDownsampler(BaseTransformer): + def __init__(self, working_dir, callback: Callable[[str], dict], **kwargs): + super().__init__(working_dir, callback) + + @staticmethod + def get_supported_file_types() -> List[str]: + return [".interval_list"] + + def transform(self, input_filename: str, output_prefix: str, regions: List[Region], **kwargs) -> dict: + output_filename = self.get_output_filename(input_filename, output_prefix) + # Note that this algorithm is not efficient. + with open(input_filename, "r") as input_file, open(output_filename, "w") as output_file: + for line in input_file: + if line.startswith("@"): + output_file.write(line) + else: + cols = line.rstrip().split() + chr, start, end = cols[0], int(cols[1]), int(cols[2]) + for region in regions: + if chr == region.chr and max(start, region.start) < min(end, region.end): + output_file.write(line) + break + return self.callback(output_filename) + + +class BedDownsampler(BaseTransformer): + def __init__(self, working_dir, callback: Callable[[str], dict], **kwargs): + super().__init__(working_dir, callback) + + @staticmethod + def get_supported_file_types() -> List[str]: + return [".bed"] + + def transform(self, input_filename: str, output_prefix: str, regions: List[Region], **kwargs) -> dict: + output_filename = self.get_output_filename(input_filename, output_prefix) + # Note that this algorithm is not efficient. + with open(input_filename, "r") as input_file, open(output_filename, "w") as output_file: + for line in input_file: + cols = line.rstrip().split() + chr, start, end = cols[0], int(cols[1]), int(cols[2]) + for region in regions: + if chr == region.chr and max(start, region.start) < min(end, region.end): + output_file.write(line) + break + return self.callback(output_filename) + + +class PrimaryContigsDownsampler(BaseTransformer): + def __init__(self, working_dir, callback: Callable[[str], dict], delimiter: str = "\t", **kwargs): + super().__init__(working_dir, callback) + self.delimiter = delimiter + + @staticmethod + def get_supported_file_types() -> List[str]: + return [] + + def transform(self, input_filename: str, output_prefix: str, regions: List[Region], **kwargs) -> dict: + output_filename = self.get_output_filename(input_filename, output_prefix) + include_chrs = set([r.chr for r in regions]) + with open(input_filename, "r") as input_file, open(output_filename, "w") as output_file: + for line in input_file: + cols = line.strip().split(self.delimiter) + if cols[0] in include_chrs: + output_file.write(self.delimiter.join(cols) + "\n") + return self.callback(output_filename)