import json import os import time import cv2 import numpy as np from matplotlib import pyplot as plt from predictions.k_nearest import KNearestNeighbours from predictions.predictor import Predictor from predictions.svm_model import SVM buffer_size = 15 def predict_video(predictor, path_video): # open mp4 video cap = cv2.VideoCapture(path_video) buffer = [] ret, img = cap.read() # read one frame from the 'capture' object; img is (H, W, C) desired_fps = 15 original_fps = int(cap.get(cv2.CAP_PROP_FPS)) print("Original FPS: ", original_fps) # Calculate the frame skipping rate based on desired frame rate frame_skip = original_fps // desired_fps if frame_skip == 0: frame_skip = 1 print("Frame skip: ", frame_skip) frame_number = 0 while img is not None: pose = predictor.extract_keypoints(img) if pose is not None and frame_number % frame_skip == 0: buffer.append(pose) frame_number += 1 ret, img = cap.read() # read one frame from the 'capture' object; img is (H, W, C) print(len(buffer)) return buffer def get_embeddings(predictor, buffer, name): # check if file exists with name # if os.path.exists("predictions/test_embeddings/" + name + ".csv"): # print("Loading embeddings from file") # # load embeddings from file # with open("predictions/test_embeddings/" + name + ".csv", 'r') as f: # embeddings = json.load(f) # else: embeddings = [] for index in range(buffer_size, len(buffer)): embedding = predictor.get_embedding(buffer[index - buffer_size:index]) embeddings.append(embedding) with open("predictions/test_embeddings/" + name + ".csv", 'w') as f: json.dump(embeddings, f) return embeddings def compare_embeddings(predictor, embeddings, label_video, ): results = [] for embedding in embeddings: label, score = predictor.predict(embedding) results.append({"label": label, "score": score, "label_video": label_video, "correct": label == label_video}) return results def predict_video_files(predictor, path_video, label_video): buffer = predict_video(predictor, path_video) embeddings = get_embeddings(predictor, buffer, path_video.split("/")[-1].split(".")[0]) return compare_embeddings(predictor, embeddings, label_video) def get_test_data(data_folder): files = np.array([data_folder + f for f in os.listdir(data_folder) if f.endswith(".mp4")]) train_test = [f.split("/")[-1].split("!")[1] for f in files] test_files = files[np.array(train_test) == "test"] test_labels = [f.split("/")[-1].split("!")[0] for f in test_files] return test_files, test_labels def test_data(predictor, data_folder): results = {} for path_video, label_video in zip(*get_test_data(data_folder)): print(path_video, label_video) start_time = time.time() prediction = predict_video_files(predictor, path_video, label_video) end_time = time.time() elapsed_time = end_time - start_time # divide elapsed time by amount of predictions made so it represents an avarage execution time if len(prediction) > 0: elapsed_time /= len(prediction) if label_video not in results: results[label_video] = [] results[label_video].append({"predictions": prediction, "elapsed_time": elapsed_time, "video": path_video}) print("DONE") return results def plot_general_accuracy(results): accuracy = [] amount = [] for result in results: for index, value in enumerate(result[0]): if len(accuracy) <= index: accuracy.append(0) amount.append(0) accuracy[index] += 1 if value["correct"] else 0 amount[index] += 1 # plot the general accuracy plt.plot(accuracy) plt.show() if __name__ == "__main__": type_predictor = "knn" if type_predictor == "knn": k = 1 predictor_type = KNearestNeighbours(k) elif type_predictor == "svm": predictor_type = SVM() else: predictor_type = KNearestNeighbours(1) # embeddings_path = 'embeddings/basic-signs/embeddings.csv' embeddings_path = 'embeddings/fingerspelling/embeddings.csv' predictor = Predictor(embeddings_path, predictor_type) data_folder = '/home/tibe/Projects/design_project/sign-predictor/data/fingerspelling/data/' results = test_data(predictor, data_folder) # write results to a results json file with open("predictions/test_results/" + type_predictor + ".json", 'w') as f: json.dump(results, f) print(results) # plot_general_accuracy(results)