Skip to content

Commit

Permalink
Deduplicate records and CPX variant IDs in ResolveComplexVariants (#576)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalker174 authored Aug 8, 2023
1 parent 08f3961 commit 0119b3f
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 118 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#!/bin/env python

import argparse
import sys
from typing import Optional, List, Text, Dict, Set, Callable

import pysam

MEMBERS_KEY = "MEMBERS"
UNRESOLVED_KEY = "UNRESOLVED"


def update_header(header: pysam.VariantHeader) -> None:
header.add_line('##INFO=<ID=UNRESOLVED,Number=0,Type=Flag,Description="Variant is unresolved.">')
header.add_line('##INFO=<ID=UNRESOLVED_TYPE,Number=1,Type=String,Description=\"Class of unresolved variant.\">')


def is_unresolved(record: pysam.VariantRecord):
return record.info.get(UNRESOLVED_KEY, None)


def is_resolved(record: pysam.VariantRecord):
return not is_unresolved(record)


def get_members(record: pysam.VariantRecord):
return list(record.info[MEMBERS_KEY]) if isinstance(record.info[MEMBERS_KEY], tuple) \
else [record.info[MEMBERS_KEY]] if record.info[MEMBERS_KEY] is not None \
else list()


def get_vids_and_members_sets(vcf: pysam.VariantFile,
predicate: Callable) -> Dict:
unresolved_vids_set = set()
unresolved_members_set = set()
for r in vcf:
if predicate(r):
unresolved_vids_set.add(r.id)
unresolved_members_set.update(get_members(r))
vcf.reset()
return unresolved_vids_set, unresolved_members_set


def write_vcf(header: pysam.VariantHeader,
all_vcf: pysam.VariantFile,
inv_vcf: pysam.VariantFile,
inv_resolved_vids_set: Set,
inv_resolved_members_set: Set,
all_unresolved_vids_set: Set,
all_unresolved_members_set: Set) -> None:
sys.stdout.write(str(header))
for r in all_vcf:
if r.id not in all_unresolved_vids_set or r.id not in inv_resolved_members_set:
# Resolved in ALL vcf, or unresolved in both VCFs
sys.stdout.write(str(r))
for r in inv_vcf:
if r.id in inv_resolved_vids_set:
# Resolved variant in the INV vcf
members = get_members(r)
if all((m in all_unresolved_members_set) for m in members):
# Resolved in the INV vcf and every member unresolved in the ALL vcf
sys.stdout.write(str(r))


def __parse_arguments(argv: List[Text]) -> argparse.Namespace:
# noinspection PyTypeChecker
parser = argparse.ArgumentParser(
description="Integrates inversion-only and all-SV VCFs from the complex resolve module. "
"Unsorted output is written to stdout.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--all-vcf", type=str, required=True,
help="Complex-resolved VCF containing all SVs")
parser.add_argument("--inv-only-vcf", type=str, required=True,
help="Complex-resolved VCF containing only inversions")
if len(argv) <= 1:
parser.parse_args(["--help"])
sys.exit(0)
parsed_arguments = parser.parse_args(argv[1:])
return parsed_arguments


def main(argv: Optional[List[Text]] = None):
if argv is None:
argv = sys.argv
arguments = __parse_arguments(argv)
with pysam.VariantFile(arguments.all_vcf) as all_vcf, \
pysam.VariantFile(arguments.inv_only_vcf) as inv_vcf:
header = all_vcf.header
update_header(header)
inv_resolved_vids_set, inv_resolved_members_set = get_vids_and_members_sets(inv_vcf, is_resolved)
all_unresolved_vids_set, all_unresolved_members_set = get_vids_and_members_sets(all_vcf, is_unresolved)
write_vcf(header=header,
all_vcf=all_vcf,
inv_vcf=inv_vcf,
inv_resolved_vids_set=inv_resolved_vids_set,
inv_resolved_members_set=inv_resolved_members_set,
all_unresolved_vids_set=all_unresolved_vids_set,
all_unresolved_members_set=all_unresolved_members_set)


if __name__ == "__main__":
main()
115 changes: 40 additions & 75 deletions src/svtk/svtk/cli/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import argparse
import sys
import subprocess
import numpy as np
import string
from collections import deque
from operator import attrgetter
import itertools
Expand All @@ -36,6 +34,17 @@
]


