-
Notifications
You must be signed in to change notification settings - Fork 3
/
algorithms.py
100 lines (68 loc) · 2.17 KB
/
algorithms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import time
import numpy as np
import scipy_python
import _lap
import lapsolver
import lap
import lapjv
import hungarian
def algorithm(func, supports_rectangular):
def decorate(cls):
cls.name = cls.__name__
cls.func = func
cls.supports_rectangular = supports_rectangular
return cls
return decorate
class Algorithm:
def adapt_result(self, result):
return result
@classmethod
def time_algorithm(self, cost_matrix):
t0 = time.time()
result = self.func(cost_matrix)
dt = time.time() - t0
return result, dt
def run(self, cost_matrix):
# get result in scipy form and calculate cost
result, dt = self.time_algorithm(np.copy(cost_matrix))
result = self.adapt_result(result)
obj = cost_matrix[result].sum()
# verify result
a, b = result
if cost_matrix.shape[1] > cost_matrix.shape[0]:
b, a = result
assert len(np.unique(a)) == len(a)
assert (np.sort(b) == np.arange(len(b))).all()
return dt, obj
@algorithm(scipy_python.linear_sum_assignment, True)
class alg_scipy(Algorithm):
pass
@algorithm(_lap.linear_sum_assignment, True)
class alg_jonkervolgenant(Algorithm):
pass
@algorithm(lapsolver.solve_dense, True)
class alg_lapsolver(Algorithm):
pass
@algorithm(hungarian.lap, False)
class alg_hungarian(Algorithm):
def adapt_result(self, result):
x = result[0]
return (np.arange(len(x)), x)
@algorithm(None, True)
class alg_gatagat_lapjv(Algorithm):
def adapt_result(self, result):
cost, x, y = result
return (np.arange(len(x)), x)
def time_algorithm(self, cost_matrix):
square = cost_matrix.shape[0] == cost_matrix.shape[1]
t0 = time.time()
result = lap.lapjv(cost_matrix, extend_cost=not square)
dt = time.time() - t0
return result, dt
@algorithm(lapjv.lapjv, False)
class alg_srcd_lapjv(Algorithm):
def adapt_result(self, result):
x, y, _ = result
return (np.arange(len(x)), x)
algs = [alg_scipy, alg_jonkervolgenant, alg_lapsolver, alg_gatagat_lapjv,
alg_srcd_lapjv, alg_hungarian]