diff --git a/scripts/mutect2_wdl/mutect3_training_data.wdl b/scripts/mutect2_wdl/mutect3_training_data.wdl index 0a6a6d07d94..444dfb76da1 100644 --- a/scripts/mutect2_wdl/mutect3_training_data.wdl +++ b/scripts/mutect2_wdl/mutect3_training_data.wdl @@ -23,6 +23,8 @@ workflow Mutect3TrainingData { String? realignment_extra_args String? m2_extra_args String? m2_extra_filtering_args + String? normal_artifact_extra_args + String? split_intervals_extra_args File? truth_vcf File? truth_vcf_idx Boolean? make_bamout @@ -36,6 +38,11 @@ workflow Mutect3TrainingData { String m2_extra_args_with_training_mode = select_first([m2_extra_args, ""]) + " --training-data-mode --training-data-mode-ref-downsample " + ref_downsample + Runtime small_runtime = {"gatk_docker": gatk_docker, "gatk_override": gatk_override, + "max_retries": 2, "preemptible": 0, "cpu": 2, + "machine_mem": 4000, "command_mem": 3500, + "disk": 100, "boot_disk_size": 12} + # call on the tumor (with normal if present) to get tumor read data and M2 filtering call m2.Mutect2 as Tumor { input: @@ -138,14 +145,48 @@ workflow Mutect3TrainingData { gatk_docker = gatk_docker, preemptible = preemptible } - } + call m2.SplitIntervals as Split { + input: + intervals = intervals, + ref_fasta = ref_fasta, + ref_fai = ref_fai, + ref_dict = ref_dict, + scatter_count = scatter_count, + split_intervals_extra_args = split_intervals_extra_args, + runtime_params = small_runtime + } + scatter (subintervals in Split.interval_files ) { + call GetNormalArtifactData { + input: + ref_fasta = ref_fasta, + ref_fai = ref_fai, + ref_dict = ref_dict, + tumor_reads = select_first([normal_bam]), + tumor_reads_index = select_first([normal_bai]), + normal_reads = tumor_bam, + normal_reads_index = tumor_bai, + intervals = subintervals, + preemptible = preemptible, + max_retries = max_retries, + extra_args = normal_artifact_extra_args, + gatk_override = gatk_override, + gatk_docker = gatk_docker + } + } + call MergeNormalArtifactData { + input: + input_tables = GetNormalArtifactData.table, + runtime_params = small_runtime + } + } output { File tumor_table = select_first([TumorConcordanceTable.table, TumorTable.table]) File? normal_table = NormalTable.table + File? normal_artifact_table = MergeNormalArtifactData.merged_table } } @@ -209,7 +250,7 @@ task MakeTableFromMutect2 { gatk --java-options "-Xmx2g" SelectVariants -V ~{filtered_vcf} --restrict-alleles-to BIALLELIC -O biallelic.vcf gatk --java-options "-Xmx2g" VariantsToTable -V biallelic.vcf \ - -F CHROM -F POS -F REF -F ALT -F POPAF -F TLOD -F STATUS -F REF_BASES -GF DP -F FILTER -GF FRS \ + -F CHROM -F POS -F REF -F ALT -F POPAF -F TLOD -F STATUS -F REF_BASES -F HEC -F HAPDOM -F HAPCOMP -GF DP -F FILTER -GF FRS \ --show-filtered \ -O output.table } @@ -245,7 +286,7 @@ task MakeTableFromConcordance { for file in ~{tpfp} ~{ftnfn}; do gatk --java-options "-Xmx2g" SelectVariants -V $file --restrict-alleles-to BIALLELIC -O biallelic.vcf gatk --java-options "-Xmx2g" VariantsToTable -V biallelic.vcf \ - -F CHROM -F POS -F REF -F ALT -F POPAF -F TLOD -F STATUS -F REF_BASES -GF DP -F FILTER -GF FRS \ + -F CHROM -F POS -F REF -F ALT -F POPAF -F TLOD -F STATUS -F REF_BASES -F HEC -F HAPDOM -F HAPCOMP -GF DP -F FILTER -GF FRS \ --show-filtered \ -O tmp.table @@ -269,4 +310,106 @@ task MakeTableFromConcordance { output { File table = "output.table" } +} + +task GetNormalArtifactData { + input { + File? intervals + File ref_fasta + File ref_fai + File ref_dict + File tumor_reads + File tumor_reads_index + File? normal_reads + File? normal_reads_index + String? extra_args + + File? gatk_override + String? gcs_project_for_requester_pays + + # runtime + String gatk_docker + Int? mem + Int? preemptible + Int? max_retries + Int? disk_space + Int? cpu + Boolean use_ssd = false + } + + # Mem is in units of GB but our command and memory runtime values are in MB + Int machine_mem = if defined(mem) then mem * 1000 else 3500 + Int command_mem = machine_mem - 500 + + parameter_meta{ + intervals: {localization_optional: true} + ref_fasta: {localization_optional: true} + ref_fai: {localization_optional: true} + ref_dict: {localization_optional: true} + tumor_reads: {localization_optional: true} + tumor_reads_index: {localization_optional: true} + normal_reads: {localization_optional: true} + normal_reads_index: {localization_optional: true} + } + + command <<< + set -e + + export GATK_LOCAL_JAR=~{default="/root/gatk.jar" gatk_override} + + if [[ ! -z "~{normal_reads}" ]]; then + gatk --java-options "-Xmx~{command_mem}m" GetSampleName -R ~{ref_fasta} -I ~{normal_reads} -O normal_name.txt -encode \ + ~{"--gcs-project-for-requester-pays " + gcs_project_for_requester_pays} + normal_sample="`cat normal_name.txt`" + fi + + gatk --java-options "-Xmx~{command_mem}m" GetNormalArtifactData \ + -R ~{ref_fasta} ~{"-L " + intervals} -I ~{tumor_reads} -I ~{normal_reads} -O normal_artifact.table \ + -normal $normal_sample \ + ~{extra_args} ~{"--gcs-project-for-requester-pays " + gcs_project_for_requester_pays} + >>> + + runtime { + docker: gatk_docker + bootDiskSizeGb: 12 + memory: machine_mem + " MB" + disks: "local-disk " + select_first([disk_space, 100]) + if use_ssd then " SSD" else " HDD" + preemptible: select_first([preemptible, 10]) + maxRetries: select_first([max_retries, 0]) + cpu: select_first([cpu, 1]) + } + + output { + File table = "normal_artifact.table" + } +} + +task MergeNormalArtifactData { + input { + Array[File] input_tables + Runtime runtime_params + } + + command { + set -e + export GATK_LOCAL_JAR=~{default="/root/gatk.jar" runtime_params.gatk_override} + + gatk --java-options "-Xmx~{runtime_params.command_mem}m" GatherNormalArtifactData \ + -I ~{sep=' -I ' input_tables} \ + -O normal_artifact.table + } + + runtime { + docker: runtime_params.gatk_docker + bootDiskSizeGb: runtime_params.boot_disk_size + memory: runtime_params.machine_mem + " MB" + disks: "local-disk " + runtime_params.disk + " HDD" + preemptible: runtime_params.preemptible + maxRetries: runtime_params.max_retries + cpu: runtime_params.cpu + } + + output { + File merged_table = "normal_artifact.table" + } } \ No newline at end of file diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/AssemblyComplexity.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/AssemblyComplexity.java index ce312b9e948..91403e9cf9f 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/AssemblyComplexity.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/AssemblyComplexity.java @@ -7,6 +7,7 @@ import htsjdk.variant.variantcontext.VariantContext; import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.mutable.MutableInt; +import org.apache.commons.lang3.tuple.Triple; import org.broadinstitute.barclay.help.DocumentedFeature; import org.broadinstitute.gatk.nativebindings.smithwaterman.SWOverhangStrategy; import org.broadinstitute.hellbender.engine.FeatureContext; @@ -40,14 +41,22 @@ public AssemblyComplexity() { } @Override public Map annotate(final ReferenceContext ref, - final FeatureContext features, - final VariantContext vc, - final AlleleLikelihoods likelihoods, - final AlleleLikelihoods fragmentLikelihoods, - final AlleleLikelihoods haplotypeLikelihoods) { - + final FeatureContext features, + final VariantContext vc, + final AlleleLikelihoods likelihoods, + final AlleleLikelihoods fragmentLikelihoods, + final AlleleLikelihoods haplotypeLikelihoods) { + final Triple annotations = annotate(vc, haplotypeLikelihoods); final Map result = new HashMap<>(); + result.put(GATKVCFConstants.HAPLOTYPE_EQUIVALENCE_COUNTS_KEY , annotations.getLeft()); + result.put(GATKVCFConstants.HAPLOTYPE_COMPLEXITY_KEY , annotations.getMiddle()); + result.put(GATKVCFConstants.HAPLOTYPE_DOMINANCE_KEY , annotations.getRight()); + return result; + } + + public static Triple annotate(final VariantContext vc, final AlleleLikelihoods haplotypeLikelihoods) { + // count best-read support for each haplotype final Map haplotypeSupportCounts = haplotypeLikelihoods.alleles().stream() .collect(Collectors.toMap(hap -> hap, label -> new MutableInt(0))); @@ -69,8 +78,6 @@ public Map annotate(final ReferenceContext ref, .mapToInt(n->n) .toArray(); - result.put(GATKVCFConstants.HAPLOTYPE_EQUIVALENCE_COUNTS_KEY, equivalenceCounts); - // we're going to calculate the complexity of this variant's haplotype (that is, the variant-supporting haplotype // with the most reads) versus the closest (in terms of edit distance) germline haplotype. The haplotype // with the greatest read support is considered germline, and as a heuristic we consider the second-most-supported @@ -95,8 +102,6 @@ public Map annotate(final ReferenceContext ref, return germlineHaplotypes.stream().mapToInt(gh -> editDistance(gh, mostSupportedHaplotypeWithAllele, vc.getStart())).min().getAsInt(); }).toArray(); - result.put(GATKVCFConstants.HAPLOTYPE_COMPLEXITY_KEY, editDistances); - // measure which proportion of reads supporting each alt allele fit the most-supported haplotype for that allele final double[] haplotypeDominance = IntStream.range(0, vc.getNAlleles() - 1).mapToDouble(altAlleleIndex -> { final int[] counts = haplotypesByDescendingSupport.stream() @@ -106,9 +111,7 @@ public Map annotate(final ReferenceContext ref, return MathUtils.arrayMax(counts) / (double) MathUtils.sum(counts); }).toArray(); - result.put(GATKVCFConstants.HAPLOTYPE_DOMINANCE_KEY, haplotypeDominance); - - return result; + return Triple.of(equivalenceCounts, editDistances, haplotypeDominance); } @@ -129,9 +132,12 @@ private static boolean containsAltAllele(final EventMap eventMap, final VariantC final List overlapping = eventMap.getOverlappingEvents(vc.getStart()); if (overlapping.isEmpty()) { return false; + } else if (overlapping.get(0).getStart() != vc.getStart()) { + return false; } else { final VariantContext eventMapVC = overlapping.get(0); final int excessBases = vc.getReference().length() - eventMapVC.getReference().length(); + return equalBasesExcludingSuffix(eventMapVC.getAlternateAllele(0).getBases(), vc.getAlternateAllele(altAlleleIndex).getBases(), excessBases); } @@ -142,6 +148,9 @@ private static boolean containsAltAllele(final EventMap eventMap, final VariantC private static boolean equalBasesExcludingSuffix(final byte[] eventMapBases, final byte[] variantContextBases, final int suffixSize) { if (eventMapBases.length + suffixSize != variantContextBases.length) { return false; + } else if (eventMapBases.length > variantContextBases.length) { + return false; // edge case -- event map is longer, though minimal, because it is a MNP + // even if the leading bases match, let's call this not a match } else { for (int n = 0; n < eventMapBases.length; n++) { if (eventMapBases[n] != variantContextBases[n]) { diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/FeaturizedReadSets.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/FeaturizedReadSets.java index 8bb9a7e09fc..c870d9693ee 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/FeaturizedReadSets.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/FeaturizedReadSets.java @@ -38,9 +38,9 @@ public class FeaturizedReadSets implements JumboGenotypeAnnotation { private static final int DEFAULT_MAX_REF_COUNT = Integer.MAX_VALUE; - private static final int FEATURES_PER_READ = 11; + public static final int FEATURES_PER_READ = 11; - private final SmithWatermanAligner aligner = SmithWatermanAligner.getAligner(SmithWatermanAligner.Implementation.JAVA); + private static final SmithWatermanAligner aligner = SmithWatermanAligner.getAligner(SmithWatermanAligner.Implementation.JAVA); // downsample ref reads to this count if needed private final int maxRefCount; @@ -55,7 +55,7 @@ public FeaturizedReadSets() { @Override public void annotate(final ReferenceContext ref, - final FeatureContext fatures, + final FeatureContext features, final VariantContext vc, final Genotype g, final GenotypeBuilder gb, @@ -70,46 +70,73 @@ public void annotate(final ReferenceContext ref, return; } + final List>> readVectorsByAllele = getReadVectors(vc, Collections.singletonList(g.getSampleName()), + likelihoods, haplotypeLikelihoods, maxRefCount, Integer.MAX_VALUE); + + // flatten twice: over all reads supporting each allele and over all alleles + // we can partition by allele with the countsInAlleleOrder annotation + // and by read using the constant feature vector size + final int[] flattenedTensorInAlleleOrder = readVectorsByAllele.stream() + .flatMap(listOfLists -> listOfLists.stream().flatMap(List::stream)) + .mapToInt(n -> n) + .toArray(); + + final int[] countsInAlleleOrder = readVectorsByAllele.stream().mapToInt(List::size).toArray(); + + gb.attribute(GATKVCFConstants.FEATURIZED_READ_SETS_KEY, flattenedTensorInAlleleOrder); + gb.attribute(GATKVCFConstants.FEATURIZED_READ_SETS_COUNTS_KEY, countsInAlleleOrder); + } + + public static List>> getReadVectors(final VariantContext vc, + final Collection samples, + final AlleleLikelihoods likelihoods, + final AlleleLikelihoods haplotypeLikelihoods, + final int refDownsample, + final int altDownsample) { + return getReadVectors(vc, samples, likelihoods, haplotypeLikelihoods, refDownsample, altDownsample, Collections.emptyMap()); + } + + // returns Lists (in allele order) of lists of read vectors supporting each allele + public static List>> getReadVectors(final VariantContext vc, + final Collection samples, + final AlleleLikelihoods likelihoods, + final AlleleLikelihoods haplotypeLikelihoods, + final int refDownsample, + final int altDownsample, + final Map altDownsampleMap) { final Map> readsByAllele = likelihoods.alleles().stream() .collect(Collectors.toMap(a -> a, a -> new ArrayList<>())); - Utils.stream(likelihoods.bestAllelesBreakingTies()) + samples.stream().flatMap(s -> likelihoods.bestAllelesBreakingTies(s).stream()) .filter(ba -> ba.isInformative()) .forEach(ba -> readsByAllele.get(ba.allele).add(ba.evidence)); // downsample if necessary final Allele refAllele = likelihoods.alleles().stream().filter(Allele::isReference).findFirst().get(); - if (readsByAllele.get(refAllele).size() > maxRefCount) { - Collections.shuffle(readsByAllele.get(refAllele)); - readsByAllele.put(refAllele, readsByAllele.get(refAllele).subList(0, maxRefCount)); + for (final Allele allele : likelihoods.alleles()) { + final int downsample = allele.isReference() ? refDownsample : altDownsampleMap.getOrDefault(allele, altDownsample); + if (readsByAllele.get(allele).size() > downsample) { + Collections.shuffle(readsByAllele.get(allele)); + readsByAllele.put(allele, readsByAllele.get(allele).subList(0, downsample)); + } } final Map bestHaplotypes = new HashMap<>(); - haplotypeLikelihoods.bestAllelesBreakingTies().stream().forEach(ba -> - ba.evidence.getReads().forEach(read -> bestHaplotypes.put(read, ba.allele))); - - final List stringsInAlleleOrder = vc.getAlleles().stream() - .map(allele -> { - final List reads = readsByAllele.get(allele); - final List flattened = new ArrayList<>(reads.size() * FEATURES_PER_READ); - reads.forEach(read -> flattened.addAll(featurize(read, vc, bestHaplotypes))); - return StringUtils.join(flattened, ","); - }).collect(Collectors.toList()); - + samples.stream().flatMap(s -> haplotypeLikelihoods.bestAllelesBreakingTies(s).stream()) + .forEach(ba -> ba.evidence.getReads().forEach(read -> bestHaplotypes.put(read, ba.allele))); - final String annotation = AnnotationUtils.encodeAnyASListWithRawDelim(stringsInAlleleOrder); - - gb.attribute(GATKVCFConstants.FEATURIZED_READ_SETS_KEY, annotation); + return vc.getAlleles().stream() + .map(allele -> readsByAllele.get(allele).stream().map(read -> featurize(read, vc, bestHaplotypes)).collect(Collectors.toList())) + .collect(Collectors.toList()); } - @Override public List getKeyNames() { - return Collections.singletonList(GATKVCFConstants.FEATURIZED_READ_SETS_KEY); + return Arrays.asList(GATKVCFConstants.FEATURIZED_READ_SETS_KEY, GATKVCFConstants.FEATURIZED_READ_SETS_COUNTS_KEY); } - private List featurize(final GATKRead read, final VariantContext vc, final Map bestHaplotypes) { + private static List featurize(final GATKRead read, final VariantContext vc, final Map bestHaplotypes) { final List result = new ArrayList<>(); result.add(read.getMappingQuality()); result.add(BaseQuality.getBaseQuality(read, vc).orElse(DEFAULT_BASE_QUALITY)); diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/ReferenceBases.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/ReferenceBases.java index c7696633755..9c2d99b97b7 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/ReferenceBases.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/ReferenceBases.java @@ -30,8 +30,8 @@ @DocumentedFeature(groupName=HelpConstants.DOC_CAT_ANNOTATORS, groupSummary=HelpConstants.DOC_CAT_ANNOTATORS_SUMMARY, summary="Annotate with local reference bases (REF_BASES)") public class ReferenceBases implements InfoFieldAnnotation { - private int NUM_BASES_ON_EITHER_SIDE = 10; - private int REFERENCE_CONTEXT_LENGTH = 2*NUM_BASES_ON_EITHER_SIDE + 1; + private static final int NUM_BASES_ON_EITHER_SIDE = 10; + private static final int REFERENCE_CONTEXT_LENGTH = 2*NUM_BASES_ON_EITHER_SIDE + 1; protected final OneShotLogger warning = new OneShotLogger(this.getClass()); @@ -46,6 +46,14 @@ public Map annotate(final ReferenceContext ref, warning.warn("REF_BASES requires the reference to annotate, none was provided"); return Collections.emptyMap(); } + + final String bases = annotate(ref, vc); + return Collections.singletonMap(GATKVCFConstants.REFERENCE_BASES_KEY, bases ); + + } + + public static String annotate(final ReferenceContext ref, final VariantContext vc) { + Utils.nonNull(ref); final int basesToDiscardInFront = Math.max(vc.getStart() - ref.getWindow().getStart() - NUM_BASES_ON_EITHER_SIDE, 0); final String allBases = new String(ref.getBases()); final int endIndex = Math.min(basesToDiscardInFront + 2 * NUM_BASES_ON_EITHER_SIDE + 1, allBases.length()); @@ -54,7 +62,7 @@ public Map annotate(final ReferenceContext ref, localBases = String.join("", localBases, StringUtils.repeat("N", REFERENCE_CONTEXT_LENGTH - localBases.length())); } - return Collections.singletonMap(GATKVCFConstants.REFERENCE_BASES_KEY, localBases ); + return localBases; } public static String getNMiddleBases(final String bases, final int n){ diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/GatherNormalArtifactData.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/GatherNormalArtifactData.java new file mode 100644 index 00000000000..ea4f75f02c9 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/GatherNormalArtifactData.java @@ -0,0 +1,44 @@ +package org.broadinstitute.hellbender.tools.walkers.mutect; + +import org.broadinstitute.barclay.argparser.Argument; +import org.broadinstitute.barclay.argparser.CommandLineProgramProperties; +import org.broadinstitute.hellbender.cmdline.CommandLineProgram; +import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions; +import org.broadinstitute.hellbender.cmdline.programgroups.CoverageAnalysisProgramGroup; +import org.broadinstitute.hellbender.exceptions.UserException; +import org.broadinstitute.hellbender.utils.io.IOUtils; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; + +@CommandLineProgramProperties( + summary="Combine output files from GetNormalArtifactData in the order defined by a sequence dictionary", + oneLineSummary = "Combine output files from GetNormalArtifactData in the order defined by a sequence dictionary", + programGroup = CoverageAnalysisProgramGroup.class +) +public class GatherNormalArtifactData extends CommandLineProgram { + + @Argument(fullName = StandardArgumentDefinitions.INPUT_LONG_NAME, shortName = StandardArgumentDefinitions.INPUT_SHORT_NAME, + doc = "an output of GetNormalArtifactData") + final List input = null; + + @Argument(fullName = StandardArgumentDefinitions.OUTPUT_LONG_NAME, shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME, + doc = "output") + final File output = null; + + @Override + protected Object doWork() { + + try ( NormalArtifactRecord.NormalArtifactWriter writer = new NormalArtifactRecord.NormalArtifactWriter(IOUtils.fileToPath(output)) ) { + for (final File inputFile : input) { + writer.writeAllRecords(NormalArtifactRecord.readFromFile(inputFile)); + } + } catch (IOException e){ + throw new UserException(String.format("Encountered an IO exception while writing to %s.", output)); + } + + return "SUCCESS"; + } +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/GetNormalArtifactData.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/GetNormalArtifactData.java new file mode 100644 index 00000000000..528bf60b5c6 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/GetNormalArtifactData.java @@ -0,0 +1,175 @@ +package org.broadinstitute.hellbender.tools.walkers.mutect; + +import htsjdk.samtools.SAMFileHeader; +import htsjdk.variant.variantcontext.VariantContext; +import htsjdk.variant.vcf.VCFConstants; +import htsjdk.variant.vcf.VCFHeader; +import it.unimi.dsi.fastutil.bytes.ByteArrayList; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.math3.distribution.BinomialDistribution; +import org.apache.commons.math3.random.RandomGenerator; +import org.apache.commons.math3.random.RandomGeneratorFactory; +import org.broadinstitute.barclay.argparser.Argument; +import org.broadinstitute.barclay.argparser.CommandLineProgramProperties; +import org.broadinstitute.barclay.help.DocumentedFeature; +import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions; +import org.broadinstitute.hellbender.cmdline.programgroups.CoverageAnalysisProgramGroup; +import org.broadinstitute.hellbender.engine.AlignmentContext; +import org.broadinstitute.hellbender.engine.FeatureContext; +import org.broadinstitute.hellbender.engine.LocusWalker; +import org.broadinstitute.hellbender.engine.ReferenceContext; +import org.broadinstitute.hellbender.engine.filters.ReadFilter; +import org.broadinstitute.hellbender.exceptions.UserException; +import org.broadinstitute.hellbender.tools.walkers.contamination.CalculateContamination; +import org.broadinstitute.hellbender.tools.walkers.contamination.PileupSummary; +import org.broadinstitute.hellbender.utils.BaseUtils; +import org.broadinstitute.hellbender.utils.MathUtils; +import org.broadinstitute.hellbender.utils.Utils; +import org.broadinstitute.hellbender.utils.activityprofile.ActivityProfileState; +import org.broadinstitute.hellbender.utils.pileup.PileupElement; +import org.broadinstitute.hellbender.utils.pileup.ReadPileup; +import org.broadinstitute.hellbender.utils.read.ReadUtils; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + *

Usage example

+ * + *
+ * gatk GetNormalArtifactData \
+ *   -I tumor.bam \
+ *   -I normal.bam \
+ *   -normal normal_sample \
+ *   -L intervals.list \
+ *   -O normal-artifact.table
+ * 
+ * + */ +@CommandLineProgramProperties( + summary = "Collects data for training normal artifact filter", + oneLineSummary = "Collects data for training normal artifact filter", + programGroup = CoverageAnalysisProgramGroup.class) +@DocumentedFeature +public class GetNormalArtifactData extends LocusWalker { + @Argument(fullName = StandardArgumentDefinitions.OUTPUT_LONG_NAME, + shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME, + doc="The output table", optional=false) + private File outputTable; + + @Argument(fullName = M2ArgumentCollection.NORMAL_SAMPLE_LONG_NAME, shortName = M2ArgumentCollection.NORMAL_SAMPLE_SHORT_NAME, + doc = "BAM sample name of normal. May be URL-encoded as output by GetSampleName with -encode argument.") + protected List normalSamples = new ArrayList<>(); + + public static final String ERROR_PROB_NAME = "error-prob"; + private static double DEFAULT_ERROR_PROB = 0.001; + @Argument(fullName = ERROR_PROB_NAME, doc = "Error probability for p-values", optional = true) + protected double errorProb = DEFAULT_ERROR_PROB; + + private List data = new ArrayList<>(); + + private SAMFileHeader header; + + private final Random rng = Utils.getRandomGenerator(); + + @Override + public boolean requiresReads() { + return true; + } + + @Override + public boolean requiresReference() { + return true; + } + + @Override + public boolean requiresIntervals() { + return false; + } + + @Override + public boolean requiresFeatures() { + return false; + } + + @Override + public List getDefaultReadFilters() { + return Mutect2Engine.makeStandardMutect2ReadFilters(); + } + + @Override + public void onTraversalStart() { + header = getHeaderForReads(); + } + + @Override + public void apply(AlignmentContext alignmentContext, ReferenceContext referenceContext, FeatureContext featureContext) { + final ReadPileup pileup = alignmentContext.getBasePileup(); + final ReadPileup normalPileup = pileup.makeFilteredPileup(pe -> normalSamples.contains(ReadUtils.getSampleName(pe.getRead(), header))); + final byte refBase = referenceContext.getBase(); + final int[] normalCounts = getBaseCounts(normalPileup, refBase); + + final int bestNormalAllele = MathUtils.maxElementIndex(normalCounts); + final int normalAltCount = normalCounts[bestNormalAllele]; + + // skip cases of no evidence in normal or likely germline + if (normalAltCount == 0 || normalAltCount > 0.2 * normalPileup.size()) { + return; + } + + final ReadPileup tumorPileup = pileup.makeFilteredPileup(pe -> !normalSamples.contains(ReadUtils.getSampleName(pe.getRead(), header))); + final int tumorAltCount = getBaseCounts(tumorPileup, refBase)[bestNormalAllele]; + + // p value for this tumor alt count or greater + // we don't want to bloat our data with a lot of sites with a sequencing error in the normal and little or nothing in the tumor + // at the same time, we must include them. Thus we downsample and record the downsampling in order to upsample later + // when the p value is not significant + final double tumorPValue = 1 - new BinomialDistribution(tumorPileup.size(), errorProb).cumulativeProbability(tumorAltCount - 1); + final double downsampleProb = Math.max(1 - tumorPValue, 0.05); + + if (rng.nextDouble() > downsampleProb) { + return; + } else if (tumorAltCount > 0.5 * tumorPileup.size()) { + return; + } + + final String type = bestNormalAllele < 4 ? "SNV" : "INDEL"; + data.add(new NormalArtifactRecord(normalAltCount, normalPileup.size(), tumorAltCount, tumorPileup.size(), downsampleProb, type)); + + } + + @Override + public Object onTraversalSuccess() { + NormalArtifactRecord.writeToFile(data, outputTable); + return "SUCCESS"; + } + + /** + * Get counts of A, C, G, T, before insertion start, before deletion start in order, which returns a int[6] vector with counts according + * to BaseUtils.simpleBaseToBaseIndex for each base, and with indices 4 and 5 for pileup elements preceding insertions + * and deletions. Insertions and deletions themselves are not counted to avoid overcounting. + */ + private static int[] getBaseCounts(final ReadPileup pileup, final byte refBase) { + final int[] counts = new int[6]; + + for (final PileupElement pe : pileup) { + if (pe.isDeletion()) { + continue; + } else if (pe.isBeforeInsertion()) { + counts[4]++; + } else if (pe.isBeforeDeletionStart()) { + counts[5]++; + } else if (pe.getBase() != refBase) { + final int index = BaseUtils.simpleBaseToBaseIndex(pe.getBase()); + if (index != -1) { + counts[index]++; + } + } + } + + return counts; + } +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/M2ArgumentCollection.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/M2ArgumentCollection.java index 924061a9630..cd835c51097 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/M2ArgumentCollection.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/M2ArgumentCollection.java @@ -49,8 +49,6 @@ public class M2ArgumentCollection extends AssemblyBasedCallerArgumentCollection public static final String PCR_SNV_QUAL_LONG_NAME = "pcr-snv-qual"; public static final String PCR_INDEL_QUAL_LONG_NAME = "pcr-indel-qual"; public static final String F1R2_TAR_GZ_NAME = "f1r2-tar-gz"; - public static final String TRAINING_DATA_MODE_LONG_NAME = "training-data-mode"; - public static final String TRAINING_DATA_MODE_REF_DOWNSAMPLE_LONG_NAME = "training-data-mode-ref-downsample"; public static final double DEFAULT_AF_FOR_TUMOR_ONLY_CALLING = 5e-8; public static final double DEFAULT_AF_FOR_TUMOR_NORMAL_CALLING = 1e-6; @@ -70,6 +68,19 @@ public class M2ArgumentCollection extends AssemblyBasedCallerArgumentCollection public static final String LOD_BAND_LONG_NAME = "gvcf-lod-band"; public static final String LOD_BAND_SHORT_NAME = "LODB"; + /* + Mutect3 parameters + */ + public static final String MUTECT3_TRAINING_MODE_LONG_NAME = "mutect3-training-mode"; + public static final String MUTECT3_TRAINING_NON_ARTIFACT_RATIO = "mutect3-non-artifact-ratio"; + public static final String MUTECT3_REF_DOWNSAMPLE_LONG_NAME = "mutect3-ref-downsample"; + public static final String MUTECT3_ALT_DOWNSAMPLE_LONG_NAME = "mutect3-alt-downsample"; + public static final String MUTECT3_DATASET_LONG_NAME = "mutect3-dataset"; + + public static final int DEFAULT_MUTECT3_REF_DOWNSAMPLE = 10; + public static final int DEFAULT_MUTECT3_ALT_DOWNSAMPLE = 20; + public static final int DEFAULT_MUTECT3_NON_ARTIFACT_RATIO = 20; + @Override protected int getDefaultMaxMnpDistance() { return 1; } @@ -104,8 +115,6 @@ public ReadThreadingAssembler createReadThreadingAssembler(){ @Argument(fullName = NORMAL_SAMPLE_LONG_NAME, shortName = NORMAL_SAMPLE_SHORT_NAME, doc = "BAM sample name of normal(s), if any. May be URL-encoded as output by GetSampleName with -encode argument.", optional = true) protected List normalSamples = new ArrayList<>(); - //TODO: END OF HACK ALERT - /***************************************/ // Reference Metadata inputs /***************************************/ @@ -159,16 +168,34 @@ public double getDefaultAlleleFrequency() { public Boolean mitochondria = false; /** - * Training data mode collects data on variant- and artifact-supporting read sets for fitting a deep learning filtering model + * If true, collect Mutect3 data for learning; otherwise collect data for generating calls with a pre-trained model + */ + @Argument(fullName = MUTECT3_TRAINING_MODE_LONG_NAME, optional = true, doc="Collect Mutect3 data for learning.") + public Boolean mutect3TrainingDataMode = false; + + /** + * Downsample ref reads for Mutect3 data + */ + @Argument(fullName = MUTECT3_REF_DOWNSAMPLE_LONG_NAME, optional = true, doc="Downsample ref reads to this count when generating a Mutect3 dataset.") + public int maxRefCountForMutect3 = DEFAULT_MUTECT3_REF_DOWNSAMPLE; + + /** + * Downsample alt reads for Mutect3 data + */ + @Argument(fullName = MUTECT3_ALT_DOWNSAMPLE_LONG_NAME, optional = true, doc="Downsample alt reads to this count for Mutect3 training datasets.") + public int maxAltCountForMutect3 = DEFAULT_MUTECT3_ALT_DOWNSAMPLE; + + /** + * Number of non-artifact data per artifact datum in Mutect3 training */ - @Argument(fullName = TRAINING_DATA_MODE_LONG_NAME, optional = true, doc="Output VCF contains featurized sets of reads for training a deep variant filter.") - public Boolean trainingDataMode = false; + @Argument(fullName = MUTECT3_TRAINING_NON_ARTIFACT_RATIO, optional = true, doc="Number of non-artifact data per artifact datum in Mutect3 training.") + public int mutect3NonArtifactRatio = DEFAULT_MUTECT3_NON_ARTIFACT_RATIO; /** - * Downsample ref reads in training data mode + * Destination for Mutect3 data collection */ - @Argument(fullName = TRAINING_DATA_MODE_REF_DOWNSAMPLE_LONG_NAME, optional = true, doc="Downsample ref reads to this count in training data mode.") - public int maxRefCountInTrainingData = Integer.MAX_VALUE; + @Argument(fullName = MUTECT3_DATASET_LONG_NAME, optional = true, doc="Destination for Mutect3 data collection") + public File mutect3Dataset; /** * Only variants with tumor LODs exceeding this threshold will be written to the VCF, regardless of filter status. diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2.java index 97f621745ae..697ae1b33c3 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2.java @@ -285,12 +285,6 @@ public Collection makeVariantAnnotations(){ annotations.add(new OriginalAlignment()); } - if (MTAC.trainingDataMode) { - annotations.add(new FeaturizedReadSets(MTAC.maxRefCountInTrainingData)); - annotations.add(new AssemblyComplexity()); - annotations.add(new ReferenceBases()); - } - return annotations; } @@ -312,7 +306,7 @@ public void closeTool() { vcfWriter.close(); } if (m2Engine != null) { - m2Engine.shutdown(); + m2Engine.close(); } } } diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2Engine.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2Engine.java index e4611bbe726..f29d913bad7 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2Engine.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect2Engine.java @@ -64,7 +64,7 @@ /** * Created by davidben on 9/15/16. */ -public final class Mutect2Engine implements AssemblyRegionEvaluator { +public final class Mutect2Engine implements AssemblyRegionEvaluator, AutoCloseable { private static final List STANDARD_MUTECT_INFO_FIELDS = Arrays.asList(GATKVCFConstants.NORMAL_LOG_10_ODDS_KEY, GATKVCFConstants.TUMOR_LOG_10_ODDS_KEY, GATKVCFConstants.NORMAL_ARTIFACT_LOG_10_ODDS_KEY, GATKVCFConstants.EVENT_COUNT_IN_HAPLOTYPE_KEY, GATKVCFConstants.IN_PON_KEY, GATKVCFConstants.POPULATION_AF_KEY, @@ -376,12 +376,14 @@ public void writeExtraOutputs(final File statsTable) { }); } - public void shutdown() { + @Override + public void close() { likelihoodCalculationEngine.close(); aligner.close(); haplotypeBAMWriter.ifPresent(HaplotypeBAMWriter::close); assembledEventMapVcfOutputWriter.ifPresent(writer -> {assembledEventMapVariants.get().forEach(writer::add); writer.close();}); referenceReader.close(); + genotypingEngine.close(); } @Override @@ -409,7 +411,7 @@ public ActivityProfileState isActive(final AlignmentContext context, final Refer if (tumorLogOdds < MTAC.getInitialLogOdds()) { return new ActivityProfileState(refInterval, 0.0); - } else if (MTAC.trainingDataMode) { + } else if (MTAC.mutect3TrainingDataMode) { return new ActivityProfileState(ref.getInterval(), 1.0); } else if (hasNormal() && !MTAC.genotypeGermlineSites) { final ReadPileup normalPileup = pileup.makeFilteredPileup(pe -> isNormalSample(ReadUtils.getSampleName(pe.getRead(), header))); diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect3DatasetEngine.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect3DatasetEngine.java new file mode 100644 index 00000000000..d92b772a9e0 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/mutect/Mutect3DatasetEngine.java @@ -0,0 +1,284 @@ +package org.broadinstitute.hellbender.tools.walkers.mutect; + +import htsjdk.variant.variantcontext.Allele; +import htsjdk.variant.variantcontext.Genotype; +import htsjdk.variant.variantcontext.VariantContext; +import org.apache.commons.lang3.tuple.Triple; +import org.apache.commons.math3.util.FastMath; +import org.broadinstitute.hellbender.engine.ReferenceContext; +import org.broadinstitute.hellbender.exceptions.UserException; +import org.broadinstitute.hellbender.tools.walkers.annotator.AssemblyComplexity; +import org.broadinstitute.hellbender.tools.walkers.annotator.FeaturizedReadSets; +import org.broadinstitute.hellbender.tools.walkers.annotator.ReferenceBases; +import org.broadinstitute.hellbender.tools.walkers.mutect.filtering.Mutect2FilteringEngine; +import org.broadinstitute.hellbender.utils.IndexRange; +import org.broadinstitute.hellbender.utils.MathUtils; +import org.broadinstitute.hellbender.utils.Utils; +import org.broadinstitute.hellbender.utils.genotyper.AlleleLikelihoods; +import org.broadinstitute.hellbender.utils.haplotype.Haplotype; +import org.broadinstitute.hellbender.utils.read.Fragment; +import org.broadinstitute.hellbender.utils.read.GATKRead; +import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants; +import org.broadinstitute.hellbender.utils.variant.VariantContextGetters; + +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.*; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class Mutect3DatasetEngine implements AutoCloseable { + + public static final int CAPACITY = 100000; + + private enum VariantType { + SNV, INSERTION, DELETION + } + + private enum Label { + ARTIFACT, VARIANT, UNLABELED, IGNORE + } + + // number of features for each vectorized read + private static final int FEATURES_PER_READ = FeaturizedReadSets.FEATURES_PER_READ; + + // number of additional variant features for assembly complexity, reference context + private static final int NUM_EXTRA_FEATURES = 9; + + // threshold of negative log-10 population allele frequency to consider something an artifact for the purposes of training data + // we want to be really sure we don't get germline variants + // TODO: is this really necessary? + private static final double RARE_POPAF_THRESHOLD = 5.9; + + // very cautious threshold of negative log-10 population allele frequency to consider something germline for training data. + // There are so many germline variants we can be wasteful! + private static final double COMMON_POPAF_THRESHOLD = 1; + + // below this tumor log odds we don't consider it an artifact, just a sequencing error + private static final double TLOD_THRESHOLD = 6.0; + + private final int maxRefCount; + private final int maxAltCount; + + // TODO: is this necessary? + private static final int MIN_REF = 5; + + private final PrintWriter printWriter; + + // number of nonartifact data to keep for each artifact datum + private final int nonArtifactPerArtifact; + + // are we generating dataset for training a model or for filtering calls with a pre-trained model? + private final boolean trainingMode; + + private final Set normalSamples; + + // simple method to balance data: for each k-alt-read artifact there are + // nonArtifactPerArtifact (downsampled) k-alt-read non-artifacts. + private final EnumMap> unmatchedArtifactAltCounts; + + + public Mutect3DatasetEngine(final File datasetFile, final boolean trainingMode, final int maxRefCount, final int maxAltCount, final int nonArtifactPerArtifact, final Set normalSamples) { + try { + printWriter = new PrintWriter(new FileWriter(Utils.nonNull(datasetFile))); + } catch (IOException ex) { + throw new UserException.BadInput("Could not create dataset file writer"); + } + + this.normalSamples = Utils.nonNull(normalSamples); + this.trainingMode = trainingMode; + this.nonArtifactPerArtifact = nonArtifactPerArtifact; + this.maxRefCount = maxRefCount; + this.maxAltCount = maxAltCount; + + unmatchedArtifactAltCounts = new EnumMap<>(VariantType.class); + for (final VariantType type : VariantType.values()) { + unmatchedArtifactAltCounts.put(type, new ArrayBlockingQueue<>(CAPACITY)); + } + } + + // add one datum per alt allele + public void addData(final ReferenceContext ref, final VariantContext vc, final AlleleLikelihoods likelihoods, + final AlleleLikelihoods logFragmentLikelihoods) { + final String refBases = ReferenceBases.annotate(ref, vc); + final String refAllele = vc.getReference().getBaseString(); + final String contig = vc.getContig(); + final int position = vc.getStart(); + final Set tumorSamples = likelihoods.samples().stream().filter(sample -> !normalSamples.contains(sample)).collect(Collectors.toSet()); + final int numAlt = vc.getNAlleles() - 1; + + + // the variant has already been annotated, so we have POPAF and AD + final double[] popafs = VariantContextGetters.getAttributeAsDoubleArray(vc, GATKVCFConstants.POPULATION_AF_KEY); + //final double[] altPopulationAFs = MathUtils.applyToArray(popafs, x -> Math.pow(10, -x )); + final double[] tumorLods = Mutect2FilteringEngine.getTumorLogOdds(vc); + final int[] tumorADs = sumADsOverSamples(vc, tumorSamples); + final int[] normalADs = sumADsOverSamples(vc, normalSamples); + final int tumorDepth = (int) MathUtils.sum(tumorADs); + final int normalDepth = (int) MathUtils.sum(normalADs); + final boolean hasNormal = normalDepth > 0; + + final List