class RecordNamer:
def __init__(self, prefix='CPX_', num_digits=6):
self.prefix = prefix
self.num_digits = num_digits
self.count = 0

def get_next_id(self, record):
self.count += 1
return f"{self.prefix}_{record.contig}_{str(self.count - 1).zfill(self.num_digits)}"


def _merge_records(vcf, cpx_records, cpx_record_ids):
"""
r1, r2 : iter of pysam.VariantRecord
Expand Down Expand Up @@ -81,18 +90,6 @@ def _next_cpx():
curr_cpx = _next_cpx()


def remove_CPX_from_INV(resolve_CPX, resolve_INV):
"""
Return list of inversion calls not overlapped by list of complex calls
"""
cpx_interval = [(i.chrom, i.pos, i.stop) for i in resolve_CPX]
out = [
inv for inv in resolve_INV
if not any(cpx[0] == inv.chrom and cpx[1] <= i.stop and i.pos <= cpx[2] for cpx in cpx_interval)
]
return out


def multisort(xs, specs):
for key, reverse in reversed(specs):
xs.sort(key=attrgetter(key), reverse=reverse)
Expand Down Expand Up @@ -137,17 +134,8 @@ def clusters_cleanup(clusters):
return deque(cluster_single_cleanup(cluster) for cluster in clusters)


def get_random_string(random_string_len):
"""
Produce string of random upper-case characters and digits, of requested length
"""
return ''.join(np.random.choice(list(string.ascii_uppercase + string.digits))
for _ in range(random_string_len))


def resolve_complex_sv(vcf, cytobands, disc_pairs, mei_bed, variant_prefix='CPX_',
min_rescan_support=4, pe_blacklist=None, quiet=False,
SR_only_cutoff=1000, random_resolved_id_length=10):
def resolve_complex_sv(vcf, cytobands, disc_pairs, mei_bed, resolved_record_namer, unresolved_record_namer,
min_rescan_support=4, pe_blacklist=None, quiet=False, SR_only_cutoff=1000):
"""
Resolve complex SV from CNV intervals and BCA breakpoints.
Yields all resolved events, simple or complex, in sorted order.
Expand All @@ -157,8 +145,10 @@ def resolve_complex_sv(vcf, cytobands, disc_pairs, mei_bed, variant_prefix='CPX_
cytobands : pysam.TabixFile
disc_pairs : pysam.TabixFile
mei_bed : pybedtools.BedTool
variant_prefix : str
Prefix to assign to resolved variants
resolved_record_namer: RecordNamer
RecordNamer object for resolved variants
unresolved_record_namer: RecordNamer
RecordNamer object for unresolved variants
min_rescan_support : int
Number of pairs required to count a sample as
supported during PE rescan
Expand All @@ -182,14 +172,8 @@ def resolve_complex_sv(vcf, cytobands, disc_pairs, mei_bed, variant_prefix='CPX_
'identified ' + str(len(clusters)) + ' candidate complex clusters ' +
'during first pass', flush=True)

# resolved_idx = unresolved_idx = 1

if not variant_prefix.endswith('_'):
variant_prefix += '_'

cpx_records = deque()
cpx_record_ids = set()
np.random.seed(1) # arbitrary fixed seed for reproducibility

for cluster in clusters:
# Print status for each cluster
Expand All @@ -212,22 +196,15 @@ def resolve_complex_sv(vcf, cytobands, disc_pairs, mei_bed, variant_prefix='CPX_
for record in cluster:
cpx = ComplexSV([record], cytobands, mei_bed, SR_only_cutoff)
cpx_record_ids = cpx_record_ids.union(cpx.record_ids)

# Assign random string as resolved ID to handle sharding
cpx.vcf_record.id = variant_prefix + \
get_random_string(random_resolved_id_length)
cpx.vcf_record.id = resolved_record_namer.get_next_id(cpx.vcf_record)
cpx_records.append(cpx.vcf_record)
# resolved_idx += 1
outcome = 'treated as separate unrelated insertions'
else:
cpx = ComplexSV(cluster, cytobands, mei_bed, SR_only_cutoff)
cpx_record_ids = cpx_record_ids.union(cpx.record_ids)
if cpx.svtype == 'UNR':
# Assign random string as unresolved ID to handle sharding
unresolved_vid = 'UNRESOLVED_' + \
get_random_string(random_resolved_id_length)
for record in cpx.records:
record.info['EVENT'] = unresolved_vid
record.info['EVENT'] = unresolved_record_namer.get_next_id(cpx.vcf_record)
record.info['UNRESOLVED'] = True
cpx_records.append(record)
# unresolved_idx += 1
Expand All @@ -252,8 +229,7 @@ def resolve_complex_sv(vcf, cytobands, disc_pairs, mei_bed, variant_prefix='CPX_
'The following records were merged into the INS record: ' + \
', '.join(cnv_ids_to_append)
else:
cpx.vcf_record.id = variant_prefix + \
get_random_string(random_resolved_id_length)
cpx.vcf_record.id = resolved_record_namer.get_next_id(cpx.vcf_record)
cpx_records.append(cpx.vcf_record)
if 'CPX_TYPE' in cpx.vcf_record.info.keys():
outcome = 'resolved as ' + \
Expand Down Expand Up @@ -310,15 +286,13 @@ def cluster_cleanup(clusters_v2):


def resolve_complex_sv_v2(resolve_INV, cytobands, disc_pairs,
mei_bed, variant_prefix='CPX_', min_rescan_support=4,
pe_blacklist=None, quiet=False, SR_only_cutoff=1000,
random_resolved_id_length=10):
mei_bed, resolved_record_namer, unresolved_record_namer,
min_rescan_support=4, pe_blacklist=None, quiet=False,
SR_only_cutoff=1000):
linked_INV = cluster_INV(resolve_INV)
clusters_v2 = link_cpx_V2(linked_INV, cpx_dist=2000)
clusters_v2 = cluster_cleanup(clusters_v2)

np.random.seed(0) # arbitrary fixed seed for reproducibility

# Print number of candidate clusters identified
if not quiet:
now = datetime.datetime.now()
Expand Down Expand Up @@ -349,29 +323,22 @@ def resolve_complex_sv_v2(resolve_INV, cytobands, disc_pairs,
for record in cluster:
cpx = ComplexSV([record], cytobands, mei_bed, SR_only_cutoff)
cpx_record_ids_v2.update(cpx.record_ids)

# Assign random string as resolved ID to handle sharding
cpx.vcf_record.id = variant_prefix + '_' + \
get_random_string(random_resolved_id_length)
cpx.vcf_record.id = resolved_record_namer.get_next_id(cpx.vcf_record)
cpx_records_v2.append(cpx.vcf_record)
# resolved_idx += 1
outcome = 'treated as separate unrelated insertions'
else:
cpx = ComplexSV(cluster, cytobands, mei_bed, SR_only_cutoff)
cpx_record_ids_v2.update(cpx.record_ids)
if cpx.svtype == 'UNR':
# Assign random string as unresolved ID to handle sharding
unresolved_vid = 'UNRESOLVED_' + \
get_random_string(random_resolved_id_length)
for record in cpx.records:
record.info['EVENT'] = unresolved_vid
record.info['EVENT'] = unresolved_record_namer.get_next_id(cpx.vcf_record)
record.info['UNRESOLVED'] = True
cpx_records_v2.append(record)
# unresolved_idx += 1
outcome = 'is unresolved'
else:
cpx.vcf_record.id = variant_prefix + '_' + \
get_random_string(random_resolved_id_length)
cpx.vcf_record.id = resolved_record_namer.get_next_id(cpx.vcf_record)
cpx_records_v2.append(cpx.vcf_record)
if 'CPX_TYPE' in cpx.vcf_record.info.keys():
outcome = 'resolved as ' + \
Expand Down Expand Up @@ -417,11 +384,6 @@ def main(argv):
parser.add_argument('--cytobands', help='Cytoband file. Required to '
'correctly classify interchromosomal events.',
required=True)
# parser.add_argument('--bincov', help='Bincov file.', required=True)
# parser.add_argument('--medianfile', help='Medianfile', required=True)
# parser.add_argument('--famfile', help='Fam file', required=True)
# parser.add_argument('--cutoffs', help='Random forest cutoffs',
# required=True)
parser.add_argument('--min-rescan-pe-support', type=int, default=4,
help='Minumum discordant pairs required during '
'single-ender rescan.')
Expand All @@ -433,6 +395,10 @@ def main(argv):
help='Unresolved complex breakpoints and CNV.')
parser.add_argument('-p', '--prefix', default='CPX_',
help='Variant prefix [CPX_]')
parser.add_argument('-d', '--variant-id-digits', type=int, default=6,
help='Number of digits in variant IDs.')
parser.add_argument('-t', '--temp-dir', type=str, default=None,
help='Temporary directory path for vcf sorting. [Default uses TMPDIR environment variable]')
parser.add_argument('-q', '--quiet', default=False,
help='Disable progress logging to stderr.')

Expand All @@ -451,7 +417,10 @@ def main(argv):
for line in CPX_INFO:
vcf.header.add_line(line)

resolved_pipe = subprocess.Popen(['vcf-sort', '-c'],
sort_command = ['bcftools', 'sort']
if args.temp_dir:
sort_command.extend(['--temp-dir', args.temp_dir])
resolved_pipe = subprocess.Popen(sort_command,
stdin=subprocess.PIPE,
stdout=args.resolved)

Expand All @@ -465,9 +434,6 @@ def main(argv):
blacklist = pysam.TabixFile(args.pe_blacklist)
else:
blacklist = None
# cutoffs = pd.read_table(args.cutoffs)
# rdtest = svu.RdTest(args.bincov, args.medianfile, args.famfile,
# list(vcf.header.samples), cutoffs)

if args.discfile is not None:
disc_pairs = pysam.TabixFile(args.discfile)
Expand All @@ -481,10 +447,11 @@ def main(argv):
resolved_records = []
unresolved_records = []
resolve_INV = []
# cpx_dist = 20000
resolved_record_namer = RecordNamer(prefix=args.prefix + '_CPX', num_digits=args.variant_id_digits)
unresolved_record_namer = RecordNamer(prefix=args.prefix + '_UNRES', num_digits=args.variant_id_digits)

for record in resolve_complex_sv(vcf, cytobands, disc_pairs, mei_bed, args.prefix,
args.min_rescan_pe_support, blacklist, args.quiet):
for record in resolve_complex_sv(vcf, cytobands, disc_pairs, mei_bed, resolved_record_namer,
unresolved_record_namer, args.min_rescan_pe_support, blacklist, args.quiet):
# Move members to existing variant IDs unless variant is complex
if record.info['SVTYPE'] != 'CPX' and args.prefix not in record.id:
# Don't alter MEMBERS if the prefix of record.id is already in MEMBERS
Expand All @@ -500,7 +467,6 @@ def main(argv):
else:
resolved_records.append(record)

# out_rec = resolve_complex_sv(vcf, cytobands, disc_pairs, mei_bed, args.prefix, args.min_rescan_pe_support, blacklist)
# Print status
if not args.quiet:
now = datetime.datetime.now()
Expand All @@ -510,9 +476,8 @@ def main(argv):

# RLC: As of Sept 19, 2018, only considering inversion single-enders in second-pass
# due to too many errors in second-pass linking and variant reporting
cpx_records_v2 = resolve_complex_sv_v2(resolve_INV,
cytobands, disc_pairs, mei_bed, args.prefix,
args.min_rescan_pe_support, blacklist, args.quiet)
cpx_records_v2 = resolve_complex_sv_v2(resolve_INV, cytobands, disc_pairs, mei_bed, resolved_record_namer,
unresolved_record_namer, args.min_rescan_pe_support, blacklist, args.quiet)

for record in cpx_records_v2:
# Move members to existing variant IDs unless variant is complex
Expand Down
Loading

0 comments on commit 0119b3f

Please sign in to comment.