diff --git a/script.py b/script.py index a1c80e2..bbcc5c8 100644 --- a/script.py +++ b/script.py @@ -354,25 +354,6 @@ def filter_det_dont_care(self): self.eval_log += " (" + str(len(self.det_dont_care_indices)) + " don't care)\n" if len(self.det_dont_care_indices) > 0 else "\n" - def one_to_one_match(self, row, col): - """One-to-One match condition""" - cont = 0 - for j in range(len(self.area_precision_matrix[0])): - if sum(self.pcc_count_matrix[row][j]) > 0 and self.area_precision_matrix[row, j] >= PARAMS.AREA_PRECISION_CONSTRAINT: - cont = cont + 1 - if cont != 1: - return False - cont = 0 - for i in range(len(self.area_precision_matrix)): - if sum(self.pcc_count_matrix[i][col]) > 0 and self.area_precision_matrix[i, col] >= PARAMS.AREA_PRECISION_CONSTRAINT: - cont = cont + 1 - if cont != 1: - return False - - if sum(self.pcc_count_matrix[row][col]) > 0 and self.area_precision_matrix[row, col] >= PARAMS.AREA_PRECISION_CONSTRAINT: - return True - return False - def one_to_many_match(self, gt_id): """One-to-Many match condition""" @@ -408,13 +389,19 @@ def many_to_one_match(self, det_id): def calc_match_matrix(self): """Calculate match matrix with PCC counting matrix information.""" self.eval_log += "Find one-to-one matches\n" - for gt_id in range(len(self.gt_boxes)): - for det_id in range(len(self.det_boxes)): - if gt_id not in self.gt_dont_care_indices and det_id not in self.det_dont_care_indices: - match = self.one_to_one_match(gt_id, det_id) - if match: - self.pairs.append({'gt': [gt_id], 'det': [det_id], 'type': 'OO'}) - self.eval_log += "Match GT #{} with Det #{}\n".format(gt_id, det_id) + single_gt_ids = [gt_id for gt_id in range(len(self.gt_boxes)) \ + if gt_id not in self.gt_dont_care_indices \ + and np.sum((np.sum(self.pcc_count_matrix[gt_id], axis=-1) > 0) + & (self.area_precision_matrix[gt_id] >= PARAMS.AREA_PRECISION_CONSTRAINT)) == 1] + single_det_ids = [det_id for det_id in range(len(self.det_boxes)) \ + if det_id not in self.det_dont_care_indices \ + and np.sum((np.array([sum(i_arr[det_id]) for i_arr in self.pcc_count_matrix]) > 0) + * (self.area_precision_matrix[:, det_id] >= PARAMS.AREA_PRECISION_CONSTRAINT)) == 1] + for gt_id in single_gt_ids: + for det_id in single_det_ids: + if sum(self.pcc_count_matrix[gt_id][det_id]) > 0 and self.area_precision_matrix[gt_id, det_id] >= PARAMS.AREA_PRECISION_CONSTRAINT: + self.pairs.append({'gt': [gt_id], 'det': [det_id], 'type': 'OO'}) + self.eval_log += "Match GT #{} with Det #{}\n".format(gt_id, det_id) # one-to-many match self.eval_log += "Find one-to-many matches\n"