Files
spoterembedding/predictions/plotting.py

87 lines
2.6 KiB
Python

import json
from matplotlib import pyplot as plt
def load_results():
with open("predictions/test_results/knn.json", 'r') as f:
results = json.load(f)
return results
def plot_all():
results = load_results()
print(f"average elapsed time to detect a sign: {get_general_elapsed_time(results)}")
plot_general_accuracy(results)
for label in results.keys():
plot_accuracy_per_label(results, label)
def general_accuracy(results):
label_accuracy = get_label_accuracy(results)
accuracy = []
amount = []
response = []
for label in label_accuracy.keys():
for index, value in enumerate(label_accuracy[label]):
if index >= len(accuracy):
accuracy.append(0)
amount.append(0)
accuracy[index] += label_accuracy[label][index]
amount[index] += 1
for a, b in zip(accuracy, amount):
if b < 5:
break
response.append(a / b)
return response
def plot_general_accuracy(results):
accuracy = general_accuracy(results)
plt.plot(accuracy)
plt.title = "General accuracy"
plt.ylabel('accuracy')
plt.xlabel('buffer')
plt.show()
def plot_accuracy_per_label(results, label):
accuracy = get_label_accuracy(results)
plt.plot(accuracy[label], label=label)
plt.titel = f"Accuracy per label {label}"
plt.ylabel('accuracy')
plt.xlabel('prediction')
plt.legend()
plt.show()
def get_label_accuracy(results):
accuracy = {}
amount = {}
response = {}
for label, predictions in results.items():
if label not in accuracy:
accuracy[label] = []
amount[label] = []
for prediction in predictions:
for index, value in enumerate(prediction["predictions"]):
if index >= len(accuracy[label]):
accuracy[label].append(0)
amount[label].append(0)
accuracy[label][index] += 1 if value["correct"] else 0
amount[label][index] += 1
for label in accuracy:
response[label] = []
for index, value in enumerate(accuracy[label]):
if amount[label][index] < 2:
break
response[label].append(accuracy[label][index] / amount[label][index])
return response
def get_general_elapsed_time(results):
label_time = get_label_elapsed_time(results)
return sum([label_time[label] for label in results]) / len(results)
def get_label_elapsed_time(results):
return {label: sum([result["elapsed_time"] for result in results[label]]) / len(results[label]) for label in results}
if __name__ == '__main__':
plot_all()