Skip to content

Commit

Permalink
Move trivial check to _get_peak_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
rfezzani committed May 22, 2020
1 parent 32d70c1 commit db38358
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions skimage/feature/peak.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def _get_peak_mask(image, min_distance, footprint, threshold_abs,
Return the mask containing all peak candidates above thresholds.
"""
if footprint is not None:
if footprint.size == 1 and min_distance == 1:
return np.ones_like(image, dtype=bool)
image_max = ndi.maximum_filter(image, footprint=footprint,
mode='constant')
else:
Expand All @@ -37,6 +39,10 @@ def _get_peak_mask(image, min_distance, footprint, threshold_abs,
else:
threshold = threshold_abs
mask &= image > threshold

# no peak for a trivial image
if np.count_nonzero(mask) == image.size:
mask[:] = False
return mask


Expand Down Expand Up @@ -159,8 +165,6 @@ def peak_local_max(image, min_distance=1, threshold_abs=None,
array([[10, 10, 10]])
"""
out = np.zeros_like(image, dtype=np.bool)

threshold_abs = threshold_abs if threshold_abs is not None else image.min()

if isinstance(exclude_border, bool):
Expand Down Expand Up @@ -188,13 +192,6 @@ def peak_local_max(image, min_distance=1, threshold_abs=None,
"`exclude_border` must be bool, int, or tuple with the same "
"length as the dimensionality of the image.")

# no peak for a trivial image
if np.all(image == image.flat[0]):
if indices is True:
return np.empty((0, image.ndim), np.int)
else:
return out

# In the case of labels, call ndi on each label
if labels is not None:
label_values = np.unique(labels)
Expand All @@ -211,6 +208,8 @@ def peak_local_max(image, min_distance=1, threshold_abs=None,
# For each label, extract a smaller image enclosing the object of
# interest, identify num_peaks_per_label peaks and mark them in
# variable out.
out = np.zeros_like(image, dtype=np.bool)

for label_idx, obj in enumerate(ndi.find_objects(labels)):
img_object = image[obj] * (labels[obj] == label_idx + 1)
mask = _get_peak_mask(img_object, min_distance, footprint,
Expand Down Expand Up @@ -250,6 +249,7 @@ def peak_local_max(image, min_distance=1, threshold_abs=None,
return coordinates
else:
nd_indices = tuple(coordinates.T)
out = np.zeros_like(image, dtype=np.bool)
out[nd_indices] = True
return out

Expand Down

0 comments on commit db38358

Please sign in to comment.