87 lines
2.6 KiB
Python
87 lines
2.6 KiB
Python
import json
|
|
|
|
from matplotlib import pyplot as plt
|
|
|
|
|
|
def load_results():
|
|
with open("predictions/test_results/knn.json", 'r') as f:
|
|
results = json.load(f)
|
|
return results
|
|
|
|
def plot_all():
|
|
results = load_results()
|
|
print(f"average elapsed time to detect a sign: {get_general_elapsed_time(results)}")
|
|
plot_general_accuracy(results)
|
|
for label in results.keys():
|
|
plot_accuracy_per_label(results, label)
|
|
|
|
|
|
def general_accuracy(results):
|
|
label_accuracy = get_label_accuracy(results)
|
|
accuracy = []
|
|
amount = []
|
|
response = []
|
|
for label in label_accuracy.keys():
|
|
for index, value in enumerate(label_accuracy[label]):
|
|
if index >= len(accuracy):
|
|
accuracy.append(0)
|
|
amount.append(0)
|
|
accuracy[index] += label_accuracy[label][index]
|
|
amount[index] += 1
|
|
for a, b in zip(accuracy, amount):
|
|
if b < 5:
|
|
break
|
|
response.append(a / b)
|
|
return response
|
|
def plot_general_accuracy(results):
|
|
accuracy = general_accuracy(results)
|
|
plt.plot(accuracy)
|
|
plt.title = "General accuracy"
|
|
plt.ylabel('accuracy')
|
|
plt.xlabel('buffer')
|
|
plt.show()
|
|
|
|
|
|
def plot_accuracy_per_label(results, label):
|
|
accuracy = get_label_accuracy(results)
|
|
plt.plot(accuracy[label], label=label)
|
|
plt.titel = f"Accuracy per label {label}"
|
|
plt.ylabel('accuracy')
|
|
plt.xlabel('prediction')
|
|
plt.legend()
|
|
plt.show()
|
|
|
|
def get_label_accuracy(results):
|
|
accuracy = {}
|
|
amount = {}
|
|
response = {}
|
|
for label, predictions in results.items():
|
|
if label not in accuracy:
|
|
accuracy[label] = []
|
|
amount[label] = []
|
|
for prediction in predictions:
|
|
for index, value in enumerate(prediction["predictions"]):
|
|
if index >= len(accuracy[label]):
|
|
accuracy[label].append(0)
|
|
amount[label].append(0)
|
|
accuracy[label][index] += 1 if value["correct"] else 0
|
|
amount[label][index] += 1
|
|
for label in accuracy:
|
|
response[label] = []
|
|
for index, value in enumerate(accuracy[label]):
|
|
if amount[label][index] < 2:
|
|
break
|
|
response[label].append(accuracy[label][index] / amount[label][index])
|
|
return response
|
|
|
|
def get_general_elapsed_time(results):
|
|
label_time = get_label_elapsed_time(results)
|
|
return sum([label_time[label] for label in results]) / len(results)
|
|
|
|
def get_label_elapsed_time(results):
|
|
return {label: sum([result["elapsed_time"] for result in results[label]]) / len(results[label]) for label in results}
|
|
|
|
|
|
if __name__ == '__main__':
|
|
plot_all()
|