105 lines
4.3 KiB
Python
105 lines
4.3 KiB
Python
import mediapipe as mp
|
|
import cv2
|
|
import time
|
|
from typing import Dict, List, Tuple
|
|
import numpy as np
|
|
import logging
|
|
import os
|
|
import pandas as pd
|
|
|
|
class KeypointExtractor:
|
|
def __init__(self, video_folder: str, cache_folder: str = "cache"):
|
|
self.mp_drawing = mp.solutions.drawing_utils
|
|
self.mp_holistic = mp.solutions.holistic
|
|
self.video_folder = video_folder
|
|
self.cache_folder = cache_folder
|
|
|
|
# we will store the keypoints of each frame as a row in the dataframe. The columns are the keypoints: Pose (33), Left Hand (21), Right Hand (21). Each keypoint has 3 values: x, y
|
|
self.columns = [f"{i}_{j}" for i in range(33+21*2) for j in ["x", "y"]]
|
|
|
|
# holistic extractor
|
|
self.holistic = mp.solutions.holistic.Holistic(
|
|
min_detection_confidence=0.5,
|
|
min_tracking_confidence=0.5,
|
|
)
|
|
|
|
def extract_keypoints_from_video(self,
|
|
video: str,
|
|
) -> pd.DataFrame:
|
|
"""extract_keypoints_from_video this function extracts keypoints from a video and stores them in a dataframe
|
|
|
|
:param video: the video to extract keypoints from
|
|
:type video: str
|
|
:return: dataframe with keypoints
|
|
:rtype: pd.DataFrame
|
|
"""
|
|
# check if video exists
|
|
if not os.path.exists(self.video_folder + video):
|
|
logging.error("Video does not exist at path: " + self.video_folder + video)
|
|
return None
|
|
|
|
# check if cache exists
|
|
if not os.path.exists(self.cache_folder):
|
|
os.makedirs(self.cache_folder)
|
|
|
|
# check if cache file exists and return
|
|
if os.path.exists(self.cache_folder + "/" + video + ".npy"):
|
|
# create dataframe from cache
|
|
return pd.DataFrame(np.load(self.cache_folder + "/" + video + ".npy", allow_pickle=True), columns=self.columns)
|
|
|
|
# open video
|
|
cap = cv2.VideoCapture(self.video_folder + video)
|
|
|
|
keypoints_df = pd.DataFrame(columns=self.columns)
|
|
|
|
while cap.isOpened():
|
|
success, image = cap.read()
|
|
if not success:
|
|
break
|
|
# extract keypoints of frame
|
|
results = self.extract_keypoints_from_frame(image)
|
|
|
|
def extract_keypoints(landmarks):
|
|
if landmarks:
|
|
return [i for landmark in landmarks.landmark for i in [landmark.x, landmark.y]]
|
|
|
|
# store keypoints in dataframe
|
|
k1 = extract_keypoints(results.pose_landmarks)
|
|
k2 = extract_keypoints(results.left_hand_landmarks)
|
|
k3 = extract_keypoints(results.right_hand_landmarks)
|
|
if k1 and k2 and k3:
|
|
keypoints_df = pd.concat([keypoints_df, pd.DataFrame([k1+k2+k3], columns=self.columns)])
|
|
|
|
# close video
|
|
cap.release()
|
|
|
|
# save keypoints to cache
|
|
np.save(self.cache_folder + "/" + video + ".npy", keypoints_df.to_numpy())
|
|
|
|
return keypoints_df
|
|
|
|
|
|
def extract_keypoints_from_frame(self, image: np.ndarray, draw: bool = False):
|
|
"""extract_keypoints_from_frame this function extracts keypoints from a frame and draws them on the frame if draw is set to True
|
|
|
|
:param image: the frame to extract keypoints from
|
|
:type image: np.ndarray
|
|
:param draw: indicates if frame with keypoints on must be returned, defaults to False
|
|
:type draw: bool, optional
|
|
:return: the keypoints and the frame with keypoints on if draw is set to True
|
|
:rtype: np.ndarray
|
|
"""
|
|
# Convert the BGR image to RGB and process it with MediaPipe Pose.
|
|
results = self.holistic.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
|
|
|
if draw:
|
|
# Draw the pose annotations on the image
|
|
draw_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
self.mp_drawing.draw_landmarks(draw_image, results.face_landmarks, self.mp_holistic.FACEMESH_CONTOURS)
|
|
self.mp_drawing.draw_landmarks(draw_image, results.left_hand_landmarks, self.mp_holistic.HAND_CONNECTIONS)
|
|
self.mp_drawing.draw_landmarks(draw_image, results.right_hand_landmarks, self.mp_holistic.HAND_CONNECTIONS)
|
|
self.mp_drawing.draw_landmarks(draw_image, results.pose_landmarks, self.mp_holistic.POSE_CONNECTIONS)
|
|
|
|
return results, draw_image
|
|
|
|
return results |