Skip to content

Commit

Permalink
🎨 make line 88 characters long, remove some expections from flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry committed Jun 4, 2024
1 parent 6207350 commit e4e2283
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 39 deletions.
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,4 @@ console_scripts =
[flake8]
max-line-length = 88
aggressive = 2
extend-select = B950
extend-ignore = E203,E501,E701
extend-ignore = E203
8 changes: 4 additions & 4 deletions src/move/data/perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def perturb_continuous_data_extended(
) -> list[DataLoader]:
"""Add perturbations to continuous data. For each feature in the target
dataset, change the feature's value in all samples (in rows):
1,2) substituting this feature in all samples by the feature's minimum/maximum value.
3,4) Adding/Substracting one standard deviation to the sample's feature value.
1,2) substituting this feature in all samples by the feature's minimum/maximum value
3,4) Adding/Substracting one standard deviation to the sample's feature value
Args:
baseline_dataloader: Baseline dataloader
Expand All @@ -144,8 +144,8 @@ def perturb_continuous_data_extended(
Note:
This function was created so that it could generalize to non-normalized
datasets. Scaling is done per dataset, not per feature -> slightly different stds
feature to feature.
datasets. Scaling is done per dataset, not per feature -> slightly different
stds feature to feature.
"""

baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
Expand Down
3 changes: 2 additions & 1 deletion src/move/tasks/encode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def encode_data(config: DataConfig):
filepath = raw_data_path / f"{dataset_name}.tsv"
names, values = io.read_tsv(filepath, sample_names)

# Plotting the value distribution for all continuous datasets before preprocessing:
# Plotting the value distribution for all continuous datasets
# before preprocessing:
fig = plot_value_distributions(values)
fig_path = str(
output_path / "Value_distribution_{}_unprocessed.png".format(dataset_name)
Expand Down
31 changes: 19 additions & 12 deletions src/move/tasks/identify_associations.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,26 +437,30 @@ def _ks_approach(
task_config: IdentifyAssociationsKSConfig configuration.
train_dataloader: training DataLoader.
baseline_dataloader: unperturbed DataLoader.
dataloaders: list of DataLoaders where DataLoader[i] is obtained by perturbing feature i
in the target dataset.
dataloaders: list of DataLoaders where DataLoader[i] is obtained by perturbing
feature i in the target dataset.
models_path: path to the models.
num_perturbed: number of perturbed features.
num_samples: total number of samples
num_continuous: number of continuous features (all continuous datasets concatenated).
con_names: list of lists where eah inner list contains the feature names of a specific continuous dataset
num_continuous: number of continuous features
(all continuous datasets concatenated).
con_names: list of lists where eah inner list
contains the feature names of a specific continuous dataset
output_path: path where QC summary metrics will be saved.
Returns:
sort_ids: list with flattened IDs of the associations above the significance threshold.
ks_distance: Ordered list with signed KS scores. KS scores quantify the direction and
magnitude of the shift in feature B's reconstruction when perturbing feature A.
sort_ids: list with flattened IDs of the associations
above the significance threshold.
ks_distance: Ordered list with signed KS scores. KS scores quantify the
direction and magnitude of the shift in feature B's reconstruction
when perturbing feature A.
!!! Note !!!:
The sign of the KS score can be misleading: negative sign means positive shift.
since the cumulative distribution starts growing later and is found below the reference
(baseline). Hence:
since the cumulative distribution starts growing later and is found below
the reference (baseline). Hence:
a) with plus_std, negative sign means a positive correlation.
b) with minus_std, negative sign means a negative correlation.
"""
Expand Down Expand Up @@ -598,11 +602,13 @@ def _ks_approach(
edges,
hist_base,
hist_pert,
f"Cumulative_perturbed_{i}_measuring_{k}_stats_{stats[j, i, k]}",
f"Cumulative_perturbed_{i}_measuring_{
k}_stats_{stats[j, i, k]}",
)
fig.savefig(
figure_path
/ f"Cumulative_refit_{j}_perturbed_{i}_measuring_{k}_stats_{stats[j, i, k]}.png"
/ f"Cumulative_refit_{j}_perturbed_{
i}_measuring_{k}_stats_{stats[j, i, k]}.png"
)

# Feature changes:
Expand Down Expand Up @@ -640,7 +646,8 @@ def _ks_approach(
sort_ids = np.argsort(abs(final_stats), axis=None)[::-1] # 1D: N x C
ks_distance = np.take(final_stats, sort_ids) # 1D: N x C

# Writing Quality control csv file. Mean slope and correlation over refits as qc metrics.
# Writing Quality control csv file.
# Mean slope and correlation over refits as qc metrics.
logger.info("Writing QC file")
qc_df = pd.DataFrame({"Feature names": feature_names})
qc_df["slope"] = np.nanmean(slope, axis=0)
Expand Down
26 changes: 15 additions & 11 deletions src/move/training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,26 @@ def training_loop(
Args:
model (VAE): trained VAE model object
train_dataloader (DataLoader): An object feeding data to the VAE with training data
valid_dataloader (Optional[DataLoader], optional): An object feeding data to the VAE with validation data.
Defaults to None.
train_dataloader (DataLoader): An object feeding data to the VAE
with training data
valid_dataloader (Optional[DataLoader], optional): An object feeding data to the
VAE with validation data. Defaults to None.
lr (float, optional): learning rate. Defaults to 1e-4.
num_epochs (int, optional): number of epochs. Defaults to 100.
batch_dilation_steps (list[int], optional): a list with integers corresponding to epochs when batch size is
increased. Defaults to [].
kld_warmup_steps (list[int], optional): a list with integers corresponding to epochs when kld is decreased by
the selected rate. Defaults to [].
early_stopping (bool, optional): boolean if use early stopping . Defaults to False.
patience (int, optional): number of epochs to wait before early stop if no progress on the validation set.
Defaults to 0.
batch_dilation_steps (list[int], optional): a list with integers corresponding
to epochs when batch size is increased. Defaults to [].
kld_warmup_steps (list[int], optional): a list with integers corresponding to
epochs when kld is decreased by the selected rate. Defaults to [].
early_stopping (bool, optional): boolean if use early stopping.
Defaults to False.
patience (int, optional): number of epochs to wait before early stop
if no progress on the validation set. Defaults to 0.
Returns:
(tuple): a tuple containing:
*outputs (*list): lists containing information of epoch loss, BCE loss, SSE loss, KLD loss
*outputs (*list): lists containing information of epoch loss, BCE loss,
SSE loss, KLD loss
kld_weight (float): final KLD after dilations during the training
"""

