From 578c7b1ff74fafa90830edf210d237a0b55b82c0 Mon Sep 17 00:00:00 2001 From: James Date: Wed, 15 Jan 2025 14:43:13 -0500 Subject: [PATCH] made a somewhat robust mechanism for converting to the python argument outputs for permutect --- scripts/permutect/make_training_dataset.wdl | 2 +- .../permutect/PermutectArgumentConstants.java | 154 ++++++++++++++++++ .../PermutectBaseModelArgumentCollection.java | 79 +++++++++ .../permutect/PermutectPreprocessDataset.java | 29 ++-- .../permutect/PermutectTrainBaseModel.java | 66 ++++++++ ...utectTrainingParamsArgumentCollection.java | 55 +++++++ .../PermutectArgumentConstantsUnitTest.java | 132 +++++++++++++++ 7 files changed, 502 insertions(+), 15 deletions(-) create mode 100644 src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectArgumentConstants.java create mode 100644 src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectBaseModelArgumentCollection.java create mode 100644 src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectTrainBaseModel.java create mode 100644 src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectTrainingParamsArgumentCollection.java create mode 100644 src/test/java/org/broadinstitute/hellbender/tools/permutect/PermutectArgumentConstantsUnitTest.java diff --git a/scripts/permutect/make_training_dataset.wdl b/scripts/permutect/make_training_dataset.wdl index 4540e839f01..62a8bb1225a 100644 --- a/scripts/permutect/make_training_dataset.wdl +++ b/scripts/permutect/make_training_dataset.wdl @@ -123,7 +123,7 @@ task Preprocess { command <<< set -e - preprocess_dataset --training_datasets ~{training_dataset} --chunk_size ~{chunk_size} ~{"--sources " + source_label} --output train.tar + gatk PermutectPreprocessDataset --training-datasets ~{training_dataset} --chunk-size ~{chunk_size} ~{"--sources " + source_label} --output train.tar >>> runtime { diff --git a/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectArgumentConstants.java b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectArgumentConstants.java new file mode 100644 index 00000000000..6422bfb6251 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectArgumentConstants.java @@ -0,0 +1,154 @@ +package org.broadinstitute.hellbender.tools.permutect; + +import com.google.common.annotations.VisibleForTesting; +import org.broadinstitute.barclay.argparser.CommandLineArgumentParser; +import org.broadinstitute.barclay.argparser.CommandLineParser; +import org.broadinstitute.barclay.argparser.NamedArgumentDefinition; + +import java.lang.reflect.Field; +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.Map.entry; + +public class PermutectArgumentConstants { + + // Java-style (kebab case) without _K suffix + public static final String STATE_DICT_NAME = "model-state-dict"; + public static final String ARTIFACT_LOG_PRIORS_NAME = "artifact-log-priors"; + public static final String ARTIFACT_SPECTRA_STATE_DICT_NAME = "artifact-spectra-state-dict"; + public static final String HYPERPARAMS_NAME = "hyperparams"; + public static final String NUM_READ_FEATURES_NAME = "num-read-features"; + public static final String NUM_INFO_FEATURES_NAME = "num-info-features"; + public static final String REF_SEQUENCE_LENGTH_NAME = "ref-sequence-length"; + public static final String HIDDEN_LAYERS_NAME = "hidden-layers"; + public static final String NUM_BASE_FEATURES_NAME = "num-base-features"; + public static final String NUM_REF_ALT_FEATURES_NAME = "num-ref-alt-features"; + + public static final String SOURCES_NAME = "sources"; + public static final String SOURCE_NAME = "source"; + + public static final String INPUT_NAME = "input"; + public static final String OUTPUT_NAME = "output"; + public static final String OUTPUT_DIR_NAME = "output-dir"; + + public static final String READ_LAYERS_NAME = "read-layers"; + public static final String SELF_ATTENTION_HIDDEN_DIMENSION_NAME = "self-attention-hidden-dimension"; + public static final String NUM_SELF_ATTENTION_LAYERS_NAME = "num-self-attention-layers"; + + public static final String LEARNING_METHOD_NAME = "learning-method"; + + public static final String INFO_LAYERS_NAME = "info-layers"; + public static final String AGGREGATION_LAYERS_NAME = "aggregation-layers"; + public static final String CALIBRATION_LAYERS_NAME = "calibration-layers"; + public static final String REF_SEQ_LAYER_STRINGS_NAME = "ref-seq-layer-strings"; + public static final String DROPOUT_P_NAME = "dropout-p"; + public static final String LEARNING_RATE_NAME = "learning-rate"; + public static final String WEIGHT_DECAY_NAME = "weight-decay"; + public static final String BATCH_NORMALIZE_NAME = "batch-normalize"; + public static final String LEARN_ARTIFACT_SPECTRA_NAME = "learn-artifact-spectra"; + + public static final String TRAINING_DATASETS_NAME = "training-datasets"; + public static final String TRAIN_TAR_NAME = "train-tar"; + public static final String EVALUATION_TAR_NAME = "evaluation-tar"; + public static final String TEST_DATASET_NAME = "test-dataset"; + public static final String NORMAL_ARTIFACT_DATASETS_NAME = "normal-artifact-datasets"; + public static final String REWEIGHTING_RANGE_NAME = "reweighting-range"; + public static final String BATCH_SIZE_NAME = "batch-size"; + public static final String CHUNK_SIZE_NAME = "chunk-size"; + public static final String NUM_EPOCHS_NAME = "num-epochs"; + public static final String NUM_CALIBRATION_EPOCHS_NAME = "num-calibration-epochs"; + public static final String INFERENCE_BATCH_SIZE_NAME = "inference-batch-size"; + public static final String NUM_WORKERS_NAME = "num-workers"; + public static final String NUM_SPECTRUM_ITERATIONS_NAME = "num-spectrum-iterations"; + public static final String SPECTRUM_LEARNING_RATE_NAME = "spectrum-learning-rate"; + + public static final String DATASET_EDIT_TYPE_NAME = "dataset-edit"; + + public static final String TENSORBOARD_DIR_NAME = "tensorboard-dir"; + + public static final String INITIAL_LOG_VARIANT_PRIOR_NAME = "initial-log-variant-prior"; + public static final String INITIAL_LOG_ARTIFACT_PRIOR_NAME = "initial-log-artifact-prior"; + public static final String CONTIGS_TABLE_NAME = "contigs-table"; + public static final String GENOMIC_SPAN_NAME = "genomic-span"; + public static final String MAF_SEGMENTS_NAME = "maf-segments"; + public static final String NORMAL_MAF_SEGMENTS_NAME = "normal-maf-segments"; + public static final String GERMLINE_MODE_NAME = "germline-mode"; + public static final String NO_GERMLINE_MODE_NAME = "no-germline-mode"; + public static final String HET_BETA_NAME = "het-beta"; + + public static final String BASE_MODEL_NAME = "base-model"; + public static final String M3_MODEL_NAME = "permutect-model"; + public static final String PRETRAINED_MODEL_NAME = "pretrained-model"; + + @VisibleForTesting + static final Map PERMUTECT_PYTHON_ARGUMENT_MAP = Collections.unmodifiableMap(generateArgumentMap()); + + + /** + * Takes in the command line parser for a permutect tool and converts and returns a string list of all of the appropriate arguments + * for the wrapped python script that are A) actually present for the tool and B) have been set by the user. + * + * @param parser the command line parser for the tool in question from which to generate python arguments + */ + //TODO this might be easier done by directly taking the input arguments directly + public static List getPtyhonClassArgumentsFromToolParser(CommandLineParser parser) { + if (parser instanceof CommandLineArgumentParser argParser) { + List pythonArgs = new ArrayList<>(); + for (Map.Entry entry : PERMUTECT_PYTHON_ARGUMENT_MAP.entrySet()) { + NamedArgumentDefinition arg = argParser.getNamedArgumentDefinitionByAlias(entry.getKey()); + if (arg != null && arg.getHasBeenSet()) { // arg can be null if it is not actually a valid argument for the tool in question + pythonArgs.add("--" + entry.getValue()); + + //TODO double check the toString() method for the argument value + if (arg.isFlag()) { + continue; // flags don't have values + } else if (arg.isCollection()) { + // The python argument code for permutect expects a sequenctial list of strings following the list argument + ((Collection) arg.getArgumentValue()).forEach(value -> pythonArgs.add(value.toString())); + } else { + pythonArgs.add(arg.getArgumentValue().toString()); + } + } + } + return pythonArgs; + + } else { + throw new IllegalArgumentException("command line parser is not CommandLineArgumentParser"); + } + } + + /** + * A number of utilities to make converting from the java wrappers to the python methods as easy as possible. + */ + private static String convertToPythonStyle(String javaStyle) { + return javaStyle.replace('-', '_'); + } + + /** + * Generate the static map using reflection. + */ + public static Map generateArgumentMap() { + return Stream.of(PermutectArgumentConstants.class.getDeclaredFields()) + .filter(field -> java.lang.reflect.Modifier.isStatic(field.getModifiers()) + && java.lang.reflect.Modifier.isFinal(field.getModifiers()) + && field.getType().equals(String.class)) + .collect(Collectors.toMap( + PermutectArgumentConstants::getFieldValue, // Java-style name + field -> convertToPythonStyle(getFieldValue(field)) // Python-style name + )); + } + + /** + * Safely get the value of a static final field. + */ + private static String getFieldValue(Field field) { + try { + return (String) field.get(null); + } catch (IllegalAccessException e) { + throw new RuntimeException("Unable to access field: " + field.getName(), e); + } + } + +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectBaseModelArgumentCollection.java b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectBaseModelArgumentCollection.java new file mode 100644 index 00000000000..b2c5fa99292 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectBaseModelArgumentCollection.java @@ -0,0 +1,79 @@ +package org.broadinstitute.hellbender.tools.permutect; + +import org.broadinstitute.barclay.argparser.Argument; + +import java.io.Serializable; +import java.util.List; + +public class PermutectBaseModelArgumentCollection implements Serializable { + private static final long serialVersionUID = 1L; + @Argument( + doc = "Optional pretrained base model to initialize training.", + fullName = PermutectArgumentConstants.PRETRAINED_MODEL_NAME, + optional = true + ) + public String pretrainedModelName = null; + + @Argument( + doc = "Dimensions of hidden layers in the read embedding subnetwork, including the dimension of the embedding itself. Negative values indicate residual skip connections.", + fullName = PermutectArgumentConstants.READ_LAYERS_NAME, + optional = false + ) + public List readLayers = null; + + @Argument( + doc = "Hidden dimension of transformer keys and values in the self-attention layers.", + fullName = PermutectArgumentConstants.SELF_ATTENTION_HIDDEN_DIMENSION_NAME, + optional = false + ) + public String selfAttentionHiddenDimension = null; + + @Argument( + doc = "Number of symmetric gated MLP self-attention layers.", + fullName = PermutectArgumentConstants.NUM_SELF_ATTENTION_LAYERS_NAME, + optional = false + ) + public String numSelfAttentionLayers = null; + + @Argument( + doc = "Dimensions of hidden layers in the info embedding subnetwork, including the dimension of the embedding itself. Negative values indicate residual skip connections.", + fullName = PermutectArgumentConstants.INFO_LAYERS_NAME, + optional = false + ) + public List infoLayers = null; + + @Argument( + doc = "Dimensions of hidden layers in the aggregation subnetwork, excluding the dimension of input from lower subnetworks and the dimension (1) of the output logit. Negative values indicate residual skip connections.", + fullName = PermutectArgumentConstants.AGGREGATION_LAYERS_NAME, + optional = false + ) + public List aggregationLayers = null; + + @Argument( + doc = "List of strings specifying convolution layers of the reference sequence embedding. For example: convolution/kernel_size=3/out_channels=64 pool/kernel_size=2 leaky_relu convolution/kernel_size=3/dilation=2/out_channels=5 leaky_relu flatten linear/out_features=10.", + fullName = PermutectArgumentConstants.REF_SEQ_LAYER_STRINGS_NAME, + optional = false + ) + public List refSeqLayerStrings = null; + + @Argument( + doc = "Dropout probability (default: 0.0).", + fullName = PermutectArgumentConstants.DROPOUT_P_NAME, + optional = true + ) + public String dropoutP = "0.0"; + + @Argument( + doc = "Magnitude of data augmentation by randomly weighted average of read embeddings. A value of x yields random weights between 1 - x and 1 + x (default: 0.3).", + fullName = PermutectArgumentConstants.REWEIGHTING_RANGE_NAME, + optional = true + ) + public String reweightingRange = "0.3"; + + @Argument( + doc = "Flag to turn on batch normalization.", + fullName = PermutectArgumentConstants.BATCH_NORMALIZE_NAME, + optional = true + ) + public Boolean batchNormalize = false; +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectPreprocessDataset.java b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectPreprocessDataset.java index e3307f1ba6f..f2684c40486 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectPreprocessDataset.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectPreprocessDataset.java @@ -1,6 +1,7 @@ package org.broadinstitute.hellbender.tools.permutect; import org.broadinstitute.barclay.argparser.Argument; +import org.broadinstitute.barclay.argparser.ArgumentCollection; import org.broadinstitute.barclay.argparser.BetaFeature; import org.broadinstitute.barclay.argparser.CommandLineProgramProperties; import org.broadinstitute.barclay.help.DocumentedFeature; @@ -28,45 +29,45 @@ public class PermutectPreprocessDataset extends CommandLineProgram { //TODO handle lists for this? Make it a gatk list? @Argument( doc = "List of plain text data files.", - fullName = "training-datasets" + fullName = PermutectArgumentConstants.TRAINING_DATASETS_NAME ) public String trainingDatasetName = null; @Argument( doc = "Size in bytes of output binary data files. Default is 2e9.", - fullName = "chunk-size", + fullName = PermutectArgumentConstants.CHUNK_SIZE_NAME, optional = true ) public String chunkSizeName = null; @Argument( doc = "Integer sources corresponding to plain text data files for distinguishing different sequencing conditions.", - fullName = "sources", + fullName = PermutectArgumentConstants.SOURCES_NAME, optional = true ) public String sources = null; @Argument( doc = "Path to output tarfile of training data.", - fullName = "output" + fullName = PermutectArgumentConstants.OUTPUT_NAME ) public String outputTarGz = null; + // Shared argument collections to include in arguments + @ArgumentCollection + PermutectBaseModelArgumentCollection baseArgumentCollection = new PermutectBaseModelArgumentCollection(); + @ArgumentCollection + PermutectTrainingParamsArgumentCollection trainingParamsArgumentCollection = new PermutectTrainingParamsArgumentCollection(); + @Override protected Object doWork() { - - //TODO this is where I check the environment - PythonScriptExecutor executor = new PythonScriptExecutor(true); - final List arguments = new ArrayList<>(); - arguments.add("--training_datasets=" + trainingDatasetName); - if (chunkSizeName != null) { arguments.add("--chunk_size=" + chunkSizeName);} - if (sources != null) { arguments.add("--sources=" + sources);} - arguments.add("--output=" + CopyNumberArgumentValidationUtils.getCanonicalPath(outputTarGz)); + List pythonifiedArguments = PermutectArgumentConstants.getPtyhonClassArgumentsFromToolParser(getCommandLineParser()); return executor.executeScript( - new Resource(PERMUTECT_PREPREOCESS_DATASET_SCRIPT, PermutectPreprocessDataset.class), + new Resource(PERMUTECT_PREPREOCESS_DATASET_SCRIPT, PermutectTrainBaseModel.class), null, - arguments); + pythonifiedArguments); } + } \ No newline at end of file diff --git a/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectTrainBaseModel.java b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectTrainBaseModel.java new file mode 100644 index 00000000000..7d38f66d3de --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectTrainBaseModel.java @@ -0,0 +1,66 @@ +package org.broadinstitute.hellbender.tools.permutect; + +import org.broadinstitute.barclay.argparser.Argument; +import org.broadinstitute.barclay.argparser.BetaFeature; +import org.broadinstitute.barclay.argparser.CommandLineProgramProperties; +import org.broadinstitute.barclay.help.DocumentedFeature; +import org.broadinstitute.hellbender.cmdline.CommandLineProgram; +import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions; +import org.broadinstitute.hellbender.tools.copynumber.arguments.CopyNumberArgumentValidationUtils; +import org.broadinstitute.hellbender.utils.io.Resource; +import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor; +import picard.cmdline.programgroups.VariantFilteringProgramGroup; + +import java.util.ArrayList; +import java.util.List; + +@CommandLineProgramProperties( + summary = "train the Permutect read set representation model.", + oneLineSummary = "train the Permutect read set representation model", + programGroup = VariantFilteringProgramGroup.class +) +@DocumentedFeature +@BetaFeature +public class PermutectTrainBaseModel extends CommandLineProgram { + + public static final String TRAIN_BASE_MODEL_PY = "train_base_model.py"; + + @Argument( + doc = "Options [SUPERVISED, SEMISUPERVISED, SUPERVISED_CLUSTERING, AFFINE, MASK_PREDICTION, AUTOENCODER, DEEPSAD, MARS].", + fullName = PermutectArgumentConstants.LEARNING_METHOD_NAME, + optional = true + ) + public String trainingDatasetName = null; + + @Argument( + doc = "Tarfile of training/validation datasets produced by preprocess_dataset.", + fullName = PermutectArgumentConstants.TRAIN_TAR_NAME, + optional = false + ) + public String chunkSizeName = null; + + @Argument( + doc = "Output location for the saved model file.", + fullName = PermutectArgumentConstants.OUTPUT_NAME, + optional = false + ) + public String sources = null; + + @Argument( + doc = "output tensorboard directory.", + fullName = PermutectArgumentConstants.TENSORBOARD_DIR_NAME, + optional = true + ) + public String outputTarGz = null; + + @Override + protected Object doWork() { + PythonScriptExecutor executor = new PythonScriptExecutor(true); + List pythonifiedArguments = PermutectArgumentConstants.getPtyhonClassArgumentsFromToolParser(getCommandLineParser()); + + return executor.executeScript( + new Resource(TRAIN_BASE_MODEL_PY, PermutectTrainBaseModel.class), + null, + pythonifiedArguments); + } +} \ No newline at end of file diff --git a/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectTrainingParamsArgumentCollection.java b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectTrainingParamsArgumentCollection.java new file mode 100644 index 00000000000..15e51de7430 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectTrainingParamsArgumentCollection.java @@ -0,0 +1,55 @@ +package org.broadinstitute.hellbender.tools.permutect; + +import org.broadinstitute.barclay.argparser.Argument; + +public class PermutectTrainingParamsArgumentCollection { + @Argument( + doc = "Learning rate for the model.", + fullName = PermutectArgumentConstants.LEARNING_RATE_NAME, + optional = true + ) + public String learningRate = "0.001"; + + @Argument( + doc = "Weight decay for the optimizer.", + fullName = PermutectArgumentConstants.WEIGHT_DECAY_NAME, + optional = true + ) + public String weightDecay = "0.0"; + + @Argument( + doc = "Batch size for training.", + fullName = PermutectArgumentConstants.BATCH_SIZE_NAME, + optional = true + ) + public String batchSize = "64"; + + @Argument( + doc = "Number of subprocesses devoted to data loading, including reading from memory map, collating batches, and transferring to GPU.", + fullName = PermutectArgumentConstants.NUM_WORKERS_NAME, + optional = true + ) + public String numWorkers = "0"; + + @Argument( + doc = "Number of epochs for primary training loop.", + fullName = PermutectArgumentConstants.NUM_EPOCHS_NAME, + optional = false + ) + public String numEpochs; + + @Argument( + doc = "Number of calibration-only epochs.", + fullName = PermutectArgumentConstants.NUM_CALIBRATION_EPOCHS_NAME, + optional = true + ) + public String numCalibrationEpochs = "0"; + + @Argument( + doc = "Batch size when performing model inference (not training).", + fullName = PermutectArgumentConstants.INFERENCE_BATCH_SIZE_NAME, + optional = true + ) + public String inferenceBatchSize = "8192"; + +} diff --git a/src/test/java/org/broadinstitute/hellbender/tools/permutect/PermutectArgumentConstantsUnitTest.java b/src/test/java/org/broadinstitute/hellbender/tools/permutect/PermutectArgumentConstantsUnitTest.java new file mode 100644 index 00000000000..b859c5129a0 --- /dev/null +++ b/src/test/java/org/broadinstitute/hellbender/tools/permutect/PermutectArgumentConstantsUnitTest.java @@ -0,0 +1,132 @@ +package org.broadinstitute.hellbender.tools.permutect; + +import org.broadinstitute.barclay.argparser.Argument; +import org.broadinstitute.barclay.argparser.ArgumentCollection; +import org.broadinstitute.barclay.argparser.CommandLineParser; +import org.broadinstitute.hellbender.GATKBaseTest; +import org.broadinstitute.hellbender.cmdline.CommandLineProgram; +import org.broadinstitute.hellbender.testutils.ArgumentsBuilder; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class PermutectArgumentConstantsUnitTest extends GATKBaseTest { + + private class DummyPermutectArgCollection { + @Argument(fullName = PermutectArgumentConstants.NUM_EPOCHS_NAME, doc = "argument from an argument collection", optional = true) + private String Arg3 = null; + } + + private class dummyPermutectWrapper extends CommandLineProgram { + + @Argument(fullName = "dummy-argument",doc = "not in python argument list", optional = true) + private String Arg1 = null; + + // TMP_DIR_NAME = "tmp_dir" // this is a representative inhereited argument that is present in the python argument list + + @Argument(fullName = PermutectArgumentConstants.OUTPUT_NAME, doc = "a standard permutect argument", optional = false) + private String Arg2 = null; + + @Argument(fullName = PermutectArgumentConstants.INFO_LAYERS_NAME, doc = "in python argument list", optional = true) + private String Arg3 = null; + + @Argument(fullName = PermutectArgumentConstants.BASE_MODEL_NAME, doc = "in python argument list, has GATK defined default value is overwritten", optional = true) + private String Arg4 = "THIS_SHOULD_NOT_BE_HERE"; + + @Argument(fullName = PermutectArgumentConstants.BATCH_SIZE_NAME, doc = "in python argument list, has GATK defined default value, but is not specified on the cli", optional = true) + private String Arg4b = "THIS_SHOULD_NOT_BE_HERE"; + + @Argument(fullName = PermutectArgumentConstants.DROPOUT_P_NAME, doc = "flag argument", optional = true) + private boolean Arg5 = false; + + @Argument(fullName = PermutectArgumentConstants.AGGREGATION_LAYERS_NAME, doc = "list argument, optional", optional = true) + private List Arg6 = new ArrayList<>(); + + @ArgumentCollection + DummyPermutectArgCollection args = new DummyPermutectArgCollection(); + + @Override + protected Object doWork() { return null; } + } + + @Test + public void testGetPtyhonClassArgumentsFromToolParser() { + ArgumentsBuilder builder = new ArgumentsBuilder(); + builder.add(PermutectArgumentConstants.OUTPUT_NAME, "output"); + builder.add("dummy-argument", "THIS_SHOULD_NOT_BE_HERE"); + builder.add(PermutectArgumentConstants.INFO_LAYERS_NAME, "info_layers"); + builder.add(PermutectArgumentConstants.BASE_MODEL_NAME, "base_model"); + builder.addFlag(PermutectArgumentConstants.DROPOUT_P_NAME); + builder.add(PermutectArgumentConstants.AGGREGATION_LAYERS_NAME, "agg1"); + builder.add(PermutectArgumentConstants.AGGREGATION_LAYERS_NAME, "agg2"); + builder.add(PermutectArgumentConstants.AGGREGATION_LAYERS_NAME, "agg3"); + builder.add(PermutectArgumentConstants.AGGREGATION_LAYERS_NAME, "agg4"); + builder.add(PermutectArgumentConstants.NUM_EPOCHS_NAME, "num_epochs"); + CommandLineParser parser = new dummyPermutectWrapper().getCommandLineParser(); + final boolean conversionMap = parser.parseArguments(new PrintStream(System.err), builder.getArgsArray()); + + List pyArgs = PermutectArgumentConstants.getPtyhonClassArgumentsFromToolParser(parser); + Assert.assertTrue(pyArgs.contains("--output")); + Assert.assertTrue(pyArgs.contains("output")); + Assert.assertEquals(pyArgs.indexOf("output") - 1, pyArgs.indexOf("--output")); + + Assert.assertTrue(pyArgs.contains("--info_layers")); + Assert.assertTrue(pyArgs.contains("info_layers")); + Assert.assertEquals(pyArgs.indexOf("info_layers") - 1, pyArgs.indexOf("--info_layers")); + + Assert.assertTrue(pyArgs.contains("--dropout_p")); + + Assert.assertTrue(pyArgs.contains("--aggregation_layers")); + Assert.assertTrue(pyArgs.contains("agg1")); + Assert.assertTrue(pyArgs.contains("agg2")); + Assert.assertTrue(pyArgs.contains("agg3")); + Assert.assertTrue(pyArgs.contains("agg4")); + Assert.assertEquals(pyArgs.indexOf("agg1") - 1, pyArgs.indexOf("--aggregation_layers")); + Assert.assertEquals(pyArgs.indexOf("agg2") - 1, pyArgs.indexOf("agg1")); + Assert.assertEquals(pyArgs.indexOf("agg3") - 1, pyArgs.indexOf("agg2")); + Assert.assertEquals(pyArgs.indexOf("agg4") - 1, pyArgs.indexOf("agg3")); + + Assert.assertTrue(pyArgs.contains("--num_epochs")); + Assert.assertTrue(pyArgs.contains("num_epochs")); + Assert.assertEquals(pyArgs.indexOf("num_epochs") - 1, pyArgs.indexOf("--num_epochs")); + + Assert.assertFalse(pyArgs.contains("--dummy-argument")); + Assert.assertFalse(pyArgs.contains("THIS_SHOULD_NOT_BE_HERE")); + + Assert.assertTrue(pyArgs.contains("--base_model")); + Assert.assertTrue(pyArgs.contains("base_model")); + Assert.assertEquals(pyArgs.indexOf("base_model") - 1, pyArgs.indexOf("--base_model")); + + Assert.assertFalse(pyArgs.contains("--tmp_dir")); + Assert.assertFalse(pyArgs.contains("tmp_dir")); + + Assert.assertFalse(pyArgs.contains("--batch_size")); + } + + @Test + public void testGenerateArgumentMap() { + final Map conversionMap = PermutectArgumentConstants.PERMUTECT_PYTHON_ARGUMENT_MAP; + + Assert.assertNotNull(conversionMap); + Assert.assertTrue(conversionMap.entrySet().size() > 30); // assert that the map is not empty and that it reflectively picked up a lot of arguments, the exact number will be subject to change + + for (Map.Entry entry : conversionMap.entrySet()) { + Assert.assertNotNull(entry.getKey()); + Assert.assertFalse(entry.getKey().contains("_")); + + // ptyhon arguments should not contain hyphens + Assert.assertNotNull(entry.getValue()); + Assert.assertFalse(entry.getValue().contains("-")); + } + + // various illegal fields that could have snuck into the reflection by acciedent that we want to make sure didn't + Assert.assertFalse(conversionMap.containsKey("PERMUTECT_PYTHON_ARGUMENT_MAP")); + Assert.assertFalse(conversionMap.containsKey("dragen-mode")); + Assert.assertFalse(conversionMap.containsKey("getPythonClassArgumentsFromToolParser")); + Assert.assertFalse(conversionMap.containsKey("serialVersionUID")); + } +} \ No newline at end of file