117 lines
4.0 KiB
Python
117 lines
4.0 KiB
Python
import cv2
|
|
import mediapipe as mp
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
|
|
from src.identifiers import LANDMARKS
|
|
from src.keypoint_extractor import KeypointExtractor
|
|
from src.model import SPOTER
|
|
from src.normalizations import normalize_hand_bohaecek, normalize_pose
|
|
|
|
# Initialize MediaPipe Hands model
|
|
holistic = mp.solutions.holistic.Holistic(
|
|
min_detection_confidence=0.5,
|
|
min_tracking_confidence=0.5,
|
|
model_complexity=2
|
|
)
|
|
mp_holistic = mp.solutions.holistic
|
|
mp_drawing = mp.solutions.drawing_utils
|
|
|
|
# Initialize video capture object
|
|
cap = cv2.VideoCapture(0)
|
|
|
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
|
|
keypoints = []
|
|
|
|
spoter_model = SPOTER(num_classes=26, hidden_dim=len(LANDMARKS) * 2)
|
|
spoter_model.load_state_dict(torch.load('models/spoter_76.pth', map_location=torch.device('cpu')))
|
|
|
|
# get values of the landmarks as a list of integers
|
|
values = []
|
|
for i in LANDMARKS.values():
|
|
values.append(i * 2)
|
|
values.append(i * 2 + 1)
|
|
values = np.array(values)
|
|
|
|
while True:
|
|
# Read frame from camera
|
|
success, frame = cap.read()
|
|
|
|
# Convert the frame to RGB
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
|
|
# Detect hand landmarks in the frame
|
|
results = holistic.process(frame)
|
|
|
|
def extract_keypoints(landmarks):
|
|
if landmarks:
|
|
return np.array([i for landmark in landmarks.landmark for i in [landmark.x, landmark.y]])
|
|
|
|
k1 = extract_keypoints(results.pose_landmarks)
|
|
k2 = extract_keypoints(results.left_hand_landmarks)
|
|
k3 = extract_keypoints(results.right_hand_landmarks)
|
|
|
|
if k1 is not None and (k2 is not None or k3 is not None):
|
|
k2 = k2 if k2 is not None else np.zeros(42)
|
|
k3 = k3 if k3 is not None else np.zeros(42)
|
|
|
|
k1 = k1 * np.array([frame_width, frame_height] * 33)
|
|
k2 = k2 * np.array([frame_width, frame_height] * 21)
|
|
k3 = k3 * np.array([frame_width, frame_height] * 21)
|
|
|
|
k1, bbox_pose = normalize_pose(k1)
|
|
k2, bbox_left = normalize_hand_bohaecek(k2)
|
|
k3, bbox_right = normalize_hand_bohaecek(k3)
|
|
|
|
# Draw normalization bounding boxes
|
|
if bbox_pose is not None:
|
|
frame = cv2.rectangle(frame, bbox_pose, (0, 255, 0), 2)
|
|
if bbox_left is not None:
|
|
frame = cv2.rectangle(frame, bbox_left, (0, 255, 0), 2)
|
|
if bbox_right is not None:
|
|
frame = cv2.rectangle(frame, bbox_right, (0, 255, 0), 2)
|
|
|
|
k = np.concatenate((k1, k2, k3))
|
|
filtered = k[values]
|
|
|
|
while len(keypoints) >= 8:
|
|
keypoints.pop(0)
|
|
keypoints.append(filtered)
|
|
|
|
if len(keypoints) == 8:
|
|
# keypoints to tensor
|
|
keypoints_tensor = torch.tensor(keypoints).float()
|
|
outputs = spoter_model(keypoints_tensor).expand(1, -1, -1)
|
|
outputs = torch.nn.functional.softmax(outputs, dim=2)
|
|
topk = torch.topk(outputs, k=3, dim=2)
|
|
|
|
# show overlay on frame at top right with confidence scores of topk predictions
|
|
for i, (label, score) in enumerate(zip(topk.indices[0][0], topk.values[0][0])):
|
|
# get the label (A-Z), index to char
|
|
l = label.item()
|
|
if l < 26:
|
|
l = chr(l + 65)
|
|
|
|
cv2.putText(frame, f"{l} {score.item():.2f}", (frame.shape[1] - 200, 50 + i * 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
|
|
|
mp_drawing.draw_landmarks(frame, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS)
|
|
mp_drawing.draw_landmarks(frame, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS)
|
|
mp_drawing.draw_landmarks(frame, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS)
|
|
|
|
# frame to rgb
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
|
|
|
# Show the frame
|
|
cv2.imshow('MediaPipe Hands', frame)
|
|
|
|
# Wait for key press to exit
|
|
if cv2.waitKey(5) & 0xFF == 27:
|
|
break
|
|
|
|
# Release the video capture object and destroy the windows
|
|
cap.release()
|
|
cv2.destroyAllWindows()
|