-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
120 lines (87 loc) · 3.95 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import csv
import copy
import argparse
import itertools
import cv2 as cv
import numpy as np
import mediapipe as mp
from utils import CvFpsCalc
from model import KeyPointClassifier
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--width", help='cap width', type=int, default=960)
parser.add_argument("--height", help='cap height', type=int, default=540)
parser.add_argument('--use_static_image_mode', action='store_true')
parser.add_argument("--min_detection_confidence",
help='min_detection_confidence',
type=float,
default=0.7)
parser.add_argument("--min_tracking_confidence",
help='min_tracking_confidence',
type=int,
default=0.5)
args = parser.parse_args()
return args
def calc_landmark_list(image, landmarks):
image_width, image_height = image.shape[1], image.shape[0]
landmark_point = []
# Keypoint
for _, landmark in enumerate(landmarks.landmark):
landmark_x = min(int(landmark.x * image_width), image_width - 1)
landmark_y = min(int(landmark.y * image_height), image_height - 1)
# landmark_z = landmark.z
landmark_point.append([landmark_x, landmark_y])
return landmark_point
def pre_process_landmark(landmark_list):
temp_landmark_list = copy.deepcopy(landmark_list)
# Convert to relative coordinates
base_x, base_y = 0, 0
for index, landmark_point in enumerate(temp_landmark_list):
if index == 0:
base_x, base_y = landmark_point[0], landmark_point[1]
temp_landmark_list[index][0] = temp_landmark_list[index][0] - base_x
temp_landmark_list[index][1] = temp_landmark_list[index][1] - base_y
# Convert to a one-dimensional list
temp_landmark_list = list(
itertools.chain.from_iterable(temp_landmark_list))
# Normalization
max_value = max(list(map(abs, temp_landmark_list)))
def normalize_(n):
return n / max_value
temp_landmark_list = list(map(normalize_, temp_landmark_list))
return temp_landmark_list
class Predict:
def __init__(self):
args = get_args()
self.use_static_image_mode = args.use_static_image_mode
self.min_detection_confidence = args.min_detection_confidence
self.min_tracking_confidence = args.min_tracking_confidence
self.mp_hands = mp.solutions.hands
self.hands = self.mp_hands.Hands(
static_image_mode=self.use_static_image_mode,
max_num_hands=1,
min_detection_confidence=self.min_detection_confidence,
min_tracking_confidence=self.min_tracking_confidence,
)
self.keypoint_classifier = KeyPointClassifier()
with open('model/keypoint_classifier/keypoint_classifier_label.csv', encoding='utf-8-sig') as f:
self.keypoint_classifier_labels = csv.reader(f)
self.keypoint_classifier_labels = [row[0] for row in self.keypoint_classifier_labels]
def get_hand_gesture_label(self, frame):
image = cv.flip(frame, 1) # Mirror display
debug_image = copy.deepcopy(image)
image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
image.flags.writeable = False
results = self.hands.process(image)
image.flags.writeable = True
label_name = ""
if results.multi_hand_landmarks is not None:
for hand_landmarks, handedness in zip(results.multi_hand_landmarks, results.multi_handedness):
landmark_list = calc_landmark_list(debug_image, hand_landmarks)
pre_processed_landmark_list = pre_process_landmark(landmark_list)
hand_sign_id = self.keypoint_classifier(pre_processed_landmark_list)
label_name = self.keypoint_classifier_labels[hand_sign_id]
return label_name