28 lines
941 B
Python
28 lines
941 B
Python
import torch
|
|
import pandas as pd
|
|
from PIL import Image
|
|
import json
|
|
|
|
class WLASLDataset(torch.utils.data.Dataset):
|
|
def __init__(self, csv_file: str, video_dir: str, subset:str="train", keypoints_file: str = "keypoints.csv", transform=None):
|
|
self.df = pd.read_csv(csv_file)
|
|
# filter wlasl data by subset
|
|
self.df = self.df[self.df["subset"] == subset]
|
|
self.video_dir = video_dir
|
|
self.transform = transform
|
|
self.subset = subset
|
|
self.keypoints_file = keypoints_file
|
|
|
|
def __len__(self):
|
|
return len(self.df)
|
|
|
|
def __getitem__(self, index):
|
|
video_id = self.df.iloc[index]["video_id"]
|
|
|
|
# check if keypoints file exists
|
|
if not os.path.exists(self.keypoints_file):
|
|
# create empty dataframe
|
|
keypoints_df = pd.DataFrame(columns=["video_id", "keypoints"])
|
|
|
|
# check if keypoints are available else extract from video
|
|
|