Expand Down
21 changes: 12 additions & 9 deletions src/move/visualization/dataset_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def plot_feature_association_graph(
) -> matplotlib.figure.Figure:
"""
This function plots a graph where each node corresponds to a feature and the edges
represent the associations between features. Edge width represents the probability of
said association, not the association's effect size.
represent the associations between features. Edge width represents the probability
of said association, not the association's effect size.
Input:
association_df: pandas dataframe containing the following columns:
Expand Down Expand Up @@ -190,7 +190,8 @@ def plot_feature_mean_median(
array: FloatArray, axis=0, style: str = DEFAULT_PLOT_STYLE
) -> matplotlib.figure.Figure:
"""
Plot feature values together with the mean, median, min and max values at each array position.
Plot feature values together with the mean, median, min and max values
at each array position.
"""
with style_settings(style):
fig = plt.figure(figsize=(15, 3))
Expand All @@ -216,14 +217,16 @@ def plot_reconstruction_movement(
style: str = DEFAULT_PLOT_STYLE,
) -> matplotlib.figure.Figure:
"""
Plot, for each sample, the change in value from the unperturbed reconstruction to the perturbed reconstruction.
Blue lines are left/negative shifts, red lines are right/positive shifts.
Plot, for each sample, the change in value from the unperturbed reconstruction to
the perturbed reconstruction. Blue lines are left/negative shifts,
red lines are right/positive shifts.
Args:
baseline_recon: baseline reconstruction array with s samples and k features (s,k).
perturb_recon: perturbed " "
k: feature index. The shift (movement) of this feature's reconstruction will be plotted for all
samples s.
baseline_recon: baseline reconstruction array with s samples
and k features (s,k).
perturb_recon: perturbed
k: feature index. The shift (movement) of this feature's reconstruction
will be plotted for all samples s.
"""
with style_settings(style):
# Feature changes
Expand Down

0 comments on commit e4e2283

Please sign in to comment.