Skip to content

Commit

Permalink
Merge pull request #10 from CallumWSimprints/fix-numpy-negation-bug
Browse files Browse the repository at this point in the history
Bugfix: Update negation syntax in compute_ground_truth_statistics
  • Loading branch information
jacobgil authored May 24, 2024
2 parents 51e8fd6 + 4e451d8 commit 5dd3e58
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
2 changes: 1 addition & 1 deletion confidenceinterval/delong.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def calc_pvalue(aucs, sigma):

def compute_ground_truth_statistics(ground_truth, sample_weight):
assert np.array_equal(np.unique(ground_truth), [0, 1])
order = (-ground_truth).argsort()
order = (~ground_truth).argsort()
label_1_count = int(ground_truth.sum())
if sample_weight is None:
ordered_sample_weight = None
Expand Down
24 changes: 24 additions & 0 deletions tests/test_delong.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from confidenceinterval.delong import compute_ground_truth_statistics
import numpy as np
import pytest


@pytest.mark.parametrize(
'ground_truth',
[
(np.array([0,0,1,1,1])), # Values are integers
(np.array([False, False, True, True, True])) # values are bools
]
)
def test_compute_ground_truth_statistics(ground_truth):
sample_weight = np.array([1,1,1,1,1])

expected_order = np.array([2,3,4,0,1])
expected_label_1_count = 3
expected_ordered_sample_weight = np.array([1,1,1,1,1])

order, label_1_count, ordered_sample_weight = compute_ground_truth_statistics(ground_truth=ground_truth, sample_weight=sample_weight)

assert np.array_equal(expected_order, order)
assert expected_label_1_count == label_1_count
assert np.array_equal(expected_ordered_sample_weight, ordered_sample_weight)

0 comments on commit 5dd3e58

Please sign in to comment.