Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract Cohort optimizations [VS-493] [VS-1516] #9055

Merged
merged 16 commits into from
Dec 10, 2024
4 changes: 3 additions & 1 deletion .dockstore.yml
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ workflows:
branches:
- master
- ah_var_store
- vs_1516_yolo
tags:
- /.*/
- name: GvsImportGenomes
Expand Down Expand Up @@ -241,6 +242,7 @@ workflows:
branches:
- master
- ah_var_store
- vs_1516_yolo
tags:
- /.*/
- name: GvsWithdrawSamples
Expand Down Expand Up @@ -314,7 +316,7 @@ workflows:
branches:
- master
- ah_var_store
- vs_1490_fix_curate_input_array_files
- vs_1516_yolo
tags:
- /.*/
- name: GvsIngestTieout
Expand Down
2 changes: 1 addition & 1 deletion scripts/variantstore/wdl/GvsUtils.wdl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ task GetToolVersions {
String cloud_sdk_slim_docker = "gcr.io/google.com/cloudsdktool/cloud-sdk:435.0.0-slim"
String variants_docker = "us-central1-docker.pkg.dev/broad-dsde-methods/gvs/variants:2024-11-25-alpine-913039adf8f4"
String variants_nirvana_docker = "us.gcr.io/broad-dsde-methods/variantstore:nirvana_2022_10_19"
String gatk_docker = "us-central1-docker.pkg.dev/broad-dsde-methods/gvs/gatk:2024-11-24-gatkbase-1807487d5912"
String gatk_docker = "us-central1-docker.pkg.dev/broad-dsde-methods/gvs/gatk:2024-11-24-gatkbase-5b5c307bdb5e"
String real_time_genomics_docker = "docker.io/realtimegenomics/rtg-tools:latest"
String gotc_imputation_docker = "us.gcr.io/broad-gotc-prod/imputation-bcf-vcf:1.0.5-1.10.2-0.1.16-1649948623"
String plink_docker = "us-central1-docker.pkg.dev/broad-dsde-methods/gvs/plink2:2024-04-23-slim-a0a65f52cc0e"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import htsjdk.variant.vcf.VCFHeader;
import org.apache.avro.generic.GenericRecord;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.math.LongRange;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.engine.FeatureContext;
Expand All @@ -19,9 +18,9 @@
import org.broadinstitute.hellbender.tools.walkers.ReferenceConfidenceVariantContextMerger;
import org.broadinstitute.hellbender.tools.walkers.annotator.VariantAnnotatorEngine;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.gvs.bigquery.AvroFileReader;
import org.broadinstitute.hellbender.utils.gvs.bigquery.StorageAPIAvroReader;
import org.broadinstitute.hellbender.utils.gvs.bigquery.TableReference;
import org.broadinstitute.hellbender.utils.gvs.bigquery.AvroFileReader;
import org.broadinstitute.hellbender.utils.gvs.localsort.AvroSortingCollectionCodec;
import org.broadinstitute.hellbender.utils.gvs.localsort.SortingCollection;
import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants;
Expand Down Expand Up @@ -86,6 +85,15 @@ public class ExtractCohortEngine {

private final Consumer<VariantContext> variantContextConsumer;

private static class VariantIterables {
public Iterable<GenericRecord> vets;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should brainstorm on a better name for the vets table and related? It now collides with the new name for VQSR Lite

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ugh yes, good point

public Iterable<GenericRecord> refRanges;
public VariantIterables(Iterable<GenericRecord> vets, Iterable<GenericRecord> refRanges) {
this.vets = vets;
this.refRanges = refRanges;
}
}

List<String> getFilterSetInfoTableFields() {
return SchemaUtils.YNG_FIELDS;
}
Expand Down Expand Up @@ -234,6 +242,18 @@ private void processBytesScanned(StorageAPIAvroReader reader) {
}

public void traverse() {

SortedSet<Long> sampleIdsToExtract = new TreeSet<>(this.sampleIdToName.keySet());
VariantBitSet vbs = new VariantBitSet(minLocation, maxLocation);
VariantIterables variantIterables;
if (fqRangesExtractVetTable != null) {
variantIterables = createVariantIterablesFromUnsortedExtractTableBigQueryRanges(fqRangesExtractVetTable, fqRangesExtractRefTable, vbs);
} else if (vetRangesFQDataSet != null) {
variantIterables = createVariantIterablesFromUnsortedBigQueryRanges(vetRangesFQDataSet, sampleIdsToExtract, vbs);
} else {
variantIterables = createVariantsIterablesFromUnsortedAvroRanges(vetAvroFileName, refRangesAvroFileName, vbs, presortedAvroFiles);
}

// First allele here is the ref, followed by the alts associated with that ref. We need this because at this
// point the alleles haven't been joined and remapped to one reference allele.
final Map<Long, Map<Allele, Map<Allele, Double>>> fullScoreMap = new HashMap<>();
Expand All @@ -260,16 +280,26 @@ public void traverse() {

// get filter info (vqslod/sensitivity & yng values)
try (StorageAPIAvroReader reader = new StorageAPIAvroReader(filterSetInfoTableRef, rowRestrictionWithFilterSetName, projectID)) {

long recordsProcessed = 0;
long recordsDropped = 0;
for (final GenericRecord queryRow : reader) {
if (++recordsProcessed % 100000 == 0) {
logger.info("Processed " + recordsProcessed + " filter set info records, dropped " + recordsDropped + ".");
}
final ExtractCohortFilterRecord filterRow = new ExtractCohortFilterRecord(queryRow, getVQScoreFieldName(), getScoreFieldName());

final long location = filterRow.getLocation();
final Allele ref = Allele.create(filterRow.getRefAllele(), true);
final Allele alt = Allele.create(filterRow.getAltAllele(), false);

if (!vbs.containsVariant(location, location + Math.max(ref.length(), alt.length()))) {
++recordsDropped;
continue;
}
final Double score = filterRow.getScore();
final Double vqsScore = filterRow.getVqScore();
final String yng = filterRow.getYng();
final Allele ref = Allele.create(filterRow.getRefAllele(), true);
final Allele alt = Allele.create(filterRow.getAltAllele(), false);

fullScoreMap.putIfAbsent(location, new HashMap<>());
fullScoreMap.get(location).putIfAbsent(ref, new HashMap<>());
fullScoreMap.get(location).get(ref).put(alt, score);
Expand All @@ -280,15 +310,25 @@ public void traverse() {
fullYngMap.get(location).putIfAbsent(ref, new HashMap<>());
fullYngMap.get(location).get(ref).put(alt, yng);
}
logger.info("Processed " + recordsProcessed + " filter set info records, dropped " + recordsDropped + ".");
processBytesScanned(reader);
}
}

// load site-level filter data into data structure
if (filterSetSiteTableRef != null) {
try (StorageAPIAvroReader reader = new StorageAPIAvroReader(filterSetSiteTableRef, rowRestrictionWithFilterSetName, projectID)) {
long recordsProcessed = 0;
long recordsDropped = 0;
for (final GenericRecord queryRow : reader) {
if (++recordsProcessed % 10000 == 0) {
logger.info("Processed " + recordsProcessed + " filter set sites records, dropped " + recordsDropped + ".");
}
long location = Long.parseLong(queryRow.get(SchemaUtils.LOCATION_FIELD_NAME).toString());
if (!vbs.containsVariant(location, location + 1)) {
++recordsDropped;
continue;
}
List<String> filters = Arrays.asList(queryRow.get(SchemaUtils.FILTERS).toString().split(","));
siteFilterMap.put(location, filters);
}
Expand Down Expand Up @@ -320,17 +360,7 @@ public void traverse() {
throw new GATKException("Can not process cross-contig boundaries for Ranges implementation");
}

SortedSet<Long> sampleIdsToExtract = new TreeSet<>(this.sampleIdToName.keySet());
if (fqRangesExtractVetTable != null) {
createVariantsFromUnsortedExtractTableBigQueryRanges(fqRangesExtractVetTable, fqRangesExtractRefTable,
sampleIdsToExtract, minLocation, maxLocation, fullScoreMap, fullVQScoreMap, fullYngMap, samplePloidyMap, siteFilterMap, noVQScoreFilteringRequested);
} else if (vetRangesFQDataSet != null) {
createVariantsFromUnsortedBigQueryRanges(vetRangesFQDataSet, sampleIdsToExtract, minLocation, maxLocation,
fullScoreMap, fullVQScoreMap, fullYngMap, samplePloidyMap, siteFilterMap, noVQScoreFilteringRequested);
} else {
createVariantsFromUnsortedAvroRanges(vetAvroFileName, refRangesAvroFileName, sampleIdsToExtract, minLocation,
maxLocation, fullScoreMap, fullVQScoreMap, fullYngMap, samplePloidyMap, siteFilterMap, noVQScoreFilteringRequested, presortedAvroFiles);
}
createVariantsFromSortedRanges(sampleIdsToExtract, variantIterables, fullScoreMap, fullVQScoreMap, fullYngMap, samplePloidyMap, siteFilterMap, noVQScoreFilteringRequested);

logger.debug("Finished Initializing Reader");

Expand Down Expand Up @@ -1049,25 +1079,16 @@ private SortingCollection<GenericRecord> createSortedReferenceRangeCollectionFro
}


private void createVariantsFromUnsortedBigQueryRanges(
private VariantIterables createVariantIterablesFromUnsortedBigQueryRanges(
final String fqDatasetName,
final SortedSet<Long> sampleIdsToExtract,
final Long minLocation,
final Long maxLocation,
final Map<Long, Map<Allele, Map<Allele, Double>>> fullScoreMap,
final Map<Long, Map<Allele, Map<Allele, Double>>> fullVQScoreMap,
final Map<Long, Map<Allele, Map<Allele, String>>> fullYngMap,
final Map<String, Integer> samplePloidyMap,
final Map<Long, List<String>> siteFilterMap,
final boolean noVQScoreFilteringRequested) {
VariantBitSet vbs) {

// We could handle this by making a map of BitSets or something, but it seems unnecessary to support this
if (!SchemaUtils.decodeContig(minLocation).equals(SchemaUtils.decodeContig(maxLocation))) {
throw new GATKException("Can not process cross-contig boundaries");
}

VariantBitSet vbs = new VariantBitSet(minLocation, maxLocation);

SortingCollection<GenericRecord> sortedVet = createSortedVetCollectionFromBigQuery(projectID,
fqDatasetName,
sampleIdsToExtract,
Expand All @@ -1085,32 +1106,22 @@ private void createVariantsFromUnsortedBigQueryRanges(
localSortMaxRecordsInRam,
vbs);

createVariantsFromSortedRanges(sampleIdsToExtract, sortedVet, sortedReferenceRange, fullScoreMap, fullVQScoreMap, fullYngMap, samplePloidyMap, siteFilterMap, noVQScoreFilteringRequested);
return new VariantIterables(sortedVet, sortedReferenceRange);
}

//
// BEGIN REF RANGES COHORT EXTRACT
//
private void createVariantsFromUnsortedExtractTableBigQueryRanges(
private VariantIterables createVariantIterablesFromUnsortedExtractTableBigQueryRanges(
final String fqVetTable,
final String fqRefTable,
final SortedSet<Long> sampleIdsToExtract,
final Long minLocation,
final Long maxLocation,
final Map<Long, Map<Allele, Map<Allele, Double>>> fullScoreMap,
final Map<Long, Map<Allele, Map<Allele, Double>>> fullVQScoreMap,
final Map<Long, Map<Allele, Map<Allele, String>>> fullYngMap,
final Map<String, Integer> samplePloidyMap,
final Map<Long, List<String>> siteFilterMap,
final boolean noVQScoreFilteringRequested) {
VariantBitSet vbs) {

// We could handle this by making a map of BitSets or something, but it seems unnecessary to support this
if (!SchemaUtils.decodeContig(minLocation).equals(SchemaUtils.decodeContig(maxLocation))) {
throw new GATKException("Can not process cross-contig boundaries");
}

VariantBitSet vbs = new VariantBitSet(minLocation, maxLocation);

SortingCollection<GenericRecord> sortedVet = createSortedVetCollectionFromExtractTableBigQuery(projectID,
fqVetTable,
minLocation,
Expand All @@ -1126,7 +1137,7 @@ private void createVariantsFromUnsortedExtractTableBigQueryRanges(
localSortMaxRecordsInRam,
vbs);

createVariantsFromSortedRanges(sampleIdsToExtract, sortedVet, sortedReferenceRange, fullScoreMap, fullVQScoreMap, fullYngMap, samplePloidyMap, siteFilterMap, noVQScoreFilteringRequested);
return new VariantIterables(sortedVet, sortedReferenceRange);
}

private SortingCollection<GenericRecord> createSortedVetCollectionFromExtractTableBigQuery(final String projectID,
Expand Down Expand Up @@ -1185,18 +1196,10 @@ private SortingCollection<GenericRecord> createSortedReferenceRangeCollectionFro
//
// END REF RANGES COHORT EXTRACT
//
private void createVariantsFromUnsortedAvroRanges(
private VariantIterables createVariantsIterablesFromUnsortedAvroRanges(
final GATKPath vetAvroFileName,
final GATKPath refRangesAvroFileName,
final SortedSet<Long> sampleIdsToExtract,
final Long minLocation,
final Long maxLocation,
final Map<Long, Map<Allele, Map<Allele, Double>>> fullScoreMap,
final Map<Long, Map<Allele, Map<Allele, Double>>> fullVQScoreMap,
final Map<Long, Map<Allele, Map<Allele, String>>> fullYngMap,
final Map<String, Integer> samplePloidyMap,
final Map<Long, List<String>> siteFilterMap,
final boolean noVQScoreFilteringRequested,
VariantBitSet vbs,
final boolean presortedAvroFiles) {

final AvroFileReader vetReader = new AvroFileReader(vetAvroFileName);
Expand All @@ -1209,8 +1212,6 @@ private void createVariantsFromUnsortedAvroRanges(
sortedVet = vetReader;
sortedReferenceRange = refRangesReader;
} else {
VariantBitSet vbs = new VariantBitSet(minLocation, maxLocation);

SortingCollection<GenericRecord> localSortedVet = getAvroSortingCollection(vetReader.getSchema(), localSortMaxRecordsInRam);
addToVetSortingCollection(localSortedVet, vetReader, vbs);

Expand All @@ -1221,13 +1222,11 @@ private void createVariantsFromUnsortedAvroRanges(
sortedReferenceRange = localSortedReferenceRange;
}

createVariantsFromSortedRanges(sampleIdsToExtract, sortedVet, sortedReferenceRange, fullScoreMap, fullVQScoreMap, fullYngMap, samplePloidyMap, siteFilterMap, noVQScoreFilteringRequested);

return new VariantIterables(sortedVet, sortedReferenceRange);
}

void createVariantsFromSortedRanges(final SortedSet<Long> sampleIdsToExtract,
final Iterable<GenericRecord> sortedVet,
Iterable<GenericRecord> sortedReferenceRange,
final VariantIterables variantIterables,
final Map<Long, Map<Allele, Map<Allele, Double>>> fullScoreMap,
final Map<Long, Map<Allele, Map<Allele, Double>>> fullVQScoreMap,
final Map<Long, Map<Allele, Map<Allele, String>>> fullYngMap,
Expand Down Expand Up @@ -1255,9 +1254,9 @@ void createVariantsFromSortedRanges(final SortedSet<Long> sampleIdsToExtract,
referenceCache.put(sampleId, new TreeSet<>());
}

Iterator<GenericRecord> sortedReferenceRangeIterator = sortedReferenceRange.iterator();
Iterator<GenericRecord> sortedReferenceRangeIterator = variantIterables.refRanges.iterator();

for (final GenericRecord sortedRow : sortedVet) {
for (final GenericRecord sortedRow : variantIterables.vets) {
final ExtractCohortRecord vetRow = new ExtractCohortRecord(sortedRow);
long variantLocation = vetRow.getLocation();
long variantSample = vetRow.getSampleId();
Expand Down
Loading