Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improved robustness and error messages in TrainVariantAnnotationsMode…
Browse files Browse the repository at this point in the history
…l for edge cases with insufficient negative training data.
samuelklee committed Nov 21, 2022
1 parent c357388 commit 21a93d4
Showing 1 changed file with 27 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -531,9 +531,6 @@ private void doModelingWorkForVariantType(final VariantType variantType) {

final double[] labeledTrainingAndVariantTypeScores = VariantAnnotationsScorer.readScores(labeledTrainingAndVariantTypeScoresFile);
final List<Boolean> isNegativeTrainingFromLabeledTrainingAndVariantType = Arrays.stream(labeledTrainingAndVariantTypeScores).boxed().map(s -> s < scoreThreshold).collect(Collectors.toList());
final int numNegativeTrainingFromLabeledTrainingAndVariantType = numPassingFilter(isNegativeTrainingFromLabeledTrainingAndVariantType);
logger.info(String.format("Selected %d labeled %s sites below score threshold of %.4f for negative-model training...",
numNegativeTrainingFromLabeledTrainingAndVariantType, variantTypeString, scoreThreshold));

logger.info(String.format("Scoring %d unlabeled %s sites...", numUnlabeledVariantType, variantTypeString));
final File unlabeledVariantTypeAnnotationsFile = LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(annotationNames, unlabeledAnnotations, isUnlabeledVariantType);
@@ -549,12 +546,17 @@ private void doModelingWorkForVariantType(final VariantType variantType) {
final int numNegativeTrainingAndVariantType = negativeTrainingAndVariantTypeAnnotations.length;
final List<Boolean> isNegativeTrainingAndVariantType = Collections.nCopies(numNegativeTrainingAndVariantType, true);

logger.info(String.format("Training %s negative model with %d negative-training sites x %d annotations %s...",
variantTypeString, numNegativeTrainingAndVariantType, annotationNames.size(), annotationNames));
final File negativeTrainingAnnotationsFile = LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(
annotationNames, negativeTrainingAndVariantTypeAnnotations, isNegativeTrainingAndVariantType);
trainAndSerializeModel(negativeTrainingAnnotationsFile, outputPrefixTag + NEGATIVE_TAG);
logger.info(String.format("%s negative model trained and serialized with output prefix \"%s\".", variantTypeString, outputPrefix + outputPrefixTag + NEGATIVE_TAG));
if (numNegativeTrainingAndVariantType > 0) {
logger.info(String.format("Training %s negative model with %d negative-training sites x %d annotations %s...",
variantTypeString, numNegativeTrainingAndVariantType, annotationNames.size(), annotationNames));
final File negativeTrainingAnnotationsFile = LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(
annotationNames, negativeTrainingAndVariantTypeAnnotations, isNegativeTrainingAndVariantType);
trainAndSerializeModel(negativeTrainingAnnotationsFile, outputPrefixTag + NEGATIVE_TAG);
logger.info(String.format("%s negative model trained and serialized with output prefix \"%s\".", variantTypeString, outputPrefix + outputPrefixTag + NEGATIVE_TAG));
} else {
throw new UserException.BadInput(String.format("Attempted to train %s negative model, " +
"but no suitable sites with scores below the specified threshold were found in the provided annotations.", variantTypeString));
}

logger.info(String.format("Re-scoring %d %s calibration sites...", numLabeledCalibrationAndVariantType, variantTypeString));
final File labeledCalibrationAnnotationsFile = LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(annotationNames, annotations, isLabeledCalibrationAndVariantType);
@@ -690,13 +692,23 @@ private static double[][] concatenateLabeledAndUnlabeledNegativeTrainingData(fin
final double[][] unlabeledAnnotations,
final List<Boolean> isNegativeTrainingFromLabeledTrainingAndVariantType,
final List<Boolean> isNegativeTrainingFromUnlabeledVariantType) {
final File negativeTrainingFromLabeledTrainingAndVariantTypeAnnotationsFile =
LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(annotationNames, annotations, isNegativeTrainingFromLabeledTrainingAndVariantType);
final double[][] negativeTrainingFromLabeledTrainingAndVariantTypeAnnotations = LabeledVariantAnnotationsData.readAnnotations(negativeTrainingFromLabeledTrainingAndVariantTypeAnnotationsFile);
final double[][] negativeTrainingFromLabeledTrainingAndVariantTypeAnnotations;
if (numPassingFilter(isNegativeTrainingFromLabeledTrainingAndVariantType) > 0) {
final File negativeTrainingFromLabeledTrainingAndVariantTypeAnnotationsFile =
LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(annotationNames, annotations, isNegativeTrainingFromLabeledTrainingAndVariantType);
negativeTrainingFromLabeledTrainingAndVariantTypeAnnotations = LabeledVariantAnnotationsData.readAnnotations(negativeTrainingFromLabeledTrainingAndVariantTypeAnnotationsFile);
} else {
negativeTrainingFromLabeledTrainingAndVariantTypeAnnotations = new double[0][];
}

final File negativeTrainingFromUnlabeledVariantTypeAnnotationsFile =
LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(annotationNames, unlabeledAnnotations, isNegativeTrainingFromUnlabeledVariantType);
final double[][] negativeTrainingFromUnlabeledVariantTypeAnnotations = LabeledVariantAnnotationsData.readAnnotations(negativeTrainingFromUnlabeledVariantTypeAnnotationsFile);
final double[][] negativeTrainingFromUnlabeledVariantTypeAnnotations;
if (numPassingFilter(isNegativeTrainingFromUnlabeledVariantType) > 0) {
final File negativeTrainingFromUnlabeledVariantTypeAnnotationsFile =
LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(annotationNames, unlabeledAnnotations, isNegativeTrainingFromUnlabeledVariantType);
negativeTrainingFromUnlabeledVariantTypeAnnotations = LabeledVariantAnnotationsData.readAnnotations(negativeTrainingFromUnlabeledVariantTypeAnnotationsFile);
} else {
negativeTrainingFromUnlabeledVariantTypeAnnotations = new double[0][];
}

return Streams.concat(
Arrays.stream(negativeTrainingFromLabeledTrainingAndVariantTypeAnnotations),

0 comments on commit 21a93d4

Please sign in to comment.