sign-predictor/dataset.py

28 lines
941 B
Python

import torch
import pandas as pd
from PIL import Image
import json
class WLASLDataset(torch.utils.data.Dataset):
def __init__(self, csv_file: str, video_dir: str, subset:str="train", keypoints_file: str = "keypoints.csv", transform=None):
self.df = pd.read_csv(csv_file)
# filter wlasl data by subset
self.df = self.df[self.df["subset"] == subset]
self.video_dir = video_dir
self.transform = transform
self.subset = subset
self.keypoints_file = keypoints_file
def __len__(self):
return len(self.df)
def __getitem__(self, index):
video_id = self.df.iloc[index]["video_id"]
# check if keypoints file exists
if not os.path.exists(self.keypoints_file):
# create empty dataframe
keypoints_df = pd.DataFrame(columns=["video_id", "keypoints"])
# check if keypoints are available else extract from video