Skip to content

Commit

Permalink
Upgrade fbcode/pytorch to Python Scientific Stack 2 (#3845)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch/audio#3845

Differential Revision: D64008689
  • Loading branch information
igorsugak authored and facebook-github-bot committed Oct 18, 2024
1 parent 8dc48a5 commit 8ee3699
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 22 deletions.
44 changes: 24 additions & 20 deletions captum/attr/_utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import matplotlib

import numpy as np
import numpy.typing as npt
from matplotlib import cm, colors, pyplot as plt
from matplotlib.axes import Axes
from matplotlib.collections import LineCollection
Expand Down Expand Up @@ -47,11 +48,11 @@ class VisualizeSign(Enum):
all = 4


def _prepare_image(attr_visual: ndarray) -> ndarray:
def _prepare_image(attr_visual: npt.NDArray) -> npt.NDArray:
return np.clip(attr_visual.astype(int), 0, 255)


def _normalize_scale(attr: ndarray, scale_factor: float) -> ndarray:
def _normalize_scale(attr: npt.NDArray, scale_factor: float) -> npt.NDArray:
assert scale_factor != 0, "Cannot normalize by scale factor = 0"
if abs(scale_factor) < 1e-5:
warnings.warn(
Expand All @@ -64,23 +65,26 @@ def _normalize_scale(attr: ndarray, scale_factor: float) -> ndarray:
return np.clip(attr_norm, -1, 1)


def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]) -> float:
def _cumulative_sum_threshold(
values: npt.NDArray, percentile: Union[int, float]
) -> float:
# given values should be non-negative
assert percentile >= 0 and percentile <= 100, (
"Percentile for thresholding must be " "between 0 and 100 inclusive."
)
sorted_vals = np.sort(values.flatten())
cum_sums = np.cumsum(sorted_vals)
threshold_id = np.where(cum_sums >= cum_sums[-1] * 0.01 * percentile)[0][0]
# pyre-fixme[7]: Expected `float` but got `ndarray[typing.Any, dtype[typing.Any]]`.
return sorted_vals[threshold_id]


def _normalize_attr(
attr: ndarray,
attr: npt.NDArray,
sign: str,
outlier_perc: Union[int, float] = 2,
reduction_axis: Optional[int] = None,
) -> ndarray:
) -> npt.NDArray:
attr_combined = attr
if reduction_axis is not None:
attr_combined = np.sum(attr, axis=reduction_axis)
Expand Down Expand Up @@ -130,7 +134,7 @@ def _initialize_cmap_and_vmin_vmax(

def _visualize_original_image(
plt_axis: Axes,
original_image: Optional[ndarray],
original_image: Optional[npt.NDArray],
**kwargs: Any,
) -> None:
assert (
Expand All @@ -143,7 +147,7 @@ def _visualize_original_image(

def _visualize_heat_map(
plt_axis: Axes,
norm_attr: ndarray,
norm_attr: npt.NDArray,
cmap: Union[str, Colormap],
vmin: float,
vmax: float,
Expand All @@ -155,8 +159,8 @@ def _visualize_heat_map(

def _visualize_blended_heat_map(
plt_axis: Axes,
original_image: ndarray,
norm_attr: ndarray,
original_image: npt.NDArray,
norm_attr: npt.NDArray,
cmap: Union[str, Colormap],
vmin: float,
vmax: float,
Expand All @@ -176,8 +180,8 @@ def _visualize_blended_heat_map(
def _visualize_masked_image(
plt_axis: Axes,
sign: str,
original_image: ndarray,
norm_attr: ndarray,
original_image: npt.NDArray,
norm_attr: npt.NDArray,
**kwargs: Any,
) -> None:
assert VisualizeSign[sign].value != VisualizeSign.all.value, (
Expand All @@ -190,8 +194,8 @@ def _visualize_masked_image(
def _visualize_alpha_scaling(
plt_axis: Axes,
sign: str,
original_image: ndarray,
norm_attr: ndarray,
original_image: npt.NDArray,
norm_attr: npt.NDArray,
**kwargs: Any,
) -> None:
assert VisualizeSign[sign].value != VisualizeSign.all.value, (
Expand All @@ -210,8 +214,8 @@ def _visualize_alpha_scaling(


def visualize_image_attr(
attr: ndarray,
original_image: Optional[ndarray] = None,
attr: npt.NDArray,
original_image: Optional[npt.NDArray] = None,
method: str = "heat_map",
sign: str = "absolute_value",
plt_fig_axis: Optional[Tuple[Figure, Axes]] = None,
Expand Down Expand Up @@ -417,8 +421,8 @@ def visualize_image_attr(


def visualize_image_attr_multiple(
attr: ndarray,
original_image: Union[None, ndarray],
attr: npt.NDArray,
original_image: Union[None, npt.NDArray],
methods: List[str],
signs: List[str],
titles: Optional[List[str]] = None,
Expand Down Expand Up @@ -526,9 +530,9 @@ def visualize_image_attr_multiple(


def visualize_timeseries_attr(
attr: ndarray,
data: ndarray,
x_values: Optional[ndarray] = None,
attr: npt.NDArray,
data: npt.NDArray,
x_values: Optional[npt.NDArray] = None,
method: str = "overlay_individual",
sign: str = "absolute_value",
channel_labels: Optional[List[str]] = None,
Expand Down
3 changes: 2 additions & 1 deletion tests/attr/test_gradient_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import cast, Tuple

import numpy as np
import numpy.typing as npt
import torch
from captum._utils.typing import Tensor
from captum.attr._core.gradient_shap import GradientShap
Expand Down Expand Up @@ -132,7 +133,7 @@ def generate_baselines_with_inputs(inputs: Tensor) -> Tensor:
inp_shape = cast(Tuple[int, ...], inputs.shape)
return torch.arange(0.0, inp_shape[1] * 2.0).reshape(2, inp_shape[1])

def generate_baselines_returns_array() -> ndarray:
def generate_baselines_returns_array() -> npt.NDArray:
return np.arange(0.0, num_in * 4.0).reshape(4, num_in)

# 10-class classification model
Expand Down
3 changes: 2 additions & 1 deletion tests/utils/models/linear_models/_test_linear_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import captum._utils.models.linear_model.model as pytorch_model_module
import numpy as np
import numpy.typing as npt
import sklearn.datasets as datasets
import torch
from tests.helpers.evaluate_linear_model import evaluate
Expand Down Expand Up @@ -107,7 +108,7 @@ def compare_to_sk_learn(
o_sklearn["l1_reg"] = alpha * sklearn_h.norm(p=1, dim=-1)

rel_diff = cast(
np.ndarray,
npt.NDArray,
# pyre-fixme[6]: For 1st argument expected `int` but got `Union[int, Tensor]`.
(sum(o_sklearn.values()) - sum(o_pytorch.values())),
) / abs(sum(o_sklearn.values()))
Expand Down

0 comments on commit 8ee3699

Please sign in to comment.