From 8e5957f4ff274d49f31a1067f9d9cf3498a3a145 Mon Sep 17 00:00:00 2001 From: Victor Mylle Date: Mon, 27 Feb 2023 13:34:26 +0000 Subject: [PATCH] Implement pytorch dataset for own collected data --- .gitignore | 5 +- src/__init__.py | 0 src/datasets/__init__.py | 0 src/datasets/finger_spelling_dataset.py | 76 +++++++++++++++++++ src/{dataset.py => datasets/wlasl_dataset.py} | 0 src/keypoint_extractor.py | 23 ++++-- src/train.py | 2 +- 7 files changed, 99 insertions(+), 7 deletions(-) create mode 100644 src/__init__.py create mode 100644 src/datasets/__init__.py create mode 100644 src/datasets/finger_spelling_dataset.py rename src/{dataset.py => datasets/wlasl_dataset.py} (100%) diff --git a/.gitignore b/.gitignore index b804a61..d2e2037 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,8 @@ .devcontainer/ data/ .DS_Store + cache/ -__pycache__/ \ No newline at end of file +cache_wlasl/ + +__pycache__/ diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/datasets/__init__.py b/src/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/datasets/finger_spelling_dataset.py b/src/datasets/finger_spelling_dataset.py new file mode 100644 index 0000000..bfa6a64 --- /dev/null +++ b/src/datasets/finger_spelling_dataset.py @@ -0,0 +1,76 @@ +import os + +import numpy as np +import torch +from sklearn.model_selection import train_test_split + +from src.identifiers import LANDMARKS +from src.keypoint_extractor import KeypointExtractor + + +class FingerSpellingDataset(torch.utils.data.Dataset): + def __init__(self, data_folder: str, keypoint_extractor: KeypointExtractor, subset:str="train", keypoints_identifier: dict = None, transform=None): + + # list data from data folder + self.data_folder = data_folder + + # list files in the datafolder ending with .mp4 + files = [f for f in os.listdir(self.data_folder) if f.endswith(".mp4")] + + labels = [f.split("!")[0] for f in files] + + # count the number of each label + self.label_mapping, counts = np.unique(labels, return_counts=True) + + # save the label mapping to a file + with open(os.path.join(self.data_folder, "label_mapping.txt"), "w") as f: + for i, label in enumerate(self.label_mapping): + f.write(f"{label} {i}") + + # map the labels to their integer + labels = [np.where(self.label_mapping == label)[0][0] for label in labels] + + # TODO: make split for train and val and test when enough data is available + + # split the data into train and val and test and make them balanced + x_train, x_test, y_train, y_test = train_test_split(files, labels, test_size=0.4, random_state=1, stratify=labels) + + if subset == "train": + self.data = x_train + self.labels = y_train + elif subset == "val": + self.data = x_test + self.labels = y_test + + # filter wlasl data by subset + self.transform = transform + self.subset = subset + self.keypoint_extractor = keypoint_extractor + if keypoints_identifier: + self.keypoints_to_keep = [f"{i}_{j}" for i in keypoints_identifier.values() for j in ["x", "y"]] + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + # get i th element from ordered dict + video_name = self.data[index] + + # get the keypoints for the video + keypoints_df = self.keypoint_extractor.extract_keypoints_from_video(video_name) + + # filter the keypoints by the identified subset + if self.keypoints_to_keep: + keypoints_df = keypoints_df[self.keypoints_to_keep] + + current_row = np.empty(shape=(keypoints_df.shape[0], keypoints_df.shape[1] // 2, 2)) + for i in range(0, keypoints_df.shape[1], 2): + current_row[:, i//2, 0] = keypoints_df.iloc[:,i] + current_row[:, i//2, 1] = keypoints_df.iloc[:,i+1] + + label = self.labels[index] + + # data to tensor + data = torch.from_numpy(current_row) + + return data, label \ No newline at end of file diff --git a/src/dataset.py b/src/datasets/wlasl_dataset.py similarity index 100% rename from src/dataset.py rename to src/datasets/wlasl_dataset.py diff --git a/src/keypoint_extractor.py b/src/keypoint_extractor.py index a0ee964..7dee850 100644 --- a/src/keypoint_extractor.py +++ b/src/keypoint_extractor.py @@ -1,12 +1,14 @@ -import mediapipe as mp -import cv2 -import time -from typing import Dict, List, Tuple -import numpy as np import logging import os +import time +from typing import Dict, List, Tuple + +import cv2 +import mediapipe as mp +import numpy as np import pandas as pd + class KeypointExtractor: def __init__(self, video_folder: str, cache_folder: str = "cache"): self.mp_drawing = mp.solutions.drawing_utils @@ -52,7 +54,18 @@ class KeypointExtractor: keypoints_df = pd.DataFrame(columns=self.columns) + # extract frames from video so we extract 5 frames per second + frame_rate = int(cap.get(cv2.CAP_PROP_FPS)) + frame_skip = frame_rate // 5 + while cap.isOpened(): + + # skip frames + for _ in range(frame_skip): + success, image = cap.read() + if not success: + break + success, image = cap.read() if not success: break diff --git a/src/train.py b/src/train.py index 245f6c5..e58d284 100644 --- a/src/train.py +++ b/src/train.py @@ -13,7 +13,7 @@ import torch.optim as optim from torch.utils.data import DataLoader from torchvision import transforms -from dataset import WLASLDataset +from datasets.wlasl_dataset import WLASLDataset from identifiers import LANDMARKS from keypoint_extractor import KeypointExtractor from model import SPOTER