Files
spoterembedding/predictions/validation.py

138 lines
4.6 KiB
Python

import json
import os
import time
import cv2
import numpy as np
from matplotlib import pyplot as plt
from predictions.k_nearest import KNearestNeighbours
from predictions.predictor import Predictor
from predictions.svm_model import SVM
buffer_size = 15
def predict_video(predictor, path_video):
# open mp4 video
cap = cv2.VideoCapture(path_video)
buffer = []
ret, img = cap.read() # read one frame from the 'capture' object; img is (H, W, C)
desired_fps = 15
original_fps = int(cap.get(cv2.CAP_PROP_FPS))
print("Original FPS: ", original_fps)
# Calculate the frame skipping rate based on desired frame rate
frame_skip = original_fps // desired_fps
if frame_skip == 0:
frame_skip = 1
print("Frame skip: ", frame_skip)
frame_number = 0
while img is not None:
pose = predictor.extract_keypoints(img)
if pose is not None and frame_number % frame_skip == 0:
buffer.append(pose)
frame_number += 1
ret, img = cap.read() # read one frame from the 'capture' object; img is (H, W, C)
print(len(buffer))
return buffer
def get_embeddings(predictor, buffer, name):
# check if file exists with name
# if os.path.exists("predictions/test_embeddings/" + name + ".csv"):
# print("Loading embeddings from file")
# # load embeddings from file
# with open("predictions/test_embeddings/" + name + ".csv", 'r') as f:
# embeddings = json.load(f)
# else:
embeddings = []
for index in range(buffer_size, len(buffer)):
embedding = predictor.get_embedding(buffer[index - buffer_size:index])
embeddings.append(embedding)
with open("predictions/test_embeddings/" + name + ".csv", 'w') as f:
json.dump(embeddings, f)
return embeddings
def compare_embeddings(predictor, embeddings, label_video, ):
results = []
for embedding in embeddings:
label, score = predictor.predict(embedding)
results.append({"label": label, "score": score, "label_video": label_video, "correct": label == label_video})
return results
def predict_video_files(predictor, path_video, label_video):
buffer = predict_video(predictor, path_video)
embeddings = get_embeddings(predictor, buffer, path_video.split("/")[-1].split(".")[0])
return compare_embeddings(predictor, embeddings, label_video)
def get_test_data(data_folder):
files = np.array([data_folder + f for f in os.listdir(data_folder) if f.endswith(".mp4")])
train_test = [f.split("/")[-1].split("!")[1] for f in files]
test_files = files[np.array(train_test) == "test"]
test_labels = [f.split("/")[-1].split("!")[0] for f in test_files]
return test_files, test_labels
def test_data(predictor, data_folder):
results = {}
for path_video, label_video in zip(*get_test_data(data_folder)):
print(path_video, label_video)
start_time = time.time()
prediction = predict_video_files(predictor, path_video, label_video)
end_time = time.time()
elapsed_time = end_time - start_time
# divide elapsed time by amount of predictions made so it represents an avarage execution time
if len(prediction) > 0:
elapsed_time /= len(prediction)
if label_video not in results:
results[label_video] = []
results[label_video].append({"predictions": prediction, "elapsed_time": elapsed_time, "video": path_video})
print("DONE")
return results
def plot_general_accuracy(results):
accuracy = []
amount = []
for result in results:
for index, value in enumerate(result[0]):
if len(accuracy) <= index:
accuracy.append(0)
amount.append(0)
accuracy[index] += 1 if value["correct"] else 0
amount[index] += 1
# plot the general accuracy
plt.plot(accuracy)
plt.show()
if __name__ == "__main__":
type_predictor = "knn"
if type_predictor == "knn":
k = 1
predictor_type = KNearestNeighbours(k)
elif type_predictor == "svm":
predictor_type = SVM()
else:
predictor_type = KNearestNeighbours(1)
# embeddings_path = 'embeddings/basic-signs/embeddings.csv'
embeddings_path = 'embeddings/fingerspelling/embeddings.csv'
predictor = Predictor(embeddings_path, predictor_type)
data_folder = '/home/tibe/Projects/design_project/sign-predictor/data/fingerspelling/data/'
results = test_data(predictor, data_folder)
# write results to a results json file
with open("predictions/test_results/" + type_predictor + ".json", 'w') as f:
json.dump(results, f)
print(results)
# plot_general_accuracy(results)