Skip to content

Commit

Permalink
[wip] refactoring Utils
Browse files Browse the repository at this point in the history
  • Loading branch information
Ledmington committed Apr 3, 2024
1 parent e2ab253 commit 51ef124
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 16 deletions.
33 changes: 20 additions & 13 deletions lib/src/main/java/com/ledmington/gal/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,32 @@ public static <X> Supplier<X> weightedChoose(
throw new IllegalArgumentException("The list of values cannot be empty");
}

final Function<X, Double> safeWeight = x -> {
final double result = weight.apply(x);
if (result < 0.0) {
throw new IllegalArgumentException(String.format(
"Negative weights are not allowed: the object '%s' produced the weight %f",
x.toString(), result));
}
return result;
};
final double totalWeight =
values.stream().mapToDouble(safeWeight::apply).sum();
double minWeight = Double.POSITIVE_INFINITY;
double maxWeight = Double.NEGATIVE_INFINITY;
double totalWeight = 0.0;
for (final X x : values) {
final double w = weight.apply(x);
minWeight = Math.min(minWeight, w);
maxWeight = Math.max(maxWeight, w);
totalWeight += w;
}

if (minWeight == maxWeight) {
// if they all have the same weight, return a special function which treats all
// values equally
return () -> values.get(rng.nextInt(0, values.size()));
}

final double finalMinWeight = minWeight;
final double finalTotalWeight = totalWeight - finalMinWeight * values.size();

return () -> {
final double chosenWeight = rng.nextDouble(0.0, totalWeight);
final double chosenWeight = rng.nextDouble(0.0, finalTotalWeight);

double sum = 0.0;
for (int i = 0; i < values.size() - 1; i++) {
final X ith_element = values.get(i);
sum += safeWeight.apply(ith_element);
sum += (weight.apply(ith_element) - finalMinWeight);
if (sum >= chosenWeight) {
return ith_element;
}
Expand Down
30 changes: 27 additions & 3 deletions lib/src/test/java/com/ledmington/gal/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.random.RandomGenerator;
import java.util.random.RandomGeneratorFactory;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -86,8 +87,9 @@ public void weightsWork() {
count.put(x, 0);
}

final Supplier<Integer> weightedChoose = Utils.weightedChoose(values, w, rng);
for (int i = 0; i < 10_000; i++) {
final Integer chosen = Utils.weightedChoose(values, w, rng).get();
final Integer chosen = weightedChoose.get();
count.put(chosen, count.get(chosen) + 1);
}

Expand All @@ -106,11 +108,33 @@ public void weightsWork() {
}

@Test
public void negativeWeightsDoNotWork() {
public void negativeWeightsWorkAsWell() {
final List<Integer> values = List.of(1, 2, 3, 4, 5, 6, 7, 8, 9);
final Function<Integer, Double> w = x -> -(double) x;

assertThrows(IllegalArgumentException.class, () -> Utils.weightedChoose(values, w, rng));
final Map<Integer, Integer> count = new HashMap<>();
for (final Integer x : values) {
count.put(x, 0);
}

final Supplier<Integer> weightedChoose = Utils.weightedChoose(values, w, rng);
for (int i = 0; i < 10_000; i++) {
final Integer chosen = weightedChoose.get();
count.put(chosen, count.get(chosen) + 1);
}

for (int i = 0; i < values.size() - 1; i++) {
final Integer first = values.get(i);
final Integer second = values.get(i + 1);
assertTrue(
count.get(first) > 0,
String.format("Value %d (with weight %f) did not appear once", first, w.apply(first)));
assertTrue(
count.get(first) < count.get(second),
String.format(
"Value %d (with weight %f) appeared more often than value %d (with weight %f): %,d > %,d",
first, w.apply(first), second, w.apply(second), count.get(first), count.get(second)));
}
}

@Test
Expand Down

0 comments on commit 51ef124

Please sign in to comment.