diff --git a/.gitignore b/.gitignore index ac36ecd..b804a61 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ .devcontainer/ data/ .DS_Store +cache/ +__pycache__/ \ No newline at end of file diff --git a/dataset.py b/dataset.py deleted file mode 100644 index 5b3ff7f..0000000 --- a/dataset.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch -import pandas as pd -from PIL import Image -import json - -class WLASLDataset(torch.utils.data.Dataset): - def __init__(self, csv_file: str, video_dir: str, subset:str="train", keypoints_file: str = "keypoints.csv", transform=None): - self.df = pd.read_csv(csv_file) - # filter wlasl data by subset - self.df = self.df[self.df["subset"] == subset] - self.video_dir = video_dir - self.transform = transform - self.subset = subset - self.keypoints_file = keypoints_file - - def __len__(self): - return len(self.df) - - def __getitem__(self, index): - video_id = self.df.iloc[index]["video_id"] - - # check if keypoints file exists - if not os.path.exists(self.keypoints_file): - # create empty dataframe - keypoints_df = pd.DataFrame(columns=["video_id", "keypoints"]) - - # check if keypoints are available else extract from video - \ No newline at end of file diff --git a/identifiers.py b/identifiers.py deleted file mode 100644 index 916aadf..0000000 --- a/identifiers.py +++ /dev/null @@ -1,61 +0,0 @@ -# Pose Landmarks -POSE_LANDMARKS = { - "nose": 0, - "left_eye_inner": 1, - "left_eye": 2, - "left_eye_outer": 3, - "right_eye_inner": 4, - "right_eye": 5, - "right_eye_outer": 6, - "left_ear": 7, - "right_ear": 8, - "mouth_left": 9, - "mouth_right": 10, - "left_shoulder": 11, - "right_shoulder": 12, - "left_elbow": 13, - "right_elbow": 14, - "left_wrist": 15, - "right_wrist": 16, - "left_pinky": 17, - "right_pinky": 18, - "left_index": 19, - "right_index": 20, - "left_thumb": 21, - "right_thumb": 22, - "left_hip": 23, - "right_hip": 24, - "left_knee": 25, - "right_knee": 26, - "left_ankle": 27, - "right_ankle": 28, - "left_heel": 29, - "right_heel": 30, - "left_foot_index": 31, - "right_foot_index": 32, -} - -# Hand Landmarks -HAND_LANDMARKS = { - "wrist": 0, - "thumb_cmc": 1, - "thumb_mcp": 2, - "thumb_ip": 3, - "thumb_tip": 4, - "index_finger_mcp": 5, - "index_finger_pip": 6, - "index_finger_dip": 7, - "index_finger_tip": 8, - "middle_finger_mcp": 9, - "middle_finger_pip": 10, - "middle_finger_dip": 11, - "middle_finger_tip": 12, - "ring_finger_mcp": 13, - "ring_finger_pip": 14, - "ring_finger_dip": 15, - "ring_finger_tip": 16, - "pinky_mcp": 17, - "pinky_pip": 18, - "pinky_dip": 19, - "pinky_tip": 20, -} diff --git a/keypoint_extractor.py b/keypoint_extractor.py deleted file mode 100644 index 15938f6..0000000 --- a/keypoint_extractor.py +++ /dev/null @@ -1,58 +0,0 @@ -import mediapipe as mp -import cv2 - -class KeypointExtractor: - def __init__(self): - self.mp_drawing = mp.solutions.drawing_utils - - # hands extractor - self.hands = mp.solutions.hands.Hands( - min_detection_confidence=0.5, - min_tracking_confidence=0.5, - max_num_hands=2 - ) - - # pose extractor - self.pose = mp.solutions.pose.Pose( - min_detection_confidence=0.5, - min_tracking_confidence=0.5, - model_complexity=2 - ) - - def extract(self, image, video): - # load video - pass - - - def extract_from_frame(self, image): - # Convert the BGR image to RGB and process it with MediaPipe Pose. - hand_results = self.hands.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - - # Draw the hand annotations on the image. - draw_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - - draw_image.flags.writeable = False - - for hand_landmarks in hand_results.multi_hand_landmarks: - self.mp_drawing.draw_landmarks( - draw_image, hand_landmarks, mp.solutions.hands.HAND_CONNECTIONS) - - pose_results = self.pose.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - - self.mp_drawing.draw_landmarks( - draw_image, pose_results.pose_landmarks, mp.solutions.pose.POSE_CONNECTIONS) - - - draw_image.flags.writeable = True - draw_image = cv2.cvtColor(draw_image, cv2.COLOR_RGB2BGR) - - return draw_image - - -ke = KeypointExtractor() -image = cv2.imread('data/test_photo.jpg') - -image = ke.extract_from_frame(image) - -# save image -cv2.imwrite('test_output.jpg', image) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 650c08d..42c82fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ torch torchvision pandas -mediapipe \ No newline at end of file +mediapipe +joblib \ No newline at end of file diff --git a/src/dataset.py b/src/dataset.py new file mode 100644 index 0000000..1a08b04 --- /dev/null +++ b/src/dataset.py @@ -0,0 +1,59 @@ +import torch +import pandas as pd +from PIL import Image +import json +from keypoint_extractor import KeypointExtractor +from collections import OrderedDict +from identifiers import LANDMARKS + +class WLASLDataset(torch.utils.data.Dataset): + def __init__(self, json_file: str, missing: str, keypoint_extractor: KeypointExtractor, subset:str="train", keypoints_identifier: dict = None, transform=None): + + # read the missing video file + with open(missing) as f: + missing = f.read().splitlines() + + # read the json file + with open(json_file) as f: + data = json.load(f) + + # remove the missing videos + for m in missing: + if m in data: + del data[m] + + new_data = OrderedDict() + for k, v in data.items(): + if v["subset"] == subset: + new_data[k] = v + + self.data = new_data + + # filter wlasl data by subset + self.transform = transform + self.subset = subset + self.keypoint_extractor = keypoint_extractor + if keypoints_identifier: + self.keypoints_to_keep = [f"{i}_{j}" for i in keypoints_identifier.values() for j in ["x", "y"]] + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + # get i th element from ordered dict + video_id = list(self.data.keys())[index] + video_name = f"{video_id}.mp4" + + # get the keypoints for the video + keypoints_df = self.keypoint_extractor.extract_keypoints_from_video(video_name) + + # filter the keypoints by the identified subset + if self.keypoints_to_keep: + keypoints_df = keypoints_df[self.keypoints_to_keep] + + # TODO: convert keypoints to tensor and return + +k = KeypointExtractor("data/videos/") +d = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS) + +d.__getitem__(0) \ No newline at end of file diff --git a/src/identifiers.py b/src/identifiers.py new file mode 100644 index 0000000..d301e5d --- /dev/null +++ b/src/identifiers.py @@ -0,0 +1,82 @@ +LANDMARKS = { + # Pose Landmarks + "nose": 0, + "left_eye_inner": 1, + "left_eye": 2, + "left_eye_outer": 3, + "right_eye_inner": 4, + "right_eye": 5, + "right_eye_outer": 6, + "left_ear": 7, + "right_ear": 8, + "mouth_left": 9, + "mouth_right": 10, + "left_shoulder": 11, + "right_shoulder": 12, + "left_elbow": 13, + "right_elbow": 14, + "left_wrist": 15, + "right_wrist": 16, + "left_pinky": 17, + "right_pinky": 18, + "left_index": 19, + "right_index": 20, + "left_thumb": 21, + "right_thumb": 22, + "left_hip": 23, + "right_hip": 24, + "left_knee": 25, + "right_knee": 26, + "left_ankle": 27, + "right_ankle": 28, + "left_heel": 29, + "right_heel": 30, + "left_foot_index": 31, + "right_foot_index": 32, + + # Left Hand Landmarks + "left_wrist": 33, + "left_thumb_cmc": 34, + "left_thumb_mcp": 35, + "left_thumb_ip": 36, + "left_thumb_tip": 37, + "left_index_finger_mcp": 38, + "left_index_finger_pip": 39, + "left_index_finger_dip": 40, + "left_index_finger_tip": 41, + "left_middle_finger_mcp": 42, + "left_middle_finger_pip": 43, + "left_middle_finger_dip": 44, + "left_middle_finger_tip": 45, + "left_ring_finger_mcp": 46, + "left_ring_finger_pip": 47, + "left_ring_finger_dip": 48, + "left_ring_finger_tip": 49, + "left_pinky_mcp": 50, + "left_pinky_pip": 51, + "left_pinky_dip": 52, + "left_pinky_tip": 53, + + # Right Hand Landmarks + "right_wrist": 54, + "right_thumb_cmc": 55, + "right_thumb_mcp": 56, + "right_thumb_ip": 57, + "right_thumb_tip": 58, + "right_index_finger_mcp": 59, + "right_index_finger_pip": 60, + "right_index_finger_dip": 61, + "right_index_finger_tip": 62, + "right_middle_finger_mcp": 63, + "right_middle_finger_pip": 64, + "right_middle_finger_dip": 65, + "right_middle_finger_tip": 66, + "right_ring_finger_mcp": 67, + "right_ring_finger_pip": 68, + "right_ring_finger_dip": 69, + "right_ring_finger_tip": 70, + "right_pinky_mcp": 71, + "right_pinky_pip": 72, + "right_pinky_dip": 73, + "right_pinky_tip": 74, +} diff --git a/src/keypoint_extractor.py b/src/keypoint_extractor.py new file mode 100644 index 0000000..93804d6 --- /dev/null +++ b/src/keypoint_extractor.py @@ -0,0 +1,106 @@ +import mediapipe as mp +import cv2 +import time +from typing import Dict, List, Tuple +import numpy as np +import logging +import os +import pandas as pd + +class KeypointExtractor: + def __init__(self, video_folder: str, cache_folder: str = "cache"): + self.mp_drawing = mp.solutions.drawing_utils + self.mp_holistic = mp.solutions.holistic + self.video_folder = video_folder + self.cache_folder = cache_folder + + # we will store the keypoints of each frame as a row in the dataframe. The columns are the keypoints: Pose (33), Left Hand (21), Right Hand (21). Each keypoint has 3 values: x, y + self.columns = [f"{i}_{j}" for i in range(33+21*2) for j in ["x", "y"]] + + # holistic extractor + self.holistic = mp.solutions.holistic.Holistic( + min_detection_confidence=0.5, + min_tracking_confidence=0.5, + ) + + def extract_keypoints_from_video(self, + video: str, + ) -> pd.DataFrame: + """extract_keypoints_from_video this function extracts keypoints from a video and stores them in a dataframe + + :param video: the video to extract keypoints from + :type video: str + :return: dataframe with keypoints + :rtype: pd.DataFrame + """ + # check if video exists + if not os.path.exists(self.video_folder + video): + logging.error("Video does not exist at path: " + self.video_folder + video) + return None + + # check if cache exists + if not os.path.exists(self.cache_folder): + os.makedirs(self.cache_folder) + + # check if cache file exists and return + if os.path.exists(self.cache_folder + "/" + video + ".npy"): + # create dataframe from cache + return pd.DataFrame(np.load(self.cache_folder + "/" + video + ".npy"), columns=self.columns) + + # open video + cap = cv2.VideoCapture(self.video_folder + video) + + keypoints_df = pd.DataFrame(columns=self.columns) + + while cap.isOpened(): + success, image = cap.read() + if not success: + break + # extract keypoints of frame + results = self.extract_keypoints_from_frame(image) + + def extract_keypoints(landmarks): + if landmarks: + return [i for landmark in landmarks.landmark for i in [landmark.x, landmark.y]] + + # store keypoints in dataframe + k1 = extract_keypoints(results.pose_landmarks) + k2 = extract_keypoints(results.left_hand_landmarks) + k3 = extract_keypoints(results.right_hand_landmarks) + if k1 and k2 and k3: + keypoints_df = pd.concat([keypoints_df, pd.DataFrame([k1+k2+k3], columns=self.columns)]) + + # close video + cap.release() + + # save keypoints to cache + np.save(self.cache_folder + "/" + video + ".npy", keypoints_df.to_numpy()) + + return keypoints_df + + + def extract_keypoints_from_frame(self, image: np.ndarray, draw: bool = False): + """extract_keypoints_from_frame this function extracts keypoints from a frame and draws them on the frame if draw is set to True + + :param image: the frame to extract keypoints from + :type image: np.ndarray + :param draw: indicates if frame with keypoints on must be returned, defaults to False + :type draw: bool, optional + :return: the keypoints and the frame with keypoints on if draw is set to True + :rtype: np.ndarray + """ + print("Extracting keypoints from frame") + # Convert the BGR image to RGB and process it with MediaPipe Pose. + results = self.holistic.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + if draw: + # Draw the pose annotations on the image + draw_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + self.mp_drawing.draw_landmarks(draw_image, results.face_landmarks, self.mp_holistic.FACEMESH_CONTOURS) + self.mp_drawing.draw_landmarks(draw_image, results.left_hand_landmarks, self.mp_holistic.HAND_CONNECTIONS) + self.mp_drawing.draw_landmarks(draw_image, results.right_hand_landmarks, self.mp_holistic.HAND_CONNECTIONS) + self.mp_drawing.draw_landmarks(draw_image, results.pose_landmarks, self.mp_holistic.POSE_CONNECTIONS) + + return results, draw_image + + return results \ No newline at end of file diff --git a/test_output.jpg b/test_output.jpg index a4dca74..17c0844 100644 Binary files a/test_output.jpg and b/test_output.jpg differ