diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 84c077c5..afe5e80c 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -24,8 +24,9 @@ jobs: run: pip install flake8 flake8-bugbear - name: Lint with flake8 run: flake8 src - run-tutorial: - name: Run tutorial - random_small + # legacy testing of t-test + run-tutorial-ttest: + name: Run - random_small - t-test runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -39,23 +40,107 @@ jobs: cd tutorial move-dl data=random_small task=encode_data --cfg job move-dl data=random_small task=encode_data - - name: Train model and analyze latent space - run: | - cd tutorial - move-dl data=random_small task=random_small__latent --cfg job - move-dl data=random_small task=random_small__latent + # - name: Identify associations - t-test + # at least 4 refits needed for t-test - name: Identify associations - t-test run: | cd tutorial move-dl data=random_small task=random_small__id_assoc_ttest --cfg job move-dl data=random_small task=random_small__id_assoc_ttest task.training_loop.num_epochs=30 task.num_refits=4 + # categorical dataset pertubation - single and multiprocessed + run-tutorial-cat-pert-single: + name: Run - random_small - singleprocess + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install dependencies + run: pip install . + - name: Prepare tutorial data + run: | + cd tutorial + move-dl data=random_small task=encode_data --cfg job + move-dl data=random_small task=encode_data + - name: Train model and analyze latent space + run: | + cd tutorial + move-dl data=random_small task=random_small__latent --cfg job + move-dl data=random_small task=random_small__latent task.training_loop.num_epochs=100 - name: Identify associations - bayes factors run: | cd tutorial move-dl data=random_small task=random_small__id_assoc_bayes --cfg job - move-dl data=random_small task=random_small__id_assoc_bayes task.training_loop.num_epochs=30 task.num_refits=20 - run-tutorial-cont: - name: Run tutorial - random_continuous + move-dl data=random_small task=random_small__id_assoc_bayes task.training_loop.num_epochs=100 task.num_refits=2 + - name: Identify associations - bayes factors - w/o training + run: | + cd tutorial + move-dl data=random_small task=random_small__id_assoc_bayes --cfg job + move-dl data=random_small task=random_small__id_assoc_bayes task.training_loop.num_epochs=100 task.num_refits=2 + run-tutorial-cat-pert-multi: + name: Run - random_small - multiprocess + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install dependencies + run: pip install . + - name: Prepare tutorial data + run: | + cd tutorial + move-dl data=random_small task=encode_data --cfg job + move-dl data=random_small task=encode_data + - name: Train model and analyze latent space - multiprocess + run: | + cd tutorial + move-dl data=random_small task=random_small__latent --cfg job + move-dl data=random_small task=random_small__latent task.training_loop.num_epochs=100 task.multiprocess=true + - name: Identify associations - bayes factors - multiprocess + run: | + cd tutorial + move-dl data=random_small task=random_small__id_assoc_bayes --cfg job + move-dl data=random_small task=random_small__id_assoc_bayes task.training_loop.num_epochs=100 task.num_refits=2 task.multiprocess=true + - name: Identify associations - bayes factors - multiprocess w/o training + run: | + cd tutorial + move-dl data=random_small task=random_small__id_assoc_bayes --cfg job + move-dl data=random_small task=random_small__id_assoc_bayes task.training_loop.num_epochs=100 task.num_refits=2 task.multiprocess=true + # continous dataset perturbation - single and multiprocessed + run-tutorial-cont-pert-multi: + name: Run - random_continuous - multiprocess + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install dependencies + run: pip install . + - name: Prepare tutorial data + run: | + cd tutorial + move-dl data=random_continuous task=encode_data --cfg job + move-dl data=random_continuous task=encode_data + - name: Train model and analyze latent space - multiprocess + run: | + cd tutorial + move-dl data=random_continuous task=random_continuous__latent task.multiprocess=true --cfg job + move-dl data=random_continuous task=random_continuous__latent task.multiprocess=true + - name: Identify associations - bayes factors - multiprocess + run: | + cd tutorial + move-dl data=random_continuous task=random_continuous__id_assoc_bayes task.multiprocess=true --cfg job + move-dl data=random_continuous task=random_continuous__id_assoc_bayes task.num_refits=1 task.multiprocess=true + - name: Identify associations - bayes factors - multiprocess w/o training + run: | + cd tutorial + move-dl data=random_continuous task=random_continuous__id_assoc_bayes task.multiprocess=true --cfg job + move-dl data=random_continuous task=random_continuous__id_assoc_bayes task.num_refits=1 task.multiprocess=true + run-tutorial-cont-pert-single: + name: Run - random_continuous - singleprocess runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -73,21 +158,22 @@ jobs: cd tutorial move-dl data=random_continuous task=random_continuous__latent --cfg job move-dl data=random_continuous task=random_continuous__latent - - name: Identify associations - t-test + - name: Identify associations - bayes factors run: | cd tutorial - move-dl data=random_continuous task=random_continuous__id_assoc_ttest --cfg job - move-dl data=random_continuous task=random_continuous__id_assoc_ttest task.training_loop.num_epochs=30 task.num_refits=4 - - name: Identify associations - bayes factors + move-dl data=random_continuous task=random_continuous__id_assoc_bayes --cfg job + move-dl data=random_continuous task=random_continuous__id_assoc_bayes task.num_refits=1 + - name: Identify associations - bayes factors - w/o training (repeat) run: | cd tutorial move-dl data=random_continuous task=random_continuous__id_assoc_bayes --cfg job - move-dl data=random_continuous task=random_continuous__id_assoc_bayes task.training_loop.num_epochs=30 task.num_refits=4 + move-dl data=random_continuous task=random_continuous__id_assoc_bayes task.num_refits=1 + # this reuses the same model trained in analyze latent space - name: Identify associations - KS run: | cd tutorial move-dl data=random_continuous task=random_continuous__id_assoc_ks --cfg job - move-dl data=random_continuous task=random_continuous__id_assoc_ks task.training_loop.num_epochs=30 task.num_refits=4 + move-dl data=random_continuous task=random_continuous__id_assoc_ks task.num_refits=1 publish: name: Publish package diff --git a/src/move/conf/schema.py b/src/move/conf/schema.py index 1f2fdeb6..7f23e4fd 100644 --- a/src/move/conf/schema.py +++ b/src/move/conf/schema.py @@ -135,6 +135,7 @@ class AnalyzeLatentConfig(TaskConfig): feature_names: list[str] = field(default_factory=list) reducer: dict[str, Any] = MISSING + multiprocess: bool = False @dataclass @@ -165,6 +166,7 @@ class IdentifyAssociationsConfig(TaskConfig): num_refits: int = MISSING sig_threshold: float = 0.05 save_refits: bool = False + multiprocess: bool = False @dataclass diff --git a/src/move/data/perturbations.py b/src/move/data/perturbations.py index bb91fc0d..cac02ef0 100644 --- a/src/move/data/perturbations.py +++ b/src/move/data/perturbations.py @@ -1,4 +1,8 @@ -__all__ = ["perturb_categorical_data", "perturb_continuous_data"] +__all__ = [ + "perturb_categorical_data", + "perturb_continuous_data_extended_one", + "perturb_continuous_data_extended", +] from pathlib import Path from typing import Literal, Optional, cast @@ -7,12 +11,69 @@ import torch from torch.utils.data import DataLoader +from move.core.logging import get_logger from move.data.dataloaders import MOVEDataset from move.data.preprocessing import feature_stats from move.visualization.dataset_distributions import plot_value_distributions ContinuousPerturbationType = Literal["minimum", "maximum", "plus_std", "minus_std"] +logger = get_logger(__name__) + + +def _build_dataloader( + cat_data, con_data, cat_shapes, con_shapes, batch_size, shuffle=False +): + # currently for continuous data only + dataset = MOVEDataset( + cat_data, + con_data, + cat_shapes, + con_shapes, + ) + + dataloader = DataLoader( + dataset, + shuffle=shuffle, + batch_size=batch_size, + ) + return dataloader + + +def _pertub_cont_feat_col( + baseline_dataset, start_idx, num_features, index_pert_feat, perturbation_type +): + + perturbed_con = baseline_dataset.con_all.clone() + target_dataset = perturbed_con[:, start_idx : start_idx + num_features] + logger.debug(f"Target dataset shape: {target_dataset.shape}") + logger.debug( + f"Changing to desired perturbation value for feature {index_pert_feat}" + ) + # Change the desired feature value by: + # ! one would only need the stats for a single feature? + min_feat_val_list, max_feat_val_list, std_feat_val_list = feature_stats( + target_dataset + ) + if perturbation_type == "minimum": + perturbed_con[:, start_idx + index_pert_feat] = torch.FloatTensor( + [min_feat_val_list[index_pert_feat]] + ) + elif perturbation_type == "maximum": + perturbed_con[:, start_idx + index_pert_feat] = torch.FloatTensor( + [max_feat_val_list[index_pert_feat]] + ) + elif perturbation_type == "plus_std": + perturbed_con[:, start_idx + index_pert_feat] += torch.FloatTensor( + [std_feat_val_list[index_pert_feat]] + ) + elif perturbation_type == "minus_std": + perturbed_con[:, start_idx + index_pert_feat] -= torch.FloatTensor( + [std_feat_val_list[index_pert_feat]] + ) + logger.debug(f"Perturbation succesful for feature {index_pert_feat}") + return perturbed_con + def perturb_categorical_data( baseline_dataloader: DataLoader, @@ -44,7 +105,7 @@ def perturb_categorical_data( slice_ = slice(*splits[target_idx : target_idx + 2]) target_shape = baseline_dataset.cat_shapes[target_idx] - num_features = target_shape[0] # CHANGE + num_features = target_shape[0] dataloaders = [] for i in range(num_features): @@ -53,27 +114,72 @@ def perturb_categorical_data( baseline_dataset.num_samples, *target_shape ) target_dataset[:, i, :] = torch.FloatTensor(target_value) - perturbed_dataset = MOVEDataset( - perturbed_cat, - baseline_dataset.con_all, - baseline_dataset.cat_shapes, - baseline_dataset.con_shapes, - ) - perturbed_dataloader = DataLoader( - perturbed_dataset, - shuffle=False, + perturbed_dataloader = _build_dataloader( + cat_data=perturbed_cat, + con_data=baseline_dataset.con_all, + cat_shapes=baseline_dataset.cat_shapes, + con_shapes=baseline_dataset.con_shapes, batch_size=baseline_dataloader.batch_size, ) dataloaders.append(perturbed_dataloader) return dataloaders -def perturb_continuous_data( +def perturb_categorical_data_one( + baseline_dataloader: DataLoader, + cat_dataset_names: list[str], + target_dataset_name: str, + target_value: np.ndarray, + index_pert_feat: int, +) -> DataLoader: + """Add perturbations to categorical data. For each feature in the target + dataset, change its value to target. + + Args: + baseline_dataloader: Baseline dataloader + cat_dataset_names: List of categorical dataset names + target_dataset_name: Target categorical dataset to perturb + target_value: Target value + + Returns: + List of dataloaders containing all perturbed datasets + """ + + baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset) + assert baseline_dataset.cat_shapes is not None + assert baseline_dataset.cat_all is not None + + target_idx = cat_dataset_names.index(target_dataset_name) + splits = np.cumsum( + [0] + [int.__mul__(*shape) for shape in baseline_dataset.cat_shapes] + ) + slice_ = slice(*splits[target_idx : target_idx + 2]) + + target_shape = baseline_dataset.cat_shapes[target_idx] + + i = index_pert_feat + perturbed_cat = baseline_dataset.cat_all.clone() + target_dataset = perturbed_cat[:, slice_].view( + baseline_dataset.num_samples, *target_shape + ) + target_dataset[:, i, :] = torch.FloatTensor(target_value) + perturbed_dataloader = _build_dataloader( + cat_data=perturbed_cat, + con_data=baseline_dataset.con_all, + cat_shapes=baseline_dataset.cat_shapes, + con_shapes=baseline_dataset.con_shapes, + batch_size=baseline_dataloader.batch_size, + ) + return perturbed_dataloader + + +def perturb_continuous_data_one( baseline_dataloader: DataLoader, con_dataset_names: list[str], target_dataset_name: str, target_value: float, -) -> list[DataLoader]: + index_pert_feat: int, # Index of the datasetto perturb +) -> DataLoader: # change list(DataLoader) to just one DataLoader """Add perturbations to continuous data. For each feature in the target dataset, change its value to target. @@ -81,10 +187,10 @@ def perturb_continuous_data( baseline_dataloader: Baseline dataloader con_dataset_names: List of continuous dataset names target_dataset_name: Target continuous dataset to perturb - target_value: Target value + target_value: Target value. Returns: - List of dataloaders containing all perturbed datasets + One dataloader, with the ith dataset perturbed """ baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset) @@ -93,29 +199,19 @@ def perturb_continuous_data( target_idx = con_dataset_names.index(target_dataset_name) splits = np.cumsum([0] + baseline_dataset.con_shapes) - slice_ = slice(*splits[target_idx : target_idx + 2]) - - num_features = baseline_dataset.con_shapes[target_idx] - - dataloaders = [] - for i in range(num_features): - perturbed_con = baseline_dataset.con_all.clone() - target_dataset = perturbed_con[:, slice_] - target_dataset[:, i] = torch.FloatTensor([target_value]) - perturbed_dataset = MOVEDataset( - baseline_dataset.cat_all, - perturbed_con, - baseline_dataset.cat_shapes, - baseline_dataset.con_shapes, - ) - perturbed_dataloader = DataLoader( - perturbed_dataset, - shuffle=False, - batch_size=baseline_dataloader.batch_size, - ) - dataloaders.append(perturbed_dataloader) + start_idx = splits[target_idx] + + perturbed_con = baseline_dataset.con_all.clone() + perturbed_con[:, start_idx + index_pert_feat] = torch.FloatTensor([target_value]) + perturbed_dataloader = _build_dataloader( + cat_data=baseline_dataset.cat_all, + con_data=perturbed_con, + cat_shapes=baseline_dataset.cat_shapes, + con_shapes=baseline_dataset.con_shapes, + batch_size=baseline_dataloader.batch_size, + ) - return dataloaders + return perturbed_dataloader def perturb_continuous_data_extended( @@ -154,40 +250,27 @@ def perturb_continuous_data_extended( target_idx = con_dataset_names.index(target_dataset_name) # dataset index splits = np.cumsum([0] + baseline_dataset.con_shapes) - slice_ = slice(*splits[target_idx : target_idx + 2]) + start_idx = splits[target_idx] num_features = baseline_dataset.con_shapes[target_idx] dataloaders = [] perturbations_list = [] for i in range(num_features): - perturbed_con = baseline_dataset.con_all.clone() - target_dataset = perturbed_con[:, slice_] - # Change the desired feature value by: - min_feat_val_list, max_feat_val_list, std_feat_val_list = feature_stats( - target_dataset - ) - if perturbation_type == "minimum": - target_dataset[:, i] = torch.FloatTensor([min_feat_val_list[i]]) - elif perturbation_type == "maximum": - target_dataset[:, i] = torch.FloatTensor([max_feat_val_list[i]]) - elif perturbation_type == "plus_std": - target_dataset[:, i] += torch.FloatTensor([std_feat_val_list[i]]) - elif perturbation_type == "minus_std": - target_dataset[:, i] -= torch.FloatTensor([std_feat_val_list[i]]) - - perturbations_list.append(target_dataset[:, i].numpy()) - - perturbed_dataset = MOVEDataset( - baseline_dataset.cat_all, - perturbed_con, - baseline_dataset.cat_shapes, - baseline_dataset.con_shapes, + perturbed_con = _pertub_cont_feat_col( + baseline_dataset=baseline_dataset, + start_idx=start_idx, + num_features=num_features, + index_pert_feat=i, + perturbation_type=perturbation_type, ) + perturbations_list.append(perturbed_con[:, start_idx + i].numpy()) - perturbed_dataloader = DataLoader( - perturbed_dataset, - shuffle=False, + perturbed_dataloader = _build_dataloader( + cat_data=baseline_dataset.cat_all, + con_data=perturbed_con, + cat_shapes=baseline_dataset.cat_shapes, + con_shapes=baseline_dataset.con_shapes, batch_size=baseline_dataloader.batch_size, ) dataloaders.append(perturbed_dataloader) @@ -201,3 +284,82 @@ def perturb_continuous_data_extended( fig.savefig(fig_path) return dataloaders + + +# We will keep the input almost the same, to make everything easier +# However, I have to introduce a variable that allows me to index the specific +# dataloader I want to create (index_pert_feat) +def perturb_continuous_data_extended_one( + baseline_dataloader: DataLoader, + con_dataset_names: list[str], + target_dataset_name: str, + perturbation_type: ContinuousPerturbationType, + index_pert_feat: int, +) -> ( + DataLoader +): # But we change the output from list[DataLoader] to just one DataLoader + """Add perturbations to continuous data. For each feature in the target + dataset, change the feature's value in all samples (in rows): + 1,2) substituting this feature in all samples by the feature's minimum/maximum value + 3,4) Adding/Substracting one standard deviation to the sample's feature value. + + Args: + baseline_dataloader: Baseline dataloader + con_dataset_names: List of continuous dataset names + target_dataset_name: Target continuous dataset to perturb + perturbation_type: 'minimum', 'maximum', 'plus_std' or 'minus_std'. + #output_subpath: path where the figure showing the perturbation will be saved + index_pert_feat: Index we want to perturb + + Returns: + - Dataloader with the ith feature (index_pert_feat) perturbed. + + Note: + This function was created so that it could generalize to non-normalized + datasets. Scaling is done per dataset, not per feature -> slightly different + stds feature to feature. + """ + logger.debug( + f"Inside perturb_continuous_data_extended_one for feature {index_pert_feat}" + ) + + baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset) + assert baseline_dataset.con_shapes is not None + assert baseline_dataset.con_all is not None + + target_idx = con_dataset_names.index(target_dataset_name) # dataset index + splits = np.cumsum([0] + baseline_dataset.con_shapes) + start_idx = splits[target_idx] + + # Use it only if we want to perturb all features in the target dataset + num_features = baseline_dataset.con_shapes[target_idx] + + # Now, instead of the for loop that iterates over all the features we want to + # perturb, we do it only for one feature, the one indicated in index_pert_feat + logger.debug(f"Setting up perturbed_con for feature {index_pert_feat}") + + perturbed_con = _pertub_cont_feat_col( + baseline_dataset=baseline_dataset, + start_idx=start_idx, + num_features=num_features, + index_pert_feat=index_pert_feat, + perturbation_type=perturbation_type, + ) + + logger.debug( + f"Creating perturbed dataset and dataloader for feature {index_pert_feat}" + ) + + perturbed_dataloader = _build_dataloader( + cat_data=baseline_dataset.cat_all, + con_data=perturbed_con, + cat_shapes=baseline_dataset.cat_shapes, + con_shapes=baseline_dataset.con_shapes, + batch_size=baseline_dataloader.batch_size, + ) + + logger.debug( + f"Finished perturb_continuous_data_extended_one for feature {index_pert_feat}" + ) + + return perturbed_dataloader diff --git a/src/move/data/preprocessing.py b/src/move/data/preprocessing.py index 269774f7..3ddbd036 100644 --- a/src/move/data/preprocessing.py +++ b/src/move/data/preprocessing.py @@ -69,6 +69,7 @@ def scale(x: np.ndarray, log2: bool = False) -> tuple[FloatArray, BoolArray]: Args: x: 2D array with samples in its rows and features in its columns + log2: whether to apply log2 transformation to the input Returns: Tuple containing (1) scaled output and (2) a 1D mask marking columns diff --git a/src/move/tasks/analyze_latent.py b/src/move/tasks/analyze_latent.py index c6d59125..baef1db2 100644 --- a/src/move/tasks/analyze_latent.py +++ b/src/move/tasks/analyze_latent.py @@ -1,3 +1,6 @@ +# Now it is analyze_latent_efficient.py + + __all__ = ["analyze_latent"] import re @@ -8,7 +11,9 @@ import numpy as np import pandas as pd import torch +import torch.multiprocessing from sklearn.base import TransformerMixin +from torch.multiprocessing import Pool import move.visualization as viz from move.analysis.metrics import ( @@ -21,8 +26,8 @@ from move.data import io from move.data.dataloaders import MOVEDataset, make_dataloader from move.data.perturbations import ( - perturb_categorical_data, - perturb_continuous_data, + perturb_categorical_data_one, + perturb_continuous_data_one, ) from move.data.preprocessing import one_hot_encode_single from move.models.vae import VAE @@ -48,17 +53,18 @@ def find_feature_values( Tuple containing (1) index of dataset containing feature and (2) values corresponding to the feature """ - _dataset_index, feature_index = [None] * 2 - for _dataset_index, feature_names in enumerate(feature_names_lists): + dataset_index, feature_index = [None] * 2 + for _dataset_index, _feature_names in enumerate(feature_names_lists): try: - feature_index = feature_names.index(feature_name) + feature_index = _feature_names.index(feature_name) + dataset_index = _dataset_index except ValueError: continue break - if _dataset_index is not None and feature_index is not None: + if dataset_index is not None and feature_index is not None: return ( - _dataset_index, - np.take(feature_values[_dataset_index], feature_index, axis=1), + dataset_index, + np.take(feature_values[dataset_index], feature_index, axis=1), ) raise KeyError(f"Feature '{feature_name}' not in any dataset.") @@ -68,6 +74,97 @@ def _validate_task_config(task_config: AnalyzeLatentConfig) -> None: raise ValueError("Reducer class not specified properly.") +def _categorical_importance_worker(args): + """ + Worker function to calculate the importance of categorical features + """ + torch.set_num_threads(1) + + logger = get_logger(__name__) + + ( + test_dataloader, + categorical_names, + dataset_name, + na_value, + index_pert_feat, + num_features, + model, + z, + ) = args + + # Diff will store the differences between z and + # z_perturb for the perturbed feature index_pert_feat + diff = np.empty((num_features)) + + logger.debug(f"Perturbing feature {index_pert_feat} for {dataset_name}") + dataloader = perturb_categorical_data_one( + test_dataloader, + categorical_names, + dataset_name, + na_value, + index_pert_feat, + ) + logger.debug( + "Perturbation completed. Projecting perturbation on latent space for " + f"feature {index_pert_feat}, {dataset_name}" + ) + z_perturb = model.project(dataloader) + logger.debug(f"Calculating diff for feature {index_pert_feat}, {dataset_name}") + diff = np.sum(z_perturb - z, axis=1) + + logger.debug( + "Finished catagorical worker function for " + f"feature {index_pert_feat}, {dataset_name}" + ) + return index_pert_feat, diff + + +def _continuous_importance_worker(args): + """ + Worker function to calculate the importance of continuous features + """ + torch.set_num_threads(1) + + logger = get_logger(__name__) + + ( + test_dataloader, + continuous_names, + dataset_name, + index_pert_feat, + num_features, + model, + z, + ) = args + + # Diff will store the differences between z and z_perturb for the perturbed + # feature index_pert_feat + diff = np.empty((num_features)) + + logger.debug(f"Perturbing feature {index_pert_feat} for {dataset_name}") + dataloader = perturb_continuous_data_one( + test_dataloader, + continuous_names, + dataset_name, + 0.0, + index_pert_feat, + ) + logger.debug( + "Perturbation completed. Projecting perturbation on latent space for " + f"feature {index_pert_feat}, {dataset_name}" + ) + z_perturb = model.project(dataloader) + logger.debug(f"Calculating diff for feature {index_pert_feat}, {dataset_name}") + diff = np.sum(z_perturb - z, axis=1) + + logger.debug( + "Finished continuous worker function for " + f"feature {index_pert_feat}, {dataset_name}" + ) + return index_pert_feat, diff + + def analyze_latent(config: MOVEConfig) -> None: """Train one model to inspect its latent space projections.""" @@ -107,9 +204,10 @@ def analyze_latent(config: MOVEConfig) -> None: logger.debug(f"Model: {model}") + # If we already have a model, we reload it. Otherwise, we train it. model_path = output_path / "model.pt" if model_path.exists(): - logger.debug("Re-loading model") + logger.debug(f"Re-loading model from {model_path}") model.load_state_dict(torch.load(model_path)) model.to(device) else: @@ -230,18 +328,65 @@ def analyze_latent(config: MOVEConfig) -> None: logger.info("Computing feature importance") num_samples = len(cast(Sized, test_dataloader.sampler)) + + # START WITH IMPORTANCE FOR CATEGORICAL FEATURES. MADE CHANGES HERE for i, dataset_name in enumerate(config.data.categorical_names): logger.debug(f"Generating plot: feature importance '{dataset_name}'") na_value = one_hot_encode_single(mappings[dataset_name], None) - dataloaders = perturb_categorical_data( - test_dataloader, config.data.categorical_names, dataset_name, na_value - ) - num_features = len(dataloaders) + cat_dataset_names = config.data.categorical_names + target_idx = cat_dataset_names.index(dataset_name) + target_shape = test_dataset.cat_shapes[target_idx] + num_features = target_shape[0] # Number of features in the current dataset + + # We will use this inside the loop that iterates over all features: + # We create one diff per dataset, to not store all of them in memory z = model.project(test_dataloader) diffs = np.empty((num_samples, num_features)) - for j, dataloader in enumerate(dataloaders): - z_perturb = model.project(dataloader) - diffs[:, j] = np.sum(z_perturb - z, axis=1) + + if config.task.multiprocess: + args = [ + ( + test_dataloader, + config.data.categorical_names, + dataset_name, + na_value, + index_pert_feat, + num_features, + model, + z, + ) + for index_pert_feat in range(num_features) + ] + + with Pool(processes=torch.multiprocessing.cpu_count() - 1) as pool: + logger.debug("Inside the pool loop for categorical features") + # Map worker function to arguments + # We get the bayes_k matrix, filled for all the perturbed features + results = pool.map(_categorical_importance_worker, args) + + # Unpack results + for j, diff in results: + diffs[:, j] = diff + + else: + j = 0 # Index to keep count of the perturbed feature we are in + + for index_pert_feat in range(num_features): + dataloader = perturb_categorical_data_one( + test_dataloader, + config.data.categorical_names, + dataset_name, + na_value, + index_pert_feat, + ) + # We calculate the difference for each of the perturbed features, + # and store it in an object + + z_perturb = model.project(dataloader) + diffs[:, j] = np.sum(z_perturb - z, axis=1) + + j = j + 1 # Increase j for the next iteration + feature_mapping = { str(code): category for category, code in mappings[dataset_name].items() } @@ -252,20 +397,75 @@ def analyze_latent(config: MOVEConfig) -> None: fig.savefig(fig_path, bbox_inches="tight") fig_df = pd.DataFrame(diffs, columns=cat_names[i], index=df_index) fig_df.to_csv(output_path / f"feat_importance_{dataset_name}.tsv", sep="\t") + logger.info( + "Analysis for categorical features completed. " + "Starting analysis for continuous features" + ) + + # NOW, THE SAME BUT FOR CONTINUOUS DATA for i, dataset_name in enumerate(config.data.continuous_names): logger.debug(f"Generating plot: feature importance '{dataset_name}'") - dataloaders = perturb_continuous_data( - test_dataloader, config.data.continuous_names, dataset_name, 0.0 - ) - num_features = len(dataloaders) + con_dataset_names = config.data.continuous_names + target_idx = con_dataset_names.index(dataset_name) + + num_features = test_dataset.con_shapes[target_idx] + + # We will use this inside the loop that iterates over all features: + # We create one diff per dataset, to not store all of them in memory z = model.project(test_dataloader) diffs = np.empty((num_samples, num_features)) - for j, dataloader in enumerate(dataloaders): - z_perturb = model.project(dataloader) - diffs[:, j] = np.sum(z_perturb - z, axis=1) + + if config.task.multiprocess: + args = [ + ( + test_dataloader, + config.data.continuous_names, + dataset_name, + index_pert_feat, + num_features, + model, + z, + ) + for index_pert_feat in range(num_features) + ] + + with Pool(processes=torch.multiprocessing.cpu_count() - 1) as pool: + logger.debug("Inside the pool loop for continuous features") + # Map worker function to arguments + # We get the bayes_k matrix, filled for all the perturbed features + results = pool.map(_continuous_importance_worker, args) + + # Unpack results + + logger.debug(f"Unpacking results for dataset {dataset_name}") + for j, diff in results: + diffs[:, j] = diff + + logger.debug(f"Generating plot for {dataset_name}") + + else: + # Index to check the number of perturbed feature we are in now + j = 0 + + for index_pert_feat in range(num_features): + dataloader = perturb_continuous_data_one( + test_dataloader, + config.data.continuous_names, + dataset_name, + 0.0, + index_pert_feat, + ) + + z_perturb = model.project(dataloader) + diffs[:, j] = np.sum(z_perturb - z, axis=1) + + j = j + 1 + fig = viz.plot_continuous_feature_importance(diffs, con_list[i], con_names[i]) fig_path = str(output_path / f"feat_importance_{dataset_name}.png") fig.savefig(fig_path, bbox_inches="tight") fig_df = pd.DataFrame(diffs, columns=con_names[i], index=df_index) fig_df.to_csv(output_path / f"feat_importance_{dataset_name}.tsv", sep="\t") + + logger.info("Continuous features finished for all datasets") diff --git a/src/move/tasks/bayes_parallel.py b/src/move/tasks/bayes_parallel.py new file mode 100644 index 00000000..1db2b464 --- /dev/null +++ b/src/move/tasks/bayes_parallel.py @@ -0,0 +1,381 @@ +from pathlib import Path +from typing import Literal, Union, cast + +import hydra +import numpy as np +import torch +import torch.multiprocessing +from torch.multiprocessing import Pool +from torch.utils.data import DataLoader + +from move.conf.schema import IdentifyAssociationsBayesConfig, MOVEConfig +from move.core.logging import get_logger +from move.core.typing import BoolArray, FloatArray, IntArray +from move.data import io +from move.data.dataloaders import MOVEDataset +from move.data.perturbations import ( + ContinuousPerturbationType, + perturb_categorical_data_one, + perturb_continuous_data_extended_one, +) +from move.data.preprocessing import one_hot_encode_single +from move.models.vae import VAE + +# We can do three types of statistical tests. Multiprocessing is only implemented +# for bayes at the moment +TaskType = Literal["bayes", "ttest", "ks"] + +# Possible values for continuous pertrubation +CONTINUOUS_TARGET_VALUE = ["minimum", "maximum", "plus_std", "minus_std"] + +logger = get_logger(__name__) + + +def _bayes_approach_worker(args): + """ + Worker function to calculate mean differences and Bayes factors for one feature. + """ + # Set the number of threads available: + # VERY IMPORTANT, TO AVOID CPU OVERSUBSCRIPTION + torch.set_num_threads(1) + + # Unpack arguments. + ( + config, + task_config, + baseline_dataloader, + num_samples, + num_continuous, + i, + models_path, + continuous_shapes, + categorical_shapes, + nan_mask, + feature_mask, + ) = args + # Initialize logging + logger.debug(f"Inside the worker function for num_perturbed {i}") + + # Now we are inside the num_perturbed loop, we will do this for each of the + # perturbed features + # Now, mean_diff will not have a first dimension for num_perturbed, because we will + # not store it for each perturbed feature + # we will use it in each loop for calculating the bayes factors, and then overwrite + # its content with a new perturbed feature + # mean_diff will contain the differences between the baseline and the perturbed + # reconstruction for feature i, taking into account + # all refits (all refits have the same importance) + # We also set up bayes_k, which has the same dimensions as mean_diff + mean_diff = np.zeros((num_samples, num_continuous)) + # Set the normalizer + # Divide by the number of refits. All the refits will have the same importance + normalizer = 1 / task_config.num_refits + + # Create perturbed dataloader for the current feature (i) + logger.debug(f"Creating perturbed dataloader for feature {i}") + if task_config.target_value in CONTINUOUS_TARGET_VALUE: + perturbed_dataloader = perturb_continuous_data_extended_one( + baseline_dataloader=baseline_dataloader, + con_dataset_names=config.data.continuous_names, + target_dataset_name=task_config.target_dataset, + perturbation_type=cast( + ContinuousPerturbationType, task_config.target_value + ), + index_pert_feat=i, + ) + else: + interim_path = Path(config.data.interim_data_path) + mappings = io.load_mappings(interim_path / "mappings.json") + target_mapping = mappings[task_config.target_dataset] + target_value = one_hot_encode_single(target_mapping, task_config.target_value) + perturbed_dataloader = perturb_categorical_data_one( + baseline_dataloader=baseline_dataloader, + cat_dataset_names=config.data.categorical_names, + target_dataset_name=task_config.target_dataset, + target_value=target_value, + index_pert_feat=i, + ) + logger.debug(f"created perturbed dataloader for feature {i}") + + # For each refit, reload baseline reconstruction (obtained in bayes_parallel + # function). Also, get the reconstruction for the perturbed dataloader + for j in range(task_config.num_refits): + + model_path = models_path / f"model_{task_config.model.num_latent}_{j}.pt" + reconstruction_path = ( + models_path / f"baseline_recon_{task_config.model.num_latent}_{j}.pt" + ) + if reconstruction_path.exists(): + logger.debug(f"Loading baseline reconstruction from {reconstruction_path}.") + baseline_recon = torch.load(reconstruction_path) + else: + raise FileNotFoundError("Baseline reconstruction not found.") + + logger.debug(f"Loading model {model_path}, using load function") + model: VAE = hydra.utils.instantiate( + task_config.model, + continuous_shapes=continuous_shapes, + categorical_shapes=categorical_shapes, + ) + device = torch.device("cuda" if task_config.model.cuda else "cpu") + logger.debug(f"Loading model from {model_path}") + model.load_state_dict(torch.load(model_path)) + logger.debug(f"Loaded model from {model_path}") + model.to(device) + model.eval() + + logger.debug(f"Reconstructing num_perturbed {i}, with model {model_path}") + _, perturb_recon = model.reconstruct( + perturbed_dataloader + ) # Instead of dataloaders[i], create the perturbed one here and + # use it only here + logger.debug( + f"Perturbed reconstruction succesful for feature {i}, model {model}" + ) + + # diff is a matrix with the same dimensions as perturb_recon and baseline_recon + # (rows are samples and columns all the continuous features) + # We calculate diff for each refit, and add it to mean_diff after dividing by + # the number of refits + logger.debug(f"Calculating diff for num_perturbed {i}, with model {model}") + diff = perturb_recon - baseline_recon # 2D: N x C + logger.debug( + f"Calculating mean_diff for num_perturbed {i}, with model {model}" + ) + mean_diff += diff * normalizer + logger.debug(f"Deleting model {model_path}, to see if I can free up space?") + del model + logger.debug(f"Deleted model {model_path} in worker {i} to save some space") + + logger.debug(f"mean_diff for feature {i}, calculated, using all refits") + mean_diff_shape = mean_diff.shape + logger.debug(f"Returning mean_diff for feature {i}. Its shape is {mean_diff_shape}") + + # Apply nan_mask to the result in mean_diff + mask = feature_mask | nan_mask # 2D: N x C + diff = np.ma.masked_array(mean_diff, mask=mask) + diff_shape = diff.shape + logger.debug(f"Calculated diff (masked) for feature {i}. Its shape is {diff_shape}") + prob = np.ma.compressed(np.mean(diff > 1e-8, axis=0)) # 1D: C + logger.debug( + f"prob ({prob.shape}) calculated for feature {i}." + " Starting to calculate bayes_k" + ) + + # Calculate bayes factor + bayes_k = np.log(prob + 1e-8) - np.log(1 - prob + 1e-8) + + # Marc's masking approach (for subset of perturbed cont. features?) + # difference for only perturbed feature? + bayes_mask = np.zeros(np.shape(bayes_k.shape)) + if task_config.target_value in CONTINUOUS_TARGET_VALUE: + bayes_mask = ( + baseline_dataloader.dataset.con_all[0, :] + - perturbed_dataloader.dataset.con_all[0, :] + ) + + logger.debug( + f"bayes factor calculated for feature {i}. Woker function {i} finished" + ) + + # Return bayes_k and the index of the feature + return i, bayes_k, bayes_mask + + +def _bayes_approach_parallel( + config: MOVEConfig, + task_config: IdentifyAssociationsBayesConfig, + train_dataloader: DataLoader, + baseline_dataloader: DataLoader, + models_path: Path, + num_perturbed: int, + num_samples: int, + num_continuous: int, + nan_mask: BoolArray, + feature_mask: BoolArray, +) -> tuple[Union[IntArray, FloatArray], ...]: + """ + Calculate Bayes factors for all perturbed features in parallel. + + First, I train or reload the models (number of refits), and save the baseline + reconstruction. We train and get the reconstruction outside to make sure + that we use the same model and use the same baseline reconstruction for all + the worker functions. + """ + logger.debug("Inside the bayes_parallel function") + + assert task_config.model is not None + device = torch.device("cuda" if task_config.model.cuda else "cpu") + + # Train or reload models + logger.info("Training or reloading models") + # non-perturbed baseline dataset + baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset) + + for j in range(task_config.num_refits): + # We create as many models (refits) as indicated in the config file + # For each j (number of refits) we train a different model, but on the same data + # Initialize model + model: VAE = hydra.utils.instantiate( + task_config.model, + continuous_shapes=baseline_dataset.con_shapes, + categorical_shapes=baseline_dataset.cat_shapes, + ) + logger.debug(f"Model: {model} (j={j})") + + # Define paths for the baseline reconstruction and for the model + reconstruction_path = ( + models_path / f"baseline_recon_{task_config.model.num_latent}_{j}.pt" + ) + model_path = models_path / f"model_{task_config.model.num_latent}_{j}.pt" + + if model_path.exists(): + # If the models were already created, we load them only if we need to get a + # baseline reconstruction. Otherwise, nothing needs to be done at this point + logger.debug(f"Model {model_path} already exists") + if not reconstruction_path.exists(): + logger.debug(f"Re-loading refit {j + 1}/{task_config.num_refits}") + model.load_state_dict(torch.load(model_path)) + model.to(device) + logger.debug(f"Model {j} reloaded") + else: + logger.debug( + f"Baseline reconstruction for {reconstruction_path} already exists" + f", no need to load model {model_path} " + ) + else: + # If the models are not created yet, he have to train them, with the + # parameters we indicated in the config file + logger.debug(f"Training refit {j + 1}/{task_config.num_refits}") + model.to(device) + hydra.utils.call( + task_config.training_loop, + model=model, + train_dataloader=train_dataloader, + ) + # Save the refits, to use them later + if task_config.save_refits: + # pickle_protocol=4 is necessary for very big models + torch.save(model.state_dict(), model_path, pickle_protocol=4) + model.eval() + + # Calculate baseline reconstruction + # For each model j, we get a different reconstruction for the baseline. + # We haven't perturbed anything yet, we are just + # getting the reconstruction for the baseline, to make sure that we get + # the same reconstruction for each refit, we cannot + # do it inside each process because the results might be different + # ! here the logic is a bit off. If the reconstruction path exist, the model + # ! does not to be loaded again. + + if reconstruction_path.exists(): + logger.debug( + f"Loading baseline reconstruction from {reconstruction_path}, " + "in the worker function" + ) + # baseline_recon = torch.load(reconstruction_path) + else: + _, baseline_recon = model.reconstruct(baseline_dataloader) + + # Save the baseline reconstruction for each model + logger.debug(f"Saving baseline reconstruction {j}") + torch.save(baseline_recon, reconstruction_path, pickle_protocol=4) + logger.debug(f"Saved baseline reconstruction {j}") + del model + + # Calculate Bayes factors + logger.info("Identifying significant features") + + # Define more arguments that are needed for the worker functions + continuous_shapes = baseline_dataset.con_shapes + categorical_shapes = baseline_dataset.cat_shapes + + logger.debug("Starting parallelization") + + # Define arguments for each worker, and iterate over models and perturbed features + args = [ + ( + config, + task_config, + baseline_dataloader, + num_samples, + num_continuous, + i, + models_path, + continuous_shapes, + categorical_shapes, + nan_mask, + feature_mask[:, [i]], + ) + for i in range(num_perturbed) + ] + + with Pool(processes=torch.multiprocessing.cpu_count() - 1) as pool: + logger.debug("Inside the pool loops") + # Map worker function to arguments + # We get the bayes_k matrix, filled for all the perturbed features + results = pool.map(_bayes_approach_worker, args) + + logger.info("Pool multiprocess completed. Calculating bayes_abs and bayes_p") + + bayes_k = np.empty((num_perturbed, num_continuous)) + bayes_mask = np.zeros(np.shape(bayes_k)) + # Get results in the correct order + for i, computed_bayes_k, mask_k in results: + logger.debug(f"{i} has bayes_k worker {computed_bayes_k}") + # computed_bayes_k: already normalized probability + # (log differences, i.e. Bayes factors) + bayes_k[i, :] = computed_bayes_k + bayes_mask[i, :] = mask_k + bayes_mask[bayes_mask != 0] = 1 + bayes_mask = np.array(bayes_mask, dtype=bool) + + # Calculate Bayes probabilities + bayes_abs = np.abs(bayes_k) # Dimensions are (num_perturbed, num_continuous) + + bayes_p = np.exp(bayes_abs) / (1.0 + np.exp(bayes_abs)) # 2D: P x C + + bayes_abs[bayes_mask] = np.min( + bayes_abs + ) # Bring feature_i feature_i associations to minimum + # Get only the significant associations: + # This will flatten the array, so we get all bayes_abs for all perturbed features + # vs all continuous features in one 1D array + # Then, we sort them, and get the indexes in the flattened array. So, we get an + # list of sorted indexes in the flatenned array + sort_ids = np.argsort(bayes_abs, axis=None)[::-1] # 1D: P*C + logger.debug(f"sort_ids are {sort_ids}") + # bayes_p is the array from which elements will be taken. + # sort_ids contains the indices that determine the order in which elements should + # be taken from bayes_p. + # This operation essentially rearranges the elements of bayes_p based on the + # sorting order specified by sort_ids + # np.take considers the input array as if it were flattened when extracting + # elements using the provided indices. + # So, even though sort_ids is obtained from a flattened version of bayes_abs, + # np.take understands how to map these indices + # correctly to the original shape of bayes_p. + prob = np.take(bayes_p, sort_ids) # 1D: P*C + logger.debug(f"Bayes proba range: [{prob[-1]:.3f} {prob[0]:.3f}]") + + # Sort bayes_k in descending order, aligning with the sorted bayes_abs. + bayes_k = np.take(bayes_k, sort_ids) # 1D: N x C + + # Calculate FDR + fdr = np.cumsum(1 - prob) / np.arange(1, prob.size + 1) # 1D: P*C + idx = np.argmin(np.abs(fdr - task_config.sig_threshold)) + # idx will contain the index of the element in fdr that is closest + # to task_config.sig_threshold. + # This line essentially finds the index where the False Discovery Rate (fdr) is + # closest to the significance threshold + # (task_config.sig_threshold). + logger.debug(f"Index is {idx}") + logger.debug(f"FDR range: [{fdr[0]:.3f} {fdr[-1]:.3f}]") + + # Return elements only up to idx. They will be the significant findings + # sort_ids[:idx]: Indices of features sorted by significance. + # prob[:idx]: Probabilities of significant associations for selected features. + # fdr[:idx]: False Discovery Rate values for selected features. + # bayes_k[:idx]: Bayes Factors indicating the strength of evidence for selected + # associations. + return sort_ids[:idx], prob[:idx], fdr[:idx], bayes_k[:idx] diff --git a/src/move/tasks/identify_associations.py b/src/move/tasks/identify_associations.py index a9daf069..7f1b8532 100644 --- a/src/move/tasks/identify_associations.py +++ b/src/move/tasks/identify_associations.py @@ -32,6 +32,7 @@ ) from move.data.preprocessing import one_hot_encode_single from move.models.vae import VAE +from move.tasks.bayes_parallel import _bayes_approach_parallel from move.visualization.dataset_distributions import ( plot_correlations, plot_cumulative_distributions, @@ -39,9 +40,15 @@ plot_reconstruction_movement, ) +# We can do three types of statistical tests. Multiprocessing is only implemented +# for bayes at the moment TaskType = Literal["bayes", "ttest", "ks"] + +# Possible values for continuous pertrubation CONTINUOUS_TARGET_VALUE = ["minimum", "maximum", "plus_std", "minus_std"] +logger = get_logger(__name__) + def _get_task_type( task_config: IdentifyAssociationsConfig, @@ -71,12 +78,7 @@ def prepare_for_categorical_perturbation( config: MOVEConfig, interim_path: Path, baseline_dataloader: DataLoader, - cat_list: list[FloatArray], -) -> tuple[ - list[DataLoader], - BoolArray, - BoolArray, -]: +) -> list[DataLoader]: """ This function creates the required dataloaders and masks for further categorical association analysis. @@ -85,7 +87,6 @@ def prepare_for_categorical_perturbation( config: main configuration file interim_path: path where the intermediate outputs are saved baseline_dataloader: reference dataloader that will be perturbed - cat_list: list of arrays with categorical data Returns: dataloaders: all dataloaders, including baseline appended last. @@ -95,7 +96,6 @@ def prepare_for_categorical_perturbation( # Read original data and create perturbed datasets task_config = cast(IdentifyAssociationsConfig, config.task) - logger = get_logger(__name__) # Loading mappings: mappings = io.load_mappings(interim_path / "mappings.json") @@ -113,34 +113,14 @@ def prepare_for_categorical_perturbation( ) dataloaders.append(baseline_dataloader) - baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset) - - assert baseline_dataset.con_all is not None - orig_con = baseline_dataset.con_all - nan_mask = (orig_con == 0).numpy() # NaN values encoded as 0s - logger.debug(f"# NaN values: {np.sum(nan_mask)}/{orig_con.numel()}") - - target_dataset_idx = config.data.categorical_names.index(task_config.target_dataset) - target_dataset = cat_list[target_dataset_idx] - feature_mask = np.all(target_dataset == target_value, axis=2) # 2D: N x P - feature_mask |= np.sum(target_dataset, axis=2) == 0 - - return ( - dataloaders, - nan_mask, - feature_mask, - ) + return dataloaders def prepare_for_continuous_perturbation( config: MOVEConfig, output_subpath: Path, baseline_dataloader: DataLoader, -) -> tuple[ - list[DataLoader], - BoolArray, - BoolArray, -]: +) -> list[DataLoader]: """ This function creates the required dataloaders and masks for further continuous association analysis. @@ -163,7 +143,6 @@ def prepare_for_continuous_perturbation( """ # Read original data and create perturbed datasets - logger = get_logger(__name__) task_config = cast(IdentifyAssociationsConfig, config.task) dataloaders = perturb_continuous_data_extended( @@ -175,15 +154,9 @@ def prepare_for_continuous_perturbation( ) dataloaders.append(baseline_dataloader) - baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset) + logger.debug(f"Dataloaders length is {len(dataloaders)}") - assert baseline_dataset.con_all is not None - orig_con = baseline_dataset.con_all - nan_mask = (orig_con == 0).numpy() # NaN values encoded as 0s - logger.debug(f"# NaN values: {np.sum(nan_mask)}/{orig_con.numel()}") - feature_mask = nan_mask - - return (dataloaders, nan_mask, feature_mask) + return dataloaders def _bayes_approach( @@ -203,16 +176,17 @@ def _bayes_approach( assert task_config.model is not None device = torch.device("cuda" if task_config.model.cuda else "cpu") - # Train models - logger = get_logger(__name__) - logger.info("Training models") + # Train or reload models + logger.info("Training or reloading models") mean_diff = np.zeros((num_perturbed, num_samples, num_continuous)) normalizer = 1 / task_config.num_refits - # Last appended dataloader is the baseline + # non-perturbed baseline dataset baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset) for j in range(task_config.num_refits): + # We create as many models (refits) as indicated in the config file + # For each j (number of refits) we train a different model, but on the same data # Initialize model model: VAE = hydra.utils.instantiate( task_config.model, @@ -236,18 +210,31 @@ def _bayes_approach( model=model, train_dataloader=train_dataloader, ) + # Save the refits, to use them later if task_config.save_refits: - torch.save(model.state_dict(), model_path) + # pickle_protocol=4 is necessary for very big models + torch.save(model.state_dict(), model_path, pickle_protocol=4) model.eval() # Calculate baseline reconstruction - _, baseline_recon = model.reconstruct(baseline_dataloader) - min_feat, max_feat = np.zeros((num_perturbed, num_continuous)), np.zeros( - (num_perturbed, num_continuous) - ) - min_baseline, max_baseline = np.min(baseline_recon, axis=0), np.max( - baseline_recon, axis=0 + # For each model j, we get a different reconstruction for the baseline. + # We haven't perturbed anything yet, we are just + # getting the reconstruction for the baseline, to make sure that we get + # the same reconstruction for each refit, we cannot + # do it inside each process because the results might be different + reconstruction_path = ( + models_path / f"baseline_recon_{task_config.model.num_latent}_{j}.pt" ) + if reconstruction_path.exists(): + logger.debug(f"Loading baseline reconstruction from {reconstruction_path}.") + baseline_recon = torch.load(reconstruction_path) + else: + _, baseline_recon = model.reconstruct(baseline_dataloader) + + # # Save the baseline reconstruction for each model + # logger.debug(f"Saving baseline reconstruction {j}") + # torch.save(baseline_recon, reconstruction_path, pickle_protocol=4) + # logger.debug(f"Saved baseline reconstruction {j}") # Calculate perturb reconstruction => keep track of mean difference for i in range(num_perturbed): @@ -255,13 +242,6 @@ def _bayes_approach( diff = perturb_recon - baseline_recon # 2D: N x C mean_diff[i, :, :] += diff * normalizer - min_perturb, max_perturb = np.min(perturb_recon, axis=0), np.max( - perturb_recon, axis=0 - ) - min_feat[i, :], max_feat[i, :] = np.min( - [min_baseline, min_perturb], axis=0 - ), np.max([max_baseline, max_perturb], axis=0) - # Calculate Bayes factors logger.info("Identifying significant features") bayes_k = np.empty((num_perturbed, num_continuous)) @@ -270,7 +250,8 @@ def _bayes_approach( mask = feature_mask[:, [i]] | nan_mask # 2D: N x C diff = np.ma.masked_array(mean_diff[i, :, :], mask=mask) # 2D: N x C prob = np.ma.compressed(np.mean(diff > 1e-8, axis=0)) # 1D: C - bayes_k[i, :] = np.log(prob + 1e-8) - np.log(1 - prob + 1e-8) + computed_bayes_k = np.log(prob + 1e-8) - np.log(1 - prob + 1e-8) + bayes_k[i, :] = computed_bayes_k if task_config.target_value in CONTINUOUS_TARGET_VALUE: bayes_mask[i, :] = ( baseline_dataloader.dataset.con_all[0, :] @@ -281,23 +262,53 @@ def _bayes_approach( bayes_mask = np.array(bayes_mask, dtype=bool) # Calculate Bayes probabilities - bayes_abs = np.abs(bayes_k) - bayes_p = np.exp(bayes_abs) / (1 + np.exp(bayes_abs)) # 2D: N x C + bayes_abs = np.abs(bayes_k) # Dimensions are (num_perturbed, num_continuous) + + bayes_p = np.exp(bayes_abs) / (1 + np.exp(bayes_abs)) # 2D: P x C + bayes_abs[bayes_mask] = np.min( bayes_abs ) # Bring feature_i feature_i associations to minimum - sort_ids = np.argsort(bayes_abs, axis=None)[::-1] # 1D: N x C - prob = np.take(bayes_p, sort_ids) # 1D: N x C + # Get only the significant associations: + # This will flatten the array, so we get all bayes_abs for all perturbed features + # vs all continuous features in one 1D array + # Then, we sort them, and get the indexes in the flattened array. So, we get an + # list of sorted indexes in the flatenned array + sort_ids = np.argsort(bayes_abs, axis=None)[::-1] # 1D: P*C + logger.debug(f"sort_ids are {sort_ids}") + # bayes_p is the array from which elements will be taken. + # sort_ids contains the indices that determine the order in which elements should + # be taken from bayes_p. + # This operation essentially rearranges the elements of bayes_p based on the + # sorting order specified by sort_ids + # np.take considers the input array as if it were flattened when extracting + # elements using the provided indices. + # So, even though sort_ids is obtained from a flattened version of bayes_abs, + # np.take understands how to map these indices + # correctly to the original shape of bayes_p. + prob = np.take(bayes_p, sort_ids) # 1D: P*C logger.debug(f"Bayes proba range: [{prob[-1]:.3f} {prob[0]:.3f}]") - # Sort Bayes + # Sort bayes_k in descending order, aligning with the sorted bayes_abs. bayes_k = np.take(bayes_k, sort_ids) # 1D: N x C # Calculate FDR fdr = np.cumsum(1 - prob) / np.arange(1, prob.size + 1) # 1D idx = np.argmin(np.abs(fdr - task_config.sig_threshold)) + # idx will contain the index of the element in fdr that is closest + # to task_config.sig_threshold. + # This line essentially finds the index where the False Discovery Rate (fdr) is + # closest to the significance threshold + # (task_config.sig_threshold). + logger.debug(f"Index is {idx}") logger.debug(f"FDR range: [{fdr[0]:.3f} {fdr[-1]:.3f}]") + # Return elements only up to idx. They will be the significant findings + # sort_ids[:idx]: Indices of features sorted by significance. + # prob[:idx]: Probabilities of significant associations for selected features. + # fdr[:idx]: False Discovery Rate values for selected features. + # bayes_k[:idx]: Bayes Factors indicating the strength of evidence for selected + # associations. return sort_ids[:idx], prob[:idx], fdr[:idx], bayes_k[:idx] @@ -321,7 +332,6 @@ def _ttest_approach( device = torch.device("cuda" if task_config.model.cuda else "cpu") # Train models - logger = get_logger(__name__) logger.info("Training models") pvalues = np.empty( ( @@ -485,7 +495,6 @@ def _ks_approach( baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset) # Train models - logger = get_logger(__name__) logger.info("Training models") target_dataset_idx = config.data.continuous_names.index(task_config.target_dataset) @@ -693,7 +702,6 @@ def save_results( extra_cols: extra data when calling the approach function extra_colnames: names for the extra data columns """ - logger = get_logger(__name__) logger.info(f"Significant hits found: {sig_ids.size}") task_config = cast(IdentifyAssociationsConfig, config.task) task_type = _get_task_type(task_config) @@ -710,11 +718,15 @@ def save_results( target_dataset_idx = config.data.continuous_names.index( task_config.target_dataset ) + # This creates a DataFrame named a_df with one column + # named "feature_a_name". The values in this column are + # taken from con_names using the target_dataset_idx index. a_df = pd.DataFrame(dict(feature_a_name=con_names[target_dataset_idx])) else: target_dataset_idx = config.data.categorical_names.index( task_config.target_dataset ) + a_df = pd.DataFrame(dict(feature_a_name=cat_names[target_dataset_idx])) a_df.index.name = "feature_a_id" a_df.reset_index(inplace=True) @@ -750,7 +762,6 @@ def identify_associations(config: MOVEConfig) -> None: # DATA PREPARATION ###################### # Read original data and create perturbed datasets#### - logger = get_logger(__name__) task_config = cast(IdentifyAssociationsConfig, config.task) task_type = _get_task_type(task_config) _validate_task_config(task_config, task_type) @@ -771,9 +782,12 @@ def identify_associations(config: MOVEConfig) -> None: config.data.continuous_names, ) + logger.debug( + "Making train dataloader in main function identify_associations_selected" + ) train_dataloader = make_dataloader( - cat_list, - con_list, + cat_list, # List of categorical datasets + con_list, # List of continuous datasets shuffle=True, batch_size=task_config.batch_size, drop_last=True, @@ -790,6 +804,13 @@ def identify_associations(config: MOVEConfig) -> None: cat_list, con_list, shuffle=False, batch_size=task_config.batch_size ) + # nan_mask based on baseline dataset + baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset) + assert baseline_dataset.con_all is not None + orig_con = baseline_dataset.con_all + nan_mask = (orig_con == 0).numpy() # NaN values encoded as 0s + logger.debug(f"# NaN values: {np.sum(nan_mask)}/{orig_con.numel()}") + # Indentify associations between continuous features: logger.info(f"Perturbing dataset: '{task_config.target_dataset}'") if task_config.target_value in CONTINUOUS_TARGET_VALUE: @@ -797,50 +818,88 @@ def identify_associations(config: MOVEConfig) -> None: logger.info(f"Perturbation type: {task_config.target_value}") output_subpath = Path(output_path) / "perturbation_visualization" output_subpath.mkdir(exist_ok=True, parents=True) - ( - dataloaders, - nan_mask, - feature_mask, - ) = prepare_for_continuous_perturbation( - config, output_subpath, baseline_dataloader - ) + if not task_config.multiprocess: + dataloaders = prepare_for_continuous_perturbation( + config, output_subpath, baseline_dataloader + ) + feature_mask = nan_mask + con_dataset_names = config.data.continuous_names + target_idx = con_dataset_names.index( + task_config.target_dataset + ) # dataset index + logger.debug(f"Cont. shapes: {baseline_dataset.con_shapes} [take {target_idx}]") + num_perturbed = baseline_dataset.con_shapes[target_idx] # Identify associations between categorical and continuous features: else: logger.info("Beginning task: identify associations categorical") - ( - dataloaders, - nan_mask, - feature_mask, - ) = prepare_for_categorical_perturbation( - config, interim_path, baseline_dataloader, cat_list + task_config = cast(IdentifyAssociationsConfig, config.task) + target_dataset_idx = config.data.categorical_names.index( + task_config.target_dataset ) - - num_perturbed = len(dataloaders) - 1 # P - logger.debug(f"# perturbed features: {num_perturbed}") + target_dataset = cat_list[target_dataset_idx] + mappings = io.load_mappings(interim_path / "mappings.json") + target_mapping = mappings[task_config.target_dataset] + target_value = one_hot_encode_single(target_mapping, task_config.target_value) + feature_mask = np.all(target_dataset == target_value, axis=2) # 2D: N x P + feature_mask |= np.sum(target_dataset, axis=2) == 0 + if not task_config.multiprocess: + dataloaders = prepare_for_categorical_perturbation( + config, interim_path, baseline_dataloader + ) + num_perturbed = target_dataset.shape[-1] + logger.info( + f"Cat. shapes: {baseline_dataset.cat_shapes}" + f" [take {target_dataset_idx}]" + ) + target_shape = baseline_dataset.cat_shapes[target_dataset_idx] + num_perturbed = target_shape[0] # APPROACH EVALUATION ########################## + # num_perturbed = len(dataloaders) - 1 # P + # logger.debug(f"# perturbed features: {num_perturbed}") if task_type == "bayes": task_config = cast(IdentifyAssociationsBayesConfig, task_config) - sig_ids, *extra_cols = _bayes_approach( - config, - task_config, - train_dataloader, - baseline_dataloader, - dataloaders, - models_path, - num_perturbed, - num_samples, - num_continuous, - nan_mask, - feature_mask, - ) + if task_config.multiprocess: + sig_ids, *extra_cols = _bayes_approach_parallel( + config=config, + task_config=task_config, + train_dataloader=train_dataloader, + baseline_dataloader=baseline_dataloader, + # perturbed dataloaders created in worker function + models_path=models_path, + num_perturbed=num_perturbed, + num_samples=num_samples, + num_continuous=num_continuous, + nan_mask=nan_mask, + feature_mask=feature_mask, + ) + logger.debug( + "Completed bayes task (parallel function in main function " + "(identify_associations_selected))" + ) + else: + sig_ids, *extra_cols = _bayes_approach( + config, + task_config, + train_dataloader=train_dataloader, + baseline_dataloader=baseline_dataloader, + dataloaders=dataloaders, + models_path=models_path, + num_perturbed=num_perturbed, + num_samples=num_samples, + num_continuous=num_continuous, + nan_mask=nan_mask, + feature_mask=feature_mask, + ) extra_colnames = ["proba", "fdr", "bayes_k"] elif task_type == "ttest": task_config = cast(IdentifyAssociationsTTestConfig, task_config) + if task_config.multiprocess: + raise NotImplementedError("Multiprocessing is not supported for T-test.") sig_ids, *extra_cols = _ttest_approach( task_config, train_dataloader, @@ -859,6 +918,8 @@ def identify_associations(config: MOVEConfig) -> None: elif task_type == "ks": task_config = cast(IdentifyAssociationsKSConfig, task_config) + if task_config.multiprocess: + raise NotImplementedError("Multiprocessing is not supported for KS.") sig_ids, *extra_cols = _ks_approach( config, task_config, diff --git a/tutorial/config/data/random_continuous.yaml b/tutorial/config/data/random_continuous.yaml index b33402fa..5071e1e4 100755 --- a/tutorial/config/data/random_continuous.yaml +++ b/tutorial/config/data/random_continuous.yaml @@ -16,9 +16,10 @@ sample_names: categorical_inputs: [] # no categorical inputs continuous_inputs: # a list of continuous datasets - - name: random.continuous.proteomics - log2: true - scale: true - - name: random.continuous.metagenomics - log2: true - scale: true + - name: random.continuous.proteomics # filename in raw_data_path + log2: true # log2 transform data + scale: true # scale data + - name: random.continuous.metagenomics # filename in raw_data_path + log2: true # log2 transform data + scale: true # scale data + \ No newline at end of file diff --git a/tutorial/config/data/random_small.yaml b/tutorial/config/data/random_small.yaml index d9e32564..9dd59c4d 100644 --- a/tutorial/config/data/random_small.yaml +++ b/tutorial/config/data/random_small.yaml @@ -17,9 +17,9 @@ categorical_inputs: # a list of categorical datasets - name: random.small.drugs continuous_inputs: # a list of continuous datasets - - name: random.small.proteomics + - name: random.small.proteomics # filename in raw_data_path log2: true #apply log2 before scaling scale: true #scale data (z-score normalize) - - name: random.small.metagenomics - log2: true - scale: true + - name: random.small.metagenomics # filename in raw_data_path + scale: true # scale data + log2: true # log2 transform data diff --git a/tutorial/config/task/random_continuous__id_assoc_bayes.yaml b/tutorial/config/task/random_continuous__id_assoc_bayes.yaml index 6a004992..78b3f788 100644 --- a/tutorial/config/task/random_continuous__id_assoc_bayes.yaml +++ b/tutorial/config/task/random_continuous__id_assoc_bayes.yaml @@ -1,6 +1,8 @@ defaults: - identify_associations_bayes +multiprocess: False + batch_size: 10 num_refits: 10 diff --git a/tutorial/config/task/random_small__id_assoc_bayes.yaml b/tutorial/config/task/random_small__id_assoc_bayes.yaml index ad7e67f7..c1252251 100644 --- a/tutorial/config/task/random_small__id_assoc_bayes.yaml +++ b/tutorial/config/task/random_small__id_assoc_bayes.yaml @@ -1,6 +1,8 @@ defaults: - identify_associations_bayes +multiprocess: False + batch_size: 10 num_refits: 40