32 lines
889 B
Python
32 lines
889 B
Python
import pandas as pd
|
|
import json
|
|
|
|
from normalization.blazepose_mapping import map_blazepose_df
|
|
|
|
# split the dataset into train and test set
|
|
dataset = "data/processed/spoter.csv"
|
|
|
|
# read the dataset
|
|
df = pd.read_csv(dataset)
|
|
|
|
with open("data/sign_to_prediction_index_map.json", "r") as f:
|
|
sign_to_prediction_index_max = json.load(f)
|
|
|
|
|
|
# filter df to make sure each sign has at least 4 samples
|
|
df = df[df["sign"].map(df["sign"].value_counts()) > 4]
|
|
|
|
# use the path column to split the dataset
|
|
paths = df["path"].unique()
|
|
|
|
# split the dataset into train and test set
|
|
train_paths = paths[:int(len(paths) * 0.8)]
|
|
|
|
# create the train and test set
|
|
train_df = df[df["path"].isin(train_paths)]
|
|
test_df = df[~df["path"].isin(train_paths)]
|
|
|
|
# save the train and test set
|
|
train_df.to_csv("data/processed/spoter_train.csv", index=False)
|
|
test_df.to_csv("data/processed/spoter_test.csv", index=False)
|