diff --git a/setup.cfg b/setup.cfg index 549ca96a..0e2ebec7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,5 +38,4 @@ console_scripts = [flake8] max-line-length = 88 aggressive = 2 -extend-select = B950 -extend-ignore = E203,E501,E701 \ No newline at end of file +extend-ignore = E203 \ No newline at end of file diff --git a/src/move/data/perturbations.py b/src/move/data/perturbations.py index 59825e8d..bb91fc0d 100644 --- a/src/move/data/perturbations.py +++ b/src/move/data/perturbations.py @@ -127,8 +127,8 @@ def perturb_continuous_data_extended( ) -> list[DataLoader]: """Add perturbations to continuous data. For each feature in the target dataset, change the feature's value in all samples (in rows): - 1,2) substituting this feature in all samples by the feature's minimum/maximum value. - 3,4) Adding/Substracting one standard deviation to the sample's feature value. + 1,2) substituting this feature in all samples by the feature's minimum/maximum value + 3,4) Adding/Substracting one standard deviation to the sample's feature value Args: baseline_dataloader: Baseline dataloader @@ -144,8 +144,8 @@ def perturb_continuous_data_extended( Note: This function was created so that it could generalize to non-normalized - datasets. Scaling is done per dataset, not per feature -> slightly different stds - feature to feature. + datasets. Scaling is done per dataset, not per feature -> slightly different + stds feature to feature. """ baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset) diff --git a/src/move/tasks/encode_data.py b/src/move/tasks/encode_data.py index f243518b..f0a83633 100644 --- a/src/move/tasks/encode_data.py +++ b/src/move/tasks/encode_data.py @@ -51,7 +51,8 @@ def encode_data(config: DataConfig): filepath = raw_data_path / f"{dataset_name}.tsv" names, values = io.read_tsv(filepath, sample_names) - # Plotting the value distribution for all continuous datasets before preprocessing: + # Plotting the value distribution for all continuous datasets + # before preprocessing: fig = plot_value_distributions(values) fig_path = str( output_path / "Value_distribution_{}_unprocessed.png".format(dataset_name) diff --git a/src/move/tasks/identify_associations.py b/src/move/tasks/identify_associations.py index bbb8c99a..95fa56b3 100644 --- a/src/move/tasks/identify_associations.py +++ b/src/move/tasks/identify_associations.py @@ -437,26 +437,30 @@ def _ks_approach( task_config: IdentifyAssociationsKSConfig configuration. train_dataloader: training DataLoader. baseline_dataloader: unperturbed DataLoader. - dataloaders: list of DataLoaders where DataLoader[i] is obtained by perturbing feature i - in the target dataset. + dataloaders: list of DataLoaders where DataLoader[i] is obtained by perturbing + feature i in the target dataset. models_path: path to the models. num_perturbed: number of perturbed features. num_samples: total number of samples - num_continuous: number of continuous features (all continuous datasets concatenated). - con_names: list of lists where eah inner list contains the feature names of a specific continuous dataset + num_continuous: number of continuous features + (all continuous datasets concatenated). + con_names: list of lists where eah inner list + contains the feature names of a specific continuous dataset output_path: path where QC summary metrics will be saved. Returns: - sort_ids: list with flattened IDs of the associations above the significance threshold. - ks_distance: Ordered list with signed KS scores. KS scores quantify the direction and - magnitude of the shift in feature B's reconstruction when perturbing feature A. + sort_ids: list with flattened IDs of the associations + above the significance threshold. + ks_distance: Ordered list with signed KS scores. KS scores quantify the + direction and magnitude of the shift in feature B's reconstruction + when perturbing feature A. !!! Note !!!: The sign of the KS score can be misleading: negative sign means positive shift. - since the cumulative distribution starts growing later and is found below the reference - (baseline). Hence: + since the cumulative distribution starts growing later and is found below + the reference (baseline). Hence: a) with plus_std, negative sign means a positive correlation. b) with minus_std, negative sign means a negative correlation. """ @@ -598,11 +602,13 @@ def _ks_approach( edges, hist_base, hist_pert, - f"Cumulative_perturbed_{i}_measuring_{k}_stats_{stats[j, i, k]}", + f"Cumulative_perturbed_{i}_measuring_{ + k}_stats_{stats[j, i, k]}", ) fig.savefig( figure_path - / f"Cumulative_refit_{j}_perturbed_{i}_measuring_{k}_stats_{stats[j, i, k]}.png" + / f"Cumulative_refit_{j}_perturbed_{ + i}_measuring_{k}_stats_{stats[j, i, k]}.png" ) # Feature changes: @@ -640,7 +646,8 @@ def _ks_approach( sort_ids = np.argsort(abs(final_stats), axis=None)[::-1] # 1D: N x C ks_distance = np.take(final_stats, sort_ids) # 1D: N x C - # Writing Quality control csv file. Mean slope and correlation over refits as qc metrics. + # Writing Quality control csv file. + # Mean slope and correlation over refits as qc metrics. logger.info("Writing QC file") qc_df = pd.DataFrame({"Feature names": feature_names}) qc_df["slope"] = np.nanmean(slope, axis=0) diff --git a/src/move/training/training_loop.py b/src/move/training/training_loop.py index 9a54fdf6..2405d676 100644 --- a/src/move/training/training_loop.py +++ b/src/move/training/training_loop.py @@ -44,22 +44,26 @@ def training_loop( Args: model (VAE): trained VAE model object - train_dataloader (DataLoader): An object feeding data to the VAE with training data - valid_dataloader (Optional[DataLoader], optional): An object feeding data to the VAE with validation data. - Defaults to None. + train_dataloader (DataLoader): An object feeding data to the VAE + with training data + valid_dataloader (Optional[DataLoader], optional): An object feeding data to the + VAE with validation data. Defaults to None. lr (float, optional): learning rate. Defaults to 1e-4. num_epochs (int, optional): number of epochs. Defaults to 100. - batch_dilation_steps (list[int], optional): a list with integers corresponding to epochs when batch size is - increased. Defaults to []. - kld_warmup_steps (list[int], optional): a list with integers corresponding to epochs when kld is decreased by - the selected rate. Defaults to []. - early_stopping (bool, optional): boolean if use early stopping . Defaults to False. - patience (int, optional): number of epochs to wait before early stop if no progress on the validation set. - Defaults to 0. + batch_dilation_steps (list[int], optional): a list with integers corresponding + to epochs when batch size is increased. Defaults to []. + kld_warmup_steps (list[int], optional): a list with integers corresponding to + epochs when kld is decreased by the selected rate. Defaults to []. + early_stopping (bool, optional): boolean if use early stopping. + Defaults to False. + + patience (int, optional): number of epochs to wait before early stop + if no progress on the validation set. Defaults to 0. Returns: (tuple): a tuple containing: - *outputs (*list): lists containing information of epoch loss, BCE loss, SSE loss, KLD loss + *outputs (*list): lists containing information of epoch loss, BCE loss, + SSE loss, KLD loss kld_weight (float): final KLD after dilations during the training """ diff --git a/src/move/visualization/dataset_distributions.py b/src/move/visualization/dataset_distributions.py index ff4e8e59..39c7f895 100644 --- a/src/move/visualization/dataset_distributions.py +++ b/src/move/visualization/dataset_distributions.py @@ -88,8 +88,8 @@ def plot_feature_association_graph( ) -> matplotlib.figure.Figure: """ This function plots a graph where each node corresponds to a feature and the edges - represent the associations between features. Edge width represents the probability of - said association, not the association's effect size. + represent the associations between features. Edge width represents the probability + of said association, not the association's effect size. Input: association_df: pandas dataframe containing the following columns: @@ -190,7 +190,8 @@ def plot_feature_mean_median( array: FloatArray, axis=0, style: str = DEFAULT_PLOT_STYLE ) -> matplotlib.figure.Figure: """ - Plot feature values together with the mean, median, min and max values at each array position. + Plot feature values together with the mean, median, min and max values + at each array position. """ with style_settings(style): fig = plt.figure(figsize=(15, 3)) @@ -216,14 +217,16 @@ def plot_reconstruction_movement( style: str = DEFAULT_PLOT_STYLE, ) -> matplotlib.figure.Figure: """ - Plot, for each sample, the change in value from the unperturbed reconstruction to the perturbed reconstruction. - Blue lines are left/negative shifts, red lines are right/positive shifts. + Plot, for each sample, the change in value from the unperturbed reconstruction to + the perturbed reconstruction. Blue lines are left/negative shifts, + red lines are right/positive shifts. Args: - baseline_recon: baseline reconstruction array with s samples and k features (s,k). - perturb_recon: perturbed " " - k: feature index. The shift (movement) of this feature's reconstruction will be plotted for all - samples s. + baseline_recon: baseline reconstruction array with s samples + and k features (s,k). + perturb_recon: perturbed + k: feature index. The shift (movement) of this feature's reconstruction + will be plotted for all samples s. """ with style_settings(style): # Feature changes