import json from matplotlib import pyplot as plt def load_results(): with open("predictions/test_results/knn.json", 'r') as f: results = json.load(f) return results def plot_all(): results = load_results() print(f"average elapsed time to detect a sign: {get_general_elapsed_time(results)}") plot_general_accuracy(results) for label in results.keys(): plot_accuracy_per_label(results, label) def general_accuracy(results): label_accuracy = get_label_accuracy(results) accuracy = [] amount = [] response = [] for label in label_accuracy.keys(): for index, value in enumerate(label_accuracy[label]): if index >= len(accuracy): accuracy.append(0) amount.append(0) accuracy[index] += label_accuracy[label][index] amount[index] += 1 for a, b in zip(accuracy, amount): if b < 5: break response.append(a / b) return response def plot_general_accuracy(results): accuracy = general_accuracy(results) plt.plot(accuracy) plt.title = "General accuracy" plt.ylabel('accuracy') plt.xlabel('buffer') plt.show() def plot_accuracy_per_label(results, label): accuracy = get_label_accuracy(results) plt.plot(accuracy[label], label=label) plt.titel = f"Accuracy per label {label}" plt.ylabel('accuracy') plt.xlabel('prediction') plt.legend() plt.show() def get_label_accuracy(results): accuracy = {} amount = {} response = {} for label, predictions in results.items(): if label not in accuracy: accuracy[label] = [] amount[label] = [] for prediction in predictions: for index, value in enumerate(prediction["predictions"]): if index >= len(accuracy[label]): accuracy[label].append(0) amount[label].append(0) accuracy[label][index] += 1 if value["correct"] else 0 amount[label][index] += 1 for label in accuracy: response[label] = [] for index, value in enumerate(accuracy[label]): if amount[label][index] < 2: break response[label].append(accuracy[label][index] / amount[label][index]) return response def get_general_elapsed_time(results): label_time = get_label_elapsed_time(results) return sum([label_time[label] for label in results]) / len(results) def get_label_elapsed_time(results): return {label: sum([result["elapsed_time"] for result in results[label]]) / len(results[label]) for label in results} if __name__ == '__main__': plot_all()