Finished KeypointExtractor
This commit is contained in:
parent
ad7b160c92
commit
e7a7329d6f
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,3 +1,5 @@
|
||||
.devcontainer/
|
||||
data/
|
||||
.DS_Store
|
||||
cache/
|
||||
__pycache__/
|
||||
28
dataset.py
28
dataset.py
@ -1,28 +0,0 @@
|
||||
import torch
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
import json
|
||||
|
||||
class WLASLDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, csv_file: str, video_dir: str, subset:str="train", keypoints_file: str = "keypoints.csv", transform=None):
|
||||
self.df = pd.read_csv(csv_file)
|
||||
# filter wlasl data by subset
|
||||
self.df = self.df[self.df["subset"] == subset]
|
||||
self.video_dir = video_dir
|
||||
self.transform = transform
|
||||
self.subset = subset
|
||||
self.keypoints_file = keypoints_file
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, index):
|
||||
video_id = self.df.iloc[index]["video_id"]
|
||||
|
||||
# check if keypoints file exists
|
||||
if not os.path.exists(self.keypoints_file):
|
||||
# create empty dataframe
|
||||
keypoints_df = pd.DataFrame(columns=["video_id", "keypoints"])
|
||||
|
||||
# check if keypoints are available else extract from video
|
||||
|
||||
@ -1,61 +0,0 @@
|
||||
# Pose Landmarks
|
||||
POSE_LANDMARKS = {
|
||||
"nose": 0,
|
||||
"left_eye_inner": 1,
|
||||
"left_eye": 2,
|
||||
"left_eye_outer": 3,
|
||||
"right_eye_inner": 4,
|
||||
"right_eye": 5,
|
||||
"right_eye_outer": 6,
|
||||
"left_ear": 7,
|
||||
"right_ear": 8,
|
||||
"mouth_left": 9,
|
||||
"mouth_right": 10,
|
||||
"left_shoulder": 11,
|
||||
"right_shoulder": 12,
|
||||
"left_elbow": 13,
|
||||
"right_elbow": 14,
|
||||
"left_wrist": 15,
|
||||
"right_wrist": 16,
|
||||
"left_pinky": 17,
|
||||
"right_pinky": 18,
|
||||
"left_index": 19,
|
||||
"right_index": 20,
|
||||
"left_thumb": 21,
|
||||
"right_thumb": 22,
|
||||
"left_hip": 23,
|
||||
"right_hip": 24,
|
||||
"left_knee": 25,
|
||||
"right_knee": 26,
|
||||
"left_ankle": 27,
|
||||
"right_ankle": 28,
|
||||
"left_heel": 29,
|
||||
"right_heel": 30,
|
||||
"left_foot_index": 31,
|
||||
"right_foot_index": 32,
|
||||
}
|
||||
|
||||
# Hand Landmarks
|
||||
HAND_LANDMARKS = {
|
||||
"wrist": 0,
|
||||
"thumb_cmc": 1,
|
||||
"thumb_mcp": 2,
|
||||
"thumb_ip": 3,
|
||||
"thumb_tip": 4,
|
||||
"index_finger_mcp": 5,
|
||||
"index_finger_pip": 6,
|
||||
"index_finger_dip": 7,
|
||||
"index_finger_tip": 8,
|
||||
"middle_finger_mcp": 9,
|
||||
"middle_finger_pip": 10,
|
||||
"middle_finger_dip": 11,
|
||||
"middle_finger_tip": 12,
|
||||
"ring_finger_mcp": 13,
|
||||
"ring_finger_pip": 14,
|
||||
"ring_finger_dip": 15,
|
||||
"ring_finger_tip": 16,
|
||||
"pinky_mcp": 17,
|
||||
"pinky_pip": 18,
|
||||
"pinky_dip": 19,
|
||||
"pinky_tip": 20,
|
||||
}
|
||||
@ -1,58 +0,0 @@
|
||||
import mediapipe as mp
|
||||
import cv2
|
||||
|
||||
class KeypointExtractor:
|
||||
def __init__(self):
|
||||
self.mp_drawing = mp.solutions.drawing_utils
|
||||
|
||||
# hands extractor
|
||||
self.hands = mp.solutions.hands.Hands(
|
||||
min_detection_confidence=0.5,
|
||||
min_tracking_confidence=0.5,
|
||||
max_num_hands=2
|
||||
)
|
||||
|
||||
# pose extractor
|
||||
self.pose = mp.solutions.pose.Pose(
|
||||
min_detection_confidence=0.5,
|
||||
min_tracking_confidence=0.5,
|
||||
model_complexity=2
|
||||
)
|
||||
|
||||
def extract(self, image, video):
|
||||
# load video
|
||||
pass
|
||||
|
||||
|
||||
def extract_from_frame(self, image):
|
||||
# Convert the BGR image to RGB and process it with MediaPipe Pose.
|
||||
hand_results = self.hands.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
# Draw the hand annotations on the image.
|
||||
draw_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
draw_image.flags.writeable = False
|
||||
|
||||
for hand_landmarks in hand_results.multi_hand_landmarks:
|
||||
self.mp_drawing.draw_landmarks(
|
||||
draw_image, hand_landmarks, mp.solutions.hands.HAND_CONNECTIONS)
|
||||
|
||||
pose_results = self.pose.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
self.mp_drawing.draw_landmarks(
|
||||
draw_image, pose_results.pose_landmarks, mp.solutions.pose.POSE_CONNECTIONS)
|
||||
|
||||
|
||||
draw_image.flags.writeable = True
|
||||
draw_image = cv2.cvtColor(draw_image, cv2.COLOR_RGB2BGR)
|
||||
|
||||
return draw_image
|
||||
|
||||
|
||||
ke = KeypointExtractor()
|
||||
image = cv2.imread('data/test_photo.jpg')
|
||||
|
||||
image = ke.extract_from_frame(image)
|
||||
|
||||
# save image
|
||||
cv2.imwrite('test_output.jpg', image)
|
||||
@ -1,4 +1,5 @@
|
||||
torch
|
||||
torchvision
|
||||
pandas
|
||||
mediapipe
|
||||
mediapipe
|
||||
joblib
|
||||
59
src/dataset.py
Normal file
59
src/dataset.py
Normal file
@ -0,0 +1,59 @@
|
||||
import torch
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
import json
|
||||
from keypoint_extractor import KeypointExtractor
|
||||
from collections import OrderedDict
|
||||
from identifiers import LANDMARKS
|
||||
|
||||
class WLASLDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, json_file: str, missing: str, keypoint_extractor: KeypointExtractor, subset:str="train", keypoints_identifier: dict = None, transform=None):
|
||||
|
||||
# read the missing video file
|
||||
with open(missing) as f:
|
||||
missing = f.read().splitlines()
|
||||
|
||||
# read the json file
|
||||
with open(json_file) as f:
|
||||
data = json.load(f)
|
||||
|
||||
# remove the missing videos
|
||||
for m in missing:
|
||||
if m in data:
|
||||
del data[m]
|
||||
|
||||
new_data = OrderedDict()
|
||||
for k, v in data.items():
|
||||
if v["subset"] == subset:
|
||||
new_data[k] = v
|
||||
|
||||
self.data = new_data
|
||||
|
||||
# filter wlasl data by subset
|
||||
self.transform = transform
|
||||
self.subset = subset
|
||||
self.keypoint_extractor = keypoint_extractor
|
||||
if keypoints_identifier:
|
||||
self.keypoints_to_keep = [f"{i}_{j}" for i in keypoints_identifier.values() for j in ["x", "y"]]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# get i th element from ordered dict
|
||||
video_id = list(self.data.keys())[index]
|
||||
video_name = f"{video_id}.mp4"
|
||||
|
||||
# get the keypoints for the video
|
||||
keypoints_df = self.keypoint_extractor.extract_keypoints_from_video(video_name)
|
||||
|
||||
# filter the keypoints by the identified subset
|
||||
if self.keypoints_to_keep:
|
||||
keypoints_df = keypoints_df[self.keypoints_to_keep]
|
||||
|
||||
# TODO: convert keypoints to tensor and return
|
||||
|
||||
k = KeypointExtractor("data/videos/")
|
||||
d = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS)
|
||||
|
||||
d.__getitem__(0)
|
||||
82
src/identifiers.py
Normal file
82
src/identifiers.py
Normal file
@ -0,0 +1,82 @@
|
||||
LANDMARKS = {
|
||||
# Pose Landmarks
|
||||
"nose": 0,
|
||||
"left_eye_inner": 1,
|
||||
"left_eye": 2,
|
||||
"left_eye_outer": 3,
|
||||
"right_eye_inner": 4,
|
||||
"right_eye": 5,
|
||||
"right_eye_outer": 6,
|
||||
"left_ear": 7,
|
||||
"right_ear": 8,
|
||||
"mouth_left": 9,
|
||||
"mouth_right": 10,
|
||||
"left_shoulder": 11,
|
||||
"right_shoulder": 12,
|
||||
"left_elbow": 13,
|
||||
"right_elbow": 14,
|
||||
"left_wrist": 15,
|
||||
"right_wrist": 16,
|
||||
"left_pinky": 17,
|
||||
"right_pinky": 18,
|
||||
"left_index": 19,
|
||||
"right_index": 20,
|
||||
"left_thumb": 21,
|
||||
"right_thumb": 22,
|
||||
"left_hip": 23,
|
||||
"right_hip": 24,
|
||||
"left_knee": 25,
|
||||
"right_knee": 26,
|
||||
"left_ankle": 27,
|
||||
"right_ankle": 28,
|
||||
"left_heel": 29,
|
||||
"right_heel": 30,
|
||||
"left_foot_index": 31,
|
||||
"right_foot_index": 32,
|
||||
|
||||
# Left Hand Landmarks
|
||||
"left_wrist": 33,
|
||||
"left_thumb_cmc": 34,
|
||||
"left_thumb_mcp": 35,
|
||||
"left_thumb_ip": 36,
|
||||
"left_thumb_tip": 37,
|
||||
"left_index_finger_mcp": 38,
|
||||
"left_index_finger_pip": 39,
|
||||
"left_index_finger_dip": 40,
|
||||
"left_index_finger_tip": 41,
|
||||
"left_middle_finger_mcp": 42,
|
||||
"left_middle_finger_pip": 43,
|
||||
"left_middle_finger_dip": 44,
|
||||
"left_middle_finger_tip": 45,
|
||||
"left_ring_finger_mcp": 46,
|
||||
"left_ring_finger_pip": 47,
|
||||
"left_ring_finger_dip": 48,
|
||||
"left_ring_finger_tip": 49,
|
||||
"left_pinky_mcp": 50,
|
||||
"left_pinky_pip": 51,
|
||||
"left_pinky_dip": 52,
|
||||
"left_pinky_tip": 53,
|
||||
|
||||
# Right Hand Landmarks
|
||||
"right_wrist": 54,
|
||||
"right_thumb_cmc": 55,
|
||||
"right_thumb_mcp": 56,
|
||||
"right_thumb_ip": 57,
|
||||
"right_thumb_tip": 58,
|
||||
"right_index_finger_mcp": 59,
|
||||
"right_index_finger_pip": 60,
|
||||
"right_index_finger_dip": 61,
|
||||
"right_index_finger_tip": 62,
|
||||
"right_middle_finger_mcp": 63,
|
||||
"right_middle_finger_pip": 64,
|
||||
"right_middle_finger_dip": 65,
|
||||
"right_middle_finger_tip": 66,
|
||||
"right_ring_finger_mcp": 67,
|
||||
"right_ring_finger_pip": 68,
|
||||
"right_ring_finger_dip": 69,
|
||||
"right_ring_finger_tip": 70,
|
||||
"right_pinky_mcp": 71,
|
||||
"right_pinky_pip": 72,
|
||||
"right_pinky_dip": 73,
|
||||
"right_pinky_tip": 74,
|
||||
}
|
||||
106
src/keypoint_extractor.py
Normal file
106
src/keypoint_extractor.py
Normal file
@ -0,0 +1,106 @@
|
||||
import mediapipe as mp
|
||||
import cv2
|
||||
import time
|
||||
from typing import Dict, List, Tuple
|
||||
import numpy as np
|
||||
import logging
|
||||
import os
|
||||
import pandas as pd
|
||||
|
||||
class KeypointExtractor:
|
||||
def __init__(self, video_folder: str, cache_folder: str = "cache"):
|
||||
self.mp_drawing = mp.solutions.drawing_utils
|
||||
self.mp_holistic = mp.solutions.holistic
|
||||
self.video_folder = video_folder
|
||||
self.cache_folder = cache_folder
|
||||
|
||||
# we will store the keypoints of each frame as a row in the dataframe. The columns are the keypoints: Pose (33), Left Hand (21), Right Hand (21). Each keypoint has 3 values: x, y
|
||||
self.columns = [f"{i}_{j}" for i in range(33+21*2) for j in ["x", "y"]]
|
||||
|
||||
# holistic extractor
|
||||
self.holistic = mp.solutions.holistic.Holistic(
|
||||
min_detection_confidence=0.5,
|
||||
min_tracking_confidence=0.5,
|
||||
)
|
||||
|
||||
def extract_keypoints_from_video(self,
|
||||
video: str,
|
||||
) -> pd.DataFrame:
|
||||
"""extract_keypoints_from_video this function extracts keypoints from a video and stores them in a dataframe
|
||||
|
||||
:param video: the video to extract keypoints from
|
||||
:type video: str
|
||||
:return: dataframe with keypoints
|
||||
:rtype: pd.DataFrame
|
||||
"""
|
||||
# check if video exists
|
||||
if not os.path.exists(self.video_folder + video):
|
||||
logging.error("Video does not exist at path: " + self.video_folder + video)
|
||||
return None
|
||||
|
||||
# check if cache exists
|
||||
if not os.path.exists(self.cache_folder):
|
||||
os.makedirs(self.cache_folder)
|
||||
|
||||
# check if cache file exists and return
|
||||
if os.path.exists(self.cache_folder + "/" + video + ".npy"):
|
||||
# create dataframe from cache
|
||||
return pd.DataFrame(np.load(self.cache_folder + "/" + video + ".npy"), columns=self.columns)
|
||||
|
||||
# open video
|
||||
cap = cv2.VideoCapture(self.video_folder + video)
|
||||
|
||||
keypoints_df = pd.DataFrame(columns=self.columns)
|
||||
|
||||
while cap.isOpened():
|
||||
success, image = cap.read()
|
||||
if not success:
|
||||
break
|
||||
# extract keypoints of frame
|
||||
results = self.extract_keypoints_from_frame(image)
|
||||
|
||||
def extract_keypoints(landmarks):
|
||||
if landmarks:
|
||||
return [i for landmark in landmarks.landmark for i in [landmark.x, landmark.y]]
|
||||
|
||||
# store keypoints in dataframe
|
||||
k1 = extract_keypoints(results.pose_landmarks)
|
||||
k2 = extract_keypoints(results.left_hand_landmarks)
|
||||
k3 = extract_keypoints(results.right_hand_landmarks)
|
||||
if k1 and k2 and k3:
|
||||
keypoints_df = pd.concat([keypoints_df, pd.DataFrame([k1+k2+k3], columns=self.columns)])
|
||||
|
||||
# close video
|
||||
cap.release()
|
||||
|
||||
# save keypoints to cache
|
||||
np.save(self.cache_folder + "/" + video + ".npy", keypoints_df.to_numpy())
|
||||
|
||||
return keypoints_df
|
||||
|
||||
|
||||
def extract_keypoints_from_frame(self, image: np.ndarray, draw: bool = False):
|
||||
"""extract_keypoints_from_frame this function extracts keypoints from a frame and draws them on the frame if draw is set to True
|
||||
|
||||
:param image: the frame to extract keypoints from
|
||||
:type image: np.ndarray
|
||||
:param draw: indicates if frame with keypoints on must be returned, defaults to False
|
||||
:type draw: bool, optional
|
||||
:return: the keypoints and the frame with keypoints on if draw is set to True
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
print("Extracting keypoints from frame")
|
||||
# Convert the BGR image to RGB and process it with MediaPipe Pose.
|
||||
results = self.holistic.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
if draw:
|
||||
# Draw the pose annotations on the image
|
||||
draw_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
self.mp_drawing.draw_landmarks(draw_image, results.face_landmarks, self.mp_holistic.FACEMESH_CONTOURS)
|
||||
self.mp_drawing.draw_landmarks(draw_image, results.left_hand_landmarks, self.mp_holistic.HAND_CONNECTIONS)
|
||||
self.mp_drawing.draw_landmarks(draw_image, results.right_hand_landmarks, self.mp_holistic.HAND_CONNECTIONS)
|
||||
self.mp_drawing.draw_landmarks(draw_image, results.pose_landmarks, self.mp_holistic.POSE_CONNECTIONS)
|
||||
|
||||
return results, draw_image
|
||||
|
||||
return results
|
||||
BIN
test_output.jpg
BIN
test_output.jpg
Binary file not shown.
|
Before Width: | Height: | Size: 1.0 MiB After Width: | Height: | Size: 1.2 MiB |
Loading…
x
Reference in New Issue
Block a user