First training
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -6,3 +6,5 @@ cache/
|
||||
cache_wlasl/
|
||||
|
||||
__pycache__/
|
||||
|
||||
checkpoints/
|
||||
BIN
models/spoter_40.pth
Normal file
BIN
models/spoter_40.pth
Normal file
Binary file not shown.
@@ -3,3 +3,4 @@ torchvision==0.14.1
|
||||
pandas==1.5.3
|
||||
mediapipe==0.9.1.0
|
||||
tensorboard==2.12.0
|
||||
mediapy==1.1.6
|
||||
11
src/augmentations.py
Normal file
11
src/augmentations.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import random
|
||||
|
||||
|
||||
class MirrorKeypoints:
|
||||
def __call__(self, sample):
|
||||
if random.random() > 0.5:
|
||||
return sample
|
||||
# flip the keypoints tensor
|
||||
sample = 1 - sample
|
||||
|
||||
return sample
|
||||
@@ -4,8 +4,8 @@ import numpy as np
|
||||
import torch
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from src.identifiers import LANDMARKS
|
||||
from src.keypoint_extractor import KeypointExtractor
|
||||
from identifiers import LANDMARKS
|
||||
from keypoint_extractor import KeypointExtractor
|
||||
|
||||
|
||||
class FingerSpellingDataset(torch.utils.data.Dataset):
|
||||
@@ -33,7 +33,7 @@ class FingerSpellingDataset(torch.utils.data.Dataset):
|
||||
# TODO: make split for train and val and test when enough data is available
|
||||
|
||||
# split the data into train and val and test and make them balanced
|
||||
x_train, x_test, y_train, y_test = train_test_split(files, labels, test_size=0.4, random_state=1, stratify=labels)
|
||||
x_train, x_test, y_train, y_test = train_test_split(files, labels, test_size=0.3, random_state=1, stratify=labels)
|
||||
|
||||
if subset == "train":
|
||||
self.data = x_train
|
||||
@@ -57,7 +57,7 @@ class FingerSpellingDataset(torch.utils.data.Dataset):
|
||||
video_name = self.data[index]
|
||||
|
||||
# get the keypoints for the video
|
||||
keypoints_df = self.keypoint_extractor.extract_keypoints_from_video(video_name)
|
||||
keypoints_df = self.keypoint_extractor.extract_keypoints_from_video(video_name, normalize=True)
|
||||
|
||||
# filter the keypoints by the identified subset
|
||||
if self.keypoints_to_keep:
|
||||
@@ -73,4 +73,7 @@ class FingerSpellingDataset(torch.utils.data.Dataset):
|
||||
# data to tensor
|
||||
data = torch.from_numpy(current_row)
|
||||
|
||||
if self.transform:
|
||||
data = self.transform(data)
|
||||
|
||||
return data, label
|
||||
@@ -27,6 +27,8 @@ class KeypointExtractor:
|
||||
|
||||
def extract_keypoints_from_video(self,
|
||||
video: str,
|
||||
normalize: bool = False,
|
||||
draw: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""extract_keypoints_from_video this function extracts keypoints from a video and stores them in a dataframe
|
||||
|
||||
@@ -35,6 +37,8 @@ class KeypointExtractor:
|
||||
:return: dataframe with keypoints
|
||||
:rtype: pd.DataFrame
|
||||
"""
|
||||
|
||||
if not draw:
|
||||
# check if video exists
|
||||
if not os.path.exists(self.video_folder + video):
|
||||
logging.error("Video does not exist at path: " + self.video_folder + video)
|
||||
@@ -47,7 +51,10 @@ class KeypointExtractor:
|
||||
# check if cache file exists and return
|
||||
if os.path.exists(self.cache_folder + "/" + video + ".npy"):
|
||||
# create dataframe from cache
|
||||
return pd.DataFrame(np.load(self.cache_folder + "/" + video + ".npy", allow_pickle=True), columns=self.columns)
|
||||
df = pd.DataFrame(np.load(self.cache_folder + "/" + video + ".npy", allow_pickle=True), columns=self.columns)
|
||||
if normalize:
|
||||
df = self.normalize_hands(df)
|
||||
return df
|
||||
|
||||
# open video
|
||||
cap = cv2.VideoCapture(self.video_folder + video)
|
||||
@@ -56,7 +63,9 @@ class KeypointExtractor:
|
||||
|
||||
# extract frames from video so we extract 5 frames per second
|
||||
frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
frame_skip = frame_rate // 5
|
||||
frame_skip = frame_rate // 10
|
||||
|
||||
output_frames = []
|
||||
|
||||
while cap.isOpened():
|
||||
|
||||
@@ -70,6 +79,10 @@ class KeypointExtractor:
|
||||
if not success:
|
||||
break
|
||||
# extract keypoints of frame
|
||||
if draw:
|
||||
results, draw_image = self.extract_keypoints_from_frame(image, draw=True)
|
||||
output_frames.append(draw_image)
|
||||
else:
|
||||
results = self.extract_keypoints_from_frame(image)
|
||||
|
||||
def extract_keypoints(landmarks):
|
||||
@@ -80,8 +93,10 @@ class KeypointExtractor:
|
||||
k1 = extract_keypoints(results.pose_landmarks)
|
||||
k2 = extract_keypoints(results.left_hand_landmarks)
|
||||
k3 = extract_keypoints(results.right_hand_landmarks)
|
||||
if k1 and k2 and k3:
|
||||
keypoints_df = pd.concat([keypoints_df, pd.DataFrame([k1+k2+k3], columns=self.columns)])
|
||||
if k1 and (k2 or k3):
|
||||
data = [k1 + (k2 or [0] * 42) + (k3 or [0] * 42)]
|
||||
new_df = pd.DataFrame(data, columns=self.columns)
|
||||
keypoints_df = pd.concat([keypoints_df, new_df], ignore_index=True)
|
||||
|
||||
# close video
|
||||
cap.release()
|
||||
@@ -89,6 +104,12 @@ class KeypointExtractor:
|
||||
# save keypoints to cache
|
||||
np.save(self.cache_folder + "/" + video + ".npy", keypoints_df.to_numpy())
|
||||
|
||||
if normalize:
|
||||
keypoints_df = self.normalize_hands(keypoints_df)
|
||||
|
||||
if draw:
|
||||
return keypoints_df, output_frames
|
||||
|
||||
return keypoints_df
|
||||
|
||||
|
||||
@@ -108,11 +129,81 @@ class KeypointExtractor:
|
||||
if draw:
|
||||
# Draw the pose annotations on the image
|
||||
draw_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
self.mp_drawing.draw_landmarks(draw_image, results.face_landmarks, self.mp_holistic.FACEMESH_CONTOURS)
|
||||
# self.mp_drawing.draw_landmarks(draw_image, results.face_landmarks, self.mp_holistic.FACEMESH_CONTOURS)
|
||||
self.mp_drawing.draw_landmarks(draw_image, results.left_hand_landmarks, self.mp_holistic.HAND_CONNECTIONS)
|
||||
self.mp_drawing.draw_landmarks(draw_image, results.right_hand_landmarks, self.mp_holistic.HAND_CONNECTIONS)
|
||||
|
||||
# create bounding box around hands
|
||||
if results.left_hand_landmarks:
|
||||
x = [landmark.x for landmark in results.left_hand_landmarks.landmark]
|
||||
y = [landmark.y for landmark in results.left_hand_landmarks.landmark]
|
||||
draw_image = cv2.rectangle(draw_image, (int(min(x) * 640), int(min(y) * 480)), (int(max(x) * 640), int(max(y) * 480)), (255, 0, 0), 2)
|
||||
|
||||
if results.right_hand_landmarks:
|
||||
x = [landmark.x for landmark in results.right_hand_landmarks.landmark]
|
||||
y = [landmark.y for landmark in results.right_hand_landmarks.landmark]
|
||||
draw_image = cv2.rectangle(draw_image, (int(min(x) * 640), int(min(y) * 480)), (int(max(x) * 640), int(max(y) * 480)), (255, 0, 0), 2)
|
||||
|
||||
self.mp_drawing.draw_landmarks(draw_image, results.pose_landmarks, self.mp_holistic.POSE_CONNECTIONS)
|
||||
|
||||
return results, draw_image
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def normalize_hands(self, dataframe: pd.DataFrame) -> pd.DataFrame:
|
||||
"""normalize_hand this function normalizes the hand keypoints of a dataframe
|
||||
|
||||
:param dataframe: the dataframe to normalize
|
||||
:type dataframe: pd.DataFrame
|
||||
:return: the normalized dataframe
|
||||
:rtype: pd.DataFrame
|
||||
"""
|
||||
|
||||
# normalize left hand
|
||||
dataframe = self.normalize_hand_helper(dataframe, "left_hand")
|
||||
|
||||
# normalize right hand
|
||||
dataframe = self.normalize_hand_helper(dataframe, "right_hand")
|
||||
|
||||
return dataframe
|
||||
|
||||
def normalize_hand_helper(self, dataframe: pd.DataFrame, hand: str) -> pd.DataFrame:
|
||||
"""normalize_hand_helper this function normalizes the hand keypoints of a dataframe
|
||||
|
||||
:param dataframe: the dataframe to normalize
|
||||
:type dataframe: pd.DataFrame
|
||||
:param hand: the hand to normalize
|
||||
:type hand: str
|
||||
:return: the normalized dataframe
|
||||
:rtype: pd.DataFrame
|
||||
"""
|
||||
# get all columns that belong to the hand (left hand column 66 - 107, right hand column 108 - 149)
|
||||
hand_columns = np.array([i for i in range(66 + (42 if hand == "right_hand" else 0), 108 + (42 if hand == "right_hand" else 0))])
|
||||
|
||||
# get the x, y coordinates of the hand keypoints
|
||||
hand_coords = dataframe.iloc[:, hand_columns].values.reshape(-1, 21, 2)
|
||||
|
||||
# get the min and max x, y coordinates of the hand keypoints
|
||||
min_x, min_y = np.min(hand_coords[:, :, 0], axis=1), np.min(hand_coords[:, :, 1], axis=1)
|
||||
max_x, max_y = np.max(hand_coords[:, :, 0], axis=1), np.max(hand_coords[:, :, 1], axis=1)
|
||||
|
||||
# calculate the center of the hand keypoints
|
||||
center_x, center_y = (min_x + max_x) / 2, (min_y + max_y) / 2
|
||||
|
||||
# calculate the width and height of the bounding box around the hand keypoints
|
||||
bbox_width, bbox_height = max_x - min_x, max_y - min_y
|
||||
|
||||
# repeat the center coordinates and bounding box dimensions to match the shape of hand_coords
|
||||
center_coords = np.tile(np.array([center_x, center_y]), (21, 1)).reshape(-1, 21, 2)
|
||||
bbox_dims = np.tile(np.array([bbox_width, bbox_height]), (21, 1)).reshape(-1, 21, 2)
|
||||
|
||||
if np.any(bbox_dims == 0):
|
||||
return dataframe
|
||||
# normalize the hand keypoints based on the bounding box around the hand
|
||||
norm_hand_coords = (hand_coords - center_coords) / bbox_dims
|
||||
|
||||
# flatten the normalized hand keypoints array and replace the original hand keypoints with the normalized hand keypoints in the dataframe
|
||||
dataframe.iloc[:, hand_columns] = norm_hand_coords.reshape(-1, 42)
|
||||
|
||||
return dataframe
|
||||
52
src/train.py
52
src/train.py
@@ -13,6 +13,8 @@ import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
|
||||
from augmentations import MirrorKeypoints
|
||||
from datasets.finger_spelling_dataset import FingerSpellingDataset
|
||||
from datasets.wlasl_dataset import WLASLDataset
|
||||
from identifiers import LANDMARKS
|
||||
from keypoint_extractor import KeypointExtractor
|
||||
@@ -32,30 +34,28 @@ def train():
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
spoter_model = SPOTER(num_classes=100, hidden_dim=len(LANDMARKS) *2)
|
||||
spoter_model = SPOTER(num_classes=5, hidden_dim=len(LANDMARKS) *2)
|
||||
spoter_model.train(True)
|
||||
spoter_model.to(device)
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(spoter_model.parameters(), lr=0.001, momentum=0.9)
|
||||
optimizer = optim.SGD(spoter_model.parameters(), lr=0.0001, momentum=0.9)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5)
|
||||
|
||||
# TODO: create paths for checkpoints
|
||||
|
||||
# TODO: transformations + augmentations
|
||||
|
||||
k = KeypointExtractor("data/videos/")
|
||||
k = KeypointExtractor("data/fingerspelling/data/")
|
||||
|
||||
train_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="train")
|
||||
transform = transforms.Compose([MirrorKeypoints()])
|
||||
|
||||
train_set = FingerSpellingDataset("data/fingerspelling/data/", k, keypoints_identifier=LANDMARKS, subset="train", transform=transform)
|
||||
train_loader = DataLoader(train_set, shuffle=True, generator=g)
|
||||
|
||||
val_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="val")
|
||||
val_set = FingerSpellingDataset("data/fingerspelling/data/", k, keypoints_identifier=LANDMARKS, subset="val")
|
||||
val_loader = DataLoader(val_set, shuffle=True, generator=g)
|
||||
|
||||
test_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="test")
|
||||
test_loader = DataLoader(test_set, shuffle=True, generator=g)
|
||||
|
||||
|
||||
train_acc, val_acc = 0, 0
|
||||
lr_progress = []
|
||||
top_train_acc, top_val_acc = 0, 0
|
||||
@@ -82,31 +82,39 @@ def train():
|
||||
pred_correct += 1
|
||||
pred_all += 1
|
||||
|
||||
if i % 100 == 0:
|
||||
print(f"Epoch: {epoch} | Batch: {i} | Loss: {running_loss.item()} | Train Acc: {(pred_correct / pred_all)}")
|
||||
# if i % 100 == 0:
|
||||
# print(f"Epoch: {epoch} | Batch: {i} | Loss: {running_loss.item()} | Train Acc: {(pred_correct / pred_all)}")
|
||||
|
||||
if scheduler:
|
||||
scheduler.step(running_loss.item() / len(train_loader))
|
||||
|
||||
# validate
|
||||
# validate and print val acc
|
||||
val_pred_correct, val_pred_all = 0, 0
|
||||
with torch.no_grad():
|
||||
for i, (inputs, labels) in enumerate(val_loader):
|
||||
inputs = inputs.squeeze(0).to(device)
|
||||
labels = labels.to(device)
|
||||
labels = labels.to(device, dtype=torch.long)
|
||||
|
||||
outputs = spoter_model(inputs)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
val_acc = (predicted == labels).sum().item() / labels.size(0)
|
||||
outputs = spoter_model(inputs).expand(1, -1, -1)
|
||||
|
||||
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]):
|
||||
val_pred_correct += 1
|
||||
val_pred_all += 1
|
||||
|
||||
val_acc = (val_pred_correct / val_pred_all)
|
||||
|
||||
print(f"Epoch: {epoch} | Train Acc: {(pred_correct / pred_all)} | Val Acc: {val_acc}")
|
||||
|
||||
|
||||
# save checkpoint
|
||||
# if val_acc > top_val_acc:
|
||||
# top_val_acc = val_acc
|
||||
# top_train_acc = train_acc
|
||||
# checkpoint_index = epoch
|
||||
# torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
|
||||
if val_acc > top_val_acc:
|
||||
top_val_acc = val_acc
|
||||
top_train_acc = train_acc
|
||||
checkpoint_index = epoch
|
||||
torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
|
||||
|
||||
print(f"Epoch: {epoch} | Train Acc: {train_acc} | Val Acc: {val_acc}")
|
||||
lr_progress.append(optimizer.param_groups[0]['lr'])
|
||||
|
||||
print(f"Best val acc: {top_val_acc} | Best train acc: {top_train_acc} | Epoch: {checkpoint_index}")
|
||||
|
||||
train()
|
||||
145
visualize_data.ipynb
Normal file
145
visualize_data.ipynb
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user