Skip to content

Commit

Permalink
Changes to Mutect2 to support Mutect3 (#7663)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbenjamin authored Mar 28, 2022
1 parent 7c06a6a commit cb813ed
Show file tree
Hide file tree
Showing 16 changed files with 997 additions and 80 deletions.
149 changes: 146 additions & 3 deletions scripts/mutect2_wdl/mutect3_training_data.wdl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ workflow Mutect3TrainingData {
String? realignment_extra_args
String? m2_extra_args
String? m2_extra_filtering_args
String? normal_artifact_extra_args
String? split_intervals_extra_args
File? truth_vcf
File? truth_vcf_idx
Boolean? make_bamout
Expand All @@ -36,6 +38,11 @@ workflow Mutect3TrainingData {

String m2_extra_args_with_training_mode = select_first([m2_extra_args, ""]) + " --training-data-mode --training-data-mode-ref-downsample " + ref_downsample

Runtime small_runtime = {"gatk_docker": gatk_docker, "gatk_override": gatk_override,
"max_retries": 2, "preemptible": 0, "cpu": 2,
"machine_mem": 4000, "command_mem": 3500,
"disk": 100, "boot_disk_size": 12}

# call on the tumor (with normal if present) to get tumor read data and M2 filtering
call m2.Mutect2 as Tumor {
input:
Expand Down Expand Up @@ -138,14 +145,48 @@ workflow Mutect3TrainingData {
gatk_docker = gatk_docker,
preemptible = preemptible
}
}

call m2.SplitIntervals as Split {
input:
intervals = intervals,
ref_fasta = ref_fasta,
ref_fai = ref_fai,
ref_dict = ref_dict,
scatter_count = scatter_count,
split_intervals_extra_args = split_intervals_extra_args,
runtime_params = small_runtime
}
scatter (subintervals in Split.interval_files ) {
call GetNormalArtifactData {
input:
ref_fasta = ref_fasta,
ref_fai = ref_fai,
ref_dict = ref_dict,
tumor_reads = select_first([normal_bam]),
tumor_reads_index = select_first([normal_bai]),
normal_reads = tumor_bam,
normal_reads_index = tumor_bai,
intervals = subintervals,
preemptible = preemptible,
max_retries = max_retries,
extra_args = normal_artifact_extra_args,
gatk_override = gatk_override,
gatk_docker = gatk_docker
}
}

call MergeNormalArtifactData {
input:
input_tables = GetNormalArtifactData.table,
runtime_params = small_runtime
}
}
output {
File tumor_table = select_first([TumorConcordanceTable.table, TumorTable.table])
File? normal_table = NormalTable.table
File? normal_artifact_table = MergeNormalArtifactData.merged_table
}
}

Expand Down Expand Up @@ -209,7 +250,7 @@ task MakeTableFromMutect2 {

gatk --java-options "-Xmx2g" SelectVariants -V ~{filtered_vcf} --restrict-alleles-to BIALLELIC -O biallelic.vcf
gatk --java-options "-Xmx2g" VariantsToTable -V biallelic.vcf \
-F CHROM -F POS -F REF -F ALT -F POPAF -F TLOD -F STATUS -F REF_BASES -GF DP -F FILTER -GF FRS \
-F CHROM -F POS -F REF -F ALT -F POPAF -F TLOD -F STATUS -F REF_BASES -F HEC -F HAPDOM -F HAPCOMP -GF DP -F FILTER -GF FRS \
--show-filtered \
-O output.table
}
Expand Down Expand Up @@ -245,7 +286,7 @@ task MakeTableFromConcordance {
for file in ~{tpfp} ~{ftnfn}; do
gatk --java-options "-Xmx2g" SelectVariants -V $file --restrict-alleles-to BIALLELIC -O biallelic.vcf
gatk --java-options "-Xmx2g" VariantsToTable -V biallelic.vcf \
-F CHROM -F POS -F REF -F ALT -F POPAF -F TLOD -F STATUS -F REF_BASES -GF DP -F FILTER -GF FRS \
-F CHROM -F POS -F REF -F ALT -F POPAF -F TLOD -F STATUS -F REF_BASES -F HEC -F HAPDOM -F HAPCOMP -GF DP -F FILTER -GF FRS \
--show-filtered \
-O tmp.table

Expand All @@ -269,4 +310,106 @@ task MakeTableFromConcordance {
output {
File table = "output.table"
}
}

task GetNormalArtifactData {
input {
File? intervals
File ref_fasta
File ref_fai
File ref_dict
File tumor_reads
File tumor_reads_index
File? normal_reads
File? normal_reads_index
String? extra_args

File? gatk_override
String? gcs_project_for_requester_pays

# runtime
String gatk_docker
Int? mem
Int? preemptible
Int? max_retries
Int? disk_space
Int? cpu
Boolean use_ssd = false
}

# Mem is in units of GB but our command and memory runtime values are in MB
Int machine_mem = if defined(mem) then mem * 1000 else 3500
Int command_mem = machine_mem - 500

parameter_meta{
intervals: {localization_optional: true}
ref_fasta: {localization_optional: true}
ref_fai: {localization_optional: true}
ref_dict: {localization_optional: true}
tumor_reads: {localization_optional: true}
tumor_reads_index: {localization_optional: true}
normal_reads: {localization_optional: true}
normal_reads_index: {localization_optional: true}
}

command <<<
set -e

export GATK_LOCAL_JAR=~{default="/root/gatk.jar" gatk_override}

if [[ ! -z "~{normal_reads}" ]]; then
gatk --java-options "-Xmx~{command_mem}m" GetSampleName -R ~{ref_fasta} -I ~{normal_reads} -O normal_name.txt -encode \
~{"--gcs-project-for-requester-pays " + gcs_project_for_requester_pays}
normal_sample="`cat normal_name.txt`"
fi

gatk --java-options "-Xmx~{command_mem}m" GetNormalArtifactData \
-R ~{ref_fasta} ~{"-L " + intervals} -I ~{tumor_reads} -I ~{normal_reads} -O normal_artifact.table \
-normal $normal_sample \
~{extra_args} ~{"--gcs-project-for-requester-pays " + gcs_project_for_requester_pays}
>>>

runtime {
docker: gatk_docker
bootDiskSizeGb: 12
memory: machine_mem + " MB"
disks: "local-disk " + select_first([disk_space, 100]) + if use_ssd then " SSD" else " HDD"
preemptible: select_first([preemptible, 10])
maxRetries: select_first([max_retries, 0])
cpu: select_first([cpu, 1])
}

output {
File table = "normal_artifact.table"
}
}

task MergeNormalArtifactData {
input {
Array[File] input_tables
Runtime runtime_params
}
command {
set -e
export GATK_LOCAL_JAR=~{default="/root/gatk.jar" runtime_params.gatk_override}
gatk --java-options "-Xmx~{runtime_params.command_mem}m" GatherNormalArtifactData \
-I ~{sep=' -I ' input_tables} \
-O normal_artifact.table
}

runtime {
docker: runtime_params.gatk_docker
bootDiskSizeGb: runtime_params.boot_disk_size
memory: runtime_params.machine_mem + " MB"
disks: "local-disk " + runtime_params.disk + " HDD"
preemptible: runtime_params.preemptible
maxRetries: runtime_params.max_retries
cpu: runtime_params.cpu
}
output {
File merged_table = "normal_artifact.table"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import htsjdk.variant.variantcontext.VariantContext;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.mutable.MutableInt;
import org.apache.commons.lang3.tuple.Triple;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.gatk.nativebindings.smithwaterman.SWOverhangStrategy;
import org.broadinstitute.hellbender.engine.FeatureContext;
Expand Down Expand Up @@ -40,14 +41,22 @@ public AssemblyComplexity() { }

@Override
public Map<String, Object> annotate(final ReferenceContext ref,
final FeatureContext features,
final VariantContext vc,
final AlleleLikelihoods<GATKRead, Allele> likelihoods,
final AlleleLikelihoods<Fragment, Allele> fragmentLikelihoods,
final AlleleLikelihoods<Fragment, Haplotype> haplotypeLikelihoods) {

final FeatureContext features,
final VariantContext vc,
final AlleleLikelihoods<GATKRead, Allele> likelihoods,
final AlleleLikelihoods<Fragment, Allele> fragmentLikelihoods,
final AlleleLikelihoods<Fragment, Haplotype> haplotypeLikelihoods) {
final Triple<int[], int[], double[]> annotations = annotate(vc, haplotypeLikelihoods);
final Map<String, Object> result = new HashMap<>();

result.put(GATKVCFConstants.HAPLOTYPE_EQUIVALENCE_COUNTS_KEY , annotations.getLeft());
result.put(GATKVCFConstants.HAPLOTYPE_COMPLEXITY_KEY , annotations.getMiddle());
result.put(GATKVCFConstants.HAPLOTYPE_DOMINANCE_KEY , annotations.getRight());
return result;
}

public static Triple<int[], int[], double[]> annotate(final VariantContext vc, final AlleleLikelihoods<Fragment, Haplotype> haplotypeLikelihoods) {

// count best-read support for each haplotype
final Map<Haplotype, MutableInt> haplotypeSupportCounts = haplotypeLikelihoods.alleles().stream()
.collect(Collectors.toMap(hap -> hap, label -> new MutableInt(0)));
Expand All @@ -69,8 +78,6 @@ public Map<String, Object> annotate(final ReferenceContext ref,
.mapToInt(n->n)
.toArray();

result.put(GATKVCFConstants.HAPLOTYPE_EQUIVALENCE_COUNTS_KEY, equivalenceCounts);

// we're going to calculate the complexity of this variant's haplotype (that is, the variant-supporting haplotype
// with the most reads) versus the closest (in terms of edit distance) germline haplotype. The haplotype
// with the greatest read support is considered germline, and as a heuristic we consider the second-most-supported
Expand All @@ -95,8 +102,6 @@ public Map<String, Object> annotate(final ReferenceContext ref,
return germlineHaplotypes.stream().mapToInt(gh -> editDistance(gh, mostSupportedHaplotypeWithAllele, vc.getStart())).min().getAsInt();
}).toArray();

result.put(GATKVCFConstants.HAPLOTYPE_COMPLEXITY_KEY, editDistances);

// measure which proportion of reads supporting each alt allele fit the most-supported haplotype for that allele
final double[] haplotypeDominance = IntStream.range(0, vc.getNAlleles() - 1).mapToDouble(altAlleleIndex -> {
final int[] counts = haplotypesByDescendingSupport.stream()
Expand All @@ -106,9 +111,7 @@ public Map<String, Object> annotate(final ReferenceContext ref,
return MathUtils.arrayMax(counts) / (double) MathUtils.sum(counts);
}).toArray();

result.put(GATKVCFConstants.HAPLOTYPE_DOMINANCE_KEY, haplotypeDominance);

return result;
return Triple.of(equivalenceCounts, editDistances, haplotypeDominance);
}


Expand All @@ -129,9 +132,12 @@ private static boolean containsAltAllele(final EventMap eventMap, final VariantC
final List<VariantContext> overlapping = eventMap.getOverlappingEvents(vc.getStart());
if (overlapping.isEmpty()) {
return false;
} else if (overlapping.get(0).getStart() != vc.getStart()) {
return false;
} else {
final VariantContext eventMapVC = overlapping.get(0);
final int excessBases = vc.getReference().length() - eventMapVC.getReference().length();

return equalBasesExcludingSuffix(eventMapVC.getAlternateAllele(0).getBases(),
vc.getAlternateAllele(altAlleleIndex).getBases(), excessBases);
}
Expand All @@ -142,6 +148,9 @@ private static boolean containsAltAllele(final EventMap eventMap, final VariantC
private static boolean equalBasesExcludingSuffix(final byte[] eventMapBases, final byte[] variantContextBases, final int suffixSize) {
if (eventMapBases.length + suffixSize != variantContextBases.length) {
return false;
} else if (eventMapBases.length > variantContextBases.length) {
return false; // edge case -- event map is longer, though minimal, because it is a MNP
// even if the leading bases match, let's call this not a match
} else {
for (int n = 0; n < eventMapBases.length; n++) {
if (eventMapBases[n] != variantContextBases[n]) {
Expand Down
Loading

0 comments on commit cb813ed

Please sign in to comment.