-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathyolo.py
143 lines (112 loc) · 5.15 KB
/
yolo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import cv2
import numpy as np
from database import log_detection
from collections import defaultdict
class CentroidTracker:
def __init__(self, max_disappeared=50):
self.next_object_id = 0
self.objects = {}
self.disappeared = defaultdict(int)
self.max_disappeared = max_disappeared
def register(self, centroid):
self.objects[self.next_object_id] = centroid
self.next_object_id += 1
def deregister(self, object_id):
del self.objects[object_id]
del self.disappeared[object_id]
def update(self, rects):
if len(rects) == 0:
for object_id in list(self.disappeared.keys()):
self.disappeared[object_id] += 1
if self.disappeared[object_id] > self.max_disappeared:
self.deregister(object_id)
return self.objects
input_centroids = np.zeros((len(rects), 2), dtype="int")
for (i, (startX, startY, endX, endY)) in enumerate(rects):
cX = int((startX + endX) / 2.0)
cY = int((startY + endY) / 2.0)
input_centroids[i] = (cX, cY)
if len(self.objects) == 0:
for i in range(0, len(input_centroids)):
self.register(input_centroids[i])
else:
object_ids = list(self.objects.keys())
object_centroids = list(self.objects.values())
D = np.linalg.norm(np.array(object_centroids)[:, np.newaxis] - input_centroids, axis=2)
rows = D.min(axis=1).argsort()
cols = D.argmin(axis=1)[rows]
used_rows = set()
used_cols = set()
for (row, col) in zip(rows, cols):
if row in used_rows or col in used_cols:
continue
object_id = object_ids[row]
self.objects[object_id] = input_centroids[col]
self.disappeared[object_id] = 0
used_rows.add(row)
used_cols.add(col)
unused_rows = set(range(0, D.shape[0])).difference(used_rows)
unused_cols = set(range(0, D.shape[1])).difference(used_cols)
if D.shape[0] >= D.shape[1]:
for row in unused_rows:
object_id = object_ids[row]
self.disappeared[object_id] += 1
if self.disappeared[object_id] > self.max_disappeared:
self.deregister(object_id)
else:
for col in unused_cols:
self.register(input_centroids[col])
return self.objects
class YOLO:
def __init__(self):
self.model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
self.classes = self.model.names
self.tracker = CentroidTracker()
self.detected_boxes = []
self.logged_objects = set() # Set to keep track of logged object IDs
def detect(self, frame):
results = self.model(frame)
detections = results.xyxy[0].numpy()
self.detected_boxes = []
rects = []
for *xyxy, conf, cls in detections:
class_index = int(cls)
if class_index < 0 or class_index >= len(self.classes):
print(f"Invalid class index detected: {class_index}")
continue
rects.append(xyxy)
objects = self.tracker.update(rects)
num_objects = 0
for *xyxy, conf, cls in detections:
startX, startY, endX, endY = xyxy
class_index = int(cls)
if class_index < 0 or class_index >= len(self.classes):
continue
centroid = ((startX + endX) / 2, (startY + endY) / 2)
object_id = None
for id, tracked_centroid in objects.items():
if np.allclose(centroid, tracked_centroid, atol=1.0):
object_id = id
break
if object_id is None:
continue
class_name = self.classes[class_index]
num_objects += 1
self.detected_boxes.append([startX, startY, endX, endY, object_id])
label = f'{class_name} {conf:.2f} ID: {object_id}'
frame = cv2.rectangle(frame, (int(startX), int(startY)), (int(endX), int(endY)), (255, 0, 0), 2)
frame = cv2.putText(frame, label, (int(startX), int(startY) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
# Log detection only if it's a new object
if object_id not in self.logged_objects:
log_detection(class_name, conf, startX, startY, endX, endY, object_id)
self.logged_objects.add(object_id)
cv2.putText(frame, f'Counting: {num_objects}', (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 3)
return frame
def toggle_detection(self, cls_name, enable):
if cls_name in self.detected_classes:
self.detected_classes[cls_name] = enable
else:
print(f"Class {cls_name} not found in detected_classes")
def get_detected_boxes(self):
return self.detected_boxes