Dev
This commit is contained in:
0
visualizations/__init__.py
Normal file
0
visualizations/__init__.py
Normal file
146
visualizations/analyze_model.ipynb
Normal file
146
visualizations/analyze_model.ipynb
Normal file
File diff suppressed because one or more lines are too long
1781
visualizations/visualize_data.ipynb
Normal file
1781
visualizations/visualize_data.ipynb
Normal file
File diff suppressed because one or more lines are too long
116
visualizations/webcam_view.py
Normal file
116
visualizations/webcam_view.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import cv2
|
||||
import mediapipe as mp
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from src.identifiers import LANDMARKS
|
||||
from src.keypoint_extractor import KeypointExtractor
|
||||
from src.model import SPOTER
|
||||
from src.normalizations import normalize_hand_bohaecek, normalize_pose
|
||||
|
||||
# Initialize MediaPipe Hands model
|
||||
holistic = mp.solutions.holistic.Holistic(
|
||||
min_detection_confidence=0.5,
|
||||
min_tracking_confidence=0.5,
|
||||
model_complexity=2
|
||||
)
|
||||
mp_holistic = mp.solutions.holistic
|
||||
mp_drawing = mp.solutions.drawing_utils
|
||||
|
||||
# Initialize video capture object
|
||||
cap = cv2.VideoCapture(0)
|
||||
|
||||
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
keypoints = []
|
||||
|
||||
spoter_model = SPOTER(num_classes=26, hidden_dim=len(LANDMARKS) * 2)
|
||||
spoter_model.load_state_dict(torch.load('models/spoter_76.pth', map_location=torch.device('cpu')))
|
||||
|
||||
# get values of the landmarks as a list of integers
|
||||
values = []
|
||||
for i in LANDMARKS.values():
|
||||
values.append(i * 2)
|
||||
values.append(i * 2 + 1)
|
||||
values = np.array(values)
|
||||
|
||||
while True:
|
||||
# Read frame from camera
|
||||
success, frame = cap.read()
|
||||
|
||||
# Convert the frame to RGB
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Detect hand landmarks in the frame
|
||||
results = holistic.process(frame)
|
||||
|
||||
def extract_keypoints(landmarks):
|
||||
if landmarks:
|
||||
return np.array([i for landmark in landmarks.landmark for i in [landmark.x, landmark.y]])
|
||||
|
||||
k1 = extract_keypoints(results.pose_landmarks)
|
||||
k2 = extract_keypoints(results.left_hand_landmarks)
|
||||
k3 = extract_keypoints(results.right_hand_landmarks)
|
||||
|
||||
if k1 is not None and (k2 is not None or k3 is not None):
|
||||
k2 = k2 if k2 is not None else np.zeros(42)
|
||||
k3 = k3 if k3 is not None else np.zeros(42)
|
||||
|
||||
k1 = k1 * np.array([frame_width, frame_height] * 33)
|
||||
k2 = k2 * np.array([frame_width, frame_height] * 21)
|
||||
k3 = k3 * np.array([frame_width, frame_height] * 21)
|
||||
|
||||
k1, bbox_pose = normalize_pose(k1)
|
||||
k2, bbox_left = normalize_hand_bohaecek(k2)
|
||||
k3, bbox_right = normalize_hand_bohaecek(k3)
|
||||
|
||||
# Draw normalization bounding boxes
|
||||
if bbox_pose is not None:
|
||||
frame = cv2.rectangle(frame, bbox_pose, (0, 255, 0), 2)
|
||||
if bbox_left is not None:
|
||||
frame = cv2.rectangle(frame, bbox_left, (0, 255, 0), 2)
|
||||
if bbox_right is not None:
|
||||
frame = cv2.rectangle(frame, bbox_right, (0, 255, 0), 2)
|
||||
|
||||
k = np.concatenate((k1, k2, k3))
|
||||
filtered = k[values]
|
||||
|
||||
while len(keypoints) >= 8:
|
||||
keypoints.pop(0)
|
||||
keypoints.append(filtered)
|
||||
|
||||
if len(keypoints) == 8:
|
||||
# keypoints to tensor
|
||||
keypoints_tensor = torch.tensor(keypoints).float()
|
||||
outputs = spoter_model(keypoints_tensor).expand(1, -1, -1)
|
||||
outputs = torch.nn.functional.softmax(outputs, dim=2)
|
||||
topk = torch.topk(outputs, k=3, dim=2)
|
||||
|
||||
# show overlay on frame at top right with confidence scores of topk predictions
|
||||
for i, (label, score) in enumerate(zip(topk.indices[0][0], topk.values[0][0])):
|
||||
# get the label (A-Z), index to char
|
||||
l = label.item()
|
||||
if l < 26:
|
||||
l = chr(l + 65)
|
||||
|
||||
cv2.putText(frame, f"{l} {score.item():.2f}", (frame.shape[1] - 200, 50 + i * 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||||
|
||||
mp_drawing.draw_landmarks(frame, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS)
|
||||
mp_drawing.draw_landmarks(frame, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS)
|
||||
mp_drawing.draw_landmarks(frame, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS)
|
||||
|
||||
# frame to rgb
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# Show the frame
|
||||
cv2.imshow('MediaPipe Hands', frame)
|
||||
|
||||
# Wait for key press to exit
|
||||
if cv2.waitKey(5) & 0xFF == 27:
|
||||
break
|
||||
|
||||
# Release the video capture object and destroy the windows
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
Reference in New Issue
Block a user